From d98967ed5d4b295d0d1871a74428956673f5f7e5 Mon Sep 17 00:00:00 2001 From: datamining99 Date: Sat, 22 Aug 2020 09:44:23 +0900 Subject: [PATCH 1/2] add support for native torch AMP in torch 1.6 --- train.py | 39 ++++++++++++++++++++++++++++++--------- 1 file changed, 30 insertions(+), 9 deletions(-) diff --git a/train.py b/train.py index c28bd266..cdb5a95b 100755 --- a/train.py +++ b/train.py @@ -25,8 +25,11 @@ try: from apex.parallel import convert_syncbn_model has_apex = True except ImportError: + from torch.cuda import amp from torch.nn.parallel import DistributedDataParallel as DDP has_apex = False + + from timm.data import Dataset, create_loader, resolve_data_config, Mixup, FastCollateMixup, AugMixDataset from timm.models import create_model, resume_checkpoint, convert_splitbn_model @@ -327,6 +330,10 @@ def main(): if has_apex and args.amp: model, optimizer = amp.initialize(model, optimizer, opt_level='O1') use_amp = True + elif args.amp: + _logger.info('Using torch AMP. Install NVIDIA Apex for Apex AMP.') + scaler = torch.cuda.amp.GradScaler() + use_amp = True if args.local_rank == 0: _logger.info('NVIDIA APEX {}. AMP {}.'.format( 'installed' if has_apex else 'not installed', 'on' if use_amp else 'off')) @@ -506,7 +513,8 @@ def main(): train_metrics = train_epoch( epoch, model, loader_train, optimizer, train_loss_fn, args, lr_scheduler=lr_scheduler, saver=saver, output_dir=output_dir, - use_amp=use_amp, model_ema=model_ema, mixup_fn=mixup_fn) + use_amp=use_amp, has_apex=has_apex, scaler = scaler, + model_ema=model_ema, mixup_fn=mixup_fn) if args.distributed and args.dist_bn in ('broadcast', 'reduce'): if args.local_rank == 0: @@ -546,7 +554,8 @@ def main(): def train_epoch( epoch, model, loader, optimizer, loss_fn, args, - lr_scheduler=None, saver=None, output_dir='', use_amp=False, model_ema=None, mixup_fn=None): + lr_scheduler=None, saver=None, output_dir='', use_amp=False, + has_apex=False, scaler = None, model_ema=None, mixup_fn=None): if args.mixup_off_epoch and epoch >= args.mixup_off_epoch: if args.prefetcher and loader.mixup_enabled: @@ -570,20 +579,32 @@ def train_epoch( input, target = input.cuda(), target.cuda() if mixup_fn is not None: input, target = mixup_fn(input, target) - - output = model(input) - - loss = loss_fn(output, target) + if not has_apex and use_amp: + with torch.cuda.amp.autocast(): + output = model(input) + loss = loss_fn(output, target) + else: + output = model(input) + loss = loss_fn(output, target) + if not args.distributed: losses_m.update(loss.item(), input.size(0)) optimizer.zero_grad() if use_amp: - with amp.scale_loss(loss, optimizer) as scaled_loss: - scaled_loss.backward() + if has_apex: + with amp.scale_loss(loss, optimizer) as scaled_loss: + scaled_loss.backward() + else: + scaler.scale(loss).backward() + else: loss.backward() - optimizer.step() + if not has_apex and use_amp: + scaler.step(optimizer) + scaler.update() + else: + optimizer.step() torch.cuda.synchronize() if model_ema is not None: From 5f563ca4df0ed101ee0f5a7966e7a28238f4d79c Mon Sep 17 00:00:00 2001 From: datamining99 Date: Sat, 22 Aug 2020 11:31:50 +0900 Subject: [PATCH 2/2] fix save_checkpoint bug with native amp --- train.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/train.py b/train.py index cdb5a95b..a99d5d36 100755 --- a/train.py +++ b/train.py @@ -544,7 +544,7 @@ def main(): save_metric = eval_metrics[eval_metric] best_metric, best_epoch = saver.save_checkpoint( model, optimizer, args, - epoch=epoch, model_ema=model_ema, metric=save_metric, use_amp=use_amp) + epoch=epoch, model_ema=model_ema, metric=save_metric, use_amp=has_apex&use_amp) except KeyboardInterrupt: pass @@ -647,8 +647,9 @@ def train_epoch( if saver is not None and args.recovery_interval and ( last_batch or (batch_idx + 1) % args.recovery_interval == 0): + saver.save_recovery( - model, optimizer, args, epoch, model_ema=model_ema, use_amp=use_amp, batch_idx=batch_idx) + model, optimizer, args, epoch, model_ema=model_ema, use_amp=has_apex&use_amp, batch_idx=batch_idx) if lr_scheduler is not None: lr_scheduler.step_update(num_updates=num_updates, metric=losses_m.avg)