Merge pull request #297 from rwightman/ema_simplify

Simplified JIT compatible Ema module. Fixes for SiLU export and torchscript training w/ Linear layer.
pull/302/head
Ross Wightman 4 years ago committed by GitHub
commit 9a25fdf3ad
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -121,7 +121,7 @@ if 'GITHUB_ACTIONS' not in os.environ:
create_model(model_name, pretrained=True, in_chans=in_chans)
@pytest.mark.timeout(120)
@pytest.mark.parametrize('model_name', list_models(pretrained=True))
@pytest.mark.parametrize('model_name', list_models(pretrained=True, exclude_filters=['vit_*']))
@pytest.mark.parametrize('batch_size', [1])
def test_model_features_pretrained(model_name, batch_size):
"""Create that pretrained weights load when features_only==True."""

@ -34,7 +34,7 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCE
from .efficientnet_blocks import round_channels, resolve_bn_args, resolve_act_layer, BN_EPS_TF_DEFAULT
from .efficientnet_builder import EfficientNetBuilder, decode_arch_def, efficientnet_init_weights
from .features import FeatureInfo, FeatureHooks
from .helpers import build_model_with_cfg
from .helpers import build_model_with_cfg, default_cfg_for_features
from .layers import create_conv2d, create_classifier
from .registry import register_model
@ -453,18 +453,20 @@ class EfficientNetFeatures(nn.Module):
def _create_effnet(model_kwargs, variant, pretrained=False):
features_only = False
model_cls = EfficientNet
if model_kwargs.pop('features_only', False):
load_strict = False
features_only = True
model_kwargs.pop('num_classes', 0)
model_kwargs.pop('num_features', 0)
model_kwargs.pop('head_conv', None)
model_cls = EfficientNetFeatures
else:
load_strict = True
model_cls = EfficientNet
return build_model_with_cfg(
model = build_model_with_cfg(
model_cls, variant, pretrained, default_cfg=default_cfgs[variant],
pretrained_strict=load_strict, **model_kwargs)
pretrained_strict=not features_only, **model_kwargs)
if features_only:
model.default_cfg = default_cfg_for_features(model.default_cfg)
return model
def _gen_mnasnet_a1(variant, channel_multiplier=1.0, pretrained=False, **kwargs):

@ -14,7 +14,7 @@ import torch.nn as nn
import torch.utils.model_zoo as model_zoo
from .features import FeatureListNet, FeatureDictNet, FeatureHookNet
from .layers import Conv2dSame
from .layers import Conv2dSame, Linear
_logger = logging.getLogger(__name__)
@ -234,7 +234,7 @@ def adapt_model_from_string(parent_module, model_string):
if isinstance(old_module, nn.Linear):
# FIXME extra checks to ensure this is actually the FC classifier layer and not a diff Linear layer?
num_features = state_dict[n + '.weight'][1]
new_fc = nn.Linear(
new_fc = Linear(
in_features=num_features, out_features=old_module.out_features, bias=old_module.bias is not None)
set_layer(new_module, n, new_fc)
if hasattr(new_module, 'num_features'):
@ -251,6 +251,15 @@ def adapt_model_from_file(parent_module, model_variant):
return adapt_model_from_string(parent_module, f.read().strip())
def default_cfg_for_features(default_cfg):
default_cfg = deepcopy(default_cfg)
# remove default pretrained cfg fields that don't have much relevance for feature backbone
to_remove = ('num_classes', 'crop_pct', 'classifier') # add default final pool size?
for tr in to_remove:
default_cfg.pop(tr, None)
return default_cfg
def build_model_with_cfg(
model_cls: Callable,
variant: str,
@ -296,5 +305,6 @@ def build_model_with_cfg(
else:
assert False, f'Unknown feature class {feature_cls}'
model = feature_cls(model, **feature_cfg)
model.default_cfg = default_cfg_for_features(default_cfg) # add back default_cfg
return model

@ -17,7 +17,7 @@ import torch.nn.functional as F
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .features import FeatureInfo
from .helpers import build_model_with_cfg
from .helpers import build_model_with_cfg, default_cfg_for_features
from .layers import create_classifier
from .registry import register_model
from .resnet import BasicBlock, Bottleneck # leveraging ResNet blocks w/ additional features like SE
@ -773,15 +773,17 @@ class HighResolutionNetFeatures(HighResolutionNet):
def _create_hrnet(variant, pretrained, **model_kwargs):
model_cls = HighResolutionNet
strict = True
features_only = False
if model_kwargs.pop('features_only', False):
model_cls = HighResolutionNetFeatures
model_kwargs['num_classes'] = 0
strict = False
return build_model_with_cfg(
features_only = True
model = build_model_with_cfg(
model_cls, variant, pretrained, default_cfg=default_cfgs[variant],
model_cfg=cfg_cls[variant], pretrained_strict=strict, **model_kwargs)
model_cfg=cfg_cls[variant], pretrained_strict=not features_only, **model_kwargs)
if features_only:
model.default_cfg = default_cfg_for_features(model.default_cfg)
return model
@register_model

@ -10,7 +10,7 @@ import torch.nn.functional as F
from timm.data import IMAGENET_DEFAULT_STD, IMAGENET_DEFAULT_MEAN, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
from .helpers import build_model_with_cfg
from .registry import register_model
from .layers import trunc_normal_, create_classifier
from .layers import trunc_normal_, create_classifier, Linear
def _cfg(url='', **kwargs):
@ -250,7 +250,7 @@ class InceptionAux(nn.Module):
self.conv0 = conv_block(in_channels, 128, kernel_size=1)
self.conv1 = conv_block(128, 768, kernel_size=5)
self.conv1.stddev = 0.01
self.fc = nn.Linear(768, num_classes)
self.fc = Linear(768, num_classes)
self.fc.stddev = 0.001
def forward(self, x):

@ -18,6 +18,7 @@ from .eca import EcaModule, CecaModule
from .evo_norm import EvoNormBatch2d, EvoNormSample2d
from .helpers import to_ntuple, to_2tuple, to_3tuple, to_4tuple
from .inplace_abn import InplaceAbn
from .linear import Linear
from .mixed_conv2d import MixedConv2d
from .norm_act import BatchNormAct2d
from .padding import get_padding

@ -119,3 +119,27 @@ class HardMish(nn.Module):
def forward(self, x):
return hard_mish(x, self.inplace)
class PReLU(nn.PReLU):
"""Applies PReLU (w/ dummy inplace arg)
"""
def __init__(self, num_parameters: int = 1, init: float = 0.25, inplace: bool = False) -> None:
super(PReLU, self).__init__(num_parameters=num_parameters, init=init)
def forward(self, input: torch.Tensor) -> torch.Tensor:
return F.prelu(input, self.weight)
def gelu(x: torch.Tensor, inplace: bool = False) -> torch.Tensor:
return F.gelu(x)
class GELU(nn.Module):
"""Applies the Gaussian Error Linear Units function (w/ dummy inplace arg)
"""
def __init__(self, inplace: bool = False):
super(GELU, self).__init__()
def forward(self, input: torch.Tensor) -> torch.Tensor:
return F.gelu(input)

@ -6,6 +6,7 @@ from torch import nn as nn
from torch.nn import functional as F
from .adaptive_avgmax_pool import SelectAdaptivePool2d
from .linear import Linear
def create_classifier(num_features, num_classes, pool_type='avg', use_conv=False):
@ -21,7 +22,8 @@ def create_classifier(num_features, num_classes, pool_type='avg', use_conv=False
elif use_conv:
fc = nn.Conv2d(num_pooled_features, num_classes, 1, bias=True)
else:
fc = nn.Linear(num_pooled_features, num_classes, bias=True)
# NOTE: using my Linear wrapper that fixes AMP + torchscript casting issue
fc = Linear(num_pooled_features, num_classes, bias=True)
return global_pool, fc

@ -19,10 +19,9 @@ _ACT_FN_DEFAULT = dict(
relu6=F.relu6,
leaky_relu=F.leaky_relu,
elu=F.elu,
prelu=F.prelu,
celu=F.celu,
selu=F.selu,
gelu=F.gelu,
gelu=gelu,
sigmoid=sigmoid,
tanh=tanh,
hard_sigmoid=hard_sigmoid,
@ -56,10 +55,10 @@ _ACT_LAYER_DEFAULT = dict(
relu6=nn.ReLU6,
leaky_relu=nn.LeakyReLU,
elu=nn.ELU,
prelu=nn.PReLU,
prelu=PReLU,
celu=nn.CELU,
selu=nn.SELU,
gelu=nn.GELU,
gelu=GELU,
sigmoid=Sigmoid,
tanh=Tanh,
hard_sigmoid=HardSigmoid,
@ -98,7 +97,10 @@ def get_act_fn(name='relu'):
# custom autograd, then fallback
if name in _ACT_FN_ME:
return _ACT_FN_ME[name]
if not is_no_jit():
if is_exportable() and name in ('silu', 'swish'):
# FIXME PyTorch SiLU doesn't ONNX export, this is a temp hack
return swish
if not (is_no_jit() or is_exportable()):
if name in _ACT_FN_JIT:
return _ACT_FN_JIT[name]
return _ACT_FN_DEFAULT[name]
@ -114,7 +116,10 @@ def get_act_layer(name='relu'):
if not (is_no_jit() or is_exportable() or is_scriptable()):
if name in _ACT_LAYER_ME:
return _ACT_LAYER_ME[name]
if not is_no_jit():
if is_exportable() and name in ('silu', 'swish'):
# FIXME PyTorch SiLU doesn't ONNX export, this is a temp hack
return Swish
if not (is_no_jit() or is_exportable()):
if name in _ACT_LAYER_JIT:
return _ACT_LAYER_JIT[name]
return _ACT_LAYER_DEFAULT[name]

@ -0,0 +1,19 @@
""" Linear layer (alternate definition)
"""
import torch
import torch.nn.functional as F
from torch import nn as nn
class Linear(nn.Linear):
r"""Applies a linear transformation to the incoming data: :math:`y = xA^T + b`
Wraps torch.nn.Linear to support AMP + torchscript usage by manually casting
weight & bias to input.dtype to work around an issue w/ torch.addmm in this use case.
"""
def forward(self, input: torch.Tensor) -> torch.Tensor:
if torch.jit.is_scripting():
bias = self.bias.to(dtype=input.dtype) if self.bias is not None else None
return F.linear(input, self.weight.to(dtype=input.dtype), bias=bias)
else:
return F.linear(input, self.weight, self.bias)

@ -17,8 +17,8 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCE
from .efficientnet_blocks import round_channels, resolve_bn_args, resolve_act_layer, BN_EPS_TF_DEFAULT
from .efficientnet_builder import EfficientNetBuilder, decode_arch_def, efficientnet_init_weights
from .features import FeatureInfo, FeatureHooks
from .helpers import build_model_with_cfg
from .layers import SelectAdaptivePool2d, create_conv2d, get_act_fn, hard_sigmoid
from .helpers import build_model_with_cfg, default_cfg_for_features
from .layers import SelectAdaptivePool2d, Linear, create_conv2d, get_act_fn, hard_sigmoid
from .registry import register_model
__all__ = ['MobileNetV3']
@ -105,7 +105,7 @@ class MobileNetV3(nn.Module):
num_pooled_chs = head_chs * self.global_pool.feat_mult()
self.conv_head = create_conv2d(num_pooled_chs, self.num_features, 1, padding=pad_type, bias=head_bias)
self.act2 = act_layer(inplace=True)
self.classifier = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
self.classifier = Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
efficientnet_init_weights(self)
@ -123,7 +123,7 @@ class MobileNetV3(nn.Module):
self.num_classes = num_classes
# cannot meaningfully change pooling of efficient head after creation
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
self.classifier = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
self.classifier = Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
def forward_features(self, x):
x = self.conv_stem(x)
@ -201,19 +201,21 @@ class MobileNetV3Features(nn.Module):
def _create_mnv3(model_kwargs, variant, pretrained=False):
features_only = False
model_cls = MobileNetV3
if model_kwargs.pop('features_only', False):
load_strict = False
features_only = True
model_kwargs.pop('num_classes', 0)
model_kwargs.pop('num_features', 0)
model_kwargs.pop('head_conv', None)
model_kwargs.pop('head_bias', None)
model_cls = MobileNetV3Features
else:
load_strict = True
model_cls = MobileNetV3
return build_model_with_cfg(
model = build_model_with_cfg(
model_cls, variant, pretrained, default_cfg=default_cfgs[variant],
pretrained_strict=load_strict, **model_kwargs)
pretrained_strict=not features_only, **model_kwargs)
if features_only:
model.default_cfg = default_cfg_for_features(model.default_cfg)
return model
def _gen_mobilenet_v3_rw(variant, channel_multiplier=1.0, pretrained=False, **kwargs):

@ -6,5 +6,5 @@ from .log import setup_default_logging, FormatterNoInfo
from .metrics import AverageMeter, accuracy
from .misc import natural_key, add_bool_arg
from .model import unwrap_model, get_state_dict
from .model_ema import ModelEma
from .model_ema import ModelEma, ModelEmaV2
from .summary import update_summary, get_outdir

@ -7,13 +7,16 @@ from collections import OrderedDict
from copy import deepcopy
import torch
import torch.nn as nn
_logger = logging.getLogger(__name__)
class ModelEma:
""" Model Exponential Moving Average
""" Model Exponential Moving Average (DEPRECATED)
Keep a moving average of everything in the model state_dict (parameters and buffers).
This version is deprecated, it does not work with scripted models. Will be removed eventually.
This is intended to allow functionality like
https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage
@ -30,7 +33,6 @@ class ModelEma:
This class is sensitive where it is initialized in the sequence of model init,
GPU assignment and distributed training wrappers.
I've tested with the sequence in my own train.py for torch.DataParallel, apex.DDP, and single-GPU.
"""
def __init__(self, model, decay=0.9999, device='', resume=''):
# make a copy of the model for accumulating moving average of weights
@ -75,3 +77,50 @@ class ModelEma:
if self.device:
model_v = model_v.to(device=self.device)
ema_v.copy_(ema_v * self.decay + (1. - self.decay) * model_v)
class ModelEmaV2(nn.Module):
""" Model Exponential Moving Average V2
Keep a moving average of everything in the model state_dict (parameters and buffers).
V2 of this module is simpler, it does not match params/buffers based on name but simply
iterates in order. It works with torchscript (JIT of full model).
This is intended to allow functionality like
https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage
A smoothed version of the weights is necessary for some training schemes to perform well.
E.g. Google's hyper-params for training MNASNet, MobileNet-V3, EfficientNet, etc that use
RMSprop with a short 2.4-3 epoch decay period and slow LR decay rate of .96-.99 requires EMA
smoothing of weights to match results. Pay attention to the decay constant you are using
relative to your update count per epoch.
To keep EMA from using GPU resources, set device='cpu'. This will save a bit of memory but
disable validation of the EMA weights. Validation will have to be done manually in a separate
process, or after the training stops converging.
This class is sensitive where it is initialized in the sequence of model init,
GPU assignment and distributed training wrappers.
"""
def __init__(self, model, decay=0.9999, device=None):
super(ModelEmaV2, self).__init__()
# make a copy of the model for accumulating moving average of weights
self.module = deepcopy(model)
self.module.eval()
self.decay = decay
self.device = device # perform ema on different device from model if set
if self.device is not None:
self.module.to(device=device)
def _update(self, model, update_fn):
with torch.no_grad():
for ema_v, model_v in zip(self.module.state_dict().values(), model.state_dict().values()):
if self.device is not None:
model_v = model_v.to(device=self.device)
ema_v.copy_(update_fn(ema_v, model_v))
def update(self, model):
self._update(model, update_fn=lambda e, m: self.decay * e + (1. - self.decay) * m)
def set(self, model):
self._update(model, update_fn=lambda e, m: m)

@ -1 +1 @@
__version__ = '0.3.1'
__version__ = '0.3.2'

@ -29,7 +29,7 @@ import torchvision.utils
from torch.nn.parallel import DistributedDataParallel as NativeDDP
from timm.data import Dataset, create_loader, resolve_data_config, Mixup, FastCollateMixup, AugMixDataset
from timm.models import create_model, resume_checkpoint, convert_splitbn_model
from timm.models import create_model, resume_checkpoint, load_checkpoint, convert_splitbn_model
from timm.utils import *
from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy, JsdCrossEntropy
from timm.optim import create_optimizer
@ -230,8 +230,6 @@ parser.add_argument('--recovery-interval', type=int, default=0, metavar='N',
help='how many batches to wait before writing recovery checkpoint')
parser.add_argument('-j', '--workers', type=int, default=4, metavar='N',
help='how many training processes to use (default: 1)')
parser.add_argument('--num-gpu', type=int, default=1,
help='Number of GPUS to use')
parser.add_argument('--save-images', action='store_true', default=False,
help='save images of input bathes every log interval for debugging')
parser.add_argument('--amp', action='store_true', default=False,
@ -255,6 +253,8 @@ parser.add_argument('--tta', type=int, default=0, metavar='N',
parser.add_argument("--local_rank", default=0, type=int)
parser.add_argument('--use-multi-epochs-loader', action='store_true', default=False,
help='use the multi-epochs-loader to save time at the beginning of every epoch')
parser.add_argument('--torchscript', dest='torchscript', action='store_true',
help='convert model torchscript for inference')
def _parse_args():
@ -282,28 +282,36 @@ def main():
args.distributed = False
if 'WORLD_SIZE' in os.environ:
args.distributed = int(os.environ['WORLD_SIZE']) > 1
if args.distributed and args.num_gpu > 1:
_logger.warning(
'Using more than one GPU per process in distributed mode is not allowed.Setting num_gpu to 1.')
args.num_gpu = 1
args.device = 'cuda:0'
args.world_size = 1
args.rank = 0 # global rank
if args.distributed:
args.num_gpu = 1
args.device = 'cuda:%d' % args.local_rank
torch.cuda.set_device(args.local_rank)
torch.distributed.init_process_group(backend='nccl', init_method='env://')
args.world_size = torch.distributed.get_world_size()
args.rank = torch.distributed.get_rank()
assert args.rank >= 0
if args.distributed:
_logger.info('Training in distributed mode with multiple processes, 1 GPU per process. Process %d, total %d.'
% (args.rank, args.world_size))
else:
_logger.info('Training with a single process on %d GPUs.' % args.num_gpu)
_logger.info('Training with a single process on 1 GPUs.')
assert args.rank >= 0
# resolve AMP arguments based on PyTorch / Apex availability
use_amp = None
if args.amp:
# for backwards compat, `--amp` arg tries apex before native amp
if has_apex:
args.apex_amp = True
elif has_native_amp:
args.native_amp = True
if args.apex_amp and has_apex:
use_amp = 'apex'
elif args.native_amp and has_native_amp:
use_amp = 'native'
elif args.apex_amp or args.native_amp:
_logger.warning("Neither APEX or native Torch AMP is available, using float32. "
"Install NVIDA apex or upgrade to PyTorch 1.6")
torch.manual_seed(args.seed + args.rank)
@ -319,6 +327,7 @@ def main():
bn_tf=args.bn_tf,
bn_momentum=args.bn_momentum,
bn_eps=args.bn_eps,
scriptable=args.torchscript,
checkpoint_path=args.initial_checkpoint)
if args.local_rank == 0:
@ -327,44 +336,43 @@ def main():
data_config = resolve_data_config(vars(args), model=model, verbose=args.local_rank == 0)
# setup augmentation batch splits for contrastive loss or split bn
num_aug_splits = 0
if args.aug_splits > 0:
assert args.aug_splits > 1, 'A split of 1 makes no sense'
num_aug_splits = args.aug_splits
# enable split bn (separate bn stats per batch-portion)
if args.split_bn:
assert num_aug_splits > 1 or args.resplit
model = convert_splitbn_model(model, max(num_aug_splits, 2))
use_amp = None
if args.amp:
# for backwards compat, `--amp` arg tries apex before native amp
if has_apex:
args.apex_amp = True
elif has_native_amp:
args.native_amp = True
if args.apex_amp and has_apex:
use_amp = 'apex'
elif args.native_amp and has_native_amp:
use_amp = 'native'
elif args.apex_amp or args.native_amp:
_logger.warning("Neither APEX or native Torch AMP is available, using float32. "
"Install NVIDA apex or upgrade to PyTorch 1.6")
# move model to GPU, enable channels last layout if set
model.cuda()
if args.channels_last:
model = model.to(memory_format=torch.channels_last)
if args.num_gpu > 1:
if use_amp == 'apex':
_logger.warning(
'Apex AMP does not work well with nn.DataParallel, disabling. Use DDP or Torch AMP.')
use_amp = None
model = nn.DataParallel(model, device_ids=list(range(args.num_gpu))).cuda()
assert not args.channels_last, "Channels last not supported with DP, use DDP."
else:
model.cuda()
if args.channels_last:
model = model.to(memory_format=torch.channels_last)
# setup synchronized BatchNorm for distributed training
if args.distributed and args.sync_bn:
assert not args.split_bn
if has_apex and use_amp != 'native':
# Apex SyncBN preferred unless native amp is activated
model = convert_syncbn_model(model)
else:
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
if args.local_rank == 0:
_logger.info(
'Converted model to use Synchronized BatchNorm. WARNING: You may have issues if using '
'zero initialized BN layers (enabled by default for ResNets) while sync-bn enabled.')
if args.torchscript:
assert not use_amp == 'apex', 'Cannot use APEX AMP with torchscripted model'
assert not args.sync_bn, 'Cannot use SyncBatchNorm with torchscripted model'
model = torch.jit.script(model)
optimizer = create_optimizer(args, model)
# setup automatic mixed-precision (AMP) loss scaling and op casting
amp_autocast = suppress # do nothing
loss_scaler = None
if use_amp == 'apex':
@ -390,30 +398,17 @@ def main():
loss_scaler=None if args.no_resume_opt else loss_scaler,
log_info=args.local_rank == 0)
# setup exponential moving average of model weights, SWA could be used here too
model_ema = None
if args.model_ema:
# Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper
model_ema = ModelEma(
model,
decay=args.model_ema_decay,
device='cpu' if args.model_ema_force_cpu else '',
resume=args.resume)
model_ema = ModelEmaV2(
model, decay=args.model_ema_decay, device='cpu' if args.model_ema_force_cpu else None)
if args.resume:
load_checkpoint(model_ema.module, args.resume, use_ema=True)
# setup distributed training
if args.distributed:
if args.sync_bn:
assert not args.split_bn
try:
if has_apex and use_amp != 'native':
# Apex SyncBN preferred unless native amp is activated
model = convert_syncbn_model(model)
else:
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
if args.local_rank == 0:
_logger.info(
'Converted model to use Synchronized BatchNorm. WARNING: You may have issues if using '
'zero initialized BN layers (enabled by default for ResNets) while sync-bn enabled.')
except Exception as e:
_logger.error('Failed to enable Synchronized BatchNorm. Install Apex or Torch >= 1.1')
if has_apex and use_amp != 'native':
# Apex DDP preferred unless native amp is activated
if args.local_rank == 0:
@ -425,6 +420,7 @@ def main():
model = NativeDDP(model, device_ids=[args.local_rank]) # can use device str in Torch >= 1.1
# NOTE: EMA model does not need to be wrapped by DDP
# setup learning rate schedule and starting epoch
lr_scheduler, num_epochs = create_scheduler(args, optimizer)
start_epoch = 0
if args.start_epoch is not None:
@ -438,12 +434,22 @@ def main():
if args.local_rank == 0:
_logger.info('Scheduled epochs: {}'.format(num_epochs))
# create the train and eval datasets
train_dir = os.path.join(args.data, 'train')
if not os.path.exists(train_dir):
_logger.error('Training folder does not exist at: {}'.format(train_dir))
exit(1)
dataset_train = Dataset(train_dir)
eval_dir = os.path.join(args.data, 'val')
if not os.path.isdir(eval_dir):
eval_dir = os.path.join(args.data, 'validation')
if not os.path.isdir(eval_dir):
_logger.error('Validation folder does not exist at: {}'.format(eval_dir))
exit(1)
dataset_eval = Dataset(eval_dir)
# setup mixup / cutmix
collate_fn = None
mixup_fn = None
mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None
@ -458,9 +464,11 @@ def main():
else:
mixup_fn = Mixup(**mixup_args)
# wrap dataset in AugMix helper
if num_aug_splits > 1:
dataset_train = AugMixDataset(dataset_train, num_splits=num_aug_splits)
# create data loaders w/ augmentation pipeiine
train_interpolation = args.train_interpolation
if args.no_aug or not train_interpolation:
train_interpolation = data_config['interpolation']
@ -492,14 +500,6 @@ def main():
use_multi_epochs_loader=args.use_multi_epochs_loader
)
eval_dir = os.path.join(args.data, 'val')
if not os.path.isdir(eval_dir):
eval_dir = os.path.join(args.data, 'validation')
if not os.path.isdir(eval_dir):
_logger.error('Validation folder does not exist at: {}'.format(eval_dir))
exit(1)
dataset_eval = Dataset(eval_dir)
loader_eval = create_loader(
dataset_eval,
input_size=data_config['input_size'],
@ -515,6 +515,7 @@ def main():
pin_memory=args.pin_mem,
)
# setup loss function
if args.jsd:
assert num_aug_splits > 1 # JSD only valid with aug splits set
train_loss_fn = JsdCrossEntropy(num_splits=num_aug_splits, smoothing=args.smoothing).cuda()
@ -527,6 +528,7 @@ def main():
train_loss_fn = nn.CrossEntropyLoss().cuda()
validate_loss_fn = nn.CrossEntropyLoss().cuda()
# setup checkpoint saver and eval metric tracking
eval_metric = args.eval_metric
best_metric = None
best_epoch = None
@ -568,7 +570,7 @@ def main():
if args.distributed and args.dist_bn in ('broadcast', 'reduce'):
distribute_bn(model_ema, args.world_size, args.dist_bn == 'reduce')
ema_eval_metrics = validate(
model_ema.ema, loader_eval, validate_loss_fn, args, amp_autocast=amp_autocast, log_suffix=' (EMA)')
model_ema.module, loader_eval, validate_loss_fn, args, amp_autocast=amp_autocast, log_suffix=' (EMA)')
eval_metrics = ema_eval_metrics
if lr_scheduler is not None:
@ -638,11 +640,11 @@ def train_epoch(
torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip_grad)
optimizer.step()
torch.cuda.synchronize()
if model_ema is not None:
model_ema.update(model)
num_updates += 1
torch.cuda.synchronize()
num_updates += 1
batch_time_m.update(time.time() - end)
if last_batch or batch_idx % args.log_interval == 0:
lrl = [param_group['lr'] for param_group in optimizer.param_groups]

Loading…
Cancel
Save