From c2cd1a332eca0f109b7fb357aa98eb7a7bfabc11 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Mon, 31 Aug 2020 17:58:16 -0700 Subject: [PATCH] Improve torch amp support and add channels_last support for train/validate scripts --- timm/utils.py | 16 +++-- train.py | 177 +++++++++++++++++++++++++++++++------------------- validate.py | 59 +++++++++++++---- 3 files changed, 165 insertions(+), 87 deletions(-) diff --git a/timm/utils.py b/timm/utils.py index 65255b53..6c10d283 100644 --- a/timm/utils.py +++ b/timm/utils.py @@ -49,7 +49,8 @@ class CheckpointSaver: checkpoint_dir='', recovery_dir='', decreasing=False, - max_history=10): + max_history=10, + save_amp=False): # state self.checkpoint_files = [] # (filename, metric) tuples in order of decreasing betterness @@ -67,13 +68,14 @@ class CheckpointSaver: self.decreasing = decreasing # a lower metric is better if True self.cmp = operator.lt if decreasing else operator.gt # True if lhs better than rhs self.max_history = max_history + self.save_apex_amp = save_amp # save APEX amp state assert self.max_history >= 1 - def save_checkpoint(self, model, optimizer, args, epoch, model_ema=None, metric=None, use_amp=False): + def save_checkpoint(self, model, optimizer, args, epoch, model_ema=None, metric=None): assert epoch >= 0 tmp_save_path = os.path.join(self.checkpoint_dir, 'tmp' + self.extension) last_save_path = os.path.join(self.checkpoint_dir, 'last' + self.extension) - self._save(tmp_save_path, model, optimizer, args, epoch, model_ema, metric, use_amp) + self._save(tmp_save_path, model, optimizer, args, epoch, model_ema, metric) if os.path.exists(last_save_path): os.unlink(last_save_path) # required for Windows support. os.rename(tmp_save_path, last_save_path) @@ -105,7 +107,7 @@ class CheckpointSaver: return (None, None) if self.best_metric is None else (self.best_metric, self.best_epoch) - def _save(self, save_path, model, optimizer, args, epoch, model_ema=None, metric=None, use_amp=False): + def _save(self, save_path, model, optimizer, args, epoch, model_ema=None, metric=None): save_state = { 'epoch': epoch, 'arch': args.model, @@ -114,7 +116,7 @@ class CheckpointSaver: 'args': args, 'version': 2, # version < 2 increments epoch before save } - if use_amp and 'state_dict' in amp.__dict__: + if self.save_apex_amp and 'state_dict' in amp.__dict__: save_state['amp'] = amp.state_dict() if model_ema is not None: save_state['state_dict_ema'] = get_state_dict(model_ema) @@ -136,11 +138,11 @@ class CheckpointSaver: _logger.error("Exception '{}' while deleting checkpoint".format(e)) self.checkpoint_files = self.checkpoint_files[:delete_index] - def save_recovery(self, model, optimizer, args, epoch, model_ema=None, use_amp=False, batch_idx=0): + def save_recovery(self, model, optimizer, args, epoch, model_ema=None, batch_idx=0): assert epoch >= 0 filename = '-'.join([self.recovery_prefix, str(epoch), str(batch_idx)]) + self.extension save_path = os.path.join(self.recovery_dir, filename) - self._save(save_path, model, optimizer, args, epoch, model_ema, use_amp=use_amp) + self._save(save_path, model, optimizer, args, epoch, model_ema) if os.path.exists(self.last_recovery_file): try: _logger.debug("Cleaning recovery: {}".format(self.last_recovery_file)) diff --git a/train.py b/train.py index a99d5d36..e017605d 100755 --- a/train.py +++ b/train.py @@ -18,18 +18,12 @@ import argparse import time import yaml from datetime import datetime +from contextlib import suppress -try: - from apex import amp - from apex.parallel import DistributedDataParallel as DDP - from apex.parallel import convert_syncbn_model - has_apex = True -except ImportError: - from torch.cuda import amp - from torch.nn.parallel import DistributedDataParallel as DDP - has_apex = False - - +import torch +import torch.nn as nn +import torchvision.utils +from torch.nn.parallel import DistributedDataParallel as NativeDDP from timm.data import Dataset, create_loader, resolve_data_config, Mixup, FastCollateMixup, AugMixDataset from timm.models import create_model, resume_checkpoint, convert_splitbn_model @@ -38,14 +32,24 @@ from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy, JsdCro from timm.optim import create_optimizer from timm.scheduler import create_scheduler -import torch -import torch.nn as nn -import torchvision.utils +try: + from apex import amp + from apex.parallel import DistributedDataParallel as ApexDDP + from apex.parallel import convert_syncbn_model + has_apex = True +except ImportError: + has_apex = False + +has_native_amp = False +try: + if getattr(torch.cuda.amp, 'autocast') is not None: + has_native_amp = True +except AttributeError: + pass torch.backends.cudnn.benchmark = True _logger = logging.getLogger('train') - # The first arg parser parses out only the --config argument, this argument is used to # load a yaml file containing key-values that override the defaults for the main parser below config_parser = parser = argparse.ArgumentParser(description='Training Config', add_help=False) @@ -221,7 +225,13 @@ parser.add_argument('--num-gpu', type=int, default=1, parser.add_argument('--save-images', action='store_true', default=False, help='save images of input bathes every log interval for debugging') parser.add_argument('--amp', action='store_true', default=False, - help='use NVIDIA amp for mixed precision training') + help='use NVIDIA Apex AMP or Native AMP for mixed precision training') +parser.add_argument('--apex-amp', action='store_true', default=False, + help='Use NVIDIA Apex AMP mixed precision') +parser.add_argument('--native-amp', action='store_true', default=False, + help='Use Native Torch AMP mixed precision') +parser.add_argument('--channels-last', action='store_true', default=False, + help='Use channels_last memory layout') parser.add_argument('--pin-mem', action='store_true', default=False, help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.') parser.add_argument('--no-prefetcher', action='store_true', default=False, @@ -254,6 +264,23 @@ def _parse_args(): return args, args_text +class ApexScaler: + def __call__(self, loss, optimizer): + with amp.scale_loss(loss, optimizer) as scaled_loss: + scaled_loss.backward() + optimizer.step() + + +class NativeScaler: + def __init__(self): + self._scaler = torch.cuda.amp.GradScaler() + + def __call__(self, loss, optimizer): + self._scaler.scale(loss).backward() + self._scaler.step(optimizer) + self._scaler.update() + + def main(): setup_default_logging() args, args_text = _parse_args() @@ -263,7 +290,8 @@ def main(): if 'WORLD_SIZE' in os.environ: args.distributed = int(os.environ['WORLD_SIZE']) > 1 if args.distributed and args.num_gpu > 1: - _logger.warning('Using more than one GPU per process in distributed mode is not allowed. Setting num_gpu to 1.') + _logger.warning( + 'Using more than one GPU per process in distributed mode is not allowed.Setting num_gpu to 1.') args.num_gpu = 1 args.device = 'cuda:0' @@ -315,28 +343,50 @@ def main(): assert num_aug_splits > 1 or args.resplit model = convert_splitbn_model(model, max(num_aug_splits, 2)) + use_amp = None + if args.amp: + # for backwards compat, `--amp` arg tries apex before native amp + if has_apex: + args.apex_amp = True + elif has_native_amp: + args.native_amp = True + if args.apex_amp and has_apex: + use_amp = 'apex' + elif args.native_amp and has_native_amp: + use_amp = 'native' + elif args.apex_amp or args.native_amp: + _logger.warning("Neither APEX or native Torch AMP is available, using float32. " + "Install NVIDA apex or upgrade to PyTorch 1.6") + if args.num_gpu > 1: - if args.amp: + if use_amp == 'apex': _logger.warning( - 'AMP does not work well with nn.DataParallel, disabling. Use distributed mode for multi-GPU AMP.') - args.amp = False + 'Apex AMP does not work well with nn.DataParallel, disabling. Use DDP or Torch AMP.') + use_amp = None model = nn.DataParallel(model, device_ids=list(range(args.num_gpu))).cuda() + assert not args.channels_last, "Channels last not supported with DP, use DDP." else: model.cuda() + if args.channels_last: + model = model.to(memory_format=torch.channels_last) optimizer = create_optimizer(args, model) - use_amp = False - if has_apex and args.amp: + amp_autocast = suppress # do nothing + loss_scaler = None + if use_amp == 'apex': model, optimizer = amp.initialize(model, optimizer, opt_level='O1') - use_amp = True - elif args.amp: - _logger.info('Using torch AMP. Install NVIDIA Apex for Apex AMP.') - scaler = torch.cuda.amp.GradScaler() - use_amp = True - if args.local_rank == 0: - _logger.info('NVIDIA APEX {}. AMP {}.'.format( - 'installed' if has_apex else 'not installed', 'on' if use_amp else 'off')) + loss_scaler = ApexScaler() + if args.local_rank == 0: + _logger.info('Using NVIDIA APEX AMP. Training in mixed precision.') + elif use_amp == 'native': + amp_autocast = torch.cuda.amp.autocast + loss_scaler = NativeScaler() + if args.local_rank == 0: + _logger.info('Using native Torch AMP. Training in mixed precision.') + else: + if args.local_rank == 0: + _logger.info('AMP not enabled. Training in float32.') # optionally resume from a checkpoint resume_state = {} @@ -346,7 +396,7 @@ def main(): if resume_state and not args.no_resume_opt: if 'optimizer' in resume_state: if args.local_rank == 0: - _logger.info('Restoring Optimizer state from checkpoint') + _logger.info('Restoring optimizer state from checkpoint') optimizer.load_state_dict(resume_state['optimizer']) if use_amp and 'amp' in resume_state and 'load_state_dict' in amp.__dict__: if args.local_rank == 0: @@ -367,7 +417,8 @@ def main(): if args.sync_bn: assert not args.split_bn try: - if has_apex: + if has_apex and use_amp != 'native': + # Apex SyncBN preferred unless native amp is activated model = convert_syncbn_model(model) else: model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) @@ -377,12 +428,15 @@ def main(): 'zero initialized BN layers (enabled by default for ResNets) while sync-bn enabled.') except Exception as e: _logger.error('Failed to enable Synchronized BatchNorm. Install Apex or Torch >= 1.1') - if has_apex: - model = DDP(model, delay_allreduce=True) + if has_apex and use_amp != 'native': + # Apex DDP preferred unless native amp is activated + if args.local_rank == 0: + _logger.info("Using NVIDIA APEX DistributedDataParallel.") + model = ApexDDP(model, delay_allreduce=True) else: if args.local_rank == 0: - _logger.info("Using torch DistributedDataParallel. Install NVIDIA Apex for Apex DDP.") - model = DDP(model, device_ids=[args.local_rank]) # can use device str in Torch >= 1.1 + _logger.info("Using native Torch DistributedDataParallel.") + model = NativeDDP(model, device_ids=[args.local_rank]) # can use device str in Torch >= 1.1 # NOTE: EMA model does not need to be wrapped by DDP lr_scheduler, num_epochs = create_scheduler(args, optimizer) @@ -501,7 +555,7 @@ def main(): ]) output_dir = get_outdir(output_base, 'train', exp_name) decreasing = True if eval_metric == 'loss' else False - saver = CheckpointSaver(checkpoint_dir=output_dir, decreasing=decreasing) + saver = CheckpointSaver(checkpoint_dir=output_dir, decreasing=decreasing, save_amp=use_amp == 'apex') with open(os.path.join(output_dir, 'args.yaml'), 'w') as f: f.write(args_text) @@ -513,22 +567,20 @@ def main(): train_metrics = train_epoch( epoch, model, loader_train, optimizer, train_loss_fn, args, lr_scheduler=lr_scheduler, saver=saver, output_dir=output_dir, - use_amp=use_amp, has_apex=has_apex, scaler = scaler, - model_ema=model_ema, mixup_fn=mixup_fn) + amp_autocast=amp_autocast, loss_scaler=loss_scaler, model_ema=model_ema, mixup_fn=mixup_fn) if args.distributed and args.dist_bn in ('broadcast', 'reduce'): if args.local_rank == 0: _logger.info("Distributing BatchNorm running means and vars") distribute_bn(model, args.world_size, args.dist_bn == 'reduce') - eval_metrics = validate(model, loader_eval, validate_loss_fn, args) + eval_metrics = validate(model, loader_eval, validate_loss_fn, args, amp_autocast=amp_autocast) if model_ema is not None and not args.model_ema_force_cpu: if args.distributed and args.dist_bn in ('broadcast', 'reduce'): distribute_bn(model_ema, args.world_size, args.dist_bn == 'reduce') - ema_eval_metrics = validate( - model_ema.ema, loader_eval, validate_loss_fn, args, log_suffix=' (EMA)') + model_ema.ema, loader_eval, validate_loss_fn, args, amp_autocast=amp_autocast, log_suffix=' (EMA)') eval_metrics = ema_eval_metrics if lr_scheduler is not None: @@ -543,8 +595,7 @@ def main(): # save proper checkpoint with eval metric save_metric = eval_metrics[eval_metric] best_metric, best_epoch = saver.save_checkpoint( - model, optimizer, args, - epoch=epoch, model_ema=model_ema, metric=save_metric, use_amp=has_apex&use_amp) + model, optimizer, args, epoch=epoch, model_ema=model_ema, metric=save_metric) except KeyboardInterrupt: pass @@ -554,8 +605,8 @@ def main(): def train_epoch( epoch, model, loader, optimizer, loss_fn, args, - lr_scheduler=None, saver=None, output_dir='', use_amp=False, - has_apex=False, scaler = None, model_ema=None, mixup_fn=None): + lr_scheduler=None, saver=None, output_dir='', amp_autocast=suppress, + loss_scaler=None, model_ema=None, mixup_fn=None): if args.mixup_off_epoch and epoch >= args.mixup_off_epoch: if args.prefetcher and loader.mixup_enabled: @@ -579,31 +630,21 @@ def train_epoch( input, target = input.cuda(), target.cuda() if mixup_fn is not None: input, target = mixup_fn(input, target) - if not has_apex and use_amp: - with torch.cuda.amp.autocast(): - output = model(input) - loss = loss_fn(output, target) - else: + if args.channels_last: + input = input.contiguous(memory_format=torch.channels_last) + + with amp_autocast(): output = model(input) loss = loss_fn(output, target) - + if not args.distributed: losses_m.update(loss.item(), input.size(0)) optimizer.zero_grad() - if use_amp: - if has_apex: - with amp.scale_loss(loss, optimizer) as scaled_loss: - scaled_loss.backward() - else: - scaler.scale(loss).backward() - + if loss_scaler is not None: + loss_scaler(loss, optimizer) else: loss.backward() - if not has_apex and use_amp: - scaler.step(optimizer) - scaler.update() - else: optimizer.step() torch.cuda.synchronize() @@ -648,8 +689,7 @@ def train_epoch( if saver is not None and args.recovery_interval and ( last_batch or (batch_idx + 1) % args.recovery_interval == 0): - saver.save_recovery( - model, optimizer, args, epoch, model_ema=model_ema, use_amp=has_apex&use_amp, batch_idx=batch_idx) + saver.save_recovery(model, optimizer, args, epoch, model_ema=model_ema, batch_idx=batch_idx) if lr_scheduler is not None: lr_scheduler.step_update(num_updates=num_updates, metric=losses_m.avg) @@ -663,7 +703,7 @@ def train_epoch( return OrderedDict([('loss', losses_m.avg)]) -def validate(model, loader, loss_fn, args, log_suffix=''): +def validate(model, loader, loss_fn, args, amp_autocast=suppress, log_suffix=''): batch_time_m = AverageMeter() losses_m = AverageMeter() top1_m = AverageMeter() @@ -679,8 +719,11 @@ def validate(model, loader, loss_fn, args, log_suffix=''): if not args.prefetcher: input = input.cuda() target = target.cuda() + if args.channels_last: + input = input.contiguous(memory_format=torch.channels_last) - output = model(input) + with amp_autocast(): + output = model(input) if isinstance(output, (tuple, list)): output = output[0] diff --git a/validate.py b/validate.py index b0d411c5..f20c35ba 100755 --- a/validate.py +++ b/validate.py @@ -17,16 +17,25 @@ import torch import torch.nn as nn import torch.nn.parallel from collections import OrderedDict +from contextlib import suppress +from timm.models import create_model, apply_test_time_pool, load_checkpoint, is_model, list_models +from timm.data import Dataset, DatasetTar, create_loader, resolve_data_config, RealLabelsImagenet +from timm.utils import accuracy, AverageMeter, natural_key, setup_default_logging + +has_apex = False try: from apex import amp has_apex = True except ImportError: - has_apex = False + pass -from timm.models import create_model, apply_test_time_pool, load_checkpoint, is_model, list_models -from timm.data import Dataset, DatasetTar, create_loader, resolve_data_config, RealLabelsImagenet -from timm.utils import accuracy, AverageMeter, natural_key, setup_default_logging +has_native_amp = False +try: + if getattr(torch.cuda.amp, 'autocast') is not None: + has_native_amp = True +except AttributeError: + pass torch.backends.cudnn.benchmark = True _logger = logging.getLogger('validate') @@ -69,8 +78,14 @@ parser.add_argument('--no-prefetcher', action='store_true', default=False, help='disable fast prefetcher') parser.add_argument('--pin-mem', action='store_true', default=False, help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.') +parser.add_argument('--channels-last', action='store_true', default=False, + help='Use channels_last memory layout') parser.add_argument('--amp', action='store_true', default=False, - help='Use AMP mixed precision') + help='Use AMP mixed precision. Defaults to Apex, fallback to native Torch AMP.') +parser.add_argument('--apex-amp', action='store_true', default=False, + help='Use NVIDIA Apex AMP mixed precision') +parser.add_argument('--native-amp', action='store_true', default=False, + help='Use Native Torch AMP mixed precision') parser.add_argument('--tf-preprocessing', action='store_true', default=False, help='Use Tensorflow preprocessing pipeline (require CPU TF installed') parser.add_argument('--use-ema', dest='use_ema', action='store_true', @@ -104,6 +119,18 @@ def validate(args): # might as well try to validate something args.pretrained = args.pretrained or not args.checkpoint args.prefetcher = not args.no_prefetcher + amp_autocast = suppress # do nothing + if args.amp: + if has_apex: + args.apex_amp = True + elif has_native_amp: + args.native_amp = True + else: + _logger.warning("Neither APEX or Native Torch AMP is available, using FP32.") + assert not args.apex_amp or not args.native_amp, "Only one AMP mode should be set." + if args.native_amp: + amp_autocast = torch.cuda.amp.autocast + if args.legacy_jit: set_jit_legacy() @@ -128,10 +155,12 @@ def validate(args): torch.jit.optimized_execution(True) model = torch.jit.script(model) - if args.amp: - model = amp.initialize(model.cuda(), opt_level='O1') - else: - model = model.cuda() + model = model.cuda() + if args.apex_amp: + model = amp.initialize(model, opt_level='O1') + + if args.channels_last: + model = model.to(memory_format=torch.channels_last) if args.num_gpu > 1: model = torch.nn.DataParallel(model, device_ids=list(range(args.num_gpu))) @@ -178,17 +207,21 @@ def validate(args): with torch.no_grad(): # warmup, reduce variability of first batch time, especially for comparing torchscript vs non input = torch.randn((args.batch_size,) + data_config['input_size']).cuda() + if args.channels_last: + input = input.contiguous(memory_format=torch.channels_last) model(input) end = time.time() for batch_idx, (input, target) in enumerate(loader): if args.no_prefetcher: target = target.cuda() input = input.cuda() - if args.fp16: - input = input.half() + if args.channels_last: + input = input.contiguous(memory_format=torch.channels_last) # compute output - output = model(input) + with amp_autocast(): + output = model(input) + if valid_labels is not None: output = output[:, valid_labels] loss = criterion(output, target) @@ -197,7 +230,7 @@ def validate(args): real_labels.add_result(output) # measure accuracy and record loss - acc1, acc5 = accuracy(output.data, target, topk=(1, 5)) + acc1, acc5 = accuracy(output.detach(), target, topk=(1, 5)) losses.update(loss.item(), input.size(0)) top1.update(acc1.item(), input.size(0)) top5.update(acc5.item(), input.size(0))