Merge branch 'attention' into EcaBam

pull/87/head
Chris Ha 6 years ago
commit 26922eb2ce

@ -54,6 +54,8 @@ model_list = [
model_desc='Trained from scratch in PyTorch w/ RandAugment'),
_entry('efficientnet_b3a', 'EfficientNet-B3 (320x320, 1.0 crop)', '1905.11946',
model_desc='Trained from scratch in PyTorch w/ RandAugment'),
_entry('efficientnet_es', 'EfficientNet-EdgeTPU-S', '1905.11946',
model_desc='Trained from scratch in PyTorch w/ RandAugment'),
_entry('fbnetc_100', 'FBNet-C', '1812.03443',
model_desc='Trained in PyTorch with RMSProp, exponential LR decay'),
_entry('gluon_inception_v3', 'Inception V3', '1512.00567', model_desc='Ported from GluonCV Model Zoo'),

@ -16,9 +16,10 @@ from .gluon_xception import *
from .res2net import *
from .dla import *
from .hrnet import *
from .sknet import *
from .registry import *
from .factory import create_model
from .helpers import load_checkpoint, resume_checkpoint
from .test_time_pool import TestTimePoolHead, apply_test_time_pool
from .split_batchnorm import convert_splitbn_model
from .layers import TestTimePoolHead, apply_test_time_pool
from .layers import convert_splitbn_model

