From 6d90fcf2821d3b948f82d6af22a7c351b8fd5787 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Tue, 18 May 2021 11:34:31 -0700 Subject: [PATCH] Fix distribute_bn and model_ema --- train.py | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/train.py b/train.py index 05da82e2..3f18c8e5 100755 --- a/train.py +++ b/train.py @@ -29,14 +29,13 @@ import torch.nn as nn import torchvision.utils from timm.bits import initialize_device, setup_model_and_optimizer, DeviceEnv, Logger, Tracker,\ - TrainState, TrainServices, TrainCfg, AccuracyTopK, AvgTensor + TrainState, TrainServices, TrainCfg, AccuracyTopK, AvgTensor, distribute_bn from timm.data import create_dataset, create_loader, resolve_data_config, Mixup, FastCollateMixup, AugMixDataset -from timm.models import create_model, safe_model_name, resume_checkpoint, load_checkpoint,\ - convert_splitbn_model, model_parameters -from timm.utils import * +from timm.models import create_model, safe_model_name, convert_splitbn_model from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy, JsdCrossEntropy from timm.optim import create_optimizer_v2, optimizer_kwargs from timm.scheduler import create_scheduler +from timm.utils import setup_default_logging, random_seed, get_outdir, CheckpointSaver _logger = logging.getLogger('train') @@ -290,7 +289,7 @@ def main(): train_state, train_cfg = setup_train_task(args, dev_env, mixup_active) - data_config, loader_eval, loader_train = setup_data(args, dev_env, mixup_active) + data_config, loader_eval, loader_train = setup_data(args, train_state.model.default_cfg, dev_env, mixup_active) # setup checkpoint saver eval_metric = args.eval_metric @@ -347,7 +346,7 @@ def main(): if dev_env.distributed and args.dist_bn in ('broadcast', 'reduce'): if dev_env.primary: _logger.info("Distributing BatchNorm running means and vars") - distribute_bn(model, dev_env.world_size, args.dist_bn == 'reduce') + distribute_bn(train_state.model, args.dist_bn == 'reduce', dev_env) eval_metrics = evaluate( train_state.model, @@ -358,7 +357,7 @@ def main(): if train_state.model_ema is not None and not args.model_ema_force_cpu: if dev_env.distributed and args.dist_bn in ('broadcast', 'reduce'): - distribute_bn(train_state.model_ema, dev_env.world_size, args.dist_bn == 'reduce') + distribute_bn(train_state.model_ema, args.dist_bn == 'reduce', dev_env) ema_eval_metrics = evaluate( train_state.model_ema.module, @@ -469,8 +468,8 @@ def setup_train_task(args, dev_env: DeviceEnv, mixup_active: bool): return train_state, train_cfg -def setup_data(args, dev_env, mixup_active): - data_config = resolve_data_config(vars(args), model=model, verbose=dev_env.primary) +def setup_data(args, default_cfg, dev_env, mixup_active): + data_config = resolve_data_config(vars(args), default_cfg=default_cfg, verbose=dev_env.primary) # create the train and eval datasets dataset_train = create_dataset( @@ -606,7 +605,7 @@ def after_train_step( loss_meter.update(loss, output.shape[0]) if state.model_ema is not None: - state.model_ema.update(model) + state.model_ema.update(state.model) state = replace(state, step_count_global=state.step_count_global + 1)