Fix regression in models with 1001 class pretrained weights. Improve batchnorm arg and BatchNormAct layer handling in several models.

pull/419/head
Ross Wightman 4 years ago
parent aaa715b1e9
commit 9811e229f7

@ -83,7 +83,6 @@ def test_model_default_cfgs(model_name, batch_size):
cfg = model.default_cfg cfg = model.default_cfg
classifier = cfg['classifier'] classifier = cfg['classifier']
first_conv = cfg['first_conv']
pool_size = cfg['pool_size'] pool_size = cfg['pool_size']
input_size = model.default_cfg['input_size'] input_size = model.default_cfg['input_size']
@ -111,9 +110,16 @@ def test_model_default_cfgs(model_name, batch_size):
# FIXME mobilenetv3 forward_features vs removed pooling differ # FIXME mobilenetv3 forward_features vs removed pooling differ
assert outputs.shape[-1] == pool_size[-1] and outputs.shape[-2] == pool_size[-2] assert outputs.shape[-1] == pool_size[-1] and outputs.shape[-2] == pool_size[-2]
# check classifier and first convolution names match those in default_cfg # check classifier name matches default_cfg
assert classifier + ".weight" in state_dict.keys(), f'{classifier} not in model params' assert classifier + ".weight" in state_dict.keys(), f'{classifier} not in model params'
assert first_conv + ".weight" in state_dict.keys(), f'{first_conv} not in model params'
# check first conv(s) names match default_cfg
first_conv = cfg['first_conv']
if isinstance(first_conv, str):
first_conv = (first_conv,)
assert isinstance(first_conv, (tuple, list))
for fc in first_conv:
assert fc + ".weight" in state_dict.keys(), f'{fc} not in model params'
if 'GITHUB_ACTIONS' not in os.environ: if 'GITHUB_ACTIONS' not in os.environ:

@ -7,6 +7,7 @@ This implementation is compatible with the pretrained weights from cypw's MXNet
Hacked together by / Copyright 2020 Ross Wightman Hacked together by / Copyright 2020 Ross Wightman
""" """
from collections import OrderedDict from collections import OrderedDict
from functools import partial
from typing import Tuple from typing import Tuple
import torch import torch
@ -173,12 +174,14 @@ class DPN(nn.Module):
self.drop_rate = drop_rate self.drop_rate = drop_rate
self.b = b self.b = b
assert output_stride == 32 # FIXME look into dilation support assert output_stride == 32 # FIXME look into dilation support
norm_layer = partial(BatchNormAct2d, eps=.001)
fc_norm_layer = partial(BatchNormAct2d, eps=.001, act_layer=fc_act, inplace=False)
bw_factor = 1 if small else 4 bw_factor = 1 if small else 4
blocks = OrderedDict() blocks = OrderedDict()
# conv1 # conv1
blocks['conv1_1'] = ConvBnAct( blocks['conv1_1'] = ConvBnAct(
in_chans, num_init_features, kernel_size=3 if small else 7, stride=2, norm_kwargs=dict(eps=.001)) in_chans, num_init_features, kernel_size=3 if small else 7, stride=2, norm_layer=norm_layer)
blocks['conv1_pool'] = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) blocks['conv1_pool'] = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.feature_info = [dict(num_chs=num_init_features, reduction=2, module='features.conv1_1')] self.feature_info = [dict(num_chs=num_init_features, reduction=2, module='features.conv1_1')]
@ -226,8 +229,7 @@ class DPN(nn.Module):
in_chs += inc in_chs += inc
self.feature_info += [dict(num_chs=in_chs, reduction=32, module=f'features.conv5_{k_sec[3]}')] self.feature_info += [dict(num_chs=in_chs, reduction=32, module=f'features.conv5_{k_sec[3]}')]
def _fc_norm(f, eps): return BatchNormAct2d(f, eps=eps, act_layer=fc_act, inplace=False) blocks['conv5_bn_ac'] = CatBnAct(in_chs, norm_layer=fc_norm_layer)
blocks['conv5_bn_ac'] = CatBnAct(in_chs, norm_layer=_fc_norm)
self.num_features = in_chs self.num_features = in_chs
self.features = nn.Sequential(blocks) self.features = nn.Sequential(blocks)

