Replace all None by nn.Identity() in all models reset_classifier when False-values num_classes is given.

Make small code refactoring
pull/142/head
Vyacheslav Shults 5 years ago
parent 6cc11a8821
commit a7ebe09029

@ -2,17 +2,17 @@
This file is a copy of https://github.com/pytorch/vision 'densenet.py' (BSD-3-Clause) with This file is a copy of https://github.com/pytorch/vision 'densenet.py' (BSD-3-Clause) with
fixed kwargs passthrough and addition of dynamic global avg/max pool. fixed kwargs passthrough and addition of dynamic global avg/max pool.
""" """
import re
from collections import OrderedDict from collections import OrderedDict
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from .registry import register_model from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .helpers import load_pretrained from .helpers import load_pretrained
from .layers import SelectAdaptivePool2d from .layers import SelectAdaptivePool2d
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from .registry import register_model
import re
__all__ = ['DenseNet'] __all__ = ['DenseNet']
@ -85,6 +85,7 @@ class DenseNet(nn.Module):
drop_rate (float) - dropout rate after each dense layer drop_rate (float) - dropout rate after each dense layer
num_classes (int) - number of classification classes num_classes (int) - number of classification classes
""" """
def __init__(self, growth_rate=32, block_config=(6, 12, 24, 16), def __init__(self, growth_rate=32, block_config=(6, 12, 24, 16),
num_init_features=64, bn_size=4, drop_rate=0, num_init_features=64, bn_size=4, drop_rate=0,
num_classes=1000, in_chans=3, global_pool='avg'): num_classes=1000, in_chans=3, global_pool='avg'):
@ -127,8 +128,11 @@ class DenseNet(nn.Module):
def reset_classifier(self, num_classes, global_pool='avg'): def reset_classifier(self, num_classes, global_pool='avg'):
self.num_classes = num_classes self.num_classes = num_classes
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
self.classifier = nn.Linear( if num_classes:
self.num_features * self.global_pool.feat_mult(), num_classes) if num_classes else None num_features = self.num_features * self.global_pool.feat_mult()
self.classifier = nn.Linear(num_features, num_classes)
else:
self.classifier = nn.Identity()
def forward_features(self, x): def forward_features(self, x):
x = self.features(x) x = self.features(x)
@ -157,7 +161,6 @@ def _filter_pretrained(state_dict):
return state_dict return state_dict
@register_model @register_model
def densenet121(pretrained=False, num_classes=1000, in_chans=3, **kwargs): def densenet121(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
r"""Densenet-121 model from r"""Densenet-121 model from

@ -11,11 +11,10 @@ import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from .registry import register_model from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .helpers import load_pretrained from .helpers import load_pretrained
from .layers import SelectAdaptivePool2d from .layers import SelectAdaptivePool2d
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from .registry import register_model
__all__ = ['DLA'] __all__ = ['DLA']
@ -51,6 +50,7 @@ default_cfgs = {
class DlaBasic(nn.Module): class DlaBasic(nn.Module):
"""DLA Basic""" """DLA Basic"""
def __init__(self, inplanes, planes, stride=1, dilation=1, **_): def __init__(self, inplanes, planes, stride=1, dilation=1, **_):
super(DlaBasic, self).__init__() super(DlaBasic, self).__init__()
self.conv1 = nn.Conv2d( self.conv1 = nn.Conv2d(
@ -170,7 +170,7 @@ class DlaBottle2neck(nn.Module):
sp = bn(sp) sp = bn(sp)
sp = self.relu(sp) sp = self.relu(sp)
spo.append(sp) spo.append(sp)
if self.scale > 1 : if self.scale > 1:
spo.append(self.pool(spx[-1]) if self.is_first else spx[-1]) spo.append(self.pool(spx[-1]) if self.is_first else spx[-1])
out = torch.cat(spo, 1) out = torch.cat(spo, 1)
@ -304,9 +304,10 @@ class DLA(nn.Module):
self.num_classes = num_classes self.num_classes = num_classes
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
if num_classes: if num_classes:
self.fc = nn.Conv2d(self.num_features * self.global_pool.feat_mult(), num_classes, 1, bias=True) num_features = self.num_features * self.global_pool.feat_mult()
self.fc = nn.Conv2d(num_features, num_classes, kernel_size=1, bias=True)
else: else:
self.fc = None self.fc = nn.Identity()
def forward_features(self, x): def forward_features(self, x):
x = self.base_layer(x) x = self.base_layer(x)

@ -9,16 +9,16 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
from collections import OrderedDict
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from collections import OrderedDict
from .registry import register_model from timm.data import IMAGENET_DPN_MEAN, IMAGENET_DPN_STD
from .helpers import load_pretrained from .helpers import load_pretrained
from .layers import SelectAdaptivePool2d from .layers import SelectAdaptivePool2d
from timm.data import IMAGENET_DPN_MEAN, IMAGENET_DPN_STD from .registry import register_model
__all__ = ['DPN'] __all__ = ['DPN']
@ -218,8 +218,8 @@ class DPN(nn.Module):
# Using 1x1 conv for the FC layer to allow the extra pooling scheme # Using 1x1 conv for the FC layer to allow the extra pooling scheme
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
self.classifier = nn.Conv2d( num_features = self.num_features * self.global_pool.feat_mult()
self.num_features * self.global_pool.feat_mult(), num_classes, kernel_size=1, bias=True) self.classifier = nn.Conv2d(num_features, num_classes, kernel_size=1, bias=True)
def get_classifier(self): def get_classifier(self):
return self.classifier return self.classifier
@ -228,10 +228,10 @@ class DPN(nn.Module):
self.num_classes = num_classes self.num_classes = num_classes
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
if num_classes: if num_classes:
self.classifier = nn.Conv2d( num_features = self.num_features * self.global_pool.feat_mult()
self.num_features * self.global_pool.feat_mult(), num_classes, kernel_size=1, bias=True) self.classifier = nn.Conv2d(num_features, num_classes, kernel_size=1, bias=True)
else: else:
self.classifier = None self.classifier = nn.Identity()
def forward_features(self, x): def forward_features(self, x):
return self.features(x) return self.features(x)

@ -24,14 +24,12 @@ An implementation of EfficienNet that covers variety of related models with effi
Hacked together by Ross Wightman Hacked together by Ross Wightman
""" """
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
from .efficientnet_builder import * from .efficientnet_builder import *
from .feature_hooks import FeatureHooks from .feature_hooks import FeatureHooks
from .registry import register_model
from .helpers import load_pretrained, adapt_model_from_file from .helpers import load_pretrained, adapt_model_from_file
from .layers import SelectAdaptivePool2d from .layers import SelectAdaptivePool2d
from timm.models.layers import create_conv2d from .registry import register_model
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
__all__ = ['EfficientNet'] __all__ = ['EfficientNet']
@ -373,8 +371,11 @@ class EfficientNet(nn.Module):
def reset_classifier(self, num_classes, global_pool='avg'): def reset_classifier(self, num_classes, global_pool='avg'):
self.num_classes = num_classes self.num_classes = num_classes
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
self.classifier = nn.Linear( if num_classes:
self.num_features * self.global_pool.feat_mult(), num_classes) if num_classes else None num_features = self.num_features * self.global_pool.feat_mult()
self.classifier = nn.Linear(num_features, num_classes)
else:
self.classifier = nn.Identity()
def forward_features(self, x): def forward_features(self, x):
x = self.conv_stem(x) x = self.conv_stem(x)
@ -1187,6 +1188,7 @@ def efficientnet_cc_b0_8e(pretrained=False, **kwargs):
pretrained=pretrained, **kwargs) pretrained=pretrained, **kwargs)
return model return model
@register_model @register_model
def efficientnet_cc_b1_8e(pretrained=False, **kwargs): def efficientnet_cc_b1_8e(pretrained=False, **kwargs):
""" EfficientNet-CondConv-B1 w/ 8 Experts """ """ EfficientNet-CondConv-B1 w/ 8 Experts """
@ -1242,8 +1244,6 @@ def efficientnet_lite4(pretrained=False, **kwargs):
return model return model
@register_model @register_model
def efficientnet_b1_pruned(pretrained=False, **kwargs): def efficientnet_b1_pruned(pretrained=False, **kwargs):
""" EfficientNet-B1 Pruned. The pruning has been obtained using https://arxiv.org/pdf/2002.08258.pdf """ """ EfficientNet-B1 Pruned. The pruning has been obtained using https://arxiv.org/pdf/2002.08258.pdf """
@ -1275,8 +1275,6 @@ def efficientnet_b3_pruned(pretrained=False, **kwargs):
return model return model
@register_model @register_model
def tf_efficientnet_b0(pretrained=False, **kwargs): def tf_efficientnet_b0(pretrained=False, **kwargs):
""" EfficientNet-B0. Tensorflow compatible variant """ """ EfficientNet-B0. Tensorflow compatible variant """
@ -1619,6 +1617,7 @@ def tf_efficientnet_cc_b0_8e(pretrained=False, **kwargs):
pretrained=pretrained, **kwargs) pretrained=pretrained, **kwargs)
return model return model
@register_model @register_model
def tf_efficientnet_cc_b1_8e(pretrained=False, **kwargs): def tf_efficientnet_cc_b1_8e(pretrained=False, **kwargs):
""" EfficientNet-CondConv-B1 w/ 8 Experts. Tensorflow compatible variant """ """ EfficientNet-CondConv-B1 w/ 8 Experts. Tensorflow compatible variant """
@ -1764,4 +1763,3 @@ def tf_mixnet_l(pretrained=False, **kwargs):
model = _gen_mixnet_m( model = _gen_mixnet_m(
'tf_mixnet_l', channel_multiplier=1.3, pretrained=pretrained, **kwargs) 'tf_mixnet_l', channel_multiplier=1.3, pretrained=pretrained, **kwargs)
return model return model

