More uniform treatment of classifiers across all models, reduce code duplication.

pull/175/head
Ross Wightman 5 years ago
parent 9806f3e1ff
commit b1f1a54de9

@ -4,8 +4,14 @@ import platform
import os
import fnmatch
import timm
from timm import list_models, create_model, set_scriptable
if hasattr(torch._C, '_jit_set_profiling_executor'):
# legacy executor is too slow to compile large models for unit tests
# no need for the fusion performance here
torch._C._jit_set_profiling_executor(True)
torch._C._jit_set_profiling_mode(False)
if 'GITHUB_ACTIONS' in os.environ: # and 'Linux' in platform.system():
# GitHub Linux runner is slower and hits memory limits sooner than MacOS, exclude bigger models
@ -78,10 +84,28 @@ def test_model_default_cfgs(model_name, batch_size):
if all([x <= MAX_FWD_FEAT_SIZE for x in input_size]) and \
not any([fnmatch.fnmatch(model_name, x) for x in EXCLUDE_FILTERS]):
# pool size only checked if default res <= 448 * 448 to keep resource down
# output sizes only checked if default res <= 448 * 448 to keep resource down
input_size = tuple([min(x, MAX_FWD_FEAT_SIZE) for x in input_size])
outputs = model.forward_features(torch.randn((batch_size, *input_size)))
input_tensor = torch.randn((batch_size, *input_size))
# test forward_features (always unpooled)
outputs = model.forward_features(input_tensor)
assert outputs.shape[-1] == pool_size[-1] and outputs.shape[-2] == pool_size[-2]
# test forward after deleting the classifier, output should be poooled, size(-1) == model.num_features
model.reset_classifier(0)
outputs = model.forward(input_tensor)
assert len(outputs.shape) == 2
assert outputs.shape[-1] == model.num_features
# test model forward without pooling and classifier
if not isinstance(model, timm.models.MobileNetV3):
model.reset_classifier(0, '') # reset classifier and set global pooling to pass-through
outputs = model.forward(input_tensor)
assert len(outputs.shape) == 4
assert outputs.shape[-1] == pool_size[-1] and outputs.shape[-2] == pool_size[-2]
# check classifier and first convolution names match those in default_cfg
assert any([k.startswith(classifier) for k in state_dict.keys()]), f'{classifier} not in model params'
assert any([k.startswith(first_conv) for k in state_dict.keys()]), f'{first_conv} not in model params'

@ -14,7 +14,7 @@ from torch.jit.annotations import List
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .helpers import build_model_with_cfg
from .layers import SelectAdaptivePool2d, BatchNormAct2d, create_norm_act, BlurPool2d
from .layers import BatchNormAct2d, create_norm_act, BlurPool2d, create_classifier
from .registry import register_model
__all__ = ['DenseNet']
@ -236,8 +236,8 @@ class DenseNet(nn.Module):
self.num_features = num_features
# Linear layer
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
self.classifier = nn.Linear(self.num_features * self.global_pool.feat_mult(), num_classes)
self.global_pool, self.classifier = create_classifier(
self.num_features, self.num_classes, pool_type=global_pool)
# Official init from torch repo.
for m in self.modules():
@ -254,19 +254,15 @@ class DenseNet(nn.Module):
def reset_classifier(self, num_classes, global_pool='avg'):
self.num_classes = num_classes
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
if num_classes:
num_features = self.num_features * self.global_pool.feat_mult()
self.classifier = nn.Linear(num_features, num_classes)
else:
self.classifier = nn.Identity()
self.global_pool, self.classifier = create_classifier(
self.num_features, self.num_classes, pool_type=global_pool)
def forward_features(self, x):
return self.features(x)
def forward(self, x):
x = self.forward_features(x)
x = self.global_pool(x).flatten(1)
x = self.global_pool(x)
# both classifier and block drop?
# if self.drop_rate > 0.:
# x = F.dropout(x, p=self.drop_rate, training=self.training)

@ -13,7 +13,7 @@ import torch.nn.functional as F
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .helpers import build_model_with_cfg
from .layers import SelectAdaptivePool2d
from .layers import create_classifier
from .registry import register_model
__all__ = ['DLA']
@ -286,9 +286,8 @@ class DLA(nn.Module):
]
self.num_features = channels[-1]
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
self.fc = nn.Conv2d(self.num_features * self.global_pool.feat_mult(), num_classes, 1, bias=True)
self.global_pool, self.fc = create_classifier(
self.num_features, self.num_classes, pool_type=global_pool, use_conv=True)
for m in self.modules():
if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
@ -313,12 +312,8 @@ class DLA(nn.Module):
def reset_classifier(self, num_classes, global_pool='avg'):
self.num_classes = num_classes
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
if num_classes:
num_features = self.num_features * self.global_pool.feat_mult()
self.fc = nn.Conv2d(num_features, num_classes, kernel_size=1, bias=True)
else:
self.fc = nn.Identity()
self.global_pool, self.fc = create_classifier(
self.num_features, self.num_classes, pool_type=global_pool, use_conv=True)
def forward_features(self, x):
x = self.base_layer(x)
@ -336,7 +331,9 @@ class DLA(nn.Module):
if self.drop_rate > 0.:
x = F.dropout(x, p=self.drop_rate, training=self.training)
x = self.fc(x)
return x.flatten(1)
if not self.global_pool.is_identity():
x = x.flatten(1) # conv classifier, flatten if pooling isn't pass-through (disabled)
return x
def _create_dla(variant, pretrained=False, **kwargs):

