Replace all None by nn.Identity() in all models reset_classifier when False-values num_classes is given.

Make small code refactoring
pull/142/head
Vyacheslav Shults 4 years ago
parent 6cc11a8821
commit a7ebe09029

@ -2,17 +2,17 @@
This file is a copy of https://github.com/pytorch/vision 'densenet.py' (BSD-3-Clause) with
fixed kwargs passthrough and addition of dynamic global avg/max pool.
"""
import re
from collections import OrderedDict
import torch
import torch.nn as nn
import torch.nn.functional as F
from .registry import register_model
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .helpers import load_pretrained
from .layers import SelectAdaptivePool2d
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
import re
from .registry import register_model
__all__ = ['DenseNet']
@ -85,6 +85,7 @@ class DenseNet(nn.Module):
drop_rate (float) - dropout rate after each dense layer
num_classes (int) - number of classification classes
"""
def __init__(self, growth_rate=32, block_config=(6, 12, 24, 16),
num_init_features=64, bn_size=4, drop_rate=0,
num_classes=1000, in_chans=3, global_pool='avg'):
@ -127,8 +128,11 @@ class DenseNet(nn.Module):
def reset_classifier(self, num_classes, global_pool='avg'):
self.num_classes = num_classes
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
self.classifier = nn.Linear(
self.num_features * self.global_pool.feat_mult(), num_classes) if num_classes else None
if num_classes:
num_features = self.num_features * self.global_pool.feat_mult()
self.classifier = nn.Linear(num_features, num_classes)
else:
self.classifier = nn.Identity()
def forward_features(self, x):
x = self.features(x)
@ -157,7 +161,6 @@ def _filter_pretrained(state_dict):
return state_dict
@register_model
def densenet121(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
r"""Densenet-121 model from

@ -11,11 +11,10 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
from .registry import register_model
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .helpers import load_pretrained
from .layers import SelectAdaptivePool2d
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .registry import register_model
__all__ = ['DLA']
@ -51,6 +50,7 @@ default_cfgs = {
class DlaBasic(nn.Module):
"""DLA Basic"""
def __init__(self, inplanes, planes, stride=1, dilation=1, **_):
super(DlaBasic, self).__init__()
self.conv1 = nn.Conv2d(
@ -304,9 +304,10 @@ class DLA(nn.Module):
self.num_classes = num_classes
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
if num_classes:
self.fc = nn.Conv2d(self.num_features * self.global_pool.feat_mult(), num_classes, 1, bias=True)
num_features = self.num_features * self.global_pool.feat_mult()
self.fc = nn.Conv2d(num_features, num_classes, kernel_size=1, bias=True)
else:
self.fc = None
self.fc = nn.Identity()
def forward_features(self, x):
x = self.base_layer(x)

@ -9,16 +9,16 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from collections import OrderedDict
import torch
import torch.nn as nn
import torch.nn.functional as F
from collections import OrderedDict
from .registry import register_model
from timm.data import IMAGENET_DPN_MEAN, IMAGENET_DPN_STD
from .helpers import load_pretrained
from .layers import SelectAdaptivePool2d
from timm.data import IMAGENET_DPN_MEAN, IMAGENET_DPN_STD
from .registry import register_model
__all__ = ['DPN']
@ -218,8 +218,8 @@ class DPN(nn.Module):
# Using 1x1 conv for the FC layer to allow the extra pooling scheme
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
self.classifier = nn.Conv2d(
self.num_features * self.global_pool.feat_mult(), num_classes, kernel_size=1, bias=True)
num_features = self.num_features * self.global_pool.feat_mult()
self.classifier = nn.Conv2d(num_features, num_classes, kernel_size=1, bias=True)
def get_classifier(self):
return self.classifier
@ -228,10 +228,10 @@ class DPN(nn.Module):
self.num_classes = num_classes
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
if num_classes:
self.classifier = nn.Conv2d(
self.num_features * self.global_pool.feat_mult(), num_classes, kernel_size=1, bias=True)
num_features = self.num_features * self.global_pool.feat_mult()
self.classifier = nn.Conv2d(num_features, num_classes, kernel_size=1, bias=True)
else:
self.classifier = None
self.classifier = nn.Identity()
def forward_features(self, x):
return self.features(x)

@ -24,14 +24,12 @@ An implementation of EfficienNet that covers variety of related models with effi
Hacked together by Ross Wightman
"""
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
from .efficientnet_builder import *
from .feature_hooks import FeatureHooks
from .registry import register_model
from .helpers import load_pretrained, adapt_model_from_file
from .layers import SelectAdaptivePool2d
from timm.models.layers import create_conv2d
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
from .registry import register_model
__all__ = ['EfficientNet']
@ -373,8 +371,11 @@ class EfficientNet(nn.Module):
def reset_classifier(self, num_classes, global_pool='avg'):
self.num_classes = num_classes
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
self.classifier = nn.Linear(
self.num_features * self.global_pool.feat_mult(), num_classes) if num_classes else None
if num_classes:
num_features = self.num_features * self.global_pool.feat_mult()
self.classifier = nn.Linear(num_features, num_classes)
else:
self.classifier = nn.Identity()
def forward_features(self, x):
x = self.conv_stem(x)
@ -1187,6 +1188,7 @@ def efficientnet_cc_b0_8e(pretrained=False, **kwargs):
pretrained=pretrained, **kwargs)
return model
@register_model
def efficientnet_cc_b1_8e(pretrained=False, **kwargs):
""" EfficientNet-CondConv-B1 w/ 8 Experts """
@ -1242,8 +1244,6 @@ def efficientnet_lite4(pretrained=False, **kwargs):
return model
@register_model
def efficientnet_b1_pruned(pretrained=False, **kwargs):
""" EfficientNet-B1 Pruned. The pruning has been obtained using https://arxiv.org/pdf/2002.08258.pdf """
@ -1275,8 +1275,6 @@ def efficientnet_b3_pruned(pretrained=False, **kwargs):
return model
@register_model
def tf_efficientnet_b0(pretrained=False, **kwargs):
""" EfficientNet-B0. Tensorflow compatible variant """
@ -1619,6 +1617,7 @@ def tf_efficientnet_cc_b0_8e(pretrained=False, **kwargs):
pretrained=pretrained, **kwargs)
return model
@register_model
def tf_efficientnet_cc_b1_8e(pretrained=False, **kwargs):
""" EfficientNet-CondConv-B1 w/ 8 Experts. Tensorflow compatible variant """
@ -1764,4 +1763,3 @@ def tf_mixnet_l(pretrained=False, **kwargs):
model = _gen_mixnet_m(
'tf_mixnet_l', channel_multiplier=1.3, pretrained=pretrained, **kwargs)
return model

