From a7ebe090291a218ad781f7239e5ca649ccd6f35b Mon Sep 17 00:00:00 2001 From: Vyacheslav Shults Date: Wed, 6 May 2020 09:54:03 +0300 Subject: [PATCH] Replace all None by nn.Identity() in all models reset_classifier when False-values num_classes is given. Make small code refactoring --- timm/models/densenet.py | 15 +++-- timm/models/dla.py | 13 ++-- timm/models/dpn.py | 18 ++--- timm/models/efficientnet.py | 34 +++++----- timm/models/gluon_resnet.py | 14 ++-- timm/models/gluon_xception.py | 20 +++--- timm/models/inception_resnet_v2.py | 11 +-- timm/models/inception_v3.py | 5 +- timm/models/inception_v4.py | 14 ++-- timm/models/mobilenetv3.py | 18 ++--- timm/models/nasnet.py | 105 +++++++++++++++-------------- timm/models/pnasnet.py | 9 +-- timm/models/res2net.py | 10 ++- timm/models/resnet.py | 12 ++-- timm/models/selecsls.py | 11 ++- timm/models/senet.py | 12 ++-- timm/models/tresnet.py | 14 ++-- timm/models/xception.py | 11 +-- 18 files changed, 179 insertions(+), 167 deletions(-) diff --git a/timm/models/densenet.py b/timm/models/densenet.py index 4235c0f7..c8be8683 100644 --- a/timm/models/densenet.py +++ b/timm/models/densenet.py @@ -2,17 +2,17 @@ This file is a copy of https://github.com/pytorch/vision 'densenet.py' (BSD-3-Clause) with fixed kwargs passthrough and addition of dynamic global avg/max pool. """ +import re from collections import OrderedDict import torch import torch.nn as nn import torch.nn.functional as F -from .registry import register_model +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from .helpers import load_pretrained from .layers import SelectAdaptivePool2d -from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -import re +from .registry import register_model __all__ = ['DenseNet'] @@ -85,6 +85,7 @@ class DenseNet(nn.Module): drop_rate (float) - dropout rate after each dense layer num_classes (int) - number of classification classes """ + def __init__(self, growth_rate=32, block_config=(6, 12, 24, 16), num_init_features=64, bn_size=4, drop_rate=0, num_classes=1000, in_chans=3, global_pool='avg'): @@ -127,8 +128,11 @@ class DenseNet(nn.Module): def reset_classifier(self, num_classes, global_pool='avg'): self.num_classes = num_classes self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) - self.classifier = nn.Linear( - self.num_features * self.global_pool.feat_mult(), num_classes) if num_classes else None + if num_classes: + num_features = self.num_features * self.global_pool.feat_mult() + self.classifier = nn.Linear(num_features, num_classes) + else: + self.classifier = nn.Identity() def forward_features(self, x): x = self.features(x) @@ -157,7 +161,6 @@ def _filter_pretrained(state_dict): return state_dict - @register_model def densenet121(pretrained=False, num_classes=1000, in_chans=3, **kwargs): r"""Densenet-121 model from diff --git a/timm/models/dla.py b/timm/models/dla.py index a9e81d16..f6820ab9 100644 --- a/timm/models/dla.py +++ b/timm/models/dla.py @@ -11,11 +11,10 @@ import torch import torch.nn as nn import torch.nn.functional as F -from .registry import register_model +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from .helpers import load_pretrained from .layers import SelectAdaptivePool2d -from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD - +from .registry import register_model __all__ = ['DLA'] @@ -51,6 +50,7 @@ default_cfgs = { class DlaBasic(nn.Module): """DLA Basic""" + def __init__(self, inplanes, planes, stride=1, dilation=1, **_): super(DlaBasic, self).__init__() self.conv1 = nn.Conv2d( @@ -170,7 +170,7 @@ class DlaBottle2neck(nn.Module): sp = bn(sp) sp = self.relu(sp) spo.append(sp) - if self.scale > 1 : + if self.scale > 1: spo.append(self.pool(spx[-1]) if self.is_first else spx[-1]) out = torch.cat(spo, 1) @@ -304,9 +304,10 @@ class DLA(nn.Module): self.num_classes = num_classes self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) if num_classes: - self.fc = nn.Conv2d(self.num_features * self.global_pool.feat_mult(), num_classes, 1, bias=True) + num_features = self.num_features * self.global_pool.feat_mult() + self.fc = nn.Conv2d(num_features, num_classes, kernel_size=1, bias=True) else: - self.fc = None + self.fc = nn.Identity() def forward_features(self, x): x = self.base_layer(x) diff --git a/timm/models/dpn.py b/timm/models/dpn.py index fd58e516..9c4fafc8 100644 --- a/timm/models/dpn.py +++ b/timm/models/dpn.py @@ -9,16 +9,16 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from collections import OrderedDict + import torch import torch.nn as nn import torch.nn.functional as F -from collections import OrderedDict -from .registry import register_model +from timm.data import IMAGENET_DPN_MEAN, IMAGENET_DPN_STD from .helpers import load_pretrained from .layers import SelectAdaptivePool2d -from timm.data import IMAGENET_DPN_MEAN, IMAGENET_DPN_STD - +from .registry import register_model __all__ = ['DPN'] @@ -218,8 +218,8 @@ class DPN(nn.Module): # Using 1x1 conv for the FC layer to allow the extra pooling scheme self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) - self.classifier = nn.Conv2d( - self.num_features * self.global_pool.feat_mult(), num_classes, kernel_size=1, bias=True) + num_features = self.num_features * self.global_pool.feat_mult() + self.classifier = nn.Conv2d(num_features, num_classes, kernel_size=1, bias=True) def get_classifier(self): return self.classifier @@ -228,10 +228,10 @@ class DPN(nn.Module): self.num_classes = num_classes self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) if num_classes: - self.classifier = nn.Conv2d( - self.num_features * self.global_pool.feat_mult(), num_classes, kernel_size=1, bias=True) + num_features = self.num_features * self.global_pool.feat_mult() + self.classifier = nn.Conv2d(num_features, num_classes, kernel_size=1, bias=True) else: - self.classifier = None + self.classifier = nn.Identity() def forward_features(self, x): return self.features(x) diff --git a/timm/models/efficientnet.py b/timm/models/efficientnet.py index 92460438..21fbee19 100644 --- a/timm/models/efficientnet.py +++ b/timm/models/efficientnet.py @@ -24,14 +24,12 @@ An implementation of EfficienNet that covers variety of related models with effi Hacked together by Ross Wightman """ +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD from .efficientnet_builder import * from .feature_hooks import FeatureHooks -from .registry import register_model from .helpers import load_pretrained, adapt_model_from_file from .layers import SelectAdaptivePool2d -from timm.models.layers import create_conv2d -from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD - +from .registry import register_model __all__ = ['EfficientNet'] @@ -373,8 +371,11 @@ class EfficientNet(nn.Module): def reset_classifier(self, num_classes, global_pool='avg'): self.num_classes = num_classes self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) - self.classifier = nn.Linear( - self.num_features * self.global_pool.feat_mult(), num_classes) if num_classes else None + if num_classes: + num_features = self.num_features * self.global_pool.feat_mult() + self.classifier = nn.Linear(num_features, num_classes) + else: + self.classifier = nn.Identity() def forward_features(self, x): x = self.conv_stem(x) @@ -785,13 +786,13 @@ def _gen_efficientnet_condconv( Ref impl: https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet/condconv """ arch_def = [ - ['ds_r1_k3_s1_e1_c16_se0.25'], - ['ir_r2_k3_s2_e6_c24_se0.25'], - ['ir_r2_k5_s2_e6_c40_se0.25'], - ['ir_r3_k3_s2_e6_c80_se0.25'], - ['ir_r3_k5_s1_e6_c112_se0.25_cc4'], - ['ir_r4_k5_s2_e6_c192_se0.25_cc4'], - ['ir_r1_k3_s1_e6_c320_se0.25_cc4'], + ['ds_r1_k3_s1_e1_c16_se0.25'], + ['ir_r2_k3_s2_e6_c24_se0.25'], + ['ir_r2_k5_s2_e6_c40_se0.25'], + ['ir_r3_k3_s2_e6_c80_se0.25'], + ['ir_r3_k5_s1_e6_c112_se0.25_cc4'], + ['ir_r4_k5_s2_e6_c192_se0.25_cc4'], + ['ir_r1_k3_s1_e6_c320_se0.25_cc4'], ] # NOTE unlike official impl, this one uses `cc` option where x is the base number of experts for each stage and # the expert_multiplier increases that on a per-model basis as with depth/channel multipliers @@ -1187,6 +1188,7 @@ def efficientnet_cc_b0_8e(pretrained=False, **kwargs): pretrained=pretrained, **kwargs) return model + @register_model def efficientnet_cc_b1_8e(pretrained=False, **kwargs): """ EfficientNet-CondConv-B1 w/ 8 Experts """ @@ -1242,8 +1244,6 @@ def efficientnet_lite4(pretrained=False, **kwargs): return model - - @register_model def efficientnet_b1_pruned(pretrained=False, **kwargs): """ EfficientNet-B1 Pruned. The pruning has been obtained using https://arxiv.org/pdf/2002.08258.pdf """ @@ -1275,8 +1275,6 @@ def efficientnet_b3_pruned(pretrained=False, **kwargs): return model - - @register_model def tf_efficientnet_b0(pretrained=False, **kwargs): """ EfficientNet-B0. Tensorflow compatible variant """ @@ -1619,6 +1617,7 @@ def tf_efficientnet_cc_b0_8e(pretrained=False, **kwargs): pretrained=pretrained, **kwargs) return model + @register_model def tf_efficientnet_cc_b1_8e(pretrained=False, **kwargs): """ EfficientNet-CondConv-B1 w/ 8 Experts. Tensorflow compatible variant """ @@ -1764,4 +1763,3 @@ def tf_mixnet_l(pretrained=False, **kwargs): model = _gen_mixnet_m( 'tf_mixnet_l', channel_multiplier=1.3, pretrained=pretrained, **kwargs) return model - diff --git a/timm/models/gluon_resnet.py b/timm/models/gluon_resnet.py index 6ccc4c53..a0bc4bb2 100644 --- a/timm/models/gluon_resnet.py +++ b/timm/models/gluon_resnet.py @@ -3,17 +3,11 @@ This file evolved from https://github.com/pytorch/vision 'resnet.py' with (SE)-R and ports of Gluon variations (https://github.com/dmlc/gluon-cv/blob/master/gluoncv/model_zoo/resnet.py) by Ross Wightman """ -import math -import torch -import torch.nn as nn -import torch.nn.functional as F - -from .registry import register_model +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from .helpers import load_pretrained from .layers import SEModule -from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD - +from .registry import register_model from .resnet import ResNet, Bottleneck, BasicBlock @@ -202,8 +196,8 @@ def gluon_resnet50_v1e(pretrained=False, num_classes=1000, in_chans=3, **kwargs) model = ResNet(Bottleneck, [3, 4, 6, 3], num_classes=num_classes, in_chans=in_chans, stem_width=64, stem_type='deep', avg_down=True, **kwargs) model.default_cfg = default_cfg - #if pretrained: - # load_pretrained(model, default_cfg, num_classes, in_chans) + if pretrained: + load_pretrained(model, default_cfg, num_classes, in_chans) return model diff --git a/timm/models/gluon_xception.py b/timm/models/gluon_xception.py index 2fc8e699..0a536b5f 100644 --- a/timm/models/gluon_xception.py +++ b/timm/models/gluon_xception.py @@ -6,15 +6,15 @@ Original PyTorch DeepLab impl: https://github.com/jfzhang95/pytorch-deeplab-xcep Hacked together by Ross Wightman """ -import torch +from collections import OrderedDict + import torch.nn as nn import torch.nn.functional as F -from collections import OrderedDict -from .registry import register_model +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from .helpers import load_pretrained from .layers import SelectAdaptivePool2d -from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from .registry import register_model __all__ = ['Xception65', 'Xception71'] @@ -47,7 +47,6 @@ default_cfgs = { } } - """ PADDING NOTES The original PyTorch and Gluon impl of these models dutifully reproduced the aligned padding added to Tensorflow models for Deeplab. This padding was compensating @@ -223,7 +222,7 @@ class Xception65(nn.Module): norm_layer=norm_layer, norm_kwargs=norm_kwargs, start_with_relu=True, grow_first=True, is_last=True) # Middle flow - self.mid = nn.Sequential(OrderedDict([('block%d' % i, Block( + self.mid = nn.Sequential(OrderedDict([('block%d' % i, Block( 728, 728, num_reps=3, stride=1, dilation=middle_block_dilation, norm_layer=norm_layer, norm_kwargs=norm_kwargs, start_with_relu=True, grow_first=True)) for i in range(4, 20)])) @@ -333,7 +332,7 @@ class Xception71(nn.Module): exit_block_dilations = (2, 4) else: raise NotImplementedError - + # Entry flow self.conv1 = nn.Conv2d(in_chans, 32, kernel_size=3, stride=2, padding=1, bias=False) self.bn1 = norm_layer(num_features=32, **norm_kwargs) @@ -394,7 +393,11 @@ class Xception71(nn.Module): def reset_classifier(self, num_classes, global_pool='avg'): self.num_classes = num_classes self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) - self.fc = nn.Linear(self.num_features * self.global_pool.feat_mult(), num_classes) if num_classes else None + if num_classes: + num_features = self.num_features * self.global_pool.feat_mult() + self.fc = nn.Linear(num_features, num_classes) + else: + self.fc = nn.Identity() def forward_features(self, x): # Entry flow @@ -465,4 +468,3 @@ def gluon_xception71(pretrained=False, num_classes=1000, in_chans=3, **kwargs): if pretrained: load_pretrained(model, default_cfg, num_classes, in_chans) return model - diff --git a/timm/models/inception_resnet_v2.py b/timm/models/inception_resnet_v2.py index 13ad0e9d..34b14570 100644 --- a/timm/models/inception_resnet_v2.py +++ b/timm/models/inception_resnet_v2.py @@ -6,10 +6,10 @@ import torch import torch.nn as nn import torch.nn.functional as F -from .registry import register_model +from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD from .helpers import load_pretrained from .layers import SelectAdaptivePool2d -from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD +from .registry import register_model __all__ = ['InceptionResnetV2'] @@ -296,8 +296,11 @@ class InceptionResnetV2(nn.Module): def reset_classifier(self, num_classes, global_pool='avg'): self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) self.num_classes = num_classes - self.classif = nn.Linear( - self.num_features * self.global_pool.feat_mult(), num_classes) if num_classes else None + if num_classes: + num_features = self.num_features * self.global_pool.feat_mult() + self.classif = nn.Linear(num_features, num_classes) + else: + self.classif = nn.Identity() def forward_features(self, x): x = self.conv2d_1a(x) diff --git a/timm/models/inception_v3.py b/timm/models/inception_v3.py index a0ea784f..64d6fe75 100644 --- a/timm/models/inception_v3.py +++ b/timm/models/inception_v3.py @@ -1,7 +1,8 @@ from torchvision.models import Inception3 -from .registry import register_model -from .helpers import load_pretrained + from timm.data import IMAGENET_DEFAULT_STD, IMAGENET_DEFAULT_MEAN, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD +from .helpers import load_pretrained +from .registry import register_model __all__ = [] diff --git a/timm/models/inception_v4.py b/timm/models/inception_v4.py index 16080554..52b5ef47 100644 --- a/timm/models/inception_v4.py +++ b/timm/models/inception_v4.py @@ -6,10 +6,10 @@ import torch import torch.nn as nn import torch.nn.functional as F -from .registry import register_model +from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD from .helpers import load_pretrained from .layers import SelectAdaptivePool2d -from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD +from .registry import register_model __all__ = ['InceptionV4'] @@ -280,8 +280,11 @@ class InceptionV4(nn.Module): def reset_classifier(self, num_classes, global_pool='avg'): self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) self.num_classes = num_classes - self.last_linear = nn.Linear( - self.num_features * self.global_pool.feat_mult(), num_classes) if num_classes else None + if num_classes: + num_features = self.num_features * self.global_pool.feat_mult() + self.last_linear = nn.Linear(num_features, num_classes) + else: + self.last_linear = nn.Identity() def forward_features(self, x): return self.features(x) @@ -303,6 +306,3 @@ def inception_v4(pretrained=False, num_classes=1000, in_chans=3, **kwargs): if pretrained: load_pretrained(model, default_cfg, num_classes, in_chans) return model - - - diff --git a/timm/models/mobilenetv3.py b/timm/models/mobilenetv3.py index 86ca9f7a..e38884b8 100644 --- a/timm/models/mobilenetv3.py +++ b/timm/models/mobilenetv3.py @@ -8,13 +8,13 @@ Paper: Searching for MobileNetV3 - https://arxiv.org/abs/1905.02244 Hacked together by Ross Wightman """ +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD from .efficientnet_builder import * -from .registry import register_model +from .feature_hooks import FeatureHooks from .helpers import load_pretrained from .layers import SelectAdaptivePool2d, create_conv2d from .layers.activations import HardSwish, hard_sigmoid -from .feature_hooks import FeatureHooks -from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD +from .registry import register_model __all__ = ['MobileNetV3'] @@ -76,7 +76,7 @@ class MobileNetV3(nn.Module): channel_multiplier=1.0, pad_type='', act_layer=nn.ReLU, drop_rate=0., drop_path_rate=0., se_kwargs=None, norm_layer=nn.BatchNorm2d, norm_kwargs=None, global_pool='avg'): super(MobileNetV3, self).__init__() - + self.num_classes = num_classes self.num_features = num_features self.drop_rate = drop_rate @@ -96,7 +96,7 @@ class MobileNetV3(nn.Module): self.blocks = nn.Sequential(*builder(self._in_chs, block_args)) self.feature_info = builder.features self._in_chs = builder.in_chs - + # Head + Pooling self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) self.conv_head = create_conv2d(self._in_chs, self.num_features, 1, padding=pad_type, bias=head_bias) @@ -120,8 +120,11 @@ class MobileNetV3(nn.Module): def reset_classifier(self, num_classes, global_pool='avg'): self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) self.num_classes = num_classes - self.classifier = nn.Linear( - self.num_features * self.global_pool.feat_mult(), num_classes) if self.num_classes else None + if num_classes: + num_features = self.num_features * self.global_pool.feat_mult() + self.classifier = nn.Linear(num_features, num_classes) + else: + self.classifier = nn.Identity() def forward_features(self, x): x = self.conv_stem(x) @@ -397,7 +400,6 @@ def mobilenetv3_small_075(pretrained=False, **kwargs): @register_model def mobilenetv3_small_100(pretrained=False, **kwargs): - print(kwargs) """ MobileNet V3 """ model = _gen_mobilenet_v3('mobilenetv3_small_100', 1.0, pretrained=pretrained, **kwargs) return model diff --git a/timm/models/nasnet.py b/timm/models/nasnet.py index 8847b1de..21d20032 100644 --- a/timm/models/nasnet.py +++ b/timm/models/nasnet.py @@ -2,10 +2,9 @@ import torch import torch.nn as nn import torch.nn.functional as F -from .registry import register_model from .helpers import load_pretrained from .layers import SelectAdaptivePool2d - +from .registry import register_model __all__ = ['NASNetALarge'] @@ -187,17 +186,17 @@ class CellStem1(nn.Module): self.stem_size = stem_size self.conv_1x1 = nn.Sequential() self.conv_1x1.add_module('relu', nn.ReLU()) - self.conv_1x1.add_module('conv', nn.Conv2d(2*self.num_channels, self.num_channels, 1, stride=1, bias=False)) + self.conv_1x1.add_module('conv', nn.Conv2d(2 * self.num_channels, self.num_channels, 1, stride=1, bias=False)) self.conv_1x1.add_module('bn', nn.BatchNorm2d(self.num_channels, eps=0.001, momentum=0.1, affine=True)) self.relu = nn.ReLU() self.path_1 = nn.Sequential() self.path_1.add_module('avgpool', nn.AvgPool2d(1, stride=2, count_include_pad=False)) - self.path_1.add_module('conv', nn.Conv2d(self.stem_size, self.num_channels//2, 1, stride=1, bias=False)) + self.path_1.add_module('conv', nn.Conv2d(self.stem_size, self.num_channels // 2, 1, stride=1, bias=False)) self.path_2 = nn.ModuleList() self.path_2.add_module('pad', nn.ZeroPad2d((0, 1, 0, 1))) self.path_2.add_module('avgpool', nn.AvgPool2d(1, stride=2, count_include_pad=False)) - self.path_2.add_module('conv', nn.Conv2d(self.stem_size, self.num_channels//2, 1, stride=1, bias=False)) + self.path_2.add_module('conv', nn.Conv2d(self.stem_size, self.num_channels // 2, 1, stride=1, bias=False)) self.final_path_bn = nn.BatchNorm2d(self.num_channels, eps=0.001, momentum=0.1, affine=True) @@ -507,50 +506,50 @@ class NASNetALarge(nn.Module): self.cell_stem_0 = CellStem0(self.stem_size, num_channels=channels // (channel_multiplier ** 2)) self.cell_stem_1 = CellStem1(self.stem_size, num_channels=channels // channel_multiplier) - self.cell_0 = FirstCell(in_channels_left=channels, out_channels_left=channels//2, - in_channels_right=2*channels, out_channels_right=channels) - self.cell_1 = NormalCell(in_channels_left=2*channels, out_channels_left=channels, - in_channels_right=6*channels, out_channels_right=channels) - self.cell_2 = NormalCell(in_channels_left=6*channels, out_channels_left=channels, - in_channels_right=6*channels, out_channels_right=channels) - self.cell_3 = NormalCell(in_channels_left=6*channels, out_channels_left=channels, - in_channels_right=6*channels, out_channels_right=channels) - self.cell_4 = NormalCell(in_channels_left=6*channels, out_channels_left=channels, - in_channels_right=6*channels, out_channels_right=channels) - self.cell_5 = NormalCell(in_channels_left=6*channels, out_channels_left=channels, - in_channels_right=6*channels, out_channels_right=channels) - - self.reduction_cell_0 = ReductionCell0(in_channels_left=6*channels, out_channels_left=2*channels, - in_channels_right=6*channels, out_channels_right=2*channels) - - self.cell_6 = FirstCell(in_channels_left=6*channels, out_channels_left=channels, - in_channels_right=8*channels, out_channels_right=2*channels) - self.cell_7 = NormalCell(in_channels_left=8*channels, out_channels_left=2*channels, - in_channels_right=12*channels, out_channels_right=2*channels) - self.cell_8 = NormalCell(in_channels_left=12*channels, out_channels_left=2*channels, - in_channels_right=12*channels, out_channels_right=2*channels) - self.cell_9 = NormalCell(in_channels_left=12*channels, out_channels_left=2*channels, - in_channels_right=12*channels, out_channels_right=2*channels) - self.cell_10 = NormalCell(in_channels_left=12*channels, out_channels_left=2*channels, - in_channels_right=12*channels, out_channels_right=2*channels) - self.cell_11 = NormalCell(in_channels_left=12*channels, out_channels_left=2*channels, - in_channels_right=12*channels, out_channels_right=2*channels) - - self.reduction_cell_1 = ReductionCell1(in_channels_left=12*channels, out_channels_left=4*channels, - in_channels_right=12*channels, out_channels_right=4*channels) - - self.cell_12 = FirstCell(in_channels_left=12*channels, out_channels_left=2*channels, - in_channels_right=16*channels, out_channels_right=4*channels) - self.cell_13 = NormalCell(in_channels_left=16*channels, out_channels_left=4*channels, - in_channels_right=24*channels, out_channels_right=4*channels) - self.cell_14 = NormalCell(in_channels_left=24*channels, out_channels_left=4*channels, - in_channels_right=24*channels, out_channels_right=4*channels) - self.cell_15 = NormalCell(in_channels_left=24*channels, out_channels_left=4*channels, - in_channels_right=24*channels, out_channels_right=4*channels) - self.cell_16 = NormalCell(in_channels_left=24*channels, out_channels_left=4*channels, - in_channels_right=24*channels, out_channels_right=4*channels) - self.cell_17 = NormalCell(in_channels_left=24*channels, out_channels_left=4*channels, - in_channels_right=24*channels, out_channels_right=4*channels) + self.cell_0 = FirstCell(in_channels_left=channels, out_channels_left=channels // 2, + in_channels_right=2 * channels, out_channels_right=channels) + self.cell_1 = NormalCell(in_channels_left=2 * channels, out_channels_left=channels, + in_channels_right=6 * channels, out_channels_right=channels) + self.cell_2 = NormalCell(in_channels_left=6 * channels, out_channels_left=channels, + in_channels_right=6 * channels, out_channels_right=channels) + self.cell_3 = NormalCell(in_channels_left=6 * channels, out_channels_left=channels, + in_channels_right=6 * channels, out_channels_right=channels) + self.cell_4 = NormalCell(in_channels_left=6 * channels, out_channels_left=channels, + in_channels_right=6 * channels, out_channels_right=channels) + self.cell_5 = NormalCell(in_channels_left=6 * channels, out_channels_left=channels, + in_channels_right=6 * channels, out_channels_right=channels) + + self.reduction_cell_0 = ReductionCell0(in_channels_left=6 * channels, out_channels_left=2 * channels, + in_channels_right=6 * channels, out_channels_right=2 * channels) + + self.cell_6 = FirstCell(in_channels_left=6 * channels, out_channels_left=channels, + in_channels_right=8 * channels, out_channels_right=2 * channels) + self.cell_7 = NormalCell(in_channels_left=8 * channels, out_channels_left=2 * channels, + in_channels_right=12 * channels, out_channels_right=2 * channels) + self.cell_8 = NormalCell(in_channels_left=12 * channels, out_channels_left=2 * channels, + in_channels_right=12 * channels, out_channels_right=2 * channels) + self.cell_9 = NormalCell(in_channels_left=12 * channels, out_channels_left=2 * channels, + in_channels_right=12 * channels, out_channels_right=2 * channels) + self.cell_10 = NormalCell(in_channels_left=12 * channels, out_channels_left=2 * channels, + in_channels_right=12 * channels, out_channels_right=2 * channels) + self.cell_11 = NormalCell(in_channels_left=12 * channels, out_channels_left=2 * channels, + in_channels_right=12 * channels, out_channels_right=2 * channels) + + self.reduction_cell_1 = ReductionCell1(in_channels_left=12 * channels, out_channels_left=4 * channels, + in_channels_right=12 * channels, out_channels_right=4 * channels) + + self.cell_12 = FirstCell(in_channels_left=12 * channels, out_channels_left=2 * channels, + in_channels_right=16 * channels, out_channels_right=4 * channels) + self.cell_13 = NormalCell(in_channels_left=16 * channels, out_channels_left=4 * channels, + in_channels_right=24 * channels, out_channels_right=4 * channels) + self.cell_14 = NormalCell(in_channels_left=24 * channels, out_channels_left=4 * channels, + in_channels_right=24 * channels, out_channels_right=4 * channels) + self.cell_15 = NormalCell(in_channels_left=24 * channels, out_channels_left=4 * channels, + in_channels_right=24 * channels, out_channels_right=4 * channels) + self.cell_16 = NormalCell(in_channels_left=24 * channels, out_channels_left=4 * channels, + in_channels_right=24 * channels, out_channels_right=4 * channels) + self.cell_17 = NormalCell(in_channels_left=24 * channels, out_channels_left=4 * channels, + in_channels_right=24 * channels, out_channels_right=4 * channels) self.relu = nn.ReLU() self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) @@ -562,9 +561,11 @@ class NASNetALarge(nn.Module): def reset_classifier(self, num_classes, global_pool='avg'): self.num_classes = num_classes self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) - del self.last_linear - self.last_linear = nn.Linear( - self.num_features * self.global_pool.feat_mult(), num_classes) if num_classes else None + if num_classes: + num_features = self.num_features * self.global_pool.feat_mult() + self.last_linear = nn.Linear(num_features, num_classes) + else: + self.last_linear = nn.Identity() def forward_features(self, x): x_conv0 = self.conv0(x) diff --git a/timm/models/pnasnet.py b/timm/models/pnasnet.py index 64d83e3c..97c2f86d 100644 --- a/timm/models/pnasnet.py +++ b/timm/models/pnasnet.py @@ -6,15 +6,16 @@ """ from __future__ import print_function, division, absolute_import + from collections import OrderedDict import torch import torch.nn as nn import torch.nn.functional as F -from .registry import register_model from .helpers import load_pretrained from .layers import SelectAdaptivePool2d +from .registry import register_model __all__ = ['PNASNet5Large'] @@ -349,11 +350,11 @@ class PNASNet5Large(nn.Module): def reset_classifier(self, num_classes, global_pool='avg'): self.num_classes = num_classes self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) - del self.last_linear if num_classes: - self.last_linear = nn.Linear(self.num_features * self.global_pool.feat_mult(), num_classes) + num_features = self.num_features * self.global_pool.feat_mult() + self.last_linear = nn.Linear(num_features, num_classes) else: - self.last_linear = None + self.last_linear = nn.Identity() def forward_features(self, x): x_conv_0 = self.conv_0(x) diff --git a/timm/models/res2net.py b/timm/models/res2net.py index 8655776c..3e3882fe 100644 --- a/timm/models/res2net.py +++ b/timm/models/res2net.py @@ -6,13 +6,11 @@ import math import torch import torch.nn as nn -import torch.nn.functional as F -from .resnet import ResNet -from .registry import register_model -from .helpers import load_pretrained -from .layers import SEModule from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from .helpers import load_pretrained +from .registry import register_model +from .resnet import ResNet __all__ = [] @@ -105,7 +103,7 @@ class Bottle2neck(nn.Module): sp = bn(sp) sp = self.relu(sp) spo.append(sp) - if self.scale > 1 : + if self.scale > 1: spo.append(self.pool(spx[-1]) if self.is_first else spx[-1]) out = torch.cat(spo, 1) diff --git a/timm/models/resnet.py b/timm/models/resnet.py index 4e865705..430bbb49 100644 --- a/timm/models/resnet.py +++ b/timm/models/resnet.py @@ -10,10 +10,10 @@ import math import torch.nn as nn import torch.nn.functional as F -from .registry import register_model +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from .helpers import load_pretrained, adapt_model_from_file from .layers import SelectAdaptivePool2d, DropBlock2d, DropPath, AvgPool2dSame, create_attn, BlurPool2d -from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from .registry import register_model __all__ = ['ResNet', 'BasicBlock', 'Bottleneck'] # model_registry will add each entrypoint fn to this @@ -377,6 +377,7 @@ class ResNet(nn.Module): global_pool : str, default 'avg' Global pooling type. One of 'avg', 'max', 'avgmax', 'catavgmax' """ + def __init__(self, block, layers, num_classes=1000, in_chans=3, cardinality=1, base_width=64, stem_width=64, stem_type='', block_reduce_first=1, down_kernel_size=1, avg_down=False, output_stride=32, @@ -482,8 +483,11 @@ class ResNet(nn.Module): def reset_classifier(self, num_classes, global_pool='avg'): self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) self.num_classes = num_classes - del self.fc - self.fc = nn.Linear(self.num_features * self.global_pool.feat_mult(), num_classes) if num_classes else None + if num_classes: + num_features = self.num_features * self.global_pool.feat_mult() + self.fc = nn.Linear(num_features, num_classes) + else: + self.fc = nn.Identity() def forward_features(self, x): x = self.conv1(x) diff --git a/timm/models/selecsls.py b/timm/models/selecsls.py index 2f369e99..7b7de369 100644 --- a/timm/models/selecsls.py +++ b/timm/models/selecsls.py @@ -9,16 +9,15 @@ https://arxiv.org/abs/1907.00837 Based on ResNet implementation in https://github.com/rwightman/pytorch-image-models and SelecSLS Net implementation in https://github.com/mehtadushy/SelecSLS-Pytorch """ -import math import torch import torch.nn as nn import torch.nn.functional as F -from .registry import register_model +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from .helpers import load_pretrained from .layers import SelectAdaptivePool2d -from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from .registry import register_model __all__ = ['SelecSLS'] # model_registry will add each entrypoint fn to this @@ -134,11 +133,11 @@ class SelecSLS(nn.Module): def reset_classifier(self, num_classes, global_pool='avg'): self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) self.num_classes = num_classes - del self.fc if num_classes: - self.fc = nn.Linear(self.num_features * self.global_pool.feat_mult(), num_classes) + num_features = self.num_features * self.global_pool.feat_mult() + self.fc = nn.Linear(num_features, num_classes) else: - self.fc = None + self.fc = nn.Identity() def forward_features(self, x): x = self.stem(x) diff --git a/timm/models/senet.py b/timm/models/senet.py index efbf4657..8594d14d 100644 --- a/timm/models/senet.py +++ b/timm/models/senet.py @@ -8,16 +8,16 @@ Original model: https://github.com/hujie-frank/SENet ResNet code gently borrowed from https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py """ -from collections import OrderedDict import math +from collections import OrderedDict import torch.nn as nn import torch.nn.functional as F -from .registry import register_model +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from .helpers import load_pretrained from .layers import SelectAdaptivePool2d -from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from .registry import register_model __all__ = ['SENet'] @@ -369,11 +369,11 @@ class SENet(nn.Module): def reset_classifier(self, num_classes, global_pool='avg'): self.num_classes = num_classes self.avg_pool = SelectAdaptivePool2d(pool_type=global_pool) - del self.last_linear if num_classes: - self.last_linear = nn.Linear(self.num_features * self.avg_pool.feat_mult(), num_classes) + num_features = self.num_features * self.avg_pool.feat_mult() + self.last_linear = nn.Linear(num_features, num_classes) else: - self.last_linear = None + self.last_linear = nn.Identity() def forward_features(self, x): x = self.layer0(x) diff --git a/timm/models/tresnet.py b/timm/models/tresnet.py index 48b3e1de..a4a980b4 100644 --- a/timm/models/tresnet.py +++ b/timm/models/tresnet.py @@ -5,14 +5,16 @@ https://arxiv.org/pdf/2003.13630.pdf Original model: https://github.com/mrT23/TResNet """ +from collections import OrderedDict from functools import partial + import torch import torch.nn as nn import torch.nn.functional as F -from collections import OrderedDict + +from .helpers import load_pretrained from .layers import SpaceToDepthModule, AntiAliasDownsampleLayer, SelectAdaptivePool2d from .registry import register_model -from .helpers import load_pretrained try: from inplace_abn import InPlaceABN @@ -88,7 +90,7 @@ class FastSEModule(nn.Module): def IABN2Float(module: nn.Module) -> nn.Module: - "If `module` is IABN don't use half precision." + """If `module` is IABN don't use half precision.""" if isinstance(module, InPlaceABN): module.float() for child in module.children(): @@ -277,8 +279,10 @@ class TResNet(nn.Module): self.num_classes = num_classes self.head = None if num_classes: - self.head = nn.Sequential(OrderedDict([ - ('fc', nn.Linear(self.num_features * self.global_pool.feat_mult(), num_classes))])) + num_features = self.num_features * self.global_pool.feat_mult() + self.head = nn.Sequential(OrderedDict([('fc', nn.Linear(num_features, num_classes))])) + else: + self.head = nn.Sequential(OrderedDict([('fc', nn.Identity())])) def forward_features(self, x): return self.body(x) diff --git a/timm/models/xception.py b/timm/models/xception.py index cb98bbc9..467b42f6 100644 --- a/timm/models/xception.py +++ b/timm/models/xception.py @@ -21,15 +21,13 @@ normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], The resize parameter of the validation transform should be 333, and make sure to center crop at 299x299 """ -import math -import torch import torch.nn as nn import torch.nn.functional as F -from .registry import register_model from .helpers import load_pretrained from .layers import SelectAdaptivePool2d +from .registry import register_model __all__ = ['Xception'] @@ -180,8 +178,11 @@ class Xception(nn.Module): def reset_classifier(self, num_classes, global_pool='avg'): self.num_classes = num_classes self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) - del self.fc - self.fc = nn.Linear(self.num_features * self.global_pool.feat_mult(), num_classes) if num_classes else None + if num_classes: + num_features = self.num_features * self.global_pool.feat_mult() + self.fc = nn.Linear(num_features, num_classes) + else: + self.fc = nn.Identity() def forward_features(self, x): x = self.conv1(x)