@ -19,7 +19,7 @@ import torch.nn.functional as F
from timm.data import IMAGENET_DPN_MEAN, IMAGENET_DPN_STD, IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .helpers import build_model_with_cfg
from .layers import SelectAdaptivePool2d, BatchNormAct2d, create_conv2d, ConvBnAct
from .layers import BatchNormAct2d, ConvBnAct, create_conv2d, create_classifier
from .registry import register_model
__all__ = ['DPN']
@ -237,21 +237,16 @@ class DPN(nn.Module):
self.features = nn.Sequential(blocks)
# Using 1x1 conv for the FC layer to allow the extra pooling scheme
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
num_features = self.num_features * self.global_pool.feat_mult()
self.classifier = nn.Conv2d(num_features, num_classes, kernel_size=1, bias=True)
self.global_pool, self.classifier = create_classifier(
self.num_features, self.num_classes, pool_type=global_pool, use_conv=True)
def get_classifier(self):
return self.classifier
def reset_classifier(self, num_classes, global_pool='avg'):
self.num_classes = num_classes
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
if num_classes:
num_features = self.num_features * self.global_pool.feat_mult()
self.classifier = nn.Conv2d(num_features, num_classes, kernel_size=1, bias=True)
else:
self.classifier = nn.Identity()
self.global_pool, self.classifier = create_classifier(
self.num_features, self.num_classes, pool_type=global_pool, use_conv=True)
def forward_features(self, x):
return self.features(x)
@ -261,8 +256,10 @@ class DPN(nn.Module):
x = self.global_pool(x)
if self.drop_rate > 0.:
x = F.dropout(x, p=self.drop_rate, training=self.training)
out = self.classifier(x)
return out.flatten(1)
x = self.classifier(x)
if not self.global_pool.is_identity():
x = x.flatten(1) # conv classifier, flatten if pooling isn't pass-through (disabled)
return x
def _create_dpn(variant, pretrained=False, **kwargs):

@ -35,7 +35,7 @@ from .efficientnet_blocks import round_channels, resolve_bn_args, resolve_act_la
from .efficientnet_builder import EfficientNetBuilder, decode_arch_def, efficientnet_init_weights
from .features import FeatureInfo, FeatureHooks
from .helpers import build_model_with_cfg
from .layers import SelectAdaptivePool2d, create_conv2d
from .layers import create_conv2d, create_classifier
from .registry import register_model
__all__ = ['EfficientNet']
@ -336,32 +336,28 @@ class EfficientNet(nn.Module):
self.num_classes = num_classes
self.num_features = num_features
self.drop_rate = drop_rate
self._in_chs = in_chans
# Stem
if not fix_stem:
stem_size = round_channels(stem_size, channel_multiplier, channel_divisor, channel_min)
self.conv_stem = create_conv2d(self._in_chs, stem_size, 3, stride=2, padding=pad_type)
self.conv_stem = create_conv2d(in_chans, stem_size, 3, stride=2, padding=pad_type)
self.bn1 = norm_layer(stem_size, **norm_kwargs)
self.act1 = act_layer(inplace=True)
self._in_chs = stem_size
# Middle stages (IR/ER/DS Blocks)
builder = EfficientNetBuilder(
channel_multiplier, channel_divisor, channel_min, output_stride, pad_type, act_layer, se_kwargs,
norm_layer, norm_kwargs, drop_path_rate, verbose=_DEBUG)
self.blocks = nn.Sequential(*builder(self._in_chs, block_args))
self.blocks = nn.Sequential(*builder(stem_size, block_args))
self.feature_info = builder.features
self._in_chs = builder.in_chs
head_chs = builder.in_chs
# Head + Pooling
self.conv_head = create_conv2d(self._in_chs, self.num_features, 1, padding=pad_type)
self.conv_head = create_conv2d(head_chs, self.num_features, 1, padding=pad_type)
self.bn2 = norm_layer(self.num_features, **norm_kwargs)
self.act2 = act_layer(inplace=True)
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
# Classifier
self.classifier = nn.Linear(self.num_features * self.global_pool.feat_mult(), self.num_classes)
self.global_pool, self.classifier = create_classifier(
self.num_features, self.num_classes, pool_type=global_pool)
efficientnet_init_weights(self)
@ -369,7 +365,7 @@ class EfficientNet(nn.Module):
layers = [self.conv_stem, self.bn1, self.act1]
layers.extend(self.blocks)
layers.extend([self.conv_head, self.bn2, self.act2, self.global_pool])
layers.extend([nn.Flatten(), nn.Dropout(self.drop_rate), self.classifier])
layers.extend([nn.Dropout(self.drop_rate), self.classifier])
return nn.Sequential(*layers)
def get_classifier(self):
@ -377,12 +373,8 @@ class EfficientNet(nn.Module):
def reset_classifier(self, num_classes, global_pool='avg'):
self.num_classes = num_classes
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
if num_classes:
num_features = self.num_features * self.global_pool.feat_mult()
self.classifier = nn.Linear(num_features, num_classes)
else:
self.classifier = nn.Identity()
self.global_pool, self.classifier = create_classifier(
self.num_features, self.num_classes, pool_type=global_pool)
def forward_features(self, x):
x = self.conv_stem(x)
@ -397,7 +389,6 @@ class EfficientNet(nn.Module):
def forward(self, x):
x = self.forward_features(x)
x = self.global_pool(x)
x = x.flatten(1)
if self.drop_rate > 0.:
x = F.dropout(x, p=self.drop_rate, training=self.training)
return self.classifier(x)
@ -417,24 +408,21 @@ class EfficientNetFeatures(nn.Module):
super(EfficientNetFeatures, self).__init__()
norm_kwargs = norm_kwargs or {}
self.drop_rate = drop_rate
self._in_chs = in_chans
# Stem
if not fix_stem:
stem_size = round_channels(stem_size, channel_multiplier, channel_divisor, channel_min)
self.conv_stem = create_conv2d(self._in_chs, stem_size, 3, stride=2, padding=pad_type)
self.conv_stem = create_conv2d(in_chans, stem_size, 3, stride=2, padding=pad_type)
self.bn1 = norm_layer(stem_size, **norm_kwargs)
self.act1 = act_layer(inplace=True)
self._in_chs = stem_size
# Middle stages (IR/ER/DS Blocks)
builder = EfficientNetBuilder(
channel_multiplier, channel_divisor, channel_min, output_stride, pad_type, act_layer, se_kwargs,
norm_layer, norm_kwargs, drop_path_rate, feature_location=feature_location, verbose=_DEBUG)
self.blocks = nn.Sequential(*builder(self._in_chs, block_args))
self.blocks = nn.Sequential(*builder(stem_size, block_args))
self.feature_info = FeatureInfo(builder.features, out_indices)
self._stage_out_idx = {v['stage']: i for i, v in enumerate(self.feature_info) if i in out_indices}
self._in_chs = builder.in_chs
efficientnet_init_weights(self)

@ -13,7 +13,7 @@ import torch.nn.functional as F
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .helpers import build_model_with_cfg
from .layers import SelectAdaptivePool2d, get_padding
from .layers import create_classifier, get_padding
from .registry import register_model
__all__ = ['Xception65']
@ -192,16 +192,14 @@ class Xception65(nn.Module):
dict(num_chs=2048, reduction=32, module='act5'),
]
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
self.fc = nn.Linear(self.num_features * self.global_pool.feat_mult(), num_classes)
self.global_pool, self.fc = create_classifier(self.num_features, self.num_classes, pool_type=global_pool)
def get_classifier(self):
return self.fc
def reset_classifier(self, num_classes, global_pool='avg'):
self.num_classes = num_classes
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
self.fc = nn.Linear(self.num_features * self.global_pool.feat_mult(), num_classes) if num_classes else None
self.global_pool, self.fc = create_classifier(self.num_features, self.num_classes, pool_type=global_pool)
def forward_features(self, x):
# Entry flow
@ -242,7 +240,7 @@ class Xception65(nn.Module):
def forward(self, x):
x = self.forward_features(x)
x = self.global_pool(x).flatten(1)
x = self.global_pool(x)
if self.drop_rate:
F.dropout(x, self.drop_rate, training=self.training)
x = self.fc(x)

@ -187,10 +187,13 @@ def adapt_model_from_string(parent_module, model_string):
affine=old_module.affine, track_running_stats=True)
set_layer(new_module, n, new_bn)
if isinstance(old_module, nn.Linear):
# FIXME extra checks to ensure this is actually the FC classifier layer and not a diff Linear layer?
num_features = state_dict[n + '.weight'][1]
new_fc = nn.Linear(
in_features=state_dict[n + '.weight'][1], out_features=old_module.out_features,
bias=old_module.bias is not None)
in_features=num_features, out_features=old_module.out_features, bias=old_module.bias is not None)
set_layer(new_module, n, new_fc)
if hasattr(new_module, 'num_features'):
new_module.num_features = num_features
new_module.eval()
parent_module.eval()

