Add back old ModelEma and rename new one to ModelEmaV2 to avoid compat breaks in dependant code. Shuffle train script, add a few comments, remove DataParallel support, support experimental torchscript training.

pull/297/head
Ross Wightman 4 years ago
parent 9214ca0716
commit 27bbc70d71

@ -6,5 +6,5 @@ from .log import setup_default_logging, FormatterNoInfo
from .metrics import AverageMeter, accuracy from .metrics import AverageMeter, accuracy
from .misc import natural_key, add_bool_arg from .misc import natural_key, add_bool_arg
from .model import unwrap_model, get_state_dict from .model import unwrap_model, get_state_dict
from .model_ema import ModelEma from .model_ema import ModelEma, ModelEmaV2
from .summary import update_summary, get_outdir from .summary import update_summary, get_outdir

@ -6,7 +6,10 @@ from .model_ema import ModelEma
def unwrap_model(model): def unwrap_model(model):
return model.module if hasattr(model, 'module') else model if isinstance(model, ModelEma):
return unwrap_model(model.ema)
else:
return model.module if hasattr(model, 'module') else model
def get_state_dict(model, unwrap_fn=unwrap_model): def get_state_dict(model, unwrap_fn=unwrap_model):

@ -2,15 +2,89 @@
Hacked together by / Copyright 2020 Ross Wightman Hacked together by / Copyright 2020 Ross Wightman
""" """
import logging
from collections import OrderedDict
from copy import deepcopy from copy import deepcopy
import torch import torch
import torch.nn as nn import torch.nn as nn
_logger = logging.getLogger(__name__)
class ModelEma:
""" Model Exponential Moving Average (DEPRECATED)
Keep a moving average of everything in the model state_dict (parameters and buffers).
This version is deprecated, it does not work with scripted models. Will be removed eventually.
This is intended to allow functionality like
https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage
A smoothed version of the weights is necessary for some training schemes to perform well.
E.g. Google's hyper-params for training MNASNet, MobileNet-V3, EfficientNet, etc that use
RMSprop with a short 2.4-3 epoch decay period and slow LR decay rate of .96-.99 requires EMA
smoothing of weights to match results. Pay attention to the decay constant you are using
relative to your update count per epoch.
To keep EMA from using GPU resources, set device='cpu'. This will save a bit of memory but
disable validation of the EMA weights. Validation will have to be done manually in a separate
process, or after the training stops converging.
This class is sensitive where it is initialized in the sequence of model init,
GPU assignment and distributed training wrappers.
"""
def __init__(self, model, decay=0.9999, device='', resume=''):
# make a copy of the model for accumulating moving average of weights
self.ema = deepcopy(model)
self.ema.eval()
self.decay = decay
self.device = device # perform ema on different device from model if set
if device:
self.ema.to(device=device)
self.ema_has_module = hasattr(self.ema, 'module')
if resume:
self._load_checkpoint(resume)
for p in self.ema.parameters():
p.requires_grad_(False)
def _load_checkpoint(self, checkpoint_path):
checkpoint = torch.load(checkpoint_path, map_location='cpu')
assert isinstance(checkpoint, dict)
if 'state_dict_ema' in checkpoint:
new_state_dict = OrderedDict()
for k, v in checkpoint['state_dict_ema'].items():
# ema model may have been wrapped by DataParallel, and need module prefix
if self.ema_has_module:
name = 'module.' + k if not k.startswith('module') else k
else:
name = k
new_state_dict[name] = v
self.ema.load_state_dict(new_state_dict)
_logger.info("Loaded state_dict_ema")
else:
_logger.warning("Failed to find state_dict_ema, starting from loaded model weights")
def update(self, model):
# correct a mismatch in state dict keys
needs_module = hasattr(model, 'module') and not self.ema_has_module
with torch.no_grad():
msd = model.state_dict()
for k, ema_v in self.ema.state_dict().items():
if needs_module:
k = 'module.' + k
model_v = msd[k].detach()
if self.device:
model_v = model_v.to(device=self.device)
ema_v.copy_(ema_v * self.decay + (1. - self.decay) * model_v)
class ModelEmaV2(nn.Module):
""" Model Exponential Moving Average V2
class ModelEma(nn.Module):
""" Model Exponential Moving Average
Keep a moving average of everything in the model state_dict (parameters and buffers). Keep a moving average of everything in the model state_dict (parameters and buffers).
V2 of this module is simpler, it does not match params/buffers based on name but simply
iterates in order. It works with torchscript (JIT of full model).
This is intended to allow functionality like This is intended to allow functionality like
https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage
@ -27,22 +101,20 @@ class ModelEma(nn.Module):
This class is sensitive where it is initialized in the sequence of model init, This class is sensitive where it is initialized in the sequence of model init,
GPU assignment and distributed training wrappers. GPU assignment and distributed training wrappers.
I've tested with the sequence in my own train.py for torch.DataParallel, apex.DDP, and single-GPU.
""" """
def __init__(self, model, decay=0.9999, device=None): def __init__(self, model, decay=0.9999, device=None):
super(ModelEma, self).__init__() super(ModelEmaV2, self).__init__()
# make a copy of the model for accumulating moving average of weights # make a copy of the model for accumulating moving average of weights
self.module = deepcopy(model) self.module = deepcopy(model)
self.module.eval() self.module.eval()
self.decay = decay self.decay = decay
self.device = device # perform ema on different device from model if set self.device = device # perform ema on different device from model if set
if device is not None: if self.device is not None:
self.module.to(device=device) self.module.to(device=device)
def update(self, model): def update(self, model):
with torch.no_grad(): with torch.no_grad():
for ema_v, model_v in zip(self.module.state_dict().values(), model.state_dict().values()): for ema_v, model_v in zip(self.module.state_dict().values(), model.state_dict().values()):
assert ema_v.shape == model_v.shape if self.device is not None:
if self.device:
model_v = model_v.to(device=self.device) model_v = model_v.to(device=self.device)
ema_v.copy_(ema_v * self.decay + (1. - self.decay) * model_v) ema_v.copy_(ema_v * self.decay + (1. - self.decay) * model_v)

