|
|
|
@ -29,14 +29,13 @@ import torch.nn as nn
|
|
|
|
|
import torchvision.utils
|
|
|
|
|
|
|
|
|
|
from timm.bits import initialize_device, setup_model_and_optimizer, DeviceEnv, Logger, Tracker,\
|
|
|
|
|
TrainState, TrainServices, TrainCfg, AccuracyTopK, AvgTensor
|
|
|
|
|
TrainState, TrainServices, TrainCfg, AccuracyTopK, AvgTensor, distribute_bn
|
|
|
|
|
from timm.data import create_dataset, create_loader, resolve_data_config, Mixup, FastCollateMixup, AugMixDataset
|
|
|
|
|
from timm.models import create_model, safe_model_name, resume_checkpoint, load_checkpoint,\
|
|
|
|
|
convert_splitbn_model, model_parameters
|
|
|
|
|
from timm.utils import *
|
|
|
|
|
from timm.models import create_model, safe_model_name, convert_splitbn_model
|
|
|
|
|
from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy, JsdCrossEntropy
|
|
|
|
|
from timm.optim import create_optimizer_v2, optimizer_kwargs
|
|
|
|
|
from timm.scheduler import create_scheduler
|
|
|
|
|
from timm.utils import setup_default_logging, random_seed, get_outdir, CheckpointSaver
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
_logger = logging.getLogger('train')
|
|
|
|
@ -290,7 +289,7 @@ def main():
|
|
|
|
|
|
|
|
|
|
train_state, train_cfg = setup_train_task(args, dev_env, mixup_active)
|
|
|
|
|
|
|
|
|
|
data_config, loader_eval, loader_train = setup_data(args, dev_env, mixup_active)
|
|
|
|
|
data_config, loader_eval, loader_train = setup_data(args, train_state.model.default_cfg, dev_env, mixup_active)
|
|
|
|
|
|
|
|
|
|
# setup checkpoint saver
|
|
|
|
|
eval_metric = args.eval_metric
|
|
|
|
@ -347,7 +346,7 @@ def main():
|
|
|
|
|
if dev_env.distributed and args.dist_bn in ('broadcast', 'reduce'):
|
|
|
|
|
if dev_env.primary:
|
|
|
|
|
_logger.info("Distributing BatchNorm running means and vars")
|
|
|
|
|
distribute_bn(model, dev_env.world_size, args.dist_bn == 'reduce')
|
|
|
|
|
distribute_bn(train_state.model, args.dist_bn == 'reduce', dev_env)
|
|
|
|
|
|
|
|
|
|
eval_metrics = evaluate(
|
|
|
|
|
train_state.model,
|
|
|
|
@ -358,7 +357,7 @@ def main():
|
|
|
|
|
|
|
|
|
|
if train_state.model_ema is not None and not args.model_ema_force_cpu:
|
|
|
|
|
if dev_env.distributed and args.dist_bn in ('broadcast', 'reduce'):
|
|
|
|
|
distribute_bn(train_state.model_ema, dev_env.world_size, args.dist_bn == 'reduce')
|
|
|
|
|
distribute_bn(train_state.model_ema, args.dist_bn == 'reduce', dev_env)
|
|
|
|
|
|
|
|
|
|
ema_eval_metrics = evaluate(
|
|
|
|
|
train_state.model_ema.module,
|
|
|
|
@ -469,8 +468,8 @@ def setup_train_task(args, dev_env: DeviceEnv, mixup_active: bool):
|
|
|
|
|
return train_state, train_cfg
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def setup_data(args, dev_env, mixup_active):
|
|
|
|
|
data_config = resolve_data_config(vars(args), model=model, verbose=dev_env.primary)
|
|
|
|
|
def setup_data(args, default_cfg, dev_env, mixup_active):
|
|
|
|
|
data_config = resolve_data_config(vars(args), default_cfg=default_cfg, verbose=dev_env.primary)
|
|
|
|
|
|
|
|
|
|
# create the train and eval datasets
|
|
|
|
|
dataset_train = create_dataset(
|
|
|
|
@ -606,7 +605,7 @@ def after_train_step(
|
|
|
|
|
loss_meter.update(loss, output.shape[0])
|
|
|
|
|
|
|
|
|
|
if state.model_ema is not None:
|
|
|
|
|
state.model_ema.update(model)
|
|
|
|
|
state.model_ema.update(state.model)
|
|
|
|
|
|
|
|
|
|
state = replace(state, step_count_global=state.step_count_global + 1)
|
|
|
|
|
|
|
|
|
|