@ -18,7 +18,7 @@ import torch.nn.functional as F
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .features import FeatureInfo
from .helpers import build_model_with_cfg
from .layers import SelectAdaptivePool2d
from .layers import create_classifier
from .registry import register_model
from .resnet import BasicBlock, Bottleneck # leveraging ResNet blocks w/ additional features like SE
@ -553,8 +553,8 @@ class HighResolutionNet(nn.Module):
# Classification Head
self.num_features = 2048
self.incre_modules, self.downsamp_modules, self.final_layer = self._make_head(pre_stage_channels)
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
self.classifier = nn.Linear(self.num_features * self.global_pool.feat_mult(), num_classes)
self.global_pool, self.classifier = create_classifier(
self.num_features, self.num_classes, pool_type=global_pool)
elif head == 'incre':
self.num_features = 2048
self.incre_modules, _, _ = self._make_head(pre_stage_channels, True)
@ -685,12 +685,8 @@ class HighResolutionNet(nn.Module):
def reset_classifier(self, num_classes, global_pool='avg'):
self.num_classes = num_classes
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
num_features = self.num_features * self.global_pool.feat_mult()
if num_classes:
self.classifier = nn.Linear(num_features, num_classes)
else:
self.classifier = nn.Identity()
self.global_pool, self.classifier = create_classifier(
self.num_features, self.num_classes, pool_type=global_pool)
def stages(self, x) -> List[torch.Tensor]:
x = self.layer1(x)
@ -726,7 +722,7 @@ class HighResolutionNet(nn.Module):
def forward(self, x):
x = self.forward_features(x)
x = self.global_pool(x).flatten(1)
x = self.global_pool(x)
if self.drop_rate > 0.:
x = F.dropout(x, p=self.drop_rate, training=self.training)
x = self.classifier(x)

@ -8,7 +8,7 @@ import torch.nn.functional as F
from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
from .helpers import build_model_with_cfg
from .layers import SelectAdaptivePool2d
from .layers import create_classifier
from .registry import register_model
__all__ = ['InceptionResnetV2']
@ -296,21 +296,14 @@ class InceptionResnetV2(nn.Module):
self.conv2d_7b = BasicConv2d(2080, self.num_features, kernel_size=1, stride=1)
self.feature_info += [dict(num_chs=self.num_features, reduction=32, module='conv2d_7b')]
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
# NOTE some variants/checkpoints for this model may have 'last_linear' as the name for the FC
self.classif = nn.Linear(self.num_features * self.global_pool.feat_mult(), num_classes)
self.global_pool, self.classif = create_classifier(self.num_features, self.num_classes, pool_type=global_pool)
def get_classifier(self):
return self.classif
def reset_classifier(self, num_classes, global_pool='avg'):
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
self.num_classes = num_classes
if num_classes:
num_features = self.num_features * self.global_pool.feat_mult()
self.classif = nn.Linear(num_features, num_classes)
else:
self.classif = nn.Identity()
self.global_pool, self.classif = create_classifier(self.num_features, self.num_classes, pool_type=global_pool)
def forward_features(self, x):
x = self.conv2d_1a(x)
@ -332,7 +325,7 @@ class InceptionResnetV2(nn.Module):
def forward(self, x):
x = self.forward_features(x)
x = self.global_pool(x).flatten(1)
x = self.global_pool(x)
if self.drop_rate > 0:
x = F.dropout(x, p=self.drop_rate, training=self.training)
x = self.classif(x)

@ -10,7 +10,7 @@ import torch.nn.functional as F
from timm.data import IMAGENET_DEFAULT_STD, IMAGENET_DEFAULT_MEAN, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
from .helpers import build_model_with_cfg
from .registry import register_model
from .layers import trunc_normal_, SelectAdaptivePool2d
from .layers import trunc_normal_, create_classifier
def _cfg(url='', **kwargs):
@ -326,8 +326,7 @@ class InceptionV3(nn.Module):
]
self.num_features = 2048
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
self.fc = nn.Linear(2048, num_classes)
self.global_pool, self.fc = create_classifier(self.num_features, self.num_classes, pool_type=global_pool)
for m in self.modules():
if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
@ -389,16 +388,12 @@ class InceptionV3(nn.Module):
return self.fc
def reset_classifier(self, num_classes, global_pool='avg'):
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
self.num_classes = num_classes
if self.num_classes > 0:
self.fc = nn.Linear(self.num_features * self.global_pool.feat_mult(), num_classes)
else:
self.fc = nn.Identity()
self.global_pool, self.fc = create_classifier(self.num_features, self.num_classes, pool_type=global_pool)
def forward(self, x):
x = self.forward_features(x)
x = self.global_pool(x).flatten(1)
x = self.global_pool(x)
if self.drop_rate > 0:
x = F.dropout(x, p=self.drop_rate, training=self.training)
x = self.fc(x)
@ -421,7 +416,7 @@ class InceptionV3Aux(InceptionV3):
def forward(self, x):
x, aux = self.forward_features(x)
x = self.global_pool(x).flatten(1)
x = self.global_pool(x)
if self.drop_rate > 0:
x = F.dropout(x, p=self.drop_rate, training=self.training)
x = self.fc(x)

