From d98967ed5d4b295d0d1871a74428956673f5f7e5 Mon Sep 17 00:00:00 2001 From: datamining99 Date: Sat, 22 Aug 2020 09:44:23 +0900 Subject: [PATCH 1/8] add support for native torch AMP in torch 1.6 --- train.py | 39 ++++++++++++++++++++++++++++++--------- 1 file changed, 30 insertions(+), 9 deletions(-) diff --git a/train.py b/train.py index c28bd266..cdb5a95b 100755 --- a/train.py +++ b/train.py @@ -25,8 +25,11 @@ try: from apex.parallel import convert_syncbn_model has_apex = True except ImportError: + from torch.cuda import amp from torch.nn.parallel import DistributedDataParallel as DDP has_apex = False + + from timm.data import Dataset, create_loader, resolve_data_config, Mixup, FastCollateMixup, AugMixDataset from timm.models import create_model, resume_checkpoint, convert_splitbn_model @@ -327,6 +330,10 @@ def main(): if has_apex and args.amp: model, optimizer = amp.initialize(model, optimizer, opt_level='O1') use_amp = True + elif args.amp: + _logger.info('Using torch AMP. Install NVIDIA Apex for Apex AMP.') + scaler = torch.cuda.amp.GradScaler() + use_amp = True if args.local_rank == 0: _logger.info('NVIDIA APEX {}. AMP {}.'.format( 'installed' if has_apex else 'not installed', 'on' if use_amp else 'off')) @@ -506,7 +513,8 @@ def main(): train_metrics = train_epoch( epoch, model, loader_train, optimizer, train_loss_fn, args, lr_scheduler=lr_scheduler, saver=saver, output_dir=output_dir, - use_amp=use_amp, model_ema=model_ema, mixup_fn=mixup_fn) + use_amp=use_amp, has_apex=has_apex, scaler = scaler, + model_ema=model_ema, mixup_fn=mixup_fn) if args.distributed and args.dist_bn in ('broadcast', 'reduce'): if args.local_rank == 0: @@ -546,7 +554,8 @@ def main(): def train_epoch( epoch, model, loader, optimizer, loss_fn, args, - lr_scheduler=None, saver=None, output_dir='', use_amp=False, model_ema=None, mixup_fn=None): + lr_scheduler=None, saver=None, output_dir='', use_amp=False, + has_apex=False, scaler = None, model_ema=None, mixup_fn=None): if args.mixup_off_epoch and epoch >= args.mixup_off_epoch: if args.prefetcher and loader.mixup_enabled: @@ -570,20 +579,32 @@ def train_epoch( input, target = input.cuda(), target.cuda() if mixup_fn is not None: input, target = mixup_fn(input, target) - - output = model(input) - - loss = loss_fn(output, target) + if not has_apex and use_amp: + with torch.cuda.amp.autocast(): + output = model(input) + loss = loss_fn(output, target) + else: + output = model(input) + loss = loss_fn(output, target) + if not args.distributed: losses_m.update(loss.item(), input.size(0)) optimizer.zero_grad() if use_amp: - with amp.scale_loss(loss, optimizer) as scaled_loss: - scaled_loss.backward() + if has_apex: + with amp.scale_loss(loss, optimizer) as scaled_loss: + scaled_loss.backward() + else: + scaler.scale(loss).backward() + else: loss.backward() - optimizer.step() + if not has_apex and use_amp: + scaler.step(optimizer) + scaler.update() + else: + optimizer.step() torch.cuda.synchronize() if model_ema is not None: From 5f563ca4df0ed101ee0f5a7966e7a28238f4d79c Mon Sep 17 00:00:00 2001 From: datamining99 Date: Sat, 22 Aug 2020 11:31:50 +0900 Subject: [PATCH 2/8] fix save_checkpoint bug with native amp --- train.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/train.py b/train.py index cdb5a95b..a99d5d36 100755 --- a/train.py +++ b/train.py @@ -544,7 +544,7 @@ def main(): save_metric = eval_metrics[eval_metric] best_metric, best_epoch = saver.save_checkpoint( model, optimizer, args, - epoch=epoch, model_ema=model_ema, metric=save_metric, use_amp=use_amp) + epoch=epoch, model_ema=model_ema, metric=save_metric, use_amp=has_apex&use_amp) except KeyboardInterrupt: pass @@ -647,8 +647,9 @@ def train_epoch( if saver is not None and args.recovery_interval and ( last_batch or (batch_idx + 1) % args.recovery_interval == 0): + saver.save_recovery( - model, optimizer, args, epoch, model_ema=model_ema, use_amp=use_amp, batch_idx=batch_idx) + model, optimizer, args, epoch, model_ema=model_ema, use_amp=has_apex&use_amp, batch_idx=batch_idx) if lr_scheduler is not None: lr_scheduler.step_update(num_updates=num_updates, metric=losses_m.avg) From c2cd1a332eca0f109b7fb357aa98eb7a7bfabc11 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Mon, 31 Aug 2020 17:58:16 -0700 Subject: [PATCH 3/8] Improve torch amp support and add channels_last support for train/validate scripts --- timm/utils.py | 16 +++-- train.py | 177 +++++++++++++++++++++++++++++++------------------- validate.py | 59 +++++++++++++---- 3 files changed, 165 insertions(+), 87 deletions(-) diff --git a/timm/utils.py b/timm/utils.py index 65255b53..6c10d283 100644 --- a/timm/utils.py +++ b/timm/utils.py @@ -49,7 +49,8 @@ class CheckpointSaver: checkpoint_dir='', recovery_dir='', decreasing=False, - max_history=10): + max_history=10, + save_amp=False): # state self.checkpoint_files = [] # (filename, metric) tuples in order of decreasing betterness @@ -67,13 +68,14 @@ class CheckpointSaver: self.decreasing = decreasing # a lower metric is better if True self.cmp = operator.lt if decreasing else operator.gt # True if lhs better than rhs self.max_history = max_history + self.save_apex_amp = save_amp # save APEX amp state assert self.max_history >= 1 - def save_checkpoint(self, model, optimizer, args, epoch, model_ema=None, metric=None, use_amp=False): + def save_checkpoint(self, model, optimizer, args, epoch, model_ema=None, metric=None): assert epoch >= 0 tmp_save_path = os.path.join(self.checkpoint_dir, 'tmp' + self.extension) last_save_path = os.path.join(self.checkpoint_dir, 'last' + self.extension) - self._save(tmp_save_path, model, optimizer, args, epoch, model_ema, metric, use_amp) + self._save(tmp_save_path, model, optimizer, args, epoch, model_ema, metric) if os.path.exists(last_save_path): os.unlink(last_save_path) # required for Windows support. os.rename(tmp_save_path, last_save_path) @@ -105,7 +107,7 @@ class CheckpointSaver: return (None, None) if self.best_metric is None else (self.best_metric, self.best_epoch) - def _save(self, save_path, model, optimizer, args, epoch, model_ema=None, metric=None, use_amp=False): + def _save(self, save_path, model, optimizer, args, epoch, model_ema=None, metric=None): save_state = { 'epoch': epoch, 'arch': args.model, @@ -114,7 +116,7 @@ class CheckpointSaver: 'args': args, 'version': 2, # version < 2 increments epoch before save } - if use_amp and 'state_dict' in amp.__dict__: + if self.save_apex_amp and 'state_dict' in amp.__dict__: save_state['amp'] = amp.state_dict() if model_ema is not None: save_state['state_dict_ema'] = get_state_dict(model_ema) @@ -136,11 +138,11 @@ class CheckpointSaver: _logger.error("Exception '{}' while deleting checkpoint".format(e)) self.checkpoint_files = self.checkpoint_files[:delete_index] - def save_recovery(self, model, optimizer, args, epoch, model_ema=None, use_amp=False, batch_idx=0): + def save_recovery(self, model, optimizer, args, epoch, model_ema=None, batch_idx=0): assert epoch >= 0 filename = '-'.join([self.recovery_prefix, str(epoch), str(batch_idx)]) + self.extension save_path = os.path.join(self.recovery_dir, filename) - self._save(save_path, model, optimizer, args, epoch, model_ema, use_amp=use_amp) + self._save(save_path, model, optimizer, args, epoch, model_ema) if os.path.exists(self.last_recovery_file): try: _logger.debug("Cleaning recovery: {}".format(self.last_recovery_file)) diff --git a/train.py b/train.py index a99d5d36..e017605d 100755 --- a/train.py +++ b/train.py @@ -18,18 +18,12 @@ import argparse import time import yaml from datetime import datetime +from contextlib import suppress -try: - from apex import amp - from apex.parallel import DistributedDataParallel as DDP - from apex.parallel import convert_syncbn_model - has_apex = True -except ImportError: - from torch.cuda import amp - from torch.nn.parallel import DistributedDataParallel as DDP - has_apex = False - - +import torch +import torch.nn as nn +import torchvision.utils +from torch.nn.parallel import DistributedDataParallel as NativeDDP from timm.data import Dataset, create_loader, resolve_data_config, Mixup, FastCollateMixup, AugMixDataset from timm.models import create_model, resume_checkpoint, convert_splitbn_model @@ -38,14 +32,24 @@ from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy, JsdCro from timm.optim import create_optimizer from timm.scheduler import create_scheduler -import torch -import torch.nn as nn -import torchvision.utils +try: + from apex import amp + from apex.parallel import DistributedDataParallel as ApexDDP + from apex.parallel import convert_syncbn_model + has_apex = True +except ImportError: + has_apex = False + +has_native_amp = False +try: + if getattr(torch.cuda.amp, 'autocast') is not None: + has_native_amp = True +except AttributeError: + pass torch.backends.cudnn.benchmark = True _logger = logging.getLogger('train') - # The first arg parser parses out only the --config argument, this argument is used to # load a yaml file containing key-values that override the defaults for the main parser below config_parser = parser = argparse.ArgumentParser(description='Training Config', add_help=False) @@ -221,7 +225,13 @@ parser.add_argument('--num-gpu', type=int, default=1, parser.add_argument('--save-images', action='store_true', default=False, help='save images of input bathes every log interval for debugging') parser.add_argument('--amp', action='store_true', default=False, - help='use NVIDIA amp for mixed precision training') + help='use NVIDIA Apex AMP or Native AMP for mixed precision training') +parser.add_argument('--apex-amp', action='store_true', default=False, + help='Use NVIDIA Apex AMP mixed precision') +parser.add_argument('--native-amp', action='store_true', default=False, + help='Use Native Torch AMP mixed precision') +parser.add_argument('--channels-last', action='store_true', default=False, + help='Use channels_last memory layout') parser.add_argument('--pin-mem', action='store_true', default=False, help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.') parser.add_argument('--no-prefetcher', action='store_true', default=False, @@ -254,6 +264,23 @@ def _parse_args(): return args, args_text +class ApexScaler: + def __call__(self, loss, optimizer): + with amp.scale_loss(loss, optimizer) as scaled_loss: + scaled_loss.backward() + optimizer.step() + + +class NativeScaler: + def __init__(self): + self._scaler = torch.cuda.amp.GradScaler() + + def __call__(self, loss, optimizer): + self._scaler.scale(loss).backward() + self._scaler.step(optimizer) + self._scaler.update() + + def main(): setup_default_logging() args, args_text = _parse_args() @@ -263,7 +290,8 @@ def main(): if 'WORLD_SIZE' in os.environ: args.distributed = int(os.environ['WORLD_SIZE']) > 1 if args.distributed and args.num_gpu > 1: - _logger.warning('Using more than one GPU per process in distributed mode is not allowed. Setting num_gpu to 1.') + _logger.warning( + 'Using more than one GPU per process in distributed mode is not allowed.Setting num_gpu to 1.') args.num_gpu = 1 args.device = 'cuda:0' @@ -315,28 +343,50 @@ def main(): assert num_aug_splits > 1 or args.resplit model = convert_splitbn_model(model, max(num_aug_splits, 2)) + use_amp = None + if args.amp: + # for backwards compat, `--amp` arg tries apex before native amp + if has_apex: + args.apex_amp = True + elif has_native_amp: + args.native_amp = True + if args.apex_amp and has_apex: + use_amp = 'apex' + elif args.native_amp and has_native_amp: + use_amp = 'native' + elif args.apex_amp or args.native_amp: + _logger.warning("Neither APEX or native Torch AMP is available, using float32. " + "Install NVIDA apex or upgrade to PyTorch 1.6") + if args.num_gpu > 1: - if args.amp: + if use_amp == 'apex': _logger.warning( - 'AMP does not work well with nn.DataParallel, disabling. Use distributed mode for multi-GPU AMP.') - args.amp = False + 'Apex AMP does not work well with nn.DataParallel, disabling. Use DDP or Torch AMP.') + use_amp = None model = nn.DataParallel(model, device_ids=list(range(args.num_gpu))).cuda() + assert not args.channels_last, "Channels last not supported with DP, use DDP." else: model.cuda() + if args.channels_last: + model = model.to(memory_format=torch.channels_last) optimizer = create_optimizer(args, model) - use_amp = False - if has_apex and args.amp: + amp_autocast = suppress # do nothing + loss_scaler = None + if use_amp == 'apex': model, optimizer = amp.initialize(model, optimizer, opt_level='O1') - use_amp = True - elif args.amp: - _logger.info('Using torch AMP. Install NVIDIA Apex for Apex AMP.') - scaler = torch.cuda.amp.GradScaler() - use_amp = True - if args.local_rank == 0: - _logger.info('NVIDIA APEX {}. AMP {}.'.format( - 'installed' if has_apex else 'not installed', 'on' if use_amp else 'off')) + loss_scaler = ApexScaler() + if args.local_rank == 0: + _logger.info('Using NVIDIA APEX AMP. Training in mixed precision.') + elif use_amp == 'native': + amp_autocast = torch.cuda.amp.autocast + loss_scaler = NativeScaler() + if args.local_rank == 0: + _logger.info('Using native Torch AMP. Training in mixed precision.') + else: + if args.local_rank == 0: + _logger.info('AMP not enabled. Training in float32.') # optionally resume from a checkpoint resume_state = {} @@ -346,7 +396,7 @@ def main(): if resume_state and not args.no_resume_opt: if 'optimizer' in resume_state: if args.local_rank == 0: - _logger.info('Restoring Optimizer state from checkpoint') + _logger.info('Restoring optimizer state from checkpoint') optimizer.load_state_dict(resume_state['optimizer']) if use_amp and 'amp' in resume_state and 'load_state_dict' in amp.__dict__: if args.local_rank == 0: @@ -367,7 +417,8 @@ def main(): if args.sync_bn: assert not args.split_bn try: - if has_apex: + if has_apex and use_amp != 'native': + # Apex SyncBN preferred unless native amp is activated model = convert_syncbn_model(model) else: model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) @@ -377,12 +428,15 @@ def main(): 'zero initialized BN layers (enabled by default for ResNets) while sync-bn enabled.') except Exception as e: _logger.error('Failed to enable Synchronized BatchNorm. Install Apex or Torch >= 1.1') - if has_apex: - model = DDP(model, delay_allreduce=True) + if has_apex and use_amp != 'native': + # Apex DDP preferred unless native amp is activated + if args.local_rank == 0: + _logger.info("Using NVIDIA APEX DistributedDataParallel.") + model = ApexDDP(model, delay_allreduce=True) else: if args.local_rank == 0: - _logger.info("Using torch DistributedDataParallel. Install NVIDIA Apex for Apex DDP.") - model = DDP(model, device_ids=[args.local_rank]) # can use device str in Torch >= 1.1 + _logger.info("Using native Torch DistributedDataParallel.") + model = NativeDDP(model, device_ids=[args.local_rank]) # can use device str in Torch >= 1.1 # NOTE: EMA model does not need to be wrapped by DDP lr_scheduler, num_epochs = create_scheduler(args, optimizer) @@ -501,7 +555,7 @@ def main(): ]) output_dir = get_outdir(output_base, 'train', exp_name) decreasing = True if eval_metric == 'loss' else False - saver = CheckpointSaver(checkpoint_dir=output_dir, decreasing=decreasing) + saver = CheckpointSaver(checkpoint_dir=output_dir, decreasing=decreasing, save_amp=use_amp == 'apex') with open(os.path.join(output_dir, 'args.yaml'), 'w') as f: f.write(args_text) @@ -513,22 +567,20 @@ def main(): train_metrics = train_epoch( epoch, model, loader_train, optimizer, train_loss_fn, args, lr_scheduler=lr_scheduler, saver=saver, output_dir=output_dir, - use_amp=use_amp, has_apex=has_apex, scaler = scaler, - model_ema=model_ema, mixup_fn=mixup_fn) + amp_autocast=amp_autocast, loss_scaler=loss_scaler, model_ema=model_ema, mixup_fn=mixup_fn) if args.distributed and args.dist_bn in ('broadcast', 'reduce'): if args.local_rank == 0: _logger.info("Distributing BatchNorm running means and vars") distribute_bn(model, args.world_size, args.dist_bn == 'reduce') - eval_metrics = validate(model, loader_eval, validate_loss_fn, args) + eval_metrics = validate(model, loader_eval, validate_loss_fn, args, amp_autocast=amp_autocast) if model_ema is not None and not args.model_ema_force_cpu: if args.distributed and args.dist_bn in ('broadcast', 'reduce'): distribute_bn(model_ema, args.world_size, args.dist_bn == 'reduce') - ema_eval_metrics = validate( - model_ema.ema, loader_eval, validate_loss_fn, args, log_suffix=' (EMA)') + model_ema.ema, loader_eval, validate_loss_fn, args, amp_autocast=amp_autocast, log_suffix=' (EMA)') eval_metrics = ema_eval_metrics if lr_scheduler is not None: @@ -543,8 +595,7 @@ def main(): # save proper checkpoint with eval metric save_metric = eval_metrics[eval_metric] best_metric, best_epoch = saver.save_checkpoint( - model, optimizer, args, - epoch=epoch, model_ema=model_ema, metric=save_metric, use_amp=has_apex&use_amp) + model, optimizer, args, epoch=epoch, model_ema=model_ema, metric=save_metric) except KeyboardInterrupt: pass @@ -554,8 +605,8 @@ def main(): def train_epoch( epoch, model, loader, optimizer, loss_fn, args, - lr_scheduler=None, saver=None, output_dir='', use_amp=False, - has_apex=False, scaler = None, model_ema=None, mixup_fn=None): + lr_scheduler=None, saver=None, output_dir='', amp_autocast=suppress, + loss_scaler=None, model_ema=None, mixup_fn=None): if args.mixup_off_epoch and epoch >= args.mixup_off_epoch: if args.prefetcher and loader.mixup_enabled: @@ -579,31 +630,21 @@ def train_epoch( input, target = input.cuda(), target.cuda() if mixup_fn is not None: input, target = mixup_fn(input, target) - if not has_apex and use_amp: - with torch.cuda.amp.autocast(): - output = model(input) - loss = loss_fn(output, target) - else: + if args.channels_last: + input = input.contiguous(memory_format=torch.channels_last) + + with amp_autocast(): output = model(input) loss = loss_fn(output, target) - + if not args.distributed: losses_m.update(loss.item(), input.size(0)) optimizer.zero_grad() - if use_amp: - if has_apex: - with amp.scale_loss(loss, optimizer) as scaled_loss: - scaled_loss.backward() - else: - scaler.scale(loss).backward() - + if loss_scaler is not None: + loss_scaler(loss, optimizer) else: loss.backward() - if not has_apex and use_amp: - scaler.step(optimizer) - scaler.update() - else: optimizer.step() torch.cuda.synchronize() @@ -648,8 +689,7 @@ def train_epoch( if saver is not None and args.recovery_interval and ( last_batch or (batch_idx + 1) % args.recovery_interval == 0): - saver.save_recovery( - model, optimizer, args, epoch, model_ema=model_ema, use_amp=has_apex&use_amp, batch_idx=batch_idx) + saver.save_recovery(model, optimizer, args, epoch, model_ema=model_ema, batch_idx=batch_idx) if lr_scheduler is not None: lr_scheduler.step_update(num_updates=num_updates, metric=losses_m.avg) @@ -663,7 +703,7 @@ def train_epoch( return OrderedDict([('loss', losses_m.avg)]) -def validate(model, loader, loss_fn, args, log_suffix=''): +def validate(model, loader, loss_fn, args, amp_autocast=suppress, log_suffix=''): batch_time_m = AverageMeter() losses_m = AverageMeter() top1_m = AverageMeter() @@ -679,8 +719,11 @@ def validate(model, loader, loss_fn, args, log_suffix=''): if not args.prefetcher: input = input.cuda() target = target.cuda() + if args.channels_last: + input = input.contiguous(memory_format=torch.channels_last) - output = model(input) + with amp_autocast(): + output = model(input) if isinstance(output, (tuple, list)): output = output[0] diff --git a/validate.py b/validate.py index b0d411c5..f20c35ba 100755 --- a/validate.py +++ b/validate.py @@ -17,16 +17,25 @@ import torch import torch.nn as nn import torch.nn.parallel from collections import OrderedDict +from contextlib import suppress +from timm.models import create_model, apply_test_time_pool, load_checkpoint, is_model, list_models +from timm.data import Dataset, DatasetTar, create_loader, resolve_data_config, RealLabelsImagenet +from timm.utils import accuracy, AverageMeter, natural_key, setup_default_logging + +has_apex = False try: from apex import amp has_apex = True except ImportError: - has_apex = False + pass -from timm.models import create_model, apply_test_time_pool, load_checkpoint, is_model, list_models -from timm.data import Dataset, DatasetTar, create_loader, resolve_data_config, RealLabelsImagenet -from timm.utils import accuracy, AverageMeter, natural_key, setup_default_logging +has_native_amp = False +try: + if getattr(torch.cuda.amp, 'autocast') is not None: + has_native_amp = True +except AttributeError: + pass torch.backends.cudnn.benchmark = True _logger = logging.getLogger('validate') @@ -69,8 +78,14 @@ parser.add_argument('--no-prefetcher', action='store_true', default=False, help='disable fast prefetcher') parser.add_argument('--pin-mem', action='store_true', default=False, help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.') +parser.add_argument('--channels-last', action='store_true', default=False, + help='Use channels_last memory layout') parser.add_argument('--amp', action='store_true', default=False, - help='Use AMP mixed precision') + help='Use AMP mixed precision. Defaults to Apex, fallback to native Torch AMP.') +parser.add_argument('--apex-amp', action='store_true', default=False, + help='Use NVIDIA Apex AMP mixed precision') +parser.add_argument('--native-amp', action='store_true', default=False, + help='Use Native Torch AMP mixed precision') parser.add_argument('--tf-preprocessing', action='store_true', default=False, help='Use Tensorflow preprocessing pipeline (require CPU TF installed') parser.add_argument('--use-ema', dest='use_ema', action='store_true', @@ -104,6 +119,18 @@ def validate(args): # might as well try to validate something args.pretrained = args.pretrained or not args.checkpoint args.prefetcher = not args.no_prefetcher + amp_autocast = suppress # do nothing + if args.amp: + if has_apex: + args.apex_amp = True + elif has_native_amp: + args.native_amp = True + else: + _logger.warning("Neither APEX or Native Torch AMP is available, using FP32.") + assert not args.apex_amp or not args.native_amp, "Only one AMP mode should be set." + if args.native_amp: + amp_autocast = torch.cuda.amp.autocast + if args.legacy_jit: set_jit_legacy() @@ -128,10 +155,12 @@ def validate(args): torch.jit.optimized_execution(True) model = torch.jit.script(model) - if args.amp: - model = amp.initialize(model.cuda(), opt_level='O1') - else: - model = model.cuda() + model = model.cuda() + if args.apex_amp: + model = amp.initialize(model, opt_level='O1') + + if args.channels_last: + model = model.to(memory_format=torch.channels_last) if args.num_gpu > 1: model = torch.nn.DataParallel(model, device_ids=list(range(args.num_gpu))) @@ -178,17 +207,21 @@ def validate(args): with torch.no_grad(): # warmup, reduce variability of first batch time, especially for comparing torchscript vs non input = torch.randn((args.batch_size,) + data_config['input_size']).cuda() + if args.channels_last: + input = input.contiguous(memory_format=torch.channels_last) model(input) end = time.time() for batch_idx, (input, target) in enumerate(loader): if args.no_prefetcher: target = target.cuda() input = input.cuda() - if args.fp16: - input = input.half() + if args.channels_last: + input = input.contiguous(memory_format=torch.channels_last) # compute output - output = model(input) + with amp_autocast(): + output = model(input) + if valid_labels is not None: output = output[:, valid_labels] loss = criterion(output, target) @@ -197,7 +230,7 @@ def validate(args): real_labels.add_result(output) # measure accuracy and record loss - acc1, acc5 = accuracy(output.data, target, topk=(1, 5)) + acc1, acc5 = accuracy(output.detach(), target, topk=(1, 5)) losses.update(loss.item(), input.size(0)) top1.update(acc1.item(), input.size(0)) top5.update(acc5.item(), input.size(0)) From 110a7c4982db758a33aabd4c7bbea09c08202bc7 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Tue, 1 Sep 2020 16:01:36 -0700 Subject: [PATCH 4/8] AdaptiveAvgPool2d -> mean((2,3)) for all SE/attn layers to avoid NaN with AMP + channels_last layout. See https://github.com/pytorch/pytorch/issues/43992 --- timm/models/efficientnet_blocks.py | 8 +++---- timm/models/layers/cbam.py | 9 ++++---- timm/models/layers/eca.py | 29 ++++++-------------------- timm/models/layers/se.py | 26 ++++++++++------------- timm/models/layers/selective_kernel.py | 4 +--- timm/models/rexnet.py | 9 +++----- timm/models/senet.py | 9 +++----- timm/models/tresnet.py | 5 ++--- 8 files changed, 33 insertions(+), 66 deletions(-) diff --git a/timm/models/efficientnet_blocks.py b/timm/models/efficientnet_blocks.py index d5fdce79..d7421ff4 100644 --- a/timm/models/efficientnet_blocks.py +++ b/timm/models/efficientnet_blocks.py @@ -106,20 +106,18 @@ class SqueezeExcite(nn.Module): def __init__(self, in_chs, se_ratio=0.25, reduced_base_chs=None, act_layer=nn.ReLU, gate_fn=sigmoid, divisor=1, **_): super(SqueezeExcite, self).__init__() - self.gate_fn = gate_fn reduced_chs = make_divisible((reduced_base_chs or in_chs) * se_ratio, divisor) - self.avg_pool = nn.AdaptiveAvgPool2d(1) self.conv_reduce = nn.Conv2d(in_chs, reduced_chs, 1, bias=True) self.act1 = act_layer(inplace=True) self.conv_expand = nn.Conv2d(reduced_chs, in_chs, 1, bias=True) + self.gate_fn = gate_fn def forward(self, x): - x_se = self.avg_pool(x) + x_se = x.mean((2, 3), keepdim=True) x_se = self.conv_reduce(x_se) x_se = self.act1(x_se) x_se = self.conv_expand(x_se) - x = x * self.gate_fn(x_se) - return x + return x * self.gate_fn(x_se) class ConvBnAct(nn.Module): diff --git a/timm/models/layers/cbam.py b/timm/models/layers/cbam.py index 600d51fa..44e2fe6d 100644 --- a/timm/models/layers/cbam.py +++ b/timm/models/layers/cbam.py @@ -10,6 +10,7 @@ Hacked together by / Copyright 2020 Ross Wightman import torch from torch import nn as nn +import torch.nn.functional as F from .conv_bn_act import ConvBnAct @@ -18,15 +19,13 @@ class ChannelAttn(nn.Module): """ def __init__(self, channels, reduction=16, act_layer=nn.ReLU): super(ChannelAttn, self).__init__() - self.avg_pool = nn.AdaptiveAvgPool2d(1) - self.max_pool = nn.AdaptiveMaxPool2d(1) self.fc1 = nn.Conv2d(channels, channels // reduction, 1, bias=False) self.act = act_layer(inplace=True) self.fc2 = nn.Conv2d(channels // reduction, channels, 1, bias=False) def forward(self, x): - x_avg = self.avg_pool(x) - x_max = self.max_pool(x) + x_avg = x.mean((2, 3), keepdim=True) + x_max = F.adaptive_max_pool2d(x, 1) x_avg = self.fc2(self.act(self.fc1(x_avg))) x_max = self.fc2(self.act(self.fc1(x_max))) x_attn = x_avg + x_max @@ -40,7 +39,7 @@ class LightChannelAttn(ChannelAttn): super(LightChannelAttn, self).__init__(channels, reduction) def forward(self, x): - x_pool = 0.5 * self.avg_pool(x) + 0.5 * self.max_pool(x) + x_pool = 0.5 * x.mean((2, 3), keepdim=True) + 0.5 * F.adaptive_max_pool2d(x, 1) x_attn = self.fc2(self.act(self.fc1(x_pool))) return x * x_attn.sigmoid() diff --git a/timm/models/layers/eca.py b/timm/models/layers/eca.py index 497c5b04..3a7f8b82 100644 --- a/timm/models/layers/eca.py +++ b/timm/models/layers/eca.py @@ -52,22 +52,15 @@ class EcaModule(nn.Module): def __init__(self, channels=None, kernel_size=3, gamma=2, beta=1): super(EcaModule, self).__init__() assert kernel_size % 2 == 1 - if channels is not None: t = int(abs(math.log(channels, 2) + beta) / gamma) kernel_size = max(t if t % 2 else t + 1, 3) - self.avg_pool = nn.AdaptiveAvgPool2d(1) self.conv = nn.Conv1d(1, 1, kernel_size=kernel_size, padding=(kernel_size - 1) // 2, bias=False) def forward(self, x): - # Feature descriptor on the global spatial information - y = self.avg_pool(x) - # Reshape for convolution - y = y.view(x.shape[0], 1, -1) - # Two different branches of ECA module + y = x.mean((2, 3)).view(x.shape[0], 1, -1) # view for 1d conv y = self.conv(y) - # Multi-scale information fusion y = y.view(x.shape[0], -1, 1, 1).sigmoid() return x * y.expand_as(x) @@ -95,30 +88,20 @@ class CecaModule(nn.Module): def __init__(self, channels=None, kernel_size=3, gamma=2, beta=1): super(CecaModule, self).__init__() assert kernel_size % 2 == 1 - if channels is not None: t = int(abs(math.log(channels, 2) + beta) / gamma) kernel_size = max(t if t % 2 else t + 1, 3) - self.avg_pool = nn.AdaptiveAvgPool2d(1) - #pytorch circular padding mode is buggy as of pytorch 1.4 - #see https://github.com/pytorch/pytorch/pull/17240 - - #implement manual circular padding + # PyTorch circular padding mode is buggy as of pytorch 1.4 + # see https://github.com/pytorch/pytorch/pull/17240 + # implement manual circular padding self.conv = nn.Conv1d(1, 1, kernel_size=kernel_size, padding=0, bias=False) self.padding = (kernel_size - 1) // 2 def forward(self, x): - # Feature descriptor on the global spatial information - y = self.avg_pool(x) - + y = x.mean((2, 3)).view(x.shape[0], 1, -1) # Manually implement circular padding, F.pad does not seemed to be bugged - y = F.pad(y.view(x.shape[0], 1, -1), (self.padding, self.padding), mode='circular') - - # Two different branches of ECA module + y = F.pad(y, (self.padding, self.padding), mode='circular') y = self.conv(y) - - # Multi-scale information fusion y = y.view(x.shape[0], -1, 1, 1).sigmoid() - return x * y.expand_as(x) diff --git a/timm/models/layers/se.py b/timm/models/layers/se.py index 578ebf08..a896fb71 100644 --- a/timm/models/layers/se.py +++ b/timm/models/layers/se.py @@ -1,40 +1,36 @@ from torch import nn as nn -from .create_act import get_act_fn +from .create_act import create_act_layer class SEModule(nn.Module): def __init__(self, channels, reduction=16, act_layer=nn.ReLU, min_channels=8, reduction_channels=None, - gate_fn='sigmoid'): + gate_layer='sigmoid'): super(SEModule, self).__init__() - self.avg_pool = nn.AdaptiveAvgPool2d(1) reduction_channels = reduction_channels or max(channels // reduction, min_channels) - self.fc1 = nn.Conv2d( - channels, reduction_channels, kernel_size=1, padding=0, bias=True) + self.fc1 = nn.Conv2d(channels, reduction_channels, kernel_size=1, bias=True) self.act = act_layer(inplace=True) - self.fc2 = nn.Conv2d( - reduction_channels, channels, kernel_size=1, padding=0, bias=True) - self.gate_fn = get_act_fn(gate_fn) + self.fc2 = nn.Conv2d(reduction_channels, channels, kernel_size=1, bias=True) + self.gate = create_act_layer(gate_layer) def forward(self, x): - x_se = self.avg_pool(x) + x_se = x.mean((2, 3), keepdim=True) x_se = self.fc1(x_se) x_se = self.act(x_se) x_se = self.fc2(x_se) - return x * self.gate_fn(x_se) + return x * self.gate(x_se) class EffectiveSEModule(nn.Module): """ 'Effective Squeeze-Excitation From `CenterMask : Real-Time Anchor-Free Instance Segmentation` - https://arxiv.org/abs/1911.06667 """ - def __init__(self, channels, gate_fn='hard_sigmoid'): + def __init__(self, channels, gate_layer='hard_sigmoid'): super(EffectiveSEModule, self).__init__() - self.avg_pool = nn.AdaptiveAvgPool2d(1) self.fc = nn.Conv2d(channels, channels, kernel_size=1, padding=0) - self.gate_fn = get_act_fn(gate_fn) + self.gate = create_act_layer(gate_layer, inplace=True) def forward(self, x): - x_se = self.avg_pool(x) + x_se = x.mean((2, 3), keepdim=True) x_se = self.fc(x_se) - return x * self.gate_fn(x_se, inplace=True) + return x * self.gate(x_se) diff --git a/timm/models/layers/selective_kernel.py b/timm/models/layers/selective_kernel.py index 2efaa487..10bfd0e0 100644 --- a/timm/models/layers/selective_kernel.py +++ b/timm/models/layers/selective_kernel.py @@ -27,7 +27,6 @@ class SelectiveKernelAttn(nn.Module): """ super(SelectiveKernelAttn, self).__init__() self.num_paths = num_paths - self.pool = nn.AdaptiveAvgPool2d(1) self.fc_reduce = nn.Conv2d(channels, attn_channels, kernel_size=1, bias=False) self.bn = norm_layer(attn_channels) self.act = act_layer(inplace=True) @@ -35,8 +34,7 @@ class SelectiveKernelAttn(nn.Module): def forward(self, x): assert x.shape[1] == self.num_paths - x = torch.sum(x, dim=1) - x = self.pool(x) + x = x.sum(1).mean((2, 3), keepdim=True) x = self.fc_reduce(x) x = self.bn(x) x = self.act(x) diff --git a/timm/models/rexnet.py b/timm/models/rexnet.py index 03e6ee02..b7522a05 100644 --- a/timm/models/rexnet.py +++ b/timm/models/rexnet.py @@ -59,18 +59,15 @@ class SEWithNorm(nn.Module): def __init__(self, channels, reduction=16, act_layer=nn.ReLU, divisor=1, reduction_channels=None, gate_layer='sigmoid'): super(SEWithNorm, self).__init__() - self.avg_pool = nn.AdaptiveAvgPool2d(1) reduction_channels = reduction_channels or make_divisible(channels // reduction, divisor=divisor) - self.fc1 = nn.Conv2d( - channels, reduction_channels, kernel_size=1, padding=0, bias=True) + self.fc1 = nn.Conv2d(channels, reduction_channels, kernel_size=1, bias=True) self.bn = nn.BatchNorm2d(reduction_channels) self.act = act_layer(inplace=True) - self.fc2 = nn.Conv2d( - reduction_channels, channels, kernel_size=1, padding=0, bias=True) + self.fc2 = nn.Conv2d(reduction_channels, channels, kernel_size=1, bias=True) self.gate = create_act_layer(gate_layer) def forward(self, x): - x_se = self.avg_pool(x) + x_se = x.mean((2, 3), keepdim=True) x_se = self.fc1(x_se) x_se = self.bn(x_se) x_se = self.act(x_se) diff --git a/timm/models/senet.py b/timm/models/senet.py index 2155ec81..8073229a 100644 --- a/timm/models/senet.py +++ b/timm/models/senet.py @@ -71,17 +71,14 @@ class SEModule(nn.Module): def __init__(self, channels, reduction): super(SEModule, self).__init__() - self.avg_pool = nn.AdaptiveAvgPool2d(1) - self.fc1 = nn.Conv2d( - channels, channels // reduction, kernel_size=1, padding=0) + self.fc1 = nn.Conv2d(channels, channels // reduction, kernel_size=1) self.relu = nn.ReLU(inplace=True) - self.fc2 = nn.Conv2d( - channels // reduction, channels, kernel_size=1, padding=0) + self.fc2 = nn.Conv2d(channels // reduction, channels, kernel_size=1) self.sigmoid = nn.Sigmoid() def forward(self, x): module_input = x - x = self.avg_pool(x) + x = x.mean((2, 3), keepdim=True) x = self.fc1(x) x = self.relu(x) x = self.fc2(x) diff --git a/timm/models/tresnet.py b/timm/models/tresnet.py index e060c20d..75b545e5 100644 --- a/timm/models/tresnet.py +++ b/timm/models/tresnet.py @@ -56,10 +56,9 @@ class FastGlobalAvgPool2d(nn.Module): def forward(self, x): if self.flatten: - in_size = x.size() - return x.view((in_size[0], in_size[1], -1)).mean(dim=2) + return x.mean((2, 3)) else: - return x.view(x.size(0), x.size(1), -1).mean(-1).view(x.size(0), x.size(1), 1, 1) + return x.mean((2, 3), keepdim=True) def feat_mult(self): return 1 From 90a01f47d19838f52142f00805d56b8e94a6ea14 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Tue, 1 Sep 2020 17:37:55 -0700 Subject: [PATCH 5/8] hrnet features_only pretrained weight loading issue. Fix #232. --- tests/test_models.py | 6 ++++++ timm/models/hrnet.py | 4 +++- 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/tests/test_models.py b/tests/test_models.py index 71d643dd..d6fcaf79 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -120,6 +120,12 @@ if 'GITHUB_ACTIONS' not in os.environ: in_chans = 3 if 'pruned' in model_name else 1 # pruning not currently supported with in_chans change create_model(model_name, pretrained=True, in_chans=in_chans) + @pytest.mark.timeout(120) + @pytest.mark.parametrize('model_name', list_models(pretrained=True)) + @pytest.mark.parametrize('batch_size', [1]) + def test_model_features_pretrained(model_name, batch_size): + """Create that pretrained weights load when features_only==True.""" + create_model(model_name, pretrained=True, features_only=True) EXCLUDE_JIT_FILTERS = [ '*iabn*', 'tresnet*', # models using inplace abn unlikely to ever be scriptable diff --git a/timm/models/hrnet.py b/timm/models/hrnet.py index ad865887..1e867686 100644 --- a/timm/models/hrnet.py +++ b/timm/models/hrnet.py @@ -773,12 +773,14 @@ class HighResolutionNetFeatures(HighResolutionNet): def _create_hrnet(variant, pretrained, **model_kwargs): model_cls = HighResolutionNet + strict = True if model_kwargs.pop('features_only', False): model_cls = HighResolutionNetFeatures + strict = False return build_model_with_cfg( model_cls, variant, pretrained, default_cfg=default_cfgs[variant], - model_cfg=cfg_cls[variant], **model_kwargs) + model_cfg=cfg_cls[variant], pretrained_strict=strict, **model_kwargs) @register_model From 80c9d9cc724f0bb00b4d96527db835179c6c02a9 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Wed, 2 Sep 2020 09:11:48 -0700 Subject: [PATCH 6/8] Add 'fast' global pool option, remove redundant SEModule from tresnet, normal one is now 'fast' --- timm/models/layers/adaptive_avgmax_pool.py | 15 ++++++- timm/models/tresnet.py | 48 ++++------------------ 2 files changed, 21 insertions(+), 42 deletions(-) diff --git a/timm/models/layers/adaptive_avgmax_pool.py b/timm/models/layers/adaptive_avgmax_pool.py index 482c0c01..d2bb9f72 100644 --- a/timm/models/layers/adaptive_avgmax_pool.py +++ b/timm/models/layers/adaptive_avgmax_pool.py @@ -49,6 +49,15 @@ def select_adaptive_pool2d(x, pool_type='avg', output_size=1): return x +class FastAdaptiveAvgPool2d(nn.Module): + def __init__(self, flatten=False): + super(FastAdaptiveAvgPool2d, self).__init__() + self.flatten = flatten + + def forward(self, x): + return x.mean((2, 3)) if self.flatten else x.mean((2, 3), keepdim=True) + + class AdaptiveAvgMaxPool2d(nn.Module): def __init__(self, output_size=1): super(AdaptiveAvgMaxPool2d, self).__init__() @@ -70,12 +79,16 @@ class AdaptiveCatAvgMaxPool2d(nn.Module): class SelectAdaptivePool2d(nn.Module): """Selectable global pooling layer with dynamic input kernel size """ - def __init__(self, output_size=1, pool_type='avg', flatten=False): + def __init__(self, output_size=1, pool_type='fast', flatten=False): super(SelectAdaptivePool2d, self).__init__() self.pool_type = pool_type or '' # convert other falsy values to empty string for consistent TS typing self.flatten = flatten if pool_type == '': self.pool = nn.Identity() # pass through + elif pool_type == 'fast': + assert output_size == 1 + self.pool = FastAdaptiveAvgPool2d(self.flatten) + self.flatten = False elif pool_type == 'avg': self.pool = nn.AdaptiveAvgPool2d(output_size) elif pool_type == 'avgmax': diff --git a/timm/models/tresnet.py b/timm/models/tresnet.py index 75b545e5..e371292f 100644 --- a/timm/models/tresnet.py +++ b/timm/models/tresnet.py @@ -14,7 +14,7 @@ import torch.nn as nn import torch.nn.functional as F from .helpers import build_model_with_cfg -from .layers import SpaceToDepthModule, AntiAliasDownsampleLayer, InplaceAbn, ClassifierHead +from .layers import SpaceToDepthModule, AntiAliasDownsampleLayer, InplaceAbn, ClassifierHead, SEModule from .registry import register_model __all__ = ['tresnet_m', 'tresnet_l', 'tresnet_xl'] @@ -49,40 +49,6 @@ default_cfgs = { } -class FastGlobalAvgPool2d(nn.Module): - def __init__(self, flatten=False): - super(FastGlobalAvgPool2d, self).__init__() - self.flatten = flatten - - def forward(self, x): - if self.flatten: - return x.mean((2, 3)) - else: - return x.mean((2, 3), keepdim=True) - - def feat_mult(self): - return 1 - - -class FastSEModule(nn.Module): - - def __init__(self, channels, reduction_channels, inplace=True): - super(FastSEModule, self).__init__() - self.avg_pool = FastGlobalAvgPool2d() - self.fc1 = nn.Conv2d(channels, reduction_channels, kernel_size=1, padding=0, bias=True) - self.relu = nn.ReLU(inplace=inplace) - self.fc2 = nn.Conv2d(reduction_channels, channels, kernel_size=1, padding=0, bias=True) - self.activation = nn.Sigmoid() - - def forward(self, x): - x_se = self.avg_pool(x) - x_se2 = self.fc1(x_se) - x_se2 = self.relu(x_se2) - x_se = self.fc2(x_se2) - x_se = self.activation(x_se) - return x * x_se - - def IABN2Float(module: nn.Module) -> nn.Module: """If `module` is IABN don't use half precision.""" if isinstance(module, InplaceAbn): @@ -119,8 +85,8 @@ class BasicBlock(nn.Module): self.relu = nn.ReLU(inplace=True) self.downsample = downsample self.stride = stride - reduce_layer_planes = max(planes * self.expansion // 4, 64) - self.se = FastSEModule(planes * self.expansion, reduce_layer_planes) if use_se else None + reduction_chs = max(planes * self.expansion // 4, 64) + self.se = SEModule(planes * self.expansion, reduction_channels=reduction_chs) if use_se else None def forward(self, x): if self.downsample is not None: @@ -159,8 +125,8 @@ class Bottleneck(nn.Module): conv2d_iabn(planes, planes, kernel_size=3, stride=1, act_layer=act_layer, act_param=1e-3), aa_layer(channels=planes, filt_size=3, stride=2)) - reduce_layer_planes = max(planes * self.expansion // 8, 64) - self.se = FastSEModule(planes, reduce_layer_planes) if use_se else None + reduction_chs = max(planes * self.expansion // 8, 64) + self.se = SEModule(planes, reduction_channels=reduction_chs) if use_se else None self.conv3 = conv2d_iabn( planes, planes * self.expansion, kernel_size=1, stride=1, act_layer="identity") @@ -189,7 +155,7 @@ class Bottleneck(nn.Module): class TResNet(nn.Module): def __init__(self, layers, in_chans=3, num_classes=1000, width_factor=1.0, no_aa_jit=False, - global_pool='avg', drop_rate=0.): + global_pool='fast', drop_rate=0.): self.num_classes = num_classes self.drop_rate = drop_rate super(TResNet, self).__init__() @@ -272,7 +238,7 @@ class TResNet(nn.Module): def get_classifier(self): return self.head.fc - def reset_classifier(self, num_classes, global_pool='avg'): + def reset_classifier(self, num_classes, global_pool='fast'): self.head = ClassifierHead( self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate) From 9c297ec67d0be9d2ed0a1c998be6387159e3eb17 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Wed, 2 Sep 2020 15:12:59 -0700 Subject: [PATCH 7/8] Cleanup Apex vs native AMP scaler state save/load. Cleanup CheckpointSaver a bit. --- timm/models/helpers.py | 29 +++++++++---- timm/utils.py | 96 ++++++++++++++++++++++++++++++++++-------- train.py | 46 +++++--------------- validate.py | 15 +------ 4 files changed, 111 insertions(+), 75 deletions(-) diff --git a/timm/models/helpers.py b/timm/models/helpers.py index f1702af0..ac119295 100644 --- a/timm/models/helpers.py +++ b/timm/models/helpers.py @@ -48,30 +48,41 @@ def load_checkpoint(model, checkpoint_path, use_ema=False, strict=True): model.load_state_dict(state_dict, strict=strict) -def resume_checkpoint(model, checkpoint_path): - other_state = {} +def resume_checkpoint(model, checkpoint_path, optimizer=None, loss_scaler=None, log_info=True): resume_epoch = None if os.path.isfile(checkpoint_path): checkpoint = torch.load(checkpoint_path, map_location='cpu') if isinstance(checkpoint, dict) and 'state_dict' in checkpoint: + if log_info: + _logger.info('Restoring model state from checkpoint...') new_state_dict = OrderedDict() for k, v in checkpoint['state_dict'].items(): name = k[7:] if k.startswith('module') else k new_state_dict[name] = v model.load_state_dict(new_state_dict) - if 'optimizer' in checkpoint: - other_state['optimizer'] = checkpoint['optimizer'] - if 'amp' in checkpoint: - other_state['amp'] = checkpoint['amp'] + + if optimizer is not None and 'optimizer' in checkpoint: + if log_info: + _logger.info('Restoring optimizer state from checkpoint...') + optimizer.load_state_dict(checkpoint['optimizer']) + + if loss_scaler is not None and loss_scaler.state_dict_key in checkpoint: + if log_info: + _logger.info('Restoring AMP loss scaler state from checkpoint...') + loss_scaler.load_state_dict(checkpoint[loss_scaler.state_dict_key]) + if 'epoch' in checkpoint: resume_epoch = checkpoint['epoch'] if 'version' in checkpoint and checkpoint['version'] > 1: resume_epoch += 1 # start at the next epoch, old checkpoints incremented before save - _logger.info("Loaded checkpoint '{}' (epoch {})".format(checkpoint_path, checkpoint['epoch'])) + + if log_info: + _logger.info("Loaded checkpoint '{}' (epoch {})".format(checkpoint_path, checkpoint['epoch'])) else: model.load_state_dict(checkpoint) - _logger.info("Loaded checkpoint '{}'".format(checkpoint_path)) - return other_state, resume_epoch + if log_info: + _logger.info("Loaded checkpoint '{}'".format(checkpoint_path)) + return resume_epoch else: _logger.error("No checkpoint found at '{}'".format(checkpoint_path)) raise FileNotFoundError() diff --git a/timm/utils.py b/timm/utils.py index 6c10d283..94f85d84 100644 --- a/timm/utils.py +++ b/timm/utils.py @@ -37,20 +37,67 @@ def unwrap_model(model): return model.module if hasattr(model, 'module') else model -def get_state_dict(model): - return unwrap_model(model).state_dict() +def get_state_dict(model, unwrap_fn=unwrap_model): + return unwrap_fn(model).state_dict() + + +class ApexScaler: + state_dict_key = "amp" + + def __call__(self, loss, optimizer): + with amp.scale_loss(loss, optimizer) as scaled_loss: + scaled_loss.backward() + optimizer.step() + + def state_dict(self): + if 'state_dict' in amp.__dict__: + return amp.state_dict() + + def load_state_dict(self, state_dict): + if 'load_state_dict' in amp.__dict__: + amp.load_state_dict(state_dict) + + +class NativeScaler: + state_dict_key = "amp_scaler" + + def __init__(self): + self._scaler = torch.cuda.amp.GradScaler() + + def __call__(self, loss, optimizer): + self._scaler.scale(loss).backward() + self._scaler.step(optimizer) + self._scaler.update() + + def state_dict(self): + return self._scaler.state_dict() + + def load_state_dict(self, state_dict): + self._scaler.load_state_dict(state_dict) class CheckpointSaver: def __init__( self, + model, + optimizer, + args=None, + model_ema=None, + amp_scaler=None, checkpoint_prefix='checkpoint', recovery_prefix='recovery', checkpoint_dir='', recovery_dir='', decreasing=False, max_history=10, - save_amp=False): + unwrap_fn=unwrap_model): + + # objects to save state_dicts of + self.model = model + self.optimizer = optimizer + self.args = args + self.model_ema = model_ema + self.amp_scaler = amp_scaler # state self.checkpoint_files = [] # (filename, metric) tuples in order of decreasing betterness @@ -68,14 +115,14 @@ class CheckpointSaver: self.decreasing = decreasing # a lower metric is better if True self.cmp = operator.lt if decreasing else operator.gt # True if lhs better than rhs self.max_history = max_history - self.save_apex_amp = save_amp # save APEX amp state + self.unwrap_fn = unwrap_fn assert self.max_history >= 1 - def save_checkpoint(self, model, optimizer, args, epoch, model_ema=None, metric=None): + def save_checkpoint(self, epoch, metric=None): assert epoch >= 0 tmp_save_path = os.path.join(self.checkpoint_dir, 'tmp' + self.extension) last_save_path = os.path.join(self.checkpoint_dir, 'last' + self.extension) - self._save(tmp_save_path, model, optimizer, args, epoch, model_ema, metric) + self._save(tmp_save_path, epoch, metric) if os.path.exists(last_save_path): os.unlink(last_save_path) # required for Windows support. os.rename(tmp_save_path, last_save_path) @@ -107,19 +154,21 @@ class CheckpointSaver: return (None, None) if self.best_metric is None else (self.best_metric, self.best_epoch) - def _save(self, save_path, model, optimizer, args, epoch, model_ema=None, metric=None): + def _save(self, save_path, epoch, metric=None): save_state = { 'epoch': epoch, - 'arch': args.model, - 'state_dict': get_state_dict(model), - 'optimizer': optimizer.state_dict(), - 'args': args, + 'arch': type(self.model).__name__.lower(), + 'state_dict': get_state_dict(self.model, self.unwrap_fn), + 'optimizer': self.optimizer.state_dict(), 'version': 2, # version < 2 increments epoch before save } - if self.save_apex_amp and 'state_dict' in amp.__dict__: - save_state['amp'] = amp.state_dict() - if model_ema is not None: - save_state['state_dict_ema'] = get_state_dict(model_ema) + if self.args is not None: + save_state['arch'] = self.args.model + save_state['args'] = self.args + if self.amp_scaler is not None: + save_state[self.amp_scaler.state_dict_key] = self.amp_scaler.state_dict() + if self.model_ema is not None: + save_state['state_dict_ema'] = get_state_dict(self.model_ema, self.unwrap_fn) if metric is not None: save_state['metric'] = metric torch.save(save_state, save_path) @@ -138,11 +187,11 @@ class CheckpointSaver: _logger.error("Exception '{}' while deleting checkpoint".format(e)) self.checkpoint_files = self.checkpoint_files[:delete_index] - def save_recovery(self, model, optimizer, args, epoch, model_ema=None, batch_idx=0): + def save_recovery(self, epoch, batch_idx=0): assert epoch >= 0 filename = '-'.join([self.recovery_prefix, str(epoch), str(batch_idx)]) + self.extension save_path = os.path.join(self.recovery_dir, filename) - self._save(save_path, model, optimizer, args, epoch, model_ema) + self._save(save_path, epoch) if os.path.exists(self.last_recovery_file): try: _logger.debug("Cleaning recovery: {}".format(self.last_recovery_file)) @@ -336,3 +385,16 @@ def add_bool_arg(parser, name, default=False, help=''): group.add_argument('--' + name, dest=dest_name, action='store_true', help=help) group.add_argument('--no-' + name, dest=dest_name, action='store_false', help=help) parser.set_defaults(**{dest_name: default}) + + +def set_jit_legacy(): + """ Set JIT executor to legacy w/ support for op fusion + This is hopefully a temporary need in 1.5/1.5.1/1.6 to restore performance due to changes + in the JIT exectutor. These API are not supported so could change. + """ + # + assert hasattr(torch._C, '_jit_set_profiling_executor'), "Old JIT behavior doesn't exist!" + torch._C._jit_set_profiling_executor(False) + torch._C._jit_set_profiling_mode(False) + torch._C._jit_override_can_fuse_on_gpu(True) + #torch._C._jit_set_texpr_fuser_enabled(True) diff --git a/train.py b/train.py index e017605d..ff421d53 100755 --- a/train.py +++ b/train.py @@ -20,7 +20,6 @@ import yaml from datetime import datetime from contextlib import suppress -import torch import torch.nn as nn import torchvision.utils from torch.nn.parallel import DistributedDataParallel as NativeDDP @@ -31,6 +30,7 @@ from timm.utils import * from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy, JsdCrossEntropy from timm.optim import create_optimizer from timm.scheduler import create_scheduler +from timm.utils import ApexScaler, NativeScaler try: from apex import amp @@ -264,23 +264,6 @@ def _parse_args(): return args, args_text -class ApexScaler: - def __call__(self, loss, optimizer): - with amp.scale_loss(loss, optimizer) as scaled_loss: - scaled_loss.backward() - optimizer.step() - - -class NativeScaler: - def __init__(self): - self._scaler = torch.cuda.amp.GradScaler() - - def __call__(self, loss, optimizer): - self._scaler.scale(loss).backward() - self._scaler.step(optimizer) - self._scaler.update() - - def main(): setup_default_logging() args, args_text = _parse_args() @@ -389,20 +372,13 @@ def main(): _logger.info('AMP not enabled. Training in float32.') # optionally resume from a checkpoint - resume_state = {} resume_epoch = None if args.resume: - resume_state, resume_epoch = resume_checkpoint(model, args.resume) - if resume_state and not args.no_resume_opt: - if 'optimizer' in resume_state: - if args.local_rank == 0: - _logger.info('Restoring optimizer state from checkpoint') - optimizer.load_state_dict(resume_state['optimizer']) - if use_amp and 'amp' in resume_state and 'load_state_dict' in amp.__dict__: - if args.local_rank == 0: - _logger.info('Restoring NVIDIA AMP state from checkpoint') - amp.load_state_dict(resume_state['amp']) - del resume_state + resume_epoch = resume_checkpoint( + model, args.resume, + optimizer=None if args.no_resume_opt else optimizer, + loss_scaler=None if args.no_resume_opt else loss_scaler, + log_info=args.local_rank == 0) model_ema = None if args.model_ema: @@ -555,7 +531,9 @@ def main(): ]) output_dir = get_outdir(output_base, 'train', exp_name) decreasing = True if eval_metric == 'loss' else False - saver = CheckpointSaver(checkpoint_dir=output_dir, decreasing=decreasing, save_amp=use_amp == 'apex') + saver = CheckpointSaver( + model=model, optimizer=optimizer, args=args, model_ema=model_ema, amp_scaler=loss_scaler, + checkpoint_dir=output_dir, recovery_dir=output_dir, decreasing=decreasing) with open(os.path.join(output_dir, 'args.yaml'), 'w') as f: f.write(args_text) @@ -594,8 +572,7 @@ def main(): if saver is not None: # save proper checkpoint with eval metric save_metric = eval_metrics[eval_metric] - best_metric, best_epoch = saver.save_checkpoint( - model, optimizer, args, epoch=epoch, model_ema=model_ema, metric=save_metric) + best_metric, best_epoch = saver.save_checkpoint(epoch, metric=save_metric) except KeyboardInterrupt: pass @@ -688,8 +665,7 @@ def train_epoch( if saver is not None and args.recovery_interval and ( last_batch or (batch_idx + 1) % args.recovery_interval == 0): - - saver.save_recovery(model, optimizer, args, epoch, model_ema=model_ema, batch_idx=batch_idx) + saver.save_recovery(epoch, batch_idx=batch_idx) if lr_scheduler is not None: lr_scheduler.step_update(num_updates=num_updates, metric=losses_m.avg) diff --git a/validate.py b/validate.py index f20c35ba..587ac6a6 100755 --- a/validate.py +++ b/validate.py @@ -21,7 +21,7 @@ from contextlib import suppress from timm.models import create_model, apply_test_time_pool, load_checkpoint, is_model, list_models from timm.data import Dataset, DatasetTar, create_loader, resolve_data_config, RealLabelsImagenet -from timm.utils import accuracy, AverageMeter, natural_key, setup_default_logging +from timm.utils import accuracy, AverageMeter, natural_key, setup_default_logging, set_jit_legacy has_apex = False try: @@ -102,19 +102,6 @@ parser.add_argument('--valid-labels', default='', type=str, metavar='FILENAME', help='Valid label indices txt file for validation of partial label space') -def set_jit_legacy(): - """ Set JIT executor to legacy w/ support for op fusion - This is hopefully a temporary need in 1.5/1.5.1/1.6 to restore performance due to changes - in the JIT exectutor. These API are not supported so could change. - """ - # - assert hasattr(torch._C, '_jit_set_profiling_executor'), "Old JIT behavior doesn't exist!" - torch._C._jit_set_profiling_executor(False) - torch._C._jit_set_profiling_mode(False) - torch._C._jit_override_can_fuse_on_gpu(True) - #torch._C._jit_set_texpr_fuser_enabled(True) - - def validate(args): # might as well try to validate something args.pretrained = args.pretrained or not args.checkpoint From 751b0bba9861fa2b8670d1b019c3fc4d9fb8381f Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Wed, 2 Sep 2020 16:13:47 -0700 Subject: [PATCH 8/8] Add global_pool (--gp) arg changes to allow passing 'fast' easily for train/validate to avoid channels_last issue with AdaptiveAvgPool --- timm/models/factory.py | 11 ++++------- train.py | 4 ++-- validate.py | 3 +++ 3 files changed, 9 insertions(+), 9 deletions(-) diff --git a/timm/models/factory.py b/timm/models/factory.py index 03d8cc1f..70209c96 100644 --- a/timm/models/factory.py +++ b/timm/models/factory.py @@ -39,11 +39,6 @@ def create_model( kwargs.pop('bn_momentum', None) kwargs.pop('bn_eps', None) - # Parameters that aren't supported by all models should default to None in command line args, - # remove them if they are present and not set so that non-supporting models don't break. - if kwargs.get('drop_block_rate', None) is None: - kwargs.pop('drop_block_rate', None) - # handle backwards compat with drop_connect -> drop_path change drop_connect_rate = kwargs.pop('drop_connect_rate', None) if drop_connect_rate is not None and kwargs.get('drop_path_rate', None) is None: @@ -51,8 +46,10 @@ def create_model( " Setting drop_path to %f." % drop_connect_rate) kwargs['drop_path_rate'] = drop_connect_rate - if kwargs.get('drop_path_rate', None) is None: - kwargs.pop('drop_path_rate', None) + # Parameters that aren't supported by all models or are intended to only override model defaults if set + # should default to None in command line args/cfg. Remove them if they are present and not set so that + # non-supporting models don't break and default args remain in effect. + kwargs = {k: v for k, v in kwargs.items() if v is not None} with set_layer_config(scriptable=scriptable, exportable=exportable, no_jit=no_jit): if is_model(model_name): diff --git a/train.py b/train.py index ff421d53..260de18b 100755 --- a/train.py +++ b/train.py @@ -74,8 +74,8 @@ parser.add_argument('--no-resume-opt', action='store_true', default=False, help='prevent resume of optimizer state when resuming model') parser.add_argument('--num-classes', type=int, default=1000, metavar='N', help='number of label classes (default: 1000)') -parser.add_argument('--gp', default='avg', type=str, metavar='POOL', - help='Type of global pool, "avg", "max", "avgmax", "avgmaxc" (default: "avg")') +parser.add_argument('--gp', default=None, type=str, metavar='POOL', + help='Global pool type, one of (fast, avg, max, avgmax, avgmaxc). Model default if None.') parser.add_argument('--img-size', type=int, default=None, metavar='N', help='Image patch size (default: None => model default)') parser.add_argument('--crop-pct', default=None, type=float, diff --git a/validate.py b/validate.py index 587ac6a6..ceca7014 100755 --- a/validate.py +++ b/validate.py @@ -64,6 +64,8 @@ parser.add_argument('--num-classes', type=int, default=1000, help='Number classes in dataset') parser.add_argument('--class-map', default='', type=str, metavar='FILENAME', help='path to class to idx mapping file (default: "")') +parser.add_argument('--gp', default=None, type=str, metavar='POOL', + help='Global pool type, one of (fast, avg, max, avgmax, avgmaxc). Model default if None.') parser.add_argument('--log-freq', default=10, type=int, metavar='N', help='batch logging frequency (default: 10)') parser.add_argument('--checkpoint', default='', type=str, metavar='PATH', @@ -127,6 +129,7 @@ def validate(args): pretrained=args.pretrained, num_classes=args.num_classes, in_chans=3, + global_pool=args.gp, scriptable=args.torchscript) if args.checkpoint: