Merge pull request #233 from rwightman/torchamp

Native Torch AMP and channels_last support for train.py and validate.py
pull/237/head
Ross Wightman 4 years ago committed by GitHub
commit 5247eb37a7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -120,6 +120,12 @@ if 'GITHUB_ACTIONS' not in os.environ:
in_chans = 3 if 'pruned' in model_name else 1 # pruning not currently supported with in_chans change
create_model(model_name, pretrained=True, in_chans=in_chans)
@pytest.mark.timeout(120)
@pytest.mark.parametrize('model_name', list_models(pretrained=True))
@pytest.mark.parametrize('batch_size', [1])
def test_model_features_pretrained(model_name, batch_size):
"""Create that pretrained weights load when features_only==True."""
create_model(model_name, pretrained=True, features_only=True)
EXCLUDE_JIT_FILTERS = [
'*iabn*', 'tresnet*', # models using inplace abn unlikely to ever be scriptable

@ -106,20 +106,18 @@ class SqueezeExcite(nn.Module):
def __init__(self, in_chs, se_ratio=0.25, reduced_base_chs=None,
act_layer=nn.ReLU, gate_fn=sigmoid, divisor=1, **_):
super(SqueezeExcite, self).__init__()
self.gate_fn = gate_fn
reduced_chs = make_divisible((reduced_base_chs or in_chs) * se_ratio, divisor)
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.conv_reduce = nn.Conv2d(in_chs, reduced_chs, 1, bias=True)
self.act1 = act_layer(inplace=True)
self.conv_expand = nn.Conv2d(reduced_chs, in_chs, 1, bias=True)
self.gate_fn = gate_fn
def forward(self, x):
x_se = self.avg_pool(x)
x_se = x.mean((2, 3), keepdim=True)
x_se = self.conv_reduce(x_se)
x_se = self.act1(x_se)
x_se = self.conv_expand(x_se)
x = x * self.gate_fn(x_se)
return x
return x * self.gate_fn(x_se)
class ConvBnAct(nn.Module):

@ -39,11 +39,6 @@ def create_model(
kwargs.pop('bn_momentum', None)
kwargs.pop('bn_eps', None)
# Parameters that aren't supported by all models should default to None in command line args,
# remove them if they are present and not set so that non-supporting models don't break.
if kwargs.get('drop_block_rate', None) is None:
kwargs.pop('drop_block_rate', None)
# handle backwards compat with drop_connect -> drop_path change
drop_connect_rate = kwargs.pop('drop_connect_rate', None)
if drop_connect_rate is not None and kwargs.get('drop_path_rate', None) is None:
@ -51,8 +46,10 @@ def create_model(
" Setting drop_path to %f." % drop_connect_rate)
kwargs['drop_path_rate'] = drop_connect_rate
if kwargs.get('drop_path_rate', None) is None:
kwargs.pop('drop_path_rate', None)
# Parameters that aren't supported by all models or are intended to only override model defaults if set
# should default to None in command line args/cfg. Remove them if they are present and not set so that
# non-supporting models don't break and default args remain in effect.
kwargs = {k: v for k, v in kwargs.items() if v is not None}
with set_layer_config(scriptable=scriptable, exportable=exportable, no_jit=no_jit):
if is_model(model_name):

@ -48,30 +48,41 @@ def load_checkpoint(model, checkpoint_path, use_ema=False, strict=True):
model.load_state_dict(state_dict, strict=strict)
def resume_checkpoint(model, checkpoint_path):
other_state = {}
def resume_checkpoint(model, checkpoint_path, optimizer=None, loss_scaler=None, log_info=True):
resume_epoch = None
if os.path.isfile(checkpoint_path):
checkpoint = torch.load(checkpoint_path, map_location='cpu')
if isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
if log_info:
_logger.info('Restoring model state from checkpoint...')
new_state_dict = OrderedDict()
for k, v in checkpoint['state_dict'].items():
name = k[7:] if k.startswith('module') else k
new_state_dict[name] = v
model.load_state_dict(new_state_dict)
if 'optimizer' in checkpoint:
other_state['optimizer'] = checkpoint['optimizer']
if 'amp' in checkpoint:
other_state['amp'] = checkpoint['amp']
if optimizer is not None and 'optimizer' in checkpoint:
if log_info:
_logger.info('Restoring optimizer state from checkpoint...')
optimizer.load_state_dict(checkpoint['optimizer'])
if loss_scaler is not None and loss_scaler.state_dict_key in checkpoint:
if log_info:
_logger.info('Restoring AMP loss scaler state from checkpoint...')
loss_scaler.load_state_dict(checkpoint[loss_scaler.state_dict_key])
if 'epoch' in checkpoint:
resume_epoch = checkpoint['epoch']
if 'version' in checkpoint and checkpoint['version'] > 1:
resume_epoch += 1 # start at the next epoch, old checkpoints incremented before save
_logger.info("Loaded checkpoint '{}' (epoch {})".format(checkpoint_path, checkpoint['epoch']))
if log_info:
_logger.info("Loaded checkpoint '{}' (epoch {})".format(checkpoint_path, checkpoint['epoch']))
else:
model.load_state_dict(checkpoint)
_logger.info("Loaded checkpoint '{}'".format(checkpoint_path))
return other_state, resume_epoch
if log_info:
_logger.info("Loaded checkpoint '{}'".format(checkpoint_path))
return resume_epoch
else:
_logger.error("No checkpoint found at '{}'".format(checkpoint_path))
raise FileNotFoundError()

@ -773,12 +773,14 @@ class HighResolutionNetFeatures(HighResolutionNet):
def _create_hrnet(variant, pretrained, **model_kwargs):
model_cls = HighResolutionNet
strict = True
if model_kwargs.pop('features_only', False):
model_cls = HighResolutionNetFeatures
strict = False
return build_model_with_cfg(
model_cls, variant, pretrained, default_cfg=default_cfgs[variant],
model_cfg=cfg_cls[variant], **model_kwargs)
model_cfg=cfg_cls[variant], pretrained_strict=strict, **model_kwargs)
@register_model

@ -49,6 +49,15 @@ def select_adaptive_pool2d(x, pool_type='avg', output_size=1):
return x
class FastAdaptiveAvgPool2d(nn.Module):
def __init__(self, flatten=False):
super(FastAdaptiveAvgPool2d, self).__init__()
self.flatten = flatten
def forward(self, x):
return x.mean((2, 3)) if self.flatten else x.mean((2, 3), keepdim=True)
class AdaptiveAvgMaxPool2d(nn.Module):
def __init__(self, output_size=1):
super(AdaptiveAvgMaxPool2d, self).__init__()
@ -70,12 +79,16 @@ class AdaptiveCatAvgMaxPool2d(nn.Module):
class SelectAdaptivePool2d(nn.Module):
"""Selectable global pooling layer with dynamic input kernel size
"""
def __init__(self, output_size=1, pool_type='avg', flatten=False):
def __init__(self, output_size=1, pool_type='fast', flatten=False):
super(SelectAdaptivePool2d, self).__init__()
self.pool_type = pool_type or '' # convert other falsy values to empty string for consistent TS typing
self.flatten = flatten
if pool_type == '':
self.pool = nn.Identity() # pass through
elif pool_type == 'fast':
assert output_size == 1
self.pool = FastAdaptiveAvgPool2d(self.flatten)
self.flatten = False
elif pool_type == 'avg':
self.pool = nn.AdaptiveAvgPool2d(output_size)
elif pool_type == 'avgmax':

@ -10,6 +10,7 @@ Hacked together by / Copyright 2020 Ross Wightman
import torch
from torch import nn as nn
import torch.nn.functional as F
from .conv_bn_act import ConvBnAct
@ -18,15 +19,13 @@ class ChannelAttn(nn.Module):
"""
def __init__(self, channels, reduction=16, act_layer=nn.ReLU):
super(ChannelAttn, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.max_pool = nn.AdaptiveMaxPool2d(1)
self.fc1 = nn.Conv2d(channels, channels // reduction, 1, bias=False)
self.act = act_layer(inplace=True)
self.fc2 = nn.Conv2d(channels // reduction, channels, 1, bias=False)
def forward(self, x):
x_avg = self.avg_pool(x)
x_max = self.max_pool(x)
x_avg = x.mean((2, 3), keepdim=True)
x_max = F.adaptive_max_pool2d(x, 1)
x_avg = self.fc2(self.act(self.fc1(x_avg)))
x_max = self.fc2(self.act(self.fc1(x_max)))
x_attn = x_avg + x_max
@ -40,7 +39,7 @@ class LightChannelAttn(ChannelAttn):
super(LightChannelAttn, self).__init__(channels, reduction)
def forward(self, x):
x_pool = 0.5 * self.avg_pool(x) + 0.5 * self.max_pool(x)
x_pool = 0.5 * x.mean((2, 3), keepdim=True) + 0.5 * F.adaptive_max_pool2d(x, 1)
x_attn = self.fc2(self.act(self.fc1(x_pool)))
return x * x_attn.sigmoid()

@ -52,22 +52,15 @@ class EcaModule(nn.Module):
def __init__(self, channels=None, kernel_size=3, gamma=2, beta=1):
super(EcaModule, self).__init__()
assert kernel_size % 2 == 1
if channels is not None:
t = int(abs(math.log(channels, 2) + beta) / gamma)
kernel_size = max(t if t % 2 else t + 1, 3)
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.conv = nn.Conv1d(1, 1, kernel_size=kernel_size, padding=(kernel_size - 1) // 2, bias=False)
def forward(self, x):
# Feature descriptor on the global spatial information
y = self.avg_pool(x)
# Reshape for convolution
y = y.view(x.shape[0], 1, -1)
# Two different branches of ECA module
y = x.mean((2, 3)).view(x.shape[0], 1, -1) # view for 1d conv
y = self.conv(y)
# Multi-scale information fusion
y = y.view(x.shape[0], -1, 1, 1).sigmoid()
return x * y.expand_as(x)
@ -95,30 +88,20 @@ class CecaModule(nn.Module):
def __init__(self, channels=None, kernel_size=3, gamma=2, beta=1):
super(CecaModule, self).__init__()
assert kernel_size % 2 == 1
if channels is not None:
t = int(abs(math.log(channels, 2) + beta) / gamma)
kernel_size = max(t if t % 2 else t + 1, 3)
self.avg_pool = nn.AdaptiveAvgPool2d(1)
#pytorch circular padding mode is buggy as of pytorch 1.4
#see https://github.com/pytorch/pytorch/pull/17240
#implement manual circular padding
# PyTorch circular padding mode is buggy as of pytorch 1.4
# see https://github.com/pytorch/pytorch/pull/17240
# implement manual circular padding
self.conv = nn.Conv1d(1, 1, kernel_size=kernel_size, padding=0, bias=False)
self.padding = (kernel_size - 1) // 2
def forward(self, x):
# Feature descriptor on the global spatial information
y = self.avg_pool(x)
y = x.mean((2, 3)).view(x.shape[0], 1, -1)
# Manually implement circular padding, F.pad does not seemed to be bugged
y = F.pad(y.view(x.shape[0], 1, -1), (self.padding, self.padding), mode='circular')
# Two different branches of ECA module
y = F.pad(y, (self.padding, self.padding), mode='circular')
y = self.conv(y)
# Multi-scale information fusion
y = y.view(x.shape[0], -1, 1, 1).sigmoid()
return x * y.expand_as(x)

@ -1,40 +1,36 @@
from torch import nn as nn
from .create_act import get_act_fn
from .create_act import create_act_layer
class SEModule(nn.Module):
def __init__(self, channels, reduction=16, act_layer=nn.ReLU, min_channels=8, reduction_channels=None,
gate_fn='sigmoid'):
gate_layer='sigmoid'):
super(SEModule, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1)
reduction_channels = reduction_channels or max(channels // reduction, min_channels)
self.fc1 = nn.Conv2d(
channels, reduction_channels, kernel_size=1, padding=0, bias=True)
self.fc1 = nn.Conv2d(channels, reduction_channels, kernel_size=1, bias=True)
self.act = act_layer(inplace=True)
self.fc2 = nn.Conv2d(
reduction_channels, channels, kernel_size=1, padding=0, bias=True)
self.gate_fn = get_act_fn(gate_fn)
self.fc2 = nn.Conv2d(reduction_channels, channels, kernel_size=1, bias=True)
self.gate = create_act_layer(gate_layer)
def forward(self, x):
x_se = self.avg_pool(x)
x_se = x.mean((2, 3), keepdim=True)
x_se = self.fc1(x_se)
x_se = self.act(x_se)
x_se = self.fc2(x_se)
return x * self.gate_fn(x_se)
return x * self.gate(x_se)
class EffectiveSEModule(nn.Module):
""" 'Effective Squeeze-Excitation
From `CenterMask : Real-Time Anchor-Free Instance Segmentation` - https://arxiv.org/abs/1911.06667
"""
def __init__(self, channels, gate_fn='hard_sigmoid'):
def __init__(self, channels, gate_layer='hard_sigmoid'):
super(EffectiveSEModule, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.fc = nn.Conv2d(channels, channels, kernel_size=1, padding=0)
self.gate_fn = get_act_fn(gate_fn)
self.gate = create_act_layer(gate_layer, inplace=True)
def forward(self, x):
x_se = self.avg_pool(x)
x_se = x.mean((2, 3), keepdim=True)
x_se = self.fc(x_se)
return x * self.gate_fn(x_se, inplace=True)
return x * self.gate(x_se)

@ -27,7 +27,6 @@ class SelectiveKernelAttn(nn.Module):
"""
super(SelectiveKernelAttn, self).__init__()
self.num_paths = num_paths
self.pool = nn.AdaptiveAvgPool2d(1)
self.fc_reduce = nn.Conv2d(channels, attn_channels, kernel_size=1, bias=False)
self.bn = norm_layer(attn_channels)
self.act = act_layer(inplace=True)
@ -35,8 +34,7 @@ class SelectiveKernelAttn(nn.Module):
def forward(self, x):
assert x.shape[1] == self.num_paths
x = torch.sum(x, dim=1)
x = self.pool(x)
x = x.sum(1).mean((2, 3), keepdim=True)
x = self.fc_reduce(x)
x = self.bn(x)
x = self.act(x)

@ -59,18 +59,15 @@ class SEWithNorm(nn.Module):
def __init__(self, channels, reduction=16, act_layer=nn.ReLU, divisor=1, reduction_channels=None,
gate_layer='sigmoid'):
super(SEWithNorm, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1)
reduction_channels = reduction_channels or make_divisible(channels // reduction, divisor=divisor)
self.fc1 = nn.Conv2d(
channels, reduction_channels, kernel_size=1, padding=0, bias=True)
self.fc1 = nn.Conv2d(channels, reduction_channels, kernel_size=1, bias=True)
self.bn = nn.BatchNorm2d(reduction_channels)
self.act = act_layer(inplace=True)
self.fc2 = nn.Conv2d(
reduction_channels, channels, kernel_size=1, padding=0, bias=True)
self.fc2 = nn.Conv2d(reduction_channels, channels, kernel_size=1, bias=True)
self.gate = create_act_layer(gate_layer)
def forward(self, x):
x_se = self.avg_pool(x)
x_se = x.mean((2, 3), keepdim=True)
x_se = self.fc1(x_se)
x_se = self.bn(x_se)
x_se = self.act(x_se)

@ -71,17 +71,14 @@ class SEModule(nn.Module):
def __init__(self, channels, reduction):
super(SEModule, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.fc1 = nn.Conv2d(
channels, channels // reduction, kernel_size=1, padding=0)
self.fc1 = nn.Conv2d(channels, channels // reduction, kernel_size=1)
self.relu = nn.ReLU(inplace=True)
self.fc2 = nn.Conv2d(
channels // reduction, channels, kernel_size=1, padding=0)
self.fc2 = nn.Conv2d(channels // reduction, channels, kernel_size=1)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
module_input = x
x = self.avg_pool(x)
x = x.mean((2, 3), keepdim=True)
x = self.fc1(x)
x = self.relu(x)
x = self.fc2(x)

@ -14,7 +14,7 @@ import torch.nn as nn
import torch.nn.functional as F
from .helpers import build_model_with_cfg
from .layers import SpaceToDepthModule, AntiAliasDownsampleLayer, InplaceAbn, ClassifierHead
from .layers import SpaceToDepthModule, AntiAliasDownsampleLayer, InplaceAbn, ClassifierHead, SEModule
from .registry import register_model
__all__ = ['tresnet_m', 'tresnet_l', 'tresnet_xl']
@ -49,41 +49,6 @@ default_cfgs = {
}
class FastGlobalAvgPool2d(nn.Module):
def __init__(self, flatten=False):
super(FastGlobalAvgPool2d, self).__init__()
self.flatten = flatten
def forward(self, x):
if self.flatten:
in_size = x.size()
return x.view((in_size[0], in_size[1], -1)).mean(dim=2)
else:
return x.view(x.size(0), x.size(1), -1).mean(-1).view(x.size(0), x.size(1), 1, 1)
def feat_mult(self):
return 1
class FastSEModule(nn.Module):
def __init__(self, channels, reduction_channels, inplace=True):
super(FastSEModule, self).__init__()
self.avg_pool = FastGlobalAvgPool2d()
self.fc1 = nn.Conv2d(channels, reduction_channels, kernel_size=1, padding=0, bias=True)
self.relu = nn.ReLU(inplace=inplace)
self.fc2 = nn.Conv2d(reduction_channels, channels, kernel_size=1, padding=0, bias=True)
self.activation = nn.Sigmoid()
def forward(self, x):
x_se = self.avg_pool(x)
x_se2 = self.fc1(x_se)
x_se2 = self.relu(x_se2)
x_se = self.fc2(x_se2)
x_se = self.activation(x_se)
return x * x_se
def IABN2Float(module: nn.Module) -> nn.Module:
"""If `module` is IABN don't use half precision."""
if isinstance(module, InplaceAbn):
@ -120,8 +85,8 @@ class BasicBlock(nn.Module):
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
self.stride = stride
reduce_layer_planes = max(planes * self.expansion // 4, 64)
self.se = FastSEModule(planes * self.expansion, reduce_layer_planes) if use_se else None
reduction_chs = max(planes * self.expansion // 4, 64)
self.se = SEModule(planes * self.expansion, reduction_channels=reduction_chs) if use_se else None
def forward(self, x):
if self.downsample is not None:
@ -160,8 +125,8 @@ class Bottleneck(nn.Module):
conv2d_iabn(planes, planes, kernel_size=3, stride=1, act_layer=act_layer, act_param=1e-3),
aa_layer(channels=planes, filt_size=3, stride=2))
reduce_layer_planes = max(planes * self.expansion // 8, 64)
self.se = FastSEModule(planes, reduce_layer_planes) if use_se else None
reduction_chs = max(planes * self.expansion // 8, 64)
self.se = SEModule(planes, reduction_channels=reduction_chs) if use_se else None
self.conv3 = conv2d_iabn(
planes, planes * self.expansion, kernel_size=1, stride=1, act_layer="identity")
@ -190,7 +155,7 @@ class Bottleneck(nn.Module):
class TResNet(nn.Module):
def __init__(self, layers, in_chans=3, num_classes=1000, width_factor=1.0, no_aa_jit=False,
global_pool='avg', drop_rate=0.):
global_pool='fast', drop_rate=0.):
self.num_classes = num_classes
self.drop_rate = drop_rate
super(TResNet, self).__init__()
@ -273,7 +238,7 @@ class TResNet(nn.Module):
def get_classifier(self):
return self.head.fc
def reset_classifier(self, num_classes, global_pool='avg'):
def reset_classifier(self, num_classes, global_pool='fast'):
self.head = ClassifierHead(
self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate)

@ -37,19 +37,67 @@ def unwrap_model(model):
return model.module if hasattr(model, 'module') else model
def get_state_dict(model):
return unwrap_model(model).state_dict()
def get_state_dict(model, unwrap_fn=unwrap_model):
return unwrap_fn(model).state_dict()
class ApexScaler:
state_dict_key = "amp"
def __call__(self, loss, optimizer):
with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()
optimizer.step()
def state_dict(self):
if 'state_dict' in amp.__dict__:
return amp.state_dict()
def load_state_dict(self, state_dict):
if 'load_state_dict' in amp.__dict__:
amp.load_state_dict(state_dict)
class NativeScaler:
state_dict_key = "amp_scaler"
def __init__(self):
self._scaler = torch.cuda.amp.GradScaler()
def __call__(self, loss, optimizer):
self._scaler.scale(loss).backward()
self._scaler.step(optimizer)
self._scaler.update()
def state_dict(self):
return self._scaler.state_dict()
def load_state_dict(self, state_dict):
self._scaler.load_state_dict(state_dict)
class CheckpointSaver:
def __init__(
self,
model,
optimizer,
args=None,
model_ema=None,
amp_scaler=None,
checkpoint_prefix='checkpoint',
recovery_prefix='recovery',
checkpoint_dir='',
recovery_dir='',
decreasing=False,
max_history=10):
max_history=10,
unwrap_fn=unwrap_model):
# objects to save state_dicts of
self.model = model
self.optimizer = optimizer
self.args = args
self.model_ema = model_ema
self.amp_scaler = amp_scaler
# state
self.checkpoint_files = [] # (filename, metric) tuples in order of decreasing betterness
@ -67,13 +115,14 @@ class CheckpointSaver:
self.decreasing = decreasing # a lower metric is better if True
self.cmp = operator.lt if decreasing else operator.gt # True if lhs better than rhs
self.max_history = max_history
self.unwrap_fn = unwrap_fn
assert self.max_history >= 1
def save_checkpoint(self, model, optimizer, args, epoch, model_ema=None, metric=None, use_amp=False):
def save_checkpoint(self, epoch, metric=None):
assert epoch >= 0
tmp_save_path = os.path.join(self.checkpoint_dir, 'tmp' + self.extension)
last_save_path = os.path.join(self.checkpoint_dir, 'last' + self.extension)
self._save(tmp_save_path, model, optimizer, args, epoch, model_ema, metric, use_amp)
self._save(tmp_save_path, epoch, metric)
if os.path.exists(last_save_path):
os.unlink(last_save_path) # required for Windows support.
os.rename(tmp_save_path, last_save_path)
@ -105,19 +154,21 @@ class CheckpointSaver:
return (None, None) if self.best_metric is None else (self.best_metric, self.best_epoch)
def _save(self, save_path, model, optimizer, args, epoch, model_ema=None, metric=None, use_amp=False):
def _save(self, save_path, epoch, metric=None):
save_state = {
'epoch': epoch,
'arch': args.model,
'state_dict': get_state_dict(model),
'optimizer': optimizer.state_dict(),
'args': args,
'arch': type(self.model).__name__.lower(),
'state_dict': get_state_dict(self.model, self.unwrap_fn),
'optimizer': self.optimizer.state_dict(),
'version': 2, # version < 2 increments epoch before save
}
if use_amp and 'state_dict' in amp.__dict__:
save_state['amp'] = amp.state_dict()
if model_ema is not None:
save_state['state_dict_ema'] = get_state_dict(model_ema)
if self.args is not None:
save_state['arch'] = self.args.model
save_state['args'] = self.args
if self.amp_scaler is not None:
save_state[self.amp_scaler.state_dict_key] = self.amp_scaler.state_dict()
if self.model_ema is not None:
save_state['state_dict_ema'] = get_state_dict(self.model_ema, self.unwrap_fn)
if metric is not None:
save_state['metric'] = metric
torch.save(save_state, save_path)
@ -136,11 +187,11 @@ class CheckpointSaver:
_logger.error("Exception '{}' while deleting checkpoint".format(e))
self.checkpoint_files = self.checkpoint_files[:delete_index]
def save_recovery(self, model, optimizer, args, epoch, model_ema=None, use_amp=False, batch_idx=0):
def save_recovery(self, epoch, batch_idx=0):
assert epoch >= 0
filename = '-'.join([self.recovery_prefix, str(epoch), str(batch_idx)]) + self.extension
save_path = os.path.join(self.recovery_dir, filename)
self._save(save_path, model, optimizer, args, epoch, model_ema, use_amp=use_amp)
self._save(save_path, epoch)
if os.path.exists(self.last_recovery_file):
try:
_logger.debug("Cleaning recovery: {}".format(self.last_recovery_file))
@ -334,3 +385,16 @@ def add_bool_arg(parser, name, default=False, help=''):
group.add_argument('--' + name, dest=dest_name, action='store_true', help=help)
group.add_argument('--no-' + name, dest=dest_name, action='store_false', help=help)
parser.set_defaults(**{dest_name: default})
def set_jit_legacy():
""" Set JIT executor to legacy w/ support for op fusion
This is hopefully a temporary need in 1.5/1.5.1/1.6 to restore performance due to changes
in the JIT exectutor. These API are not supported so could change.
"""
#
assert hasattr(torch._C, '_jit_set_profiling_executor'), "Old JIT behavior doesn't exist!"
torch._C._jit_set_profiling_executor(False)
torch._C._jit_set_profiling_mode(False)
torch._C._jit_override_can_fuse_on_gpu(True)
#torch._C._jit_set_texpr_fuser_enabled(True)

@ -18,15 +18,11 @@ import argparse
import time
import yaml
from datetime import datetime
from contextlib import suppress
try:
from apex import amp
from apex.parallel import DistributedDataParallel as DDP
from apex.parallel import convert_syncbn_model
has_apex = True
except ImportError:
from torch.nn.parallel import DistributedDataParallel as DDP
has_apex = False
import torch.nn as nn
import torchvision.utils
from torch.nn.parallel import DistributedDataParallel as NativeDDP
from timm.data import Dataset, create_loader, resolve_data_config, Mixup, FastCollateMixup, AugMixDataset
from timm.models import create_model, resume_checkpoint, convert_splitbn_model
@ -34,15 +30,26 @@ from timm.utils import *
from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy, JsdCrossEntropy
from timm.optim import create_optimizer
from timm.scheduler import create_scheduler
from timm.utils import ApexScaler, NativeScaler
import torch
import torch.nn as nn
import torchvision.utils
try:
from apex import amp
from apex.parallel import DistributedDataParallel as ApexDDP
from apex.parallel import convert_syncbn_model
has_apex = True
except ImportError:
has_apex = False
has_native_amp = False
try:
if getattr(torch.cuda.amp, 'autocast') is not None:
has_native_amp = True
except AttributeError:
pass
torch.backends.cudnn.benchmark = True
_logger = logging.getLogger('train')
# The first arg parser parses out only the --config argument, this argument is used to
# load a yaml file containing key-values that override the defaults for the main parser below
config_parser = parser = argparse.ArgumentParser(description='Training Config', add_help=False)
@ -67,8 +74,8 @@ parser.add_argument('--no-resume-opt', action='store_true', default=False,
help='prevent resume of optimizer state when resuming model')
parser.add_argument('--num-classes', type=int, default=1000, metavar='N',
help='number of label classes (default: 1000)')
parser.add_argument('--gp', default='avg', type=str, metavar='POOL',
help='Type of global pool, "avg", "max", "avgmax", "avgmaxc" (default: "avg")')
parser.add_argument('--gp', default=None, type=str, metavar='POOL',
help='Global pool type, one of (fast, avg, max, avgmax, avgmaxc). Model default if None.')
parser.add_argument('--img-size', type=int, default=None, metavar='N',
help='Image patch size (default: None => model default)')
parser.add_argument('--crop-pct', default=None, type=float,
@ -218,7 +225,13 @@ parser.add_argument('--num-gpu', type=int, default=1,
parser.add_argument('--save-images', action='store_true', default=False,
help='save images of input bathes every log interval for debugging')
parser.add_argument('--amp', action='store_true', default=False,
help='use NVIDIA amp for mixed precision training')
help='use NVIDIA Apex AMP or Native AMP for mixed precision training')
parser.add_argument('--apex-amp', action='store_true', default=False,
help='Use NVIDIA Apex AMP mixed precision')
parser.add_argument('--native-amp', action='store_true', default=False,
help='Use Native Torch AMP mixed precision')
parser.add_argument('--channels-last', action='store_true', default=False,
help='Use channels_last memory layout')
parser.add_argument('--pin-mem', action='store_true', default=False,
help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')
parser.add_argument('--no-prefetcher', action='store_true', default=False,
@ -260,7 +273,8 @@ def main():
if 'WORLD_SIZE' in os.environ:
args.distributed = int(os.environ['WORLD_SIZE']) > 1
if args.distributed and args.num_gpu > 1:
_logger.warning('Using more than one GPU per process in distributed mode is not allowed. Setting num_gpu to 1.')
_logger.warning(
'Using more than one GPU per process in distributed mode is not allowed.Setting num_gpu to 1.')
args.num_gpu = 1
args.device = 'cuda:0'
@ -312,40 +326,59 @@ def main():
assert num_aug_splits > 1 or args.resplit
model = convert_splitbn_model(model, max(num_aug_splits, 2))
use_amp = None
if args.amp:
# for backwards compat, `--amp` arg tries apex before native amp
if has_apex:
args.apex_amp = True
elif has_native_amp:
args.native_amp = True
if args.apex_amp and has_apex:
use_amp = 'apex'
elif args.native_amp and has_native_amp:
use_amp = 'native'
elif args.apex_amp or args.native_amp:
_logger.warning("Neither APEX or native Torch AMP is available, using float32. "
"Install NVIDA apex or upgrade to PyTorch 1.6")
if args.num_gpu > 1:
if args.amp:
if use_amp == 'apex':
_logger.warning(
'AMP does not work well with nn.DataParallel, disabling. Use distributed mode for multi-GPU AMP.')
args.amp = False
'Apex AMP does not work well with nn.DataParallel, disabling. Use DDP or Torch AMP.')
use_amp = None
model = nn.DataParallel(model, device_ids=list(range(args.num_gpu))).cuda()
assert not args.channels_last, "Channels last not supported with DP, use DDP."
else:
model.cuda()
if args.channels_last:
model = model.to(memory_format=torch.channels_last)
optimizer = create_optimizer(args, model)
use_amp = False
if has_apex and args.amp:
amp_autocast = suppress # do nothing
loss_scaler = None
if use_amp == 'apex':
model, optimizer = amp.initialize(model, optimizer, opt_level='O1')
use_amp = True
if args.local_rank == 0:
_logger.info('NVIDIA APEX {}. AMP {}.'.format(
'installed' if has_apex else 'not installed', 'on' if use_amp else 'off'))
loss_scaler = ApexScaler()
if args.local_rank == 0:
_logger.info('Using NVIDIA APEX AMP. Training in mixed precision.')
elif use_amp == 'native':
amp_autocast = torch.cuda.amp.autocast
loss_scaler = NativeScaler()
if args.local_rank == 0:
_logger.info('Using native Torch AMP. Training in mixed precision.')
else:
if args.local_rank == 0:
_logger.info('AMP not enabled. Training in float32.')
# optionally resume from a checkpoint
resume_state = {}
resume_epoch = None
if args.resume:
resume_state, resume_epoch = resume_checkpoint(model, args.resume)
if resume_state and not args.no_resume_opt:
if 'optimizer' in resume_state:
if args.local_rank == 0:
_logger.info('Restoring Optimizer state from checkpoint')
optimizer.load_state_dict(resume_state['optimizer'])
if use_amp and 'amp' in resume_state and 'load_state_dict' in amp.__dict__:
if args.local_rank == 0:
_logger.info('Restoring NVIDIA AMP state from checkpoint')
amp.load_state_dict(resume_state['amp'])
del resume_state
resume_epoch = resume_checkpoint(
model, args.resume,
optimizer=None if args.no_resume_opt else optimizer,
loss_scaler=None if args.no_resume_opt else loss_scaler,
log_info=args.local_rank == 0)
model_ema = None
if args.model_ema:
@ -360,7 +393,8 @@ def main():
if args.sync_bn:
assert not args.split_bn
try:
if has_apex:
if has_apex and use_amp != 'native':
# Apex SyncBN preferred unless native amp is activated
model = convert_syncbn_model(model)
else:
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
@ -370,12 +404,15 @@ def main():
'zero initialized BN layers (enabled by default for ResNets) while sync-bn enabled.')
except Exception as e:
_logger.error('Failed to enable Synchronized BatchNorm. Install Apex or Torch >= 1.1')
if has_apex:
model = DDP(model, delay_allreduce=True)
if has_apex and use_amp != 'native':
# Apex DDP preferred unless native amp is activated
if args.local_rank == 0:
_logger.info("Using NVIDIA APEX DistributedDataParallel.")
model = ApexDDP(model, delay_allreduce=True)
else:
if args.local_rank == 0:
_logger.info("Using torch DistributedDataParallel. Install NVIDIA Apex for Apex DDP.")
model = DDP(model, device_ids=[args.local_rank]) # can use device str in Torch >= 1.1
_logger.info("Using native Torch DistributedDataParallel.")
model = NativeDDP(model, device_ids=[args.local_rank]) # can use device str in Torch >= 1.1
# NOTE: EMA model does not need to be wrapped by DDP
lr_scheduler, num_epochs = create_scheduler(args, optimizer)
@ -494,7 +531,9 @@ def main():
])
output_dir = get_outdir(output_base, 'train', exp_name)
decreasing = True if eval_metric == 'loss' else False
saver = CheckpointSaver(checkpoint_dir=output_dir, decreasing=decreasing)
saver = CheckpointSaver(
model=model, optimizer=optimizer, args=args, model_ema=model_ema, amp_scaler=loss_scaler,
checkpoint_dir=output_dir, recovery_dir=output_dir, decreasing=decreasing)
with open(os.path.join(output_dir, 'args.yaml'), 'w') as f:
f.write(args_text)
@ -506,21 +545,20 @@ def main():
train_metrics = train_epoch(
epoch, model, loader_train, optimizer, train_loss_fn, args,
lr_scheduler=lr_scheduler, saver=saver, output_dir=output_dir,
use_amp=use_amp, model_ema=model_ema, mixup_fn=mixup_fn)
amp_autocast=amp_autocast, loss_scaler=loss_scaler, model_ema=model_ema, mixup_fn=mixup_fn)
if args.distributed and args.dist_bn in ('broadcast', 'reduce'):
if args.local_rank == 0:
_logger.info("Distributing BatchNorm running means and vars")
distribute_bn(model, args.world_size, args.dist_bn == 'reduce')
eval_metrics = validate(model, loader_eval, validate_loss_fn, args)
eval_metrics = validate(model, loader_eval, validate_loss_fn, args, amp_autocast=amp_autocast)
if model_ema is not None and not args.model_ema_force_cpu:
if args.distributed and args.dist_bn in ('broadcast', 'reduce'):
distribute_bn(model_ema, args.world_size, args.dist_bn == 'reduce')
ema_eval_metrics = validate(
model_ema.ema, loader_eval, validate_loss_fn, args, log_suffix=' (EMA)')
model_ema.ema, loader_eval, validate_loss_fn, args, amp_autocast=amp_autocast, log_suffix=' (EMA)')
eval_metrics = ema_eval_metrics
if lr_scheduler is not None:
@ -534,9 +572,7 @@ def main():
if saver is not None:
# save proper checkpoint with eval metric
save_metric = eval_metrics[eval_metric]
best_metric, best_epoch = saver.save_checkpoint(
model, optimizer, args,
epoch=epoch, model_ema=model_ema, metric=save_metric, use_amp=use_amp)
best_metric, best_epoch = saver.save_checkpoint(epoch, metric=save_metric)
except KeyboardInterrupt:
pass
@ -546,7 +582,8 @@ def main():
def train_epoch(
epoch, model, loader, optimizer, loss_fn, args,
lr_scheduler=None, saver=None, output_dir='', use_amp=False, model_ema=None, mixup_fn=None):
lr_scheduler=None, saver=None, output_dir='', amp_autocast=suppress,
loss_scaler=None, model_ema=None, mixup_fn=None):
if args.mixup_off_epoch and epoch >= args.mixup_off_epoch:
if args.prefetcher and loader.mixup_enabled:
@ -570,20 +607,22 @@ def train_epoch(
input, target = input.cuda(), target.cuda()
if mixup_fn is not None:
input, target = mixup_fn(input, target)
if args.channels_last:
input = input.contiguous(memory_format=torch.channels_last)
output = model(input)
with amp_autocast():
output = model(input)
loss = loss_fn(output, target)
loss = loss_fn(output, target)
if not args.distributed:
losses_m.update(loss.item(), input.size(0))
optimizer.zero_grad()
if use_amp:
with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()
if loss_scaler is not None:
loss_scaler(loss, optimizer)
else:
loss.backward()
optimizer.step()
optimizer.step()
torch.cuda.synchronize()
if model_ema is not None:
@ -626,8 +665,7 @@ def train_epoch(
if saver is not None and args.recovery_interval and (
last_batch or (batch_idx + 1) % args.recovery_interval == 0):
saver.save_recovery(
model, optimizer, args, epoch, model_ema=model_ema, use_amp=use_amp, batch_idx=batch_idx)
saver.save_recovery(epoch, batch_idx=batch_idx)
if lr_scheduler is not None:
lr_scheduler.step_update(num_updates=num_updates, metric=losses_m.avg)
@ -641,7 +679,7 @@ def train_epoch(
return OrderedDict([('loss', losses_m.avg)])
def validate(model, loader, loss_fn, args, log_suffix=''):
def validate(model, loader, loss_fn, args, amp_autocast=suppress, log_suffix=''):
batch_time_m = AverageMeter()
losses_m = AverageMeter()
top1_m = AverageMeter()
@ -657,8 +695,11 @@ def validate(model, loader, loss_fn, args, log_suffix=''):
if not args.prefetcher:
input = input.cuda()
target = target.cuda()
if args.channels_last:
input = input.contiguous(memory_format=torch.channels_last)
output = model(input)
with amp_autocast():
output = model(input)
if isinstance(output, (tuple, list)):
output = output[0]

@ -17,16 +17,25 @@ import torch
import torch.nn as nn
import torch.nn.parallel
from collections import OrderedDict
from contextlib import suppress
from timm.models import create_model, apply_test_time_pool, load_checkpoint, is_model, list_models
from timm.data import Dataset, DatasetTar, create_loader, resolve_data_config, RealLabelsImagenet
from timm.utils import accuracy, AverageMeter, natural_key, setup_default_logging, set_jit_legacy
has_apex = False
try:
from apex import amp
has_apex = True
except ImportError:
has_apex = False
pass
from timm.models import create_model, apply_test_time_pool, load_checkpoint, is_model, list_models
from timm.data import Dataset, DatasetTar, create_loader, resolve_data_config, RealLabelsImagenet
from timm.utils import accuracy, AverageMeter, natural_key, setup_default_logging
has_native_amp = False
try:
if getattr(torch.cuda.amp, 'autocast') is not None:
has_native_amp = True
except AttributeError:
pass
torch.backends.cudnn.benchmark = True
_logger = logging.getLogger('validate')
@ -55,6 +64,8 @@ parser.add_argument('--num-classes', type=int, default=1000,
help='Number classes in dataset')
parser.add_argument('--class-map', default='', type=str, metavar='FILENAME',
help='path to class to idx mapping file (default: "")')
parser.add_argument('--gp', default=None, type=str, metavar='POOL',
help='Global pool type, one of (fast, avg, max, avgmax, avgmaxc). Model default if None.')
parser.add_argument('--log-freq', default=10, type=int,
metavar='N', help='batch logging frequency (default: 10)')
parser.add_argument('--checkpoint', default='', type=str, metavar='PATH',
@ -69,8 +80,14 @@ parser.add_argument('--no-prefetcher', action='store_true', default=False,
help='disable fast prefetcher')
parser.add_argument('--pin-mem', action='store_true', default=False,
help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')
parser.add_argument('--channels-last', action='store_true', default=False,
help='Use channels_last memory layout')
parser.add_argument('--amp', action='store_true', default=False,
help='Use AMP mixed precision')
help='Use AMP mixed precision. Defaults to Apex, fallback to native Torch AMP.')
parser.add_argument('--apex-amp', action='store_true', default=False,
help='Use NVIDIA Apex AMP mixed precision')
parser.add_argument('--native-amp', action='store_true', default=False,
help='Use Native Torch AMP mixed precision')
parser.add_argument('--tf-preprocessing', action='store_true', default=False,
help='Use Tensorflow preprocessing pipeline (require CPU TF installed')
parser.add_argument('--use-ema', dest='use_ema', action='store_true',
@ -87,23 +104,22 @@ parser.add_argument('--valid-labels', default='', type=str, metavar='FILENAME',
help='Valid label indices txt file for validation of partial label space')
def set_jit_legacy():
""" Set JIT executor to legacy w/ support for op fusion
This is hopefully a temporary need in 1.5/1.5.1/1.6 to restore performance due to changes
in the JIT exectutor. These API are not supported so could change.
"""
#
assert hasattr(torch._C, '_jit_set_profiling_executor'), "Old JIT behavior doesn't exist!"
torch._C._jit_set_profiling_executor(False)
torch._C._jit_set_profiling_mode(False)
torch._C._jit_override_can_fuse_on_gpu(True)
#torch._C._jit_set_texpr_fuser_enabled(True)
def validate(args):
# might as well try to validate something
args.pretrained = args.pretrained or not args.checkpoint
args.prefetcher = not args.no_prefetcher
amp_autocast = suppress # do nothing
if args.amp:
if has_apex:
args.apex_amp = True
elif has_native_amp:
args.native_amp = True
else:
_logger.warning("Neither APEX or Native Torch AMP is available, using FP32.")
assert not args.apex_amp or not args.native_amp, "Only one AMP mode should be set."
if args.native_amp:
amp_autocast = torch.cuda.amp.autocast
if args.legacy_jit:
set_jit_legacy()
@ -113,6 +129,7 @@ def validate(args):
pretrained=args.pretrained,
num_classes=args.num_classes,
in_chans=3,
global_pool=args.gp,
scriptable=args.torchscript)
if args.checkpoint:
@ -128,10 +145,12 @@ def validate(args):
torch.jit.optimized_execution(True)
model = torch.jit.script(model)
if args.amp:
model = amp.initialize(model.cuda(), opt_level='O1')
else:
model = model.cuda()
model = model.cuda()
if args.apex_amp:
model = amp.initialize(model, opt_level='O1')
if args.channels_last:
model = model.to(memory_format=torch.channels_last)
if args.num_gpu > 1:
model = torch.nn.DataParallel(model, device_ids=list(range(args.num_gpu)))
@ -178,17 +197,21 @@ def validate(args):
with torch.no_grad():
# warmup, reduce variability of first batch time, especially for comparing torchscript vs non
input = torch.randn((args.batch_size,) + data_config['input_size']).cuda()
if args.channels_last:
input = input.contiguous(memory_format=torch.channels_last)
model(input)
end = time.time()
for batch_idx, (input, target) in enumerate(loader):
if args.no_prefetcher:
target = target.cuda()
input = input.cuda()
if args.fp16:
input = input.half()
if args.channels_last:
input = input.contiguous(memory_format=torch.channels_last)
# compute output
output = model(input)
with amp_autocast():
output = model(input)
if valid_labels is not None:
output = output[:, valid_labels]
loss = criterion(output, target)
@ -197,7 +220,7 @@ def validate(args):
real_labels.add_result(output)
# measure accuracy and record loss
acc1, acc5 = accuracy(output.data, target, topk=(1, 5))
acc1, acc5 = accuracy(output.detach(), target, topk=(1, 5))
losses.update(loss.item(), input.size(0))
top1.update(acc1.item(), input.size(0))
top5.update(acc5.item(), input.size(0))

Loading…
Cancel
Save