Fix distribute_bn and model_ema

pull/1239/head
Ross Wightman 3 years ago
parent 74d2829341
commit 6d90fcf282

@ -29,14 +29,13 @@ import torch.nn as nn
import torchvision.utils
from timm.bits import initialize_device, setup_model_and_optimizer, DeviceEnv, Logger, Tracker,\
TrainState, TrainServices, TrainCfg, AccuracyTopK, AvgTensor
TrainState, TrainServices, TrainCfg, AccuracyTopK, AvgTensor, distribute_bn
from timm.data import create_dataset, create_loader, resolve_data_config, Mixup, FastCollateMixup, AugMixDataset
from timm.models import create_model, safe_model_name, resume_checkpoint, load_checkpoint,\
convert_splitbn_model, model_parameters
from timm.utils import *
from timm.models import create_model, safe_model_name, convert_splitbn_model
from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy, JsdCrossEntropy
from timm.optim import create_optimizer_v2, optimizer_kwargs
from timm.scheduler import create_scheduler
from timm.utils import setup_default_logging, random_seed, get_outdir, CheckpointSaver
_logger = logging.getLogger('train')
@ -290,7 +289,7 @@ def main():
train_state, train_cfg = setup_train_task(args, dev_env, mixup_active)
data_config, loader_eval, loader_train = setup_data(args, dev_env, mixup_active)
data_config, loader_eval, loader_train = setup_data(args, train_state.model.default_cfg, dev_env, mixup_active)
# setup checkpoint saver
eval_metric = args.eval_metric
@ -347,7 +346,7 @@ def main():
if dev_env.distributed and args.dist_bn in ('broadcast', 'reduce'):
if dev_env.primary:
_logger.info("Distributing BatchNorm running means and vars")
distribute_bn(model, dev_env.world_size, args.dist_bn == 'reduce')
distribute_bn(train_state.model, args.dist_bn == 'reduce', dev_env)
eval_metrics = evaluate(
train_state.model,
@ -358,7 +357,7 @@ def main():
if train_state.model_ema is not None and not args.model_ema_force_cpu:
if dev_env.distributed and args.dist_bn in ('broadcast', 'reduce'):
distribute_bn(train_state.model_ema, dev_env.world_size, args.dist_bn == 'reduce')
distribute_bn(train_state.model_ema, args.dist_bn == 'reduce', dev_env)
ema_eval_metrics = evaluate(
train_state.model_ema.module,
@ -469,8 +468,8 @@ def setup_train_task(args, dev_env: DeviceEnv, mixup_active: bool):
return train_state, train_cfg
def setup_data(args, dev_env, mixup_active):
data_config = resolve_data_config(vars(args), model=model, verbose=dev_env.primary)
def setup_data(args, default_cfg, dev_env, mixup_active):
data_config = resolve_data_config(vars(args), default_cfg=default_cfg, verbose=dev_env.primary)
# create the train and eval datasets
dataset_train = create_dataset(
@ -606,7 +605,7 @@ def after_train_step(
loss_meter.update(loss, output.shape[0])
if state.model_ema is not None:
state.model_ema.update(model)
state.model_ema.update(state.model)
state = replace(state, step_count_global=state.step_count_global + 1)

Loading…
Cancel
Save