Layer refactoring continues, ResNet downsample rewrite for proper dilation in 3x3 and avg_pool cases

* select_conv2d -> create_conv2d
* added create_attn to create attention module from string/bool/module
* factor padding helpers into own file, use in both conv2d_same and avg_pool2d_same
* add some more test eca resnet variants
* minor tweaks, naming, comments, consistency
pull/87/head
Ross Wightman 5 years ago
parent a99ec4e7d1
commit f902bcd54c

@ -28,7 +28,7 @@ from .feature_hooks import FeatureHooks
from .registry import register_model
from .helpers import load_pretrained
from .layers import SelectAdaptivePool2d
from timm.models.layers import select_conv2d
from timm.models.layers import create_conv2d
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
@ -220,7 +220,7 @@ class EfficientNet(nn.Module):
def __init__(self, block_args, num_classes=1000, num_features=1280, in_chans=3, stem_size=32,
channel_multiplier=1.0, channel_divisor=8, channel_min=None,
pad_type='', act_layer=nn.ReLU, drop_rate=0., drop_connect_rate=0.,
output_stride=32, pad_type='', act_layer=nn.ReLU, drop_rate=0., drop_connect_rate=0.,
se_kwargs=None, norm_layer=nn.BatchNorm2d, norm_kwargs=None, global_pool='avg'):
super(EfficientNet, self).__init__()
norm_kwargs = norm_kwargs or {}
@ -232,21 +232,21 @@ class EfficientNet(nn.Module):
# Stem
stem_size = round_channels(stem_size, channel_multiplier, channel_divisor, channel_min)
self.conv_stem = select_conv2d(self._in_chs, stem_size, 3, stride=2, padding=pad_type)
self.conv_stem = create_conv2d(self._in_chs, stem_size, 3, stride=2, padding=pad_type)
self.bn1 = norm_layer(stem_size, **norm_kwargs)
self.act1 = act_layer(inplace=True)
self._in_chs = stem_size
# Middle stages (IR/ER/DS Blocks)
builder = EfficientNetBuilder(
channel_multiplier, channel_divisor, channel_min, 32, pad_type, act_layer, se_kwargs,
channel_multiplier, channel_divisor, channel_min, output_stride, pad_type, act_layer, se_kwargs,
norm_layer, norm_kwargs, drop_connect_rate, verbose=_DEBUG)
self.blocks = nn.Sequential(*builder(self._in_chs, block_args))
self.feature_info = builder.features
self._in_chs = builder.in_chs
# Head + Pooling
self.conv_head = select_conv2d(self._in_chs, self.num_features, 1, padding=pad_type)
self.conv_head = create_conv2d(self._in_chs, self.num_features, 1, padding=pad_type)
self.bn2 = norm_layer(self.num_features, **norm_kwargs)
self.act2 = act_layer(inplace=True)
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
@ -314,7 +314,7 @@ class EfficientNetFeatures(nn.Module):
# Stem
stem_size = round_channels(stem_size, channel_multiplier, channel_divisor, channel_min)
self.conv_stem = select_conv2d(self._in_chs, stem_size, 3, stride=2, padding=pad_type)
self.conv_stem = create_conv2d(self._in_chs, stem_size, 3, stride=2, padding=pad_type)
self.bn1 = norm_layer(stem_size, **norm_kwargs)
self.act1 = act_layer(inplace=True)
self._in_chs = stem_size