@ -3,17 +3,11 @@ This file evolved from https://github.com/pytorch/vision 'resnet.py' with (SE)-R
and ports of Gluon variations (https://github.com/dmlc/gluon-cv/blob/master/gluoncv/model_zoo/resnet.py)
by Ross Wightman
"""
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from .registry import register_model
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .helpers import load_pretrained
from .layers import SEModule
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .registry import register_model
from .resnet import ResNet, Bottleneck, BasicBlock
@ -202,8 +196,8 @@ def gluon_resnet50_v1e(pretrained=False, num_classes=1000, in_chans=3, **kwargs)
model = ResNet(Bottleneck, [3, 4, 6, 3], num_classes=num_classes, in_chans=in_chans,
stem_width=64, stem_type='deep', avg_down=True, **kwargs)
model.default_cfg = default_cfg
#if pretrained:
# load_pretrained(model, default_cfg, num_classes, in_chans)
if pretrained:
load_pretrained(model, default_cfg, num_classes, in_chans)
return model

@ -6,15 +6,15 @@ Original PyTorch DeepLab impl: https://github.com/jfzhang95/pytorch-deeplab-xcep
Hacked together by Ross Wightman
"""
import torch
from collections import OrderedDict
import torch.nn as nn
import torch.nn.functional as F
from collections import OrderedDict
from .registry import register_model
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .helpers import load_pretrained
from .layers import SelectAdaptivePool2d
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .registry import register_model
__all__ = ['Xception65', 'Xception71']
@ -47,7 +47,6 @@ default_cfgs = {
}
}
""" PADDING NOTES
The original PyTorch and Gluon impl of these models dutifully reproduced the
aligned padding added to Tensorflow models for Deeplab. This padding was compensating
@ -394,7 +393,11 @@ class Xception71(nn.Module):
def reset_classifier(self, num_classes, global_pool='avg'):
self.num_classes = num_classes
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
self.fc = nn.Linear(self.num_features * self.global_pool.feat_mult(), num_classes) if num_classes else None
if num_classes:
num_features = self.num_features * self.global_pool.feat_mult()
self.fc = nn.Linear(num_features, num_classes)
else:
self.fc = nn.Identity()
def forward_features(self, x):
# Entry flow
@ -465,4 +468,3 @@ def gluon_xception71(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
if pretrained:
load_pretrained(model, default_cfg, num_classes, in_chans)
return model

@ -6,10 +6,10 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
from .registry import register_model
from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
from .helpers import load_pretrained
from .layers import SelectAdaptivePool2d
from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
from .registry import register_model
__all__ = ['InceptionResnetV2']
@ -296,8 +296,11 @@ class InceptionResnetV2(nn.Module):
def reset_classifier(self, num_classes, global_pool='avg'):
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
self.num_classes = num_classes
self.classif = nn.Linear(
self.num_features * self.global_pool.feat_mult(), num_classes) if num_classes else None
if num_classes:
num_features = self.num_features * self.global_pool.feat_mult()
self.classif = nn.Linear(num_features, num_classes)
else:
self.classif = nn.Identity()
def forward_features(self, x):
x = self.conv2d_1a(x)

@ -1,7 +1,8 @@
from torchvision.models import Inception3
from .registry import register_model
from .helpers import load_pretrained
from timm.data import IMAGENET_DEFAULT_STD, IMAGENET_DEFAULT_MEAN, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
from .helpers import load_pretrained
from .registry import register_model
__all__ = []

@ -6,10 +6,10 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
from .registry import register_model
from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
from .helpers import load_pretrained
from .layers import SelectAdaptivePool2d
from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
from .registry import register_model
__all__ = ['InceptionV4']
@ -280,8 +280,11 @@ class InceptionV4(nn.Module):
def reset_classifier(self, num_classes, global_pool='avg'):
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
self.num_classes = num_classes
self.last_linear = nn.Linear(
self.num_features * self.global_pool.feat_mult(), num_classes) if num_classes else None
if num_classes:
num_features = self.num_features * self.global_pool.feat_mult()
self.last_linear = nn.Linear(num_features, num_classes)
else:
self.last_linear = nn.Identity()
def forward_features(self, x):
return self.features(x)
@ -303,6 +306,3 @@ def inception_v4(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
if pretrained:
load_pretrained(model, default_cfg, num_classes, in_chans)
return model

@ -8,13 +8,13 @@ Paper: Searching for MobileNetV3 - https://arxiv.org/abs/1905.02244
Hacked together by Ross Wightman
"""
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
from .efficientnet_builder import *
from .registry import register_model
from .feature_hooks import FeatureHooks
from .helpers import load_pretrained
from .layers import SelectAdaptivePool2d, create_conv2d
from .layers.activations import HardSwish, hard_sigmoid
from .feature_hooks import FeatureHooks
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
from .registry import register_model
__all__ = ['MobileNetV3']
@ -120,8 +120,11 @@ class MobileNetV3(nn.Module):
def reset_classifier(self, num_classes, global_pool='avg'):
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
self.num_classes = num_classes
self.classifier = nn.Linear(
self.num_features * self.global_pool.feat_mult(), num_classes) if self.num_classes else None
if num_classes:
num_features = self.num_features * self.global_pool.feat_mult()
self.classifier = nn.Linear(num_features, num_classes)
else:
self.classifier = nn.Identity()
def forward_features(self, x):
x = self.conv_stem(x)
@ -397,7 +400,6 @@ def mobilenetv3_small_075(pretrained=False, **kwargs):
@register_model
def mobilenetv3_small_100(pretrained=False, **kwargs):
print(kwargs)
""" MobileNet V3 """
model = _gen_mobilenet_v3('mobilenetv3_small_100', 1.0, pretrained=pretrained, **kwargs)
return model

