Merge pull request #297 from rwightman/ema_simplify

Simplified JIT compatible Ema module. Fixes for SiLU export and torchscript training w/ Linear layer.
pull/302/head
Ross Wightman 4 years ago committed by GitHub
commit 9a25fdf3ad
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -121,7 +121,7 @@ if 'GITHUB_ACTIONS' not in os.environ:
create_model(model_name, pretrained=True, in_chans=in_chans) create_model(model_name, pretrained=True, in_chans=in_chans)
@pytest.mark.timeout(120) @pytest.mark.timeout(120)
@pytest.mark.parametrize('model_name', list_models(pretrained=True)) @pytest.mark.parametrize('model_name', list_models(pretrained=True, exclude_filters=['vit_*']))
@pytest.mark.parametrize('batch_size', [1]) @pytest.mark.parametrize('batch_size', [1])
def test_model_features_pretrained(model_name, batch_size): def test_model_features_pretrained(model_name, batch_size):
"""Create that pretrained weights load when features_only==True.""" """Create that pretrained weights load when features_only==True."""

@ -34,7 +34,7 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCE
from .efficientnet_blocks import round_channels, resolve_bn_args, resolve_act_layer, BN_EPS_TF_DEFAULT from .efficientnet_blocks import round_channels, resolve_bn_args, resolve_act_layer, BN_EPS_TF_DEFAULT
from .efficientnet_builder import EfficientNetBuilder, decode_arch_def, efficientnet_init_weights from .efficientnet_builder import EfficientNetBuilder, decode_arch_def, efficientnet_init_weights
from .features import FeatureInfo, FeatureHooks from .features import FeatureInfo, FeatureHooks
from .helpers import build_model_with_cfg from .helpers import build_model_with_cfg, default_cfg_for_features
from .layers import create_conv2d, create_classifier from .layers import create_conv2d, create_classifier
from .registry import register_model from .registry import register_model
@ -453,18 +453,20 @@ class EfficientNetFeatures(nn.Module):
def _create_effnet(model_kwargs, variant, pretrained=False): def _create_effnet(model_kwargs, variant, pretrained=False):
features_only = False
model_cls = EfficientNet
if model_kwargs.pop('features_only', False): if model_kwargs.pop('features_only', False):
load_strict = False features_only = True
model_kwargs.pop('num_classes', 0) model_kwargs.pop('num_classes', 0)
model_kwargs.pop('num_features', 0) model_kwargs.pop('num_features', 0)
model_kwargs.pop('head_conv', None) model_kwargs.pop('head_conv', None)
model_cls = EfficientNetFeatures model_cls = EfficientNetFeatures
else: model = build_model_with_cfg(
load_strict = True
model_cls = EfficientNet
return build_model_with_cfg(
model_cls, variant, pretrained, default_cfg=default_cfgs[variant], model_cls, variant, pretrained, default_cfg=default_cfgs[variant],
pretrained_strict=load_strict, **model_kwargs) pretrained_strict=not features_only, **model_kwargs)
if features_only:
model.default_cfg = default_cfg_for_features(model.default_cfg)
return model
def _gen_mnasnet_a1(variant, channel_multiplier=1.0, pretrained=False, **kwargs): def _gen_mnasnet_a1(variant, channel_multiplier=1.0, pretrained=False, **kwargs):