@ -29,7 +29,7 @@ import torchvision.utils
from torch.nn.parallel import DistributedDataParallel as NativeDDP from torch.nn.parallel import DistributedDataParallel as NativeDDP
from timm.data import Dataset, create_loader, resolve_data_config, Mixup, FastCollateMixup, AugMixDataset from timm.data import Dataset, create_loader, resolve_data_config, Mixup, FastCollateMixup, AugMixDataset
from timm.models import create_model, resume_checkpoint, convert_splitbn_model from timm.models import create_model, resume_checkpoint, load_checkpoint, convert_splitbn_model
from timm.utils import * from timm.utils import *
from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy, JsdCrossEntropy from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy, JsdCrossEntropy
from timm.optim import create_optimizer from timm.optim import create_optimizer
@ -230,8 +230,6 @@ parser.add_argument('--recovery-interval', type=int, default=0, metavar='N',
help='how many batches to wait before writing recovery checkpoint') help='how many batches to wait before writing recovery checkpoint')
parser.add_argument('-j', '--workers', type=int, default=4, metavar='N', parser.add_argument('-j', '--workers', type=int, default=4, metavar='N',
help='how many training processes to use (default: 1)') help='how many training processes to use (default: 1)')
parser.add_argument('--num-gpu', type=int, default=1,
help='Number of GPUS to use')
parser.add_argument('--save-images', action='store_true', default=False, parser.add_argument('--save-images', action='store_true', default=False,
help='save images of input bathes every log interval for debugging') help='save images of input bathes every log interval for debugging')
parser.add_argument('--amp', action='store_true', default=False, parser.add_argument('--amp', action='store_true', default=False,
@ -255,6 +253,8 @@ parser.add_argument('--tta', type=int, default=0, metavar='N',
parser.add_argument("--local_rank", default=0, type=int) parser.add_argument("--local_rank", default=0, type=int)
parser.add_argument('--use-multi-epochs-loader', action='store_true', default=False, parser.add_argument('--use-multi-epochs-loader', action='store_true', default=False,
help='use the multi-epochs-loader to save time at the beginning of every epoch') help='use the multi-epochs-loader to save time at the beginning of every epoch')
parser.add_argument('--torchscript', dest='torchscript', action='store_true',
help='convert model torchscript for inference')
def _parse_args(): def _parse_args():
@ -282,28 +282,36 @@ def main():
args.distributed = False args.distributed = False
if 'WORLD_SIZE' in os.environ: if 'WORLD_SIZE' in os.environ:
args.distributed = int(os.environ['WORLD_SIZE']) > 1 args.distributed = int(os.environ['WORLD_SIZE']) > 1
if args.distributed and args.num_gpu > 1:
_logger.warning(
'Using more than one GPU per process in distributed mode is not allowed.Setting num_gpu to 1.')
args.num_gpu = 1
args.device = 'cuda:0' args.device = 'cuda:0'
args.world_size = 1 args.world_size = 1
args.rank = 0 # global rank args.rank = 0 # global rank
if args.distributed: if args.distributed:
args.num_gpu = 1
args.device = 'cuda:%d' % args.local_rank args.device = 'cuda:%d' % args.local_rank
torch.cuda.set_device(args.local_rank) torch.cuda.set_device(args.local_rank)
torch.distributed.init_process_group(backend='nccl', init_method='env://') torch.distributed.init_process_group(backend='nccl', init_method='env://')
args.world_size = torch.distributed.get_world_size() args.world_size = torch.distributed.get_world_size()
args.rank = torch.distributed.get_rank() args.rank = torch.distributed.get_rank()
assert args.rank >= 0
if args.distributed:
_logger.info('Training in distributed mode with multiple processes, 1 GPU per process. Process %d, total %d.' _logger.info('Training in distributed mode with multiple processes, 1 GPU per process. Process %d, total %d.'
% (args.rank, args.world_size)) % (args.rank, args.world_size))
else: else:
_logger.info('Training with a single process on %d GPUs.' % args.num_gpu) _logger.info('Training with a single process on 1 GPUs.')
assert args.rank >= 0
# resolve AMP arguments based on PyTorch / Apex availability
use_amp = None
if args.amp:
# for backwards compat, `--amp` arg tries apex before native amp
if has_apex:
args.apex_amp = True
elif has_native_amp:
args.native_amp = True
if args.apex_amp and has_apex:
use_amp = 'apex'
elif args.native_amp and has_native_amp:
use_amp = 'native'
elif args.apex_amp or args.native_amp:
_logger.warning("Neither APEX or native Torch AMP is available, using float32. "
"Install NVIDA apex or upgrade to PyTorch 1.6")
torch.manual_seed(args.seed + args.rank) torch.manual_seed(args.seed + args.rank)
@ -327,44 +335,44 @@ def main():
data_config = resolve_data_config(vars(args), model=model, verbose=args.local_rank == 0) data_config = resolve_data_config(vars(args), model=model, verbose=args.local_rank == 0)
# setup augmentation batch splits for contrastive loss or split bn
num_aug_splits = 0 num_aug_splits = 0
if args.aug_splits > 0: if args.aug_splits > 0:
assert args.aug_splits > 1, 'A split of 1 makes no sense' assert args.aug_splits > 1, 'A split of 1 makes no sense'
num_aug_splits = args.aug_splits num_aug_splits = args.aug_splits
# enable split bn (separate bn stats per batch-portion)
if args.split_bn: if args.split_bn:
assert num_aug_splits > 1 or args.resplit assert num_aug_splits > 1 or args.resplit
model = convert_splitbn_model(model, max(num_aug_splits, 2)) model = convert_splitbn_model(model, max(num_aug_splits, 2))
use_amp = None # move model to GPU, enable channels last layout if set
if args.amp: model.cuda()
# for backwards compat, `--amp` arg tries apex before native amp if args.channels_last:
if has_apex: model = model.to(memory_format=torch.channels_last)
args.apex_amp = True
elif has_native_amp:
args.native_amp = True
if args.apex_amp and has_apex:
use_amp = 'apex'
elif args.native_amp and has_native_amp:
use_amp = 'native'
elif args.apex_amp or args.native_amp:
_logger.warning("Neither APEX or native Torch AMP is available, using float32. "
"Install NVIDA apex or upgrade to PyTorch 1.6")
if args.num_gpu > 1: # setup synchronized BatchNorm for distributed training
if use_amp == 'apex': if args.distributed and args.sync_bn:
_logger.warning( assert not args.split_bn
'Apex AMP does not work well with nn.DataParallel, disabling. Use DDP or Torch AMP.') if has_apex and use_amp != 'native':
use_amp = None # Apex SyncBN preferred unless native amp is activated
model = nn.DataParallel(model, device_ids=list(range(args.num_gpu))).cuda() model = convert_syncbn_model(model)
assert not args.channels_last, "Channels last not supported with DP, use DDP." else:
else: model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
model.cuda() if args.local_rank == 0:
if args.channels_last: _logger.info(
model = model.to(memory_format=torch.channels_last) 'Converted model to use Synchronized BatchNorm. WARNING: You may have issues if using '
'zero initialized BN layers (enabled by default for ResNets) while sync-bn enabled.')
if args.torchscript:
assert not use_amp == 'apex', 'Cannot use APEX AMP with torchscripted model'
assert not args.sync_bn, 'Cannot use SyncBatchNorm with torchscripted model'
# FIXME I ran into a bug w/ AMP + torchscript + Linear layers
model = torch.jit.script(model)
optimizer = create_optimizer(args, model) optimizer = create_optimizer(args, model)
# setup automatic mixed-precision (AMP) loss scaling and op casting
amp_autocast = suppress # do nothing amp_autocast = suppress # do nothing
loss_scaler = None loss_scaler = None
if use_amp == 'apex': if use_amp == 'apex':
@ -390,30 +398,17 @@ def main():
loss_scaler=None if args.no_resume_opt else loss_scaler, loss_scaler=None if args.no_resume_opt else loss_scaler,
log_info=args.local_rank == 0) log_info=args.local_rank == 0)
# setup exponential moving average of model weights, SWA could be used here too
model_ema = None model_ema = None
if args.model_ema: if args.model_ema:
# Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper # Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper
model_ema = ModelEma( model_ema = ModelEmaV2(
model, model, decay=args.model_ema_decay, device='cpu' if args.model_ema_force_cpu else None)
decay=args.model_ema_decay, if args.resume:
device='cpu' if args.model_ema_force_cpu else '', load_checkpoint(model_ema.module, args.resume, use_ema=True)
resume=args.resume)
# setup distributed training
if args.distributed: if args.distributed:
if args.sync_bn:
assert not args.split_bn
try:
if has_apex and use_amp != 'native':
# Apex SyncBN preferred unless native amp is activated
model = convert_syncbn_model(model)
else:
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
if args.local_rank == 0:
_logger.info(
'Converted model to use Synchronized BatchNorm. WARNING: You may have issues if using '
'zero initialized BN layers (enabled by default for ResNets) while sync-bn enabled.')
except Exception as e:
_logger.error('Failed to enable Synchronized BatchNorm. Install Apex or Torch >= 1.1')
if has_apex and use_amp != 'native': if has_apex and use_amp != 'native':
# Apex DDP preferred unless native amp is activated # Apex DDP preferred unless native amp is activated
if args.local_rank == 0: if args.local_rank == 0:
@ -425,6 +420,7 @@ def main():
model = NativeDDP(model, device_ids=[args.local_rank]) # can use device str in Torch >= 1.1 model = NativeDDP(model, device_ids=[args.local_rank]) # can use device str in Torch >= 1.1
# NOTE: EMA model does not need to be wrapped by DDP # NOTE: EMA model does not need to be wrapped by DDP
# setup learning rate schedule and starting epoch
lr_scheduler, num_epochs = create_scheduler(args, optimizer) lr_scheduler, num_epochs = create_scheduler(args, optimizer)
start_epoch = 0 start_epoch = 0
if args.start_epoch is not None: if args.start_epoch is not None:
@ -438,12 +434,22 @@ def main():
if args.local_rank == 0: if args.local_rank == 0:
_logger.info('Scheduled epochs: {}'.format(num_epochs)) _logger.info('Scheduled epochs: {}'.format(num_epochs))
# create the train and eval datasets
train_dir = os.path.join(args.data, 'train') train_dir = os.path.join(args.data, 'train')
if not os.path.exists(train_dir): if not os.path.exists(train_dir):
_logger.error('Training folder does not exist at: {}'.format(train_dir)) _logger.error('Training folder does not exist at: {}'.format(train_dir))
exit(1) exit(1)
dataset_train = Dataset(train_dir) dataset_train = Dataset(train_dir)
eval_dir = os.path.join(args.data, 'val')
if not os.path.isdir(eval_dir):
eval_dir = os.path.join(args.data, 'validation')
if not os.path.isdir(eval_dir):
_logger.error('Validation folder does not exist at: {}'.format(eval_dir))
exit(1)
dataset_eval = Dataset(eval_dir)
# setup mixup / cutmix
collate_fn = None collate_fn = None
mixup_fn = None mixup_fn = None
mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None
@ -458,9 +464,11 @@ def main():
else: else:
mixup_fn = Mixup(**mixup_args) mixup_fn = Mixup(**mixup_args)
# wrap dataset in AugMix helper
if num_aug_splits > 1: if num_aug_splits > 1:
dataset_train = AugMixDataset(dataset_train, num_splits=num_aug_splits) dataset_train = AugMixDataset(dataset_train, num_splits=num_aug_splits)
# create data loaders w/ augmentation pipeiine
train_interpolation = args.train_interpolation train_interpolation = args.train_interpolation
if args.no_aug or not train_interpolation: if args.no_aug or not train_interpolation:
train_interpolation = data_config['interpolation'] train_interpolation = data_config['interpolation']
@ -492,14 +500,6 @@ def main():
use_multi_epochs_loader=args.use_multi_epochs_loader use_multi_epochs_loader=args.use_multi_epochs_loader
) )
eval_dir = os.path.join(args.data, 'val')
if not os.path.isdir(eval_dir):
eval_dir = os.path.join(args.data, 'validation')
if not os.path.isdir(eval_dir):
_logger.error('Validation folder does not exist at: {}'.format(eval_dir))
exit(1)
dataset_eval = Dataset(eval_dir)
loader_eval = create_loader( loader_eval = create_loader(
dataset_eval, dataset_eval,
input_size=data_config['input_size'], input_size=data_config['input_size'],
@ -515,6 +515,7 @@ def main():
pin_memory=args.pin_mem, pin_memory=args.pin_mem,
) )
# setup loss function
if args.jsd: if args.jsd:
assert num_aug_splits > 1 # JSD only valid with aug splits set assert num_aug_splits > 1 # JSD only valid with aug splits set
train_loss_fn = JsdCrossEntropy(num_splits=num_aug_splits, smoothing=args.smoothing).cuda() train_loss_fn = JsdCrossEntropy(num_splits=num_aug_splits, smoothing=args.smoothing).cuda()
@ -527,6 +528,7 @@ def main():
train_loss_fn = nn.CrossEntropyLoss().cuda() train_loss_fn = nn.CrossEntropyLoss().cuda()
validate_loss_fn = nn.CrossEntropyLoss().cuda() validate_loss_fn = nn.CrossEntropyLoss().cuda()
# setup checkpoint saver and eval metric tracking
eval_metric = args.eval_metric eval_metric = args.eval_metric
best_metric = None best_metric = None
best_epoch = None best_epoch = None
@ -638,11 +640,11 @@ def train_epoch(
torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip_grad) torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip_grad)
optimizer.step() optimizer.step()
torch.cuda.synchronize()
if model_ema is not None: if model_ema is not None:
model_ema.update(model) model_ema.update(model)
num_updates += 1
torch.cuda.synchronize()
num_updates += 1
batch_time_m.update(time.time() - end) batch_time_m.update(time.time() - end)
if last_batch or batch_idx % args.log_interval == 0: if last_batch or batch_idx % args.log_interval == 0:
lrl = [param_group['lr'] for param_group in optimizer.param_groups] lrl = [param_group['lr'] for param_group in optimizer.param_groups]

Loading…
Cancel
Save