Working on feature extraction, interfaces refined, a number of models working, some in progress.

pull/175/head
Ross Wightman 5 years ago
parent 24e7535278
commit d23a2697d0

@ -7,3 +7,4 @@ from .transforms_factory import create_transform
from .mixup import mixup_batch, FastCollateMixup
from .auto_augment import RandAugment, AutoAugment, rand_augment_ops, auto_augment_policy,\
rand_augment_transform, auto_augment_transform
from .real_labels import RealLabelsImagenet

@ -0,0 +1,36 @@
import os
import json
import numpy as np
class RealLabelsImagenet:
def __init__(self, filenames, real_json='real.json', topk=(1, 5)):
with open(real_json) as real_labels:
real_labels = json.load(real_labels)
real_labels = {f'ILSVRC2012_val_{i + 1:08d}.JPEG': labels for i, labels in enumerate(real_labels)}
self.real_labels = real_labels
self.filenames = filenames
assert len(self.filenames) == len(self.real_labels)
self.topk = topk
self.is_correct = {k: [] for k in topk}
self.sample_idx = 0
def add_result(self, output):
maxk = max(self.topk)
_, pred_batch = output.topk(maxk, 1, True, True)
pred_batch = pred_batch.cpu().numpy()
for pred in pred_batch:
filename = self.filenames[self.sample_idx]
filename = os.path.basename(filename)
if self.real_labels[filename]:
for k in self.topk:
self.is_correct[k].append(
any([p in self.real_labels[filename] for p in pred[:k]]))
self.sample_idx += 1
def get_accuracy(self, k=None):
if k is None:
return {k: float(np.mean(self.is_correct[k] for k in self.topk))}
else:
return float(np.mean(self.is_correct[k])) * 100

@ -13,6 +13,7 @@ import torch.utils.checkpoint as cp
from torch.jit.annotations import List
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .features import FeatureNet
from .helpers import load_pretrained
from .layers import SelectAdaptivePool2d, BatchNormAct2d, create_norm_act, BlurPool2d
from .registry import register_model
@ -199,6 +200,9 @@ class DenseNet(nn.Module):
('norm0', norm_layer(num_init_features)),
('pool0', stem_pool),
]))
self.feature_info = [
dict(num_chs=num_init_features, reduction=2, module=f'features.norm{2 if deep_stem else 0}')]
current_stride = 4
# DenseBlocks
num_features = num_init_features
@ -212,21 +216,27 @@ class DenseNet(nn.Module):
drop_rate=drop_rate,
memory_efficient=memory_efficient
)
self.features.add_module('denseblock%d' % (i + 1), block)
module_name = f'denseblock{(i + 1)}'
self.features.add_module(module_name, block)
num_features = num_features + num_layers * growth_rate
transition_aa_layer = None if aa_stem_only else aa_layer
if i != len(block_config) - 1:
self.feature_info += [
dict(num_chs=num_features, reduction=current_stride, module='features.' + module_name)]
current_stride *= 2
trans = DenseTransition(
num_input_features=num_features, num_output_features=num_features // 2,
norm_layer=norm_layer, aa_layer=transition_aa_layer)
self.features.add_module('transition%d' % (i + 1), trans)
self.features.add_module(f'transition{i + 1}', trans)
num_features = num_features // 2
# Final batch norm
self.features.add_module('norm5', norm_layer(num_features))
# Linear layer
self.feature_info += [dict(num_chs=num_features, reduction=current_stride, module='features.norm5')]
self.num_features = num_features
# Linear layer
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
self.classifier = nn.Linear(self.num_features * self.global_pool.feat_mult(), num_classes)
@ -279,16 +289,14 @@ def _filter_torchvision_pretrained(state_dict):
def _densenet(variant, growth_rate, block_config, pretrained, **kwargs):
features = False
out_indices = None
if kwargs.pop('features_only', False):
assert False, 'Not Implemented' # TODO
load_strict = False
features = True
kwargs.pop('num_classes', 0)
model_class = DenseNet
else:
load_strict = True
model_class = DenseNet
out_indices = kwargs.pop('out_indices', (0, 1, 2, 3, 4))
default_cfg = default_cfgs[variant]
model = model_class(growth_rate=growth_rate, block_config=block_config, **kwargs)
model = DenseNet(growth_rate=growth_rate, block_config=block_config, **kwargs)
model.default_cfg = default_cfg
if pretrained:
load_pretrained(
@ -296,7 +304,9 @@ def _densenet(variant, growth_rate, block_config, pretrained, **kwargs):
num_classes=kwargs.get('num_classes', 0),
in_chans=kwargs.get('in_chans', 3),
filter_fn=_filter_torchvision_pretrained,
strict=load_strict)
strict=not features)
if features:
model = FeatureNet(model, out_indices, flatten_sequential=True)
return model

@ -34,6 +34,7 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCE
from .efficientnet_blocks import round_channels, resolve_bn_args, resolve_act_layer, BN_EPS_TF_DEFAULT
from .efficientnet_builder import EfficientNetBuilder, decode_arch_def, efficientnet_init_weights
from .feature_hooks import FeatureHooks
from .features import FeatureInfo
from .helpers import load_pretrained, adapt_model_from_file
from .layers import SelectAdaptivePool2d, create_conv2d
from .registry import register_model
@ -438,42 +439,22 @@ class EfficientNetFeatures(nn.Module):
channel_multiplier, channel_divisor, channel_min, output_stride, pad_type, act_layer, se_kwargs,
norm_layer, norm_kwargs, drop_path_rate, feature_location=feature_location, verbose=_DEBUG)
self.blocks = nn.Sequential(*builder(self._in_chs, block_args))
self._feature_info = builder.features # builder provides info about feature channels for each block
self.feature_info = FeatureInfo(builder.features, out_indices)
self._stage_to_feature_idx = {
v['stage_idx']: fi for fi, v in self._feature_info.items() if fi in self.out_indices}
v['stage_idx']: fi for fi, v in enumerate(self.feature_info) if fi in self.out_indices}
self._in_chs = builder.in_chs
efficientnet_init_weights(self)
if _DEBUG:
for k, v in self._feature_info.items():
print('Feature idx: {}: Name: {}, Channels: {}'.format(k, v['name'], v['num_chs']))
for fi, v in enumerate(self.feature_info):
print('Feature idx: {}: Name: {}, Channels: {}'.format(fi, v['module'], v['num_chs']))
# Register feature extraction hooks with FeatureHooks helper
self.feature_hooks = None
if feature_location != 'bottleneck':
hooks = [dict(
name=self._feature_info[idx]['module'],
type=self._feature_info[idx]['hook_type']) for idx in out_indices]
hooks = self.feature_info.get_by_key(keys=('module', 'hook_type'))
self.feature_hooks = FeatureHooks(hooks, self.named_modules())
def feature_channels(self, idx=None):
""" Feature Channel Shortcut
Returns feature channel count for each output index if idx == None. If idx is an integer, will
return feature channel count for that feature block index (independent of out_indices setting).
"""
if isinstance(idx, int):
return self._feature_info[idx]['num_chs']
return [self._feature_info[i]['num_chs'] for i in self.out_indices]
def feature_info(self, idx=None):
""" Feature Channel Shortcut
Returns feature channel count for each output index if idx == None. If idx is an integer, will
return feature channel count for that feature block index (independent of out_indices setting).
"""
if isinstance(idx, int):
return self._feature_info[idx]
return [self._feature_info[i] for i in self.out_indices]
def forward(self, x) -> List[torch.Tensor]:
x = self.conv_stem(x)
x = self.bn1(x)

@ -225,7 +225,7 @@ class EfficientNetBuilder:
# state updated during build, consumed by model
self.in_chs = None
self.features = OrderedDict()
self.features = []
def _round_channels(self, chs):
return round_channels(chs, self.channel_multiplier, self.channel_divisor, self.channel_min)
@ -291,7 +291,6 @@ class EfficientNetBuilder:
total_block_idx = 0
current_stride = 2
current_dilation = 1
feature_idx = 0
stages = []
# outer list of block_args defines the stacks ('stages' by some conventions)
for stage_idx, stage_block_args in enumerate(model_block_args):
@ -351,13 +350,15 @@ class EfficientNetBuilder:
# stash feature module name and channel info for model feature extraction
if extract_features:
feature_info = block.feature_info(extract_features)
if feature_info['module']:
feature_info['module'] = 'blocks.{}.{}.'.format(stage_idx, block_idx) + feature_info['module']
module_name = f'blocks.{stage_idx}.{block_idx}'
if 'module' in feature_info and feature_info['module']:
feature_info['module'] = '.'.join([module_name, feature_info['module']])
else:
feature_info['module'] = module_name
feature_info['stage_idx'] = stage_idx
feature_info['block_idx'] = block_idx
feature_info['reduction'] = current_stride
self.features[feature_idx] = feature_info
feature_idx += 1
self.features.append(feature_info)
total_block_idx += 1 # incr global block idx (across all stacks)
stages.append(nn.Sequential(*blocks))

@ -1,3 +1,9 @@
""" PyTorch Feature Hook Helper
This class helps gather features from a network via hooks specified on the module name.
Hacked together by Ross Wightman
"""
import torch
from collections import defaultdict, OrderedDict
@ -7,20 +13,21 @@ from typing import List
class FeatureHooks:
def __init__(self, hooks, named_modules):
def __init__(self, hooks, named_modules, output_as_dict=False):
# setup feature hooks
modules = {k: v for k, v in named_modules}
for h in hooks:
hook_name = h['name']
hook_name = h['module']
m = modules[hook_name]
hook_fn = partial(self._collect_output_hook, hook_name)
if h['type'] == 'forward_pre':
if h['hook_type'] == 'forward_pre':
m.register_forward_pre_hook(hook_fn)
elif h['type'] == 'forward':
elif h['hook_type'] == 'forward':
m.register_forward_hook(hook_fn)
else:
assert False, "Unsupported hook type"
self._feature_outputs = defaultdict(OrderedDict)
self.output_as_dict = output_as_dict
def _collect_output_hook(self, name, *args):
x = args[-1] # tensor we want is last argument, output for fwd, input for fwd_pre
@ -29,6 +36,9 @@ class FeatureHooks:
self._feature_outputs[x.device][name] = x
def get_output(self, device) -> List[torch.tensor]:
if self.output_as_dict:
output = self._feature_outputs[device]
else:
output = list(self._feature_outputs[device].values())
self._feature_outputs[device] = OrderedDict() # clear after reading
return output

@ -0,0 +1,251 @@
""" PyTorch Feature Extraction Helpers
A collection of classes, functions, modules to help extract features from models
and provide a common interface for describing them.
Hacked together by Ross Wightman
"""
from collections import OrderedDict
from typing import Dict, List, Tuple, Any
from copy import deepcopy
import torch
import torch.nn as nn
import torch.nn.functional as F
class FeatureInfo:
def __init__(self, feature_info: List[Dict], out_indices: Tuple[int]):
prev_reduction = 1
for fi in feature_info:
# sanity check the mandatory fields, there may be additional fields depending on the model
assert 'num_chs' in fi and fi['num_chs'] > 0
assert 'reduction' in fi and fi['reduction'] >= prev_reduction
prev_reduction = fi['reduction']
assert 'module' in fi
self._out_indices = out_indices
self._info = feature_info
def from_other(self, out_indices: Tuple[int]):
return FeatureInfo(deepcopy(self._info), out_indices)
def channels(self, idx=None):
""" feature channels accessor
if idx == None, returns feature channel count at each output index
if idx is an integer, return feature channel count for that feature module index
"""
if isinstance(idx, int):
return self._info[idx]['num_chs']
return [self._info[i]['num_chs'] for i in self._out_indices]
def reduction(self, idx=None):
""" feature reduction (output stride) accessor
if idx == None, returns feature reduction factor at each output index
if idx is an integer, return feature channel count at that feature module index
"""
if isinstance(idx, int):
return self._info[idx]['reduction']
return [self._info[i]['reduction'] for i in self._out_indices]
def module_name(self, idx=None):
""" feature module name accessor
if idx == None, returns feature module name at each output index
if idx is an integer, return feature module name at that feature module index
"""
if isinstance(idx, int):
return self._info[idx]['module']
return [self._info[i]['module'] for i in self._out_indices]
def get_by_key(self, idx=None, keys=None):
""" return info dicts for specified keys (or all if None) at specified idx (or out_indices if None)
"""
if isinstance(idx, int):
return self._info[idx] if keys is None else {k: self._info[idx][k] for k in keys}
if keys is None:
return [self._info[i] for i in self._out_indices]
else:
return [{k: self._info[i][k] for k in keys} for i in self._out_indices]
def __getitem__(self, item):
return self._info[item]
def __len__(self):
return len(self._info)
def _module_list(module, flatten_sequential=False):
# a yield/iter would be better for this but wouldn't be compatible with torchscript
ml = []
for name, module in module.named_children():
if flatten_sequential and isinstance(module, nn.Sequential):
# first level of Sequential containers is flattened into containing model
for child_name, child_module in module.named_children():
ml.append(('_'.join([name, child_name]), child_module))
else:
ml.append((name, module))
return ml
def _check_return_layers(input_return_layers, modules):
return_layers = {}
for k, v in input_return_layers.items():
ks = k.split('.')
assert 0 < len(ks) <= 2
return_layers['_'.join(ks)] = v
return_set = set(return_layers.keys())
sdiff = return_set - {name for name, _ in modules}
if sdiff:
raise ValueError(f'return_layers {sdiff} are not present in model')
return return_layers, return_set
class LayerGetterDict(nn.ModuleDict):
"""
Module wrapper that returns intermediate layers from a model as a dictionary
Originally based on IntermediateLayerGetter at
https://github.com/pytorch/vision/blob/d88d8961ae51507d0cb680329d985b1488b1b76b/torchvision/models/_utils.py
It has a strong assumption that the modules have been registered into the model in the same
order as they are used. This means that one should **not** reuse the same nn.Module twice
in the forward if you want this to work.
Additionally, it is only able to query submodules that are directly assigned to the model
class (`model.feature1`) or at most one Sequential container deep (`model.features.1`, so
long as `features` is a sequential container assigned to the model).
All Sequential containers that are directly assigned to the original model will have their
modules assigned to this module with the name `model.features.1` being changed to `model.features_1`
Arguments:
model (nn.Module): model on which we will extract the features
return_layers (Dict[name, new_name]): a dict containing the names
of the modules for which the activations will be returned as
the key of the dict, and the value of the dict is the name
of the returned activation (which the user can specify).
concat (bool): whether to concatenate intermediate features that are lists or tuples
vs select element [0]
flatten_sequential (bool): whether to flatten sequential modules assigned to model
"""
def __init__(self, model, return_layers, concat=False, flatten_sequential=False):
modules = _module_list(model, flatten_sequential=flatten_sequential)
self.return_layers, remaining = _check_return_layers(return_layers, modules)
layers = OrderedDict()
self.concat = concat
for name, module in modules:
layers[name] = module
if name in remaining:
remaining.remove(name)
if not remaining:
break
super(LayerGetterDict, self).__init__(layers)
def forward(self, x) -> Dict[Any, torch.Tensor]:
out = OrderedDict()
for name, module in self.items():
x = module(x)
if name in self.return_layers:
out_id = self.return_layers[name]
if isinstance(x, (tuple, list)):
# If model tap is a tuple or list, concat or select first element
# FIXME this may need to be more generic / flexible for some nets
out[out_id] = torch.cat(x, 1) if self.concat else x[0]
else:
out[out_id] = x
return out
class LayerGetterList(nn.Sequential):
"""
Module wrapper that returns intermediate layers from a model as a list
Originally based on IntermediateLayerGetter at
https://github.com/pytorch/vision/blob/d88d8961ae51507d0cb680329d985b1488b1b76b/torchvision/models/_utils.py
It has a strong assumption that the modules have been registered into the model in the same
order as they are used. This means that one should **not** reuse the same nn.Module twice
in the forward if you want this to work.
Additionally, it is only able to query submodules that are directly assigned to the model
class (`model.feature1`) or at most one Sequential container deep (`model.features.1`) so
long as `features` is a sequential container assigned to the model and flatten_sequent=True.
All Sequential containers that are directly assigned to the original model will have their
modules assigned to this module with the name `model.features.1` being changed to `model.features_1`
Arguments:
model (nn.Module): model on which we will extract the features
return_layers (Dict[name, new_name]): a dict containing the names
of the modules for which the activations will be returned as
the key of the dict, and the value of the dict is the name
of the returned activation (which the user can specify).
concat (bool): whether to concatenate intermediate features that are lists or tuples
vs select element [0]
flatten_sequential (bool): whether to flatten sequential modules assigned to model
"""
def __init__(self, model, return_layers, concat=False, flatten_sequential=False):
super(LayerGetterList, self).__init__()
modules = _module_list(model, flatten_sequential=flatten_sequential)
self.return_layers, remaining = _check_return_layers(return_layers, modules)
self.concat = concat
for name, module in modules:
self.add_module(name, module)
if name in remaining:
remaining.remove(name)
if not remaining:
break
def forward(self, x) -> List[torch.Tensor]:
out = []
for name, module in self.named_children():
x = module(x)
if name in self.return_layers:
if isinstance(x, (tuple, list)):
# If model tap is a tuple or list, concat or select first element
# FIXME this may need to be more generic / flexible for some nets
out.append(torch.cat(x, 1) if self.concat else x[0])
else:
out.append(x)
return out
def _resolve_feature_info(net, out_indices, feature_info=None):
if feature_info is None:
feature_info = getattr(net, 'feature_info')
if isinstance(feature_info, FeatureInfo):
return feature_info.from_other(out_indices)
elif isinstance(feature_info, (list, tuple)):
return FeatureInfo(net.feature_info, out_indices)
else:
assert False, "Provided feature_info is not valid"
class FeatureNet(nn.Module):
""" FeatureNet
Wrap a model and extract features as specified by the out indices, the network
is partially re-built from contained modules using the LayerGetters.
Please read the docstrings of the LayerGetter classes, they will not work on all models.
"""
def __init__(
self, net,
out_indices=(0, 1, 2, 3, 4), out_map=None, out_as_dict=False,
feature_info=None, feature_concat=False, flatten_sequential=False):
super(FeatureNet, self).__init__()
self.feature_info = _resolve_feature_info(net, out_indices, feature_info)
module_names = self.feature_info.module_name()
return_layers = {}
for i in range(len(out_indices)):
return_layers[module_names[i]] = out_map[i] if out_map is not None else out_indices[i]
lg_args = dict(return_layers=return_layers, concat=feature_concat, flatten_sequential=flatten_sequential)
self.body = LayerGetterDict(net, **lg_args) if out_as_dict else LayerGetterList(net, **lg_args)
def forward(self, x):
output = self.body(x)
return output

