diff --git a/tests/test_models.py b/tests/test_models.py index 63be6a6e..2babd74a 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -37,8 +37,7 @@ def test_model_forward(model_name, batch_size): @pytest.mark.timeout(120) -# DLA models have an issue TBD, add them to exclusions -@pytest.mark.parametrize('model_name', list_models(exclude_filters=EXCLUDE_FILTERS + ['dla*'])) +@pytest.mark.parametrize('model_name', list_models(exclude_filters=EXCLUDE_FILTERS)) @pytest.mark.parametrize('batch_size', [2]) def test_model_backward(model_name, batch_size): """Run a single forward pass with each model""" diff --git a/timm/models/dla.py b/timm/models/dla.py index d9c5cbc0..212150e6 100644 --- a/timm/models/dla.py +++ b/timm/models/dla.py @@ -12,7 +12,7 @@ import torch.nn as nn import torch.nn.functional as F from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from .helpers import load_pretrained +from .helpers import build_model_with_cfg from .layers import SelectAdaptivePool2d from .registry import register_model @@ -212,10 +212,19 @@ class DlaTree(nn.Module): root_dim = 2 * out_channels if level_root: root_dim += in_channels + self.downsample = nn.MaxPool2d(stride, stride=stride) if stride > 1 else nn.Identity() + self.project = nn.Identity() cargs = dict(dilation=dilation, cardinality=cardinality, base_width=base_width) if levels == 1: self.tree1 = block(in_channels, out_channels, stride, **cargs) self.tree2 = block(out_channels, out_channels, 1, **cargs) + if in_channels != out_channels: + # NOTE the official impl/weights have project layers in levels > 1 case that are never + # used, I've moved the project layer here to avoid wasted params but old checkpoints will + # need strict=False while loading. + self.project = nn.Sequential( + nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, bias=False), + nn.BatchNorm2d(out_channels)) else: cargs.update(dict(root_kernel_size=root_kernel_size, root_residual=root_residual)) self.tree1 = DlaTree( @@ -226,22 +235,12 @@ class DlaTree(nn.Module): self.root = DlaRoot(root_dim, out_channels, root_kernel_size, root_residual) self.level_root = level_root self.root_dim = root_dim - self.downsample = nn.MaxPool2d(stride, stride=stride) if stride > 1 else None - self.project = None - if in_channels != out_channels: - self.project = nn.Sequential( - nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, bias=False), - nn.BatchNorm2d(out_channels) - ) self.levels = levels def forward(self, x, residual=None, children=None): children = [] if children is None else children - # FIXME the way downsample / project are used here and residual is passed to next level up - # the tree, the residual is overridden and some project weights are thus never used and - # have no gradients. This appears to be an issue with the original model / weights. - bottom = self.downsample(x) if self.downsample is not None else x - residual = self.project(bottom) if self.project is not None else bottom + bottom = self.downsample(x) + residual = self.project(bottom) if self.level_root: children.append(bottom) x1 = self.tree1(x, residual) @@ -255,8 +254,8 @@ class DlaTree(nn.Module): class DLA(nn.Module): - def __init__(self, levels, channels, num_classes=1000, in_chans=3, cardinality=1, base_width=64, - block=DlaBottle2neck, residual_root=False, linear_root=False, + def __init__(self, levels, channels, output_stride=32, num_classes=1000, in_chans=3, + cardinality=1, base_width=64, block=DlaBottle2neck, residual_root=False, drop_rate=0.0, global_pool='avg'): super(DLA, self).__init__() self.channels = channels @@ -264,6 +263,7 @@ class DLA(nn.Module): self.cardinality = cardinality self.base_width = base_width self.drop_rate = drop_rate + assert output_stride == 32 # FIXME support dilation self.base_layer = nn.Sequential( nn.Conv2d(in_chans, channels[0], kernel_size=7, stride=1, padding=3, bias=False), @@ -276,6 +276,14 @@ class DLA(nn.Module): self.level3 = DlaTree(levels[3], block, channels[2], channels[3], 2, level_root=True, **cargs) self.level4 = DlaTree(levels[4], block, channels[3], channels[4], 2, level_root=True, **cargs) self.level5 = DlaTree(levels[5], block, channels[4], channels[5], 2, level_root=True, **cargs) + self.feature_info = [ + dict(num_chs=channels[0], reduction=1, module='level0'), # rare to have a meaningful stride 1 level + dict(num_chs=channels[1], reduction=2, module='level1'), + dict(num_chs=channels[2], reduction=4, module='level2'), + dict(num_chs=channels[3], reduction=8, module='level3'), + dict(num_chs=channels[4], reduction=16, module='level4'), + dict(num_chs=channels[5], reduction=32, module='level5'), + ] self.num_features = channels[-1] self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) @@ -331,142 +339,103 @@ class DLA(nn.Module): return x.flatten(1) +def _create_dla(variant, pretrained=False, **kwargs): + return build_model_with_cfg( + DLA, variant, pretrained, default_cfg=default_cfgs[variant], + pretrained_strict=False, feature_cfg=dict(out_indices=(1, 2, 3, 4, 5)), **kwargs) + + @register_model -def dla60_res2net(pretrained=None, num_classes=1000, in_chans=3, **kwargs): - default_cfg = default_cfgs['dla60_res2net'] - model = DLA(levels=(1, 1, 1, 2, 3, 1), channels=(16, 32, 128, 256, 512, 1024), - block=DlaBottle2neck, cardinality=1, base_width=28, - num_classes=num_classes, in_chans=in_chans, **kwargs) - model.default_cfg = default_cfg - if pretrained: - load_pretrained(model, default_cfg, num_classes, in_chans) - return model +def dla60_res2net(pretrained=False, **kwargs): + model_kwargs = dict( + levels=(1, 1, 1, 2, 3, 1), channels=(16, 32, 128, 256, 512, 1024), + block=DlaBottle2neck, cardinality=1, base_width=28, **kwargs) + return _create_dla('dla60_res2net', pretrained, **model_kwargs) @register_model -def dla60_res2next(pretrained=None, num_classes=1000, in_chans=3, **kwargs): - default_cfg = default_cfgs['dla60_res2next'] - model = DLA(levels=(1, 1, 1, 2, 3, 1), channels=(16, 32, 128, 256, 512, 1024), - block=DlaBottle2neck, cardinality=8, base_width=4, - num_classes=num_classes, in_chans=in_chans, **kwargs) - model.default_cfg = default_cfg - if pretrained: - load_pretrained(model, default_cfg, num_classes, in_chans) - return model +def dla60_res2next(pretrained=False,**kwargs): + model_kwargs = dict( + levels=(1, 1, 1, 2, 3, 1), channels=(16, 32, 128, 256, 512, 1024), + block=DlaBottle2neck, cardinality=8, base_width=4, **kwargs) + return _create_dla('dla60_res2next', pretrained, **model_kwargs) @register_model -def dla34(pretrained=None, num_classes=1000, in_chans=3, **kwargs): # DLA-34 - default_cfg = default_cfgs['dla34'] - model = DLA([1, 1, 1, 2, 2, 1], [16, 32, 64, 128, 256, 512], block=DlaBasic, - num_classes=num_classes, in_chans=in_chans, **kwargs) - model.default_cfg = default_cfg - if pretrained: - load_pretrained(model, default_cfg, num_classes, in_chans) - return model +def dla34(pretrained=False, **kwargs): # DLA-34 + model_kwargs = dict( + levels=[1, 1, 1, 2, 2, 1], channels=[16, 32, 64, 128, 256, 512], + block=DlaBasic, **kwargs) + return _create_dla('dla34', pretrained, **model_kwargs) @register_model -def dla46_c(pretrained=None, num_classes=1000, in_chans=3, **kwargs): # DLA-46-C - default_cfg = default_cfgs['dla46_c'] - model = DLA(levels=[1, 1, 1, 2, 2, 1], channels=[16, 32, 64, 64, 128, 256], - block=DlaBottleneck, num_classes=num_classes, in_chans=in_chans, **kwargs) - model.default_cfg = default_cfg - if pretrained: - load_pretrained(model, default_cfg, num_classes, in_chans) - return model +def dla46_c(pretrained=False, **kwargs): # DLA-46-C + model_kwargs = dict( + levels=[1, 1, 1, 2, 2, 1], channels=[16, 32, 64, 64, 128, 256], + block=DlaBottleneck, **kwargs) + return _create_dla('dla46_c', pretrained, **model_kwargs) @register_model -def dla46x_c(pretrained=None, num_classes=1000, in_chans=3, **kwargs): # DLA-X-46-C - default_cfg = default_cfgs['dla46x_c'] - model = DLA(levels=[1, 1, 1, 2, 2, 1], channels=[16, 32, 64, 64, 128, 256], - block=DlaBottleneck, cardinality=32, base_width=4, - num_classes=num_classes, in_chans=in_chans, **kwargs) - model.default_cfg = default_cfg - if pretrained: - load_pretrained(model, default_cfg, num_classes, in_chans) - return model +def dla46x_c(pretrained=False, **kwargs): # DLA-X-46-C + model_kwargs = dict( + levels=[1, 1, 1, 2, 2, 1], channels=[16, 32, 64, 64, 128, 256], + block=DlaBottleneck, cardinality=32, base_width=4, **kwargs) + return _create_dla('dla46x_c', pretrained, **model_kwargs) @register_model -def dla60x_c(pretrained=None, num_classes=1000, in_chans=3, **kwargs): # DLA-X-60-C - default_cfg = default_cfgs['dla60x_c'] - model = DLA([1, 1, 1, 2, 3, 1], [16, 32, 64, 64, 128, 256], - block=DlaBottleneck, cardinality=32, base_width=4, - num_classes=num_classes, in_chans=in_chans, **kwargs) - model.default_cfg = default_cfg - if pretrained: - load_pretrained(model, default_cfg, num_classes, in_chans) - return model +def dla60x_c(pretrained=False, **kwargs): # DLA-X-60-C + model_kwargs = dict( + levels=[1, 1, 1, 2, 3, 1], channels=[16, 32, 64, 64, 128, 256], + block=DlaBottleneck, cardinality=32, base_width=4, **kwargs) + return _create_dla('dla60x_c', pretrained, **model_kwargs) @register_model -def dla60(pretrained=None, num_classes=1000, in_chans=3, **kwargs): # DLA-60 - default_cfg = default_cfgs['dla60'] - model = DLA([1, 1, 1, 2, 3, 1], [16, 32, 128, 256, 512, 1024], - block=DlaBottleneck, num_classes=num_classes, in_chans=in_chans, **kwargs) - model.default_cfg = default_cfg - if pretrained: - load_pretrained(model, default_cfg, num_classes, in_chans) - return model +def dla60(pretrained=False, **kwargs): # DLA-60 + model_kwargs = dict( + levels=[1, 1, 1, 2, 3, 1], channels=[16, 32, 128, 256, 512, 1024], + block=DlaBottleneck, **kwargs) + return _create_dla('dla60', pretrained, **model_kwargs) @register_model -def dla60x(pretrained=None, num_classes=1000, in_chans=3, **kwargs): # DLA-X-60 - default_cfg = default_cfgs['dla60x'] - model = DLA([1, 1, 1, 2, 3, 1], [16, 32, 128, 256, 512, 1024], - block=DlaBottleneck, cardinality=32, base_width=4, - num_classes=num_classes, in_chans=in_chans, **kwargs) - model.default_cfg = default_cfg - if pretrained: - load_pretrained(model, default_cfg, num_classes, in_chans) - return model +def dla60x(pretrained=False, **kwargs): # DLA-X-60 + model_kwargs = dict( + levels=[1, 1, 1, 2, 3, 1], channels=[16, 32, 128, 256, 512, 1024], + block=DlaBottleneck, cardinality=32, base_width=4, **kwargs) + return _create_dla('dla60x', pretrained, **model_kwargs) @register_model -def dla102(pretrained=None, num_classes=1000, in_chans=3, **kwargs): # DLA-102 - default_cfg = default_cfgs['dla102'] - model = DLA([1, 1, 1, 3, 4, 1], [16, 32, 128, 256, 512, 1024], - block=DlaBottleneck, residual_root=True, - num_classes=num_classes, in_chans=in_chans, **kwargs) - model.default_cfg = default_cfg - if pretrained: - load_pretrained(model, default_cfg, num_classes, in_chans) - return model +def dla102(pretrained=False, **kwargs): # DLA-102 + model_kwargs = dict( + levels=[1, 1, 1, 3, 4, 1], channels=[16, 32, 128, 256, 512, 1024], + block=DlaBottleneck, residual_root=True, **kwargs) + return _create_dla('dla102', pretrained, **model_kwargs) @register_model -def dla102x(pretrained=None, num_classes=1000, in_chans=3, **kwargs): # DLA-X-102 - default_cfg = default_cfgs['dla102x'] - model = DLA([1, 1, 1, 3, 4, 1], [16, 32, 128, 256, 512, 1024], - block=DlaBottleneck, cardinality=32, base_width=4, residual_root=True, - num_classes=num_classes, in_chans=in_chans, **kwargs) - model.default_cfg = default_cfg - if pretrained: - load_pretrained(model, default_cfg, num_classes, in_chans) - return model +def dla102x(pretrained=False, **kwargs): # DLA-X-102 + model_kwargs = dict( + levels=[1, 1, 1, 3, 4, 1], channels=[16, 32, 128, 256, 512, 1024], + block=DlaBottleneck, cardinality=32, base_width=4, residual_root=True, **kwargs) + return _create_dla('dla102x', pretrained, **model_kwargs) @register_model -def dla102x2(pretrained=None, num_classes=1000, in_chans=3, **kwargs): # DLA-X-102 64 - default_cfg = default_cfgs['dla102x2'] - model = DLA([1, 1, 1, 3, 4, 1], [16, 32, 128, 256, 512, 1024], - block=DlaBottleneck, cardinality=64, base_width=4, residual_root=True, - num_classes=num_classes, in_chans=in_chans, **kwargs) - model.default_cfg = default_cfg - if pretrained: - load_pretrained(model, default_cfg, num_classes, in_chans) - return model +def dla102x2(pretrained=False, **kwargs): # DLA-X-102 64 + model_kwargs = dict( + levels=[1, 1, 1, 3, 4, 1], channels=[16, 32, 128, 256, 512, 1024], + block=DlaBottleneck, cardinality=64, base_width=4, residual_root=True, **kwargs) + return _create_dla('dla102x2', pretrained, **model_kwargs) @register_model -def dla169(pretrained=None, num_classes=1000, in_chans=3, **kwargs): # DLA-169 - default_cfg = default_cfgs['dla169'] - model = DLA([1, 1, 2, 3, 5, 1], [16, 32, 128, 256, 512, 1024], - block=DlaBottleneck, residual_root=True, - num_classes=num_classes, in_chans=in_chans, **kwargs) - model.default_cfg = default_cfg - if pretrained: - load_pretrained(model, default_cfg, num_classes, in_chans) - return model +def dla169(pretrained=False, **kwargs): # DLA-169 + model_kwargs = dict( + levels=[1, 1, 2, 3, 5, 1], channels=[16, 32, 128, 256, 512, 1024], + block=DlaBottleneck, residual_root=True, **kwargs) + return _create_dla('dla169', pretrained, **model_kwargs) diff --git a/timm/models/gluon_xception.py b/timm/models/gluon_xception.py index 88a61944..da12bf64 100644 --- a/timm/models/gluon_xception.py +++ b/timm/models/gluon_xception.py @@ -12,7 +12,7 @@ import torch.nn as nn import torch.nn.functional as F from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from .helpers import load_pretrained +from .helpers import build_model_with_cfg from .layers import SelectAdaptivePool2d, get_padding from .registry import register_model @@ -141,13 +141,15 @@ class Xception65(nn.Module): # Entry flow self.conv1 = nn.Conv2d(in_chans, 32, kernel_size=3, stride=2, padding=1, bias=False) self.bn1 = norm_layer(num_features=32, **norm_kwargs) - self.relu = nn.ReLU(inplace=True) + self.act1 = nn.ReLU(inplace=True) self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1, bias=False) self.bn2 = norm_layer(num_features=64) + self.act2 = nn.ReLU(inplace=True) self.block1 = Block( 64, 128, stride=2, start_with_relu=False, norm_layer=norm_layer, norm_kwargs=norm_kwargs) + self.block1_act = nn.ReLU(inplace=True) self.block2 = Block( 128, 256, stride=2, start_with_relu=False, norm_layer=norm_layer, norm_kwargs=norm_kwargs) self.block3 = Block( @@ -162,22 +164,34 @@ class Xception65(nn.Module): self.block20 = Block( 728, (728, 1024, 1024), stride=exit_block20_stride, dilation=exit_block_dilations[0], norm_layer=norm_layer, norm_kwargs=norm_kwargs) + self.block20_act = nn.ReLU(inplace=True) self.conv3 = SeparableConv2d( 1024, 1536, 3, stride=1, dilation=exit_block_dilations[1], norm_layer=norm_layer, norm_kwargs=norm_kwargs) self.bn3 = norm_layer(num_features=1536, **norm_kwargs) + self.act3 = nn.ReLU(inplace=True) self.conv4 = SeparableConv2d( 1536, 1536, 3, stride=1, dilation=exit_block_dilations[1], norm_layer=norm_layer, norm_kwargs=norm_kwargs) self.bn4 = norm_layer(num_features=1536, **norm_kwargs) + self.act4 = nn.ReLU(inplace=True) self.num_features = 2048 self.conv5 = SeparableConv2d( 1536, self.num_features, 3, stride=1, dilation=exit_block_dilations[1], norm_layer=norm_layer, norm_kwargs=norm_kwargs) self.bn5 = norm_layer(num_features=self.num_features, **norm_kwargs) + self.act5 = nn.ReLU(inplace=True) + self.feature_info = [ + dict(num_chs=64, reduction=2, module='act2'), + dict(num_chs=128, reduction=4, module='block1_act'), + dict(num_chs=256, reduction=8, module='block3.rep.act1'), + dict(num_chs=728, reduction=16, module='block20.rep.act1'), + dict(num_chs=2048, reduction=32, module='act5'), + ] + self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) self.fc = nn.Linear(self.num_features * self.global_pool.feat_mult(), num_classes) @@ -193,15 +207,14 @@ class Xception65(nn.Module): # Entry flow x = self.conv1(x) x = self.bn1(x) - x = self.relu(x) + x = self.act1(x) x = self.conv2(x) x = self.bn2(x) - x = self.relu(x) + x = self.act2(x) x = self.block1(x) - # add relu here - x = self.relu(x) + x = self.block1_act(x) # c1 = x x = self.block2(x) # c2 = x @@ -213,18 +226,18 @@ class Xception65(nn.Module): # Exit flow x = self.block20(x) - x = self.relu(x) + x = self.block20_act(x) x = self.conv3(x) x = self.bn3(x) - x = self.relu(x) + x = self.act3(x) x = self.conv4(x) x = self.bn4(x) - x = self.relu(x) + x = self.act4(x) x = self.conv5(x) x = self.bn5(x) - x = self.relu(x) + x = self.act5(x) return x def forward(self, x): @@ -236,13 +249,14 @@ class Xception65(nn.Module): return x +def _create_gluon_xception(variant, pretrained=False, **kwargs): + return build_model_with_cfg( + Xception65, variant, pretrained, default_cfg=default_cfgs[variant], + feature_cfg=dict(use_hooks=True), **kwargs) + + @register_model -def gluon_xception65(pretrained=False, num_classes=1000, in_chans=3, **kwargs): +def gluon_xception65(pretrained=False, **kwargs): """ Modified Aligned Xception-65 """ - default_cfg = default_cfgs['gluon_xception65'] - model = Xception65(num_classes=num_classes, in_chans=in_chans, **kwargs) - model.default_cfg = default_cfg - if pretrained: - load_pretrained(model, default_cfg, num_classes, in_chans) - return model + return _create_gluon_xception('gluon_xception65', pretrained, **kwargs) diff --git a/timm/models/helpers.py b/timm/models/helpers.py index b27dceb6..593b7df5 100644 --- a/timm/models/helpers.py +++ b/timm/models/helpers.py @@ -8,7 +8,7 @@ import torch import torch.nn as nn import torch.utils.model_zoo as model_zoo -from .features import FeatureNet +from .features import FeatureNet, FeatureHookNet from .layers import Conv2dSame @@ -207,6 +207,7 @@ def build_model_with_cfg( default_cfg: dict, model_cfg: dict = None, feature_cfg: dict = None, + pretrained_strict: bool = True, pretrained_filter_fn: Callable = None, **kwargs): pruned = kwargs.pop('pruned', False) @@ -230,10 +231,18 @@ def build_model_with_cfg( model, num_classes=kwargs.get('num_classes', 0), in_chans=kwargs.get('in_chans', 3), - filter_fn=pretrained_filter_fn) + filter_fn=pretrained_filter_fn, strict=pretrained_strict) if features: feature_cls = feature_cfg.pop('feature_cls', FeatureNet) + if isinstance(feature_cls, str): + feature_cls = feature_cls.lower() + if feature_cls == 'hook' or feature_cls == 'featurehooknet': + feature_cls = FeatureHookNet + else: + assert False, f'Unknown feature class {feature_cls}' + if feature_cls == FeatureHookNet and hasattr(model, 'reset_classifier'): + model.reset_classifier(0) model = feature_cls(model, **feature_cfg) return model diff --git a/timm/models/hrnet.py b/timm/models/hrnet.py index 23836d3b..f4d47fc2 100644 --- a/timm/models/hrnet.py +++ b/timm/models/hrnet.py @@ -735,6 +735,7 @@ class HighResolutionNet(nn.Module): def _create_hrnet(variant, pretrained, **model_kwargs): + assert not model_kwargs.pop('features_only', False) # feature extraction not figured out yet return build_model_with_cfg( HighResolutionNet, variant, pretrained, default_cfg=default_cfgs[variant], model_cfg=cfg_cls[variant], **model_kwargs) diff --git a/timm/models/inception_v3.py b/timm/models/inception_v3.py index 8a425f4c..aa16cf06 100644 --- a/timm/models/inception_v3.py +++ b/timm/models/inception_v3.py @@ -3,7 +3,7 @@ import torch.nn as nn import torch.nn.functional as F from timm.data import IMAGENET_DEFAULT_STD, IMAGENET_DEFAULT_MEAN, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD -from .helpers import load_pretrained +from .helpers import build_model_with_cfg from .registry import register_model from .layers import trunc_normal_, SelectAdaptivePool2d @@ -44,231 +44,6 @@ default_cfgs = { } -class InceptionV3Aux(nn.Module): - """InceptionV3 with AuxLogits - """ - - def __init__(self, inception_blocks=None, num_classes=1000, in_chans=3, drop_rate=0., global_pool='avg'): - super(InceptionV3Aux, self).__init__() - self.num_classes = num_classes - self.drop_rate = drop_rate - - if inception_blocks is None: - inception_blocks = [ - BasicConv2d, InceptionA, InceptionB, InceptionC, - InceptionD, InceptionE, InceptionAux - ] - assert len(inception_blocks) == 7 - conv_block = inception_blocks[0] - inception_a = inception_blocks[1] - inception_b = inception_blocks[2] - inception_c = inception_blocks[3] - inception_d = inception_blocks[4] - inception_e = inception_blocks[5] - inception_aux = inception_blocks[6] - - self.Conv2d_1a_3x3 = conv_block(in_chans, 32, kernel_size=3, stride=2) - self.Conv2d_2a_3x3 = conv_block(32, 32, kernel_size=3) - self.Conv2d_2b_3x3 = conv_block(32, 64, kernel_size=3, padding=1) - self.Conv2d_3b_1x1 = conv_block(64, 80, kernel_size=1) - self.Conv2d_4a_3x3 = conv_block(80, 192, kernel_size=3) - self.Mixed_5b = inception_a(192, pool_features=32) - self.Mixed_5c = inception_a(256, pool_features=64) - self.Mixed_5d = inception_a(288, pool_features=64) - self.Mixed_6a = inception_b(288) - self.Mixed_6b = inception_c(768, channels_7x7=128) - self.Mixed_6c = inception_c(768, channels_7x7=160) - self.Mixed_6d = inception_c(768, channels_7x7=160) - self.Mixed_6e = inception_c(768, channels_7x7=192) - self.AuxLogits = inception_aux(768, num_classes) - self.Mixed_7a = inception_d(768) - self.Mixed_7b = inception_e(1280) - self.Mixed_7c = inception_e(2048) - - self.num_features = 2048 - self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) - self.fc = nn.Linear(self.num_features * self.global_pool.feat_mult(), num_classes) - - for m in self.modules(): - if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear): - stddev = m.stddev if hasattr(m, 'stddev') else 0.1 - trunc_normal_(m.weight, std=stddev) - elif isinstance(m, nn.BatchNorm2d): - nn.init.constant_(m.weight, 1) - nn.init.constant_(m.bias, 0) - - def forward_features(self, x): - # N x 3 x 299 x 299 - x = self.Conv2d_1a_3x3(x) - # N x 32 x 149 x 149 - x = self.Conv2d_2a_3x3(x) - # N x 32 x 147 x 147 - x = self.Conv2d_2b_3x3(x) - # N x 64 x 147 x 147 - x = F.max_pool2d(x, kernel_size=3, stride=2) - # N x 64 x 73 x 73 - x = self.Conv2d_3b_1x1(x) - # N x 80 x 73 x 73 - x = self.Conv2d_4a_3x3(x) - # N x 192 x 71 x 71 - x = F.max_pool2d(x, kernel_size=3, stride=2) - # N x 192 x 35 x 35 - x = self.Mixed_5b(x) - # N x 256 x 35 x 35 - x = self.Mixed_5c(x) - # N x 288 x 35 x 35 - x = self.Mixed_5d(x) - # N x 288 x 35 x 35 - x = self.Mixed_6a(x) - # N x 768 x 17 x 17 - x = self.Mixed_6b(x) - # N x 768 x 17 x 17 - x = self.Mixed_6c(x) - # N x 768 x 17 x 17 - x = self.Mixed_6d(x) - # N x 768 x 17 x 17 - x = self.Mixed_6e(x) - # N x 768 x 17 x 17 - aux = self.AuxLogits(x) if self.training else None - # N x 768 x 17 x 17 - x = self.Mixed_7a(x) - # N x 1280 x 8 x 8 - x = self.Mixed_7b(x) - # N x 2048 x 8 x 8 - x = self.Mixed_7c(x) - # N x 2048 x 8 x 8 - return x, aux - - def get_classifier(self): - return self.fc - - def reset_classifier(self, num_classes, global_pool='avg'): - self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) - self.num_classes = num_classes - if self.num_classes > 0: - self.fc = nn.Linear(self.num_features * self.global_pool.feat_mult(), num_classes) - else: - self.fc = nn.Identity() - - def forward(self, x): - x, aux = self.forward_features(x) - x = self.global_pool(x).flatten(1) - if self.drop_rate > 0: - x = F.dropout(x, p=self.drop_rate, training=self.training) - x = self.fc(x) - return x, aux - - -class InceptionV3(nn.Module): - """Inception-V3 with no AuxLogits - FIXME two class defs are redundant, but less screwing around with torchsript fussyness and inconsistent returns - """ - - def __init__(self, inception_blocks=None, num_classes=1000, in_chans=3, drop_rate=0., global_pool='avg'): - super(InceptionV3, self).__init__() - self.num_classes = num_classes - self.drop_rate = drop_rate - - if inception_blocks is None: - inception_blocks = [ - BasicConv2d, InceptionA, InceptionB, InceptionC, InceptionD, InceptionE] - assert len(inception_blocks) >= 6 - conv_block = inception_blocks[0] - inception_a = inception_blocks[1] - inception_b = inception_blocks[2] - inception_c = inception_blocks[3] - inception_d = inception_blocks[4] - inception_e = inception_blocks[5] - - self.Conv2d_1a_3x3 = conv_block(in_chans, 32, kernel_size=3, stride=2) - self.Conv2d_2a_3x3 = conv_block(32, 32, kernel_size=3) - self.Conv2d_2b_3x3 = conv_block(32, 64, kernel_size=3, padding=1) - self.Conv2d_3b_1x1 = conv_block(64, 80, kernel_size=1) - self.Conv2d_4a_3x3 = conv_block(80, 192, kernel_size=3) - self.Mixed_5b = inception_a(192, pool_features=32) - self.Mixed_5c = inception_a(256, pool_features=64) - self.Mixed_5d = inception_a(288, pool_features=64) - self.Mixed_6a = inception_b(288) - self.Mixed_6b = inception_c(768, channels_7x7=128) - self.Mixed_6c = inception_c(768, channels_7x7=160) - self.Mixed_6d = inception_c(768, channels_7x7=160) - self.Mixed_6e = inception_c(768, channels_7x7=192) - self.Mixed_7a = inception_d(768) - self.Mixed_7b = inception_e(1280) - self.Mixed_7c = inception_e(2048) - - self.num_features = 2048 - self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) - self.fc = nn.Linear(2048, num_classes) - - for m in self.modules(): - if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear): - stddev = m.stddev if hasattr(m, 'stddev') else 0.1 - trunc_normal_(m.weight, std=stddev) - elif isinstance(m, nn.BatchNorm2d): - nn.init.constant_(m.weight, 1) - nn.init.constant_(m.bias, 0) - - def forward_features(self, x): - # N x 3 x 299 x 299 - x = self.Conv2d_1a_3x3(x) - # N x 32 x 149 x 149 - x = self.Conv2d_2a_3x3(x) - # N x 32 x 147 x 147 - x = self.Conv2d_2b_3x3(x) - # N x 64 x 147 x 147 - x = F.max_pool2d(x, kernel_size=3, stride=2) - # N x 64 x 73 x 73 - x = self.Conv2d_3b_1x1(x) - # N x 80 x 73 x 73 - x = self.Conv2d_4a_3x3(x) - # N x 192 x 71 x 71 - x = F.max_pool2d(x, kernel_size=3, stride=2) - # N x 192 x 35 x 35 - x = self.Mixed_5b(x) - # N x 256 x 35 x 35 - x = self.Mixed_5c(x) - # N x 288 x 35 x 35 - x = self.Mixed_5d(x) - # N x 288 x 35 x 35 - x = self.Mixed_6a(x) - # N x 768 x 17 x 17 - x = self.Mixed_6b(x) - # N x 768 x 17 x 17 - x = self.Mixed_6c(x) - # N x 768 x 17 x 17 - x = self.Mixed_6d(x) - # N x 768 x 17 x 17 - x = self.Mixed_6e(x) - # N x 768 x 17 x 17 - x = self.Mixed_7a(x) - # N x 1280 x 8 x 8 - x = self.Mixed_7b(x) - # N x 2048 x 8 x 8 - x = self.Mixed_7c(x) - # N x 2048 x 8 x 8 - return x - - def get_classifier(self): - return self.fc - - def reset_classifier(self, num_classes, global_pool='avg'): - self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) - self.num_classes = num_classes - if self.num_classes > 0: - self.fc = nn.Linear(self.num_features * self.global_pool.feat_mult(), num_classes) - else: - self.fc = nn.Identity() - - def forward(self, x): - x = self.forward_features(x) - x = self.global_pool(x).flatten(1) - if self.drop_rate > 0: - x = F.dropout(x, p=self.drop_rate, training=self.training) - x = self.fc(x) - return x - - class InceptionA(nn.Module): def __init__(self, in_channels, pool_features, conv_block=None): @@ -504,26 +279,163 @@ class BasicConv2d(nn.Module): return F.relu(x, inplace=True) +class InceptionV3(nn.Module): + """Inception-V3 with no AuxLogits + FIXME two class defs are redundant, but less screwing around with torchsript fussyness and inconsistent returns + """ + + def __init__(self, num_classes=1000, in_chans=3, drop_rate=0., global_pool='avg', aux_logits=False): + super(InceptionV3, self).__init__() + self.num_classes = num_classes + self.drop_rate = drop_rate + self.aux_logits = aux_logits + + self.Conv2d_1a_3x3 = BasicConv2d(in_chans, 32, kernel_size=3, stride=2) + self.Conv2d_2a_3x3 = BasicConv2d(32, 32, kernel_size=3) + self.Conv2d_2b_3x3 = BasicConv2d(32, 64, kernel_size=3, padding=1) + self.Pool1 = nn.MaxPool2d(kernel_size=3, stride=2) + self.Conv2d_3b_1x1 = BasicConv2d(64, 80, kernel_size=1) + self.Conv2d_4a_3x3 = BasicConv2d(80, 192, kernel_size=3) + self.Pool2 = nn.MaxPool2d(kernel_size=3, stride=2) + self.Mixed_5b = InceptionA(192, pool_features=32) + self.Mixed_5c = InceptionA(256, pool_features=64) + self.Mixed_5d = InceptionA(288, pool_features=64) + self.Mixed_6a = InceptionB(288) + self.Mixed_6b = InceptionC(768, channels_7x7=128) + self.Mixed_6c = InceptionC(768, channels_7x7=160) + self.Mixed_6d = InceptionC(768, channels_7x7=160) + self.Mixed_6e = InceptionC(768, channels_7x7=192) + if aux_logits: + self.AuxLogits = InceptionAux(768, num_classes) + else: + self.AuxLogits = None + self.Mixed_7a = InceptionD(768) + self.Mixed_7b = InceptionE(1280) + self.Mixed_7c = InceptionE(2048) + self.feature_info = [ + dict(num_chs=64, reduction=2, module='Conv2d_2b_3x3'), + dict(num_chs=192, reduction=4, module='Conv2d_4a_3x3'), + dict(num_chs=288, reduction=8, module='Mixed_5d'), + dict(num_chs=768, reduction=16, module='Mixed_6e'), + dict(num_chs=2048, reduction=32, module='Mixed_7c'), + ] + + self.num_features = 2048 + self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) + self.fc = nn.Linear(2048, num_classes) + + for m in self.modules(): + if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear): + stddev = m.stddev if hasattr(m, 'stddev') else 0.1 + trunc_normal_(m.weight, std=stddev) + elif isinstance(m, nn.BatchNorm2d): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + def forward_preaux(self, x): + # N x 3 x 299 x 299 + x = self.Conv2d_1a_3x3(x) + # N x 32 x 149 x 149 + x = self.Conv2d_2a_3x3(x) + # N x 32 x 147 x 147 + x = self.Conv2d_2b_3x3(x) + # N x 64 x 147 x 147 + x = self.Pool1(x) + # N x 64 x 73 x 73 + x = self.Conv2d_3b_1x1(x) + # N x 80 x 73 x 73 + x = self.Conv2d_4a_3x3(x) + # N x 192 x 71 x 71 + x = self.Pool2(x) + # N x 192 x 35 x 35 + x = self.Mixed_5b(x) + # N x 256 x 35 x 35 + x = self.Mixed_5c(x) + # N x 288 x 35 x 35 + x = self.Mixed_5d(x) + # N x 288 x 35 x 35 + x = self.Mixed_6a(x) + # N x 768 x 17 x 17 + x = self.Mixed_6b(x) + # N x 768 x 17 x 17 + x = self.Mixed_6c(x) + # N x 768 x 17 x 17 + x = self.Mixed_6d(x) + # N x 768 x 17 x 17 + x = self.Mixed_6e(x) + # N x 768 x 17 x 17 + return x + + def forward_postaux(self, x): + x = self.Mixed_7a(x) + # N x 1280 x 8 x 8 + x = self.Mixed_7b(x) + # N x 2048 x 8 x 8 + x = self.Mixed_7c(x) + # N x 2048 x 8 x 8 + return x + + def forward_features(self, x): + x = self.forward_preaux(x) + x = self.forward_postaux(x) + return x + + def get_classifier(self): + return self.fc + + def reset_classifier(self, num_classes, global_pool='avg'): + self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) + self.num_classes = num_classes + if self.num_classes > 0: + self.fc = nn.Linear(self.num_features * self.global_pool.feat_mult(), num_classes) + else: + self.fc = nn.Identity() + + def forward(self, x): + x = self.forward_features(x) + x = self.global_pool(x).flatten(1) + if self.drop_rate > 0: + x = F.dropout(x, p=self.drop_rate, training=self.training) + x = self.fc(x) + return x + + +class InceptionV3Aux(InceptionV3): + """InceptionV3 with AuxLogits + """ + + def __init__(self, num_classes=1000, in_chans=3, drop_rate=0., global_pool='avg', aux_logits=True): + super(InceptionV3Aux, self).__init__( + num_classes, in_chans, drop_rate, global_pool, aux_logits) + + def forward_features(self, x): + x = self.forward_preaux(x) + aux = self.AuxLogits(x) if self.training else None + x = self.forward_postaux(x) + return x, aux + + def forward(self, x): + x, aux = self.forward_features(x) + x = self.global_pool(x).flatten(1) + if self.drop_rate > 0: + x = F.dropout(x, p=self.drop_rate, training=self.training) + x = self.fc(x) + return x, aux + + def _create_inception_v3(variant, pretrained=False, **kwargs): - assert not kwargs.pop('features_only', False) default_cfg = default_cfgs[variant] aux_logits = kwargs.pop('aux_logits', False) if aux_logits: - model_class = InceptionV3Aux + assert not kwargs.pop('features_only', False) + model_cls = InceptionV3Aux load_strict = default_cfg['has_aux'] else: - model_class = InceptionV3 + model_cls = InceptionV3 load_strict = not default_cfg['has_aux'] - - model = model_class(**kwargs) - model.default_cfg = default_cfg - if pretrained: - load_pretrained( - model, - num_classes=kwargs.get('num_classes', 0), - in_chans=kwargs.get('in_chans', 3), - strict=load_strict) - return model + return build_model_with_cfg( + model_cls, variant, pretrained, default_cfg=default_cfgs[variant], + pretrained_strict=load_strict, **kwargs) @register_model diff --git a/timm/models/inception_v4.py b/timm/models/inception_v4.py index 52b5ef47..d74354bd 100644 --- a/timm/models/inception_v4.py +++ b/timm/models/inception_v4.py @@ -7,7 +7,7 @@ import torch.nn as nn import torch.nn.functional as F from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD -from .helpers import load_pretrained +from .helpers import build_model_with_cfg from .layers import SelectAdaptivePool2d from .registry import register_model @@ -39,9 +39,9 @@ class BasicConv2d(nn.Module): return x -class Mixed_3a(nn.Module): +class Mixed3a(nn.Module): def __init__(self): - super(Mixed_3a, self).__init__() + super(Mixed3a, self).__init__() self.maxpool = nn.MaxPool2d(3, stride=2) self.conv = BasicConv2d(64, 96, kernel_size=3, stride=2) @@ -52,9 +52,9 @@ class Mixed_3a(nn.Module): return out -class Mixed_4a(nn.Module): +class Mixed4a(nn.Module): def __init__(self): - super(Mixed_4a, self).__init__() + super(Mixed4a, self).__init__() self.branch0 = nn.Sequential( BasicConv2d(160, 64, kernel_size=1, stride=1), @@ -75,9 +75,9 @@ class Mixed_4a(nn.Module): return out -class Mixed_5a(nn.Module): +class Mixed5a(nn.Module): def __init__(self): - super(Mixed_5a, self).__init__() + super(Mixed5a, self).__init__() self.conv = BasicConv2d(192, 192, kernel_size=3, stride=2) self.maxpool = nn.MaxPool2d(3, stride=2) @@ -88,9 +88,9 @@ class Mixed_5a(nn.Module): return out -class Inception_A(nn.Module): +class InceptionA(nn.Module): def __init__(self): - super(Inception_A, self).__init__() + super(InceptionA, self).__init__() self.branch0 = BasicConv2d(384, 96, kernel_size=1, stride=1) self.branch1 = nn.Sequential( @@ -118,9 +118,9 @@ class Inception_A(nn.Module): return out -class Reduction_A(nn.Module): +class ReductionA(nn.Module): def __init__(self): - super(Reduction_A, self).__init__() + super(ReductionA, self).__init__() self.branch0 = BasicConv2d(384, 384, kernel_size=3, stride=2) self.branch1 = nn.Sequential( @@ -139,9 +139,9 @@ class Reduction_A(nn.Module): return out -class Inception_B(nn.Module): +class InceptionB(nn.Module): def __init__(self): - super(Inception_B, self).__init__() + super(InceptionB, self).__init__() self.branch0 = BasicConv2d(1024, 384, kernel_size=1, stride=1) self.branch1 = nn.Sequential( @@ -172,9 +172,9 @@ class Inception_B(nn.Module): return out -class Reduction_B(nn.Module): +class ReductionB(nn.Module): def __init__(self): - super(Reduction_B, self).__init__() + super(ReductionB, self).__init__() self.branch0 = nn.Sequential( BasicConv2d(1024, 192, kernel_size=1, stride=1), @@ -198,9 +198,9 @@ class Reduction_B(nn.Module): return out -class Inception_C(nn.Module): +class InceptionC(nn.Module): def __init__(self): - super(Inception_C, self).__init__() + super(InceptionC, self).__init__() self.branch0 = BasicConv2d(1536, 256, kernel_size=1, stride=1) @@ -241,8 +241,9 @@ class Inception_C(nn.Module): class InceptionV4(nn.Module): - def __init__(self, num_classes=1001, in_chans=3, drop_rate=0., global_pool='avg'): + def __init__(self, num_classes=1001, in_chans=3, output_stride=32, drop_rate=0., global_pool='avg'): super(InceptionV4, self).__init__() + assert output_stride == 32 self.drop_rate = drop_rate self.num_classes = num_classes self.num_features = 1536 @@ -251,26 +252,33 @@ class InceptionV4(nn.Module): BasicConv2d(in_chans, 32, kernel_size=3, stride=2), BasicConv2d(32, 32, kernel_size=3, stride=1), BasicConv2d(32, 64, kernel_size=3, stride=1, padding=1), - Mixed_3a(), - Mixed_4a(), - Mixed_5a(), - Inception_A(), - Inception_A(), - Inception_A(), - Inception_A(), - Reduction_A(), # Mixed_6a - Inception_B(), - Inception_B(), - Inception_B(), - Inception_B(), - Inception_B(), - Inception_B(), - Inception_B(), - Reduction_B(), # Mixed_7a - Inception_C(), - Inception_C(), - Inception_C(), + Mixed3a(), + Mixed4a(), + Mixed5a(), + InceptionA(), + InceptionA(), + InceptionA(), + InceptionA(), + ReductionA(), # Mixed6a + InceptionB(), + InceptionB(), + InceptionB(), + InceptionB(), + InceptionB(), + InceptionB(), + InceptionB(), + ReductionB(), # Mixed7a + InceptionC(), + InceptionC(), + InceptionC(), ) + self.feature_info = [ + dict(num_chs=64, reduction=2, module='features.2'), + dict(num_chs=160, reduction=4, module='features.3'), + dict(num_chs=384, reduction=8, module='features.9'), + dict(num_chs=1024, reduction=16, module='features.17'), + dict(num_chs=1536, reduction=32, module='features.21'), + ] self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) self.last_linear = nn.Linear(self.num_features * self.global_pool.feat_mult(), num_classes) @@ -298,11 +306,12 @@ class InceptionV4(nn.Module): return x +def _create_inception_v4(variant, pretrained=False, **kwargs): + return build_model_with_cfg( + InceptionV4, variant, pretrained, default_cfg=default_cfgs[variant], + feature_cfg=dict(flatten_sequential=True), **kwargs) + + @register_model -def inception_v4(pretrained=False, num_classes=1000, in_chans=3, **kwargs): - default_cfg = default_cfgs['inception_v4'] - model = InceptionV4(num_classes=num_classes, in_chans=in_chans, **kwargs) - model.default_cfg = default_cfg - if pretrained: - load_pretrained(model, default_cfg, num_classes, in_chans) - return model +def inception_v4(pretrained=False, **kwargs): + return _create_inception_v4('inception_v4', pretrained, **kwargs) diff --git a/timm/models/nasnet.py b/timm/models/nasnet.py index 24e3f2a8..27c59ecd 100644 --- a/timm/models/nasnet.py +++ b/timm/models/nasnet.py @@ -1,8 +1,11 @@ +""" + +""" import torch import torch.nn as nn import torch.nn.functional as F -from .helpers import load_pretrained +from .helpers import build_model_with_cfg from .layers import SelectAdaptivePool2d, ConvBnAct, create_conv2d, create_pool2d from .registry import register_model @@ -484,8 +487,15 @@ class NASNetALarge(nn.Module): self.cell_17 = NormalCell( in_chs_left=24 * channels, out_chs_left=4 * channels, in_chs_right=24 * channels, out_chs_right=4 * channels, pad_type=pad_type) - self.act = nn.ReLU(inplace=True) + self.feature_info = [ + dict(num_chs=96, reduction=2, module='conv0'), + dict(num_chs=168, reduction=4, module='cell_stem_1.conv_1x1.act'), + dict(num_chs=1008, reduction=8, module='reduction_cell_0.conv_1x1.act'), + dict(num_chs=2016, reduction=16, module='reduction_cell_1.conv_1x1.act'), + dict(num_chs=4032, reduction=32, module='act'), + ] + self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) self.last_linear = nn.Linear(self.num_features * self.global_pool.feat_mult(), num_classes) @@ -503,11 +513,9 @@ class NASNetALarge(nn.Module): def forward_features(self, x): x_conv0 = self.conv0(x) - #0 x_stem_0 = self.cell_stem_0(x_conv0) x_stem_1 = self.cell_stem_1(x_conv0, x_stem_0) - #1 x_cell_0 = self.cell_0(x_stem_1, x_stem_0) x_cell_1 = self.cell_1(x_cell_0, x_stem_1) @@ -515,7 +523,6 @@ class NASNetALarge(nn.Module): x_cell_3 = self.cell_3(x_cell_2, x_cell_1) x_cell_4 = self.cell_4(x_cell_3, x_cell_2) x_cell_5 = self.cell_5(x_cell_4, x_cell_3) - #2 x_reduction_cell_0 = self.reduction_cell_0(x_cell_5, x_cell_4) x_cell_6 = self.cell_6(x_reduction_cell_0, x_cell_4) @@ -524,7 +531,6 @@ class NASNetALarge(nn.Module): x_cell_9 = self.cell_9(x_cell_8, x_cell_7) x_cell_10 = self.cell_10(x_cell_9, x_cell_8) x_cell_11 = self.cell_11(x_cell_10, x_cell_9) - #3 x_reduction_cell_1 = self.reduction_cell_1(x_cell_11, x_cell_10) x_cell_12 = self.cell_12(x_reduction_cell_1, x_cell_10) @@ -534,8 +540,6 @@ class NASNetALarge(nn.Module): x_cell_16 = self.cell_16(x_cell_15, x_cell_14) x_cell_17 = self.cell_17(x_cell_16, x_cell_15) x = self.act(x_cell_17) - #4 - return x def forward(self, x): @@ -547,14 +551,16 @@ class NASNetALarge(nn.Module): return x +def _create_nasnet(variant, pretrained=False, **kwargs): + return build_model_with_cfg( + NASNetALarge, variant, pretrained, default_cfg=default_cfgs[variant], + feature_cfg=dict(feature_cls='hook'), # not possible to re-write this model, must use FeatureHookNet + **kwargs) + + @register_model -def nasnetalarge(pretrained=False, num_classes=1000, in_chans=3, **kwargs): +def nasnetalarge(pretrained=False, **kwargs): """NASNet-A large model architecture. """ - default_cfg = default_cfgs['nasnetalarge'] - model = NASNetALarge(num_classes=num_classes, in_chans=in_chans, **kwargs) - model.default_cfg = default_cfg - if pretrained: - load_pretrained(model, default_cfg, num_classes, in_chans) - - return model + model_kwargs = dict(pad_type='same', **kwargs) + return _create_nasnet('nasnetalarge', pretrained, **model_kwargs) diff --git a/timm/models/pnasnet.py b/timm/models/pnasnet.py index fb2eb0dd..e5f3b6d5 100644 --- a/timm/models/pnasnet.py +++ b/timm/models/pnasnet.py @@ -5,15 +5,13 @@ https://github.com/Cadene/pretrained-models.pytorch/blob/master/pretrainedmodels/models/pnasnet.py """ -from __future__ import print_function, division, absolute_import - from collections import OrderedDict import torch import torch.nn as nn import torch.nn.functional as F -from .helpers import load_pretrained +from .helpers import build_model_with_cfg from .layers import SelectAdaptivePool2d, ConvBnAct, create_conv2d, create_pool2d from .registry import register_model @@ -147,35 +145,35 @@ class CellBase(nn.Module): class CellStem0(CellBase): - def __init__(self, in_chs_left, out_chs_left, in_chs_right, out_chs_right, padding=''): + def __init__(self, in_chs_left, out_chs_left, in_chs_right, out_chs_right, pad_type=''): super(CellStem0, self).__init__() - self.conv_1x1 = ActConvBn(in_chs_right, out_chs_right, kernel_size=1, padding=padding) + self.conv_1x1 = ActConvBn(in_chs_right, out_chs_right, kernel_size=1, padding=pad_type) self.comb_iter_0_left = BranchSeparables( - in_chs_left, out_chs_left, kernel_size=5, stride=2, stem_cell=True, padding=padding) + in_chs_left, out_chs_left, kernel_size=5, stride=2, stem_cell=True, padding=pad_type) self.comb_iter_0_right = nn.Sequential(OrderedDict([ - ('max_pool', create_pool2d('max', 3, stride=2, padding=padding)), - ('conv', create_conv2d(in_chs_left, out_chs_left, kernel_size=1, padding=padding)), + ('max_pool', create_pool2d('max', 3, stride=2, padding=pad_type)), + ('conv', create_conv2d(in_chs_left, out_chs_left, kernel_size=1, padding=pad_type)), ('bn', nn.BatchNorm2d(out_chs_left, eps=0.001)), ])) self.comb_iter_1_left = BranchSeparables( - out_chs_right, out_chs_right, kernel_size=7, stride=2, padding=padding) - self.comb_iter_1_right = create_pool2d('max', 3, stride=2, padding=padding) + out_chs_right, out_chs_right, kernel_size=7, stride=2, padding=pad_type) + self.comb_iter_1_right = create_pool2d('max', 3, stride=2, padding=pad_type) self.comb_iter_2_left = BranchSeparables( - out_chs_right, out_chs_right, kernel_size=5, stride=2, padding=padding) + out_chs_right, out_chs_right, kernel_size=5, stride=2, padding=pad_type) self.comb_iter_2_right = BranchSeparables( - out_chs_right, out_chs_right, kernel_size=3, stride=2, padding=padding) + out_chs_right, out_chs_right, kernel_size=3, stride=2, padding=pad_type) self.comb_iter_3_left = BranchSeparables( - out_chs_right, out_chs_right, kernel_size=3, padding=padding) - self.comb_iter_3_right = create_pool2d('max', 3, stride=2, padding=padding) + out_chs_right, out_chs_right, kernel_size=3, padding=pad_type) + self.comb_iter_3_right = create_pool2d('max', 3, stride=2, padding=pad_type) self.comb_iter_4_left = BranchSeparables( - in_chs_right, out_chs_right, kernel_size=3, stride=2, stem_cell=True, padding=padding) + in_chs_right, out_chs_right, kernel_size=3, stride=2, stem_cell=True, padding=pad_type) self.comb_iter_4_right = ActConvBn( - out_chs_right, out_chs_right, kernel_size=1, stride=2, padding=padding) + out_chs_right, out_chs_right, kernel_size=1, stride=2, padding=pad_type) def forward(self, x_left): x_right = self.conv_1x1(x_left) @@ -185,12 +183,12 @@ class CellStem0(CellBase): class Cell(CellBase): - def __init__(self, in_chs_left, out_chs_left, in_chs_right, out_chs_right, padding='', + def __init__(self, in_chs_left, out_chs_left, in_chs_right, out_chs_right, pad_type='', is_reduction=False, match_prev_layer_dims=False): super(Cell, self).__init__() # If `is_reduction` is set to `True` stride 2 is used for - # convolutional and pooling layers to reduce the spatial size of + # convolution and pooling layers to reduce the spatial size of # the output of a cell approximately by a factor of 2. stride = 2 if is_reduction else 1 @@ -199,32 +197,32 @@ class Cell(CellBase): # of the left input of a cell approximately by a factor of 2. self.match_prev_layer_dimensions = match_prev_layer_dims if match_prev_layer_dims: - self.conv_prev_1x1 = FactorizedReduction(in_chs_left, out_chs_left, padding=padding) + self.conv_prev_1x1 = FactorizedReduction(in_chs_left, out_chs_left, padding=pad_type) else: - self.conv_prev_1x1 = ActConvBn(in_chs_left, out_chs_left, kernel_size=1, padding=padding) - self.conv_1x1 = ActConvBn(in_chs_right, out_chs_right, kernel_size=1, padding=padding) + self.conv_prev_1x1 = ActConvBn(in_chs_left, out_chs_left, kernel_size=1, padding=pad_type) + self.conv_1x1 = ActConvBn(in_chs_right, out_chs_right, kernel_size=1, padding=pad_type) self.comb_iter_0_left = BranchSeparables( - out_chs_left, out_chs_left, kernel_size=5, stride=stride, padding=padding) - self.comb_iter_0_right = create_pool2d('max', 3, stride=stride, padding=padding) + out_chs_left, out_chs_left, kernel_size=5, stride=stride, padding=pad_type) + self.comb_iter_0_right = create_pool2d('max', 3, stride=stride, padding=pad_type) self.comb_iter_1_left = BranchSeparables( - out_chs_right, out_chs_right, kernel_size=7, stride=stride, padding=padding) - self.comb_iter_1_right = create_pool2d('max', 3, stride=stride, padding=padding) + out_chs_right, out_chs_right, kernel_size=7, stride=stride, padding=pad_type) + self.comb_iter_1_right = create_pool2d('max', 3, stride=stride, padding=pad_type) self.comb_iter_2_left = BranchSeparables( - out_chs_right, out_chs_right, kernel_size=5, stride=stride, padding=padding) + out_chs_right, out_chs_right, kernel_size=5, stride=stride, padding=pad_type) self.comb_iter_2_right = BranchSeparables( - out_chs_right, out_chs_right, kernel_size=3, stride=stride, padding=padding) + out_chs_right, out_chs_right, kernel_size=3, stride=stride, padding=pad_type) self.comb_iter_3_left = BranchSeparables(out_chs_right, out_chs_right, kernel_size=3) - self.comb_iter_3_right = create_pool2d('max', 3, stride=stride, padding=padding) + self.comb_iter_3_right = create_pool2d('max', 3, stride=stride, padding=pad_type) self.comb_iter_4_left = BranchSeparables( - out_chs_left, out_chs_left, kernel_size=3, stride=stride, padding=padding) + out_chs_left, out_chs_left, kernel_size=3, stride=stride, padding=pad_type) if is_reduction: self.comb_iter_4_right = ActConvBn( - out_chs_right, out_chs_right, kernel_size=1, stride=stride, padding=padding) + out_chs_right, out_chs_right, kernel_size=1, stride=stride, padding=pad_type) else: self.comb_iter_4_right = None @@ -236,7 +234,7 @@ class Cell(CellBase): class PNASNet5Large(nn.Module): - def __init__(self, num_classes=1001, in_chans=3, output_stride=32, drop_rate=0.5, global_pool='avg', padding=''): + def __init__(self, num_classes=1001, in_chans=3, output_stride=32, drop_rate=0., global_pool='avg', pad_type=''): super(PNASNet5Large, self).__init__() self.num_classes = num_classes self.drop_rate = drop_rate @@ -248,43 +246,51 @@ class PNASNet5Large(nn.Module): norm_kwargs=dict(eps=0.001, momentum=0.1), act_layer=None) self.cell_stem_0 = CellStem0( - in_chs_left=96, out_chs_left=54, in_chs_right=96, out_chs_right=54, padding=padding) + in_chs_left=96, out_chs_left=54, in_chs_right=96, out_chs_right=54, pad_type=pad_type) self.cell_stem_1 = Cell( - in_chs_left=96, out_chs_left=108, in_chs_right=270, out_chs_right=108, padding=padding, + in_chs_left=96, out_chs_left=108, in_chs_right=270, out_chs_right=108, pad_type=pad_type, match_prev_layer_dims=True, is_reduction=True) self.cell_0 = Cell( - in_chs_left=270, out_chs_left=216, in_chs_right=540, out_chs_right=216, padding=padding, + in_chs_left=270, out_chs_left=216, in_chs_right=540, out_chs_right=216, pad_type=pad_type, match_prev_layer_dims=True) self.cell_1 = Cell( - in_chs_left=540, out_chs_left=216, in_chs_right=1080, out_chs_right=216, padding=padding) + in_chs_left=540, out_chs_left=216, in_chs_right=1080, out_chs_right=216, pad_type=pad_type) self.cell_2 = Cell( - in_chs_left=1080, out_chs_left=216, in_chs_right=1080, out_chs_right=216, padding=padding) + in_chs_left=1080, out_chs_left=216, in_chs_right=1080, out_chs_right=216, pad_type=pad_type) self.cell_3 = Cell( - in_chs_left=1080, out_chs_left=216, in_chs_right=1080, out_chs_right=216, padding=padding) + in_chs_left=1080, out_chs_left=216, in_chs_right=1080, out_chs_right=216, pad_type=pad_type) self.cell_4 = Cell( - in_chs_left=1080, out_chs_left=432, in_chs_right=1080, out_chs_right=432, padding=padding, + in_chs_left=1080, out_chs_left=432, in_chs_right=1080, out_chs_right=432, pad_type=pad_type, is_reduction=True) self.cell_5 = Cell( - in_chs_left=1080, out_chs_left=432, in_chs_right=2160, out_chs_right=432, padding=padding, + in_chs_left=1080, out_chs_left=432, in_chs_right=2160, out_chs_right=432, pad_type=pad_type, match_prev_layer_dims=True) self.cell_6 = Cell( - in_chs_left=2160, out_chs_left=432, in_chs_right=2160, out_chs_right=432, padding=padding) + in_chs_left=2160, out_chs_left=432, in_chs_right=2160, out_chs_right=432, pad_type=pad_type) self.cell_7 = Cell( - in_chs_left=2160, out_chs_left=432, in_chs_right=2160, out_chs_right=432, padding=padding) + in_chs_left=2160, out_chs_left=432, in_chs_right=2160, out_chs_right=432, pad_type=pad_type) self.cell_8 = Cell( - in_chs_left=2160, out_chs_left=864, in_chs_right=2160, out_chs_right=864, padding=padding, + in_chs_left=2160, out_chs_left=864, in_chs_right=2160, out_chs_right=864, pad_type=pad_type, is_reduction=True) self.cell_9 = Cell( - in_chs_left=2160, out_chs_left=864, in_chs_right=4320, out_chs_right=864, padding=padding, + in_chs_left=2160, out_chs_left=864, in_chs_right=4320, out_chs_right=864, pad_type=pad_type, match_prev_layer_dims=True) self.cell_10 = Cell( - in_chs_left=4320, out_chs_left=864, in_chs_right=4320, out_chs_right=864, padding=padding) + in_chs_left=4320, out_chs_left=864, in_chs_right=4320, out_chs_right=864, pad_type=pad_type) self.cell_11 = Cell( - in_chs_left=4320, out_chs_left=864, in_chs_right=4320, out_chs_right=864, padding=padding) - self.relu = nn.ReLU() + in_chs_left=4320, out_chs_left=864, in_chs_right=4320, out_chs_right=864, pad_type=pad_type) + self.act = nn.ReLU() + self.feature_info = [ + dict(num_chs=96, reduction=2, module='conv_0'), + dict(num_chs=270, reduction=4, module='cell_stem_1.conv_1x1.act'), + dict(num_chs=1080, reduction=8, module='cell_4.conv_1x1.act'), + dict(num_chs=2160, reduction=16, module='cell_8.conv_1x1.act'), + dict(num_chs=4320, reduction=32, module='act'), + ] + self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) self.last_linear = nn.Linear(self.num_features * self.global_pool.feat_mult(), num_classes) @@ -316,7 +322,7 @@ class PNASNet5Large(nn.Module): x_cell_9 = self.cell_9(x_cell_7, x_cell_8) x_cell_10 = self.cell_10(x_cell_8, x_cell_9) x_cell_11 = self.cell_11(x_cell_9, x_cell_10) - x = self.relu(x_cell_11) + x = self.act(x_cell_11) return x def forward(self, x): @@ -328,16 +334,18 @@ class PNASNet5Large(nn.Module): return x +def _create_pnasnet(variant, pretrained=False, **kwargs): + return build_model_with_cfg( + PNASNet5Large, variant, pretrained, default_cfg=default_cfgs[variant], + feature_cfg=dict(feature_cls='hook'), # not possible to re-write this model, must use FeatureHookNet + **kwargs) + + @register_model -def pnasnet5large(pretrained=False, num_classes=1000, in_chans=3, **kwargs): +def pnasnet5large(pretrained=False, **kwargs): r"""PNASNet-5 model architecture from the `"Progressive Neural Architecture Search" `_ paper. """ - default_cfg = default_cfgs['pnasnet5large'] - model = PNASNet5Large(num_classes=num_classes, in_chans=in_chans, padding='same', **kwargs) - model.default_cfg = default_cfg - if pretrained: - load_pretrained(model, default_cfg, num_classes, in_chans) - - return model + model_kwargs = dict(pad_type='same', **kwargs) + return _create_pnasnet('pnasnet5large', pretrained, **model_kwargs) diff --git a/timm/models/xception.py b/timm/models/xception.py index 8bf62624..28a78344 100644 --- a/timm/models/xception.py +++ b/timm/models/xception.py @@ -154,6 +154,13 @@ class Xception(nn.Module): self.conv4 = SeparableConv2d(1536, self.num_features, 3, 1, 1) self.bn4 = nn.BatchNorm2d(self.num_features) self.act4 = nn.ReLU(inplace=True) + self.feature_info = [ + dict(num_chs=64, reduction=2, module='act2'), + dict(num_chs=128, reduction=4, module='block2.rep.0'), + dict(num_chs=256, reduction=8, module='block3.rep.0'), + dict(num_chs=728, reduction=16, module='block12.rep.0'), + dict(num_chs=2048, reduction=32, module='act4'), + ] self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) self.fc = nn.Linear(self.num_features * self.global_pool.feat_mult(), num_classes) @@ -221,7 +228,7 @@ class Xception(nn.Module): def _xception(variant, pretrained=False, **kwargs): return build_model_with_cfg( Xception, variant, pretrained, default_cfg=default_cfgs[variant], - feature_cfg=dict(), **kwargs) + feature_cfg=dict(use_hooks=True), **kwargs) @register_model