@ -1,260 +0,0 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch._six import container_abcs
from itertools import repeat
from functools import partial
import numpy as np
import math
# Tuple helpers ripped from PyTorch
def _ntuple(n):
def parse(x):
if isinstance(x, container_abcs.Iterable):
return x
return tuple(repeat(x, n))
return parse
_single = _ntuple(1)
_pair = _ntuple(2)
_triple = _ntuple(3)
_quadruple = _ntuple(4)
def _is_static_pad(kernel_size, stride=1, dilation=1, **_):
return stride == 1 and (dilation * (kernel_size - 1)) % 2 == 0
def _get_padding(kernel_size, stride=1, dilation=1, **_):
padding = ((stride - 1) + dilation * (kernel_size - 1)) // 2
return padding
def _calc_same_pad(i, k, s, d):
return max((math.ceil(i / s) - 1) * s + (k - 1) * d + 1 - i, 0)
def _split_channels(num_chan, num_groups):
split = [num_chan // num_groups for _ in range(num_groups)]
split[0] += num_chan - sum(split)
return split
# pylint: disable=unused-argument
def conv2d_same(x, weight, bias=None, stride=(1, 1), padding=(0, 0), dilation=(1, 1), groups=1):
ih, iw = x.size()[-2:]
kh, kw = weight.size()[-2:]
pad_h = _calc_same_pad(ih, kh, stride[0], dilation[0])
pad_w = _calc_same_pad(iw, kw, stride[1], dilation[1])
if pad_h > 0 or pad_w > 0:
x = F.pad(x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2])
return F.conv2d(x, weight, bias, stride, (0, 0), dilation, groups)
class Conv2dSame(nn.Conv2d):
""" Tensorflow like 'SAME' convolution wrapper for 2D convolutions
"""
# pylint: disable=unused-argument
def __init__(self, in_channels, out_channels, kernel_size, stride=1,
padding=0, dilation=1, groups=1, bias=True):
super(Conv2dSame, self).__init__(
in_channels, out_channels, kernel_size, stride, 0, dilation, groups, bias)
def forward(self, x):
return conv2d_same(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
def get_padding_value(padding, kernel_size, **kwargs):
dynamic = False
if isinstance(padding, str):
# for any string padding, the padding will be calculated for you, one of three ways
padding = padding.lower()
if padding == 'same':
# TF compatible 'SAME' padding, has a performance and GPU memory allocation impact
if _is_static_pad(kernel_size, **kwargs):
# static case, no extra overhead
padding = _get_padding(kernel_size, **kwargs)
else:
# dynamic 'SAME' padding, has runtime/GPU memory overhead
padding = 0
dynamic = True
elif padding == 'valid':
# 'VALID' padding, same as padding=0
padding = 0
else:
# Default to PyTorch style 'same'-ish symmetric padding
padding = _get_padding(kernel_size, **kwargs)
return padding, dynamic
def create_conv2d_pad(in_chs, out_chs, kernel_size, **kwargs):
padding = kwargs.pop('padding', '')
kwargs.setdefault('bias', False)
padding, is_dynamic = get_padding_value(padding, kernel_size, **kwargs)
if is_dynamic:
return Conv2dSame(in_chs, out_chs, kernel_size, **kwargs)
else:
return nn.Conv2d(in_chs, out_chs, kernel_size, padding=padding, **kwargs)
class MixedConv2d(nn.Module):
""" Mixed Grouped Convolution
Based on MDConv and GroupedConv in MixNet impl:
https://github.com/tensorflow/tpu/blob/master/models/official/mnasnet/mixnet/custom_layers.py
NOTE: This does not currently work with torch.jit.script
"""
def __init__(self, in_channels, out_channels, kernel_size=3,
stride=1, padding='', dilation=1, depthwise=False, **kwargs):
super(MixedConv2d, self).__init__()
kernel_size = kernel_size if isinstance(kernel_size, list) else [kernel_size]
num_groups = len(kernel_size)
in_splits = _split_channels(in_channels, num_groups)
out_splits = _split_channels(out_channels, num_groups)
self.in_channels = sum(in_splits)
self.out_channels = sum(out_splits)
for idx, (k, in_ch, out_ch) in enumerate(zip(kernel_size, in_splits, out_splits)):
conv_groups = out_ch if depthwise else 1
# use add_module to keep key space clean
self.add_module(
str(idx),
create_conv2d_pad(
in_ch, out_ch, k, stride=stride,
padding=padding, dilation=dilation, groups=conv_groups, **kwargs)
)
self.splits = in_splits
def forward(self, x):
x_split = torch.split(x, self.splits, 1)
x_out = [c(x) for x, c in zip(x_split, self._modules.values())]
x = torch.cat(x_out, 1)
return x
def get_condconv_initializer(initializer, num_experts, expert_shape):
def condconv_initializer(weight):
"""CondConv initializer function."""
num_params = np.prod(expert_shape)
if (len(weight.shape) != 2 or weight.shape[0] != num_experts or
weight.shape[1] != num_params):
raise (ValueError(
'CondConv variables must have shape [num_experts, num_params]'))
for i in range(num_experts):
initializer(weight[i].view(expert_shape))
return condconv_initializer
class CondConv2d(nn.Module):
""" Conditional Convolution
Inspired by: https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/condconv/condconv_layers.py
Grouped convolution hackery for parallel execution of the per-sample kernel filters inspired by this discussion:
https://github.com/pytorch/pytorch/issues/17983
"""
__constants__ = ['bias', 'in_channels', 'out_channels', 'dynamic_padding']
def __init__(self, in_channels, out_channels, kernel_size=3,
stride=1, padding='', dilation=1, groups=1, bias=False, num_experts=4):
super(CondConv2d, self).__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = _pair(kernel_size)
self.stride = _pair(stride)
padding_val, is_padding_dynamic = get_padding_value(
padding, kernel_size, stride=stride, dilation=dilation)
self.dynamic_padding = is_padding_dynamic # if in forward to work with torchscript
self.padding = _pair(padding_val)
self.dilation = _pair(dilation)
self.groups = groups
self.num_experts = num_experts
self.weight_shape = (self.out_channels, self.in_channels // self.groups) + self.kernel_size
weight_num_param = 1
for wd in self.weight_shape:
weight_num_param *= wd
self.weight = torch.nn.Parameter(torch.Tensor(self.num_experts, weight_num_param))
if bias:
self.bias_shape = (self.out_channels,)
self.bias = torch.nn.Parameter(torch.Tensor(self.num_experts, self.out_channels))
else:
self.register_parameter('bias', None)
self.reset_parameters()
def reset_parameters(self):
init_weight = get_condconv_initializer(
partial(nn.init.kaiming_uniform_, a=math.sqrt(5)), self.num_experts, self.weight_shape)
init_weight(self.weight)
if self.bias is not None:
fan_in = np.prod(self.weight_shape[1:])
bound = 1 / math.sqrt(fan_in)
init_bias = get_condconv_initializer(
partial(nn.init.uniform_, a=-bound, b=bound), self.num_experts, self.bias_shape)
init_bias(self.bias)
def forward(self, x, routing_weights):
B, C, H, W = x.shape
weight = torch.matmul(routing_weights, self.weight)
new_weight_shape = (B * self.out_channels, self.in_channels // self.groups) + self.kernel_size
weight = weight.view(new_weight_shape)
bias = None
if self.bias is not None:
bias = torch.matmul(routing_weights, self.bias)
bias = bias.view(B * self.out_channels)
# move batch elements with channels so each batch element can be efficiently convolved with separate kernel
x = x.view(1, B * C, H, W)
if self.dynamic_padding:
out = conv2d_same(
x, weight, bias, stride=self.stride, padding=self.padding,
dilation=self.dilation, groups=self.groups * B)
else:
out = F.conv2d(
x, weight, bias, stride=self.stride, padding=self.padding,
dilation=self.dilation, groups=self.groups * B)
out = out.permute([1, 0, 2, 3]).view(B, self.out_channels, out.shape[-2], out.shape[-1])
# Literal port (from TF definition)
# x = torch.split(x, 1, 0)
# weight = torch.split(weight, 1, 0)
# if self.bias is not None:
# bias = torch.matmul(routing_weights, self.bias)
# bias = torch.split(bias, 1, 0)
# else:
# bias = [None] * B
# out = []
# for xi, wi, bi in zip(x, weight, bias):
# wi = wi.view(*self.weight_shape)
# if bi is not None:
# bi = bi.view(*self.bias_shape)
# out.append(self.conv_fn(
# xi, wi, bi, stride=self.stride, padding=self.padding,
# dilation=self.dilation, groups=self.groups))
# out = torch.cat(out, 0)
return out
# helper method
def select_conv2d(in_chs, out_chs, kernel_size, **kwargs):
assert 'groups' not in kwargs # only use 'depthwise' bool arg
if isinstance(kernel_size, list):
assert 'num_experts' not in kwargs # MixNet + CondConv combo not supported currently
# We're going to use only lists for defining the MixedConv2d kernel groups,
# ints, tuples, other iterables will continue to pass to normal conv and specify h, w.
m = MixedConv2d(in_chs, out_chs, kernel_size, **kwargs)
else:
depthwise = kwargs.pop('depthwise', False)
groups = out_chs if depthwise else 1
if 'num_experts' in kwargs and kwargs['num_experts'] > 0:
m = CondConv2d(in_chs, out_chs, kernel_size, groups=groups, **kwargs)
else:
m = create_conv2d_pad(in_chs, out_chs, kernel_size, groups=groups, **kwargs)
return m

@ -10,7 +10,7 @@ import torch.nn.functional as F
from .registry import register_model
from .helpers import load_pretrained
from .adaptive_avgmax_pool import SelectAdaptivePool2d
from .layers import SelectAdaptivePool2d
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
import re

@ -13,7 +13,7 @@ import torch.nn.functional as F
from .registry import register_model
from .helpers import load_pretrained
from .adaptive_avgmax_pool import SelectAdaptivePool2d
from .layers import SelectAdaptivePool2d
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD

@ -16,7 +16,7 @@ from collections import OrderedDict
from .registry import register_model
from .helpers import load_pretrained
from .adaptive_avgmax_pool import SelectAdaptivePool2d
from .layers import SelectAdaptivePool2d
from timm.data import IMAGENET_DPN_MEAN, IMAGENET_DPN_STD

@ -27,8 +27,8 @@ from .efficientnet_builder import *
from .feature_hooks import FeatureHooks
from .registry import register_model
from .helpers import load_pretrained
from .adaptive_avgmax_pool import SelectAdaptivePool2d
from .conv2d_layers import select_conv2d
from .layers import SelectAdaptivePool2d
from timm.models.layers import create_conv2d
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
@ -220,7 +220,7 @@ class EfficientNet(nn.Module):
def __init__(self, block_args, num_classes=1000, num_features=1280, in_chans=3, stem_size=32,
channel_multiplier=1.0, channel_divisor=8, channel_min=None,
pad_type='', act_layer=nn.ReLU, drop_rate=0., drop_connect_rate=0.,
output_stride=32, pad_type='', act_layer=nn.ReLU, drop_rate=0., drop_connect_rate=0.,
se_kwargs=None, norm_layer=nn.BatchNorm2d, norm_kwargs=None, global_pool='avg'):
super(EfficientNet, self).__init__()
norm_kwargs = norm_kwargs or {}
@ -232,21 +232,21 @@ class EfficientNet(nn.Module):
# Stem
stem_size = round_channels(stem_size, channel_multiplier, channel_divisor, channel_min)
self.conv_stem = select_conv2d(self._in_chs, stem_size, 3, stride=2, padding=pad_type)
self.conv_stem = create_conv2d(self._in_chs, stem_size, 3, stride=2, padding=pad_type)
self.bn1 = norm_layer(stem_size, **norm_kwargs)
self.act1 = act_layer(inplace=True)
self._in_chs = stem_size
# Middle stages (IR/ER/DS Blocks)
builder = EfficientNetBuilder(
channel_multiplier, channel_divisor, channel_min, 32, pad_type, act_layer, se_kwargs,
channel_multiplier, channel_divisor, channel_min, output_stride, pad_type, act_layer, se_kwargs,
norm_layer, norm_kwargs, drop_connect_rate, verbose=_DEBUG)
self.blocks = nn.Sequential(*builder(self._in_chs, block_args))
self.feature_info = builder.features
self._in_chs = builder.in_chs
# Head + Pooling
self.conv_head = select_conv2d(self._in_chs, self.num_features, 1, padding=pad_type)
self.conv_head = create_conv2d(self._in_chs, self.num_features, 1, padding=pad_type)
self.bn2 = norm_layer(self.num_features, **norm_kwargs)
self.act2 = act_layer(inplace=True)
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
@ -314,7 +314,7 @@ class EfficientNetFeatures(nn.Module):
# Stem
stem_size = round_channels(stem_size, channel_multiplier, channel_divisor, channel_min)
self.conv_stem = select_conv2d(self._in_chs, stem_size, 3, stride=2, padding=pad_type)
self.conv_stem = create_conv2d(self._in_chs, stem_size, 3, stride=2, padding=pad_type)
self.bn1 = norm_layer(stem_size, **norm_kwargs)
self.act1 = act_layer(inplace=True)
self._in_chs = stem_size

@ -1,11 +1,8 @@
from functools import partial
import torch
import torch.nn as nn
import torch.nn.functional as F
from .activations import sigmoid
from .conv2d_layers import *
from torch.nn import functional as F
from .layers.activations import sigmoid
from .layers import create_conv2d
# Defaults used for Google/Tensorflow training of mobile networks /w RMSprop as per
@ -72,7 +69,7 @@ def round_channels(channels, multiplier=1.0, divisor=8, channel_min=None):
return make_divisible(channels, divisor, channel_min)
def drop_connect(inputs, training=False, drop_connect_rate=0.):
def drop_connect(inputs, training: bool = False, drop_connect_rate: float = 0.):
"""Apply drop connect."""
if not training:
return inputs
@ -132,7 +129,7 @@ class ConvBnAct(nn.Module):
norm_layer=nn.BatchNorm2d, norm_kwargs=None):
super(ConvBnAct, self).__init__()
norm_kwargs = norm_kwargs or {}
self.conv = select_conv2d(in_chs, out_chs, kernel_size, stride=stride, dilation=dilation, padding=pad_type)
self.conv = create_conv2d(in_chs, out_chs, kernel_size, stride=stride, dilation=dilation, padding=pad_type)
self.bn1 = norm_layer(out_chs, **norm_kwargs)
self.act1 = act_layer(inplace=True)
@ -160,22 +157,24 @@ class DepthwiseSeparableConv(nn.Module):
norm_layer=nn.BatchNorm2d, norm_kwargs=None, drop_connect_rate=0.):
super(DepthwiseSeparableConv, self).__init__()
norm_kwargs = norm_kwargs or {}
self.has_se = se_ratio is not None and se_ratio > 0.
has_se = se_ratio is not None and se_ratio > 0.
self.has_residual = (stride == 1 and in_chs == out_chs) and not noskip
self.has_pw_act = pw_act # activation after point-wise conv
self.drop_connect_rate = drop_connect_rate
self.conv_dw = select_conv2d(
self.conv_dw = create_conv2d(
in_chs, in_chs, dw_kernel_size, stride=stride, dilation=dilation, padding=pad_type, depthwise=True)
self.bn1 = norm_layer(in_chs, **norm_kwargs)
self.act1 = act_layer(inplace=True)
# Squeeze-and-excitation
if self.has_se:
if has_se:
se_kwargs = resolve_se_args(se_kwargs, in_chs, act_layer)
self.se = SqueezeExcite(in_chs, se_ratio=se_ratio, **se_kwargs)
else:
self.se = None
self.conv_pw = select_conv2d(in_chs, out_chs, pw_kernel_size, padding=pad_type)
self.conv_pw = create_conv2d(in_chs, out_chs, pw_kernel_size, padding=pad_type)
self.bn2 = norm_layer(out_chs, **norm_kwargs)
self.act2 = act_layer(inplace=True) if self.has_pw_act else nn.Identity()
@ -193,7 +192,7 @@ class DepthwiseSeparableConv(nn.Module):
x = self.bn1(x)
x = self.act1(x)
if self.has_se:
if self.se is not None:
x = self.se(x)
x = self.conv_pw(x)
@ -219,29 +218,31 @@ class InvertedResidual(nn.Module):
norm_kwargs = norm_kwargs or {}
conv_kwargs = conv_kwargs or {}
mid_chs = make_divisible(in_chs * exp_ratio)
self.has_se = se_ratio is not None and se_ratio > 0.
has_se = se_ratio is not None and se_ratio > 0.
self.has_residual = (in_chs == out_chs and stride == 1) and not noskip
self.drop_connect_rate = drop_connect_rate
# Point-wise expansion
self.conv_pw = select_conv2d(in_chs, mid_chs, exp_kernel_size, padding=pad_type, **conv_kwargs)
self.conv_pw = create_conv2d(in_chs, mid_chs, exp_kernel_size, padding=pad_type, **conv_kwargs)
self.bn1 = norm_layer(mid_chs, **norm_kwargs)
self.act1 = act_layer(inplace=True)
# Depth-wise convolution
self.conv_dw = select_conv2d(
self.conv_dw = create_conv2d(
mid_chs, mid_chs, dw_kernel_size, stride=stride, dilation=dilation,
padding=pad_type, depthwise=True, **conv_kwargs)
self.bn2 = norm_layer(mid_chs, **norm_kwargs)
self.act2 = act_layer(inplace=True)
# Squeeze-and-excitation
if self.has_se:
if has_se:
se_kwargs = resolve_se_args(se_kwargs, in_chs, act_layer)
self.se = SqueezeExcite(mid_chs, se_ratio=se_ratio, **se_kwargs)
else:
self.se = None
# Point-wise linear projection
self.conv_pwl = select_conv2d(mid_chs, out_chs, pw_kernel_size, padding=pad_type, **conv_kwargs)
self.conv_pwl = create_conv2d(mid_chs, out_chs, pw_kernel_size, padding=pad_type, **conv_kwargs)
self.bn3 = norm_layer(out_chs, **norm_kwargs)
def feature_module(self, location):
@ -269,7 +270,7 @@ class InvertedResidual(nn.Module):
x = self.act2(x)
# Squeeze-and-excitation
if self.has_se:
if self.se is not None:
x = self.se(x)
# Point-wise linear projection
@ -323,7 +324,7 @@ class CondConvResidual(InvertedResidual):
x = self.act2(x)
# Squeeze-and-excitation
if self.has_se:
if self.se is not None:
x = self.se(x)
# Point-wise linear projection
@ -350,22 +351,24 @@ class EdgeResidual(nn.Module):
mid_chs = make_divisible(fake_in_chs * exp_ratio)
else:
mid_chs = make_divisible(in_chs * exp_ratio)
self.has_se = se_ratio is not None and se_ratio > 0.
has_se = se_ratio is not None and se_ratio > 0.
self.has_residual = (in_chs == out_chs and stride == 1) and not noskip
self.drop_connect_rate = drop_connect_rate
# Expansion convolution
self.conv_exp = select_conv2d(in_chs, mid_chs, exp_kernel_size, padding=pad_type)
self.conv_exp = create_conv2d(in_chs, mid_chs, exp_kernel_size, padding=pad_type)
self.bn1 = norm_layer(mid_chs, **norm_kwargs)
self.act1 = act_layer(inplace=True)
# Squeeze-and-excitation
if self.has_se:
if has_se:
se_kwargs = resolve_se_args(se_kwargs, in_chs, act_layer)
self.se = SqueezeExcite(mid_chs, se_ratio=se_ratio, **se_kwargs)
else:
self.se = None
# Point-wise linear projection
self.conv_pwl = select_conv2d(
self.conv_pwl = create_conv2d(
mid_chs, out_chs, pw_kernel_size, stride=stride, dilation=dilation, padding=pad_type)
self.bn2 = norm_layer(out_chs, **norm_kwargs)
@ -389,7 +392,7 @@ class EdgeResidual(nn.Module):
x = self.act1(x)
# Squeeze-and-excitation
if self.has_se:
if self.se is not None:
x = self.se(x)
# Point-wise linear projection

@ -5,7 +5,8 @@ from collections.__init__ import OrderedDict
from copy import deepcopy
import torch.nn as nn
from .activations import sigmoid, HardSwish, Swish
from .layers import CondConv2d, get_condconv_initializer
from .layers.activations import HardSwish, Swish
from .efficientnet_blocks import *
@ -358,15 +359,24 @@ class EfficientNetBuilder:
return stages
def _init_weight_goog(m, n=''):
def _init_weight_goog(m, n='', fix_group_fanout=False):
""" Weight initialization as per Tensorflow official implementations.
Args:
m (nn.Module): module to init
n (str): module name
fix_group_fanout (bool): enable correct fanout calculation w/ group convs
FIXME change fix_group_fanout to default to True if experiments show better training results
Handles layers in EfficientNet, EfficientNet-CondConv, MixNet, MnasNet, MobileNetV3, etc:
* https://github.com/tensorflow/tpu/blob/master/models/official/mnasnet/mnasnet_model.py
* https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/efficientnet_model.py
"""
if isinstance(m, CondConv2d):
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
if fix_group_fanout:
fan_out //= m.groups
init_weight_fn = get_condconv_initializer(
lambda w: w.data.normal_(0, math.sqrt(2.0 / fan_out)), m.num_experts, m.weight_shape)
init_weight_fn(m.weight)
@ -374,6 +384,8 @@ def _init_weight_goog(m, n=''):
m.bias.data.zero_()
elif isinstance(m, nn.Conv2d):
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
if fix_group_fanout:
fan_out //= m.groups
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
if m.bias is not None:
m.bias.data.zero_()
@ -390,21 +402,6 @@ def _init_weight_goog(m, n=''):
m.bias.data.zero_()
def _init_weight_default(m, n=''):
""" Basic ResNet (Kaiming) style weight init"""
if isinstance(m, CondConv2d):
init_fn = get_condconv_initializer(partial(
nn.init.kaiming_normal_, mode='fan_out', nonlinearity='relu'), m.num_experts, m.weight_shape)
init_fn(m.weight)
elif isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1.0)
m.bias.data.zero_()
elif isinstance(m, nn.Linear):
nn.init.kaiming_uniform_(m.weight, mode='fan_in', nonlinearity='linear')
def efficientnet_init_weights(model: nn.Module, init_fn=None):
init_fn = init_fn or _init_weight_goog
for n, m in model.named_modules():

@ -11,6 +11,7 @@ import torch.nn.functional as F
from .registry import register_model
from .helpers import load_pretrained
from .layers import SEModule
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .resnet import ResNet, Bottleneck, BasicBlock
@ -319,8 +320,8 @@ def gluon_seresnext50_32x4d(pretrained=False, num_classes=1000, in_chans=3, **kw
"""
default_cfg = default_cfgs['gluon_seresnext50_32x4d']
model = ResNet(
Bottleneck, [3, 4, 6, 3], cardinality=32, base_width=4, use_se=True,
num_classes=num_classes, in_chans=in_chans, **kwargs)
Bottleneck, [3, 4, 6, 3], cardinality=32, base_width=4,
num_classes=num_classes, in_chans=in_chans, block_args=dict(attn_layer=SEModule), **kwargs)
model.default_cfg = default_cfg
if pretrained:
load_pretrained(model, default_cfg, num_classes, in_chans)
@ -333,8 +334,8 @@ def gluon_seresnext101_32x4d(pretrained=False, num_classes=1000, in_chans=3, **k
"""
default_cfg = default_cfgs['gluon_seresnext101_32x4d']
model = ResNet(
Bottleneck, [3, 4, 23, 3], cardinality=32, base_width=4, use_se=True,
num_classes=num_classes, in_chans=in_chans, **kwargs)
Bottleneck, [3, 4, 23, 3], cardinality=32, base_width=4,
num_classes=num_classes, in_chans=in_chans, block_args=dict(attn_layer=SEModule), **kwargs)
model.default_cfg = default_cfg
if pretrained:
load_pretrained(model, default_cfg, num_classes, in_chans)
@ -346,9 +347,10 @@ def gluon_seresnext101_64x4d(pretrained=False, num_classes=1000, in_chans=3, **k
"""Constructs a SEResNeXt-101-64x4d model.
"""
default_cfg = default_cfgs['gluon_seresnext101_64x4d']
block_args = dict(attn_layer=SEModule)
model = ResNet(
Bottleneck, [3, 4, 23, 3], cardinality=64, base_width=4, use_se=True,
num_classes=num_classes, in_chans=in_chans, **kwargs)
Bottleneck, [3, 4, 23, 3], cardinality=64, base_width=4,
num_classes=num_classes, in_chans=in_chans, block_args=block_args, **kwargs)
model.default_cfg = default_cfg
if pretrained:
load_pretrained(model, default_cfg, num_classes, in_chans)
@ -360,10 +362,10 @@ def gluon_senet154(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
"""Constructs an SENet-154 model.
"""
default_cfg = default_cfgs['gluon_senet154']
block_args = dict(attn_layer=SEModule)
model = ResNet(
Bottleneck, [3, 8, 36, 3], cardinality=64, base_width=4, use_se=True,
stem_type='deep', down_kernel_size=3, block_reduce_first=2,
num_classes=num_classes, in_chans=in_chans, **kwargs)
Bottleneck, [3, 8, 36, 3], cardinality=64, base_width=4, stem_type='deep', down_kernel_size=3,
block_reduce_first=2, num_classes=num_classes, in_chans=in_chans, block_args=block_args, **kwargs)
model.default_cfg = default_cfg
if pretrained:
load_pretrained(model, default_cfg, num_classes, in_chans)

@ -13,7 +13,7 @@ from collections import OrderedDict
from .registry import register_model
from .helpers import load_pretrained
from .adaptive_avgmax_pool import SelectAdaptivePool2d
from .layers import SelectAdaptivePool2d
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
__all__ = ['Xception65', 'Xception71']

@ -25,7 +25,7 @@ import torch.nn.functional as F
from .resnet import BasicBlock, Bottleneck # leveraging ResNet blocks w/ additional features like SE
from .registry import register_model
from .helpers import load_pretrained
from .adaptive_avgmax_pool import SelectAdaptivePool2d
from .layers import SelectAdaptivePool2d
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
_BN_MOMENTUM = 0.1

@ -8,7 +8,7 @@ import torch.nn.functional as F
from .registry import register_model
from .helpers import load_pretrained
from .adaptive_avgmax_pool import SelectAdaptivePool2d
from .layers import SelectAdaptivePool2d
from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
__all__ = ['InceptionResnetV2']

@ -8,7 +8,7 @@ import torch.nn.functional as F
from .registry import register_model
from .helpers import load_pretrained
from .adaptive_avgmax_pool import SelectAdaptivePool2d
from .layers import SelectAdaptivePool2d
from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
__all__ = ['InceptionV4']

@ -0,0 +1,17 @@
from .padding import get_padding
from .avg_pool2d_same import AvgPool2dSame
from .conv2d_same import Conv2dSame
from .conv_bn_act import ConvBnAct
from .mixed_conv2d import MixedConv2d
from .cond_conv2d import CondConv2d, get_condconv_initializer
from .create_conv2d import create_conv2d
from .create_attn import create_attn
from .selective_kernel import SelectiveKernelConv
from .se import SEModule
from .eca import EcaModule, CecaModule
from .activations import *
from .adaptive_avgmax_pool import \
adaptive_avgmax_pool2d, select_adaptive_pool2d, AdaptiveAvgMaxPool2d, SelectAdaptivePool2d
from .drop import DropBlock2d, DropPath
from .test_time_pool import TestTimePoolHead, apply_test_time_pool
from .split_batchnorm import SplitBatchNorm2d, convert_splitbn_model

@ -1,3 +1,12 @@
""" Activations
A collection of activations fn and modules with a common interface so that they can
easily be swapped. All have an `inplace` arg even if not used.
Hacked together by Ross Wightman
"""
import torch
from torch import nn as nn
from torch.nn import functional as F
@ -66,20 +75,20 @@ if _USE_MEM_EFFICIENT_ISH:
return MishJitAutoFn.apply(x)
else:
def swish(x, inplace=False):
def swish(x, inplace: bool = False):
"""Swish - Described in: https://arxiv.org/abs/1710.05941
"""
return x.mul_(x.sigmoid()) if inplace else x.mul(x.sigmoid())
def mish(x, _inplace=False):
def mish(x, _inplace: bool = False):
"""Mish: A Self Regularized Non-Monotonic Neural Activation Function - https://arxiv.org/abs/1908.08681
"""
return x.mul(F.softplus(x).tanh())
class Swish(nn.Module):
def __init__(self, inplace=False):
def __init__(self, inplace: bool = False):
super(Swish, self).__init__()
self.inplace = inplace
@ -88,7 +97,7 @@ class Swish(nn.Module):
class Mish(nn.Module):
def __init__(self, inplace=False):
def __init__(self, inplace: bool = False):
super(Mish, self).__init__()
self.inplace = inplace
@ -96,13 +105,13 @@ class Mish(nn.Module):
return mish(x, self.inplace)
def sigmoid(x, inplace=False):
def sigmoid(x, inplace: bool = False):
return x.sigmoid_() if inplace else x.sigmoid()
# PyTorch has this, but not with a consistent inplace argmument interface
class Sigmoid(nn.Module):
def __init__(self, inplace=False):
def __init__(self, inplace: bool = False):
super(Sigmoid, self).__init__()
self.inplace = inplace
@ -110,13 +119,13 @@ class Sigmoid(nn.Module):
return x.sigmoid_() if self.inplace else x.sigmoid()
def tanh(x, inplace=False):
def tanh(x, inplace: bool = False):
return x.tanh_() if inplace else x.tanh()
# PyTorch has this, but not with a consistent inplace argmument interface
class Tanh(nn.Module):
def __init__(self, inplace=False):
def __init__(self, inplace: bool = False):
super(Tanh, self).__init__()
self.inplace = inplace
@ -124,13 +133,13 @@ class Tanh(nn.Module):
return x.tanh_() if self.inplace else x.tanh()
def hard_swish(x, inplace=False):
def hard_swish(x, inplace: bool = False):
inner = F.relu6(x + 3.).div_(6.)
return x.mul_(inner) if inplace else x.mul(inner)
class HardSwish(nn.Module):
def __init__(self, inplace=False):
def __init__(self, inplace: bool = False):
super(HardSwish, self).__init__()
self.inplace = inplace
@ -138,7 +147,7 @@ class HardSwish(nn.Module):
return hard_swish(x, self.inplace)
def hard_sigmoid(x, inplace=False):
def hard_sigmoid(x, inplace: bool = False):
if inplace:
return x.add_(3.).clamp_(0., 6.).div_(6.)
else:
@ -146,7 +155,7 @@ def hard_sigmoid(x, inplace=False):
class HardSigmoid(nn.Module):
def __init__(self, inplace=False):
def __init__(self, inplace: bool = False):
super(HardSigmoid, self).__init__()
self.inplace = inplace

@ -0,0 +1,31 @@
""" AvgPool2d w/ Same Padding
Hacked together by Ross Wightman
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import List
import math
from .helpers import tup_pair
from .padding import pad_same
def avg_pool2d_same(x, kernel_size: List[int], stride: List[int], padding: List[int] = (0, 0),
ceil_mode: bool = False, count_include_pad: bool = True):
x = pad_same(x, kernel_size, stride)
return F.avg_pool2d(x, kernel_size, stride, (0, 0), ceil_mode, count_include_pad)
class AvgPool2dSame(nn.AvgPool2d):
""" Tensorflow like 'SAME' wrapper for 2D average pooling
"""
def __init__(self, kernel_size: int, stride=None, padding=0, ceil_mode=False, count_include_pad=True):
kernel_size = tup_pair(kernel_size)
stride = tup_pair(stride)
super(AvgPool2dSame, self).__init__(kernel_size, stride, (0, 0), ceil_mode, count_include_pad)
def forward(self, x):
return avg_pool2d_same(
x, self.kernel_size, self.stride, self.padding, self.ceil_mode, self.count_include_pad)

@ -0,0 +1,97 @@
""" CBAM (sort-of) Attention
Experimental impl of CBAM: Convolutional Block Attention Module: https://arxiv.org/abs/1807.06521
Hacked together by Ross Wightman
"""
import torch
from torch import nn as nn
from .conv_bn_act import ConvBnAct
class ChannelAttn(nn.Module):
""" Original CBAM channel attention module, currently avg + max pool variant only.
"""
def __init__(self, channels, reduction=16, act_layer=nn.ReLU):
super(ChannelAttn, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.max_pool = nn.AdaptiveMaxPool2d(1)
self.fc1 = nn.Conv2d(channels, channels // reduction, 1, bias=False)
self.act = act_layer(inplace=True)
self.fc2 = nn.Conv2d(channels // reduction, channels, 1, bias=False)
def forward(self, x):
x_avg = self.avg_pool(x)
x_max = self.max_pool(x)
x_avg = self.fc2(self.act(self.fc1(x_avg)))
x_max = self.fc2(self.act(self.fc1(x_max)))
x_attn = x_avg + x_max
return x * x_attn.sigmoid()
class LightChannelAttn(ChannelAttn):
"""An experimental 'lightweight' that sums avg + max pool first
"""
def __init__(self, channels, reduction=16):
super(LightChannelAttn, self).__init__(channels, reduction)
def forward(self, x):
x_pool = 0.5 * self.avg_pool(x) + 0.5 * self.max_pool(x)
x_attn = self.fc2(self.act(self.fc1(x_pool)))
return x * x_attn.sigmoid()
class SpatialAttn(nn.Module):
""" Original CBAM spatial attention module
"""
def __init__(self, kernel_size=7):
super(SpatialAttn, self).__init__()
self.conv = ConvBnAct(2, 1, kernel_size, act_layer=None)
def forward(self, x):
x_avg = torch.mean(x, dim=1, keepdim=True)
x_max = torch.max(x, dim=1, keepdim=True)[0]
x_attn = torch.cat([x_avg, x_max], dim=1)
x_attn = self.conv(x_attn)
return x * x_attn.sigmoid()
class LightSpatialAttn(nn.Module):
"""An experimental 'lightweight' variant that sums avg_pool and max_pool results.
"""
def __init__(self, kernel_size=7):
super(LightSpatialAttn, self).__init__()
self.conv = ConvBnAct(1, 1, kernel_size, act_layer=None)
def forward(self, x):
x_avg = torch.mean(x, dim=1, keepdim=True)
x_max = torch.max(x, dim=1, keepdim=True)[0]
x_attn = 0.5 * x_avg + 0.5 * x_max
x_attn = self.conv(x_attn)
return x * x_attn.sigmoid()
class CbamModule(nn.Module):
def __init__(self, channels, spatial_kernel_size=7):
super(CbamModule, self).__init__()
self.channel = ChannelAttn(channels)
self.spatial = SpatialAttn(spatial_kernel_size)
def forward(self, x):
x = self.channel(x)
x = self.spatial(x)
return x
class LightCbamModule(nn.Module):
def __init__(self, channels, spatial_kernel_size=7):
super(LightCbamModule, self).__init__()
self.channel = LightChannelAttn(channels)
self.spatial = LightSpatialAttn(spatial_kernel_size)
def forward(self, x):
x = self.channel(x)
x = self.spatial(x)
return x

@ -0,0 +1,118 @@
""" Conditional Convolution
Hacked together by Ross Wightman
"""
import math
from functools import partial
import numpy as np
import torch
from torch import nn as nn
from torch.nn import functional as F
from .helpers import tup_pair
from .conv2d_same import get_padding_value, conv2d_same
def get_condconv_initializer(initializer, num_experts, expert_shape):
def condconv_initializer(weight):
"""CondConv initializer function."""
num_params = np.prod(expert_shape)
if (len(weight.shape) != 2 or weight.shape[0] != num_experts or
weight.shape[1] != num_params):
raise (ValueError(
'CondConv variables must have shape [num_experts, num_params]'))
for i in range(num_experts):
initializer(weight[i].view(expert_shape))
return condconv_initializer
class CondConv2d(nn.Module):
""" Conditional Convolution
Inspired by: https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/condconv/condconv_layers.py
Grouped convolution hackery for parallel execution of the per-sample kernel filters inspired by this discussion:
https://github.com/pytorch/pytorch/issues/17983
"""
__constants__ = ['bias', 'in_channels', 'out_channels', 'dynamic_padding']
def __init__(self, in_channels, out_channels, kernel_size=3,
stride=1, padding='', dilation=1, groups=1, bias=False, num_experts=4):
super(CondConv2d, self).__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = tup_pair(kernel_size)
self.stride = tup_pair(stride)
padding_val, is_padding_dynamic = get_padding_value(
padding, kernel_size, stride=stride, dilation=dilation)
self.dynamic_padding = is_padding_dynamic # if in forward to work with torchscript
self.padding = tup_pair(padding_val)
self.dilation = tup_pair(dilation)
self.groups = groups
self.num_experts = num_experts
self.weight_shape = (self.out_channels, self.in_channels // self.groups) + self.kernel_size
weight_num_param = 1
for wd in self.weight_shape:
weight_num_param *= wd
self.weight = torch.nn.Parameter(torch.Tensor(self.num_experts, weight_num_param))
if bias:
self.bias_shape = (self.out_channels,)
self.bias = torch.nn.Parameter(torch.Tensor(self.num_experts, self.out_channels))
else:
self.register_parameter('bias', None)
self.reset_parameters()
def reset_parameters(self):
init_weight = get_condconv_initializer(
partial(nn.init.kaiming_uniform_, a=math.sqrt(5)), self.num_experts, self.weight_shape)
init_weight(self.weight)
if self.bias is not None:
fan_in = np.prod(self.weight_shape[1:])
bound = 1 / math.sqrt(fan_in)
init_bias = get_condconv_initializer(
partial(nn.init.uniform_, a=-bound, b=bound), self.num_experts, self.bias_shape)
init_bias(self.bias)
def forward(self, x, routing_weights):
B, C, H, W = x.shape
weight = torch.matmul(routing_weights, self.weight)
new_weight_shape = (B * self.out_channels, self.in_channels // self.groups) + self.kernel_size
weight = weight.view(new_weight_shape)
bias = None
if self.bias is not None:
bias = torch.matmul(routing_weights, self.bias)
bias = bias.view(B * self.out_channels)
# move batch elements with channels so each batch element can be efficiently convolved with separate kernel
x = x.view(1, B * C, H, W)
if self.dynamic_padding:
out = conv2d_same(
x, weight, bias, stride=self.stride, padding=self.padding,
dilation=self.dilation, groups=self.groups * B)
else:
out = F.conv2d(
x, weight, bias, stride=self.stride, padding=self.padding,
dilation=self.dilation, groups=self.groups * B)
out = out.permute([1, 0, 2, 3]).view(B, self.out_channels, out.shape[-2], out.shape[-1])
# Literal port (from TF definition)
# x = torch.split(x, 1, 0)
# weight = torch.split(weight, 1, 0)
# if self.bias is not None:
# bias = torch.matmul(routing_weights, self.bias)
# bias = torch.split(bias, 1, 0)
# else:
# bias = [None] * B
# out = []
# for xi, wi, bi in zip(x, weight, bias):
# wi = wi.view(*self.weight_shape)
# if bi is not None:
# bi = bi.view(*self.bias_shape)
# out.append(self.conv_fn(
# xi, wi, bi, stride=self.stride, padding=self.padding,
# dilation=self.dilation, groups=self.groups))
# out = torch.cat(out, 0)
return out

@ -0,0 +1,66 @@
""" Conv2d w/ Same Padding
Hacked together by Ross Wightman
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Union, List, Tuple, Optional, Callable
import math
from .padding import get_padding, pad_same, is_static_pad
def conv2d_same(
x, weight: torch.Tensor, bias: Optional[torch.Tensor] = None, stride: Tuple[int, int] = (1, 1),
padding: Tuple[int, int] = (0, 0), dilation: Tuple[int, int] = (1, 1), groups: int = 1):
x = pad_same(x, weight.shape[-2:], stride, dilation)
return F.conv2d(x, weight, bias, stride, (0, 0), dilation, groups)
class Conv2dSame(nn.Conv2d):
""" Tensorflow like 'SAME' convolution wrapper for 2D convolutions
"""
def __init__(self, in_channels, out_channels, kernel_size, stride=1,
padding=0, dilation=1, groups=1, bias=True):
super(Conv2dSame, self).__init__(
in_channels, out_channels, kernel_size, stride, 0, dilation, groups, bias)
def forward(self, x):
return conv2d_same(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
def get_padding_value(padding, kernel_size, **kwargs) -> Tuple[Tuple, bool]:
dynamic = False
if isinstance(padding, str):
# for any string padding, the padding will be calculated for you, one of three ways
padding = padding.lower()
if padding == 'same':
# TF compatible 'SAME' padding, has a performance and GPU memory allocation impact
if is_static_pad(kernel_size, **kwargs):
# static case, no extra overhead
padding = get_padding(kernel_size, **kwargs)
else:
# dynamic 'SAME' padding, has runtime/GPU memory overhead
padding = 0
dynamic = True
elif padding == 'valid':
# 'VALID' padding, same as padding=0
padding = 0
else:
# Default to PyTorch style 'same'-ish symmetric padding
padding = get_padding(kernel_size, **kwargs)
return padding, dynamic
def create_conv2d_pad(in_chs, out_chs, kernel_size, **kwargs):
padding = kwargs.pop('padding', '')
kwargs.setdefault('bias', False)
padding, is_dynamic = get_padding_value(padding, kernel_size, **kwargs)
if is_dynamic:
return Conv2dSame(in_chs, out_chs, kernel_size, **kwargs)
else:
return nn.Conv2d(in_chs, out_chs, kernel_size, padding=padding, **kwargs)

@ -0,0 +1,32 @@
""" Conv2d + BN + Act
Hacked together by Ross Wightman
"""
from torch import nn as nn
from timm.models.layers import get_padding
class ConvBnAct(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, dilation=1, groups=1,
drop_block=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d):
super(ConvBnAct, self).__init__()
padding = get_padding(kernel_size, stride, dilation) # assuming PyTorch style padding for this block
self.conv = nn.Conv2d(
in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride,
padding=padding, dilation=dilation, groups=groups, bias=False)
self.bn = norm_layer(out_channels)
self.drop_block = drop_block
if act_layer is not None:
self.act = act_layer(inplace=True)
else:
self.act = None
def forward(self, x):
x = self.conv(x)
x = self.bn(x)
if self.drop_block is not None:
x = self.drop_block(x)
if self.act is not None:
x = self.act(x)
return x

@ -0,0 +1,35 @@
""" Select AttentionFactory Method
Hacked together by Ross Wightman
"""
import torch
from .se import SEModule
from .eca import EcaModule, CecaModule
from .cbam import CbamModule, LightCbamModule
def create_attn(attn_type, channels, **kwargs):
module_cls = None
if attn_type is not None:
if isinstance(attn_type, str):
attn_type = attn_type.lower()
if attn_type == 'se':
module_cls = SEModule
elif attn_type == 'eca':
module_cls = EcaModule
elif attn_type == 'eca':
module_cls = CecaModule
elif attn_type == 'cbam':
module_cls = CbamModule
elif attn_type == 'lcbam':
module_cls = LightCbamModule
else:
assert False, "Invalid attn module (%s)" % attn_type
elif isinstance(attn_type, bool):
if attn_type:
module_cls = SEModule
else:
module_cls = attn_type
if module_cls is not None:
return module_cls(channels, **kwargs)
return None

@ -0,0 +1,30 @@
""" Create Conv2d Factory Method
Hacked together by Ross Wightman
"""
from .mixed_conv2d import MixedConv2d
from .cond_conv2d import CondConv2d
from .conv2d_same import create_conv2d_pad
def create_conv2d(in_chs, out_chs, kernel_size, **kwargs):
""" Select a 2d convolution implementation based on arguments
Creates and returns one of torch.nn.Conv2d, Conv2dSame, MixedConv2d, or CondConv2d.
Used extensively by EfficientNet, MobileNetv3 and related networks.
"""
assert 'groups' not in kwargs # only use 'depthwise' bool arg
if isinstance(kernel_size, list):
assert 'num_experts' not in kwargs # MixNet + CondConv combo not supported currently
# We're going to use only lists for defining the MixedConv2d kernel groups,
# ints, tuples, other iterables will continue to pass to normal conv and specify h, w.
m = MixedConv2d(in_chs, out_chs, kernel_size, **kwargs)
else:
depthwise = kwargs.pop('depthwise', False)
groups = out_chs if depthwise else 1
if 'num_experts' in kwargs and kwargs['num_experts'] > 0:
m = CondConv2d(in_chs, out_chs, kernel_size, groups=groups, **kwargs)
else:
m = create_conv2d_pad(in_chs, out_chs, kernel_size, groups=groups, **kwargs)
return m

@ -0,0 +1,88 @@
""" DropBlock, DropPath
PyTorch implementations of DropBlock and DropPath (Stochastic Depth) regularization layers.
Hacked together by Ross Wightman
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import math
def drop_block_2d(x, drop_prob=0.1, block_size=7, gamma_scale=1.0, drop_with_noise=False):
""" DropBlock. See https://arxiv.org/pdf/1810.12890.pdf
"""
_, _, height, width = x.shape
total_size = width * height
clipped_block_size = min(block_size, min(width, height))
# seed_drop_rate, the gamma parameter
seed_drop_rate = gamma_scale * drop_prob * total_size / clipped_block_size ** 2 / (
(width - block_size + 1) *
(height - block_size + 1))
# Forces the block to be inside the feature map.
w_i, h_i = torch.meshgrid(torch.arange(width).to(x.device), torch.arange(height).to(x.device))
valid_block = ((w_i >= clipped_block_size // 2) & (w_i < width - (clipped_block_size - 1) // 2)) & \
((h_i >= clipped_block_size // 2) & (h_i < height - (clipped_block_size - 1) // 2))
valid_block = torch.reshape(valid_block, (1, 1, height, width)).float()
uniform_noise = torch.rand_like(x, dtype=torch.float32)
block_mask = ((2 - seed_drop_rate - valid_block + uniform_noise) >= 1).float()
block_mask = -F.max_pool2d(
-block_mask,
kernel_size=clipped_block_size, # block_size, ???
stride=1,
padding=clipped_block_size // 2)
if drop_with_noise:
normal_noise = torch.randn_like(x)
x = x * block_mask + normal_noise * (1 - block_mask)
else:
normalize_scale = block_mask.numel() / (torch.sum(block_mask) + 1e-7)
x = x * block_mask * normalize_scale
return x
class DropBlock2d(nn.Module):
""" DropBlock. See https://arxiv.org/pdf/1810.12890.pdf
"""
def __init__(self,
drop_prob=0.1,
block_size=7,
gamma_scale=1.0,
with_noise=False):
super(DropBlock2d, self).__init__()
self.drop_prob = drop_prob
self.gamma_scale = gamma_scale
self.block_size = block_size
self.with_noise = with_noise
def forward(self, x):
if not self.training or not self.drop_prob:
return x
return drop_block_2d(x, self.drop_prob, self.block_size, self.gamma_scale, self.with_noise)
def drop_path(x, drop_prob=0.):
"""Drop paths (Stochastic Depth) per sample (when applied in residual blocks).
"""
keep_prob = 1 - drop_prob
random_tensor = keep_prob + torch.rand((x.size()[0], 1, 1, 1), dtype=x.dtype, device=x.device)
random_tensor.floor_() # binarize
output = x.div(keep_prob) * random_tensor
return output
class DropPath(nn.ModuleDict):
"""Drop paths (Stochastic Depth) per sample (when applied in residual blocks).
"""
def __init__(self, drop_prob=None):
super(DropPath, self).__init__()
self.drop_prob = drop_prob
def forward(self, x):
if not self.training or not self.drop_prob:
return x
return drop_path(x, self.drop_prob)

@ -1,14 +1,16 @@
'''
"""
ECA module from ECAnet
original paper: ECA-Net: Efficient Channel Attention for Deep Convolutional Neural Networks
paper: ECA-Net: Efficient Channel Attention for Deep Convolutional Neural Networks
https://arxiv.org/abs/1910.03151
https://github.com/BangguWu/ECANet
original ECA model borrowed from original github
modified circular ECA implementation and
adoptation for use in pytorch image models package
Original ECA model borrowed from https://github.com/BangguWu/ECANet
Modified circular ECA implementation and adaption for use in timm package
by Chris Ha https://github.com/VRandme
Original License:
MIT License
Copyright (c) 2019 BangguWu, Qilong Wang
@ -30,14 +32,15 @@ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
'''
"""
import math
import torch
from torch import nn
import torch.nn.functional as F
class EcaModule(nn.Module):
"""Constructs a ECA module.
"""Constructs an ECA module.
Args:
channel: Number of channels of the input feature map for use in adaptive kernel sizes
@ -45,35 +48,36 @@ class EcaModule(nn.Module):
gamma, beta: when channel is given parameters of mapping function
refer to original paper https://arxiv.org/pdf/1910.03151.pdf
(default=None. if channel size not given, use k_size given for kernel size.)
k_size: Adaptive selection of kernel size (default=3)
kernel_size: Adaptive selection of kernel size (default=3)
"""
def __init__(self, channel=None, k_size=3, gamma=2, beta=1):
def __init__(self, channels=None, kernel_size=3, gamma=2, beta=1):
super(EcaModule, self).__init__()
assert k_size % 2 == 1
assert kernel_size % 2 == 1
if channel is not None:
t = int(abs(math.log(channel, 2)+beta) / gamma)
k_size = t if t % 2 else t + 1
if channels is not None:
t = int(abs(math.log(channels, 2) + beta) / gamma)
kernel_size = max(t if t % 2 else t + 1, 3)
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.conv = nn.Conv1d(1, 1, kernel_size=k_size, padding=(k_size - 1) // 2, bias=False)
self.sigmoid = nn.Sigmoid()
self.conv = nn.Conv1d(1, 1, kernel_size=kernel_size, padding=(kernel_size - 1) // 2, bias=False)
def forward(self, x):
# feature descriptor on the global spatial information
# Feature descriptor on the global spatial information
y = self.avg_pool(x)
# reshape for convolution
# Reshape for convolution
y = y.view(x.shape[0], 1, -1)
# Two different branches of ECA module
y = self.conv(y)
# Multi-scale information fusion
y = self.sigmoid(y.view(x.shape[0], -1, 1, 1))
y = y.view(x.shape[0], -1, 1, 1).sigmoid()
return x * y.expand_as(x)
class CecaModule(nn.Module):
"""Constructs a circular ECA module.
the primary difference is that the conv uses a circular padding rather than zero padding.
This is because unlike images, the channels themselves do not have inherent ordering nor
ECA module where the conv uses circular padding rather than zero padding.
Unlike the spatial dimension, the channels do not have inherent ordering nor
locality. Although this module in essence, applies such an assumption, it is unnecessary
to limit the channels on either "edge" from being circularly adapted to each other.
This will fundamentally increase connectivity and possibly increase performance metrics
@ -81,43 +85,42 @@ class CecaModule(nn.Module):
(parameter size, throughput,latency, etc)
Args:
channel: Number of channels of the input feature map for use in adaptive kernel sizes
channels: Number of channels of the input feature map for use in adaptive kernel sizes
for actual calculations according to channel.
gamma, beta: when channel is given parameters of mapping function
refer to original paper https://arxiv.org/pdf/1910.03151.pdf
(default=None. if channel size not given, use k_size given for kernel size.)
k_size: Adaptive selection of kernel size (default=3)
kernel_size: Adaptive selection of kernel size (default=3)
"""
def __init__(self, channel=None, k_size=3, gamma=2, beta=1):
def __init__(self, channels=None, kernel_size=3, gamma=2, beta=1):
super(CecaModule, self).__init__()
assert k_size % 2 == 1
assert kernel_size % 2 == 1
if channel is not None:
t = int(abs(math.log(channel, 2)+beta) / gamma)
k_size = t if t % 2 else t + 1
if channels is not None:
t = int(abs(math.log(channels, 2) + beta) / gamma)
kernel_size = max(t if t % 2 else t + 1, 3)
self.avg_pool = nn.AdaptiveAvgPool2d(1)
#pytorch circular padding mode is bugged as of pytorch 1.4
#pytorch circular padding mode is buggy as of pytorch 1.4
#see https://github.com/pytorch/pytorch/pull/17240
#implement manual circular padding
self.conv = nn.Conv1d(1, 1, kernel_size=k_size, padding=0, bias=False)
self.padding = (k_size - 1) // 2
self.sigmoid = nn.Sigmoid()
self.conv = nn.Conv1d(1, 1, kernel_size=kernel_size, padding=0, bias=False)
self.padding = (kernel_size - 1) // 2
def forward(self, x):
# feature descriptor on the global spatial information
# Feature descriptor on the global spatial information
y = self.avg_pool(x)
#manually implement circular padding, F.pad does not seemed to be bugged
# Manually implement circular padding, F.pad does not seemed to be bugged
y = F.pad(y.view(x.shape[0], 1, -1), (self.padding, self.padding), mode='circular')
# Two different branches of ECA module
y = self.conv(y)
# Multi-scale information fusion
y = self.sigmoid(y.view(x.shape[0], -1, 1, 1))
y = y.view(x.shape[0], -1, 1, 1).sigmoid()
return x * y.expand_as(x)

@ -0,0 +1,27 @@
""" Layer/Module Helpers
Hacked together by Ross Wightman
"""
from itertools import repeat
from torch._six import container_abcs
# From PyTorch internals
def _ntuple(n):
def parse(x):
if isinstance(x, container_abcs.Iterable):
return x
return tuple(repeat(x, n))
return parse
tup_single = _ntuple(1)
tup_pair = _ntuple(2)
tup_triple = _ntuple(3)
tup_quadruple = _ntuple(4)

@ -0,0 +1,49 @@
""" Conditional Convolution
Hacked together by Ross Wightman
"""
import torch
from torch import nn as nn
from .conv2d_same import create_conv2d_pad
def _split_channels(num_chan, num_groups):
split = [num_chan // num_groups for _ in range(num_groups)]
split[0] += num_chan - sum(split)
return split
class MixedConv2d(nn.ModuleDict):
""" Mixed Grouped Convolution
Based on MDConv and GroupedConv in MixNet impl:
https://github.com/tensorflow/tpu/blob/master/models/official/mnasnet/mixnet/custom_layers.py
"""
def __init__(self, in_channels, out_channels, kernel_size=3,
stride=1, padding='', dilation=1, depthwise=False, **kwargs):
super(MixedConv2d, self).__init__()
kernel_size = kernel_size if isinstance(kernel_size, list) else [kernel_size]
num_groups = len(kernel_size)
in_splits = _split_channels(in_channels, num_groups)
out_splits = _split_channels(out_channels, num_groups)
self.in_channels = sum(in_splits)
self.out_channels = sum(out_splits)
for idx, (k, in_ch, out_ch) in enumerate(zip(kernel_size, in_splits, out_splits)):
conv_groups = out_ch if depthwise else 1
# use add_module to keep key space clean
self.add_module(
str(idx),
create_conv2d_pad(
in_ch, out_ch, k, stride=stride,
padding=padding, dilation=dilation, groups=conv_groups, **kwargs)
)
self.splits = in_splits
def forward(self, x):
x_split = torch.split(x, self.splits, 1)
x_out = [c(x_split[i]) for i, c in enumerate(self.values())]
x = torch.cat(x_out, 1)
return x

@ -0,0 +1,33 @@
""" Padding Helpers
Hacked together by Ross Wightman
"""
import math
from typing import List
import torch.nn.functional as F
# Calculate symmetric padding for a convolution
def get_padding(kernel_size: int, stride: int = 1, dilation: int = 1, **_) -> int:
padding = ((stride - 1) + dilation * (kernel_size - 1)) // 2
return padding
# Calculate asymmetric TensorFlow-like 'SAME' padding for a convolution
def get_same_padding(x: int, k: int, s: int, d: int):
return max((math.ceil(x / s) - 1) * s + (k - 1) * d + 1 - x, 0)
# Can SAME padding for given args be done statically?
def is_static_pad(kernel_size: int, stride: int = 1, dilation: int = 1, **_):
return stride == 1 and (dilation * (kernel_size - 1)) % 2 == 0
# Dynamically pad input x with 'SAME' padding for conv with specified args
def pad_same(x, k: List[int], s: List[int], d: List[int] = (1, 1)):
ih, iw = x.size()[-2:]
pad_h, pad_w = get_same_padding(ih, k[0], s[0], d[0]), get_same_padding(iw, k[1], s[1], d[1])
if pad_h > 0 or pad_w > 0:
x = F.pad(x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2])
return x

@ -0,0 +1,21 @@
from torch import nn as nn
class SEModule(nn.Module):
def __init__(self, channels, reduction=16, act_layer=nn.ReLU):
super(SEModule, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1)
reduction_channels = max(channels // reduction, 8)
self.fc1 = nn.Conv2d(
channels, reduction_channels, kernel_size=1, padding=0, bias=True)
self.act = act_layer(inplace=True)
self.fc2 = nn.Conv2d(
reduction_channels, channels, kernel_size=1, padding=0, bias=True)
def forward(self, x):
x_se = self.avg_pool(x)
x_se = self.fc1(x_se)
x_se = self.act(x_se)
x_se = self.fc2(x_se)
return x * x_se.sigmoid()

@ -0,0 +1,88 @@
""" Selective Kernel Convolution Attention
Hacked together by Ross Wightman
"""
import torch
from torch import nn as nn
from .conv_bn_act import ConvBnAct
def _kernel_valid(k):
if isinstance(k, (list, tuple)):
for ki in k:
return _kernel_valid(ki)
assert k >= 3 and k % 2
class SelectiveKernelAttn(nn.Module):
def __init__(self, channels, num_paths=2, attn_channels=32,
act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d):
super(SelectiveKernelAttn, self).__init__()
self.num_paths = num_paths
self.pool = nn.AdaptiveAvgPool2d(1)
self.fc_reduce = nn.Conv2d(channels, attn_channels, kernel_size=1, bias=False)
self.bn = norm_layer(attn_channels)
self.act = act_layer(inplace=True)
self.fc_select = nn.Conv2d(attn_channels, channels * num_paths, kernel_size=1, bias=False)
def forward(self, x):
assert x.shape[1] == self.num_paths
x = torch.sum(x, dim=1)
x = self.pool(x)
x = self.fc_reduce(x)
x = self.bn(x)
x = self.act(x)
x = self.fc_select(x)
B, C, H, W = x.shape
x = x.view(B, self.num_paths, C // self.num_paths, H, W)
x = torch.softmax(x, dim=1)
return x
class SelectiveKernelConv(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size=None, stride=1, dilation=1, groups=1,
attn_reduction=16, min_attn_channels=32, keep_3x3=True, split_input=False,
drop_block=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d):
super(SelectiveKernelConv, self).__init__()
kernel_size = kernel_size or [3, 5]
_kernel_valid(kernel_size)
if not isinstance(kernel_size, list):
kernel_size = [kernel_size] * 2
if keep_3x3:
dilation = [dilation * (k - 1) // 2 for k in kernel_size]
kernel_size = [3] * len(kernel_size)
else:
dilation = [dilation] * len(kernel_size)
self.num_paths = len(kernel_size)
self.in_channels = in_channels
self.out_channels = out_channels
self.split_input = split_input
if self.split_input:
assert in_channels % self.num_paths == 0
in_channels = in_channels // self.num_paths
groups = min(out_channels, groups)
conv_kwargs = dict(
stride=stride, groups=groups, drop_block=drop_block, act_layer=act_layer, norm_layer=norm_layer)
self.paths = nn.ModuleList([
ConvBnAct(in_channels, out_channels, kernel_size=k, dilation=d, **conv_kwargs)
for k, d in zip(kernel_size, dilation)])
attn_channels = max(int(out_channels / attn_reduction), min_attn_channels)
self.attn = SelectiveKernelAttn(out_channels, self.num_paths, attn_channels)
self.drop_block = drop_block
def forward(self, x):
if self.split_input:
x_split = torch.split(x, self.in_channels // self.num_paths, 1)
x_paths = [op(x_split[i]) for i, op in enumerate(self.paths)]
else:
x_paths = [op(x) for op in self.paths]
x = torch.stack(x_paths, dim=1)
x_attn = self.attn(x)
x = x * x_attn
x = torch.sum(x, dim=1)
return x

@ -1,3 +1,8 @@
""" Test Time Pooling (Average-Max Pool)
Hacked together by Ross Wightman
"""
import logging
from torch import nn
import torch.nn.functional as F
@ -29,6 +34,8 @@ class TestTimePoolHead(nn.Module):
def apply_test_time_pool(model, config, args):
test_time_pool = False
if not hasattr(model, 'default_cfg') or not model.default_cfg:
return model, False
if not args.no_test_pool and \
config['input_size'][-1] > model.default_cfg['input_size'][-1] and \
config['input_size'][-2] > model.default_cfg['input_size'][-2]:

@ -7,15 +7,12 @@ Paper: Searching for MobileNetV3 - https://arxiv.org/abs/1905.02244
Hacked together by Ross Wightman
"""
import torch.nn as nn
import torch.nn.functional as F
from .efficientnet_builder import *
from .activations import HardSwish, hard_sigmoid
from .registry import register_model
from .helpers import load_pretrained
from .adaptive_avgmax_pool import SelectAdaptivePool2d
from .conv2d_layers import select_conv2d
from .layers import SelectAdaptivePool2d, create_conv2d
from .layers.activations import HardSwish, hard_sigmoid
from .feature_hooks import FeatureHooks
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
@ -85,7 +82,7 @@ class MobileNetV3(nn.Module):
# Stem
stem_size = round_channels(stem_size, channel_multiplier)
self.conv_stem = select_conv2d(self._in_chs, stem_size, 3, stride=2, padding=pad_type)
self.conv_stem = create_conv2d(self._in_chs, stem_size, 3, stride=2, padding=pad_type)
self.bn1 = norm_layer(stem_size, **norm_kwargs)
self.act1 = act_layer(inplace=True)
self._in_chs = stem_size
@ -100,7 +97,7 @@ class MobileNetV3(nn.Module):
# Head + Pooling
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
self.conv_head = select_conv2d(self._in_chs, self.num_features, 1, padding=pad_type, bias=head_bias)
self.conv_head = create_conv2d(self._in_chs, self.num_features, 1, padding=pad_type, bias=head_bias)
self.act2 = act_layer(inplace=True)
# Classifier
@ -165,7 +162,7 @@ class MobileNetV3Features(nn.Module):
# Stem
stem_size = round_channels(stem_size, channel_multiplier)
self.conv_stem = select_conv2d(self._in_chs, stem_size, 3, stride=2, padding=pad_type)
self.conv_stem = create_conv2d(self._in_chs, stem_size, 3, stride=2, padding=pad_type)
self.bn1 = norm_layer(stem_size, **norm_kwargs)
self.act1 = act_layer(inplace=True)
self._in_chs = stem_size

@ -4,7 +4,7 @@ import torch.nn.functional as F
from .registry import register_model
from .helpers import load_pretrained
from .adaptive_avgmax_pool import SelectAdaptivePool2d
from .layers import SelectAdaptivePool2d
__all__ = ['NASNetALarge']

@ -14,7 +14,7 @@ import torch.nn.functional as F
from .registry import register_model
from .helpers import load_pretrained
from .adaptive_avgmax_pool import SelectAdaptivePool2d
from .layers import SelectAdaptivePool2d
__all__ = ['PNASNet5Large']

@ -8,10 +8,10 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
from .resnet import ResNet, SEModule
from .resnet import ResNet
from .registry import register_model
from .helpers import load_pretrained
from .adaptive_avgmax_pool import SelectAdaptivePool2d
from .layers import SEModule
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
__all__ = []
@ -53,15 +53,16 @@ class Bottle2neck(nn.Module):
expansion = 4
def __init__(self, inplanes, planes, stride=1, downsample=None,
cardinality=1, base_width=26, scale=4, use_se=False,
act_layer=nn.ReLU, norm_layer=None, dilation=1, previous_dilation=1, **_):
cardinality=1, base_width=26, scale=4, dilation=1, first_dilation=None,
act_layer=nn.ReLU, norm_layer=None, attn_layer=None, **_):
super(Bottle2neck, self).__init__()
self.scale = scale
self.is_first = stride > 1 or downsample is not None
self.num_scales = max(1, scale - 1)
width = int(math.floor(planes * (base_width / 64.0))) * cardinality
outplanes = planes * self.expansion
self.width = width
outplanes = planes * self.expansion
first_dilation = first_dilation or dilation
self.conv1 = nn.Conv2d(inplanes, width * scale, kernel_size=1, bias=False)
self.bn1 = norm_layer(width * scale)
@ -70,8 +71,8 @@ class Bottle2neck(nn.Module):
bns = []
for i in range(self.num_scales):
convs.append(nn.Conv2d(
width, width, kernel_size=3, stride=stride, padding=dilation,
dilation=dilation, groups=cardinality, bias=False))
width, width, kernel_size=3, stride=stride, padding=first_dilation,
dilation=first_dilation, groups=cardinality, bias=False))
bns.append(norm_layer(width))
self.convs = nn.ModuleList(convs)
self.bns = nn.ModuleList(bns)
@ -81,11 +82,14 @@ class Bottle2neck(nn.Module):
self.conv3 = nn.Conv2d(width * scale, outplanes, kernel_size=1, bias=False)
self.bn3 = norm_layer(outplanes)
self.se = SEModule(outplanes, planes // 4) if use_se else None
self.se = attn_layer(outplanes) if attn_layer is not None else None
self.relu = act_layer(inplace=True)
self.downsample = downsample
def zero_init_last_bn(self):
nn.init.zeros_(self.bn3.weight)
def forward(self, x):
residual = x

@ -7,14 +7,12 @@ ResNeXt, SE-ResNeXt, SENet, and MXNet Gluon stem/downsample variants, tiered ste
"""
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from .registry import register_model
from .helpers import load_pretrained
from .adaptive_avgmax_pool import SelectAdaptivePool2d
from .EcaModule import EcaModule
from .layers import SelectAdaptivePool2d, DropBlock2d, DropPath, AvgPool2dSame, create_attn
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
@ -104,147 +102,179 @@ default_cfgs = {
'ecaresnext26tn_32x4d': _cfg(
url='',
interpolation='bicubic'),
'ecaresnet18': _cfg(),
'ecaresnet50': _cfg(),
}
def _get_padding(kernel_size, stride, dilation=1):
def get_padding(kernel_size, stride, dilation=1):
padding = ((stride - 1) + dilation * (kernel_size - 1)) // 2
return padding
class SEModule(nn.Module):
def __init__(self, channels, reduction_channels):
super(SEModule, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.fc1 = nn.Conv2d(
channels, reduction_channels, kernel_size=1, padding=0, bias=True)
self.relu = nn.ReLU(inplace=True)
self.fc2 = nn.Conv2d(
reduction_channels, channels, kernel_size=1, padding=0, bias=True)
def forward(self, x):
x_se = self.avg_pool(x)
x_se = self.fc1(x_se)
x_se = self.relu(x_se)
x_se = self.fc2(x_se)
return x * x_se.sigmoid()
class BasicBlock(nn.Module):
__constants__ = ['se', 'downsample'] # for pre 1.4 torchscript compat
expansion = 1
def __init__(self, inplanes, planes, stride=1, downsample=None,
cardinality=1, base_width=64, use_se=False, use_eca = False,
reduce_first=1, dilation=1, previous_dilation=1, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d):
def __init__(self, inplanes, planes, stride=1, downsample=None, cardinality=1, base_width=64,
reduce_first=1, dilation=1, first_dilation=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d,
attn_layer=None, drop_block=None, drop_path=None):
super(BasicBlock, self).__init__()
assert cardinality == 1, 'BasicBlock only supports cardinality of 1'
assert base_width == 64, 'BasicBlock doest not support changing base width'
first_planes = planes // reduce_first
outplanes = planes * self.expansion
first_dilation = first_dilation or dilation
self.conv1 = nn.Conv2d(
inplanes, first_planes, kernel_size=3, stride=stride, padding=dilation,
dilation=dilation, bias=False)
inplanes, first_planes, kernel_size=3, stride=stride, padding=first_dilation,
dilation=first_dilation, bias=False)
self.bn1 = norm_layer(first_planes)
self.act1 = act_layer(inplace=True)
self.conv2 = nn.Conv2d(
first_planes, outplanes, kernel_size=3, padding=previous_dilation,
dilation=previous_dilation, bias=False)
first_planes, outplanes, kernel_size=3, padding=dilation, dilation=dilation, bias=False)
self.bn2 = norm_layer(outplanes)
self.se = SEModule(outplanes, planes // 4) if use_se else None
self.eca = EcaModule(outplanes) if use_eca else None
self.se = create_attn(attn_layer, outplanes)
self.act2 = act_layer(inplace=True)
self.downsample = downsample
self.stride = stride
self.dilation = dilation
self.drop_block = drop_block
self.drop_path = drop_path
def zero_init_last_bn(self):
nn.init.zeros_(self.bn2.weight)
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.act1(out)
out = self.conv2(out)
out = self.bn2(out)
x = self.conv1(x)
x = self.bn1(x)
if self.drop_block is not None:
x = self.drop_block(x)
x = self.act1(x)
x = self.conv2(x)
x = self.bn2(x)
if self.drop_block is not None:
x = self.drop_block(x)
if self.se is not None:
out = self.se(out)
if self.eca is not None:
out = self.eca(out)
x = self.se(x)
if self.downsample is not None:
residual = self.downsample(x)
if self.drop_path is not None:
x = self.drop_path(x)
out += residual
out = self.act2(out)
if self.downsample is not None:
residual = self.downsample(residual)
x += residual
x = self.act2(x)
return out
return x
class Bottleneck(nn.Module):
__constants__ = ['se', 'downsample'] # for pre 1.4 torchscript compat
expansion = 4
def __init__(self, inplanes, planes, stride=1, downsample=None,
cardinality=1, base_width=64, use_se=False, use_eca=False,
reduce_first=1, dilation=1, previous_dilation=1, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d):
def __init__(self, inplanes, planes, stride=1, downsample=None, cardinality=1, base_width=64,
reduce_first=1, dilation=1, first_dilation=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d,
attn_layer=None, drop_block=None, drop_path=None):
super(Bottleneck, self).__init__()
width = int(math.floor(planes * (base_width / 64)) * cardinality)
first_planes = width // reduce_first
outplanes = planes * self.expansion
first_dilation = first_dilation or dilation
self.conv1 = nn.Conv2d(inplanes, first_planes, kernel_size=1, bias=False)
self.bn1 = norm_layer(first_planes)
self.act1 = act_layer(inplace=True)
self.conv2 = nn.Conv2d(
first_planes, width, kernel_size=3, stride=stride,
padding=dilation, dilation=dilation, groups=cardinality, bias=False)
padding=first_dilation, dilation=first_dilation, groups=cardinality, bias=False)
self.bn2 = norm_layer(width)
self.act2 = act_layer(inplace=True)
self.conv3 = nn.Conv2d(width, outplanes, kernel_size=1, bias=False)
self.bn3 = norm_layer(outplanes)
self.se = SEModule(outplanes, planes // 4) if use_se else None
self.eca = Eca_Module(outplanes) if use_eca else None
self.se = create_attn(attn_layer, outplanes)
self.act3 = act_layer(inplace=True)
self.downsample = downsample
self.stride = stride
self.dilation = dilation
self.drop_block = drop_block
self.drop_path = drop_path
def zero_init_last_bn(self):
nn.init.zeros_(self.bn3.weight)
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.act1(out)
x = self.conv1(x)
x = self.bn1(x)
if self.drop_block is not None:
x = self.drop_block(x)
x = self.act1(x)
out = self.conv2(out)
out = self.bn2(out)
out = self.act2(out)
x = self.conv2(x)
x = self.bn2(x)
if self.drop_block is not None:
x = self.drop_block(x)
x = self.act2(x)
out = self.conv3(out)
out = self.bn3(out)
x = self.conv3(x)
x = self.bn3(x)
if self.drop_block is not None:
x = self.drop_block(x)
if self.se is not None:
out = self.se(out)
if self.eca is not None:
out = self.eca(out)
x = self.se(x)
if self.drop_path is not None:
x = self.drop_path(x)
if self.downsample is not None:
residual = self.downsample(x)
residual = self.downsample(residual)
x += residual
x = self.act3(x)
return x
def downsample_conv(
in_channels, out_channels, kernel_size, stride=1, dilation=1, first_dilation=None, norm_layer=None):
norm_layer = norm_layer or nn.BatchNorm2d
kernel_size = 1 if stride == 1 and dilation == 1 else kernel_size
first_dilation = (first_dilation or dilation) if kernel_size > 1 else 1
p = get_padding(kernel_size, stride, first_dilation)
return nn.Sequential(*[
nn.Conv2d(
in_channels, out_channels, kernel_size, stride=stride, padding=p, dilation=first_dilation, bias=False),
norm_layer(out_channels)
])
out += residual
out = self.act3(out)
def downsample_avg(
in_channels, out_channels, kernel_size, stride=1, dilation=1, first_dilation=None, norm_layer=None):
norm_layer = norm_layer or nn.BatchNorm2d
avg_stride = stride if dilation == 1 else 1
if stride == 1 and dilation == 1:
pool = nn.Identity()
else:
avg_pool_fn = AvgPool2dSame if avg_stride == 1 and dilation > 1 else nn.AvgPool2d
pool = avg_pool_fn(2, avg_stride, ceil_mode=True, count_include_pad=False)
return out
return nn.Sequential(*[
pool,
nn.Conv2d(in_channels, out_channels, 1, stride=1, padding=0, bias=False),
norm_layer(out_channels)
])
class ResNet(nn.Module):
@ -288,10 +318,6 @@ class ResNet(nn.Module):
Number of classification classes.
in_chans : int, default 3
Number of input (color) channels.
use_se : bool, default False
Enable Squeeze-Excitation module in blocks
use_eca : bool, default False
Enable ECA module in blocks
cardinality : int, default 1
Number of convolution groups for 3x3 conv in Bottleneck.
base_width : int, default 64
@ -320,11 +346,11 @@ class ResNet(nn.Module):
global_pool : str, default 'avg'
Global pooling type. One of 'avg', 'max', 'avgmax', 'catavgmax'
"""
def __init__(self, block, layers, num_classes=1000, in_chans=3, use_se=False, use_eca=False,
def __init__(self, block, layers, num_classes=1000, in_chans=3,
cardinality=1, base_width=64, stem_width=64, stem_type='',
block_reduce_first=1, down_kernel_size=1, avg_down=False, output_stride=32,
act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, drop_rate=0.0, global_pool='avg',
zero_init_last_bn=True, block_args=None):
act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, drop_rate=0.0, drop_path_rate=0.,
drop_block_rate=0., global_pool='avg', zero_init_last_bn=True, block_args=None):
block_args = block_args or dict()
self.num_classes = num_classes
deep_stem = 'deep' in stem_type
@ -356,6 +382,9 @@ class ResNet(nn.Module):
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
# Feature Blocks
dp = DropPath(drop_path_rate) if drop_block_rate else None
db_3 = DropBlock2d(drop_block_rate, 7, 0.25) if drop_block_rate else None
db_4 = DropBlock2d(drop_block_rate, 7, 1.00) if drop_block_rate else None
channels, strides, dilations = [64, 128, 256, 512], [1, 2, 2, 2], [1] * 4
if output_stride == 16:
strides[3] = 1
@ -365,61 +394,47 @@ class ResNet(nn.Module):
dilations[2:4] = [2, 4]
else:
assert output_stride == 32
llargs = list(zip(channels, layers, strides, dilations))
lkwargs = dict(
use_se=use_se, use_eca=use_eca, reduce_first=block_reduce_first, act_layer=act_layer, norm_layer=norm_layer,
avg_down=avg_down, down_kernel_size=down_kernel_size, **block_args)
self.layer1 = self._make_layer(block, *llargs[0], **lkwargs)
self.layer2 = self._make_layer(block, *llargs[1], **lkwargs)
self.layer3 = self._make_layer(block, *llargs[2], **lkwargs)
self.layer4 = self._make_layer(block, *llargs[3], **lkwargs)
layer_args = list(zip(channels, layers, strides, dilations))
layer_kwargs = dict(
reduce_first=block_reduce_first, act_layer=act_layer, norm_layer=norm_layer,
avg_down=avg_down, down_kernel_size=down_kernel_size, drop_path=dp, **block_args)
self.layer1 = self._make_layer(block, *layer_args[0], **layer_kwargs)
self.layer2 = self._make_layer(block, *layer_args[1], **layer_kwargs)
self.layer3 = self._make_layer(block, drop_block=db_3, *layer_args[2], **layer_kwargs)
self.layer4 = self._make_layer(block, drop_block=db_4, *layer_args[3], **layer_kwargs)
# Head (Pooling and Classifier)
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
self.num_features = 512 * block.expansion
self.fc = nn.Linear(self.num_features * self.global_pool.feat_mult(), num_classes)
last_bn_name = 'bn3' if 'Bottle' in block.__name__ else 'bn2'
for n, m in self.named_modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
elif isinstance(m, nn.BatchNorm2d):
if zero_init_last_bn and 'layer' in n and last_bn_name in n:
# Initialize weight/gamma of last BN in each residual block to zero
nn.init.constant_(m.weight, 0.)
else:
nn.init.constant_(m.weight, 1.)
nn.init.constant_(m.weight, 1.)
nn.init.constant_(m.bias, 0.)
if zero_init_last_bn:
for m in self.modules():
if hasattr(m, 'zero_init_last_bn'):
m.zero_init_last_bn()
def _make_layer(self, block, planes, blocks, stride=1, dilation=1, reduce_first=1,
use_se=False, use_eca=False,avg_down=False, down_kernel_size=1, **kwargs):
norm_layer = kwargs.get('norm_layer')
avg_down=False, down_kernel_size=1, **kwargs):
downsample = None
down_kernel_size = 1 if stride == 1 and dilation == 1 else down_kernel_size
first_dilation = 1 if dilation in (1, 2) else 2
if stride != 1 or self.inplanes != planes * block.expansion:
downsample_padding = _get_padding(down_kernel_size, stride)
downsample_layers = []
conv_stride = stride
if avg_down:
avg_stride = stride if dilation == 1 else 1
conv_stride = 1
downsample_layers = [nn.AvgPool2d(avg_stride, avg_stride, ceil_mode=True, count_include_pad=False)]
downsample_layers += [
nn.Conv2d(self.inplanes, planes * block.expansion, down_kernel_size,
stride=conv_stride, padding=downsample_padding, bias=False),
norm_layer(planes * block.expansion)]
downsample = nn.Sequential(*downsample_layers)
downsample_args = dict(
in_channels=self.inplanes, out_channels=planes * block.expansion, kernel_size=down_kernel_size,
stride=stride, dilation=dilation, first_dilation=first_dilation, norm_layer=kwargs.get('norm_layer'))
downsample = downsample_avg(**downsample_args) if avg_down else downsample_conv(**downsample_args)
first_dilation = 1 if dilation in (1, 2) else 2
bkwargs = dict(
block_kwargs = dict(
cardinality=self.cardinality, base_width=self.base_width, reduce_first=reduce_first,
use_se=use_se, use_eca=use_eca, **kwargs)
layers = [block(
self.inplanes, planes, stride, downsample, dilation=first_dilation, previous_dilation=dilation, **bkwargs)]
dilation=dilation, **kwargs)
layers = [block(self.inplanes, planes, stride, downsample, first_dilation=first_dilation, **block_kwargs)]
self.inplanes = planes * block.expansion
for i in range(1, blocks):
layers.append(block(
self.inplanes, planes, dilation=dilation, previous_dilation=dilation, **bkwargs))
layers += [block(self.inplanes, planes, **block_kwargs) for _ in range(1, blocks)]
return nn.Sequential(*layers)
@ -447,8 +462,8 @@ class ResNet(nn.Module):
def forward(self, x):
x = self.forward_features(x)
x = self.global_pool(x).flatten(1)
if self.drop_rate > 0.:
x = F.dropout(x, p=self.drop_rate, training=self.training)
if self.drop_rate:
x = F.dropout(x, p=float(self.drop_rate), training=self.training)
x = self.fc(x)
return x
@ -920,9 +935,8 @@ def seresnext26d_32x4d(pretrained=False, num_classes=1000, in_chans=3, **kwargs)
"""
default_cfg = default_cfgs['seresnext26d_32x4d']
model = ResNet(
Bottleneck, [2, 2, 2, 2], cardinality=32, base_width=4,
stem_width=32, stem_type='deep', avg_down=True, use_se=True,
num_classes=num_classes, in_chans=in_chans, **kwargs)
Bottleneck, [2, 2, 2, 2], cardinality=32, base_width=4, stem_width=32, stem_type='deep', avg_down=True,
num_classes=num_classes, in_chans=in_chans, block_args=dict(attn_layer='se'), **kwargs)
model.default_cfg = default_cfg
if pretrained:
load_pretrained(model, default_cfg, num_classes, in_chans)
@ -938,8 +952,8 @@ def seresnext26t_32x4d(pretrained=False, num_classes=1000, in_chans=3, **kwargs)
default_cfg = default_cfgs['seresnext26t_32x4d']
model = ResNet(
Bottleneck, [2, 2, 2, 2], cardinality=32, base_width=4,
stem_width=32, stem_type='deep_tiered', avg_down=True, use_se=True,
num_classes=num_classes, in_chans=in_chans, **kwargs)
stem_width=32, stem_type='deep_tiered', avg_down=True,
num_classes=num_classes, in_chans=in_chans, block_args=dict(attn_layer='se'), **kwargs)
model.default_cfg = default_cfg
if pretrained:
load_pretrained(model, default_cfg, num_classes, in_chans)
@ -955,25 +969,55 @@ def seresnext26tn_32x4d(pretrained=False, num_classes=1000, in_chans=3, **kwargs
default_cfg = default_cfgs['seresnext26tn_32x4d']
model = ResNet(
Bottleneck, [2, 2, 2, 2], cardinality=32, base_width=4,
stem_width=32, stem_type='deep_tiered_narrow', avg_down=True, use_se=True,
num_classes=num_classes, in_chans=in_chans, **kwargs)
stem_width=32, stem_type='deep_tiered_narrow', avg_down=True,
num_classes=num_classes, in_chans=in_chans, block_args=dict(attn_layer='se'), **kwargs)
model.default_cfg = default_cfg
if pretrained:
load_pretrained(model, default_cfg, num_classes, in_chans)
return model
@register_model
def ecaresnext26tn_32x4d(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
"""Constructs a eca-ResNeXt-26-TN model.
"""Constructs an ECA-ResNeXt-26-TN model.
This is technically a 28 layer ResNet, like a 'D' bag-of-tricks model but with tiered 24, 32, 64 channels
in the deep stem. The channel number of the middle stem conv is narrower than the 'T' variant.
this model replaces SE module with the ECA module
"""
default_cfg = default_cfgs['ecaresnext26tn_32x4d']
block_args = dict(attn_layer='eca')
model = ResNet(
Bottleneck, [2, 2, 2, 2], cardinality=32, base_width=4,
stem_width=32, stem_type='deep_tiered_narrow', avg_down=True, use_eca=True,
num_classes=num_classes, in_chans=in_chans, **kwargs)
stem_width=32, stem_type='deep_tiered_narrow', avg_down=True,
num_classes=num_classes, in_chans=in_chans, block_args=block_args, **kwargs)
model.default_cfg = default_cfg
if pretrained:
load_pretrained(model, default_cfg, num_classes, in_chans)
return model
@register_model
def ecaresnet18(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
""" Constructs an ECA-ResNet-18 model.
"""
default_cfg = default_cfgs['ecaresnet18']
block_args = dict(attn_layer='eca')
model = ResNet(
BasicBlock, [2, 2, 2, 2], num_classes=num_classes, in_chans=in_chans, block_args=block_args, **kwargs)
model.default_cfg = default_cfg
if pretrained:
load_pretrained(model, default_cfg, num_classes, in_chans)
return model
@register_model
def ecaresnet50(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
"""Constructs an ECA-ResNet-50 model.
"""
default_cfg = default_cfgs['ecaresnet50']
block_args = dict(attn_layer='eca')
model = ResNet(
Bottleneck, [3, 4, 6, 3], num_classes=num_classes, in_chans=in_chans, block_args=block_args, **kwargs)
model.default_cfg = default_cfg
if pretrained:
load_pretrained(model, default_cfg, num_classes, in_chans)

@ -17,7 +17,7 @@ import torch.nn.functional as F
from .registry import register_model
from .helpers import load_pretrained
from .adaptive_avgmax_pool import SelectAdaptivePool2d
from .layers import SelectAdaptivePool2d
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
__all__ = ['SelecSLS'] # model_registry will add each entrypoint fn to this

@ -16,7 +16,7 @@ import torch.nn.functional as F
from .registry import register_model
from .helpers import load_pretrained
from .adaptive_avgmax_pool import SelectAdaptivePool2d
from .layers import SelectAdaptivePool2d
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
__all__ = ['SENet']

@ -0,0 +1,240 @@
import math
from torch import nn as nn
from .registry import register_model
from .helpers import load_pretrained
from .layers import SelectiveKernelConv, ConvBnAct, create_attn
from .resnet import ResNet
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
def _cfg(url='', **kwargs):
return {
'url': url,
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
'crop_pct': 0.875, 'interpolation': 'bilinear',
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
'first_conv': 'conv1', 'classifier': 'fc',
**kwargs
}
default_cfgs = {
'skresnet18': _cfg(url=''),
'skresnet26d': _cfg(),
'skresnet50': _cfg(),
'skresnet50d': _cfg(),
'skresnext50_32x4d': _cfg(),
}
class SelectiveKernelBasic(nn.Module):
expansion = 1
def __init__(self, inplanes, planes, stride=1, downsample=None, cardinality=1, base_width=64,
sk_kwargs=None, reduce_first=1, dilation=1, first_dilation=None,
drop_block=None, drop_path=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, attn_layer=None):
super(SelectiveKernelBasic, self).__init__()
sk_kwargs = sk_kwargs or {}
conv_kwargs = dict(drop_block=drop_block, act_layer=act_layer, norm_layer=norm_layer)
assert cardinality == 1, 'BasicBlock only supports cardinality of 1'
assert base_width == 64, 'BasicBlock doest not support changing base width'
first_planes = planes // reduce_first
outplanes = planes * self.expansion
first_dilation = first_dilation or dilation
_selective_first = True # FIXME temporary, for experiments
if _selective_first:
self.conv1 = SelectiveKernelConv(
inplanes, first_planes, stride=stride, dilation=first_dilation, **conv_kwargs, **sk_kwargs)
conv_kwargs['act_layer'] = None
self.conv2 = ConvBnAct(
first_planes, outplanes, kernel_size=3, dilation=dilation, **conv_kwargs)
else:
self.conv1 = ConvBnAct(
inplanes, first_planes, kernel_size=3, stride=stride, dilation=first_dilation, **conv_kwargs)
conv_kwargs['act_layer'] = None
self.conv2 = SelectiveKernelConv(
first_planes, outplanes, dilation=dilation, **conv_kwargs, **sk_kwargs)
self.se = create_attn(attn_layer, outplanes)
self.act = act_layer(inplace=True)
self.downsample = downsample
self.stride = stride
self.dilation = dilation
self.drop_block = drop_block
self.drop_path = drop_path
def zero_init_last_bn(self):
nn.init.zeros_(self.conv2.bn.weight)
def forward(self, x):
residual = x
x = self.conv1(x)
x = self.conv2(x)
if self.se is not None:
x = self.se(x)
if self.drop_path is not None:
x = self.drop_path(x)
if self.downsample is not None:
residual = self.downsample(residual)
x += residual
x = self.act(x)
return x
class SelectiveKernelBottleneck(nn.Module):
expansion = 4
def __init__(self, inplanes, planes, stride=1, downsample=None,
cardinality=1, base_width=64, sk_kwargs=None, reduce_first=1, dilation=1, first_dilation=None,
drop_block=None, drop_path=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, attn_layer=None):
super(SelectiveKernelBottleneck, self).__init__()
sk_kwargs = sk_kwargs or {}
conv_kwargs = dict(drop_block=drop_block, act_layer=act_layer, norm_layer=norm_layer)
width = int(math.floor(planes * (base_width / 64)) * cardinality)
first_planes = width // reduce_first
outplanes = planes * self.expansion
first_dilation = first_dilation or dilation
self.conv1 = ConvBnAct(inplanes, first_planes, kernel_size=1, **conv_kwargs)
self.conv2 = SelectiveKernelConv(
first_planes, width, stride=stride, dilation=first_dilation, groups=cardinality,
**conv_kwargs, **sk_kwargs)
conv_kwargs['act_layer'] = None
self.conv3 = ConvBnAct(width, outplanes, kernel_size=1, **conv_kwargs)
self.se = create_attn(attn_layer, outplanes)
self.act = act_layer(inplace=True)
self.downsample = downsample
self.stride = stride
self.dilation = dilation
self.drop_block = drop_block
self.drop_path = drop_path
def zero_init_last_bn(self):
nn.init.zeros_(self.conv3.bn.weight)
def forward(self, x):
residual = x
x = self.conv1(x)
x = self.conv2(x)
x = self.conv3(x)
if self.se is not None:
x = self.se(x)
if self.drop_path is not None:
x = self.drop_path(x)
if self.downsample is not None:
residual = self.downsample(residual)
x += residual
x = self.act(x)
return x
@register_model
def skresnet18(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
"""Constructs a ResNet-18 model.
"""
default_cfg = default_cfgs['skresnet18']
sk_kwargs = dict(
min_attn_channels=16,
)
model = ResNet(
SelectiveKernelBasic, [2, 2, 2, 2], num_classes=num_classes, in_chans=in_chans,
block_args=dict(sk_kwargs=sk_kwargs), **kwargs)
model.default_cfg = default_cfg
if pretrained:
load_pretrained(model, default_cfg, num_classes, in_chans)
return model
@register_model
def sksresnet18(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
"""Constructs a ResNet-18 model.
"""
default_cfg = default_cfgs['skresnet18']
sk_kwargs = dict(
min_attn_channels=16,
attn_reduction=8,
split_input=True
)
model = ResNet(
SelectiveKernelBasic, [2, 2, 2, 2], num_classes=num_classes, in_chans=in_chans,
block_args=dict(sk_kwargs=sk_kwargs), zero_init_last_bn=False, **kwargs)
model.default_cfg = default_cfg
if pretrained:
load_pretrained(model, default_cfg, num_classes, in_chans)
return model
@register_model
def skresnet26d(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
"""Constructs a ResNet-26 model.
"""
default_cfg = default_cfgs['skresnet26d']
sk_kwargs = dict(
keep_3x3=False,
)
model = ResNet(
SelectiveKernelBottleneck, [2, 2, 2, 2], stem_width=32, stem_type='deep', avg_down=True,
num_classes=num_classes, in_chans=in_chans, block_args=dict(sk_kwargs=sk_kwargs), zero_init_last_bn=False
**kwargs)
model.default_cfg = default_cfg
if pretrained:
load_pretrained(model, default_cfg, num_classes, in_chans)
return model
@register_model
def skresnet50(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
"""Constructs a Select Kernel ResNet-50 model.
Based on config in "Compounding the Performance Improvements of Assembled Techniques in a
Convolutional Neural Network"
"""
sk_kwargs = dict(
attn_reduction=2,
)
default_cfg = default_cfgs['skresnet50']
model = ResNet(
SelectiveKernelBottleneck, [3, 4, 6, 3], num_classes=num_classes, in_chans=in_chans,
block_args=dict(sk_kwargs=sk_kwargs), zero_init_last_bn=False, **kwargs)
model.default_cfg = default_cfg
if pretrained:
load_pretrained(model, default_cfg, num_classes, in_chans)
return model
@register_model
def skresnet50d(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
"""Constructs a Select Kernel ResNet-50-D model.
Based on config in "Compounding the Performance Improvements of Assembled Techniques in a
Convolutional Neural Network"
"""
sk_kwargs = dict(
attn_reduction=2,
)
default_cfg = default_cfgs['skresnet50d']
model = ResNet(
SelectiveKernelBottleneck, [3, 4, 6, 3], stem_width=32, stem_type='deep', avg_down=True,
num_classes=num_classes, in_chans=in_chans, block_args=dict(sk_kwargs=sk_kwargs),
zero_init_last_bn=False, **kwargs)
model.default_cfg = default_cfg
if pretrained:
load_pretrained(model, default_cfg, num_classes, in_chans)
return model
@register_model
def skresnext50_32x4d(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
"""Constructs a Select Kernel ResNeXt50-32x4d model. This should be equivalent to
the SKNet50 model in the Select Kernel Paper
"""
default_cfg = default_cfgs['skresnext50_32x4d']
model = ResNet(
SelectiveKernelBottleneck, [3, 4, 6, 3], cardinality=32, base_width=4,
num_classes=num_classes, in_chans=in_chans, zero_init_last_bn=False, **kwargs)
model.default_cfg = default_cfg
if pretrained:
load_pretrained(model, default_cfg, num_classes, in_chans)
return model

@ -29,7 +29,7 @@ import torch.nn.functional as F
from .registry import register_model
from .helpers import load_pretrained
from .adaptive_avgmax_pool import SelectAdaptivePool2d
from .layers import SelectAdaptivePool2d
__all__ = ['Xception']

Loading…
Cancel
Save