diff --git a/tests/test_models.py b/tests/test_models.py index db8efbf3..a62625d9 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -121,7 +121,7 @@ if 'GITHUB_ACTIONS' not in os.environ: create_model(model_name, pretrained=True, in_chans=in_chans) @pytest.mark.timeout(120) - @pytest.mark.parametrize('model_name', list_models(pretrained=True)) + @pytest.mark.parametrize('model_name', list_models(pretrained=True, exclude_filters=['vit_*'])) @pytest.mark.parametrize('batch_size', [1]) def test_model_features_pretrained(model_name, batch_size): """Create that pretrained weights load when features_only==True.""" diff --git a/timm/models/efficientnet.py b/timm/models/efficientnet.py index a61a6f47..4a89590b 100644 --- a/timm/models/efficientnet.py +++ b/timm/models/efficientnet.py @@ -34,7 +34,7 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCE from .efficientnet_blocks import round_channels, resolve_bn_args, resolve_act_layer, BN_EPS_TF_DEFAULT from .efficientnet_builder import EfficientNetBuilder, decode_arch_def, efficientnet_init_weights from .features import FeatureInfo, FeatureHooks -from .helpers import build_model_with_cfg +from .helpers import build_model_with_cfg, default_cfg_for_features from .layers import create_conv2d, create_classifier from .registry import register_model @@ -453,18 +453,20 @@ class EfficientNetFeatures(nn.Module): def _create_effnet(model_kwargs, variant, pretrained=False): + features_only = False + model_cls = EfficientNet if model_kwargs.pop('features_only', False): - load_strict = False + features_only = True model_kwargs.pop('num_classes', 0) model_kwargs.pop('num_features', 0) model_kwargs.pop('head_conv', None) model_cls = EfficientNetFeatures - else: - load_strict = True - model_cls = EfficientNet - return build_model_with_cfg( + model = build_model_with_cfg( model_cls, variant, pretrained, default_cfg=default_cfgs[variant], - pretrained_strict=load_strict, **model_kwargs) + pretrained_strict=not features_only, **model_kwargs) + if features_only: + model.default_cfg = default_cfg_for_features(model.default_cfg) + return model def _gen_mnasnet_a1(variant, channel_multiplier=1.0, pretrained=False, **kwargs): diff --git a/timm/models/helpers.py b/timm/models/helpers.py index b90ce1db..77b98dc6 100644 --- a/timm/models/helpers.py +++ b/timm/models/helpers.py @@ -14,7 +14,7 @@ import torch.nn as nn import torch.utils.model_zoo as model_zoo from .features import FeatureListNet, FeatureDictNet, FeatureHookNet -from .layers import Conv2dSame +from .layers import Conv2dSame, Linear _logger = logging.getLogger(__name__) @@ -234,7 +234,7 @@ def adapt_model_from_string(parent_module, model_string): if isinstance(old_module, nn.Linear): # FIXME extra checks to ensure this is actually the FC classifier layer and not a diff Linear layer? num_features = state_dict[n + '.weight'][1] - new_fc = nn.Linear( + new_fc = Linear( in_features=num_features, out_features=old_module.out_features, bias=old_module.bias is not None) set_layer(new_module, n, new_fc) if hasattr(new_module, 'num_features'): @@ -251,6 +251,15 @@ def adapt_model_from_file(parent_module, model_variant): return adapt_model_from_string(parent_module, f.read().strip()) +def default_cfg_for_features(default_cfg): + default_cfg = deepcopy(default_cfg) + # remove default pretrained cfg fields that don't have much relevance for feature backbone + to_remove = ('num_classes', 'crop_pct', 'classifier') # add default final pool size? + for tr in to_remove: + default_cfg.pop(tr, None) + return default_cfg + + def build_model_with_cfg( model_cls: Callable, variant: str, @@ -296,5 +305,6 @@ def build_model_with_cfg( else: assert False, f'Unknown feature class {feature_cls}' model = feature_cls(model, **feature_cfg) + model.default_cfg = default_cfg_for_features(default_cfg) # add back default_cfg return model diff --git a/timm/models/hrnet.py b/timm/models/hrnet.py index 2e8757b5..1c0bc9f0 100644 --- a/timm/models/hrnet.py +++ b/timm/models/hrnet.py @@ -17,7 +17,7 @@ import torch.nn.functional as F from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from .features import FeatureInfo -from .helpers import build_model_with_cfg +from .helpers import build_model_with_cfg, default_cfg_for_features from .layers import create_classifier from .registry import register_model from .resnet import BasicBlock, Bottleneck # leveraging ResNet blocks w/ additional features like SE @@ -773,15 +773,17 @@ class HighResolutionNetFeatures(HighResolutionNet): def _create_hrnet(variant, pretrained, **model_kwargs): model_cls = HighResolutionNet - strict = True + features_only = False if model_kwargs.pop('features_only', False): model_cls = HighResolutionNetFeatures model_kwargs['num_classes'] = 0 - strict = False - - return build_model_with_cfg( + features_only = True + model = build_model_with_cfg( model_cls, variant, pretrained, default_cfg=default_cfgs[variant], - model_cfg=cfg_cls[variant], pretrained_strict=strict, **model_kwargs) + model_cfg=cfg_cls[variant], pretrained_strict=not features_only, **model_kwargs) + if features_only: + model.default_cfg = default_cfg_for_features(model.default_cfg) + return model @register_model diff --git a/timm/models/inception_v3.py b/timm/models/inception_v3.py index aee1cccc..9ae7105f 100644 --- a/timm/models/inception_v3.py +++ b/timm/models/inception_v3.py @@ -10,7 +10,7 @@ import torch.nn.functional as F from timm.data import IMAGENET_DEFAULT_STD, IMAGENET_DEFAULT_MEAN, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD from .helpers import build_model_with_cfg from .registry import register_model -from .layers import trunc_normal_, create_classifier +from .layers import trunc_normal_, create_classifier, Linear def _cfg(url='', **kwargs): @@ -250,7 +250,7 @@ class InceptionAux(nn.Module): self.conv0 = conv_block(in_channels, 128, kernel_size=1) self.conv1 = conv_block(128, 768, kernel_size=5) self.conv1.stddev = 0.01 - self.fc = nn.Linear(768, num_classes) + self.fc = Linear(768, num_classes) self.fc.stddev = 0.001 def forward(self, x): diff --git a/timm/models/layers/__init__.py b/timm/models/layers/__init__.py index a252b8c1..dac1beb8 100644 --- a/timm/models/layers/__init__.py +++ b/timm/models/layers/__init__.py @@ -18,6 +18,7 @@ from .eca import EcaModule, CecaModule from .evo_norm import EvoNormBatch2d, EvoNormSample2d from .helpers import to_ntuple, to_2tuple, to_3tuple, to_4tuple from .inplace_abn import InplaceAbn +from .linear import Linear from .mixed_conv2d import MixedConv2d from .norm_act import BatchNormAct2d from .padding import get_padding diff --git a/timm/models/layers/activations.py b/timm/models/layers/activations.py index edb2074f..e16b3bd3 100644 --- a/timm/models/layers/activations.py +++ b/timm/models/layers/activations.py @@ -119,3 +119,27 @@ class HardMish(nn.Module): def forward(self, x): return hard_mish(x, self.inplace) + + +class PReLU(nn.PReLU): + """Applies PReLU (w/ dummy inplace arg) + """ + def __init__(self, num_parameters: int = 1, init: float = 0.25, inplace: bool = False) -> None: + super(PReLU, self).__init__(num_parameters=num_parameters, init=init) + + def forward(self, input: torch.Tensor) -> torch.Tensor: + return F.prelu(input, self.weight) + + +def gelu(x: torch.Tensor, inplace: bool = False) -> torch.Tensor: + return F.gelu(x) + + +class GELU(nn.Module): + """Applies the Gaussian Error Linear Units function (w/ dummy inplace arg) + """ + def __init__(self, inplace: bool = False): + super(GELU, self).__init__() + + def forward(self, input: torch.Tensor) -> torch.Tensor: + return F.gelu(input) diff --git a/timm/models/layers/classifier.py b/timm/models/layers/classifier.py index e9194f05..89fe5458 100644 --- a/timm/models/layers/classifier.py +++ b/timm/models/layers/classifier.py @@ -6,6 +6,7 @@ from torch import nn as nn from torch.nn import functional as F from .adaptive_avgmax_pool import SelectAdaptivePool2d +from .linear import Linear def create_classifier(num_features, num_classes, pool_type='avg', use_conv=False): @@ -21,7 +22,8 @@ def create_classifier(num_features, num_classes, pool_type='avg', use_conv=False elif use_conv: fc = nn.Conv2d(num_pooled_features, num_classes, 1, bias=True) else: - fc = nn.Linear(num_pooled_features, num_classes, bias=True) + # NOTE: using my Linear wrapper that fixes AMP + torchscript casting issue + fc = Linear(num_pooled_features, num_classes, bias=True) return global_pool, fc diff --git a/timm/models/layers/create_act.py b/timm/models/layers/create_act.py index 6f2ab83e..426c3681 100644 --- a/timm/models/layers/create_act.py +++ b/timm/models/layers/create_act.py @@ -19,10 +19,9 @@ _ACT_FN_DEFAULT = dict( relu6=F.relu6, leaky_relu=F.leaky_relu, elu=F.elu, - prelu=F.prelu, celu=F.celu, selu=F.selu, - gelu=F.gelu, + gelu=gelu, sigmoid=sigmoid, tanh=tanh, hard_sigmoid=hard_sigmoid, @@ -56,10 +55,10 @@ _ACT_LAYER_DEFAULT = dict( relu6=nn.ReLU6, leaky_relu=nn.LeakyReLU, elu=nn.ELU, - prelu=nn.PReLU, + prelu=PReLU, celu=nn.CELU, selu=nn.SELU, - gelu=nn.GELU, + gelu=GELU, sigmoid=Sigmoid, tanh=Tanh, hard_sigmoid=HardSigmoid, @@ -98,7 +97,10 @@ def get_act_fn(name='relu'): # custom autograd, then fallback if name in _ACT_FN_ME: return _ACT_FN_ME[name] - if not is_no_jit(): + if is_exportable() and name in ('silu', 'swish'): + # FIXME PyTorch SiLU doesn't ONNX export, this is a temp hack + return swish + if not (is_no_jit() or is_exportable()): if name in _ACT_FN_JIT: return _ACT_FN_JIT[name] return _ACT_FN_DEFAULT[name] @@ -114,7 +116,10 @@ def get_act_layer(name='relu'): if not (is_no_jit() or is_exportable() or is_scriptable()): if name in _ACT_LAYER_ME: return _ACT_LAYER_ME[name] - if not is_no_jit(): + if is_exportable() and name in ('silu', 'swish'): + # FIXME PyTorch SiLU doesn't ONNX export, this is a temp hack + return Swish + if not (is_no_jit() or is_exportable()): if name in _ACT_LAYER_JIT: return _ACT_LAYER_JIT[name] return _ACT_LAYER_DEFAULT[name] diff --git a/timm/models/layers/linear.py b/timm/models/layers/linear.py new file mode 100644 index 00000000..38fe3380 --- /dev/null +++ b/timm/models/layers/linear.py @@ -0,0 +1,19 @@ +""" Linear layer (alternate definition) +""" +import torch +import torch.nn.functional as F +from torch import nn as nn + + +class Linear(nn.Linear): + r"""Applies a linear transformation to the incoming data: :math:`y = xA^T + b` + + Wraps torch.nn.Linear to support AMP + torchscript usage by manually casting + weight & bias to input.dtype to work around an issue w/ torch.addmm in this use case. + """ + def forward(self, input: torch.Tensor) -> torch.Tensor: + if torch.jit.is_scripting(): + bias = self.bias.to(dtype=input.dtype) if self.bias is not None else None + return F.linear(input, self.weight.to(dtype=input.dtype), bias=bias) + else: + return F.linear(input, self.weight, self.bias) diff --git a/timm/models/mobilenetv3.py b/timm/models/mobilenetv3.py index e20b6d34..8a48ce72 100644 --- a/timm/models/mobilenetv3.py +++ b/timm/models/mobilenetv3.py @@ -17,8 +17,8 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCE from .efficientnet_blocks import round_channels, resolve_bn_args, resolve_act_layer, BN_EPS_TF_DEFAULT from .efficientnet_builder import EfficientNetBuilder, decode_arch_def, efficientnet_init_weights from .features import FeatureInfo, FeatureHooks -from .helpers import build_model_with_cfg -from .layers import SelectAdaptivePool2d, create_conv2d, get_act_fn, hard_sigmoid +from .helpers import build_model_with_cfg, default_cfg_for_features +from .layers import SelectAdaptivePool2d, Linear, create_conv2d, get_act_fn, hard_sigmoid from .registry import register_model __all__ = ['MobileNetV3'] @@ -105,7 +105,7 @@ class MobileNetV3(nn.Module): num_pooled_chs = head_chs * self.global_pool.feat_mult() self.conv_head = create_conv2d(num_pooled_chs, self.num_features, 1, padding=pad_type, bias=head_bias) self.act2 = act_layer(inplace=True) - self.classifier = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() + self.classifier = Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() efficientnet_init_weights(self) @@ -123,7 +123,7 @@ class MobileNetV3(nn.Module): self.num_classes = num_classes # cannot meaningfully change pooling of efficient head after creation self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) - self.classifier = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() + self.classifier = Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() def forward_features(self, x): x = self.conv_stem(x) @@ -201,19 +201,21 @@ class MobileNetV3Features(nn.Module): def _create_mnv3(model_kwargs, variant, pretrained=False): + features_only = False + model_cls = MobileNetV3 if model_kwargs.pop('features_only', False): - load_strict = False + features_only = True model_kwargs.pop('num_classes', 0) model_kwargs.pop('num_features', 0) model_kwargs.pop('head_conv', None) model_kwargs.pop('head_bias', None) model_cls = MobileNetV3Features - else: - load_strict = True - model_cls = MobileNetV3 - return build_model_with_cfg( + model = build_model_with_cfg( model_cls, variant, pretrained, default_cfg=default_cfgs[variant], - pretrained_strict=load_strict, **model_kwargs) + pretrained_strict=not features_only, **model_kwargs) + if features_only: + model.default_cfg = default_cfg_for_features(model.default_cfg) + return model def _gen_mobilenet_v3_rw(variant, channel_multiplier=1.0, pretrained=False, **kwargs): diff --git a/timm/utils/__init__.py b/timm/utils/__init__.py index 6efc2115..0f7c4b05 100644 --- a/timm/utils/__init__.py +++ b/timm/utils/__init__.py @@ -6,5 +6,5 @@ from .log import setup_default_logging, FormatterNoInfo from .metrics import AverageMeter, accuracy from .misc import natural_key, add_bool_arg from .model import unwrap_model, get_state_dict -from .model_ema import ModelEma +from .model_ema import ModelEma, ModelEmaV2 from .summary import update_summary, get_outdir diff --git a/timm/utils/model_ema.py b/timm/utils/model_ema.py index b788b32e..073d5c5e 100644 --- a/timm/utils/model_ema.py +++ b/timm/utils/model_ema.py @@ -7,13 +7,16 @@ from collections import OrderedDict from copy import deepcopy import torch +import torch.nn as nn _logger = logging.getLogger(__name__) class ModelEma: - """ Model Exponential Moving Average + """ Model Exponential Moving Average (DEPRECATED) + Keep a moving average of everything in the model state_dict (parameters and buffers). + This version is deprecated, it does not work with scripted models. Will be removed eventually. This is intended to allow functionality like https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage @@ -30,7 +33,6 @@ class ModelEma: This class is sensitive where it is initialized in the sequence of model init, GPU assignment and distributed training wrappers. - I've tested with the sequence in my own train.py for torch.DataParallel, apex.DDP, and single-GPU. """ def __init__(self, model, decay=0.9999, device='', resume=''): # make a copy of the model for accumulating moving average of weights @@ -75,3 +77,50 @@ class ModelEma: if self.device: model_v = model_v.to(device=self.device) ema_v.copy_(ema_v * self.decay + (1. - self.decay) * model_v) + + +class ModelEmaV2(nn.Module): + """ Model Exponential Moving Average V2 + + Keep a moving average of everything in the model state_dict (parameters and buffers). + V2 of this module is simpler, it does not match params/buffers based on name but simply + iterates in order. It works with torchscript (JIT of full model). + + This is intended to allow functionality like + https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage + + A smoothed version of the weights is necessary for some training schemes to perform well. + E.g. Google's hyper-params for training MNASNet, MobileNet-V3, EfficientNet, etc that use + RMSprop with a short 2.4-3 epoch decay period and slow LR decay rate of .96-.99 requires EMA + smoothing of weights to match results. Pay attention to the decay constant you are using + relative to your update count per epoch. + + To keep EMA from using GPU resources, set device='cpu'. This will save a bit of memory but + disable validation of the EMA weights. Validation will have to be done manually in a separate + process, or after the training stops converging. + + This class is sensitive where it is initialized in the sequence of model init, + GPU assignment and distributed training wrappers. + """ + def __init__(self, model, decay=0.9999, device=None): + super(ModelEmaV2, self).__init__() + # make a copy of the model for accumulating moving average of weights + self.module = deepcopy(model) + self.module.eval() + self.decay = decay + self.device = device # perform ema on different device from model if set + if self.device is not None: + self.module.to(device=device) + + def _update(self, model, update_fn): + with torch.no_grad(): + for ema_v, model_v in zip(self.module.state_dict().values(), model.state_dict().values()): + if self.device is not None: + model_v = model_v.to(device=self.device) + ema_v.copy_(update_fn(ema_v, model_v)) + + def update(self, model): + self._update(model, update_fn=lambda e, m: self.decay * e + (1. - self.decay) * m) + + def set(self, model): + self._update(model, update_fn=lambda e, m: m) diff --git a/timm/version.py b/timm/version.py index e1424ed0..73e3bb4f 100644 --- a/timm/version.py +++ b/timm/version.py @@ -1 +1 @@ -__version__ = '0.3.1' +__version__ = '0.3.2' diff --git a/train.py b/train.py index ef3adf85..7a93a1b6 100755 --- a/train.py +++ b/train.py @@ -29,7 +29,7 @@ import torchvision.utils from torch.nn.parallel import DistributedDataParallel as NativeDDP from timm.data import Dataset, create_loader, resolve_data_config, Mixup, FastCollateMixup, AugMixDataset -from timm.models import create_model, resume_checkpoint, convert_splitbn_model +from timm.models import create_model, resume_checkpoint, load_checkpoint, convert_splitbn_model from timm.utils import * from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy, JsdCrossEntropy from timm.optim import create_optimizer @@ -230,8 +230,6 @@ parser.add_argument('--recovery-interval', type=int, default=0, metavar='N', help='how many batches to wait before writing recovery checkpoint') parser.add_argument('-j', '--workers', type=int, default=4, metavar='N', help='how many training processes to use (default: 1)') -parser.add_argument('--num-gpu', type=int, default=1, - help='Number of GPUS to use') parser.add_argument('--save-images', action='store_true', default=False, help='save images of input bathes every log interval for debugging') parser.add_argument('--amp', action='store_true', default=False, @@ -255,6 +253,8 @@ parser.add_argument('--tta', type=int, default=0, metavar='N', parser.add_argument("--local_rank", default=0, type=int) parser.add_argument('--use-multi-epochs-loader', action='store_true', default=False, help='use the multi-epochs-loader to save time at the beginning of every epoch') +parser.add_argument('--torchscript', dest='torchscript', action='store_true', + help='convert model torchscript for inference') def _parse_args(): @@ -282,28 +282,36 @@ def main(): args.distributed = False if 'WORLD_SIZE' in os.environ: args.distributed = int(os.environ['WORLD_SIZE']) > 1 - if args.distributed and args.num_gpu > 1: - _logger.warning( - 'Using more than one GPU per process in distributed mode is not allowed.Setting num_gpu to 1.') - args.num_gpu = 1 - args.device = 'cuda:0' args.world_size = 1 args.rank = 0 # global rank if args.distributed: - args.num_gpu = 1 args.device = 'cuda:%d' % args.local_rank torch.cuda.set_device(args.local_rank) torch.distributed.init_process_group(backend='nccl', init_method='env://') args.world_size = torch.distributed.get_world_size() args.rank = torch.distributed.get_rank() - assert args.rank >= 0 - - if args.distributed: _logger.info('Training in distributed mode with multiple processes, 1 GPU per process. Process %d, total %d.' % (args.rank, args.world_size)) else: - _logger.info('Training with a single process on %d GPUs.' % args.num_gpu) + _logger.info('Training with a single process on 1 GPUs.') + assert args.rank >= 0 + + # resolve AMP arguments based on PyTorch / Apex availability + use_amp = None + if args.amp: + # for backwards compat, `--amp` arg tries apex before native amp + if has_apex: + args.apex_amp = True + elif has_native_amp: + args.native_amp = True + if args.apex_amp and has_apex: + use_amp = 'apex' + elif args.native_amp and has_native_amp: + use_amp = 'native' + elif args.apex_amp or args.native_amp: + _logger.warning("Neither APEX or native Torch AMP is available, using float32. " + "Install NVIDA apex or upgrade to PyTorch 1.6") torch.manual_seed(args.seed + args.rank) @@ -319,6 +327,7 @@ def main(): bn_tf=args.bn_tf, bn_momentum=args.bn_momentum, bn_eps=args.bn_eps, + scriptable=args.torchscript, checkpoint_path=args.initial_checkpoint) if args.local_rank == 0: @@ -327,44 +336,43 @@ def main(): data_config = resolve_data_config(vars(args), model=model, verbose=args.local_rank == 0) + # setup augmentation batch splits for contrastive loss or split bn num_aug_splits = 0 if args.aug_splits > 0: assert args.aug_splits > 1, 'A split of 1 makes no sense' num_aug_splits = args.aug_splits + # enable split bn (separate bn stats per batch-portion) if args.split_bn: assert num_aug_splits > 1 or args.resplit model = convert_splitbn_model(model, max(num_aug_splits, 2)) - use_amp = None - if args.amp: - # for backwards compat, `--amp` arg tries apex before native amp - if has_apex: - args.apex_amp = True - elif has_native_amp: - args.native_amp = True - if args.apex_amp and has_apex: - use_amp = 'apex' - elif args.native_amp and has_native_amp: - use_amp = 'native' - elif args.apex_amp or args.native_amp: - _logger.warning("Neither APEX or native Torch AMP is available, using float32. " - "Install NVIDA apex or upgrade to PyTorch 1.6") + # move model to GPU, enable channels last layout if set + model.cuda() + if args.channels_last: + model = model.to(memory_format=torch.channels_last) - if args.num_gpu > 1: - if use_amp == 'apex': - _logger.warning( - 'Apex AMP does not work well with nn.DataParallel, disabling. Use DDP or Torch AMP.') - use_amp = None - model = nn.DataParallel(model, device_ids=list(range(args.num_gpu))).cuda() - assert not args.channels_last, "Channels last not supported with DP, use DDP." - else: - model.cuda() - if args.channels_last: - model = model.to(memory_format=torch.channels_last) + # setup synchronized BatchNorm for distributed training + if args.distributed and args.sync_bn: + assert not args.split_bn + if has_apex and use_amp != 'native': + # Apex SyncBN preferred unless native amp is activated + model = convert_syncbn_model(model) + else: + model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) + if args.local_rank == 0: + _logger.info( + 'Converted model to use Synchronized BatchNorm. WARNING: You may have issues if using ' + 'zero initialized BN layers (enabled by default for ResNets) while sync-bn enabled.') + + if args.torchscript: + assert not use_amp == 'apex', 'Cannot use APEX AMP with torchscripted model' + assert not args.sync_bn, 'Cannot use SyncBatchNorm with torchscripted model' + model = torch.jit.script(model) optimizer = create_optimizer(args, model) + # setup automatic mixed-precision (AMP) loss scaling and op casting amp_autocast = suppress # do nothing loss_scaler = None if use_amp == 'apex': @@ -390,30 +398,17 @@ def main(): loss_scaler=None if args.no_resume_opt else loss_scaler, log_info=args.local_rank == 0) + # setup exponential moving average of model weights, SWA could be used here too model_ema = None if args.model_ema: # Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper - model_ema = ModelEma( - model, - decay=args.model_ema_decay, - device='cpu' if args.model_ema_force_cpu else '', - resume=args.resume) + model_ema = ModelEmaV2( + model, decay=args.model_ema_decay, device='cpu' if args.model_ema_force_cpu else None) + if args.resume: + load_checkpoint(model_ema.module, args.resume, use_ema=True) + # setup distributed training if args.distributed: - if args.sync_bn: - assert not args.split_bn - try: - if has_apex and use_amp != 'native': - # Apex SyncBN preferred unless native amp is activated - model = convert_syncbn_model(model) - else: - model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) - if args.local_rank == 0: - _logger.info( - 'Converted model to use Synchronized BatchNorm. WARNING: You may have issues if using ' - 'zero initialized BN layers (enabled by default for ResNets) while sync-bn enabled.') - except Exception as e: - _logger.error('Failed to enable Synchronized BatchNorm. Install Apex or Torch >= 1.1') if has_apex and use_amp != 'native': # Apex DDP preferred unless native amp is activated if args.local_rank == 0: @@ -425,6 +420,7 @@ def main(): model = NativeDDP(model, device_ids=[args.local_rank]) # can use device str in Torch >= 1.1 # NOTE: EMA model does not need to be wrapped by DDP + # setup learning rate schedule and starting epoch lr_scheduler, num_epochs = create_scheduler(args, optimizer) start_epoch = 0 if args.start_epoch is not None: @@ -438,12 +434,22 @@ def main(): if args.local_rank == 0: _logger.info('Scheduled epochs: {}'.format(num_epochs)) + # create the train and eval datasets train_dir = os.path.join(args.data, 'train') if not os.path.exists(train_dir): _logger.error('Training folder does not exist at: {}'.format(train_dir)) exit(1) dataset_train = Dataset(train_dir) + eval_dir = os.path.join(args.data, 'val') + if not os.path.isdir(eval_dir): + eval_dir = os.path.join(args.data, 'validation') + if not os.path.isdir(eval_dir): + _logger.error('Validation folder does not exist at: {}'.format(eval_dir)) + exit(1) + dataset_eval = Dataset(eval_dir) + + # setup mixup / cutmix collate_fn = None mixup_fn = None mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None @@ -458,9 +464,11 @@ def main(): else: mixup_fn = Mixup(**mixup_args) + # wrap dataset in AugMix helper if num_aug_splits > 1: dataset_train = AugMixDataset(dataset_train, num_splits=num_aug_splits) + # create data loaders w/ augmentation pipeiine train_interpolation = args.train_interpolation if args.no_aug or not train_interpolation: train_interpolation = data_config['interpolation'] @@ -492,14 +500,6 @@ def main(): use_multi_epochs_loader=args.use_multi_epochs_loader ) - eval_dir = os.path.join(args.data, 'val') - if not os.path.isdir(eval_dir): - eval_dir = os.path.join(args.data, 'validation') - if not os.path.isdir(eval_dir): - _logger.error('Validation folder does not exist at: {}'.format(eval_dir)) - exit(1) - dataset_eval = Dataset(eval_dir) - loader_eval = create_loader( dataset_eval, input_size=data_config['input_size'], @@ -515,6 +515,7 @@ def main(): pin_memory=args.pin_mem, ) + # setup loss function if args.jsd: assert num_aug_splits > 1 # JSD only valid with aug splits set train_loss_fn = JsdCrossEntropy(num_splits=num_aug_splits, smoothing=args.smoothing).cuda() @@ -527,6 +528,7 @@ def main(): train_loss_fn = nn.CrossEntropyLoss().cuda() validate_loss_fn = nn.CrossEntropyLoss().cuda() + # setup checkpoint saver and eval metric tracking eval_metric = args.eval_metric best_metric = None best_epoch = None @@ -568,7 +570,7 @@ def main(): if args.distributed and args.dist_bn in ('broadcast', 'reduce'): distribute_bn(model_ema, args.world_size, args.dist_bn == 'reduce') ema_eval_metrics = validate( - model_ema.ema, loader_eval, validate_loss_fn, args, amp_autocast=amp_autocast, log_suffix=' (EMA)') + model_ema.module, loader_eval, validate_loss_fn, args, amp_autocast=amp_autocast, log_suffix=' (EMA)') eval_metrics = ema_eval_metrics if lr_scheduler is not None: @@ -638,11 +640,11 @@ def train_epoch( torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip_grad) optimizer.step() - torch.cuda.synchronize() if model_ema is not None: model_ema.update(model) - num_updates += 1 + torch.cuda.synchronize() + num_updates += 1 batch_time_m.update(time.time() - end) if last_batch or batch_idx % args.log_interval == 0: lrl = [param_group['lr'] for param_group in optimizer.param_groups]