@ -14,7 +14,7 @@ import torch.nn as nn
import torch.utils.model_zoo as model_zoo import torch.utils.model_zoo as model_zoo
from .features import FeatureListNet, FeatureDictNet, FeatureHookNet from .features import FeatureListNet, FeatureDictNet, FeatureHookNet
from .layers import Conv2dSame from .layers import Conv2dSame, Linear
_logger = logging.getLogger(__name__) _logger = logging.getLogger(__name__)
@ -234,7 +234,7 @@ def adapt_model_from_string(parent_module, model_string):
if isinstance(old_module, nn.Linear): if isinstance(old_module, nn.Linear):
# FIXME extra checks to ensure this is actually the FC classifier layer and not a diff Linear layer? # FIXME extra checks to ensure this is actually the FC classifier layer and not a diff Linear layer?
num_features = state_dict[n + '.weight'][1] num_features = state_dict[n + '.weight'][1]
new_fc = nn.Linear( new_fc = Linear(
in_features=num_features, out_features=old_module.out_features, bias=old_module.bias is not None) in_features=num_features, out_features=old_module.out_features, bias=old_module.bias is not None)
set_layer(new_module, n, new_fc) set_layer(new_module, n, new_fc)
if hasattr(new_module, 'num_features'): if hasattr(new_module, 'num_features'):
@ -251,6 +251,15 @@ def adapt_model_from_file(parent_module, model_variant):
return adapt_model_from_string(parent_module, f.read().strip()) return adapt_model_from_string(parent_module, f.read().strip())
def default_cfg_for_features(default_cfg):
default_cfg = deepcopy(default_cfg)
# remove default pretrained cfg fields that don't have much relevance for feature backbone
to_remove = ('num_classes', 'crop_pct', 'classifier') # add default final pool size?
for tr in to_remove:
default_cfg.pop(tr, None)
return default_cfg
def build_model_with_cfg( def build_model_with_cfg(
model_cls: Callable, model_cls: Callable,
variant: str, variant: str,
@ -296,5 +305,6 @@ def build_model_with_cfg(
else: else:
assert False, f'Unknown feature class {feature_cls}' assert False, f'Unknown feature class {feature_cls}'
model = feature_cls(model, **feature_cfg) model = feature_cls(model, **feature_cfg)
model.default_cfg = default_cfg_for_features(default_cfg) # add back default_cfg
return model return model

@ -17,7 +17,7 @@ import torch.nn.functional as F
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .features import FeatureInfo from .features import FeatureInfo
from .helpers import build_model_with_cfg from .helpers import build_model_with_cfg, default_cfg_for_features
from .layers import create_classifier from .layers import create_classifier
from .registry import register_model from .registry import register_model
from .resnet import BasicBlock, Bottleneck # leveraging ResNet blocks w/ additional features like SE from .resnet import BasicBlock, Bottleneck # leveraging ResNet blocks w/ additional features like SE
@ -773,15 +773,17 @@ class HighResolutionNetFeatures(HighResolutionNet):
def _create_hrnet(variant, pretrained, **model_kwargs): def _create_hrnet(variant, pretrained, **model_kwargs):
model_cls = HighResolutionNet model_cls = HighResolutionNet
strict = True features_only = False
if model_kwargs.pop('features_only', False): if model_kwargs.pop('features_only', False):
model_cls = HighResolutionNetFeatures model_cls = HighResolutionNetFeatures
model_kwargs['num_classes'] = 0 model_kwargs['num_classes'] = 0
strict = False features_only = True
model = build_model_with_cfg(
return build_model_with_cfg(
model_cls, variant, pretrained, default_cfg=default_cfgs[variant], model_cls, variant, pretrained, default_cfg=default_cfgs[variant],
model_cfg=cfg_cls[variant], pretrained_strict=strict, **model_kwargs) model_cfg=cfg_cls[variant], pretrained_strict=not features_only, **model_kwargs)
if features_only:
model.default_cfg = default_cfg_for_features(model.default_cfg)
return model
@register_model @register_model

@ -10,7 +10,7 @@ import torch.nn.functional as F
from timm.data import IMAGENET_DEFAULT_STD, IMAGENET_DEFAULT_MEAN, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD from timm.data import IMAGENET_DEFAULT_STD, IMAGENET_DEFAULT_MEAN, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
from .helpers import build_model_with_cfg from .helpers import build_model_with_cfg
from .registry import register_model from .registry import register_model
from .layers import trunc_normal_, create_classifier from .layers import trunc_normal_, create_classifier, Linear
def _cfg(url='', **kwargs): def _cfg(url='', **kwargs):
@ -250,7 +250,7 @@ class InceptionAux(nn.Module):
self.conv0 = conv_block(in_channels, 128, kernel_size=1) self.conv0 = conv_block(in_channels, 128, kernel_size=1)
self.conv1 = conv_block(128, 768, kernel_size=5) self.conv1 = conv_block(128, 768, kernel_size=5)
self.conv1.stddev = 0.01 self.conv1.stddev = 0.01
self.fc = nn.Linear(768, num_classes) self.fc = Linear(768, num_classes)
self.fc.stddev = 0.001 self.fc.stddev = 0.001
def forward(self, x): def forward(self, x):

@ -18,6 +18,7 @@ from .eca import EcaModule, CecaModule
from .evo_norm import EvoNormBatch2d, EvoNormSample2d from .evo_norm import EvoNormBatch2d, EvoNormSample2d
from .helpers import to_ntuple, to_2tuple, to_3tuple, to_4tuple from .helpers import to_ntuple, to_2tuple, to_3tuple, to_4tuple
from .inplace_abn import InplaceAbn from .inplace_abn import InplaceAbn
from .linear import Linear
from .mixed_conv2d import MixedConv2d from .mixed_conv2d import MixedConv2d
from .norm_act import BatchNormAct2d from .norm_act import BatchNormAct2d
from .padding import get_padding from .padding import get_padding

@ -119,3 +119,27 @@ class HardMish(nn.Module):
def forward(self, x): def forward(self, x):
return hard_mish(x, self.inplace) return hard_mish(x, self.inplace)
class PReLU(nn.PReLU):
"""Applies PReLU (w/ dummy inplace arg)
"""
def __init__(self, num_parameters: int = 1, init: float = 0.25, inplace: bool = False) -> None:
super(PReLU, self).__init__(num_parameters=num_parameters, init=init)
def forward(self, input: torch.Tensor) -> torch.Tensor:
return F.prelu(input, self.weight)
def gelu(x: torch.Tensor, inplace: bool = False) -> torch.Tensor:
return F.gelu(x)
class GELU(nn.Module):
"""Applies the Gaussian Error Linear Units function (w/ dummy inplace arg)
"""
def __init__(self, inplace: bool = False):
super(GELU, self).__init__()
def forward(self, input: torch.Tensor) -> torch.Tensor:
return F.gelu(input)

@ -6,6 +6,7 @@ from torch import nn as nn
from torch.nn import functional as F from torch.nn import functional as F
from .adaptive_avgmax_pool import SelectAdaptivePool2d from .adaptive_avgmax_pool import SelectAdaptivePool2d
from .linear import Linear
def create_classifier(num_features, num_classes, pool_type='avg', use_conv=False): def create_classifier(num_features, num_classes, pool_type='avg', use_conv=False):
@ -21,7 +22,8 @@ def create_classifier(num_features, num_classes, pool_type='avg', use_conv=False
elif use_conv: elif use_conv:
fc = nn.Conv2d(num_pooled_features, num_classes, 1, bias=True) fc = nn.Conv2d(num_pooled_features, num_classes, 1, bias=True)
else: else:
fc = nn.Linear(num_pooled_features, num_classes, bias=True) # NOTE: using my Linear wrapper that fixes AMP + torchscript casting issue
fc = Linear(num_pooled_features, num_classes, bias=True)
return global_pool, fc return global_pool, fc

@ -19,10 +19,9 @@ _ACT_FN_DEFAULT = dict(
relu6=F.relu6, relu6=F.relu6,
leaky_relu=F.leaky_relu, leaky_relu=F.leaky_relu,
elu=F.elu, elu=F.elu,
prelu=F.prelu,
celu=F.celu, celu=F.celu,
selu=F.selu, selu=F.selu,
gelu=F.gelu, gelu=gelu,
sigmoid=sigmoid, sigmoid=sigmoid,
tanh=tanh, tanh=tanh,
hard_sigmoid=hard_sigmoid, hard_sigmoid=hard_sigmoid,
@ -56,10 +55,10 @@ _ACT_LAYER_DEFAULT = dict(
relu6=nn.ReLU6, relu6=nn.ReLU6,
leaky_relu=nn.LeakyReLU, leaky_relu=nn.LeakyReLU,
elu=nn.ELU, elu=nn.ELU,
prelu=nn.PReLU, prelu=PReLU,
celu=nn.CELU, celu=nn.CELU,
selu=nn.SELU, selu=nn.SELU,
gelu=nn.GELU, gelu=GELU,
sigmoid=Sigmoid, sigmoid=Sigmoid,
tanh=Tanh, tanh=Tanh,
hard_sigmoid=HardSigmoid, hard_sigmoid=HardSigmoid,
@ -98,7 +97,10 @@ def get_act_fn(name='relu'):
# custom autograd, then fallback # custom autograd, then fallback
if name in _ACT_FN_ME: if name in _ACT_FN_ME:
return _ACT_FN_ME[name] return _ACT_FN_ME[name]
if not is_no_jit(): if is_exportable() and name in ('silu', 'swish'):
# FIXME PyTorch SiLU doesn't ONNX export, this is a temp hack
return swish
if not (is_no_jit() or is_exportable()):
if name in _ACT_FN_JIT: if name in _ACT_FN_JIT:
return _ACT_FN_JIT[name] return _ACT_FN_JIT[name]
return _ACT_FN_DEFAULT[name] return _ACT_FN_DEFAULT[name]
@ -114,7 +116,10 @@ def get_act_layer(name='relu'):
if not (is_no_jit() or is_exportable() or is_scriptable()): if not (is_no_jit() or is_exportable() or is_scriptable()):
if name in _ACT_LAYER_ME: if name in _ACT_LAYER_ME:
return _ACT_LAYER_ME[name] return _ACT_LAYER_ME[name]
if not is_no_jit(): if is_exportable() and name in ('silu', 'swish'):
# FIXME PyTorch SiLU doesn't ONNX export, this is a temp hack
return Swish
if not (is_no_jit() or is_exportable()):
if name in _ACT_LAYER_JIT: if name in _ACT_LAYER_JIT:
return _ACT_LAYER_JIT[name] return _ACT_LAYER_JIT[name]
return _ACT_LAYER_DEFAULT[name] return _ACT_LAYER_DEFAULT[name]

@ -0,0 +1,19 @@
""" Linear layer (alternate definition)
"""
import torch
import torch.nn.functional as F
from torch import nn as nn
class Linear(nn.Linear):
r"""Applies a linear transformation to the incoming data: :math:`y = xA^T + b`
Wraps torch.nn.Linear to support AMP + torchscript usage by manually casting
weight & bias to input.dtype to work around an issue w/ torch.addmm in this use case.
"""
def forward(self, input: torch.Tensor) -> torch.Tensor:
if torch.jit.is_scripting():
bias = self.bias.to(dtype=input.dtype) if self.bias is not None else None
return F.linear(input, self.weight.to(dtype=input.dtype), bias=bias)
else:
return F.linear(input, self.weight, self.bias)

@ -17,8 +17,8 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCE
from .efficientnet_blocks import round_channels, resolve_bn_args, resolve_act_layer, BN_EPS_TF_DEFAULT from .efficientnet_blocks import round_channels, resolve_bn_args, resolve_act_layer, BN_EPS_TF_DEFAULT
from .efficientnet_builder import EfficientNetBuilder, decode_arch_def, efficientnet_init_weights from .efficientnet_builder import EfficientNetBuilder, decode_arch_def, efficientnet_init_weights
from .features import FeatureInfo, FeatureHooks from .features import FeatureInfo, FeatureHooks
from .helpers import build_model_with_cfg from .helpers import build_model_with_cfg, default_cfg_for_features
from .layers import SelectAdaptivePool2d, create_conv2d, get_act_fn, hard_sigmoid from .layers import SelectAdaptivePool2d, Linear, create_conv2d, get_act_fn, hard_sigmoid
from .registry import register_model from .registry import register_model
__all__ = ['MobileNetV3'] __all__ = ['MobileNetV3']
@ -105,7 +105,7 @@ class MobileNetV3(nn.Module):
num_pooled_chs = head_chs * self.global_pool.feat_mult() num_pooled_chs = head_chs * self.global_pool.feat_mult()
self.conv_head = create_conv2d(num_pooled_chs, self.num_features, 1, padding=pad_type, bias=head_bias) self.conv_head = create_conv2d(num_pooled_chs, self.num_features, 1, padding=pad_type, bias=head_bias)
self.act2 = act_layer(inplace=True) self.act2 = act_layer(inplace=True)
self.classifier = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() self.classifier = Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
efficientnet_init_weights(self) efficientnet_init_weights(self)
@ -123,7 +123,7 @@ class MobileNetV3(nn.Module):
self.num_classes = num_classes self.num_classes = num_classes
# cannot meaningfully change pooling of efficient head after creation # cannot meaningfully change pooling of efficient head after creation
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
self.classifier = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() self.classifier = Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
def forward_features(self, x): def forward_features(self, x):
x = self.conv_stem(x) x = self.conv_stem(x)
@ -201,19 +201,21 @@ class MobileNetV3Features(nn.Module):
def _create_mnv3(model_kwargs, variant, pretrained=False): def _create_mnv3(model_kwargs, variant, pretrained=False):
features_only = False
model_cls = MobileNetV3
if model_kwargs.pop('features_only', False): if model_kwargs.pop('features_only', False):
load_strict = False features_only = True
model_kwargs.pop('num_classes', 0) model_kwargs.pop('num_classes', 0)
model_kwargs.pop('num_features', 0) model_kwargs.pop('num_features', 0)
model_kwargs.pop('head_conv', None) model_kwargs.pop('head_conv', None)
model_kwargs.pop('head_bias', None) model_kwargs.pop('head_bias', None)
model_cls = MobileNetV3Features model_cls = MobileNetV3Features
else: model = build_model_with_cfg(
load_strict = True
model_cls = MobileNetV3
return build_model_with_cfg(
model_cls, variant, pretrained, default_cfg=default_cfgs[variant], model_cls, variant, pretrained, default_cfg=default_cfgs[variant],
pretrained_strict=load_strict, **model_kwargs) pretrained_strict=not features_only, **model_kwargs)
if features_only:
model.default_cfg = default_cfg_for_features(model.default_cfg)
return model
def _gen_mobilenet_v3_rw(variant, channel_multiplier=1.0, pretrained=False, **kwargs): def _gen_mobilenet_v3_rw(variant, channel_multiplier=1.0, pretrained=False, **kwargs):

@ -6,5 +6,5 @@ from .log import setup_default_logging, FormatterNoInfo
from .metrics import AverageMeter, accuracy from .metrics import AverageMeter, accuracy
from .misc import natural_key, add_bool_arg from .misc import natural_key, add_bool_arg
from .model import unwrap_model, get_state_dict from .model import unwrap_model, get_state_dict
from .model_ema import ModelEma from .model_ema import ModelEma, ModelEmaV2
from .summary import update_summary, get_outdir from .summary import update_summary, get_outdir

@ -7,13 +7,16 @@ from collections import OrderedDict
from copy import deepcopy from copy import deepcopy
import torch import torch
import torch.nn as nn
_logger = logging.getLogger(__name__) _logger = logging.getLogger(__name__)
class ModelEma: class ModelEma:
""" Model Exponential Moving Average """ Model Exponential Moving Average (DEPRECATED)
Keep a moving average of everything in the model state_dict (parameters and buffers). Keep a moving average of everything in the model state_dict (parameters and buffers).
This version is deprecated, it does not work with scripted models. Will be removed eventually.
This is intended to allow functionality like This is intended to allow functionality like
https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage
@ -30,7 +33,6 @@ class ModelEma:
This class is sensitive where it is initialized in the sequence of model init, This class is sensitive where it is initialized in the sequence of model init,
GPU assignment and distributed training wrappers. GPU assignment and distributed training wrappers.
I've tested with the sequence in my own train.py for torch.DataParallel, apex.DDP, and single-GPU.
""" """
def __init__(self, model, decay=0.9999, device='', resume=''): def __init__(self, model, decay=0.9999, device='', resume=''):
# make a copy of the model for accumulating moving average of weights # make a copy of the model for accumulating moving average of weights
@ -75,3 +77,50 @@ class ModelEma:
if self.device: if self.device:
model_v = model_v.to(device=self.device) model_v = model_v.to(device=self.device)
ema_v.copy_(ema_v * self.decay + (1. - self.decay) * model_v) ema_v.copy_(ema_v * self.decay + (1. - self.decay) * model_v)
class ModelEmaV2(nn.Module):
""" Model Exponential Moving Average V2
Keep a moving average of everything in the model state_dict (parameters and buffers).
V2 of this module is simpler, it does not match params/buffers based on name but simply
iterates in order. It works with torchscript (JIT of full model).
This is intended to allow functionality like
https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage
A smoothed version of the weights is necessary for some training schemes to perform well.
E.g. Google's hyper-params for training MNASNet, MobileNet-V3, EfficientNet, etc that use
RMSprop with a short 2.4-3 epoch decay period and slow LR decay rate of .96-.99 requires EMA
smoothing of weights to match results. Pay attention to the decay constant you are using
relative to your update count per epoch.
To keep EMA from using GPU resources, set device='cpu'. This will save a bit of memory but
disable validation of the EMA weights. Validation will have to be done manually in a separate
process, or after the training stops converging.
This class is sensitive where it is initialized in the sequence of model init,
GPU assignment and distributed training wrappers.
"""
def __init__(self, model, decay=0.9999, device=None):
super(ModelEmaV2, self).__init__()
# make a copy of the model for accumulating moving average of weights
self.module = deepcopy(model)
self.module.eval()
self.decay = decay
self.device = device # perform ema on different device from model if set
if self.device is not None:
self.module.to(device=device)
def _update(self, model, update_fn):
with torch.no_grad():
for ema_v, model_v in zip(self.module.state_dict().values(), model.state_dict().values()):
if self.device is not None:
model_v = model_v.to(device=self.device)
ema_v.copy_(update_fn(ema_v, model_v))
def update(self, model):
self._update(model, update_fn=lambda e, m: self.decay * e + (1. - self.decay) * m)
def set(self, model):
self._update(model, update_fn=lambda e, m: m)

@ -1 +1 @@
__version__ = '0.3.1' __version__ = '0.3.2'

@ -29,7 +29,7 @@ import torchvision.utils
from torch.nn.parallel import DistributedDataParallel as NativeDDP from torch.nn.parallel import DistributedDataParallel as NativeDDP
from timm.data import Dataset, create_loader, resolve_data_config, Mixup, FastCollateMixup, AugMixDataset from timm.data import Dataset, create_loader, resolve_data_config, Mixup, FastCollateMixup, AugMixDataset
from timm.models import create_model, resume_checkpoint, convert_splitbn_model from timm.models import create_model, resume_checkpoint, load_checkpoint, convert_splitbn_model
from timm.utils import * from timm.utils import *
from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy, JsdCrossEntropy from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy, JsdCrossEntropy
from timm.optim import create_optimizer from timm.optim import create_optimizer
@ -230,8 +230,6 @@ parser.add_argument('--recovery-interval', type=int, default=0, metavar='N',
help='how many batches to wait before writing recovery checkpoint') help='how many batches to wait before writing recovery checkpoint')
parser.add_argument('-j', '--workers', type=int, default=4, metavar='N', parser.add_argument('-j', '--workers', type=int, default=4, metavar='N',
help='how many training processes to use (default: 1)') help='how many training processes to use (default: 1)')
parser.add_argument('--num-gpu', type=int, default=1,
help='Number of GPUS to use')
parser.add_argument('--save-images', action='store_true', default=False, parser.add_argument('--save-images', action='store_true', default=False,
help='save images of input bathes every log interval for debugging') help='save images of input bathes every log interval for debugging')
parser.add_argument('--amp', action='store_true', default=False, parser.add_argument('--amp', action='store_true', default=False,
@ -255,6 +253,8 @@ parser.add_argument('--tta', type=int, default=0, metavar='N',
parser.add_argument("--local_rank", default=0, type=int) parser.add_argument("--local_rank", default=0, type=int)
parser.add_argument('--use-multi-epochs-loader', action='store_true', default=False, parser.add_argument('--use-multi-epochs-loader', action='store_true', default=False,
help='use the multi-epochs-loader to save time at the beginning of every epoch') help='use the multi-epochs-loader to save time at the beginning of every epoch')
parser.add_argument('--torchscript', dest='torchscript', action='store_true',
help='convert model torchscript for inference')
def _parse_args(): def _parse_args():
@ -282,28 +282,36 @@ def main():
args.distributed = False args.distributed = False
if 'WORLD_SIZE' in os.environ: if 'WORLD_SIZE' in os.environ:
args.distributed = int(os.environ['WORLD_SIZE']) > 1 args.distributed = int(os.environ['WORLD_SIZE']) > 1
if args.distributed and args.num_gpu > 1:
_logger.warning(
'Using more than one GPU per process in distributed mode is not allowed.Setting num_gpu to 1.')
args.num_gpu = 1
args.device = 'cuda:0' args.device = 'cuda:0'
args.world_size = 1 args.world_size = 1
args.rank = 0 # global rank args.rank = 0 # global rank
if args.distributed: if args.distributed:
args.num_gpu = 1
args.device = 'cuda:%d' % args.local_rank args.device = 'cuda:%d' % args.local_rank
torch.cuda.set_device(args.local_rank) torch.cuda.set_device(args.local_rank)
torch.distributed.init_process_group(backend='nccl', init_method='env://') torch.distributed.init_process_group(backend='nccl', init_method='env://')
args.world_size = torch.distributed.get_world_size() args.world_size = torch.distributed.get_world_size()
args.rank = torch.distributed.get_rank() args.rank = torch.distributed.get_rank()
assert args.rank >= 0
if args.distributed:
_logger.info('Training in distributed mode with multiple processes, 1 GPU per process. Process %d, total %d.' _logger.info('Training in distributed mode with multiple processes, 1 GPU per process. Process %d, total %d.'
% (args.rank, args.world_size)) % (args.rank, args.world_size))
else: else:
_logger.info('Training with a single process on %d GPUs.' % args.num_gpu) _logger.info('Training with a single process on 1 GPUs.')
assert args.rank >= 0
# resolve AMP arguments based on PyTorch / Apex availability
use_amp = None
if args.amp:
# for backwards compat, `--amp` arg tries apex before native amp
if has_apex:
args.apex_amp = True
elif has_native_amp:
args.native_amp = True
if args.apex_amp and has_apex:
use_amp = 'apex'
elif args.native_amp and has_native_amp:
use_amp = 'native'
elif args.apex_amp or args.native_amp:
_logger.warning("Neither APEX or native Torch AMP is available, using float32. "
"Install NVIDA apex or upgrade to PyTorch 1.6")
torch.manual_seed(args.seed + args.rank) torch.manual_seed(args.seed + args.rank)
@ -319,6 +327,7 @@ def main():
bn_tf=args.bn_tf, bn_tf=args.bn_tf,
bn_momentum=args.bn_momentum, bn_momentum=args.bn_momentum,
bn_eps=args.bn_eps, bn_eps=args.bn_eps,
scriptable=args.torchscript,
checkpoint_path=args.initial_checkpoint) checkpoint_path=args.initial_checkpoint)
if args.local_rank == 0: if args.local_rank == 0:
@ -327,44 +336,43 @@ def main():
data_config = resolve_data_config(vars(args), model=model, verbose=args.local_rank == 0) data_config = resolve_data_config(vars(args), model=model, verbose=args.local_rank == 0)
# setup augmentation batch splits for contrastive loss or split bn
num_aug_splits = 0 num_aug_splits = 0
if args.aug_splits > 0: if args.aug_splits > 0:
assert args.aug_splits > 1, 'A split of 1 makes no sense' assert args.aug_splits > 1, 'A split of 1 makes no sense'
num_aug_splits = args.aug_splits num_aug_splits = args.aug_splits
# enable split bn (separate bn stats per batch-portion)
if args.split_bn: if args.split_bn:
assert num_aug_splits > 1 or args.resplit assert num_aug_splits > 1 or args.resplit
model = convert_splitbn_model(model, max(num_aug_splits, 2)) model = convert_splitbn_model(model, max(num_aug_splits, 2))
use_amp = None # move model to GPU, enable channels last layout if set
if args.amp: model.cuda()
# for backwards compat, `--amp` arg tries apex before native amp if args.channels_last:
if has_apex: model = model.to(memory_format=torch.channels_last)
args.apex_amp = True
elif has_native_amp:
args.native_amp = True
if args.apex_amp and has_apex:
use_amp = 'apex'
elif args.native_amp and has_native_amp:
use_amp = 'native'
elif args.apex_amp or args.native_amp:
_logger.warning("Neither APEX or native Torch AMP is available, using float32. "
"Install NVIDA apex or upgrade to PyTorch 1.6")
if args.num_gpu > 1: # setup synchronized BatchNorm for distributed training
if use_amp == 'apex': if args.distributed and args.sync_bn:
_logger.warning( assert not args.split_bn
'Apex AMP does not work well with nn.DataParallel, disabling. Use DDP or Torch AMP.') if has_apex and use_amp != 'native':
use_amp = None # Apex SyncBN preferred unless native amp is activated
model = nn.DataParallel(model, device_ids=list(range(args.num_gpu))).cuda() model = convert_syncbn_model(model)
assert not args.channels_last, "Channels last not supported with DP, use DDP." else:
else: model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
model.cuda() if args.local_rank == 0:
if args.channels_last: _logger.info(
model = model.to(memory_format=torch.channels_last) 'Converted model to use Synchronized BatchNorm. WARNING: You may have issues if using '
'zero initialized BN layers (enabled by default for ResNets) while sync-bn enabled.')
if args.torchscript:
assert not use_amp == 'apex', 'Cannot use APEX AMP with torchscripted model'
assert not args.sync_bn, 'Cannot use SyncBatchNorm with torchscripted model'
model = torch.jit.script(model)
optimizer = create_optimizer(args, model) optimizer = create_optimizer(args, model)
# setup automatic mixed-precision (AMP) loss scaling and op casting
amp_autocast = suppress # do nothing amp_autocast = suppress # do nothing
loss_scaler = None loss_scaler = None
if use_amp == 'apex': if use_amp == 'apex':
@ -390,30 +398,17 @@ def main():
loss_scaler=None if args.no_resume_opt else loss_scaler, loss_scaler=None if args.no_resume_opt else loss_scaler,
log_info=args.local_rank == 0) log_info=args.local_rank == 0)
# setup exponential moving average of model weights, SWA could be used here too
model_ema = None model_ema = None
if args.model_ema: if args.model_ema:
# Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper # Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper
model_ema = ModelEma( model_ema = ModelEmaV2(
model, model, decay=args.model_ema_decay, device='cpu' if args.model_ema_force_cpu else None)
decay=args.model_ema_decay, if args.resume:
device='cpu' if args.model_ema_force_cpu else '', load_checkpoint(model_ema.module, args.resume, use_ema=True)
resume=args.resume)
# setup distributed training
if args.distributed: if args.distributed:
if args.sync_bn:
assert not args.split_bn
try:
if has_apex and use_amp != 'native':
# Apex SyncBN preferred unless native amp is activated
model = convert_syncbn_model(model)
else:
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
if args.local_rank == 0:
_logger.info(
'Converted model to use Synchronized BatchNorm. WARNING: You may have issues if using '
'zero initialized BN layers (enabled by default for ResNets) while sync-bn enabled.')
except Exception as e:
_logger.error('Failed to enable Synchronized BatchNorm. Install Apex or Torch >= 1.1')
if has_apex and use_amp != 'native': if has_apex and use_amp != 'native':
# Apex DDP preferred unless native amp is activated # Apex DDP preferred unless native amp is activated
if args.local_rank == 0: if args.local_rank == 0:
@ -425,6 +420,7 @@ def main():
model = NativeDDP(model, device_ids=[args.local_rank]) # can use device str in Torch >= 1.1 model = NativeDDP(model, device_ids=[args.local_rank]) # can use device str in Torch >= 1.1
# NOTE: EMA model does not need to be wrapped by DDP # NOTE: EMA model does not need to be wrapped by DDP
# setup learning rate schedule and starting epoch
lr_scheduler, num_epochs = create_scheduler(args, optimizer) lr_scheduler, num_epochs = create_scheduler(args, optimizer)
start_epoch = 0 start_epoch = 0
if args.start_epoch is not None: if args.start_epoch is not None:
@ -438,12 +434,22 @@ def main():
if args.local_rank == 0: if args.local_rank == 0:
_logger.info('Scheduled epochs: {}'.format(num_epochs)) _logger.info('Scheduled epochs: {}'.format(num_epochs))
# create the train and eval datasets
train_dir = os.path.join(args.data, 'train') train_dir = os.path.join(args.data, 'train')
if not os.path.exists(train_dir): if not os.path.exists(train_dir):
_logger.error('Training folder does not exist at: {}'.format(train_dir)) _logger.error('Training folder does not exist at: {}'.format(train_dir))
exit(1) exit(1)
dataset_train = Dataset(train_dir) dataset_train = Dataset(train_dir)
eval_dir = os.path.join(args.data, 'val')
if not os.path.isdir(eval_dir):
eval_dir = os.path.join(args.data, 'validation')
if not os.path.isdir(eval_dir):
_logger.error('Validation folder does not exist at: {}'.format(eval_dir))
exit(1)
dataset_eval = Dataset(eval_dir)
# setup mixup / cutmix
collate_fn = None collate_fn = None
mixup_fn = None mixup_fn = None
mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None
@ -458,9 +464,11 @@ def main():
else: else:
mixup_fn = Mixup(**mixup_args) mixup_fn = Mixup(**mixup_args)
# wrap dataset in AugMix helper
if num_aug_splits > 1: if num_aug_splits > 1:
dataset_train = AugMixDataset(dataset_train, num_splits=num_aug_splits) dataset_train = AugMixDataset(dataset_train, num_splits=num_aug_splits)
# create data loaders w/ augmentation pipeiine
train_interpolation = args.train_interpolation train_interpolation = args.train_interpolation
if args.no_aug or not train_interpolation: if args.no_aug or not train_interpolation:
train_interpolation = data_config['interpolation'] train_interpolation = data_config['interpolation']
@ -492,14 +500,6 @@ def main():
use_multi_epochs_loader=args.use_multi_epochs_loader use_multi_epochs_loader=args.use_multi_epochs_loader
) )
eval_dir = os.path.join(args.data, 'val')
if not os.path.isdir(eval_dir):
eval_dir = os.path.join(args.data, 'validation')
if not os.path.isdir(eval_dir):
_logger.error('Validation folder does not exist at: {}'.format(eval_dir))
exit(1)
dataset_eval = Dataset(eval_dir)
loader_eval = create_loader( loader_eval = create_loader(
dataset_eval, dataset_eval,
input_size=data_config['input_size'], input_size=data_config['input_size'],
@ -515,6 +515,7 @@ def main():
pin_memory=args.pin_mem, pin_memory=args.pin_mem,
) )
# setup loss function
if args.jsd: if args.jsd:
assert num_aug_splits > 1 # JSD only valid with aug splits set assert num_aug_splits > 1 # JSD only valid with aug splits set
train_loss_fn = JsdCrossEntropy(num_splits=num_aug_splits, smoothing=args.smoothing).cuda() train_loss_fn = JsdCrossEntropy(num_splits=num_aug_splits, smoothing=args.smoothing).cuda()
@ -527,6 +528,7 @@ def main():
train_loss_fn = nn.CrossEntropyLoss().cuda() train_loss_fn = nn.CrossEntropyLoss().cuda()
validate_loss_fn = nn.CrossEntropyLoss().cuda() validate_loss_fn = nn.CrossEntropyLoss().cuda()
# setup checkpoint saver and eval metric tracking
eval_metric = args.eval_metric eval_metric = args.eval_metric
best_metric = None best_metric = None
best_epoch = None best_epoch = None
@ -568,7 +570,7 @@ def main():
if args.distributed and args.dist_bn in ('broadcast', 'reduce'): if args.distributed and args.dist_bn in ('broadcast', 'reduce'):
distribute_bn(model_ema, args.world_size, args.dist_bn == 'reduce') distribute_bn(model_ema, args.world_size, args.dist_bn == 'reduce')
ema_eval_metrics = validate( ema_eval_metrics = validate(
model_ema.ema, loader_eval, validate_loss_fn, args, amp_autocast=amp_autocast, log_suffix=' (EMA)') model_ema.module, loader_eval, validate_loss_fn, args, amp_autocast=amp_autocast, log_suffix=' (EMA)')
eval_metrics = ema_eval_metrics eval_metrics = ema_eval_metrics
if lr_scheduler is not None: if lr_scheduler is not None:
@ -638,11 +640,11 @@ def train_epoch(
torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip_grad) torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip_grad)
optimizer.step() optimizer.step()
torch.cuda.synchronize()
if model_ema is not None: if model_ema is not None:
model_ema.update(model) model_ema.update(model)
num_updates += 1
torch.cuda.synchronize()
num_updates += 1
batch_time_m.update(time.time() - end) batch_time_m.update(time.time() - end)
if last_batch or batch_idx % args.log_interval == 0: if last_batch or batch_idx % args.log_interval == 0:
lrl = [param_group['lr'] for param_group in optimizer.param_groups] lrl = [param_group['lr'] for param_group in optimizer.param_groups]

Loading…
Cancel
Save