@ -42,10 +42,8 @@ for Tensorflow 'SAME' padding. PyTorch symmetric padding behaves the way we'd w
class SeparableConv2d(nn.Module): class SeparableConv2d(nn.Module):
def __init__(self, inplanes, planes, kernel_size=3, stride=1, def __init__(self, inplanes, planes, kernel_size=3, stride=1, dilation=1, bias=False, norm_layer=None):
dilation=1, bias=False, norm_layer=None, norm_kwargs=None):
super(SeparableConv2d, self).__init__() super(SeparableConv2d, self).__init__()
norm_kwargs = norm_kwargs if norm_kwargs is not None else {}
self.kernel_size = kernel_size self.kernel_size = kernel_size
self.dilation = dilation self.dilation = dilation
@ -54,7 +52,7 @@ class SeparableConv2d(nn.Module):
self.conv_dw = nn.Conv2d( self.conv_dw = nn.Conv2d(
inplanes, inplanes, kernel_size, stride=stride, inplanes, inplanes, kernel_size, stride=stride,
padding=padding, dilation=dilation, groups=inplanes, bias=bias) padding=padding, dilation=dilation, groups=inplanes, bias=bias)
self.bn = norm_layer(num_features=inplanes, **norm_kwargs) self.bn = norm_layer(num_features=inplanes)
# pointwise convolution # pointwise convolution
self.conv_pw = nn.Conv2d(inplanes, planes, kernel_size=1, bias=bias) self.conv_pw = nn.Conv2d(inplanes, planes, kernel_size=1, bias=bias)
@ -66,10 +64,8 @@ class SeparableConv2d(nn.Module):
class Block(nn.Module): class Block(nn.Module):
def __init__(self, inplanes, planes, stride=1, dilation=1, start_with_relu=True, def __init__(self, inplanes, planes, stride=1, dilation=1, start_with_relu=True, norm_layer=None):
norm_layer=None, norm_kwargs=None, ):
super(Block, self).__init__() super(Block, self).__init__()
norm_kwargs = norm_kwargs if norm_kwargs is not None else {}
if isinstance(planes, (list, tuple)): if isinstance(planes, (list, tuple)):
assert len(planes) == 3 assert len(planes) == 3
else: else:
@ -80,7 +76,7 @@ class Block(nn.Module):
self.skip = nn.Sequential() self.skip = nn.Sequential()
self.skip.add_module('conv1', nn.Conv2d( self.skip.add_module('conv1', nn.Conv2d(
inplanes, outplanes, 1, stride=stride, bias=False)), inplanes, outplanes, 1, stride=stride, bias=False)),
self.skip.add_module('bn1', norm_layer(num_features=outplanes, **norm_kwargs)) self.skip.add_module('bn1', norm_layer(num_features=outplanes))
else: else:
self.skip = None self.skip = None
@ -88,9 +84,8 @@ class Block(nn.Module):
for i in range(3): for i in range(3):
rep['act%d' % (i + 1)] = nn.ReLU(inplace=True) rep['act%d' % (i + 1)] = nn.ReLU(inplace=True)
rep['conv%d' % (i + 1)] = SeparableConv2d( rep['conv%d' % (i + 1)] = SeparableConv2d(
inplanes, planes[i], 3, stride=stride if i == 2 else 1, dilation=dilation, inplanes, planes[i], 3, stride=stride if i == 2 else 1, dilation=dilation, norm_layer=norm_layer)
norm_layer=norm_layer, norm_kwargs=norm_kwargs) rep['bn%d' % (i + 1)] = norm_layer(planes[i])
rep['bn%d' % (i + 1)] = norm_layer(planes[i], **norm_kwargs)
inplanes = planes[i] inplanes = planes[i]
if not start_with_relu: if not start_with_relu:
@ -115,74 +110,63 @@ class Xception65(nn.Module):
""" """
def __init__(self, num_classes=1000, in_chans=3, output_stride=32, norm_layer=nn.BatchNorm2d, def __init__(self, num_classes=1000, in_chans=3, output_stride=32, norm_layer=nn.BatchNorm2d,
norm_kwargs=None, drop_rate=0., global_pool='avg'): drop_rate=0., global_pool='avg'):
super(Xception65, self).__init__() super(Xception65, self).__init__()
self.num_classes = num_classes self.num_classes = num_classes
self.drop_rate = drop_rate self.drop_rate = drop_rate
norm_kwargs = norm_kwargs if norm_kwargs is not None else {}
if output_stride == 32: if output_stride == 32:
entry_block3_stride = 2 entry_block3_stride = 2
exit_block20_stride = 2 exit_block20_stride = 2
middle_block_dilation = 1 middle_dilation = 1
exit_block_dilations = (1, 1) exit_dilation = (1, 1)
elif output_stride == 16: elif output_stride == 16:
entry_block3_stride = 2 entry_block3_stride = 2
exit_block20_stride = 1 exit_block20_stride = 1
middle_block_dilation = 1 middle_dilation = 1
exit_block_dilations = (1, 2) exit_dilation = (1, 2)
elif output_stride == 8: elif output_stride == 8:
entry_block3_stride = 1 entry_block3_stride = 1
exit_block20_stride = 1 exit_block20_stride = 1
middle_block_dilation = 2 middle_dilation = 2
exit_block_dilations = (2, 4) exit_dilation = (2, 4)
else: else:
raise NotImplementedError raise NotImplementedError
# Entry flow # Entry flow
self.conv1 = nn.Conv2d(in_chans, 32, kernel_size=3, stride=2, padding=1, bias=False) self.conv1 = nn.Conv2d(in_chans, 32, kernel_size=3, stride=2, padding=1, bias=False)
self.bn1 = norm_layer(num_features=32, **norm_kwargs) self.bn1 = norm_layer(num_features=32)
self.act1 = nn.ReLU(inplace=True) self.act1 = nn.ReLU(inplace=True)
self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1, bias=False) self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1, bias=False)
self.bn2 = norm_layer(num_features=64) self.bn2 = norm_layer(num_features=64)
self.act2 = nn.ReLU(inplace=True) self.act2 = nn.ReLU(inplace=True)
self.block1 = Block( self.block1 = Block(64, 128, stride=2, start_with_relu=False, norm_layer=norm_layer)
64, 128, stride=2, start_with_relu=False, norm_layer=norm_layer, norm_kwargs=norm_kwargs)
self.block1_act = nn.ReLU(inplace=True) self.block1_act = nn.ReLU(inplace=True)
self.block2 = Block( self.block2 = Block(128, 256, stride=2, start_with_relu=False, norm_layer=norm_layer)
128, 256, stride=2, start_with_relu=False, norm_layer=norm_layer, norm_kwargs=norm_kwargs) self.block3 = Block(256, 728, stride=entry_block3_stride, norm_layer=norm_layer)
self.block3 = Block(
256, 728, stride=entry_block3_stride, norm_layer=norm_layer, norm_kwargs=norm_kwargs)
# Middle flow # Middle flow
self.mid = nn.Sequential(OrderedDict([('block%d' % i, Block( self.mid = nn.Sequential(OrderedDict([('block%d' % i, Block(
728, 728, stride=1, dilation=middle_block_dilation, 728, 728, stride=1, dilation=middle_dilation, norm_layer=norm_layer)) for i in range(4, 20)]))
norm_layer=norm_layer, norm_kwargs=norm_kwargs)) for i in range(4, 20)]))
# Exit flow # Exit flow
self.block20 = Block( self.block20 = Block(
728, (728, 1024, 1024), stride=exit_block20_stride, dilation=exit_block_dilations[0], 728, (728, 1024, 1024), stride=exit_block20_stride, dilation=exit_dilation[0], norm_layer=norm_layer)
norm_layer=norm_layer, norm_kwargs=norm_kwargs)
self.block20_act = nn.ReLU(inplace=True) self.block20_act = nn.ReLU(inplace=True)
self.conv3 = SeparableConv2d( self.conv3 = SeparableConv2d(1024, 1536, 3, stride=1, dilation=exit_dilation[1], norm_layer=norm_layer)
1024, 1536, 3, stride=1, dilation=exit_block_dilations[1], self.bn3 = norm_layer(num_features=1536)
norm_layer=norm_layer, norm_kwargs=norm_kwargs)
self.bn3 = norm_layer(num_features=1536, **norm_kwargs)
self.act3 = nn.ReLU(inplace=True) self.act3 = nn.ReLU(inplace=True)
self.conv4 = SeparableConv2d( self.conv4 = SeparableConv2d(1536, 1536, 3, stride=1, dilation=exit_dilation[1], norm_layer=norm_layer)
1536, 1536, 3, stride=1, dilation=exit_block_dilations[1], self.bn4 = norm_layer(num_features=1536)
norm_layer=norm_layer, norm_kwargs=norm_kwargs)
self.bn4 = norm_layer(num_features=1536, **norm_kwargs)
self.act4 = nn.ReLU(inplace=True) self.act4 = nn.ReLU(inplace=True)
self.num_features = 2048 self.num_features = 2048
self.conv5 = SeparableConv2d( self.conv5 = SeparableConv2d(
1536, self.num_features, 3, stride=1, dilation=exit_block_dilations[1], 1536, self.num_features, 3, stride=1, dilation=exit_dilation[1], norm_layer=norm_layer)
norm_layer=norm_layer, norm_kwargs=norm_kwargs) self.bn5 = norm_layer(num_features=self.num_features)
self.bn5 = norm_layer(num_features=self.num_features, **norm_kwargs)
self.act5 = nn.ReLU(inplace=True) self.act5 = nn.ReLU(inplace=True)
self.feature_info = [ self.feature_info = [
dict(num_chs=64, reduction=2, module='act2'), dict(num_chs=64, reduction=2, module='act2'),

@ -148,6 +148,31 @@ def load_custom_pretrained(model, cfg=None, load_fn=None, progress=False, check_
_logger.warning("Valid function to load pretrained weights is not available, using random initialization.") _logger.warning("Valid function to load pretrained weights is not available, using random initialization.")
def adapt_input_conv(in_chans, conv_weight):
conv_type = conv_weight.dtype
conv_weight = conv_weight.float() # Some weights are in torch.half, ensure it's float for sum on CPU
O, I, J, K = conv_weight.shape
if in_chans == 1:
if I > 3:
assert conv_weight.shape[1] % 3 == 0
# For models with space2depth stems
conv_weight = conv_weight.reshape(O, I // 3, 3, J, K)
conv_weight = conv_weight.sum(dim=2, keepdim=False)
else:
conv_weight = conv_weight.sum(dim=1, keepdim=True)
elif in_chans != 3:
if I != 3:
raise NotImplementedError('Weight format not supported by conversion.')
else:
# NOTE this strategy should be better than random init, but there could be other combinations of
# the original RGB input layer weights that'd work better for specific cases.
repeat = int(math.ceil(in_chans / 3))
conv_weight = conv_weight.repeat(1, repeat, 1, 1)[:, :in_chans, :, :]
conv_weight *= (3 / float(in_chans))
conv_weight = conv_weight.to(conv_type)
return conv_weight
def load_pretrained(model, cfg=None, num_classes=1000, in_chans=3, filter_fn=None, strict=True, progress=False): def load_pretrained(model, cfg=None, num_classes=1000, in_chans=3, filter_fn=None, strict=True, progress=False):
if cfg is None: if cfg is None:
cfg = getattr(model, 'default_cfg') cfg = getattr(model, 'default_cfg')
@ -159,56 +184,35 @@ def load_pretrained(model, cfg=None, num_classes=1000, in_chans=3, filter_fn=Non
if filter_fn is not None: if filter_fn is not None:
state_dict = filter_fn(state_dict) state_dict = filter_fn(state_dict)
if in_chans == 1: input_convs = cfg.get('first_conv', None)
conv1_name = cfg['first_conv'] if input_convs is not None:
_logger.info('Converting first conv (%s) pretrained weights from 3 to 1 channel' % conv1_name) if isinstance(input_convs, str):
conv1_weight = state_dict[conv1_name + '.weight'] input_convs = (input_convs,)
# Some weights are in torch.half, ensure it's float for sum on CPU for input_conv_name in input_convs:
conv1_type = conv1_weight.dtype weight_name = input_conv_name + '.weight'
conv1_weight = conv1_weight.float() try:
O, I, J, K = conv1_weight.shape state_dict[weight_name] = adapt_input_conv(in_chans, state_dict[weight_name])
if I > 3: _logger.info(
assert conv1_weight.shape[1] % 3 == 0 f'Converted input conv {input_conv_name} pretrained weights from 3 to {in_chans} channel(s)')
# For models with space2depth stems except NotImplementedError as e:
conv1_weight = conv1_weight.reshape(O, I // 3, 3, J, K) del state_dict[weight_name]
conv1_weight = conv1_weight.sum(dim=2, keepdim=False)
else:
conv1_weight = conv1_weight.sum(dim=1, keepdim=True)
conv1_weight = conv1_weight.to(conv1_type)
state_dict[conv1_name + '.weight'] = conv1_weight
elif in_chans != 3:
conv1_name = cfg['first_conv']
conv1_weight = state_dict[conv1_name + '.weight']
conv1_type = conv1_weight.dtype
conv1_weight = conv1_weight.float()
O, I, J, K = conv1_weight.shape
if I != 3:
_logger.warning('Deleting first conv (%s) from pretrained weights.' % conv1_name)
del state_dict[conv1_name + '.weight']
strict = False strict = False
else: _logger.warning(
# NOTE this strategy should be better than random init, but there could be other combinations of f'Unable to convert pretrained {input_conv_name} weights, using random init for this layer.')
# the original RGB input layer weights that'd work better for specific cases.
_logger.info('Repeating first conv (%s) weights in channel dim.' % conv1_name)
repeat = int(math.ceil(in_chans / 3))
conv1_weight = conv1_weight.repeat(1, repeat, 1, 1)[:, :in_chans, :, :]
conv1_weight *= (3 / float(in_chans))
conv1_weight = conv1_weight.to(conv1_type)
state_dict[conv1_name + '.weight'] = conv1_weight
classifier_name = cfg['classifier'] classifier_name = cfg['classifier']
if num_classes == 1000 and cfg['num_classes'] == 1001: label_offset = cfg.get('label_offset', 0)
# FIXME this special case is problematic as number of pretrained weight sources increases if num_classes != cfg['num_classes']:
# special case for imagenet trained models with extra background class in pretrained weights # completely discard fully connected if model num_classes doesn't match pretrained weights
classifier_weight = state_dict[classifier_name + '.weight']
state_dict[classifier_name + '.weight'] = classifier_weight[1:]
classifier_bias = state_dict[classifier_name + '.bias']
state_dict[classifier_name + '.bias'] = classifier_bias[1:]
elif num_classes != cfg['num_classes']:
# completely discard fully connected for all other differences between pretrained and created model
del state_dict[classifier_name + '.weight'] del state_dict[classifier_name + '.weight']
del state_dict[classifier_name + '.bias'] del state_dict[classifier_name + '.bias']
strict = False strict = False
elif label_offset > 0:
# special case for pretrained weights with an extra background class in pretrained weights
classifier_weight = state_dict[classifier_name + '.weight']
state_dict[classifier_name + '.weight'] = classifier_weight[label_offset:]
classifier_bias = state_dict[classifier_name + '.bias']
state_dict[classifier_name + '.bias'] = classifier_bias[label_offset:]
model.load_state_dict(state_dict, strict=strict) model.load_state_dict(state_dict, strict=strict)

@ -17,18 +17,20 @@ default_cfgs = {
# ported from http://download.tensorflow.org/models/inception_resnet_v2_2016_08_30.tar.gz # ported from http://download.tensorflow.org/models/inception_resnet_v2_2016_08_30.tar.gz
'inception_resnet_v2': { 'inception_resnet_v2': {
'url': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/inception_resnet_v2-940b1cd6.pth', 'url': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/inception_resnet_v2-940b1cd6.pth',
'num_classes': 1001, 'input_size': (3, 299, 299), 'pool_size': (8, 8), 'num_classes': 1000, 'input_size': (3, 299, 299), 'pool_size': (8, 8),
'crop_pct': 0.8975, 'interpolation': 'bicubic', 'crop_pct': 0.8975, 'interpolation': 'bicubic',
'mean': IMAGENET_INCEPTION_MEAN, 'std': IMAGENET_INCEPTION_STD, 'mean': IMAGENET_INCEPTION_MEAN, 'std': IMAGENET_INCEPTION_STD,
'first_conv': 'conv2d_1a.conv', 'classifier': 'classif', 'first_conv': 'conv2d_1a.conv', 'classifier': 'classif',
'label_offset': 1, # 1001 classes in pretrained weights
}, },
# ported from http://download.tensorflow.org/models/ens_adv_inception_resnet_v2_2017_08_18.tar.gz # ported from http://download.tensorflow.org/models/ens_adv_inception_resnet_v2_2017_08_18.tar.gz
'ens_adv_inception_resnet_v2': { 'ens_adv_inception_resnet_v2': {
'url': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/ens_adv_inception_resnet_v2-2592a550.pth', 'url': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/ens_adv_inception_resnet_v2-2592a550.pth',
'num_classes': 1001, 'input_size': (3, 299, 299), 'pool_size': (8, 8), 'num_classes': 1000, 'input_size': (3, 299, 299), 'pool_size': (8, 8),
'crop_pct': 0.8975, 'interpolation': 'bicubic', 'crop_pct': 0.8975, 'interpolation': 'bicubic',
'mean': IMAGENET_INCEPTION_MEAN, 'std': IMAGENET_INCEPTION_STD, 'mean': IMAGENET_INCEPTION_MEAN, 'std': IMAGENET_INCEPTION_STD,
'first_conv': 'conv2d_1a.conv', 'classifier': 'classif', 'first_conv': 'conv2d_1a.conv', 'classifier': 'classif',
'label_offset': 1, # 1001 classes in pretrained weights
} }
} }
@ -222,7 +224,7 @@ class Block8(nn.Module):
class InceptionResnetV2(nn.Module): class InceptionResnetV2(nn.Module):
def __init__(self, num_classes=1001, in_chans=3, drop_rate=0., output_stride=32, global_pool='avg'): def __init__(self, num_classes=1000, in_chans=3, drop_rate=0., output_stride=32, global_pool='avg'):
super(InceptionResnetV2, self).__init__() super(InceptionResnetV2, self).__init__()
self.drop_rate = drop_rate self.drop_rate = drop_rate
self.num_classes = num_classes self.num_classes = num_classes

@ -16,10 +16,11 @@ __all__ = ['InceptionV4']
default_cfgs = { default_cfgs = {
'inception_v4': { 'inception_v4': {
'url': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-cadene/inceptionv4-8e4777a0.pth', 'url': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-cadene/inceptionv4-8e4777a0.pth',
'num_classes': 1001, 'input_size': (3, 299, 299), 'pool_size': (8, 8), 'num_classes': 1000, 'input_size': (3, 299, 299), 'pool_size': (8, 8),
'crop_pct': 0.875, 'interpolation': 'bicubic', 'crop_pct': 0.875, 'interpolation': 'bicubic',
'mean': IMAGENET_INCEPTION_MEAN, 'std': IMAGENET_INCEPTION_STD, 'mean': IMAGENET_INCEPTION_MEAN, 'std': IMAGENET_INCEPTION_STD,
'first_conv': 'features.0.conv', 'classifier': 'last_linear', 'first_conv': 'features.0.conv', 'classifier': 'last_linear',
'label_offset': 1, # 1001 classes in pretrained weights
} }
} }
@ -241,7 +242,7 @@ class InceptionC(nn.Module):
class InceptionV4(nn.Module): class InceptionV4(nn.Module):
def __init__(self, num_classes=1001, in_chans=3, output_stride=32, drop_rate=0., global_pool='avg'): def __init__(self, num_classes=1000, in_chans=3, output_stride=32, drop_rate=0., global_pool='avg'):
super(InceptionV4, self).__init__() super(InceptionV4, self).__init__()
assert output_stride == 32 assert output_stride == 32
self.drop_rate = drop_rate self.drop_rate = drop_rate

@ -12,7 +12,7 @@ from .conv_bn_act import ConvBnAct
from .create_act import create_act_layer, get_act_layer, get_act_fn from .create_act import create_act_layer, get_act_layer, get_act_fn
from .create_attn import get_attn, create_attn from .create_attn import get_attn, create_attn
from .create_conv2d import create_conv2d from .create_conv2d import create_conv2d
from .create_norm_act import create_norm_act, get_norm_act_layer from .create_norm_act import get_norm_act_layer, create_norm_act, convert_norm_act
from .drop import DropBlock2d, DropPath, drop_block_2d, drop_path from .drop import DropBlock2d, DropPath, drop_block_2d, drop_path
from .eca import EcaModule, CecaModule from .eca import EcaModule, CecaModule
from .evo_norm import EvoNormBatch2d, EvoNormSample2d from .evo_norm import EvoNormBatch2d, EvoNormSample2d

@ -5,23 +5,23 @@ Hacked together by / Copyright 2020 Ross Wightman
from torch import nn as nn from torch import nn as nn
from .create_conv2d import create_conv2d from .create_conv2d import create_conv2d
from .create_norm_act import convert_norm_act_type from .create_norm_act import convert_norm_act
class ConvBnAct(nn.Module): class ConvBnAct(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, padding='', dilation=1, groups=1, def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, padding='', dilation=1, groups=1,
norm_layer=nn.BatchNorm2d, norm_kwargs=None, act_layer=nn.ReLU, apply_act=True, bias=False, apply_act=True, norm_layer=nn.BatchNorm2d, act_layer=nn.ReLU, aa_layer=None,
drop_block=None, aa_layer=None): drop_block=None):
super(ConvBnAct, self).__init__() super(ConvBnAct, self).__init__()
use_aa = aa_layer is not None use_aa = aa_layer is not None
self.conv = create_conv2d( self.conv = create_conv2d(
in_channels, out_channels, kernel_size, stride=1 if use_aa else stride, in_channels, out_channels, kernel_size, stride=1 if use_aa else stride,
padding=padding, dilation=dilation, groups=groups, bias=False) padding=padding, dilation=dilation, groups=groups, bias=bias)
# NOTE for backwards compatibility with models that use separate norm and act layer definitions # NOTE for backwards compatibility with models that use separate norm and act layer definitions
norm_act_layer, norm_act_args = convert_norm_act_type(norm_layer, act_layer, norm_kwargs) norm_act_layer = convert_norm_act(norm_layer, act_layer)
self.bn = norm_act_layer(out_channels, apply_act=apply_act, drop_block=drop_block, **norm_act_args) self.bn = norm_act_layer(out_channels, apply_act=apply_act, drop_block=drop_block)
self.aa = aa_layer(channels=out_channels) if stride == 2 and use_aa else None self.aa = aa_layer(channels=out_channels) if stride == 2 and use_aa else None
@property @property

@ -9,6 +9,8 @@ from .cbam import CbamModule, LightCbamModule
def get_attn(attn_type): def get_attn(attn_type):
if isinstance(attn_type, torch.nn.Module):
return attn_type
module_cls = None module_cls = None
if attn_type is not None: if attn_type is not None:
if isinstance(attn_type, str): if isinstance(attn_type, str):

@ -19,6 +19,7 @@ from .inplace_abn import InplaceAbn
_NORM_ACT_TYPES = {BatchNormAct2d, GroupNormAct, EvoNormBatch2d, EvoNormSample2d, InplaceAbn} _NORM_ACT_TYPES = {BatchNormAct2d, GroupNormAct, EvoNormBatch2d, EvoNormSample2d, InplaceAbn}
_NORM_ACT_REQUIRES_ARG = {BatchNormAct2d, GroupNormAct, InplaceAbn} # requires act_layer arg to define act type _NORM_ACT_REQUIRES_ARG = {BatchNormAct2d, GroupNormAct, InplaceAbn} # requires act_layer arg to define act type
def get_norm_act_layer(layer_class): def get_norm_act_layer(layer_class):
layer_class = layer_class.replace('_', '').lower() layer_class = layer_class.replace('_', '').lower()
if layer_class.startswith("batchnorm"): if layer_class.startswith("batchnorm"):
@ -47,16 +48,22 @@ def create_norm_act(layer_type, num_features, apply_act=True, jit=False, **kwarg
return layer_instance return layer_instance
def convert_norm_act_type(norm_layer, act_layer, norm_kwargs=None): def convert_norm_act(norm_layer, act_layer):
assert isinstance(norm_layer, (type, str, types.FunctionType, functools.partial)) assert isinstance(norm_layer, (type, str, types.FunctionType, functools.partial))
assert act_layer is None or isinstance(act_layer, (type, str, types.FunctionType, functools.partial)) assert act_layer is None or isinstance(act_layer, (type, str, types.FunctionType, functools.partial))
norm_act_args = norm_kwargs.copy() if norm_kwargs else {} norm_act_kwargs = {}
# unbind partial fn, so args can be rebound later
if isinstance(norm_layer, functools.partial):
norm_act_kwargs.update(norm_layer.keywords)
norm_layer = norm_layer.func
if isinstance(norm_layer, str): if isinstance(norm_layer, str):
norm_act_layer = get_norm_act_layer(norm_layer) norm_act_layer = get_norm_act_layer(norm_layer)
elif norm_layer in _NORM_ACT_TYPES: elif norm_layer in _NORM_ACT_TYPES:
norm_act_layer = norm_layer norm_act_layer = norm_layer
elif isinstance(norm_layer, (types.FunctionType, functools.partial)): elif isinstance(norm_layer, types.FunctionType):
# assuming this is a lambda/fn/bound partial that creates norm_act layer # if function type, must be a lambda/fn that creates a norm_act layer
norm_act_layer = norm_layer norm_act_layer = norm_layer
else: else:
type_name = norm_layer.__name__.lower() type_name = norm_layer.__name__.lower()
@ -66,9 +73,11 @@ def convert_norm_act_type(norm_layer, act_layer, norm_kwargs=None):
norm_act_layer = GroupNormAct norm_act_layer = GroupNormAct
else: else:
assert False, f"No equivalent norm_act layer for {type_name}" assert False, f"No equivalent norm_act layer for {type_name}"
if norm_act_layer in _NORM_ACT_REQUIRES_ARG: if norm_act_layer in _NORM_ACT_REQUIRES_ARG:
# Must pass `act_layer` through for backwards compat where `act_layer=None` implies no activation. # pass `act_layer` through for backwards compat where `act_layer=None` implies no activation.
# In the future, may force use of `apply_act` with `act_layer` arg bound to relevant NormAct types # In the future, may force use of `apply_act` with `act_layer` arg bound to relevant NormAct types
# It is intended that functions/partial does not trigger this, they should define act. norm_act_kwargs.setdefault('act_layer', act_layer)
norm_act_args.update(dict(act_layer=act_layer)) if norm_act_kwargs:
return norm_act_layer, norm_act_args norm_act_layer = functools.partial(norm_act_layer, **norm_act_kwargs) # bind/rebind args
return norm_act_layer

@ -24,7 +24,7 @@ class BatchNormAct2d(nn.BatchNorm2d):
act_args = dict(inplace=True) if inplace else {} act_args = dict(inplace=True) if inplace else {}
self.act = act_layer(**act_args) self.act = act_layer(**act_args)
else: else:
self.act = None self.act = nn.Identity()
def _forward_jit(self, x): def _forward_jit(self, x):
""" A cut & paste of the contents of the PyTorch BatchNorm2d forward function """ A cut & paste of the contents of the PyTorch BatchNorm2d forward function
@ -62,7 +62,6 @@ class BatchNormAct2d(nn.BatchNorm2d):
x = self._forward_jit(x) x = self._forward_jit(x)
else: else:
x = self._forward_python(x) x = self._forward_python(x)
if self.act is not None:
x = self.act(x) x = self.act(x)
return x return x
@ -75,12 +74,12 @@ class GroupNormAct(nn.GroupNorm):
if isinstance(act_layer, str): if isinstance(act_layer, str):
act_layer = get_act_layer(act_layer) act_layer = get_act_layer(act_layer)
if act_layer is not None and apply_act: if act_layer is not None and apply_act:
self.act = act_layer(inplace=inplace) act_args = dict(inplace=True) if inplace else {}
self.act = act_layer(**act_args)
else: else:
self.act = None self.act = nn.Identity()
def forward(self, x): def forward(self, x):
x = F.group_norm(x, self.num_groups, self.weight, self.bias, self.eps) x = F.group_norm(x, self.num_groups, self.weight, self.bias, self.eps)
if self.act is not None:
x = self.act(x) x = self.act(x)
return x return x

@ -8,17 +8,16 @@ Hacked together by / Copyright 2020 Ross Wightman
from torch import nn as nn from torch import nn as nn
from .create_conv2d import create_conv2d from .create_conv2d import create_conv2d
from .create_norm_act import convert_norm_act_type from .create_norm_act import convert_norm_act
class SeparableConvBnAct(nn.Module): class SeparableConvBnAct(nn.Module):
""" Separable Conv w/ trailing Norm and Activation """ Separable Conv w/ trailing Norm and Activation
""" """
def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, dilation=1, padding='', bias=False, def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, dilation=1, padding='', bias=False,
channel_multiplier=1.0, pw_kernel_size=1, norm_layer=nn.BatchNorm2d, norm_kwargs=None, channel_multiplier=1.0, pw_kernel_size=1, norm_layer=nn.BatchNorm2d, act_layer=nn.ReLU,
act_layer=nn.ReLU, apply_act=True, drop_block=None): apply_act=True, drop_block=None):
super(SeparableConvBnAct, self).__init__() super(SeparableConvBnAct, self).__init__()
norm_kwargs = norm_kwargs or {}
self.conv_dw = create_conv2d( self.conv_dw = create_conv2d(
in_channels, int(in_channels * channel_multiplier), kernel_size, in_channels, int(in_channels * channel_multiplier), kernel_size,
@ -27,8 +26,8 @@ class SeparableConvBnAct(nn.Module):
self.conv_pw = create_conv2d( self.conv_pw = create_conv2d(
int(in_channels * channel_multiplier), out_channels, pw_kernel_size, padding=padding, bias=bias) int(in_channels * channel_multiplier), out_channels, pw_kernel_size, padding=padding, bias=bias)
norm_act_layer, norm_act_args = convert_norm_act_type(norm_layer, act_layer, norm_kwargs) norm_act_layer = convert_norm_act(norm_layer, act_layer)
self.bn = norm_act_layer(out_channels, apply_act=apply_act, drop_block=drop_block, **norm_act_args) self.bn = norm_act_layer(out_channels, apply_act=apply_act, drop_block=drop_block)
@property @property
def in_channels(self): def in_channels(self):

@ -1,6 +1,9 @@
""" NasNet-A (Large)
nasnetalarge implementation grabbed from Cadene's pretrained models
https://github.com/Cadene/pretrained-models.pytorch
""" """
from functools import partial
"""
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
@ -20,9 +23,10 @@ default_cfgs = {
'interpolation': 'bicubic', 'interpolation': 'bicubic',
'mean': (0.5, 0.5, 0.5), 'mean': (0.5, 0.5, 0.5),
'std': (0.5, 0.5, 0.5), 'std': (0.5, 0.5, 0.5),
'num_classes': 1001, 'num_classes': 1000,
'first_conv': 'conv0.conv', 'first_conv': 'conv0.conv',
'classifier': 'last_linear', 'classifier': 'last_linear',
'label_offset': 1, # 1001 classes in pretrained weights
}, },
} }
@ -418,7 +422,7 @@ class NASNetALarge(nn.Module):
self.conv0 = ConvBnAct( self.conv0 = ConvBnAct(
in_channels=in_chans, out_channels=self.stem_size, kernel_size=3, padding=0, stride=2, in_channels=in_chans, out_channels=self.stem_size, kernel_size=3, padding=0, stride=2,
norm_kwargs=dict(eps=0.001, momentum=0.1), act_layer=None) norm_layer=partial(nn.BatchNorm2d, eps=0.001, momentum=0.1), apply_act=False)
self.cell_stem_0 = CellStem0( self.cell_stem_0 = CellStem0(
self.stem_size, num_channels=channels // (channel_multiplier ** 2), pad_type=pad_type) self.stem_size, num_channels=channels // (channel_multiplier ** 2), pad_type=pad_type)

@ -6,6 +6,7 @@
""" """
from collections import OrderedDict from collections import OrderedDict
from functools import partial
import torch import torch
import torch.nn as nn import torch.nn as nn
@ -26,9 +27,10 @@ default_cfgs = {
'interpolation': 'bicubic', 'interpolation': 'bicubic',
'mean': (0.5, 0.5, 0.5), 'mean': (0.5, 0.5, 0.5),
'std': (0.5, 0.5, 0.5), 'std': (0.5, 0.5, 0.5),
'num_classes': 1001, 'num_classes': 1000,
'first_conv': 'conv_0.conv', 'first_conv': 'conv_0.conv',
'classifier': 'last_linear', 'classifier': 'last_linear',
'label_offset': 1, # 1001 classes in pretrained weights
}, },
} }
@ -234,7 +236,7 @@ class Cell(CellBase):
class PNASNet5Large(nn.Module): class PNASNet5Large(nn.Module):
def __init__(self, num_classes=1001, in_chans=3, output_stride=32, drop_rate=0., global_pool='avg', pad_type=''): def __init__(self, num_classes=1000, in_chans=3, output_stride=32, drop_rate=0., global_pool='avg', pad_type=''):
super(PNASNet5Large, self).__init__() super(PNASNet5Large, self).__init__()
self.num_classes = num_classes self.num_classes = num_classes
self.drop_rate = drop_rate self.drop_rate = drop_rate
@ -243,7 +245,7 @@ class PNASNet5Large(nn.Module):
self.conv_0 = ConvBnAct( self.conv_0 = ConvBnAct(
in_chans, 96, kernel_size=3, stride=2, padding=0, in_chans, 96, kernel_size=3, stride=2, padding=0,
norm_kwargs=dict(eps=0.001, momentum=0.1), act_layer=None) norm_layer=partial(nn.BatchNorm2d, eps=0.001, momentum=0.1), apply_act=False)
self.cell_stem_0 = CellStem0( self.cell_stem_0 = CellStem0(
in_chs_left=96, out_chs_left=54, in_chs_right=96, out_chs_right=54, pad_type=pad_type) in_chs_left=96, out_chs_left=54, in_chs_right=96, out_chs_right=54, pad_type=pad_type)

@ -5,7 +5,7 @@ https://github.com/tensorflow/models/blob/master/research/deeplab/g3doc/model_zo
Hacked together by / Copyright 2020 Ross Wightman Hacked together by / Copyright 2020 Ross Wightman
""" """
from collections import OrderedDict from functools import partial
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
@ -43,9 +43,8 @@ default_cfgs = dict(
class SeparableConv2d(nn.Module): class SeparableConv2d(nn.Module):
def __init__( def __init__(
self, inplanes, planes, kernel_size=3, stride=1, dilation=1, padding='', self, inplanes, planes, kernel_size=3, stride=1, dilation=1, padding='',
act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, norm_kwargs=None): act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d):
super(SeparableConv2d, self).__init__() super(SeparableConv2d, self).__init__()
norm_kwargs = norm_kwargs if norm_kwargs is not None else {}
self.kernel_size = kernel_size self.kernel_size = kernel_size
self.dilation = dilation self.dilation = dilation
@ -53,7 +52,7 @@ class SeparableConv2d(nn.Module):
self.conv_dw = create_conv2d( self.conv_dw = create_conv2d(
inplanes, inplanes, kernel_size, stride=stride, inplanes, inplanes, kernel_size, stride=stride,
padding=padding, dilation=dilation, depthwise=True) padding=padding, dilation=dilation, depthwise=True)
self.bn_dw = norm_layer(inplanes, **norm_kwargs) self.bn_dw = norm_layer(inplanes)
if act_layer is not None: if act_layer is not None:
self.act_dw = act_layer(inplace=True) self.act_dw = act_layer(inplace=True)
else: else:
@ -61,7 +60,7 @@ class SeparableConv2d(nn.Module):
# pointwise convolution # pointwise convolution
self.conv_pw = create_conv2d(inplanes, planes, kernel_size=1) self.conv_pw = create_conv2d(inplanes, planes, kernel_size=1)
self.bn_pw = norm_layer(planes, **norm_kwargs) self.bn_pw = norm_layer(planes)
if act_layer is not None: if act_layer is not None:
self.act_pw = act_layer(inplace=True) self.act_pw = act_layer(inplace=True)
else: else:
@ -82,17 +81,15 @@ class SeparableConv2d(nn.Module):
class XceptionModule(nn.Module): class XceptionModule(nn.Module):
def __init__( def __init__(
self, in_chs, out_chs, stride=1, dilation=1, pad_type='', self, in_chs, out_chs, stride=1, dilation=1, pad_type='',
start_with_relu=True, no_skip=False, act_layer=nn.ReLU, norm_layer=None, norm_kwargs=None): start_with_relu=True, no_skip=False, act_layer=nn.ReLU, norm_layer=None):
super(XceptionModule, self).__init__() super(XceptionModule, self).__init__()
norm_kwargs = norm_kwargs if norm_kwargs is not None else {}
out_chs = to_3tuple(out_chs) out_chs = to_3tuple(out_chs)
self.in_channels = in_chs self.in_channels = in_chs
self.out_channels = out_chs[-1] self.out_channels = out_chs[-1]
self.no_skip = no_skip self.no_skip = no_skip
if not no_skip and (self.out_channels != self.in_channels or stride != 1): if not no_skip and (self.out_channels != self.in_channels or stride != 1):
self.shortcut = ConvBnAct( self.shortcut = ConvBnAct(
in_chs, self.out_channels, 1, stride=stride, in_chs, self.out_channels, 1, stride=stride, norm_layer=norm_layer, act_layer=None)
norm_layer=norm_layer, norm_kwargs=norm_kwargs, act_layer=None)
else: else:
self.shortcut = None self.shortcut = None
@ -103,7 +100,7 @@ class XceptionModule(nn.Module):
self.stack.add_module(f'act{i + 1}', nn.ReLU(inplace=i > 0)) self.stack.add_module(f'act{i + 1}', nn.ReLU(inplace=i > 0))
self.stack.add_module(f'conv{i + 1}', SeparableConv2d( self.stack.add_module(f'conv{i + 1}', SeparableConv2d(
in_chs, out_chs[i], 3, stride=stride if i == 2 else 1, dilation=dilation, padding=pad_type, in_chs, out_chs[i], 3, stride=stride if i == 2 else 1, dilation=dilation, padding=pad_type,
act_layer=separable_act_layer, norm_layer=norm_layer, norm_kwargs=norm_kwargs)) act_layer=separable_act_layer, norm_layer=norm_layer))
in_chs = out_chs[i] in_chs = out_chs[i]
def forward(self, x): def forward(self, x):
@ -121,14 +118,13 @@ class XceptionAligned(nn.Module):
""" """
def __init__(self, block_cfg, num_classes=1000, in_chans=3, output_stride=32, def __init__(self, block_cfg, num_classes=1000, in_chans=3, output_stride=32,
act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, norm_kwargs=None, drop_rate=0., global_pool='avg'): act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, drop_rate=0., global_pool='avg'):
super(XceptionAligned, self).__init__() super(XceptionAligned, self).__init__()
self.num_classes = num_classes self.num_classes = num_classes
self.drop_rate = drop_rate self.drop_rate = drop_rate
assert output_stride in (8, 16, 32) assert output_stride in (8, 16, 32)
norm_kwargs = norm_kwargs if norm_kwargs is not None else {}
layer_args = dict(act_layer=act_layer, norm_layer=norm_layer, norm_kwargs=norm_kwargs) layer_args = dict(act_layer=act_layer, norm_layer=norm_layer)
self.stem = nn.Sequential(*[ self.stem = nn.Sequential(*[
ConvBnAct(in_chans, 32, kernel_size=3, stride=2, **layer_args), ConvBnAct(in_chans, 32, kernel_size=3, stride=2, **layer_args),
ConvBnAct(32, 64, kernel_size=3, stride=1, **layer_args) ConvBnAct(32, 64, kernel_size=3, stride=1, **layer_args)
@ -196,7 +192,7 @@ def xception41(pretrained=False, **kwargs):
dict(in_chs=728, out_chs=(728, 1024, 1024), stride=2), dict(in_chs=728, out_chs=(728, 1024, 1024), stride=2),
dict(in_chs=1024, out_chs=(1536, 1536, 2048), stride=1, no_skip=True, start_with_relu=False), dict(in_chs=1024, out_chs=(1536, 1536, 2048), stride=1, no_skip=True, start_with_relu=False),
] ]
model_args = dict(block_cfg=block_cfg, norm_kwargs=dict(eps=.001, momentum=.1), **kwargs) model_args = dict(block_cfg=block_cfg, norm_layer=partial(nn.BatchNorm2d, eps=.001, momentum=.1), **kwargs)
return _xception('xception41', pretrained=pretrained, **model_args) return _xception('xception41', pretrained=pretrained, **model_args)
@ -215,7 +211,7 @@ def xception65(pretrained=False, **kwargs):
dict(in_chs=728, out_chs=(728, 1024, 1024), stride=2), dict(in_chs=728, out_chs=(728, 1024, 1024), stride=2),
dict(in_chs=1024, out_chs=(1536, 1536, 2048), stride=1, no_skip=True, start_with_relu=False), dict(in_chs=1024, out_chs=(1536, 1536, 2048), stride=1, no_skip=True, start_with_relu=False),
] ]
model_args = dict(block_cfg=block_cfg, norm_kwargs=dict(eps=.001, momentum=.1), **kwargs) model_args = dict(block_cfg=block_cfg, norm_layer=partial(nn.BatchNorm2d, eps=.001, momentum=.1), **kwargs)
return _xception('xception65', pretrained=pretrained, **model_args) return _xception('xception65', pretrained=pretrained, **model_args)
@ -236,5 +232,5 @@ def xception71(pretrained=False, **kwargs):
dict(in_chs=728, out_chs=(728, 1024, 1024), stride=2), dict(in_chs=728, out_chs=(728, 1024, 1024), stride=2),
dict(in_chs=1024, out_chs=(1536, 1536, 2048), stride=1, no_skip=True, start_with_relu=False), dict(in_chs=1024, out_chs=(1536, 1536, 2048), stride=1, no_skip=True, start_with_relu=False),
] ]
model_args = dict(block_cfg=block_cfg, norm_kwargs=dict(eps=.001, momentum=.1), **kwargs) model_args = dict(block_cfg=block_cfg, norm_layer=partial(nn.BatchNorm2d, eps=.001, momentum=.1), **kwargs)
return _xception('xception71', pretrained=pretrained, **model_args) return _xception('xception71', pretrained=pretrained, **model_args)

Loading…
Cancel
Save