@ -8,7 +8,7 @@ import torch.nn.functional as F
from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
from .helpers import build_model_with_cfg
from .layers import SelectAdaptivePool2d
from .layers import create_classifier
from .registry import register_model
__all__ = ['InceptionV4']
@ -279,27 +279,23 @@ class InceptionV4(nn.Module):
dict(num_chs=1024, reduction=16, module='features.17'),
dict(num_chs=1536, reduction=32, module='features.21'),
]
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
self.last_linear = nn.Linear(self.num_features * self.global_pool.feat_mult(), num_classes)
self.global_pool, self.last_linear = create_classifier(
self.num_features, self.num_classes, pool_type=global_pool)
def get_classifier(self):
return self.last_linear
def reset_classifier(self, num_classes, global_pool='avg'):
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
self.num_classes = num_classes
if num_classes:
num_features = self.num_features * self.global_pool.feat_mult()
self.last_linear = nn.Linear(num_features, num_classes)
else:
self.last_linear = nn.Identity()
self.global_pool, self.last_linear = create_classifier(
self.num_features, self.num_classes, pool_type=global_pool)
def forward_features(self, x):
return self.features(x)
def forward(self, x):
x = self.forward_features(x)
x = self.global_pool(x).flatten(1)
x = self.global_pool(x)
if self.drop_rate > 0:
x = F.dropout(x, p=self.drop_rate, training=self.training)
x = self.last_linear(x)

@ -3,7 +3,7 @@ from .adaptive_avgmax_pool import \
adaptive_avgmax_pool2d, select_adaptive_pool2d, AdaptiveAvgMaxPool2d, SelectAdaptivePool2d
from .anti_aliasing import AntiAliasDownsampleLayer
from .blur_pool import BlurPool2d
from .classifier import ClassifierHead
from .classifier import ClassifierHead, create_classifier
from .cond_conv2d import CondConv2d, get_condconv_initializer
from .config import is_exportable, is_scriptable, is_no_jit, set_exportable, set_scriptable, set_no_jit,\
set_layer_config