@ -3,17 +3,11 @@ This file evolved from https://github.com/pytorch/vision 'resnet.py' with (SE)-R
and ports of Gluon variations (https://github.com/dmlc/gluon-cv/blob/master/gluoncv/model_zoo/resnet.py) and ports of Gluon variations (https://github.com/dmlc/gluon-cv/blob/master/gluoncv/model_zoo/resnet.py)
by Ross Wightman by Ross Wightman
""" """
import math
import torch from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
import torch.nn as nn
import torch.nn.functional as F
from .registry import register_model
from .helpers import load_pretrained from .helpers import load_pretrained
from .layers import SEModule from .layers import SEModule
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from .registry import register_model
from .resnet import ResNet, Bottleneck, BasicBlock from .resnet import ResNet, Bottleneck, BasicBlock
@ -202,8 +196,8 @@ def gluon_resnet50_v1e(pretrained=False, num_classes=1000, in_chans=3, **kwargs)
model = ResNet(Bottleneck, [3, 4, 6, 3], num_classes=num_classes, in_chans=in_chans, model = ResNet(Bottleneck, [3, 4, 6, 3], num_classes=num_classes, in_chans=in_chans,
stem_width=64, stem_type='deep', avg_down=True, **kwargs) stem_width=64, stem_type='deep', avg_down=True, **kwargs)
model.default_cfg = default_cfg model.default_cfg = default_cfg
#if pretrained: if pretrained:
# load_pretrained(model, default_cfg, num_classes, in_chans) load_pretrained(model, default_cfg, num_classes, in_chans)
return model return model

@ -6,15 +6,15 @@ Original PyTorch DeepLab impl: https://github.com/jfzhang95/pytorch-deeplab-xcep
Hacked together by Ross Wightman Hacked together by Ross Wightman
""" """
import torch from collections import OrderedDict
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from collections import OrderedDict
from .registry import register_model from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .helpers import load_pretrained from .helpers import load_pretrained
from .layers import SelectAdaptivePool2d from .layers import SelectAdaptivePool2d
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from .registry import register_model
__all__ = ['Xception65', 'Xception71'] __all__ = ['Xception65', 'Xception71']
@ -47,7 +47,6 @@ default_cfgs = {
} }
} }
""" PADDING NOTES """ PADDING NOTES
The original PyTorch and Gluon impl of these models dutifully reproduced the The original PyTorch and Gluon impl of these models dutifully reproduced the
aligned padding added to Tensorflow models for Deeplab. This padding was compensating aligned padding added to Tensorflow models for Deeplab. This padding was compensating
@ -394,7 +393,11 @@ class Xception71(nn.Module):
def reset_classifier(self, num_classes, global_pool='avg'): def reset_classifier(self, num_classes, global_pool='avg'):
self.num_classes = num_classes self.num_classes = num_classes
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
self.fc = nn.Linear(self.num_features * self.global_pool.feat_mult(), num_classes) if num_classes else None if num_classes:
num_features = self.num_features * self.global_pool.feat_mult()
self.fc = nn.Linear(num_features, num_classes)
else:
self.fc = nn.Identity()
def forward_features(self, x): def forward_features(self, x):
# Entry flow # Entry flow
@ -465,4 +468,3 @@ def gluon_xception71(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
if pretrained: if pretrained:
load_pretrained(model, default_cfg, num_classes, in_chans) load_pretrained(model, default_cfg, num_classes, in_chans)
return model return model

@ -6,10 +6,10 @@ import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from .registry import register_model from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
from .helpers import load_pretrained from .helpers import load_pretrained
from .layers import SelectAdaptivePool2d from .layers import SelectAdaptivePool2d
from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD from .registry import register_model
__all__ = ['InceptionResnetV2'] __all__ = ['InceptionResnetV2']
@ -296,8 +296,11 @@ class InceptionResnetV2(nn.Module):
def reset_classifier(self, num_classes, global_pool='avg'): def reset_classifier(self, num_classes, global_pool='avg'):
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
self.num_classes = num_classes self.num_classes = num_classes
self.classif = nn.Linear( if num_classes:
self.num_features * self.global_pool.feat_mult(), num_classes) if num_classes else None num_features = self.num_features * self.global_pool.feat_mult()
self.classif = nn.Linear(num_features, num_classes)
else:
self.classif = nn.Identity()
def forward_features(self, x): def forward_features(self, x):
x = self.conv2d_1a(x) x = self.conv2d_1a(x)

@ -1,7 +1,8 @@
from torchvision.models import Inception3 from torchvision.models import Inception3
from .registry import register_model
from .helpers import load_pretrained
from timm.data import IMAGENET_DEFAULT_STD, IMAGENET_DEFAULT_MEAN, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD from timm.data import IMAGENET_DEFAULT_STD, IMAGENET_DEFAULT_MEAN, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
from .helpers import load_pretrained
from .registry import register_model
__all__ = [] __all__ = []

@ -6,10 +6,10 @@ import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from .registry import register_model from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
from .helpers import load_pretrained from .helpers import load_pretrained
from .layers import SelectAdaptivePool2d from .layers import SelectAdaptivePool2d
from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD from .registry import register_model
__all__ = ['InceptionV4'] __all__ = ['InceptionV4']
@ -280,8 +280,11 @@ class InceptionV4(nn.Module):
def reset_classifier(self, num_classes, global_pool='avg'): def reset_classifier(self, num_classes, global_pool='avg'):
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
self.num_classes = num_classes self.num_classes = num_classes
self.last_linear = nn.Linear( if num_classes:
self.num_features * self.global_pool.feat_mult(), num_classes) if num_classes else None num_features = self.num_features * self.global_pool.feat_mult()
self.last_linear = nn.Linear(num_features, num_classes)
else:
self.last_linear = nn.Identity()
def forward_features(self, x): def forward_features(self, x):
return self.features(x) return self.features(x)
@ -303,6 +306,3 @@ def inception_v4(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
if pretrained: if pretrained:
load_pretrained(model, default_cfg, num_classes, in_chans) load_pretrained(model, default_cfg, num_classes, in_chans)
return model return model

@ -8,13 +8,13 @@ Paper: Searching for MobileNetV3 - https://arxiv.org/abs/1905.02244
Hacked together by Ross Wightman Hacked together by Ross Wightman
""" """
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
from .efficientnet_builder import * from .efficientnet_builder import *
from .registry import register_model from .feature_hooks import FeatureHooks
from .helpers import load_pretrained from .helpers import load_pretrained
from .layers import SelectAdaptivePool2d, create_conv2d from .layers import SelectAdaptivePool2d, create_conv2d
from .layers.activations import HardSwish, hard_sigmoid from .layers.activations import HardSwish, hard_sigmoid
from .feature_hooks import FeatureHooks from .registry import register_model
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
__all__ = ['MobileNetV3'] __all__ = ['MobileNetV3']
@ -120,8 +120,11 @@ class MobileNetV3(nn.Module):
def reset_classifier(self, num_classes, global_pool='avg'): def reset_classifier(self, num_classes, global_pool='avg'):
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
self.num_classes = num_classes self.num_classes = num_classes
self.classifier = nn.Linear( if num_classes:
self.num_features * self.global_pool.feat_mult(), num_classes) if self.num_classes else None num_features = self.num_features * self.global_pool.feat_mult()
self.classifier = nn.Linear(num_features, num_classes)
else:
self.classifier = nn.Identity()
def forward_features(self, x): def forward_features(self, x):
x = self.conv_stem(x) x = self.conv_stem(x)
@ -397,7 +400,6 @@ def mobilenetv3_small_075(pretrained=False, **kwargs):
@register_model @register_model
def mobilenetv3_small_100(pretrained=False, **kwargs): def mobilenetv3_small_100(pretrained=False, **kwargs):
print(kwargs)
""" MobileNet V3 """ """ MobileNet V3 """
model = _gen_mobilenet_v3('mobilenetv3_small_100', 1.0, pretrained=pretrained, **kwargs) model = _gen_mobilenet_v3('mobilenetv3_small_100', 1.0, pretrained=pretrained, **kwargs)
return model return model

@ -2,10 +2,9 @@ import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from .registry import register_model
from .helpers import load_pretrained from .helpers import load_pretrained
from .layers import SelectAdaptivePool2d from .layers import SelectAdaptivePool2d
from .registry import register_model
__all__ = ['NASNetALarge'] __all__ = ['NASNetALarge']
@ -187,17 +186,17 @@ class CellStem1(nn.Module):
self.stem_size = stem_size self.stem_size = stem_size
self.conv_1x1 = nn.Sequential() self.conv_1x1 = nn.Sequential()
self.conv_1x1.add_module('relu', nn.ReLU()) self.conv_1x1.add_module('relu', nn.ReLU())
self.conv_1x1.add_module('conv', nn.Conv2d(2*self.num_channels, self.num_channels, 1, stride=1, bias=False)) self.conv_1x1.add_module('conv', nn.Conv2d(2 * self.num_channels, self.num_channels, 1, stride=1, bias=False))
self.conv_1x1.add_module('bn', nn.BatchNorm2d(self.num_channels, eps=0.001, momentum=0.1, affine=True)) self.conv_1x1.add_module('bn', nn.BatchNorm2d(self.num_channels, eps=0.001, momentum=0.1, affine=True))
self.relu = nn.ReLU() self.relu = nn.ReLU()
self.path_1 = nn.Sequential() self.path_1 = nn.Sequential()
self.path_1.add_module('avgpool', nn.AvgPool2d(1, stride=2, count_include_pad=False)) self.path_1.add_module('avgpool', nn.AvgPool2d(1, stride=2, count_include_pad=False))
self.path_1.add_module('conv', nn.Conv2d(self.stem_size, self.num_channels//2, 1, stride=1, bias=False)) self.path_1.add_module('conv', nn.Conv2d(self.stem_size, self.num_channels // 2, 1, stride=1, bias=False))
self.path_2 = nn.ModuleList() self.path_2 = nn.ModuleList()
self.path_2.add_module('pad', nn.ZeroPad2d((0, 1, 0, 1))) self.path_2.add_module('pad', nn.ZeroPad2d((0, 1, 0, 1)))
self.path_2.add_module('avgpool', nn.AvgPool2d(1, stride=2, count_include_pad=False)) self.path_2.add_module('avgpool', nn.AvgPool2d(1, stride=2, count_include_pad=False))
self.path_2.add_module('conv', nn.Conv2d(self.stem_size, self.num_channels//2, 1, stride=1, bias=False)) self.path_2.add_module('conv', nn.Conv2d(self.stem_size, self.num_channels // 2, 1, stride=1, bias=False))
self.final_path_bn = nn.BatchNorm2d(self.num_channels, eps=0.001, momentum=0.1, affine=True) self.final_path_bn = nn.BatchNorm2d(self.num_channels, eps=0.001, momentum=0.1, affine=True)
@ -507,50 +506,50 @@ class NASNetALarge(nn.Module):
self.cell_stem_0 = CellStem0(self.stem_size, num_channels=channels // (channel_multiplier ** 2)) self.cell_stem_0 = CellStem0(self.stem_size, num_channels=channels // (channel_multiplier ** 2))
self.cell_stem_1 = CellStem1(self.stem_size, num_channels=channels // channel_multiplier) self.cell_stem_1 = CellStem1(self.stem_size, num_channels=channels // channel_multiplier)
self.cell_0 = FirstCell(in_channels_left=channels, out_channels_left=channels//2, self.cell_0 = FirstCell(in_channels_left=channels, out_channels_left=channels // 2,
in_channels_right=2*channels, out_channels_right=channels) in_channels_right=2 * channels, out_channels_right=channels)
self.cell_1 = NormalCell(in_channels_left=2*channels, out_channels_left=channels, self.cell_1 = NormalCell(in_channels_left=2 * channels, out_channels_left=channels,
in_channels_right=6*channels, out_channels_right=channels) in_channels_right=6 * channels, out_channels_right=channels)
self.cell_2 = NormalCell(in_channels_left=6*channels, out_channels_left=channels, self.cell_2 = NormalCell(in_channels_left=6 * channels, out_channels_left=channels,
in_channels_right=6*channels, out_channels_right=channels) in_channels_right=6 * channels, out_channels_right=channels)
self.cell_3 = NormalCell(in_channels_left=6*channels, out_channels_left=channels, self.cell_3 = NormalCell(in_channels_left=6 * channels, out_channels_left=channels,
in_channels_right=6*channels, out_channels_right=channels) in_channels_right=6 * channels, out_channels_right=channels)
self.cell_4 = NormalCell(in_channels_left=6*channels, out_channels_left=channels, self.cell_4 = NormalCell(in_channels_left=6 * channels, out_channels_left=channels,
in_channels_right=6*channels, out_channels_right=channels) in_channels_right=6 * channels, out_channels_right=channels)
self.cell_5 = NormalCell(in_channels_left=6*channels, out_channels_left=channels, self.cell_5 = NormalCell(in_channels_left=6 * channels, out_channels_left=channels,
in_channels_right=6*channels, out_channels_right=channels) in_channels_right=6 * channels, out_channels_right=channels)
self.reduction_cell_0 = ReductionCell0(in_channels_left=6*channels, out_channels_left=2*channels, self.reduction_cell_0 = ReductionCell0(in_channels_left=6 * channels, out_channels_left=2 * channels,
in_channels_right=6*channels, out_channels_right=2*channels) in_channels_right=6 * channels, out_channels_right=2 * channels)
self.cell_6 = FirstCell(in_channels_left=6*channels, out_channels_left=channels, self.cell_6 = FirstCell(in_channels_left=6 * channels, out_channels_left=channels,
in_channels_right=8*channels, out_channels_right=2*channels) in_channels_right=8 * channels, out_channels_right=2 * channels)
self.cell_7 = NormalCell(in_channels_left=8*channels, out_channels_left=2*channels, self.cell_7 = NormalCell(in_channels_left=8 * channels, out_channels_left=2 * channels,
in_channels_right=12*channels, out_channels_right=2*channels) in_channels_right=12 * channels, out_channels_right=2 * channels)
self.cell_8 = NormalCell(in_channels_left=12*channels, out_channels_left=2*channels, self.cell_8 = NormalCell(in_channels_left=12 * channels, out_channels_left=2 * channels,
in_channels_right=12*channels, out_channels_right=2*channels) in_channels_right=12 * channels, out_channels_right=2 * channels)
self.cell_9 = NormalCell(in_channels_left=12*channels, out_channels_left=2*channels, self.cell_9 = NormalCell(in_channels_left=12 * channels, out_channels_left=2 * channels,
in_channels_right=12*channels, out_channels_right=2*channels) in_channels_right=12 * channels, out_channels_right=2 * channels)
self.cell_10 = NormalCell(in_channels_left=12*channels, out_channels_left=2*channels, self.cell_10 = NormalCell(in_channels_left=12 * channels, out_channels_left=2 * channels,
in_channels_right=12*channels, out_channels_right=2*channels) in_channels_right=12 * channels, out_channels_right=2 * channels)
self.cell_11 = NormalCell(in_channels_left=12*channels, out_channels_left=2*channels, self.cell_11 = NormalCell(in_channels_left=12 * channels, out_channels_left=2 * channels,
in_channels_right=12*channels, out_channels_right=2*channels) in_channels_right=12 * channels, out_channels_right=2 * channels)
self.reduction_cell_1 = ReductionCell1(in_channels_left=12*channels, out_channels_left=4*channels, self.reduction_cell_1 = ReductionCell1(in_channels_left=12 * channels, out_channels_left=4 * channels,
in_channels_right=12*channels, out_channels_right=4*channels) in_channels_right=12 * channels, out_channels_right=4 * channels)
self.cell_12 = FirstCell(in_channels_left=12*channels, out_channels_left=2*channels, self.cell_12 = FirstCell(in_channels_left=12 * channels, out_channels_left=2 * channels,
in_channels_right=16*channels, out_channels_right=4*channels) in_channels_right=16 * channels, out_channels_right=4 * channels)
self.cell_13 = NormalCell(in_channels_left=16*channels, out_channels_left=4*channels, self.cell_13 = NormalCell(in_channels_left=16 * channels, out_channels_left=4 * channels,
in_channels_right=24*channels, out_channels_right=4*channels) in_channels_right=24 * channels, out_channels_right=4 * channels)
self.cell_14 = NormalCell(in_channels_left=24*channels, out_channels_left=4*channels, self.cell_14 = NormalCell(in_channels_left=24 * channels, out_channels_left=4 * channels,
in_channels_right=24*channels, out_channels_right=4*channels) in_channels_right=24 * channels, out_channels_right=4 * channels)
self.cell_15 = NormalCell(in_channels_left=24*channels, out_channels_left=4*channels, self.cell_15 = NormalCell(in_channels_left=24 * channels, out_channels_left=4 * channels,
in_channels_right=24*channels, out_channels_right=4*channels) in_channels_right=24 * channels, out_channels_right=4 * channels)
self.cell_16 = NormalCell(in_channels_left=24*channels, out_channels_left=4*channels, self.cell_16 = NormalCell(in_channels_left=24 * channels, out_channels_left=4 * channels,
in_channels_right=24*channels, out_channels_right=4*channels) in_channels_right=24 * channels, out_channels_right=4 * channels)
self.cell_17 = NormalCell(in_channels_left=24*channels, out_channels_left=4*channels, self.cell_17 = NormalCell(in_channels_left=24 * channels, out_channels_left=4 * channels,
in_channels_right=24*channels, out_channels_right=4*channels) in_channels_right=24 * channels, out_channels_right=4 * channels)
self.relu = nn.ReLU() self.relu = nn.ReLU()
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
@ -562,9 +561,11 @@ class NASNetALarge(nn.Module):
def reset_classifier(self, num_classes, global_pool='avg'): def reset_classifier(self, num_classes, global_pool='avg'):
self.num_classes = num_classes self.num_classes = num_classes
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
del self.last_linear if num_classes:
self.last_linear = nn.Linear( num_features = self.num_features * self.global_pool.feat_mult()
self.num_features * self.global_pool.feat_mult(), num_classes) if num_classes else None self.last_linear = nn.Linear(num_features, num_classes)
else:
self.last_linear = nn.Identity()
def forward_features(self, x): def forward_features(self, x):
x_conv0 = self.conv0(x) x_conv0 = self.conv0(x)

@ -6,15 +6,16 @@
""" """
from __future__ import print_function, division, absolute_import from __future__ import print_function, division, absolute_import
from collections import OrderedDict from collections import OrderedDict
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from .registry import register_model
from .helpers import load_pretrained from .helpers import load_pretrained
from .layers import SelectAdaptivePool2d from .layers import SelectAdaptivePool2d
from .registry import register_model
__all__ = ['PNASNet5Large'] __all__ = ['PNASNet5Large']
@ -349,11 +350,11 @@ class PNASNet5Large(nn.Module):
def reset_classifier(self, num_classes, global_pool='avg'): def reset_classifier(self, num_classes, global_pool='avg'):
self.num_classes = num_classes self.num_classes = num_classes
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
del self.last_linear
if num_classes: if num_classes:
self.last_linear = nn.Linear(self.num_features * self.global_pool.feat_mult(), num_classes) num_features = self.num_features * self.global_pool.feat_mult()
self.last_linear = nn.Linear(num_features, num_classes)
else: else:
self.last_linear = None self.last_linear = nn.Identity()
def forward_features(self, x): def forward_features(self, x):
x_conv_0 = self.conv_0(x) x_conv_0 = self.conv_0(x)

@ -6,13 +6,11 @@ import math
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F
from .resnet import ResNet
from .registry import register_model
from .helpers import load_pretrained
from .layers import SEModule
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .helpers import load_pretrained
from .registry import register_model
from .resnet import ResNet
__all__ = [] __all__ = []
@ -105,7 +103,7 @@ class Bottle2neck(nn.Module):
sp = bn(sp) sp = bn(sp)
sp = self.relu(sp) sp = self.relu(sp)
spo.append(sp) spo.append(sp)
if self.scale > 1 : if self.scale > 1:
spo.append(self.pool(spx[-1]) if self.is_first else spx[-1]) spo.append(self.pool(spx[-1]) if self.is_first else spx[-1])
out = torch.cat(spo, 1) out = torch.cat(spo, 1)

@ -10,10 +10,10 @@ import math
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from .registry import register_model from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .helpers import load_pretrained, adapt_model_from_file from .helpers import load_pretrained, adapt_model_from_file
from .layers import SelectAdaptivePool2d, DropBlock2d, DropPath, AvgPool2dSame, create_attn, BlurPool2d from .layers import SelectAdaptivePool2d, DropBlock2d, DropPath, AvgPool2dSame, create_attn, BlurPool2d
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from .registry import register_model
__all__ = ['ResNet', 'BasicBlock', 'Bottleneck'] # model_registry will add each entrypoint fn to this __all__ = ['ResNet', 'BasicBlock', 'Bottleneck'] # model_registry will add each entrypoint fn to this
@ -377,6 +377,7 @@ class ResNet(nn.Module):
global_pool : str, default 'avg' global_pool : str, default 'avg'
Global pooling type. One of 'avg', 'max', 'avgmax', 'catavgmax' Global pooling type. One of 'avg', 'max', 'avgmax', 'catavgmax'
""" """
def __init__(self, block, layers, num_classes=1000, in_chans=3, def __init__(self, block, layers, num_classes=1000, in_chans=3,
cardinality=1, base_width=64, stem_width=64, stem_type='', cardinality=1, base_width=64, stem_width=64, stem_type='',
block_reduce_first=1, down_kernel_size=1, avg_down=False, output_stride=32, block_reduce_first=1, down_kernel_size=1, avg_down=False, output_stride=32,
@ -482,8 +483,11 @@ class ResNet(nn.Module):
def reset_classifier(self, num_classes, global_pool='avg'): def reset_classifier(self, num_classes, global_pool='avg'):
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
self.num_classes = num_classes self.num_classes = num_classes
del self.fc if num_classes:
self.fc = nn.Linear(self.num_features * self.global_pool.feat_mult(), num_classes) if num_classes else None num_features = self.num_features * self.global_pool.feat_mult()
self.fc = nn.Linear(num_features, num_classes)
else:
self.fc = nn.Identity()
def forward_features(self, x): def forward_features(self, x):
x = self.conv1(x) x = self.conv1(x)

@ -9,16 +9,15 @@ https://arxiv.org/abs/1907.00837
Based on ResNet implementation in https://github.com/rwightman/pytorch-image-models Based on ResNet implementation in https://github.com/rwightman/pytorch-image-models
and SelecSLS Net implementation in https://github.com/mehtadushy/SelecSLS-Pytorch and SelecSLS Net implementation in https://github.com/mehtadushy/SelecSLS-Pytorch
""" """
import math
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from .registry import register_model from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .helpers import load_pretrained from .helpers import load_pretrained
from .layers import SelectAdaptivePool2d from .layers import SelectAdaptivePool2d
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from .registry import register_model
__all__ = ['SelecSLS'] # model_registry will add each entrypoint fn to this __all__ = ['SelecSLS'] # model_registry will add each entrypoint fn to this
@ -134,11 +133,11 @@ class SelecSLS(nn.Module):
def reset_classifier(self, num_classes, global_pool='avg'): def reset_classifier(self, num_classes, global_pool='avg'):
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
self.num_classes = num_classes self.num_classes = num_classes
del self.fc
if num_classes: if num_classes:
self.fc = nn.Linear(self.num_features * self.global_pool.feat_mult(), num_classes) num_features = self.num_features * self.global_pool.feat_mult()
self.fc = nn.Linear(num_features, num_classes)
else: else:
self.fc = None self.fc = nn.Identity()
def forward_features(self, x): def forward_features(self, x):
x = self.stem(x) x = self.stem(x)

@ -8,16 +8,16 @@ Original model: https://github.com/hujie-frank/SENet
ResNet code gently borrowed from ResNet code gently borrowed from
https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py
""" """
from collections import OrderedDict
import math import math
from collections import OrderedDict
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from .registry import register_model from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .helpers import load_pretrained from .helpers import load_pretrained
from .layers import SelectAdaptivePool2d from .layers import SelectAdaptivePool2d
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from .registry import register_model
__all__ = ['SENet'] __all__ = ['SENet']
@ -369,11 +369,11 @@ class SENet(nn.Module):
def reset_classifier(self, num_classes, global_pool='avg'): def reset_classifier(self, num_classes, global_pool='avg'):
self.num_classes = num_classes self.num_classes = num_classes
self.avg_pool = SelectAdaptivePool2d(pool_type=global_pool) self.avg_pool = SelectAdaptivePool2d(pool_type=global_pool)
del self.last_linear
if num_classes: if num_classes:
self.last_linear = nn.Linear(self.num_features * self.avg_pool.feat_mult(), num_classes) num_features = self.num_features * self.avg_pool.feat_mult()
self.last_linear = nn.Linear(num_features, num_classes)
else: else:
self.last_linear = None self.last_linear = nn.Identity()
def forward_features(self, x): def forward_features(self, x):
x = self.layer0(x) x = self.layer0(x)

@ -5,14 +5,16 @@ https://arxiv.org/pdf/2003.13630.pdf
Original model: https://github.com/mrT23/TResNet Original model: https://github.com/mrT23/TResNet
""" """
from collections import OrderedDict
from functools import partial from functools import partial
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from collections import OrderedDict
from .helpers import load_pretrained
from .layers import SpaceToDepthModule, AntiAliasDownsampleLayer, SelectAdaptivePool2d from .layers import SpaceToDepthModule, AntiAliasDownsampleLayer, SelectAdaptivePool2d
from .registry import register_model from .registry import register_model
from .helpers import load_pretrained
try: try:
from inplace_abn import InPlaceABN from inplace_abn import InPlaceABN
@ -88,7 +90,7 @@ class FastSEModule(nn.Module):
def IABN2Float(module: nn.Module) -> nn.Module: def IABN2Float(module: nn.Module) -> nn.Module:
"If `module` is IABN don't use half precision." """If `module` is IABN don't use half precision."""
if isinstance(module, InPlaceABN): if isinstance(module, InPlaceABN):
module.float() module.float()
for child in module.children(): for child in module.children():
@ -277,8 +279,10 @@ class TResNet(nn.Module):
self.num_classes = num_classes self.num_classes = num_classes
self.head = None self.head = None
if num_classes: if num_classes:
self.head = nn.Sequential(OrderedDict([ num_features = self.num_features * self.global_pool.feat_mult()
('fc', nn.Linear(self.num_features * self.global_pool.feat_mult(), num_classes))])) self.head = nn.Sequential(OrderedDict([('fc', nn.Linear(num_features, num_classes))]))
else:
self.head = nn.Sequential(OrderedDict([('fc', nn.Identity())]))
def forward_features(self, x): def forward_features(self, x):
return self.body(x) return self.body(x)

@ -21,15 +21,13 @@ normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5],
The resize parameter of the validation transform should be 333, and make sure to center crop at 299x299 The resize parameter of the validation transform should be 333, and make sure to center crop at 299x299
""" """
import math
import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from .registry import register_model
from .helpers import load_pretrained from .helpers import load_pretrained
from .layers import SelectAdaptivePool2d from .layers import SelectAdaptivePool2d
from .registry import register_model
__all__ = ['Xception'] __all__ = ['Xception']
@ -180,8 +178,11 @@ class Xception(nn.Module):
def reset_classifier(self, num_classes, global_pool='avg'): def reset_classifier(self, num_classes, global_pool='avg'):
self.num_classes = num_classes self.num_classes = num_classes
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
del self.fc if num_classes:
self.fc = nn.Linear(self.num_features * self.global_pool.feat_mult(), num_classes) if num_classes else None num_features = self.num_features * self.global_pool.feat_mult()
self.fc = nn.Linear(num_features, num_classes)
else:
self.fc = nn.Identity()
def forward_features(self, x): def forward_features(self, x):
x = self.conv1(x) x = self.conv1(x)

Loading…
Cancel
Save