From d23a2697d0a150f04fcd0671131cb1e99862877a Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Mon, 29 Jun 2020 17:45:38 -0700 Subject: [PATCH] Working on feature extraction, interfaces refined, a number of models working, some in progress. --- timm/data/__init__.py | 1 + timm/data/real_labels.py | 36 ++ timm/models/densenet.py | 32 +- timm/models/efficientnet.py | 31 +- timm/models/efficientnet_builder.py | 13 +- timm/models/feature_hooks.py | 20 +- timm/models/features.py | 251 +++++++++++++ timm/models/gluon_xception.py | 300 ++------------- timm/models/inception_resnet_v2.py | 47 ++- timm/models/mobilenetv3.py | 26 +- timm/models/nasnet.py | 450 ++++++++++------------- timm/models/pnasnet.py | 315 +++++++--------- timm/models/res2net.py | 110 ++---- timm/models/resnest.py | 107 ++---- timm/models/resnet.py | 542 ++++++++++------------------ timm/models/selecsls.py | 50 ++- timm/models/sknet.py | 97 ++--- timm/models/vovnet.py | 27 +- timm/models/xception.py | 112 +++--- validate.py | 60 ++- 20 files changed, 1214 insertions(+), 1413 deletions(-) create mode 100644 timm/data/real_labels.py create mode 100644 timm/models/features.py diff --git a/timm/data/__init__.py b/timm/data/__init__.py index ee2240b4..37c3068a 100644 --- a/timm/data/__init__.py +++ b/timm/data/__init__.py @@ -7,3 +7,4 @@ from .transforms_factory import create_transform from .mixup import mixup_batch, FastCollateMixup from .auto_augment import RandAugment, AutoAugment, rand_augment_ops, auto_augment_policy,\ rand_augment_transform, auto_augment_transform +from .real_labels import RealLabelsImagenet diff --git a/timm/data/real_labels.py b/timm/data/real_labels.py new file mode 100644 index 00000000..be82e5e0 --- /dev/null +++ b/timm/data/real_labels.py @@ -0,0 +1,36 @@ +import os +import json +import numpy as np + + +class RealLabelsImagenet: + + def __init__(self, filenames, real_json='real.json', topk=(1, 5)): + with open(real_json) as real_labels: + real_labels = json.load(real_labels) + real_labels = {f'ILSVRC2012_val_{i + 1:08d}.JPEG': labels for i, labels in enumerate(real_labels)} + self.real_labels = real_labels + self.filenames = filenames + assert len(self.filenames) == len(self.real_labels) + self.topk = topk + self.is_correct = {k: [] for k in topk} + self.sample_idx = 0 + + def add_result(self, output): + maxk = max(self.topk) + _, pred_batch = output.topk(maxk, 1, True, True) + pred_batch = pred_batch.cpu().numpy() + for pred in pred_batch: + filename = self.filenames[self.sample_idx] + filename = os.path.basename(filename) + if self.real_labels[filename]: + for k in self.topk: + self.is_correct[k].append( + any([p in self.real_labels[filename] for p in pred[:k]])) + self.sample_idx += 1 + + def get_accuracy(self, k=None): + if k is None: + return {k: float(np.mean(self.is_correct[k] for k in self.topk))} + else: + return float(np.mean(self.is_correct[k])) * 100 diff --git a/timm/models/densenet.py b/timm/models/densenet.py index 59a15a85..1eeaacee 100644 --- a/timm/models/densenet.py +++ b/timm/models/densenet.py @@ -13,6 +13,7 @@ import torch.utils.checkpoint as cp from torch.jit.annotations import List from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from .features import FeatureNet from .helpers import load_pretrained from .layers import SelectAdaptivePool2d, BatchNormAct2d, create_norm_act, BlurPool2d from .registry import register_model @@ -199,6 +200,9 @@ class DenseNet(nn.Module): ('norm0', norm_layer(num_init_features)), ('pool0', stem_pool), ])) + self.feature_info = [ + dict(num_chs=num_init_features, reduction=2, module=f'features.norm{2 if deep_stem else 0}')] + current_stride = 4 # DenseBlocks num_features = num_init_features @@ -212,21 +216,27 @@ class DenseNet(nn.Module): drop_rate=drop_rate, memory_efficient=memory_efficient ) - self.features.add_module('denseblock%d' % (i + 1), block) + module_name = f'denseblock{(i + 1)}' + self.features.add_module(module_name, block) num_features = num_features + num_layers * growth_rate transition_aa_layer = None if aa_stem_only else aa_layer if i != len(block_config) - 1: + self.feature_info += [ + dict(num_chs=num_features, reduction=current_stride, module='features.' + module_name)] + current_stride *= 2 trans = DenseTransition( num_input_features=num_features, num_output_features=num_features // 2, norm_layer=norm_layer, aa_layer=transition_aa_layer) - self.features.add_module('transition%d' % (i + 1), trans) + self.features.add_module(f'transition{i + 1}', trans) num_features = num_features // 2 # Final batch norm self.features.add_module('norm5', norm_layer(num_features)) - # Linear layer + self.feature_info += [dict(num_chs=num_features, reduction=current_stride, module='features.norm5')] self.num_features = num_features + + # Linear layer self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) self.classifier = nn.Linear(self.num_features * self.global_pool.feat_mult(), num_classes) @@ -279,16 +289,14 @@ def _filter_torchvision_pretrained(state_dict): def _densenet(variant, growth_rate, block_config, pretrained, **kwargs): + features = False + out_indices = None if kwargs.pop('features_only', False): - assert False, 'Not Implemented' # TODO - load_strict = False + features = True kwargs.pop('num_classes', 0) - model_class = DenseNet - else: - load_strict = True - model_class = DenseNet + out_indices = kwargs.pop('out_indices', (0, 1, 2, 3, 4)) default_cfg = default_cfgs[variant] - model = model_class(growth_rate=growth_rate, block_config=block_config, **kwargs) + model = DenseNet(growth_rate=growth_rate, block_config=block_config, **kwargs) model.default_cfg = default_cfg if pretrained: load_pretrained( @@ -296,7 +304,9 @@ def _densenet(variant, growth_rate, block_config, pretrained, **kwargs): num_classes=kwargs.get('num_classes', 0), in_chans=kwargs.get('in_chans', 3), filter_fn=_filter_torchvision_pretrained, - strict=load_strict) + strict=not features) + if features: + model = FeatureNet(model, out_indices, flatten_sequential=True) return model diff --git a/timm/models/efficientnet.py b/timm/models/efficientnet.py index 47cd0b9d..08b14cb0 100644 --- a/timm/models/efficientnet.py +++ b/timm/models/efficientnet.py @@ -34,6 +34,7 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCE from .efficientnet_blocks import round_channels, resolve_bn_args, resolve_act_layer, BN_EPS_TF_DEFAULT from .efficientnet_builder import EfficientNetBuilder, decode_arch_def, efficientnet_init_weights from .feature_hooks import FeatureHooks +from .features import FeatureInfo from .helpers import load_pretrained, adapt_model_from_file from .layers import SelectAdaptivePool2d, create_conv2d from .registry import register_model @@ -438,42 +439,22 @@ class EfficientNetFeatures(nn.Module): channel_multiplier, channel_divisor, channel_min, output_stride, pad_type, act_layer, se_kwargs, norm_layer, norm_kwargs, drop_path_rate, feature_location=feature_location, verbose=_DEBUG) self.blocks = nn.Sequential(*builder(self._in_chs, block_args)) - self._feature_info = builder.features # builder provides info about feature channels for each block + self.feature_info = FeatureInfo(builder.features, out_indices) self._stage_to_feature_idx = { - v['stage_idx']: fi for fi, v in self._feature_info.items() if fi in self.out_indices} + v['stage_idx']: fi for fi, v in enumerate(self.feature_info) if fi in self.out_indices} self._in_chs = builder.in_chs efficientnet_init_weights(self) if _DEBUG: - for k, v in self._feature_info.items(): - print('Feature idx: {}: Name: {}, Channels: {}'.format(k, v['name'], v['num_chs'])) + for fi, v in enumerate(self.feature_info): + print('Feature idx: {}: Name: {}, Channels: {}'.format(fi, v['module'], v['num_chs'])) # Register feature extraction hooks with FeatureHooks helper self.feature_hooks = None if feature_location != 'bottleneck': - hooks = [dict( - name=self._feature_info[idx]['module'], - type=self._feature_info[idx]['hook_type']) for idx in out_indices] + hooks = self.feature_info.get_by_key(keys=('module', 'hook_type')) self.feature_hooks = FeatureHooks(hooks, self.named_modules()) - def feature_channels(self, idx=None): - """ Feature Channel Shortcut - Returns feature channel count for each output index if idx == None. If idx is an integer, will - return feature channel count for that feature block index (independent of out_indices setting). - """ - if isinstance(idx, int): - return self._feature_info[idx]['num_chs'] - return [self._feature_info[i]['num_chs'] for i in self.out_indices] - - def feature_info(self, idx=None): - """ Feature Channel Shortcut - Returns feature channel count for each output index if idx == None. If idx is an integer, will - return feature channel count for that feature block index (independent of out_indices setting). - """ - if isinstance(idx, int): - return self._feature_info[idx] - return [self._feature_info[i] for i in self.out_indices] - def forward(self, x) -> List[torch.Tensor]: x = self.conv_stem(x) x = self.bn1(x) diff --git a/timm/models/efficientnet_builder.py b/timm/models/efficientnet_builder.py index 1e06b4f3..68d39c86 100644 --- a/timm/models/efficientnet_builder.py +++ b/timm/models/efficientnet_builder.py @@ -225,7 +225,7 @@ class EfficientNetBuilder: # state updated during build, consumed by model self.in_chs = None - self.features = OrderedDict() + self.features = [] def _round_channels(self, chs): return round_channels(chs, self.channel_multiplier, self.channel_divisor, self.channel_min) @@ -291,7 +291,6 @@ class EfficientNetBuilder: total_block_idx = 0 current_stride = 2 current_dilation = 1 - feature_idx = 0 stages = [] # outer list of block_args defines the stacks ('stages' by some conventions) for stage_idx, stage_block_args in enumerate(model_block_args): @@ -351,13 +350,15 @@ class EfficientNetBuilder: # stash feature module name and channel info for model feature extraction if extract_features: feature_info = block.feature_info(extract_features) - if feature_info['module']: - feature_info['module'] = 'blocks.{}.{}.'.format(stage_idx, block_idx) + feature_info['module'] + module_name = f'blocks.{stage_idx}.{block_idx}' + if 'module' in feature_info and feature_info['module']: + feature_info['module'] = '.'.join([module_name, feature_info['module']]) + else: + feature_info['module'] = module_name feature_info['stage_idx'] = stage_idx feature_info['block_idx'] = block_idx feature_info['reduction'] = current_stride - self.features[feature_idx] = feature_info - feature_idx += 1 + self.features.append(feature_info) total_block_idx += 1 # incr global block idx (across all stacks) stages.append(nn.Sequential(*blocks)) diff --git a/timm/models/feature_hooks.py b/timm/models/feature_hooks.py index 7b7b3da1..7c3f6f4b 100644 --- a/timm/models/feature_hooks.py +++ b/timm/models/feature_hooks.py @@ -1,3 +1,9 @@ +""" PyTorch Feature Hook Helper + +This class helps gather features from a network via hooks specified on the module name. + +Hacked together by Ross Wightman +""" import torch from collections import defaultdict, OrderedDict @@ -7,20 +13,21 @@ from typing import List class FeatureHooks: - def __init__(self, hooks, named_modules): + def __init__(self, hooks, named_modules, output_as_dict=False): # setup feature hooks modules = {k: v for k, v in named_modules} for h in hooks: - hook_name = h['name'] + hook_name = h['module'] m = modules[hook_name] hook_fn = partial(self._collect_output_hook, hook_name) - if h['type'] == 'forward_pre': + if h['hook_type'] == 'forward_pre': m.register_forward_pre_hook(hook_fn) - elif h['type'] == 'forward': + elif h['hook_type'] == 'forward': m.register_forward_hook(hook_fn) else: assert False, "Unsupported hook type" self._feature_outputs = defaultdict(OrderedDict) + self.output_as_dict = output_as_dict def _collect_output_hook(self, name, *args): x = args[-1] # tensor we want is last argument, output for fwd, input for fwd_pre @@ -29,6 +36,9 @@ class FeatureHooks: self._feature_outputs[x.device][name] = x def get_output(self, device) -> List[torch.tensor]: - output = list(self._feature_outputs[device].values()) + if self.output_as_dict: + output = self._feature_outputs[device] + else: + output = list(self._feature_outputs[device].values()) self._feature_outputs[device] = OrderedDict() # clear after reading return output diff --git a/timm/models/features.py b/timm/models/features.py new file mode 100644 index 00000000..e4c19755 --- /dev/null +++ b/timm/models/features.py @@ -0,0 +1,251 @@ +""" PyTorch Feature Extraction Helpers + +A collection of classes, functions, modules to help extract features from models +and provide a common interface for describing them. + +Hacked together by Ross Wightman +""" +from collections import OrderedDict +from typing import Dict, List, Tuple, Any +from copy import deepcopy + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class FeatureInfo: + + def __init__(self, feature_info: List[Dict], out_indices: Tuple[int]): + prev_reduction = 1 + for fi in feature_info: + # sanity check the mandatory fields, there may be additional fields depending on the model + assert 'num_chs' in fi and fi['num_chs'] > 0 + assert 'reduction' in fi and fi['reduction'] >= prev_reduction + prev_reduction = fi['reduction'] + assert 'module' in fi + self._out_indices = out_indices + self._info = feature_info + + def from_other(self, out_indices: Tuple[int]): + return FeatureInfo(deepcopy(self._info), out_indices) + + def channels(self, idx=None): + """ feature channels accessor + if idx == None, returns feature channel count at each output index + if idx is an integer, return feature channel count for that feature module index + """ + if isinstance(idx, int): + return self._info[idx]['num_chs'] + return [self._info[i]['num_chs'] for i in self._out_indices] + + def reduction(self, idx=None): + """ feature reduction (output stride) accessor + if idx == None, returns feature reduction factor at each output index + if idx is an integer, return feature channel count at that feature module index + """ + if isinstance(idx, int): + return self._info[idx]['reduction'] + return [self._info[i]['reduction'] for i in self._out_indices] + + def module_name(self, idx=None): + """ feature module name accessor + if idx == None, returns feature module name at each output index + if idx is an integer, return feature module name at that feature module index + """ + if isinstance(idx, int): + return self._info[idx]['module'] + return [self._info[i]['module'] for i in self._out_indices] + + def get_by_key(self, idx=None, keys=None): + """ return info dicts for specified keys (or all if None) at specified idx (or out_indices if None) + """ + if isinstance(idx, int): + return self._info[idx] if keys is None else {k: self._info[idx][k] for k in keys} + if keys is None: + return [self._info[i] for i in self._out_indices] + else: + return [{k: self._info[i][k] for k in keys} for i in self._out_indices] + + def __getitem__(self, item): + return self._info[item] + + def __len__(self): + return len(self._info) + + +def _module_list(module, flatten_sequential=False): + # a yield/iter would be better for this but wouldn't be compatible with torchscript + ml = [] + for name, module in module.named_children(): + if flatten_sequential and isinstance(module, nn.Sequential): + # first level of Sequential containers is flattened into containing model + for child_name, child_module in module.named_children(): + ml.append(('_'.join([name, child_name]), child_module)) + else: + ml.append((name, module)) + return ml + + +def _check_return_layers(input_return_layers, modules): + return_layers = {} + for k, v in input_return_layers.items(): + ks = k.split('.') + assert 0 < len(ks) <= 2 + return_layers['_'.join(ks)] = v + return_set = set(return_layers.keys()) + sdiff = return_set - {name for name, _ in modules} + if sdiff: + raise ValueError(f'return_layers {sdiff} are not present in model') + return return_layers, return_set + + +class LayerGetterDict(nn.ModuleDict): + """ + Module wrapper that returns intermediate layers from a model as a dictionary + + Originally based on IntermediateLayerGetter at + https://github.com/pytorch/vision/blob/d88d8961ae51507d0cb680329d985b1488b1b76b/torchvision/models/_utils.py + + It has a strong assumption that the modules have been registered into the model in the same + order as they are used. This means that one should **not** reuse the same nn.Module twice + in the forward if you want this to work. + + Additionally, it is only able to query submodules that are directly assigned to the model + class (`model.feature1`) or at most one Sequential container deep (`model.features.1`, so + long as `features` is a sequential container assigned to the model). + + All Sequential containers that are directly assigned to the original model will have their + modules assigned to this module with the name `model.features.1` being changed to `model.features_1` + + Arguments: + model (nn.Module): model on which we will extract the features + return_layers (Dict[name, new_name]): a dict containing the names + of the modules for which the activations will be returned as + the key of the dict, and the value of the dict is the name + of the returned activation (which the user can specify). + concat (bool): whether to concatenate intermediate features that are lists or tuples + vs select element [0] + flatten_sequential (bool): whether to flatten sequential modules assigned to model + + """ + + def __init__(self, model, return_layers, concat=False, flatten_sequential=False): + modules = _module_list(model, flatten_sequential=flatten_sequential) + self.return_layers, remaining = _check_return_layers(return_layers, modules) + layers = OrderedDict() + self.concat = concat + for name, module in modules: + layers[name] = module + if name in remaining: + remaining.remove(name) + if not remaining: + break + super(LayerGetterDict, self).__init__(layers) + + def forward(self, x) -> Dict[Any, torch.Tensor]: + out = OrderedDict() + for name, module in self.items(): + x = module(x) + if name in self.return_layers: + out_id = self.return_layers[name] + if isinstance(x, (tuple, list)): + # If model tap is a tuple or list, concat or select first element + # FIXME this may need to be more generic / flexible for some nets + out[out_id] = torch.cat(x, 1) if self.concat else x[0] + else: + out[out_id] = x + return out + + +class LayerGetterList(nn.Sequential): + """ + Module wrapper that returns intermediate layers from a model as a list + + Originally based on IntermediateLayerGetter at + https://github.com/pytorch/vision/blob/d88d8961ae51507d0cb680329d985b1488b1b76b/torchvision/models/_utils.py + + It has a strong assumption that the modules have been registered into the model in the same + order as they are used. This means that one should **not** reuse the same nn.Module twice + in the forward if you want this to work. + + Additionally, it is only able to query submodules that are directly assigned to the model + class (`model.feature1`) or at most one Sequential container deep (`model.features.1`) so + long as `features` is a sequential container assigned to the model and flatten_sequent=True. + + All Sequential containers that are directly assigned to the original model will have their + modules assigned to this module with the name `model.features.1` being changed to `model.features_1` + + Arguments: + model (nn.Module): model on which we will extract the features + return_layers (Dict[name, new_name]): a dict containing the names + of the modules for which the activations will be returned as + the key of the dict, and the value of the dict is the name + of the returned activation (which the user can specify). + concat (bool): whether to concatenate intermediate features that are lists or tuples + vs select element [0] + flatten_sequential (bool): whether to flatten sequential modules assigned to model + + """ + + def __init__(self, model, return_layers, concat=False, flatten_sequential=False): + super(LayerGetterList, self).__init__() + modules = _module_list(model, flatten_sequential=flatten_sequential) + self.return_layers, remaining = _check_return_layers(return_layers, modules) + self.concat = concat + for name, module in modules: + self.add_module(name, module) + if name in remaining: + remaining.remove(name) + if not remaining: + break + + def forward(self, x) -> List[torch.Tensor]: + out = [] + for name, module in self.named_children(): + x = module(x) + if name in self.return_layers: + if isinstance(x, (tuple, list)): + # If model tap is a tuple or list, concat or select first element + # FIXME this may need to be more generic / flexible for some nets + out.append(torch.cat(x, 1) if self.concat else x[0]) + else: + out.append(x) + return out + + +def _resolve_feature_info(net, out_indices, feature_info=None): + if feature_info is None: + feature_info = getattr(net, 'feature_info') + if isinstance(feature_info, FeatureInfo): + return feature_info.from_other(out_indices) + elif isinstance(feature_info, (list, tuple)): + return FeatureInfo(net.feature_info, out_indices) + else: + assert False, "Provided feature_info is not valid" + + +class FeatureNet(nn.Module): + """ FeatureNet + + Wrap a model and extract features as specified by the out indices, the network + is partially re-built from contained modules using the LayerGetters. + + Please read the docstrings of the LayerGetter classes, they will not work on all models. + """ + def __init__( + self, net, + out_indices=(0, 1, 2, 3, 4), out_map=None, out_as_dict=False, + feature_info=None, feature_concat=False, flatten_sequential=False): + super(FeatureNet, self).__init__() + self.feature_info = _resolve_feature_info(net, out_indices, feature_info) + module_names = self.feature_info.module_name() + return_layers = {} + for i in range(len(out_indices)): + return_layers[module_names[i]] = out_map[i] if out_map is not None else out_indices[i] + lg_args = dict(return_layers=return_layers, concat=feature_concat, flatten_sequential=flatten_sequential) + self.body = LayerGetterDict(net, **lg_args) if out_as_dict else LayerGetterList(net, **lg_args) + + def forward(self, x): + output = self.body(x) + return output diff --git a/timm/models/gluon_xception.py b/timm/models/gluon_xception.py index 59534007..88a61944 100644 --- a/timm/models/gluon_xception.py +++ b/timm/models/gluon_xception.py @@ -13,16 +13,16 @@ import torch.nn.functional as F from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from .helpers import load_pretrained -from .layers import SelectAdaptivePool2d +from .layers import SelectAdaptivePool2d, get_padding from .registry import register_model -__all__ = ['Xception65', 'Xception71'] +__all__ = ['Xception65'] default_cfgs = { 'gluon_xception65': { 'url': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/gluon_xception-7015a15c.pth', 'input_size': (3, 299, 299), - 'crop_pct': 0.875, + 'crop_pct': 0.903, 'pool_size': (10, 10), 'interpolation': 'bicubic', 'mean': IMAGENET_DEFAULT_MEAN, @@ -32,52 +32,13 @@ default_cfgs = { 'classifier': 'fc' # The resize parameter of the validation transform should be 333, and make sure to center crop at 299x299 }, - 'gluon_xception71': { - 'url': '', - 'input_size': (3, 299, 299), - 'crop_pct': 0.875, - 'pool_size': (5, 5), - 'interpolation': 'bicubic', - 'mean': IMAGENET_DEFAULT_MEAN, - 'std': IMAGENET_DEFAULT_STD, - 'num_classes': 1000, - 'first_conv': 'conv1', - 'classifier': 'fc' - # The resize parameter of the validation transform should be 333, and make sure to center crop at 299x299 - } } """ PADDING NOTES The original PyTorch and Gluon impl of these models dutifully reproduced the aligned padding added to Tensorflow models for Deeplab. This padding was compensating for Tensorflow 'SAME' padding. PyTorch symmetric padding behaves the way we'd want it to. - -So, I'm phasing out the 'fixed_padding' ported from TF and replacing with normal -PyTorch padding, some asserts to validate the equivalence for any scenario we'd -care about before removing altogether. """ -_USE_FIXED_PAD = False - - -def _pytorch_padding(kernel_size, stride=1, dilation=1, **_): - if _USE_FIXED_PAD: - return 0 # FIXME remove once verified - else: - padding = ((stride - 1) + dilation * (kernel_size - 1)) // 2 - - # FIXME remove once verified - fp = _fixed_padding(kernel_size, dilation) - assert all(padding == p for p in fp) - - return padding - - -def _fixed_padding(kernel_size, dilation): - kernel_size_effective = kernel_size + (kernel_size - 1) * (dilation - 1) - pad_total = kernel_size_effective - 1 - pad_beg = pad_total // 2 - pad_end = pad_total - pad_beg - return [pad_beg, pad_end, pad_beg, pad_end] class SeparableConv2d(nn.Module): @@ -88,24 +49,16 @@ class SeparableConv2d(nn.Module): self.kernel_size = kernel_size self.dilation = dilation - padding = _fixed_padding(self.kernel_size, self.dilation) - if _USE_FIXED_PAD and any(p > 0 for p in padding): - self.fixed_padding = nn.ZeroPad2d(padding) - else: - self.fixed_padding = None - # depthwise convolution + padding = get_padding(kernel_size, stride, dilation) self.conv_dw = nn.Conv2d( inplanes, inplanes, kernel_size, stride=stride, - padding=_pytorch_padding(kernel_size, stride, dilation), dilation=dilation, groups=inplanes, bias=bias) + padding=padding, dilation=dilation, groups=inplanes, bias=bias) self.bn = norm_layer(num_features=inplanes, **norm_kwargs) # pointwise convolution self.conv_pw = nn.Conv2d(inplanes, planes, kernel_size=1, bias=bias) def forward(self, x): - if self.fixed_padding is not None: - # FIXME remove once verified - x = self.fixed_padding(x) x = self.conv_dw(x) x = self.bn(x) x = self.conv_pw(x) @@ -113,58 +66,37 @@ class SeparableConv2d(nn.Module): class Block(nn.Module): - def __init__(self, inplanes, planes, num_reps, stride=1, dilation=1, norm_layer=None, - norm_kwargs=None, start_with_relu=True, grow_first=True, is_last=False): + def __init__(self, inplanes, planes, stride=1, dilation=1, start_with_relu=True, + norm_layer=None, norm_kwargs=None, ): super(Block, self).__init__() norm_kwargs = norm_kwargs if norm_kwargs is not None else {} - if planes != inplanes or stride != 1: + if isinstance(planes, (list, tuple)): + assert len(planes) == 3 + else: + planes = (planes,) * 3 + outplanes = planes[-1] + + if outplanes != inplanes or stride != 1: self.skip = nn.Sequential() self.skip.add_module('conv1', nn.Conv2d( - inplanes, planes, 1, stride=stride, bias=False)), - self.skip.add_module('bn1', norm_layer(num_features=planes, **norm_kwargs)) + inplanes, outplanes, 1, stride=stride, bias=False)), + self.skip.add_module('bn1', norm_layer(num_features=outplanes, **norm_kwargs)) else: self.skip = None rep = OrderedDict() - l = 1 - filters = inplanes - if grow_first: - if start_with_relu: - rep['act%d' % l] = nn.ReLU(inplace=False) # NOTE: silent failure if inplace=True here - rep['conv%d' % l] = SeparableConv2d( - inplanes, planes, 3, 1, dilation, norm_layer=norm_layer, norm_kwargs=norm_kwargs) - rep['bn%d' % l] = norm_layer(num_features=planes, **norm_kwargs) - filters = planes - l += 1 - - for _ in range(num_reps - 1): - if grow_first or start_with_relu: - # FIXME being conservative with inplace here, think it's fine to leave True? - rep['act%d' % l] = nn.ReLU(inplace=grow_first or not start_with_relu) - rep['conv%d' % l] = SeparableConv2d( - filters, filters, 3, 1, dilation, norm_layer=norm_layer, norm_kwargs=norm_kwargs) - rep['bn%d' % l] = norm_layer(num_features=filters, **norm_kwargs) - l += 1 - - if not grow_first: - rep['act%d' % l] = nn.ReLU(inplace=True) - rep['conv%d' % l] = SeparableConv2d( - inplanes, planes, 3, 1, dilation, norm_layer=norm_layer, norm_kwargs=norm_kwargs) - rep['bn%d' % l] = norm_layer(num_features=planes, **norm_kwargs) - l += 1 - - if stride != 1: - rep['act%d' % l] = nn.ReLU(inplace=True) - rep['conv%d' % l] = SeparableConv2d( - planes, planes, 3, stride, norm_layer=norm_layer, norm_kwargs=norm_kwargs) - rep['bn%d' % l] = norm_layer(num_features=planes, **norm_kwargs) - l += 1 - elif is_last: - rep['act%d' % l] = nn.ReLU(inplace=True) - rep['conv%d' % l] = SeparableConv2d( - planes, planes, 3, 1, dilation, norm_layer=norm_layer, norm_kwargs=norm_kwargs) - rep['bn%d' % l] = norm_layer(num_features=planes, **norm_kwargs) - l += 1 + for i in range(3): + rep['act%d' % (i + 1)] = nn.ReLU(inplace=True) + rep['conv%d' % (i + 1)] = SeparableConv2d( + inplanes, planes[i], 3, stride=stride if i == 2 else 1, dilation=dilation, + norm_layer=norm_layer, norm_kwargs=norm_kwargs) + rep['bn%d' % (i + 1)] = norm_layer(planes[i], **norm_kwargs) + inplanes = planes[i] + + if not start_with_relu: + del rep['act1'] + else: + rep['act1'] = nn.ReLU(inplace=False) self.rep = nn.Sequential(rep) def forward(self, x): @@ -176,7 +108,10 @@ class Block(nn.Module): class Xception65(nn.Module): - """Modified Aligned Xception + """Modified Aligned Xception. + + NOTE: only the 65 layer version is included here, the 71 layer variant + was not correct and had no pretrained weights """ def __init__(self, num_classes=1000, in_chans=3, output_stride=32, norm_layer=nn.BatchNorm2d, @@ -212,25 +147,21 @@ class Xception65(nn.Module): self.bn2 = norm_layer(num_features=64) self.block1 = Block( - 64, 128, num_reps=2, stride=2, - norm_layer=norm_layer, norm_kwargs=norm_kwargs, start_with_relu=False) + 64, 128, stride=2, start_with_relu=False, norm_layer=norm_layer, norm_kwargs=norm_kwargs) self.block2 = Block( - 128, 256, num_reps=2, stride=2, - norm_layer=norm_layer, norm_kwargs=norm_kwargs, start_with_relu=False, grow_first=True) + 128, 256, stride=2, start_with_relu=False, norm_layer=norm_layer, norm_kwargs=norm_kwargs) self.block3 = Block( - 256, 728, num_reps=2, stride=entry_block3_stride, - norm_layer=norm_layer, norm_kwargs=norm_kwargs, start_with_relu=True, grow_first=True, is_last=True) + 256, 728, stride=entry_block3_stride, norm_layer=norm_layer, norm_kwargs=norm_kwargs) # Middle flow self.mid = nn.Sequential(OrderedDict([('block%d' % i, Block( - 728, 728, num_reps=3, stride=1, dilation=middle_block_dilation, - norm_layer=norm_layer, norm_kwargs=norm_kwargs, start_with_relu=True, grow_first=True)) - for i in range(4, 20)])) + 728, 728, stride=1, dilation=middle_block_dilation, + norm_layer=norm_layer, norm_kwargs=norm_kwargs)) for i in range(4, 20)])) # Exit flow self.block20 = Block( - 728, 1024, num_reps=2, stride=exit_block20_stride, dilation=exit_block_dilations[0], - norm_layer=norm_layer, norm_kwargs=norm_kwargs, start_with_relu=True, grow_first=False, is_last=True) + 728, (728, 1024, 1024), stride=exit_block20_stride, dilation=exit_block_dilations[0], + norm_layer=norm_layer, norm_kwargs=norm_kwargs) self.conv3 = SeparableConv2d( 1024, 1536, 3, stride=1, dilation=exit_block_dilations[1], @@ -305,147 +236,6 @@ class Xception65(nn.Module): return x -class Xception71(nn.Module): - """Modified Aligned Xception - """ - - def __init__(self, num_classes=1000, in_chans=3, output_stride=32, norm_layer=nn.BatchNorm2d, - norm_kwargs=None, drop_rate=0., global_pool='avg'): - super(Xception71, self).__init__() - self.num_classes = num_classes - self.drop_rate = drop_rate - norm_kwargs = norm_kwargs if norm_kwargs is not None else {} - if output_stride == 32: - entry_block3_stride = 2 - exit_block20_stride = 2 - middle_block_dilation = 1 - exit_block_dilations = (1, 1) - elif output_stride == 16: - entry_block3_stride = 2 - exit_block20_stride = 1 - middle_block_dilation = 1 - exit_block_dilations = (1, 2) - elif output_stride == 8: - entry_block3_stride = 1 - exit_block20_stride = 1 - middle_block_dilation = 2 - exit_block_dilations = (2, 4) - else: - raise NotImplementedError - - # Entry flow - self.conv1 = nn.Conv2d(in_chans, 32, kernel_size=3, stride=2, padding=1, bias=False) - self.bn1 = norm_layer(num_features=32, **norm_kwargs) - self.relu = nn.ReLU(inplace=True) - - self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1, bias=False) - self.bn2 = norm_layer(num_features=64) - - self.block1 = Block( - 64, 128, num_reps=2, stride=2, norm_layer=norm_layer, - norm_kwargs=norm_kwargs, start_with_relu=False) - self.block2 = nn.Sequential(*[ - Block( - 128, 256, num_reps=2, stride=1, norm_layer=norm_layer, - norm_kwargs=norm_kwargs, start_with_relu=False, grow_first=True), - Block( - 256, 256, num_reps=2, stride=2, norm_layer=norm_layer, - norm_kwargs=norm_kwargs, start_with_relu=False, grow_first=True), - Block( - 256, 728, num_reps=2, stride=2, norm_layer=norm_layer, - norm_kwargs=norm_kwargs, start_with_relu=False, grow_first=True)]) - self.block3 = Block( - 728, 728, num_reps=2, stride=entry_block3_stride, norm_layer=norm_layer, - norm_kwargs=norm_kwargs, start_with_relu=True, grow_first=True, is_last=True) - - # Middle flow - self.mid = nn.Sequential(OrderedDict([('block%d' % i, Block( - 728, 728, num_reps=3, stride=1, dilation=middle_block_dilation, - norm_layer=norm_layer, norm_kwargs=norm_kwargs, start_with_relu=True, grow_first=True)) - for i in range(4, 20)])) - - # Exit flow - self.block20 = Block( - 728, 1024, num_reps=2, stride=exit_block20_stride, dilation=exit_block_dilations[0], - norm_layer=norm_layer, norm_kwargs=norm_kwargs, start_with_relu=True, grow_first=False, is_last=True) - - self.conv3 = SeparableConv2d( - 1024, 1536, 3, stride=1, dilation=exit_block_dilations[1], - norm_layer=norm_layer, norm_kwargs=norm_kwargs) - self.bn3 = norm_layer(num_features=1536, **norm_kwargs) - - self.conv4 = SeparableConv2d( - 1536, 1536, 3, stride=1, dilation=exit_block_dilations[1], - norm_layer=norm_layer, norm_kwargs=norm_kwargs) - self.bn4 = norm_layer(num_features=1536, **norm_kwargs) - - self.num_features = 2048 - self.conv5 = SeparableConv2d( - 1536, self.num_features, 3, stride=1, dilation=exit_block_dilations[1], - norm_layer=norm_layer, norm_kwargs=norm_kwargs) - self.bn5 = norm_layer(num_features=self.num_features, **norm_kwargs) - self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) - self.fc = nn.Linear(self.num_features * self.global_pool.feat_mult(), num_classes) - - def get_classifier(self): - return self.fc - - def reset_classifier(self, num_classes, global_pool='avg'): - self.num_classes = num_classes - self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) - if num_classes: - num_features = self.num_features * self.global_pool.feat_mult() - self.fc = nn.Linear(num_features, num_classes) - else: - self.fc = nn.Identity() - - def forward_features(self, x): - # Entry flow - x = self.conv1(x) - x = self.bn1(x) - x = self.relu(x) - - x = self.conv2(x) - x = self.bn2(x) - x = self.relu(x) - - x = self.block1(x) - # add relu here - x = self.relu(x) - # low_level_feat = x - x = self.block2(x) - # c2 = x - x = self.block3(x) - - # Middle flow - x = self.mid(x) - # c3 = x - - # Exit flow - x = self.block20(x) - x = self.relu(x) - x = self.conv3(x) - x = self.bn3(x) - x = self.relu(x) - - x = self.conv4(x) - x = self.bn4(x) - x = self.relu(x) - - x = self.conv5(x) - x = self.bn5(x) - x = self.relu(x) - return x - - def forward(self, x): - x = self.forward_features(x) - x = self.global_pool(x).flatten(1) - if self.drop_rate: - F.dropout(x, self.drop_rate, training=self.training) - x = self.fc(x) - return x - - @register_model def gluon_xception65(pretrained=False, num_classes=1000, in_chans=3, **kwargs): """ Modified Aligned Xception-65 @@ -456,15 +246,3 @@ def gluon_xception65(pretrained=False, num_classes=1000, in_chans=3, **kwargs): if pretrained: load_pretrained(model, default_cfg, num_classes, in_chans) return model - - -@register_model -def gluon_xception71(pretrained=False, num_classes=1000, in_chans=3, **kwargs): - """ Modified Aligned Xception-71 - """ - default_cfg = default_cfgs['gluon_xception71'] - model = Xception71(num_classes=num_classes, in_chans=in_chans, **kwargs) - model.default_cfg = default_cfg - if pretrained: - load_pretrained(model, default_cfg, num_classes, in_chans) - return model diff --git a/timm/models/inception_resnet_v2.py b/timm/models/inception_resnet_v2.py index f8772cc8..951648c7 100644 --- a/timm/models/inception_resnet_v2.py +++ b/timm/models/inception_resnet_v2.py @@ -7,6 +7,7 @@ import torch.nn as nn import torch.nn.functional as F from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD +from .features import FeatureNet from .helpers import load_pretrained from .layers import SelectAdaptivePool2d from .registry import register_model @@ -231,9 +232,13 @@ class InceptionResnetV2(nn.Module): self.conv2d_1a = BasicConv2d(in_chans, 32, kernel_size=3, stride=2) self.conv2d_2a = BasicConv2d(32, 32, kernel_size=3, stride=1) self.conv2d_2b = BasicConv2d(32, 64, kernel_size=3, stride=1, padding=1) + self.feature_info = [dict(num_chs=64, reduction=2, module='conv2d_2b')] + self.maxpool_3a = nn.MaxPool2d(3, stride=2) self.conv2d_3b = BasicConv2d(64, 80, kernel_size=1, stride=1) self.conv2d_4a = BasicConv2d(80, 192, kernel_size=3, stride=1) + self.feature_info += [dict(num_chs=192, reduction=4, module='conv2d_4a')] + self.maxpool_5a = nn.MaxPool2d(3, stride=2) self.mixed_5b = Mixed_5b() self.repeat = nn.Sequential( @@ -248,6 +253,8 @@ class InceptionResnetV2(nn.Module): Block35(scale=0.17), Block35(scale=0.17) ) + self.feature_info += [dict(num_chs=320, reduction=8, module='repeat')] + self.mixed_6a = Mixed_6a() self.repeat_1 = nn.Sequential( Block17(scale=0.10), @@ -271,6 +278,8 @@ class InceptionResnetV2(nn.Module): Block17(scale=0.10), Block17(scale=0.10) ) + self.feature_info += [dict(num_chs=1088, reduction=16, module='repeat_1')] + self.mixed_7a = Mixed_7a() self.repeat_2 = nn.Sequential( Block8(scale=0.20), @@ -285,6 +294,8 @@ class InceptionResnetV2(nn.Module): ) self.block8 = Block8(no_relu=True) self.conv2d_7b = BasicConv2d(2080, self.num_features, kernel_size=1, stride=1) + self.feature_info += [dict(num_chs=self.num_features, reduction=32, module='conv2d_7b')] + self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) # NOTE some variants/checkpoints for this model may have 'last_linear' as the name for the FC self.classif = nn.Linear(self.num_features * self.global_pool.feat_mult(), num_classes) @@ -328,30 +339,34 @@ class InceptionResnetV2(nn.Module): return x +def _inception_resnet_v2(variant, pretrained=False, **kwargs): + load_strict, features, out_indices = True, False, None + if kwargs.pop('features_only', False): + load_strict, features, out_indices = False, True, kwargs.pop('out_indices', (0, 1, 2, 3, 4)) + kwargs.pop('num_classes', 0) + model = InceptionResnetV2(**kwargs) + model.default_cfg = default_cfgs[variant] + if pretrained: + load_pretrained( + model, + num_classes=kwargs.get('num_classes', 0), in_chans=kwargs.get('in_chans', 3), strict=load_strict) + if features: + model = FeatureNet(model, out_indices) + return model + + @register_model -def inception_resnet_v2(pretrained=False, num_classes=1000, in_chans=3, **kwargs): +def inception_resnet_v2(pretrained=False, **kwargs): r"""InceptionResnetV2 model architecture from the `"InceptionV4, Inception-ResNet..." ` paper. """ - default_cfg = default_cfgs['inception_resnet_v2'] - model = InceptionResnetV2(num_classes=num_classes, in_chans=in_chans, **kwargs) - model.default_cfg = default_cfg - if pretrained: - load_pretrained(model, default_cfg, num_classes, in_chans) - - return model + return _inception_resnet_v2('inception_resnet_v2', pretrained=pretrained, **kwargs) @register_model -def ens_adv_inception_resnet_v2(pretrained=False, num_classes=1000, in_chans=3, **kwargs): +def ens_adv_inception_resnet_v2(pretrained=False, **kwargs): r""" Ensemble Adversarially trained InceptionResnetV2 model architecture As per https://arxiv.org/abs/1705.07204 and https://github.com/tensorflow/models/tree/master/research/adv_imagenet_models. """ - default_cfg = default_cfgs['ens_adv_inception_resnet_v2'] - model = InceptionResnetV2(num_classes=num_classes, in_chans=in_chans, **kwargs) - model.default_cfg = default_cfg - if pretrained: - load_pretrained(model, default_cfg, num_classes, in_chans) - - return model + return _inception_resnet_v2('ens_adv_inception_resnet_v2', pretrained=pretrained, **kwargs) diff --git a/timm/models/mobilenetv3.py b/timm/models/mobilenetv3.py index 9c4a9af5..850f7120 100644 --- a/timm/models/mobilenetv3.py +++ b/timm/models/mobilenetv3.py @@ -17,6 +17,7 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCE from .efficientnet_blocks import round_channels, resolve_bn_args, resolve_act_layer, BN_EPS_TF_DEFAULT from .efficientnet_builder import EfficientNetBuilder, decode_arch_def, efficientnet_init_weights from .feature_hooks import FeatureHooks +from .features import FeatureInfo from .helpers import load_pretrained from .layers import SelectAdaptivePool2d, create_conv2d, get_act_fn, hard_sigmoid from .registry import register_model @@ -182,22 +183,20 @@ class MobileNetV3Features(nn.Module): channel_multiplier, 8, None, output_stride, pad_type, act_layer, se_kwargs, norm_layer, norm_kwargs, drop_path_rate, feature_location=feature_location, verbose=_DEBUG) self.blocks = nn.Sequential(*builder(self._in_chs, block_args)) - self._feature_info = builder.features # builder provides info about feature channels for each block + self.feature_info = FeatureInfo(builder.features, out_indices) self._stage_to_feature_idx = { - v['stage_idx']: fi for fi, v in self._feature_info.items() if fi in self.out_indices} + v['stage_idx']: fi for fi, v in enumerate(self.feature_info) if fi in self.out_indices} self._in_chs = builder.in_chs efficientnet_init_weights(self) if _DEBUG: - for k, v in self._feature_info.items(): - print('Feature idx: {}: Name: {}, Channels: {}'.format(k, v['name'], v['num_chs'])) + for fi, v in enumerate(self.feature_info): + print('Feature idx: {}: Name: {}, Channels: {}'.format(fi, v['module'], v['num_chs'])) # Register feature extraction hooks with FeatureHooks helper self.feature_hooks = None if feature_location != 'bottleneck': - hooks = [dict( - name=self._feature_info[idx]['module'], - type=self._feature_info[idx]['hook_type']) for idx in out_indices] + hooks = self.feature_info.get_by_key(keys=('module', 'hook_type')) self.feature_hooks = FeatureHooks(hooks, self.named_modules()) def feature_channels(self, idx=None): @@ -206,17 +205,8 @@ class MobileNetV3Features(nn.Module): return feature channel count for that feature block index (independent of out_indices setting). """ if isinstance(idx, int): - return self._feature_info[idx]['num_chs'] - return [self._feature_info[i]['num_chs'] for i in self.out_indices] - - def feature_info(self, idx=None): - """ Feature Channel Shortcut - Returns feature channel count for each output index if idx == None. If idx is an integer, will - return feature channel count for that feature block index (independent of out_indices setting). - """ - if isinstance(idx, int): - return self._feature_info[idx] - return [self._feature_info[i] for i in self.out_indices] + return self.feature_info[idx]['num_chs'] + return [self.feature_info[i]['num_chs'] for i in self.out_indices] def forward(self, x) -> List[torch.Tensor]: x = self.conv_stem(x) diff --git a/timm/models/nasnet.py b/timm/models/nasnet.py index fea928c1..4e23eb99 100644 --- a/timm/models/nasnet.py +++ b/timm/models/nasnet.py @@ -3,7 +3,7 @@ import torch.nn as nn import torch.nn.functional as F from .helpers import load_pretrained -from .layers import SelectAdaptivePool2d +from .layers import SelectAdaptivePool2d, ConvBnAct, create_conv2d, create_pool2d from .registry import register_model __all__ = ['NASNetALarge'] @@ -24,43 +24,31 @@ default_cfgs = { } -class MaxPoolPad(nn.Module): +class ActConvBn(nn.Module): - def __init__(self): - super(MaxPoolPad, self).__init__() - self.pad = nn.ZeroPad2d((1, 0, 1, 0)) - self.pool = nn.MaxPool2d(3, stride=2, padding=1) + def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=''): + super(ActConvBn, self).__init__() + self.act = nn.ReLU() + self.conv = create_conv2d( + in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding) + self.bn = nn.BatchNorm2d(out_channels, eps=0.001, momentum=0.1) def forward(self, x): - x = self.pad(x) - x = self.pool(x) - x = x[:, :, 1:, 1:] - return x - - -class AvgPoolPad(nn.Module): - - def __init__(self, stride=2, padding=1): - super(AvgPoolPad, self).__init__() - self.pad = nn.ZeroPad2d((1, 0, 1, 0)) - self.pool = nn.AvgPool2d(3, stride=stride, padding=padding, count_include_pad=False) - - def forward(self, x): - x = self.pad(x) - x = self.pool(x) - x = x[:, :, 1:, 1:] + x = self.act(x) + x = self.conv(x) + x = self.bn(x) return x class SeparableConv2d(nn.Module): - def __init__(self, in_channels, out_channels, dw_kernel, dw_stride, dw_padding, bias=False): + def __init__(self, in_channels, out_channels, kernel_size, stride, padding=''): super(SeparableConv2d, self).__init__() - self.depthwise_conv2d = nn.Conv2d( - in_channels, in_channels, dw_kernel, - stride=dw_stride, padding=dw_padding, - bias=bias, groups=in_channels) - self.pointwise_conv2d = nn.Conv2d(in_channels, out_channels, 1, stride=1, bias=bias) + self.depthwise_conv2d = create_conv2d( + in_channels, in_channels, kernel_size=kernel_size, + stride=stride, padding=padding, groups=in_channels) + self.pointwise_conv2d = create_conv2d( + in_channels, out_channels, kernel_size=1, padding=0) def forward(self, x): x = self.depthwise_conv2d(x) @@ -70,87 +58,48 @@ class SeparableConv2d(nn.Module): class BranchSeparables(nn.Module): - def __init__(self, in_channels, out_channels, kernel_size, stride, padding, bias=False): + def __init__(self, in_channels, out_channels, kernel_size, stride=1, pad_type='', stem_cell=False): super(BranchSeparables, self).__init__() - self.relu = nn.ReLU() - self.separable_1 = SeparableConv2d(in_channels, in_channels, kernel_size, stride, padding, bias=bias) - self.bn_sep_1 = nn.BatchNorm2d(in_channels, eps=0.001, momentum=0.1, affine=True) - self.relu1 = nn.ReLU() - self.separable_2 = SeparableConv2d(in_channels, out_channels, kernel_size, 1, padding, bias=bias) - self.bn_sep_2 = nn.BatchNorm2d(out_channels, eps=0.001, momentum=0.1, affine=True) - - def forward(self, x): - x = self.relu(x) - x = self.separable_1(x) - x = self.bn_sep_1(x) - x = self.relu1(x) - x = self.separable_2(x) - x = self.bn_sep_2(x) - return x - - -class BranchSeparablesStem(nn.Module): - - def __init__(self, in_channels, out_channels, kernel_size, stride, padding, bias=False): - super(BranchSeparablesStem, self).__init__() - self.relu = nn.ReLU() - self.separable_1 = SeparableConv2d(in_channels, out_channels, kernel_size, stride, padding, bias=bias) - self.bn_sep_1 = nn.BatchNorm2d(out_channels, eps=0.001, momentum=0.1, affine=True) - self.relu1 = nn.ReLU() - self.separable_2 = SeparableConv2d(out_channels, out_channels, kernel_size, 1, padding, bias=bias) - self.bn_sep_2 = nn.BatchNorm2d(out_channels, eps=0.001, momentum=0.1, affine=True) - - def forward(self, x): - x = self.relu(x) - x = self.separable_1(x) - x = self.bn_sep_1(x) - x = self.relu1(x) - x = self.separable_2(x) - x = self.bn_sep_2(x) - return x - - -class BranchSeparablesReduction(BranchSeparables): - - def __init__(self, in_channels, out_channels, kernel_size, stride, padding, z_padding=1, bias=False): - BranchSeparables.__init__(self, in_channels, out_channels, kernel_size, stride, padding, bias) - self.padding = nn.ZeroPad2d((z_padding, 0, z_padding, 0)) + middle_channels = out_channels if stem_cell else in_channels + self.act_1 = nn.ReLU() + self.separable_1 = SeparableConv2d( + in_channels, middle_channels, kernel_size, stride=stride, padding=pad_type) + self.bn_sep_1 = nn.BatchNorm2d(middle_channels, eps=0.001, momentum=0.1) + self.act_2 = nn.ReLU(inplace=True) + self.separable_2 = SeparableConv2d( + middle_channels, out_channels, kernel_size, stride=1, padding=pad_type) + self.bn_sep_2 = nn.BatchNorm2d(out_channels, eps=0.001, momentum=0.1) def forward(self, x): - x = self.relu(x) - x = self.padding(x) + x = self.act_1(x) x = self.separable_1(x) - x = x[:, :, 1:, 1:].contiguous() x = self.bn_sep_1(x) - x = self.relu1(x) + x = self.act_2(x) x = self.separable_2(x) x = self.bn_sep_2(x) return x class CellStem0(nn.Module): - def __init__(self, stem_size, num_channels=42): + def __init__(self, stem_size, num_channels=42, pad_type=''): super(CellStem0, self).__init__() self.num_channels = num_channels self.stem_size = stem_size - self.conv_1x1 = nn.Sequential() - self.conv_1x1.add_module('relu', nn.ReLU()) - self.conv_1x1.add_module('conv', nn.Conv2d(self.stem_size, self.num_channels, 1, stride=1, bias=False)) - self.conv_1x1.add_module('bn', nn.BatchNorm2d(self.num_channels, eps=0.001, momentum=0.1, affine=True)) + self.conv_1x1 = ActConvBn(self.stem_size, self.num_channels, 1, stride=1) - self.comb_iter_0_left = BranchSeparables(self.num_channels, self.num_channels, 5, 2, 2) - self.comb_iter_0_right = BranchSeparablesStem(self.stem_size, self.num_channels, 7, 2, 3, bias=False) + self.comb_iter_0_left = BranchSeparables(self.num_channels, self.num_channels, 5, 2, pad_type) + self.comb_iter_0_right = BranchSeparables(self.stem_size, self.num_channels, 7, 2, pad_type, stem_cell=True) - self.comb_iter_1_left = nn.MaxPool2d(3, stride=2, padding=1) - self.comb_iter_1_right = BranchSeparablesStem(self.stem_size, self.num_channels, 7, 2, 3, bias=False) + self.comb_iter_1_left = create_pool2d('max', 3, 2, padding=pad_type) + self.comb_iter_1_right = BranchSeparables(self.stem_size, self.num_channels, 7, 2, pad_type, stem_cell=True) - self.comb_iter_2_left = nn.AvgPool2d(3, stride=2, padding=1, count_include_pad=False) - self.comb_iter_2_right = BranchSeparablesStem(self.stem_size, self.num_channels, 5, 2, 2, bias=False) + self.comb_iter_2_left = create_pool2d('avg', 3, 2, count_include_pad=False, padding=pad_type) + self.comb_iter_2_right = BranchSeparables(self.stem_size, self.num_channels, 5, 2, pad_type, stem_cell=True) - self.comb_iter_3_right = nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False) + self.comb_iter_3_right = create_pool2d('avg', 3, 1, count_include_pad=False, padding=pad_type) - self.comb_iter_4_left = BranchSeparables(self.num_channels, self.num_channels, 3, 1, 1, bias=False) - self.comb_iter_4_right = nn.MaxPool2d(3, stride=2, padding=1) + self.comb_iter_4_left = BranchSeparables(self.num_channels, self.num_channels, 3, 1, pad_type) + self.comb_iter_4_right = create_pool2d('max', 3, 2, padding=pad_type) def forward(self, x): x1 = self.conv_1x1(x) @@ -180,51 +129,46 @@ class CellStem0(nn.Module): class CellStem1(nn.Module): - def __init__(self, stem_size, num_channels): + def __init__(self, stem_size, num_channels, pad_type=''): super(CellStem1, self).__init__() self.num_channels = num_channels self.stem_size = stem_size - self.conv_1x1 = nn.Sequential() - self.conv_1x1.add_module('relu', nn.ReLU()) - self.conv_1x1.add_module('conv', nn.Conv2d(2 * self.num_channels, self.num_channels, 1, stride=1, bias=False)) - self.conv_1x1.add_module('bn', nn.BatchNorm2d(self.num_channels, eps=0.001, momentum=0.1, affine=True)) + self.conv_1x1 = ActConvBn(2 * self.num_channels, self.num_channels, 1, stride=1) - self.relu = nn.ReLU() + self.act = nn.ReLU() self.path_1 = nn.Sequential() self.path_1.add_module('avgpool', nn.AvgPool2d(1, stride=2, count_include_pad=False)) self.path_1.add_module('conv', nn.Conv2d(self.stem_size, self.num_channels // 2, 1, stride=1, bias=False)) - self.path_2 = nn.ModuleList() - self.path_2.add_module('pad', nn.ZeroPad2d((0, 1, 0, 1))) + + self.path_2 = nn.Sequential() + self.path_2.add_module('pad', nn.ZeroPad2d((-1, 1, -1, 1))) self.path_2.add_module('avgpool', nn.AvgPool2d(1, stride=2, count_include_pad=False)) self.path_2.add_module('conv', nn.Conv2d(self.stem_size, self.num_channels // 2, 1, stride=1, bias=False)) - self.final_path_bn = nn.BatchNorm2d(self.num_channels, eps=0.001, momentum=0.1, affine=True) + self.final_path_bn = nn.BatchNorm2d(self.num_channels, eps=0.001, momentum=0.1) - self.comb_iter_0_left = BranchSeparables(self.num_channels, self.num_channels, 5, 2, 2, bias=False) - self.comb_iter_0_right = BranchSeparables(self.num_channels, self.num_channels, 7, 2, 3, bias=False) + self.comb_iter_0_left = BranchSeparables(self.num_channels, self.num_channels, 5, 2, pad_type) + self.comb_iter_0_right = BranchSeparables(self.num_channels, self.num_channels, 7, 2, pad_type) - self.comb_iter_1_left = nn.MaxPool2d(3, stride=2, padding=1) - self.comb_iter_1_right = BranchSeparables(self.num_channels, self.num_channels, 7, 2, 3, bias=False) + self.comb_iter_1_left = create_pool2d('max', 3, 2, padding=pad_type) + self.comb_iter_1_right = BranchSeparables(self.num_channels, self.num_channels, 7, 2, pad_type) - self.comb_iter_2_left = nn.AvgPool2d(3, stride=2, padding=1, count_include_pad=False) - self.comb_iter_2_right = BranchSeparables(self.num_channels, self.num_channels, 5, 2, 2, bias=False) + self.comb_iter_2_left = create_pool2d('avg', 3, 2, count_include_pad=False, padding=pad_type) + self.comb_iter_2_right = BranchSeparables(self.num_channels, self.num_channels, 5, 2, pad_type) - self.comb_iter_3_right = nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False) + self.comb_iter_3_right = create_pool2d('avg', 3, 1, count_include_pad=False, padding=pad_type) - self.comb_iter_4_left = BranchSeparables(self.num_channels, self.num_channels, 3, 1, 1, bias=False) - self.comb_iter_4_right = nn.MaxPool2d(3, stride=2, padding=1) + self.comb_iter_4_left = BranchSeparables(self.num_channels, self.num_channels, 3, 1, pad_type) + self.comb_iter_4_right = create_pool2d('max', 3, 2, padding=pad_type) def forward(self, x_conv0, x_stem_0): x_left = self.conv_1x1(x_stem_0) - x_relu = self.relu(x_conv0) + x_relu = self.act(x_conv0) # path 1 x_path1 = self.path_1(x_relu) # path 2 - x_path2 = self.path_2.pad(x_relu) - x_path2 = x_path2[:, :, 1:, 1:] - x_path2 = self.path_2.avgpool(x_path2) - x_path2 = self.path_2.conv(x_path2) + x_path2 = self.path_2(x_relu) # final path x_right = self.final_path_bn(torch.cat([x_path1, x_path2], 1)) @@ -253,49 +197,40 @@ class CellStem1(nn.Module): class FirstCell(nn.Module): - def __init__(self, in_channels_left, out_channels_left, in_channels_right, out_channels_right): + def __init__(self, in_chs_left, out_chs_left, in_chs_right, out_chs_right, pad_type=''): super(FirstCell, self).__init__() - self.conv_1x1 = nn.Sequential() - self.conv_1x1.add_module('relu', nn.ReLU()) - self.conv_1x1.add_module('conv', nn.Conv2d(in_channels_right, out_channels_right, 1, stride=1, bias=False)) - self.conv_1x1.add_module('bn', nn.BatchNorm2d(out_channels_right, eps=0.001, momentum=0.1, affine=True)) + self.conv_1x1 = ActConvBn(in_chs_right, out_chs_right, 1, stride=1) - self.relu = nn.ReLU() + self.act = nn.ReLU() self.path_1 = nn.Sequential() self.path_1.add_module('avgpool', nn.AvgPool2d(1, stride=2, count_include_pad=False)) - self.path_1.add_module('conv', nn.Conv2d(in_channels_left, out_channels_left, 1, stride=1, bias=False)) - self.path_2 = nn.ModuleList() - self.path_2.add_module('pad', nn.ZeroPad2d((0, 1, 0, 1))) + self.path_1.add_module('conv', nn.Conv2d(in_chs_left, out_chs_left, 1, stride=1, bias=False)) + + self.path_2 = nn.Sequential() + self.path_2.add_module('pad', nn.ZeroPad2d((-1, 1, -1, 1))) self.path_2.add_module('avgpool', nn.AvgPool2d(1, stride=2, count_include_pad=False)) - self.path_2.add_module('conv', nn.Conv2d(in_channels_left, out_channels_left, 1, stride=1, bias=False)) + self.path_2.add_module('conv', nn.Conv2d(in_chs_left, out_chs_left, 1, stride=1, bias=False)) - self.final_path_bn = nn.BatchNorm2d(out_channels_left * 2, eps=0.001, momentum=0.1, affine=True) + self.final_path_bn = nn.BatchNorm2d(out_chs_left * 2, eps=0.001, momentum=0.1) - self.comb_iter_0_left = BranchSeparables(out_channels_right, out_channels_right, 5, 1, 2, bias=False) - self.comb_iter_0_right = BranchSeparables(out_channels_right, out_channels_right, 3, 1, 1, bias=False) + self.comb_iter_0_left = BranchSeparables(out_chs_right, out_chs_right, 5, 1, pad_type) + self.comb_iter_0_right = BranchSeparables(out_chs_right, out_chs_right, 3, 1, pad_type) - self.comb_iter_1_left = BranchSeparables(out_channels_right, out_channels_right, 5, 1, 2, bias=False) - self.comb_iter_1_right = BranchSeparables(out_channels_right, out_channels_right, 3, 1, 1, bias=False) + self.comb_iter_1_left = BranchSeparables(out_chs_right, out_chs_right, 5, 1, pad_type) + self.comb_iter_1_right = BranchSeparables(out_chs_right, out_chs_right, 3, 1, pad_type) - self.comb_iter_2_left = nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False) + self.comb_iter_2_left = create_pool2d('avg', 3, 1, count_include_pad=False, padding=pad_type) - self.comb_iter_3_left = nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False) - self.comb_iter_3_right = nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False) + self.comb_iter_3_left = create_pool2d('avg', 3, 1, count_include_pad=False, padding=pad_type) + self.comb_iter_3_right = create_pool2d('avg', 3, 1, count_include_pad=False, padding=pad_type) - self.comb_iter_4_left = BranchSeparables(out_channels_right, out_channels_right, 3, 1, 1, bias=False) + self.comb_iter_4_left = BranchSeparables(out_chs_right, out_chs_right, 3, 1, pad_type) def forward(self, x, x_prev): - x_relu = self.relu(x_prev) - # path 1 + x_relu = self.act(x_prev) x_path1 = self.path_1(x_relu) - # path 2 - x_path2 = self.path_2.pad(x_relu) - x_path2 = x_path2[:, :, 1:, 1:] - x_path2 = self.path_2.avgpool(x_path2) - x_path2 = self.path_2.conv(x_path2) - # final path + x_path2 = self.path_2(x_relu) x_left = self.final_path_bn(torch.cat([x_path1, x_path2], 1)) - x_right = self.conv_1x1(x) x_comb_iter_0_left = self.comb_iter_0_left(x_right) @@ -322,30 +257,23 @@ class FirstCell(nn.Module): class NormalCell(nn.Module): - def __init__(self, in_channels_left, out_channels_left, in_channels_right, out_channels_right): + def __init__(self, in_chs_left, out_chs_left, in_chs_right, out_chs_right, pad_type=''): super(NormalCell, self).__init__() - self.conv_prev_1x1 = nn.Sequential() - self.conv_prev_1x1.add_module('relu', nn.ReLU()) - self.conv_prev_1x1.add_module('conv', nn.Conv2d(in_channels_left, out_channels_left, 1, stride=1, bias=False)) - self.conv_prev_1x1.add_module('bn', nn.BatchNorm2d(out_channels_left, eps=0.001, momentum=0.1, affine=True)) - - self.conv_1x1 = nn.Sequential() - self.conv_1x1.add_module('relu', nn.ReLU()) - self.conv_1x1.add_module('conv', nn.Conv2d(in_channels_right, out_channels_right, 1, stride=1, bias=False)) - self.conv_1x1.add_module('bn', nn.BatchNorm2d(out_channels_right, eps=0.001, momentum=0.1, affine=True)) + self.conv_prev_1x1 = ActConvBn(in_chs_left, out_chs_left, 1, stride=1, padding=pad_type) + self.conv_1x1 = ActConvBn(in_chs_right, out_chs_right, 1, stride=1, padding=pad_type) - self.comb_iter_0_left = BranchSeparables(out_channels_right, out_channels_right, 5, 1, 2, bias=False) - self.comb_iter_0_right = BranchSeparables(out_channels_left, out_channels_left, 3, 1, 1, bias=False) + self.comb_iter_0_left = BranchSeparables(out_chs_right, out_chs_right, 5, 1, pad_type) + self.comb_iter_0_right = BranchSeparables(out_chs_left, out_chs_left, 3, 1, pad_type) - self.comb_iter_1_left = BranchSeparables(out_channels_left, out_channels_left, 5, 1, 2, bias=False) - self.comb_iter_1_right = BranchSeparables(out_channels_left, out_channels_left, 3, 1, 1, bias=False) + self.comb_iter_1_left = BranchSeparables(out_chs_left, out_chs_left, 5, 1, pad_type) + self.comb_iter_1_right = BranchSeparables(out_chs_left, out_chs_left, 3, 1, pad_type) - self.comb_iter_2_left = nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False) + self.comb_iter_2_left = create_pool2d('avg', 3, 1, count_include_pad=False, padding=pad_type) - self.comb_iter_3_left = nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False) - self.comb_iter_3_right = nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False) + self.comb_iter_3_left = create_pool2d('avg', 3, 1, count_include_pad=False, padding=pad_type) + self.comb_iter_3_right = create_pool2d('avg', 3, 1, count_include_pad=False, padding=pad_type) - self.comb_iter_4_left = BranchSeparables(out_channels_right, out_channels_right, 3, 1, 1, bias=False) + self.comb_iter_4_left = BranchSeparables(out_chs_right, out_chs_right, 3, 1, pad_type) def forward(self, x, x_prev): x_left = self.conv_prev_1x1(x_prev) @@ -375,31 +303,24 @@ class NormalCell(nn.Module): class ReductionCell0(nn.Module): - def __init__(self, in_channels_left, out_channels_left, in_channels_right, out_channels_right): + def __init__(self, in_chs_left, out_chs_left, in_chs_right, out_chs_right, pad_type=''): super(ReductionCell0, self).__init__() - self.conv_prev_1x1 = nn.Sequential() - self.conv_prev_1x1.add_module('relu', nn.ReLU()) - self.conv_prev_1x1.add_module('conv', nn.Conv2d(in_channels_left, out_channels_left, 1, stride=1, bias=False)) - self.conv_prev_1x1.add_module('bn', nn.BatchNorm2d(out_channels_left, eps=0.001, momentum=0.1, affine=True)) + self.conv_prev_1x1 = ActConvBn(in_chs_left, out_chs_left, 1, stride=1, padding=pad_type) + self.conv_1x1 = ActConvBn(in_chs_right, out_chs_right, 1, stride=1, padding=pad_type) - self.conv_1x1 = nn.Sequential() - self.conv_1x1.add_module('relu', nn.ReLU()) - self.conv_1x1.add_module('conv', nn.Conv2d(in_channels_right, out_channels_right, 1, stride=1, bias=False)) - self.conv_1x1.add_module('bn', nn.BatchNorm2d(out_channels_right, eps=0.001, momentum=0.1, affine=True)) + self.comb_iter_0_left = BranchSeparables(out_chs_right, out_chs_right, 5, 2, pad_type) + self.comb_iter_0_right = BranchSeparables(out_chs_right, out_chs_right, 7, 2, pad_type) - self.comb_iter_0_left = BranchSeparablesReduction(out_channels_right, out_channels_right, 5, 2, 2, bias=False) - self.comb_iter_0_right = BranchSeparablesReduction(out_channels_right, out_channels_right, 7, 2, 3, bias=False) + self.comb_iter_1_left = create_pool2d('max', 3, 2, padding=pad_type) + self.comb_iter_1_right = BranchSeparables(out_chs_right, out_chs_right, 7, 2, pad_type) - self.comb_iter_1_left = MaxPoolPad() - self.comb_iter_1_right = BranchSeparablesReduction(out_channels_right, out_channels_right, 7, 2, 3, bias=False) + self.comb_iter_2_left = create_pool2d('avg', 3, 2, count_include_pad=False, padding=pad_type) + self.comb_iter_2_right = BranchSeparables(out_chs_right, out_chs_right, 5, 2, pad_type) - self.comb_iter_2_left = AvgPoolPad() - self.comb_iter_2_right = BranchSeparablesReduction(out_channels_right, out_channels_right, 5, 2, 2, bias=False) + self.comb_iter_3_right = create_pool2d('avg', 3, 1, count_include_pad=False, padding=pad_type) - self.comb_iter_3_right = nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False) - - self.comb_iter_4_left = BranchSeparablesReduction(out_channels_right, out_channels_right, 3, 1, 1, bias=False) - self.comb_iter_4_right = MaxPoolPad() + self.comb_iter_4_left = BranchSeparables(out_chs_right, out_chs_right, 3, 1, pad_type) + self.comb_iter_4_right = create_pool2d('max', 3, 2, padding=pad_type) def forward(self, x, x_prev): x_left = self.conv_prev_1x1(x_prev) @@ -430,31 +351,24 @@ class ReductionCell0(nn.Module): class ReductionCell1(nn.Module): - def __init__(self, in_channels_left, out_channels_left, in_channels_right, out_channels_right): + def __init__(self, in_chs_left, out_chs_left, in_chs_right, out_chs_right, pad_type=''): super(ReductionCell1, self).__init__() - self.conv_prev_1x1 = nn.Sequential() - self.conv_prev_1x1.add_module('relu', nn.ReLU()) - self.conv_prev_1x1.add_module('conv', nn.Conv2d(in_channels_left, out_channels_left, 1, stride=1, bias=False)) - self.conv_prev_1x1.add_module('bn', nn.BatchNorm2d(out_channels_left, eps=0.001, momentum=0.1, affine=True)) - - self.conv_1x1 = nn.Sequential() - self.conv_1x1.add_module('relu', nn.ReLU()) - self.conv_1x1.add_module('conv', nn.Conv2d(in_channels_right, out_channels_right, 1, stride=1, bias=False)) - self.conv_1x1.add_module('bn', nn.BatchNorm2d(out_channels_right, eps=0.001, momentum=0.1, affine=True)) + self.conv_prev_1x1 = ActConvBn(in_chs_left, out_chs_left, 1, stride=1, padding=pad_type) + self.conv_1x1 = ActConvBn(in_chs_right, out_chs_right, 1, stride=1, padding=pad_type) - self.comb_iter_0_left = BranchSeparables(out_channels_right, out_channels_right, 5, 2, 2, bias=False) - self.comb_iter_0_right = BranchSeparables(out_channels_right, out_channels_right, 7, 2, 3, bias=False) + self.comb_iter_0_left = BranchSeparables(out_chs_right, out_chs_right, 5, 2, pad_type) + self.comb_iter_0_right = BranchSeparables(out_chs_right, out_chs_right, 7, 2, pad_type) - self.comb_iter_1_left = nn.MaxPool2d(3, stride=2, padding=1) - self.comb_iter_1_right = BranchSeparables(out_channels_right, out_channels_right, 7, 2, 3, bias=False) + self.comb_iter_1_left = create_pool2d('max', 3, 2, padding=pad_type) + self.comb_iter_1_right = BranchSeparables(out_chs_right, out_chs_right, 7, 2, pad_type) - self.comb_iter_2_left = nn.AvgPool2d(3, stride=2, padding=1, count_include_pad=False) - self.comb_iter_2_right = BranchSeparables(out_channels_right, out_channels_right, 5, 2, 2, bias=False) + self.comb_iter_2_left = create_pool2d('avg', 3, 2, count_include_pad=False, padding=pad_type) + self.comb_iter_2_right = BranchSeparables(out_chs_right, out_chs_right, 5, 2, pad_type) - self.comb_iter_3_right = nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False) + self.comb_iter_3_right = create_pool2d('avg', 3, 1, count_include_pad=False, padding=pad_type) - self.comb_iter_4_left = BranchSeparables(out_channels_right, out_channels_right, 3, 1, 1, bias=False) - self.comb_iter_4_right = nn.MaxPool2d(3, stride=2, padding=1) + self.comb_iter_4_left = BranchSeparables(out_chs_right, out_chs_right, 3, 1, pad_type) + self.comb_iter_4_right = create_pool2d('max', 3, 2, padding=pad_type) def forward(self, x, x_prev): x_left = self.conv_prev_1x1(x_prev) @@ -487,7 +401,7 @@ class NASNetALarge(nn.Module): """NASNetALarge (6 @ 4032) """ def __init__(self, num_classes=1000, in_chans=1, stem_size=96, num_features=4032, channel_multiplier=2, - drop_rate=0., global_pool='avg'): + drop_rate=0., global_pool='avg', pad_type='same'): super(NASNetALarge, self).__init__() self.num_classes = num_classes self.stem_size = stem_size @@ -498,60 +412,79 @@ class NASNetALarge(nn.Module): channels = self.num_features // 24 # 24 is default value for the architecture - self.conv0 = nn.Sequential() - self.conv0.add_module('conv', nn.Conv2d( - in_channels=in_chans, out_channels=self.stem_size, kernel_size=3, padding=0, stride=2, bias=False)) - self.conv0.add_module('bn', nn.BatchNorm2d(self.stem_size, eps=0.001, momentum=0.1, affine=True)) - - self.cell_stem_0 = CellStem0(self.stem_size, num_channels=channels // (channel_multiplier ** 2)) - self.cell_stem_1 = CellStem1(self.stem_size, num_channels=channels // channel_multiplier) - - self.cell_0 = FirstCell(in_channels_left=channels, out_channels_left=channels // 2, - in_channels_right=2 * channels, out_channels_right=channels) - self.cell_1 = NormalCell(in_channels_left=2 * channels, out_channels_left=channels, - in_channels_right=6 * channels, out_channels_right=channels) - self.cell_2 = NormalCell(in_channels_left=6 * channels, out_channels_left=channels, - in_channels_right=6 * channels, out_channels_right=channels) - self.cell_3 = NormalCell(in_channels_left=6 * channels, out_channels_left=channels, - in_channels_right=6 * channels, out_channels_right=channels) - self.cell_4 = NormalCell(in_channels_left=6 * channels, out_channels_left=channels, - in_channels_right=6 * channels, out_channels_right=channels) - self.cell_5 = NormalCell(in_channels_left=6 * channels, out_channels_left=channels, - in_channels_right=6 * channels, out_channels_right=channels) - - self.reduction_cell_0 = ReductionCell0(in_channels_left=6 * channels, out_channels_left=2 * channels, - in_channels_right=6 * channels, out_channels_right=2 * channels) - - self.cell_6 = FirstCell(in_channels_left=6 * channels, out_channels_left=channels, - in_channels_right=8 * channels, out_channels_right=2 * channels) - self.cell_7 = NormalCell(in_channels_left=8 * channels, out_channels_left=2 * channels, - in_channels_right=12 * channels, out_channels_right=2 * channels) - self.cell_8 = NormalCell(in_channels_left=12 * channels, out_channels_left=2 * channels, - in_channels_right=12 * channels, out_channels_right=2 * channels) - self.cell_9 = NormalCell(in_channels_left=12 * channels, out_channels_left=2 * channels, - in_channels_right=12 * channels, out_channels_right=2 * channels) - self.cell_10 = NormalCell(in_channels_left=12 * channels, out_channels_left=2 * channels, - in_channels_right=12 * channels, out_channels_right=2 * channels) - self.cell_11 = NormalCell(in_channels_left=12 * channels, out_channels_left=2 * channels, - in_channels_right=12 * channels, out_channels_right=2 * channels) - - self.reduction_cell_1 = ReductionCell1(in_channels_left=12 * channels, out_channels_left=4 * channels, - in_channels_right=12 * channels, out_channels_right=4 * channels) - - self.cell_12 = FirstCell(in_channels_left=12 * channels, out_channels_left=2 * channels, - in_channels_right=16 * channels, out_channels_right=4 * channels) - self.cell_13 = NormalCell(in_channels_left=16 * channels, out_channels_left=4 * channels, - in_channels_right=24 * channels, out_channels_right=4 * channels) - self.cell_14 = NormalCell(in_channels_left=24 * channels, out_channels_left=4 * channels, - in_channels_right=24 * channels, out_channels_right=4 * channels) - self.cell_15 = NormalCell(in_channels_left=24 * channels, out_channels_left=4 * channels, - in_channels_right=24 * channels, out_channels_right=4 * channels) - self.cell_16 = NormalCell(in_channels_left=24 * channels, out_channels_left=4 * channels, - in_channels_right=24 * channels, out_channels_right=4 * channels) - self.cell_17 = NormalCell(in_channels_left=24 * channels, out_channels_left=4 * channels, - in_channels_right=24 * channels, out_channels_right=4 * channels) - - self.relu = nn.ReLU() + self.conv0 = ConvBnAct( + in_channels=in_chans, out_channels=self.stem_size, kernel_size=3, padding=0, stride=2, + norm_kwargs=dict(eps=0.001, momentum=0.1), act_layer=None) + + self.cell_stem_0 = CellStem0( + self.stem_size, num_channels=channels // (channel_multiplier ** 2), pad_type=pad_type) + self.cell_stem_1 = CellStem1( + self.stem_size, num_channels=channels // channel_multiplier, pad_type=pad_type) + + self.cell_0 = FirstCell( + in_chs_left=channels, out_chs_left=channels // 2, + in_chs_right=2 * channels, out_chs_right=channels, pad_type=pad_type) + self.cell_1 = NormalCell( + in_chs_left=2 * channels, out_chs_left=channels, + in_chs_right=6 * channels, out_chs_right=channels, pad_type=pad_type) + self.cell_2 = NormalCell( + in_chs_left=6 * channels, out_chs_left=channels, + in_chs_right=6 * channels, out_chs_right=channels, pad_type=pad_type) + self.cell_3 = NormalCell( + in_chs_left=6 * channels, out_chs_left=channels, + in_chs_right=6 * channels, out_chs_right=channels, pad_type=pad_type) + self.cell_4 = NormalCell( + in_chs_left=6 * channels, out_chs_left=channels, + in_chs_right=6 * channels, out_chs_right=channels, pad_type=pad_type) + self.cell_5 = NormalCell( + in_chs_left=6 * channels, out_chs_left=channels, + in_chs_right=6 * channels, out_chs_right=channels, pad_type=pad_type) + + self.reduction_cell_0 = ReductionCell0( + in_chs_left=6 * channels, out_chs_left=2 * channels, + in_chs_right=6 * channels, out_chs_right=2 * channels, pad_type=pad_type) + self.cell_6 = FirstCell( + in_chs_left=6 * channels, out_chs_left=channels, + in_chs_right=8 * channels, out_chs_right=2 * channels, pad_type=pad_type) + self.cell_7 = NormalCell( + in_chs_left=8 * channels, out_chs_left=2 * channels, + in_chs_right=12 * channels, out_chs_right=2 * channels, pad_type=pad_type) + self.cell_8 = NormalCell( + in_chs_left=12 * channels, out_chs_left=2 * channels, + in_chs_right=12 * channels, out_chs_right=2 * channels, pad_type=pad_type) + self.cell_9 = NormalCell( + in_chs_left=12 * channels, out_chs_left=2 * channels, + in_chs_right=12 * channels, out_chs_right=2 * channels, pad_type=pad_type) + self.cell_10 = NormalCell( + in_chs_left=12 * channels, out_chs_left=2 * channels, + in_chs_right=12 * channels, out_chs_right=2 * channels, pad_type=pad_type) + self.cell_11 = NormalCell( + in_chs_left=12 * channels, out_chs_left=2 * channels, + in_chs_right=12 * channels, out_chs_right=2 * channels, pad_type=pad_type) + + self.reduction_cell_1 = ReductionCell1( + in_chs_left=12 * channels, out_chs_left=4 * channels, + in_chs_right=12 * channels, out_chs_right=4 * channels, pad_type=pad_type) + self.cell_12 = FirstCell( + in_chs_left=12 * channels, out_chs_left=2 * channels, + in_chs_right=16 * channels, out_chs_right=4 * channels, pad_type=pad_type) + self.cell_13 = NormalCell( + in_chs_left=16 * channels, out_chs_left=4 * channels, + in_chs_right=24 * channels, out_chs_right=4 * channels, pad_type=pad_type) + self.cell_14 = NormalCell( + in_chs_left=24 * channels, out_chs_left=4 * channels, + in_chs_right=24 * channels, out_chs_right=4 * channels, pad_type=pad_type) + self.cell_15 = NormalCell( + in_chs_left=24 * channels, out_chs_left=4 * channels, + in_chs_right=24 * channels, out_chs_right=4 * channels, pad_type=pad_type) + self.cell_16 = NormalCell( + in_chs_left=24 * channels, out_chs_left=4 * channels, + in_chs_right=24 * channels, out_chs_right=4 * channels, pad_type=pad_type) + self.cell_17 = NormalCell( + in_chs_left=24 * channels, out_chs_left=4 * channels, + in_chs_right=24 * channels, out_chs_right=4 * channels, pad_type=pad_type) + + self.act = nn.ReLU(inplace=True) self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) self.last_linear = nn.Linear(self.num_features * self.global_pool.feat_mult(), num_classes) @@ -569,8 +502,11 @@ class NASNetALarge(nn.Module): def forward_features(self, x): x_conv0 = self.conv0(x) + #0 + x_stem_0 = self.cell_stem_0(x_conv0) x_stem_1 = self.cell_stem_1(x_conv0, x_stem_0) + #1 x_cell_0 = self.cell_0(x_stem_1, x_stem_0) x_cell_1 = self.cell_1(x_cell_0, x_stem_1) @@ -578,25 +514,27 @@ class NASNetALarge(nn.Module): x_cell_3 = self.cell_3(x_cell_2, x_cell_1) x_cell_4 = self.cell_4(x_cell_3, x_cell_2) x_cell_5 = self.cell_5(x_cell_4, x_cell_3) + #2 x_reduction_cell_0 = self.reduction_cell_0(x_cell_5, x_cell_4) - x_cell_6 = self.cell_6(x_reduction_cell_0, x_cell_4) x_cell_7 = self.cell_7(x_cell_6, x_reduction_cell_0) x_cell_8 = self.cell_8(x_cell_7, x_cell_6) x_cell_9 = self.cell_9(x_cell_8, x_cell_7) x_cell_10 = self.cell_10(x_cell_9, x_cell_8) x_cell_11 = self.cell_11(x_cell_10, x_cell_9) + #3 x_reduction_cell_1 = self.reduction_cell_1(x_cell_11, x_cell_10) - x_cell_12 = self.cell_12(x_reduction_cell_1, x_cell_10) x_cell_13 = self.cell_13(x_cell_12, x_reduction_cell_1) x_cell_14 = self.cell_14(x_cell_13, x_cell_12) x_cell_15 = self.cell_15(x_cell_14, x_cell_13) x_cell_16 = self.cell_16(x_cell_15, x_cell_14) x_cell_17 = self.cell_17(x_cell_16, x_cell_15) - x = self.relu(x_cell_17) + x = self.act(x_cell_17) + #4 + return x def forward(self, x): diff --git a/timm/models/pnasnet.py b/timm/models/pnasnet.py index 56614bd6..db558401 100644 --- a/timm/models/pnasnet.py +++ b/timm/models/pnasnet.py @@ -14,7 +14,7 @@ import torch.nn as nn import torch.nn.functional as F from .helpers import load_pretrained -from .layers import SelectAdaptivePool2d +from .layers import SelectAdaptivePool2d, ConvBnAct, create_conv2d, create_pool2d from .registry import register_model __all__ = ['PNASNet5Large'] @@ -35,34 +35,15 @@ default_cfgs = { } -class MaxPool(nn.Module): - - def __init__(self, kernel_size, stride=1, padding=1, zero_pad=False): - super(MaxPool, self).__init__() - self.zero_pad = nn.ZeroPad2d((1, 0, 1, 0)) if zero_pad else None - self.pool = nn.MaxPool2d(kernel_size, stride=stride, padding=padding) - - def forward(self, x): - if self.zero_pad is not None: - x = self.zero_pad(x) - x = self.pool(x) - x = x[:, :, 1:, 1:] - else: - x = self.pool(x) - return x - - class SeparableConv2d(nn.Module): - def __init__(self, in_channels, out_channels, dw_kernel_size, dw_stride, - dw_padding): + def __init__(self, in_channels, out_channels, kernel_size, stride, padding=''): super(SeparableConv2d, self).__init__() - self.depthwise_conv2d = nn.Conv2d(in_channels, in_channels, - kernel_size=dw_kernel_size, - stride=dw_stride, padding=dw_padding, - groups=in_channels, bias=False) - self.pointwise_conv2d = nn.Conv2d(in_channels, out_channels, - kernel_size=1, bias=False) + self.depthwise_conv2d = create_conv2d( + in_channels, in_channels, kernel_size=kernel_size, + stride=stride, padding=padding, groups=in_channels) + self.pointwise_conv2d = create_conv2d( + in_channels, out_channels, kernel_size=1, padding=padding) def forward(self, x): x = self.depthwise_conv2d(x) @@ -72,50 +53,39 @@ class SeparableConv2d(nn.Module): class BranchSeparables(nn.Module): - def __init__(self, in_channels, out_channels, kernel_size, stride=1, - stem_cell=False, zero_pad=False): + def __init__(self, in_channels, out_channels, kernel_size, stride=1, stem_cell=False, padding=''): super(BranchSeparables, self).__init__() - padding = kernel_size // 2 middle_channels = out_channels if stem_cell else in_channels - self.zero_pad = nn.ZeroPad2d((1, 0, 1, 0)) if zero_pad else None - self.relu_1 = nn.ReLU() - self.separable_1 = SeparableConv2d(in_channels, middle_channels, - kernel_size, dw_stride=stride, - dw_padding=padding) + self.act_1 = nn.ReLU() + self.separable_1 = SeparableConv2d( + in_channels, middle_channels, kernel_size, stride=stride, padding=padding) self.bn_sep_1 = nn.BatchNorm2d(middle_channels, eps=0.001) - self.relu_2 = nn.ReLU() - self.separable_2 = SeparableConv2d(middle_channels, out_channels, - kernel_size, dw_stride=1, - dw_padding=padding) + self.act_2 = nn.ReLU() + self.separable_2 = SeparableConv2d( + middle_channels, out_channels, kernel_size, stride=1, padding=padding) self.bn_sep_2 = nn.BatchNorm2d(out_channels, eps=0.001) def forward(self, x): - x = self.relu_1(x) - if self.zero_pad is not None: - x = self.zero_pad(x) - x = self.separable_1(x) - x = x[:, :, 1:, 1:].contiguous() - else: - x = self.separable_1(x) + x = self.act_1(x) + x = self.separable_1(x) x = self.bn_sep_1(x) - x = self.relu_2(x) + x = self.act_2(x) x = self.separable_2(x) x = self.bn_sep_2(x) return x -class ReluConvBn(nn.Module): +class ActConvBn(nn.Module): - def __init__(self, in_channels, out_channels, kernel_size, stride=1): - super(ReluConvBn, self).__init__() - self.relu = nn.ReLU() - self.conv = nn.Conv2d(in_channels, out_channels, - kernel_size=kernel_size, stride=stride, - bias=False) + def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=''): + super(ActConvBn, self).__init__() + self.act = nn.ReLU() + self.conv = create_conv2d( + in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding) self.bn = nn.BatchNorm2d(out_channels, eps=0.001) def forward(self, x): - x = self.relu(x) + x = self.act(x) x = self.conv(x) x = self.bn(x) return x @@ -123,32 +93,24 @@ class ReluConvBn(nn.Module): class FactorizedReduction(nn.Module): - def __init__(self, in_channels, out_channels): + def __init__(self, in_channels, out_channels, padding=''): super(FactorizedReduction, self).__init__() - self.relu = nn.ReLU() + self.act = nn.ReLU() self.path_1 = nn.Sequential(OrderedDict([ ('avgpool', nn.AvgPool2d(1, stride=2, count_include_pad=False)), - ('conv', nn.Conv2d(in_channels, out_channels // 2, - kernel_size=1, bias=False)), + ('conv', create_conv2d(in_channels, out_channels // 2, kernel_size=1, padding=padding)), ])) self.path_2 = nn.Sequential(OrderedDict([ - ('pad', nn.ZeroPad2d((0, 1, 0, 1))), + ('pad', nn.ZeroPad2d((-1, 1, -1, 1))), # shift ('avgpool', nn.AvgPool2d(1, stride=2, count_include_pad=False)), - ('conv', nn.Conv2d(in_channels, out_channels // 2, - kernel_size=1, bias=False)), + ('conv', create_conv2d(in_channels, out_channels // 2, kernel_size=1, padding=padding)), ])) self.final_path_bn = nn.BatchNorm2d(out_channels, eps=0.001) def forward(self, x): - x = self.relu(x) - + x = self.act(x) x_path1 = self.path_1(x) - - x_path2 = self.path_2.pad(x) - x_path2 = x_path2[:, :, 1:, 1:] - x_path2 = self.path_2.avgpool(x_path2) - x_path2 = self.path_2.conv(x_path2) - + x_path2 = self.path_2(x) out = self.final_path_bn(torch.cat([x_path1, x_path2], 1)) return out @@ -179,49 +141,41 @@ class CellBase(nn.Module): x_comb_iter_4_right = x_right x_comb_iter_4 = x_comb_iter_4_left + x_comb_iter_4_right - x_out = torch.cat( - [x_comb_iter_0, x_comb_iter_1, x_comb_iter_2, x_comb_iter_3, x_comb_iter_4], 1) + x_out = torch.cat([x_comb_iter_0, x_comb_iter_1, x_comb_iter_2, x_comb_iter_3, x_comb_iter_4], 1) return x_out class CellStem0(CellBase): - def __init__(self, in_channels_left, out_channels_left, in_channels_right, - out_channels_right): + def __init__(self, in_chs_left, out_chs_left, in_chs_right, out_chs_right, padding=''): super(CellStem0, self).__init__() - self.conv_1x1 = ReluConvBn(in_channels_right, out_channels_right, - kernel_size=1) - self.comb_iter_0_left = BranchSeparables(in_channels_left, - out_channels_left, - kernel_size=5, stride=2, - stem_cell=True) + self.conv_1x1 = ActConvBn(in_chs_right, out_chs_right, kernel_size=1, padding=padding) + + self.comb_iter_0_left = BranchSeparables( + in_chs_left, out_chs_left, kernel_size=5, stride=2, stem_cell=True, padding=padding) self.comb_iter_0_right = nn.Sequential(OrderedDict([ - ('max_pool', MaxPool(3, stride=2)), - ('conv', nn.Conv2d(in_channels_left, out_channels_left, - kernel_size=1, bias=False)), - ('bn', nn.BatchNorm2d(out_channels_left, eps=0.001)), + ('max_pool', create_pool2d('max', 3, stride=2, padding=padding)), + ('conv', create_conv2d(in_chs_left, out_chs_left, kernel_size=1, padding=padding)), + ('bn', nn.BatchNorm2d(out_chs_left, eps=0.001)), ])) - self.comb_iter_1_left = BranchSeparables(out_channels_right, - out_channels_right, - kernel_size=7, stride=2) - self.comb_iter_1_right = MaxPool(3, stride=2) - self.comb_iter_2_left = BranchSeparables(out_channels_right, - out_channels_right, - kernel_size=5, stride=2) - self.comb_iter_2_right = BranchSeparables(out_channels_right, - out_channels_right, - kernel_size=3, stride=2) - self.comb_iter_3_left = BranchSeparables(out_channels_right, - out_channels_right, - kernel_size=3) - self.comb_iter_3_right = MaxPool(3, stride=2) - self.comb_iter_4_left = BranchSeparables(in_channels_right, - out_channels_right, - kernel_size=3, stride=2, - stem_cell=True) - self.comb_iter_4_right = ReluConvBn(out_channels_right, - out_channels_right, - kernel_size=1, stride=2) + + self.comb_iter_1_left = BranchSeparables( + out_chs_right, out_chs_right, kernel_size=7, stride=2, padding=padding) + self.comb_iter_1_right = create_pool2d('max', 3, stride=2, padding=padding) + + self.comb_iter_2_left = BranchSeparables( + out_chs_right, out_chs_right, kernel_size=5, stride=2, padding=padding) + self.comb_iter_2_right = BranchSeparables( + out_chs_right, out_chs_right, kernel_size=3, stride=2, padding=padding) + + self.comb_iter_3_left = BranchSeparables( + out_chs_right, out_chs_right, kernel_size=3, padding=padding) + self.comb_iter_3_right = create_pool2d('max', 3, stride=2, padding=padding) + + self.comb_iter_4_left = BranchSeparables( + in_chs_right, out_chs_right, kernel_size=3, stride=2, stem_cell=True, padding=padding) + self.comb_iter_4_right = ActConvBn( + out_chs_right, out_chs_right, kernel_size=1, stride=2, padding=padding) def forward(self, x_left): x_right = self.conv_1x1(x_left) @@ -231,9 +185,8 @@ class CellStem0(CellBase): class Cell(CellBase): - def __init__(self, in_channels_left, out_channels_left, in_channels_right, - out_channels_right, is_reduction=False, zero_pad=False, - match_prev_layer_dimensions=False): + def __init__(self, in_chs_left, out_chs_left, in_chs_right, out_chs_right, padding='', + is_reduction=False, match_prev_layer_dims=False): super(Cell, self).__init__() # If `is_reduction` is set to `True` stride 2 is used for @@ -244,45 +197,34 @@ class Cell(CellBase): # If `match_prev_layer_dimensions` is set to `True` # `FactorizedReduction` is used to reduce the spatial size # of the left input of a cell approximately by a factor of 2. - self.match_prev_layer_dimensions = match_prev_layer_dimensions - if match_prev_layer_dimensions: - self.conv_prev_1x1 = FactorizedReduction(in_channels_left, - out_channels_left) + self.match_prev_layer_dimensions = match_prev_layer_dims + if match_prev_layer_dims: + self.conv_prev_1x1 = FactorizedReduction(in_chs_left, out_chs_left, padding=padding) else: - self.conv_prev_1x1 = ReluConvBn(in_channels_left, - out_channels_left, kernel_size=1) - - self.conv_1x1 = ReluConvBn(in_channels_right, out_channels_right, - kernel_size=1) - self.comb_iter_0_left = BranchSeparables(out_channels_left, - out_channels_left, - kernel_size=5, stride=stride, - zero_pad=zero_pad) - self.comb_iter_0_right = MaxPool(3, stride=stride, zero_pad=zero_pad) - self.comb_iter_1_left = BranchSeparables(out_channels_right, - out_channels_right, - kernel_size=7, stride=stride, - zero_pad=zero_pad) - self.comb_iter_1_right = MaxPool(3, stride=stride, zero_pad=zero_pad) - self.comb_iter_2_left = BranchSeparables(out_channels_right, - out_channels_right, - kernel_size=5, stride=stride, - zero_pad=zero_pad) - self.comb_iter_2_right = BranchSeparables(out_channels_right, - out_channels_right, - kernel_size=3, stride=stride, - zero_pad=zero_pad) - self.comb_iter_3_left = BranchSeparables(out_channels_right, - out_channels_right, - kernel_size=3) - self.comb_iter_3_right = MaxPool(3, stride=stride, zero_pad=zero_pad) - self.comb_iter_4_left = BranchSeparables(out_channels_left, - out_channels_left, - kernel_size=3, stride=stride, - zero_pad=zero_pad) + self.conv_prev_1x1 = ActConvBn(in_chs_left, out_chs_left, kernel_size=1, padding=padding) + self.conv_1x1 = ActConvBn(in_chs_right, out_chs_right, kernel_size=1, padding=padding) + + self.comb_iter_0_left = BranchSeparables( + out_chs_left, out_chs_left, kernel_size=5, stride=stride, padding=padding) + self.comb_iter_0_right = create_pool2d('max', 3, stride=stride, padding=padding) + + self.comb_iter_1_left = BranchSeparables( + out_chs_right, out_chs_right, kernel_size=7, stride=stride, padding=padding) + self.comb_iter_1_right = create_pool2d('max', 3, stride=stride, padding=padding) + + self.comb_iter_2_left = BranchSeparables( + out_chs_right, out_chs_right, kernel_size=5, stride=stride, padding=padding) + self.comb_iter_2_right = BranchSeparables( + out_chs_right, out_chs_right, kernel_size=3, stride=stride, padding=padding) + + self.comb_iter_3_left = BranchSeparables(out_chs_right, out_chs_right, kernel_size=3) + self.comb_iter_3_right = create_pool2d('max', 3, stride=stride, padding=padding) + + self.comb_iter_4_left = BranchSeparables( + out_chs_left, out_chs_left, kernel_size=3, stride=stride, padding=padding) if is_reduction: - self.comb_iter_4_right = ReluConvBn( - out_channels_right, out_channels_right, kernel_size=1, stride=stride) + self.comb_iter_4_right = ActConvBn( + out_chs_right, out_chs_right, kernel_size=1, stride=stride, padding=padding) else: self.comb_iter_4_right = None @@ -294,52 +236,53 @@ class Cell(CellBase): class PNASNet5Large(nn.Module): - def __init__(self, num_classes=1001, in_chans=3, drop_rate=0.5, global_pool='avg'): + def __init__(self, num_classes=1001, in_chans=3, drop_rate=0.5, global_pool='avg', padding=''): super(PNASNet5Large, self).__init__() self.num_classes = num_classes self.num_features = 4320 self.drop_rate = drop_rate - self.conv_0 = nn.Sequential(OrderedDict([ - ('conv', nn.Conv2d(in_chans, 96, kernel_size=3, stride=2, bias=False)), - ('bn', nn.BatchNorm2d(96, eps=0.001)) - ])) - self.cell_stem_0 = CellStem0(in_channels_left=96, out_channels_left=54, - in_channels_right=96, - out_channels_right=54) - self.cell_stem_1 = Cell(in_channels_left=96, out_channels_left=108, - in_channels_right=270, out_channels_right=108, - match_prev_layer_dimensions=True, - is_reduction=True) - self.cell_0 = Cell(in_channels_left=270, out_channels_left=216, - in_channels_right=540, out_channels_right=216, - match_prev_layer_dimensions=True) - self.cell_1 = Cell(in_channels_left=540, out_channels_left=216, - in_channels_right=1080, out_channels_right=216) - self.cell_2 = Cell(in_channels_left=1080, out_channels_left=216, - in_channels_right=1080, out_channels_right=216) - self.cell_3 = Cell(in_channels_left=1080, out_channels_left=216, - in_channels_right=1080, out_channels_right=216) - self.cell_4 = Cell(in_channels_left=1080, out_channels_left=432, - in_channels_right=1080, out_channels_right=432, - is_reduction=True, zero_pad=True) - self.cell_5 = Cell(in_channels_left=1080, out_channels_left=432, - in_channels_right=2160, out_channels_right=432, - match_prev_layer_dimensions=True) - self.cell_6 = Cell(in_channels_left=2160, out_channels_left=432, - in_channels_right=2160, out_channels_right=432) - self.cell_7 = Cell(in_channels_left=2160, out_channels_left=432, - in_channels_right=2160, out_channels_right=432) - self.cell_8 = Cell(in_channels_left=2160, out_channels_left=864, - in_channels_right=2160, out_channels_right=864, - is_reduction=True) - self.cell_9 = Cell(in_channels_left=2160, out_channels_left=864, - in_channels_right=4320, out_channels_right=864, - match_prev_layer_dimensions=True) - self.cell_10 = Cell(in_channels_left=4320, out_channels_left=864, - in_channels_right=4320, out_channels_right=864) - self.cell_11 = Cell(in_channels_left=4320, out_channels_left=864, - in_channels_right=4320, out_channels_right=864) + self.conv_0 = ConvBnAct( + in_chans, 96, kernel_size=3, stride=2, padding=0, + norm_kwargs=dict(eps=0.001, momentum=0.1), act_layer=None) + + self.cell_stem_0 = CellStem0( + in_chs_left=96, out_chs_left=54, in_chs_right=96, out_chs_right=54, padding=padding) + + self.cell_stem_1 = Cell( + in_chs_left=96, out_chs_left=108, in_chs_right=270, out_chs_right=108, padding=padding, + match_prev_layer_dims=True, is_reduction=True) + self.cell_0 = Cell( + in_chs_left=270, out_chs_left=216, in_chs_right=540, out_chs_right=216, padding=padding, + match_prev_layer_dims=True) + self.cell_1 = Cell( + in_chs_left=540, out_chs_left=216, in_chs_right=1080, out_chs_right=216, padding=padding) + self.cell_2 = Cell( + in_chs_left=1080, out_chs_left=216, in_chs_right=1080, out_chs_right=216, padding=padding) + self.cell_3 = Cell( + in_chs_left=1080, out_chs_left=216, in_chs_right=1080, out_chs_right=216, padding=padding) + + self.cell_4 = Cell( + in_chs_left=1080, out_chs_left=432, in_chs_right=1080, out_chs_right=432, padding=padding, + is_reduction=True) + self.cell_5 = Cell( + in_chs_left=1080, out_chs_left=432, in_chs_right=2160, out_chs_right=432, padding=padding, + match_prev_layer_dims=True) + self.cell_6 = Cell( + in_chs_left=2160, out_chs_left=432, in_chs_right=2160, out_chs_right=432, padding=padding) + self.cell_7 = Cell( + in_chs_left=2160, out_chs_left=432, in_chs_right=2160, out_chs_right=432, padding=padding) + + self.cell_8 = Cell( + in_chs_left=2160, out_chs_left=864, in_chs_right=2160, out_chs_right=864, padding=padding, + is_reduction=True) + self.cell_9 = Cell( + in_chs_left=2160, out_chs_left=864, in_chs_right=4320, out_chs_right=864, padding=padding, + match_prev_layer_dims=True) + self.cell_10 = Cell( + in_chs_left=4320, out_chs_left=864, in_chs_right=4320, out_chs_right=864, padding=padding) + self.cell_11 = Cell( + in_chs_left=4320, out_chs_left=864, in_chs_right=4320, out_chs_right=864, padding=padding) self.relu = nn.ReLU() self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) self.last_linear = nn.Linear(self.num_features * self.global_pool.feat_mult(), num_classes) @@ -391,7 +334,7 @@ def pnasnet5large(pretrained=False, num_classes=1000, in_chans=3, **kwargs): `_ paper. """ default_cfg = default_cfgs['pnasnet5large'] - model = PNASNet5Large(num_classes=num_classes, in_chans=in_chans, **kwargs) + model = PNASNet5Large(num_classes=num_classes, in_chans=in_chans, padding='same', **kwargs) model.default_cfg = default_cfg if pretrained: load_pretrained(model, default_cfg, num_classes, in_chans) diff --git a/timm/models/res2net.py b/timm/models/res2net.py index c3773dd5..536fd49a 100644 --- a/timm/models/res2net.py +++ b/timm/models/res2net.py @@ -10,7 +10,7 @@ import torch.nn as nn from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from .helpers import load_pretrained from .registry import register_model -from .resnet import ResNet +from .resnet import _create_resnet_with_cfg __all__ = [] @@ -132,113 +132,83 @@ class Bottle2neck(nn.Module): return out +def _create_res2net(variant, pretrained=False, **kwargs): + default_cfg = default_cfgs[variant] + return _create_resnet_with_cfg(variant, default_cfg, pretrained=pretrained, **kwargs) + + @register_model -def res2net50_26w_4s(pretrained=False, num_classes=1000, in_chans=3, **kwargs): - """Constructs a Res2Net-50_26w_4s model. +def res2net50_26w_4s(pretrained=False, **kwargs): + """Constructs a Res2Net-50 26w4s model. Args: pretrained (bool): If True, returns a model pre-trained on ImageNet """ - default_cfg = default_cfgs['res2net50_26w_4s'] - res2net_block_args = dict(scale=4) - model = ResNet(Bottle2neck, [3, 4, 6, 3], base_width=26, - num_classes=num_classes, in_chans=in_chans, block_args=res2net_block_args, **kwargs) - model.default_cfg = default_cfg - if pretrained: - load_pretrained(model, default_cfg, num_classes, in_chans) - return model + model_args = dict( + block=Bottle2neck, layers=[3, 4, 6, 3], base_width=26, block_args=dict(scale=4), **kwargs) + return _create_res2net('res2net50_26w_4s', pretrained, **model_args) @register_model -def res2net101_26w_4s(pretrained=False, num_classes=1000, in_chans=3, **kwargs): - """Constructs a Res2Net-50_26w_4s model. +def res2net101_26w_4s(pretrained=False, **kwargs): + """Constructs a Res2Net-101 26w4s model. Args: pretrained (bool): If True, returns a model pre-trained on ImageNet """ - default_cfg = default_cfgs['res2net101_26w_4s'] - res2net_block_args = dict(scale=4) - model = ResNet(Bottle2neck, [3, 4, 23, 3], base_width=26, - num_classes=num_classes, in_chans=in_chans, block_args=res2net_block_args, **kwargs) - model.default_cfg = default_cfg - if pretrained: - load_pretrained(model, default_cfg, num_classes, in_chans) - return model + model_args = dict( + block=Bottle2neck, layers=[3, 4, 23, 3], base_width=26, block_args=dict(scale=4), **kwargs) + return _create_res2net('res2net101_26w_4s', pretrained, **model_args) @register_model -def res2net50_26w_6s(pretrained=False, num_classes=1000, in_chans=3, **kwargs): - """Constructs a Res2Net-50_26w_4s model. +def res2net50_26w_6s(pretrained=False, **kwargs): + """Constructs a Res2Net-50 26w6s model. Args: pretrained (bool): If True, returns a model pre-trained on ImageNet """ - default_cfg = default_cfgs['res2net50_26w_6s'] - res2net_block_args = dict(scale=6) - model = ResNet(Bottle2neck, [3, 4, 6, 3], base_width=26, - num_classes=num_classes, in_chans=in_chans, block_args=res2net_block_args, **kwargs) - model.default_cfg = default_cfg - if pretrained: - load_pretrained(model, default_cfg, num_classes, in_chans) - return model + model_args = dict( + block=Bottle2neck, layers=[3, 4, 6, 3], base_width=26, block_args=dict(scale=6), **kwargs) + return _create_res2net('res2net50_26w_6s', pretrained, **model_args) @register_model -def res2net50_26w_8s(pretrained=False, num_classes=1000, in_chans=3, **kwargs): - """Constructs a Res2Net-50_26w_4s model. +def res2net50_26w_8s(pretrained=False, **kwargs): + """Constructs a Res2Net-50 26w8s model. Args: pretrained (bool): If True, returns a model pre-trained on ImageNet """ - default_cfg = default_cfgs['res2net50_26w_8s'] - res2net_block_args = dict(scale=8) - model = ResNet(Bottle2neck, [3, 4, 6, 3], base_width=26, - num_classes=num_classes, in_chans=in_chans, block_args=res2net_block_args, **kwargs) - model.default_cfg = default_cfg - if pretrained: - load_pretrained(model, default_cfg, num_classes, in_chans) - return model + model_args = dict( + block=Bottle2neck, layers=[3, 4, 6, 3], base_width=26, block_args=dict(scale=8), **kwargs) + return _create_res2net('res2net50_26w_8s', pretrained, **model_args) @register_model -def res2net50_48w_2s(pretrained=False, num_classes=1000, in_chans=3, **kwargs): - """Constructs a Res2Net-50_48w_2s model. +def res2net50_48w_2s(pretrained=False, **kwargs): + """Constructs a Res2Net-50 48w2s model. Args: pretrained (bool): If True, returns a model pre-trained on ImageNet """ - default_cfg = default_cfgs['res2net50_48w_2s'] - res2net_block_args = dict(scale=2) - model = ResNet(Bottle2neck, [3, 4, 6, 3], base_width=48, - num_classes=num_classes, in_chans=in_chans, block_args=res2net_block_args, **kwargs) - model.default_cfg = default_cfg - if pretrained: - load_pretrained(model, default_cfg, num_classes, in_chans) - return model + model_args = dict( + block=Bottle2neck, layers=[3, 4, 6, 3], base_width=48, block_args=dict(scale=2), **kwargs) + return _create_res2net('res2net50_26w_8s', pretrained, **model_args) @register_model -def res2net50_14w_8s(pretrained=False, num_classes=1000, in_chans=3, **kwargs): - """Constructs a Res2Net-50_14w_8s model. +def res2net50_14w_8s(pretrained=False, **kwargs): + """Constructs a Res2Net-50 14w8s model. Args: pretrained (bool): If True, returns a model pre-trained on ImageNet """ - default_cfg = default_cfgs['res2net50_14w_8s'] - res2net_block_args = dict(scale=8) - model = ResNet(Bottle2neck, [3, 4, 6, 3], base_width=14, num_classes=num_classes, in_chans=in_chans, - block_args=res2net_block_args, **kwargs) - model.default_cfg = default_cfg - if pretrained: - load_pretrained(model, default_cfg, num_classes, in_chans) - return model + model_args = dict( + block=Bottle2neck, layers=[3, 4, 6, 3], base_width=14, block_args=dict(scale=8), **kwargs) + return _create_res2net('res2net50_26w_8s', pretrained, **model_args) @register_model -def res2next50(pretrained=False, num_classes=1000, in_chans=3, **kwargs): +def res2next50(pretrained=False, **kwargs): """Construct Res2NeXt-50 4s Args: pretrained (bool): If True, returns a model pre-trained on ImageNet """ - default_cfg = default_cfgs['res2next50'] - res2net_block_args = dict(scale=4) - model = ResNet(Bottle2neck, [3, 4, 6, 3], base_width=4, cardinality=8, - num_classes=num_classes, in_chans=in_chans, block_args=res2net_block_args, **kwargs) - model.default_cfg = default_cfg - if pretrained: - load_pretrained(model, default_cfg, num_classes, in_chans) - return model + model_args = dict( + block=Bottle2neck, layers=[3, 4, 6, 3], base_width=4, cardinality=8, block_args=dict(scale=4), **kwargs) + return _create_res2net('res2next50', pretrained, **model_args) diff --git a/timm/models/resnest.py b/timm/models/resnest.py index 2ebd125b..cf207faa 100644 --- a/timm/models/resnest.py +++ b/timm/models/resnest.py @@ -6,18 +6,14 @@ Adapted from original PyTorch impl w/ weights at https://github.com/zhanghang198 Modified for torchscript compat, and consistency with timm by Ross Wightman """ -import math import torch -import torch.nn.functional as F from torch import nn from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.models.layers import DropBlock2d -from .helpers import load_pretrained -from .layers import SelectiveKernelConv, ConvBnAct, create_attn from .layers.split_attn import SplitAttnConv2d from .registry import register_model -from .resnet import ResNet +from .resnet import _create_resnet_with_cfg def _cfg(url='', **kwargs): @@ -143,125 +139,98 @@ class ResNestBottleneck(nn.Module): return out +def _create_resnest(variant, pretrained=False, **kwargs): + default_cfg = default_cfgs[variant] + return _create_resnet_with_cfg(variant, default_cfg, pretrained=pretrained, **kwargs) + + @register_model -def resnest14d(pretrained=False, num_classes=1000, in_chans=3, **kwargs): +def resnest14d(pretrained=False, **kwargs): """ ResNeSt-14d model. Weights ported from GluonCV. """ - default_cfg = default_cfgs['resnest14d'] - model = ResNet( - ResNestBottleneck, [1, 1, 1, 1], num_classes=num_classes, in_chans=in_chans, + model_kwargs = dict( + block=ResNestBottleneck, layers=[1, 1, 1, 1], stem_type='deep', stem_width=32, avg_down=True, base_width=64, cardinality=1, block_args=dict(radix=2, avd=True, avd_first=False), **kwargs) - model.default_cfg = default_cfg - if pretrained: - load_pretrained(model, default_cfg, num_classes, in_chans) - return model + return _create_resnest('resnest14d', pretrained=pretrained, **model_kwargs) @register_model -def resnest26d(pretrained=False, num_classes=1000, in_chans=3, **kwargs): +def resnest26d(pretrained=False, **kwargs): """ ResNeSt-26d model. Weights ported from GluonCV. """ - default_cfg = default_cfgs['resnest26d'] - model = ResNet( - ResNestBottleneck, [2, 2, 2, 2], num_classes=num_classes, in_chans=in_chans, + model_kwargs = dict( + block=ResNestBottleneck, layers=[2, 2, 2, 2], stem_type='deep', stem_width=32, avg_down=True, base_width=64, cardinality=1, block_args=dict(radix=2, avd=True, avd_first=False), **kwargs) - model.default_cfg = default_cfg - if pretrained: - load_pretrained(model, default_cfg, num_classes, in_chans) - return model + return _create_resnest('resnest26d', pretrained=pretrained, **model_kwargs) @register_model -def resnest50d(pretrained=False, num_classes=1000, in_chans=3, **kwargs): +def resnest50d(pretrained=False, **kwargs): """ ResNeSt-50d model. Matches paper ResNeSt-50 model, https://arxiv.org/abs/2004.08955 Since this codebase supports all possible variations, 'd' for deep stem, stem_width 32, avg in downsample. """ - default_cfg = default_cfgs['resnest50d'] - model = ResNet( - ResNestBottleneck, [3, 4, 6, 3], num_classes=num_classes, in_chans=in_chans, + model_kwargs = dict( + block=ResNestBottleneck, layers=[3, 4, 6, 3], stem_type='deep', stem_width=32, avg_down=True, base_width=64, cardinality=1, block_args=dict(radix=2, avd=True, avd_first=False), **kwargs) - model.default_cfg = default_cfg - if pretrained: - load_pretrained(model, default_cfg, num_classes, in_chans) - return model + return _create_resnest('resnest50d', pretrained=pretrained, **model_kwargs) @register_model -def resnest101e(pretrained=False, num_classes=1000, in_chans=3, **kwargs): +def resnest101e(pretrained=False, **kwargs): """ ResNeSt-101e model. Matches paper ResNeSt-101 model, https://arxiv.org/abs/2004.08955 Since this codebase supports all possible variations, 'e' for deep stem, stem_width 64, avg in downsample. """ - default_cfg = default_cfgs['resnest101e'] - model = ResNet( - ResNestBottleneck, [3, 4, 23, 3], num_classes=num_classes, in_chans=in_chans, + model_kwargs = dict( + block=ResNestBottleneck, layers=[3, 4, 23, 3], stem_type='deep', stem_width=64, avg_down=True, base_width=64, cardinality=1, block_args=dict(radix=2, avd=True, avd_first=False), **kwargs) - model.default_cfg = default_cfg - if pretrained: - load_pretrained(model, default_cfg, num_classes, in_chans) - return model + return _create_resnest('resnest101e', pretrained=pretrained, **model_kwargs) @register_model -def resnest200e(pretrained=False, num_classes=1000, in_chans=3, **kwargs): +def resnest200e(pretrained=False, **kwargs): """ ResNeSt-200e model. Matches paper ResNeSt-200 model, https://arxiv.org/abs/2004.08955 Since this codebase supports all possible variations, 'e' for deep stem, stem_width 64, avg in downsample. """ - default_cfg = default_cfgs['resnest200e'] - model = ResNet( - ResNestBottleneck, [3, 24, 36, 3], num_classes=num_classes, in_chans=in_chans, + model_kwargs = dict( + block=ResNestBottleneck, layers=[3, 24, 36, 3], stem_type='deep', stem_width=64, avg_down=True, base_width=64, cardinality=1, block_args=dict(radix=2, avd=True, avd_first=False), **kwargs) - model.default_cfg = default_cfg - if pretrained: - load_pretrained(model, default_cfg, num_classes, in_chans) - return model + return _create_resnest('resnest200e', pretrained=pretrained, **model_kwargs) @register_model -def resnest269e(pretrained=False, num_classes=1000, in_chans=3, **kwargs): +def resnest269e(pretrained=False, **kwargs): """ ResNeSt-269e model. Matches paper ResNeSt-269 model, https://arxiv.org/abs/2004.08955 Since this codebase supports all possible variations, 'e' for deep stem, stem_width 64, avg in downsample. """ - default_cfg = default_cfgs['resnest269e'] - model = ResNet( - ResNestBottleneck, [3, 30, 48, 8], num_classes=num_classes, in_chans=in_chans, + model_kwargs = dict( + block=ResNestBottleneck, layers=[3, 30, 48, 8], stem_type='deep', stem_width=64, avg_down=True, base_width=64, cardinality=1, block_args=dict(radix=2, avd=True, avd_first=False), **kwargs) - model.default_cfg = default_cfg - if pretrained: - load_pretrained(model, default_cfg, num_classes, in_chans) - return model + return _create_resnest('resnest269e', pretrained=pretrained, **model_kwargs) @register_model -def resnest50d_4s2x40d(pretrained=False, num_classes=1000, in_chans=3, **kwargs): +def resnest50d_4s2x40d(pretrained=False, **kwargs): """ResNeSt-50 4s2x40d from https://github.com/zhanghang1989/ResNeSt/blob/master/ablation.md """ - default_cfg = default_cfgs['resnest50d_4s2x40d'] - model = ResNet( - ResNestBottleneck, [3, 4, 6, 3], num_classes=num_classes, in_chans=in_chans, + model_kwargs = dict( + block=ResNestBottleneck, layers=[3, 4, 6, 3], stem_type='deep', stem_width=32, avg_down=True, base_width=40, cardinality=2, block_args=dict(radix=4, avd=True, avd_first=True), **kwargs) - model.default_cfg = default_cfg - if pretrained: - load_pretrained(model, default_cfg, num_classes, in_chans) - return model + return _create_resnest('resnest50d_4s2x40d', pretrained=pretrained, **model_kwargs) @register_model -def resnest50d_1s4x24d(pretrained=False, num_classes=1000, in_chans=3, **kwargs): +def resnest50d_1s4x24d(pretrained=False, **kwargs): """ResNeSt-50 1s4x24d from https://github.com/zhanghang1989/ResNeSt/blob/master/ablation.md """ - default_cfg = default_cfgs['resnest50d_1s4x24d'] - model = ResNet( - ResNestBottleneck, [3, 4, 6, 3], num_classes=num_classes, in_chans=in_chans, + model_kwargs = dict( + block=ResNestBottleneck, layers=[3, 4, 6, 3], stem_type='deep', stem_width=32, avg_down=True, base_width=24, cardinality=4, block_args=dict(radix=1, avd=True, avd_first=True), **kwargs) - model.default_cfg = default_cfg - if pretrained: - load_pretrained(model, default_cfg, num_classes, in_chans) - return model + return _create_resnest('resnest50d_1s4x24d', pretrained=pretrained, **model_kwargs) diff --git a/timm/models/resnet.py b/timm/models/resnet.py index 8750c5bd..e3b5f12f 100644 --- a/timm/models/resnet.py +++ b/timm/models/resnet.py @@ -6,11 +6,14 @@ additional dropout and dynamic global avg/max pool. ResNeXt, SE-ResNeXt, SENet, and MXNet Gluon stem/downsample variants, tiered stems added by Ross Wightman """ import math +import copy +import torch import torch.nn as nn import torch.nn.functional as F from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from .features import FeatureNet from .helpers import load_pretrained, adapt_model_from_file from .layers import SelectAdaptivePool2d, DropBlock2d, DropPath, AvgPool2dSame, create_attn, BlurPool2d from .registry import register_model @@ -390,6 +393,7 @@ class ResNet(nn.Module): self.base_width = base_width self.drop_rate = drop_rate self.expansion = block.expansion + self.feature_info = [dict(num_chs=self.inplanes, reduction=2, module='act1')] super(ResNet, self).__init__() # Stem @@ -420,9 +424,6 @@ class ResNet(nn.Module): self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) # Feature Blocks - dp = DropPath(drop_path_rate) if drop_path_rate else None - db_3 = DropBlock2d(drop_block_rate, 7, 0.25) if drop_block_rate else None - db_4 = DropBlock2d(drop_block_rate, 7, 1.00) if drop_block_rate else None channels, strides, dilations = [64, 128, 256, 512], [1, 2, 2, 2], [1] * 4 if output_stride == 16: strides[3] = 1 @@ -432,14 +433,23 @@ class ResNet(nn.Module): dilations[2:4] = [2, 4] else: assert output_stride == 32 + dp = DropPath(drop_path_rate) if drop_path_rate else None + db = [ + None, None, + DropBlock2d(drop_block_rate, 5, 0.25) if drop_block_rate else None, + DropBlock2d(drop_block_rate, 3, 1.00) if drop_block_rate else None] layer_args = list(zip(channels, layers, strides, dilations)) layer_kwargs = dict( reduce_first=block_reduce_first, act_layer=act_layer, norm_layer=norm_layer, aa_layer=aa_layer, avg_down=avg_down, down_kernel_size=down_kernel_size, drop_path=dp, **block_args) - self.layer1 = self._make_layer(block, *layer_args[0], **layer_kwargs) - self.layer2 = self._make_layer(block, *layer_args[1], **layer_kwargs) - self.layer3 = self._make_layer(block, drop_block=db_3, *layer_args[2], **layer_kwargs) - self.layer4 = self._make_layer(block, drop_block=db_4, *layer_args[3], **layer_kwargs) + current_stride = 4 + for i in range(4): + layer_name = f'layer{i + 1}' + self.add_module(layer_name, self._make_layer( + block, *layer_args[i], drop_block=db[i], **layer_kwargs)) + current_stride *= strides[i] + self.feature_info.append(dict( + num_chs=self.inplanes, reduction=current_stride, module=layer_name)) # Head (Pooling and Classifier) self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) @@ -509,245 +519,185 @@ class ResNet(nn.Module): return x +def _create_resnet_with_cfg(variant, default_cfg, pretrained=False, **kwargs): + assert isinstance(default_cfg, dict) + load_strict, features = True, False + out_indices = None + if kwargs.pop('features_only', False): + load_strict, features = False, True + kwargs.pop('num_classes', 0) + out_indices = kwargs.pop('out_indices', (0, 1, 2, 3, 4)) + model = ResNet(**kwargs) + model.default_cfg = copy.deepcopy(default_cfg) + if kwargs.pop('pruned', False): + model = adapt_model_from_file(model, variant) + if pretrained: + load_pretrained( + model, + num_classes=kwargs.get('num_classes', 0), in_chans=kwargs.get('in_chans', 3), strict=load_strict) + if features: + model = FeatureNet(model, out_indices=out_indices) + return model + + +def _create_resnet(variant, pretrained=False, **kwargs): + default_cfg = default_cfgs[variant] + return _create_resnet_with_cfg(variant, default_cfg, pretrained=pretrained, **kwargs) + + @register_model -def resnet18(pretrained=False, num_classes=1000, in_chans=3, **kwargs): +def resnet18(pretrained=False, **kwargs): """Constructs a ResNet-18 model. """ - default_cfg = default_cfgs['resnet18'] - model = ResNet(BasicBlock, [2, 2, 2, 2], num_classes=num_classes, in_chans=in_chans, **kwargs) - model.default_cfg = default_cfg - if pretrained: - load_pretrained(model, default_cfg, num_classes, in_chans) - return model + model_args = dict(block=BasicBlock, layers=[2, 2, 2, 2], **kwargs) + return _create_resnet('resnet18', pretrained, **model_args) @register_model -def resnet34(pretrained=False, num_classes=1000, in_chans=3, **kwargs): +def resnet34(pretrained=False, **kwargs): """Constructs a ResNet-34 model. """ - default_cfg = default_cfgs['resnet34'] - model = ResNet(BasicBlock, [3, 4, 6, 3], num_classes=num_classes, in_chans=in_chans, **kwargs) - model.default_cfg = default_cfg - if pretrained: - load_pretrained(model, default_cfg, num_classes, in_chans) - return model + model_args = dict(block=BasicBlock, layers=[3, 4, 6, 3], **kwargs) + return _create_resnet('resnet34', pretrained, **model_args) @register_model -def resnet26(pretrained=False, num_classes=1000, in_chans=3, **kwargs): +def resnet26(pretrained=False, **kwargs): """Constructs a ResNet-26 model. """ - default_cfg = default_cfgs['resnet26'] - model = ResNet(Bottleneck, [2, 2, 2, 2], num_classes=num_classes, in_chans=in_chans, **kwargs) - model.default_cfg = default_cfg - if pretrained: - load_pretrained(model, default_cfg, num_classes, in_chans) - return model + model_args = dict(block=Bottleneck, layers=[2, 2, 2, 2], **kwargs) + return _create_resnet('resnet26', pretrained, **model_args) @register_model -def resnet26d(pretrained=False, num_classes=1000, in_chans=3, **kwargs): +def resnet26d(pretrained=False, **kwargs): """Constructs a ResNet-26 v1d model. This is technically a 28 layer ResNet, sticking with 'd' modifier from Gluon for now. """ - default_cfg = default_cfgs['resnet26d'] - model = ResNet( - Bottleneck, [2, 2, 2, 2], stem_width=32, stem_type='deep', avg_down=True, - num_classes=num_classes, in_chans=in_chans, **kwargs) - model.default_cfg = default_cfg - if pretrained: - load_pretrained(model, default_cfg, num_classes, in_chans) - return model + model_args = dict(block=Bottleneck, layers=[2, 2, 2, 2], stem_type='deep', avg_down=True, **kwargs) + return _create_resnet('resnet26d', pretrained, **model_args) @register_model -def resnet50(pretrained=False, num_classes=1000, in_chans=3, **kwargs): +def resnet50(pretrained=False, **kwargs): """Constructs a ResNet-50 model. """ - default_cfg = default_cfgs['resnet50'] - model = ResNet(Bottleneck, [3, 4, 6, 3], num_classes=num_classes, in_chans=in_chans, **kwargs) - model.default_cfg = default_cfg - if pretrained: - load_pretrained(model, default_cfg, num_classes, in_chans) - return model + model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], **kwargs) + return _create_resnet('resnet50', pretrained, **model_args) @register_model -def resnet50d(pretrained=False, num_classes=1000, in_chans=3, **kwargs): +def resnet50d(pretrained=False, **kwargs): """Constructs a ResNet-50-D model. """ - default_cfg = default_cfgs['resnet50d'] - model = ResNet( - Bottleneck, [3, 4, 6, 3], stem_width=32, stem_type='deep', avg_down=True, - num_classes=num_classes, in_chans=in_chans, **kwargs) - model.default_cfg = default_cfg - if pretrained: - load_pretrained(model, default_cfg, num_classes, in_chans) - return model + model_args = dict( + block=Bottleneck, layers=[3, 4, 6, 3], stem_width=32, stem_type='deep', avg_down=True, **kwargs) + return _create_resnet('resnet50d', pretrained, **model_args) @register_model -def resnet101(pretrained=False, num_classes=1000, in_chans=3, **kwargs): +def resnet101(pretrained=False, **kwargs): """Constructs a ResNet-101 model. """ - default_cfg = default_cfgs['resnet101'] - model = ResNet(Bottleneck, [3, 4, 23, 3], num_classes=num_classes, in_chans=in_chans, **kwargs) - model.default_cfg = default_cfg - if pretrained: - load_pretrained(model, default_cfg, num_classes, in_chans) - return model + model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], **kwargs) + return _create_resnet('resnet101', pretrained, **model_args) @register_model -def resnet152(pretrained=False, num_classes=1000, in_chans=3, **kwargs): +def resnet152(pretrained=False, **kwargs): """Constructs a ResNet-152 model. """ - default_cfg = default_cfgs['resnet152'] - model = ResNet(Bottleneck, [3, 8, 36, 3], num_classes=num_classes, in_chans=in_chans, **kwargs) - model.default_cfg = default_cfg - if pretrained: - load_pretrained(model, default_cfg, num_classes, in_chans) - return model + model_args = dict(block=Bottleneck, layers=[3, 8, 36, 3], **kwargs) + return _create_resnet('resnet152', pretrained, **model_args) @register_model -def tv_resnet34(pretrained=False, num_classes=1000, in_chans=3, **kwargs): +def tv_resnet34(pretrained=False, **kwargs): """Constructs a ResNet-34 model with original Torchvision weights. """ - model = ResNet(BasicBlock, [3, 4, 6, 3], num_classes=num_classes, in_chans=in_chans, **kwargs) - model.default_cfg = default_cfgs['tv_resnet34'] - if pretrained: - load_pretrained(model, model.default_cfg, num_classes, in_chans) - return model + model_args = dict(block=BasicBlock, layers=[3, 4, 6, 3], **kwargs) + return _create_resnet('tv_resnet34', pretrained, **model_args) @register_model -def tv_resnet50(pretrained=False, num_classes=1000, in_chans=3, **kwargs): +def tv_resnet50(pretrained=False, **kwargs): """Constructs a ResNet-50 model with original Torchvision weights. """ - model = ResNet(Bottleneck, [3, 4, 6, 3], num_classes=num_classes, in_chans=in_chans, **kwargs) - model.default_cfg = default_cfgs['tv_resnet50'] - if pretrained: - load_pretrained(model, model.default_cfg, num_classes, in_chans) - return model + model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], **kwargs) + return _create_resnet('tv_resnet50', pretrained, **model_args) @register_model -def wide_resnet50_2(pretrained=False, num_classes=1000, in_chans=3, **kwargs): +def wide_resnet50_2(pretrained=False, **kwargs): """Constructs a Wide ResNet-50-2 model. The model is the same as ResNet except for the bottleneck number of channels which is twice larger in every block. The number of channels in outer 1x1 convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048 channels, and in Wide ResNet-50-2 has 2048-1024-2048. """ - model = ResNet( - Bottleneck, [3, 4, 6, 3], base_width=128, - num_classes=num_classes, in_chans=in_chans, **kwargs) - model.default_cfg = default_cfgs['wide_resnet50_2'] - if pretrained: - load_pretrained(model, model.default_cfg, num_classes, in_chans) - return model + model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], base_width=128, **kwargs) + return _create_resnet('wide_resnet50_2', pretrained, **model_args) @register_model -def wide_resnet101_2(pretrained=False, num_classes=1000, in_chans=3, **kwargs): +def wide_resnet101_2(pretrained=False, **kwargs): """Constructs a Wide ResNet-101-2 model. The model is the same as ResNet except for the bottleneck number of channels which is twice larger in every block. The number of channels in outer 1x1 convolutions is the same. """ - model = ResNet( - Bottleneck, [3, 4, 23, 3], base_width=128, - num_classes=num_classes, in_chans=in_chans, **kwargs) - model.default_cfg = default_cfgs['wide_resnet101_2'] - if pretrained: - load_pretrained(model, model.default_cfg, num_classes, in_chans) - return model + model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], base_width=128, **kwargs) + return _create_resnet('wide_resnet101_2', pretrained, **model_args) @register_model -def resnext50_32x4d(pretrained=False, num_classes=1000, in_chans=3, **kwargs): +def resnext50_32x4d(pretrained=False, **kwargs): """Constructs a ResNeXt50-32x4d model. """ - default_cfg = default_cfgs['resnext50_32x4d'] - model = ResNet( - Bottleneck, [3, 4, 6, 3], cardinality=32, base_width=4, - num_classes=num_classes, in_chans=in_chans, **kwargs) - model.default_cfg = default_cfg - if pretrained: - load_pretrained(model, default_cfg, num_classes, in_chans) - return model + model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], cardinality=32, base_width=4, **kwargs) + return _create_resnet('resnext50_32x4d', pretrained, **model_args) @register_model -def resnext50d_32x4d(pretrained=False, num_classes=1000, in_chans=3, **kwargs): +def resnext50d_32x4d(pretrained=False, **kwargs): """Constructs a ResNeXt50d-32x4d model. ResNext50 w/ deep stem & avg pool downsample """ - default_cfg = default_cfgs['resnext50d_32x4d'] - model = ResNet( - Bottleneck, [3, 4, 6, 3], cardinality=32, base_width=4, - stem_width=32, stem_type='deep', avg_down=True, - num_classes=num_classes, in_chans=in_chans, **kwargs) - model.default_cfg = default_cfg - if pretrained: - load_pretrained(model, default_cfg, num_classes, in_chans) - return model + model_args = dict( + block=Bottleneck, layers=[3, 4, 6, 3], cardinality=32, base_width=4, + stem_width=32, stem_type='deep', avg_down=True, **kwargs) + return _create_resnet('resnext50d_32x4d', pretrained, **model_args) @register_model -def resnext101_32x4d(pretrained=False, num_classes=1000, in_chans=3, **kwargs): +def resnext101_32x4d(pretrained=False, **kwargs): """Constructs a ResNeXt-101 32x4d model. """ - default_cfg = default_cfgs['resnext101_32x4d'] - model = ResNet( - Bottleneck, [3, 4, 23, 3], cardinality=32, base_width=4, - num_classes=num_classes, in_chans=in_chans, **kwargs) - model.default_cfg = default_cfg - if pretrained: - load_pretrained(model, default_cfg, num_classes, in_chans) - return model + model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=4, **kwargs) + return _create_resnet('resnext101_32x4d', pretrained, **model_args) @register_model -def resnext101_32x8d(pretrained=False, num_classes=1000, in_chans=3, **kwargs): +def resnext101_32x8d(pretrained=False, **kwargs): """Constructs a ResNeXt-101 32x8d model. """ - default_cfg = default_cfgs['resnext101_32x8d'] - model = ResNet( - Bottleneck, [3, 4, 23, 3], cardinality=32, base_width=8, - num_classes=num_classes, in_chans=in_chans, **kwargs) - model.default_cfg = default_cfg - if pretrained: - load_pretrained(model, default_cfg, num_classes, in_chans) - return model + model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=8, **kwargs) + return _create_resnet('resnext101_32x8d', pretrained, **model_args) @register_model -def resnext101_64x4d(pretrained=False, num_classes=1000, in_chans=3, **kwargs): +def resnext101_64x4d(pretrained=False, **kwargs): """Constructs a ResNeXt101-64x4d model. """ - default_cfg = default_cfgs['resnext101_32x4d'] - model = ResNet( - Bottleneck, [3, 4, 23, 3], cardinality=64, base_width=4, - num_classes=num_classes, in_chans=in_chans, **kwargs) - model.default_cfg = default_cfg - if pretrained: - load_pretrained(model, default_cfg, num_classes, in_chans) - return model + model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], cardinality=64, base_width=4, **kwargs) + return _create_resnet('resnext101_64x4d', pretrained, **model_args) @register_model -def tv_resnext50_32x4d(pretrained=False, num_classes=1000, in_chans=3, **kwargs): +def tv_resnext50_32x4d(pretrained=False, **kwargs): """Constructs a ResNeXt50-32x4d model with original Torchvision weights. """ - default_cfg = default_cfgs['tv_resnext50_32x4d'] - model = ResNet( - Bottleneck, [3, 4, 6, 3], cardinality=32, base_width=4, - num_classes=num_classes, in_chans=in_chans, **kwargs) - model.default_cfg = default_cfg - if pretrained: - load_pretrained(model, default_cfg, num_classes, in_chans) - return model + model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], cardinality=32, base_width=4, **kwargs) + return _create_resnet('tv_resnext50_32x4d', pretrained, **model_args) @register_model @@ -757,11 +707,8 @@ def ig_resnext101_32x8d(pretrained=True, **kwargs): `"Exploring the Limits of Weakly Supervised Pretraining" `_ Weights from https://pytorch.org/hub/facebookresearch_WSL-Images_resnext/ """ - model = ResNet(Bottleneck, [3, 4, 23, 3], cardinality=32, base_width=8, **kwargs) - model.default_cfg = default_cfgs['ig_resnext101_32x8d'] - if pretrained: - load_pretrained(model, num_classes=kwargs.get('num_classes', 0), in_chans=kwargs.get('in_chans', 3)) - return model + model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=8, **kwargs) + return _create_resnet('ig_resnext101_32x8d', pretrained, **model_args) @register_model @@ -771,11 +718,8 @@ def ig_resnext101_32x16d(pretrained=True, **kwargs): `"Exploring the Limits of Weakly Supervised Pretraining" `_ Weights from https://pytorch.org/hub/facebookresearch_WSL-Images_resnext/ """ - model = ResNet(Bottleneck, [3, 4, 23, 3], cardinality=32, base_width=16, **kwargs) - model.default_cfg = default_cfgs['ig_resnext101_32x16d'] - if pretrained: - load_pretrained(model, num_classes=kwargs.get('num_classes', 0), in_chans=kwargs.get('in_chans', 3)) - return model + model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=16, **kwargs) + return _create_resnet('ig_resnext101_32x16d', pretrained, **model_args) @register_model @@ -785,11 +729,8 @@ def ig_resnext101_32x32d(pretrained=True, **kwargs): `"Exploring the Limits of Weakly Supervised Pretraining" `_ Weights from https://pytorch.org/hub/facebookresearch_WSL-Images_resnext/ """ - model = ResNet(Bottleneck, [3, 4, 23, 3], cardinality=32, base_width=32, **kwargs) - model.default_cfg = default_cfgs['ig_resnext101_32x32d'] - if pretrained: - load_pretrained(model, num_classes=kwargs.get('num_classes', 0), in_chans=kwargs.get('in_chans', 3)) - return model + model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=32, **kwargs) + return _create_resnet('ig_resnext101_32x32d', pretrained, **model_args) @register_model @@ -799,11 +740,8 @@ def ig_resnext101_32x48d(pretrained=True, **kwargs): `"Exploring the Limits of Weakly Supervised Pretraining" `_ Weights from https://pytorch.org/hub/facebookresearch_WSL-Images_resnext/ """ - model = ResNet(Bottleneck, [3, 4, 23, 3], cardinality=32, base_width=48, **kwargs) - model.default_cfg = default_cfgs['ig_resnext101_32x48d'] - if pretrained: - load_pretrained(model, num_classes=kwargs.get('num_classes', 0), in_chans=kwargs.get('in_chans', 3)) - return model + model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=48, **kwargs) + return _create_resnet('ig_resnext101_32x48d', pretrained, **model_args) @register_model @@ -812,11 +750,8 @@ def ssl_resnet18(pretrained=True, **kwargs): `"Billion-scale Semi-Supervised Learning for Image Classification" `_ Weights from https://github.com/facebookresearch/semi-supervised-ImageNet1K-models/ """ - model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) - model.default_cfg = default_cfgs['ssl_resnet18'] - if pretrained: - load_pretrained(model, num_classes=kwargs.get('num_classes', 0), in_chans=kwargs.get('in_chans', 3)) - return model + model_args = dict(block=BasicBlock, layers=[2, 2, 2, 2], **kwargs) + return _create_resnet('ssl_resnet18', pretrained, **model_args) @register_model @@ -825,11 +760,8 @@ def ssl_resnet50(pretrained=True, **kwargs): `"Billion-scale Semi-Supervised Learning for Image Classification" `_ Weights from https://github.com/facebookresearch/semi-supervised-ImageNet1K-models/ """ - model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) - model.default_cfg = default_cfgs['ssl_resnet50'] - if pretrained: - load_pretrained(model, num_classes=kwargs.get('num_classes', 0), in_chans=kwargs.get('in_chans', 3)) - return model + model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], **kwargs) + return _create_resnet('ssl_resnet50', pretrained, **model_args) @register_model @@ -838,11 +770,8 @@ def ssl_resnext50_32x4d(pretrained=True, **kwargs): `"Billion-scale Semi-Supervised Learning for Image Classification" `_ Weights from https://github.com/facebookresearch/semi-supervised-ImageNet1K-models/ """ - model = ResNet(Bottleneck, [3, 4, 6, 3], cardinality=32, base_width=4, **kwargs) - model.default_cfg = default_cfgs['ssl_resnext50_32x4d'] - if pretrained: - load_pretrained(model, num_classes=kwargs.get('num_classes', 0), in_chans=kwargs.get('in_chans', 3)) - return model + model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], cardinality=32, base_width=4, **kwargs) + return _create_resnet('ssl_resnext50_32x4d', pretrained, **model_args) @register_model @@ -851,11 +780,8 @@ def ssl_resnext101_32x4d(pretrained=True, **kwargs): `"Billion-scale Semi-Supervised Learning for Image Classification" `_ Weights from https://github.com/facebookresearch/semi-supervised-ImageNet1K-models/ """ - model = ResNet(Bottleneck, [3, 4, 23, 3], cardinality=32, base_width=4, **kwargs) - model.default_cfg = default_cfgs['ssl_resnext101_32x4d'] - if pretrained: - load_pretrained(model, num_classes=kwargs.get('num_classes', 0), in_chans=kwargs.get('in_chans', 3)) - return model + model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=4, **kwargs) + return _create_resnet('ssl_resnext101_32x4d', pretrained, **model_args) @register_model @@ -864,11 +790,8 @@ def ssl_resnext101_32x8d(pretrained=True, **kwargs): `"Billion-scale Semi-Supervised Learning for Image Classification" `_ Weights from https://github.com/facebookresearch/semi-supervised-ImageNet1K-models/ """ - model = ResNet(Bottleneck, [3, 4, 23, 3], cardinality=32, base_width=8, **kwargs) - model.default_cfg = default_cfgs['ssl_resnext101_32x8d'] - if pretrained: - load_pretrained(model, num_classes=kwargs.get('num_classes', 0), in_chans=kwargs.get('in_chans', 3)) - return model + model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=8, **kwargs) + return _create_resnet('ssl_resnext101_32x8d', pretrained, **model_args) @register_model @@ -877,11 +800,8 @@ def ssl_resnext101_32x16d(pretrained=True, **kwargs): `"Billion-scale Semi-Supervised Learning for Image Classification" `_ Weights from https://github.com/facebookresearch/semi-supervised-ImageNet1K-models/ """ - model = ResNet(Bottleneck, [3, 4, 23, 3], cardinality=32, base_width=16, **kwargs) - model.default_cfg = default_cfgs['ssl_resnext101_32x16d'] - if pretrained: - load_pretrained(model, num_classes=kwargs.get('num_classes', 0), in_chans=kwargs.get('in_chans', 3)) - return model + model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=16, **kwargs) + return _create_resnet('ssl_resnext101_32x16d', pretrained, **model_args) @register_model @@ -891,11 +811,8 @@ def swsl_resnet18(pretrained=True, **kwargs): `"Billion-scale Semi-Supervised Learning for Image Classification" `_ Weights from https://github.com/facebookresearch/semi-supervised-ImageNet1K-models/ """ - model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) - model.default_cfg = default_cfgs['swsl_resnet18'] - if pretrained: - load_pretrained(model, num_classes=kwargs.get('num_classes', 0), in_chans=kwargs.get('in_chans', 3)) - return model + model_args = dict(block=BasicBlock, layers=[2, 2, 2, 2], **kwargs) + return _create_resnet('swsl_resnet18', pretrained, **model_args) @register_model @@ -905,11 +822,8 @@ def swsl_resnet50(pretrained=True, **kwargs): `"Billion-scale Semi-Supervised Learning for Image Classification" `_ Weights from https://github.com/facebookresearch/semi-supervised-ImageNet1K-models/ """ - model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) - model.default_cfg = default_cfgs['swsl_resnet50'] - if pretrained: - load_pretrained(model, num_classes=kwargs.get('num_classes', 0), in_chans=kwargs.get('in_chans', 3)) - return model + model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], **kwargs) + return _create_resnet('swsl_resnet50', pretrained, **model_args) @register_model @@ -919,11 +833,8 @@ def swsl_resnext50_32x4d(pretrained=True, **kwargs): `"Billion-scale Semi-Supervised Learning for Image Classification" `_ Weights from https://github.com/facebookresearch/semi-supervised-ImageNet1K-models/ """ - model = ResNet(Bottleneck, [3, 4, 6, 3], cardinality=32, base_width=4, **kwargs) - model.default_cfg = default_cfgs['swsl_resnext50_32x4d'] - if pretrained: - load_pretrained(model, num_classes=kwargs.get('num_classes', 0), in_chans=kwargs.get('in_chans', 3)) - return model + model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], cardinality=32, base_width=4, **kwargs) + return _create_resnet('swsl_resnext50_32x4d', pretrained, **model_args) @register_model @@ -933,11 +844,8 @@ def swsl_resnext101_32x4d(pretrained=True, **kwargs): `"Billion-scale Semi-Supervised Learning for Image Classification" `_ Weights from https://github.com/facebookresearch/semi-supervised-ImageNet1K-models/ """ - model = ResNet(Bottleneck, [3, 4, 23, 3], cardinality=32, base_width=4, **kwargs) - model.default_cfg = default_cfgs['swsl_resnext101_32x4d'] - if pretrained: - load_pretrained(model, num_classes=kwargs.get('num_classes', 0), in_chans=kwargs.get('in_chans', 3)) - return model + model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=4, **kwargs) + return _create_resnet('swsl_resnext101_32x4d', pretrained, **model_args) @register_model @@ -947,11 +855,8 @@ def swsl_resnext101_32x8d(pretrained=True, **kwargs): `"Billion-scale Semi-Supervised Learning for Image Classification" `_ Weights from https://github.com/facebookresearch/semi-supervised-ImageNet1K-models/ """ - model = ResNet(Bottleneck, [3, 4, 23, 3], cardinality=32, base_width=8, **kwargs) - model.default_cfg = default_cfgs['swsl_resnext101_32x8d'] - if pretrained: - load_pretrained(model, num_classes=kwargs.get('num_classes', 0), in_chans=kwargs.get('in_chans', 3)) - return model + model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=8, **kwargs) + return _create_resnet('swsl_resnext101_32x8d', pretrained, **model_args) @register_model @@ -961,61 +866,44 @@ def swsl_resnext101_32x16d(pretrained=True, **kwargs): `"Billion-scale Semi-Supervised Learning for Image Classification" `_ Weights from https://github.com/facebookresearch/semi-supervised-ImageNet1K-models/ """ - model = ResNet(Bottleneck, [3, 4, 23, 3], cardinality=32, base_width=16, **kwargs) - model.default_cfg = default_cfgs['swsl_resnext101_32x16d'] - if pretrained: - load_pretrained(model, num_classes=kwargs.get('num_classes', 0), in_chans=kwargs.get('in_chans', 3)) - return model + model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=16, **kwargs) + return _create_resnet('swsl_resnext101_32x16d', pretrained, **model_args) @register_model -def seresnext26d_32x4d(pretrained=False, num_classes=1000, in_chans=3, **kwargs): +def seresnext26d_32x4d(pretrained=False, **kwargs): """Constructs a SE-ResNeXt-26-D model. This is technically a 28 layer ResNet, using the 'D' modifier from Gluon / bag-of-tricks for combination of deep stem and avg_pool in downsample. """ - default_cfg = default_cfgs['seresnext26d_32x4d'] - model = ResNet( - Bottleneck, [2, 2, 2, 2], cardinality=32, base_width=4, stem_width=32, stem_type='deep', avg_down=True, - num_classes=num_classes, in_chans=in_chans, block_args=dict(attn_layer='se'), **kwargs) - model.default_cfg = default_cfg - if pretrained: - load_pretrained(model, default_cfg, num_classes, in_chans) - return model + model_args = dict( + block=Bottleneck, layers=[2, 2, 2, 2], cardinality=32, base_width=4, stem_width=32, + stem_type='deep', avg_down=True, block_args=dict(attn_layer='se'), **kwargs) + return _create_resnet('seresnext26d_32x4d', pretrained, **model_args) @register_model -def seresnext26t_32x4d(pretrained=False, num_classes=1000, in_chans=3, **kwargs): +def seresnext26t_32x4d(pretrained=False, **kwargs): """Constructs a SE-ResNet-26-T model. This is technically a 28 layer ResNet, like a 'D' bag-of-tricks model but with tiered 24, 48, 64 channels in the deep stem. """ - default_cfg = default_cfgs['seresnext26t_32x4d'] - model = ResNet( - Bottleneck, [2, 2, 2, 2], cardinality=32, base_width=4, - stem_width=32, stem_type='deep_tiered', avg_down=True, - num_classes=num_classes, in_chans=in_chans, block_args=dict(attn_layer='se'), **kwargs) - model.default_cfg = default_cfg - if pretrained: - load_pretrained(model, default_cfg, num_classes, in_chans) - return model + model_args = dict( + block=Bottleneck, layers=[2, 2, 2, 2], cardinality=32, base_width=4, stem_width=32, + stem_type='deep_tiered', avg_down=True, block_args=dict(attn_layer='se'), **kwargs) + return _create_resnet('seresnext26t_32x4d', pretrained, **model_args) @register_model -def seresnext26tn_32x4d(pretrained=False, num_classes=1000, in_chans=3, **kwargs): +def seresnext26tn_32x4d(pretrained=False, **kwargs): """Constructs a SE-ResNeXt-26-TN model. This is technically a 28 layer ResNet, like a 'D' bag-of-tricks model but with tiered 24, 32, 64 channels in the deep stem. The channel number of the middle stem conv is narrower than the 'T' variant. """ - default_cfg = default_cfgs['seresnext26tn_32x4d'] - model = ResNet( - Bottleneck, [2, 2, 2, 2], cardinality=32, base_width=4, - stem_width=32, stem_type='deep_tiered_narrow', avg_down=True, - num_classes=num_classes, in_chans=in_chans, block_args=dict(attn_layer='se'), **kwargs) - model.default_cfg = default_cfg - if pretrained: - load_pretrained(model, default_cfg, num_classes, in_chans) - return model + model_args = dict( + block=Bottleneck, layers=[2, 2, 2, 2], cardinality=32, base_width=4, stem_width=32, + stem_type='deep_tiered_narrow', avg_down=True, block_args=dict(attn_layer='se'), **kwargs) + return _create_resnet('seresnext26tn_32x4d', pretrained, **model_args) @register_model @@ -1025,145 +913,91 @@ def ecaresnext26tn_32x4d(pretrained=False, num_classes=1000, in_chans=3, **kwarg in the deep stem. The channel number of the middle stem conv is narrower than the 'T' variant. this model replaces SE module with the ECA module """ - default_cfg = default_cfgs['ecaresnext26tn_32x4d'] - block_args = dict(attn_layer='eca') - model = ResNet( - Bottleneck, [2, 2, 2, 2], cardinality=32, base_width=4, - stem_width=32, stem_type='deep_tiered_narrow', avg_down=True, - num_classes=num_classes, in_chans=in_chans, block_args=block_args, **kwargs) - model.default_cfg = default_cfg - if pretrained: - load_pretrained(model, default_cfg, num_classes, in_chans) - return model + model_args = dict( + block=Bottleneck, layers=[2, 2, 2, 2], cardinality=32, base_width=4, stem_width=32, + stem_type='deep_tiered_narrow', avg_down=True, block_args=dict(attn_layer='eca'), **kwargs) + return _create_resnet('ecaresnext26tn_32x4d', pretrained, **model_args) @register_model -def ecaresnet18(pretrained=False, num_classes=1000, in_chans=3, **kwargs): +def ecaresnet18(pretrained=False, **kwargs): """ Constructs an ECA-ResNet-18 model. """ - default_cfg = default_cfgs['ecaresnet18'] - block_args = dict(attn_layer='eca') - model = ResNet( - BasicBlock, [2, 2, 2, 2], num_classes=num_classes, in_chans=in_chans, block_args=block_args, **kwargs) - model.default_cfg = default_cfg - if pretrained: - load_pretrained(model, default_cfg, num_classes, in_chans) - return model + model_args = dict(block=BasicBlock, layers=[2, 2, 2, 2], block_args=dict(attn_layer='eca'), **kwargs) + return _create_resnet('ecaresnet18', pretrained, **model_args) @register_model -def ecaresnet50(pretrained=False, num_classes=1000, in_chans=3, **kwargs): +def ecaresnet50(pretrained=False, **kwargs): """Constructs an ECA-ResNet-50 model. """ - default_cfg = default_cfgs['ecaresnet50'] - block_args = dict(attn_layer='eca') - model = ResNet( - Bottleneck, [3, 4, 6, 3], num_classes=num_classes, in_chans=in_chans, block_args=block_args, **kwargs) - model.default_cfg = default_cfg - if pretrained: - load_pretrained(model, default_cfg, num_classes, in_chans) - return model + model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], block_args=dict(attn_layer='eca'), **kwargs) + return _create_resnet('ecaresnet50', pretrained, **model_args) @register_model -def ecaresnet50d(pretrained=False, num_classes=1000, in_chans=3, **kwargs): +def ecaresnet50d(pretrained=False, **kwargs): """Constructs a ResNet-50-D model with eca. """ - default_cfg = default_cfgs['ecaresnet50d'] - model = ResNet( - Bottleneck, [3, 4, 6, 3], stem_width=32, stem_type='deep', avg_down=True, - num_classes=num_classes, in_chans=in_chans, block_args=dict(attn_layer='eca'), **kwargs) - model.default_cfg = default_cfg - if pretrained: - load_pretrained(model, default_cfg, num_classes, in_chans) - return model + model_args = dict( + block=Bottleneck, layers=[3, 4, 6, 3], stem_width=32, stem_type='deep', avg_down=True, + block_args=dict(attn_layer='eca'), **kwargs) + return _create_resnet('ecaresnet50d', pretrained, **model_args) @register_model -def ecaresnet50d_pruned(pretrained=False, num_classes=1000, in_chans=3, **kwargs): +def ecaresnet50d_pruned(pretrained=False, **kwargs): """Constructs a ResNet-50-D model pruned with eca. The pruning has been obtained using https://arxiv.org/pdf/2002.08258.pdf """ - variant = 'ecaresnet50d_pruned' - default_cfg = default_cfgs[variant] - model = ResNet( - Bottleneck, [3, 4, 6, 3], stem_width=32, stem_type='deep', avg_down=True, - num_classes=num_classes, in_chans=in_chans, block_args=dict(attn_layer='eca'), **kwargs) - model.default_cfg = default_cfg - model = adapt_model_from_file(model, variant) - - if pretrained: - load_pretrained(model, default_cfg, num_classes, in_chans) - return model + model_args = dict( + block=Bottleneck, layers=[3, 4, 6, 3], stem_width=32, stem_type='deep', avg_down=True, + block_args=dict(attn_layer='eca'), **kwargs) + return _create_resnet('ecaresnet50d_pruned', pretrained, pruned=True, **model_args) @register_model -def ecaresnetlight(pretrained=False, num_classes=1000, in_chans=3, **kwargs): +def ecaresnetlight(pretrained=False, **kwargs): """Constructs a ResNet-50-D light model with eca. """ - default_cfg = default_cfgs['ecaresnetlight'] - model = ResNet( - Bottleneck, [1, 1, 11, 3], stem_width=32, avg_down=True, - num_classes=num_classes, in_chans=in_chans, block_args=dict(attn_layer='eca'), **kwargs) - model.default_cfg = default_cfg - if pretrained: - load_pretrained(model, default_cfg, num_classes, in_chans) - return model + model_args = dict( + block=Bottleneck, layers=[1, 1, 11, 3], stem_width=32, avg_down=True, + block_args=dict(attn_layer='eca'), **kwargs) + return _create_resnet('ecaresnetlight', pretrained, **model_args) @register_model -def ecaresnet101d(pretrained=False, num_classes=1000, in_chans=3, **kwargs): +def ecaresnet101d(pretrained=False, **kwargs): """Constructs a ResNet-101-D model with eca. """ - default_cfg = default_cfgs['ecaresnet101d'] - model = ResNet( - Bottleneck, [3, 4, 23, 3], stem_width=32, stem_type='deep', avg_down=True, - num_classes=num_classes, in_chans=in_chans, block_args=dict(attn_layer='eca'), **kwargs) - model.default_cfg = default_cfg - if pretrained: - load_pretrained(model, default_cfg, num_classes, in_chans) - return model + model_args = dict( + block=Bottleneck, layers=[3, 4, 23, 3], stem_width=32, stem_type='deep', avg_down=True, + block_args=dict(attn_layer='eca'), **kwargs) + return _create_resnet('ecaresnet101d', pretrained, **model_args) @register_model -def ecaresnet101d_pruned(pretrained=False, num_classes=1000, in_chans=3, **kwargs): +def ecaresnet101d_pruned(pretrained=False, **kwargs): """Constructs a ResNet-101-D model pruned with eca. The pruning has been obtained using https://arxiv.org/pdf/2002.08258.pdf """ - variant = 'ecaresnet101d_pruned' - default_cfg = default_cfgs[variant] - model = ResNet( - Bottleneck, [3, 4, 23, 3], stem_width=32, stem_type='deep', avg_down=True, - num_classes=num_classes, in_chans=in_chans, block_args=dict(attn_layer='eca'), **kwargs) - model.default_cfg = default_cfg - model = adapt_model_from_file(model, variant) - - if pretrained: - load_pretrained(model, default_cfg, num_classes, in_chans) - return model + model_args = dict( + block=Bottleneck, layers=[3, 4, 23, 3], stem_width=32, stem_type='deep', avg_down=True, + block_args=dict(attn_layer='eca'), **kwargs) + return _create_resnet('ecaresnet101d_pruned', pretrained, pruned=True, **model_args) @register_model -def resnetblur18(pretrained=False, num_classes=1000, in_chans=3, **kwargs): +def resnetblur18(pretrained=False, **kwargs): """Constructs a ResNet-18 model with blur anti-aliasing """ - default_cfg = default_cfgs['resnetblur18'] - model = ResNet( - BasicBlock, [2, 2, 2, 2], num_classes=num_classes, in_chans=in_chans, aa_layer=BlurPool2d, **kwargs) - model.default_cfg = default_cfg - if pretrained: - load_pretrained(model, default_cfg, num_classes, in_chans) - return model + model_args = dict(block=BasicBlock, layers=[2, 2, 2, 2], aa_layer=BlurPool2d, **kwargs) + return _create_resnet('resnetblur18', pretrained, **model_args) @register_model -def resnetblur50(pretrained=False, num_classes=1000, in_chans=3, **kwargs): +def resnetblur50(pretrained=False, **kwargs): """Constructs a ResNet-50 model with blur anti-aliasing """ - default_cfg = default_cfgs['resnetblur50'] - model = ResNet( - Bottleneck, [3, 4, 6, 3], num_classes=num_classes, in_chans=in_chans, aa_layer=BlurPool2d, **kwargs) - model.default_cfg = default_cfg - if pretrained: - load_pretrained(model, default_cfg, num_classes, in_chans) - return model + model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], aa_layer=BlurPool2d, **kwargs) + return _create_resnet('resnetblur50', pretrained, **model_args) diff --git a/timm/models/selecsls.py b/timm/models/selecsls.py index b7573086..5dddedb5 100644 --- a/timm/models/selecsls.py +++ b/timm/models/selecsls.py @@ -16,6 +16,7 @@ import torch.nn as nn import torch.nn.functional as F from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from .features import FeatureNet from .helpers import load_pretrained from .layers import SelectAdaptivePool2d from .registry import register_model @@ -100,7 +101,8 @@ class SelecSLSBlock(nn.Module): self.conv6 = conv_bn(2 * mid_chs + (0 if is_first else skip_chs), out_chs, 1) def forward(self, x: List[torch.Tensor]) -> List[torch.Tensor]: - assert isinstance(x, list) + if not isinstance(x, list): + x = [x] assert len(x) in [1, 2] d1 = self.conv1(x[0]) @@ -163,7 +165,7 @@ class SelecSLS(nn.Module): def forward_features(self, x): x = self.stem(x) - x = self.features([x]) + x = self.features(x) x = self.head(x[0]) return x @@ -178,6 +180,7 @@ class SelecSLS(nn.Module): def _create_model(variant, pretrained, model_kwargs): cfg = {} + feature_info = [dict(num_chs=32, reduction=2, module='stem.2')] if variant.startswith('selecsls42'): cfg['block'] = SelecSLSBlock # Define configuration of the network after the initial neck @@ -190,7 +193,13 @@ def _create_model(variant, pretrained, model_kwargs): (288, 0, 304, 304, True, 2), (304, 304, 304, 480, False, 1), ] + feature_info.extend([ + dict(num_chs=128, reduction=4, module='features.1'), + dict(num_chs=288, reduction=8, module='features.3'), + dict(num_chs=480, reduction=16, module='features.5'), + ]) # Head can be replaced with alternative configurations depending on the problem + feature_info.append(dict(num_chs=1024, reduction=32, module='head.1')) if variant == 'selecsls42b': cfg['head'] = [ (480, 960, 3, 2), @@ -198,6 +207,7 @@ def _create_model(variant, pretrained, model_kwargs): (1024, 1280, 3, 2), (1280, 1024, 1, 1), ] + feature_info.append(dict(num_chs=1024, reduction=64, module='head.3')) cfg['num_features'] = 1024 else: cfg['head'] = [ @@ -206,7 +216,9 @@ def _create_model(variant, pretrained, model_kwargs): (1024, 1024, 3, 2), (1024, 1280, 1, 1), ] + feature_info.append(dict(num_chs=1280, reduction=64, module='head.3')) cfg['num_features'] = 1280 + elif variant.startswith('selecsls60'): cfg['block'] = SelecSLSBlock # Define configuration of the network after the initial neck @@ -222,7 +234,13 @@ def _create_model(variant, pretrained, model_kwargs): (288, 288, 288, 288, False, 1), (288, 288, 288, 416, False, 1), ] + feature_info.extend([ + dict(num_chs=128, reduction=4, module='features.1'), + dict(num_chs=288, reduction=8, module='features.4'), + dict(num_chs=416, reduction=16, module='features.8'), + ]) # Head can be replaced with alternative configurations depending on the problem + feature_info.append(dict(num_chs=1024, reduction=32, module='head.1')) if variant == 'selecsls60b': cfg['head'] = [ (416, 756, 3, 2), @@ -230,6 +248,7 @@ def _create_model(variant, pretrained, model_kwargs): (1024, 1280, 3, 2), (1280, 1024, 1, 1), ] + feature_info.append(dict(num_chs=1024, reduction=64, module='head.3')) cfg['num_features'] = 1024 else: cfg['head'] = [ @@ -238,7 +257,9 @@ def _create_model(variant, pretrained, model_kwargs): (1024, 1024, 3, 2), (1024, 1280, 1, 1), ] + feature_info.append(dict(num_chs=1280, reduction=64, module='head.3')) cfg['num_features'] = 1280 + elif variant == 'selecsls84': cfg['block'] = SelecSLSBlock # Define configuration of the network after the initial neck @@ -258,6 +279,11 @@ def _create_model(variant, pretrained, model_kwargs): (304, 304, 304, 304, False, 1), (304, 304, 304, 512, False, 1), ] + feature_info.extend([ + dict(num_chs=144, reduction=4, module='features.1'), + dict(num_chs=304, reduction=8, module='features.6'), + dict(num_chs=512, reduction=16, module='features.12'), + ]) # Head can be replaced with alternative configurations depending on the problem cfg['head'] = [ (512, 960, 3, 2), @@ -266,17 +292,35 @@ def _create_model(variant, pretrained, model_kwargs): (1024, 1280, 3, 1), ] cfg['num_features'] = 1280 + feature_info.extend([ + dict(num_chs=1024, reduction=32, module='head.1'), + dict(num_chs=1280, reduction=64, module='head.3') + ]) else: raise ValueError('Invalid net configuration ' + variant + ' !!!') + load_strict = True + features = False + out_indices = None + if model_kwargs.pop('features_only', False): + load_strict = False + features = True + # this model can do 6 feature levels by default, unlike most others, leave as 0-4 to avoid surprises? + out_indices = model_kwargs.pop('out_indices', (0, 1, 2, 3, 4)) + model_kwargs.pop('num_classes', 0) + model = SelecSLS(cfg, **model_kwargs) model.default_cfg = default_cfgs[variant] + model.feature_info = feature_info if pretrained: load_pretrained( model, num_classes=model_kwargs.get('num_classes', 0), in_chans=model_kwargs.get('in_chans', 3), - strict=True) + strict=load_strict) + + if features: + model = FeatureNet(model, out_indices, flatten_sequential=True) return model diff --git a/timm/models/sknet.py b/timm/models/sknet.py index 2ba1b772..2bbf9786 100644 --- a/timm/models/sknet.py +++ b/timm/models/sknet.py @@ -12,11 +12,11 @@ import math from torch import nn as nn -from .registry import register_model +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from .helpers import load_pretrained from .layers import SelectiveKernelConv, ConvBnAct, create_attn -from .resnet import ResNet -from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from .registry import register_model +from .resnet import _create_resnet_with_cfg def _cfg(url='', **kwargs): @@ -138,101 +138,80 @@ class SelectiveKernelBottleneck(nn.Module): return x +def _create_skresnet(variant, pretrained=False, **kwargs): + default_cfg = default_cfgs[variant] + return _create_resnet_with_cfg(variant, default_cfg, pretrained=pretrained, **kwargs) + + @register_model -def skresnet18(pretrained=False, num_classes=1000, in_chans=3, **kwargs): +def skresnet18(pretrained=False, **kwargs): """Constructs a Selective Kernel ResNet-18 model. Different from configs in Select Kernel paper or "Compounding the Performance Improvements..." this variation splits the input channels to the selective convolutions to keep param count down. """ - default_cfg = default_cfgs['skresnet18'] sk_kwargs = dict( min_attn_channels=16, attn_reduction=8, - split_input=True - ) - model = ResNet( - SelectiveKernelBasic, [2, 2, 2, 2], num_classes=num_classes, in_chans=in_chans, - block_args=dict(sk_kwargs=sk_kwargs), zero_init_last_bn=False, **kwargs) - model.default_cfg = default_cfg - if pretrained: - load_pretrained(model, default_cfg, num_classes, in_chans) - return model + split_input=True) + model_args = dict( + block=SelectiveKernelBasic, layers=[2, 2, 2, 2], block_args=dict(sk_kwargs=sk_kwargs), + zero_init_last_bn=False, **kwargs) + return _create_skresnet('skresnet18', pretrained, **model_args) @register_model -def skresnet34(pretrained=False, num_classes=1000, in_chans=3, **kwargs): +def skresnet34(pretrained=False, **kwargs): """Constructs a Selective Kernel ResNet-34 model. Different from configs in Select Kernel paper or "Compounding the Performance Improvements..." this variation splits the input channels to the selective convolutions to keep param count down. """ - default_cfg = default_cfgs['skresnet34'] sk_kwargs = dict( min_attn_channels=16, attn_reduction=8, - split_input=True - ) - model = ResNet( - SelectiveKernelBasic, [3, 4, 6, 3], num_classes=num_classes, in_chans=in_chans, - block_args=dict(sk_kwargs=sk_kwargs), zero_init_last_bn=False, **kwargs) - model.default_cfg = default_cfg - if pretrained: - load_pretrained(model, default_cfg, num_classes, in_chans) - return model + split_input=True) + model_args = dict( + block=SelectiveKernelBasic, layers=[3, 4, 6, 3], block_args=dict(sk_kwargs=sk_kwargs), + zero_init_last_bn=False, **kwargs) + return _create_skresnet('skresnet34', pretrained, **model_args) @register_model -def skresnet50(pretrained=False, num_classes=1000, in_chans=3, **kwargs): +def skresnet50(pretrained=False, **kwargs): """Constructs a Select Kernel ResNet-50 model. Different from configs in Select Kernel paper or "Compounding the Performance Improvements..." this variation splits the input channels to the selective convolutions to keep param count down. """ - sk_kwargs = dict( - split_input=True, - ) - default_cfg = default_cfgs['skresnet50'] - model = ResNet( - SelectiveKernelBottleneck, [3, 4, 6, 3], num_classes=num_classes, in_chans=in_chans, - block_args=dict(sk_kwargs=sk_kwargs), zero_init_last_bn=False, **kwargs) - model.default_cfg = default_cfg - if pretrained: - load_pretrained(model, default_cfg, num_classes, in_chans) - return model + sk_kwargs = dict(split_input=True) + model_args = dict( + block=SelectiveKernelBottleneck, layers=[3, 4, 6, 3], block_args=dict(sk_kwargs=sk_kwargs), + zero_init_last_bn=False, **kwargs) + return _create_skresnet('skresnet50', pretrained, **model_args) @register_model -def skresnet50d(pretrained=False, num_classes=1000, in_chans=3, **kwargs): +def skresnet50d(pretrained=False, **kwargs): """Constructs a Select Kernel ResNet-50-D model. Different from configs in Select Kernel paper or "Compounding the Performance Improvements..." this variation splits the input channels to the selective convolutions to keep param count down. """ - sk_kwargs = dict( - split_input=True, - ) - default_cfg = default_cfgs['skresnet50d'] - model = ResNet( - SelectiveKernelBottleneck, [3, 4, 6, 3], stem_width=32, stem_type='deep', avg_down=True, - num_classes=num_classes, in_chans=in_chans, block_args=dict(sk_kwargs=sk_kwargs), - zero_init_last_bn=False, **kwargs) - model.default_cfg = default_cfg - if pretrained: - load_pretrained(model, default_cfg, num_classes, in_chans) - return model + sk_kwargs = dict(split_input=True) + model_args = dict( + block=SelectiveKernelBottleneck, layers=[3, 4, 6, 3], stem_width=32, stem_type='deep', avg_down=True, + block_args=dict(sk_kwargs=sk_kwargs), zero_init_last_bn=False, **kwargs) + return _create_skresnet('skresnet50d', pretrained, **model_args) @register_model -def skresnext50_32x4d(pretrained=False, num_classes=1000, in_chans=3, **kwargs): +def skresnext50_32x4d(pretrained=False, **kwargs): """Constructs a Select Kernel ResNeXt50-32x4d model. This should be equivalent to the SKNet-50 model in the Select Kernel Paper """ - default_cfg = default_cfgs['skresnext50_32x4d'] - model = ResNet( - SelectiveKernelBottleneck, [3, 4, 6, 3], cardinality=32, base_width=4, - num_classes=num_classes, in_chans=in_chans, zero_init_last_bn=False, **kwargs) - model.default_cfg = default_cfg - if pretrained: - load_pretrained(model, default_cfg, num_classes, in_chans) - return model + model_args = dict( + block=SelectiveKernelBottleneck, layers=[3, 4, 6, 3], cardinality=32, base_width=4, + zero_init_last_bn=False, **kwargs) + return _create_skresnet('skresnext50_32x4d', pretrained, **model_args) + diff --git a/timm/models/vovnet.py b/timm/models/vovnet.py index 3fd79a13..0793120e 100644 --- a/timm/models/vovnet.py +++ b/timm/models/vovnet.py @@ -20,6 +20,7 @@ import torch.nn.functional as F from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from .registry import register_model from .helpers import load_pretrained +from .features import FeatureNet from .layers import ConvBnAct, SeparableConvBnAct, BatchNormAct2d, SelectAdaptivePool2d, \ create_attn, create_norm_act, get_norm_act_layer @@ -296,6 +297,9 @@ class VovNet(nn.Module): conv_type(stem_chs[0], stem_chs[1], 3, stride=1, norm_layer=norm_layer), conv_type(stem_chs[1], stem_chs[2], 3, stride=last_stem_stride, norm_layer=norm_layer), ]) + self.feature_info = [dict( + num_chs=stem_chs[1], reduction=2, module=f'stem.{1 if stem_stride == 4 else 2}')] + current_stride = stem_stride # OSA stages in_ch_list = stem_chs[-1:] + stage_out_chs[:-1] @@ -309,6 +313,9 @@ class VovNet(nn.Module): downsample=downsample, **stage_args) ] self.num_features = stage_out_chs[i] + current_stride *= 2 if downsample else 1 + self.feature_info += [dict(num_chs=self.num_features, reduction=current_stride, module=f'stages.{i}')] + self.stages = nn.Sequential(*stages) self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=drop_rate) @@ -338,24 +345,24 @@ class VovNet(nn.Module): def _vovnet(variant, pretrained=False, **kwargs): - load_strict = True - model_class = VovNet + features = False + out_indices = None if kwargs.pop('features_only', False): - assert False, 'Not Implemented' # TODO - load_strict = False + features = True kwargs.pop('num_classes', 0) + out_indices = kwargs.pop('out_indices', (0, 1, 2, 3, 4)) model_cfg = model_cfgs[variant] - default_cfg = default_cfgs[variant] - model = model_class(model_cfg, **kwargs) - model.default_cfg = default_cfg + model = VovNet(model_cfg, **kwargs) + model.default_cfg = default_cfgs[variant] if pretrained: load_pretrained( - model, default_cfg, - num_classes=kwargs.get('num_classes', 0), in_chans=kwargs.get('in_chans', 3), strict=load_strict) + model, + num_classes=kwargs.get('num_classes', 0), in_chans=kwargs.get('in_chans', 3), strict=not features) + if features: + model = FeatureNet(model, out_indices, flatten_sequential=True) return model - @register_model def vovnet39a(pretrained=False, **kwargs): return _vovnet('vovnet39a', pretrained=pretrained, **kwargs) diff --git a/timm/models/xception.py b/timm/models/xception.py index 8dea81b9..60241f29 100644 --- a/timm/models/xception.py +++ b/timm/models/xception.py @@ -26,6 +26,7 @@ import torch.nn as nn import torch.nn.functional as F from .helpers import load_pretrained +from .features import FeatureNet from .layers import SelectAdaptivePool2d from .registry import register_model @@ -49,12 +50,12 @@ default_cfgs = { class SeparableConv2d(nn.Module): - def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, padding=0, dilation=1, bias=False): + def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, padding=0, dilation=1): super(SeparableConv2d, self).__init__() self.conv1 = nn.Conv2d( - in_channels, in_channels, kernel_size, stride, padding, dilation, groups=in_channels, bias=bias) - self.pointwise = nn.Conv2d(in_channels, out_channels, 1, 1, 0, 1, 1, bias=bias) + in_channels, in_channels, kernel_size, stride, padding, dilation, groups=in_channels, bias=False) + self.pointwise = nn.Conv2d(in_channels, out_channels, 1, 1, 0, 1, 1, bias=False) def forward(self, x): x = self.conv1(x) @@ -63,34 +64,26 @@ class SeparableConv2d(nn.Module): class Block(nn.Module): - def __init__(self, in_filters, out_filters, reps, strides=1, start_with_relu=True, grow_first=True): + def __init__(self, in_channels, out_channels, reps, strides=1, start_with_relu=True, grow_first=True): super(Block, self).__init__() - if out_filters != in_filters or strides != 1: - self.skip = nn.Conv2d(in_filters, out_filters, 1, stride=strides, bias=False) - self.skipbn = nn.BatchNorm2d(out_filters) + if out_channels != in_channels or strides != 1: + self.skip = nn.Conv2d(in_channels, out_channels, 1, stride=strides, bias=False) + self.skipbn = nn.BatchNorm2d(out_channels) else: self.skip = None - self.relu = nn.ReLU(inplace=True) rep = [] - - filters = in_filters - if grow_first: - rep.append(self.relu) - rep.append(SeparableConv2d(in_filters, out_filters, 3, stride=1, padding=1, bias=False)) - rep.append(nn.BatchNorm2d(out_filters)) - filters = out_filters - - for i in range(reps - 1): - rep.append(self.relu) - rep.append(SeparableConv2d(filters, filters, 3, stride=1, padding=1, bias=False)) - rep.append(nn.BatchNorm2d(filters)) - - if not grow_first: - rep.append(self.relu) - rep.append(SeparableConv2d(in_filters, out_filters, 3, stride=1, padding=1, bias=False)) - rep.append(nn.BatchNorm2d(out_filters)) + for i in range(reps): + if grow_first: + inc = in_channels if i == 0 else out_channels + outc = out_channels + else: + inc = in_channels + outc = in_channels if i < (reps - 1) else out_channels + rep.append(nn.ReLU(inplace=True)) + rep.append(SeparableConv2d(inc, outc, 3, stride=1, padding=1)) + rep.append(nn.BatchNorm2d(outc)) if not start_with_relu: rep = rep[1:] @@ -133,34 +126,35 @@ class Xception(nn.Module): self.conv1 = nn.Conv2d(in_chans, 32, 3, 2, 0, bias=False) self.bn1 = nn.BatchNorm2d(32) - self.relu = nn.ReLU(inplace=True) + self.act1 = nn.ReLU(inplace=True) self.conv2 = nn.Conv2d(32, 64, 3, bias=False) self.bn2 = nn.BatchNorm2d(64) - # do relu here + self.act2 = nn.ReLU(inplace=True) - self.block1 = Block(64, 128, 2, 2, start_with_relu=False, grow_first=True) - self.block2 = Block(128, 256, 2, 2, start_with_relu=True, grow_first=True) - self.block3 = Block(256, 728, 2, 2, start_with_relu=True, grow_first=True) + self.block1 = Block(64, 128, 2, 2, start_with_relu=False) + self.block2 = Block(128, 256, 2, 2) + self.block3 = Block(256, 728, 2, 2) - self.block4 = Block(728, 728, 3, 1, start_with_relu=True, grow_first=True) - self.block5 = Block(728, 728, 3, 1, start_with_relu=True, grow_first=True) - self.block6 = Block(728, 728, 3, 1, start_with_relu=True, grow_first=True) - self.block7 = Block(728, 728, 3, 1, start_with_relu=True, grow_first=True) + self.block4 = Block(728, 728, 3, 1) + self.block5 = Block(728, 728, 3, 1) + self.block6 = Block(728, 728, 3, 1) + self.block7 = Block(728, 728, 3, 1) - self.block8 = Block(728, 728, 3, 1, start_with_relu=True, grow_first=True) - self.block9 = Block(728, 728, 3, 1, start_with_relu=True, grow_first=True) - self.block10 = Block(728, 728, 3, 1, start_with_relu=True, grow_first=True) - self.block11 = Block(728, 728, 3, 1, start_with_relu=True, grow_first=True) + self.block8 = Block(728, 728, 3, 1) + self.block9 = Block(728, 728, 3, 1) + self.block10 = Block(728, 728, 3, 1) + self.block11 = Block(728, 728, 3, 1) - self.block12 = Block(728, 1024, 2, 2, start_with_relu=True, grow_first=False) + self.block12 = Block(728, 1024, 2, 2, grow_first=False) self.conv3 = SeparableConv2d(1024, 1536, 3, 1, 1) self.bn3 = nn.BatchNorm2d(1536) + self.act3 = nn.ReLU(inplace=True) - # do relu here self.conv4 = SeparableConv2d(1536, self.num_features, 3, 1, 1) self.bn4 = nn.BatchNorm2d(self.num_features) + self.act4 = nn.ReLU(inplace=True) self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) self.fc = nn.Linear(self.num_features * self.global_pool.feat_mult(), num_classes) @@ -188,11 +182,11 @@ class Xception(nn.Module): def forward_features(self, x): x = self.conv1(x) x = self.bn1(x) - x = self.relu(x) + x = self.act1(x) x = self.conv2(x) x = self.bn2(x) - x = self.relu(x) + x = self.act2(x) x = self.block1(x) x = self.block2(x) @@ -209,11 +203,11 @@ class Xception(nn.Module): x = self.conv3(x) x = self.bn3(x) - x = self.relu(x) + x = self.act3(x) x = self.conv4(x) x = self.bn4(x) - x = self.relu(x) + x = self.act4(x) return x def forward(self, x): @@ -225,12 +219,28 @@ class Xception(nn.Module): return x -@register_model -def xception(pretrained=False, num_classes=1000, in_chans=3, **kwargs): - default_cfg = default_cfgs['xception'] - model = Xception(num_classes=num_classes, in_chans=in_chans, **kwargs) - model.default_cfg = default_cfg +def _xception(variant, pretrained=False, **kwargs): + load_strict = True + features = False + out_indices = None + if kwargs.pop('features_only', False): + load_strict = False + features = True + kwargs.pop('num_classes', 0) + out_indices = kwargs.pop('out_indices', (0, 1, 2, 3, 4)) + model = Xception(**kwargs) + model.default_cfg = default_cfgs[variant] if pretrained: - load_pretrained(model, default_cfg, num_classes, in_chans) - + load_pretrained( + model, + num_classes=kwargs.get('num_classes', 0), + in_chans=kwargs.get('in_chans', 3), + strict=load_strict) + if features: + model = FeatureNet(model, out_indices) return model + + +@register_model +def xception(pretrained=False, **kwargs): + return _xception('xception', pretrained=pretrained, **kwargs) diff --git a/validate.py b/validate.py index ebd4d849..576567bd 100755 --- a/validate.py +++ b/validate.py @@ -24,9 +24,8 @@ try: except ImportError: has_apex = False -from timm.models import create_model, apply_test_time_pool, load_checkpoint, is_model, list_models,\ - set_scriptable, set_no_jit -from timm.data import Dataset, DatasetTar, create_loader, resolve_data_config +from timm.models import create_model, apply_test_time_pool, load_checkpoint, is_model, list_models +from timm.data import Dataset, DatasetTar, create_loader, resolve_data_config, RealLabelsImagenet from timm.utils import accuracy, AverageMeter, natural_key, setup_default_logging torch.backends.cudnn.benchmark = True @@ -76,8 +75,25 @@ parser.add_argument('--use-ema', dest='use_ema', action='store_true', help='use ema version of weights if present') parser.add_argument('--torchscript', dest='torchscript', action='store_true', help='convert model torchscript for inference') +parser.add_argument('--legacy-jit', dest='legacy_jit', action='store_true', + help='use legacy jit mode for pytorch 1.5/1.5.1/1.6 to get back fusion performance') parser.add_argument('--results-file', default='', type=str, metavar='FILENAME', help='Output csv file for validation results (summary)') +parser.add_argument('--real-labels', default='', type=str, metavar='FILENAME', + help='Real labels JSON file for imagenet evaluation') + + +def set_jit_legacy(): + """ Set JIT executor to legacy w/ support for op fusion + This is hopefully a temporary need in 1.5/1.5.1/1.6 to restore performance due to changes + in the JIT exectutor. These API are not supported so could change. + """ + # + assert hasattr(torch._C, '_jit_set_profiling_executor'), "Old JIT behavior doesn't exist!" + torch._C._jit_set_profiling_executor(False) + torch._C._jit_set_profiling_mode(False) + torch._C._jit_override_can_fuse_on_gpu(True) + #torch._C._jit_set_texpr_fuser_enabled(True) def validate(args): @@ -103,6 +119,8 @@ def validate(args): model, test_time_pool = apply_test_time_pool(model, data_config, args) if args.torchscript: + if args.legacy_jit: + set_jit_legacy() torch.jit.optimized_execution(True) model = torch.jit.script(model) @@ -116,13 +134,16 @@ def validate(args): criterion = nn.CrossEntropyLoss().cuda() - #from torchvision.datasets import ImageNet - #dataset = ImageNet(args.data, split='val') if os.path.splitext(args.data)[1] == '.tar' and os.path.isfile(args.data): dataset = DatasetTar(args.data, load_bytes=args.tf_preprocessing, class_map=args.class_map) else: dataset = Dataset(args.data, load_bytes=args.tf_preprocessing, class_map=args.class_map) + if args.real_labels: + real_labels = RealLabelsImagenet(dataset.filenames(basename=True), real_json=args.real_labels) + else: + real_labels = None + crop_pct = 1.0 if test_time_pool else data_config['crop_pct'] loader = create_loader( dataset, @@ -148,7 +169,7 @@ def validate(args): input = torch.randn((args.batch_size,) + data_config['input_size']).cuda() model(input) end = time.time() - for i, (input, target) in enumerate(loader): + for batch_idx, (input, target) in enumerate(loader): if args.no_prefetcher: target = target.cuda() input = input.cuda() @@ -159,6 +180,9 @@ def validate(args): output = model(input) loss = criterion(output, target) + if real_labels is not None: + real_labels.add_result(output) + # measure accuracy and record loss acc1, acc5 = accuracy(output.data, target, topk=(1, 5)) losses.update(loss.item(), input.size(0)) @@ -169,25 +193,35 @@ def validate(args): batch_time.update(time.time() - end) end = time.time() - if i % args.log_freq == 0: + if batch_idx % args.log_freq == 0: logging.info( 'Test: [{0:>4d}/{1}] ' 'Time: {batch_time.val:.3f}s ({batch_time.avg:.3f}s, {rate_avg:>7.2f}/s) ' 'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f}) ' 'Acc@1: {top1.val:>7.3f} ({top1.avg:>7.3f}) ' 'Acc@5: {top5.val:>7.3f} ({top5.avg:>7.3f})'.format( - i, len(loader), batch_time=batch_time, + batch_idx, len(loader), batch_time=batch_time, rate_avg=input.size(0) / batch_time.avg, loss=losses, top1=top1, top5=top5)) - results = OrderedDict( - top1=round(top1.avg, 4), top1_err=round(100 - top1.avg, 4), - top5=round(top5.avg, 4), top5_err=round(100 - top5.avg, 4), + if real_labels is not None: + real_top1 = real_labels.get_accuracy(k=1) + real_top5 = real_labels.get_accuracy(k=5) + results = OrderedDict( + top1=round(real_top1, 4), top1_err=round(100 - real_top1, 4), + top5=round(real_top5, 4), top5_err=round(100 - real_top5, 4), + top1_original=round(top1.avg, 4), + top5_original=round(top5.avg, 4)) + else: + results = OrderedDict( + top1=round(top1.avg, 4), top1_err=round(100 - top1.avg, 4), + top5=round(top5.avg, 4), top5_err=round(100 - top5.avg, 4)) + results.update(OrderedDict( param_count=round(param_count / 1e6, 2), img_size=data_config['input_size'][-1], cropt_pct=crop_pct, - interpolation=data_config['interpolation']) - + interpolation=data_config['interpolation'] + )) logging.info(' * Acc@1 {:.3f} ({:.3f}) Acc@5 {:.3f} ({:.3f})'.format( results['top1'], results['top1_err'], results['top5'], results['top5_err']))