@ -13,16 +13,16 @@ import torch.nn.functional as F
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .helpers import load_pretrained
from .layers import SelectAdaptivePool2d
from .layers import SelectAdaptivePool2d, get_padding
from .registry import register_model
__all__ = ['Xception65', 'Xception71']
__all__ = ['Xception65']
default_cfgs = {
'gluon_xception65': {
'url': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/gluon_xception-7015a15c.pth',
'input_size': (3, 299, 299),
'crop_pct': 0.875,
'crop_pct': 0.903,
'pool_size': (10, 10),
'interpolation': 'bicubic',
'mean': IMAGENET_DEFAULT_MEAN,
@ -32,52 +32,13 @@ default_cfgs = {
'classifier': 'fc'
# The resize parameter of the validation transform should be 333, and make sure to center crop at 299x299
},
'gluon_xception71': {
'url': '',
'input_size': (3, 299, 299),
'crop_pct': 0.875,
'pool_size': (5, 5),
'interpolation': 'bicubic',
'mean': IMAGENET_DEFAULT_MEAN,
'std': IMAGENET_DEFAULT_STD,
'num_classes': 1000,
'first_conv': 'conv1',
'classifier': 'fc'
# The resize parameter of the validation transform should be 333, and make sure to center crop at 299x299
}
}
""" PADDING NOTES
The original PyTorch and Gluon impl of these models dutifully reproduced the
aligned padding added to Tensorflow models for Deeplab. This padding was compensating
for Tensorflow 'SAME' padding. PyTorch symmetric padding behaves the way we'd want it to.
So, I'm phasing out the 'fixed_padding' ported from TF and replacing with normal
PyTorch padding, some asserts to validate the equivalence for any scenario we'd
care about before removing altogether.
"""
_USE_FIXED_PAD = False
def _pytorch_padding(kernel_size, stride=1, dilation=1, **_):
if _USE_FIXED_PAD:
return 0 # FIXME remove once verified
else:
padding = ((stride - 1) + dilation * (kernel_size - 1)) // 2
# FIXME remove once verified
fp = _fixed_padding(kernel_size, dilation)
assert all(padding == p for p in fp)
return padding
def _fixed_padding(kernel_size, dilation):
kernel_size_effective = kernel_size + (kernel_size - 1) * (dilation - 1)
pad_total = kernel_size_effective - 1
pad_beg = pad_total // 2
pad_end = pad_total - pad_beg
return [pad_beg, pad_end, pad_beg, pad_end]
class SeparableConv2d(nn.Module):
@ -88,24 +49,16 @@ class SeparableConv2d(nn.Module):
self.kernel_size = kernel_size
self.dilation = dilation
padding = _fixed_padding(self.kernel_size, self.dilation)
if _USE_FIXED_PAD and any(p > 0 for p in padding):
self.fixed_padding = nn.ZeroPad2d(padding)
else:
self.fixed_padding = None
# depthwise convolution
padding = get_padding(kernel_size, stride, dilation)
self.conv_dw = nn.Conv2d(
inplanes, inplanes, kernel_size, stride=stride,
padding=_pytorch_padding(kernel_size, stride, dilation), dilation=dilation, groups=inplanes, bias=bias)
padding=padding, dilation=dilation, groups=inplanes, bias=bias)
self.bn = norm_layer(num_features=inplanes, **norm_kwargs)
# pointwise convolution
self.conv_pw = nn.Conv2d(inplanes, planes, kernel_size=1, bias=bias)
def forward(self, x):
if self.fixed_padding is not None:
# FIXME remove once verified
x = self.fixed_padding(x)
x = self.conv_dw(x)
x = self.bn(x)
x = self.conv_pw(x)
@ -113,58 +66,37 @@ class SeparableConv2d(nn.Module):
class Block(nn.Module):
def __init__(self, inplanes, planes, num_reps, stride=1, dilation=1, norm_layer=None,
norm_kwargs=None, start_with_relu=True, grow_first=True, is_last=False):
def __init__(self, inplanes, planes, stride=1, dilation=1, start_with_relu=True,
norm_layer=None, norm_kwargs=None, ):
super(Block, self).__init__()
norm_kwargs = norm_kwargs if norm_kwargs is not None else {}
if planes != inplanes or stride != 1:
if isinstance(planes, (list, tuple)):
assert len(planes) == 3
else:
planes = (planes,) * 3
outplanes = planes[-1]
if outplanes != inplanes or stride != 1:
self.skip = nn.Sequential()
self.skip.add_module('conv1', nn.Conv2d(
inplanes, planes, 1, stride=stride, bias=False)),
self.skip.add_module('bn1', norm_layer(num_features=planes, **norm_kwargs))
inplanes, outplanes, 1, stride=stride, bias=False)),
self.skip.add_module('bn1', norm_layer(num_features=outplanes, **norm_kwargs))
else:
self.skip = None
rep = OrderedDict()
l = 1
filters = inplanes
if grow_first:
if start_with_relu:
rep['act%d' % l] = nn.ReLU(inplace=False) # NOTE: silent failure if inplace=True here
rep['conv%d' % l] = SeparableConv2d(
inplanes, planes, 3, 1, dilation, norm_layer=norm_layer, norm_kwargs=norm_kwargs)
rep['bn%d' % l] = norm_layer(num_features=planes, **norm_kwargs)
filters = planes
l += 1
for _ in range(num_reps - 1):
if grow_first or start_with_relu:
# FIXME being conservative with inplace here, think it's fine to leave True?
rep['act%d' % l] = nn.ReLU(inplace=grow_first or not start_with_relu)
rep['conv%d' % l] = SeparableConv2d(
filters, filters, 3, 1, dilation, norm_layer=norm_layer, norm_kwargs=norm_kwargs)
rep['bn%d' % l] = norm_layer(num_features=filters, **norm_kwargs)
l += 1
if not grow_first:
rep['act%d' % l] = nn.ReLU(inplace=True)
rep['conv%d' % l] = SeparableConv2d(
inplanes, planes, 3, 1, dilation, norm_layer=norm_layer, norm_kwargs=norm_kwargs)
rep['bn%d' % l] = norm_layer(num_features=planes, **norm_kwargs)
l += 1
if stride != 1:
rep['act%d' % l] = nn.ReLU(inplace=True)
rep['conv%d' % l] = SeparableConv2d(
planes, planes, 3, stride, norm_layer=norm_layer, norm_kwargs=norm_kwargs)
rep['bn%d' % l] = norm_layer(num_features=planes, **norm_kwargs)
l += 1
elif is_last:
rep['act%d' % l] = nn.ReLU(inplace=True)
rep['conv%d' % l] = SeparableConv2d(
planes, planes, 3, 1, dilation, norm_layer=norm_layer, norm_kwargs=norm_kwargs)
rep['bn%d' % l] = norm_layer(num_features=planes, **norm_kwargs)
l += 1
for i in range(3):
rep['act%d' % (i + 1)] = nn.ReLU(inplace=True)
rep['conv%d' % (i + 1)] = SeparableConv2d(
inplanes, planes[i], 3, stride=stride if i == 2 else 1, dilation=dilation,
norm_layer=norm_layer, norm_kwargs=norm_kwargs)
rep['bn%d' % (i + 1)] = norm_layer(planes[i], **norm_kwargs)
inplanes = planes[i]
if not start_with_relu:
del rep['act1']
else:
rep['act1'] = nn.ReLU(inplace=False)
self.rep = nn.Sequential(rep)
def forward(self, x):
@ -176,7 +108,10 @@ class Block(nn.Module):
class Xception65(nn.Module):
"""Modified Aligned Xception
"""Modified Aligned Xception.
NOTE: only the 65 layer version is included here, the 71 layer variant
was not correct and had no pretrained weights
"""
def __init__(self, num_classes=1000, in_chans=3, output_stride=32, norm_layer=nn.BatchNorm2d,
@ -212,25 +147,21 @@ class Xception65(nn.Module):
self.bn2 = norm_layer(num_features=64)
self.block1 = Block(
64, 128, num_reps=2, stride=2,
norm_layer=norm_layer, norm_kwargs=norm_kwargs, start_with_relu=False)
64, 128, stride=2, start_with_relu=False, norm_layer=norm_layer, norm_kwargs=norm_kwargs)
self.block2 = Block(
128, 256, num_reps=2, stride=2,
norm_layer=norm_layer, norm_kwargs=norm_kwargs, start_with_relu=False, grow_first=True)
128, 256, stride=2, start_with_relu=False, norm_layer=norm_layer, norm_kwargs=norm_kwargs)
self.block3 = Block(
256, 728, num_reps=2, stride=entry_block3_stride,
norm_layer=norm_layer, norm_kwargs=norm_kwargs, start_with_relu=True, grow_first=True, is_last=True)
256, 728, stride=entry_block3_stride, norm_layer=norm_layer, norm_kwargs=norm_kwargs)
# Middle flow
self.mid = nn.Sequential(OrderedDict([('block%d' % i, Block(
728, 728, num_reps=3, stride=1, dilation=middle_block_dilation,
norm_layer=norm_layer, norm_kwargs=norm_kwargs, start_with_relu=True, grow_first=True))
for i in range(4, 20)]))
728, 728, stride=1, dilation=middle_block_dilation,
norm_layer=norm_layer, norm_kwargs=norm_kwargs)) for i in range(4, 20)]))
# Exit flow
self.block20 = Block(
728, 1024, num_reps=2, stride=exit_block20_stride, dilation=exit_block_dilations[0],
norm_layer=norm_layer, norm_kwargs=norm_kwargs, start_with_relu=True, grow_first=False, is_last=True)
728, (728, 1024, 1024), stride=exit_block20_stride, dilation=exit_block_dilations[0],
norm_layer=norm_layer, norm_kwargs=norm_kwargs)
self.conv3 = SeparableConv2d(
1024, 1536, 3, stride=1, dilation=exit_block_dilations[1],
@ -305,147 +236,6 @@ class Xception65(nn.Module):
return x
class Xception71(nn.Module):
"""Modified Aligned Xception
"""
def __init__(self, num_classes=1000, in_chans=3, output_stride=32, norm_layer=nn.BatchNorm2d,
norm_kwargs=None, drop_rate=0., global_pool='avg'):
super(Xception71, self).__init__()
self.num_classes = num_classes
self.drop_rate = drop_rate
norm_kwargs = norm_kwargs if norm_kwargs is not None else {}
if output_stride == 32:
entry_block3_stride = 2
exit_block20_stride = 2
middle_block_dilation = 1
exit_block_dilations = (1, 1)
elif output_stride == 16:
entry_block3_stride = 2
exit_block20_stride = 1
middle_block_dilation = 1
exit_block_dilations = (1, 2)
elif output_stride == 8:
entry_block3_stride = 1
exit_block20_stride = 1
middle_block_dilation = 2
exit_block_dilations = (2, 4)
else:
raise NotImplementedError
# Entry flow
self.conv1 = nn.Conv2d(in_chans, 32, kernel_size=3, stride=2, padding=1, bias=False)
self.bn1 = norm_layer(num_features=32, **norm_kwargs)
self.relu = nn.ReLU(inplace=True)
self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1, bias=False)
self.bn2 = norm_layer(num_features=64)
self.block1 = Block(
64, 128, num_reps=2, stride=2, norm_layer=norm_layer,
norm_kwargs=norm_kwargs, start_with_relu=False)
self.block2 = nn.Sequential(*[
Block(
128, 256, num_reps=2, stride=1, norm_layer=norm_layer,
norm_kwargs=norm_kwargs, start_with_relu=False, grow_first=True),
Block(
256, 256, num_reps=2, stride=2, norm_layer=norm_layer,
norm_kwargs=norm_kwargs, start_with_relu=False, grow_first=True),
Block(
256, 728, num_reps=2, stride=2, norm_layer=norm_layer,
norm_kwargs=norm_kwargs, start_with_relu=False, grow_first=True)])
self.block3 = Block(
728, 728, num_reps=2, stride=entry_block3_stride, norm_layer=norm_layer,
norm_kwargs=norm_kwargs, start_with_relu=True, grow_first=True, is_last=True)
# Middle flow
self.mid = nn.Sequential(OrderedDict([('block%d' % i, Block(
728, 728, num_reps=3, stride=1, dilation=middle_block_dilation,
norm_layer=norm_layer, norm_kwargs=norm_kwargs, start_with_relu=True, grow_first=True))
for i in range(4, 20)]))
# Exit flow
self.block20 = Block(
728, 1024, num_reps=2, stride=exit_block20_stride, dilation=exit_block_dilations[0],
norm_layer=norm_layer, norm_kwargs=norm_kwargs, start_with_relu=True, grow_first=False, is_last=True)
self.conv3 = SeparableConv2d(
1024, 1536, 3, stride=1, dilation=exit_block_dilations[1],
norm_layer=norm_layer, norm_kwargs=norm_kwargs)
self.bn3 = norm_layer(num_features=1536, **norm_kwargs)
self.conv4 = SeparableConv2d(
1536, 1536, 3, stride=1, dilation=exit_block_dilations[1],
norm_layer=norm_layer, norm_kwargs=norm_kwargs)
self.bn4 = norm_layer(num_features=1536, **norm_kwargs)
self.num_features = 2048
self.conv5 = SeparableConv2d(
1536, self.num_features, 3, stride=1, dilation=exit_block_dilations[1],
norm_layer=norm_layer, norm_kwargs=norm_kwargs)
self.bn5 = norm_layer(num_features=self.num_features, **norm_kwargs)
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
self.fc = nn.Linear(self.num_features * self.global_pool.feat_mult(), num_classes)
def get_classifier(self):
return self.fc
def reset_classifier(self, num_classes, global_pool='avg'):
self.num_classes = num_classes
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
if num_classes:
num_features = self.num_features * self.global_pool.feat_mult()
self.fc = nn.Linear(num_features, num_classes)
else:
self.fc = nn.Identity()
def forward_features(self, x):
# Entry flow
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.conv2(x)
x = self.bn2(x)
x = self.relu(x)
x = self.block1(x)
# add relu here
x = self.relu(x)
# low_level_feat = x
x = self.block2(x)
# c2 = x
x = self.block3(x)
# Middle flow
x = self.mid(x)
# c3 = x
# Exit flow
x = self.block20(x)
x = self.relu(x)
x = self.conv3(x)
x = self.bn3(x)
x = self.relu(x)
x = self.conv4(x)
x = self.bn4(x)
x = self.relu(x)
x = self.conv5(x)
x = self.bn5(x)
x = self.relu(x)
return x
def forward(self, x):
x = self.forward_features(x)
x = self.global_pool(x).flatten(1)
if self.drop_rate:
F.dropout(x, self.drop_rate, training=self.training)
x = self.fc(x)
return x
@register_model
def gluon_xception65(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
""" Modified Aligned Xception-65
@ -456,15 +246,3 @@ def gluon_xception65(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
if pretrained:
load_pretrained(model, default_cfg, num_classes, in_chans)
return model
@register_model
def gluon_xception71(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
""" Modified Aligned Xception-71
"""
default_cfg = default_cfgs['gluon_xception71']
model = Xception71(num_classes=num_classes, in_chans=in_chans, **kwargs)
model.default_cfg = default_cfg
if pretrained:
load_pretrained(model, default_cfg, num_classes, in_chans)
return model

@ -7,6 +7,7 @@ import torch.nn as nn
import torch.nn.functional as F
from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
from .features import FeatureNet
from .helpers import load_pretrained
from .layers import SelectAdaptivePool2d
from .registry import register_model
@ -231,9 +232,13 @@ class InceptionResnetV2(nn.Module):
self.conv2d_1a = BasicConv2d(in_chans, 32, kernel_size=3, stride=2)
self.conv2d_2a = BasicConv2d(32, 32, kernel_size=3, stride=1)
self.conv2d_2b = BasicConv2d(32, 64, kernel_size=3, stride=1, padding=1)
self.feature_info = [dict(num_chs=64, reduction=2, module='conv2d_2b')]
self.maxpool_3a = nn.MaxPool2d(3, stride=2)
self.conv2d_3b = BasicConv2d(64, 80, kernel_size=1, stride=1)
self.conv2d_4a = BasicConv2d(80, 192, kernel_size=3, stride=1)
self.feature_info += [dict(num_chs=192, reduction=4, module='conv2d_4a')]
self.maxpool_5a = nn.MaxPool2d(3, stride=2)
self.mixed_5b = Mixed_5b()
self.repeat = nn.Sequential(
@ -248,6 +253,8 @@ class InceptionResnetV2(nn.Module):
Block35(scale=0.17),
Block35(scale=0.17)
)
self.feature_info += [dict(num_chs=320, reduction=8, module='repeat')]
self.mixed_6a = Mixed_6a()
self.repeat_1 = nn.Sequential(
Block17(scale=0.10),
@ -271,6 +278,8 @@ class InceptionResnetV2(nn.Module):
Block17(scale=0.10),
Block17(scale=0.10)
)
self.feature_info += [dict(num_chs=1088, reduction=16, module='repeat_1')]
self.mixed_7a = Mixed_7a()
self.repeat_2 = nn.Sequential(
Block8(scale=0.20),
@ -285,6 +294,8 @@ class InceptionResnetV2(nn.Module):
)
self.block8 = Block8(no_relu=True)
self.conv2d_7b = BasicConv2d(2080, self.num_features, kernel_size=1, stride=1)
self.feature_info += [dict(num_chs=self.num_features, reduction=32, module='conv2d_7b')]
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
# NOTE some variants/checkpoints for this model may have 'last_linear' as the name for the FC
self.classif = nn.Linear(self.num_features * self.global_pool.feat_mult(), num_classes)
@ -328,30 +339,34 @@ class InceptionResnetV2(nn.Module):
return x
def _inception_resnet_v2(variant, pretrained=False, **kwargs):
load_strict, features, out_indices = True, False, None
if kwargs.pop('features_only', False):
load_strict, features, out_indices = False, True, kwargs.pop('out_indices', (0, 1, 2, 3, 4))
kwargs.pop('num_classes', 0)
model = InceptionResnetV2(**kwargs)
model.default_cfg = default_cfgs[variant]
if pretrained:
load_pretrained(
model,
num_classes=kwargs.get('num_classes', 0), in_chans=kwargs.get('in_chans', 3), strict=load_strict)
if features:
model = FeatureNet(model, out_indices)
return model
@register_model
def inception_resnet_v2(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
def inception_resnet_v2(pretrained=False, **kwargs):
r"""InceptionResnetV2 model architecture from the
`"InceptionV4, Inception-ResNet..." <https://arxiv.org/abs/1602.07261>` paper.
"""
default_cfg = default_cfgs['inception_resnet_v2']
model = InceptionResnetV2(num_classes=num_classes, in_chans=in_chans, **kwargs)
model.default_cfg = default_cfg
if pretrained:
load_pretrained(model, default_cfg, num_classes, in_chans)
return model
return _inception_resnet_v2('inception_resnet_v2', pretrained=pretrained, **kwargs)
@register_model
def ens_adv_inception_resnet_v2(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
def ens_adv_inception_resnet_v2(pretrained=False, **kwargs):
r""" Ensemble Adversarially trained InceptionResnetV2 model architecture
As per https://arxiv.org/abs/1705.07204 and
https://github.com/tensorflow/models/tree/master/research/adv_imagenet_models.
"""
default_cfg = default_cfgs['ens_adv_inception_resnet_v2']
model = InceptionResnetV2(num_classes=num_classes, in_chans=in_chans, **kwargs)
model.default_cfg = default_cfg
if pretrained:
load_pretrained(model, default_cfg, num_classes, in_chans)
return model
return _inception_resnet_v2('ens_adv_inception_resnet_v2', pretrained=pretrained, **kwargs)

@ -17,6 +17,7 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCE
from .efficientnet_blocks import round_channels, resolve_bn_args, resolve_act_layer, BN_EPS_TF_DEFAULT
from .efficientnet_builder import EfficientNetBuilder, decode_arch_def, efficientnet_init_weights
from .feature_hooks import FeatureHooks
from .features import FeatureInfo
from .helpers import load_pretrained
from .layers import SelectAdaptivePool2d, create_conv2d, get_act_fn, hard_sigmoid
from .registry import register_model
@ -182,22 +183,20 @@ class MobileNetV3Features(nn.Module):
channel_multiplier, 8, None, output_stride, pad_type, act_layer, se_kwargs,
norm_layer, norm_kwargs, drop_path_rate, feature_location=feature_location, verbose=_DEBUG)
self.blocks = nn.Sequential(*builder(self._in_chs, block_args))
self._feature_info = builder.features # builder provides info about feature channels for each block
self.feature_info = FeatureInfo(builder.features, out_indices)
self._stage_to_feature_idx = {
v['stage_idx']: fi for fi, v in self._feature_info.items() if fi in self.out_indices}
v['stage_idx']: fi for fi, v in enumerate(self.feature_info) if fi in self.out_indices}
self._in_chs = builder.in_chs
efficientnet_init_weights(self)
if _DEBUG:
for k, v in self._feature_info.items():
print('Feature idx: {}: Name: {}, Channels: {}'.format(k, v['name'], v['num_chs']))
for fi, v in enumerate(self.feature_info):
print('Feature idx: {}: Name: {}, Channels: {}'.format(fi, v['module'], v['num_chs']))
# Register feature extraction hooks with FeatureHooks helper
self.feature_hooks = None
if feature_location != 'bottleneck':
hooks = [dict(
name=self._feature_info[idx]['module'],
type=self._feature_info[idx]['hook_type']) for idx in out_indices]
hooks = self.feature_info.get_by_key(keys=('module', 'hook_type'))
self.feature_hooks = FeatureHooks(hooks, self.named_modules())
def feature_channels(self, idx=None):
@ -206,17 +205,8 @@ class MobileNetV3Features(nn.Module):
return feature channel count for that feature block index (independent of out_indices setting).
"""
if isinstance(idx, int):
return self._feature_info[idx]['num_chs']
return [self._feature_info[i]['num_chs'] for i in self.out_indices]
def feature_info(self, idx=None):
""" Feature Channel Shortcut
Returns feature channel count for each output index if idx == None. If idx is an integer, will
return feature channel count for that feature block index (independent of out_indices setting).
"""
if isinstance(idx, int):
return self._feature_info[idx]
return [self._feature_info[i] for i in self.out_indices]
return self.feature_info[idx]['num_chs']
return [self.feature_info[i]['num_chs'] for i in self.out_indices]
def forward(self, x) -> List[torch.Tensor]:
x = self.conv_stem(x)

@ -3,7 +3,7 @@ import torch.nn as nn
import torch.nn.functional as F
from .helpers import load_pretrained
from .layers import SelectAdaptivePool2d
from .layers import SelectAdaptivePool2d, ConvBnAct, create_conv2d, create_pool2d
from .registry import register_model
__all__ = ['NASNetALarge']
@ -24,43 +24,31 @@ default_cfgs = {
}
class MaxPoolPad(nn.Module):
class ActConvBn(nn.Module):
def __init__(self):
super(MaxPoolPad, self).__init__()
self.pad = nn.ZeroPad2d((1, 0, 1, 0))
self.pool = nn.MaxPool2d(3, stride=2, padding=1)
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=''):
super(ActConvBn, self).__init__()
self.act = nn.ReLU()
self.conv = create_conv2d(
in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding)
self.bn = nn.BatchNorm2d(out_channels, eps=0.001, momentum=0.1)
def forward(self, x):
x = self.pad(x)
x = self.pool(x)
x = x[:, :, 1:, 1:]
return x
class AvgPoolPad(nn.Module):
def __init__(self, stride=2, padding=1):
super(AvgPoolPad, self).__init__()
self.pad = nn.ZeroPad2d((1, 0, 1, 0))
self.pool = nn.AvgPool2d(3, stride=stride, padding=padding, count_include_pad=False)
def forward(self, x):
x = self.pad(x)
x = self.pool(x)
x = x[:, :, 1:, 1:]
x = self.act(x)
x = self.conv(x)
x = self.bn(x)
return x
class SeparableConv2d(nn.Module):
def __init__(self, in_channels, out_channels, dw_kernel, dw_stride, dw_padding, bias=False):
def __init__(self, in_channels, out_channels, kernel_size, stride, padding=''):
super(SeparableConv2d, self).__init__()
self.depthwise_conv2d = nn.Conv2d(
in_channels, in_channels, dw_kernel,
stride=dw_stride, padding=dw_padding,
bias=bias, groups=in_channels)
self.pointwise_conv2d = nn.Conv2d(in_channels, out_channels, 1, stride=1, bias=bias)
self.depthwise_conv2d = create_conv2d(
in_channels, in_channels, kernel_size=kernel_size,
stride=stride, padding=padding, groups=in_channels)
self.pointwise_conv2d = create_conv2d(
in_channels, out_channels, kernel_size=1, padding=0)
def forward(self, x):
x = self.depthwise_conv2d(x)
@ -70,87 +58,48 @@ class SeparableConv2d(nn.Module):
class BranchSeparables(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, stride, padding, bias=False):
def __init__(self, in_channels, out_channels, kernel_size, stride=1, pad_type='', stem_cell=False):
super(BranchSeparables, self).__init__()
self.relu = nn.ReLU()
self.separable_1 = SeparableConv2d(in_channels, in_channels, kernel_size, stride, padding, bias=bias)
self.bn_sep_1 = nn.BatchNorm2d(in_channels, eps=0.001, momentum=0.1, affine=True)
self.relu1 = nn.ReLU()
self.separable_2 = SeparableConv2d(in_channels, out_channels, kernel_size, 1, padding, bias=bias)
self.bn_sep_2 = nn.BatchNorm2d(out_channels, eps=0.001, momentum=0.1, affine=True)
def forward(self, x):
x = self.relu(x)
x = self.separable_1(x)
x = self.bn_sep_1(x)
x = self.relu1(x)
x = self.separable_2(x)
x = self.bn_sep_2(x)
return x
class BranchSeparablesStem(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, stride, padding, bias=False):
super(BranchSeparablesStem, self).__init__()
self.relu = nn.ReLU()
self.separable_1 = SeparableConv2d(in_channels, out_channels, kernel_size, stride, padding, bias=bias)
self.bn_sep_1 = nn.BatchNorm2d(out_channels, eps=0.001, momentum=0.1, affine=True)
self.relu1 = nn.ReLU()
self.separable_2 = SeparableConv2d(out_channels, out_channels, kernel_size, 1, padding, bias=bias)
self.bn_sep_2 = nn.BatchNorm2d(out_channels, eps=0.001, momentum=0.1, affine=True)
middle_channels = out_channels if stem_cell else in_channels
self.act_1 = nn.ReLU()
self.separable_1 = SeparableConv2d(
in_channels, middle_channels, kernel_size, stride=stride, padding=pad_type)
self.bn_sep_1 = nn.BatchNorm2d(middle_channels, eps=0.001, momentum=0.1)
self.act_2 = nn.ReLU(inplace=True)
self.separable_2 = SeparableConv2d(
middle_channels, out_channels, kernel_size, stride=1, padding=pad_type)
self.bn_sep_2 = nn.BatchNorm2d(out_channels, eps=0.001, momentum=0.1)
def forward(self, x):
x = self.relu(x)
x = self.act_1(x)
x = self.separable_1(x)
x = self.bn_sep_1(x)
x = self.relu1(x)
x = self.separable_2(x)
x = self.bn_sep_2(x)
return x
class BranchSeparablesReduction(BranchSeparables):
def __init__(self, in_channels, out_channels, kernel_size, stride, padding, z_padding=1, bias=False):
BranchSeparables.__init__(self, in_channels, out_channels, kernel_size, stride, padding, bias)
self.padding = nn.ZeroPad2d((z_padding, 0, z_padding, 0))
def forward(self, x):
x = self.relu(x)
x = self.padding(x)
x = self.separable_1(x)
x = x[:, :, 1:, 1:].contiguous()
x = self.bn_sep_1(x)
x = self.relu1(x)
x = self.act_2(x)
x = self.separable_2(x)
x = self.bn_sep_2(x)
return x
class CellStem0(nn.Module):
def __init__(self, stem_size, num_channels=42):
def __init__(self, stem_size, num_channels=42, pad_type=''):
super(CellStem0, self).__init__()
self.num_channels = num_channels
self.stem_size = stem_size
self.conv_1x1 = nn.Sequential()
self.conv_1x1.add_module('relu', nn.ReLU())
self.conv_1x1.add_module('conv', nn.Conv2d(self.stem_size, self.num_channels, 1, stride=1, bias=False))
self.conv_1x1.add_module('bn', nn.BatchNorm2d(self.num_channels, eps=0.001, momentum=0.1, affine=True))
self.conv_1x1 = ActConvBn(self.stem_size, self.num_channels, 1, stride=1)
self.comb_iter_0_left = BranchSeparables(self.num_channels, self.num_channels, 5, 2, 2)
self.comb_iter_0_right = BranchSeparablesStem(self.stem_size, self.num_channels, 7, 2, 3, bias=False)
self.comb_iter_0_left = BranchSeparables(self.num_channels, self.num_channels, 5, 2, pad_type)
self.comb_iter_0_right = BranchSeparables(self.stem_size, self.num_channels, 7, 2, pad_type, stem_cell=True)
self.comb_iter_1_left = nn.MaxPool2d(3, stride=2, padding=1)
self.comb_iter_1_right = BranchSeparablesStem(self.stem_size, self.num_channels, 7, 2, 3, bias=False)
self.comb_iter_1_left = create_pool2d('max', 3, 2, padding=pad_type)
self.comb_iter_1_right = BranchSeparables(self.stem_size, self.num_channels, 7, 2, pad_type, stem_cell=True)
self.comb_iter_2_left = nn.AvgPool2d(3, stride=2, padding=1, count_include_pad=False)
self.comb_iter_2_right = BranchSeparablesStem(self.stem_size, self.num_channels, 5, 2, 2, bias=False)
self.comb_iter_2_left = create_pool2d('avg', 3, 2, count_include_pad=False, padding=pad_type)
self.comb_iter_2_right = BranchSeparables(self.stem_size, self.num_channels, 5, 2, pad_type, stem_cell=True)
self.comb_iter_3_right = nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False)
self.comb_iter_3_right = create_pool2d('avg', 3, 1, count_include_pad=False, padding=pad_type)
self.comb_iter_4_left = BranchSeparables(self.num_channels, self.num_channels, 3, 1, 1, bias=False)
self.comb_iter_4_right = nn.MaxPool2d(3, stride=2, padding=1)
self.comb_iter_4_left = BranchSeparables(self.num_channels, self.num_channels, 3, 1, pad_type)
self.comb_iter_4_right = create_pool2d('max', 3, 2, padding=pad_type)
def forward(self, x):
x1 = self.conv_1x1(x)
@ -180,51 +129,46 @@ class CellStem0(nn.Module):
class CellStem1(nn.Module):
def __init__(self, stem_size, num_channels):
def __init__(self, stem_size, num_channels, pad_type=''):
super(CellStem1, self).__init__()
self.num_channels = num_channels
self.stem_size = stem_size
self.conv_1x1 = nn.Sequential()
self.conv_1x1.add_module('relu', nn.ReLU())
self.conv_1x1.add_module('conv', nn.Conv2d(2 * self.num_channels, self.num_channels, 1, stride=1, bias=False))
self.conv_1x1.add_module('bn', nn.BatchNorm2d(self.num_channels, eps=0.001, momentum=0.1, affine=True))
self.conv_1x1 = ActConvBn(2 * self.num_channels, self.num_channels, 1, stride=1)
self.relu = nn.ReLU()
self.act = nn.ReLU()
self.path_1 = nn.Sequential()
self.path_1.add_module('avgpool', nn.AvgPool2d(1, stride=2, count_include_pad=False))
self.path_1.add_module('conv', nn.Conv2d(self.stem_size, self.num_channels // 2, 1, stride=1, bias=False))
self.path_2 = nn.ModuleList()
self.path_2.add_module('pad', nn.ZeroPad2d((0, 1, 0, 1)))
self.path_2 = nn.Sequential()
self.path_2.add_module('pad', nn.ZeroPad2d((-1, 1, -1, 1)))
self.path_2.add_module('avgpool', nn.AvgPool2d(1, stride=2, count_include_pad=False))
self.path_2.add_module('conv', nn.Conv2d(self.stem_size, self.num_channels // 2, 1, stride=1, bias=False))
self.final_path_bn = nn.BatchNorm2d(self.num_channels, eps=0.001, momentum=0.1, affine=True)
self.final_path_bn = nn.BatchNorm2d(self.num_channels, eps=0.001, momentum=0.1)
self.comb_iter_0_left = BranchSeparables(self.num_channels, self.num_channels, 5, 2, 2, bias=False)
self.comb_iter_0_right = BranchSeparables(self.num_channels, self.num_channels, 7, 2, 3, bias=False)
self.comb_iter_0_left = BranchSeparables(self.num_channels, self.num_channels, 5, 2, pad_type)
self.comb_iter_0_right = BranchSeparables(self.num_channels, self.num_channels, 7, 2, pad_type)
self.comb_iter_1_left = nn.MaxPool2d(3, stride=2, padding=1)
self.comb_iter_1_right = BranchSeparables(self.num_channels, self.num_channels, 7, 2, 3, bias=False)
self.comb_iter_1_left = create_pool2d('max', 3, 2, padding=pad_type)
self.comb_iter_1_right = BranchSeparables(self.num_channels, self.num_channels, 7, 2, pad_type)
self.comb_iter_2_left = nn.AvgPool2d(3, stride=2, padding=1, count_include_pad=False)
self.comb_iter_2_right = BranchSeparables(self.num_channels, self.num_channels, 5, 2, 2, bias=False)
self.comb_iter_2_left = create_pool2d('avg', 3, 2, count_include_pad=False, padding=pad_type)
self.comb_iter_2_right = BranchSeparables(self.num_channels, self.num_channels, 5, 2, pad_type)
self.comb_iter_3_right = nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False)
self.comb_iter_3_right = create_pool2d('avg', 3, 1, count_include_pad=False, padding=pad_type)
self.comb_iter_4_left = BranchSeparables(self.num_channels, self.num_channels, 3, 1, 1, bias=False)
self.comb_iter_4_right = nn.MaxPool2d(3, stride=2, padding=1)
self.comb_iter_4_left = BranchSeparables(self.num_channels, self.num_channels, 3, 1, pad_type)
self.comb_iter_4_right = create_pool2d('max', 3, 2, padding=pad_type)
def forward(self, x_conv0, x_stem_0):
x_left = self.conv_1x1(x_stem_0)
x_relu = self.relu(x_conv0)
x_relu = self.act(x_conv0)
# path 1
x_path1 = self.path_1(x_relu)
# path 2
x_path2 = self.path_2.pad(x_relu)
x_path2 = x_path2[:, :, 1:, 1:]
x_path2 = self.path_2.avgpool(x_path2)
x_path2 = self.path_2.conv(x_path2)
x_path2 = self.path_2(x_relu)
# final path
x_right = self.final_path_bn(torch.cat([x_path1, x_path2], 1))
@ -253,49 +197,40 @@ class CellStem1(nn.Module):
class FirstCell(nn.Module):
def __init__(self, in_channels_left, out_channels_left, in_channels_right, out_channels_right):
def __init__(self, in_chs_left, out_chs_left, in_chs_right, out_chs_right, pad_type=''):
super(FirstCell, self).__init__()
self.conv_1x1 = nn.Sequential()
self.conv_1x1.add_module('relu', nn.ReLU())
self.conv_1x1.add_module('conv', nn.Conv2d(in_channels_right, out_channels_right, 1, stride=1, bias=False))
self.conv_1x1.add_module('bn', nn.BatchNorm2d(out_channels_right, eps=0.001, momentum=0.1, affine=True))
self.conv_1x1 = ActConvBn(in_chs_right, out_chs_right, 1, stride=1)
self.relu = nn.ReLU()
self.act = nn.ReLU()
self.path_1 = nn.Sequential()
self.path_1.add_module('avgpool', nn.AvgPool2d(1, stride=2, count_include_pad=False))
self.path_1.add_module('conv', nn.Conv2d(in_channels_left, out_channels_left, 1, stride=1, bias=False))
self.path_2 = nn.ModuleList()
self.path_2.add_module('pad', nn.ZeroPad2d((0, 1, 0, 1)))
self.path_1.add_module('conv', nn.Conv2d(in_chs_left, out_chs_left, 1, stride=1, bias=False))
self.path_2 = nn.Sequential()
self.path_2.add_module('pad', nn.ZeroPad2d((-1, 1, -1, 1)))
self.path_2.add_module('avgpool', nn.AvgPool2d(1, stride=2, count_include_pad=False))
self.path_2.add_module('conv', nn.Conv2d(in_channels_left, out_channels_left, 1, stride=1, bias=False))
self.path_2.add_module('conv', nn.Conv2d(in_chs_left, out_chs_left, 1, stride=1, bias=False))
self.final_path_bn = nn.BatchNorm2d(out_channels_left * 2, eps=0.001, momentum=0.1, affine=True)
self.final_path_bn = nn.BatchNorm2d(out_chs_left * 2, eps=0.001, momentum=0.1)
self.comb_iter_0_left = BranchSeparables(out_channels_right, out_channels_right, 5, 1, 2, bias=False)
self.comb_iter_0_right = BranchSeparables(out_channels_right, out_channels_right, 3, 1, 1, bias=False)
self.comb_iter_0_left = BranchSeparables(out_chs_right, out_chs_right, 5, 1, pad_type)
self.comb_iter_0_right = BranchSeparables(out_chs_right, out_chs_right, 3, 1, pad_type)
self.comb_iter_1_left = BranchSeparables(out_channels_right, out_channels_right, 5, 1, 2, bias=False)
self.comb_iter_1_right = BranchSeparables(out_channels_right, out_channels_right, 3, 1, 1, bias=False)
self.comb_iter_1_left = BranchSeparables(out_chs_right, out_chs_right, 5, 1, pad_type)
self.comb_iter_1_right = BranchSeparables(out_chs_right, out_chs_right, 3, 1, pad_type)
self.comb_iter_2_left = nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False)
self.comb_iter_2_left = create_pool2d('avg', 3, 1, count_include_pad=False, padding=pad_type)
self.comb_iter_3_left = nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False)
self.comb_iter_3_right = nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False)
self.comb_iter_3_left = create_pool2d('avg', 3, 1, count_include_pad=False, padding=pad_type)
self.comb_iter_3_right = create_pool2d('avg', 3, 1, count_include_pad=False, padding=pad_type)
self.comb_iter_4_left = BranchSeparables(out_channels_right, out_channels_right, 3, 1, 1, bias=False)
self.comb_iter_4_left = BranchSeparables(out_chs_right, out_chs_right, 3, 1, pad_type)
def forward(self, x, x_prev):
x_relu = self.relu(x_prev)
# path 1
x_relu = self.act(x_prev)
x_path1 = self.path_1(x_relu)
# path 2
x_path2 = self.path_2.pad(x_relu)
x_path2 = x_path2[:, :, 1:, 1:]
x_path2 = self.path_2.avgpool(x_path2)
x_path2 = self.path_2.conv(x_path2)
# final path
x_path2 = self.path_2(x_relu)
x_left = self.final_path_bn(torch.cat([x_path1, x_path2], 1))
x_right = self.conv_1x1(x)
x_comb_iter_0_left = self.comb_iter_0_left(x_right)
@ -322,30 +257,23 @@ class FirstCell(nn.Module):
class NormalCell(nn.Module):
def __init__(self, in_channels_left, out_channels_left, in_channels_right, out_channels_right):
def __init__(self, in_chs_left, out_chs_left, in_chs_right, out_chs_right, pad_type=''):
super(NormalCell, self).__init__()
self.conv_prev_1x1 = nn.Sequential()
self.conv_prev_1x1.add_module('relu', nn.ReLU())
self.conv_prev_1x1.add_module('conv', nn.Conv2d(in_channels_left, out_channels_left, 1, stride=1, bias=False))
self.conv_prev_1x1.add_module('bn', nn.BatchNorm2d(out_channels_left, eps=0.001, momentum=0.1, affine=True))
self.conv_1x1 = nn.Sequential()
self.conv_1x1.add_module('relu', nn.ReLU())
self.conv_1x1.add_module('conv', nn.Conv2d(in_channels_right, out_channels_right, 1, stride=1, bias=False))
self.conv_1x1.add_module('bn', nn.BatchNorm2d(out_channels_right, eps=0.001, momentum=0.1, affine=True))
self.conv_prev_1x1 = ActConvBn(in_chs_left, out_chs_left, 1, stride=1, padding=pad_type)
self.conv_1x1 = ActConvBn(in_chs_right, out_chs_right, 1, stride=1, padding=pad_type)
self.comb_iter_0_left = BranchSeparables(out_channels_right, out_channels_right, 5, 1, 2, bias=False)
self.comb_iter_0_right = BranchSeparables(out_channels_left, out_channels_left, 3, 1, 1, bias=False)
self.comb_iter_0_left = BranchSeparables(out_chs_right, out_chs_right, 5, 1, pad_type)
self.comb_iter_0_right = BranchSeparables(out_chs_left, out_chs_left, 3, 1, pad_type)
self.comb_iter_1_left = BranchSeparables(out_channels_left, out_channels_left, 5, 1, 2, bias=False)
self.comb_iter_1_right = BranchSeparables(out_channels_left, out_channels_left, 3, 1, 1, bias=False)
self.comb_iter_1_left = BranchSeparables(out_chs_left, out_chs_left, 5, 1, pad_type)
self.comb_iter_1_right = BranchSeparables(out_chs_left, out_chs_left, 3, 1, pad_type)
self.comb_iter_2_left = nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False)
self.comb_iter_2_left = create_pool2d('avg', 3, 1, count_include_pad=False, padding=pad_type)
self.comb_iter_3_left = nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False)
self.comb_iter_3_right = nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False)
self.comb_iter_3_left = create_pool2d('avg', 3, 1, count_include_pad=False, padding=pad_type)
self.comb_iter_3_right = create_pool2d('avg', 3, 1, count_include_pad=False, padding=pad_type)
self.comb_iter_4_left = BranchSeparables(out_channels_right, out_channels_right, 3, 1, 1, bias=False)
self.comb_iter_4_left = BranchSeparables(out_chs_right, out_chs_right, 3, 1, pad_type)
def forward(self, x, x_prev):
x_left = self.conv_prev_1x1(x_prev)
@ -375,31 +303,24 @@ class NormalCell(nn.Module):
class ReductionCell0(nn.Module):
def __init__(self, in_channels_left, out_channels_left, in_channels_right, out_channels_right):
def __init__(self, in_chs_left, out_chs_left, in_chs_right, out_chs_right, pad_type=''):
super(ReductionCell0, self).__init__()
self.conv_prev_1x1 = nn.Sequential()
self.conv_prev_1x1.add_module('relu', nn.ReLU())
self.conv_prev_1x1.add_module('conv', nn.Conv2d(in_channels_left, out_channels_left, 1, stride=1, bias=False))
self.conv_prev_1x1.add_module('bn', nn.BatchNorm2d(out_channels_left, eps=0.001, momentum=0.1, affine=True))
self.conv_prev_1x1 = ActConvBn(in_chs_left, out_chs_left, 1, stride=1, padding=pad_type)
self.conv_1x1 = ActConvBn(in_chs_right, out_chs_right, 1, stride=1, padding=pad_type)
self.conv_1x1 = nn.Sequential()
self.conv_1x1.add_module('relu', nn.ReLU())
self.conv_1x1.add_module('conv', nn.Conv2d(in_channels_right, out_channels_right, 1, stride=1, bias=False))
self.conv_1x1.add_module('bn', nn.BatchNorm2d(out_channels_right, eps=0.001, momentum=0.1, affine=True))
self.comb_iter_0_left = BranchSeparables(out_chs_right, out_chs_right, 5, 2, pad_type)
self.comb_iter_0_right = BranchSeparables(out_chs_right, out_chs_right, 7, 2, pad_type)
self.comb_iter_0_left = BranchSeparablesReduction(out_channels_right, out_channels_right, 5, 2, 2, bias=False)
self.comb_iter_0_right = BranchSeparablesReduction(out_channels_right, out_channels_right, 7, 2, 3, bias=False)
self.comb_iter_1_left = create_pool2d('max', 3, 2, padding=pad_type)
self.comb_iter_1_right = BranchSeparables(out_chs_right, out_chs_right, 7, 2, pad_type)
self.comb_iter_1_left = MaxPoolPad()
self.comb_iter_1_right = BranchSeparablesReduction(out_channels_right, out_channels_right, 7, 2, 3, bias=False)
self.comb_iter_2_left = create_pool2d('avg', 3, 2, count_include_pad=False, padding=pad_type)
self.comb_iter_2_right = BranchSeparables(out_chs_right, out_chs_right, 5, 2, pad_type)
self.comb_iter_2_left = AvgPoolPad()
self.comb_iter_2_right = BranchSeparablesReduction(out_channels_right, out_channels_right, 5, 2, 2, bias=False)
self.comb_iter_3_right = create_pool2d('avg', 3, 1, count_include_pad=False, padding=pad_type)
self.comb_iter_3_right = nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False)
self.comb_iter_4_left = BranchSeparablesReduction(out_channels_right, out_channels_right, 3, 1, 1, bias=False)
self.comb_iter_4_right = MaxPoolPad()
self.comb_iter_4_left = BranchSeparables(out_chs_right, out_chs_right, 3, 1, pad_type)
self.comb_iter_4_right = create_pool2d('max', 3, 2, padding=pad_type)
def forward(self, x, x_prev):
x_left = self.conv_prev_1x1(x_prev)
@ -430,31 +351,24 @@ class ReductionCell0(nn.Module):
class ReductionCell1(nn.Module):
def __init__(self, in_channels_left, out_channels_left, in_channels_right, out_channels_right):
def __init__(self, in_chs_left, out_chs_left, in_chs_right, out_chs_right, pad_type=''):
super(ReductionCell1, self).__init__()
self.conv_prev_1x1 = nn.Sequential()
self.conv_prev_1x1.add_module('relu', nn.ReLU())
self.conv_prev_1x1.add_module('conv', nn.Conv2d(in_channels_left, out_channels_left, 1, stride=1, bias=False))
self.conv_prev_1x1.add_module('bn', nn.BatchNorm2d(out_channels_left, eps=0.001, momentum=0.1, affine=True))
self.conv_1x1 = nn.Sequential()
self.conv_1x1.add_module('relu', nn.ReLU())
self.conv_1x1.add_module('conv', nn.Conv2d(in_channels_right, out_channels_right, 1, stride=1, bias=False))
self.conv_1x1.add_module('bn', nn.BatchNorm2d(out_channels_right, eps=0.001, momentum=0.1, affine=True))
self.conv_prev_1x1 = ActConvBn(in_chs_left, out_chs_left, 1, stride=1, padding=pad_type)
self.conv_1x1 = ActConvBn(in_chs_right, out_chs_right, 1, stride=1, padding=pad_type)
self.comb_iter_0_left = BranchSeparables(out_channels_right, out_channels_right, 5, 2, 2, bias=False)
self.comb_iter_0_right = BranchSeparables(out_channels_right, out_channels_right, 7, 2, 3, bias=False)
self.comb_iter_0_left = BranchSeparables(out_chs_right, out_chs_right, 5, 2, pad_type)
self.comb_iter_0_right = BranchSeparables(out_chs_right, out_chs_right, 7, 2, pad_type)
self.comb_iter_1_left = nn.MaxPool2d(3, stride=2, padding=1)
self.comb_iter_1_right = BranchSeparables(out_channels_right, out_channels_right, 7, 2, 3, bias=False)
self.comb_iter_1_left = create_pool2d('max', 3, 2, padding=pad_type)
self.comb_iter_1_right = BranchSeparables(out_chs_right, out_chs_right, 7, 2, pad_type)
self.comb_iter_2_left = nn.AvgPool2d(3, stride=2, padding=1, count_include_pad=False)
self.comb_iter_2_right = BranchSeparables(out_channels_right, out_channels_right, 5, 2, 2, bias=False)
self.comb_iter_2_left = create_pool2d('avg', 3, 2, count_include_pad=False, padding=pad_type)
self.comb_iter_2_right = BranchSeparables(out_chs_right, out_chs_right, 5, 2, pad_type)
self.comb_iter_3_right = nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False)
self.comb_iter_3_right = create_pool2d('avg', 3, 1, count_include_pad=False, padding=pad_type)
self.comb_iter_4_left = BranchSeparables(out_channels_right, out_channels_right, 3, 1, 1, bias=False)
self.comb_iter_4_right = nn.MaxPool2d(3, stride=2, padding=1)
self.comb_iter_4_left = BranchSeparables(out_chs_right, out_chs_right, 3, 1, pad_type)
self.comb_iter_4_right = create_pool2d('max', 3, 2, padding=pad_type)
def forward(self, x, x_prev):
x_left = self.conv_prev_1x1(x_prev)
@ -487,7 +401,7 @@ class NASNetALarge(nn.Module):
"""NASNetALarge (6 @ 4032) """
def __init__(self, num_classes=1000, in_chans=1, stem_size=96, num_features=4032, channel_multiplier=2,
drop_rate=0., global_pool='avg'):
drop_rate=0., global_pool='avg', pad_type='same'):
super(NASNetALarge, self).__init__()
self.num_classes = num_classes
self.stem_size = stem_size
@ -498,60 +412,79 @@ class NASNetALarge(nn.Module):
channels = self.num_features // 24
# 24 is default value for the architecture
self.conv0 = nn.Sequential()
self.conv0.add_module('conv', nn.Conv2d(
in_channels=in_chans, out_channels=self.stem_size, kernel_size=3, padding=0, stride=2, bias=False))
self.conv0.add_module('bn', nn.BatchNorm2d(self.stem_size, eps=0.001, momentum=0.1, affine=True))
self.cell_stem_0 = CellStem0(self.stem_size, num_channels=channels // (channel_multiplier ** 2))
self.cell_stem_1 = CellStem1(self.stem_size, num_channels=channels // channel_multiplier)
self.cell_0 = FirstCell(in_channels_left=channels, out_channels_left=channels // 2,
in_channels_right=2 * channels, out_channels_right=channels)
self.cell_1 = NormalCell(in_channels_left=2 * channels, out_channels_left=channels,
in_channels_right=6 * channels, out_channels_right=channels)
self.cell_2 = NormalCell(in_channels_left=6 * channels, out_channels_left=channels,
in_channels_right=6 * channels, out_channels_right=channels)
self.cell_3 = NormalCell(in_channels_left=6 * channels, out_channels_left=channels,
in_channels_right=6 * channels, out_channels_right=channels)
self.cell_4 = NormalCell(in_channels_left=6 * channels, out_channels_left=channels,
in_channels_right=6 * channels, out_channels_right=channels)
self.cell_5 = NormalCell(in_channels_left=6 * channels, out_channels_left=channels,
in_channels_right=6 * channels, out_channels_right=channels)
self.reduction_cell_0 = ReductionCell0(in_channels_left=6 * channels, out_channels_left=2 * channels,
in_channels_right=6 * channels, out_channels_right=2 * channels)
self.cell_6 = FirstCell(in_channels_left=6 * channels, out_channels_left=channels,
in_channels_right=8 * channels, out_channels_right=2 * channels)
self.cell_7 = NormalCell(in_channels_left=8 * channels, out_channels_left=2 * channels,
in_channels_right=12 * channels, out_channels_right=2 * channels)
self.cell_8 = NormalCell(in_channels_left=12 * channels, out_channels_left=2 * channels,
in_channels_right=12 * channels, out_channels_right=2 * channels)
self.cell_9 = NormalCell(in_channels_left=12 * channels, out_channels_left=2 * channels,
in_channels_right=12 * channels, out_channels_right=2 * channels)
self.cell_10 = NormalCell(in_channels_left=12 * channels, out_channels_left=2 * channels,
in_channels_right=12 * channels, out_channels_right=2 * channels)
self.cell_11 = NormalCell(in_channels_left=12 * channels, out_channels_left=2 * channels,
in_channels_right=12 * channels, out_channels_right=2 * channels)
self.reduction_cell_1 = ReductionCell1(in_channels_left=12 * channels, out_channels_left=4 * channels,
in_channels_right=12 * channels, out_channels_right=4 * channels)
self.cell_12 = FirstCell(in_channels_left=12 * channels, out_channels_left=2 * channels,
in_channels_right=16 * channels, out_channels_right=4 * channels)
self.cell_13 = NormalCell(in_channels_left=16 * channels, out_channels_left=4 * channels,
in_channels_right=24 * channels, out_channels_right=4 * channels)
self.cell_14 = NormalCell(in_channels_left=24 * channels, out_channels_left=4 * channels,
in_channels_right=24 * channels, out_channels_right=4 * channels)
self.cell_15 = NormalCell(in_channels_left=24 * channels, out_channels_left=4 * channels,
in_channels_right=24 * channels, out_channels_right=4 * channels)
self.cell_16 = NormalCell(in_channels_left=24 * channels, out_channels_left=4 * channels,
in_channels_right=24 * channels, out_channels_right=4 * channels)
self.cell_17 = NormalCell(in_channels_left=24 * channels, out_channels_left=4 * channels,
in_channels_right=24 * channels, out_channels_right=4 * channels)
self.relu = nn.ReLU()
self.conv0 = ConvBnAct(
in_channels=in_chans, out_channels=self.stem_size, kernel_size=3, padding=0, stride=2,
norm_kwargs=dict(eps=0.001, momentum=0.1), act_layer=None)
self.cell_stem_0 = CellStem0(
self.stem_size, num_channels=channels // (channel_multiplier ** 2), pad_type=pad_type)
self.cell_stem_1 = CellStem1(
self.stem_size, num_channels=channels // channel_multiplier, pad_type=pad_type)
self.cell_0 = FirstCell(
in_chs_left=channels, out_chs_left=channels // 2,
in_chs_right=2 * channels, out_chs_right=channels, pad_type=pad_type)
self.cell_1 = NormalCell(
in_chs_left=2 * channels, out_chs_left=channels,
in_chs_right=6 * channels, out_chs_right=channels, pad_type=pad_type)
self.cell_2 = NormalCell(
in_chs_left=6 * channels, out_chs_left=channels,
in_chs_right=6 * channels, out_chs_right=channels, pad_type=pad_type)
self.cell_3 = NormalCell(
in_chs_left=6 * channels, out_chs_left=channels,
in_chs_right=6 * channels, out_chs_right=channels, pad_type=pad_type)
self.cell_4 = NormalCell(
in_chs_left=6 * channels, out_chs_left=channels,
in_chs_right=6 * channels, out_chs_right=channels, pad_type=pad_type)
self.cell_5 = NormalCell(
in_chs_left=6 * channels, out_chs_left=channels,
in_chs_right=6 * channels, out_chs_right=channels, pad_type=pad_type)
self.reduction_cell_0 = ReductionCell0(
in_chs_left=6 * channels, out_chs_left=2 * channels,
in_chs_right=6 * channels, out_chs_right=2 * channels, pad_type=pad_type)
self.cell_6 = FirstCell(
in_chs_left=6 * channels, out_chs_left=channels,
in_chs_right=8 * channels, out_chs_right=2 * channels, pad_type=pad_type)
self.cell_7 = NormalCell(
in_chs_left=8 * channels, out_chs_left=2 * channels,
in_chs_right=12 * channels, out_chs_right=2 * channels, pad_type=pad_type)
self.cell_8 = NormalCell(
in_chs_left=12 * channels, out_chs_left=2 * channels,
in_chs_right=12 * channels, out_chs_right=2 * channels, pad_type=pad_type)
self.cell_9 = NormalCell(
in_chs_left=12 * channels, out_chs_left=2 * channels,
in_chs_right=12 * channels, out_chs_right=2 * channels, pad_type=pad_type)
self.cell_10 = NormalCell(
in_chs_left=12 * channels, out_chs_left=2 * channels,
in_chs_right=12 * channels, out_chs_right=2 * channels, pad_type=pad_type)
self.cell_11 = NormalCell(
in_chs_left=12 * channels, out_chs_left=2 * channels,
in_chs_right=12 * channels, out_chs_right=2 * channels, pad_type=pad_type)
self.reduction_cell_1 = ReductionCell1(
in_chs_left=12 * channels, out_chs_left=4 * channels,
in_chs_right=12 * channels, out_chs_right=4 * channels, pad_type=pad_type)
self.cell_12 = FirstCell(
in_chs_left=12 * channels, out_chs_left=2 * channels,
in_chs_right=16 * channels, out_chs_right=4 * channels, pad_type=pad_type)
self.cell_13 = NormalCell(
in_chs_left=16 * channels, out_chs_left=4 * channels,
in_chs_right=24 * channels, out_chs_right=4 * channels, pad_type=pad_type)
self.cell_14 = NormalCell(
in_chs_left=24 * channels, out_chs_left=4 * channels,
in_chs_right=24 * channels, out_chs_right=4 * channels, pad_type=pad_type)
self.cell_15 = NormalCell(
in_chs_left=24 * channels, out_chs_left=4 * channels,
in_chs_right=24 * channels, out_chs_right=4 * channels, pad_type=pad_type)
self.cell_16 = NormalCell(
in_chs_left=24 * channels, out_chs_left=4 * channels,
in_chs_right=24 * channels, out_chs_right=4 * channels, pad_type=pad_type)
self.cell_17 = NormalCell(
in_chs_left=24 * channels, out_chs_left=4 * channels,
in_chs_right=24 * channels, out_chs_right=4 * channels, pad_type=pad_type)
self.act = nn.ReLU(inplace=True)
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
self.last_linear = nn.Linear(self.num_features * self.global_pool.feat_mult(), num_classes)
@ -569,8 +502,11 @@ class NASNetALarge(nn.Module):
def forward_features(self, x):
x_conv0 = self.conv0(x)
#0
x_stem_0 = self.cell_stem_0(x_conv0)
x_stem_1 = self.cell_stem_1(x_conv0, x_stem_0)
#1
x_cell_0 = self.cell_0(x_stem_1, x_stem_0)
x_cell_1 = self.cell_1(x_cell_0, x_stem_1)
@ -578,25 +514,27 @@ class NASNetALarge(nn.Module):
x_cell_3 = self.cell_3(x_cell_2, x_cell_1)
x_cell_4 = self.cell_4(x_cell_3, x_cell_2)
x_cell_5 = self.cell_5(x_cell_4, x_cell_3)
#2
x_reduction_cell_0 = self.reduction_cell_0(x_cell_5, x_cell_4)
x_cell_6 = self.cell_6(x_reduction_cell_0, x_cell_4)
x_cell_7 = self.cell_7(x_cell_6, x_reduction_cell_0)
x_cell_8 = self.cell_8(x_cell_7, x_cell_6)
x_cell_9 = self.cell_9(x_cell_8, x_cell_7)
x_cell_10 = self.cell_10(x_cell_9, x_cell_8)
x_cell_11 = self.cell_11(x_cell_10, x_cell_9)
#3
x_reduction_cell_1 = self.reduction_cell_1(x_cell_11, x_cell_10)
x_cell_12 = self.cell_12(x_reduction_cell_1, x_cell_10)
x_cell_13 = self.cell_13(x_cell_12, x_reduction_cell_1)
x_cell_14 = self.cell_14(x_cell_13, x_cell_12)
x_cell_15 = self.cell_15(x_cell_14, x_cell_13)
x_cell_16 = self.cell_16(x_cell_15, x_cell_14)
x_cell_17 = self.cell_17(x_cell_16, x_cell_15)
x = self.relu(x_cell_17)
x = self.act(x_cell_17)
#4
return x
def forward(self, x):

@ -14,7 +14,7 @@ import torch.nn as nn
import torch.nn.functional as F
from .helpers import load_pretrained
from .layers import SelectAdaptivePool2d
from .layers import SelectAdaptivePool2d, ConvBnAct, create_conv2d, create_pool2d
from .registry import register_model
__all__ = ['PNASNet5Large']
@ -35,34 +35,15 @@ default_cfgs = {
}
class MaxPool(nn.Module):
def __init__(self, kernel_size, stride=1, padding=1, zero_pad=False):
super(MaxPool, self).__init__()
self.zero_pad = nn.ZeroPad2d((1, 0, 1, 0)) if zero_pad else None
self.pool = nn.MaxPool2d(kernel_size, stride=stride, padding=padding)
def forward(self, x):
if self.zero_pad is not None:
x = self.zero_pad(x)
x = self.pool(x)
x = x[:, :, 1:, 1:]
else:
x = self.pool(x)
return x
class SeparableConv2d(nn.Module):
def __init__(self, in_channels, out_channels, dw_kernel_size, dw_stride,
dw_padding):
def __init__(self, in_channels, out_channels, kernel_size, stride, padding=''):
super(SeparableConv2d, self).__init__()
self.depthwise_conv2d = nn.Conv2d(in_channels, in_channels,
kernel_size=dw_kernel_size,
stride=dw_stride, padding=dw_padding,
groups=in_channels, bias=False)
self.pointwise_conv2d = nn.Conv2d(in_channels, out_channels,
kernel_size=1, bias=False)
self.depthwise_conv2d = create_conv2d(
in_channels, in_channels, kernel_size=kernel_size,
stride=stride, padding=padding, groups=in_channels)
self.pointwise_conv2d = create_conv2d(
in_channels, out_channels, kernel_size=1, padding=padding)
def forward(self, x):
x = self.depthwise_conv2d(x)
@ -72,50 +53,39 @@ class SeparableConv2d(nn.Module):
class BranchSeparables(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, stride=1,
stem_cell=False, zero_pad=False):
def __init__(self, in_channels, out_channels, kernel_size, stride=1, stem_cell=False, padding=''):
super(BranchSeparables, self).__init__()
padding = kernel_size // 2
middle_channels = out_channels if stem_cell else in_channels
self.zero_pad = nn.ZeroPad2d((1, 0, 1, 0)) if zero_pad else None
self.relu_1 = nn.ReLU()
self.separable_1 = SeparableConv2d(in_channels, middle_channels,
kernel_size, dw_stride=stride,
dw_padding=padding)
self.act_1 = nn.ReLU()
self.separable_1 = SeparableConv2d(
in_channels, middle_channels, kernel_size, stride=stride, padding=padding)
self.bn_sep_1 = nn.BatchNorm2d(middle_channels, eps=0.001)
self.relu_2 = nn.ReLU()
self.separable_2 = SeparableConv2d(middle_channels, out_channels,
kernel_size, dw_stride=1,
dw_padding=padding)
self.act_2 = nn.ReLU()
self.separable_2 = SeparableConv2d(
middle_channels, out_channels, kernel_size, stride=1, padding=padding)
self.bn_sep_2 = nn.BatchNorm2d(out_channels, eps=0.001)
def forward(self, x):
x = self.relu_1(x)
if self.zero_pad is not None:
x = self.zero_pad(x)
x = self.separable_1(x)
x = x[:, :, 1:, 1:].contiguous()
else:
x = self.act_1(x)
x = self.separable_1(x)
x = self.bn_sep_1(x)
x = self.relu_2(x)
x = self.act_2(x)
x = self.separable_2(x)
x = self.bn_sep_2(x)
return x
class ReluConvBn(nn.Module):
class ActConvBn(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, stride=1):
super(ReluConvBn, self).__init__()
self.relu = nn.ReLU()
self.conv = nn.Conv2d(in_channels, out_channels,
kernel_size=kernel_size, stride=stride,
bias=False)
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=''):
super(ActConvBn, self).__init__()
self.act = nn.ReLU()
self.conv = create_conv2d(
in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding)
self.bn = nn.BatchNorm2d(out_channels, eps=0.001)
def forward(self, x):
x = self.relu(x)
x = self.act(x)
x = self.conv(x)
x = self.bn(x)
return x
@ -123,32 +93,24 @@ class ReluConvBn(nn.Module):
class FactorizedReduction(nn.Module):
def __init__(self, in_channels, out_channels):
def __init__(self, in_channels, out_channels, padding=''):
super(FactorizedReduction, self).__init__()
self.relu = nn.ReLU()
self.act = nn.ReLU()
self.path_1 = nn.Sequential(OrderedDict([
('avgpool', nn.AvgPool2d(1, stride=2, count_include_pad=False)),
('conv', nn.Conv2d(in_channels, out_channels // 2,
kernel_size=1, bias=False)),
('conv', create_conv2d(in_channels, out_channels // 2, kernel_size=1, padding=padding)),
]))
self.path_2 = nn.Sequential(OrderedDict([
('pad', nn.ZeroPad2d((0, 1, 0, 1))),
('pad', nn.ZeroPad2d((-1, 1, -1, 1))), # shift
('avgpool', nn.AvgPool2d(1, stride=2, count_include_pad=False)),
('conv', nn.Conv2d(in_channels, out_channels // 2,
kernel_size=1, bias=False)),
('conv', create_conv2d(in_channels, out_channels // 2, kernel_size=1, padding=padding)),
]))
self.final_path_bn = nn.BatchNorm2d(out_channels, eps=0.001)
def forward(self, x):
x = self.relu(x)
x = self.act(x)
x_path1 = self.path_1(x)
x_path2 = self.path_2.pad(x)
x_path2 = x_path2[:, :, 1:, 1:]
x_path2 = self.path_2.avgpool(x_path2)
x_path2 = self.path_2.conv(x_path2)
x_path2 = self.path_2(x)
out = self.final_path_bn(torch.cat([x_path1, x_path2], 1))
return out
@ -179,49 +141,41 @@ class CellBase(nn.Module):
x_comb_iter_4_right = x_right
x_comb_iter_4 = x_comb_iter_4_left + x_comb_iter_4_right
x_out = torch.cat(
[x_comb_iter_0, x_comb_iter_1, x_comb_iter_2, x_comb_iter_3, x_comb_iter_4], 1)
x_out = torch.cat([x_comb_iter_0, x_comb_iter_1, x_comb_iter_2, x_comb_iter_3, x_comb_iter_4], 1)
return x_out
class CellStem0(CellBase):
def __init__(self, in_channels_left, out_channels_left, in_channels_right,
out_channels_right):
def __init__(self, in_chs_left, out_chs_left, in_chs_right, out_chs_right, padding=''):
super(CellStem0, self).__init__()
self.conv_1x1 = ReluConvBn(in_channels_right, out_channels_right,
kernel_size=1)
self.comb_iter_0_left = BranchSeparables(in_channels_left,
out_channels_left,
kernel_size=5, stride=2,
stem_cell=True)
self.conv_1x1 = ActConvBn(in_chs_right, out_chs_right, kernel_size=1, padding=padding)
self.comb_iter_0_left = BranchSeparables(
in_chs_left, out_chs_left, kernel_size=5, stride=2, stem_cell=True, padding=padding)
self.comb_iter_0_right = nn.Sequential(OrderedDict([
('max_pool', MaxPool(3, stride=2)),
('conv', nn.Conv2d(in_channels_left, out_channels_left,
kernel_size=1, bias=False)),
('bn', nn.BatchNorm2d(out_channels_left, eps=0.001)),
('max_pool', create_pool2d('max', 3, stride=2, padding=padding)),
('conv', create_conv2d(in_chs_left, out_chs_left, kernel_size=1, padding=padding)),
('bn', nn.BatchNorm2d(out_chs_left, eps=0.001)),
]))
self.comb_iter_1_left = BranchSeparables(out_channels_right,
out_channels_right,
kernel_size=7, stride=2)
self.comb_iter_1_right = MaxPool(3, stride=2)
self.comb_iter_2_left = BranchSeparables(out_channels_right,
out_channels_right,
kernel_size=5, stride=2)
self.comb_iter_2_right = BranchSeparables(out_channels_right,
out_channels_right,
kernel_size=3, stride=2)
self.comb_iter_3_left = BranchSeparables(out_channels_right,
out_channels_right,
kernel_size=3)
self.comb_iter_3_right = MaxPool(3, stride=2)
self.comb_iter_4_left = BranchSeparables(in_channels_right,
out_channels_right,
kernel_size=3, stride=2,
stem_cell=True)
self.comb_iter_4_right = ReluConvBn(out_channels_right,
out_channels_right,
kernel_size=1, stride=2)
self.comb_iter_1_left = BranchSeparables(
out_chs_right, out_chs_right, kernel_size=7, stride=2, padding=padding)
self.comb_iter_1_right = create_pool2d('max', 3, stride=2, padding=padding)
self.comb_iter_2_left = BranchSeparables(
out_chs_right, out_chs_right, kernel_size=5, stride=2, padding=padding)
self.comb_iter_2_right = BranchSeparables(
out_chs_right, out_chs_right, kernel_size=3, stride=2, padding=padding)
self.comb_iter_3_left = BranchSeparables(
out_chs_right, out_chs_right, kernel_size=3, padding=padding)
self.comb_iter_3_right = create_pool2d('max', 3, stride=2, padding=padding)
self.comb_iter_4_left = BranchSeparables(
in_chs_right, out_chs_right, kernel_size=3, stride=2, stem_cell=True, padding=padding)
self.comb_iter_4_right = ActConvBn(
out_chs_right, out_chs_right, kernel_size=1, stride=2, padding=padding)
def forward(self, x_left):
x_right = self.conv_1x1(x_left)
@ -231,9 +185,8 @@ class CellStem0(CellBase):
class Cell(CellBase):
def __init__(self, in_channels_left, out_channels_left, in_channels_right,
out_channels_right, is_reduction=False, zero_pad=False,
match_prev_layer_dimensions=False):
def __init__(self, in_chs_left, out_chs_left, in_chs_right, out_chs_right, padding='',
is_reduction=False, match_prev_layer_dims=False):
super(Cell, self).__init__()
# If `is_reduction` is set to `True` stride 2 is used for
@ -244,45 +197,34 @@ class Cell(CellBase):
# If `match_prev_layer_dimensions` is set to `True`
# `FactorizedReduction` is used to reduce the spatial size
# of the left input of a cell approximately by a factor of 2.
self.match_prev_layer_dimensions = match_prev_layer_dimensions
if match_prev_layer_dimensions:
self.conv_prev_1x1 = FactorizedReduction(in_channels_left,
out_channels_left)
self.match_prev_layer_dimensions = match_prev_layer_dims
if match_prev_layer_dims:
self.conv_prev_1x1 = FactorizedReduction(in_chs_left, out_chs_left, padding=padding)
else:
self.conv_prev_1x1 = ReluConvBn(in_channels_left,
out_channels_left, kernel_size=1)
self.conv_1x1 = ReluConvBn(in_channels_right, out_channels_right,
kernel_size=1)
self.comb_iter_0_left = BranchSeparables(out_channels_left,
out_channels_left,
kernel_size=5, stride=stride,
zero_pad=zero_pad)
self.comb_iter_0_right = MaxPool(3, stride=stride, zero_pad=zero_pad)
self.comb_iter_1_left = BranchSeparables(out_channels_right,
out_channels_right,
kernel_size=7, stride=stride,
zero_pad=zero_pad)
self.comb_iter_1_right = MaxPool(3, stride=stride, zero_pad=zero_pad)
self.comb_iter_2_left = BranchSeparables(out_channels_right,
out_channels_right,
kernel_size=5, stride=stride,
zero_pad=zero_pad)
self.comb_iter_2_right = BranchSeparables(out_channels_right,
out_channels_right,
kernel_size=3, stride=stride,
zero_pad=zero_pad)
self.comb_iter_3_left = BranchSeparables(out_channels_right,
out_channels_right,
kernel_size=3)
self.comb_iter_3_right = MaxPool(3, stride=stride, zero_pad=zero_pad)
self.comb_iter_4_left = BranchSeparables(out_channels_left,
out_channels_left,
kernel_size=3, stride=stride,
zero_pad=zero_pad)
self.conv_prev_1x1 = ActConvBn(in_chs_left, out_chs_left, kernel_size=1, padding=padding)
self.conv_1x1 = ActConvBn(in_chs_right, out_chs_right, kernel_size=1, padding=padding)
self.comb_iter_0_left = BranchSeparables(
out_chs_left, out_chs_left, kernel_size=5, stride=stride, padding=padding)
self.comb_iter_0_right = create_pool2d('max', 3, stride=stride, padding=padding)
self.comb_iter_1_left = BranchSeparables(
out_chs_right, out_chs_right, kernel_size=7, stride=stride, padding=padding)
self.comb_iter_1_right = create_pool2d('max', 3, stride=stride, padding=padding)
self.comb_iter_2_left = BranchSeparables(
out_chs_right, out_chs_right, kernel_size=5, stride=stride, padding=padding)
self.comb_iter_2_right = BranchSeparables(
out_chs_right, out_chs_right, kernel_size=3, stride=stride, padding=padding)
self.comb_iter_3_left = BranchSeparables(out_chs_right, out_chs_right, kernel_size=3)
self.comb_iter_3_right = create_pool2d('max', 3, stride=stride, padding=padding)
self.comb_iter_4_left = BranchSeparables(
out_chs_left, out_chs_left, kernel_size=3, stride=stride, padding=padding)
if is_reduction:
self.comb_iter_4_right = ReluConvBn(
out_channels_right, out_channels_right, kernel_size=1, stride=stride)
self.comb_iter_4_right = ActConvBn(
out_chs_right, out_chs_right, kernel_size=1, stride=stride, padding=padding)
else:
self.comb_iter_4_right = None
@ -294,52 +236,53 @@ class Cell(CellBase):
class PNASNet5Large(nn.Module):
def __init__(self, num_classes=1001, in_chans=3, drop_rate=0.5, global_pool='avg'):
def __init__(self, num_classes=1001, in_chans=3, drop_rate=0.5, global_pool='avg', padding=''):
super(PNASNet5Large, self).__init__()
self.num_classes = num_classes
self.num_features = 4320
self.drop_rate = drop_rate
self.conv_0 = nn.Sequential(OrderedDict([
('conv', nn.Conv2d(in_chans, 96, kernel_size=3, stride=2, bias=False)),
('bn', nn.BatchNorm2d(96, eps=0.001))
]))
self.cell_stem_0 = CellStem0(in_channels_left=96, out_channels_left=54,
in_channels_right=96,
out_channels_right=54)
self.cell_stem_1 = Cell(in_channels_left=96, out_channels_left=108,
in_channels_right=270, out_channels_right=108,
match_prev_layer_dimensions=True,
self.conv_0 = ConvBnAct(
in_chans, 96, kernel_size=3, stride=2, padding=0,
norm_kwargs=dict(eps=0.001, momentum=0.1), act_layer=None)
self.cell_stem_0 = CellStem0(
in_chs_left=96, out_chs_left=54, in_chs_right=96, out_chs_right=54, padding=padding)
self.cell_stem_1 = Cell(
in_chs_left=96, out_chs_left=108, in_chs_right=270, out_chs_right=108, padding=padding,
match_prev_layer_dims=True, is_reduction=True)
self.cell_0 = Cell(
in_chs_left=270, out_chs_left=216, in_chs_right=540, out_chs_right=216, padding=padding,
match_prev_layer_dims=True)
self.cell_1 = Cell(
in_chs_left=540, out_chs_left=216, in_chs_right=1080, out_chs_right=216, padding=padding)
self.cell_2 = Cell(
in_chs_left=1080, out_chs_left=216, in_chs_right=1080, out_chs_right=216, padding=padding)
self.cell_3 = Cell(
in_chs_left=1080, out_chs_left=216, in_chs_right=1080, out_chs_right=216, padding=padding)
self.cell_4 = Cell(
in_chs_left=1080, out_chs_left=432, in_chs_right=1080, out_chs_right=432, padding=padding,
is_reduction=True)
self.cell_0 = Cell(in_channels_left=270, out_channels_left=216,
in_channels_right=540, out_channels_right=216,
match_prev_layer_dimensions=True)
self.cell_1 = Cell(in_channels_left=540, out_channels_left=216,
in_channels_right=1080, out_channels_right=216)
self.cell_2 = Cell(in_channels_left=1080, out_channels_left=216,
in_channels_right=1080, out_channels_right=216)
self.cell_3 = Cell(in_channels_left=1080, out_channels_left=216,
in_channels_right=1080, out_channels_right=216)
self.cell_4 = Cell(in_channels_left=1080, out_channels_left=432,
in_channels_right=1080, out_channels_right=432,
is_reduction=True, zero_pad=True)
self.cell_5 = Cell(in_channels_left=1080, out_channels_left=432,
in_channels_right=2160, out_channels_right=432,
match_prev_layer_dimensions=True)
self.cell_6 = Cell(in_channels_left=2160, out_channels_left=432,
in_channels_right=2160, out_channels_right=432)
self.cell_7 = Cell(in_channels_left=2160, out_channels_left=432,
in_channels_right=2160, out_channels_right=432)
self.cell_8 = Cell(in_channels_left=2160, out_channels_left=864,
in_channels_right=2160, out_channels_right=864,
self.cell_5 = Cell(
in_chs_left=1080, out_chs_left=432, in_chs_right=2160, out_chs_right=432, padding=padding,
match_prev_layer_dims=True)
self.cell_6 = Cell(
in_chs_left=2160, out_chs_left=432, in_chs_right=2160, out_chs_right=432, padding=padding)
self.cell_7 = Cell(
in_chs_left=2160, out_chs_left=432, in_chs_right=2160, out_chs_right=432, padding=padding)
self.cell_8 = Cell(
in_chs_left=2160, out_chs_left=864, in_chs_right=2160, out_chs_right=864, padding=padding,
is_reduction=True)
self.cell_9 = Cell(in_channels_left=2160, out_channels_left=864,
in_channels_right=4320, out_channels_right=864,
match_prev_layer_dimensions=True)
self.cell_10 = Cell(in_channels_left=4320, out_channels_left=864,
in_channels_right=4320, out_channels_right=864)
self.cell_11 = Cell(in_channels_left=4320, out_channels_left=864,
in_channels_right=4320, out_channels_right=864)
self.cell_9 = Cell(
in_chs_left=2160, out_chs_left=864, in_chs_right=4320, out_chs_right=864, padding=padding,
match_prev_layer_dims=True)
self.cell_10 = Cell(
in_chs_left=4320, out_chs_left=864, in_chs_right=4320, out_chs_right=864, padding=padding)
self.cell_11 = Cell(
in_chs_left=4320, out_chs_left=864, in_chs_right=4320, out_chs_right=864, padding=padding)
self.relu = nn.ReLU()
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
self.last_linear = nn.Linear(self.num_features * self.global_pool.feat_mult(), num_classes)
@ -391,7 +334,7 @@ def pnasnet5large(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
<https://arxiv.org/abs/1712.00559>`_ paper.
"""
default_cfg = default_cfgs['pnasnet5large']
model = PNASNet5Large(num_classes=num_classes, in_chans=in_chans, **kwargs)
model = PNASNet5Large(num_classes=num_classes, in_chans=in_chans, padding='same', **kwargs)
model.default_cfg = default_cfg
if pretrained:
load_pretrained(model, default_cfg, num_classes, in_chans)

@ -10,7 +10,7 @@ import torch.nn as nn
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .helpers import load_pretrained
from .registry import register_model
from .resnet import ResNet
from .resnet import _create_resnet_with_cfg
__all__ = []
@ -132,113 +132,83 @@ class Bottle2neck(nn.Module):
return out
def _create_res2net(variant, pretrained=False, **kwargs):
default_cfg = default_cfgs[variant]
return _create_resnet_with_cfg(variant, default_cfg, pretrained=pretrained, **kwargs)
@register_model
def res2net50_26w_4s(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
"""Constructs a Res2Net-50_26w_4s model.
def res2net50_26w_4s(pretrained=False, **kwargs):
"""Constructs a Res2Net-50 26w4s model.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
default_cfg = default_cfgs['res2net50_26w_4s']
res2net_block_args = dict(scale=4)
model = ResNet(Bottle2neck, [3, 4, 6, 3], base_width=26,
num_classes=num_classes, in_chans=in_chans, block_args=res2net_block_args, **kwargs)
model.default_cfg = default_cfg
if pretrained:
load_pretrained(model, default_cfg, num_classes, in_chans)
return model
model_args = dict(
block=Bottle2neck, layers=[3, 4, 6, 3], base_width=26, block_args=dict(scale=4), **kwargs)
return _create_res2net('res2net50_26w_4s', pretrained, **model_args)
@register_model
def res2net101_26w_4s(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
"""Constructs a Res2Net-50_26w_4s model.
def res2net101_26w_4s(pretrained=False, **kwargs):
"""Constructs a Res2Net-101 26w4s model.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
default_cfg = default_cfgs['res2net101_26w_4s']
res2net_block_args = dict(scale=4)
model = ResNet(Bottle2neck, [3, 4, 23, 3], base_width=26,
num_classes=num_classes, in_chans=in_chans, block_args=res2net_block_args, **kwargs)
model.default_cfg = default_cfg
if pretrained:
load_pretrained(model, default_cfg, num_classes, in_chans)
return model
model_args = dict(
block=Bottle2neck, layers=[3, 4, 23, 3], base_width=26, block_args=dict(scale=4), **kwargs)
return _create_res2net('res2net101_26w_4s', pretrained, **model_args)
@register_model
def res2net50_26w_6s(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
"""Constructs a Res2Net-50_26w_4s model.
def res2net50_26w_6s(pretrained=False, **kwargs):
"""Constructs a Res2Net-50 26w6s model.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
default_cfg = default_cfgs['res2net50_26w_6s']
res2net_block_args = dict(scale=6)
model = ResNet(Bottle2neck, [3, 4, 6, 3], base_width=26,
num_classes=num_classes, in_chans=in_chans, block_args=res2net_block_args, **kwargs)
model.default_cfg = default_cfg
if pretrained:
load_pretrained(model, default_cfg, num_classes, in_chans)
return model
model_args = dict(
block=Bottle2neck, layers=[3, 4, 6, 3], base_width=26, block_args=dict(scale=6), **kwargs)
return _create_res2net('res2net50_26w_6s', pretrained, **model_args)
@register_model
def res2net50_26w_8s(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
"""Constructs a Res2Net-50_26w_4s model.
def res2net50_26w_8s(pretrained=False, **kwargs):
"""Constructs a Res2Net-50 26w8s model.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
default_cfg = default_cfgs['res2net50_26w_8s']
res2net_block_args = dict(scale=8)
model = ResNet(Bottle2neck, [3, 4, 6, 3], base_width=26,
num_classes=num_classes, in_chans=in_chans, block_args=res2net_block_args, **kwargs)
model.default_cfg = default_cfg
if pretrained:
load_pretrained(model, default_cfg, num_classes, in_chans)
return model
model_args = dict(
block=Bottle2neck, layers=[3, 4, 6, 3], base_width=26, block_args=dict(scale=8), **kwargs)
return _create_res2net('res2net50_26w_8s', pretrained, **model_args)
@register_model
def res2net50_48w_2s(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
"""Constructs a Res2Net-50_48w_2s model.
def res2net50_48w_2s(pretrained=False, **kwargs):
"""Constructs a Res2Net-50 48w2s model.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
default_cfg = default_cfgs['res2net50_48w_2s']
res2net_block_args = dict(scale=2)
model = ResNet(Bottle2neck, [3, 4, 6, 3], base_width=48,
num_classes=num_classes, in_chans=in_chans, block_args=res2net_block_args, **kwargs)
model.default_cfg = default_cfg
if pretrained:
load_pretrained(model, default_cfg, num_classes, in_chans)
return model
model_args = dict(
block=Bottle2neck, layers=[3, 4, 6, 3], base_width=48, block_args=dict(scale=2), **kwargs)
return _create_res2net('res2net50_26w_8s', pretrained, **model_args)
@register_model
def res2net50_14w_8s(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
"""Constructs a Res2Net-50_14w_8s model.
def res2net50_14w_8s(pretrained=False, **kwargs):
"""Constructs a Res2Net-50 14w8s model.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
default_cfg = default_cfgs['res2net50_14w_8s']
res2net_block_args = dict(scale=8)
model = ResNet(Bottle2neck, [3, 4, 6, 3], base_width=14, num_classes=num_classes, in_chans=in_chans,
block_args=res2net_block_args, **kwargs)
model.default_cfg = default_cfg
if pretrained:
load_pretrained(model, default_cfg, num_classes, in_chans)
return model
model_args = dict(
block=Bottle2neck, layers=[3, 4, 6, 3], base_width=14, block_args=dict(scale=8), **kwargs)
return _create_res2net('res2net50_26w_8s', pretrained, **model_args)
@register_model
def res2next50(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
def res2next50(pretrained=False, **kwargs):
"""Construct Res2NeXt-50 4s
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
default_cfg = default_cfgs['res2next50']
res2net_block_args = dict(scale=4)
model = ResNet(Bottle2neck, [3, 4, 6, 3], base_width=4, cardinality=8,
num_classes=num_classes, in_chans=in_chans, block_args=res2net_block_args, **kwargs)
model.default_cfg = default_cfg
if pretrained:
load_pretrained(model, default_cfg, num_classes, in_chans)
return model
model_args = dict(
block=Bottle2neck, layers=[3, 4, 6, 3], base_width=4, cardinality=8, block_args=dict(scale=4), **kwargs)
return _create_res2net('res2next50', pretrained, **model_args)

@ -6,18 +6,14 @@ Adapted from original PyTorch impl w/ weights at https://github.com/zhanghang198
Modified for torchscript compat, and consistency with timm by Ross Wightman
"""
import math
import torch
import torch.nn.functional as F
from torch import nn
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.models.layers import DropBlock2d
from .helpers import load_pretrained
from .layers import SelectiveKernelConv, ConvBnAct, create_attn
from .layers.split_attn import SplitAttnConv2d
from .registry import register_model
from .resnet import ResNet
from .resnet import _create_resnet_with_cfg
def _cfg(url='', **kwargs):
@ -143,125 +139,98 @@ class ResNestBottleneck(nn.Module):
return out
def _create_resnest(variant, pretrained=False, **kwargs):
default_cfg = default_cfgs[variant]
return _create_resnet_with_cfg(variant, default_cfg, pretrained=pretrained, **kwargs)
@register_model
def resnest14d(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
def resnest14d(pretrained=False, **kwargs):
""" ResNeSt-14d model. Weights ported from GluonCV.
"""
default_cfg = default_cfgs['resnest14d']
model = ResNet(
ResNestBottleneck, [1, 1, 1, 1], num_classes=num_classes, in_chans=in_chans,
model_kwargs = dict(
block=ResNestBottleneck, layers=[1, 1, 1, 1],
stem_type='deep', stem_width=32, avg_down=True, base_width=64, cardinality=1,
block_args=dict(radix=2, avd=True, avd_first=False), **kwargs)
model.default_cfg = default_cfg
if pretrained:
load_pretrained(model, default_cfg, num_classes, in_chans)
return model
return _create_resnest('resnest14d', pretrained=pretrained, **model_kwargs)
@register_model
def resnest26d(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
def resnest26d(pretrained=False, **kwargs):
""" ResNeSt-26d model. Weights ported from GluonCV.
"""
default_cfg = default_cfgs['resnest26d']
model = ResNet(
ResNestBottleneck, [2, 2, 2, 2], num_classes=num_classes, in_chans=in_chans,
model_kwargs = dict(
block=ResNestBottleneck, layers=[2, 2, 2, 2],
stem_type='deep', stem_width=32, avg_down=True, base_width=64, cardinality=1,
block_args=dict(radix=2, avd=True, avd_first=False), **kwargs)
model.default_cfg = default_cfg
if pretrained:
load_pretrained(model, default_cfg, num_classes, in_chans)
return model
return _create_resnest('resnest26d', pretrained=pretrained, **model_kwargs)
@register_model
def resnest50d(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
def resnest50d(pretrained=False, **kwargs):
""" ResNeSt-50d model. Matches paper ResNeSt-50 model, https://arxiv.org/abs/2004.08955
Since this codebase supports all possible variations, 'd' for deep stem, stem_width 32, avg in downsample.
"""
default_cfg = default_cfgs['resnest50d']
model = ResNet(
ResNestBottleneck, [3, 4, 6, 3], num_classes=num_classes, in_chans=in_chans,
model_kwargs = dict(
block=ResNestBottleneck, layers=[3, 4, 6, 3],
stem_type='deep', stem_width=32, avg_down=True, base_width=64, cardinality=1,
block_args=dict(radix=2, avd=True, avd_first=False), **kwargs)
model.default_cfg = default_cfg
if pretrained:
load_pretrained(model, default_cfg, num_classes, in_chans)
return model
return _create_resnest('resnest50d', pretrained=pretrained, **model_kwargs)
@register_model
def resnest101e(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
def resnest101e(pretrained=False, **kwargs):
""" ResNeSt-101e model. Matches paper ResNeSt-101 model, https://arxiv.org/abs/2004.08955
Since this codebase supports all possible variations, 'e' for deep stem, stem_width 64, avg in downsample.
"""
default_cfg = default_cfgs['resnest101e']
model = ResNet(
ResNestBottleneck, [3, 4, 23, 3], num_classes=num_classes, in_chans=in_chans,
model_kwargs = dict(
block=ResNestBottleneck, layers=[3, 4, 23, 3],
stem_type='deep', stem_width=64, avg_down=True, base_width=64, cardinality=1,
block_args=dict(radix=2, avd=True, avd_first=False), **kwargs)
model.default_cfg = default_cfg
if pretrained:
load_pretrained(model, default_cfg, num_classes, in_chans)
return model
return _create_resnest('resnest101e', pretrained=pretrained, **model_kwargs)
@register_model
def resnest200e(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
def resnest200e(pretrained=False, **kwargs):
""" ResNeSt-200e model. Matches paper ResNeSt-200 model, https://arxiv.org/abs/2004.08955
Since this codebase supports all possible variations, 'e' for deep stem, stem_width 64, avg in downsample.
"""
default_cfg = default_cfgs['resnest200e']
model = ResNet(
ResNestBottleneck, [3, 24, 36, 3], num_classes=num_classes, in_chans=in_chans,
model_kwargs = dict(
block=ResNestBottleneck, layers=[3, 24, 36, 3],
stem_type='deep', stem_width=64, avg_down=True, base_width=64, cardinality=1,
block_args=dict(radix=2, avd=True, avd_first=False), **kwargs)
model.default_cfg = default_cfg
if pretrained:
load_pretrained(model, default_cfg, num_classes, in_chans)
return model
return _create_resnest('resnest200e', pretrained=pretrained, **model_kwargs)
@register_model
def resnest269e(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
def resnest269e(pretrained=False, **kwargs):
""" ResNeSt-269e model. Matches paper ResNeSt-269 model, https://arxiv.org/abs/2004.08955
Since this codebase supports all possible variations, 'e' for deep stem, stem_width 64, avg in downsample.
"""
default_cfg = default_cfgs['resnest269e']
model = ResNet(
ResNestBottleneck, [3, 30, 48, 8], num_classes=num_classes, in_chans=in_chans,
model_kwargs = dict(
block=ResNestBottleneck, layers=[3, 30, 48, 8],
stem_type='deep', stem_width=64, avg_down=True, base_width=64, cardinality=1,
block_args=dict(radix=2, avd=True, avd_first=False), **kwargs)
model.default_cfg = default_cfg
if pretrained:
load_pretrained(model, default_cfg, num_classes, in_chans)
return model
return _create_resnest('resnest269e', pretrained=pretrained, **model_kwargs)
@register_model
def resnest50d_4s2x40d(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
def resnest50d_4s2x40d(pretrained=False, **kwargs):
"""ResNeSt-50 4s2x40d from https://github.com/zhanghang1989/ResNeSt/blob/master/ablation.md
"""
default_cfg = default_cfgs['resnest50d_4s2x40d']
model = ResNet(
ResNestBottleneck, [3, 4, 6, 3], num_classes=num_classes, in_chans=in_chans,
model_kwargs = dict(
block=ResNestBottleneck, layers=[3, 4, 6, 3],
stem_type='deep', stem_width=32, avg_down=True, base_width=40, cardinality=2,
block_args=dict(radix=4, avd=True, avd_first=True), **kwargs)
model.default_cfg = default_cfg
if pretrained:
load_pretrained(model, default_cfg, num_classes, in_chans)
return model
return _create_resnest('resnest50d_4s2x40d', pretrained=pretrained, **model_kwargs)
@register_model
def resnest50d_1s4x24d(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
def resnest50d_1s4x24d(pretrained=False, **kwargs):
"""ResNeSt-50 1s4x24d from https://github.com/zhanghang1989/ResNeSt/blob/master/ablation.md
"""
default_cfg = default_cfgs['resnest50d_1s4x24d']
model = ResNet(
ResNestBottleneck, [3, 4, 6, 3], num_classes=num_classes, in_chans=in_chans,
model_kwargs = dict(
block=ResNestBottleneck, layers=[3, 4, 6, 3],
stem_type='deep', stem_width=32, avg_down=True, base_width=24, cardinality=4,
block_args=dict(radix=1, avd=True, avd_first=True), **kwargs)
model.default_cfg = default_cfg
if pretrained:
load_pretrained(model, default_cfg, num_classes, in_chans)
return model
return _create_resnest('resnest50d_1s4x24d', pretrained=pretrained, **model_kwargs)

@ -6,11 +6,14 @@ additional dropout and dynamic global avg/max pool.
ResNeXt, SE-ResNeXt, SENet, and MXNet Gluon stem/downsample variants, tiered stems added by Ross Wightman
"""
import math
import copy
import torch
import torch.nn as nn
import torch.nn.functional as F
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .features import FeatureNet
from .helpers import load_pretrained, adapt_model_from_file
from .layers import SelectAdaptivePool2d, DropBlock2d, DropPath, AvgPool2dSame, create_attn, BlurPool2d
from .registry import register_model
@ -390,6 +393,7 @@ class ResNet(nn.Module):
self.base_width = base_width
self.drop_rate = drop_rate
self.expansion = block.expansion
self.feature_info = [dict(num_chs=self.inplanes, reduction=2, module='act1')]
super(ResNet, self).__init__()
# Stem
@ -420,9 +424,6 @@ class ResNet(nn.Module):
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
# Feature Blocks
dp = DropPath(drop_path_rate) if drop_path_rate else None
db_3 = DropBlock2d(drop_block_rate, 7, 0.25) if drop_block_rate else None
db_4 = DropBlock2d(drop_block_rate, 7, 1.00) if drop_block_rate else None
channels, strides, dilations = [64, 128, 256, 512], [1, 2, 2, 2], [1] * 4
if output_stride == 16:
strides[3] = 1
@ -432,14 +433,23 @@ class ResNet(nn.Module):
dilations[2:4] = [2, 4]
else:
assert output_stride == 32
dp = DropPath(drop_path_rate) if drop_path_rate else None
db = [
None, None,
DropBlock2d(drop_block_rate, 5, 0.25) if drop_block_rate else None,
DropBlock2d(drop_block_rate, 3, 1.00) if drop_block_rate else None]
layer_args = list(zip(channels, layers, strides, dilations))
layer_kwargs = dict(
reduce_first=block_reduce_first, act_layer=act_layer, norm_layer=norm_layer, aa_layer=aa_layer,
avg_down=avg_down, down_kernel_size=down_kernel_size, drop_path=dp, **block_args)
self.layer1 = self._make_layer(block, *layer_args[0], **layer_kwargs)
self.layer2 = self._make_layer(block, *layer_args[1], **layer_kwargs)
self.layer3 = self._make_layer(block, drop_block=db_3, *layer_args[2], **layer_kwargs)
self.layer4 = self._make_layer(block, drop_block=db_4, *layer_args[3], **layer_kwargs)
current_stride = 4
for i in range(4):
layer_name = f'layer{i + 1}'
self.add_module(layer_name, self._make_layer(
block, *layer_args[i], drop_block=db[i], **layer_kwargs))
current_stride *= strides[i]
self.feature_info.append(dict(
num_chs=self.inplanes, reduction=current_stride, module=layer_name))
# Head (Pooling and Classifier)
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
@ -509,245 +519,185 @@ class ResNet(nn.Module):
return x
def _create_resnet_with_cfg(variant, default_cfg, pretrained=False, **kwargs):
assert isinstance(default_cfg, dict)
load_strict, features = True, False
out_indices = None
if kwargs.pop('features_only', False):
load_strict, features = False, True
kwargs.pop('num_classes', 0)
out_indices = kwargs.pop('out_indices', (0, 1, 2, 3, 4))
model = ResNet(**kwargs)
model.default_cfg = copy.deepcopy(default_cfg)
if kwargs.pop('pruned', False):
model = adapt_model_from_file(model, variant)
if pretrained:
load_pretrained(
model,
num_classes=kwargs.get('num_classes', 0), in_chans=kwargs.get('in_chans', 3), strict=load_strict)
if features:
model = FeatureNet(model, out_indices=out_indices)
return model
def _create_resnet(variant, pretrained=False, **kwargs):
default_cfg = default_cfgs[variant]
return _create_resnet_with_cfg(variant, default_cfg, pretrained=pretrained, **kwargs)
@register_model
def resnet18(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
def resnet18(pretrained=False, **kwargs):
"""Constructs a ResNet-18 model.
"""
default_cfg = default_cfgs['resnet18']
model = ResNet(BasicBlock, [2, 2, 2, 2], num_classes=num_classes, in_chans=in_chans, **kwargs)
model.default_cfg = default_cfg
if pretrained:
load_pretrained(model, default_cfg, num_classes, in_chans)
return model
model_args = dict(block=BasicBlock, layers=[2, 2, 2, 2], **kwargs)
return _create_resnet('resnet18', pretrained, **model_args)
@register_model
def resnet34(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
def resnet34(pretrained=False, **kwargs):
"""Constructs a ResNet-34 model.
"""
default_cfg = default_cfgs['resnet34']
model = ResNet(BasicBlock, [3, 4, 6, 3], num_classes=num_classes, in_chans=in_chans, **kwargs)
model.default_cfg = default_cfg
if pretrained:
load_pretrained(model, default_cfg, num_classes, in_chans)
return model
model_args = dict(block=BasicBlock, layers=[3, 4, 6, 3], **kwargs)
return _create_resnet('resnet34', pretrained, **model_args)
@register_model
def resnet26(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
def resnet26(pretrained=False, **kwargs):
"""Constructs a ResNet-26 model.
"""
default_cfg = default_cfgs['resnet26']
model = ResNet(Bottleneck, [2, 2, 2, 2], num_classes=num_classes, in_chans=in_chans, **kwargs)
model.default_cfg = default_cfg
if pretrained:
load_pretrained(model, default_cfg, num_classes, in_chans)
return model
model_args = dict(block=Bottleneck, layers=[2, 2, 2, 2], **kwargs)
return _create_resnet('resnet26', pretrained, **model_args)
@register_model
def resnet26d(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
def resnet26d(pretrained=False, **kwargs):
"""Constructs a ResNet-26 v1d model.
This is technically a 28 layer ResNet, sticking with 'd' modifier from Gluon for now.
"""
default_cfg = default_cfgs['resnet26d']
model = ResNet(
Bottleneck, [2, 2, 2, 2], stem_width=32, stem_type='deep', avg_down=True,
num_classes=num_classes, in_chans=in_chans, **kwargs)
model.default_cfg = default_cfg
if pretrained:
load_pretrained(model, default_cfg, num_classes, in_chans)
return model
model_args = dict(block=Bottleneck, layers=[2, 2, 2, 2], stem_type='deep', avg_down=True, **kwargs)
return _create_resnet('resnet26d', pretrained, **model_args)
@register_model
def resnet50(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
def resnet50(pretrained=False, **kwargs):
"""Constructs a ResNet-50 model.
"""
default_cfg = default_cfgs['resnet50']
model = ResNet(Bottleneck, [3, 4, 6, 3], num_classes=num_classes, in_chans=in_chans, **kwargs)
model.default_cfg = default_cfg
if pretrained:
load_pretrained(model, default_cfg, num_classes, in_chans)
return model
model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], **kwargs)
return _create_resnet('resnet50', pretrained, **model_args)
@register_model
def resnet50d(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
def resnet50d(pretrained=False, **kwargs):
"""Constructs a ResNet-50-D model.
"""
default_cfg = default_cfgs['resnet50d']
model = ResNet(
Bottleneck, [3, 4, 6, 3], stem_width=32, stem_type='deep', avg_down=True,
num_classes=num_classes, in_chans=in_chans, **kwargs)
model.default_cfg = default_cfg
if pretrained:
load_pretrained(model, default_cfg, num_classes, in_chans)
return model
model_args = dict(
block=Bottleneck, layers=[3, 4, 6, 3], stem_width=32, stem_type='deep', avg_down=True, **kwargs)
return _create_resnet('resnet50d', pretrained, **model_args)
@register_model
def resnet101(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
def resnet101(pretrained=False, **kwargs):
"""Constructs a ResNet-101 model.
"""
default_cfg = default_cfgs['resnet101']
model = ResNet(Bottleneck, [3, 4, 23, 3], num_classes=num_classes, in_chans=in_chans, **kwargs)
model.default_cfg = default_cfg
if pretrained:
load_pretrained(model, default_cfg, num_classes, in_chans)
return model
model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], **kwargs)
return _create_resnet('resnet101', pretrained, **model_args)
@register_model
def resnet152(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
def resnet152(pretrained=False, **kwargs):
"""Constructs a ResNet-152 model.
"""
default_cfg = default_cfgs['resnet152']
model = ResNet(Bottleneck, [3, 8, 36, 3], num_classes=num_classes, in_chans=in_chans, **kwargs)
model.default_cfg = default_cfg
if pretrained:
load_pretrained(model, default_cfg, num_classes, in_chans)
return model
model_args = dict(block=Bottleneck, layers=[3, 8, 36, 3], **kwargs)
return _create_resnet('resnet152', pretrained, **model_args)
@register_model
def tv_resnet34(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
def tv_resnet34(pretrained=False, **kwargs):
"""Constructs a ResNet-34 model with original Torchvision weights.
"""
model = ResNet(BasicBlock, [3, 4, 6, 3], num_classes=num_classes, in_chans=in_chans, **kwargs)
model.default_cfg = default_cfgs['tv_resnet34']
if pretrained:
load_pretrained(model, model.default_cfg, num_classes, in_chans)
return model
model_args = dict(block=BasicBlock, layers=[3, 4, 6, 3], **kwargs)
return _create_resnet('tv_resnet34', pretrained, **model_args)
@register_model
def tv_resnet50(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
def tv_resnet50(pretrained=False, **kwargs):
"""Constructs a ResNet-50 model with original Torchvision weights.
"""
model = ResNet(Bottleneck, [3, 4, 6, 3], num_classes=num_classes, in_chans=in_chans, **kwargs)
model.default_cfg = default_cfgs['tv_resnet50']
if pretrained:
load_pretrained(model, model.default_cfg, num_classes, in_chans)
return model
model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], **kwargs)
return _create_resnet('tv_resnet50', pretrained, **model_args)
@register_model
def wide_resnet50_2(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
def wide_resnet50_2(pretrained=False, **kwargs):
"""Constructs a Wide ResNet-50-2 model.
The model is the same as ResNet except for the bottleneck number of channels
which is twice larger in every block. The number of channels in outer 1x1
convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
channels, and in Wide ResNet-50-2 has 2048-1024-2048.
"""
model = ResNet(
Bottleneck, [3, 4, 6, 3], base_width=128,
num_classes=num_classes, in_chans=in_chans, **kwargs)
model.default_cfg = default_cfgs['wide_resnet50_2']
if pretrained:
load_pretrained(model, model.default_cfg, num_classes, in_chans)
return model
model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], base_width=128, **kwargs)
return _create_resnet('wide_resnet50_2', pretrained, **model_args)
@register_model
def wide_resnet101_2(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
def wide_resnet101_2(pretrained=False, **kwargs):
"""Constructs a Wide ResNet-101-2 model.
The model is the same as ResNet except for the bottleneck number of channels
which is twice larger in every block. The number of channels in outer 1x1
convolutions is the same.
"""
model = ResNet(
Bottleneck, [3, 4, 23, 3], base_width=128,
num_classes=num_classes, in_chans=in_chans, **kwargs)
model.default_cfg = default_cfgs['wide_resnet101_2']
if pretrained:
load_pretrained(model, model.default_cfg, num_classes, in_chans)
return model
model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], base_width=128, **kwargs)
return _create_resnet('wide_resnet101_2', pretrained, **model_args)
@register_model
def resnext50_32x4d(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
def resnext50_32x4d(pretrained=False, **kwargs):
"""Constructs a ResNeXt50-32x4d model.
"""
default_cfg = default_cfgs['resnext50_32x4d']
model = ResNet(
Bottleneck, [3, 4, 6, 3], cardinality=32, base_width=4,
num_classes=num_classes, in_chans=in_chans, **kwargs)
model.default_cfg = default_cfg
if pretrained:
load_pretrained(model, default_cfg, num_classes, in_chans)
return model
model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], cardinality=32, base_width=4, **kwargs)
return _create_resnet('resnext50_32x4d', pretrained, **model_args)
@register_model
def resnext50d_32x4d(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
def resnext50d_32x4d(pretrained=False, **kwargs):
"""Constructs a ResNeXt50d-32x4d model. ResNext50 w/ deep stem & avg pool downsample
"""
default_cfg = default_cfgs['resnext50d_32x4d']
model = ResNet(
Bottleneck, [3, 4, 6, 3], cardinality=32, base_width=4,
stem_width=32, stem_type='deep', avg_down=True,
num_classes=num_classes, in_chans=in_chans, **kwargs)
model.default_cfg = default_cfg
if pretrained:
load_pretrained(model, default_cfg, num_classes, in_chans)
return model
model_args = dict(
block=Bottleneck, layers=[3, 4, 6, 3], cardinality=32, base_width=4,
stem_width=32, stem_type='deep', avg_down=True, **kwargs)
return _create_resnet('resnext50d_32x4d', pretrained, **model_args)
@register_model
def resnext101_32x4d(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
def resnext101_32x4d(pretrained=False, **kwargs):
"""Constructs a ResNeXt-101 32x4d model.
"""
default_cfg = default_cfgs['resnext101_32x4d']
model = ResNet(
Bottleneck, [3, 4, 23, 3], cardinality=32, base_width=4,
num_classes=num_classes, in_chans=in_chans, **kwargs)
model.default_cfg = default_cfg
if pretrained:
load_pretrained(model, default_cfg, num_classes, in_chans)
return model
model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=4, **kwargs)
return _create_resnet('resnext101_32x4d', pretrained, **model_args)
@register_model
def resnext101_32x8d(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
def resnext101_32x8d(pretrained=False, **kwargs):
"""Constructs a ResNeXt-101 32x8d model.
"""
default_cfg = default_cfgs['resnext101_32x8d']
model = ResNet(
Bottleneck, [3, 4, 23, 3], cardinality=32, base_width=8,
num_classes=num_classes, in_chans=in_chans, **kwargs)
model.default_cfg = default_cfg
if pretrained:
load_pretrained(model, default_cfg, num_classes, in_chans)
return model
model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=8, **kwargs)
return _create_resnet('resnext101_32x8d', pretrained, **model_args)
@register_model
def resnext101_64x4d(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
def resnext101_64x4d(pretrained=False, **kwargs):
"""Constructs a ResNeXt101-64x4d model.
"""
default_cfg = default_cfgs['resnext101_32x4d']
model = ResNet(
Bottleneck, [3, 4, 23, 3], cardinality=64, base_width=4,
num_classes=num_classes, in_chans=in_chans, **kwargs)
model.default_cfg = default_cfg
if pretrained:
load_pretrained(model, default_cfg, num_classes, in_chans)
return model
model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], cardinality=64, base_width=4, **kwargs)
return _create_resnet('resnext101_64x4d', pretrained, **model_args)
@register_model
def tv_resnext50_32x4d(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
def tv_resnext50_32x4d(pretrained=False, **kwargs):
"""Constructs a ResNeXt50-32x4d model with original Torchvision weights.
"""
default_cfg = default_cfgs['tv_resnext50_32x4d']
model = ResNet(
Bottleneck, [3, 4, 6, 3], cardinality=32, base_width=4,
num_classes=num_classes, in_chans=in_chans, **kwargs)
model.default_cfg = default_cfg
if pretrained:
load_pretrained(model, default_cfg, num_classes, in_chans)
return model
model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], cardinality=32, base_width=4, **kwargs)
return _create_resnet('tv_resnext50_32x4d', pretrained, **model_args)
@register_model
@ -757,11 +707,8 @@ def ig_resnext101_32x8d(pretrained=True, **kwargs):
`"Exploring the Limits of Weakly Supervised Pretraining" <https://arxiv.org/abs/1805.00932>`_
Weights from https://pytorch.org/hub/facebookresearch_WSL-Images_resnext/
"""
model = ResNet(Bottleneck, [3, 4, 23, 3], cardinality=32, base_width=8, **kwargs)
model.default_cfg = default_cfgs['ig_resnext101_32x8d']
if pretrained:
load_pretrained(model, num_classes=kwargs.get('num_classes', 0), in_chans=kwargs.get('in_chans', 3))
return model
model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=8, **kwargs)
return _create_resnet('ig_resnext101_32x8d', pretrained, **model_args)
@register_model
@ -771,11 +718,8 @@ def ig_resnext101_32x16d(pretrained=True, **kwargs):
`"Exploring the Limits of Weakly Supervised Pretraining" <https://arxiv.org/abs/1805.00932>`_
Weights from https://pytorch.org/hub/facebookresearch_WSL-Images_resnext/
"""
model = ResNet(Bottleneck, [3, 4, 23, 3], cardinality=32, base_width=16, **kwargs)
model.default_cfg = default_cfgs['ig_resnext101_32x16d']
if pretrained:
load_pretrained(model, num_classes=kwargs.get('num_classes', 0), in_chans=kwargs.get('in_chans', 3))
return model
model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=16, **kwargs)
return _create_resnet('ig_resnext101_32x16d', pretrained, **model_args)
@register_model
@ -785,11 +729,8 @@ def ig_resnext101_32x32d(pretrained=True, **kwargs):
`"Exploring the Limits of Weakly Supervised Pretraining" <https://arxiv.org/abs/1805.00932>`_
Weights from https://pytorch.org/hub/facebookresearch_WSL-Images_resnext/
"""
model = ResNet(Bottleneck, [3, 4, 23, 3], cardinality=32, base_width=32, **kwargs)
model.default_cfg = default_cfgs['ig_resnext101_32x32d']
if pretrained:
load_pretrained(model, num_classes=kwargs.get('num_classes', 0), in_chans=kwargs.get('in_chans', 3))
return model
model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=32, **kwargs)
return _create_resnet('ig_resnext101_32x32d', pretrained, **model_args)
@register_model
@ -799,11 +740,8 @@ def ig_resnext101_32x48d(pretrained=True, **kwargs):
`"Exploring the Limits of Weakly Supervised Pretraining" <https://arxiv.org/abs/1805.00932>`_
Weights from https://pytorch.org/hub/facebookresearch_WSL-Images_resnext/
"""
model = ResNet(Bottleneck, [3, 4, 23, 3], cardinality=32, base_width=48, **kwargs)
model.default_cfg = default_cfgs['ig_resnext101_32x48d']
if pretrained:
load_pretrained(model, num_classes=kwargs.get('num_classes', 0), in_chans=kwargs.get('in_chans', 3))
return model
model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=48, **kwargs)
return _create_resnet('ig_resnext101_32x48d', pretrained, **model_args)
@register_model
@ -812,11 +750,8 @@ def ssl_resnet18(pretrained=True, **kwargs):
`"Billion-scale Semi-Supervised Learning for Image Classification" <https://arxiv.org/abs/1905.00546>`_
Weights from https://github.com/facebookresearch/semi-supervised-ImageNet1K-models/
"""
model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
model.default_cfg = default_cfgs['ssl_resnet18']
if pretrained:
load_pretrained(model, num_classes=kwargs.get('num_classes', 0), in_chans=kwargs.get('in_chans', 3))
return model
model_args = dict(block=BasicBlock, layers=[2, 2, 2, 2], **kwargs)
return _create_resnet('ssl_resnet18', pretrained, **model_args)
@register_model
@ -825,11 +760,8 @@ def ssl_resnet50(pretrained=True, **kwargs):
`"Billion-scale Semi-Supervised Learning for Image Classification" <https://arxiv.org/abs/1905.00546>`_
Weights from https://github.com/facebookresearch/semi-supervised-ImageNet1K-models/
"""
model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
model.default_cfg = default_cfgs['ssl_resnet50']
if pretrained:
load_pretrained(model, num_classes=kwargs.get('num_classes', 0), in_chans=kwargs.get('in_chans', 3))
return model
model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], **kwargs)
return _create_resnet('ssl_resnet50', pretrained, **model_args)
@register_model
@ -838,11 +770,8 @@ def ssl_resnext50_32x4d(pretrained=True, **kwargs):
`"Billion-scale Semi-Supervised Learning for Image Classification" <https://arxiv.org/abs/1905.00546>`_
Weights from https://github.com/facebookresearch/semi-supervised-ImageNet1K-models/
"""
model = ResNet(Bottleneck, [3, 4, 6, 3], cardinality=32, base_width=4, **kwargs)
model.default_cfg = default_cfgs['ssl_resnext50_32x4d']
if pretrained:
load_pretrained(model, num_classes=kwargs.get('num_classes', 0), in_chans=kwargs.get('in_chans', 3))
return model
model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], cardinality=32, base_width=4, **kwargs)
return _create_resnet('ssl_resnext50_32x4d', pretrained, **model_args)
@register_model
@ -851,11 +780,8 @@ def ssl_resnext101_32x4d(pretrained=True, **kwargs):
`"Billion-scale Semi-Supervised Learning for Image Classification" <https://arxiv.org/abs/1905.00546>`_
Weights from https://github.com/facebookresearch/semi-supervised-ImageNet1K-models/
"""
model = ResNet(Bottleneck, [3, 4, 23, 3], cardinality=32, base_width=4, **kwargs)
model.default_cfg = default_cfgs['ssl_resnext101_32x4d']
if pretrained:
load_pretrained(model, num_classes=kwargs.get('num_classes', 0), in_chans=kwargs.get('in_chans', 3))
return model
model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=4, **kwargs)
return _create_resnet('ssl_resnext101_32x4d', pretrained, **model_args)
@register_model
@ -864,11 +790,8 @@ def ssl_resnext101_32x8d(pretrained=True, **kwargs):
`"Billion-scale Semi-Supervised Learning for Image Classification" <https://arxiv.org/abs/1905.00546>`_
Weights from https://github.com/facebookresearch/semi-supervised-ImageNet1K-models/
"""
model = ResNet(Bottleneck, [3, 4, 23, 3], cardinality=32, base_width=8, **kwargs)
model.default_cfg = default_cfgs['ssl_resnext101_32x8d']
if pretrained:
load_pretrained(model, num_classes=kwargs.get('num_classes', 0), in_chans=kwargs.get('in_chans', 3))
return model
model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=8, **kwargs)
return _create_resnet('ssl_resnext101_32x8d', pretrained, **model_args)
@register_model
@ -877,11 +800,8 @@ def ssl_resnext101_32x16d(pretrained=True, **kwargs):
`"Billion-scale Semi-Supervised Learning for Image Classification" <https://arxiv.org/abs/1905.00546>`_
Weights from https://github.com/facebookresearch/semi-supervised-ImageNet1K-models/
"""
model = ResNet(Bottleneck, [3, 4, 23, 3], cardinality=32, base_width=16, **kwargs)
model.default_cfg = default_cfgs['ssl_resnext101_32x16d']
if pretrained:
load_pretrained(model, num_classes=kwargs.get('num_classes', 0), in_chans=kwargs.get('in_chans', 3))
return model
model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=16, **kwargs)
return _create_resnet('ssl_resnext101_32x16d', pretrained, **model_args)
@register_model
@ -891,11 +811,8 @@ def swsl_resnet18(pretrained=True, **kwargs):
`"Billion-scale Semi-Supervised Learning for Image Classification" <https://arxiv.org/abs/1905.00546>`_
Weights from https://github.com/facebookresearch/semi-supervised-ImageNet1K-models/
"""
model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
model.default_cfg = default_cfgs['swsl_resnet18']
if pretrained:
load_pretrained(model, num_classes=kwargs.get('num_classes', 0), in_chans=kwargs.get('in_chans', 3))
return model
model_args = dict(block=BasicBlock, layers=[2, 2, 2, 2], **kwargs)
return _create_resnet('swsl_resnet18', pretrained, **model_args)
@register_model
@ -905,11 +822,8 @@ def swsl_resnet50(pretrained=True, **kwargs):
`"Billion-scale Semi-Supervised Learning for Image Classification" <https://arxiv.org/abs/1905.00546>`_
Weights from https://github.com/facebookresearch/semi-supervised-ImageNet1K-models/
"""
model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
model.default_cfg = default_cfgs['swsl_resnet50']
if pretrained:
load_pretrained(model, num_classes=kwargs.get('num_classes', 0), in_chans=kwargs.get('in_chans', 3))
return model
model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], **kwargs)
return _create_resnet('swsl_resnet50', pretrained, **model_args)
@register_model
@ -919,11 +833,8 @@ def swsl_resnext50_32x4d(pretrained=True, **kwargs):
`"Billion-scale Semi-Supervised Learning for Image Classification" <https://arxiv.org/abs/1905.00546>`_
Weights from https://github.com/facebookresearch/semi-supervised-ImageNet1K-models/
"""
model = ResNet(Bottleneck, [3, 4, 6, 3], cardinality=32, base_width=4, **kwargs)
model.default_cfg = default_cfgs['swsl_resnext50_32x4d']
if pretrained:
load_pretrained(model, num_classes=kwargs.get('num_classes', 0), in_chans=kwargs.get('in_chans', 3))
return model
model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], cardinality=32, base_width=4, **kwargs)
return _create_resnet('swsl_resnext50_32x4d', pretrained, **model_args)
@register_model
@ -933,11 +844,8 @@ def swsl_resnext101_32x4d(pretrained=True, **kwargs):
`"Billion-scale Semi-Supervised Learning for Image Classification" <https://arxiv.org/abs/1905.00546>`_
Weights from https://github.com/facebookresearch/semi-supervised-ImageNet1K-models/
"""
model = ResNet(Bottleneck, [3, 4, 23, 3], cardinality=32, base_width=4, **kwargs)
model.default_cfg = default_cfgs['swsl_resnext101_32x4d']
if pretrained:
load_pretrained(model, num_classes=kwargs.get('num_classes', 0), in_chans=kwargs.get('in_chans', 3))
return model
model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=4, **kwargs)
return _create_resnet('swsl_resnext101_32x4d', pretrained, **model_args)
@register_model
@ -947,11 +855,8 @@ def swsl_resnext101_32x8d(pretrained=True, **kwargs):
`"Billion-scale Semi-Supervised Learning for Image Classification" <https://arxiv.org/abs/1905.00546>`_
Weights from https://github.com/facebookresearch/semi-supervised-ImageNet1K-models/
"""
model = ResNet(Bottleneck, [3, 4, 23, 3], cardinality=32, base_width=8, **kwargs)
model.default_cfg = default_cfgs['swsl_resnext101_32x8d']
if pretrained:
load_pretrained(model, num_classes=kwargs.get('num_classes', 0), in_chans=kwargs.get('in_chans', 3))
return model
model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=8, **kwargs)
return _create_resnet('swsl_resnext101_32x8d', pretrained, **model_args)
@register_model
@ -961,61 +866,44 @@ def swsl_resnext101_32x16d(pretrained=True, **kwargs):
`"Billion-scale Semi-Supervised Learning for Image Classification" <https://arxiv.org/abs/1905.00546>`_
Weights from https://github.com/facebookresearch/semi-supervised-ImageNet1K-models/
"""
model = ResNet(Bottleneck, [3, 4, 23, 3], cardinality=32, base_width=16, **kwargs)
model.default_cfg = default_cfgs['swsl_resnext101_32x16d']
if pretrained:
load_pretrained(model, num_classes=kwargs.get('num_classes', 0), in_chans=kwargs.get('in_chans', 3))
return model
model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=16, **kwargs)
return _create_resnet('swsl_resnext101_32x16d', pretrained, **model_args)
@register_model
def seresnext26d_32x4d(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
def seresnext26d_32x4d(pretrained=False, **kwargs):
"""Constructs a SE-ResNeXt-26-D model.
This is technically a 28 layer ResNet, using the 'D' modifier from Gluon / bag-of-tricks for
combination of deep stem and avg_pool in downsample.
"""
default_cfg = default_cfgs['seresnext26d_32x4d']
model = ResNet(
Bottleneck, [2, 2, 2, 2], cardinality=32, base_width=4, stem_width=32, stem_type='deep', avg_down=True,
num_classes=num_classes, in_chans=in_chans, block_args=dict(attn_layer='se'), **kwargs)
model.default_cfg = default_cfg
if pretrained:
load_pretrained(model, default_cfg, num_classes, in_chans)
return model
model_args = dict(
block=Bottleneck, layers=[2, 2, 2, 2], cardinality=32, base_width=4, stem_width=32,
stem_type='deep', avg_down=True, block_args=dict(attn_layer='se'), **kwargs)
return _create_resnet('seresnext26d_32x4d', pretrained, **model_args)
@register_model
def seresnext26t_32x4d(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
def seresnext26t_32x4d(pretrained=False, **kwargs):
"""Constructs a SE-ResNet-26-T model.
This is technically a 28 layer ResNet, like a 'D' bag-of-tricks model but with tiered 24, 48, 64 channels
in the deep stem.
"""
default_cfg = default_cfgs['seresnext26t_32x4d']
model = ResNet(
Bottleneck, [2, 2, 2, 2], cardinality=32, base_width=4,
stem_width=32, stem_type='deep_tiered', avg_down=True,
num_classes=num_classes, in_chans=in_chans, block_args=dict(attn_layer='se'), **kwargs)
model.default_cfg = default_cfg
if pretrained:
load_pretrained(model, default_cfg, num_classes, in_chans)
return model
model_args = dict(
block=Bottleneck, layers=[2, 2, 2, 2], cardinality=32, base_width=4, stem_width=32,
stem_type='deep_tiered', avg_down=True, block_args=dict(attn_layer='se'), **kwargs)
return _create_resnet('seresnext26t_32x4d', pretrained, **model_args)
@register_model
def seresnext26tn_32x4d(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
def seresnext26tn_32x4d(pretrained=False, **kwargs):
"""Constructs a SE-ResNeXt-26-TN model.
This is technically a 28 layer ResNet, like a 'D' bag-of-tricks model but with tiered 24, 32, 64 channels
in the deep stem. The channel number of the middle stem conv is narrower than the 'T' variant.
"""
default_cfg = default_cfgs['seresnext26tn_32x4d']
model = ResNet(
Bottleneck, [2, 2, 2, 2], cardinality=32, base_width=4,
stem_width=32, stem_type='deep_tiered_narrow', avg_down=True,
num_classes=num_classes, in_chans=in_chans, block_args=dict(attn_layer='se'), **kwargs)
model.default_cfg = default_cfg
if pretrained:
load_pretrained(model, default_cfg, num_classes, in_chans)
return model
model_args = dict(
block=Bottleneck, layers=[2, 2, 2, 2], cardinality=32, base_width=4, stem_width=32,
stem_type='deep_tiered_narrow', avg_down=True, block_args=dict(attn_layer='se'), **kwargs)
return _create_resnet('seresnext26tn_32x4d', pretrained, **model_args)
@register_model
@ -1025,145 +913,91 @@ def ecaresnext26tn_32x4d(pretrained=False, num_classes=1000, in_chans=3, **kwarg
in the deep stem. The channel number of the middle stem conv is narrower than the 'T' variant.
this model replaces SE module with the ECA module
"""
default_cfg = default_cfgs['ecaresnext26tn_32x4d']
block_args = dict(attn_layer='eca')
model = ResNet(
Bottleneck, [2, 2, 2, 2], cardinality=32, base_width=4,
stem_width=32, stem_type='deep_tiered_narrow', avg_down=True,
num_classes=num_classes, in_chans=in_chans, block_args=block_args, **kwargs)
model.default_cfg = default_cfg
if pretrained:
load_pretrained(model, default_cfg, num_classes, in_chans)
return model
model_args = dict(
block=Bottleneck, layers=[2, 2, 2, 2], cardinality=32, base_width=4, stem_width=32,
stem_type='deep_tiered_narrow', avg_down=True, block_args=dict(attn_layer='eca'), **kwargs)
return _create_resnet('ecaresnext26tn_32x4d', pretrained, **model_args)
@register_model
def ecaresnet18(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
def ecaresnet18(pretrained=False, **kwargs):
""" Constructs an ECA-ResNet-18 model.
"""
default_cfg = default_cfgs['ecaresnet18']
block_args = dict(attn_layer='eca')
model = ResNet(
BasicBlock, [2, 2, 2, 2], num_classes=num_classes, in_chans=in_chans, block_args=block_args, **kwargs)
model.default_cfg = default_cfg
if pretrained:
load_pretrained(model, default_cfg, num_classes, in_chans)
return model
model_args = dict(block=BasicBlock, layers=[2, 2, 2, 2], block_args=dict(attn_layer='eca'), **kwargs)
return _create_resnet('ecaresnet18', pretrained, **model_args)
@register_model
def ecaresnet50(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
def ecaresnet50(pretrained=False, **kwargs):
"""Constructs an ECA-ResNet-50 model.
"""
default_cfg = default_cfgs['ecaresnet50']
block_args = dict(attn_layer='eca')
model = ResNet(
Bottleneck, [3, 4, 6, 3], num_classes=num_classes, in_chans=in_chans, block_args=block_args, **kwargs)
model.default_cfg = default_cfg
if pretrained:
load_pretrained(model, default_cfg, num_classes, in_chans)
return model
model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], block_args=dict(attn_layer='eca'), **kwargs)
return _create_resnet('ecaresnet50', pretrained, **model_args)
@register_model
def ecaresnet50d(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
def ecaresnet50d(pretrained=False, **kwargs):
"""Constructs a ResNet-50-D model with eca.
"""
default_cfg = default_cfgs['ecaresnet50d']
model = ResNet(
Bottleneck, [3, 4, 6, 3], stem_width=32, stem_type='deep', avg_down=True,
num_classes=num_classes, in_chans=in_chans, block_args=dict(attn_layer='eca'), **kwargs)
model.default_cfg = default_cfg
if pretrained:
load_pretrained(model, default_cfg, num_classes, in_chans)
return model
model_args = dict(
block=Bottleneck, layers=[3, 4, 6, 3], stem_width=32, stem_type='deep', avg_down=True,
block_args=dict(attn_layer='eca'), **kwargs)
return _create_resnet('ecaresnet50d', pretrained, **model_args)
@register_model
def ecaresnet50d_pruned(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
def ecaresnet50d_pruned(pretrained=False, **kwargs):
"""Constructs a ResNet-50-D model pruned with eca.
The pruning has been obtained using https://arxiv.org/pdf/2002.08258.pdf
"""
variant = 'ecaresnet50d_pruned'
default_cfg = default_cfgs[variant]
model = ResNet(
Bottleneck, [3, 4, 6, 3], stem_width=32, stem_type='deep', avg_down=True,
num_classes=num_classes, in_chans=in_chans, block_args=dict(attn_layer='eca'), **kwargs)
model.default_cfg = default_cfg
model = adapt_model_from_file(model, variant)
if pretrained:
load_pretrained(model, default_cfg, num_classes, in_chans)
return model
model_args = dict(
block=Bottleneck, layers=[3, 4, 6, 3], stem_width=32, stem_type='deep', avg_down=True,
block_args=dict(attn_layer='eca'), **kwargs)
return _create_resnet('ecaresnet50d_pruned', pretrained, pruned=True, **model_args)
@register_model
def ecaresnetlight(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
def ecaresnetlight(pretrained=False, **kwargs):
"""Constructs a ResNet-50-D light model with eca.
"""
default_cfg = default_cfgs['ecaresnetlight']
model = ResNet(
Bottleneck, [1, 1, 11, 3], stem_width=32, avg_down=True,
num_classes=num_classes, in_chans=in_chans, block_args=dict(attn_layer='eca'), **kwargs)
model.default_cfg = default_cfg
if pretrained:
load_pretrained(model, default_cfg, num_classes, in_chans)
return model
model_args = dict(
block=Bottleneck, layers=[1, 1, 11, 3], stem_width=32, avg_down=True,
block_args=dict(attn_layer='eca'), **kwargs)
return _create_resnet('ecaresnetlight', pretrained, **model_args)
@register_model
def ecaresnet101d(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
def ecaresnet101d(pretrained=False, **kwargs):
"""Constructs a ResNet-101-D model with eca.
"""
default_cfg = default_cfgs['ecaresnet101d']
model = ResNet(
Bottleneck, [3, 4, 23, 3], stem_width=32, stem_type='deep', avg_down=True,
num_classes=num_classes, in_chans=in_chans, block_args=dict(attn_layer='eca'), **kwargs)
model.default_cfg = default_cfg
if pretrained:
load_pretrained(model, default_cfg, num_classes, in_chans)
return model
model_args = dict(
block=Bottleneck, layers=[3, 4, 23, 3], stem_width=32, stem_type='deep', avg_down=True,
block_args=dict(attn_layer='eca'), **kwargs)
return _create_resnet('ecaresnet101d', pretrained, **model_args)
@register_model
def ecaresnet101d_pruned(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
def ecaresnet101d_pruned(pretrained=False, **kwargs):
"""Constructs a ResNet-101-D model pruned with eca.
The pruning has been obtained using https://arxiv.org/pdf/2002.08258.pdf
"""
variant = 'ecaresnet101d_pruned'
default_cfg = default_cfgs[variant]
model = ResNet(
Bottleneck, [3, 4, 23, 3], stem_width=32, stem_type='deep', avg_down=True,
num_classes=num_classes, in_chans=in_chans, block_args=dict(attn_layer='eca'), **kwargs)
model.default_cfg = default_cfg
model = adapt_model_from_file(model, variant)
if pretrained:
load_pretrained(model, default_cfg, num_classes, in_chans)
return model
model_args = dict(
block=Bottleneck, layers=[3, 4, 23, 3], stem_width=32, stem_type='deep', avg_down=True,
block_args=dict(attn_layer='eca'), **kwargs)
return _create_resnet('ecaresnet101d_pruned', pretrained, pruned=True, **model_args)
@register_model
def resnetblur18(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
def resnetblur18(pretrained=False, **kwargs):
"""Constructs a ResNet-18 model with blur anti-aliasing
"""
default_cfg = default_cfgs['resnetblur18']
model = ResNet(
BasicBlock, [2, 2, 2, 2], num_classes=num_classes, in_chans=in_chans, aa_layer=BlurPool2d, **kwargs)
model.default_cfg = default_cfg
if pretrained:
load_pretrained(model, default_cfg, num_classes, in_chans)
return model
model_args = dict(block=BasicBlock, layers=[2, 2, 2, 2], aa_layer=BlurPool2d, **kwargs)
return _create_resnet('resnetblur18', pretrained, **model_args)
@register_model
def resnetblur50(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
def resnetblur50(pretrained=False, **kwargs):
"""Constructs a ResNet-50 model with blur anti-aliasing
"""
default_cfg = default_cfgs['resnetblur50']
model = ResNet(
Bottleneck, [3, 4, 6, 3], num_classes=num_classes, in_chans=in_chans, aa_layer=BlurPool2d, **kwargs)
model.default_cfg = default_cfg
if pretrained:
load_pretrained(model, default_cfg, num_classes, in_chans)
return model
model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], aa_layer=BlurPool2d, **kwargs)
return _create_resnet('resnetblur50', pretrained, **model_args)

@ -16,6 +16,7 @@ import torch.nn as nn
import torch.nn.functional as F
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .features import FeatureNet
from .helpers import load_pretrained
from .layers import SelectAdaptivePool2d
from .registry import register_model
@ -100,7 +101,8 @@ class SelecSLSBlock(nn.Module):
self.conv6 = conv_bn(2 * mid_chs + (0 if is_first else skip_chs), out_chs, 1)
def forward(self, x: List[torch.Tensor]) -> List[torch.Tensor]:
assert isinstance(x, list)
if not isinstance(x, list):
x = [x]
assert len(x) in [1, 2]
d1 = self.conv1(x[0])
@ -163,7 +165,7 @@ class SelecSLS(nn.Module):
def forward_features(self, x):
x = self.stem(x)
x = self.features([x])
x = self.features(x)
x = self.head(x[0])
return x
@ -178,6 +180,7 @@ class SelecSLS(nn.Module):
def _create_model(variant, pretrained, model_kwargs):
cfg = {}
feature_info = [dict(num_chs=32, reduction=2, module='stem.2')]
if variant.startswith('selecsls42'):
cfg['block'] = SelecSLSBlock
# Define configuration of the network after the initial neck
@ -190,7 +193,13 @@ def _create_model(variant, pretrained, model_kwargs):
(288, 0, 304, 304, True, 2),
(304, 304, 304, 480, False, 1),
]
feature_info.extend([
dict(num_chs=128, reduction=4, module='features.1'),
dict(num_chs=288, reduction=8, module='features.3'),
dict(num_chs=480, reduction=16, module='features.5'),
])
# Head can be replaced with alternative configurations depending on the problem
feature_info.append(dict(num_chs=1024, reduction=32, module='head.1'))
if variant == 'selecsls42b':
cfg['head'] = [
(480, 960, 3, 2),
@ -198,6 +207,7 @@ def _create_model(variant, pretrained, model_kwargs):
(1024, 1280, 3, 2),
(1280, 1024, 1, 1),
]
feature_info.append(dict(num_chs=1024, reduction=64, module='head.3'))
cfg['num_features'] = 1024
else:
cfg['head'] = [
@ -206,7 +216,9 @@ def _create_model(variant, pretrained, model_kwargs):
(1024, 1024, 3, 2),
(1024, 1280, 1, 1),
]
feature_info.append(dict(num_chs=1280, reduction=64, module='head.3'))
cfg['num_features'] = 1280
elif variant.startswith('selecsls60'):
cfg['block'] = SelecSLSBlock
# Define configuration of the network after the initial neck
@ -222,7 +234,13 @@ def _create_model(variant, pretrained, model_kwargs):
(288, 288, 288, 288, False, 1),
(288, 288, 288, 416, False, 1),
]
feature_info.extend([
dict(num_chs=128, reduction=4, module='features.1'),
dict(num_chs=288, reduction=8, module='features.4'),
dict(num_chs=416, reduction=16, module='features.8'),
])
# Head can be replaced with alternative configurations depending on the problem
feature_info.append(dict(num_chs=1024, reduction=32, module='head.1'))
if variant == 'selecsls60b':
cfg['head'] = [
(416, 756, 3, 2),
@ -230,6 +248,7 @@ def _create_model(variant, pretrained, model_kwargs):
(1024, 1280, 3, 2),
(1280, 1024, 1, 1),
]
feature_info.append(dict(num_chs=1024, reduction=64, module='head.3'))
cfg['num_features'] = 1024
else:
cfg['head'] = [
@ -238,7 +257,9 @@ def _create_model(variant, pretrained, model_kwargs):
(1024, 1024, 3, 2),
(1024, 1280, 1, 1),
]
feature_info.append(dict(num_chs=1280, reduction=64, module='head.3'))
cfg['num_features'] = 1280
elif variant == 'selecsls84':
cfg['block'] = SelecSLSBlock
# Define configuration of the network after the initial neck
@ -258,6 +279,11 @@ def _create_model(variant, pretrained, model_kwargs):
(304, 304, 304, 304, False, 1),
(304, 304, 304, 512, False, 1),
]
feature_info.extend([
dict(num_chs=144, reduction=4, module='features.1'),
dict(num_chs=304, reduction=8, module='features.6'),
dict(num_chs=512, reduction=16, module='features.12'),
])
# Head can be replaced with alternative configurations depending on the problem
cfg['head'] = [
(512, 960, 3, 2),
@ -266,17 +292,35 @@ def _create_model(variant, pretrained, model_kwargs):
(1024, 1280, 3, 1),
]
cfg['num_features'] = 1280
feature_info.extend([
dict(num_chs=1024, reduction=32, module='head.1'),
dict(num_chs=1280, reduction=64, module='head.3')
])
else:
raise ValueError('Invalid net configuration ' + variant + ' !!!')
load_strict = True
features = False
out_indices = None
if model_kwargs.pop('features_only', False):
load_strict = False
features = True
# this model can do 6 feature levels by default, unlike most others, leave as 0-4 to avoid surprises?
out_indices = model_kwargs.pop('out_indices', (0, 1, 2, 3, 4))
model_kwargs.pop('num_classes', 0)
model = SelecSLS(cfg, **model_kwargs)
model.default_cfg = default_cfgs[variant]
model.feature_info = feature_info
if pretrained:
load_pretrained(
model,
num_classes=model_kwargs.get('num_classes', 0),
in_chans=model_kwargs.get('in_chans', 3),
strict=True)
strict=load_strict)
if features:
model = FeatureNet(model, out_indices, flatten_sequential=True)
return model

@ -12,11 +12,11 @@ import math
from torch import nn as nn
from .registry import register_model
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .helpers import load_pretrained
from .layers import SelectiveKernelConv, ConvBnAct, create_attn
from .resnet import ResNet
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .registry import register_model
from .resnet import _create_resnet_with_cfg
def _cfg(url='', **kwargs):
@ -138,101 +138,80 @@ class SelectiveKernelBottleneck(nn.Module):
return x
def _create_skresnet(variant, pretrained=False, **kwargs):
default_cfg = default_cfgs[variant]
return _create_resnet_with_cfg(variant, default_cfg, pretrained=pretrained, **kwargs)
@register_model
def skresnet18(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
def skresnet18(pretrained=False, **kwargs):
"""Constructs a Selective Kernel ResNet-18 model.
Different from configs in Select Kernel paper or "Compounding the Performance Improvements..." this
variation splits the input channels to the selective convolutions to keep param count down.
"""
default_cfg = default_cfgs['skresnet18']
sk_kwargs = dict(
min_attn_channels=16,
attn_reduction=8,
split_input=True
)
model = ResNet(
SelectiveKernelBasic, [2, 2, 2, 2], num_classes=num_classes, in_chans=in_chans,
block_args=dict(sk_kwargs=sk_kwargs), zero_init_last_bn=False, **kwargs)
model.default_cfg = default_cfg
if pretrained:
load_pretrained(model, default_cfg, num_classes, in_chans)
return model
split_input=True)
model_args = dict(
block=SelectiveKernelBasic, layers=[2, 2, 2, 2], block_args=dict(sk_kwargs=sk_kwargs),
zero_init_last_bn=False, **kwargs)
return _create_skresnet('skresnet18', pretrained, **model_args)
@register_model
def skresnet34(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
def skresnet34(pretrained=False, **kwargs):
"""Constructs a Selective Kernel ResNet-34 model.
Different from configs in Select Kernel paper or "Compounding the Performance Improvements..." this
variation splits the input channels to the selective convolutions to keep param count down.
"""
default_cfg = default_cfgs['skresnet34']
sk_kwargs = dict(
min_attn_channels=16,
attn_reduction=8,
split_input=True
)
model = ResNet(
SelectiveKernelBasic, [3, 4, 6, 3], num_classes=num_classes, in_chans=in_chans,
block_args=dict(sk_kwargs=sk_kwargs), zero_init_last_bn=False, **kwargs)
model.default_cfg = default_cfg
if pretrained:
load_pretrained(model, default_cfg, num_classes, in_chans)
return model
split_input=True)
model_args = dict(
block=SelectiveKernelBasic, layers=[3, 4, 6, 3], block_args=dict(sk_kwargs=sk_kwargs),
zero_init_last_bn=False, **kwargs)
return _create_skresnet('skresnet34', pretrained, **model_args)
@register_model
def skresnet50(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
def skresnet50(pretrained=False, **kwargs):
"""Constructs a Select Kernel ResNet-50 model.
Different from configs in Select Kernel paper or "Compounding the Performance Improvements..." this
variation splits the input channels to the selective convolutions to keep param count down.
"""
sk_kwargs = dict(
split_input=True,
)
default_cfg = default_cfgs['skresnet50']
model = ResNet(
SelectiveKernelBottleneck, [3, 4, 6, 3], num_classes=num_classes, in_chans=in_chans,
block_args=dict(sk_kwargs=sk_kwargs), zero_init_last_bn=False, **kwargs)
model.default_cfg = default_cfg
if pretrained:
load_pretrained(model, default_cfg, num_classes, in_chans)
return model
sk_kwargs = dict(split_input=True)
model_args = dict(
block=SelectiveKernelBottleneck, layers=[3, 4, 6, 3], block_args=dict(sk_kwargs=sk_kwargs),
zero_init_last_bn=False, **kwargs)
return _create_skresnet('skresnet50', pretrained, **model_args)
@register_model
def skresnet50d(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
def skresnet50d(pretrained=False, **kwargs):
"""Constructs a Select Kernel ResNet-50-D model.
Different from configs in Select Kernel paper or "Compounding the Performance Improvements..." this
variation splits the input channels to the selective convolutions to keep param count down.
"""
sk_kwargs = dict(
split_input=True,
)
default_cfg = default_cfgs['skresnet50d']
model = ResNet(
SelectiveKernelBottleneck, [3, 4, 6, 3], stem_width=32, stem_type='deep', avg_down=True,
num_classes=num_classes, in_chans=in_chans, block_args=dict(sk_kwargs=sk_kwargs),
zero_init_last_bn=False, **kwargs)
model.default_cfg = default_cfg
if pretrained:
load_pretrained(model, default_cfg, num_classes, in_chans)
return model
sk_kwargs = dict(split_input=True)
model_args = dict(
block=SelectiveKernelBottleneck, layers=[3, 4, 6, 3], stem_width=32, stem_type='deep', avg_down=True,
block_args=dict(sk_kwargs=sk_kwargs), zero_init_last_bn=False, **kwargs)
return _create_skresnet('skresnet50d', pretrained, **model_args)
@register_model
def skresnext50_32x4d(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
def skresnext50_32x4d(pretrained=False, **kwargs):
"""Constructs a Select Kernel ResNeXt50-32x4d model. This should be equivalent to
the SKNet-50 model in the Select Kernel Paper
"""
default_cfg = default_cfgs['skresnext50_32x4d']
model = ResNet(
SelectiveKernelBottleneck, [3, 4, 6, 3], cardinality=32, base_width=4,
num_classes=num_classes, in_chans=in_chans, zero_init_last_bn=False, **kwargs)
model.default_cfg = default_cfg
if pretrained:
load_pretrained(model, default_cfg, num_classes, in_chans)
return model
model_args = dict(
block=SelectiveKernelBottleneck, layers=[3, 4, 6, 3], cardinality=32, base_width=4,
zero_init_last_bn=False, **kwargs)
return _create_skresnet('skresnext50_32x4d', pretrained, **model_args)

@ -20,6 +20,7 @@ import torch.nn.functional as F
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .registry import register_model
from .helpers import load_pretrained
from .features import FeatureNet
from .layers import ConvBnAct, SeparableConvBnAct, BatchNormAct2d, SelectAdaptivePool2d, \
create_attn, create_norm_act, get_norm_act_layer
@ -296,6 +297,9 @@ class VovNet(nn.Module):
conv_type(stem_chs[0], stem_chs[1], 3, stride=1, norm_layer=norm_layer),
conv_type(stem_chs[1], stem_chs[2], 3, stride=last_stem_stride, norm_layer=norm_layer),
])
self.feature_info = [dict(
num_chs=stem_chs[1], reduction=2, module=f'stem.{1 if stem_stride == 4 else 2}')]
current_stride = stem_stride
# OSA stages
in_ch_list = stem_chs[-1:] + stage_out_chs[:-1]
@ -309,6 +313,9 @@ class VovNet(nn.Module):
downsample=downsample, **stage_args)
]
self.num_features = stage_out_chs[i]
current_stride *= 2 if downsample else 1
self.feature_info += [dict(num_chs=self.num_features, reduction=current_stride, module=f'stages.{i}')]
self.stages = nn.Sequential(*stages)
self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=drop_rate)
@ -338,24 +345,24 @@ class VovNet(nn.Module):
def _vovnet(variant, pretrained=False, **kwargs):
load_strict = True
model_class = VovNet
features = False
out_indices = None
if kwargs.pop('features_only', False):
assert False, 'Not Implemented' # TODO
load_strict = False
features = True
kwargs.pop('num_classes', 0)
out_indices = kwargs.pop('out_indices', (0, 1, 2, 3, 4))
model_cfg = model_cfgs[variant]
default_cfg = default_cfgs[variant]
model = model_class(model_cfg, **kwargs)
model.default_cfg = default_cfg
model = VovNet(model_cfg, **kwargs)
model.default_cfg = default_cfgs[variant]
if pretrained:
load_pretrained(
model, default_cfg,
num_classes=kwargs.get('num_classes', 0), in_chans=kwargs.get('in_chans', 3), strict=load_strict)
model,
num_classes=kwargs.get('num_classes', 0), in_chans=kwargs.get('in_chans', 3), strict=not features)
if features:
model = FeatureNet(model, out_indices, flatten_sequential=True)
return model
@register_model
def vovnet39a(pretrained=False, **kwargs):
return _vovnet('vovnet39a', pretrained=pretrained, **kwargs)

@ -26,6 +26,7 @@ import torch.nn as nn
import torch.nn.functional as F
from .helpers import load_pretrained
from .features import FeatureNet
from .layers import SelectAdaptivePool2d
from .registry import register_model
@ -49,12 +50,12 @@ default_cfgs = {
class SeparableConv2d(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, padding=0, dilation=1, bias=False):
def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, padding=0, dilation=1):
super(SeparableConv2d, self).__init__()
self.conv1 = nn.Conv2d(
in_channels, in_channels, kernel_size, stride, padding, dilation, groups=in_channels, bias=bias)
self.pointwise = nn.Conv2d(in_channels, out_channels, 1, 1, 0, 1, 1, bias=bias)
in_channels, in_channels, kernel_size, stride, padding, dilation, groups=in_channels, bias=False)
self.pointwise = nn.Conv2d(in_channels, out_channels, 1, 1, 0, 1, 1, bias=False)
def forward(self, x):
x = self.conv1(x)
@ -63,34 +64,26 @@ class SeparableConv2d(nn.Module):
class Block(nn.Module):
def __init__(self, in_filters, out_filters, reps, strides=1, start_with_relu=True, grow_first=True):
def __init__(self, in_channels, out_channels, reps, strides=1, start_with_relu=True, grow_first=True):
super(Block, self).__init__()
if out_filters != in_filters or strides != 1:
self.skip = nn.Conv2d(in_filters, out_filters, 1, stride=strides, bias=False)
self.skipbn = nn.BatchNorm2d(out_filters)
if out_channels != in_channels or strides != 1:
self.skip = nn.Conv2d(in_channels, out_channels, 1, stride=strides, bias=False)
self.skipbn = nn.BatchNorm2d(out_channels)
else:
self.skip = None
self.relu = nn.ReLU(inplace=True)
rep = []
filters = in_filters
for i in range(reps):
if grow_first:
rep.append(self.relu)
rep.append(SeparableConv2d(in_filters, out_filters, 3, stride=1, padding=1, bias=False))
rep.append(nn.BatchNorm2d(out_filters))
filters = out_filters
for i in range(reps - 1):
rep.append(self.relu)
rep.append(SeparableConv2d(filters, filters, 3, stride=1, padding=1, bias=False))
rep.append(nn.BatchNorm2d(filters))
if not grow_first:
rep.append(self.relu)
rep.append(SeparableConv2d(in_filters, out_filters, 3, stride=1, padding=1, bias=False))
rep.append(nn.BatchNorm2d(out_filters))
inc = in_channels if i == 0 else out_channels
outc = out_channels
else:
inc = in_channels
outc = in_channels if i < (reps - 1) else out_channels
rep.append(nn.ReLU(inplace=True))
rep.append(SeparableConv2d(inc, outc, 3, stride=1, padding=1))
rep.append(nn.BatchNorm2d(outc))
if not start_with_relu:
rep = rep[1:]
@ -133,34 +126,35 @@ class Xception(nn.Module):
self.conv1 = nn.Conv2d(in_chans, 32, 3, 2, 0, bias=False)
self.bn1 = nn.BatchNorm2d(32)
self.relu = nn.ReLU(inplace=True)
self.act1 = nn.ReLU(inplace=True)
self.conv2 = nn.Conv2d(32, 64, 3, bias=False)
self.bn2 = nn.BatchNorm2d(64)
# do relu here
self.act2 = nn.ReLU(inplace=True)
self.block1 = Block(64, 128, 2, 2, start_with_relu=False, grow_first=True)
self.block2 = Block(128, 256, 2, 2, start_with_relu=True, grow_first=True)
self.block3 = Block(256, 728, 2, 2, start_with_relu=True, grow_first=True)
self.block1 = Block(64, 128, 2, 2, start_with_relu=False)
self.block2 = Block(128, 256, 2, 2)
self.block3 = Block(256, 728, 2, 2)
self.block4 = Block(728, 728, 3, 1, start_with_relu=True, grow_first=True)
self.block5 = Block(728, 728, 3, 1, start_with_relu=True, grow_first=True)
self.block6 = Block(728, 728, 3, 1, start_with_relu=True, grow_first=True)
self.block7 = Block(728, 728, 3, 1, start_with_relu=True, grow_first=True)
self.block4 = Block(728, 728, 3, 1)
self.block5 = Block(728, 728, 3, 1)
self.block6 = Block(728, 728, 3, 1)
self.block7 = Block(728, 728, 3, 1)
self.block8 = Block(728, 728, 3, 1, start_with_relu=True, grow_first=True)
self.block9 = Block(728, 728, 3, 1, start_with_relu=True, grow_first=True)
self.block10 = Block(728, 728, 3, 1, start_with_relu=True, grow_first=True)
self.block11 = Block(728, 728, 3, 1, start_with_relu=True, grow_first=True)
self.block8 = Block(728, 728, 3, 1)
self.block9 = Block(728, 728, 3, 1)
self.block10 = Block(728, 728, 3, 1)
self.block11 = Block(728, 728, 3, 1)
self.block12 = Block(728, 1024, 2, 2, start_with_relu=True, grow_first=False)
self.block12 = Block(728, 1024, 2, 2, grow_first=False)
self.conv3 = SeparableConv2d(1024, 1536, 3, 1, 1)
self.bn3 = nn.BatchNorm2d(1536)
self.act3 = nn.ReLU(inplace=True)
# do relu here
self.conv4 = SeparableConv2d(1536, self.num_features, 3, 1, 1)
self.bn4 = nn.BatchNorm2d(self.num_features)
self.act4 = nn.ReLU(inplace=True)
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
self.fc = nn.Linear(self.num_features * self.global_pool.feat_mult(), num_classes)
@ -188,11 +182,11 @@ class Xception(nn.Module):
def forward_features(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.act1(x)
x = self.conv2(x)
x = self.bn2(x)
x = self.relu(x)
x = self.act2(x)
x = self.block1(x)
x = self.block2(x)
@ -209,11 +203,11 @@ class Xception(nn.Module):
x = self.conv3(x)
x = self.bn3(x)
x = self.relu(x)
x = self.act3(x)
x = self.conv4(x)
x = self.bn4(x)
x = self.relu(x)
x = self.act4(x)
return x
def forward(self, x):
@ -225,12 +219,28 @@ class Xception(nn.Module):
return x
@register_model
def xception(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
default_cfg = default_cfgs['xception']
model = Xception(num_classes=num_classes, in_chans=in_chans, **kwargs)
model.default_cfg = default_cfg
def _xception(variant, pretrained=False, **kwargs):
load_strict = True
features = False
out_indices = None
if kwargs.pop('features_only', False):
load_strict = False
features = True
kwargs.pop('num_classes', 0)
out_indices = kwargs.pop('out_indices', (0, 1, 2, 3, 4))
model = Xception(**kwargs)
model.default_cfg = default_cfgs[variant]
if pretrained:
load_pretrained(model, default_cfg, num_classes, in_chans)
load_pretrained(
model,
num_classes=kwargs.get('num_classes', 0),
in_chans=kwargs.get('in_chans', 3),
strict=load_strict)
if features:
model = FeatureNet(model, out_indices)
return model
@register_model
def xception(pretrained=False, **kwargs):
return _xception('xception', pretrained=pretrained, **kwargs)

@ -24,9 +24,8 @@ try:
except ImportError:
has_apex = False
from timm.models import create_model, apply_test_time_pool, load_checkpoint, is_model, list_models,\
set_scriptable, set_no_jit
from timm.data import Dataset, DatasetTar, create_loader, resolve_data_config
from timm.models import create_model, apply_test_time_pool, load_checkpoint, is_model, list_models
from timm.data import Dataset, DatasetTar, create_loader, resolve_data_config, RealLabelsImagenet
from timm.utils import accuracy, AverageMeter, natural_key, setup_default_logging
torch.backends.cudnn.benchmark = True
@ -76,8 +75,25 @@ parser.add_argument('--use-ema', dest='use_ema', action='store_true',
help='use ema version of weights if present')
parser.add_argument('--torchscript', dest='torchscript', action='store_true',
help='convert model torchscript for inference')
parser.add_argument('--legacy-jit', dest='legacy_jit', action='store_true',
help='use legacy jit mode for pytorch 1.5/1.5.1/1.6 to get back fusion performance')
parser.add_argument('--results-file', default='', type=str, metavar='FILENAME',
help='Output csv file for validation results (summary)')
parser.add_argument('--real-labels', default='', type=str, metavar='FILENAME',
help='Real labels JSON file for imagenet evaluation')
def set_jit_legacy():
""" Set JIT executor to legacy w/ support for op fusion
This is hopefully a temporary need in 1.5/1.5.1/1.6 to restore performance due to changes
in the JIT exectutor. These API are not supported so could change.
"""
#
assert hasattr(torch._C, '_jit_set_profiling_executor'), "Old JIT behavior doesn't exist!"
torch._C._jit_set_profiling_executor(False)
torch._C._jit_set_profiling_mode(False)
torch._C._jit_override_can_fuse_on_gpu(True)
#torch._C._jit_set_texpr_fuser_enabled(True)
def validate(args):
@ -103,6 +119,8 @@ def validate(args):
model, test_time_pool = apply_test_time_pool(model, data_config, args)
if args.torchscript:
if args.legacy_jit:
set_jit_legacy()
torch.jit.optimized_execution(True)
model = torch.jit.script(model)
@ -116,13 +134,16 @@ def validate(args):
criterion = nn.CrossEntropyLoss().cuda()
#from torchvision.datasets import ImageNet
#dataset = ImageNet(args.data, split='val')
if os.path.splitext(args.data)[1] == '.tar' and os.path.isfile(args.data):
dataset = DatasetTar(args.data, load_bytes=args.tf_preprocessing, class_map=args.class_map)
else:
dataset = Dataset(args.data, load_bytes=args.tf_preprocessing, class_map=args.class_map)
if args.real_labels:
real_labels = RealLabelsImagenet(dataset.filenames(basename=True), real_json=args.real_labels)
else:
real_labels = None
crop_pct = 1.0 if test_time_pool else data_config['crop_pct']
loader = create_loader(
dataset,
@ -148,7 +169,7 @@ def validate(args):
input = torch.randn((args.batch_size,) + data_config['input_size']).cuda()
model(input)
end = time.time()
for i, (input, target) in enumerate(loader):
for batch_idx, (input, target) in enumerate(loader):
if args.no_prefetcher:
target = target.cuda()
input = input.cuda()
@ -159,6 +180,9 @@ def validate(args):
output = model(input)
loss = criterion(output, target)
if real_labels is not None:
real_labels.add_result(output)
# measure accuracy and record loss
acc1, acc5 = accuracy(output.data, target, topk=(1, 5))
losses.update(loss.item(), input.size(0))
@ -169,25 +193,35 @@ def validate(args):
batch_time.update(time.time() - end)
end = time.time()
if i % args.log_freq == 0:
if batch_idx % args.log_freq == 0:
logging.info(
'Test: [{0:>4d}/{1}] '
'Time: {batch_time.val:.3f}s ({batch_time.avg:.3f}s, {rate_avg:>7.2f}/s) '
'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f}) '
'Acc@1: {top1.val:>7.3f} ({top1.avg:>7.3f}) '
'Acc@5: {top5.val:>7.3f} ({top5.avg:>7.3f})'.format(
i, len(loader), batch_time=batch_time,
batch_idx, len(loader), batch_time=batch_time,
rate_avg=input.size(0) / batch_time.avg,
loss=losses, top1=top1, top5=top5))
if real_labels is not None:
real_top1 = real_labels.get_accuracy(k=1)
real_top5 = real_labels.get_accuracy(k=5)
results = OrderedDict(
top1=round(real_top1, 4), top1_err=round(100 - real_top1, 4),
top5=round(real_top5, 4), top5_err=round(100 - real_top5, 4),
top1_original=round(top1.avg, 4),
top5_original=round(top5.avg, 4))
else:
results = OrderedDict(
top1=round(top1.avg, 4), top1_err=round(100 - top1.avg, 4),
top5=round(top5.avg, 4), top5_err=round(100 - top5.avg, 4),
top5=round(top5.avg, 4), top5_err=round(100 - top5.avg, 4))
results.update(OrderedDict(
param_count=round(param_count / 1e6, 2),
img_size=data_config['input_size'][-1],
cropt_pct=crop_pct,
interpolation=data_config['interpolation'])
interpolation=data_config['interpolation']
))
logging.info(' * Acc@1 {:.3f} ({:.3f}) Acc@5 {:.3f} ({:.3f})'.format(
results['top1'], results['top1_err'], results['top5'], results['top5_err']))

Loading…
Cancel
Save