diff --git a/timm/models/efficientnet.py b/timm/models/efficientnet.py index 7261fe10..c5dcacd3 100644 --- a/timm/models/efficientnet.py +++ b/timm/models/efficientnet.py @@ -27,7 +27,8 @@ from .efficientnet_builder import * from .feature_hooks import FeatureHooks from .registry import register_model from .helpers import load_pretrained -from .layers import SelectAdaptivePool2d, select_conv2d +from .layers import SelectAdaptivePool2d +from timm.models.layers import select_conv2d from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD diff --git a/timm/models/efficientnet_blocks.py b/timm/models/efficientnet_blocks.py index 78d451be..a231fa31 100644 --- a/timm/models/efficientnet_blocks.py +++ b/timm/models/efficientnet_blocks.py @@ -1,11 +1,8 @@ - -from functools import partial - import torch import torch.nn as nn -import torch.nn.functional as F +from torch.nn import functional as F from .layers.activations import sigmoid -from .layers.conv2d_layers import * +from .layers import select_conv2d # Defaults used for Google/Tensorflow training of mobile networks /w RMSprop as per @@ -72,7 +69,7 @@ def round_channels(channels, multiplier=1.0, divisor=8, channel_min=None): return make_divisible(channels, divisor, channel_min) -def drop_connect(inputs, training=False, drop_connect_rate=0.): +def drop_connect(inputs, training: bool = False, drop_connect_rate: float = 0.): """Apply drop connect.""" if not training: return inputs @@ -160,7 +157,7 @@ class DepthwiseSeparableConv(nn.Module): norm_layer=nn.BatchNorm2d, norm_kwargs=None, drop_connect_rate=0.): super(DepthwiseSeparableConv, self).__init__() norm_kwargs = norm_kwargs or {} - self.has_se = se_ratio is not None and se_ratio > 0. + has_se = se_ratio is not None and se_ratio > 0. self.has_residual = (stride == 1 and in_chs == out_chs) and not noskip self.has_pw_act = pw_act # activation after point-wise conv self.drop_connect_rate = drop_connect_rate @@ -171,9 +168,11 @@ class DepthwiseSeparableConv(nn.Module): self.act1 = act_layer(inplace=True) # Squeeze-and-excitation - if self.has_se: + if has_se: se_kwargs = resolve_se_args(se_kwargs, in_chs, act_layer) self.se = SqueezeExcite(in_chs, se_ratio=se_ratio, **se_kwargs) + else: + self.se = None self.conv_pw = select_conv2d(in_chs, out_chs, pw_kernel_size, padding=pad_type) self.bn2 = norm_layer(out_chs, **norm_kwargs) @@ -193,7 +192,7 @@ class DepthwiseSeparableConv(nn.Module): x = self.bn1(x) x = self.act1(x) - if self.has_se: + if self.se is not None: x = self.se(x) x = self.conv_pw(x) @@ -219,7 +218,7 @@ class InvertedResidual(nn.Module): norm_kwargs = norm_kwargs or {} conv_kwargs = conv_kwargs or {} mid_chs = make_divisible(in_chs * exp_ratio) - self.has_se = se_ratio is not None and se_ratio > 0. + has_se = se_ratio is not None and se_ratio > 0. self.has_residual = (in_chs == out_chs and stride == 1) and not noskip self.drop_connect_rate = drop_connect_rate @@ -236,9 +235,11 @@ class InvertedResidual(nn.Module): self.act2 = act_layer(inplace=True) # Squeeze-and-excitation - if self.has_se: + if has_se: se_kwargs = resolve_se_args(se_kwargs, in_chs, act_layer) self.se = SqueezeExcite(mid_chs, se_ratio=se_ratio, **se_kwargs) + else: + self.se = None # Point-wise linear projection self.conv_pwl = select_conv2d(mid_chs, out_chs, pw_kernel_size, padding=pad_type, **conv_kwargs) @@ -269,7 +270,7 @@ class InvertedResidual(nn.Module): x = self.act2(x) # Squeeze-and-excitation - if self.has_se: + if self.se is not None: x = self.se(x) # Point-wise linear projection @@ -323,7 +324,7 @@ class CondConvResidual(InvertedResidual): x = self.act2(x) # Squeeze-and-excitation - if self.has_se: + if self.se is not None: x = self.se(x) # Point-wise linear projection @@ -350,7 +351,7 @@ class EdgeResidual(nn.Module): mid_chs = make_divisible(fake_in_chs * exp_ratio) else: mid_chs = make_divisible(in_chs * exp_ratio) - self.has_se = se_ratio is not None and se_ratio > 0. + has_se = se_ratio is not None and se_ratio > 0. self.has_residual = (in_chs == out_chs and stride == 1) and not noskip self.drop_connect_rate = drop_connect_rate @@ -360,9 +361,11 @@ class EdgeResidual(nn.Module): self.act1 = act_layer(inplace=True) # Squeeze-and-excitation - if self.has_se: + if has_se: se_kwargs = resolve_se_args(se_kwargs, in_chs, act_layer) self.se = SqueezeExcite(mid_chs, se_ratio=se_ratio, **se_kwargs) + else: + self.se = None # Point-wise linear projection self.conv_pwl = select_conv2d( @@ -389,7 +392,7 @@ class EdgeResidual(nn.Module): x = self.act1(x) # Squeeze-and-excitation - if self.has_se: + if self.se is not None: x = self.se(x) # Point-wise linear projection diff --git a/timm/models/efficientnet_builder.py b/timm/models/efficientnet_builder.py index b159eefe..954420fb 100644 --- a/timm/models/efficientnet_builder.py +++ b/timm/models/efficientnet_builder.py @@ -5,7 +5,8 @@ from collections.__init__ import OrderedDict from copy import deepcopy import torch.nn as nn -from .layers.activations import sigmoid, HardSwish, Swish +from .layers import CondConv2d, get_condconv_initializer +from .layers.activations import HardSwish, Swish from .efficientnet_blocks import * diff --git a/timm/models/layers/__init__.py b/timm/models/layers/__init__.py index 8e9fcae2..79aa9ac2 100644 --- a/timm/models/layers/__init__.py +++ b/timm/models/layers/__init__.py @@ -1,8 +1,12 @@ -from .conv2d_layers import select_conv2d, MixedConv2d, CondConv2d, ConvBnAct, SelectiveKernelConv +from .conv_bn_act import ConvBnAct +from .mixed_conv2d import MixedConv2d +from .cond_conv2d import CondConv2d, get_condconv_initializer +from .select_conv2d import select_conv2d +from .selective_kernel import SelectiveKernelConv from .eca import EcaModule, CecaModule from .activations import * from .adaptive_avgmax_pool import \ adaptive_avgmax_pool2d, select_adaptive_pool2d, AdaptiveAvgMaxPool2d, SelectAdaptivePool2d -from .nn_ops import DropBlock2d, DropPath +from .drop import DropBlock2d, DropPath from .test_time_pool import TestTimePoolHead, apply_test_time_pool from .split_batchnorm import SplitBatchNorm2d, convert_splitbn_model diff --git a/timm/models/layers/activations.py b/timm/models/layers/activations.py index aafa290c..165b7951 100644 --- a/timm/models/layers/activations.py +++ b/timm/models/layers/activations.py @@ -1,9 +1,18 @@ +""" Activations + +A collection of activations fn and modules with a common interface so that they can +easily be swapped. All have an `inplace` arg even if not used. + +Hacked together by Ross Wightman +""" + + import torch from torch import nn as nn from torch.nn import functional as F -_USE_MEM_EFFICIENT_ISH = True +_USE_MEM_EFFICIENT_ISH = False if _USE_MEM_EFFICIENT_ISH: # This version reduces memory overhead of Swish during training by # recomputing torch.sigmoid(x) in backward instead of saving it. @@ -66,20 +75,20 @@ if _USE_MEM_EFFICIENT_ISH: return MishJitAutoFn.apply(x) else: - def swish(x, inplace=False): + def swish(x, inplace: bool = False): """Swish - Described in: https://arxiv.org/abs/1710.05941 """ return x.mul_(x.sigmoid()) if inplace else x.mul(x.sigmoid()) - def mish(x, _inplace=False): + def mish(x, _inplace: bool = False): """Mish: A Self Regularized Non-Monotonic Neural Activation Function - https://arxiv.org/abs/1908.08681 """ return x.mul(F.softplus(x).tanh()) class Swish(nn.Module): - def __init__(self, inplace=False): + def __init__(self, inplace: bool = False): super(Swish, self).__init__() self.inplace = inplace @@ -88,7 +97,7 @@ class Swish(nn.Module): class Mish(nn.Module): - def __init__(self, inplace=False): + def __init__(self, inplace: bool = False): super(Mish, self).__init__() self.inplace = inplace @@ -96,13 +105,13 @@ class Mish(nn.Module): return mish(x, self.inplace) -def sigmoid(x, inplace=False): +def sigmoid(x, inplace: bool = False): return x.sigmoid_() if inplace else x.sigmoid() # PyTorch has this, but not with a consistent inplace argmument interface class Sigmoid(nn.Module): - def __init__(self, inplace=False): + def __init__(self, inplace: bool = False): super(Sigmoid, self).__init__() self.inplace = inplace @@ -110,13 +119,13 @@ class Sigmoid(nn.Module): return x.sigmoid_() if self.inplace else x.sigmoid() -def tanh(x, inplace=False): +def tanh(x, inplace: bool = False): return x.tanh_() if inplace else x.tanh() # PyTorch has this, but not with a consistent inplace argmument interface class Tanh(nn.Module): - def __init__(self, inplace=False): + def __init__(self, inplace: bool = False): super(Tanh, self).__init__() self.inplace = inplace @@ -124,13 +133,13 @@ class Tanh(nn.Module): return x.tanh_() if self.inplace else x.tanh() -def hard_swish(x, inplace=False): +def hard_swish(x, inplace: bool = False): inner = F.relu6(x + 3.).div_(6.) return x.mul_(inner) if inplace else x.mul(inner) class HardSwish(nn.Module): - def __init__(self, inplace=False): + def __init__(self, inplace: bool = False): super(HardSwish, self).__init__() self.inplace = inplace @@ -138,7 +147,7 @@ class HardSwish(nn.Module): return hard_swish(x, self.inplace) -def hard_sigmoid(x, inplace=False): +def hard_sigmoid(x, inplace: bool = False): if inplace: return x.add_(3.).clamp_(0., 6.).div_(6.) else: @@ -146,7 +155,7 @@ def hard_sigmoid(x, inplace=False): class HardSigmoid(nn.Module): - def __init__(self, inplace=False): + def __init__(self, inplace: bool = False): super(HardSigmoid, self).__init__() self.inplace = inplace diff --git a/timm/models/layers/cond_conv2d.py b/timm/models/layers/cond_conv2d.py new file mode 100644 index 00000000..d6cba889 --- /dev/null +++ b/timm/models/layers/cond_conv2d.py @@ -0,0 +1,118 @@ +""" Conditional Convolution + +Hacked together by Ross Wightman +""" + +import math +from functools import partial +import numpy as np +import torch +from torch import nn as nn +from torch.nn import functional as F + +from .conv2d_same import get_padding_value, conv2d_same +from .conv_helpers import tup_pair + + +def get_condconv_initializer(initializer, num_experts, expert_shape): + def condconv_initializer(weight): + """CondConv initializer function.""" + num_params = np.prod(expert_shape) + if (len(weight.shape) != 2 or weight.shape[0] != num_experts or + weight.shape[1] != num_params): + raise (ValueError( + 'CondConv variables must have shape [num_experts, num_params]')) + for i in range(num_experts): + initializer(weight[i].view(expert_shape)) + return condconv_initializer + + +class CondConv2d(nn.Module): + """ Conditional Convolution + Inspired by: https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/condconv/condconv_layers.py + + Grouped convolution hackery for parallel execution of the per-sample kernel filters inspired by this discussion: + https://github.com/pytorch/pytorch/issues/17983 + """ + __constants__ = ['bias', 'in_channels', 'out_channels', 'dynamic_padding'] + + def __init__(self, in_channels, out_channels, kernel_size=3, + stride=1, padding='', dilation=1, groups=1, bias=False, num_experts=4): + super(CondConv2d, self).__init__() + + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = tup_pair(kernel_size) + self.stride = tup_pair(stride) + padding_val, is_padding_dynamic = get_padding_value( + padding, kernel_size, stride=stride, dilation=dilation) + self.dynamic_padding = is_padding_dynamic # if in forward to work with torchscript + self.padding = tup_pair(padding_val) + self.dilation = tup_pair(dilation) + self.groups = groups + self.num_experts = num_experts + + self.weight_shape = (self.out_channels, self.in_channels // self.groups) + self.kernel_size + weight_num_param = 1 + for wd in self.weight_shape: + weight_num_param *= wd + self.weight = torch.nn.Parameter(torch.Tensor(self.num_experts, weight_num_param)) + + if bias: + self.bias_shape = (self.out_channels,) + self.bias = torch.nn.Parameter(torch.Tensor(self.num_experts, self.out_channels)) + else: + self.register_parameter('bias', None) + + self.reset_parameters() + + def reset_parameters(self): + init_weight = get_condconv_initializer( + partial(nn.init.kaiming_uniform_, a=math.sqrt(5)), self.num_experts, self.weight_shape) + init_weight(self.weight) + if self.bias is not None: + fan_in = np.prod(self.weight_shape[1:]) + bound = 1 / math.sqrt(fan_in) + init_bias = get_condconv_initializer( + partial(nn.init.uniform_, a=-bound, b=bound), self.num_experts, self.bias_shape) + init_bias(self.bias) + + def forward(self, x, routing_weights): + B, C, H, W = x.shape + weight = torch.matmul(routing_weights, self.weight) + new_weight_shape = (B * self.out_channels, self.in_channels // self.groups) + self.kernel_size + weight = weight.view(new_weight_shape) + bias = None + if self.bias is not None: + bias = torch.matmul(routing_weights, self.bias) + bias = bias.view(B * self.out_channels) + # move batch elements with channels so each batch element can be efficiently convolved with separate kernel + x = x.view(1, B * C, H, W) + if self.dynamic_padding: + out = conv2d_same( + x, weight, bias, stride=self.stride, padding=self.padding, + dilation=self.dilation, groups=self.groups * B) + else: + out = F.conv2d( + x, weight, bias, stride=self.stride, padding=self.padding, + dilation=self.dilation, groups=self.groups * B) + out = out.permute([1, 0, 2, 3]).view(B, self.out_channels, out.shape[-2], out.shape[-1]) + + # Literal port (from TF definition) + # x = torch.split(x, 1, 0) + # weight = torch.split(weight, 1, 0) + # if self.bias is not None: + # bias = torch.matmul(routing_weights, self.bias) + # bias = torch.split(bias, 1, 0) + # else: + # bias = [None] * B + # out = [] + # for xi, wi, bi in zip(x, weight, bias): + # wi = wi.view(*self.weight_shape) + # if bi is not None: + # bi = bi.view(*self.bias_shape) + # out.append(self.conv_fn( + # xi, wi, bi, stride=self.stride, padding=self.padding, + # dilation=self.dilation, groups=self.groups)) + # out = torch.cat(out, 0) + return out diff --git a/timm/models/layers/conv2d_layers.py b/timm/models/layers/conv2d_layers.py deleted file mode 100644 index feaf653c..00000000 --- a/timm/models/layers/conv2d_layers.py +++ /dev/null @@ -1,361 +0,0 @@ -from collections import OrderedDict - -import torch -import torch.nn as nn -import torch.nn.functional as F -from torch._six import container_abcs -from itertools import repeat -from functools import partial -import numpy as np -import math - - -# Tuple helpers ripped from PyTorch -def _ntuple(n): - def parse(x): - if isinstance(x, container_abcs.Iterable): - return x - return tuple(repeat(x, n)) - return parse - - -_single = _ntuple(1) -_pair = _ntuple(2) -_triple = _ntuple(3) -_quadruple = _ntuple(4) - - -def _is_static_pad(kernel_size, stride=1, dilation=1, **_): - return stride == 1 and (dilation * (kernel_size - 1)) % 2 == 0 - - -def _get_padding(kernel_size, stride=1, dilation=1, **_): - padding = ((stride - 1) + dilation * (kernel_size - 1)) // 2 - return padding - - -def _calc_same_pad(i, k, s, d): - return max((math.ceil(i / s) - 1) * s + (k - 1) * d + 1 - i, 0) - - -def _split_channels(num_chan, num_groups): - split = [num_chan // num_groups for _ in range(num_groups)] - split[0] += num_chan - sum(split) - return split - - -# pylint: disable=unused-argument -def conv2d_same(x, weight, bias=None, stride=(1, 1), padding=(0, 0), dilation=(1, 1), groups=1): - ih, iw = x.size()[-2:] - kh, kw = weight.size()[-2:] - pad_h = _calc_same_pad(ih, kh, stride[0], dilation[0]) - pad_w = _calc_same_pad(iw, kw, stride[1], dilation[1]) - if pad_h > 0 or pad_w > 0: - x = F.pad(x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2]) - return F.conv2d(x, weight, bias, stride, (0, 0), dilation, groups) - - -class Conv2dSame(nn.Conv2d): - """ Tensorflow like 'SAME' convolution wrapper for 2D convolutions - """ - - # pylint: disable=unused-argument - def __init__(self, in_channels, out_channels, kernel_size, stride=1, - padding=0, dilation=1, groups=1, bias=True): - super(Conv2dSame, self).__init__( - in_channels, out_channels, kernel_size, stride, 0, dilation, groups, bias) - - def forward(self, x): - return conv2d_same(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups) - - -def get_padding_value(padding, kernel_size, **kwargs): - dynamic = False - if isinstance(padding, str): - # for any string padding, the padding will be calculated for you, one of three ways - padding = padding.lower() - if padding == 'same': - # TF compatible 'SAME' padding, has a performance and GPU memory allocation impact - if _is_static_pad(kernel_size, **kwargs): - # static case, no extra overhead - padding = _get_padding(kernel_size, **kwargs) - else: - # dynamic 'SAME' padding, has runtime/GPU memory overhead - padding = 0 - dynamic = True - elif padding == 'valid': - # 'VALID' padding, same as padding=0 - padding = 0 - else: - # Default to PyTorch style 'same'-ish symmetric padding - padding = _get_padding(kernel_size, **kwargs) - return padding, dynamic - - -def create_conv2d_pad(in_chs, out_chs, kernel_size, **kwargs): - padding = kwargs.pop('padding', '') - kwargs.setdefault('bias', False) - padding, is_dynamic = get_padding_value(padding, kernel_size, **kwargs) - if is_dynamic: - return Conv2dSame(in_chs, out_chs, kernel_size, **kwargs) - else: - return nn.Conv2d(in_chs, out_chs, kernel_size, padding=padding, **kwargs) - - -class MixedConv2d(nn.ModuleDict): - """ Mixed Grouped Convolution - Based on MDConv and GroupedConv in MixNet impl: - https://github.com/tensorflow/tpu/blob/master/models/official/mnasnet/mixnet/custom_layers.py - """ - def __init__(self, in_channels, out_channels, kernel_size=3, - stride=1, padding='', dilation=1, depthwise=False, **kwargs): - super(MixedConv2d, self).__init__() - - kernel_size = kernel_size if isinstance(kernel_size, list) else [kernel_size] - num_groups = len(kernel_size) - in_splits = _split_channels(in_channels, num_groups) - out_splits = _split_channels(out_channels, num_groups) - self.in_channels = sum(in_splits) - self.out_channels = sum(out_splits) - for idx, (k, in_ch, out_ch) in enumerate(zip(kernel_size, in_splits, out_splits)): - conv_groups = out_ch if depthwise else 1 - # use add_module to keep key space clean - self.add_module( - str(idx), - create_conv2d_pad( - in_ch, out_ch, k, stride=stride, - padding=padding, dilation=dilation, groups=conv_groups, **kwargs) - ) - self.splits = in_splits - - def forward(self, x): - x_split = torch.split(x, self.splits, 1) - x_out = [c(x_split[i]) for i, c in enumerate(self.values())] - x = torch.cat(x_out, 1) - return x - - -def get_condconv_initializer(initializer, num_experts, expert_shape): - def condconv_initializer(weight): - """CondConv initializer function.""" - num_params = np.prod(expert_shape) - if (len(weight.shape) != 2 or weight.shape[0] != num_experts or - weight.shape[1] != num_params): - raise (ValueError( - 'CondConv variables must have shape [num_experts, num_params]')) - for i in range(num_experts): - initializer(weight[i].view(expert_shape)) - return condconv_initializer - - -class CondConv2d(nn.Module): - """ Conditional Convolution - Inspired by: https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/condconv/condconv_layers.py - - Grouped convolution hackery for parallel execution of the per-sample kernel filters inspired by this discussion: - https://github.com/pytorch/pytorch/issues/17983 - """ - __constants__ = ['bias', 'in_channels', 'out_channels', 'dynamic_padding'] - - def __init__(self, in_channels, out_channels, kernel_size=3, - stride=1, padding='', dilation=1, groups=1, bias=False, num_experts=4): - super(CondConv2d, self).__init__() - - self.in_channels = in_channels - self.out_channels = out_channels - self.kernel_size = _pair(kernel_size) - self.stride = _pair(stride) - padding_val, is_padding_dynamic = get_padding_value( - padding, kernel_size, stride=stride, dilation=dilation) - self.dynamic_padding = is_padding_dynamic # if in forward to work with torchscript - self.padding = _pair(padding_val) - self.dilation = _pair(dilation) - self.groups = groups - self.num_experts = num_experts - - self.weight_shape = (self.out_channels, self.in_channels // self.groups) + self.kernel_size - weight_num_param = 1 - for wd in self.weight_shape: - weight_num_param *= wd - self.weight = torch.nn.Parameter(torch.Tensor(self.num_experts, weight_num_param)) - - if bias: - self.bias_shape = (self.out_channels,) - self.bias = torch.nn.Parameter(torch.Tensor(self.num_experts, self.out_channels)) - else: - self.register_parameter('bias', None) - - self.reset_parameters() - - def reset_parameters(self): - init_weight = get_condconv_initializer( - partial(nn.init.kaiming_uniform_, a=math.sqrt(5)), self.num_experts, self.weight_shape) - init_weight(self.weight) - if self.bias is not None: - fan_in = np.prod(self.weight_shape[1:]) - bound = 1 / math.sqrt(fan_in) - init_bias = get_condconv_initializer( - partial(nn.init.uniform_, a=-bound, b=bound), self.num_experts, self.bias_shape) - init_bias(self.bias) - - def forward(self, x, routing_weights): - B, C, H, W = x.shape - weight = torch.matmul(routing_weights, self.weight) - new_weight_shape = (B * self.out_channels, self.in_channels // self.groups) + self.kernel_size - weight = weight.view(new_weight_shape) - bias = None - if self.bias is not None: - bias = torch.matmul(routing_weights, self.bias) - bias = bias.view(B * self.out_channels) - # move batch elements with channels so each batch element can be efficiently convolved with separate kernel - x = x.view(1, B * C, H, W) - if self.dynamic_padding: - out = conv2d_same( - x, weight, bias, stride=self.stride, padding=self.padding, - dilation=self.dilation, groups=self.groups * B) - else: - out = F.conv2d( - x, weight, bias, stride=self.stride, padding=self.padding, - dilation=self.dilation, groups=self.groups * B) - out = out.permute([1, 0, 2, 3]).view(B, self.out_channels, out.shape[-2], out.shape[-1]) - - # Literal port (from TF definition) - # x = torch.split(x, 1, 0) - # weight = torch.split(weight, 1, 0) - # if self.bias is not None: - # bias = torch.matmul(routing_weights, self.bias) - # bias = torch.split(bias, 1, 0) - # else: - # bias = [None] * B - # out = [] - # for xi, wi, bi in zip(x, weight, bias): - # wi = wi.view(*self.weight_shape) - # if bi is not None: - # bi = bi.view(*self.bias_shape) - # out.append(self.conv_fn( - # xi, wi, bi, stride=self.stride, padding=self.padding, - # dilation=self.dilation, groups=self.groups)) - # out = torch.cat(out, 0) - return out - - -class SelectiveKernelAttn(nn.Module): - def __init__(self, channels, num_paths=2, attn_channels=32, - act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d): - super(SelectiveKernelAttn, self).__init__() - self.num_paths = num_paths - self.pool = nn.AdaptiveAvgPool2d(1) - self.fc_reduce = nn.Conv2d(channels, attn_channels, kernel_size=1, bias=False) - self.bn = norm_layer(attn_channels) - self.act = act_layer(inplace=True) - self.fc_select = nn.Conv2d(attn_channels, channels * num_paths, kernel_size=1, bias=False) - - def forward(self, x): - assert x.shape[1] == self.num_paths - x = torch.sum(x, dim=1) - x = self.pool(x) - x = self.fc_reduce(x) - x = self.bn(x) - x = self.act(x) - x = self.fc_select(x) - B, C, H, W = x.shape - x = x.view(B, self.num_paths, C // self.num_paths, H, W) - x = torch.softmax(x, dim=1) - return x - - -def _kernel_valid(k): - if isinstance(k, (list, tuple)): - for ki in k: - return _kernel_valid(ki) - assert k >= 3 and k % 2 - - -class ConvBnAct(nn.Module): - def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, dilation=1, groups=1, - drop_block=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d): - super(ConvBnAct, self).__init__() - padding = _get_padding(kernel_size, stride, dilation) # assuming PyTorch style padding for this block - self.conv = nn.Conv2d( - in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, - padding=padding, dilation=dilation, groups=groups, bias=False) - self.bn = norm_layer(out_channels) - self.drop_block = drop_block - if act_layer is not None: - self.act = act_layer(inplace=True) - else: - self.act = None - - def forward(self, x): - x = self.conv(x) - x = self.bn(x) - if self.drop_block is not None: - x = self.drop_block(x) - if self.act is not None: - x = self.act(x) - return x - - -class SelectiveKernelConv(nn.Module): - - def __init__(self, in_channels, out_channels, kernel_size=None, stride=1, dilation=1, groups=1, - attn_reduction=16, min_attn_channels=32, keep_3x3=True, split_input=False, - drop_block=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d): - super(SelectiveKernelConv, self).__init__() - kernel_size = kernel_size or [3, 5] - _kernel_valid(kernel_size) - if not isinstance(kernel_size, list): - kernel_size = [kernel_size] * 2 - if keep_3x3: - dilation = [dilation * (k - 1) // 2 for k in kernel_size] - kernel_size = [3] * len(kernel_size) - else: - dilation = [dilation] * len(kernel_size) - self.num_paths = len(kernel_size) - self.in_channels = in_channels - self.out_channels = out_channels - self.split_input = split_input - if self.split_input: - assert in_channels % self.num_paths == 0 - in_channels = in_channels // self.num_paths - groups = min(out_channels, groups) - - conv_kwargs = dict( - stride=stride, groups=groups, drop_block=drop_block, act_layer=act_layer, norm_layer=norm_layer) - self.paths = nn.ModuleList([ - ConvBnAct(in_channels, out_channels, kernel_size=k, dilation=d, **conv_kwargs) - for k, d in zip(kernel_size, dilation)]) - - attn_channels = max(int(out_channels / attn_reduction), min_attn_channels) - self.attn = SelectiveKernelAttn(out_channels, self.num_paths, attn_channels) - self.drop_block = drop_block - - def forward(self, x): - if self.split_input: - x_split = torch.split(x, self.in_channels // self.num_paths, 1) - x_paths = [op(x_split[i]) for i, op in enumerate(self.paths)] - else: - x_paths = [op(x) for op in self.paths] - x = torch.stack(x_paths, dim=1) - x_attn = self.attn(x) - x = x * x_attn - x = torch.sum(x, dim=1) - return x - - -# helper method -def select_conv2d(in_chs, out_chs, kernel_size, **kwargs): - assert 'groups' not in kwargs # only use 'depthwise' bool arg - if isinstance(kernel_size, list): - assert 'num_experts' not in kwargs # MixNet + CondConv combo not supported currently - # We're going to use only lists for defining the MixedConv2d kernel groups, - # ints, tuples, other iterables will continue to pass to normal conv and specify h, w. - m = MixedConv2d(in_chs, out_chs, kernel_size, **kwargs) - else: - depthwise = kwargs.pop('depthwise', False) - groups = out_chs if depthwise else 1 - if 'num_experts' in kwargs and kwargs['num_experts'] > 0: - m = CondConv2d(in_chs, out_chs, kernel_size, groups=groups, **kwargs) - else: - m = create_conv2d_pad(in_chs, out_chs, kernel_size, groups=groups, **kwargs) - return m diff --git a/timm/models/layers/conv2d_same.py b/timm/models/layers/conv2d_same.py new file mode 100644 index 00000000..579757b8 --- /dev/null +++ b/timm/models/layers/conv2d_same.py @@ -0,0 +1,79 @@ +""" Conv2d w/ Same Padding + +Hacked together by Ross Wightman +""" +import torch +import torch.nn as nn +import torch.nn.functional as F +from typing import Union, List, Tuple, Optional, Callable +import math + +from .conv_helpers import get_padding + + +def _is_static_pad(kernel_size, stride=1, dilation=1, **_): + return stride == 1 and (dilation * (kernel_size - 1)) % 2 == 0 + + +def _calc_same_pad(i: int, k: int, s: int, d: int): + return max((math.ceil(i / s) - 1) * s + (k - 1) * d + 1 - i, 0) + + +def conv2d_same( + x, weight: torch.Tensor, bias: Optional[torch.Tensor] = None, stride: Tuple[int, int] = (1, 1), + padding: Tuple[int, int] = (0, 0), dilation: Tuple[int, int] = (1, 1), groups: int = 1): + ih, iw = x.size()[-2:] + kh, kw = weight.size()[-2:] + pad_h = _calc_same_pad(ih, kh, stride[0], dilation[0]) + pad_w = _calc_same_pad(iw, kw, stride[1], dilation[1]) + if pad_h > 0 or pad_w > 0: + x = F.pad(x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2]) + return F.conv2d(x, weight, bias, stride, (0, 0), dilation, groups) + + +class Conv2dSame(nn.Conv2d): + """ Tensorflow like 'SAME' convolution wrapper for 2D convolutions + """ + + def __init__(self, in_channels, out_channels, kernel_size, stride=1, + padding=0, dilation=1, groups=1, bias=True): + super(Conv2dSame, self).__init__( + in_channels, out_channels, kernel_size, stride, 0, dilation, groups, bias) + + def forward(self, x): + return conv2d_same(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups) + + +def get_padding_value(padding, kernel_size, **kwargs) -> Tuple[Tuple, bool]: + dynamic = False + if isinstance(padding, str): + # for any string padding, the padding will be calculated for you, one of three ways + padding = padding.lower() + if padding == 'same': + # TF compatible 'SAME' padding, has a performance and GPU memory allocation impact + if _is_static_pad(kernel_size, **kwargs): + # static case, no extra overhead + padding = get_padding(kernel_size, **kwargs) + else: + # dynamic 'SAME' padding, has runtime/GPU memory overhead + padding = 0 + dynamic = True + elif padding == 'valid': + # 'VALID' padding, same as padding=0 + padding = 0 + else: + # Default to PyTorch style 'same'-ish symmetric padding + padding = get_padding(kernel_size, **kwargs) + return padding, dynamic + + +def create_conv2d_pad(in_chs, out_chs, kernel_size, **kwargs): + padding = kwargs.pop('padding', '') + kwargs.setdefault('bias', False) + padding, is_dynamic = get_padding_value(padding, kernel_size, **kwargs) + if is_dynamic: + return Conv2dSame(in_chs, out_chs, kernel_size, **kwargs) + else: + return nn.Conv2d(in_chs, out_chs, kernel_size, padding=padding, **kwargs) + + diff --git a/timm/models/layers/conv_bn_act.py b/timm/models/layers/conv_bn_act.py new file mode 100644 index 00000000..a10c1d38 --- /dev/null +++ b/timm/models/layers/conv_bn_act.py @@ -0,0 +1,32 @@ +""" Conv2d + BN + Act + +Hacked together by Ross Wightman +""" +from torch import nn as nn + +from timm.models.layers.conv_helpers import get_padding + + +class ConvBnAct(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, dilation=1, groups=1, + drop_block=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d): + super(ConvBnAct, self).__init__() + padding = get_padding(kernel_size, stride, dilation) # assuming PyTorch style padding for this block + self.conv = nn.Conv2d( + in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, + padding=padding, dilation=dilation, groups=groups, bias=False) + self.bn = norm_layer(out_channels) + self.drop_block = drop_block + if act_layer is not None: + self.act = act_layer(inplace=True) + else: + self.act = None + + def forward(self, x): + x = self.conv(x) + x = self.bn(x) + if self.drop_block is not None: + x = self.drop_block(x) + if self.act is not None: + x = self.act(x) + return x diff --git a/timm/models/layers/conv_helpers.py b/timm/models/layers/conv_helpers.py new file mode 100644 index 00000000..3f8b160e --- /dev/null +++ b/timm/models/layers/conv_helpers.py @@ -0,0 +1,27 @@ +""" Common Helpers + +Hacked together by Ross Wightman +""" +from itertools import repeat +from torch._six import container_abcs + + +# From PyTorch internals +def _ntuple(n): + def parse(x): + if isinstance(x, container_abcs.Iterable): + return x + return tuple(repeat(x, n)) + return parse + + +tup_single = _ntuple(1) +tup_pair = _ntuple(2) +tup_triple = _ntuple(3) +tup_quadruple = _ntuple(4) + + +# Calculate symmetric padding for a convolution +def get_padding(kernel_size: int, stride: int = 1, dilation: int = 1, **_) -> int: + padding = ((stride - 1) + dilation * (kernel_size - 1)) // 2 + return padding diff --git a/timm/models/layers/nn_ops.py b/timm/models/layers/drop.py similarity index 100% rename from timm/models/layers/nn_ops.py rename to timm/models/layers/drop.py diff --git a/timm/models/layers/mixed_conv2d.py b/timm/models/layers/mixed_conv2d.py new file mode 100644 index 00000000..3e280c03 --- /dev/null +++ b/timm/models/layers/mixed_conv2d.py @@ -0,0 +1,49 @@ +""" Conditional Convolution + +Hacked together by Ross Wightman +""" + +import torch +from torch import nn as nn + +from .conv2d_same import create_conv2d_pad + + +def _split_channels(num_chan, num_groups): + split = [num_chan // num_groups for _ in range(num_groups)] + split[0] += num_chan - sum(split) + return split + + +class MixedConv2d(nn.ModuleDict): + """ Mixed Grouped Convolution + + Based on MDConv and GroupedConv in MixNet impl: + https://github.com/tensorflow/tpu/blob/master/models/official/mnasnet/mixnet/custom_layers.py + """ + def __init__(self, in_channels, out_channels, kernel_size=3, + stride=1, padding='', dilation=1, depthwise=False, **kwargs): + super(MixedConv2d, self).__init__() + + kernel_size = kernel_size if isinstance(kernel_size, list) else [kernel_size] + num_groups = len(kernel_size) + in_splits = _split_channels(in_channels, num_groups) + out_splits = _split_channels(out_channels, num_groups) + self.in_channels = sum(in_splits) + self.out_channels = sum(out_splits) + for idx, (k, in_ch, out_ch) in enumerate(zip(kernel_size, in_splits, out_splits)): + conv_groups = out_ch if depthwise else 1 + # use add_module to keep key space clean + self.add_module( + str(idx), + create_conv2d_pad( + in_ch, out_ch, k, stride=stride, + padding=padding, dilation=dilation, groups=conv_groups, **kwargs) + ) + self.splits = in_splits + + def forward(self, x): + x_split = torch.split(x, self.splits, 1) + x_out = [c(x_split[i]) for i, c in enumerate(self.values())] + x = torch.cat(x_out, 1) + return x diff --git a/timm/models/layers/select_conv2d.py b/timm/models/layers/select_conv2d.py new file mode 100644 index 00000000..a8713b0b --- /dev/null +++ b/timm/models/layers/select_conv2d.py @@ -0,0 +1,30 @@ +""" Select Conv2d Factory Method + +Hacked together by Ross Wightman +""" + +from .mixed_conv2d import MixedConv2d +from .cond_conv2d import CondConv2d +from .conv2d_same import create_conv2d_pad + + +def select_conv2d(in_chs, out_chs, kernel_size, **kwargs): + """ Select a 2d convolution implementation based on arguments + Creates and returns one of torch.nn.Conv2d, Conv2dSame, MixedConv2d, or CondConv2d. + + Used extensively by EfficientNet, MobileNetv3 and related networks. + """ + assert 'groups' not in kwargs # only use 'depthwise' bool arg + if isinstance(kernel_size, list): + assert 'num_experts' not in kwargs # MixNet + CondConv combo not supported currently + # We're going to use only lists for defining the MixedConv2d kernel groups, + # ints, tuples, other iterables will continue to pass to normal conv and specify h, w. + m = MixedConv2d(in_chs, out_chs, kernel_size, **kwargs) + else: + depthwise = kwargs.pop('depthwise', False) + groups = out_chs if depthwise else 1 + if 'num_experts' in kwargs and kwargs['num_experts'] > 0: + m = CondConv2d(in_chs, out_chs, kernel_size, groups=groups, **kwargs) + else: + m = create_conv2d_pad(in_chs, out_chs, kernel_size, groups=groups, **kwargs) + return m diff --git a/timm/models/layers/selective_kernel.py b/timm/models/layers/selective_kernel.py new file mode 100644 index 00000000..4100aa02 --- /dev/null +++ b/timm/models/layers/selective_kernel.py @@ -0,0 +1,88 @@ +""" Selective Kernel Convolution Attention + +Hacked together by Ross Wightman +""" + +import torch +from torch import nn as nn + +from .conv_bn_act import ConvBnAct + + +def _kernel_valid(k): + if isinstance(k, (list, tuple)): + for ki in k: + return _kernel_valid(ki) + assert k >= 3 and k % 2 + + +class SelectiveKernelAttn(nn.Module): + def __init__(self, channels, num_paths=2, attn_channels=32, + act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d): + super(SelectiveKernelAttn, self).__init__() + self.num_paths = num_paths + self.pool = nn.AdaptiveAvgPool2d(1) + self.fc_reduce = nn.Conv2d(channels, attn_channels, kernel_size=1, bias=False) + self.bn = norm_layer(attn_channels) + self.act = act_layer(inplace=True) + self.fc_select = nn.Conv2d(attn_channels, channels * num_paths, kernel_size=1, bias=False) + + def forward(self, x): + assert x.shape[1] == self.num_paths + x = torch.sum(x, dim=1) + x = self.pool(x) + x = self.fc_reduce(x) + x = self.bn(x) + x = self.act(x) + x = self.fc_select(x) + B, C, H, W = x.shape + x = x.view(B, self.num_paths, C // self.num_paths, H, W) + x = torch.softmax(x, dim=1) + return x + + +class SelectiveKernelConv(nn.Module): + + def __init__(self, in_channels, out_channels, kernel_size=None, stride=1, dilation=1, groups=1, + attn_reduction=16, min_attn_channels=32, keep_3x3=True, split_input=False, + drop_block=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d): + super(SelectiveKernelConv, self).__init__() + kernel_size = kernel_size or [3, 5] + _kernel_valid(kernel_size) + if not isinstance(kernel_size, list): + kernel_size = [kernel_size] * 2 + if keep_3x3: + dilation = [dilation * (k - 1) // 2 for k in kernel_size] + kernel_size = [3] * len(kernel_size) + else: + dilation = [dilation] * len(kernel_size) + self.num_paths = len(kernel_size) + self.in_channels = in_channels + self.out_channels = out_channels + self.split_input = split_input + if self.split_input: + assert in_channels % self.num_paths == 0 + in_channels = in_channels // self.num_paths + groups = min(out_channels, groups) + + conv_kwargs = dict( + stride=stride, groups=groups, drop_block=drop_block, act_layer=act_layer, norm_layer=norm_layer) + self.paths = nn.ModuleList([ + ConvBnAct(in_channels, out_channels, kernel_size=k, dilation=d, **conv_kwargs) + for k, d in zip(kernel_size, dilation)]) + + attn_channels = max(int(out_channels / attn_reduction), min_attn_channels) + self.attn = SelectiveKernelAttn(out_channels, self.num_paths, attn_channels) + self.drop_block = drop_block + + def forward(self, x): + if self.split_input: + x_split = torch.split(x, self.in_channels // self.num_paths, 1) + x_paths = [op(x_split[i]) for i, op in enumerate(self.paths)] + else: + x_paths = [op(x) for op in self.paths] + x = torch.stack(x_paths, dim=1) + x_attn = self.attn(x) + x = x * x_attn + x = torch.sum(x, dim=1) + return x diff --git a/timm/models/layers/test_time_pool.py b/timm/models/layers/test_time_pool.py index ce6ddf07..33e24970 100644 --- a/timm/models/layers/test_time_pool.py +++ b/timm/models/layers/test_time_pool.py @@ -1,3 +1,8 @@ +""" Test Time Pooling (Average-Max Pool) + +Hacked together by Ross Wightman +""" + import logging from torch import nn import torch.nn.functional as F diff --git a/timm/models/mobilenetv3.py b/timm/models/mobilenetv3.py index 76f6363c..9d4de856 100644 --- a/timm/models/mobilenetv3.py +++ b/timm/models/mobilenetv3.py @@ -7,8 +7,6 @@ Paper: Searching for MobileNetV3 - https://arxiv.org/abs/1905.02244 Hacked together by Ross Wightman """ -import torch.nn as nn -import torch.nn.functional as F from .efficientnet_builder import * from .registry import register_model