pull/233/head
Ross Wightman 4 years ago
commit 1d34a0a851

@ -25,8 +25,11 @@ try:
from apex.parallel import convert_syncbn_model from apex.parallel import convert_syncbn_model
has_apex = True has_apex = True
except ImportError: except ImportError:
from torch.cuda import amp
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
has_apex = False has_apex = False
from timm.data import Dataset, create_loader, resolve_data_config, Mixup, FastCollateMixup, AugMixDataset from timm.data import Dataset, create_loader, resolve_data_config, Mixup, FastCollateMixup, AugMixDataset
from timm.models import create_model, resume_checkpoint, convert_splitbn_model from timm.models import create_model, resume_checkpoint, convert_splitbn_model
@ -327,6 +330,10 @@ def main():
if has_apex and args.amp: if has_apex and args.amp:
model, optimizer = amp.initialize(model, optimizer, opt_level='O1') model, optimizer = amp.initialize(model, optimizer, opt_level='O1')
use_amp = True use_amp = True
elif args.amp:
_logger.info('Using torch AMP. Install NVIDIA Apex for Apex AMP.')
scaler = torch.cuda.amp.GradScaler()
use_amp = True
if args.local_rank == 0: if args.local_rank == 0:
_logger.info('NVIDIA APEX {}. AMP {}.'.format( _logger.info('NVIDIA APEX {}. AMP {}.'.format(
'installed' if has_apex else 'not installed', 'on' if use_amp else 'off')) 'installed' if has_apex else 'not installed', 'on' if use_amp else 'off'))
@ -506,7 +513,8 @@ def main():
train_metrics = train_epoch( train_metrics = train_epoch(
epoch, model, loader_train, optimizer, train_loss_fn, args, epoch, model, loader_train, optimizer, train_loss_fn, args,
lr_scheduler=lr_scheduler, saver=saver, output_dir=output_dir, lr_scheduler=lr_scheduler, saver=saver, output_dir=output_dir,
use_amp=use_amp, model_ema=model_ema, mixup_fn=mixup_fn) use_amp=use_amp, has_apex=has_apex, scaler = scaler,
model_ema=model_ema, mixup_fn=mixup_fn)
if args.distributed and args.dist_bn in ('broadcast', 'reduce'): if args.distributed and args.dist_bn in ('broadcast', 'reduce'):
if args.local_rank == 0: if args.local_rank == 0:
@ -536,7 +544,7 @@ def main():
save_metric = eval_metrics[eval_metric] save_metric = eval_metrics[eval_metric]
best_metric, best_epoch = saver.save_checkpoint( best_metric, best_epoch = saver.save_checkpoint(
model, optimizer, args, model, optimizer, args,
epoch=epoch, model_ema=model_ema, metric=save_metric, use_amp=use_amp) epoch=epoch, model_ema=model_ema, metric=save_metric, use_amp=has_apex&use_amp)
except KeyboardInterrupt: except KeyboardInterrupt:
pass pass
@ -546,7 +554,8 @@ def main():
def train_epoch( def train_epoch(
epoch, model, loader, optimizer, loss_fn, args, epoch, model, loader, optimizer, loss_fn, args,
lr_scheduler=None, saver=None, output_dir='', use_amp=False, model_ema=None, mixup_fn=None): lr_scheduler=None, saver=None, output_dir='', use_amp=False,
has_apex=False, scaler = None, model_ema=None, mixup_fn=None):
if args.mixup_off_epoch and epoch >= args.mixup_off_epoch: if args.mixup_off_epoch and epoch >= args.mixup_off_epoch:
if args.prefetcher and loader.mixup_enabled: if args.prefetcher and loader.mixup_enabled:
@ -570,20 +579,32 @@ def train_epoch(
input, target = input.cuda(), target.cuda() input, target = input.cuda(), target.cuda()
if mixup_fn is not None: if mixup_fn is not None:
input, target = mixup_fn(input, target) input, target = mixup_fn(input, target)
if not has_apex and use_amp:
output = model(input) with torch.cuda.amp.autocast():
output = model(input)
loss = loss_fn(output, target) loss = loss_fn(output, target)
else:
output = model(input)
loss = loss_fn(output, target)
if not args.distributed: if not args.distributed:
losses_m.update(loss.item(), input.size(0)) losses_m.update(loss.item(), input.size(0))
optimizer.zero_grad() optimizer.zero_grad()
if use_amp: if use_amp:
with amp.scale_loss(loss, optimizer) as scaled_loss: if has_apex:
scaled_loss.backward() with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()
else:
scaler.scale(loss).backward()
else: else:
loss.backward() loss.backward()
optimizer.step() if not has_apex and use_amp:
scaler.step(optimizer)
scaler.update()
else:
optimizer.step()
torch.cuda.synchronize() torch.cuda.synchronize()
if model_ema is not None: if model_ema is not None:
@ -626,8 +647,9 @@ def train_epoch(
if saver is not None and args.recovery_interval and ( if saver is not None and args.recovery_interval and (
last_batch or (batch_idx + 1) % args.recovery_interval == 0): last_batch or (batch_idx + 1) % args.recovery_interval == 0):
saver.save_recovery( saver.save_recovery(
model, optimizer, args, epoch, model_ema=model_ema, use_amp=use_amp, batch_idx=batch_idx) model, optimizer, args, epoch, model_ema=model_ema, use_amp=has_apex&use_amp, batch_idx=batch_idx)
if lr_scheduler is not None: if lr_scheduler is not None:
lr_scheduler.step_update(num_updates=num_updates, metric=losses_m.avg) lr_scheduler.step_update(num_updates=num_updates, metric=losses_m.avg)

Loading…
Cancel
Save