@ -2,10 +2,9 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
from .registry import register_model
from .helpers import load_pretrained
from .layers import SelectAdaptivePool2d
from .registry import register_model
__all__ = ['NASNetALarge']
@ -562,9 +561,11 @@ class NASNetALarge(nn.Module):
def reset_classifier(self, num_classes, global_pool='avg'):
self.num_classes = num_classes
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
del self.last_linear
self.last_linear = nn.Linear(
self.num_features * self.global_pool.feat_mult(), num_classes) if num_classes else None
if num_classes:
num_features = self.num_features * self.global_pool.feat_mult()
self.last_linear = nn.Linear(num_features, num_classes)
else:
self.last_linear = nn.Identity()
def forward_features(self, x):
x_conv0 = self.conv0(x)

@ -6,15 +6,16 @@
"""
from __future__ import print_function, division, absolute_import
from collections import OrderedDict
import torch
import torch.nn as nn
import torch.nn.functional as F
from .registry import register_model
from .helpers import load_pretrained
from .layers import SelectAdaptivePool2d
from .registry import register_model
__all__ = ['PNASNet5Large']
@ -349,11 +350,11 @@ class PNASNet5Large(nn.Module):
def reset_classifier(self, num_classes, global_pool='avg'):
self.num_classes = num_classes
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
del self.last_linear
if num_classes:
self.last_linear = nn.Linear(self.num_features * self.global_pool.feat_mult(), num_classes)
num_features = self.num_features * self.global_pool.feat_mult()
self.last_linear = nn.Linear(num_features, num_classes)
else:
self.last_linear = None
self.last_linear = nn.Identity()
def forward_features(self, x):
x_conv_0 = self.conv_0(x)

@ -6,13 +6,11 @@ import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from .resnet import ResNet
from .registry import register_model
from .helpers import load_pretrained
from .layers import SEModule
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .helpers import load_pretrained
from .registry import register_model
from .resnet import ResNet
__all__ = []

@ -10,10 +10,10 @@ import math
import torch.nn as nn
import torch.nn.functional as F
from .registry import register_model
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .helpers import load_pretrained, adapt_model_from_file
from .layers import SelectAdaptivePool2d, DropBlock2d, DropPath, AvgPool2dSame, create_attn, BlurPool2d
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .registry import register_model
__all__ = ['ResNet', 'BasicBlock', 'Bottleneck'] # model_registry will add each entrypoint fn to this
@ -377,6 +377,7 @@ class ResNet(nn.Module):
global_pool : str, default 'avg'
Global pooling type. One of 'avg', 'max', 'avgmax', 'catavgmax'
"""
def __init__(self, block, layers, num_classes=1000, in_chans=3,
cardinality=1, base_width=64, stem_width=64, stem_type='',
block_reduce_first=1, down_kernel_size=1, avg_down=False, output_stride=32,
@ -482,8 +483,11 @@ class ResNet(nn.Module):
def reset_classifier(self, num_classes, global_pool='avg'):
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
self.num_classes = num_classes
del self.fc
self.fc = nn.Linear(self.num_features * self.global_pool.feat_mult(), num_classes) if num_classes else None
if num_classes:
num_features = self.num_features * self.global_pool.feat_mult()
self.fc = nn.Linear(num_features, num_classes)
else:
self.fc = nn.Identity()
def forward_features(self, x):
x = self.conv1(x)

@ -9,16 +9,15 @@ https://arxiv.org/abs/1907.00837
Based on ResNet implementation in https://github.com/rwightman/pytorch-image-models
and SelecSLS Net implementation in https://github.com/mehtadushy/SelecSLS-Pytorch
"""
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from .registry import register_model
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .helpers import load_pretrained
from .layers import SelectAdaptivePool2d
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .registry import register_model
__all__ = ['SelecSLS'] # model_registry will add each entrypoint fn to this
@ -134,11 +133,11 @@ class SelecSLS(nn.Module):
def reset_classifier(self, num_classes, global_pool='avg'):
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
self.num_classes = num_classes
del self.fc
if num_classes:
self.fc = nn.Linear(self.num_features * self.global_pool.feat_mult(), num_classes)
num_features = self.num_features * self.global_pool.feat_mult()
self.fc = nn.Linear(num_features, num_classes)
else:
self.fc = None
self.fc = nn.Identity()
def forward_features(self, x):
x = self.stem(x)

@ -8,16 +8,16 @@ Original model: https://github.com/hujie-frank/SENet
ResNet code gently borrowed from
https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py
"""
from collections import OrderedDict
import math
from collections import OrderedDict
import torch.nn as nn
import torch.nn.functional as F
from .registry import register_model
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .helpers import load_pretrained
from .layers import SelectAdaptivePool2d
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .registry import register_model
__all__ = ['SENet']
@ -369,11 +369,11 @@ class SENet(nn.Module):
def reset_classifier(self, num_classes, global_pool='avg'):
self.num_classes = num_classes
self.avg_pool = SelectAdaptivePool2d(pool_type=global_pool)
del self.last_linear
if num_classes:
self.last_linear = nn.Linear(self.num_features * self.avg_pool.feat_mult(), num_classes)
num_features = self.num_features * self.avg_pool.feat_mult()
self.last_linear = nn.Linear(num_features, num_classes)
else:
self.last_linear = None
self.last_linear = nn.Identity()
def forward_features(self, x):
x = self.layer0(x)

@ -5,14 +5,16 @@ https://arxiv.org/pdf/2003.13630.pdf
Original model: https://github.com/mrT23/TResNet
"""
from collections import OrderedDict
from functools import partial
import torch
import torch.nn as nn
import torch.nn.functional as F
from collections import OrderedDict
from .helpers import load_pretrained
from .layers import SpaceToDepthModule, AntiAliasDownsampleLayer, SelectAdaptivePool2d
from .registry import register_model
from .helpers import load_pretrained
try:
from inplace_abn import InPlaceABN
@ -88,7 +90,7 @@ class FastSEModule(nn.Module):
def IABN2Float(module: nn.Module) -> nn.Module:
"If `module` is IABN don't use half precision."
"""If `module` is IABN don't use half precision."""
if isinstance(module, InPlaceABN):
module.float()
for child in module.children():
@ -277,8 +279,10 @@ class TResNet(nn.Module):
self.num_classes = num_classes
self.head = None
if num_classes:
self.head = nn.Sequential(OrderedDict([
('fc', nn.Linear(self.num_features * self.global_pool.feat_mult(), num_classes))]))
num_features = self.num_features * self.global_pool.feat_mult()
self.head = nn.Sequential(OrderedDict([('fc', nn.Linear(num_features, num_classes))]))
else:
self.head = nn.Sequential(OrderedDict([('fc', nn.Identity())]))
def forward_features(self, x):
return self.body(x)

@ -21,15 +21,13 @@ normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5],
The resize parameter of the validation transform should be 333, and make sure to center crop at 299x299
"""
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from .registry import register_model
from .helpers import load_pretrained
from .layers import SelectAdaptivePool2d
from .registry import register_model
__all__ = ['Xception']
@ -180,8 +178,11 @@ class Xception(nn.Module):
def reset_classifier(self, num_classes, global_pool='avg'):
self.num_classes = num_classes
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
del self.fc
self.fc = nn.Linear(self.num_features * self.global_pool.feat_mult(), num_classes) if num_classes else None
if num_classes:
num_features = self.num_features * self.global_pool.feat_mult()
self.fc = nn.Linear(num_features, num_classes)
else:
self.fc = nn.Identity()
def forward_features(self, x):
x = self.conv1(x)

Loading…
Cancel
Save