Improve torch amp support and add channels_last support for train/validate scripts

pull/233/head
Ross Wightman 4 years ago
parent 1d34a0a851
commit c2cd1a332e

@ -49,7 +49,8 @@ class CheckpointSaver:
checkpoint_dir='', checkpoint_dir='',
recovery_dir='', recovery_dir='',
decreasing=False, decreasing=False,
max_history=10): max_history=10,
save_amp=False):
# state # state
self.checkpoint_files = [] # (filename, metric) tuples in order of decreasing betterness self.checkpoint_files = [] # (filename, metric) tuples in order of decreasing betterness
@ -67,13 +68,14 @@ class CheckpointSaver:
self.decreasing = decreasing # a lower metric is better if True self.decreasing = decreasing # a lower metric is better if True
self.cmp = operator.lt if decreasing else operator.gt # True if lhs better than rhs self.cmp = operator.lt if decreasing else operator.gt # True if lhs better than rhs
self.max_history = max_history self.max_history = max_history
self.save_apex_amp = save_amp # save APEX amp state
assert self.max_history >= 1 assert self.max_history >= 1
def save_checkpoint(self, model, optimizer, args, epoch, model_ema=None, metric=None, use_amp=False): def save_checkpoint(self, model, optimizer, args, epoch, model_ema=None, metric=None):
assert epoch >= 0 assert epoch >= 0
tmp_save_path = os.path.join(self.checkpoint_dir, 'tmp' + self.extension) tmp_save_path = os.path.join(self.checkpoint_dir, 'tmp' + self.extension)
last_save_path = os.path.join(self.checkpoint_dir, 'last' + self.extension) last_save_path = os.path.join(self.checkpoint_dir, 'last' + self.extension)
self._save(tmp_save_path, model, optimizer, args, epoch, model_ema, metric, use_amp) self._save(tmp_save_path, model, optimizer, args, epoch, model_ema, metric)
if os.path.exists(last_save_path): if os.path.exists(last_save_path):
os.unlink(last_save_path) # required for Windows support. os.unlink(last_save_path) # required for Windows support.
os.rename(tmp_save_path, last_save_path) os.rename(tmp_save_path, last_save_path)
@ -105,7 +107,7 @@ class CheckpointSaver:
return (None, None) if self.best_metric is None else (self.best_metric, self.best_epoch) return (None, None) if self.best_metric is None else (self.best_metric, self.best_epoch)
def _save(self, save_path, model, optimizer, args, epoch, model_ema=None, metric=None, use_amp=False): def _save(self, save_path, model, optimizer, args, epoch, model_ema=None, metric=None):
save_state = { save_state = {
'epoch': epoch, 'epoch': epoch,
'arch': args.model, 'arch': args.model,
@ -114,7 +116,7 @@ class CheckpointSaver:
'args': args, 'args': args,
'version': 2, # version < 2 increments epoch before save 'version': 2, # version < 2 increments epoch before save
} }
if use_amp and 'state_dict' in amp.__dict__: if self.save_apex_amp and 'state_dict' in amp.__dict__:
save_state['amp'] = amp.state_dict() save_state['amp'] = amp.state_dict()
if model_ema is not None: if model_ema is not None:
save_state['state_dict_ema'] = get_state_dict(model_ema) save_state['state_dict_ema'] = get_state_dict(model_ema)
@ -136,11 +138,11 @@ class CheckpointSaver:
_logger.error("Exception '{}' while deleting checkpoint".format(e)) _logger.error("Exception '{}' while deleting checkpoint".format(e))
self.checkpoint_files = self.checkpoint_files[:delete_index] self.checkpoint_files = self.checkpoint_files[:delete_index]
def save_recovery(self, model, optimizer, args, epoch, model_ema=None, use_amp=False, batch_idx=0): def save_recovery(self, model, optimizer, args, epoch, model_ema=None, batch_idx=0):
assert epoch >= 0 assert epoch >= 0
filename = '-'.join([self.recovery_prefix, str(epoch), str(batch_idx)]) + self.extension filename = '-'.join([self.recovery_prefix, str(epoch), str(batch_idx)]) + self.extension
save_path = os.path.join(self.recovery_dir, filename) save_path = os.path.join(self.recovery_dir, filename)
self._save(save_path, model, optimizer, args, epoch, model_ema, use_amp=use_amp) self._save(save_path, model, optimizer, args, epoch, model_ema)
if os.path.exists(self.last_recovery_file): if os.path.exists(self.last_recovery_file):
try: try:
_logger.debug("Cleaning recovery: {}".format(self.last_recovery_file)) _logger.debug("Cleaning recovery: {}".format(self.last_recovery_file))

