From 27bbc70d71d392a45e325c6064e35108aa984553 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Sun, 29 Nov 2020 16:12:41 -0800 Subject: [PATCH] Add back old ModelEma and rename new one to ModelEmaV2 to avoid compat breaks in dependant code. Shuffle train script, add a few comments, remove DataParallel support, support experimental torchscript training. --- timm/utils/__init__.py | 2 +- timm/utils/model.py | 5 +- timm/utils/model_ema.py | 86 ++++++++++++++++++++++--- train.py | 136 ++++++++++++++++++++-------------------- 4 files changed, 153 insertions(+), 76 deletions(-) diff --git a/timm/utils/__init__.py b/timm/utils/__init__.py index 6efc2115..0f7c4b05 100644 --- a/timm/utils/__init__.py +++ b/timm/utils/__init__.py @@ -6,5 +6,5 @@ from .log import setup_default_logging, FormatterNoInfo from .metrics import AverageMeter, accuracy from .misc import natural_key, add_bool_arg from .model import unwrap_model, get_state_dict -from .model_ema import ModelEma +from .model_ema import ModelEma, ModelEmaV2 from .summary import update_summary, get_outdir diff --git a/timm/utils/model.py b/timm/utils/model.py index 0d6700b7..cfd42806 100644 --- a/timm/utils/model.py +++ b/timm/utils/model.py @@ -6,7 +6,10 @@ from .model_ema import ModelEma def unwrap_model(model): - return model.module if hasattr(model, 'module') else model + if isinstance(model, ModelEma): + return unwrap_model(model.ema) + else: + return model.module if hasattr(model, 'module') else model def get_state_dict(model, unwrap_fn=unwrap_model): diff --git a/timm/utils/model_ema.py b/timm/utils/model_ema.py index f146e471..a767eaa5 100644 --- a/timm/utils/model_ema.py +++ b/timm/utils/model_ema.py @@ -2,15 +2,89 @@ Hacked together by / Copyright 2020 Ross Wightman """ +import logging +from collections import OrderedDict from copy import deepcopy import torch import torch.nn as nn +_logger = logging.getLogger(__name__) + + +class ModelEma: + """ Model Exponential Moving Average (DEPRECATED) + + Keep a moving average of everything in the model state_dict (parameters and buffers). + This version is deprecated, it does not work with scripted models. Will be removed eventually. + + This is intended to allow functionality like + https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage + + A smoothed version of the weights is necessary for some training schemes to perform well. + E.g. Google's hyper-params for training MNASNet, MobileNet-V3, EfficientNet, etc that use + RMSprop with a short 2.4-3 epoch decay period and slow LR decay rate of .96-.99 requires EMA + smoothing of weights to match results. Pay attention to the decay constant you are using + relative to your update count per epoch. + + To keep EMA from using GPU resources, set device='cpu'. This will save a bit of memory but + disable validation of the EMA weights. Validation will have to be done manually in a separate + process, or after the training stops converging. + + This class is sensitive where it is initialized in the sequence of model init, + GPU assignment and distributed training wrappers. + """ + def __init__(self, model, decay=0.9999, device='', resume=''): + # make a copy of the model for accumulating moving average of weights + self.ema = deepcopy(model) + self.ema.eval() + self.decay = decay + self.device = device # perform ema on different device from model if set + if device: + self.ema.to(device=device) + self.ema_has_module = hasattr(self.ema, 'module') + if resume: + self._load_checkpoint(resume) + for p in self.ema.parameters(): + p.requires_grad_(False) + + def _load_checkpoint(self, checkpoint_path): + checkpoint = torch.load(checkpoint_path, map_location='cpu') + assert isinstance(checkpoint, dict) + if 'state_dict_ema' in checkpoint: + new_state_dict = OrderedDict() + for k, v in checkpoint['state_dict_ema'].items(): + # ema model may have been wrapped by DataParallel, and need module prefix + if self.ema_has_module: + name = 'module.' + k if not k.startswith('module') else k + else: + name = k + new_state_dict[name] = v + self.ema.load_state_dict(new_state_dict) + _logger.info("Loaded state_dict_ema") + else: + _logger.warning("Failed to find state_dict_ema, starting from loaded model weights") + + def update(self, model): + # correct a mismatch in state dict keys + needs_module = hasattr(model, 'module') and not self.ema_has_module + with torch.no_grad(): + msd = model.state_dict() + for k, ema_v in self.ema.state_dict().items(): + if needs_module: + k = 'module.' + k + model_v = msd[k].detach() + if self.device: + model_v = model_v.to(device=self.device) + ema_v.copy_(ema_v * self.decay + (1. - self.decay) * model_v) + + +class ModelEmaV2(nn.Module): + """ Model Exponential Moving Average V2 -class ModelEma(nn.Module): - """ Model Exponential Moving Average Keep a moving average of everything in the model state_dict (parameters and buffers). + V2 of this module is simpler, it does not match params/buffers based on name but simply + iterates in order. It works with torchscript (JIT of full model). This is intended to allow functionality like https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage @@ -27,22 +101,20 @@ class ModelEma(nn.Module): This class is sensitive where it is initialized in the sequence of model init, GPU assignment and distributed training wrappers. - I've tested with the sequence in my own train.py for torch.DataParallel, apex.DDP, and single-GPU. """ def __init__(self, model, decay=0.9999, device=None): - super(ModelEma, self).__init__() + super(ModelEmaV2, self).__init__() # make a copy of the model for accumulating moving average of weights self.module = deepcopy(model) self.module.eval() self.decay = decay self.device = device # perform ema on different device from model if set - if device is not None: + if self.device is not None: self.module.to(device=device) def update(self, model): with torch.no_grad(): for ema_v, model_v in zip(self.module.state_dict().values(), model.state_dict().values()): - assert ema_v.shape == model_v.shape - if self.device: + if self.device is not None: model_v = model_v.to(device=self.device) ema_v.copy_(ema_v * self.decay + (1. - self.decay) * model_v) diff --git a/train.py b/train.py index f56089e3..722c79e4 100755 --- a/train.py +++ b/train.py @@ -29,7 +29,7 @@ import torchvision.utils from torch.nn.parallel import DistributedDataParallel as NativeDDP from timm.data import Dataset, create_loader, resolve_data_config, Mixup, FastCollateMixup, AugMixDataset -from timm.models import create_model, resume_checkpoint, convert_splitbn_model +from timm.models import create_model, resume_checkpoint, load_checkpoint, convert_splitbn_model from timm.utils import * from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy, JsdCrossEntropy from timm.optim import create_optimizer @@ -230,8 +230,6 @@ parser.add_argument('--recovery-interval', type=int, default=0, metavar='N', help='how many batches to wait before writing recovery checkpoint') parser.add_argument('-j', '--workers', type=int, default=4, metavar='N', help='how many training processes to use (default: 1)') -parser.add_argument('--num-gpu', type=int, default=1, - help='Number of GPUS to use') parser.add_argument('--save-images', action='store_true', default=False, help='save images of input bathes every log interval for debugging') parser.add_argument('--amp', action='store_true', default=False, @@ -255,6 +253,8 @@ parser.add_argument('--tta', type=int, default=0, metavar='N', parser.add_argument("--local_rank", default=0, type=int) parser.add_argument('--use-multi-epochs-loader', action='store_true', default=False, help='use the multi-epochs-loader to save time at the beginning of every epoch') +parser.add_argument('--torchscript', dest='torchscript', action='store_true', + help='convert model torchscript for inference') def _parse_args(): @@ -282,28 +282,36 @@ def main(): args.distributed = False if 'WORLD_SIZE' in os.environ: args.distributed = int(os.environ['WORLD_SIZE']) > 1 - if args.distributed and args.num_gpu > 1: - _logger.warning( - 'Using more than one GPU per process in distributed mode is not allowed.Setting num_gpu to 1.') - args.num_gpu = 1 - args.device = 'cuda:0' args.world_size = 1 args.rank = 0 # global rank if args.distributed: - args.num_gpu = 1 args.device = 'cuda:%d' % args.local_rank torch.cuda.set_device(args.local_rank) torch.distributed.init_process_group(backend='nccl', init_method='env://') args.world_size = torch.distributed.get_world_size() args.rank = torch.distributed.get_rank() - assert args.rank >= 0 - - if args.distributed: _logger.info('Training in distributed mode with multiple processes, 1 GPU per process. Process %d, total %d.' % (args.rank, args.world_size)) else: - _logger.info('Training with a single process on %d GPUs.' % args.num_gpu) + _logger.info('Training with a single process on 1 GPUs.') + assert args.rank >= 0 + + # resolve AMP arguments based on PyTorch / Apex availability + use_amp = None + if args.amp: + # for backwards compat, `--amp` arg tries apex before native amp + if has_apex: + args.apex_amp = True + elif has_native_amp: + args.native_amp = True + if args.apex_amp and has_apex: + use_amp = 'apex' + elif args.native_amp and has_native_amp: + use_amp = 'native' + elif args.apex_amp or args.native_amp: + _logger.warning("Neither APEX or native Torch AMP is available, using float32. " + "Install NVIDA apex or upgrade to PyTorch 1.6") torch.manual_seed(args.seed + args.rank) @@ -327,44 +335,44 @@ def main(): data_config = resolve_data_config(vars(args), model=model, verbose=args.local_rank == 0) + # setup augmentation batch splits for contrastive loss or split bn num_aug_splits = 0 if args.aug_splits > 0: assert args.aug_splits > 1, 'A split of 1 makes no sense' num_aug_splits = args.aug_splits + # enable split bn (separate bn stats per batch-portion) if args.split_bn: assert num_aug_splits > 1 or args.resplit model = convert_splitbn_model(model, max(num_aug_splits, 2)) - use_amp = None - if args.amp: - # for backwards compat, `--amp` arg tries apex before native amp - if has_apex: - args.apex_amp = True - elif has_native_amp: - args.native_amp = True - if args.apex_amp and has_apex: - use_amp = 'apex' - elif args.native_amp and has_native_amp: - use_amp = 'native' - elif args.apex_amp or args.native_amp: - _logger.warning("Neither APEX or native Torch AMP is available, using float32. " - "Install NVIDA apex or upgrade to PyTorch 1.6") + # move model to GPU, enable channels last layout if set + model.cuda() + if args.channels_last: + model = model.to(memory_format=torch.channels_last) - if args.num_gpu > 1: - if use_amp == 'apex': - _logger.warning( - 'Apex AMP does not work well with nn.DataParallel, disabling. Use DDP or Torch AMP.') - use_amp = None - model = nn.DataParallel(model, device_ids=list(range(args.num_gpu))).cuda() - assert not args.channels_last, "Channels last not supported with DP, use DDP." - else: - model.cuda() - if args.channels_last: - model = model.to(memory_format=torch.channels_last) + # setup synchronized BatchNorm for distributed training + if args.distributed and args.sync_bn: + assert not args.split_bn + if has_apex and use_amp != 'native': + # Apex SyncBN preferred unless native amp is activated + model = convert_syncbn_model(model) + else: + model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) + if args.local_rank == 0: + _logger.info( + 'Converted model to use Synchronized BatchNorm. WARNING: You may have issues if using ' + 'zero initialized BN layers (enabled by default for ResNets) while sync-bn enabled.') + + if args.torchscript: + assert not use_amp == 'apex', 'Cannot use APEX AMP with torchscripted model' + assert not args.sync_bn, 'Cannot use SyncBatchNorm with torchscripted model' + # FIXME I ran into a bug w/ AMP + torchscript + Linear layers + model = torch.jit.script(model) optimizer = create_optimizer(args, model) + # setup automatic mixed-precision (AMP) loss scaling and op casting amp_autocast = suppress # do nothing loss_scaler = None if use_amp == 'apex': @@ -390,30 +398,17 @@ def main(): loss_scaler=None if args.no_resume_opt else loss_scaler, log_info=args.local_rank == 0) + # setup exponential moving average of model weights, SWA could be used here too model_ema = None if args.model_ema: # Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper - model_ema = ModelEma( - model, - decay=args.model_ema_decay, - device='cpu' if args.model_ema_force_cpu else '', - resume=args.resume) + model_ema = ModelEmaV2( + model, decay=args.model_ema_decay, device='cpu' if args.model_ema_force_cpu else None) + if args.resume: + load_checkpoint(model_ema.module, args.resume, use_ema=True) + # setup distributed training if args.distributed: - if args.sync_bn: - assert not args.split_bn - try: - if has_apex and use_amp != 'native': - # Apex SyncBN preferred unless native amp is activated - model = convert_syncbn_model(model) - else: - model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) - if args.local_rank == 0: - _logger.info( - 'Converted model to use Synchronized BatchNorm. WARNING: You may have issues if using ' - 'zero initialized BN layers (enabled by default for ResNets) while sync-bn enabled.') - except Exception as e: - _logger.error('Failed to enable Synchronized BatchNorm. Install Apex or Torch >= 1.1') if has_apex and use_amp != 'native': # Apex DDP preferred unless native amp is activated if args.local_rank == 0: @@ -425,6 +420,7 @@ def main(): model = NativeDDP(model, device_ids=[args.local_rank]) # can use device str in Torch >= 1.1 # NOTE: EMA model does not need to be wrapped by DDP + # setup learning rate schedule and starting epoch lr_scheduler, num_epochs = create_scheduler(args, optimizer) start_epoch = 0 if args.start_epoch is not None: @@ -438,12 +434,22 @@ def main(): if args.local_rank == 0: _logger.info('Scheduled epochs: {}'.format(num_epochs)) + # create the train and eval datasets train_dir = os.path.join(args.data, 'train') if not os.path.exists(train_dir): _logger.error('Training folder does not exist at: {}'.format(train_dir)) exit(1) dataset_train = Dataset(train_dir) + eval_dir = os.path.join(args.data, 'val') + if not os.path.isdir(eval_dir): + eval_dir = os.path.join(args.data, 'validation') + if not os.path.isdir(eval_dir): + _logger.error('Validation folder does not exist at: {}'.format(eval_dir)) + exit(1) + dataset_eval = Dataset(eval_dir) + + # setup mixup / cutmix collate_fn = None mixup_fn = None mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None @@ -458,9 +464,11 @@ def main(): else: mixup_fn = Mixup(**mixup_args) + # wrap dataset in AugMix helper if num_aug_splits > 1: dataset_train = AugMixDataset(dataset_train, num_splits=num_aug_splits) + # create data loaders w/ augmentation pipeiine train_interpolation = args.train_interpolation if args.no_aug or not train_interpolation: train_interpolation = data_config['interpolation'] @@ -492,14 +500,6 @@ def main(): use_multi_epochs_loader=args.use_multi_epochs_loader ) - eval_dir = os.path.join(args.data, 'val') - if not os.path.isdir(eval_dir): - eval_dir = os.path.join(args.data, 'validation') - if not os.path.isdir(eval_dir): - _logger.error('Validation folder does not exist at: {}'.format(eval_dir)) - exit(1) - dataset_eval = Dataset(eval_dir) - loader_eval = create_loader( dataset_eval, input_size=data_config['input_size'], @@ -515,6 +515,7 @@ def main(): pin_memory=args.pin_mem, ) + # setup loss function if args.jsd: assert num_aug_splits > 1 # JSD only valid with aug splits set train_loss_fn = JsdCrossEntropy(num_splits=num_aug_splits, smoothing=args.smoothing).cuda() @@ -527,6 +528,7 @@ def main(): train_loss_fn = nn.CrossEntropyLoss().cuda() validate_loss_fn = nn.CrossEntropyLoss().cuda() + # setup checkpoint saver and eval metric tracking eval_metric = args.eval_metric best_metric = None best_epoch = None @@ -638,11 +640,11 @@ def train_epoch( torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip_grad) optimizer.step() - torch.cuda.synchronize() if model_ema is not None: model_ema.update(model) - num_updates += 1 + torch.cuda.synchronize() + num_updates += 1 batch_time_m.update(time.time() - end) if last_batch or batch_idx % args.log_interval == 0: lrl = [param_group['lr'] for param_group in optimizer.param_groups]