diff --git a/sotabench.py b/sotabench.py index bd5b0b81..66b7d323 100644 --- a/sotabench.py +++ b/sotabench.py @@ -54,6 +54,8 @@ model_list = [ model_desc='Trained from scratch in PyTorch w/ RandAugment'), _entry('efficientnet_b3a', 'EfficientNet-B3 (320x320, 1.0 crop)', '1905.11946', model_desc='Trained from scratch in PyTorch w/ RandAugment'), + _entry('efficientnet_es', 'EfficientNet-EdgeTPU-S', '1905.11946', + model_desc='Trained from scratch in PyTorch w/ RandAugment'), _entry('fbnetc_100', 'FBNet-C', '1812.03443', model_desc='Trained in PyTorch with RMSProp, exponential LR decay'), _entry('gluon_inception_v3', 'Inception V3', '1512.00567', model_desc='Ported from GluonCV Model Zoo'), diff --git a/timm/models/__init__.py b/timm/models/__init__.py index 0fa4d210..cc4d470e 100644 --- a/timm/models/__init__.py +++ b/timm/models/__init__.py @@ -16,9 +16,10 @@ from .gluon_xception import * from .res2net import * from .dla import * from .hrnet import * +from .sknet import * from .registry import * from .factory import create_model from .helpers import load_checkpoint, resume_checkpoint -from .test_time_pool import TestTimePoolHead, apply_test_time_pool -from .split_batchnorm import convert_splitbn_model +from .layers import TestTimePoolHead, apply_test_time_pool +from .layers import convert_splitbn_model diff --git a/timm/models/conv2d_layers.py b/timm/models/conv2d_layers.py deleted file mode 100644 index acd14fde..00000000 --- a/timm/models/conv2d_layers.py +++ /dev/null @@ -1,260 +0,0 @@ -import torch -import torch.nn as nn -import torch.nn.functional as F -from torch._six import container_abcs -from itertools import repeat -from functools import partial -import numpy as np -import math - - -# Tuple helpers ripped from PyTorch -def _ntuple(n): - def parse(x): - if isinstance(x, container_abcs.Iterable): - return x - return tuple(repeat(x, n)) - return parse - - -_single = _ntuple(1) -_pair = _ntuple(2) -_triple = _ntuple(3) -_quadruple = _ntuple(4) - - -def _is_static_pad(kernel_size, stride=1, dilation=1, **_): - return stride == 1 and (dilation * (kernel_size - 1)) % 2 == 0 - - -def _get_padding(kernel_size, stride=1, dilation=1, **_): - padding = ((stride - 1) + dilation * (kernel_size - 1)) // 2 - return padding - - -def _calc_same_pad(i, k, s, d): - return max((math.ceil(i / s) - 1) * s + (k - 1) * d + 1 - i, 0) - - -def _split_channels(num_chan, num_groups): - split = [num_chan // num_groups for _ in range(num_groups)] - split[0] += num_chan - sum(split) - return split - - -# pylint: disable=unused-argument -def conv2d_same(x, weight, bias=None, stride=(1, 1), padding=(0, 0), dilation=(1, 1), groups=1): - ih, iw = x.size()[-2:] - kh, kw = weight.size()[-2:] - pad_h = _calc_same_pad(ih, kh, stride[0], dilation[0]) - pad_w = _calc_same_pad(iw, kw, stride[1], dilation[1]) - if pad_h > 0 or pad_w > 0: - x = F.pad(x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2]) - return F.conv2d(x, weight, bias, stride, (0, 0), dilation, groups) - - -class Conv2dSame(nn.Conv2d): - """ Tensorflow like 'SAME' convolution wrapper for 2D convolutions - """ - - # pylint: disable=unused-argument - def __init__(self, in_channels, out_channels, kernel_size, stride=1, - padding=0, dilation=1, groups=1, bias=True): - super(Conv2dSame, self).__init__( - in_channels, out_channels, kernel_size, stride, 0, dilation, groups, bias) - - def forward(self, x): - return conv2d_same(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups) - - -def get_padding_value(padding, kernel_size, **kwargs): - dynamic = False - if isinstance(padding, str): - # for any string padding, the padding will be calculated for you, one of three ways - padding = padding.lower() - if padding == 'same': - # TF compatible 'SAME' padding, has a performance and GPU memory allocation impact - if _is_static_pad(kernel_size, **kwargs): - # static case, no extra overhead - padding = _get_padding(kernel_size, **kwargs) - else: - # dynamic 'SAME' padding, has runtime/GPU memory overhead - padding = 0 - dynamic = True - elif padding == 'valid': - # 'VALID' padding, same as padding=0 - padding = 0 - else: - # Default to PyTorch style 'same'-ish symmetric padding - padding = _get_padding(kernel_size, **kwargs) - return padding, dynamic - - -def create_conv2d_pad(in_chs, out_chs, kernel_size, **kwargs): - padding = kwargs.pop('padding', '') - kwargs.setdefault('bias', False) - padding, is_dynamic = get_padding_value(padding, kernel_size, **kwargs) - if is_dynamic: - return Conv2dSame(in_chs, out_chs, kernel_size, **kwargs) - else: - return nn.Conv2d(in_chs, out_chs, kernel_size, padding=padding, **kwargs) - - -class MixedConv2d(nn.Module): - """ Mixed Grouped Convolution - Based on MDConv and GroupedConv in MixNet impl: - https://github.com/tensorflow/tpu/blob/master/models/official/mnasnet/mixnet/custom_layers.py - - NOTE: This does not currently work with torch.jit.script - """ - - def __init__(self, in_channels, out_channels, kernel_size=3, - stride=1, padding='', dilation=1, depthwise=False, **kwargs): - super(MixedConv2d, self).__init__() - - kernel_size = kernel_size if isinstance(kernel_size, list) else [kernel_size] - num_groups = len(kernel_size) - in_splits = _split_channels(in_channels, num_groups) - out_splits = _split_channels(out_channels, num_groups) - self.in_channels = sum(in_splits) - self.out_channels = sum(out_splits) - for idx, (k, in_ch, out_ch) in enumerate(zip(kernel_size, in_splits, out_splits)): - conv_groups = out_ch if depthwise else 1 - # use add_module to keep key space clean - self.add_module( - str(idx), - create_conv2d_pad( - in_ch, out_ch, k, stride=stride, - padding=padding, dilation=dilation, groups=conv_groups, **kwargs) - ) - self.splits = in_splits - - def forward(self, x): - x_split = torch.split(x, self.splits, 1) - x_out = [c(x) for x, c in zip(x_split, self._modules.values())] - x = torch.cat(x_out, 1) - return x - - -def get_condconv_initializer(initializer, num_experts, expert_shape): - def condconv_initializer(weight): - """CondConv initializer function.""" - num_params = np.prod(expert_shape) - if (len(weight.shape) != 2 or weight.shape[0] != num_experts or - weight.shape[1] != num_params): - raise (ValueError( - 'CondConv variables must have shape [num_experts, num_params]')) - for i in range(num_experts): - initializer(weight[i].view(expert_shape)) - return condconv_initializer - - -class CondConv2d(nn.Module): - """ Conditional Convolution - Inspired by: https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/condconv/condconv_layers.py - - Grouped convolution hackery for parallel execution of the per-sample kernel filters inspired by this discussion: - https://github.com/pytorch/pytorch/issues/17983 - """ - __constants__ = ['bias', 'in_channels', 'out_channels', 'dynamic_padding'] - - def __init__(self, in_channels, out_channels, kernel_size=3, - stride=1, padding='', dilation=1, groups=1, bias=False, num_experts=4): - super(CondConv2d, self).__init__() - - self.in_channels = in_channels - self.out_channels = out_channels - self.kernel_size = _pair(kernel_size) - self.stride = _pair(stride) - padding_val, is_padding_dynamic = get_padding_value( - padding, kernel_size, stride=stride, dilation=dilation) - self.dynamic_padding = is_padding_dynamic # if in forward to work with torchscript - self.padding = _pair(padding_val) - self.dilation = _pair(dilation) - self.groups = groups - self.num_experts = num_experts - - self.weight_shape = (self.out_channels, self.in_channels // self.groups) + self.kernel_size - weight_num_param = 1 - for wd in self.weight_shape: - weight_num_param *= wd - self.weight = torch.nn.Parameter(torch.Tensor(self.num_experts, weight_num_param)) - - if bias: - self.bias_shape = (self.out_channels,) - self.bias = torch.nn.Parameter(torch.Tensor(self.num_experts, self.out_channels)) - else: - self.register_parameter('bias', None) - - self.reset_parameters() - - def reset_parameters(self): - init_weight = get_condconv_initializer( - partial(nn.init.kaiming_uniform_, a=math.sqrt(5)), self.num_experts, self.weight_shape) - init_weight(self.weight) - if self.bias is not None: - fan_in = np.prod(self.weight_shape[1:]) - bound = 1 / math.sqrt(fan_in) - init_bias = get_condconv_initializer( - partial(nn.init.uniform_, a=-bound, b=bound), self.num_experts, self.bias_shape) - init_bias(self.bias) - - def forward(self, x, routing_weights): - B, C, H, W = x.shape - weight = torch.matmul(routing_weights, self.weight) - new_weight_shape = (B * self.out_channels, self.in_channels // self.groups) + self.kernel_size - weight = weight.view(new_weight_shape) - bias = None - if self.bias is not None: - bias = torch.matmul(routing_weights, self.bias) - bias = bias.view(B * self.out_channels) - # move batch elements with channels so each batch element can be efficiently convolved with separate kernel - x = x.view(1, B * C, H, W) - if self.dynamic_padding: - out = conv2d_same( - x, weight, bias, stride=self.stride, padding=self.padding, - dilation=self.dilation, groups=self.groups * B) - else: - out = F.conv2d( - x, weight, bias, stride=self.stride, padding=self.padding, - dilation=self.dilation, groups=self.groups * B) - out = out.permute([1, 0, 2, 3]).view(B, self.out_channels, out.shape[-2], out.shape[-1]) - - # Literal port (from TF definition) - # x = torch.split(x, 1, 0) - # weight = torch.split(weight, 1, 0) - # if self.bias is not None: - # bias = torch.matmul(routing_weights, self.bias) - # bias = torch.split(bias, 1, 0) - # else: - # bias = [None] * B - # out = [] - # for xi, wi, bi in zip(x, weight, bias): - # wi = wi.view(*self.weight_shape) - # if bi is not None: - # bi = bi.view(*self.bias_shape) - # out.append(self.conv_fn( - # xi, wi, bi, stride=self.stride, padding=self.padding, - # dilation=self.dilation, groups=self.groups)) - # out = torch.cat(out, 0) - return out - - -# helper method -def select_conv2d(in_chs, out_chs, kernel_size, **kwargs): - assert 'groups' not in kwargs # only use 'depthwise' bool arg - if isinstance(kernel_size, list): - assert 'num_experts' not in kwargs # MixNet + CondConv combo not supported currently - # We're going to use only lists for defining the MixedConv2d kernel groups, - # ints, tuples, other iterables will continue to pass to normal conv and specify h, w. - m = MixedConv2d(in_chs, out_chs, kernel_size, **kwargs) - else: - depthwise = kwargs.pop('depthwise', False) - groups = out_chs if depthwise else 1 - if 'num_experts' in kwargs and kwargs['num_experts'] > 0: - m = CondConv2d(in_chs, out_chs, kernel_size, groups=groups, **kwargs) - else: - m = create_conv2d_pad(in_chs, out_chs, kernel_size, groups=groups, **kwargs) - return m - - diff --git a/timm/models/densenet.py b/timm/models/densenet.py index d1ac5857..4235c0f7 100644 --- a/timm/models/densenet.py +++ b/timm/models/densenet.py @@ -10,7 +10,7 @@ import torch.nn.functional as F from .registry import register_model from .helpers import load_pretrained -from .adaptive_avgmax_pool import SelectAdaptivePool2d +from .layers import SelectAdaptivePool2d from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD import re diff --git a/timm/models/dla.py b/timm/models/dla.py index cd560f44..a9e81d16 100644 --- a/timm/models/dla.py +++ b/timm/models/dla.py @@ -13,7 +13,7 @@ import torch.nn.functional as F from .registry import register_model from .helpers import load_pretrained -from .adaptive_avgmax_pool import SelectAdaptivePool2d +from .layers import SelectAdaptivePool2d from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD diff --git a/timm/models/dpn.py b/timm/models/dpn.py index 7f46e8e0..fd58e516 100644 --- a/timm/models/dpn.py +++ b/timm/models/dpn.py @@ -16,7 +16,7 @@ from collections import OrderedDict from .registry import register_model from .helpers import load_pretrained -from .adaptive_avgmax_pool import SelectAdaptivePool2d +from .layers import SelectAdaptivePool2d from timm.data import IMAGENET_DPN_MEAN, IMAGENET_DPN_STD diff --git a/timm/models/efficientnet.py b/timm/models/efficientnet.py index 8d07a2ca..ea71c873 100644 --- a/timm/models/efficientnet.py +++ b/timm/models/efficientnet.py @@ -27,8 +27,8 @@ from .efficientnet_builder import * from .feature_hooks import FeatureHooks from .registry import register_model from .helpers import load_pretrained -from .adaptive_avgmax_pool import SelectAdaptivePool2d -from .conv2d_layers import select_conv2d +from .layers import SelectAdaptivePool2d +from timm.models.layers import create_conv2d from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD @@ -220,7 +220,7 @@ class EfficientNet(nn.Module): def __init__(self, block_args, num_classes=1000, num_features=1280, in_chans=3, stem_size=32, channel_multiplier=1.0, channel_divisor=8, channel_min=None, - pad_type='', act_layer=nn.ReLU, drop_rate=0., drop_connect_rate=0., + output_stride=32, pad_type='', act_layer=nn.ReLU, drop_rate=0., drop_connect_rate=0., se_kwargs=None, norm_layer=nn.BatchNorm2d, norm_kwargs=None, global_pool='avg'): super(EfficientNet, self).__init__() norm_kwargs = norm_kwargs or {} @@ -232,21 +232,21 @@ class EfficientNet(nn.Module): # Stem stem_size = round_channels(stem_size, channel_multiplier, channel_divisor, channel_min) - self.conv_stem = select_conv2d(self._in_chs, stem_size, 3, stride=2, padding=pad_type) + self.conv_stem = create_conv2d(self._in_chs, stem_size, 3, stride=2, padding=pad_type) self.bn1 = norm_layer(stem_size, **norm_kwargs) self.act1 = act_layer(inplace=True) self._in_chs = stem_size # Middle stages (IR/ER/DS Blocks) builder = EfficientNetBuilder( - channel_multiplier, channel_divisor, channel_min, 32, pad_type, act_layer, se_kwargs, + channel_multiplier, channel_divisor, channel_min, output_stride, pad_type, act_layer, se_kwargs, norm_layer, norm_kwargs, drop_connect_rate, verbose=_DEBUG) self.blocks = nn.Sequential(*builder(self._in_chs, block_args)) self.feature_info = builder.features self._in_chs = builder.in_chs # Head + Pooling - self.conv_head = select_conv2d(self._in_chs, self.num_features, 1, padding=pad_type) + self.conv_head = create_conv2d(self._in_chs, self.num_features, 1, padding=pad_type) self.bn2 = norm_layer(self.num_features, **norm_kwargs) self.act2 = act_layer(inplace=True) self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) @@ -314,7 +314,7 @@ class EfficientNetFeatures(nn.Module): # Stem stem_size = round_channels(stem_size, channel_multiplier, channel_divisor, channel_min) - self.conv_stem = select_conv2d(self._in_chs, stem_size, 3, stride=2, padding=pad_type) + self.conv_stem = create_conv2d(self._in_chs, stem_size, 3, stride=2, padding=pad_type) self.bn1 = norm_layer(stem_size, **norm_kwargs) self.act1 = act_layer(inplace=True) self._in_chs = stem_size diff --git a/timm/models/efficientnet_blocks.py b/timm/models/efficientnet_blocks.py index 13ab051a..c87c2237 100644 --- a/timm/models/efficientnet_blocks.py +++ b/timm/models/efficientnet_blocks.py @@ -1,11 +1,8 @@ - -from functools import partial - import torch import torch.nn as nn -import torch.nn.functional as F -from .activations import sigmoid -from .conv2d_layers import * +from torch.nn import functional as F +from .layers.activations import sigmoid +from .layers import create_conv2d # Defaults used for Google/Tensorflow training of mobile networks /w RMSprop as per @@ -72,7 +69,7 @@ def round_channels(channels, multiplier=1.0, divisor=8, channel_min=None): return make_divisible(channels, divisor, channel_min) -def drop_connect(inputs, training=False, drop_connect_rate=0.): +def drop_connect(inputs, training: bool = False, drop_connect_rate: float = 0.): """Apply drop connect.""" if not training: return inputs @@ -132,7 +129,7 @@ class ConvBnAct(nn.Module): norm_layer=nn.BatchNorm2d, norm_kwargs=None): super(ConvBnAct, self).__init__() norm_kwargs = norm_kwargs or {} - self.conv = select_conv2d(in_chs, out_chs, kernel_size, stride=stride, dilation=dilation, padding=pad_type) + self.conv = create_conv2d(in_chs, out_chs, kernel_size, stride=stride, dilation=dilation, padding=pad_type) self.bn1 = norm_layer(out_chs, **norm_kwargs) self.act1 = act_layer(inplace=True) @@ -160,22 +157,24 @@ class DepthwiseSeparableConv(nn.Module): norm_layer=nn.BatchNorm2d, norm_kwargs=None, drop_connect_rate=0.): super(DepthwiseSeparableConv, self).__init__() norm_kwargs = norm_kwargs or {} - self.has_se = se_ratio is not None and se_ratio > 0. + has_se = se_ratio is not None and se_ratio > 0. self.has_residual = (stride == 1 and in_chs == out_chs) and not noskip self.has_pw_act = pw_act # activation after point-wise conv self.drop_connect_rate = drop_connect_rate - self.conv_dw = select_conv2d( + self.conv_dw = create_conv2d( in_chs, in_chs, dw_kernel_size, stride=stride, dilation=dilation, padding=pad_type, depthwise=True) self.bn1 = norm_layer(in_chs, **norm_kwargs) self.act1 = act_layer(inplace=True) # Squeeze-and-excitation - if self.has_se: + if has_se: se_kwargs = resolve_se_args(se_kwargs, in_chs, act_layer) self.se = SqueezeExcite(in_chs, se_ratio=se_ratio, **se_kwargs) + else: + self.se = None - self.conv_pw = select_conv2d(in_chs, out_chs, pw_kernel_size, padding=pad_type) + self.conv_pw = create_conv2d(in_chs, out_chs, pw_kernel_size, padding=pad_type) self.bn2 = norm_layer(out_chs, **norm_kwargs) self.act2 = act_layer(inplace=True) if self.has_pw_act else nn.Identity() @@ -193,7 +192,7 @@ class DepthwiseSeparableConv(nn.Module): x = self.bn1(x) x = self.act1(x) - if self.has_se: + if self.se is not None: x = self.se(x) x = self.conv_pw(x) @@ -219,29 +218,31 @@ class InvertedResidual(nn.Module): norm_kwargs = norm_kwargs or {} conv_kwargs = conv_kwargs or {} mid_chs = make_divisible(in_chs * exp_ratio) - self.has_se = se_ratio is not None and se_ratio > 0. + has_se = se_ratio is not None and se_ratio > 0. self.has_residual = (in_chs == out_chs and stride == 1) and not noskip self.drop_connect_rate = drop_connect_rate # Point-wise expansion - self.conv_pw = select_conv2d(in_chs, mid_chs, exp_kernel_size, padding=pad_type, **conv_kwargs) + self.conv_pw = create_conv2d(in_chs, mid_chs, exp_kernel_size, padding=pad_type, **conv_kwargs) self.bn1 = norm_layer(mid_chs, **norm_kwargs) self.act1 = act_layer(inplace=True) # Depth-wise convolution - self.conv_dw = select_conv2d( + self.conv_dw = create_conv2d( mid_chs, mid_chs, dw_kernel_size, stride=stride, dilation=dilation, padding=pad_type, depthwise=True, **conv_kwargs) self.bn2 = norm_layer(mid_chs, **norm_kwargs) self.act2 = act_layer(inplace=True) # Squeeze-and-excitation - if self.has_se: + if has_se: se_kwargs = resolve_se_args(se_kwargs, in_chs, act_layer) self.se = SqueezeExcite(mid_chs, se_ratio=se_ratio, **se_kwargs) + else: + self.se = None # Point-wise linear projection - self.conv_pwl = select_conv2d(mid_chs, out_chs, pw_kernel_size, padding=pad_type, **conv_kwargs) + self.conv_pwl = create_conv2d(mid_chs, out_chs, pw_kernel_size, padding=pad_type, **conv_kwargs) self.bn3 = norm_layer(out_chs, **norm_kwargs) def feature_module(self, location): @@ -269,7 +270,7 @@ class InvertedResidual(nn.Module): x = self.act2(x) # Squeeze-and-excitation - if self.has_se: + if self.se is not None: x = self.se(x) # Point-wise linear projection @@ -323,7 +324,7 @@ class CondConvResidual(InvertedResidual): x = self.act2(x) # Squeeze-and-excitation - if self.has_se: + if self.se is not None: x = self.se(x) # Point-wise linear projection @@ -350,22 +351,24 @@ class EdgeResidual(nn.Module): mid_chs = make_divisible(fake_in_chs * exp_ratio) else: mid_chs = make_divisible(in_chs * exp_ratio) - self.has_se = se_ratio is not None and se_ratio > 0. + has_se = se_ratio is not None and se_ratio > 0. self.has_residual = (in_chs == out_chs and stride == 1) and not noskip self.drop_connect_rate = drop_connect_rate # Expansion convolution - self.conv_exp = select_conv2d(in_chs, mid_chs, exp_kernel_size, padding=pad_type) + self.conv_exp = create_conv2d(in_chs, mid_chs, exp_kernel_size, padding=pad_type) self.bn1 = norm_layer(mid_chs, **norm_kwargs) self.act1 = act_layer(inplace=True) # Squeeze-and-excitation - if self.has_se: + if has_se: se_kwargs = resolve_se_args(se_kwargs, in_chs, act_layer) self.se = SqueezeExcite(mid_chs, se_ratio=se_ratio, **se_kwargs) + else: + self.se = None # Point-wise linear projection - self.conv_pwl = select_conv2d( + self.conv_pwl = create_conv2d( mid_chs, out_chs, pw_kernel_size, stride=stride, dilation=dilation, padding=pad_type) self.bn2 = norm_layer(out_chs, **norm_kwargs) @@ -389,7 +392,7 @@ class EdgeResidual(nn.Module): x = self.act1(x) # Squeeze-and-excitation - if self.has_se: + if self.se is not None: x = self.se(x) # Point-wise linear projection diff --git a/timm/models/efficientnet_builder.py b/timm/models/efficientnet_builder.py index db6f54f9..954420fb 100644 --- a/timm/models/efficientnet_builder.py +++ b/timm/models/efficientnet_builder.py @@ -5,7 +5,8 @@ from collections.__init__ import OrderedDict from copy import deepcopy import torch.nn as nn -from .activations import sigmoid, HardSwish, Swish +from .layers import CondConv2d, get_condconv_initializer +from .layers.activations import HardSwish, Swish from .efficientnet_blocks import * @@ -358,15 +359,24 @@ class EfficientNetBuilder: return stages -def _init_weight_goog(m, n=''): +def _init_weight_goog(m, n='', fix_group_fanout=False): """ Weight initialization as per Tensorflow official implementations. + Args: + m (nn.Module): module to init + n (str): module name + fix_group_fanout (bool): enable correct fanout calculation w/ group convs + + FIXME change fix_group_fanout to default to True if experiments show better training results + Handles layers in EfficientNet, EfficientNet-CondConv, MixNet, MnasNet, MobileNetV3, etc: * https://github.com/tensorflow/tpu/blob/master/models/official/mnasnet/mnasnet_model.py * https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/efficientnet_model.py """ if isinstance(m, CondConv2d): fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + if fix_group_fanout: + fan_out //= m.groups init_weight_fn = get_condconv_initializer( lambda w: w.data.normal_(0, math.sqrt(2.0 / fan_out)), m.num_experts, m.weight_shape) init_weight_fn(m.weight) @@ -374,6 +384,8 @@ def _init_weight_goog(m, n=''): m.bias.data.zero_() elif isinstance(m, nn.Conv2d): fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + if fix_group_fanout: + fan_out //= m.groups m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) if m.bias is not None: m.bias.data.zero_() @@ -390,21 +402,6 @@ def _init_weight_goog(m, n=''): m.bias.data.zero_() -def _init_weight_default(m, n=''): - """ Basic ResNet (Kaiming) style weight init""" - if isinstance(m, CondConv2d): - init_fn = get_condconv_initializer(partial( - nn.init.kaiming_normal_, mode='fan_out', nonlinearity='relu'), m.num_experts, m.weight_shape) - init_fn(m.weight) - elif isinstance(m, nn.Conv2d): - nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') - elif isinstance(m, nn.BatchNorm2d): - m.weight.data.fill_(1.0) - m.bias.data.zero_() - elif isinstance(m, nn.Linear): - nn.init.kaiming_uniform_(m.weight, mode='fan_in', nonlinearity='linear') - - def efficientnet_init_weights(model: nn.Module, init_fn=None): init_fn = init_fn or _init_weight_goog for n, m in model.named_modules(): diff --git a/timm/models/gluon_resnet.py b/timm/models/gluon_resnet.py index f835a485..6ccc4c53 100644 --- a/timm/models/gluon_resnet.py +++ b/timm/models/gluon_resnet.py @@ -11,6 +11,7 @@ import torch.nn.functional as F from .registry import register_model from .helpers import load_pretrained +from .layers import SEModule from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from .resnet import ResNet, Bottleneck, BasicBlock @@ -319,8 +320,8 @@ def gluon_seresnext50_32x4d(pretrained=False, num_classes=1000, in_chans=3, **kw """ default_cfg = default_cfgs['gluon_seresnext50_32x4d'] model = ResNet( - Bottleneck, [3, 4, 6, 3], cardinality=32, base_width=4, use_se=True, - num_classes=num_classes, in_chans=in_chans, **kwargs) + Bottleneck, [3, 4, 6, 3], cardinality=32, base_width=4, + num_classes=num_classes, in_chans=in_chans, block_args=dict(attn_layer=SEModule), **kwargs) model.default_cfg = default_cfg if pretrained: load_pretrained(model, default_cfg, num_classes, in_chans) @@ -333,8 +334,8 @@ def gluon_seresnext101_32x4d(pretrained=False, num_classes=1000, in_chans=3, **k """ default_cfg = default_cfgs['gluon_seresnext101_32x4d'] model = ResNet( - Bottleneck, [3, 4, 23, 3], cardinality=32, base_width=4, use_se=True, - num_classes=num_classes, in_chans=in_chans, **kwargs) + Bottleneck, [3, 4, 23, 3], cardinality=32, base_width=4, + num_classes=num_classes, in_chans=in_chans, block_args=dict(attn_layer=SEModule), **kwargs) model.default_cfg = default_cfg if pretrained: load_pretrained(model, default_cfg, num_classes, in_chans) @@ -346,9 +347,10 @@ def gluon_seresnext101_64x4d(pretrained=False, num_classes=1000, in_chans=3, **k """Constructs a SEResNeXt-101-64x4d model. """ default_cfg = default_cfgs['gluon_seresnext101_64x4d'] + block_args = dict(attn_layer=SEModule) model = ResNet( - Bottleneck, [3, 4, 23, 3], cardinality=64, base_width=4, use_se=True, - num_classes=num_classes, in_chans=in_chans, **kwargs) + Bottleneck, [3, 4, 23, 3], cardinality=64, base_width=4, + num_classes=num_classes, in_chans=in_chans, block_args=block_args, **kwargs) model.default_cfg = default_cfg if pretrained: load_pretrained(model, default_cfg, num_classes, in_chans) @@ -360,10 +362,10 @@ def gluon_senet154(pretrained=False, num_classes=1000, in_chans=3, **kwargs): """Constructs an SENet-154 model. """ default_cfg = default_cfgs['gluon_senet154'] + block_args = dict(attn_layer=SEModule) model = ResNet( - Bottleneck, [3, 8, 36, 3], cardinality=64, base_width=4, use_se=True, - stem_type='deep', down_kernel_size=3, block_reduce_first=2, - num_classes=num_classes, in_chans=in_chans, **kwargs) + Bottleneck, [3, 8, 36, 3], cardinality=64, base_width=4, stem_type='deep', down_kernel_size=3, + block_reduce_first=2, num_classes=num_classes, in_chans=in_chans, block_args=block_args, **kwargs) model.default_cfg = default_cfg if pretrained: load_pretrained(model, default_cfg, num_classes, in_chans) diff --git a/timm/models/gluon_xception.py b/timm/models/gluon_xception.py index 5a35d226..2fc8e699 100644 --- a/timm/models/gluon_xception.py +++ b/timm/models/gluon_xception.py @@ -13,7 +13,7 @@ from collections import OrderedDict from .registry import register_model from .helpers import load_pretrained -from .adaptive_avgmax_pool import SelectAdaptivePool2d +from .layers import SelectAdaptivePool2d from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD __all__ = ['Xception65', 'Xception71'] diff --git a/timm/models/hrnet.py b/timm/models/hrnet.py index 99a2bd91..16df5bc1 100644 --- a/timm/models/hrnet.py +++ b/timm/models/hrnet.py @@ -25,7 +25,7 @@ import torch.nn.functional as F from .resnet import BasicBlock, Bottleneck # leveraging ResNet blocks w/ additional features like SE from .registry import register_model from .helpers import load_pretrained -from .adaptive_avgmax_pool import SelectAdaptivePool2d +from .layers import SelectAdaptivePool2d from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD _BN_MOMENTUM = 0.1 diff --git a/timm/models/inception_resnet_v2.py b/timm/models/inception_resnet_v2.py index 285863f5..13ad0e9d 100644 --- a/timm/models/inception_resnet_v2.py +++ b/timm/models/inception_resnet_v2.py @@ -8,7 +8,7 @@ import torch.nn.functional as F from .registry import register_model from .helpers import load_pretrained -from .adaptive_avgmax_pool import SelectAdaptivePool2d +from .layers import SelectAdaptivePool2d from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD __all__ = ['InceptionResnetV2'] diff --git a/timm/models/inception_v4.py b/timm/models/inception_v4.py index 8c3dee86..16080554 100644 --- a/timm/models/inception_v4.py +++ b/timm/models/inception_v4.py @@ -8,7 +8,7 @@ import torch.nn.functional as F from .registry import register_model from .helpers import load_pretrained -from .adaptive_avgmax_pool import SelectAdaptivePool2d +from .layers import SelectAdaptivePool2d from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD __all__ = ['InceptionV4'] diff --git a/timm/models/layers/__init__.py b/timm/models/layers/__init__.py new file mode 100644 index 00000000..828c20b2 --- /dev/null +++ b/timm/models/layers/__init__.py @@ -0,0 +1,17 @@ +from .padding import get_padding +from .avg_pool2d_same import AvgPool2dSame +from .conv2d_same import Conv2dSame +from .conv_bn_act import ConvBnAct +from .mixed_conv2d import MixedConv2d +from .cond_conv2d import CondConv2d, get_condconv_initializer +from .create_conv2d import create_conv2d +from .create_attn import create_attn +from .selective_kernel import SelectiveKernelConv +from .se import SEModule +from .eca import EcaModule, CecaModule +from .activations import * +from .adaptive_avgmax_pool import \ + adaptive_avgmax_pool2d, select_adaptive_pool2d, AdaptiveAvgMaxPool2d, SelectAdaptivePool2d +from .drop import DropBlock2d, DropPath +from .test_time_pool import TestTimePoolHead, apply_test_time_pool +from .split_batchnorm import SplitBatchNorm2d, convert_splitbn_model diff --git a/timm/models/activations.py b/timm/models/layers/activations.py similarity index 83% rename from timm/models/activations.py rename to timm/models/layers/activations.py index aafa290c..6f8d2f89 100644 --- a/timm/models/activations.py +++ b/timm/models/layers/activations.py @@ -1,3 +1,12 @@ +""" Activations + +A collection of activations fn and modules with a common interface so that they can +easily be swapped. All have an `inplace` arg even if not used. + +Hacked together by Ross Wightman +""" + + import torch from torch import nn as nn from torch.nn import functional as F @@ -66,20 +75,20 @@ if _USE_MEM_EFFICIENT_ISH: return MishJitAutoFn.apply(x) else: - def swish(x, inplace=False): + def swish(x, inplace: bool = False): """Swish - Described in: https://arxiv.org/abs/1710.05941 """ return x.mul_(x.sigmoid()) if inplace else x.mul(x.sigmoid()) - def mish(x, _inplace=False): + def mish(x, _inplace: bool = False): """Mish: A Self Regularized Non-Monotonic Neural Activation Function - https://arxiv.org/abs/1908.08681 """ return x.mul(F.softplus(x).tanh()) class Swish(nn.Module): - def __init__(self, inplace=False): + def __init__(self, inplace: bool = False): super(Swish, self).__init__() self.inplace = inplace @@ -88,7 +97,7 @@ class Swish(nn.Module): class Mish(nn.Module): - def __init__(self, inplace=False): + def __init__(self, inplace: bool = False): super(Mish, self).__init__() self.inplace = inplace @@ -96,13 +105,13 @@ class Mish(nn.Module): return mish(x, self.inplace) -def sigmoid(x, inplace=False): +def sigmoid(x, inplace: bool = False): return x.sigmoid_() if inplace else x.sigmoid() # PyTorch has this, but not with a consistent inplace argmument interface class Sigmoid(nn.Module): - def __init__(self, inplace=False): + def __init__(self, inplace: bool = False): super(Sigmoid, self).__init__() self.inplace = inplace @@ -110,13 +119,13 @@ class Sigmoid(nn.Module): return x.sigmoid_() if self.inplace else x.sigmoid() -def tanh(x, inplace=False): +def tanh(x, inplace: bool = False): return x.tanh_() if inplace else x.tanh() # PyTorch has this, but not with a consistent inplace argmument interface class Tanh(nn.Module): - def __init__(self, inplace=False): + def __init__(self, inplace: bool = False): super(Tanh, self).__init__() self.inplace = inplace @@ -124,13 +133,13 @@ class Tanh(nn.Module): return x.tanh_() if self.inplace else x.tanh() -def hard_swish(x, inplace=False): +def hard_swish(x, inplace: bool = False): inner = F.relu6(x + 3.).div_(6.) return x.mul_(inner) if inplace else x.mul(inner) class HardSwish(nn.Module): - def __init__(self, inplace=False): + def __init__(self, inplace: bool = False): super(HardSwish, self).__init__() self.inplace = inplace @@ -138,7 +147,7 @@ class HardSwish(nn.Module): return hard_swish(x, self.inplace) -def hard_sigmoid(x, inplace=False): +def hard_sigmoid(x, inplace: bool = False): if inplace: return x.add_(3.).clamp_(0., 6.).div_(6.) else: @@ -146,7 +155,7 @@ def hard_sigmoid(x, inplace=False): class HardSigmoid(nn.Module): - def __init__(self, inplace=False): + def __init__(self, inplace: bool = False): super(HardSigmoid, self).__init__() self.inplace = inplace diff --git a/timm/models/adaptive_avgmax_pool.py b/timm/models/layers/adaptive_avgmax_pool.py similarity index 100% rename from timm/models/adaptive_avgmax_pool.py rename to timm/models/layers/adaptive_avgmax_pool.py diff --git a/timm/models/layers/avg_pool2d_same.py b/timm/models/layers/avg_pool2d_same.py new file mode 100644 index 00000000..33656e79 --- /dev/null +++ b/timm/models/layers/avg_pool2d_same.py @@ -0,0 +1,31 @@ +""" AvgPool2d w/ Same Padding + +Hacked together by Ross Wightman +""" +import torch +import torch.nn as nn +import torch.nn.functional as F +from typing import List +import math + +from .helpers import tup_pair +from .padding import pad_same + + +def avg_pool2d_same(x, kernel_size: List[int], stride: List[int], padding: List[int] = (0, 0), + ceil_mode: bool = False, count_include_pad: bool = True): + x = pad_same(x, kernel_size, stride) + return F.avg_pool2d(x, kernel_size, stride, (0, 0), ceil_mode, count_include_pad) + + +class AvgPool2dSame(nn.AvgPool2d): + """ Tensorflow like 'SAME' wrapper for 2D average pooling + """ + def __init__(self, kernel_size: int, stride=None, padding=0, ceil_mode=False, count_include_pad=True): + kernel_size = tup_pair(kernel_size) + stride = tup_pair(stride) + super(AvgPool2dSame, self).__init__(kernel_size, stride, (0, 0), ceil_mode, count_include_pad) + + def forward(self, x): + return avg_pool2d_same( + x, self.kernel_size, self.stride, self.padding, self.ceil_mode, self.count_include_pad) diff --git a/timm/models/layers/cbam.py b/timm/models/layers/cbam.py new file mode 100644 index 00000000..37ba1c35 --- /dev/null +++ b/timm/models/layers/cbam.py @@ -0,0 +1,97 @@ +""" CBAM (sort-of) Attention + +Experimental impl of CBAM: Convolutional Block Attention Module: https://arxiv.org/abs/1807.06521 + +Hacked together by Ross Wightman +""" + +import torch +from torch import nn as nn +from .conv_bn_act import ConvBnAct + + +class ChannelAttn(nn.Module): + """ Original CBAM channel attention module, currently avg + max pool variant only. + """ + def __init__(self, channels, reduction=16, act_layer=nn.ReLU): + super(ChannelAttn, self).__init__() + self.avg_pool = nn.AdaptiveAvgPool2d(1) + self.max_pool = nn.AdaptiveMaxPool2d(1) + self.fc1 = nn.Conv2d(channels, channels // reduction, 1, bias=False) + self.act = act_layer(inplace=True) + self.fc2 = nn.Conv2d(channels // reduction, channels, 1, bias=False) + + def forward(self, x): + x_avg = self.avg_pool(x) + x_max = self.max_pool(x) + x_avg = self.fc2(self.act(self.fc1(x_avg))) + x_max = self.fc2(self.act(self.fc1(x_max))) + x_attn = x_avg + x_max + return x * x_attn.sigmoid() + + +class LightChannelAttn(ChannelAttn): + """An experimental 'lightweight' that sums avg + max pool first + """ + def __init__(self, channels, reduction=16): + super(LightChannelAttn, self).__init__(channels, reduction) + + def forward(self, x): + x_pool = 0.5 * self.avg_pool(x) + 0.5 * self.max_pool(x) + x_attn = self.fc2(self.act(self.fc1(x_pool))) + return x * x_attn.sigmoid() + + +class SpatialAttn(nn.Module): + """ Original CBAM spatial attention module + """ + def __init__(self, kernel_size=7): + super(SpatialAttn, self).__init__() + self.conv = ConvBnAct(2, 1, kernel_size, act_layer=None) + + def forward(self, x): + x_avg = torch.mean(x, dim=1, keepdim=True) + x_max = torch.max(x, dim=1, keepdim=True)[0] + x_attn = torch.cat([x_avg, x_max], dim=1) + x_attn = self.conv(x_attn) + return x * x_attn.sigmoid() + + +class LightSpatialAttn(nn.Module): + """An experimental 'lightweight' variant that sums avg_pool and max_pool results. + """ + def __init__(self, kernel_size=7): + super(LightSpatialAttn, self).__init__() + self.conv = ConvBnAct(1, 1, kernel_size, act_layer=None) + + def forward(self, x): + x_avg = torch.mean(x, dim=1, keepdim=True) + x_max = torch.max(x, dim=1, keepdim=True)[0] + x_attn = 0.5 * x_avg + 0.5 * x_max + x_attn = self.conv(x_attn) + return x * x_attn.sigmoid() + + +class CbamModule(nn.Module): + def __init__(self, channels, spatial_kernel_size=7): + super(CbamModule, self).__init__() + self.channel = ChannelAttn(channels) + self.spatial = SpatialAttn(spatial_kernel_size) + + def forward(self, x): + x = self.channel(x) + x = self.spatial(x) + return x + + +class LightCbamModule(nn.Module): + def __init__(self, channels, spatial_kernel_size=7): + super(LightCbamModule, self).__init__() + self.channel = LightChannelAttn(channels) + self.spatial = LightSpatialAttn(spatial_kernel_size) + + def forward(self, x): + x = self.channel(x) + x = self.spatial(x) + return x + diff --git a/timm/models/layers/cond_conv2d.py b/timm/models/layers/cond_conv2d.py new file mode 100644 index 00000000..a7a424a6 --- /dev/null +++ b/timm/models/layers/cond_conv2d.py @@ -0,0 +1,118 @@ +""" Conditional Convolution + +Hacked together by Ross Wightman +""" + +import math +from functools import partial +import numpy as np +import torch +from torch import nn as nn +from torch.nn import functional as F + +from .helpers import tup_pair +from .conv2d_same import get_padding_value, conv2d_same + + +def get_condconv_initializer(initializer, num_experts, expert_shape): + def condconv_initializer(weight): + """CondConv initializer function.""" + num_params = np.prod(expert_shape) + if (len(weight.shape) != 2 or weight.shape[0] != num_experts or + weight.shape[1] != num_params): + raise (ValueError( + 'CondConv variables must have shape [num_experts, num_params]')) + for i in range(num_experts): + initializer(weight[i].view(expert_shape)) + return condconv_initializer + + +class CondConv2d(nn.Module): + """ Conditional Convolution + Inspired by: https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/condconv/condconv_layers.py + + Grouped convolution hackery for parallel execution of the per-sample kernel filters inspired by this discussion: + https://github.com/pytorch/pytorch/issues/17983 + """ + __constants__ = ['bias', 'in_channels', 'out_channels', 'dynamic_padding'] + + def __init__(self, in_channels, out_channels, kernel_size=3, + stride=1, padding='', dilation=1, groups=1, bias=False, num_experts=4): + super(CondConv2d, self).__init__() + + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = tup_pair(kernel_size) + self.stride = tup_pair(stride) + padding_val, is_padding_dynamic = get_padding_value( + padding, kernel_size, stride=stride, dilation=dilation) + self.dynamic_padding = is_padding_dynamic # if in forward to work with torchscript + self.padding = tup_pair(padding_val) + self.dilation = tup_pair(dilation) + self.groups = groups + self.num_experts = num_experts + + self.weight_shape = (self.out_channels, self.in_channels // self.groups) + self.kernel_size + weight_num_param = 1 + for wd in self.weight_shape: + weight_num_param *= wd + self.weight = torch.nn.Parameter(torch.Tensor(self.num_experts, weight_num_param)) + + if bias: + self.bias_shape = (self.out_channels,) + self.bias = torch.nn.Parameter(torch.Tensor(self.num_experts, self.out_channels)) + else: + self.register_parameter('bias', None) + + self.reset_parameters() + + def reset_parameters(self): + init_weight = get_condconv_initializer( + partial(nn.init.kaiming_uniform_, a=math.sqrt(5)), self.num_experts, self.weight_shape) + init_weight(self.weight) + if self.bias is not None: + fan_in = np.prod(self.weight_shape[1:]) + bound = 1 / math.sqrt(fan_in) + init_bias = get_condconv_initializer( + partial(nn.init.uniform_, a=-bound, b=bound), self.num_experts, self.bias_shape) + init_bias(self.bias) + + def forward(self, x, routing_weights): + B, C, H, W = x.shape + weight = torch.matmul(routing_weights, self.weight) + new_weight_shape = (B * self.out_channels, self.in_channels // self.groups) + self.kernel_size + weight = weight.view(new_weight_shape) + bias = None + if self.bias is not None: + bias = torch.matmul(routing_weights, self.bias) + bias = bias.view(B * self.out_channels) + # move batch elements with channels so each batch element can be efficiently convolved with separate kernel + x = x.view(1, B * C, H, W) + if self.dynamic_padding: + out = conv2d_same( + x, weight, bias, stride=self.stride, padding=self.padding, + dilation=self.dilation, groups=self.groups * B) + else: + out = F.conv2d( + x, weight, bias, stride=self.stride, padding=self.padding, + dilation=self.dilation, groups=self.groups * B) + out = out.permute([1, 0, 2, 3]).view(B, self.out_channels, out.shape[-2], out.shape[-1]) + + # Literal port (from TF definition) + # x = torch.split(x, 1, 0) + # weight = torch.split(weight, 1, 0) + # if self.bias is not None: + # bias = torch.matmul(routing_weights, self.bias) + # bias = torch.split(bias, 1, 0) + # else: + # bias = [None] * B + # out = [] + # for xi, wi, bi in zip(x, weight, bias): + # wi = wi.view(*self.weight_shape) + # if bi is not None: + # bi = bi.view(*self.bias_shape) + # out.append(self.conv_fn( + # xi, wi, bi, stride=self.stride, padding=self.padding, + # dilation=self.dilation, groups=self.groups)) + # out = torch.cat(out, 0) + return out diff --git a/timm/models/layers/conv2d_same.py b/timm/models/layers/conv2d_same.py new file mode 100644 index 00000000..0e29ae8c --- /dev/null +++ b/timm/models/layers/conv2d_same.py @@ -0,0 +1,66 @@ +""" Conv2d w/ Same Padding + +Hacked together by Ross Wightman +""" +import torch +import torch.nn as nn +import torch.nn.functional as F +from typing import Union, List, Tuple, Optional, Callable +import math + +from .padding import get_padding, pad_same, is_static_pad + + +def conv2d_same( + x, weight: torch.Tensor, bias: Optional[torch.Tensor] = None, stride: Tuple[int, int] = (1, 1), + padding: Tuple[int, int] = (0, 0), dilation: Tuple[int, int] = (1, 1), groups: int = 1): + x = pad_same(x, weight.shape[-2:], stride, dilation) + return F.conv2d(x, weight, bias, stride, (0, 0), dilation, groups) + + +class Conv2dSame(nn.Conv2d): + """ Tensorflow like 'SAME' convolution wrapper for 2D convolutions + """ + + def __init__(self, in_channels, out_channels, kernel_size, stride=1, + padding=0, dilation=1, groups=1, bias=True): + super(Conv2dSame, self).__init__( + in_channels, out_channels, kernel_size, stride, 0, dilation, groups, bias) + + def forward(self, x): + return conv2d_same(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups) + + +def get_padding_value(padding, kernel_size, **kwargs) -> Tuple[Tuple, bool]: + dynamic = False + if isinstance(padding, str): + # for any string padding, the padding will be calculated for you, one of three ways + padding = padding.lower() + if padding == 'same': + # TF compatible 'SAME' padding, has a performance and GPU memory allocation impact + if is_static_pad(kernel_size, **kwargs): + # static case, no extra overhead + padding = get_padding(kernel_size, **kwargs) + else: + # dynamic 'SAME' padding, has runtime/GPU memory overhead + padding = 0 + dynamic = True + elif padding == 'valid': + # 'VALID' padding, same as padding=0 + padding = 0 + else: + # Default to PyTorch style 'same'-ish symmetric padding + padding = get_padding(kernel_size, **kwargs) + return padding, dynamic + + +def create_conv2d_pad(in_chs, out_chs, kernel_size, **kwargs): + padding = kwargs.pop('padding', '') + kwargs.setdefault('bias', False) + padding, is_dynamic = get_padding_value(padding, kernel_size, **kwargs) + if is_dynamic: + return Conv2dSame(in_chs, out_chs, kernel_size, **kwargs) + else: + return nn.Conv2d(in_chs, out_chs, kernel_size, padding=padding, **kwargs) + + diff --git a/timm/models/layers/conv_bn_act.py b/timm/models/layers/conv_bn_act.py new file mode 100644 index 00000000..f5c94720 --- /dev/null +++ b/timm/models/layers/conv_bn_act.py @@ -0,0 +1,32 @@ +""" Conv2d + BN + Act + +Hacked together by Ross Wightman +""" +from torch import nn as nn + +from timm.models.layers import get_padding + + +class ConvBnAct(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, dilation=1, groups=1, + drop_block=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d): + super(ConvBnAct, self).__init__() + padding = get_padding(kernel_size, stride, dilation) # assuming PyTorch style padding for this block + self.conv = nn.Conv2d( + in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, + padding=padding, dilation=dilation, groups=groups, bias=False) + self.bn = norm_layer(out_channels) + self.drop_block = drop_block + if act_layer is not None: + self.act = act_layer(inplace=True) + else: + self.act = None + + def forward(self, x): + x = self.conv(x) + x = self.bn(x) + if self.drop_block is not None: + x = self.drop_block(x) + if self.act is not None: + x = self.act(x) + return x diff --git a/timm/models/layers/create_attn.py b/timm/models/layers/create_attn.py new file mode 100644 index 00000000..3bca254f --- /dev/null +++ b/timm/models/layers/create_attn.py @@ -0,0 +1,35 @@ +""" Select AttentionFactory Method + +Hacked together by Ross Wightman +""" +import torch +from .se import SEModule +from .eca import EcaModule, CecaModule +from .cbam import CbamModule, LightCbamModule + + +def create_attn(attn_type, channels, **kwargs): + module_cls = None + if attn_type is not None: + if isinstance(attn_type, str): + attn_type = attn_type.lower() + if attn_type == 'se': + module_cls = SEModule + elif attn_type == 'eca': + module_cls = EcaModule + elif attn_type == 'eca': + module_cls = CecaModule + elif attn_type == 'cbam': + module_cls = CbamModule + elif attn_type == 'lcbam': + module_cls = LightCbamModule + else: + assert False, "Invalid attn module (%s)" % attn_type + elif isinstance(attn_type, bool): + if attn_type: + module_cls = SEModule + else: + module_cls = attn_type + if module_cls is not None: + return module_cls(channels, **kwargs) + return None diff --git a/timm/models/layers/create_conv2d.py b/timm/models/layers/create_conv2d.py new file mode 100644 index 00000000..527c80a3 --- /dev/null +++ b/timm/models/layers/create_conv2d.py @@ -0,0 +1,30 @@ +""" Create Conv2d Factory Method + +Hacked together by Ross Wightman +""" + +from .mixed_conv2d import MixedConv2d +from .cond_conv2d import CondConv2d +from .conv2d_same import create_conv2d_pad + + +def create_conv2d(in_chs, out_chs, kernel_size, **kwargs): + """ Select a 2d convolution implementation based on arguments + Creates and returns one of torch.nn.Conv2d, Conv2dSame, MixedConv2d, or CondConv2d. + + Used extensively by EfficientNet, MobileNetv3 and related networks. + """ + assert 'groups' not in kwargs # only use 'depthwise' bool arg + if isinstance(kernel_size, list): + assert 'num_experts' not in kwargs # MixNet + CondConv combo not supported currently + # We're going to use only lists for defining the MixedConv2d kernel groups, + # ints, tuples, other iterables will continue to pass to normal conv and specify h, w. + m = MixedConv2d(in_chs, out_chs, kernel_size, **kwargs) + else: + depthwise = kwargs.pop('depthwise', False) + groups = out_chs if depthwise else 1 + if 'num_experts' in kwargs and kwargs['num_experts'] > 0: + m = CondConv2d(in_chs, out_chs, kernel_size, groups=groups, **kwargs) + else: + m = create_conv2d_pad(in_chs, out_chs, kernel_size, groups=groups, **kwargs) + return m diff --git a/timm/models/layers/drop.py b/timm/models/layers/drop.py new file mode 100644 index 00000000..46d5d20b --- /dev/null +++ b/timm/models/layers/drop.py @@ -0,0 +1,88 @@ +""" DropBlock, DropPath + +PyTorch implementations of DropBlock and DropPath (Stochastic Depth) regularization layers. + +Hacked together by Ross Wightman +""" +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np +import math + + +def drop_block_2d(x, drop_prob=0.1, block_size=7, gamma_scale=1.0, drop_with_noise=False): + """ DropBlock. See https://arxiv.org/pdf/1810.12890.pdf + """ + _, _, height, width = x.shape + total_size = width * height + clipped_block_size = min(block_size, min(width, height)) + # seed_drop_rate, the gamma parameter + seed_drop_rate = gamma_scale * drop_prob * total_size / clipped_block_size ** 2 / ( + (width - block_size + 1) * + (height - block_size + 1)) + + # Forces the block to be inside the feature map. + w_i, h_i = torch.meshgrid(torch.arange(width).to(x.device), torch.arange(height).to(x.device)) + valid_block = ((w_i >= clipped_block_size // 2) & (w_i < width - (clipped_block_size - 1) // 2)) & \ + ((h_i >= clipped_block_size // 2) & (h_i < height - (clipped_block_size - 1) // 2)) + valid_block = torch.reshape(valid_block, (1, 1, height, width)).float() + + uniform_noise = torch.rand_like(x, dtype=torch.float32) + block_mask = ((2 - seed_drop_rate - valid_block + uniform_noise) >= 1).float() + block_mask = -F.max_pool2d( + -block_mask, + kernel_size=clipped_block_size, # block_size, ??? + stride=1, + padding=clipped_block_size // 2) + + if drop_with_noise: + normal_noise = torch.randn_like(x) + x = x * block_mask + normal_noise * (1 - block_mask) + else: + normalize_scale = block_mask.numel() / (torch.sum(block_mask) + 1e-7) + x = x * block_mask * normalize_scale + return x + + +class DropBlock2d(nn.Module): + """ DropBlock. See https://arxiv.org/pdf/1810.12890.pdf + """ + def __init__(self, + drop_prob=0.1, + block_size=7, + gamma_scale=1.0, + with_noise=False): + super(DropBlock2d, self).__init__() + self.drop_prob = drop_prob + self.gamma_scale = gamma_scale + self.block_size = block_size + self.with_noise = with_noise + + def forward(self, x): + if not self.training or not self.drop_prob: + return x + return drop_block_2d(x, self.drop_prob, self.block_size, self.gamma_scale, self.with_noise) + + +def drop_path(x, drop_prob=0.): + """Drop paths (Stochastic Depth) per sample (when applied in residual blocks). + """ + keep_prob = 1 - drop_prob + random_tensor = keep_prob + torch.rand((x.size()[0], 1, 1, 1), dtype=x.dtype, device=x.device) + random_tensor.floor_() # binarize + output = x.div(keep_prob) * random_tensor + return output + + +class DropPath(nn.ModuleDict): + """Drop paths (Stochastic Depth) per sample (when applied in residual blocks). + """ + def __init__(self, drop_prob=None): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + + def forward(self, x): + if not self.training or not self.drop_prob: + return x + return drop_path(x, self.drop_prob) diff --git a/timm/models/EcaModule.py b/timm/models/layers/eca.py similarity index 69% rename from timm/models/EcaModule.py rename to timm/models/layers/eca.py index 54971436..f4072aeb 100644 --- a/timm/models/EcaModule.py +++ b/timm/models/layers/eca.py @@ -1,14 +1,16 @@ -''' +""" ECA module from ECAnet -original paper: ECA-Net: Efficient Channel Attention for Deep Convolutional Neural Networks + +paper: ECA-Net: Efficient Channel Attention for Deep Convolutional Neural Networks https://arxiv.org/abs/1910.03151 -https://github.com/BangguWu/ECANet -original ECA model borrowed from original github -modified circular ECA implementation and -adoptation for use in pytorch image models package +Original ECA model borrowed from https://github.com/BangguWu/ECANet + +Modified circular ECA implementation and adaption for use in timm package by Chris Ha https://github.com/VRandme +Original License: + MIT License Copyright (c) 2019 BangguWu, Qilong Wang @@ -30,14 +32,15 @@ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -''' +""" import math import torch from torch import nn import torch.nn.functional as F + class EcaModule(nn.Module): - """Constructs a ECA module. + """Constructs an ECA module. Args: channel: Number of channels of the input feature map for use in adaptive kernel sizes @@ -45,35 +48,36 @@ class EcaModule(nn.Module): gamma, beta: when channel is given parameters of mapping function refer to original paper https://arxiv.org/pdf/1910.03151.pdf (default=None. if channel size not given, use k_size given for kernel size.) - k_size: Adaptive selection of kernel size (default=3) + kernel_size: Adaptive selection of kernel size (default=3) """ - def __init__(self, channel=None, k_size=3, gamma=2, beta=1): + def __init__(self, channels=None, kernel_size=3, gamma=2, beta=1): super(EcaModule, self).__init__() - assert k_size % 2 == 1 + assert kernel_size % 2 == 1 - if channel is not None: - t = int(abs(math.log(channel, 2)+beta) / gamma) - k_size = t if t % 2 else t + 1 + if channels is not None: + t = int(abs(math.log(channels, 2) + beta) / gamma) + kernel_size = max(t if t % 2 else t + 1, 3) self.avg_pool = nn.AdaptiveAvgPool2d(1) - self.conv = nn.Conv1d(1, 1, kernel_size=k_size, padding=(k_size - 1) // 2, bias=False) - self.sigmoid = nn.Sigmoid() + self.conv = nn.Conv1d(1, 1, kernel_size=kernel_size, padding=(kernel_size - 1) // 2, bias=False) def forward(self, x): - # feature descriptor on the global spatial information + # Feature descriptor on the global spatial information y = self.avg_pool(x) - # reshape for convolution + # Reshape for convolution y = y.view(x.shape[0], 1, -1) # Two different branches of ECA module y = self.conv(y) # Multi-scale information fusion - y = self.sigmoid(y.view(x.shape[0], -1, 1, 1)) + y = y.view(x.shape[0], -1, 1, 1).sigmoid() return x * y.expand_as(x) + class CecaModule(nn.Module): """Constructs a circular ECA module. - the primary difference is that the conv uses a circular padding rather than zero padding. - This is because unlike images, the channels themselves do not have inherent ordering nor + + ECA module where the conv uses circular padding rather than zero padding. + Unlike the spatial dimension, the channels do not have inherent ordering nor locality. Although this module in essence, applies such an assumption, it is unnecessary to limit the channels on either "edge" from being circularly adapted to each other. This will fundamentally increase connectivity and possibly increase performance metrics @@ -81,43 +85,42 @@ class CecaModule(nn.Module): (parameter size, throughput,latency, etc) Args: - channel: Number of channels of the input feature map for use in adaptive kernel sizes + channels: Number of channels of the input feature map for use in adaptive kernel sizes for actual calculations according to channel. gamma, beta: when channel is given parameters of mapping function refer to original paper https://arxiv.org/pdf/1910.03151.pdf (default=None. if channel size not given, use k_size given for kernel size.) - k_size: Adaptive selection of kernel size (default=3) + kernel_size: Adaptive selection of kernel size (default=3) """ - def __init__(self, channel=None, k_size=3, gamma=2, beta=1): + def __init__(self, channels=None, kernel_size=3, gamma=2, beta=1): super(CecaModule, self).__init__() - assert k_size % 2 == 1 + assert kernel_size % 2 == 1 - if channel is not None: - t = int(abs(math.log(channel, 2)+beta) / gamma) - k_size = t if t % 2 else t + 1 + if channels is not None: + t = int(abs(math.log(channels, 2) + beta) / gamma) + kernel_size = max(t if t % 2 else t + 1, 3) self.avg_pool = nn.AdaptiveAvgPool2d(1) - #pytorch circular padding mode is bugged as of pytorch 1.4 + #pytorch circular padding mode is buggy as of pytorch 1.4 #see https://github.com/pytorch/pytorch/pull/17240 #implement manual circular padding - self.conv = nn.Conv1d(1, 1, kernel_size=k_size, padding=0, bias=False) - self.padding = (k_size - 1) // 2 - self.sigmoid = nn.Sigmoid() + self.conv = nn.Conv1d(1, 1, kernel_size=kernel_size, padding=0, bias=False) + self.padding = (kernel_size - 1) // 2 def forward(self, x): - # feature descriptor on the global spatial information + # Feature descriptor on the global spatial information y = self.avg_pool(x) - #manually implement circular padding, F.pad does not seemed to be bugged + # Manually implement circular padding, F.pad does not seemed to be bugged y = F.pad(y.view(x.shape[0], 1, -1), (self.padding, self.padding), mode='circular') # Two different branches of ECA module y = self.conv(y) # Multi-scale information fusion - y = self.sigmoid(y.view(x.shape[0], -1, 1, 1)) + y = y.view(x.shape[0], -1, 1, 1).sigmoid() return x * y.expand_as(x) diff --git a/timm/models/layers/helpers.py b/timm/models/layers/helpers.py new file mode 100644 index 00000000..967c2f4c --- /dev/null +++ b/timm/models/layers/helpers.py @@ -0,0 +1,27 @@ +""" Layer/Module Helpers + +Hacked together by Ross Wightman +""" +from itertools import repeat +from torch._six import container_abcs + + +# From PyTorch internals +def _ntuple(n): + def parse(x): + if isinstance(x, container_abcs.Iterable): + return x + return tuple(repeat(x, n)) + return parse + + +tup_single = _ntuple(1) +tup_pair = _ntuple(2) +tup_triple = _ntuple(3) +tup_quadruple = _ntuple(4) + + + + + + diff --git a/timm/models/median_pool.py b/timm/models/layers/median_pool.py similarity index 100% rename from timm/models/median_pool.py rename to timm/models/layers/median_pool.py diff --git a/timm/models/layers/mixed_conv2d.py b/timm/models/layers/mixed_conv2d.py new file mode 100644 index 00000000..3e280c03 --- /dev/null +++ b/timm/models/layers/mixed_conv2d.py @@ -0,0 +1,49 @@ +""" Conditional Convolution + +Hacked together by Ross Wightman +""" + +import torch +from torch import nn as nn + +from .conv2d_same import create_conv2d_pad + + +def _split_channels(num_chan, num_groups): + split = [num_chan // num_groups for _ in range(num_groups)] + split[0] += num_chan - sum(split) + return split + + +class MixedConv2d(nn.ModuleDict): + """ Mixed Grouped Convolution + + Based on MDConv and GroupedConv in MixNet impl: + https://github.com/tensorflow/tpu/blob/master/models/official/mnasnet/mixnet/custom_layers.py + """ + def __init__(self, in_channels, out_channels, kernel_size=3, + stride=1, padding='', dilation=1, depthwise=False, **kwargs): + super(MixedConv2d, self).__init__() + + kernel_size = kernel_size if isinstance(kernel_size, list) else [kernel_size] + num_groups = len(kernel_size) + in_splits = _split_channels(in_channels, num_groups) + out_splits = _split_channels(out_channels, num_groups) + self.in_channels = sum(in_splits) + self.out_channels = sum(out_splits) + for idx, (k, in_ch, out_ch) in enumerate(zip(kernel_size, in_splits, out_splits)): + conv_groups = out_ch if depthwise else 1 + # use add_module to keep key space clean + self.add_module( + str(idx), + create_conv2d_pad( + in_ch, out_ch, k, stride=stride, + padding=padding, dilation=dilation, groups=conv_groups, **kwargs) + ) + self.splits = in_splits + + def forward(self, x): + x_split = torch.split(x, self.splits, 1) + x_out = [c(x_split[i]) for i, c in enumerate(self.values())] + x = torch.cat(x_out, 1) + return x diff --git a/timm/models/layers/padding.py b/timm/models/layers/padding.py new file mode 100644 index 00000000..b3653866 --- /dev/null +++ b/timm/models/layers/padding.py @@ -0,0 +1,33 @@ +""" Padding Helpers + +Hacked together by Ross Wightman +""" +import math +from typing import List + +import torch.nn.functional as F + + +# Calculate symmetric padding for a convolution +def get_padding(kernel_size: int, stride: int = 1, dilation: int = 1, **_) -> int: + padding = ((stride - 1) + dilation * (kernel_size - 1)) // 2 + return padding + + +# Calculate asymmetric TensorFlow-like 'SAME' padding for a convolution +def get_same_padding(x: int, k: int, s: int, d: int): + return max((math.ceil(x / s) - 1) * s + (k - 1) * d + 1 - x, 0) + + +# Can SAME padding for given args be done statically? +def is_static_pad(kernel_size: int, stride: int = 1, dilation: int = 1, **_): + return stride == 1 and (dilation * (kernel_size - 1)) % 2 == 0 + + +# Dynamically pad input x with 'SAME' padding for conv with specified args +def pad_same(x, k: List[int], s: List[int], d: List[int] = (1, 1)): + ih, iw = x.size()[-2:] + pad_h, pad_w = get_same_padding(ih, k[0], s[0], d[0]), get_same_padding(iw, k[1], s[1], d[1]) + if pad_h > 0 or pad_w > 0: + x = F.pad(x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2]) + return x diff --git a/timm/models/layers/se.py b/timm/models/layers/se.py new file mode 100644 index 00000000..de87ccf5 --- /dev/null +++ b/timm/models/layers/se.py @@ -0,0 +1,21 @@ +from torch import nn as nn + + +class SEModule(nn.Module): + + def __init__(self, channels, reduction=16, act_layer=nn.ReLU): + super(SEModule, self).__init__() + self.avg_pool = nn.AdaptiveAvgPool2d(1) + reduction_channels = max(channels // reduction, 8) + self.fc1 = nn.Conv2d( + channels, reduction_channels, kernel_size=1, padding=0, bias=True) + self.act = act_layer(inplace=True) + self.fc2 = nn.Conv2d( + reduction_channels, channels, kernel_size=1, padding=0, bias=True) + + def forward(self, x): + x_se = self.avg_pool(x) + x_se = self.fc1(x_se) + x_se = self.act(x_se) + x_se = self.fc2(x_se) + return x * x_se.sigmoid() diff --git a/timm/models/layers/selective_kernel.py b/timm/models/layers/selective_kernel.py new file mode 100644 index 00000000..4100aa02 --- /dev/null +++ b/timm/models/layers/selective_kernel.py @@ -0,0 +1,88 @@ +""" Selective Kernel Convolution Attention + +Hacked together by Ross Wightman +""" + +import torch +from torch import nn as nn + +from .conv_bn_act import ConvBnAct + + +def _kernel_valid(k): + if isinstance(k, (list, tuple)): + for ki in k: + return _kernel_valid(ki) + assert k >= 3 and k % 2 + + +class SelectiveKernelAttn(nn.Module): + def __init__(self, channels, num_paths=2, attn_channels=32, + act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d): + super(SelectiveKernelAttn, self).__init__() + self.num_paths = num_paths + self.pool = nn.AdaptiveAvgPool2d(1) + self.fc_reduce = nn.Conv2d(channels, attn_channels, kernel_size=1, bias=False) + self.bn = norm_layer(attn_channels) + self.act = act_layer(inplace=True) + self.fc_select = nn.Conv2d(attn_channels, channels * num_paths, kernel_size=1, bias=False) + + def forward(self, x): + assert x.shape[1] == self.num_paths + x = torch.sum(x, dim=1) + x = self.pool(x) + x = self.fc_reduce(x) + x = self.bn(x) + x = self.act(x) + x = self.fc_select(x) + B, C, H, W = x.shape + x = x.view(B, self.num_paths, C // self.num_paths, H, W) + x = torch.softmax(x, dim=1) + return x + + +class SelectiveKernelConv(nn.Module): + + def __init__(self, in_channels, out_channels, kernel_size=None, stride=1, dilation=1, groups=1, + attn_reduction=16, min_attn_channels=32, keep_3x3=True, split_input=False, + drop_block=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d): + super(SelectiveKernelConv, self).__init__() + kernel_size = kernel_size or [3, 5] + _kernel_valid(kernel_size) + if not isinstance(kernel_size, list): + kernel_size = [kernel_size] * 2 + if keep_3x3: + dilation = [dilation * (k - 1) // 2 for k in kernel_size] + kernel_size = [3] * len(kernel_size) + else: + dilation = [dilation] * len(kernel_size) + self.num_paths = len(kernel_size) + self.in_channels = in_channels + self.out_channels = out_channels + self.split_input = split_input + if self.split_input: + assert in_channels % self.num_paths == 0 + in_channels = in_channels // self.num_paths + groups = min(out_channels, groups) + + conv_kwargs = dict( + stride=stride, groups=groups, drop_block=drop_block, act_layer=act_layer, norm_layer=norm_layer) + self.paths = nn.ModuleList([ + ConvBnAct(in_channels, out_channels, kernel_size=k, dilation=d, **conv_kwargs) + for k, d in zip(kernel_size, dilation)]) + + attn_channels = max(int(out_channels / attn_reduction), min_attn_channels) + self.attn = SelectiveKernelAttn(out_channels, self.num_paths, attn_channels) + self.drop_block = drop_block + + def forward(self, x): + if self.split_input: + x_split = torch.split(x, self.in_channels // self.num_paths, 1) + x_paths = [op(x_split[i]) for i, op in enumerate(self.paths)] + else: + x_paths = [op(x) for op in self.paths] + x = torch.stack(x_paths, dim=1) + x_attn = self.attn(x) + x = x * x_attn + x = torch.sum(x, dim=1) + return x diff --git a/timm/models/split_batchnorm.py b/timm/models/layers/split_batchnorm.py similarity index 100% rename from timm/models/split_batchnorm.py rename to timm/models/layers/split_batchnorm.py diff --git a/timm/models/test_time_pool.py b/timm/models/layers/test_time_pool.py similarity index 90% rename from timm/models/test_time_pool.py rename to timm/models/layers/test_time_pool.py index ce6ddf07..dcfc66ca 100644 --- a/timm/models/test_time_pool.py +++ b/timm/models/layers/test_time_pool.py @@ -1,3 +1,8 @@ +""" Test Time Pooling (Average-Max Pool) + +Hacked together by Ross Wightman +""" + import logging from torch import nn import torch.nn.functional as F @@ -29,6 +34,8 @@ class TestTimePoolHead(nn.Module): def apply_test_time_pool(model, config, args): test_time_pool = False + if not hasattr(model, 'default_cfg') or not model.default_cfg: + return model, False if not args.no_test_pool and \ config['input_size'][-1] > model.default_cfg['input_size'][-1] and \ config['input_size'][-2] > model.default_cfg['input_size'][-2]: diff --git a/timm/models/mobilenetv3.py b/timm/models/mobilenetv3.py index a6b67532..c74f4224 100644 --- a/timm/models/mobilenetv3.py +++ b/timm/models/mobilenetv3.py @@ -7,15 +7,12 @@ Paper: Searching for MobileNetV3 - https://arxiv.org/abs/1905.02244 Hacked together by Ross Wightman """ -import torch.nn as nn -import torch.nn.functional as F from .efficientnet_builder import * -from .activations import HardSwish, hard_sigmoid from .registry import register_model from .helpers import load_pretrained -from .adaptive_avgmax_pool import SelectAdaptivePool2d -from .conv2d_layers import select_conv2d +from .layers import SelectAdaptivePool2d, create_conv2d +from .layers.activations import HardSwish, hard_sigmoid from .feature_hooks import FeatureHooks from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD @@ -85,7 +82,7 @@ class MobileNetV3(nn.Module): # Stem stem_size = round_channels(stem_size, channel_multiplier) - self.conv_stem = select_conv2d(self._in_chs, stem_size, 3, stride=2, padding=pad_type) + self.conv_stem = create_conv2d(self._in_chs, stem_size, 3, stride=2, padding=pad_type) self.bn1 = norm_layer(stem_size, **norm_kwargs) self.act1 = act_layer(inplace=True) self._in_chs = stem_size @@ -100,7 +97,7 @@ class MobileNetV3(nn.Module): # Head + Pooling self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) - self.conv_head = select_conv2d(self._in_chs, self.num_features, 1, padding=pad_type, bias=head_bias) + self.conv_head = create_conv2d(self._in_chs, self.num_features, 1, padding=pad_type, bias=head_bias) self.act2 = act_layer(inplace=True) # Classifier @@ -165,7 +162,7 @@ class MobileNetV3Features(nn.Module): # Stem stem_size = round_channels(stem_size, channel_multiplier) - self.conv_stem = select_conv2d(self._in_chs, stem_size, 3, stride=2, padding=pad_type) + self.conv_stem = create_conv2d(self._in_chs, stem_size, 3, stride=2, padding=pad_type) self.bn1 = norm_layer(stem_size, **norm_kwargs) self.act1 = act_layer(inplace=True) self._in_chs = stem_size diff --git a/timm/models/nasnet.py b/timm/models/nasnet.py index 009c62d3..8847b1de 100644 --- a/timm/models/nasnet.py +++ b/timm/models/nasnet.py @@ -4,7 +4,7 @@ import torch.nn.functional as F from .registry import register_model from .helpers import load_pretrained -from .adaptive_avgmax_pool import SelectAdaptivePool2d +from .layers import SelectAdaptivePool2d __all__ = ['NASNetALarge'] diff --git a/timm/models/pnasnet.py b/timm/models/pnasnet.py index 396e6157..dc9b3e20 100644 --- a/timm/models/pnasnet.py +++ b/timm/models/pnasnet.py @@ -14,7 +14,7 @@ import torch.nn.functional as F from .registry import register_model from .helpers import load_pretrained -from .adaptive_avgmax_pool import SelectAdaptivePool2d +from .layers import SelectAdaptivePool2d __all__ = ['PNASNet5Large'] diff --git a/timm/models/res2net.py b/timm/models/res2net.py index da20e7a0..134cf00d 100644 --- a/timm/models/res2net.py +++ b/timm/models/res2net.py @@ -8,10 +8,10 @@ import torch import torch.nn as nn import torch.nn.functional as F -from .resnet import ResNet, SEModule +from .resnet import ResNet from .registry import register_model from .helpers import load_pretrained -from .adaptive_avgmax_pool import SelectAdaptivePool2d +from .layers import SEModule from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD __all__ = [] @@ -53,15 +53,16 @@ class Bottle2neck(nn.Module): expansion = 4 def __init__(self, inplanes, planes, stride=1, downsample=None, - cardinality=1, base_width=26, scale=4, use_se=False, - act_layer=nn.ReLU, norm_layer=None, dilation=1, previous_dilation=1, **_): + cardinality=1, base_width=26, scale=4, dilation=1, first_dilation=None, + act_layer=nn.ReLU, norm_layer=None, attn_layer=None, **_): super(Bottle2neck, self).__init__() self.scale = scale self.is_first = stride > 1 or downsample is not None self.num_scales = max(1, scale - 1) width = int(math.floor(planes * (base_width / 64.0))) * cardinality - outplanes = planes * self.expansion self.width = width + outplanes = planes * self.expansion + first_dilation = first_dilation or dilation self.conv1 = nn.Conv2d(inplanes, width * scale, kernel_size=1, bias=False) self.bn1 = norm_layer(width * scale) @@ -70,8 +71,8 @@ class Bottle2neck(nn.Module): bns = [] for i in range(self.num_scales): convs.append(nn.Conv2d( - width, width, kernel_size=3, stride=stride, padding=dilation, - dilation=dilation, groups=cardinality, bias=False)) + width, width, kernel_size=3, stride=stride, padding=first_dilation, + dilation=first_dilation, groups=cardinality, bias=False)) bns.append(norm_layer(width)) self.convs = nn.ModuleList(convs) self.bns = nn.ModuleList(bns) @@ -81,11 +82,14 @@ class Bottle2neck(nn.Module): self.conv3 = nn.Conv2d(width * scale, outplanes, kernel_size=1, bias=False) self.bn3 = norm_layer(outplanes) - self.se = SEModule(outplanes, planes // 4) if use_se else None + self.se = attn_layer(outplanes) if attn_layer is not None else None self.relu = act_layer(inplace=True) self.downsample = downsample + def zero_init_last_bn(self): + nn.init.zeros_(self.bn3.weight) + def forward(self, x): residual = x diff --git a/timm/models/resnet.py b/timm/models/resnet.py index da755373..5b020272 100644 --- a/timm/models/resnet.py +++ b/timm/models/resnet.py @@ -7,14 +7,12 @@ ResNeXt, SE-ResNeXt, SENet, and MXNet Gluon stem/downsample variants, tiered ste """ import math -import torch import torch.nn as nn import torch.nn.functional as F from .registry import register_model from .helpers import load_pretrained -from .adaptive_avgmax_pool import SelectAdaptivePool2d -from .EcaModule import EcaModule +from .layers import SelectAdaptivePool2d, DropBlock2d, DropPath, AvgPool2dSame, create_attn from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD @@ -104,147 +102,179 @@ default_cfgs = { 'ecaresnext26tn_32x4d': _cfg( url='', interpolation='bicubic'), - + 'ecaresnet18': _cfg(), + 'ecaresnet50': _cfg(), } -def _get_padding(kernel_size, stride, dilation=1): +def get_padding(kernel_size, stride, dilation=1): padding = ((stride - 1) + dilation * (kernel_size - 1)) // 2 return padding -class SEModule(nn.Module): - - def __init__(self, channels, reduction_channels): - super(SEModule, self).__init__() - self.avg_pool = nn.AdaptiveAvgPool2d(1) - self.fc1 = nn.Conv2d( - channels, reduction_channels, kernel_size=1, padding=0, bias=True) - self.relu = nn.ReLU(inplace=True) - self.fc2 = nn.Conv2d( - reduction_channels, channels, kernel_size=1, padding=0, bias=True) - - def forward(self, x): - x_se = self.avg_pool(x) - x_se = self.fc1(x_se) - x_se = self.relu(x_se) - x_se = self.fc2(x_se) - return x * x_se.sigmoid() - - class BasicBlock(nn.Module): - __constants__ = ['se', 'downsample'] # for pre 1.4 torchscript compat expansion = 1 - def __init__(self, inplanes, planes, stride=1, downsample=None, - cardinality=1, base_width=64, use_se=False, use_eca = False, - reduce_first=1, dilation=1, previous_dilation=1, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d): + def __init__(self, inplanes, planes, stride=1, downsample=None, cardinality=1, base_width=64, + reduce_first=1, dilation=1, first_dilation=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, + attn_layer=None, drop_block=None, drop_path=None): super(BasicBlock, self).__init__() assert cardinality == 1, 'BasicBlock only supports cardinality of 1' assert base_width == 64, 'BasicBlock doest not support changing base width' first_planes = planes // reduce_first outplanes = planes * self.expansion + first_dilation = first_dilation or dilation self.conv1 = nn.Conv2d( - inplanes, first_planes, kernel_size=3, stride=stride, padding=dilation, - dilation=dilation, bias=False) + inplanes, first_planes, kernel_size=3, stride=stride, padding=first_dilation, + dilation=first_dilation, bias=False) self.bn1 = norm_layer(first_planes) self.act1 = act_layer(inplace=True) self.conv2 = nn.Conv2d( - first_planes, outplanes, kernel_size=3, padding=previous_dilation, - dilation=previous_dilation, bias=False) + first_planes, outplanes, kernel_size=3, padding=dilation, dilation=dilation, bias=False) self.bn2 = norm_layer(outplanes) - self.se = SEModule(outplanes, planes // 4) if use_se else None - self.eca = EcaModule(outplanes) if use_eca else None + self.se = create_attn(attn_layer, outplanes) self.act2 = act_layer(inplace=True) self.downsample = downsample self.stride = stride self.dilation = dilation + self.drop_block = drop_block + self.drop_path = drop_path + + def zero_init_last_bn(self): + nn.init.zeros_(self.bn2.weight) def forward(self, x): residual = x - out = self.conv1(x) - out = self.bn1(out) - out = self.act1(out) - out = self.conv2(out) - out = self.bn2(out) + x = self.conv1(x) + x = self.bn1(x) + if self.drop_block is not None: + x = self.drop_block(x) + x = self.act1(x) + + x = self.conv2(x) + x = self.bn2(x) + if self.drop_block is not None: + x = self.drop_block(x) if self.se is not None: - out = self.se(out) - if self.eca is not None: - out = self.eca(out) + x = self.se(x) - if self.downsample is not None: - residual = self.downsample(x) + if self.drop_path is not None: + x = self.drop_path(x) - out += residual - out = self.act2(out) + if self.downsample is not None: + residual = self.downsample(residual) + x += residual + x = self.act2(x) - return out + return x class Bottleneck(nn.Module): __constants__ = ['se', 'downsample'] # for pre 1.4 torchscript compat expansion = 4 - def __init__(self, inplanes, planes, stride=1, downsample=None, - cardinality=1, base_width=64, use_se=False, use_eca=False, - reduce_first=1, dilation=1, previous_dilation=1, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d): + def __init__(self, inplanes, planes, stride=1, downsample=None, cardinality=1, base_width=64, + reduce_first=1, dilation=1, first_dilation=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, + attn_layer=None, drop_block=None, drop_path=None): super(Bottleneck, self).__init__() width = int(math.floor(planes * (base_width / 64)) * cardinality) first_planes = width // reduce_first outplanes = planes * self.expansion + first_dilation = first_dilation or dilation self.conv1 = nn.Conv2d(inplanes, first_planes, kernel_size=1, bias=False) self.bn1 = norm_layer(first_planes) self.act1 = act_layer(inplace=True) self.conv2 = nn.Conv2d( first_planes, width, kernel_size=3, stride=stride, - padding=dilation, dilation=dilation, groups=cardinality, bias=False) + padding=first_dilation, dilation=first_dilation, groups=cardinality, bias=False) self.bn2 = norm_layer(width) self.act2 = act_layer(inplace=True) self.conv3 = nn.Conv2d(width, outplanes, kernel_size=1, bias=False) self.bn3 = norm_layer(outplanes) - self.se = SEModule(outplanes, planes // 4) if use_se else None - self.eca = Eca_Module(outplanes) if use_eca else None - + self.se = create_attn(attn_layer, outplanes) + self.act3 = act_layer(inplace=True) self.downsample = downsample self.stride = stride self.dilation = dilation + self.drop_block = drop_block + self.drop_path = drop_path + + def zero_init_last_bn(self): + nn.init.zeros_(self.bn3.weight) def forward(self, x): residual = x - out = self.conv1(x) - out = self.bn1(out) - out = self.act1(out) + x = self.conv1(x) + x = self.bn1(x) + if self.drop_block is not None: + x = self.drop_block(x) + x = self.act1(x) - out = self.conv2(out) - out = self.bn2(out) - out = self.act2(out) + x = self.conv2(x) + x = self.bn2(x) + if self.drop_block is not None: + x = self.drop_block(x) + x = self.act2(x) - out = self.conv3(out) - out = self.bn3(out) + x = self.conv3(x) + x = self.bn3(x) + if self.drop_block is not None: + x = self.drop_block(x) if self.se is not None: - out = self.se(out) - if self.eca is not None: - out = self.eca(out) + x = self.se(x) + + if self.drop_path is not None: + x = self.drop_path(x) if self.downsample is not None: - residual = self.downsample(x) + residual = self.downsample(residual) + x += residual + x = self.act3(x) + + return x + + +def downsample_conv( + in_channels, out_channels, kernel_size, stride=1, dilation=1, first_dilation=None, norm_layer=None): + norm_layer = norm_layer or nn.BatchNorm2d + kernel_size = 1 if stride == 1 and dilation == 1 else kernel_size + first_dilation = (first_dilation or dilation) if kernel_size > 1 else 1 + p = get_padding(kernel_size, stride, first_dilation) + + return nn.Sequential(*[ + nn.Conv2d( + in_channels, out_channels, kernel_size, stride=stride, padding=p, dilation=first_dilation, bias=False), + norm_layer(out_channels) + ]) + - out += residual - out = self.act3(out) +def downsample_avg( + in_channels, out_channels, kernel_size, stride=1, dilation=1, first_dilation=None, norm_layer=None): + norm_layer = norm_layer or nn.BatchNorm2d + avg_stride = stride if dilation == 1 else 1 + if stride == 1 and dilation == 1: + pool = nn.Identity() + else: + avg_pool_fn = AvgPool2dSame if avg_stride == 1 and dilation > 1 else nn.AvgPool2d + pool = avg_pool_fn(2, avg_stride, ceil_mode=True, count_include_pad=False) - return out + return nn.Sequential(*[ + pool, + nn.Conv2d(in_channels, out_channels, 1, stride=1, padding=0, bias=False), + norm_layer(out_channels) + ]) class ResNet(nn.Module): @@ -288,10 +318,6 @@ class ResNet(nn.Module): Number of classification classes. in_chans : int, default 3 Number of input (color) channels. - use_se : bool, default False - Enable Squeeze-Excitation module in blocks - use_eca : bool, default False - Enable ECA module in blocks cardinality : int, default 1 Number of convolution groups for 3x3 conv in Bottleneck. base_width : int, default 64 @@ -320,11 +346,11 @@ class ResNet(nn.Module): global_pool : str, default 'avg' Global pooling type. One of 'avg', 'max', 'avgmax', 'catavgmax' """ - def __init__(self, block, layers, num_classes=1000, in_chans=3, use_se=False, use_eca=False, + def __init__(self, block, layers, num_classes=1000, in_chans=3, cardinality=1, base_width=64, stem_width=64, stem_type='', block_reduce_first=1, down_kernel_size=1, avg_down=False, output_stride=32, - act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, drop_rate=0.0, global_pool='avg', - zero_init_last_bn=True, block_args=None): + act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, drop_rate=0.0, drop_path_rate=0., + drop_block_rate=0., global_pool='avg', zero_init_last_bn=True, block_args=None): block_args = block_args or dict() self.num_classes = num_classes deep_stem = 'deep' in stem_type @@ -356,6 +382,9 @@ class ResNet(nn.Module): self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) # Feature Blocks + dp = DropPath(drop_path_rate) if drop_block_rate else None + db_3 = DropBlock2d(drop_block_rate, 7, 0.25) if drop_block_rate else None + db_4 = DropBlock2d(drop_block_rate, 7, 1.00) if drop_block_rate else None channels, strides, dilations = [64, 128, 256, 512], [1, 2, 2, 2], [1] * 4 if output_stride == 16: strides[3] = 1 @@ -365,61 +394,47 @@ class ResNet(nn.Module): dilations[2:4] = [2, 4] else: assert output_stride == 32 - llargs = list(zip(channels, layers, strides, dilations)) - lkwargs = dict( - use_se=use_se, use_eca=use_eca, reduce_first=block_reduce_first, act_layer=act_layer, norm_layer=norm_layer, - avg_down=avg_down, down_kernel_size=down_kernel_size, **block_args) - self.layer1 = self._make_layer(block, *llargs[0], **lkwargs) - self.layer2 = self._make_layer(block, *llargs[1], **lkwargs) - self.layer3 = self._make_layer(block, *llargs[2], **lkwargs) - self.layer4 = self._make_layer(block, *llargs[3], **lkwargs) + layer_args = list(zip(channels, layers, strides, dilations)) + layer_kwargs = dict( + reduce_first=block_reduce_first, act_layer=act_layer, norm_layer=norm_layer, + avg_down=avg_down, down_kernel_size=down_kernel_size, drop_path=dp, **block_args) + self.layer1 = self._make_layer(block, *layer_args[0], **layer_kwargs) + self.layer2 = self._make_layer(block, *layer_args[1], **layer_kwargs) + self.layer3 = self._make_layer(block, drop_block=db_3, *layer_args[2], **layer_kwargs) + self.layer4 = self._make_layer(block, drop_block=db_4, *layer_args[3], **layer_kwargs) # Head (Pooling and Classifier) self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) self.num_features = 512 * block.expansion self.fc = nn.Linear(self.num_features * self.global_pool.feat_mult(), num_classes) - last_bn_name = 'bn3' if 'Bottle' in block.__name__ else 'bn2' for n, m in self.named_modules(): if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') elif isinstance(m, nn.BatchNorm2d): - if zero_init_last_bn and 'layer' in n and last_bn_name in n: - # Initialize weight/gamma of last BN in each residual block to zero - nn.init.constant_(m.weight, 0.) - else: - nn.init.constant_(m.weight, 1.) + nn.init.constant_(m.weight, 1.) nn.init.constant_(m.bias, 0.) + if zero_init_last_bn: + for m in self.modules(): + if hasattr(m, 'zero_init_last_bn'): + m.zero_init_last_bn() def _make_layer(self, block, planes, blocks, stride=1, dilation=1, reduce_first=1, - use_se=False, use_eca=False,avg_down=False, down_kernel_size=1, **kwargs): - norm_layer = kwargs.get('norm_layer') + avg_down=False, down_kernel_size=1, **kwargs): downsample = None - down_kernel_size = 1 if stride == 1 and dilation == 1 else down_kernel_size + first_dilation = 1 if dilation in (1, 2) else 2 if stride != 1 or self.inplanes != planes * block.expansion: - downsample_padding = _get_padding(down_kernel_size, stride) - downsample_layers = [] - conv_stride = stride - if avg_down: - avg_stride = stride if dilation == 1 else 1 - conv_stride = 1 - downsample_layers = [nn.AvgPool2d(avg_stride, avg_stride, ceil_mode=True, count_include_pad=False)] - downsample_layers += [ - nn.Conv2d(self.inplanes, planes * block.expansion, down_kernel_size, - stride=conv_stride, padding=downsample_padding, bias=False), - norm_layer(planes * block.expansion)] - downsample = nn.Sequential(*downsample_layers) + downsample_args = dict( + in_channels=self.inplanes, out_channels=planes * block.expansion, kernel_size=down_kernel_size, + stride=stride, dilation=dilation, first_dilation=first_dilation, norm_layer=kwargs.get('norm_layer')) + downsample = downsample_avg(**downsample_args) if avg_down else downsample_conv(**downsample_args) - first_dilation = 1 if dilation in (1, 2) else 2 - bkwargs = dict( + block_kwargs = dict( cardinality=self.cardinality, base_width=self.base_width, reduce_first=reduce_first, - use_se=use_se, use_eca=use_eca, **kwargs) - layers = [block( - self.inplanes, planes, stride, downsample, dilation=first_dilation, previous_dilation=dilation, **bkwargs)] + dilation=dilation, **kwargs) + layers = [block(self.inplanes, planes, stride, downsample, first_dilation=first_dilation, **block_kwargs)] self.inplanes = planes * block.expansion - for i in range(1, blocks): - layers.append(block( - self.inplanes, planes, dilation=dilation, previous_dilation=dilation, **bkwargs)) + layers += [block(self.inplanes, planes, **block_kwargs) for _ in range(1, blocks)] return nn.Sequential(*layers) @@ -447,8 +462,8 @@ class ResNet(nn.Module): def forward(self, x): x = self.forward_features(x) x = self.global_pool(x).flatten(1) - if self.drop_rate > 0.: - x = F.dropout(x, p=self.drop_rate, training=self.training) + if self.drop_rate: + x = F.dropout(x, p=float(self.drop_rate), training=self.training) x = self.fc(x) return x @@ -920,9 +935,8 @@ def seresnext26d_32x4d(pretrained=False, num_classes=1000, in_chans=3, **kwargs) """ default_cfg = default_cfgs['seresnext26d_32x4d'] model = ResNet( - Bottleneck, [2, 2, 2, 2], cardinality=32, base_width=4, - stem_width=32, stem_type='deep', avg_down=True, use_se=True, - num_classes=num_classes, in_chans=in_chans, **kwargs) + Bottleneck, [2, 2, 2, 2], cardinality=32, base_width=4, stem_width=32, stem_type='deep', avg_down=True, + num_classes=num_classes, in_chans=in_chans, block_args=dict(attn_layer='se'), **kwargs) model.default_cfg = default_cfg if pretrained: load_pretrained(model, default_cfg, num_classes, in_chans) @@ -938,8 +952,8 @@ def seresnext26t_32x4d(pretrained=False, num_classes=1000, in_chans=3, **kwargs) default_cfg = default_cfgs['seresnext26t_32x4d'] model = ResNet( Bottleneck, [2, 2, 2, 2], cardinality=32, base_width=4, - stem_width=32, stem_type='deep_tiered', avg_down=True, use_se=True, - num_classes=num_classes, in_chans=in_chans, **kwargs) + stem_width=32, stem_type='deep_tiered', avg_down=True, + num_classes=num_classes, in_chans=in_chans, block_args=dict(attn_layer='se'), **kwargs) model.default_cfg = default_cfg if pretrained: load_pretrained(model, default_cfg, num_classes, in_chans) @@ -955,25 +969,55 @@ def seresnext26tn_32x4d(pretrained=False, num_classes=1000, in_chans=3, **kwargs default_cfg = default_cfgs['seresnext26tn_32x4d'] model = ResNet( Bottleneck, [2, 2, 2, 2], cardinality=32, base_width=4, - stem_width=32, stem_type='deep_tiered_narrow', avg_down=True, use_se=True, - num_classes=num_classes, in_chans=in_chans, **kwargs) + stem_width=32, stem_type='deep_tiered_narrow', avg_down=True, + num_classes=num_classes, in_chans=in_chans, block_args=dict(attn_layer='se'), **kwargs) model.default_cfg = default_cfg if pretrained: load_pretrained(model, default_cfg, num_classes, in_chans) return model + @register_model def ecaresnext26tn_32x4d(pretrained=False, num_classes=1000, in_chans=3, **kwargs): - """Constructs a eca-ResNeXt-26-TN model. + """Constructs an ECA-ResNeXt-26-TN model. This is technically a 28 layer ResNet, like a 'D' bag-of-tricks model but with tiered 24, 32, 64 channels in the deep stem. The channel number of the middle stem conv is narrower than the 'T' variant. this model replaces SE module with the ECA module """ default_cfg = default_cfgs['ecaresnext26tn_32x4d'] + block_args = dict(attn_layer='eca') model = ResNet( Bottleneck, [2, 2, 2, 2], cardinality=32, base_width=4, - stem_width=32, stem_type='deep_tiered_narrow', avg_down=True, use_eca=True, - num_classes=num_classes, in_chans=in_chans, **kwargs) + stem_width=32, stem_type='deep_tiered_narrow', avg_down=True, + num_classes=num_classes, in_chans=in_chans, block_args=block_args, **kwargs) + model.default_cfg = default_cfg + if pretrained: + load_pretrained(model, default_cfg, num_classes, in_chans) + return model + + +@register_model +def ecaresnet18(pretrained=False, num_classes=1000, in_chans=3, **kwargs): + """ Constructs an ECA-ResNet-18 model. + """ + default_cfg = default_cfgs['ecaresnet18'] + block_args = dict(attn_layer='eca') + model = ResNet( + BasicBlock, [2, 2, 2, 2], num_classes=num_classes, in_chans=in_chans, block_args=block_args, **kwargs) + model.default_cfg = default_cfg + if pretrained: + load_pretrained(model, default_cfg, num_classes, in_chans) + return model + + +@register_model +def ecaresnet50(pretrained=False, num_classes=1000, in_chans=3, **kwargs): + """Constructs an ECA-ResNet-50 model. + """ + default_cfg = default_cfgs['ecaresnet50'] + block_args = dict(attn_layer='eca') + model = ResNet( + Bottleneck, [3, 4, 6, 3], num_classes=num_classes, in_chans=in_chans, block_args=block_args, **kwargs) model.default_cfg = default_cfg if pretrained: load_pretrained(model, default_cfg, num_classes, in_chans) diff --git a/timm/models/selecsls.py b/timm/models/selecsls.py index 17796700..2f369e99 100644 --- a/timm/models/selecsls.py +++ b/timm/models/selecsls.py @@ -17,7 +17,7 @@ import torch.nn.functional as F from .registry import register_model from .helpers import load_pretrained -from .adaptive_avgmax_pool import SelectAdaptivePool2d +from .layers import SelectAdaptivePool2d from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD __all__ = ['SelecSLS'] # model_registry will add each entrypoint fn to this diff --git a/timm/models/senet.py b/timm/models/senet.py index 90ef5ae1..efbf4657 100644 --- a/timm/models/senet.py +++ b/timm/models/senet.py @@ -16,7 +16,7 @@ import torch.nn.functional as F from .registry import register_model from .helpers import load_pretrained -from .adaptive_avgmax_pool import SelectAdaptivePool2d +from .layers import SelectAdaptivePool2d from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD __all__ = ['SENet'] diff --git a/timm/models/sknet.py b/timm/models/sknet.py new file mode 100644 index 00000000..6db37da5 --- /dev/null +++ b/timm/models/sknet.py @@ -0,0 +1,240 @@ +import math + +from torch import nn as nn + +from .registry import register_model +from .helpers import load_pretrained +from .layers import SelectiveKernelConv, ConvBnAct, create_attn +from .resnet import ResNet +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD + + +def _cfg(url='', **kwargs): + return { + 'url': url, + 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7), + 'crop_pct': 0.875, 'interpolation': 'bilinear', + 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, + 'first_conv': 'conv1', 'classifier': 'fc', + **kwargs + } + + +default_cfgs = { + 'skresnet18': _cfg(url=''), + 'skresnet26d': _cfg(), + 'skresnet50': _cfg(), + 'skresnet50d': _cfg(), + 'skresnext50_32x4d': _cfg(), +} + + +class SelectiveKernelBasic(nn.Module): + expansion = 1 + + def __init__(self, inplanes, planes, stride=1, downsample=None, cardinality=1, base_width=64, + sk_kwargs=None, reduce_first=1, dilation=1, first_dilation=None, + drop_block=None, drop_path=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, attn_layer=None): + super(SelectiveKernelBasic, self).__init__() + + sk_kwargs = sk_kwargs or {} + conv_kwargs = dict(drop_block=drop_block, act_layer=act_layer, norm_layer=norm_layer) + assert cardinality == 1, 'BasicBlock only supports cardinality of 1' + assert base_width == 64, 'BasicBlock doest not support changing base width' + first_planes = planes // reduce_first + outplanes = planes * self.expansion + first_dilation = first_dilation or dilation + + _selective_first = True # FIXME temporary, for experiments + if _selective_first: + self.conv1 = SelectiveKernelConv( + inplanes, first_planes, stride=stride, dilation=first_dilation, **conv_kwargs, **sk_kwargs) + conv_kwargs['act_layer'] = None + self.conv2 = ConvBnAct( + first_planes, outplanes, kernel_size=3, dilation=dilation, **conv_kwargs) + else: + self.conv1 = ConvBnAct( + inplanes, first_planes, kernel_size=3, stride=stride, dilation=first_dilation, **conv_kwargs) + conv_kwargs['act_layer'] = None + self.conv2 = SelectiveKernelConv( + first_planes, outplanes, dilation=dilation, **conv_kwargs, **sk_kwargs) + self.se = create_attn(attn_layer, outplanes) + self.act = act_layer(inplace=True) + self.downsample = downsample + self.stride = stride + self.dilation = dilation + self.drop_block = drop_block + self.drop_path = drop_path + + def zero_init_last_bn(self): + nn.init.zeros_(self.conv2.bn.weight) + + def forward(self, x): + residual = x + x = self.conv1(x) + x = self.conv2(x) + if self.se is not None: + x = self.se(x) + if self.drop_path is not None: + x = self.drop_path(x) + if self.downsample is not None: + residual = self.downsample(residual) + x += residual + x = self.act(x) + return x + + +class SelectiveKernelBottleneck(nn.Module): + expansion = 4 + + def __init__(self, inplanes, planes, stride=1, downsample=None, + cardinality=1, base_width=64, sk_kwargs=None, reduce_first=1, dilation=1, first_dilation=None, + drop_block=None, drop_path=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, attn_layer=None): + super(SelectiveKernelBottleneck, self).__init__() + + sk_kwargs = sk_kwargs or {} + conv_kwargs = dict(drop_block=drop_block, act_layer=act_layer, norm_layer=norm_layer) + width = int(math.floor(planes * (base_width / 64)) * cardinality) + first_planes = width // reduce_first + outplanes = planes * self.expansion + first_dilation = first_dilation or dilation + + self.conv1 = ConvBnAct(inplanes, first_planes, kernel_size=1, **conv_kwargs) + self.conv2 = SelectiveKernelConv( + first_planes, width, stride=stride, dilation=first_dilation, groups=cardinality, + **conv_kwargs, **sk_kwargs) + conv_kwargs['act_layer'] = None + self.conv3 = ConvBnAct(width, outplanes, kernel_size=1, **conv_kwargs) + self.se = create_attn(attn_layer, outplanes) + self.act = act_layer(inplace=True) + self.downsample = downsample + self.stride = stride + self.dilation = dilation + self.drop_block = drop_block + self.drop_path = drop_path + + def zero_init_last_bn(self): + nn.init.zeros_(self.conv3.bn.weight) + + def forward(self, x): + residual = x + x = self.conv1(x) + x = self.conv2(x) + x = self.conv3(x) + if self.se is not None: + x = self.se(x) + if self.drop_path is not None: + x = self.drop_path(x) + if self.downsample is not None: + residual = self.downsample(residual) + x += residual + x = self.act(x) + return x + + +@register_model +def skresnet18(pretrained=False, num_classes=1000, in_chans=3, **kwargs): + """Constructs a ResNet-18 model. + """ + default_cfg = default_cfgs['skresnet18'] + sk_kwargs = dict( + min_attn_channels=16, + ) + model = ResNet( + SelectiveKernelBasic, [2, 2, 2, 2], num_classes=num_classes, in_chans=in_chans, + block_args=dict(sk_kwargs=sk_kwargs), **kwargs) + model.default_cfg = default_cfg + if pretrained: + load_pretrained(model, default_cfg, num_classes, in_chans) + return model + + +@register_model +def sksresnet18(pretrained=False, num_classes=1000, in_chans=3, **kwargs): + """Constructs a ResNet-18 model. + """ + default_cfg = default_cfgs['skresnet18'] + sk_kwargs = dict( + min_attn_channels=16, + attn_reduction=8, + split_input=True + ) + model = ResNet( + SelectiveKernelBasic, [2, 2, 2, 2], num_classes=num_classes, in_chans=in_chans, + block_args=dict(sk_kwargs=sk_kwargs), zero_init_last_bn=False, **kwargs) + model.default_cfg = default_cfg + if pretrained: + load_pretrained(model, default_cfg, num_classes, in_chans) + return model + + +@register_model +def skresnet26d(pretrained=False, num_classes=1000, in_chans=3, **kwargs): + """Constructs a ResNet-26 model. + """ + default_cfg = default_cfgs['skresnet26d'] + sk_kwargs = dict( + keep_3x3=False, + ) + model = ResNet( + SelectiveKernelBottleneck, [2, 2, 2, 2], stem_width=32, stem_type='deep', avg_down=True, + num_classes=num_classes, in_chans=in_chans, block_args=dict(sk_kwargs=sk_kwargs), zero_init_last_bn=False + **kwargs) + model.default_cfg = default_cfg + if pretrained: + load_pretrained(model, default_cfg, num_classes, in_chans) + return model + + +@register_model +def skresnet50(pretrained=False, num_classes=1000, in_chans=3, **kwargs): + """Constructs a Select Kernel ResNet-50 model. + Based on config in "Compounding the Performance Improvements of Assembled Techniques in a + Convolutional Neural Network" + """ + sk_kwargs = dict( + attn_reduction=2, + ) + default_cfg = default_cfgs['skresnet50'] + model = ResNet( + SelectiveKernelBottleneck, [3, 4, 6, 3], num_classes=num_classes, in_chans=in_chans, + block_args=dict(sk_kwargs=sk_kwargs), zero_init_last_bn=False, **kwargs) + model.default_cfg = default_cfg + if pretrained: + load_pretrained(model, default_cfg, num_classes, in_chans) + return model + + +@register_model +def skresnet50d(pretrained=False, num_classes=1000, in_chans=3, **kwargs): + """Constructs a Select Kernel ResNet-50-D model. + Based on config in "Compounding the Performance Improvements of Assembled Techniques in a + Convolutional Neural Network" + """ + sk_kwargs = dict( + attn_reduction=2, + ) + default_cfg = default_cfgs['skresnet50d'] + model = ResNet( + SelectiveKernelBottleneck, [3, 4, 6, 3], stem_width=32, stem_type='deep', avg_down=True, + num_classes=num_classes, in_chans=in_chans, block_args=dict(sk_kwargs=sk_kwargs), + zero_init_last_bn=False, **kwargs) + model.default_cfg = default_cfg + if pretrained: + load_pretrained(model, default_cfg, num_classes, in_chans) + return model + + +@register_model +def skresnext50_32x4d(pretrained=False, num_classes=1000, in_chans=3, **kwargs): + """Constructs a Select Kernel ResNeXt50-32x4d model. This should be equivalent to + the SKNet50 model in the Select Kernel Paper + """ + default_cfg = default_cfgs['skresnext50_32x4d'] + model = ResNet( + SelectiveKernelBottleneck, [3, 4, 6, 3], cardinality=32, base_width=4, + num_classes=num_classes, in_chans=in_chans, zero_init_last_bn=False, **kwargs) + model.default_cfg = default_cfg + if pretrained: + load_pretrained(model, default_cfg, num_classes, in_chans) + return model diff --git a/timm/models/xception.py b/timm/models/xception.py index 2dc334fa..cb98bbc9 100644 --- a/timm/models/xception.py +++ b/timm/models/xception.py @@ -29,7 +29,7 @@ import torch.nn.functional as F from .registry import register_model from .helpers import load_pretrained -from .adaptive_avgmax_pool import SelectAdaptivePool2d +from .layers import SelectAdaptivePool2d __all__ = ['Xception']