@ -72,19 +72,23 @@ class SelectAdaptivePool2d(nn.Module):
"""
def __init__(self, output_size=1, pool_type='avg', flatten=False):
super(SelectAdaptivePool2d, self).__init__()
self.output_size = output_size
self.pool_type = pool_type
self.pool_type = pool_type or '' # convert other falsy values to empty string for consistent TS typing
self.flatten = flatten
if pool_type == 'avgmax':
if pool_type == '':
self.pool = nn.Identity() # pass through
elif pool_type == 'avg':
self.pool = nn.AdaptiveAvgPool2d(output_size)
elif pool_type == 'avgmax':
self.pool = AdaptiveAvgMaxPool2d(output_size)
elif pool_type == 'catavgmax':
self.pool = AdaptiveCatAvgMaxPool2d(output_size)
elif pool_type == 'max':
self.pool = nn.AdaptiveMaxPool2d(output_size)
else:
if pool_type != 'avg':
assert False, 'Invalid pool type: %s' % pool_type
self.pool = nn.AdaptiveAvgPool2d(output_size)
assert False, 'Invalid pool type: %s' % pool_type
def is_identity(self):
return self.pool_type == ''
def forward(self, x):
x = self.pool(x)
@ -97,5 +101,6 @@ class SelectAdaptivePool2d(nn.Module):
def __repr__(self):
return self.__class__.__name__ + ' (' \
+ 'output_size=' + str(self.output_size) \
+ ', pool_type=' + self.pool_type + ')'
+ 'pool_type=' + self.pool_type \
+ ', flatten=' + str(self.flatten) + ')'

@ -1,23 +1,40 @@
""" Classifier head and layer factory
Hacked together by / Copyright 2020 Ross Wightman
"""
from torch import nn as nn
from torch.nn import functional as F
from .adaptive_avgmax_pool import SelectAdaptivePool2d
def create_classifier(num_features, num_classes, pool_type='avg', use_conv=False):
flatten = not use_conv # flatten when we use a Linear layer after pooling
if not pool_type:
assert num_classes == 0 or use_conv,\
'Pooling can only be disabled if classifier is also removed or conv classifier is used'
flatten = False # disable flattening if pooling is pass-through (no pooling)
global_pool = SelectAdaptivePool2d(pool_type=pool_type, flatten=flatten)
num_pooled_features = num_features * global_pool.feat_mult()
if num_classes <= 0:
fc = nn.Identity() # pass-through (no classifier)
elif use_conv:
fc = nn.Conv2d(num_pooled_features, num_classes, 1, bias=True)
else:
fc = nn.Linear(num_pooled_features, num_classes, bias=True)
return global_pool, fc
class ClassifierHead(nn.Module):
"""Classifier Head w/ configurable global pooling and dropout."""
"""Classifier head w/ configurable global pooling and dropout."""
def __init__(self, in_chs, num_classes, pool_type='avg', drop_rate=0.):
super(ClassifierHead, self).__init__()
self.drop_rate = drop_rate
self.global_pool = SelectAdaptivePool2d(pool_type=pool_type)
if num_classes > 0:
self.fc = nn.Linear(in_chs * self.global_pool.feat_mult(), num_classes, bias=True)
else:
self.fc = nn.Identity()
self.global_pool, self.fc = create_classifier(in_chs, num_classes, pool_type=pool_type)
def forward(self, x):
x = self.global_pool(x).flatten(1)
x = self.global_pool(x)
if self.drop_rate:
x = F.dropout(x, p=float(self.drop_rate), training=self.training)
x = self.fc(x)

@ -85,30 +85,27 @@ class MobileNetV3(nn.Module):
self.num_classes = num_classes
self.num_features = num_features
self.drop_rate = drop_rate
self._in_chs = in_chans
# Stem
stem_size = round_channels(stem_size, channel_multiplier)
self.conv_stem = create_conv2d(self._in_chs, stem_size, 3, stride=2, padding=pad_type)
self.conv_stem = create_conv2d(in_chans, stem_size, 3, stride=2, padding=pad_type)
self.bn1 = norm_layer(stem_size, **norm_kwargs)
self.act1 = act_layer(inplace=True)
self._in_chs = stem_size
# Middle stages (IR/ER/DS Blocks)
builder = EfficientNetBuilder(
channel_multiplier, 8, None, 32, pad_type, act_layer, se_kwargs,
norm_layer, norm_kwargs, drop_path_rate, verbose=_DEBUG)
self.blocks = nn.Sequential(*builder(self._in_chs, block_args))
self.blocks = nn.Sequential(*builder(stem_size, block_args))
self.feature_info = builder.features
self._in_chs = builder.in_chs
head_chs = builder.in_chs
# Head + Pooling
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
self.conv_head = create_conv2d(self._in_chs, self.num_features, 1, padding=pad_type, bias=head_bias)
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) if global_pool else nn.Identity()
num_pooled_chs = head_chs * self.global_pool.feat_mult()
self.conv_head = create_conv2d(num_pooled_chs, self.num_features, 1, padding=pad_type, bias=head_bias)
self.act2 = act_layer(inplace=True)
# Classifier
self.classifier = nn.Linear(self.num_features * self.global_pool.feat_mult(), self.num_classes)
self.classifier = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
efficientnet_init_weights(self)
@ -123,13 +120,10 @@ class MobileNetV3(nn.Module):
return self.classifier
def reset_classifier(self, num_classes, global_pool='avg'):
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
self.num_classes = num_classes
if num_classes:
num_features = self.num_features * self.global_pool.feat_mult()
self.classifier = nn.Linear(num_features, num_classes)
else:
self.classifier = nn.Identity()
# cannot meaningfully change pooling of efficient head after creation
assert global_pool == self.global_pool.pool_type
self.classifier = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
def forward_features(self, x):
x = self.conv_stem(x)
@ -142,8 +136,7 @@ class MobileNetV3(nn.Module):
return x
def forward(self, x):
x = self.forward_features(x)
x = x.flatten(1)
x = self.forward_features(x).flatten(1)
if self.drop_rate > 0.:
x = F.dropout(x, p=self.drop_rate, training=self.training)
return self.classifier(x)
@ -163,23 +156,20 @@ class MobileNetV3Features(nn.Module):
super(MobileNetV3Features, self).__init__()
norm_kwargs = norm_kwargs or {}
self.drop_rate = drop_rate
self._in_chs = in_chans
# Stem
stem_size = round_channels(stem_size, channel_multiplier)
self.conv_stem = create_conv2d(self._in_chs, stem_size, 3, stride=2, padding=pad_type)
self.conv_stem = create_conv2d(in_chans, stem_size, 3, stride=2, padding=pad_type)
self.bn1 = norm_layer(stem_size, **norm_kwargs)
self.act1 = act_layer(inplace=True)
self._in_chs = stem_size
# Middle stages (IR/ER/DS Blocks)
builder = EfficientNetBuilder(
channel_multiplier, 8, None, output_stride, pad_type, act_layer, se_kwargs,
norm_layer, norm_kwargs, drop_path_rate, feature_location=feature_location, verbose=_DEBUG)
self.blocks = nn.Sequential(*builder(self._in_chs, block_args))
self.blocks = nn.Sequential(*builder(stem_size, block_args))
self.feature_info = FeatureInfo(builder.features, out_indices)
self._stage_out_idx = {v['stage']: i for i, v in enumerate(self.feature_info) if i in out_indices}
self._in_chs = builder.in_chs
efficientnet_init_weights(self)

@ -6,7 +6,7 @@ import torch.nn as nn
import torch.nn.functional as F
from .helpers import build_model_with_cfg
from .layers import SelectAdaptivePool2d, ConvBnAct, create_conv2d, create_pool2d
from .layers import ConvBnAct, create_conv2d, create_pool2d, create_classifier
from .registry import register_model
__all__ = ['NASNetALarge']
@ -496,20 +496,16 @@ class NASNetALarge(nn.Module):
dict(num_chs=4032, reduction=32, module='act'),
]
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
self.last_linear = nn.Linear(self.num_features * self.global_pool.feat_mult(), num_classes)
self.global_pool, self.last_linear = create_classifier(
self.num_features, self.num_classes, pool_type=global_pool)
def get_classifier(self):
return self.last_linear
def reset_classifier(self, num_classes, global_pool='avg'):
self.num_classes = num_classes
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
if num_classes:
num_features = self.num_features * self.global_pool.feat_mult()
self.last_linear = nn.Linear(num_features, num_classes)
else:
self.last_linear = nn.Identity()
self.global_pool, self.last_linear = create_classifier(
self.num_features, self.num_classes, pool_type=global_pool)
def forward_features(self, x):
x_conv0 = self.conv0(x)
@ -544,7 +540,7 @@ class NASNetALarge(nn.Module):
def forward(self, x):
x = self.forward_features(x)
x = self.global_pool(x).flatten(1)
x = self.global_pool(x)
if self.drop_rate > 0:
x = F.dropout(x, self.drop_rate, training=self.training)
x = self.last_linear(x)

@ -12,7 +12,7 @@ import torch.nn as nn
import torch.nn.functional as F
from .helpers import build_model_with_cfg
from .layers import SelectAdaptivePool2d, ConvBnAct, create_conv2d, create_pool2d
from .layers import ConvBnAct, create_conv2d, create_pool2d, create_classifier
from .registry import register_model
__all__ = ['PNASNet5Large']
@ -291,20 +291,16 @@ class PNASNet5Large(nn.Module):
dict(num_chs=4320, reduction=32, module='act'),
]
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
self.last_linear = nn.Linear(self.num_features * self.global_pool.feat_mult(), num_classes)
self.global_pool, self.last_linear = create_classifier(
self.num_features, self.num_classes, pool_type=global_pool)
def get_classifier(self):
return self.last_linear
def reset_classifier(self, num_classes, global_pool='avg'):
self.num_classes = num_classes
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
if num_classes:
num_features = self.num_features * self.global_pool.feat_mult()
self.last_linear = nn.Linear(num_features, num_classes)
else:
self.last_linear = nn.Identity()
self.global_pool, self.last_linear = create_classifier(
self.num_features, self.num_classes, pool_type=global_pool)
def forward_features(self, x):
x_conv_0 = self.conv_0(x)
@ -327,7 +323,7 @@ class PNASNet5Large(nn.Module):
def forward(self, x):
x = self.forward_features(x)
x = self.global_pool(x).flatten(1)
x = self.global_pool(x)
if self.drop_rate > 0:
x = F.dropout(x, self.drop_rate, training=self.training)
x = self.last_linear(x)

@ -15,7 +15,7 @@ import torch.nn.functional as F
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .helpers import build_model_with_cfg
from .layers import SelectAdaptivePool2d, DropBlock2d, DropPath, AvgPool2dSame, create_attn, BlurPool2d
from .layers import DropBlock2d, DropPath, AvgPool2dSame, BlurPool2d, create_attn, create_classifier
from .registry import register_model
__all__ = ['ResNet', 'BasicBlock', 'Bottleneck'] # model_registry will add each entrypoint fn to this
@ -542,9 +542,8 @@ class ResNet(nn.Module):
self.feature_info.extend(stage_feature_info)
# Head (Pooling and Classifier)
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
self.num_features = 512 * block.expansion
self.fc = nn.Linear(self.num_features * self.global_pool.feat_mult(), num_classes)
self.global_pool, self.fc = create_classifier(self.num_features, self.num_classes, pool_type=global_pool)
for n, m in self.named_modules():
if isinstance(m, nn.Conv2d):
@ -561,13 +560,8 @@ class ResNet(nn.Module):
return self.fc
def reset_classifier(self, num_classes, global_pool='avg'):
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
self.num_classes = num_classes
if num_classes:
num_features = self.num_features * self.global_pool.feat_mult()
self.fc = nn.Linear(num_features, num_classes)
else:
self.fc = nn.Identity()
self.global_pool, self.fc = create_classifier(self.num_features, self.num_classes, pool_type=global_pool)
def forward_features(self, x):
x = self.conv1(x)
@ -583,7 +577,7 @@ class ResNet(nn.Module):
def forward(self, x):
x = self.forward_features(x)
x = self.global_pool(x).flatten(1)
x = self.global_pool(x)
if self.drop_rate:
x = F.dropout(x, p=float(self.drop_rate), training=self.training)
x = self.fc(x)

@ -168,6 +168,7 @@ class ReXNetV1(nn.Module):
initial_chs=16, final_chs=180, width_mult=1.0, depth_mult=1.0, use_se=True,
se_rd=12, ch_div=1, drop_rate=0.2, feature_location='bottleneck'):
super(ReXNetV1, self).__init__()
self.drop_rate = drop_rate
assert output_stride == 32 # FIXME support dilation
stem_base_chs = 32 / width_mult if width_mult < 1.0 else 32

@ -17,7 +17,7 @@ import torch.nn.functional as F
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .helpers import build_model_with_cfg
from .layers import SelectAdaptivePool2d
from .layers import create_classifier
from .registry import register_model
__all__ = ['SelecSLS'] # model_registry will add each entrypoint fn to this
@ -165,8 +165,7 @@ class SelecSLS(nn.Module):
self.num_features = cfg['num_features']
self.feature_info = cfg['feature_info']
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
self.fc = nn.Linear(self.num_features * self.global_pool.feat_mult(), num_classes)
self.global_pool, self.fc = create_classifier(self.num_features, self.num_classes, pool_type=global_pool)
for n, m in self.named_modules():
if isinstance(m, nn.Conv2d):
@ -179,13 +178,8 @@ class SelecSLS(nn.Module):
return self.fc
def reset_classifier(self, num_classes, global_pool='avg'):
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
self.num_classes = num_classes
if num_classes:
num_features = self.num_features * self.global_pool.feat_mult()
self.fc = nn.Linear(num_features, num_classes)
else:
self.fc = nn.Identity()
self.global_pool, self.fc = create_classifier(self.num_features, self.num_classes, pool_type=global_pool)
def forward_features(self, x):
x = self.stem(x)
@ -195,7 +189,7 @@ class SelecSLS(nn.Module):
def forward(self, x):
x = self.forward_features(x)
x = self.global_pool(x).flatten(1)
x = self.global_pool(x)
if self.drop_rate > 0.:
x = F.dropout(x, p=self.drop_rate, training=self.training)
x = self.fc(x)

@ -19,7 +19,7 @@ import torch.nn.functional as F
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .helpers import build_model_with_cfg
from .layers import SelectAdaptivePool2d
from .layers import create_classifier
from .registry import register_model
__all__ = ['SENet']
@ -345,8 +345,8 @@ class SENet(nn.Module):
)
self.feature_info += [dict(num_chs=512 * block.expansion, reduction=32, module='layer4')]
self.num_features = 512 * block.expansion
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
self.last_linear = nn.Linear(self.num_features, num_classes)
self.global_pool, self.last_linear = create_classifier(
self.num_features, self.num_classes, pool_type=global_pool)
for m in self.modules():
_weight_init(m)
@ -374,12 +374,8 @@ class SENet(nn.Module):
def reset_classifier(self, num_classes, global_pool='avg'):
self.num_classes = num_classes
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
if num_classes:
num_features = self.num_features * self.global_pool.feat_mult()
self.last_linear = nn.Linear(num_features, num_classes)
else:
self.last_linear = nn.Identity()
self.global_pool, self.last_linear = create_classifier(
self.num_features, self.num_classes, pool_type=global_pool)
def forward_features(self, x):
x = self.layer0(x)
@ -391,7 +387,7 @@ class SENet(nn.Module):
return x
def logits(self, x):
x = self.global_pool(x).flatten(1)
x = self.global_pool(x)
if self.drop_rate > 0.:
x = F.dropout(x, p=self.drop_rate, training=self.training)
x = self.last_linear(x)

@ -26,7 +26,7 @@ import torch.nn as nn
import torch.nn.functional as F
from .helpers import build_model_with_cfg
from .layers import SelectAdaptivePool2d
from .layers import create_classifier
from .registry import register_model
__all__ = ['Xception']
@ -162,8 +162,7 @@ class Xception(nn.Module):
dict(num_chs=2048, reduction=32, module='act4'),
]
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
self.fc = nn.Linear(self.num_features * self.global_pool.feat_mult(), num_classes)
self.global_pool, self.fc = create_classifier(self.num_features, self.num_classes, pool_type=global_pool)
# #------- init weights --------
for m in self.modules():
@ -178,12 +177,7 @@ class Xception(nn.Module):
def reset_classifier(self, num_classes, global_pool='avg'):
self.num_classes = num_classes
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
if num_classes:
num_features = self.num_features * self.global_pool.feat_mult()
self.fc = nn.Linear(num_features, num_classes)
else:
self.fc = nn.Identity()
self.global_pool, self.fc = create_classifier(self.num_features, self.num_classes, pool_type=global_pool)
def forward_features(self, x):
x = self.conv1(x)
@ -218,7 +212,7 @@ class Xception(nn.Module):
def forward(self, x):
x = self.forward_features(x)
x = self.global_pool(x).flatten(1)
x = self.global_pool(x)
if self.drop_rate:
F.dropout(x, self.drop_rate, training=self.training)
x = self.fc(x)

Loading…
Cancel
Save