@ -2,7 +2,7 @@ import torch
import torch.nn as nn
from torch.nn import functional as F
from .layers.activations import sigmoid
from .layers import select_conv2d
from .layers import create_conv2d
# Defaults used for Google/Tensorflow training of mobile networks /w RMSprop as per
@ -129,7 +129,7 @@ class ConvBnAct(nn.Module):
norm_layer=nn.BatchNorm2d, norm_kwargs=None):
super(ConvBnAct, self).__init__()
norm_kwargs = norm_kwargs or {}
self.conv = select_conv2d(in_chs, out_chs, kernel_size, stride=stride, dilation=dilation, padding=pad_type)
self.conv = create_conv2d(in_chs, out_chs, kernel_size, stride=stride, dilation=dilation, padding=pad_type)
self.bn1 = norm_layer(out_chs, **norm_kwargs)
self.act1 = act_layer(inplace=True)
@ -162,7 +162,7 @@ class DepthwiseSeparableConv(nn.Module):
self.has_pw_act = pw_act # activation after point-wise conv
self.drop_connect_rate = drop_connect_rate
self.conv_dw = select_conv2d(
self.conv_dw = create_conv2d(
in_chs, in_chs, dw_kernel_size, stride=stride, dilation=dilation, padding=pad_type, depthwise=True)
self.bn1 = norm_layer(in_chs, **norm_kwargs)
self.act1 = act_layer(inplace=True)
@ -174,7 +174,7 @@ class DepthwiseSeparableConv(nn.Module):
else:
self.se = None
self.conv_pw = select_conv2d(in_chs, out_chs, pw_kernel_size, padding=pad_type)
self.conv_pw = create_conv2d(in_chs, out_chs, pw_kernel_size, padding=pad_type)
self.bn2 = norm_layer(out_chs, **norm_kwargs)
self.act2 = act_layer(inplace=True) if self.has_pw_act else nn.Identity()
@ -223,12 +223,12 @@ class InvertedResidual(nn.Module):
self.drop_connect_rate = drop_connect_rate
# Point-wise expansion
self.conv_pw = select_conv2d(in_chs, mid_chs, exp_kernel_size, padding=pad_type, **conv_kwargs)
self.conv_pw = create_conv2d(in_chs, mid_chs, exp_kernel_size, padding=pad_type, **conv_kwargs)
self.bn1 = norm_layer(mid_chs, **norm_kwargs)
self.act1 = act_layer(inplace=True)
# Depth-wise convolution
self.conv_dw = select_conv2d(
self.conv_dw = create_conv2d(
mid_chs, mid_chs, dw_kernel_size, stride=stride, dilation=dilation,
padding=pad_type, depthwise=True, **conv_kwargs)
self.bn2 = norm_layer(mid_chs, **norm_kwargs)
@ -242,7 +242,7 @@ class InvertedResidual(nn.Module):
self.se = None
# Point-wise linear projection
self.conv_pwl = select_conv2d(mid_chs, out_chs, pw_kernel_size, padding=pad_type, **conv_kwargs)
self.conv_pwl = create_conv2d(mid_chs, out_chs, pw_kernel_size, padding=pad_type, **conv_kwargs)
self.bn3 = norm_layer(out_chs, **norm_kwargs)
def feature_module(self, location):
@ -356,7 +356,7 @@ class EdgeResidual(nn.Module):
self.drop_connect_rate = drop_connect_rate
# Expansion convolution
self.conv_exp = select_conv2d(in_chs, mid_chs, exp_kernel_size, padding=pad_type)
self.conv_exp = create_conv2d(in_chs, mid_chs, exp_kernel_size, padding=pad_type)
self.bn1 = norm_layer(mid_chs, **norm_kwargs)
self.act1 = act_layer(inplace=True)
@ -368,7 +368,7 @@ class EdgeResidual(nn.Module):
self.se = None
# Point-wise linear projection
self.conv_pwl = select_conv2d(
self.conv_pwl = create_conv2d(
mid_chs, out_chs, pw_kernel_size, stride=stride, dilation=dilation, padding=pad_type)
self.bn2 = norm_layer(out_chs, **norm_kwargs)

@ -11,6 +11,7 @@ import torch.nn.functional as F
from .registry import register_model
from .helpers import load_pretrained
from .layers import SEModule
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .resnet import ResNet, Bottleneck, BasicBlock
@ -319,8 +320,8 @@ def gluon_seresnext50_32x4d(pretrained=False, num_classes=1000, in_chans=3, **kw
"""
default_cfg = default_cfgs['gluon_seresnext50_32x4d']
model = ResNet(
Bottleneck, [3, 4, 6, 3], cardinality=32, base_width=4, use_se=True,
num_classes=num_classes, in_chans=in_chans, **kwargs)
Bottleneck, [3, 4, 6, 3], cardinality=32, base_width=4,
num_classes=num_classes, in_chans=in_chans, block_args=dict(attn_layer=SEModule), **kwargs)
model.default_cfg = default_cfg
if pretrained:
load_pretrained(model, default_cfg, num_classes, in_chans)
@ -333,8 +334,8 @@ def gluon_seresnext101_32x4d(pretrained=False, num_classes=1000, in_chans=3, **k
"""
default_cfg = default_cfgs['gluon_seresnext101_32x4d']
model = ResNet(
Bottleneck, [3, 4, 23, 3], cardinality=32, base_width=4, use_se=True,
num_classes=num_classes, in_chans=in_chans, **kwargs)
Bottleneck, [3, 4, 23, 3], cardinality=32, base_width=4,
num_classes=num_classes, in_chans=in_chans, block_args=dict(attn_layer=SEModule), **kwargs)
model.default_cfg = default_cfg
if pretrained:
load_pretrained(model, default_cfg, num_classes, in_chans)
@ -346,9 +347,10 @@ def gluon_seresnext101_64x4d(pretrained=False, num_classes=1000, in_chans=3, **k
"""Constructs a SEResNeXt-101-64x4d model.
"""
default_cfg = default_cfgs['gluon_seresnext101_64x4d']
block_args = dict(attn_layer=SEModule)
model = ResNet(
Bottleneck, [3, 4, 23, 3], cardinality=64, base_width=4, use_se=True,
num_classes=num_classes, in_chans=in_chans, **kwargs)
Bottleneck, [3, 4, 23, 3], cardinality=64, base_width=4,
num_classes=num_classes, in_chans=in_chans, block_args=block_args, **kwargs)
model.default_cfg = default_cfg
if pretrained:
load_pretrained(model, default_cfg, num_classes, in_chans)
@ -360,10 +362,10 @@ def gluon_senet154(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
"""Constructs an SENet-154 model.
"""
default_cfg = default_cfgs['gluon_senet154']
block_args = dict(attn_layer=SEModule)
model = ResNet(
Bottleneck, [3, 8, 36, 3], cardinality=64, base_width=4, use_se=True,
stem_type='deep', down_kernel_size=3, block_reduce_first=2,
num_classes=num_classes, in_chans=in_chans, **kwargs)
Bottleneck, [3, 8, 36, 3], cardinality=64, base_width=4, stem_type='deep', down_kernel_size=3,
block_reduce_first=2, num_classes=num_classes, in_chans=in_chans, block_args=block_args, **kwargs)
model.default_cfg = default_cfg
if pretrained:
load_pretrained(model, default_cfg, num_classes, in_chans)

@ -1,8 +1,13 @@
from .padding import get_padding
from .avg_pool2d_same import AvgPool2dSame
from .conv2d_same import Conv2dSame
from .conv_bn_act import ConvBnAct
from .mixed_conv2d import MixedConv2d
from .cond_conv2d import CondConv2d, get_condconv_initializer
from .select_conv2d import select_conv2d
from .create_conv2d import create_conv2d
from .create_attn import create_attn
from .selective_kernel import SelectiveKernelConv
from .se import SEModule
from .eca import EcaModule, CecaModule
from .activations import *
from .adaptive_avgmax_pool import \

@ -0,0 +1,31 @@
""" AvgPool2d w/ Same Padding
Hacked together by Ross Wightman
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import List
import math
from .helpers import tup_pair
from .padding import pad_same
def avg_pool2d_same(x, kernel_size: List[int], stride: List[int], padding: List[int] = (0, 0),
ceil_mode: bool = False, count_include_pad: bool = True):
x = pad_same(x, kernel_size, stride)
return F.avg_pool2d(x, kernel_size, stride, (0, 0), ceil_mode, count_include_pad)
class AvgPool2dSame(nn.AvgPool2d):
""" Tensorflow like 'SAME' wrapper for 2D average pooling
"""
def __init__(self, kernel_size: int, stride=None, padding=0, ceil_mode=False, count_include_pad=True):
kernel_size = tup_pair(kernel_size)
stride = tup_pair(stride)
super(AvgPool2dSame, self).__init__(kernel_size, stride, (0, 0), ceil_mode, count_include_pad)
def forward(self, x):
return avg_pool2d_same(
x, self.kernel_size, self.stride, self.padding, self.ceil_mode, self.count_include_pad)

@ -10,8 +10,8 @@ import torch
from torch import nn as nn
from torch.nn import functional as F
from .helpers import tup_pair
from .conv2d_same import get_padding_value, conv2d_same
from .conv_helpers import tup_pair
def get_condconv_initializer(initializer, num_experts, expert_shape):

@ -8,26 +8,13 @@ import torch.nn.functional as F
from typing import Union, List, Tuple, Optional, Callable
import math
from .conv_helpers import get_padding
def _is_static_pad(kernel_size, stride=1, dilation=1, **_):
return stride == 1 and (dilation * (kernel_size - 1)) % 2 == 0
def _calc_same_pad(i: int, k: int, s: int, d: int):
return max((math.ceil(i / s) - 1) * s + (k - 1) * d + 1 - i, 0)
from .padding import get_padding, pad_same, is_static_pad
def conv2d_same(
x, weight: torch.Tensor, bias: Optional[torch.Tensor] = None, stride: Tuple[int, int] = (1, 1),
padding: Tuple[int, int] = (0, 0), dilation: Tuple[int, int] = (1, 1), groups: int = 1):
ih, iw = x.size()[-2:]
kh, kw = weight.size()[-2:]
pad_h = _calc_same_pad(ih, kh, stride[0], dilation[0])
pad_w = _calc_same_pad(iw, kw, stride[1], dilation[1])
if pad_h > 0 or pad_w > 0:
x = F.pad(x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2])
x = pad_same(x, weight.shape[-2:], stride, dilation)
return F.conv2d(x, weight, bias, stride, (0, 0), dilation, groups)
@ -51,7 +38,7 @@ def get_padding_value(padding, kernel_size, **kwargs) -> Tuple[Tuple, bool]:
padding = padding.lower()
if padding == 'same':
# TF compatible 'SAME' padding, has a performance and GPU memory allocation impact
if _is_static_pad(kernel_size, **kwargs):
if is_static_pad(kernel_size, **kwargs):
# static case, no extra overhead
padding = get_padding(kernel_size, **kwargs)
else:

@ -4,7 +4,7 @@ Hacked together by Ross Wightman
"""
from torch import nn as nn
from timm.models.layers.conv_helpers import get_padding
from timm.models.layers import get_padding
class ConvBnAct(nn.Module):

@ -0,0 +1,30 @@
""" Select AttentionFactory Method
Hacked together by Ross Wightman
"""
import torch
from .se import SEModule
from .eca import EcaModule, CecaModule
def create_attn(attn_type, channels, **kwargs):
module_cls = None
if attn_type is not None:
if isinstance(attn_type, str):
attn_type = attn_type.lower()
if attn_type == 'se':
module_cls = SEModule
elif attn_type == 'eca':
module_cls = EcaModule
elif attn_type == 'eca':
module_cls = CecaModule
else:
assert False, "Invalid attn module (%s)" % attn_type
elif isinstance(attn_type, bool):
if attn_type:
module_cls = SEModule
else:
module_cls = attn_type
if module_cls is not None:
return module_cls(channels, **kwargs)
return None

@ -1,4 +1,4 @@
""" Select Conv2d Factory Method
""" Create Conv2d Factory Method
Hacked together by Ross Wightman
"""
@ -8,7 +8,7 @@ from .cond_conv2d import CondConv2d
from .conv2d_same import create_conv2d_pad
def select_conv2d(in_chs, out_chs, kernel_size, **kwargs):
def create_conv2d(in_chs, out_chs, kernel_size, **kwargs):
""" Select a 2d convolution implementation based on arguments
Creates and returns one of torch.nn.Conv2d, Conv2dSame, MixedConv2d, or CondConv2d.

@ -1,3 +1,9 @@
""" DropBlock, DropPath
PyTorch implementations of DropBlock and DropPath (Stochastic Depth) regularization layers.
Hacked together by Ross Wightman
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
@ -6,6 +12,8 @@ import math
def drop_block_2d(x, drop_prob=0.1, block_size=7, gamma_scale=1.0, drop_with_noise=False):
""" DropBlock. See https://arxiv.org/pdf/1810.12890.pdf
"""
_, _, height, width = x.shape
total_size = width * height
clipped_block_size = min(block_size, min(width, height))
@ -24,7 +32,7 @@ def drop_block_2d(x, drop_prob=0.1, block_size=7, gamma_scale=1.0, drop_with_noi
block_mask = ((2 - seed_drop_rate - valid_block + uniform_noise) >= 1).float()
block_mask = -F.max_pool2d(
-block_mask,
kernel_size=clipped_block_size, # block_size,
kernel_size=clipped_block_size, # block_size, ???
stride=1,
padding=clipped_block_size // 2)
@ -58,7 +66,8 @@ class DropBlock2d(nn.Module):
def drop_path(x, drop_prob=0.):
"""Drop paths (Stochastic Depth) per sample (when applied in residual blocks)."""
"""Drop paths (Stochastic Depth) per sample (when applied in residual blocks).
"""
keep_prob = 1 - drop_prob
random_tensor = keep_prob + torch.rand((x.size()[0], 1, 1, 1), dtype=x.dtype, device=x.device)
random_tensor.floor_() # binarize
@ -67,6 +76,8 @@ def drop_path(x, drop_prob=0.):
class DropPath(nn.ModuleDict):
"""Drop paths (Stochastic Depth) per sample (when applied in residual blocks).
"""
def __init__(self, drop_prob=None):
super(DropPath, self).__init__()
self.drop_prob = drop_prob

@ -47,19 +47,20 @@ class EcaModule(nn.Module):
gamma, beta: when channel is given parameters of mapping function
refer to original paper https://arxiv.org/pdf/1910.03151.pdf
(default=None. if channel size not given, use k_size given for kernel size.)
k_size: Adaptive selection of kernel size (default=3)
kernel_size: Adaptive selection of kernel size (default=3)
"""
def __init__(self, channel=None, k_size=3, gamma=2, beta=1):
def __init__(self, channels=None, kernel_size=3, gamma=2, beta=1):
super(EcaModule, self).__init__()
assert k_size % 2 == 1
assert kernel_size % 2 == 1
if channels is not None:
t = int(abs(math.log(channels, 2) + beta) / gamma)
kernel_size = max(t if t % 2 else t + 1, 3)
if channel is not None:
t = int(abs(math.log(channel, 2)+beta) / gamma)
k_size = t if t % 2 else t + 1
print('florg', kernel_size)
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.conv = nn.Conv1d(1, 1, kernel_size=k_size, padding=(k_size - 1) // 2, bias=False)
self.sigmoid = nn.Sigmoid()
self.conv = nn.Conv1d(1, 1, kernel_size=kernel_size, padding=(kernel_size - 1) // 2, bias=False)
def forward(self, x):
# Feature descriptor on the global spatial information
@ -69,7 +70,7 @@ class EcaModule(nn.Module):
# Two different branches of ECA module
y = self.conv(y)
# Multi-scale information fusion
y = self.sigmoid(y.view(x.shape[0], -1, 1, 1))
y = y.view(x.shape[0], -1, 1, 1).sigmoid()
return x * y.expand_as(x)
@ -93,22 +94,21 @@ class CecaModule(nn.Module):
k_size: Adaptive selection of kernel size (default=3)
"""
def __init__(self, channel=None, k_size=3, gamma=2, beta=1):
def __init__(self, channels=None, kernel_size=3, gamma=2, beta=1):
super(CecaModule, self).__init__()
assert k_size % 2 == 1
assert kernel_size % 2 == 1
if channel is not None:
t = int(abs(math.log(channel, 2)+beta) / gamma)
k_size = t if t % 2 else t + 1
if channels is not None:
t = int(abs(math.log(channels, 2) + beta) / gamma)
kernel_size = max(t if t % 2 else t + 1, 3)
self.avg_pool = nn.AdaptiveAvgPool2d(1)
#pytorch circular padding mode is buggy as of pytorch 1.4
#see https://github.com/pytorch/pytorch/pull/17240
#implement manual circular padding
self.conv = nn.Conv1d(1, 1, kernel_size=k_size, padding=0, bias=False)
self.padding = (k_size - 1) // 2
self.sigmoid = nn.Sigmoid()
self.conv = nn.Conv1d(1, 1, kernel_size=kernel_size, padding=0, bias=False)
self.padding = (kernel_size - 1) // 2
def forward(self, x):
# Feature descriptor on the global spatial information
@ -121,6 +121,6 @@ class CecaModule(nn.Module):
y = self.conv(y)
# Multi-scale information fusion
y = self.sigmoid(y.view(x.shape[0], -1, 1, 1))
y = y.view(x.shape[0], -1, 1, 1).sigmoid()
return x * y.expand_as(x)

@ -1,4 +1,4 @@
""" Common Helpers
""" Layer/Module Helpers
Hacked together by Ross Wightman
"""
@ -21,7 +21,7 @@ tup_triple = _ntuple(3)
tup_quadruple = _ntuple(4)
# Calculate symmetric padding for a convolution
def get_padding(kernel_size: int, stride: int = 1, dilation: int = 1, **_) -> int:
padding = ((stride - 1) + dilation * (kernel_size - 1)) // 2
return padding

@ -0,0 +1,33 @@
""" Padding Helpers
Hacked together by Ross Wightman
"""
import math
from typing import List
import torch.nn.functional as F
# Calculate symmetric padding for a convolution
def get_padding(kernel_size: int, stride: int = 1, dilation: int = 1, **_) -> int:
padding = ((stride - 1) + dilation * (kernel_size - 1)) // 2
return padding
# Calculate asymmetric TensorFlow-like 'SAME' padding for a convolution
def get_same_padding(x: int, k: int, s: int, d: int):
return max((math.ceil(x / s) - 1) * s + (k - 1) * d + 1 - x, 0)
# Can SAME padding for given args be done statically?
def is_static_pad(kernel_size: int, stride: int = 1, dilation: int = 1, **_):
return stride == 1 and (dilation * (kernel_size - 1)) % 2 == 0
# Dynamically pad input x with 'SAME' padding for conv with specified args
def pad_same(x, k: List[int], s: List[int], d: List[int] = (1, 1)):
ih, iw = x.size()[-2:]
pad_h, pad_w = get_same_padding(ih, k[0], s[0], d[0]), get_same_padding(iw, k[1], s[1], d[1])
if pad_h > 0 or pad_w > 0:
x = F.pad(x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2])
return x

@ -0,0 +1,21 @@
from torch import nn as nn
class SEModule(nn.Module):
def __init__(self, channels, reduction=16, act_layer=nn.ReLU):
super(SEModule, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1)
reduction_channels = max(channels // reduction, 8)
self.fc1 = nn.Conv2d(
channels, reduction_channels, kernel_size=1, padding=0, bias=True)
self.act = act_layer(inplace=True)
self.fc2 = nn.Conv2d(
reduction_channels, channels, kernel_size=1, padding=0, bias=True)
def forward(self, x):
x_se = self.avg_pool(x)
x_se = self.fc1(x_se)
x_se = self.act(x_se)
x_se = self.fc2(x_se)
return x * x_se.sigmoid()

@ -34,6 +34,8 @@ class TestTimePoolHead(nn.Module):
def apply_test_time_pool(model, config, args):
test_time_pool = False
if not hasattr(model, 'default_cfg') or not model.default_cfg:
return model, False
if not args.no_test_pool and \
config['input_size'][-1] > model.default_cfg['input_size'][-1] and \
config['input_size'][-2] > model.default_cfg['input_size'][-2]:

@ -11,7 +11,7 @@ Hacked together by Ross Wightman
from .efficientnet_builder import *
from .registry import register_model
from .helpers import load_pretrained
from .layers import SelectAdaptivePool2d, select_conv2d
from .layers import SelectAdaptivePool2d, create_conv2d
from .layers.activations import HardSwish, hard_sigmoid
from .feature_hooks import FeatureHooks
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
@ -82,7 +82,7 @@ class MobileNetV3(nn.Module):
# Stem
stem_size = round_channels(stem_size, channel_multiplier)
self.conv_stem = select_conv2d(self._in_chs, stem_size, 3, stride=2, padding=pad_type)
self.conv_stem = create_conv2d(self._in_chs, stem_size, 3, stride=2, padding=pad_type)
self.bn1 = norm_layer(stem_size, **norm_kwargs)
self.act1 = act_layer(inplace=True)
self._in_chs = stem_size
@ -97,7 +97,7 @@ class MobileNetV3(nn.Module):
# Head + Pooling
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
self.conv_head = select_conv2d(self._in_chs, self.num_features, 1, padding=pad_type, bias=head_bias)
self.conv_head = create_conv2d(self._in_chs, self.num_features, 1, padding=pad_type, bias=head_bias)
self.act2 = act_layer(inplace=True)
# Classifier
@ -162,7 +162,7 @@ class MobileNetV3Features(nn.Module):
# Stem
stem_size = round_channels(stem_size, channel_multiplier)
self.conv_stem = select_conv2d(self._in_chs, stem_size, 3, stride=2, padding=pad_type)
self.conv_stem = create_conv2d(self._in_chs, stem_size, 3, stride=2, padding=pad_type)
self.bn1 = norm_layer(stem_size, **norm_kwargs)
self.act1 = act_layer(inplace=True)
self._in_chs = stem_size

@ -8,10 +8,10 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
from .resnet import ResNet, SEModule
from .resnet import ResNet
from .registry import register_model
from .helpers import load_pretrained
from .layers import SelectAdaptivePool2d
from .layers import SEModule
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
__all__ = []
@ -53,8 +53,8 @@ class Bottle2neck(nn.Module):
expansion = 4
def __init__(self, inplanes, planes, stride=1, downsample=None,
cardinality=1, base_width=26, scale=4, use_se=False,
act_layer=nn.ReLU, norm_layer=None, dilation=1, first_dilation=None, **_):
cardinality=1, base_width=26, scale=4, dilation=1, first_dilation=None,
act_layer=nn.ReLU, norm_layer=None, attn_layer=None, **_):
super(Bottle2neck, self).__init__()
self.scale = scale
self.is_first = stride > 1 or downsample is not None
@ -82,7 +82,7 @@ class Bottle2neck(nn.Module):
self.conv3 = nn.Conv2d(width * scale, outplanes, kernel_size=1, bias=False)
self.bn3 = norm_layer(outplanes)
self.se = SEModule(outplanes, planes // 4) if use_se else None
self.se = attn_layer(outplanes) if attn_layer is not None else None
self.relu = act_layer(inplace=True)
self.downsample = downsample

@ -7,13 +7,12 @@ ResNeXt, SE-ResNeXt, SENet, and MXNet Gluon stem/downsample variants, tiered ste
"""
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from .registry import register_model
from .helpers import load_pretrained
from .layers import EcaModule, SelectAdaptivePool2d, DropBlock2d, DropPath
from .layers import SelectAdaptivePool2d, DropBlock2d, DropPath, AvgPool2dSame, create_attn
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
@ -103,7 +102,8 @@ default_cfgs = {
'ecaresnext26tn_32x4d': _cfg(
url='',
interpolation='bicubic'),
'ecaresnet18': _cfg(),
'ecaresnet50': _cfg(),
}
@ -112,32 +112,12 @@ def get_padding(kernel_size, stride, dilation=1):
return padding
class SEModule(nn.Module):
def __init__(self, channels, reduction_channels):
super(SEModule, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.fc1 = nn.Conv2d(
channels, reduction_channels, kernel_size=1, padding=0, bias=True)
self.relu = nn.ReLU(inplace=True)
self.fc2 = nn.Conv2d(
reduction_channels, channels, kernel_size=1, padding=0, bias=True)
def forward(self, x):
x_se = self.avg_pool(x)
x_se = self.fc1(x_se)
x_se = self.relu(x_se)
x_se = self.fc2(x_se)
return x * x_se.sigmoid()
class BasicBlock(nn.Module):
__constants__ = ['se', 'downsample'] # for pre 1.4 torchscript compat
expansion = 1
def __init__(self, inplanes, planes, stride=1, downsample=None, cardinality=1, base_width=64, use_se=False,
def __init__(self, inplanes, planes, stride=1, downsample=None, cardinality=1, base_width=64,
reduce_first=1, dilation=1, first_dilation=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d,
drop_block=None, drop_path=None):
attn_layer=None, drop_block=None, drop_path=None):
super(BasicBlock, self).__init__()
assert cardinality == 1, 'BasicBlock only supports cardinality of 1'
@ -155,7 +135,7 @@ class BasicBlock(nn.Module):
first_planes, outplanes, kernel_size=3, padding=dilation, dilation=dilation, bias=False)
self.bn2 = norm_layer(outplanes)
self.se = SEModule(outplanes, planes // 4) if use_se else None
self.se = create_attn(attn_layer, outplanes)
self.act2 = act_layer(inplace=True)
self.downsample = downsample
@ -199,9 +179,9 @@ class Bottleneck(nn.Module):
__constants__ = ['se', 'downsample'] # for pre 1.4 torchscript compat
expansion = 4
def __init__(self, inplanes, planes, stride=1, downsample=None, cardinality=1, base_width=64, use_se=False,
def __init__(self, inplanes, planes, stride=1, downsample=None, cardinality=1, base_width=64,
reduce_first=1, dilation=1, first_dilation=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d,
drop_block=None, drop_path=None):
attn_layer=None, drop_block=None, drop_path=None):
super(Bottleneck, self).__init__()
width = int(math.floor(planes * (base_width / 64)) * cardinality)
@ -220,7 +200,7 @@ class Bottleneck(nn.Module):
self.conv3 = nn.Conv2d(width, outplanes, kernel_size=1, bias=False)
self.bn3 = norm_layer(outplanes)
self.se = SEModule(outplanes, planes // 4) if use_se else None
self.se = create_attn(attn_layer, outplanes)
self.act3 = act_layer(inplace=True)
self.downsample = downsample
@ -266,6 +246,37 @@ class Bottleneck(nn.Module):
return x
def downsample_conv(
in_channels, out_channels, kernel_size, stride=1, dilation=1, first_dilation=None, norm_layer=None):
norm_layer = norm_layer or nn.BatchNorm2d
kernel_size = 1 if stride == 1 and dilation == 1 else kernel_size
first_dilation = (first_dilation or dilation) if kernel_size > 1 else 1
p = get_padding(kernel_size, stride, first_dilation)
return nn.Sequential(*[
nn.Conv2d(
in_channels, out_channels, kernel_size, stride=stride, padding=p, dilation=first_dilation, bias=False),
norm_layer(out_channels)
])
def downsample_avg(
in_channels, out_channels, kernel_size, stride=1, dilation=1, first_dilation=None, norm_layer=None):
norm_layer = norm_layer or nn.BatchNorm2d
avg_stride = stride if dilation == 1 else 1
if stride == 1 and dilation == 1:
pool = nn.Identity()
else:
avg_pool_fn = AvgPool2dSame if avg_stride == 1 and dilation > 1 else nn.AvgPool2d
pool = avg_pool_fn(2, avg_stride, ceil_mode=True, count_include_pad=False)
return nn.Sequential(*[
pool,
nn.Conv2d(in_channels, out_channels, 1, stride=1, padding=0, bias=False),
norm_layer(out_channels)
])
class ResNet(nn.Module):
"""ResNet / ResNeXt / SE-ResNeXt / SE-Net
@ -307,8 +318,6 @@ class ResNet(nn.Module):
Number of classification classes.
in_chans : int, default 3
Number of input (color) channels.
use_se : bool, default False
Enable Squeeze-Excitation module in blocks
cardinality : int, default 1
Number of convolution groups for 3x3 conv in Bottleneck.
base_width : int, default 64
@ -337,7 +346,7 @@ class ResNet(nn.Module):
global_pool : str, default 'avg'
Global pooling type. One of 'avg', 'max', 'avgmax', 'catavgmax'
"""
def __init__(self, block, layers, num_classes=1000, in_chans=3, use_se=False, use_eca=False,
def __init__(self, block, layers, num_classes=1000, in_chans=3,
cardinality=1, base_width=64, stem_width=64, stem_type='',
block_reduce_first=1, down_kernel_size=1, avg_down=False, output_stride=32,
act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, drop_rate=0.0, drop_path_rate=0.,
@ -385,14 +394,14 @@ class ResNet(nn.Module):
dilations[2:4] = [2, 4]
else:
assert output_stride == 32
llargs = list(zip(channels, layers, strides, dilations))
lkwargs = dict(
use_se=use_se, reduce_first=block_reduce_first, act_layer=act_layer, norm_layer=norm_layer,
layer_args = list(zip(channels, layers, strides, dilations))
layer_kwargs = dict(
reduce_first=block_reduce_first, act_layer=act_layer, norm_layer=norm_layer,
avg_down=avg_down, down_kernel_size=down_kernel_size, drop_path=dp, **block_args)
self.layer1 = self._make_layer(block, *llargs[0], **lkwargs)
self.layer2 = self._make_layer(block, *llargs[1], **lkwargs)
self.layer3 = self._make_layer(block, drop_block=db_3, *llargs[2], **lkwargs)
self.layer4 = self._make_layer(block, drop_block=db_4, *llargs[3], **lkwargs)
self.layer1 = self._make_layer(block, *layer_args[0], **layer_kwargs)
self.layer2 = self._make_layer(block, *layer_args[1], **layer_kwargs)
self.layer3 = self._make_layer(block, drop_block=db_3, *layer_args[2], **layer_kwargs)
self.layer4 = self._make_layer(block, drop_block=db_4, *layer_args[3], **layer_kwargs)
# Head (Pooling and Classifier)
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
@ -411,31 +420,21 @@ class ResNet(nn.Module):
m.zero_init_last_bn()
def _make_layer(self, block, planes, blocks, stride=1, dilation=1, reduce_first=1,
use_se=False, use_eca=False,avg_down=False, down_kernel_size=1, **kwargs):
norm_layer = kwargs.get('norm_layer')
avg_down=False, down_kernel_size=1, **kwargs):
downsample = None
down_kernel_size = 1 if stride == 1 and dilation == 1 else down_kernel_size
first_dilation = 1 if dilation in (1, 2) else 2
if stride != 1 or self.inplanes != planes * block.expansion:
downsample_padding = get_padding(down_kernel_size, stride)
downsample_layers = []
conv_stride = stride
if avg_down:
avg_stride = stride if dilation == 1 else 1
conv_stride = 1
downsample_layers = [nn.AvgPool2d(avg_stride, avg_stride, ceil_mode=True, count_include_pad=False)]
downsample_layers += [
nn.Conv2d(self.inplanes, planes * block.expansion, down_kernel_size,
stride=conv_stride, padding=downsample_padding, bias=False),
norm_layer(planes * block.expansion)]
downsample = nn.Sequential(*downsample_layers)
downsample_args = dict(
in_channels=self.inplanes, out_channels=planes * block.expansion, kernel_size=down_kernel_size,
stride=stride, dilation=dilation, first_dilation=first_dilation, norm_layer=kwargs.get('norm_layer'))
downsample = downsample_avg(**downsample_args) if avg_down else downsample_conv(**downsample_args)
first_dilation = 1 if dilation in (1, 2) else 2
bkwargs = dict(
block_kwargs = dict(
cardinality=self.cardinality, base_width=self.base_width, reduce_first=reduce_first,
dilation=dilation, use_se=use_se, **kwargs)
layers = [block(self.inplanes, planes, stride, downsample, first_dilation=first_dilation, **bkwargs)]
dilation=dilation, **kwargs)
layers = [block(self.inplanes, planes, stride, downsample, first_dilation=first_dilation, **block_kwargs)]
self.inplanes = planes * block.expansion
layers += [block(self.inplanes, planes, **bkwargs) for _ in range(1, blocks)]
layers += [block(self.inplanes, planes, **block_kwargs) for _ in range(1, blocks)]
return nn.Sequential(*layers)
@ -936,9 +935,8 @@ def seresnext26d_32x4d(pretrained=False, num_classes=1000, in_chans=3, **kwargs)
"""
default_cfg = default_cfgs['seresnext26d_32x4d']
model = ResNet(
Bottleneck, [2, 2, 2, 2], cardinality=32, base_width=4,
stem_width=32, stem_type='deep', avg_down=True, use_se=True,
num_classes=num_classes, in_chans=in_chans, **kwargs)
Bottleneck, [2, 2, 2, 2], cardinality=32, base_width=4, stem_width=32, stem_type='deep', avg_down=True,
num_classes=num_classes, in_chans=in_chans, block_args=dict(attn_layer='se'), **kwargs)
model.default_cfg = default_cfg
if pretrained:
load_pretrained(model, default_cfg, num_classes, in_chans)
@ -954,8 +952,8 @@ def seresnext26t_32x4d(pretrained=False, num_classes=1000, in_chans=3, **kwargs)
default_cfg = default_cfgs['seresnext26t_32x4d']
model = ResNet(
Bottleneck, [2, 2, 2, 2], cardinality=32, base_width=4,
stem_width=32, stem_type='deep_tiered', avg_down=True, use_se=True,
num_classes=num_classes, in_chans=in_chans, **kwargs)
stem_width=32, stem_type='deep_tiered', avg_down=True,
num_classes=num_classes, in_chans=in_chans, block_args=dict(attn_layer='se'), **kwargs)
model.default_cfg = default_cfg
if pretrained:
load_pretrained(model, default_cfg, num_classes, in_chans)
@ -971,25 +969,55 @@ def seresnext26tn_32x4d(pretrained=False, num_classes=1000, in_chans=3, **kwargs
default_cfg = default_cfgs['seresnext26tn_32x4d']
model = ResNet(
Bottleneck, [2, 2, 2, 2], cardinality=32, base_width=4,
stem_width=32, stem_type='deep_tiered_narrow', avg_down=True, use_se=True,
num_classes=num_classes, in_chans=in_chans, **kwargs)
stem_width=32, stem_type='deep_tiered_narrow', avg_down=True,
num_classes=num_classes, in_chans=in_chans, block_args=dict(attn_layer='se'), **kwargs)
model.default_cfg = default_cfg
if pretrained:
load_pretrained(model, default_cfg, num_classes, in_chans)
return model
@register_model
def ecaresnext26tn_32x4d(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
"""Constructs a eca-ResNeXt-26-TN model.
"""Constructs an ECA-ResNeXt-26-TN model.
This is technically a 28 layer ResNet, like a 'D' bag-of-tricks model but with tiered 24, 32, 64 channels
in the deep stem. The channel number of the middle stem conv is narrower than the 'T' variant.
this model replaces SE module with the ECA module
"""
default_cfg = default_cfgs['ecaresnext26tn_32x4d']
block_args = dict(attn_layer='eca')
model = ResNet(
Bottleneck, [2, 2, 2, 2], cardinality=32, base_width=4,
stem_width=32, stem_type='deep_tiered_narrow', avg_down=True, use_eca=True,
num_classes=num_classes, in_chans=in_chans, **kwargs)
stem_width=32, stem_type='deep_tiered_narrow', avg_down=True,
num_classes=num_classes, in_chans=in_chans, block_args=block_args, **kwargs)
model.default_cfg = default_cfg
if pretrained:
load_pretrained(model, default_cfg, num_classes, in_chans)
return model
@register_model
def ecaresnet18(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
""" Constructs an ECA-ResNet-18 model.
"""
default_cfg = default_cfgs['ecaresnet18']
block_args = dict(attn_layer='eca')
model = ResNet(
BasicBlock, [2, 2, 2, 2], num_classes=num_classes, in_chans=in_chans, block_args=block_args, **kwargs)
model.default_cfg = default_cfg
if pretrained:
load_pretrained(model, default_cfg, num_classes, in_chans)
return model
@register_model
def ecaresnet50(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
"""Constructs an ECA-ResNet-50 model.
"""
default_cfg = default_cfgs['ecaresnet50']
block_args = dict(attn_layer='eca')
model = ResNet(
Bottleneck, [3, 4, 6, 3], num_classes=num_classes, in_chans=in_chans, block_args=block_args, **kwargs)
model.default_cfg = default_cfg
if pretrained:
load_pretrained(model, default_cfg, num_classes, in_chans)

@ -4,8 +4,8 @@ from torch import nn as nn
from .registry import register_model
from .helpers import load_pretrained
from .layers import SelectiveKernelConv, ConvBnAct
from .resnet import ResNet, SEModule
from .layers import SelectiveKernelConv, ConvBnAct, create_attn
from .resnet import ResNet
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
@ -33,8 +33,8 @@ class SelectiveKernelBasic(nn.Module):
expansion = 1
def __init__(self, inplanes, planes, stride=1, downsample=None, cardinality=1, base_width=64,
use_se=False, sk_kwargs=None, reduce_first=1, dilation=1, first_dilation=None,
drop_block=None, drop_path=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d):
sk_kwargs=None, reduce_first=1, dilation=1, first_dilation=None,
drop_block=None, drop_path=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, attn_layer=None):
super(SelectiveKernelBasic, self).__init__()
sk_kwargs = sk_kwargs or {}
@ -42,7 +42,7 @@ class SelectiveKernelBasic(nn.Module):
assert cardinality == 1, 'BasicBlock only supports cardinality of 1'
assert base_width == 64, 'BasicBlock doest not support changing base width'
first_planes = planes // reduce_first
out_planes = planes * self.expansion
outplanes = planes * self.expansion
first_dilation = first_dilation or dilation
_selective_first = True # FIXME temporary, for experiments
@ -51,14 +51,14 @@ class SelectiveKernelBasic(nn.Module):
inplanes, first_planes, stride=stride, dilation=first_dilation, **conv_kwargs, **sk_kwargs)
conv_kwargs['act_layer'] = None
self.conv2 = ConvBnAct(
first_planes, out_planes, kernel_size=3, dilation=dilation, **conv_kwargs)
first_planes, outplanes, kernel_size=3, dilation=dilation, **conv_kwargs)
else:
self.conv1 = ConvBnAct(
inplanes, first_planes, kernel_size=3, stride=stride, dilation=first_dilation, **conv_kwargs)
conv_kwargs['act_layer'] = None
self.conv2 = SelectiveKernelConv(
first_planes, out_planes, dilation=dilation, **conv_kwargs, **sk_kwargs)
self.se = SEModule(out_planes, planes // 4) if use_se else None
first_planes, outplanes, dilation=dilation, **conv_kwargs, **sk_kwargs)
self.se = create_attn(attn_layer, outplanes)
self.act = act_layer(inplace=True)
self.downsample = downsample
self.stride = stride
@ -88,17 +88,15 @@ class SelectiveKernelBottleneck(nn.Module):
expansion = 4
def __init__(self, inplanes, planes, stride=1, downsample=None,
cardinality=1, base_width=64, use_se=False, sk_kwargs=None,
reduce_first=1, dilation=1, first_dilation=None,
drop_block=None, drop_path=None,
act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d):
cardinality=1, base_width=64, sk_kwargs=None, reduce_first=1, dilation=1, first_dilation=None,
drop_block=None, drop_path=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, attn_layer=None):
super(SelectiveKernelBottleneck, self).__init__()
sk_kwargs = sk_kwargs or {}
conv_kwargs = dict(drop_block=drop_block, act_layer=act_layer, norm_layer=norm_layer)
width = int(math.floor(planes * (base_width / 64)) * cardinality)
first_planes = width // reduce_first
out_planes = planes * self.expansion
outplanes = planes * self.expansion
first_dilation = first_dilation or dilation
self.conv1 = ConvBnAct(inplanes, first_planes, kernel_size=1, **conv_kwargs)
@ -106,8 +104,8 @@ class SelectiveKernelBottleneck(nn.Module):
first_planes, width, stride=stride, dilation=first_dilation, groups=cardinality,
**conv_kwargs, **sk_kwargs)
conv_kwargs['act_layer'] = None
self.conv3 = ConvBnAct(width, out_planes, kernel_size=1, **conv_kwargs)
self.se = SEModule(out_planes, planes // 4) if use_se else None
self.conv3 = ConvBnAct(width, outplanes, kernel_size=1, **conv_kwargs)
self.se = create_attn(attn_layer, outplanes)
self.act = act_layer(inplace=True)
self.downsample = downsample
self.stride = stride

Loading…
Cancel
Save