diff --git a/tests/test_models.py b/tests/test_models.py index dee4fbe7..3f1c4cda 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -83,7 +83,6 @@ def test_model_default_cfgs(model_name, batch_size): cfg = model.default_cfg classifier = cfg['classifier'] - first_conv = cfg['first_conv'] pool_size = cfg['pool_size'] input_size = model.default_cfg['input_size'] @@ -111,9 +110,16 @@ def test_model_default_cfgs(model_name, batch_size): # FIXME mobilenetv3 forward_features vs removed pooling differ assert outputs.shape[-1] == pool_size[-1] and outputs.shape[-2] == pool_size[-2] - # check classifier and first convolution names match those in default_cfg + # check classifier name matches default_cfg assert classifier + ".weight" in state_dict.keys(), f'{classifier} not in model params' - assert first_conv + ".weight" in state_dict.keys(), f'{first_conv} not in model params' + + # check first conv(s) names match default_cfg + first_conv = cfg['first_conv'] + if isinstance(first_conv, str): + first_conv = (first_conv,) + assert isinstance(first_conv, (tuple, list)) + for fc in first_conv: + assert fc + ".weight" in state_dict.keys(), f'{fc} not in model params' if 'GITHUB_ACTIONS' not in os.environ: diff --git a/timm/models/dpn.py b/timm/models/dpn.py index 045d634c..ac9c7755 100644 --- a/timm/models/dpn.py +++ b/timm/models/dpn.py @@ -7,6 +7,7 @@ This implementation is compatible with the pretrained weights from cypw's MXNet Hacked together by / Copyright 2020 Ross Wightman """ from collections import OrderedDict +from functools import partial from typing import Tuple import torch @@ -173,12 +174,14 @@ class DPN(nn.Module): self.drop_rate = drop_rate self.b = b assert output_stride == 32 # FIXME look into dilation support + norm_layer = partial(BatchNormAct2d, eps=.001) + fc_norm_layer = partial(BatchNormAct2d, eps=.001, act_layer=fc_act, inplace=False) bw_factor = 1 if small else 4 blocks = OrderedDict() # conv1 blocks['conv1_1'] = ConvBnAct( - in_chans, num_init_features, kernel_size=3 if small else 7, stride=2, norm_kwargs=dict(eps=.001)) + in_chans, num_init_features, kernel_size=3 if small else 7, stride=2, norm_layer=norm_layer) blocks['conv1_pool'] = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) self.feature_info = [dict(num_chs=num_init_features, reduction=2, module='features.conv1_1')] @@ -226,8 +229,7 @@ class DPN(nn.Module): in_chs += inc self.feature_info += [dict(num_chs=in_chs, reduction=32, module=f'features.conv5_{k_sec[3]}')] - def _fc_norm(f, eps): return BatchNormAct2d(f, eps=eps, act_layer=fc_act, inplace=False) - blocks['conv5_bn_ac'] = CatBnAct(in_chs, norm_layer=_fc_norm) + blocks['conv5_bn_ac'] = CatBnAct(in_chs, norm_layer=fc_norm_layer) self.num_features = in_chs self.features = nn.Sequential(blocks) diff --git a/timm/models/gluon_xception.py b/timm/models/gluon_xception.py index 3782c500..8fc398d6 100644 --- a/timm/models/gluon_xception.py +++ b/timm/models/gluon_xception.py @@ -42,10 +42,8 @@ for Tensorflow 'SAME' padding. PyTorch symmetric padding behaves the way we'd w class SeparableConv2d(nn.Module): - def __init__(self, inplanes, planes, kernel_size=3, stride=1, - dilation=1, bias=False, norm_layer=None, norm_kwargs=None): + def __init__(self, inplanes, planes, kernel_size=3, stride=1, dilation=1, bias=False, norm_layer=None): super(SeparableConv2d, self).__init__() - norm_kwargs = norm_kwargs if norm_kwargs is not None else {} self.kernel_size = kernel_size self.dilation = dilation @@ -54,7 +52,7 @@ class SeparableConv2d(nn.Module): self.conv_dw = nn.Conv2d( inplanes, inplanes, kernel_size, stride=stride, padding=padding, dilation=dilation, groups=inplanes, bias=bias) - self.bn = norm_layer(num_features=inplanes, **norm_kwargs) + self.bn = norm_layer(num_features=inplanes) # pointwise convolution self.conv_pw = nn.Conv2d(inplanes, planes, kernel_size=1, bias=bias) @@ -66,10 +64,8 @@ class SeparableConv2d(nn.Module): class Block(nn.Module): - def __init__(self, inplanes, planes, stride=1, dilation=1, start_with_relu=True, - norm_layer=None, norm_kwargs=None, ): + def __init__(self, inplanes, planes, stride=1, dilation=1, start_with_relu=True, norm_layer=None): super(Block, self).__init__() - norm_kwargs = norm_kwargs if norm_kwargs is not None else {} if isinstance(planes, (list, tuple)): assert len(planes) == 3 else: @@ -80,7 +76,7 @@ class Block(nn.Module): self.skip = nn.Sequential() self.skip.add_module('conv1', nn.Conv2d( inplanes, outplanes, 1, stride=stride, bias=False)), - self.skip.add_module('bn1', norm_layer(num_features=outplanes, **norm_kwargs)) + self.skip.add_module('bn1', norm_layer(num_features=outplanes)) else: self.skip = None @@ -88,9 +84,8 @@ class Block(nn.Module): for i in range(3): rep['act%d' % (i + 1)] = nn.ReLU(inplace=True) rep['conv%d' % (i + 1)] = SeparableConv2d( - inplanes, planes[i], 3, stride=stride if i == 2 else 1, dilation=dilation, - norm_layer=norm_layer, norm_kwargs=norm_kwargs) - rep['bn%d' % (i + 1)] = norm_layer(planes[i], **norm_kwargs) + inplanes, planes[i], 3, stride=stride if i == 2 else 1, dilation=dilation, norm_layer=norm_layer) + rep['bn%d' % (i + 1)] = norm_layer(planes[i]) inplanes = planes[i] if not start_with_relu: @@ -115,74 +110,63 @@ class Xception65(nn.Module): """ def __init__(self, num_classes=1000, in_chans=3, output_stride=32, norm_layer=nn.BatchNorm2d, - norm_kwargs=None, drop_rate=0., global_pool='avg'): + drop_rate=0., global_pool='avg'): super(Xception65, self).__init__() self.num_classes = num_classes self.drop_rate = drop_rate - norm_kwargs = norm_kwargs if norm_kwargs is not None else {} if output_stride == 32: entry_block3_stride = 2 exit_block20_stride = 2 - middle_block_dilation = 1 - exit_block_dilations = (1, 1) + middle_dilation = 1 + exit_dilation = (1, 1) elif output_stride == 16: entry_block3_stride = 2 exit_block20_stride = 1 - middle_block_dilation = 1 - exit_block_dilations = (1, 2) + middle_dilation = 1 + exit_dilation = (1, 2) elif output_stride == 8: entry_block3_stride = 1 exit_block20_stride = 1 - middle_block_dilation = 2 - exit_block_dilations = (2, 4) + middle_dilation = 2 + exit_dilation = (2, 4) else: raise NotImplementedError # Entry flow self.conv1 = nn.Conv2d(in_chans, 32, kernel_size=3, stride=2, padding=1, bias=False) - self.bn1 = norm_layer(num_features=32, **norm_kwargs) + self.bn1 = norm_layer(num_features=32) self.act1 = nn.ReLU(inplace=True) self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1, bias=False) self.bn2 = norm_layer(num_features=64) self.act2 = nn.ReLU(inplace=True) - self.block1 = Block( - 64, 128, stride=2, start_with_relu=False, norm_layer=norm_layer, norm_kwargs=norm_kwargs) + self.block1 = Block(64, 128, stride=2, start_with_relu=False, norm_layer=norm_layer) self.block1_act = nn.ReLU(inplace=True) - self.block2 = Block( - 128, 256, stride=2, start_with_relu=False, norm_layer=norm_layer, norm_kwargs=norm_kwargs) - self.block3 = Block( - 256, 728, stride=entry_block3_stride, norm_layer=norm_layer, norm_kwargs=norm_kwargs) + self.block2 = Block(128, 256, stride=2, start_with_relu=False, norm_layer=norm_layer) + self.block3 = Block(256, 728, stride=entry_block3_stride, norm_layer=norm_layer) # Middle flow self.mid = nn.Sequential(OrderedDict([('block%d' % i, Block( - 728, 728, stride=1, dilation=middle_block_dilation, - norm_layer=norm_layer, norm_kwargs=norm_kwargs)) for i in range(4, 20)])) + 728, 728, stride=1, dilation=middle_dilation, norm_layer=norm_layer)) for i in range(4, 20)])) # Exit flow self.block20 = Block( - 728, (728, 1024, 1024), stride=exit_block20_stride, dilation=exit_block_dilations[0], - norm_layer=norm_layer, norm_kwargs=norm_kwargs) + 728, (728, 1024, 1024), stride=exit_block20_stride, dilation=exit_dilation[0], norm_layer=norm_layer) self.block20_act = nn.ReLU(inplace=True) - self.conv3 = SeparableConv2d( - 1024, 1536, 3, stride=1, dilation=exit_block_dilations[1], - norm_layer=norm_layer, norm_kwargs=norm_kwargs) - self.bn3 = norm_layer(num_features=1536, **norm_kwargs) + self.conv3 = SeparableConv2d(1024, 1536, 3, stride=1, dilation=exit_dilation[1], norm_layer=norm_layer) + self.bn3 = norm_layer(num_features=1536) self.act3 = nn.ReLU(inplace=True) - self.conv4 = SeparableConv2d( - 1536, 1536, 3, stride=1, dilation=exit_block_dilations[1], - norm_layer=norm_layer, norm_kwargs=norm_kwargs) - self.bn4 = norm_layer(num_features=1536, **norm_kwargs) + self.conv4 = SeparableConv2d(1536, 1536, 3, stride=1, dilation=exit_dilation[1], norm_layer=norm_layer) + self.bn4 = norm_layer(num_features=1536) self.act4 = nn.ReLU(inplace=True) self.num_features = 2048 self.conv5 = SeparableConv2d( - 1536, self.num_features, 3, stride=1, dilation=exit_block_dilations[1], - norm_layer=norm_layer, norm_kwargs=norm_kwargs) - self.bn5 = norm_layer(num_features=self.num_features, **norm_kwargs) + 1536, self.num_features, 3, stride=1, dilation=exit_dilation[1], norm_layer=norm_layer) + self.bn5 = norm_layer(num_features=self.num_features) self.act5 = nn.ReLU(inplace=True) self.feature_info = [ dict(num_chs=64, reduction=2, module='act2'), diff --git a/timm/models/helpers.py b/timm/models/helpers.py index 562a01c5..d56cdc57 100644 --- a/timm/models/helpers.py +++ b/timm/models/helpers.py @@ -148,6 +148,31 @@ def load_custom_pretrained(model, cfg=None, load_fn=None, progress=False, check_ _logger.warning("Valid function to load pretrained weights is not available, using random initialization.") +def adapt_input_conv(in_chans, conv_weight): + conv_type = conv_weight.dtype + conv_weight = conv_weight.float() # Some weights are in torch.half, ensure it's float for sum on CPU + O, I, J, K = conv_weight.shape + if in_chans == 1: + if I > 3: + assert conv_weight.shape[1] % 3 == 0 + # For models with space2depth stems + conv_weight = conv_weight.reshape(O, I // 3, 3, J, K) + conv_weight = conv_weight.sum(dim=2, keepdim=False) + else: + conv_weight = conv_weight.sum(dim=1, keepdim=True) + elif in_chans != 3: + if I != 3: + raise NotImplementedError('Weight format not supported by conversion.') + else: + # NOTE this strategy should be better than random init, but there could be other combinations of + # the original RGB input layer weights that'd work better for specific cases. + repeat = int(math.ceil(in_chans / 3)) + conv_weight = conv_weight.repeat(1, repeat, 1, 1)[:, :in_chans, :, :] + conv_weight *= (3 / float(in_chans)) + conv_weight = conv_weight.to(conv_type) + return conv_weight + + def load_pretrained(model, cfg=None, num_classes=1000, in_chans=3, filter_fn=None, strict=True, progress=False): if cfg is None: cfg = getattr(model, 'default_cfg') @@ -159,56 +184,35 @@ def load_pretrained(model, cfg=None, num_classes=1000, in_chans=3, filter_fn=Non if filter_fn is not None: state_dict = filter_fn(state_dict) - if in_chans == 1: - conv1_name = cfg['first_conv'] - _logger.info('Converting first conv (%s) pretrained weights from 3 to 1 channel' % conv1_name) - conv1_weight = state_dict[conv1_name + '.weight'] - # Some weights are in torch.half, ensure it's float for sum on CPU - conv1_type = conv1_weight.dtype - conv1_weight = conv1_weight.float() - O, I, J, K = conv1_weight.shape - if I > 3: - assert conv1_weight.shape[1] % 3 == 0 - # For models with space2depth stems - conv1_weight = conv1_weight.reshape(O, I // 3, 3, J, K) - conv1_weight = conv1_weight.sum(dim=2, keepdim=False) - else: - conv1_weight = conv1_weight.sum(dim=1, keepdim=True) - conv1_weight = conv1_weight.to(conv1_type) - state_dict[conv1_name + '.weight'] = conv1_weight - elif in_chans != 3: - conv1_name = cfg['first_conv'] - conv1_weight = state_dict[conv1_name + '.weight'] - conv1_type = conv1_weight.dtype - conv1_weight = conv1_weight.float() - O, I, J, K = conv1_weight.shape - if I != 3: - _logger.warning('Deleting first conv (%s) from pretrained weights.' % conv1_name) - del state_dict[conv1_name + '.weight'] - strict = False - else: - # NOTE this strategy should be better than random init, but there could be other combinations of - # the original RGB input layer weights that'd work better for specific cases. - _logger.info('Repeating first conv (%s) weights in channel dim.' % conv1_name) - repeat = int(math.ceil(in_chans / 3)) - conv1_weight = conv1_weight.repeat(1, repeat, 1, 1)[:, :in_chans, :, :] - conv1_weight *= (3 / float(in_chans)) - conv1_weight = conv1_weight.to(conv1_type) - state_dict[conv1_name + '.weight'] = conv1_weight + input_convs = cfg.get('first_conv', None) + if input_convs is not None: + if isinstance(input_convs, str): + input_convs = (input_convs,) + for input_conv_name in input_convs: + weight_name = input_conv_name + '.weight' + try: + state_dict[weight_name] = adapt_input_conv(in_chans, state_dict[weight_name]) + _logger.info( + f'Converted input conv {input_conv_name} pretrained weights from 3 to {in_chans} channel(s)') + except NotImplementedError as e: + del state_dict[weight_name] + strict = False + _logger.warning( + f'Unable to convert pretrained {input_conv_name} weights, using random init for this layer.') classifier_name = cfg['classifier'] - if num_classes == 1000 and cfg['num_classes'] == 1001: - # FIXME this special case is problematic as number of pretrained weight sources increases - # special case for imagenet trained models with extra background class in pretrained weights - classifier_weight = state_dict[classifier_name + '.weight'] - state_dict[classifier_name + '.weight'] = classifier_weight[1:] - classifier_bias = state_dict[classifier_name + '.bias'] - state_dict[classifier_name + '.bias'] = classifier_bias[1:] - elif num_classes != cfg['num_classes']: - # completely discard fully connected for all other differences between pretrained and created model + label_offset = cfg.get('label_offset', 0) + if num_classes != cfg['num_classes']: + # completely discard fully connected if model num_classes doesn't match pretrained weights del state_dict[classifier_name + '.weight'] del state_dict[classifier_name + '.bias'] strict = False + elif label_offset > 0: + # special case for pretrained weights with an extra background class in pretrained weights + classifier_weight = state_dict[classifier_name + '.weight'] + state_dict[classifier_name + '.weight'] = classifier_weight[label_offset:] + classifier_bias = state_dict[classifier_name + '.bias'] + state_dict[classifier_name + '.bias'] = classifier_bias[label_offset:] model.load_state_dict(state_dict, strict=strict) diff --git a/timm/models/inception_resnet_v2.py b/timm/models/inception_resnet_v2.py index a5efa330..adfe330e 100644 --- a/timm/models/inception_resnet_v2.py +++ b/timm/models/inception_resnet_v2.py @@ -17,18 +17,20 @@ default_cfgs = { # ported from http://download.tensorflow.org/models/inception_resnet_v2_2016_08_30.tar.gz 'inception_resnet_v2': { 'url': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/inception_resnet_v2-940b1cd6.pth', - 'num_classes': 1001, 'input_size': (3, 299, 299), 'pool_size': (8, 8), + 'num_classes': 1000, 'input_size': (3, 299, 299), 'pool_size': (8, 8), 'crop_pct': 0.8975, 'interpolation': 'bicubic', 'mean': IMAGENET_INCEPTION_MEAN, 'std': IMAGENET_INCEPTION_STD, 'first_conv': 'conv2d_1a.conv', 'classifier': 'classif', + 'label_offset': 1, # 1001 classes in pretrained weights }, # ported from http://download.tensorflow.org/models/ens_adv_inception_resnet_v2_2017_08_18.tar.gz 'ens_adv_inception_resnet_v2': { 'url': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/ens_adv_inception_resnet_v2-2592a550.pth', - 'num_classes': 1001, 'input_size': (3, 299, 299), 'pool_size': (8, 8), + 'num_classes': 1000, 'input_size': (3, 299, 299), 'pool_size': (8, 8), 'crop_pct': 0.8975, 'interpolation': 'bicubic', 'mean': IMAGENET_INCEPTION_MEAN, 'std': IMAGENET_INCEPTION_STD, 'first_conv': 'conv2d_1a.conv', 'classifier': 'classif', + 'label_offset': 1, # 1001 classes in pretrained weights } } @@ -222,7 +224,7 @@ class Block8(nn.Module): class InceptionResnetV2(nn.Module): - def __init__(self, num_classes=1001, in_chans=3, drop_rate=0., output_stride=32, global_pool='avg'): + def __init__(self, num_classes=1000, in_chans=3, drop_rate=0., output_stride=32, global_pool='avg'): super(InceptionResnetV2, self).__init__() self.drop_rate = drop_rate self.num_classes = num_classes diff --git a/timm/models/inception_v4.py b/timm/models/inception_v4.py index 40a0f291..69f9ff5a 100644 --- a/timm/models/inception_v4.py +++ b/timm/models/inception_v4.py @@ -16,10 +16,11 @@ __all__ = ['InceptionV4'] default_cfgs = { 'inception_v4': { 'url': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-cadene/inceptionv4-8e4777a0.pth', - 'num_classes': 1001, 'input_size': (3, 299, 299), 'pool_size': (8, 8), + 'num_classes': 1000, 'input_size': (3, 299, 299), 'pool_size': (8, 8), 'crop_pct': 0.875, 'interpolation': 'bicubic', 'mean': IMAGENET_INCEPTION_MEAN, 'std': IMAGENET_INCEPTION_STD, 'first_conv': 'features.0.conv', 'classifier': 'last_linear', + 'label_offset': 1, # 1001 classes in pretrained weights } } @@ -241,7 +242,7 @@ class InceptionC(nn.Module): class InceptionV4(nn.Module): - def __init__(self, num_classes=1001, in_chans=3, output_stride=32, drop_rate=0., global_pool='avg'): + def __init__(self, num_classes=1000, in_chans=3, output_stride=32, drop_rate=0., global_pool='avg'): super(InceptionV4, self).__init__() assert output_stride == 32 self.drop_rate = drop_rate diff --git a/timm/models/layers/__init__.py b/timm/models/layers/__init__.py index 8f52099f..6eb9f8a1 100644 --- a/timm/models/layers/__init__.py +++ b/timm/models/layers/__init__.py @@ -12,7 +12,7 @@ from .conv_bn_act import ConvBnAct from .create_act import create_act_layer, get_act_layer, get_act_fn from .create_attn import get_attn, create_attn from .create_conv2d import create_conv2d -from .create_norm_act import create_norm_act, get_norm_act_layer +from .create_norm_act import get_norm_act_layer, create_norm_act, convert_norm_act from .drop import DropBlock2d, DropPath, drop_block_2d, drop_path from .eca import EcaModule, CecaModule from .evo_norm import EvoNormBatch2d, EvoNormSample2d diff --git a/timm/models/layers/conv_bn_act.py b/timm/models/layers/conv_bn_act.py index 90735357..33005c37 100644 --- a/timm/models/layers/conv_bn_act.py +++ b/timm/models/layers/conv_bn_act.py @@ -5,23 +5,23 @@ Hacked together by / Copyright 2020 Ross Wightman from torch import nn as nn from .create_conv2d import create_conv2d -from .create_norm_act import convert_norm_act_type +from .create_norm_act import convert_norm_act class ConvBnAct(nn.Module): def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, padding='', dilation=1, groups=1, - norm_layer=nn.BatchNorm2d, norm_kwargs=None, act_layer=nn.ReLU, apply_act=True, - drop_block=None, aa_layer=None): + bias=False, apply_act=True, norm_layer=nn.BatchNorm2d, act_layer=nn.ReLU, aa_layer=None, + drop_block=None): super(ConvBnAct, self).__init__() use_aa = aa_layer is not None self.conv = create_conv2d( in_channels, out_channels, kernel_size, stride=1 if use_aa else stride, - padding=padding, dilation=dilation, groups=groups, bias=False) + padding=padding, dilation=dilation, groups=groups, bias=bias) # NOTE for backwards compatibility with models that use separate norm and act layer definitions - norm_act_layer, norm_act_args = convert_norm_act_type(norm_layer, act_layer, norm_kwargs) - self.bn = norm_act_layer(out_channels, apply_act=apply_act, drop_block=drop_block, **norm_act_args) + norm_act_layer = convert_norm_act(norm_layer, act_layer) + self.bn = norm_act_layer(out_channels, apply_act=apply_act, drop_block=drop_block) self.aa = aa_layer(channels=out_channels) if stride == 2 and use_aa else None @property diff --git a/timm/models/layers/create_attn.py b/timm/models/layers/create_attn.py index f4a4c2c9..ff20e5df 100644 --- a/timm/models/layers/create_attn.py +++ b/timm/models/layers/create_attn.py @@ -9,6 +9,8 @@ from .cbam import CbamModule, LightCbamModule def get_attn(attn_type): + if isinstance(attn_type, torch.nn.Module): + return attn_type module_cls = None if attn_type is not None: if isinstance(attn_type, str): diff --git a/timm/models/layers/create_norm_act.py b/timm/models/layers/create_norm_act.py index 9e7e529e..5b562945 100644 --- a/timm/models/layers/create_norm_act.py +++ b/timm/models/layers/create_norm_act.py @@ -19,6 +19,7 @@ from .inplace_abn import InplaceAbn _NORM_ACT_TYPES = {BatchNormAct2d, GroupNormAct, EvoNormBatch2d, EvoNormSample2d, InplaceAbn} _NORM_ACT_REQUIRES_ARG = {BatchNormAct2d, GroupNormAct, InplaceAbn} # requires act_layer arg to define act type + def get_norm_act_layer(layer_class): layer_class = layer_class.replace('_', '').lower() if layer_class.startswith("batchnorm"): @@ -47,16 +48,22 @@ def create_norm_act(layer_type, num_features, apply_act=True, jit=False, **kwarg return layer_instance -def convert_norm_act_type(norm_layer, act_layer, norm_kwargs=None): +def convert_norm_act(norm_layer, act_layer): assert isinstance(norm_layer, (type, str, types.FunctionType, functools.partial)) assert act_layer is None or isinstance(act_layer, (type, str, types.FunctionType, functools.partial)) - norm_act_args = norm_kwargs.copy() if norm_kwargs else {} + norm_act_kwargs = {} + + # unbind partial fn, so args can be rebound later + if isinstance(norm_layer, functools.partial): + norm_act_kwargs.update(norm_layer.keywords) + norm_layer = norm_layer.func + if isinstance(norm_layer, str): norm_act_layer = get_norm_act_layer(norm_layer) elif norm_layer in _NORM_ACT_TYPES: norm_act_layer = norm_layer - elif isinstance(norm_layer, (types.FunctionType, functools.partial)): - # assuming this is a lambda/fn/bound partial that creates norm_act layer + elif isinstance(norm_layer, types.FunctionType): + # if function type, must be a lambda/fn that creates a norm_act layer norm_act_layer = norm_layer else: type_name = norm_layer.__name__.lower() @@ -66,9 +73,11 @@ def convert_norm_act_type(norm_layer, act_layer, norm_kwargs=None): norm_act_layer = GroupNormAct else: assert False, f"No equivalent norm_act layer for {type_name}" + if norm_act_layer in _NORM_ACT_REQUIRES_ARG: - # Must pass `act_layer` through for backwards compat where `act_layer=None` implies no activation. + # pass `act_layer` through for backwards compat where `act_layer=None` implies no activation. # In the future, may force use of `apply_act` with `act_layer` arg bound to relevant NormAct types - # It is intended that functions/partial does not trigger this, they should define act. - norm_act_args.update(dict(act_layer=act_layer)) - return norm_act_layer, norm_act_args + norm_act_kwargs.setdefault('act_layer', act_layer) + if norm_act_kwargs: + norm_act_layer = functools.partial(norm_act_layer, **norm_act_kwargs) # bind/rebind args + return norm_act_layer diff --git a/timm/models/layers/norm_act.py b/timm/models/layers/norm_act.py index e3fe3940..02cabe88 100644 --- a/timm/models/layers/norm_act.py +++ b/timm/models/layers/norm_act.py @@ -24,7 +24,7 @@ class BatchNormAct2d(nn.BatchNorm2d): act_args = dict(inplace=True) if inplace else {} self.act = act_layer(**act_args) else: - self.act = None + self.act = nn.Identity() def _forward_jit(self, x): """ A cut & paste of the contents of the PyTorch BatchNorm2d forward function @@ -62,8 +62,7 @@ class BatchNormAct2d(nn.BatchNorm2d): x = self._forward_jit(x) else: x = self._forward_python(x) - if self.act is not None: - x = self.act(x) + x = self.act(x) return x @@ -75,12 +74,12 @@ class GroupNormAct(nn.GroupNorm): if isinstance(act_layer, str): act_layer = get_act_layer(act_layer) if act_layer is not None and apply_act: - self.act = act_layer(inplace=inplace) + act_args = dict(inplace=True) if inplace else {} + self.act = act_layer(**act_args) else: - self.act = None + self.act = nn.Identity() def forward(self, x): x = F.group_norm(x, self.num_groups, self.weight, self.bias, self.eps) - if self.act is not None: - x = self.act(x) + x = self.act(x) return x diff --git a/timm/models/layers/separable_conv.py b/timm/models/layers/separable_conv.py index e949ea43..1ddcb4e6 100644 --- a/timm/models/layers/separable_conv.py +++ b/timm/models/layers/separable_conv.py @@ -8,17 +8,16 @@ Hacked together by / Copyright 2020 Ross Wightman from torch import nn as nn from .create_conv2d import create_conv2d -from .create_norm_act import convert_norm_act_type +from .create_norm_act import convert_norm_act class SeparableConvBnAct(nn.Module): """ Separable Conv w/ trailing Norm and Activation """ def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, dilation=1, padding='', bias=False, - channel_multiplier=1.0, pw_kernel_size=1, norm_layer=nn.BatchNorm2d, norm_kwargs=None, - act_layer=nn.ReLU, apply_act=True, drop_block=None): + channel_multiplier=1.0, pw_kernel_size=1, norm_layer=nn.BatchNorm2d, act_layer=nn.ReLU, + apply_act=True, drop_block=None): super(SeparableConvBnAct, self).__init__() - norm_kwargs = norm_kwargs or {} self.conv_dw = create_conv2d( in_channels, int(in_channels * channel_multiplier), kernel_size, @@ -27,8 +26,8 @@ class SeparableConvBnAct(nn.Module): self.conv_pw = create_conv2d( int(in_channels * channel_multiplier), out_channels, pw_kernel_size, padding=padding, bias=bias) - norm_act_layer, norm_act_args = convert_norm_act_type(norm_layer, act_layer, norm_kwargs) - self.bn = norm_act_layer(out_channels, apply_act=apply_act, drop_block=drop_block, **norm_act_args) + norm_act_layer = convert_norm_act(norm_layer, act_layer) + self.bn = norm_act_layer(out_channels, apply_act=apply_act, drop_block=drop_block) @property def in_channels(self): diff --git a/timm/models/nasnet.py b/timm/models/nasnet.py index 60e1a276..1f1a3b75 100644 --- a/timm/models/nasnet.py +++ b/timm/models/nasnet.py @@ -1,6 +1,9 @@ +""" NasNet-A (Large) + nasnetalarge implementation grabbed from Cadene's pretrained models + https://github.com/Cadene/pretrained-models.pytorch """ +from functools import partial -""" import torch import torch.nn as nn import torch.nn.functional as F @@ -20,9 +23,10 @@ default_cfgs = { 'interpolation': 'bicubic', 'mean': (0.5, 0.5, 0.5), 'std': (0.5, 0.5, 0.5), - 'num_classes': 1001, + 'num_classes': 1000, 'first_conv': 'conv0.conv', 'classifier': 'last_linear', + 'label_offset': 1, # 1001 classes in pretrained weights }, } @@ -418,7 +422,7 @@ class NASNetALarge(nn.Module): self.conv0 = ConvBnAct( in_channels=in_chans, out_channels=self.stem_size, kernel_size=3, padding=0, stride=2, - norm_kwargs=dict(eps=0.001, momentum=0.1), act_layer=None) + norm_layer=partial(nn.BatchNorm2d, eps=0.001, momentum=0.1), apply_act=False) self.cell_stem_0 = CellStem0( self.stem_size, num_channels=channels // (channel_multiplier ** 2), pad_type=pad_type) diff --git a/timm/models/pnasnet.py b/timm/models/pnasnet.py index 5f1e177f..73073009 100644 --- a/timm/models/pnasnet.py +++ b/timm/models/pnasnet.py @@ -6,6 +6,7 @@ """ from collections import OrderedDict +from functools import partial import torch import torch.nn as nn @@ -26,9 +27,10 @@ default_cfgs = { 'interpolation': 'bicubic', 'mean': (0.5, 0.5, 0.5), 'std': (0.5, 0.5, 0.5), - 'num_classes': 1001, + 'num_classes': 1000, 'first_conv': 'conv_0.conv', 'classifier': 'last_linear', + 'label_offset': 1, # 1001 classes in pretrained weights }, } @@ -234,7 +236,7 @@ class Cell(CellBase): class PNASNet5Large(nn.Module): - def __init__(self, num_classes=1001, in_chans=3, output_stride=32, drop_rate=0., global_pool='avg', pad_type=''): + def __init__(self, num_classes=1000, in_chans=3, output_stride=32, drop_rate=0., global_pool='avg', pad_type=''): super(PNASNet5Large, self).__init__() self.num_classes = num_classes self.drop_rate = drop_rate @@ -243,7 +245,7 @@ class PNASNet5Large(nn.Module): self.conv_0 = ConvBnAct( in_chans, 96, kernel_size=3, stride=2, padding=0, - norm_kwargs=dict(eps=0.001, momentum=0.1), act_layer=None) + norm_layer=partial(nn.BatchNorm2d, eps=0.001, momentum=0.1), apply_act=False) self.cell_stem_0 = CellStem0( in_chs_left=96, out_chs_left=54, in_chs_right=96, out_chs_right=54, pad_type=pad_type) diff --git a/timm/models/xception_aligned.py b/timm/models/xception_aligned.py index e6b21576..dd7a7a86 100644 --- a/timm/models/xception_aligned.py +++ b/timm/models/xception_aligned.py @@ -5,7 +5,7 @@ https://github.com/tensorflow/models/blob/master/research/deeplab/g3doc/model_zo Hacked together by / Copyright 2020 Ross Wightman """ -from collections import OrderedDict +from functools import partial import torch.nn as nn import torch.nn.functional as F @@ -43,9 +43,8 @@ default_cfgs = dict( class SeparableConv2d(nn.Module): def __init__( self, inplanes, planes, kernel_size=3, stride=1, dilation=1, padding='', - act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, norm_kwargs=None): + act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d): super(SeparableConv2d, self).__init__() - norm_kwargs = norm_kwargs if norm_kwargs is not None else {} self.kernel_size = kernel_size self.dilation = dilation @@ -53,7 +52,7 @@ class SeparableConv2d(nn.Module): self.conv_dw = create_conv2d( inplanes, inplanes, kernel_size, stride=stride, padding=padding, dilation=dilation, depthwise=True) - self.bn_dw = norm_layer(inplanes, **norm_kwargs) + self.bn_dw = norm_layer(inplanes) if act_layer is not None: self.act_dw = act_layer(inplace=True) else: @@ -61,7 +60,7 @@ class SeparableConv2d(nn.Module): # pointwise convolution self.conv_pw = create_conv2d(inplanes, planes, kernel_size=1) - self.bn_pw = norm_layer(planes, **norm_kwargs) + self.bn_pw = norm_layer(planes) if act_layer is not None: self.act_pw = act_layer(inplace=True) else: @@ -82,17 +81,15 @@ class SeparableConv2d(nn.Module): class XceptionModule(nn.Module): def __init__( self, in_chs, out_chs, stride=1, dilation=1, pad_type='', - start_with_relu=True, no_skip=False, act_layer=nn.ReLU, norm_layer=None, norm_kwargs=None): + start_with_relu=True, no_skip=False, act_layer=nn.ReLU, norm_layer=None): super(XceptionModule, self).__init__() - norm_kwargs = norm_kwargs if norm_kwargs is not None else {} out_chs = to_3tuple(out_chs) self.in_channels = in_chs self.out_channels = out_chs[-1] self.no_skip = no_skip if not no_skip and (self.out_channels != self.in_channels or stride != 1): self.shortcut = ConvBnAct( - in_chs, self.out_channels, 1, stride=stride, - norm_layer=norm_layer, norm_kwargs=norm_kwargs, act_layer=None) + in_chs, self.out_channels, 1, stride=stride, norm_layer=norm_layer, act_layer=None) else: self.shortcut = None @@ -103,7 +100,7 @@ class XceptionModule(nn.Module): self.stack.add_module(f'act{i + 1}', nn.ReLU(inplace=i > 0)) self.stack.add_module(f'conv{i + 1}', SeparableConv2d( in_chs, out_chs[i], 3, stride=stride if i == 2 else 1, dilation=dilation, padding=pad_type, - act_layer=separable_act_layer, norm_layer=norm_layer, norm_kwargs=norm_kwargs)) + act_layer=separable_act_layer, norm_layer=norm_layer)) in_chs = out_chs[i] def forward(self, x): @@ -121,14 +118,13 @@ class XceptionAligned(nn.Module): """ def __init__(self, block_cfg, num_classes=1000, in_chans=3, output_stride=32, - act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, norm_kwargs=None, drop_rate=0., global_pool='avg'): + act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, drop_rate=0., global_pool='avg'): super(XceptionAligned, self).__init__() self.num_classes = num_classes self.drop_rate = drop_rate assert output_stride in (8, 16, 32) - norm_kwargs = norm_kwargs if norm_kwargs is not None else {} - layer_args = dict(act_layer=act_layer, norm_layer=norm_layer, norm_kwargs=norm_kwargs) + layer_args = dict(act_layer=act_layer, norm_layer=norm_layer) self.stem = nn.Sequential(*[ ConvBnAct(in_chans, 32, kernel_size=3, stride=2, **layer_args), ConvBnAct(32, 64, kernel_size=3, stride=1, **layer_args) @@ -196,7 +192,7 @@ def xception41(pretrained=False, **kwargs): dict(in_chs=728, out_chs=(728, 1024, 1024), stride=2), dict(in_chs=1024, out_chs=(1536, 1536, 2048), stride=1, no_skip=True, start_with_relu=False), ] - model_args = dict(block_cfg=block_cfg, norm_kwargs=dict(eps=.001, momentum=.1), **kwargs) + model_args = dict(block_cfg=block_cfg, norm_layer=partial(nn.BatchNorm2d, eps=.001, momentum=.1), **kwargs) return _xception('xception41', pretrained=pretrained, **model_args) @@ -215,7 +211,7 @@ def xception65(pretrained=False, **kwargs): dict(in_chs=728, out_chs=(728, 1024, 1024), stride=2), dict(in_chs=1024, out_chs=(1536, 1536, 2048), stride=1, no_skip=True, start_with_relu=False), ] - model_args = dict(block_cfg=block_cfg, norm_kwargs=dict(eps=.001, momentum=.1), **kwargs) + model_args = dict(block_cfg=block_cfg, norm_layer=partial(nn.BatchNorm2d, eps=.001, momentum=.1), **kwargs) return _xception('xception65', pretrained=pretrained, **model_args) @@ -236,5 +232,5 @@ def xception71(pretrained=False, **kwargs): dict(in_chs=728, out_chs=(728, 1024, 1024), stride=2), dict(in_chs=1024, out_chs=(1536, 1536, 2048), stride=1, no_skip=True, start_with_relu=False), ] - model_args = dict(block_cfg=block_cfg, norm_kwargs=dict(eps=.001, momentum=.1), **kwargs) + model_args = dict(block_cfg=block_cfg, norm_layer=partial(nn.BatchNorm2d, eps=.001, momentum=.1), **kwargs) return _xception('xception71', pretrained=pretrained, **model_args)