From 9214ca071674ce62b0eff36f0a1e3eaaba6ec2e3 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Mon, 16 Nov 2020 12:51:52 -0800 Subject: [PATCH 01/10] Simplifying EMA... --- timm/utils/model.py | 5 +---- timm/utils/model_ema.py | 49 +++++++++-------------------------------- train.py | 2 +- 3 files changed, 12 insertions(+), 44 deletions(-) diff --git a/timm/utils/model.py b/timm/utils/model.py index cfd42806..0d6700b7 100644 --- a/timm/utils/model.py +++ b/timm/utils/model.py @@ -6,10 +6,7 @@ from .model_ema import ModelEma def unwrap_model(model): - if isinstance(model, ModelEma): - return unwrap_model(model.ema) - else: - return model.module if hasattr(model, 'module') else model + return model.module if hasattr(model, 'module') else model def get_state_dict(model, unwrap_fn=unwrap_model): diff --git a/timm/utils/model_ema.py b/timm/utils/model_ema.py index b788b32e..f146e471 100644 --- a/timm/utils/model_ema.py +++ b/timm/utils/model_ema.py @@ -2,16 +2,13 @@ Hacked together by / Copyright 2020 Ross Wightman """ -import logging -from collections import OrderedDict from copy import deepcopy import torch +import torch.nn as nn -_logger = logging.getLogger(__name__) - -class ModelEma: +class ModelEma(nn.Module): """ Model Exponential Moving Average Keep a moving average of everything in the model state_dict (parameters and buffers). @@ -32,46 +29,20 @@ class ModelEma: GPU assignment and distributed training wrappers. I've tested with the sequence in my own train.py for torch.DataParallel, apex.DDP, and single-GPU. """ - def __init__(self, model, decay=0.9999, device='', resume=''): + def __init__(self, model, decay=0.9999, device=None): + super(ModelEma, self).__init__() # make a copy of the model for accumulating moving average of weights - self.ema = deepcopy(model) - self.ema.eval() + self.module = deepcopy(model) + self.module.eval() self.decay = decay self.device = device # perform ema on different device from model if set - if device: - self.ema.to(device=device) - self.ema_has_module = hasattr(self.ema, 'module') - if resume: - self._load_checkpoint(resume) - for p in self.ema.parameters(): - p.requires_grad_(False) - - def _load_checkpoint(self, checkpoint_path): - checkpoint = torch.load(checkpoint_path, map_location='cpu') - assert isinstance(checkpoint, dict) - if 'state_dict_ema' in checkpoint: - new_state_dict = OrderedDict() - for k, v in checkpoint['state_dict_ema'].items(): - # ema model may have been wrapped by DataParallel, and need module prefix - if self.ema_has_module: - name = 'module.' + k if not k.startswith('module') else k - else: - name = k - new_state_dict[name] = v - self.ema.load_state_dict(new_state_dict) - _logger.info("Loaded state_dict_ema") - else: - _logger.warning("Failed to find state_dict_ema, starting from loaded model weights") + if device is not None: + self.module.to(device=device) def update(self, model): - # correct a mismatch in state dict keys - needs_module = hasattr(model, 'module') and not self.ema_has_module with torch.no_grad(): - msd = model.state_dict() - for k, ema_v in self.ema.state_dict().items(): - if needs_module: - k = 'module.' + k - model_v = msd[k].detach() + for ema_v, model_v in zip(self.module.state_dict().values(), model.state_dict().values()): + assert ema_v.shape == model_v.shape if self.device: model_v = model_v.to(device=self.device) ema_v.copy_(ema_v * self.decay + (1. - self.decay) * model_v) diff --git a/train.py b/train.py index ef3adf85..f56089e3 100755 --- a/train.py +++ b/train.py @@ -568,7 +568,7 @@ def main(): if args.distributed and args.dist_bn in ('broadcast', 'reduce'): distribute_bn(model_ema, args.world_size, args.dist_bn == 'reduce') ema_eval_metrics = validate( - model_ema.ema, loader_eval, validate_loss_fn, args, amp_autocast=amp_autocast, log_suffix=' (EMA)') + model_ema.module, loader_eval, validate_loss_fn, args, amp_autocast=amp_autocast, log_suffix=' (EMA)') eval_metrics = ema_eval_metrics if lr_scheduler is not None: From 27bbc70d71d392a45e325c6064e35108aa984553 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Sun, 29 Nov 2020 16:12:41 -0800 Subject: [PATCH 02/10] Add back old ModelEma and rename new one to ModelEmaV2 to avoid compat breaks in dependant code. Shuffle train script, add a few comments, remove DataParallel support, support experimental torchscript training. --- timm/utils/__init__.py | 2 +- timm/utils/model.py | 5 +- timm/utils/model_ema.py | 86 ++++++++++++++++++++++--- train.py | 136 ++++++++++++++++++++-------------------- 4 files changed, 153 insertions(+), 76 deletions(-) diff --git a/timm/utils/__init__.py b/timm/utils/__init__.py index 6efc2115..0f7c4b05 100644 --- a/timm/utils/__init__.py +++ b/timm/utils/__init__.py @@ -6,5 +6,5 @@ from .log import setup_default_logging, FormatterNoInfo from .metrics import AverageMeter, accuracy from .misc import natural_key, add_bool_arg from .model import unwrap_model, get_state_dict -from .model_ema import ModelEma +from .model_ema import ModelEma, ModelEmaV2 from .summary import update_summary, get_outdir diff --git a/timm/utils/model.py b/timm/utils/model.py index 0d6700b7..cfd42806 100644 --- a/timm/utils/model.py +++ b/timm/utils/model.py @@ -6,7 +6,10 @@ from .model_ema import ModelEma def unwrap_model(model): - return model.module if hasattr(model, 'module') else model + if isinstance(model, ModelEma): + return unwrap_model(model.ema) + else: + return model.module if hasattr(model, 'module') else model def get_state_dict(model, unwrap_fn=unwrap_model): diff --git a/timm/utils/model_ema.py b/timm/utils/model_ema.py index f146e471..a767eaa5 100644 --- a/timm/utils/model_ema.py +++ b/timm/utils/model_ema.py @@ -2,15 +2,89 @@ Hacked together by / Copyright 2020 Ross Wightman """ +import logging +from collections import OrderedDict from copy import deepcopy import torch import torch.nn as nn +_logger = logging.getLogger(__name__) + + +class ModelEma: + """ Model Exponential Moving Average (DEPRECATED) + + Keep a moving average of everything in the model state_dict (parameters and buffers). + This version is deprecated, it does not work with scripted models. Will be removed eventually. + + This is intended to allow functionality like + https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage + + A smoothed version of the weights is necessary for some training schemes to perform well. + E.g. Google's hyper-params for training MNASNet, MobileNet-V3, EfficientNet, etc that use + RMSprop with a short 2.4-3 epoch decay period and slow LR decay rate of .96-.99 requires EMA + smoothing of weights to match results. Pay attention to the decay constant you are using + relative to your update count per epoch. + + To keep EMA from using GPU resources, set device='cpu'. This will save a bit of memory but + disable validation of the EMA weights. Validation will have to be done manually in a separate + process, or after the training stops converging. + + This class is sensitive where it is initialized in the sequence of model init, + GPU assignment and distributed training wrappers. + """ + def __init__(self, model, decay=0.9999, device='', resume=''): + # make a copy of the model for accumulating moving average of weights + self.ema = deepcopy(model) + self.ema.eval() + self.decay = decay + self.device = device # perform ema on different device from model if set + if device: + self.ema.to(device=device) + self.ema_has_module = hasattr(self.ema, 'module') + if resume: + self._load_checkpoint(resume) + for p in self.ema.parameters(): + p.requires_grad_(False) + + def _load_checkpoint(self, checkpoint_path): + checkpoint = torch.load(checkpoint_path, map_location='cpu') + assert isinstance(checkpoint, dict) + if 'state_dict_ema' in checkpoint: + new_state_dict = OrderedDict() + for k, v in checkpoint['state_dict_ema'].items(): + # ema model may have been wrapped by DataParallel, and need module prefix + if self.ema_has_module: + name = 'module.' + k if not k.startswith('module') else k + else: + name = k + new_state_dict[name] = v + self.ema.load_state_dict(new_state_dict) + _logger.info("Loaded state_dict_ema") + else: + _logger.warning("Failed to find state_dict_ema, starting from loaded model weights") + + def update(self, model): + # correct a mismatch in state dict keys + needs_module = hasattr(model, 'module') and not self.ema_has_module + with torch.no_grad(): + msd = model.state_dict() + for k, ema_v in self.ema.state_dict().items(): + if needs_module: + k = 'module.' + k + model_v = msd[k].detach() + if self.device: + model_v = model_v.to(device=self.device) + ema_v.copy_(ema_v * self.decay + (1. - self.decay) * model_v) + + +class ModelEmaV2(nn.Module): + """ Model Exponential Moving Average V2 -class ModelEma(nn.Module): - """ Model Exponential Moving Average Keep a moving average of everything in the model state_dict (parameters and buffers). + V2 of this module is simpler, it does not match params/buffers based on name but simply + iterates in order. It works with torchscript (JIT of full model). This is intended to allow functionality like https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage @@ -27,22 +101,20 @@ class ModelEma(nn.Module): This class is sensitive where it is initialized in the sequence of model init, GPU assignment and distributed training wrappers. - I've tested with the sequence in my own train.py for torch.DataParallel, apex.DDP, and single-GPU. """ def __init__(self, model, decay=0.9999, device=None): - super(ModelEma, self).__init__() + super(ModelEmaV2, self).__init__() # make a copy of the model for accumulating moving average of weights self.module = deepcopy(model) self.module.eval() self.decay = decay self.device = device # perform ema on different device from model if set - if device is not None: + if self.device is not None: self.module.to(device=device) def update(self, model): with torch.no_grad(): for ema_v, model_v in zip(self.module.state_dict().values(), model.state_dict().values()): - assert ema_v.shape == model_v.shape - if self.device: + if self.device is not None: model_v = model_v.to(device=self.device) ema_v.copy_(ema_v * self.decay + (1. - self.decay) * model_v) diff --git a/train.py b/train.py index f56089e3..722c79e4 100755 --- a/train.py +++ b/train.py @@ -29,7 +29,7 @@ import torchvision.utils from torch.nn.parallel import DistributedDataParallel as NativeDDP from timm.data import Dataset, create_loader, resolve_data_config, Mixup, FastCollateMixup, AugMixDataset -from timm.models import create_model, resume_checkpoint, convert_splitbn_model +from timm.models import create_model, resume_checkpoint, load_checkpoint, convert_splitbn_model from timm.utils import * from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy, JsdCrossEntropy from timm.optim import create_optimizer @@ -230,8 +230,6 @@ parser.add_argument('--recovery-interval', type=int, default=0, metavar='N', help='how many batches to wait before writing recovery checkpoint') parser.add_argument('-j', '--workers', type=int, default=4, metavar='N', help='how many training processes to use (default: 1)') -parser.add_argument('--num-gpu', type=int, default=1, - help='Number of GPUS to use') parser.add_argument('--save-images', action='store_true', default=False, help='save images of input bathes every log interval for debugging') parser.add_argument('--amp', action='store_true', default=False, @@ -255,6 +253,8 @@ parser.add_argument('--tta', type=int, default=0, metavar='N', parser.add_argument("--local_rank", default=0, type=int) parser.add_argument('--use-multi-epochs-loader', action='store_true', default=False, help='use the multi-epochs-loader to save time at the beginning of every epoch') +parser.add_argument('--torchscript', dest='torchscript', action='store_true', + help='convert model torchscript for inference') def _parse_args(): @@ -282,28 +282,36 @@ def main(): args.distributed = False if 'WORLD_SIZE' in os.environ: args.distributed = int(os.environ['WORLD_SIZE']) > 1 - if args.distributed and args.num_gpu > 1: - _logger.warning( - 'Using more than one GPU per process in distributed mode is not allowed.Setting num_gpu to 1.') - args.num_gpu = 1 - args.device = 'cuda:0' args.world_size = 1 args.rank = 0 # global rank if args.distributed: - args.num_gpu = 1 args.device = 'cuda:%d' % args.local_rank torch.cuda.set_device(args.local_rank) torch.distributed.init_process_group(backend='nccl', init_method='env://') args.world_size = torch.distributed.get_world_size() args.rank = torch.distributed.get_rank() - assert args.rank >= 0 - - if args.distributed: _logger.info('Training in distributed mode with multiple processes, 1 GPU per process. Process %d, total %d.' % (args.rank, args.world_size)) else: - _logger.info('Training with a single process on %d GPUs.' % args.num_gpu) + _logger.info('Training with a single process on 1 GPUs.') + assert args.rank >= 0 + + # resolve AMP arguments based on PyTorch / Apex availability + use_amp = None + if args.amp: + # for backwards compat, `--amp` arg tries apex before native amp + if has_apex: + args.apex_amp = True + elif has_native_amp: + args.native_amp = True + if args.apex_amp and has_apex: + use_amp = 'apex' + elif args.native_amp and has_native_amp: + use_amp = 'native' + elif args.apex_amp or args.native_amp: + _logger.warning("Neither APEX or native Torch AMP is available, using float32. " + "Install NVIDA apex or upgrade to PyTorch 1.6") torch.manual_seed(args.seed + args.rank) @@ -327,44 +335,44 @@ def main(): data_config = resolve_data_config(vars(args), model=model, verbose=args.local_rank == 0) + # setup augmentation batch splits for contrastive loss or split bn num_aug_splits = 0 if args.aug_splits > 0: assert args.aug_splits > 1, 'A split of 1 makes no sense' num_aug_splits = args.aug_splits + # enable split bn (separate bn stats per batch-portion) if args.split_bn: assert num_aug_splits > 1 or args.resplit model = convert_splitbn_model(model, max(num_aug_splits, 2)) - use_amp = None - if args.amp: - # for backwards compat, `--amp` arg tries apex before native amp - if has_apex: - args.apex_amp = True - elif has_native_amp: - args.native_amp = True - if args.apex_amp and has_apex: - use_amp = 'apex' - elif args.native_amp and has_native_amp: - use_amp = 'native' - elif args.apex_amp or args.native_amp: - _logger.warning("Neither APEX or native Torch AMP is available, using float32. " - "Install NVIDA apex or upgrade to PyTorch 1.6") + # move model to GPU, enable channels last layout if set + model.cuda() + if args.channels_last: + model = model.to(memory_format=torch.channels_last) - if args.num_gpu > 1: - if use_amp == 'apex': - _logger.warning( - 'Apex AMP does not work well with nn.DataParallel, disabling. Use DDP or Torch AMP.') - use_amp = None - model = nn.DataParallel(model, device_ids=list(range(args.num_gpu))).cuda() - assert not args.channels_last, "Channels last not supported with DP, use DDP." - else: - model.cuda() - if args.channels_last: - model = model.to(memory_format=torch.channels_last) + # setup synchronized BatchNorm for distributed training + if args.distributed and args.sync_bn: + assert not args.split_bn + if has_apex and use_amp != 'native': + # Apex SyncBN preferred unless native amp is activated + model = convert_syncbn_model(model) + else: + model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) + if args.local_rank == 0: + _logger.info( + 'Converted model to use Synchronized BatchNorm. WARNING: You may have issues if using ' + 'zero initialized BN layers (enabled by default for ResNets) while sync-bn enabled.') + + if args.torchscript: + assert not use_amp == 'apex', 'Cannot use APEX AMP with torchscripted model' + assert not args.sync_bn, 'Cannot use SyncBatchNorm with torchscripted model' + # FIXME I ran into a bug w/ AMP + torchscript + Linear layers + model = torch.jit.script(model) optimizer = create_optimizer(args, model) + # setup automatic mixed-precision (AMP) loss scaling and op casting amp_autocast = suppress # do nothing loss_scaler = None if use_amp == 'apex': @@ -390,30 +398,17 @@ def main(): loss_scaler=None if args.no_resume_opt else loss_scaler, log_info=args.local_rank == 0) + # setup exponential moving average of model weights, SWA could be used here too model_ema = None if args.model_ema: # Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper - model_ema = ModelEma( - model, - decay=args.model_ema_decay, - device='cpu' if args.model_ema_force_cpu else '', - resume=args.resume) + model_ema = ModelEmaV2( + model, decay=args.model_ema_decay, device='cpu' if args.model_ema_force_cpu else None) + if args.resume: + load_checkpoint(model_ema.module, args.resume, use_ema=True) + # setup distributed training if args.distributed: - if args.sync_bn: - assert not args.split_bn - try: - if has_apex and use_amp != 'native': - # Apex SyncBN preferred unless native amp is activated - model = convert_syncbn_model(model) - else: - model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) - if args.local_rank == 0: - _logger.info( - 'Converted model to use Synchronized BatchNorm. WARNING: You may have issues if using ' - 'zero initialized BN layers (enabled by default for ResNets) while sync-bn enabled.') - except Exception as e: - _logger.error('Failed to enable Synchronized BatchNorm. Install Apex or Torch >= 1.1') if has_apex and use_amp != 'native': # Apex DDP preferred unless native amp is activated if args.local_rank == 0: @@ -425,6 +420,7 @@ def main(): model = NativeDDP(model, device_ids=[args.local_rank]) # can use device str in Torch >= 1.1 # NOTE: EMA model does not need to be wrapped by DDP + # setup learning rate schedule and starting epoch lr_scheduler, num_epochs = create_scheduler(args, optimizer) start_epoch = 0 if args.start_epoch is not None: @@ -438,12 +434,22 @@ def main(): if args.local_rank == 0: _logger.info('Scheduled epochs: {}'.format(num_epochs)) + # create the train and eval datasets train_dir = os.path.join(args.data, 'train') if not os.path.exists(train_dir): _logger.error('Training folder does not exist at: {}'.format(train_dir)) exit(1) dataset_train = Dataset(train_dir) + eval_dir = os.path.join(args.data, 'val') + if not os.path.isdir(eval_dir): + eval_dir = os.path.join(args.data, 'validation') + if not os.path.isdir(eval_dir): + _logger.error('Validation folder does not exist at: {}'.format(eval_dir)) + exit(1) + dataset_eval = Dataset(eval_dir) + + # setup mixup / cutmix collate_fn = None mixup_fn = None mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None @@ -458,9 +464,11 @@ def main(): else: mixup_fn = Mixup(**mixup_args) + # wrap dataset in AugMix helper if num_aug_splits > 1: dataset_train = AugMixDataset(dataset_train, num_splits=num_aug_splits) + # create data loaders w/ augmentation pipeiine train_interpolation = args.train_interpolation if args.no_aug or not train_interpolation: train_interpolation = data_config['interpolation'] @@ -492,14 +500,6 @@ def main(): use_multi_epochs_loader=args.use_multi_epochs_loader ) - eval_dir = os.path.join(args.data, 'val') - if not os.path.isdir(eval_dir): - eval_dir = os.path.join(args.data, 'validation') - if not os.path.isdir(eval_dir): - _logger.error('Validation folder does not exist at: {}'.format(eval_dir)) - exit(1) - dataset_eval = Dataset(eval_dir) - loader_eval = create_loader( dataset_eval, input_size=data_config['input_size'], @@ -515,6 +515,7 @@ def main(): pin_memory=args.pin_mem, ) + # setup loss function if args.jsd: assert num_aug_splits > 1 # JSD only valid with aug splits set train_loss_fn = JsdCrossEntropy(num_splits=num_aug_splits, smoothing=args.smoothing).cuda() @@ -527,6 +528,7 @@ def main(): train_loss_fn = nn.CrossEntropyLoss().cuda() validate_loss_fn = nn.CrossEntropyLoss().cuda() + # setup checkpoint saver and eval metric tracking eval_metric = args.eval_metric best_metric = None best_epoch = None @@ -638,11 +640,11 @@ def train_epoch( torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip_grad) optimizer.step() - torch.cuda.synchronize() if model_ema is not None: model_ema.update(model) - num_updates += 1 + torch.cuda.synchronize() + num_updates += 1 batch_time_m.update(time.time() - end) if last_batch or batch_idx % args.log_interval == 0: lrl = [param_group['lr'] for param_group in optimizer.param_groups] From fd962c4b4a5214650a8678a2a987d1853933e1c0 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Sun, 29 Nov 2020 21:56:55 -0800 Subject: [PATCH 03/10] Native SiLU (Swish) op doesn't export to ONNX --- timm/models/layers/create_act.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/timm/models/layers/create_act.py b/timm/models/layers/create_act.py index 6f2ab83e..3f39bcf4 100644 --- a/timm/models/layers/create_act.py +++ b/timm/models/layers/create_act.py @@ -98,7 +98,10 @@ def get_act_fn(name='relu'): # custom autograd, then fallback if name in _ACT_FN_ME: return _ACT_FN_ME[name] - if not is_no_jit(): + if is_exportable() and name in ('silu', 'swish'): + # FIXME PyTorch SiLU doesn't ONNX export, this is a temp hack + return swish + if not (is_no_jit() or is_exportable()): if name in _ACT_FN_JIT: return _ACT_FN_JIT[name] return _ACT_FN_DEFAULT[name] @@ -114,7 +117,10 @@ def get_act_layer(name='relu'): if not (is_no_jit() or is_exportable() or is_scriptable()): if name in _ACT_LAYER_ME: return _ACT_LAYER_ME[name] - if not is_no_jit(): + if is_exportable() and name in ('silu', 'swish'): + # FIXME PyTorch SiLU doesn't ONNX export, this is a temp hack + return Swish + if not (is_no_jit() or is_exportable()): if name in _ACT_LAYER_JIT: return _ACT_LAYER_JIT[name] return _ACT_LAYER_DEFAULT[name] From 5f4b6076d8351f9ccfa7f034f83a9e9c969fb14d Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Mon, 30 Nov 2020 13:27:40 -0800 Subject: [PATCH 04/10] Fix inplace arg compat for GELU and PreLU via activation factory --- timm/models/layers/activations.py | 24 ++++++++++++++++++++++++ timm/models/layers/create_act.py | 7 +++---- 2 files changed, 27 insertions(+), 4 deletions(-) diff --git a/timm/models/layers/activations.py b/timm/models/layers/activations.py index edb2074f..e16b3bd3 100644 --- a/timm/models/layers/activations.py +++ b/timm/models/layers/activations.py @@ -119,3 +119,27 @@ class HardMish(nn.Module): def forward(self, x): return hard_mish(x, self.inplace) + + +class PReLU(nn.PReLU): + """Applies PReLU (w/ dummy inplace arg) + """ + def __init__(self, num_parameters: int = 1, init: float = 0.25, inplace: bool = False) -> None: + super(PReLU, self).__init__(num_parameters=num_parameters, init=init) + + def forward(self, input: torch.Tensor) -> torch.Tensor: + return F.prelu(input, self.weight) + + +def gelu(x: torch.Tensor, inplace: bool = False) -> torch.Tensor: + return F.gelu(x) + + +class GELU(nn.Module): + """Applies the Gaussian Error Linear Units function (w/ dummy inplace arg) + """ + def __init__(self, inplace: bool = False): + super(GELU, self).__init__() + + def forward(self, input: torch.Tensor) -> torch.Tensor: + return F.gelu(input) diff --git a/timm/models/layers/create_act.py b/timm/models/layers/create_act.py index 3f39bcf4..426c3681 100644 --- a/timm/models/layers/create_act.py +++ b/timm/models/layers/create_act.py @@ -19,10 +19,9 @@ _ACT_FN_DEFAULT = dict( relu6=F.relu6, leaky_relu=F.leaky_relu, elu=F.elu, - prelu=F.prelu, celu=F.celu, selu=F.selu, - gelu=F.gelu, + gelu=gelu, sigmoid=sigmoid, tanh=tanh, hard_sigmoid=hard_sigmoid, @@ -56,10 +55,10 @@ _ACT_LAYER_DEFAULT = dict( relu6=nn.ReLU6, leaky_relu=nn.LeakyReLU, elu=nn.ELU, - prelu=nn.PReLU, + prelu=PReLU, celu=nn.CELU, selu=nn.SELU, - gelu=nn.GELU, + gelu=GELU, sigmoid=Sigmoid, tanh=Tanh, hard_sigmoid=HardSigmoid, From 460eba7f24defba4f50898aae79e1f9d263b15e4 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Mon, 30 Nov 2020 13:29:33 -0800 Subject: [PATCH 05/10] Work around casting issue with combination of native torch AMP and torchscript for Linear layers --- timm/models/layers/classifier.py | 4 +++- timm/models/layers/linear.py | 18 ++++++++++++++++++ train.py | 1 - 3 files changed, 21 insertions(+), 2 deletions(-) create mode 100644 timm/models/layers/linear.py diff --git a/timm/models/layers/classifier.py b/timm/models/layers/classifier.py index e9194f05..89fe5458 100644 --- a/timm/models/layers/classifier.py +++ b/timm/models/layers/classifier.py @@ -6,6 +6,7 @@ from torch import nn as nn from torch.nn import functional as F from .adaptive_avgmax_pool import SelectAdaptivePool2d +from .linear import Linear def create_classifier(num_features, num_classes, pool_type='avg', use_conv=False): @@ -21,7 +22,8 @@ def create_classifier(num_features, num_classes, pool_type='avg', use_conv=False elif use_conv: fc = nn.Conv2d(num_pooled_features, num_classes, 1, bias=True) else: - fc = nn.Linear(num_pooled_features, num_classes, bias=True) + # NOTE: using my Linear wrapper that fixes AMP + torchscript casting issue + fc = Linear(num_pooled_features, num_classes, bias=True) return global_pool, fc diff --git a/timm/models/layers/linear.py b/timm/models/layers/linear.py new file mode 100644 index 00000000..4607f284 --- /dev/null +++ b/timm/models/layers/linear.py @@ -0,0 +1,18 @@ +""" Linear layer (alternate definition) +""" +import torch +import torch.nn.functional as F +from torch import nn as nn + + +class Linear(nn.Linear): + r"""Applies a linear transformation to the incoming data: :math:`y = xA^T + b` + + Wraps torch.nn.Linear to support AMP + torchscript usage by manually casting + weight & bias to input.dtype to work around an issue w/ torch.addmm in this use case. + """ + def forward(self, input: torch.Tensor) -> torch.Tensor: + if torch.jit.is_scripting(): + return F.linear(input, self.weight.to(dtype=input.dtype), self.bias.to(dtype=input.dtype)) + else: + return F.linear(input, self.weight, self.bias) \ No newline at end of file diff --git a/train.py b/train.py index 722c79e4..23a8e9b0 100755 --- a/train.py +++ b/train.py @@ -367,7 +367,6 @@ def main(): if args.torchscript: assert not use_amp == 'apex', 'Cannot use APEX AMP with torchscripted model' assert not args.sync_bn, 'Cannot use SyncBatchNorm with torchscripted model' - # FIXME I ran into a bug w/ AMP + torchscript + Linear layers model = torch.jit.script(model) optimizer = create_optimizer(args, model) From 6504a42832cb6687e3f595c37e003ea920268365 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Mon, 30 Nov 2020 13:39:08 -0800 Subject: [PATCH 06/10] Version 0.3.2 --- timm/version.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/timm/version.py b/timm/version.py index e1424ed0..73e3bb4f 100644 --- a/timm/version.py +++ b/timm/version.py @@ -1 +1 @@ -__version__ = '0.3.1' +__version__ = '0.3.2' From 2ed8f247154870be7acc1908fde0a7f457f67456 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Mon, 30 Nov 2020 16:19:52 -0800 Subject: [PATCH 07/10] A few more changes for 0.3.2 maint release. Linear layer change for mobilenetv3 and inception_v3, support no bias for linear wrapper. --- tests/test_models.py | 2 +- timm/models/helpers.py | 4 ++-- timm/models/inception_v3.py | 4 ++-- timm/models/layers/__init__.py | 1 + timm/models/layers/linear.py | 5 +++-- timm/models/mobilenetv3.py | 6 +++--- train.py | 1 + 7 files changed, 13 insertions(+), 10 deletions(-) diff --git a/tests/test_models.py b/tests/test_models.py index db8efbf3..a62625d9 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -121,7 +121,7 @@ if 'GITHUB_ACTIONS' not in os.environ: create_model(model_name, pretrained=True, in_chans=in_chans) @pytest.mark.timeout(120) - @pytest.mark.parametrize('model_name', list_models(pretrained=True)) + @pytest.mark.parametrize('model_name', list_models(pretrained=True, exclude_filters=['vit_*'])) @pytest.mark.parametrize('batch_size', [1]) def test_model_features_pretrained(model_name, batch_size): """Create that pretrained weights load when features_only==True.""" diff --git a/timm/models/helpers.py b/timm/models/helpers.py index b90ce1db..0bc6d2f7 100644 --- a/timm/models/helpers.py +++ b/timm/models/helpers.py @@ -14,7 +14,7 @@ import torch.nn as nn import torch.utils.model_zoo as model_zoo from .features import FeatureListNet, FeatureDictNet, FeatureHookNet -from .layers import Conv2dSame +from .layers import Conv2dSame, Linear _logger = logging.getLogger(__name__) @@ -234,7 +234,7 @@ def adapt_model_from_string(parent_module, model_string): if isinstance(old_module, nn.Linear): # FIXME extra checks to ensure this is actually the FC classifier layer and not a diff Linear layer? num_features = state_dict[n + '.weight'][1] - new_fc = nn.Linear( + new_fc = Linear( in_features=num_features, out_features=old_module.out_features, bias=old_module.bias is not None) set_layer(new_module, n, new_fc) if hasattr(new_module, 'num_features'): diff --git a/timm/models/inception_v3.py b/timm/models/inception_v3.py index aee1cccc..9ae7105f 100644 --- a/timm/models/inception_v3.py +++ b/timm/models/inception_v3.py @@ -10,7 +10,7 @@ import torch.nn.functional as F from timm.data import IMAGENET_DEFAULT_STD, IMAGENET_DEFAULT_MEAN, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD from .helpers import build_model_with_cfg from .registry import register_model -from .layers import trunc_normal_, create_classifier +from .layers import trunc_normal_, create_classifier, Linear def _cfg(url='', **kwargs): @@ -250,7 +250,7 @@ class InceptionAux(nn.Module): self.conv0 = conv_block(in_channels, 128, kernel_size=1) self.conv1 = conv_block(128, 768, kernel_size=5) self.conv1.stddev = 0.01 - self.fc = nn.Linear(768, num_classes) + self.fc = Linear(768, num_classes) self.fc.stddev = 0.001 def forward(self, x): diff --git a/timm/models/layers/__init__.py b/timm/models/layers/__init__.py index a252b8c1..dac1beb8 100644 --- a/timm/models/layers/__init__.py +++ b/timm/models/layers/__init__.py @@ -18,6 +18,7 @@ from .eca import EcaModule, CecaModule from .evo_norm import EvoNormBatch2d, EvoNormSample2d from .helpers import to_ntuple, to_2tuple, to_3tuple, to_4tuple from .inplace_abn import InplaceAbn +from .linear import Linear from .mixed_conv2d import MixedConv2d from .norm_act import BatchNormAct2d from .padding import get_padding diff --git a/timm/models/layers/linear.py b/timm/models/layers/linear.py index 4607f284..38fe3380 100644 --- a/timm/models/layers/linear.py +++ b/timm/models/layers/linear.py @@ -13,6 +13,7 @@ class Linear(nn.Linear): """ def forward(self, input: torch.Tensor) -> torch.Tensor: if torch.jit.is_scripting(): - return F.linear(input, self.weight.to(dtype=input.dtype), self.bias.to(dtype=input.dtype)) + bias = self.bias.to(dtype=input.dtype) if self.bias is not None else None + return F.linear(input, self.weight.to(dtype=input.dtype), bias=bias) else: - return F.linear(input, self.weight, self.bias) \ No newline at end of file + return F.linear(input, self.weight, self.bias) diff --git a/timm/models/mobilenetv3.py b/timm/models/mobilenetv3.py index e20b6d34..ea930308 100644 --- a/timm/models/mobilenetv3.py +++ b/timm/models/mobilenetv3.py @@ -18,7 +18,7 @@ from .efficientnet_blocks import round_channels, resolve_bn_args, resolve_act_la from .efficientnet_builder import EfficientNetBuilder, decode_arch_def, efficientnet_init_weights from .features import FeatureInfo, FeatureHooks from .helpers import build_model_with_cfg -from .layers import SelectAdaptivePool2d, create_conv2d, get_act_fn, hard_sigmoid +from .layers import SelectAdaptivePool2d, Linear, create_conv2d, get_act_fn, hard_sigmoid from .registry import register_model __all__ = ['MobileNetV3'] @@ -105,7 +105,7 @@ class MobileNetV3(nn.Module): num_pooled_chs = head_chs * self.global_pool.feat_mult() self.conv_head = create_conv2d(num_pooled_chs, self.num_features, 1, padding=pad_type, bias=head_bias) self.act2 = act_layer(inplace=True) - self.classifier = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() + self.classifier = Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() efficientnet_init_weights(self) @@ -123,7 +123,7 @@ class MobileNetV3(nn.Module): self.num_classes = num_classes # cannot meaningfully change pooling of efficient head after creation self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) - self.classifier = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() + self.classifier = Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() def forward_features(self, x): x = self.conv_stem(x) diff --git a/train.py b/train.py index 23a8e9b0..7a93a1b6 100755 --- a/train.py +++ b/train.py @@ -327,6 +327,7 @@ def main(): bn_tf=args.bn_tf, bn_momentum=args.bn_momentum, bn_eps=args.bn_eps, + scriptable=args.torchscript, checkpoint_path=args.initial_checkpoint) if args.local_rank == 0: From 4ca52d73d8fb1ebfd5d272576295f03f7e34fc15 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Thu, 3 Dec 2020 10:05:09 -0800 Subject: [PATCH 08/10] Add separate set and update method to ModelEmaV2 --- timm/utils/model_ema.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/timm/utils/model_ema.py b/timm/utils/model_ema.py index a767eaa5..073d5c5e 100644 --- a/timm/utils/model_ema.py +++ b/timm/utils/model_ema.py @@ -112,9 +112,15 @@ class ModelEmaV2(nn.Module): if self.device is not None: self.module.to(device=device) - def update(self, model): + def _update(self, model, update_fn): with torch.no_grad(): for ema_v, model_v in zip(self.module.state_dict().values(), model.state_dict().values()): if self.device is not None: model_v = model_v.to(device=self.device) - ema_v.copy_(ema_v * self.decay + (1. - self.decay) * model_v) + ema_v.copy_(update_fn(ema_v, model_v)) + + def update(self, model): + self._update(model, update_fn=lambda e, m: self.decay * e + (1. - self.decay) * m) + + def set(self, model): + self._update(model, update_fn=lambda e, m: m) From 867a0e5a049516b9597e05751799a2502fae0ec8 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Thu, 3 Dec 2020 10:24:35 -0800 Subject: [PATCH 09/10] Add default_cfg back to models wrapped in feature extraction module as per discussion in #294. --- timm/models/efficientnet.py | 6 ++++-- timm/models/helpers.py | 10 ++++++++++ timm/models/hrnet.py | 6 ++++-- timm/models/mobilenetv3.py | 6 ++++-- 4 files changed, 22 insertions(+), 6 deletions(-) diff --git a/timm/models/efficientnet.py b/timm/models/efficientnet.py index a61a6f47..7eeda3ca 100644 --- a/timm/models/efficientnet.py +++ b/timm/models/efficientnet.py @@ -34,7 +34,7 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCE from .efficientnet_blocks import round_channels, resolve_bn_args, resolve_act_layer, BN_EPS_TF_DEFAULT from .efficientnet_builder import EfficientNetBuilder, decode_arch_def, efficientnet_init_weights from .features import FeatureInfo, FeatureHooks -from .helpers import build_model_with_cfg +from .helpers import build_model_with_cfg, default_cfg_for_features from .layers import create_conv2d, create_classifier from .registry import register_model @@ -462,9 +462,11 @@ def _create_effnet(model_kwargs, variant, pretrained=False): else: load_strict = True model_cls = EfficientNet - return build_model_with_cfg( + model = build_model_with_cfg( model_cls, variant, pretrained, default_cfg=default_cfgs[variant], pretrained_strict=load_strict, **model_kwargs) + model.default_cfg = default_cfg_for_features(model.default_cfg) + return model def _gen_mnasnet_a1(variant, channel_multiplier=1.0, pretrained=False, **kwargs): diff --git a/timm/models/helpers.py b/timm/models/helpers.py index 0bc6d2f7..77b98dc6 100644 --- a/timm/models/helpers.py +++ b/timm/models/helpers.py @@ -251,6 +251,15 @@ def adapt_model_from_file(parent_module, model_variant): return adapt_model_from_string(parent_module, f.read().strip()) +def default_cfg_for_features(default_cfg): + default_cfg = deepcopy(default_cfg) + # remove default pretrained cfg fields that don't have much relevance for feature backbone + to_remove = ('num_classes', 'crop_pct', 'classifier') # add default final pool size? + for tr in to_remove: + default_cfg.pop(tr, None) + return default_cfg + + def build_model_with_cfg( model_cls: Callable, variant: str, @@ -296,5 +305,6 @@ def build_model_with_cfg( else: assert False, f'Unknown feature class {feature_cls}' model = feature_cls(model, **feature_cfg) + model.default_cfg = default_cfg_for_features(default_cfg) # add back default_cfg return model diff --git a/timm/models/hrnet.py b/timm/models/hrnet.py index 2e8757b5..d246812e 100644 --- a/timm/models/hrnet.py +++ b/timm/models/hrnet.py @@ -17,7 +17,7 @@ import torch.nn.functional as F from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from .features import FeatureInfo -from .helpers import build_model_with_cfg +from .helpers import build_model_with_cfg, default_cfg_for_features from .layers import create_classifier from .registry import register_model from .resnet import BasicBlock, Bottleneck # leveraging ResNet blocks w/ additional features like SE @@ -779,9 +779,11 @@ def _create_hrnet(variant, pretrained, **model_kwargs): model_kwargs['num_classes'] = 0 strict = False - return build_model_with_cfg( + model = build_model_with_cfg( model_cls, variant, pretrained, default_cfg=default_cfgs[variant], model_cfg=cfg_cls[variant], pretrained_strict=strict, **model_kwargs) + model.default_cfg = default_cfg_for_features(model.default_cfg) + return model @register_model diff --git a/timm/models/mobilenetv3.py b/timm/models/mobilenetv3.py index ea930308..afded75f 100644 --- a/timm/models/mobilenetv3.py +++ b/timm/models/mobilenetv3.py @@ -17,7 +17,7 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCE from .efficientnet_blocks import round_channels, resolve_bn_args, resolve_act_layer, BN_EPS_TF_DEFAULT from .efficientnet_builder import EfficientNetBuilder, decode_arch_def, efficientnet_init_weights from .features import FeatureInfo, FeatureHooks -from .helpers import build_model_with_cfg +from .helpers import build_model_with_cfg, default_cfg_for_features from .layers import SelectAdaptivePool2d, Linear, create_conv2d, get_act_fn, hard_sigmoid from .registry import register_model @@ -211,9 +211,11 @@ def _create_mnv3(model_kwargs, variant, pretrained=False): else: load_strict = True model_cls = MobileNetV3 - return build_model_with_cfg( + model = build_model_with_cfg( model_cls, variant, pretrained, default_cfg=default_cfgs[variant], pretrained_strict=load_strict, **model_kwargs) + model.default_cfg = default_cfg_for_features(model.default_cfg) + return model def _gen_mobilenet_v3_rw(variant, channel_multiplier=1.0, pretrained=False, **kwargs): From cd72e66effd32e460a7f129d3c426b7151c89723 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Thu, 3 Dec 2020 12:33:01 -0800 Subject: [PATCH 10/10] Bug in last mod for features_only default_cfg --- timm/models/efficientnet.py | 12 ++++++------ timm/models/hrnet.py | 10 +++++----- timm/models/mobilenetv3.py | 12 ++++++------ 3 files changed, 17 insertions(+), 17 deletions(-) diff --git a/timm/models/efficientnet.py b/timm/models/efficientnet.py index 7eeda3ca..4a89590b 100644 --- a/timm/models/efficientnet.py +++ b/timm/models/efficientnet.py @@ -453,19 +453,19 @@ class EfficientNetFeatures(nn.Module): def _create_effnet(model_kwargs, variant, pretrained=False): + features_only = False + model_cls = EfficientNet if model_kwargs.pop('features_only', False): - load_strict = False + features_only = True model_kwargs.pop('num_classes', 0) model_kwargs.pop('num_features', 0) model_kwargs.pop('head_conv', None) model_cls = EfficientNetFeatures - else: - load_strict = True - model_cls = EfficientNet model = build_model_with_cfg( model_cls, variant, pretrained, default_cfg=default_cfgs[variant], - pretrained_strict=load_strict, **model_kwargs) - model.default_cfg = default_cfg_for_features(model.default_cfg) + pretrained_strict=not features_only, **model_kwargs) + if features_only: + model.default_cfg = default_cfg_for_features(model.default_cfg) return model diff --git a/timm/models/hrnet.py b/timm/models/hrnet.py index d246812e..1c0bc9f0 100644 --- a/timm/models/hrnet.py +++ b/timm/models/hrnet.py @@ -773,16 +773,16 @@ class HighResolutionNetFeatures(HighResolutionNet): def _create_hrnet(variant, pretrained, **model_kwargs): model_cls = HighResolutionNet - strict = True + features_only = False if model_kwargs.pop('features_only', False): model_cls = HighResolutionNetFeatures model_kwargs['num_classes'] = 0 - strict = False - + features_only = True model = build_model_with_cfg( model_cls, variant, pretrained, default_cfg=default_cfgs[variant], - model_cfg=cfg_cls[variant], pretrained_strict=strict, **model_kwargs) - model.default_cfg = default_cfg_for_features(model.default_cfg) + model_cfg=cfg_cls[variant], pretrained_strict=not features_only, **model_kwargs) + if features_only: + model.default_cfg = default_cfg_for_features(model.default_cfg) return model diff --git a/timm/models/mobilenetv3.py b/timm/models/mobilenetv3.py index afded75f..8a48ce72 100644 --- a/timm/models/mobilenetv3.py +++ b/timm/models/mobilenetv3.py @@ -201,20 +201,20 @@ class MobileNetV3Features(nn.Module): def _create_mnv3(model_kwargs, variant, pretrained=False): + features_only = False + model_cls = MobileNetV3 if model_kwargs.pop('features_only', False): - load_strict = False + features_only = True model_kwargs.pop('num_classes', 0) model_kwargs.pop('num_features', 0) model_kwargs.pop('head_conv', None) model_kwargs.pop('head_bias', None) model_cls = MobileNetV3Features - else: - load_strict = True - model_cls = MobileNetV3 model = build_model_with_cfg( model_cls, variant, pretrained, default_cfg=default_cfgs[variant], - pretrained_strict=load_strict, **model_kwargs) - model.default_cfg = default_cfg_for_features(model.default_cfg) + pretrained_strict=not features_only, **model_kwargs) + if features_only: + model.default_cfg = default_cfg_for_features(model.default_cfg) return model