@ -18,18 +18,12 @@ import argparse
import time import time
import yaml import yaml
from datetime import datetime from datetime import datetime
from contextlib import suppress
try: import torch
from apex import amp import torch.nn as nn
from apex.parallel import DistributedDataParallel as DDP import torchvision.utils
from apex.parallel import convert_syncbn_model from torch.nn.parallel import DistributedDataParallel as NativeDDP
has_apex = True
except ImportError:
from torch.cuda import amp
from torch.nn.parallel import DistributedDataParallel as DDP
has_apex = False
from timm.data import Dataset, create_loader, resolve_data_config, Mixup, FastCollateMixup, AugMixDataset from timm.data import Dataset, create_loader, resolve_data_config, Mixup, FastCollateMixup, AugMixDataset
from timm.models import create_model, resume_checkpoint, convert_splitbn_model from timm.models import create_model, resume_checkpoint, convert_splitbn_model
@ -38,14 +32,24 @@ from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy, JsdCro
from timm.optim import create_optimizer from timm.optim import create_optimizer
from timm.scheduler import create_scheduler from timm.scheduler import create_scheduler
import torch try:
import torch.nn as nn from apex import amp
import torchvision.utils from apex.parallel import DistributedDataParallel as ApexDDP
from apex.parallel import convert_syncbn_model
has_apex = True
except ImportError:
has_apex = False
has_native_amp = False
try:
if getattr(torch.cuda.amp, 'autocast') is not None:
has_native_amp = True
except AttributeError:
pass
torch.backends.cudnn.benchmark = True torch.backends.cudnn.benchmark = True
_logger = logging.getLogger('train') _logger = logging.getLogger('train')
# The first arg parser parses out only the --config argument, this argument is used to # The first arg parser parses out only the --config argument, this argument is used to
# load a yaml file containing key-values that override the defaults for the main parser below # load a yaml file containing key-values that override the defaults for the main parser below
config_parser = parser = argparse.ArgumentParser(description='Training Config', add_help=False) config_parser = parser = argparse.ArgumentParser(description='Training Config', add_help=False)
@ -221,7 +225,13 @@ parser.add_argument('--num-gpu', type=int, default=1,
parser.add_argument('--save-images', action='store_true', default=False, parser.add_argument('--save-images', action='store_true', default=False,
help='save images of input bathes every log interval for debugging') help='save images of input bathes every log interval for debugging')
parser.add_argument('--amp', action='store_true', default=False, parser.add_argument('--amp', action='store_true', default=False,
help='use NVIDIA amp for mixed precision training') help='use NVIDIA Apex AMP or Native AMP for mixed precision training')
parser.add_argument('--apex-amp', action='store_true', default=False,
help='Use NVIDIA Apex AMP mixed precision')
parser.add_argument('--native-amp', action='store_true', default=False,
help='Use Native Torch AMP mixed precision')
parser.add_argument('--channels-last', action='store_true', default=False,
help='Use channels_last memory layout')
parser.add_argument('--pin-mem', action='store_true', default=False, parser.add_argument('--pin-mem', action='store_true', default=False,
help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.') help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')
parser.add_argument('--no-prefetcher', action='store_true', default=False, parser.add_argument('--no-prefetcher', action='store_true', default=False,
@ -254,6 +264,23 @@ def _parse_args():
return args, args_text return args, args_text
class ApexScaler:
def __call__(self, loss, optimizer):
with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()
optimizer.step()
class NativeScaler:
def __init__(self):
self._scaler = torch.cuda.amp.GradScaler()
def __call__(self, loss, optimizer):
self._scaler.scale(loss).backward()
self._scaler.step(optimizer)
self._scaler.update()
def main(): def main():
setup_default_logging() setup_default_logging()
args, args_text = _parse_args() args, args_text = _parse_args()
@ -263,7 +290,8 @@ def main():
if 'WORLD_SIZE' in os.environ: if 'WORLD_SIZE' in os.environ:
args.distributed = int(os.environ['WORLD_SIZE']) > 1 args.distributed = int(os.environ['WORLD_SIZE']) > 1
if args.distributed and args.num_gpu > 1: if args.distributed and args.num_gpu > 1:
_logger.warning('Using more than one GPU per process in distributed mode is not allowed. Setting num_gpu to 1.') _logger.warning(
'Using more than one GPU per process in distributed mode is not allowed.Setting num_gpu to 1.')
args.num_gpu = 1 args.num_gpu = 1
args.device = 'cuda:0' args.device = 'cuda:0'
@ -315,28 +343,50 @@ def main():
assert num_aug_splits > 1 or args.resplit assert num_aug_splits > 1 or args.resplit
model = convert_splitbn_model(model, max(num_aug_splits, 2)) model = convert_splitbn_model(model, max(num_aug_splits, 2))
if args.num_gpu > 1: use_amp = None
if args.amp: if args.amp:
# for backwards compat, `--amp` arg tries apex before native amp
if has_apex:
args.apex_amp = True
elif has_native_amp:
args.native_amp = True
if args.apex_amp and has_apex:
use_amp = 'apex'
elif args.native_amp and has_native_amp:
use_amp = 'native'
elif args.apex_amp or args.native_amp:
_logger.warning("Neither APEX or native Torch AMP is available, using float32. "
"Install NVIDA apex or upgrade to PyTorch 1.6")
if args.num_gpu > 1:
if use_amp == 'apex':
_logger.warning( _logger.warning(
'AMP does not work well with nn.DataParallel, disabling. Use distributed mode for multi-GPU AMP.') 'Apex AMP does not work well with nn.DataParallel, disabling. Use DDP or Torch AMP.')
args.amp = False use_amp = None
model = nn.DataParallel(model, device_ids=list(range(args.num_gpu))).cuda() model = nn.DataParallel(model, device_ids=list(range(args.num_gpu))).cuda()
assert not args.channels_last, "Channels last not supported with DP, use DDP."
else: else:
model.cuda() model.cuda()
if args.channels_last:
model = model.to(memory_format=torch.channels_last)
optimizer = create_optimizer(args, model) optimizer = create_optimizer(args, model)
use_amp = False amp_autocast = suppress # do nothing
if has_apex and args.amp: loss_scaler = None
if use_amp == 'apex':
model, optimizer = amp.initialize(model, optimizer, opt_level='O1') model, optimizer = amp.initialize(model, optimizer, opt_level='O1')
use_amp = True loss_scaler = ApexScaler()
elif args.amp: if args.local_rank == 0:
_logger.info('Using torch AMP. Install NVIDIA Apex for Apex AMP.') _logger.info('Using NVIDIA APEX AMP. Training in mixed precision.')
scaler = torch.cuda.amp.GradScaler() elif use_amp == 'native':
use_amp = True amp_autocast = torch.cuda.amp.autocast
loss_scaler = NativeScaler()
if args.local_rank == 0: if args.local_rank == 0:
_logger.info('NVIDIA APEX {}. AMP {}.'.format( _logger.info('Using native Torch AMP. Training in mixed precision.')
'installed' if has_apex else 'not installed', 'on' if use_amp else 'off')) else:
if args.local_rank == 0:
_logger.info('AMP not enabled. Training in float32.')
# optionally resume from a checkpoint # optionally resume from a checkpoint
resume_state = {} resume_state = {}
@ -346,7 +396,7 @@ def main():
if resume_state and not args.no_resume_opt: if resume_state and not args.no_resume_opt:
if 'optimizer' in resume_state: if 'optimizer' in resume_state:
if args.local_rank == 0: if args.local_rank == 0:
_logger.info('Restoring Optimizer state from checkpoint') _logger.info('Restoring optimizer state from checkpoint')
optimizer.load_state_dict(resume_state['optimizer']) optimizer.load_state_dict(resume_state['optimizer'])
if use_amp and 'amp' in resume_state and 'load_state_dict' in amp.__dict__: if use_amp and 'amp' in resume_state and 'load_state_dict' in amp.__dict__:
if args.local_rank == 0: if args.local_rank == 0:
@ -367,7 +417,8 @@ def main():
if args.sync_bn: if args.sync_bn:
assert not args.split_bn assert not args.split_bn
try: try:
if has_apex: if has_apex and use_amp != 'native':
# Apex SyncBN preferred unless native amp is activated
model = convert_syncbn_model(model) model = convert_syncbn_model(model)
else: else:
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
@ -377,12 +428,15 @@ def main():
'zero initialized BN layers (enabled by default for ResNets) while sync-bn enabled.') 'zero initialized BN layers (enabled by default for ResNets) while sync-bn enabled.')
except Exception as e: except Exception as e:
_logger.error('Failed to enable Synchronized BatchNorm. Install Apex or Torch >= 1.1') _logger.error('Failed to enable Synchronized BatchNorm. Install Apex or Torch >= 1.1')
if has_apex: if has_apex and use_amp != 'native':
model = DDP(model, delay_allreduce=True) # Apex DDP preferred unless native amp is activated
if args.local_rank == 0:
_logger.info("Using NVIDIA APEX DistributedDataParallel.")
model = ApexDDP(model, delay_allreduce=True)
else: else:
if args.local_rank == 0: if args.local_rank == 0:
_logger.info("Using torch DistributedDataParallel. Install NVIDIA Apex for Apex DDP.") _logger.info("Using native Torch DistributedDataParallel.")
model = DDP(model, device_ids=[args.local_rank]) # can use device str in Torch >= 1.1 model = NativeDDP(model, device_ids=[args.local_rank]) # can use device str in Torch >= 1.1
# NOTE: EMA model does not need to be wrapped by DDP # NOTE: EMA model does not need to be wrapped by DDP
lr_scheduler, num_epochs = create_scheduler(args, optimizer) lr_scheduler, num_epochs = create_scheduler(args, optimizer)
@ -501,7 +555,7 @@ def main():
]) ])
output_dir = get_outdir(output_base, 'train', exp_name) output_dir = get_outdir(output_base, 'train', exp_name)
decreasing = True if eval_metric == 'loss' else False decreasing = True if eval_metric == 'loss' else False
saver = CheckpointSaver(checkpoint_dir=output_dir, decreasing=decreasing) saver = CheckpointSaver(checkpoint_dir=output_dir, decreasing=decreasing, save_amp=use_amp == 'apex')
with open(os.path.join(output_dir, 'args.yaml'), 'w') as f: with open(os.path.join(output_dir, 'args.yaml'), 'w') as f:
f.write(args_text) f.write(args_text)
@ -513,22 +567,20 @@ def main():
train_metrics = train_epoch( train_metrics = train_epoch(
epoch, model, loader_train, optimizer, train_loss_fn, args, epoch, model, loader_train, optimizer, train_loss_fn, args,
lr_scheduler=lr_scheduler, saver=saver, output_dir=output_dir, lr_scheduler=lr_scheduler, saver=saver, output_dir=output_dir,
use_amp=use_amp, has_apex=has_apex, scaler = scaler, amp_autocast=amp_autocast, loss_scaler=loss_scaler, model_ema=model_ema, mixup_fn=mixup_fn)
model_ema=model_ema, mixup_fn=mixup_fn)
if args.distributed and args.dist_bn in ('broadcast', 'reduce'): if args.distributed and args.dist_bn in ('broadcast', 'reduce'):
if args.local_rank == 0: if args.local_rank == 0:
_logger.info("Distributing BatchNorm running means and vars") _logger.info("Distributing BatchNorm running means and vars")
distribute_bn(model, args.world_size, args.dist_bn == 'reduce') distribute_bn(model, args.world_size, args.dist_bn == 'reduce')
eval_metrics = validate(model, loader_eval, validate_loss_fn, args) eval_metrics = validate(model, loader_eval, validate_loss_fn, args, amp_autocast=amp_autocast)
if model_ema is not None and not args.model_ema_force_cpu: if model_ema is not None and not args.model_ema_force_cpu:
if args.distributed and args.dist_bn in ('broadcast', 'reduce'): if args.distributed and args.dist_bn in ('broadcast', 'reduce'):
distribute_bn(model_ema, args.world_size, args.dist_bn == 'reduce') distribute_bn(model_ema, args.world_size, args.dist_bn == 'reduce')
ema_eval_metrics = validate( ema_eval_metrics = validate(
model_ema.ema, loader_eval, validate_loss_fn, args, log_suffix=' (EMA)') model_ema.ema, loader_eval, validate_loss_fn, args, amp_autocast=amp_autocast, log_suffix=' (EMA)')
eval_metrics = ema_eval_metrics eval_metrics = ema_eval_metrics
if lr_scheduler is not None: if lr_scheduler is not None:
@ -543,8 +595,7 @@ def main():
# save proper checkpoint with eval metric # save proper checkpoint with eval metric
save_metric = eval_metrics[eval_metric] save_metric = eval_metrics[eval_metric]
best_metric, best_epoch = saver.save_checkpoint( best_metric, best_epoch = saver.save_checkpoint(
model, optimizer, args, model, optimizer, args, epoch=epoch, model_ema=model_ema, metric=save_metric)
epoch=epoch, model_ema=model_ema, metric=save_metric, use_amp=has_apex&use_amp)
except KeyboardInterrupt: except KeyboardInterrupt:
pass pass
@ -554,8 +605,8 @@ def main():
def train_epoch( def train_epoch(
epoch, model, loader, optimizer, loss_fn, args, epoch, model, loader, optimizer, loss_fn, args,
lr_scheduler=None, saver=None, output_dir='', use_amp=False, lr_scheduler=None, saver=None, output_dir='', amp_autocast=suppress,
has_apex=False, scaler = None, model_ema=None, mixup_fn=None): loss_scaler=None, model_ema=None, mixup_fn=None):
if args.mixup_off_epoch and epoch >= args.mixup_off_epoch: if args.mixup_off_epoch and epoch >= args.mixup_off_epoch:
if args.prefetcher and loader.mixup_enabled: if args.prefetcher and loader.mixup_enabled:
@ -579,11 +630,10 @@ def train_epoch(
input, target = input.cuda(), target.cuda() input, target = input.cuda(), target.cuda()
if mixup_fn is not None: if mixup_fn is not None:
input, target = mixup_fn(input, target) input, target = mixup_fn(input, target)
if not has_apex and use_amp: if args.channels_last:
with torch.cuda.amp.autocast(): input = input.contiguous(memory_format=torch.channels_last)
output = model(input)
loss = loss_fn(output, target) with amp_autocast():
else:
output = model(input) output = model(input)
loss = loss_fn(output, target) loss = loss_fn(output, target)
@ -591,19 +641,10 @@ def train_epoch(
losses_m.update(loss.item(), input.size(0)) losses_m.update(loss.item(), input.size(0))
optimizer.zero_grad() optimizer.zero_grad()
if use_amp: if loss_scaler is not None:
if has_apex: loss_scaler(loss, optimizer)
with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()
else:
scaler.scale(loss).backward()
else: else:
loss.backward() loss.backward()
if not has_apex and use_amp:
scaler.step(optimizer)
scaler.update()
else:
optimizer.step() optimizer.step()
torch.cuda.synchronize() torch.cuda.synchronize()
@ -648,8 +689,7 @@ def train_epoch(
if saver is not None and args.recovery_interval and ( if saver is not None and args.recovery_interval and (
last_batch or (batch_idx + 1) % args.recovery_interval == 0): last_batch or (batch_idx + 1) % args.recovery_interval == 0):
saver.save_recovery( saver.save_recovery(model, optimizer, args, epoch, model_ema=model_ema, batch_idx=batch_idx)
model, optimizer, args, epoch, model_ema=model_ema, use_amp=has_apex&use_amp, batch_idx=batch_idx)
if lr_scheduler is not None: if lr_scheduler is not None:
lr_scheduler.step_update(num_updates=num_updates, metric=losses_m.avg) lr_scheduler.step_update(num_updates=num_updates, metric=losses_m.avg)
@ -663,7 +703,7 @@ def train_epoch(
return OrderedDict([('loss', losses_m.avg)]) return OrderedDict([('loss', losses_m.avg)])
def validate(model, loader, loss_fn, args, log_suffix=''): def validate(model, loader, loss_fn, args, amp_autocast=suppress, log_suffix=''):
batch_time_m = AverageMeter() batch_time_m = AverageMeter()
losses_m = AverageMeter() losses_m = AverageMeter()
top1_m = AverageMeter() top1_m = AverageMeter()
@ -679,7 +719,10 @@ def validate(model, loader, loss_fn, args, log_suffix=''):
if not args.prefetcher: if not args.prefetcher:
input = input.cuda() input = input.cuda()
target = target.cuda() target = target.cuda()
if args.channels_last:
input = input.contiguous(memory_format=torch.channels_last)
with amp_autocast():
output = model(input) output = model(input)
if isinstance(output, (tuple, list)): if isinstance(output, (tuple, list)):
output = output[0] output = output[0]

@ -17,16 +17,25 @@ import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.parallel import torch.nn.parallel
from collections import OrderedDict from collections import OrderedDict
from contextlib import suppress
from timm.models import create_model, apply_test_time_pool, load_checkpoint, is_model, list_models
from timm.data import Dataset, DatasetTar, create_loader, resolve_data_config, RealLabelsImagenet
from timm.utils import accuracy, AverageMeter, natural_key, setup_default_logging
has_apex = False
try: try:
from apex import amp from apex import amp
has_apex = True has_apex = True
except ImportError: except ImportError:
has_apex = False pass
from timm.models import create_model, apply_test_time_pool, load_checkpoint, is_model, list_models has_native_amp = False
from timm.data import Dataset, DatasetTar, create_loader, resolve_data_config, RealLabelsImagenet try:
from timm.utils import accuracy, AverageMeter, natural_key, setup_default_logging if getattr(torch.cuda.amp, 'autocast') is not None:
has_native_amp = True
except AttributeError:
pass
torch.backends.cudnn.benchmark = True torch.backends.cudnn.benchmark = True
_logger = logging.getLogger('validate') _logger = logging.getLogger('validate')
@ -69,8 +78,14 @@ parser.add_argument('--no-prefetcher', action='store_true', default=False,
help='disable fast prefetcher') help='disable fast prefetcher')
parser.add_argument('--pin-mem', action='store_true', default=False, parser.add_argument('--pin-mem', action='store_true', default=False,
help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.') help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')
parser.add_argument('--channels-last', action='store_true', default=False,
help='Use channels_last memory layout')
parser.add_argument('--amp', action='store_true', default=False, parser.add_argument('--amp', action='store_true', default=False,
help='Use AMP mixed precision') help='Use AMP mixed precision. Defaults to Apex, fallback to native Torch AMP.')
parser.add_argument('--apex-amp', action='store_true', default=False,
help='Use NVIDIA Apex AMP mixed precision')
parser.add_argument('--native-amp', action='store_true', default=False,
help='Use Native Torch AMP mixed precision')
parser.add_argument('--tf-preprocessing', action='store_true', default=False, parser.add_argument('--tf-preprocessing', action='store_true', default=False,
help='Use Tensorflow preprocessing pipeline (require CPU TF installed') help='Use Tensorflow preprocessing pipeline (require CPU TF installed')
parser.add_argument('--use-ema', dest='use_ema', action='store_true', parser.add_argument('--use-ema', dest='use_ema', action='store_true',
@ -104,6 +119,18 @@ def validate(args):
# might as well try to validate something # might as well try to validate something
args.pretrained = args.pretrained or not args.checkpoint args.pretrained = args.pretrained or not args.checkpoint
args.prefetcher = not args.no_prefetcher args.prefetcher = not args.no_prefetcher
amp_autocast = suppress # do nothing
if args.amp:
if has_apex:
args.apex_amp = True
elif has_native_amp:
args.native_amp = True
else:
_logger.warning("Neither APEX or Native Torch AMP is available, using FP32.")
assert not args.apex_amp or not args.native_amp, "Only one AMP mode should be set."
if args.native_amp:
amp_autocast = torch.cuda.amp.autocast
if args.legacy_jit: if args.legacy_jit:
set_jit_legacy() set_jit_legacy()
@ -128,10 +155,12 @@ def validate(args):
torch.jit.optimized_execution(True) torch.jit.optimized_execution(True)
model = torch.jit.script(model) model = torch.jit.script(model)
if args.amp:
model = amp.initialize(model.cuda(), opt_level='O1')
else:
model = model.cuda() model = model.cuda()
if args.apex_amp:
model = amp.initialize(model, opt_level='O1')
if args.channels_last:
model = model.to(memory_format=torch.channels_last)
if args.num_gpu > 1: if args.num_gpu > 1:
model = torch.nn.DataParallel(model, device_ids=list(range(args.num_gpu))) model = torch.nn.DataParallel(model, device_ids=list(range(args.num_gpu)))
@ -178,17 +207,21 @@ def validate(args):
with torch.no_grad(): with torch.no_grad():
# warmup, reduce variability of first batch time, especially for comparing torchscript vs non # warmup, reduce variability of first batch time, especially for comparing torchscript vs non
input = torch.randn((args.batch_size,) + data_config['input_size']).cuda() input = torch.randn((args.batch_size,) + data_config['input_size']).cuda()
if args.channels_last:
input = input.contiguous(memory_format=torch.channels_last)
model(input) model(input)
end = time.time() end = time.time()
for batch_idx, (input, target) in enumerate(loader): for batch_idx, (input, target) in enumerate(loader):
if args.no_prefetcher: if args.no_prefetcher:
target = target.cuda() target = target.cuda()
input = input.cuda() input = input.cuda()
if args.fp16: if args.channels_last:
input = input.half() input = input.contiguous(memory_format=torch.channels_last)
# compute output # compute output
with amp_autocast():
output = model(input) output = model(input)
if valid_labels is not None: if valid_labels is not None:
output = output[:, valid_labels] output = output[:, valid_labels]
loss = criterion(output, target) loss = criterion(output, target)
@ -197,7 +230,7 @@ def validate(args):
real_labels.add_result(output) real_labels.add_result(output)
# measure accuracy and record loss # measure accuracy and record loss
acc1, acc5 = accuracy(output.data, target, topk=(1, 5)) acc1, acc5 = accuracy(output.detach(), target, topk=(1, 5))
losses.update(loss.item(), input.size(0)) losses.update(loss.item(), input.size(0))
top1.update(acc1.item(), input.size(0)) top1.update(acc1.item(), input.size(0))
top5.update(acc5.item(), input.size(0)) top5.update(acc5.item(), input.size(0))

Loading…
Cancel
Save