Fix distribute_bn and model_ema

pull/1239/head
Ross Wightman 3 years ago
parent 74d2829341
commit 6d90fcf282

@ -29,14 +29,13 @@ import torch.nn as nn
import torchvision.utils import torchvision.utils
from timm.bits import initialize_device, setup_model_and_optimizer, DeviceEnv, Logger, Tracker,\ from timm.bits import initialize_device, setup_model_and_optimizer, DeviceEnv, Logger, Tracker,\
TrainState, TrainServices, TrainCfg, AccuracyTopK, AvgTensor TrainState, TrainServices, TrainCfg, AccuracyTopK, AvgTensor, distribute_bn
from timm.data import create_dataset, create_loader, resolve_data_config, Mixup, FastCollateMixup, AugMixDataset from timm.data import create_dataset, create_loader, resolve_data_config, Mixup, FastCollateMixup, AugMixDataset
from timm.models import create_model, safe_model_name, resume_checkpoint, load_checkpoint,\ from timm.models import create_model, safe_model_name, convert_splitbn_model
convert_splitbn_model, model_parameters
from timm.utils import *
from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy, JsdCrossEntropy from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy, JsdCrossEntropy
from timm.optim import create_optimizer_v2, optimizer_kwargs from timm.optim import create_optimizer_v2, optimizer_kwargs
from timm.scheduler import create_scheduler from timm.scheduler import create_scheduler
from timm.utils import setup_default_logging, random_seed, get_outdir, CheckpointSaver
_logger = logging.getLogger('train') _logger = logging.getLogger('train')
@ -290,7 +289,7 @@ def main():
train_state, train_cfg = setup_train_task(args, dev_env, mixup_active) train_state, train_cfg = setup_train_task(args, dev_env, mixup_active)
data_config, loader_eval, loader_train = setup_data(args, dev_env, mixup_active) data_config, loader_eval, loader_train = setup_data(args, train_state.model.default_cfg, dev_env, mixup_active)
# setup checkpoint saver # setup checkpoint saver
eval_metric = args.eval_metric eval_metric = args.eval_metric
@ -347,7 +346,7 @@ def main():
if dev_env.distributed and args.dist_bn in ('broadcast', 'reduce'): if dev_env.distributed and args.dist_bn in ('broadcast', 'reduce'):
if dev_env.primary: if dev_env.primary:
_logger.info("Distributing BatchNorm running means and vars") _logger.info("Distributing BatchNorm running means and vars")
distribute_bn(model, dev_env.world_size, args.dist_bn == 'reduce') distribute_bn(train_state.model, args.dist_bn == 'reduce', dev_env)
eval_metrics = evaluate( eval_metrics = evaluate(
train_state.model, train_state.model,
@ -358,7 +357,7 @@ def main():
if train_state.model_ema is not None and not args.model_ema_force_cpu: if train_state.model_ema is not None and not args.model_ema_force_cpu:
if dev_env.distributed and args.dist_bn in ('broadcast', 'reduce'): if dev_env.distributed and args.dist_bn in ('broadcast', 'reduce'):
distribute_bn(train_state.model_ema, dev_env.world_size, args.dist_bn == 'reduce') distribute_bn(train_state.model_ema, args.dist_bn == 'reduce', dev_env)
ema_eval_metrics = evaluate( ema_eval_metrics = evaluate(
train_state.model_ema.module, train_state.model_ema.module,
@ -469,8 +468,8 @@ def setup_train_task(args, dev_env: DeviceEnv, mixup_active: bool):
return train_state, train_cfg return train_state, train_cfg
def setup_data(args, dev_env, mixup_active): def setup_data(args, default_cfg, dev_env, mixup_active):
data_config = resolve_data_config(vars(args), model=model, verbose=dev_env.primary) data_config = resolve_data_config(vars(args), default_cfg=default_cfg, verbose=dev_env.primary)
# create the train and eval datasets # create the train and eval datasets
dataset_train = create_dataset( dataset_train = create_dataset(
@ -606,7 +605,7 @@ def after_train_step(
loss_meter.update(loss, output.shape[0]) loss_meter.update(loss, output.shape[0])
if state.model_ema is not None: if state.model_ema is not None:
state.model_ema.update(model) state.model_ema.update(state.model)
state = replace(state, step_count_global=state.step_count_global + 1) state = replace(state, step_count_global=state.step_count_global + 1)

Loading…
Cancel
Save