Fix regression in models with 1001 class pretrained weights. Improve batchnorm arg and BatchNormAct layer handling in several models.

pull/419/head
Ross Wightman 4 years ago
parent aaa715b1e9
commit 9811e229f7

@ -83,7 +83,6 @@ def test_model_default_cfgs(model_name, batch_size):
cfg = model.default_cfg
classifier = cfg['classifier']
first_conv = cfg['first_conv']
pool_size = cfg['pool_size']
input_size = model.default_cfg['input_size']
@ -111,9 +110,16 @@ def test_model_default_cfgs(model_name, batch_size):
# FIXME mobilenetv3 forward_features vs removed pooling differ
assert outputs.shape[-1] == pool_size[-1] and outputs.shape[-2] == pool_size[-2]
# check classifier and first convolution names match those in default_cfg
# check classifier name matches default_cfg
assert classifier + ".weight" in state_dict.keys(), f'{classifier} not in model params'
assert first_conv + ".weight" in state_dict.keys(), f'{first_conv} not in model params'
# check first conv(s) names match default_cfg
first_conv = cfg['first_conv']
if isinstance(first_conv, str):
first_conv = (first_conv,)
assert isinstance(first_conv, (tuple, list))
for fc in first_conv:
assert fc + ".weight" in state_dict.keys(), f'{fc} not in model params'
if 'GITHUB_ACTIONS' not in os.environ:

@ -7,6 +7,7 @@ This implementation is compatible with the pretrained weights from cypw's MXNet
Hacked together by / Copyright 2020 Ross Wightman
"""
from collections import OrderedDict
from functools import partial
from typing import Tuple
import torch
@ -173,12 +174,14 @@ class DPN(nn.Module):
self.drop_rate = drop_rate
self.b = b
assert output_stride == 32 # FIXME look into dilation support
norm_layer = partial(BatchNormAct2d, eps=.001)
fc_norm_layer = partial(BatchNormAct2d, eps=.001, act_layer=fc_act, inplace=False)
bw_factor = 1 if small else 4
blocks = OrderedDict()
# conv1
blocks['conv1_1'] = ConvBnAct(
in_chans, num_init_features, kernel_size=3 if small else 7, stride=2, norm_kwargs=dict(eps=.001))
in_chans, num_init_features, kernel_size=3 if small else 7, stride=2, norm_layer=norm_layer)
blocks['conv1_pool'] = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.feature_info = [dict(num_chs=num_init_features, reduction=2, module='features.conv1_1')]
@ -226,8 +229,7 @@ class DPN(nn.Module):
in_chs += inc
self.feature_info += [dict(num_chs=in_chs, reduction=32, module=f'features.conv5_{k_sec[3]}')]
def _fc_norm(f, eps): return BatchNormAct2d(f, eps=eps, act_layer=fc_act, inplace=False)
blocks['conv5_bn_ac'] = CatBnAct(in_chs, norm_layer=_fc_norm)
blocks['conv5_bn_ac'] = CatBnAct(in_chs, norm_layer=fc_norm_layer)
self.num_features = in_chs
self.features = nn.Sequential(blocks)

@ -42,10 +42,8 @@ for Tensorflow 'SAME' padding. PyTorch symmetric padding behaves the way we'd w
class SeparableConv2d(nn.Module):
def __init__(self, inplanes, planes, kernel_size=3, stride=1,
dilation=1, bias=False, norm_layer=None, norm_kwargs=None):
def __init__(self, inplanes, planes, kernel_size=3, stride=1, dilation=1, bias=False, norm_layer=None):
super(SeparableConv2d, self).__init__()
norm_kwargs = norm_kwargs if norm_kwargs is not None else {}
self.kernel_size = kernel_size
self.dilation = dilation
@ -54,7 +52,7 @@ class SeparableConv2d(nn.Module):
self.conv_dw = nn.Conv2d(
inplanes, inplanes, kernel_size, stride=stride,
padding=padding, dilation=dilation, groups=inplanes, bias=bias)
self.bn = norm_layer(num_features=inplanes, **norm_kwargs)
self.bn = norm_layer(num_features=inplanes)
# pointwise convolution
self.conv_pw = nn.Conv2d(inplanes, planes, kernel_size=1, bias=bias)
@ -66,10 +64,8 @@ class SeparableConv2d(nn.Module):
class Block(nn.Module):
def __init__(self, inplanes, planes, stride=1, dilation=1, start_with_relu=True,
norm_layer=None, norm_kwargs=None, ):
def __init__(self, inplanes, planes, stride=1, dilation=1, start_with_relu=True, norm_layer=None):
super(Block, self).__init__()
norm_kwargs = norm_kwargs if norm_kwargs is not None else {}
if isinstance(planes, (list, tuple)):
assert len(planes) == 3
else:
@ -80,7 +76,7 @@ class Block(nn.Module):
self.skip = nn.Sequential()
self.skip.add_module('conv1', nn.Conv2d(
inplanes, outplanes, 1, stride=stride, bias=False)),
self.skip.add_module('bn1', norm_layer(num_features=outplanes, **norm_kwargs))
self.skip.add_module('bn1', norm_layer(num_features=outplanes))
else:
self.skip = None
@ -88,9 +84,8 @@ class Block(nn.Module):
for i in range(3):
rep['act%d' % (i + 1)] = nn.ReLU(inplace=True)
rep['conv%d' % (i + 1)] = SeparableConv2d(
inplanes, planes[i], 3, stride=stride if i == 2 else 1, dilation=dilation,
norm_layer=norm_layer, norm_kwargs=norm_kwargs)
rep['bn%d' % (i + 1)] = norm_layer(planes[i], **norm_kwargs)
inplanes, planes[i], 3, stride=stride if i == 2 else 1, dilation=dilation, norm_layer=norm_layer)
rep['bn%d' % (i + 1)] = norm_layer(planes[i])
inplanes = planes[i]
if not start_with_relu:
@ -115,74 +110,63 @@ class Xception65(nn.Module):
"""
def __init__(self, num_classes=1000, in_chans=3, output_stride=32, norm_layer=nn.BatchNorm2d,
norm_kwargs=None, drop_rate=0., global_pool='avg'):
drop_rate=0., global_pool='avg'):
super(Xception65, self).__init__()
self.num_classes = num_classes
self.drop_rate = drop_rate
norm_kwargs = norm_kwargs if norm_kwargs is not None else {}
if output_stride == 32:
entry_block3_stride = 2
exit_block20_stride = 2
middle_block_dilation = 1
exit_block_dilations = (1, 1)
middle_dilation = 1
exit_dilation = (1, 1)
elif output_stride == 16:
entry_block3_stride = 2
exit_block20_stride = 1
middle_block_dilation = 1
exit_block_dilations = (1, 2)
middle_dilation = 1
exit_dilation = (1, 2)
elif output_stride == 8:
entry_block3_stride = 1
exit_block20_stride = 1
middle_block_dilation = 2
exit_block_dilations = (2, 4)
middle_dilation = 2
exit_dilation = (2, 4)
else:
raise NotImplementedError
# Entry flow
self.conv1 = nn.Conv2d(in_chans, 32, kernel_size=3, stride=2, padding=1, bias=False)
self.bn1 = norm_layer(num_features=32, **norm_kwargs)
self.bn1 = norm_layer(num_features=32)
self.act1 = nn.ReLU(inplace=True)
self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1, bias=False)
self.bn2 = norm_layer(num_features=64)
self.act2 = nn.ReLU(inplace=True)
self.block1 = Block(
64, 128, stride=2, start_with_relu=False, norm_layer=norm_layer, norm_kwargs=norm_kwargs)
self.block1 = Block(64, 128, stride=2, start_with_relu=False, norm_layer=norm_layer)
self.block1_act = nn.ReLU(inplace=True)
self.block2 = Block(
128, 256, stride=2, start_with_relu=False, norm_layer=norm_layer, norm_kwargs=norm_kwargs)
self.block3 = Block(
256, 728, stride=entry_block3_stride, norm_layer=norm_layer, norm_kwargs=norm_kwargs)
self.block2 = Block(128, 256, stride=2, start_with_relu=False, norm_layer=norm_layer)
self.block3 = Block(256, 728, stride=entry_block3_stride, norm_layer=norm_layer)
# Middle flow
self.mid = nn.Sequential(OrderedDict([('block%d' % i, Block(
728, 728, stride=1, dilation=middle_block_dilation,
norm_layer=norm_layer, norm_kwargs=norm_kwargs)) for i in range(4, 20)]))
728, 728, stride=1, dilation=middle_dilation, norm_layer=norm_layer)) for i in range(4, 20)]))
# Exit flow
self.block20 = Block(
728, (728, 1024, 1024), stride=exit_block20_stride, dilation=exit_block_dilations[0],
norm_layer=norm_layer, norm_kwargs=norm_kwargs)
728, (728, 1024, 1024), stride=exit_block20_stride, dilation=exit_dilation[0], norm_layer=norm_layer)
self.block20_act = nn.ReLU(inplace=True)
self.conv3 = SeparableConv2d(
1024, 1536, 3, stride=1, dilation=exit_block_dilations[1],
norm_layer=norm_layer, norm_kwargs=norm_kwargs)
self.bn3 = norm_layer(num_features=1536, **norm_kwargs)
self.conv3 = SeparableConv2d(1024, 1536, 3, stride=1, dilation=exit_dilation[1], norm_layer=norm_layer)
self.bn3 = norm_layer(num_features=1536)
self.act3 = nn.ReLU(inplace=True)
self.conv4 = SeparableConv2d(
1536, 1536, 3, stride=1, dilation=exit_block_dilations[1],
norm_layer=norm_layer, norm_kwargs=norm_kwargs)
self.bn4 = norm_layer(num_features=1536, **norm_kwargs)
self.conv4 = SeparableConv2d(1536, 1536, 3, stride=1, dilation=exit_dilation[1], norm_layer=norm_layer)
self.bn4 = norm_layer(num_features=1536)
self.act4 = nn.ReLU(inplace=True)
self.num_features = 2048
self.conv5 = SeparableConv2d(
1536, self.num_features, 3, stride=1, dilation=exit_block_dilations[1],
norm_layer=norm_layer, norm_kwargs=norm_kwargs)
self.bn5 = norm_layer(num_features=self.num_features, **norm_kwargs)
1536, self.num_features, 3, stride=1, dilation=exit_dilation[1], norm_layer=norm_layer)
self.bn5 = norm_layer(num_features=self.num_features)
self.act5 = nn.ReLU(inplace=True)
self.feature_info = [
dict(num_chs=64, reduction=2, module='act2'),

@ -148,6 +148,31 @@ def load_custom_pretrained(model, cfg=None, load_fn=None, progress=False, check_
_logger.warning("Valid function to load pretrained weights is not available, using random initialization.")
def adapt_input_conv(in_chans, conv_weight):
conv_type = conv_weight.dtype
conv_weight = conv_weight.float() # Some weights are in torch.half, ensure it's float for sum on CPU
O, I, J, K = conv_weight.shape
if in_chans == 1:
if I > 3:
assert conv_weight.shape[1] % 3 == 0
# For models with space2depth stems
conv_weight = conv_weight.reshape(O, I // 3, 3, J, K)
conv_weight = conv_weight.sum(dim=2, keepdim=False)
else:
conv_weight = conv_weight.sum(dim=1, keepdim=True)
elif in_chans != 3:
if I != 3:
raise NotImplementedError('Weight format not supported by conversion.')
else:
# NOTE this strategy should be better than random init, but there could be other combinations of
# the original RGB input layer weights that'd work better for specific cases.
repeat = int(math.ceil(in_chans / 3))
conv_weight = conv_weight.repeat(1, repeat, 1, 1)[:, :in_chans, :, :]
conv_weight *= (3 / float(in_chans))
conv_weight = conv_weight.to(conv_type)
return conv_weight
def load_pretrained(model, cfg=None, num_classes=1000, in_chans=3, filter_fn=None, strict=True, progress=False):
if cfg is None:
cfg = getattr(model, 'default_cfg')
@ -159,56 +184,35 @@ def load_pretrained(model, cfg=None, num_classes=1000, in_chans=3, filter_fn=Non
if filter_fn is not None:
state_dict = filter_fn(state_dict)
if in_chans == 1:
conv1_name = cfg['first_conv']
_logger.info('Converting first conv (%s) pretrained weights from 3 to 1 channel' % conv1_name)
conv1_weight = state_dict[conv1_name + '.weight']
# Some weights are in torch.half, ensure it's float for sum on CPU
conv1_type = conv1_weight.dtype
conv1_weight = conv1_weight.float()
O, I, J, K = conv1_weight.shape
if I > 3:
assert conv1_weight.shape[1] % 3 == 0
# For models with space2depth stems
conv1_weight = conv1_weight.reshape(O, I // 3, 3, J, K)
conv1_weight = conv1_weight.sum(dim=2, keepdim=False)
else:
conv1_weight = conv1_weight.sum(dim=1, keepdim=True)
conv1_weight = conv1_weight.to(conv1_type)
state_dict[conv1_name + '.weight'] = conv1_weight
elif in_chans != 3:
conv1_name = cfg['first_conv']
conv1_weight = state_dict[conv1_name + '.weight']
conv1_type = conv1_weight.dtype
conv1_weight = conv1_weight.float()
O, I, J, K = conv1_weight.shape
if I != 3:
_logger.warning('Deleting first conv (%s) from pretrained weights.' % conv1_name)
del state_dict[conv1_name + '.weight']
input_convs = cfg.get('first_conv', None)
if input_convs is not None:
if isinstance(input_convs, str):
input_convs = (input_convs,)
for input_conv_name in input_convs:
weight_name = input_conv_name + '.weight'
try:
state_dict[weight_name] = adapt_input_conv(in_chans, state_dict[weight_name])
_logger.info(
f'Converted input conv {input_conv_name} pretrained weights from 3 to {in_chans} channel(s)')
except NotImplementedError as e:
del state_dict[weight_name]
strict = False
else:
# NOTE this strategy should be better than random init, but there could be other combinations of
# the original RGB input layer weights that'd work better for specific cases.
_logger.info('Repeating first conv (%s) weights in channel dim.' % conv1_name)
repeat = int(math.ceil(in_chans / 3))
conv1_weight = conv1_weight.repeat(1, repeat, 1, 1)[:, :in_chans, :, :]
conv1_weight *= (3 / float(in_chans))
conv1_weight = conv1_weight.to(conv1_type)
state_dict[conv1_name + '.weight'] = conv1_weight
_logger.warning(
f'Unable to convert pretrained {input_conv_name} weights, using random init for this layer.')
classifier_name = cfg['classifier']
if num_classes == 1000 and cfg['num_classes'] == 1001:
# FIXME this special case is problematic as number of pretrained weight sources increases
# special case for imagenet trained models with extra background class in pretrained weights
classifier_weight = state_dict[classifier_name + '.weight']
state_dict[classifier_name + '.weight'] = classifier_weight[1:]
classifier_bias = state_dict[classifier_name + '.bias']
state_dict[classifier_name + '.bias'] = classifier_bias[1:]
elif num_classes != cfg['num_classes']:
# completely discard fully connected for all other differences between pretrained and created model
label_offset = cfg.get('label_offset', 0)
if num_classes != cfg['num_classes']:
# completely discard fully connected if model num_classes doesn't match pretrained weights
del state_dict[classifier_name + '.weight']
del state_dict[classifier_name + '.bias']
strict = False
elif label_offset > 0:
# special case for pretrained weights with an extra background class in pretrained weights
classifier_weight = state_dict[classifier_name + '.weight']
state_dict[classifier_name + '.weight'] = classifier_weight[label_offset:]
classifier_bias = state_dict[classifier_name + '.bias']
state_dict[classifier_name + '.bias'] = classifier_bias[label_offset:]
model.load_state_dict(state_dict, strict=strict)

@ -17,18 +17,20 @@ default_cfgs = {
# ported from http://download.tensorflow.org/models/inception_resnet_v2_2016_08_30.tar.gz
'inception_resnet_v2': {
'url': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/inception_resnet_v2-940b1cd6.pth',
'num_classes': 1001, 'input_size': (3, 299, 299), 'pool_size': (8, 8),
'num_classes': 1000, 'input_size': (3, 299, 299), 'pool_size': (8, 8),
'crop_pct': 0.8975, 'interpolation': 'bicubic',
'mean': IMAGENET_INCEPTION_MEAN, 'std': IMAGENET_INCEPTION_STD,
'first_conv': 'conv2d_1a.conv', 'classifier': 'classif',
'label_offset': 1, # 1001 classes in pretrained weights
},
# ported from http://download.tensorflow.org/models/ens_adv_inception_resnet_v2_2017_08_18.tar.gz
'ens_adv_inception_resnet_v2': {
'url': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/ens_adv_inception_resnet_v2-2592a550.pth',
'num_classes': 1001, 'input_size': (3, 299, 299), 'pool_size': (8, 8),
'num_classes': 1000, 'input_size': (3, 299, 299), 'pool_size': (8, 8),
'crop_pct': 0.8975, 'interpolation': 'bicubic',
'mean': IMAGENET_INCEPTION_MEAN, 'std': IMAGENET_INCEPTION_STD,
'first_conv': 'conv2d_1a.conv', 'classifier': 'classif',
'label_offset': 1, # 1001 classes in pretrained weights
}
}
@ -222,7 +224,7 @@ class Block8(nn.Module):
class InceptionResnetV2(nn.Module):
def __init__(self, num_classes=1001, in_chans=3, drop_rate=0., output_stride=32, global_pool='avg'):
def __init__(self, num_classes=1000, in_chans=3, drop_rate=0., output_stride=32, global_pool='avg'):
super(InceptionResnetV2, self).__init__()
self.drop_rate = drop_rate
self.num_classes = num_classes

@ -16,10 +16,11 @@ __all__ = ['InceptionV4']
default_cfgs = {
'inception_v4': {
'url': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-cadene/inceptionv4-8e4777a0.pth',
'num_classes': 1001, 'input_size': (3, 299, 299), 'pool_size': (8, 8),
'num_classes': 1000, 'input_size': (3, 299, 299), 'pool_size': (8, 8),
'crop_pct': 0.875, 'interpolation': 'bicubic',
'mean': IMAGENET_INCEPTION_MEAN, 'std': IMAGENET_INCEPTION_STD,
'first_conv': 'features.0.conv', 'classifier': 'last_linear',
'label_offset': 1, # 1001 classes in pretrained weights
}
}
@ -241,7 +242,7 @@ class InceptionC(nn.Module):
class InceptionV4(nn.Module):
def __init__(self, num_classes=1001, in_chans=3, output_stride=32, drop_rate=0., global_pool='avg'):
def __init__(self, num_classes=1000, in_chans=3, output_stride=32, drop_rate=0., global_pool='avg'):
super(InceptionV4, self).__init__()
assert output_stride == 32
self.drop_rate = drop_rate

@ -12,7 +12,7 @@ from .conv_bn_act import ConvBnAct
from .create_act import create_act_layer, get_act_layer, get_act_fn
from .create_attn import get_attn, create_attn
from .create_conv2d import create_conv2d
from .create_norm_act import create_norm_act, get_norm_act_layer
from .create_norm_act import get_norm_act_layer, create_norm_act, convert_norm_act
from .drop import DropBlock2d, DropPath, drop_block_2d, drop_path
from .eca import EcaModule, CecaModule
from .evo_norm import EvoNormBatch2d, EvoNormSample2d

@ -5,23 +5,23 @@ Hacked together by / Copyright 2020 Ross Wightman
from torch import nn as nn
from .create_conv2d import create_conv2d
from .create_norm_act import convert_norm_act_type
from .create_norm_act import convert_norm_act
class ConvBnAct(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, padding='', dilation=1, groups=1,
norm_layer=nn.BatchNorm2d, norm_kwargs=None, act_layer=nn.ReLU, apply_act=True,
drop_block=None, aa_layer=None):
bias=False, apply_act=True, norm_layer=nn.BatchNorm2d, act_layer=nn.ReLU, aa_layer=None,
drop_block=None):
super(ConvBnAct, self).__init__()
use_aa = aa_layer is not None
self.conv = create_conv2d(
in_channels, out_channels, kernel_size, stride=1 if use_aa else stride,
padding=padding, dilation=dilation, groups=groups, bias=False)
padding=padding, dilation=dilation, groups=groups, bias=bias)
# NOTE for backwards compatibility with models that use separate norm and act layer definitions
norm_act_layer, norm_act_args = convert_norm_act_type(norm_layer, act_layer, norm_kwargs)
self.bn = norm_act_layer(out_channels, apply_act=apply_act, drop_block=drop_block, **norm_act_args)
norm_act_layer = convert_norm_act(norm_layer, act_layer)
self.bn = norm_act_layer(out_channels, apply_act=apply_act, drop_block=drop_block)
self.aa = aa_layer(channels=out_channels) if stride == 2 and use_aa else None
@property

@ -9,6 +9,8 @@ from .cbam import CbamModule, LightCbamModule
def get_attn(attn_type):
if isinstance(attn_type, torch.nn.Module):
return attn_type
module_cls = None
if attn_type is not None:
if isinstance(attn_type, str):

@ -19,6 +19,7 @@ from .inplace_abn import InplaceAbn
_NORM_ACT_TYPES = {BatchNormAct2d, GroupNormAct, EvoNormBatch2d, EvoNormSample2d, InplaceAbn}
_NORM_ACT_REQUIRES_ARG = {BatchNormAct2d, GroupNormAct, InplaceAbn} # requires act_layer arg to define act type
def get_norm_act_layer(layer_class):
layer_class = layer_class.replace('_', '').lower()
if layer_class.startswith("batchnorm"):
@ -47,16 +48,22 @@ def create_norm_act(layer_type, num_features, apply_act=True, jit=False, **kwarg
return layer_instance
def convert_norm_act_type(norm_layer, act_layer, norm_kwargs=None):
def convert_norm_act(norm_layer, act_layer):
assert isinstance(norm_layer, (type, str, types.FunctionType, functools.partial))
assert act_layer is None or isinstance(act_layer, (type, str, types.FunctionType, functools.partial))
norm_act_args = norm_kwargs.copy() if norm_kwargs else {}
norm_act_kwargs = {}
# unbind partial fn, so args can be rebound later
if isinstance(norm_layer, functools.partial):
norm_act_kwargs.update(norm_layer.keywords)
norm_layer = norm_layer.func
if isinstance(norm_layer, str):
norm_act_layer = get_norm_act_layer(norm_layer)
elif norm_layer in _NORM_ACT_TYPES:
norm_act_layer = norm_layer
elif isinstance(norm_layer, (types.FunctionType, functools.partial)):
# assuming this is a lambda/fn/bound partial that creates norm_act layer
elif isinstance(norm_layer, types.FunctionType):
# if function type, must be a lambda/fn that creates a norm_act layer
norm_act_layer = norm_layer
else:
type_name = norm_layer.__name__.lower()
@ -66,9 +73,11 @@ def convert_norm_act_type(norm_layer, act_layer, norm_kwargs=None):
norm_act_layer = GroupNormAct
else:
assert False, f"No equivalent norm_act layer for {type_name}"
if norm_act_layer in _NORM_ACT_REQUIRES_ARG:
# Must pass `act_layer` through for backwards compat where `act_layer=None` implies no activation.
# pass `act_layer` through for backwards compat where `act_layer=None` implies no activation.
# In the future, may force use of `apply_act` with `act_layer` arg bound to relevant NormAct types
# It is intended that functions/partial does not trigger this, they should define act.
norm_act_args.update(dict(act_layer=act_layer))
return norm_act_layer, norm_act_args
norm_act_kwargs.setdefault('act_layer', act_layer)
if norm_act_kwargs:
norm_act_layer = functools.partial(norm_act_layer, **norm_act_kwargs) # bind/rebind args
return norm_act_layer

@ -24,7 +24,7 @@ class BatchNormAct2d(nn.BatchNorm2d):
act_args = dict(inplace=True) if inplace else {}
self.act = act_layer(**act_args)
else:
self.act = None
self.act = nn.Identity()
def _forward_jit(self, x):
""" A cut & paste of the contents of the PyTorch BatchNorm2d forward function
@ -62,7 +62,6 @@ class BatchNormAct2d(nn.BatchNorm2d):
x = self._forward_jit(x)
else:
x = self._forward_python(x)
if self.act is not None:
x = self.act(x)
return x
@ -75,12 +74,12 @@ class GroupNormAct(nn.GroupNorm):
if isinstance(act_layer, str):
act_layer = get_act_layer(act_layer)
if act_layer is not None and apply_act:
self.act = act_layer(inplace=inplace)
act_args = dict(inplace=True) if inplace else {}
self.act = act_layer(**act_args)
else:
self.act = None
self.act = nn.Identity()
def forward(self, x):
x = F.group_norm(x, self.num_groups, self.weight, self.bias, self.eps)
if self.act is not None:
x = self.act(x)
return x

@ -8,17 +8,16 @@ Hacked together by / Copyright 2020 Ross Wightman
from torch import nn as nn
from .create_conv2d import create_conv2d
from .create_norm_act import convert_norm_act_type
from .create_norm_act import convert_norm_act
class SeparableConvBnAct(nn.Module):
""" Separable Conv w/ trailing Norm and Activation
"""
def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, dilation=1, padding='', bias=False,
channel_multiplier=1.0, pw_kernel_size=1, norm_layer=nn.BatchNorm2d, norm_kwargs=None,
act_layer=nn.ReLU, apply_act=True, drop_block=None):
channel_multiplier=1.0, pw_kernel_size=1, norm_layer=nn.BatchNorm2d, act_layer=nn.ReLU,
apply_act=True, drop_block=None):
super(SeparableConvBnAct, self).__init__()
norm_kwargs = norm_kwargs or {}
self.conv_dw = create_conv2d(
in_channels, int(in_channels * channel_multiplier), kernel_size,
@ -27,8 +26,8 @@ class SeparableConvBnAct(nn.Module):
self.conv_pw = create_conv2d(
int(in_channels * channel_multiplier), out_channels, pw_kernel_size, padding=padding, bias=bias)
norm_act_layer, norm_act_args = convert_norm_act_type(norm_layer, act_layer, norm_kwargs)
self.bn = norm_act_layer(out_channels, apply_act=apply_act, drop_block=drop_block, **norm_act_args)
norm_act_layer = convert_norm_act(norm_layer, act_layer)
self.bn = norm_act_layer(out_channels, apply_act=apply_act, drop_block=drop_block)
@property
def in_channels(self):

@ -1,6 +1,9 @@
""" NasNet-A (Large)
nasnetalarge implementation grabbed from Cadene's pretrained models
https://github.com/Cadene/pretrained-models.pytorch
"""
from functools import partial
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
@ -20,9 +23,10 @@ default_cfgs = {
'interpolation': 'bicubic',
'mean': (0.5, 0.5, 0.5),
'std': (0.5, 0.5, 0.5),
'num_classes': 1001,
'num_classes': 1000,
'first_conv': 'conv0.conv',
'classifier': 'last_linear',
'label_offset': 1, # 1001 classes in pretrained weights
},
}
@ -418,7 +422,7 @@ class NASNetALarge(nn.Module):
self.conv0 = ConvBnAct(
in_channels=in_chans, out_channels=self.stem_size, kernel_size=3, padding=0, stride=2,
norm_kwargs=dict(eps=0.001, momentum=0.1), act_layer=None)
norm_layer=partial(nn.BatchNorm2d, eps=0.001, momentum=0.1), apply_act=False)
self.cell_stem_0 = CellStem0(
self.stem_size, num_channels=channels // (channel_multiplier ** 2), pad_type=pad_type)

@ -6,6 +6,7 @@
"""
from collections import OrderedDict
from functools import partial
import torch
import torch.nn as nn
@ -26,9 +27,10 @@ default_cfgs = {
'interpolation': 'bicubic',
'mean': (0.5, 0.5, 0.5),
'std': (0.5, 0.5, 0.5),
'num_classes': 1001,
'num_classes': 1000,
'first_conv': 'conv_0.conv',
'classifier': 'last_linear',
'label_offset': 1, # 1001 classes in pretrained weights
},
}
@ -234,7 +236,7 @@ class Cell(CellBase):
class PNASNet5Large(nn.Module):
def __init__(self, num_classes=1001, in_chans=3, output_stride=32, drop_rate=0., global_pool='avg', pad_type=''):
def __init__(self, num_classes=1000, in_chans=3, output_stride=32, drop_rate=0., global_pool='avg', pad_type=''):
super(PNASNet5Large, self).__init__()
self.num_classes = num_classes
self.drop_rate = drop_rate
@ -243,7 +245,7 @@ class PNASNet5Large(nn.Module):
self.conv_0 = ConvBnAct(
in_chans, 96, kernel_size=3, stride=2, padding=0,
norm_kwargs=dict(eps=0.001, momentum=0.1), act_layer=None)
norm_layer=partial(nn.BatchNorm2d, eps=0.001, momentum=0.1), apply_act=False)
self.cell_stem_0 = CellStem0(
in_chs_left=96, out_chs_left=54, in_chs_right=96, out_chs_right=54, pad_type=pad_type)

@ -5,7 +5,7 @@ https://github.com/tensorflow/models/blob/master/research/deeplab/g3doc/model_zo
Hacked together by / Copyright 2020 Ross Wightman
"""
from collections import OrderedDict
from functools import partial
import torch.nn as nn
import torch.nn.functional as F
@ -43,9 +43,8 @@ default_cfgs = dict(
class SeparableConv2d(nn.Module):
def __init__(
self, inplanes, planes, kernel_size=3, stride=1, dilation=1, padding='',
act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, norm_kwargs=None):
act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d):
super(SeparableConv2d, self).__init__()
norm_kwargs = norm_kwargs if norm_kwargs is not None else {}
self.kernel_size = kernel_size
self.dilation = dilation
@ -53,7 +52,7 @@ class SeparableConv2d(nn.Module):
self.conv_dw = create_conv2d(
inplanes, inplanes, kernel_size, stride=stride,
padding=padding, dilation=dilation, depthwise=True)
self.bn_dw = norm_layer(inplanes, **norm_kwargs)
self.bn_dw = norm_layer(inplanes)
if act_layer is not None:
self.act_dw = act_layer(inplace=True)
else:
@ -61,7 +60,7 @@ class SeparableConv2d(nn.Module):
# pointwise convolution
self.conv_pw = create_conv2d(inplanes, planes, kernel_size=1)
self.bn_pw = norm_layer(planes, **norm_kwargs)
self.bn_pw = norm_layer(planes)
if act_layer is not None:
self.act_pw = act_layer(inplace=True)
else:
@ -82,17 +81,15 @@ class SeparableConv2d(nn.Module):
class XceptionModule(nn.Module):
def __init__(
self, in_chs, out_chs, stride=1, dilation=1, pad_type='',
start_with_relu=True, no_skip=False, act_layer=nn.ReLU, norm_layer=None, norm_kwargs=None):
start_with_relu=True, no_skip=False, act_layer=nn.ReLU, norm_layer=None):
super(XceptionModule, self).__init__()
norm_kwargs = norm_kwargs if norm_kwargs is not None else {}
out_chs = to_3tuple(out_chs)
self.in_channels = in_chs
self.out_channels = out_chs[-1]
self.no_skip = no_skip
if not no_skip and (self.out_channels != self.in_channels or stride != 1):
self.shortcut = ConvBnAct(
in_chs, self.out_channels, 1, stride=stride,
norm_layer=norm_layer, norm_kwargs=norm_kwargs, act_layer=None)
in_chs, self.out_channels, 1, stride=stride, norm_layer=norm_layer, act_layer=None)
else:
self.shortcut = None
@ -103,7 +100,7 @@ class XceptionModule(nn.Module):
self.stack.add_module(f'act{i + 1}', nn.ReLU(inplace=i > 0))
self.stack.add_module(f'conv{i + 1}', SeparableConv2d(
in_chs, out_chs[i], 3, stride=stride if i == 2 else 1, dilation=dilation, padding=pad_type,
act_layer=separable_act_layer, norm_layer=norm_layer, norm_kwargs=norm_kwargs))
act_layer=separable_act_layer, norm_layer=norm_layer))
in_chs = out_chs[i]
def forward(self, x):
@ -121,14 +118,13 @@ class XceptionAligned(nn.Module):
"""
def __init__(self, block_cfg, num_classes=1000, in_chans=3, output_stride=32,
act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, norm_kwargs=None, drop_rate=0., global_pool='avg'):
act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, drop_rate=0., global_pool='avg'):
super(XceptionAligned, self).__init__()
self.num_classes = num_classes
self.drop_rate = drop_rate
assert output_stride in (8, 16, 32)
norm_kwargs = norm_kwargs if norm_kwargs is not None else {}
layer_args = dict(act_layer=act_layer, norm_layer=norm_layer, norm_kwargs=norm_kwargs)
layer_args = dict(act_layer=act_layer, norm_layer=norm_layer)
self.stem = nn.Sequential(*[
ConvBnAct(in_chans, 32, kernel_size=3, stride=2, **layer_args),
ConvBnAct(32, 64, kernel_size=3, stride=1, **layer_args)
@ -196,7 +192,7 @@ def xception41(pretrained=False, **kwargs):
dict(in_chs=728, out_chs=(728, 1024, 1024), stride=2),
dict(in_chs=1024, out_chs=(1536, 1536, 2048), stride=1, no_skip=True, start_with_relu=False),
]
model_args = dict(block_cfg=block_cfg, norm_kwargs=dict(eps=.001, momentum=.1), **kwargs)
model_args = dict(block_cfg=block_cfg, norm_layer=partial(nn.BatchNorm2d, eps=.001, momentum=.1), **kwargs)
return _xception('xception41', pretrained=pretrained, **model_args)
@ -215,7 +211,7 @@ def xception65(pretrained=False, **kwargs):
dict(in_chs=728, out_chs=(728, 1024, 1024), stride=2),
dict(in_chs=1024, out_chs=(1536, 1536, 2048), stride=1, no_skip=True, start_with_relu=False),
]
model_args = dict(block_cfg=block_cfg, norm_kwargs=dict(eps=.001, momentum=.1), **kwargs)
model_args = dict(block_cfg=block_cfg, norm_layer=partial(nn.BatchNorm2d, eps=.001, momentum=.1), **kwargs)
return _xception('xception65', pretrained=pretrained, **model_args)
@ -236,5 +232,5 @@ def xception71(pretrained=False, **kwargs):
dict(in_chs=728, out_chs=(728, 1024, 1024), stride=2),
dict(in_chs=1024, out_chs=(1536, 1536, 2048), stride=1, no_skip=True, start_with_relu=False),
]
model_args = dict(block_cfg=block_cfg, norm_kwargs=dict(eps=.001, momentum=.1), **kwargs)
model_args = dict(block_cfg=block_cfg, norm_layer=partial(nn.BatchNorm2d, eps=.001, momentum=.1), **kwargs)
return _xception('xception71', pretrained=pretrained, **model_args)

Loading…
Cancel
Save