commit
e0685dd415
@ -0,0 +1,10 @@
|
|||||||
|
dependencies = ['torch']
|
||||||
|
|
||||||
|
from timm.models import registry
|
||||||
|
|
||||||
|
current_module = __import__(__name__)
|
||||||
|
current_module.__dict__.update(registry._model_entrypoints)
|
||||||
|
#for fn_name in registry.list_models():
|
||||||
|
# fn = registry.model_entrypoint(fn_name)
|
||||||
|
# setattr(current_module, fn_name, fn)
|
||||||
|
|
|
|
|
|
@ -1,260 +0,0 @@
|
|||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
import torch.nn.functional as F
|
|
||||||
from torch._six import container_abcs
|
|
||||||
from itertools import repeat
|
|
||||||
from functools import partial
|
|
||||||
import numpy as np
|
|
||||||
import math
|
|
||||||
|
|
||||||
|
|
||||||
# Tuple helpers ripped from PyTorch
|
|
||||||
def _ntuple(n):
|
|
||||||
def parse(x):
|
|
||||||
if isinstance(x, container_abcs.Iterable):
|
|
||||||
return x
|
|
||||||
return tuple(repeat(x, n))
|
|
||||||
return parse
|
|
||||||
|
|
||||||
|
|
||||||
_single = _ntuple(1)
|
|
||||||
_pair = _ntuple(2)
|
|
||||||
_triple = _ntuple(3)
|
|
||||||
_quadruple = _ntuple(4)
|
|
||||||
|
|
||||||
|
|
||||||
def _is_static_pad(kernel_size, stride=1, dilation=1, **_):
|
|
||||||
return stride == 1 and (dilation * (kernel_size - 1)) % 2 == 0
|
|
||||||
|
|
||||||
|
|
||||||
def _get_padding(kernel_size, stride=1, dilation=1, **_):
|
|
||||||
padding = ((stride - 1) + dilation * (kernel_size - 1)) // 2
|
|
||||||
return padding
|
|
||||||
|
|
||||||
|
|
||||||
def _calc_same_pad(i, k, s, d):
|
|
||||||
return max((math.ceil(i / s) - 1) * s + (k - 1) * d + 1 - i, 0)
|
|
||||||
|
|
||||||
|
|
||||||
def _split_channels(num_chan, num_groups):
|
|
||||||
split = [num_chan // num_groups for _ in range(num_groups)]
|
|
||||||
split[0] += num_chan - sum(split)
|
|
||||||
return split
|
|
||||||
|
|
||||||
|
|
||||||
# pylint: disable=unused-argument
|
|
||||||
def conv2d_same(x, weight, bias=None, stride=(1, 1), padding=(0, 0), dilation=(1, 1), groups=1):
|
|
||||||
ih, iw = x.size()[-2:]
|
|
||||||
kh, kw = weight.size()[-2:]
|
|
||||||
pad_h = _calc_same_pad(ih, kh, stride[0], dilation[0])
|
|
||||||
pad_w = _calc_same_pad(iw, kw, stride[1], dilation[1])
|
|
||||||
if pad_h > 0 or pad_w > 0:
|
|
||||||
x = F.pad(x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2])
|
|
||||||
return F.conv2d(x, weight, bias, stride, (0, 0), dilation, groups)
|
|
||||||
|
|
||||||
|
|
||||||
class Conv2dSame(nn.Conv2d):
|
|
||||||
""" Tensorflow like 'SAME' convolution wrapper for 2D convolutions
|
|
||||||
"""
|
|
||||||
|
|
||||||
# pylint: disable=unused-argument
|
|
||||||
def __init__(self, in_channels, out_channels, kernel_size, stride=1,
|
|
||||||
padding=0, dilation=1, groups=1, bias=True):
|
|
||||||
super(Conv2dSame, self).__init__(
|
|
||||||
in_channels, out_channels, kernel_size, stride, 0, dilation, groups, bias)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
return conv2d_same(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
|
|
||||||
|
|
||||||
|
|
||||||
def get_padding_value(padding, kernel_size, **kwargs):
|
|
||||||
dynamic = False
|
|
||||||
if isinstance(padding, str):
|
|
||||||
# for any string padding, the padding will be calculated for you, one of three ways
|
|
||||||
padding = padding.lower()
|
|
||||||
if padding == 'same':
|
|
||||||
# TF compatible 'SAME' padding, has a performance and GPU memory allocation impact
|
|
||||||
if _is_static_pad(kernel_size, **kwargs):
|
|
||||||
# static case, no extra overhead
|
|
||||||
padding = _get_padding(kernel_size, **kwargs)
|
|
||||||
else:
|
|
||||||
# dynamic 'SAME' padding, has runtime/GPU memory overhead
|
|
||||||
padding = 0
|
|
||||||
dynamic = True
|
|
||||||
elif padding == 'valid':
|
|
||||||
# 'VALID' padding, same as padding=0
|
|
||||||
padding = 0
|
|
||||||
else:
|
|
||||||
# Default to PyTorch style 'same'-ish symmetric padding
|
|
||||||
padding = _get_padding(kernel_size, **kwargs)
|
|
||||||
return padding, dynamic
|
|
||||||
|
|
||||||
|
|
||||||
def create_conv2d_pad(in_chs, out_chs, kernel_size, **kwargs):
|
|
||||||
padding = kwargs.pop('padding', '')
|
|
||||||
kwargs.setdefault('bias', False)
|
|
||||||
padding, is_dynamic = get_padding_value(padding, kernel_size, **kwargs)
|
|
||||||
if is_dynamic:
|
|
||||||
return Conv2dSame(in_chs, out_chs, kernel_size, **kwargs)
|
|
||||||
else:
|
|
||||||
return nn.Conv2d(in_chs, out_chs, kernel_size, padding=padding, **kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
class MixedConv2d(nn.Module):
|
|
||||||
""" Mixed Grouped Convolution
|
|
||||||
Based on MDConv and GroupedConv in MixNet impl:
|
|
||||||
https://github.com/tensorflow/tpu/blob/master/models/official/mnasnet/mixnet/custom_layers.py
|
|
||||||
|
|
||||||
NOTE: This does not currently work with torch.jit.script
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, in_channels, out_channels, kernel_size=3,
|
|
||||||
stride=1, padding='', dilation=1, depthwise=False, **kwargs):
|
|
||||||
super(MixedConv2d, self).__init__()
|
|
||||||
|
|
||||||
kernel_size = kernel_size if isinstance(kernel_size, list) else [kernel_size]
|
|
||||||
num_groups = len(kernel_size)
|
|
||||||
in_splits = _split_channels(in_channels, num_groups)
|
|
||||||
out_splits = _split_channels(out_channels, num_groups)
|
|
||||||
self.in_channels = sum(in_splits)
|
|
||||||
self.out_channels = sum(out_splits)
|
|
||||||
for idx, (k, in_ch, out_ch) in enumerate(zip(kernel_size, in_splits, out_splits)):
|
|
||||||
conv_groups = out_ch if depthwise else 1
|
|
||||||
# use add_module to keep key space clean
|
|
||||||
self.add_module(
|
|
||||||
str(idx),
|
|
||||||
create_conv2d_pad(
|
|
||||||
in_ch, out_ch, k, stride=stride,
|
|
||||||
padding=padding, dilation=dilation, groups=conv_groups, **kwargs)
|
|
||||||
)
|
|
||||||
self.splits = in_splits
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
x_split = torch.split(x, self.splits, 1)
|
|
||||||
x_out = [c(x) for x, c in zip(x_split, self._modules.values())]
|
|
||||||
x = torch.cat(x_out, 1)
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
def get_condconv_initializer(initializer, num_experts, expert_shape):
|
|
||||||
def condconv_initializer(weight):
|
|
||||||
"""CondConv initializer function."""
|
|
||||||
num_params = np.prod(expert_shape)
|
|
||||||
if (len(weight.shape) != 2 or weight.shape[0] != num_experts or
|
|
||||||
weight.shape[1] != num_params):
|
|
||||||
raise (ValueError(
|
|
||||||
'CondConv variables must have shape [num_experts, num_params]'))
|
|
||||||
for i in range(num_experts):
|
|
||||||
initializer(weight[i].view(expert_shape))
|
|
||||||
return condconv_initializer
|
|
||||||
|
|
||||||
|
|
||||||
class CondConv2d(nn.Module):
|
|
||||||
""" Conditional Convolution
|
|
||||||
Inspired by: https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/condconv/condconv_layers.py
|
|
||||||
|
|
||||||
Grouped convolution hackery for parallel execution of the per-sample kernel filters inspired by this discussion:
|
|
||||||
https://github.com/pytorch/pytorch/issues/17983
|
|
||||||
"""
|
|
||||||
__constants__ = ['bias', 'in_channels', 'out_channels', 'dynamic_padding']
|
|
||||||
|
|
||||||
def __init__(self, in_channels, out_channels, kernel_size=3,
|
|
||||||
stride=1, padding='', dilation=1, groups=1, bias=False, num_experts=4):
|
|
||||||
super(CondConv2d, self).__init__()
|
|
||||||
|
|
||||||
self.in_channels = in_channels
|
|
||||||
self.out_channels = out_channels
|
|
||||||
self.kernel_size = _pair(kernel_size)
|
|
||||||
self.stride = _pair(stride)
|
|
||||||
padding_val, is_padding_dynamic = get_padding_value(
|
|
||||||
padding, kernel_size, stride=stride, dilation=dilation)
|
|
||||||
self.dynamic_padding = is_padding_dynamic # if in forward to work with torchscript
|
|
||||||
self.padding = _pair(padding_val)
|
|
||||||
self.dilation = _pair(dilation)
|
|
||||||
self.groups = groups
|
|
||||||
self.num_experts = num_experts
|
|
||||||
|
|
||||||
self.weight_shape = (self.out_channels, self.in_channels // self.groups) + self.kernel_size
|
|
||||||
weight_num_param = 1
|
|
||||||
for wd in self.weight_shape:
|
|
||||||
weight_num_param *= wd
|
|
||||||
self.weight = torch.nn.Parameter(torch.Tensor(self.num_experts, weight_num_param))
|
|
||||||
|
|
||||||
if bias:
|
|
||||||
self.bias_shape = (self.out_channels,)
|
|
||||||
self.bias = torch.nn.Parameter(torch.Tensor(self.num_experts, self.out_channels))
|
|
||||||
else:
|
|
||||||
self.register_parameter('bias', None)
|
|
||||||
|
|
||||||
self.reset_parameters()
|
|
||||||
|
|
||||||
def reset_parameters(self):
|
|
||||||
init_weight = get_condconv_initializer(
|
|
||||||
partial(nn.init.kaiming_uniform_, a=math.sqrt(5)), self.num_experts, self.weight_shape)
|
|
||||||
init_weight(self.weight)
|
|
||||||
if self.bias is not None:
|
|
||||||
fan_in = np.prod(self.weight_shape[1:])
|
|
||||||
bound = 1 / math.sqrt(fan_in)
|
|
||||||
init_bias = get_condconv_initializer(
|
|
||||||
partial(nn.init.uniform_, a=-bound, b=bound), self.num_experts, self.bias_shape)
|
|
||||||
init_bias(self.bias)
|
|
||||||
|
|
||||||
def forward(self, x, routing_weights):
|
|
||||||
B, C, H, W = x.shape
|
|
||||||
weight = torch.matmul(routing_weights, self.weight)
|
|
||||||
new_weight_shape = (B * self.out_channels, self.in_channels // self.groups) + self.kernel_size
|
|
||||||
weight = weight.view(new_weight_shape)
|
|
||||||
bias = None
|
|
||||||
if self.bias is not None:
|
|
||||||
bias = torch.matmul(routing_weights, self.bias)
|
|
||||||
bias = bias.view(B * self.out_channels)
|
|
||||||
# move batch elements with channels so each batch element can be efficiently convolved with separate kernel
|
|
||||||
x = x.view(1, B * C, H, W)
|
|
||||||
if self.dynamic_padding:
|
|
||||||
out = conv2d_same(
|
|
||||||
x, weight, bias, stride=self.stride, padding=self.padding,
|
|
||||||
dilation=self.dilation, groups=self.groups * B)
|
|
||||||
else:
|
|
||||||
out = F.conv2d(
|
|
||||||
x, weight, bias, stride=self.stride, padding=self.padding,
|
|
||||||
dilation=self.dilation, groups=self.groups * B)
|
|
||||||
out = out.permute([1, 0, 2, 3]).view(B, self.out_channels, out.shape[-2], out.shape[-1])
|
|
||||||
|
|
||||||
# Literal port (from TF definition)
|
|
||||||
# x = torch.split(x, 1, 0)
|
|
||||||
# weight = torch.split(weight, 1, 0)
|
|
||||||
# if self.bias is not None:
|
|
||||||
# bias = torch.matmul(routing_weights, self.bias)
|
|
||||||
# bias = torch.split(bias, 1, 0)
|
|
||||||
# else:
|
|
||||||
# bias = [None] * B
|
|
||||||
# out = []
|
|
||||||
# for xi, wi, bi in zip(x, weight, bias):
|
|
||||||
# wi = wi.view(*self.weight_shape)
|
|
||||||
# if bi is not None:
|
|
||||||
# bi = bi.view(*self.bias_shape)
|
|
||||||
# out.append(self.conv_fn(
|
|
||||||
# xi, wi, bi, stride=self.stride, padding=self.padding,
|
|
||||||
# dilation=self.dilation, groups=self.groups))
|
|
||||||
# out = torch.cat(out, 0)
|
|
||||||
return out
|
|
||||||
|
|
||||||
|
|
||||||
# helper method
|
|
||||||
def select_conv2d(in_chs, out_chs, kernel_size, **kwargs):
|
|
||||||
assert 'groups' not in kwargs # only use 'depthwise' bool arg
|
|
||||||
if isinstance(kernel_size, list):
|
|
||||||
assert 'num_experts' not in kwargs # MixNet + CondConv combo not supported currently
|
|
||||||
# We're going to use only lists for defining the MixedConv2d kernel groups,
|
|
||||||
# ints, tuples, other iterables will continue to pass to normal conv and specify h, w.
|
|
||||||
m = MixedConv2d(in_chs, out_chs, kernel_size, **kwargs)
|
|
||||||
else:
|
|
||||||
depthwise = kwargs.pop('depthwise', False)
|
|
||||||
groups = out_chs if depthwise else 1
|
|
||||||
if 'num_experts' in kwargs and kwargs['num_experts'] > 0:
|
|
||||||
m = CondConv2d(in_chs, out_chs, kernel_size, groups=groups, **kwargs)
|
|
||||||
else:
|
|
||||||
m = create_conv2d_pad(in_chs, out_chs, kernel_size, groups=groups, **kwargs)
|
|
||||||
return m
|
|
||||||
|
|
||||||
|
|
@ -0,0 +1,17 @@
|
|||||||
|
from .padding import get_padding
|
||||||
|
from .avg_pool2d_same import AvgPool2dSame
|
||||||
|
from .conv2d_same import Conv2dSame
|
||||||
|
from .conv_bn_act import ConvBnAct
|
||||||
|
from .mixed_conv2d import MixedConv2d
|
||||||
|
from .cond_conv2d import CondConv2d, get_condconv_initializer
|
||||||
|
from .create_conv2d import create_conv2d
|
||||||
|
from .create_attn import create_attn
|
||||||
|
from .selective_kernel import SelectiveKernelConv
|
||||||
|
from .se import SEModule
|
||||||
|
from .eca import EcaModule, CecaModule
|
||||||
|
from .activations import *
|
||||||
|
from .adaptive_avgmax_pool import \
|
||||||
|
adaptive_avgmax_pool2d, select_adaptive_pool2d, AdaptiveAvgMaxPool2d, SelectAdaptivePool2d
|
||||||
|
from .drop import DropBlock2d, DropPath, drop_block_2d, drop_path
|
||||||
|
from .test_time_pool import TestTimePoolHead, apply_test_time_pool
|
||||||
|
from .split_batchnorm import SplitBatchNorm2d, convert_splitbn_model
|
@ -0,0 +1,31 @@
|
|||||||
|
""" AvgPool2d w/ Same Padding
|
||||||
|
|
||||||
|
Hacked together by Ross Wightman
|
||||||
|
"""
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from typing import List
|
||||||
|
import math
|
||||||
|
|
||||||
|
from .helpers import tup_pair
|
||||||
|
from .padding import pad_same
|
||||||
|
|
||||||
|
|
||||||
|
def avg_pool2d_same(x, kernel_size: List[int], stride: List[int], padding: List[int] = (0, 0),
|
||||||
|
ceil_mode: bool = False, count_include_pad: bool = True):
|
||||||
|
x = pad_same(x, kernel_size, stride)
|
||||||
|
return F.avg_pool2d(x, kernel_size, stride, (0, 0), ceil_mode, count_include_pad)
|
||||||
|
|
||||||
|
|
||||||
|
class AvgPool2dSame(nn.AvgPool2d):
|
||||||
|
""" Tensorflow like 'SAME' wrapper for 2D average pooling
|
||||||
|
"""
|
||||||
|
def __init__(self, kernel_size: int, stride=None, padding=0, ceil_mode=False, count_include_pad=True):
|
||||||
|
kernel_size = tup_pair(kernel_size)
|
||||||
|
stride = tup_pair(stride)
|
||||||
|
super(AvgPool2dSame, self).__init__(kernel_size, stride, (0, 0), ceil_mode, count_include_pad)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return avg_pool2d_same(
|
||||||
|
x, self.kernel_size, self.stride, self.padding, self.ceil_mode, self.count_include_pad)
|
@ -0,0 +1,100 @@
|
|||||||
|
""" CBAM (sort-of) Attention
|
||||||
|
|
||||||
|
Experimental impl of CBAM: Convolutional Block Attention Module: https://arxiv.org/abs/1807.06521
|
||||||
|
|
||||||
|
WARNING: Results with these attention layers have been mixed. They can significantly reduce performance on
|
||||||
|
some tasks, especially fine-grained it seems. I may end up removing this impl.
|
||||||
|
|
||||||
|
Hacked together by Ross Wightman
|
||||||
|
"""
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import nn as nn
|
||||||
|
from .conv_bn_act import ConvBnAct
|
||||||
|
|
||||||
|
|
||||||
|
class ChannelAttn(nn.Module):
|
||||||
|
""" Original CBAM channel attention module, currently avg + max pool variant only.
|
||||||
|
"""
|
||||||
|
def __init__(self, channels, reduction=16, act_layer=nn.ReLU):
|
||||||
|
super(ChannelAttn, self).__init__()
|
||||||
|
self.avg_pool = nn.AdaptiveAvgPool2d(1)
|
||||||
|
self.max_pool = nn.AdaptiveMaxPool2d(1)
|
||||||
|
self.fc1 = nn.Conv2d(channels, channels // reduction, 1, bias=False)
|
||||||
|
self.act = act_layer(inplace=True)
|
||||||
|
self.fc2 = nn.Conv2d(channels // reduction, channels, 1, bias=False)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x_avg = self.avg_pool(x)
|
||||||
|
x_max = self.max_pool(x)
|
||||||
|
x_avg = self.fc2(self.act(self.fc1(x_avg)))
|
||||||
|
x_max = self.fc2(self.act(self.fc1(x_max)))
|
||||||
|
x_attn = x_avg + x_max
|
||||||
|
return x * x_attn.sigmoid()
|
||||||
|
|
||||||
|
|
||||||
|
class LightChannelAttn(ChannelAttn):
|
||||||
|
"""An experimental 'lightweight' that sums avg + max pool first
|
||||||
|
"""
|
||||||
|
def __init__(self, channels, reduction=16):
|
||||||
|
super(LightChannelAttn, self).__init__(channels, reduction)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x_pool = 0.5 * self.avg_pool(x) + 0.5 * self.max_pool(x)
|
||||||
|
x_attn = self.fc2(self.act(self.fc1(x_pool)))
|
||||||
|
return x * x_attn.sigmoid()
|
||||||
|
|
||||||
|
|
||||||
|
class SpatialAttn(nn.Module):
|
||||||
|
""" Original CBAM spatial attention module
|
||||||
|
"""
|
||||||
|
def __init__(self, kernel_size=7):
|
||||||
|
super(SpatialAttn, self).__init__()
|
||||||
|
self.conv = ConvBnAct(2, 1, kernel_size, act_layer=None)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x_avg = torch.mean(x, dim=1, keepdim=True)
|
||||||
|
x_max = torch.max(x, dim=1, keepdim=True)[0]
|
||||||
|
x_attn = torch.cat([x_avg, x_max], dim=1)
|
||||||
|
x_attn = self.conv(x_attn)
|
||||||
|
return x * x_attn.sigmoid()
|
||||||
|
|
||||||
|
|
||||||
|
class LightSpatialAttn(nn.Module):
|
||||||
|
"""An experimental 'lightweight' variant that sums avg_pool and max_pool results.
|
||||||
|
"""
|
||||||
|
def __init__(self, kernel_size=7):
|
||||||
|
super(LightSpatialAttn, self).__init__()
|
||||||
|
self.conv = ConvBnAct(1, 1, kernel_size, act_layer=None)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x_avg = torch.mean(x, dim=1, keepdim=True)
|
||||||
|
x_max = torch.max(x, dim=1, keepdim=True)[0]
|
||||||
|
x_attn = 0.5 * x_avg + 0.5 * x_max
|
||||||
|
x_attn = self.conv(x_attn)
|
||||||
|
return x * x_attn.sigmoid()
|
||||||
|
|
||||||
|
|
||||||
|
class CbamModule(nn.Module):
|
||||||
|
def __init__(self, channels, spatial_kernel_size=7):
|
||||||
|
super(CbamModule, self).__init__()
|
||||||
|
self.channel = ChannelAttn(channels)
|
||||||
|
self.spatial = SpatialAttn(spatial_kernel_size)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.channel(x)
|
||||||
|
x = self.spatial(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class LightCbamModule(nn.Module):
|
||||||
|
def __init__(self, channels, spatial_kernel_size=7):
|
||||||
|
super(LightCbamModule, self).__init__()
|
||||||
|
self.channel = LightChannelAttn(channels)
|
||||||
|
self.spatial = LightSpatialAttn(spatial_kernel_size)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.channel(x)
|
||||||
|
x = self.spatial(x)
|
||||||
|
return x
|
||||||
|
|
@ -0,0 +1,121 @@
|
|||||||
|
""" PyTorch Conditionally Parameterized Convolution (CondConv)
|
||||||
|
|
||||||
|
Paper: CondConv: Conditionally Parameterized Convolutions for Efficient Inference
|
||||||
|
(https://arxiv.org/abs/1904.04971)
|
||||||
|
|
||||||
|
Hacked together by Ross Wightman
|
||||||
|
"""
|
||||||
|
|
||||||
|
import math
|
||||||
|
from functools import partial
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from torch import nn as nn
|
||||||
|
from torch.nn import functional as F
|
||||||
|
|
||||||
|
from .helpers import tup_pair
|
||||||
|
from .conv2d_same import get_padding_value, conv2d_same
|
||||||
|
|
||||||
|
|
||||||
|
def get_condconv_initializer(initializer, num_experts, expert_shape):
|
||||||
|
def condconv_initializer(weight):
|
||||||
|
"""CondConv initializer function."""
|
||||||
|
num_params = np.prod(expert_shape)
|
||||||
|
if (len(weight.shape) != 2 or weight.shape[0] != num_experts or
|
||||||
|
weight.shape[1] != num_params):
|
||||||
|
raise (ValueError(
|
||||||
|
'CondConv variables must have shape [num_experts, num_params]'))
|
||||||
|
for i in range(num_experts):
|
||||||
|
initializer(weight[i].view(expert_shape))
|
||||||
|
return condconv_initializer
|
||||||
|
|
||||||
|
|
||||||
|
class CondConv2d(nn.Module):
|
||||||
|
""" Conditionally Parameterized Convolution
|
||||||
|
Inspired by: https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/condconv/condconv_layers.py
|
||||||
|
|
||||||
|
Grouped convolution hackery for parallel execution of the per-sample kernel filters inspired by this discussion:
|
||||||
|
https://github.com/pytorch/pytorch/issues/17983
|
||||||
|
"""
|
||||||
|
__constants__ = ['bias', 'in_channels', 'out_channels', 'dynamic_padding']
|
||||||
|
|
||||||
|
def __init__(self, in_channels, out_channels, kernel_size=3,
|
||||||
|
stride=1, padding='', dilation=1, groups=1, bias=False, num_experts=4):
|
||||||
|
super(CondConv2d, self).__init__()
|
||||||
|
|
||||||
|
self.in_channels = in_channels
|
||||||
|
self.out_channels = out_channels
|
||||||
|
self.kernel_size = tup_pair(kernel_size)
|
||||||
|
self.stride = tup_pair(stride)
|
||||||
|
padding_val, is_padding_dynamic = get_padding_value(
|
||||||
|
padding, kernel_size, stride=stride, dilation=dilation)
|
||||||
|
self.dynamic_padding = is_padding_dynamic # if in forward to work with torchscript
|
||||||
|
self.padding = tup_pair(padding_val)
|
||||||
|
self.dilation = tup_pair(dilation)
|
||||||
|
self.groups = groups
|
||||||
|
self.num_experts = num_experts
|
||||||
|
|
||||||
|
self.weight_shape = (self.out_channels, self.in_channels // self.groups) + self.kernel_size
|
||||||
|
weight_num_param = 1
|
||||||
|
for wd in self.weight_shape:
|
||||||
|
weight_num_param *= wd
|
||||||
|
self.weight = torch.nn.Parameter(torch.Tensor(self.num_experts, weight_num_param))
|
||||||
|
|
||||||
|
if bias:
|
||||||
|
self.bias_shape = (self.out_channels,)
|
||||||
|
self.bias = torch.nn.Parameter(torch.Tensor(self.num_experts, self.out_channels))
|
||||||
|
else:
|
||||||
|
self.register_parameter('bias', None)
|
||||||
|
|
||||||
|
self.reset_parameters()
|
||||||
|
|
||||||
|
def reset_parameters(self):
|
||||||
|
init_weight = get_condconv_initializer(
|
||||||
|
partial(nn.init.kaiming_uniform_, a=math.sqrt(5)), self.num_experts, self.weight_shape)
|
||||||
|
init_weight(self.weight)
|
||||||
|
if self.bias is not None:
|
||||||
|
fan_in = np.prod(self.weight_shape[1:])
|
||||||
|
bound = 1 / math.sqrt(fan_in)
|
||||||
|
init_bias = get_condconv_initializer(
|
||||||
|
partial(nn.init.uniform_, a=-bound, b=bound), self.num_experts, self.bias_shape)
|
||||||
|
init_bias(self.bias)
|
||||||
|
|
||||||
|
def forward(self, x, routing_weights):
|
||||||
|
B, C, H, W = x.shape
|
||||||
|
weight = torch.matmul(routing_weights, self.weight)
|
||||||
|
new_weight_shape = (B * self.out_channels, self.in_channels // self.groups) + self.kernel_size
|
||||||
|
weight = weight.view(new_weight_shape)
|
||||||
|
bias = None
|
||||||
|
if self.bias is not None:
|
||||||
|
bias = torch.matmul(routing_weights, self.bias)
|
||||||
|
bias = bias.view(B * self.out_channels)
|
||||||
|
# move batch elements with channels so each batch element can be efficiently convolved with separate kernel
|
||||||
|
x = x.view(1, B * C, H, W)
|
||||||
|
if self.dynamic_padding:
|
||||||
|
out = conv2d_same(
|
||||||
|
x, weight, bias, stride=self.stride, padding=self.padding,
|
||||||
|
dilation=self.dilation, groups=self.groups * B)
|
||||||
|
else:
|
||||||
|
out = F.conv2d(
|
||||||
|
x, weight, bias, stride=self.stride, padding=self.padding,
|
||||||
|
dilation=self.dilation, groups=self.groups * B)
|
||||||
|
out = out.permute([1, 0, 2, 3]).view(B, self.out_channels, out.shape[-2], out.shape[-1])
|
||||||
|
|
||||||
|
# Literal port (from TF definition)
|
||||||
|
# x = torch.split(x, 1, 0)
|
||||||
|
# weight = torch.split(weight, 1, 0)
|
||||||
|
# if self.bias is not None:
|
||||||
|
# bias = torch.matmul(routing_weights, self.bias)
|
||||||
|
# bias = torch.split(bias, 1, 0)
|
||||||
|
# else:
|
||||||
|
# bias = [None] * B
|
||||||
|
# out = []
|
||||||
|
# for xi, wi, bi in zip(x, weight, bias):
|
||||||
|
# wi = wi.view(*self.weight_shape)
|
||||||
|
# if bi is not None:
|
||||||
|
# bi = bi.view(*self.bias_shape)
|
||||||
|
# out.append(self.conv_fn(
|
||||||
|
# xi, wi, bi, stride=self.stride, padding=self.padding,
|
||||||
|
# dilation=self.dilation, groups=self.groups))
|
||||||
|
# out = torch.cat(out, 0)
|
||||||
|
return out
|
@ -0,0 +1,66 @@
|
|||||||
|
""" Conv2d w/ Same Padding
|
||||||
|
|
||||||
|
Hacked together by Ross Wightman
|
||||||
|
"""
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from typing import Union, List, Tuple, Optional, Callable
|
||||||
|
import math
|
||||||
|
|
||||||
|
from .padding import get_padding, pad_same, is_static_pad
|
||||||
|
|
||||||
|
|
||||||
|
def conv2d_same(
|
||||||
|
x, weight: torch.Tensor, bias: Optional[torch.Tensor] = None, stride: Tuple[int, int] = (1, 1),
|
||||||
|
padding: Tuple[int, int] = (0, 0), dilation: Tuple[int, int] = (1, 1), groups: int = 1):
|
||||||
|
x = pad_same(x, weight.shape[-2:], stride, dilation)
|
||||||
|
return F.conv2d(x, weight, bias, stride, (0, 0), dilation, groups)
|
||||||
|
|
||||||
|
|
||||||
|
class Conv2dSame(nn.Conv2d):
|
||||||
|
""" Tensorflow like 'SAME' convolution wrapper for 2D convolutions
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, in_channels, out_channels, kernel_size, stride=1,
|
||||||
|
padding=0, dilation=1, groups=1, bias=True):
|
||||||
|
super(Conv2dSame, self).__init__(
|
||||||
|
in_channels, out_channels, kernel_size, stride, 0, dilation, groups, bias)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return conv2d_same(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
|
||||||
|
|
||||||
|
|
||||||
|
def get_padding_value(padding, kernel_size, **kwargs) -> Tuple[Tuple, bool]:
|
||||||
|
dynamic = False
|
||||||
|
if isinstance(padding, str):
|
||||||
|
# for any string padding, the padding will be calculated for you, one of three ways
|
||||||
|
padding = padding.lower()
|
||||||
|
if padding == 'same':
|
||||||
|
# TF compatible 'SAME' padding, has a performance and GPU memory allocation impact
|
||||||
|
if is_static_pad(kernel_size, **kwargs):
|
||||||
|
# static case, no extra overhead
|
||||||
|
padding = get_padding(kernel_size, **kwargs)
|
||||||
|
else:
|
||||||
|
# dynamic 'SAME' padding, has runtime/GPU memory overhead
|
||||||
|
padding = 0
|
||||||
|
dynamic = True
|
||||||
|
elif padding == 'valid':
|
||||||
|
# 'VALID' padding, same as padding=0
|
||||||
|
padding = 0
|
||||||
|
else:
|
||||||
|
# Default to PyTorch style 'same'-ish symmetric padding
|
||||||
|
padding = get_padding(kernel_size, **kwargs)
|
||||||
|
return padding, dynamic
|
||||||
|
|
||||||
|
|
||||||
|
def create_conv2d_pad(in_chs, out_chs, kernel_size, **kwargs):
|
||||||
|
padding = kwargs.pop('padding', '')
|
||||||
|
kwargs.setdefault('bias', False)
|
||||||
|
padding, is_dynamic = get_padding_value(padding, kernel_size, **kwargs)
|
||||||
|
if is_dynamic:
|
||||||
|
return Conv2dSame(in_chs, out_chs, kernel_size, **kwargs)
|
||||||
|
else:
|
||||||
|
return nn.Conv2d(in_chs, out_chs, kernel_size, padding=padding, **kwargs)
|
||||||
|
|
||||||
|
|
@ -0,0 +1,32 @@
|
|||||||
|
""" Conv2d + BN + Act
|
||||||
|
|
||||||
|
Hacked together by Ross Wightman
|
||||||
|
"""
|
||||||
|
from torch import nn as nn
|
||||||
|
|
||||||
|
from timm.models.layers import get_padding
|
||||||
|
|
||||||
|
|
||||||
|
class ConvBnAct(nn.Module):
|
||||||
|
def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, dilation=1, groups=1,
|
||||||
|
drop_block=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d):
|
||||||
|
super(ConvBnAct, self).__init__()
|
||||||
|
padding = get_padding(kernel_size, stride, dilation) # assuming PyTorch style padding for this block
|
||||||
|
self.conv = nn.Conv2d(
|
||||||
|
in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride,
|
||||||
|
padding=padding, dilation=dilation, groups=groups, bias=False)
|
||||||
|
self.bn = norm_layer(out_channels)
|
||||||
|
self.drop_block = drop_block
|
||||||
|
if act_layer is not None:
|
||||||
|
self.act = act_layer(inplace=True)
|
||||||
|
else:
|
||||||
|
self.act = None
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.conv(x)
|
||||||
|
x = self.bn(x)
|
||||||
|
if self.drop_block is not None:
|
||||||
|
x = self.drop_block(x)
|
||||||
|
if self.act is not None:
|
||||||
|
x = self.act(x)
|
||||||
|
return x
|
@ -0,0 +1,35 @@
|
|||||||
|
""" Select AttentionFactory Method
|
||||||
|
|
||||||
|
Hacked together by Ross Wightman
|
||||||
|
"""
|
||||||
|
import torch
|
||||||
|
from .se import SEModule
|
||||||
|
from .eca import EcaModule, CecaModule
|
||||||
|
from .cbam import CbamModule, LightCbamModule
|
||||||
|
|
||||||
|
|
||||||
|
def create_attn(attn_type, channels, **kwargs):
|
||||||
|
module_cls = None
|
||||||
|
if attn_type is not None:
|
||||||
|
if isinstance(attn_type, str):
|
||||||
|
attn_type = attn_type.lower()
|
||||||
|
if attn_type == 'se':
|
||||||
|
module_cls = SEModule
|
||||||
|
elif attn_type == 'eca':
|
||||||
|
module_cls = EcaModule
|
||||||
|
elif attn_type == 'eca':
|
||||||
|
module_cls = CecaModule
|
||||||
|
elif attn_type == 'cbam':
|
||||||
|
module_cls = CbamModule
|
||||||
|
elif attn_type == 'lcbam':
|
||||||
|
module_cls = LightCbamModule
|
||||||
|
else:
|
||||||
|
assert False, "Invalid attn module (%s)" % attn_type
|
||||||
|
elif isinstance(attn_type, bool):
|
||||||
|
if attn_type:
|
||||||
|
module_cls = SEModule
|
||||||
|
else:
|
||||||
|
module_cls = attn_type
|
||||||
|
if module_cls is not None:
|
||||||
|
return module_cls(channels, **kwargs)
|
||||||
|
return None
|
@ -0,0 +1,30 @@
|
|||||||
|
""" Create Conv2d Factory Method
|
||||||
|
|
||||||
|
Hacked together by Ross Wightman
|
||||||
|
"""
|
||||||
|
|
||||||
|
from .mixed_conv2d import MixedConv2d
|
||||||
|
from .cond_conv2d import CondConv2d
|
||||||
|
from .conv2d_same import create_conv2d_pad
|
||||||
|
|
||||||
|
|
||||||
|
def create_conv2d(in_chs, out_chs, kernel_size, **kwargs):
|
||||||
|
""" Select a 2d convolution implementation based on arguments
|
||||||
|
Creates and returns one of torch.nn.Conv2d, Conv2dSame, MixedConv2d, or CondConv2d.
|
||||||
|
|
||||||
|
Used extensively by EfficientNet, MobileNetv3 and related networks.
|
||||||
|
"""
|
||||||
|
assert 'groups' not in kwargs # only use 'depthwise' bool arg
|
||||||
|
if isinstance(kernel_size, list):
|
||||||
|
assert 'num_experts' not in kwargs # MixNet + CondConv combo not supported currently
|
||||||
|
# We're going to use only lists for defining the MixedConv2d kernel groups,
|
||||||
|
# ints, tuples, other iterables will continue to pass to normal conv and specify h, w.
|
||||||
|
m = MixedConv2d(in_chs, out_chs, kernel_size, **kwargs)
|
||||||
|
else:
|
||||||
|
depthwise = kwargs.pop('depthwise', False)
|
||||||
|
groups = out_chs if depthwise else 1
|
||||||
|
if 'num_experts' in kwargs and kwargs['num_experts'] > 0:
|
||||||
|
m = CondConv2d(in_chs, out_chs, kernel_size, groups=groups, **kwargs)
|
||||||
|
else:
|
||||||
|
m = create_conv2d_pad(in_chs, out_chs, kernel_size, groups=groups, **kwargs)
|
||||||
|
return m
|
@ -0,0 +1,109 @@
|
|||||||
|
""" DropBlock, DropPath
|
||||||
|
|
||||||
|
PyTorch implementations of DropBlock and DropPath (Stochastic Depth) regularization layers.
|
||||||
|
|
||||||
|
Papers:
|
||||||
|
DropBlock: A regularization method for convolutional networks (https://arxiv.org/abs/1810.12890)
|
||||||
|
|
||||||
|
Deep Networks with Stochastic Depth (https://arxiv.org/abs/1603.09382)
|
||||||
|
|
||||||
|
Code:
|
||||||
|
DropBlock impl inspired by two Tensorflow impl that I liked:
|
||||||
|
- https://github.com/tensorflow/tpu/blob/master/models/official/resnet/resnet_model.py#L74
|
||||||
|
- https://github.com/clovaai/assembled-cnn/blob/master/nets/blocks.py
|
||||||
|
|
||||||
|
Hacked together by Ross Wightman
|
||||||
|
"""
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import numpy as np
|
||||||
|
import math
|
||||||
|
|
||||||
|
|
||||||
|
def drop_block_2d(x, drop_prob=0.1, training=False, block_size=7, gamma_scale=1.0, drop_with_noise=False):
|
||||||
|
""" DropBlock. See https://arxiv.org/pdf/1810.12890.pdf
|
||||||
|
|
||||||
|
DropBlock with an experimental gaussian noise option. This layer has been tested on a few training
|
||||||
|
runs with success, but needs further validation and possibly optimization for lower runtime impact.
|
||||||
|
|
||||||
|
"""
|
||||||
|
if drop_prob == 0. or not training:
|
||||||
|
return x
|
||||||
|
_, _, height, width = x.shape
|
||||||
|
total_size = width * height
|
||||||
|
clipped_block_size = min(block_size, min(width, height))
|
||||||
|
# seed_drop_rate, the gamma parameter
|
||||||
|
seed_drop_rate = gamma_scale * drop_prob * total_size / clipped_block_size ** 2 / (
|
||||||
|
(width - block_size + 1) *
|
||||||
|
(height - block_size + 1))
|
||||||
|
|
||||||
|
# Forces the block to be inside the feature map.
|
||||||
|
w_i, h_i = torch.meshgrid(torch.arange(width).to(x.device), torch.arange(height).to(x.device))
|
||||||
|
valid_block = ((w_i >= clipped_block_size // 2) & (w_i < width - (clipped_block_size - 1) // 2)) & \
|
||||||
|
((h_i >= clipped_block_size // 2) & (h_i < height - (clipped_block_size - 1) // 2))
|
||||||
|
valid_block = torch.reshape(valid_block, (1, 1, height, width)).float()
|
||||||
|
|
||||||
|
uniform_noise = torch.rand_like(x, dtype=torch.float32)
|
||||||
|
block_mask = ((2 - seed_drop_rate - valid_block + uniform_noise) >= 1).float()
|
||||||
|
block_mask = -F.max_pool2d(
|
||||||
|
-block_mask,
|
||||||
|
kernel_size=clipped_block_size, # block_size, ???
|
||||||
|
stride=1,
|
||||||
|
padding=clipped_block_size // 2)
|
||||||
|
|
||||||
|
if drop_with_noise:
|
||||||
|
normal_noise = torch.randn_like(x)
|
||||||
|
x = x * block_mask + normal_noise * (1 - block_mask)
|
||||||
|
else:
|
||||||
|
normalize_scale = block_mask.numel() / (torch.sum(block_mask) + 1e-7)
|
||||||
|
x = x * block_mask * normalize_scale
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class DropBlock2d(nn.Module):
|
||||||
|
""" DropBlock. See https://arxiv.org/pdf/1810.12890.pdf
|
||||||
|
"""
|
||||||
|
def __init__(self,
|
||||||
|
drop_prob=0.1,
|
||||||
|
block_size=7,
|
||||||
|
gamma_scale=1.0,
|
||||||
|
with_noise=False):
|
||||||
|
super(DropBlock2d, self).__init__()
|
||||||
|
self.drop_prob = drop_prob
|
||||||
|
self.gamma_scale = gamma_scale
|
||||||
|
self.block_size = block_size
|
||||||
|
self.with_noise = with_noise
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return drop_block_2d(x, self.drop_prob, self.training, self.block_size, self.gamma_scale, self.with_noise)
|
||||||
|
|
||||||
|
|
||||||
|
def drop_path(x, drop_prob=0., training=False):
|
||||||
|
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
||||||
|
|
||||||
|
This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
|
||||||
|
the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
|
||||||
|
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
|
||||||
|
changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
|
||||||
|
'survival rate' as the argument.
|
||||||
|
|
||||||
|
"""
|
||||||
|
if drop_prob == 0. or not training:
|
||||||
|
return x
|
||||||
|
keep_prob = 1 - drop_prob
|
||||||
|
random_tensor = keep_prob + torch.rand((x.size()[0], 1, 1, 1), dtype=x.dtype, device=x.device)
|
||||||
|
random_tensor.floor_() # binarize
|
||||||
|
output = x.div(keep_prob) * random_tensor
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
class DropPath(nn.ModuleDict):
|
||||||
|
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
||||||
|
"""
|
||||||
|
def __init__(self, drop_prob=None):
|
||||||
|
super(DropPath, self).__init__()
|
||||||
|
self.drop_prob = drop_prob
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return drop_path(x, self.drop_prob, self.training)
|
@ -0,0 +1,124 @@
|
|||||||
|
"""
|
||||||
|
ECA module from ECAnet
|
||||||
|
|
||||||
|
paper: ECA-Net: Efficient Channel Attention for Deep Convolutional Neural Networks
|
||||||
|
https://arxiv.org/abs/1910.03151
|
||||||
|
|
||||||
|
Original ECA model borrowed from https://github.com/BangguWu/ECANet
|
||||||
|
|
||||||
|
Modified circular ECA implementation and adaption for use in timm package
|
||||||
|
by Chris Ha https://github.com/VRandme
|
||||||
|
|
||||||
|
Original License:
|
||||||
|
|
||||||
|
MIT License
|
||||||
|
|
||||||
|
Copyright (c) 2019 BangguWu, Qilong Wang
|
||||||
|
|
||||||
|
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||||
|
of this software and associated documentation files (the "Software"), to deal
|
||||||
|
in the Software without restriction, including without limitation the rights
|
||||||
|
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||||
|
copies of the Software, and to permit persons to whom the Software is
|
||||||
|
furnished to do so, subject to the following conditions:
|
||||||
|
|
||||||
|
The above copyright notice and this permission notice shall be included in all
|
||||||
|
copies or substantial portions of the Software.
|
||||||
|
|
||||||
|
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||||
|
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||||
|
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||||
|
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||||
|
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||||
|
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||||
|
SOFTWARE.
|
||||||
|
"""
|
||||||
|
import math
|
||||||
|
from torch import nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
|
||||||
|
class EcaModule(nn.Module):
|
||||||
|
"""Constructs an ECA module.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
channels: Number of channels of the input feature map for use in adaptive kernel sizes
|
||||||
|
for actual calculations according to channel.
|
||||||
|
gamma, beta: when channel is given parameters of mapping function
|
||||||
|
refer to original paper https://arxiv.org/pdf/1910.03151.pdf
|
||||||
|
(default=None. if channel size not given, use k_size given for kernel size.)
|
||||||
|
kernel_size: Adaptive selection of kernel size (default=3)
|
||||||
|
"""
|
||||||
|
def __init__(self, channels=None, kernel_size=3, gamma=2, beta=1):
|
||||||
|
super(EcaModule, self).__init__()
|
||||||
|
assert kernel_size % 2 == 1
|
||||||
|
|
||||||
|
if channels is not None:
|
||||||
|
t = int(abs(math.log(channels, 2) + beta) / gamma)
|
||||||
|
kernel_size = max(t if t % 2 else t + 1, 3)
|
||||||
|
|
||||||
|
self.avg_pool = nn.AdaptiveAvgPool2d(1)
|
||||||
|
self.conv = nn.Conv1d(1, 1, kernel_size=kernel_size, padding=(kernel_size - 1) // 2, bias=False)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
# Feature descriptor on the global spatial information
|
||||||
|
y = self.avg_pool(x)
|
||||||
|
# Reshape for convolution
|
||||||
|
y = y.view(x.shape[0], 1, -1)
|
||||||
|
# Two different branches of ECA module
|
||||||
|
y = self.conv(y)
|
||||||
|
# Multi-scale information fusion
|
||||||
|
y = y.view(x.shape[0], -1, 1, 1).sigmoid()
|
||||||
|
return x * y.expand_as(x)
|
||||||
|
|
||||||
|
|
||||||
|
class CecaModule(nn.Module):
|
||||||
|
"""Constructs a circular ECA module.
|
||||||
|
|
||||||
|
ECA module where the conv uses circular padding rather than zero padding.
|
||||||
|
Unlike the spatial dimension, the channels do not have inherent ordering nor
|
||||||
|
locality. Although this module in essence, applies such an assumption, it is unnecessary
|
||||||
|
to limit the channels on either "edge" from being circularly adapted to each other.
|
||||||
|
This will fundamentally increase connectivity and possibly increase performance metrics
|
||||||
|
(accuracy, robustness), without signficantly impacting resource metrics
|
||||||
|
(parameter size, throughput,latency, etc)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
channels: Number of channels of the input feature map for use in adaptive kernel sizes
|
||||||
|
for actual calculations according to channel.
|
||||||
|
gamma, beta: when channel is given parameters of mapping function
|
||||||
|
refer to original paper https://arxiv.org/pdf/1910.03151.pdf
|
||||||
|
(default=None. if channel size not given, use k_size given for kernel size.)
|
||||||
|
kernel_size: Adaptive selection of kernel size (default=3)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, channels=None, kernel_size=3, gamma=2, beta=1):
|
||||||
|
super(CecaModule, self).__init__()
|
||||||
|
assert kernel_size % 2 == 1
|
||||||
|
|
||||||
|
if channels is not None:
|
||||||
|
t = int(abs(math.log(channels, 2) + beta) / gamma)
|
||||||
|
kernel_size = max(t if t % 2 else t + 1, 3)
|
||||||
|
|
||||||
|
self.avg_pool = nn.AdaptiveAvgPool2d(1)
|
||||||
|
#pytorch circular padding mode is buggy as of pytorch 1.4
|
||||||
|
#see https://github.com/pytorch/pytorch/pull/17240
|
||||||
|
|
||||||
|
#implement manual circular padding
|
||||||
|
self.conv = nn.Conv1d(1, 1, kernel_size=kernel_size, padding=0, bias=False)
|
||||||
|
self.padding = (kernel_size - 1) // 2
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
# Feature descriptor on the global spatial information
|
||||||
|
y = self.avg_pool(x)
|
||||||
|
|
||||||
|
# Manually implement circular padding, F.pad does not seemed to be bugged
|
||||||
|
y = F.pad(y.view(x.shape[0], 1, -1), (self.padding, self.padding), mode='circular')
|
||||||
|
|
||||||
|
# Two different branches of ECA module
|
||||||
|
y = self.conv(y)
|
||||||
|
|
||||||
|
# Multi-scale information fusion
|
||||||
|
y = y.view(x.shape[0], -1, 1, 1).sigmoid()
|
||||||
|
|
||||||
|
return x * y.expand_as(x)
|
@ -0,0 +1,27 @@
|
|||||||
|
""" Layer/Module Helpers
|
||||||
|
|
||||||
|
Hacked together by Ross Wightman
|
||||||
|
"""
|
||||||
|
from itertools import repeat
|
||||||
|
from torch._six import container_abcs
|
||||||
|
|
||||||
|
|
||||||
|
# From PyTorch internals
|
||||||
|
def _ntuple(n):
|
||||||
|
def parse(x):
|
||||||
|
if isinstance(x, container_abcs.Iterable):
|
||||||
|
return x
|
||||||
|
return tuple(repeat(x, n))
|
||||||
|
return parse
|
||||||
|
|
||||||
|
|
||||||
|
tup_single = _ntuple(1)
|
||||||
|
tup_pair = _ntuple(2)
|
||||||
|
tup_triple = _ntuple(3)
|
||||||
|
tup_quadruple = _ntuple(4)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -0,0 +1,51 @@
|
|||||||
|
""" PyTorch Mixed Convolution
|
||||||
|
|
||||||
|
Paper: MixConv: Mixed Depthwise Convolutional Kernels (https://arxiv.org/abs/1907.09595)
|
||||||
|
|
||||||
|
Hacked together by Ross Wightman
|
||||||
|
"""
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import nn as nn
|
||||||
|
|
||||||
|
from .conv2d_same import create_conv2d_pad
|
||||||
|
|
||||||
|
|
||||||
|
def _split_channels(num_chan, num_groups):
|
||||||
|
split = [num_chan // num_groups for _ in range(num_groups)]
|
||||||
|
split[0] += num_chan - sum(split)
|
||||||
|
return split
|
||||||
|
|
||||||
|
|
||||||
|
class MixedConv2d(nn.ModuleDict):
|
||||||
|
""" Mixed Grouped Convolution
|
||||||
|
|
||||||
|
Based on MDConv and GroupedConv in MixNet impl:
|
||||||
|
https://github.com/tensorflow/tpu/blob/master/models/official/mnasnet/mixnet/custom_layers.py
|
||||||
|
"""
|
||||||
|
def __init__(self, in_channels, out_channels, kernel_size=3,
|
||||||
|
stride=1, padding='', dilation=1, depthwise=False, **kwargs):
|
||||||
|
super(MixedConv2d, self).__init__()
|
||||||
|
|
||||||
|
kernel_size = kernel_size if isinstance(kernel_size, list) else [kernel_size]
|
||||||
|
num_groups = len(kernel_size)
|
||||||
|
in_splits = _split_channels(in_channels, num_groups)
|
||||||
|
out_splits = _split_channels(out_channels, num_groups)
|
||||||
|
self.in_channels = sum(in_splits)
|
||||||
|
self.out_channels = sum(out_splits)
|
||||||
|
for idx, (k, in_ch, out_ch) in enumerate(zip(kernel_size, in_splits, out_splits)):
|
||||||
|
conv_groups = out_ch if depthwise else 1
|
||||||
|
# use add_module to keep key space clean
|
||||||
|
self.add_module(
|
||||||
|
str(idx),
|
||||||
|
create_conv2d_pad(
|
||||||
|
in_ch, out_ch, k, stride=stride,
|
||||||
|
padding=padding, dilation=dilation, groups=conv_groups, **kwargs)
|
||||||
|
)
|
||||||
|
self.splits = in_splits
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x_split = torch.split(x, self.splits, 1)
|
||||||
|
x_out = [c(x_split[i]) for i, c in enumerate(self.values())]
|
||||||
|
x = torch.cat(x_out, 1)
|
||||||
|
return x
|
@ -0,0 +1,33 @@
|
|||||||
|
""" Padding Helpers
|
||||||
|
|
||||||
|
Hacked together by Ross Wightman
|
||||||
|
"""
|
||||||
|
import math
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
|
||||||
|
# Calculate symmetric padding for a convolution
|
||||||
|
def get_padding(kernel_size: int, stride: int = 1, dilation: int = 1, **_) -> int:
|
||||||
|
padding = ((stride - 1) + dilation * (kernel_size - 1)) // 2
|
||||||
|
return padding
|
||||||
|
|
||||||
|
|
||||||
|
# Calculate asymmetric TensorFlow-like 'SAME' padding for a convolution
|
||||||
|
def get_same_padding(x: int, k: int, s: int, d: int):
|
||||||
|
return max((math.ceil(x / s) - 1) * s + (k - 1) * d + 1 - x, 0)
|
||||||
|
|
||||||
|
|
||||||
|
# Can SAME padding for given args be done statically?
|
||||||
|
def is_static_pad(kernel_size: int, stride: int = 1, dilation: int = 1, **_):
|
||||||
|
return stride == 1 and (dilation * (kernel_size - 1)) % 2 == 0
|
||||||
|
|
||||||
|
|
||||||
|
# Dynamically pad input x with 'SAME' padding for conv with specified args
|
||||||
|
def pad_same(x, k: List[int], s: List[int], d: List[int] = (1, 1)):
|
||||||
|
ih, iw = x.size()[-2:]
|
||||||
|
pad_h, pad_w = get_same_padding(ih, k[0], s[0], d[0]), get_same_padding(iw, k[1], s[1], d[1])
|
||||||
|
if pad_h > 0 or pad_w > 0:
|
||||||
|
x = F.pad(x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2])
|
||||||
|
return x
|
@ -0,0 +1,21 @@
|
|||||||
|
from torch import nn as nn
|
||||||
|
|
||||||
|
|
||||||
|
class SEModule(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, channels, reduction=16, act_layer=nn.ReLU):
|
||||||
|
super(SEModule, self).__init__()
|
||||||
|
self.avg_pool = nn.AdaptiveAvgPool2d(1)
|
||||||
|
reduction_channels = max(channels // reduction, 8)
|
||||||
|
self.fc1 = nn.Conv2d(
|
||||||
|
channels, reduction_channels, kernel_size=1, padding=0, bias=True)
|
||||||
|
self.act = act_layer(inplace=True)
|
||||||
|
self.fc2 = nn.Conv2d(
|
||||||
|
reduction_channels, channels, kernel_size=1, padding=0, bias=True)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x_se = self.avg_pool(x)
|
||||||
|
x_se = self.fc1(x_se)
|
||||||
|
x_se = self.act(x_se)
|
||||||
|
x_se = self.fc2(x_se)
|
||||||
|
return x * x_se.sigmoid()
|
@ -0,0 +1,120 @@
|
|||||||
|
""" Selective Kernel Convolution/Attention
|
||||||
|
|
||||||
|
Paper: Selective Kernel Networks (https://arxiv.org/abs/1903.06586)
|
||||||
|
|
||||||
|
Hacked together by Ross Wightman
|
||||||
|
"""
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import nn as nn
|
||||||
|
|
||||||
|
from .conv_bn_act import ConvBnAct
|
||||||
|
|
||||||
|
|
||||||
|
def _kernel_valid(k):
|
||||||
|
if isinstance(k, (list, tuple)):
|
||||||
|
for ki in k:
|
||||||
|
return _kernel_valid(ki)
|
||||||
|
assert k >= 3 and k % 2
|
||||||
|
|
||||||
|
|
||||||
|
class SelectiveKernelAttn(nn.Module):
|
||||||
|
def __init__(self, channels, num_paths=2, attn_channels=32,
|
||||||
|
act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d):
|
||||||
|
""" Selective Kernel Attention Module
|
||||||
|
|
||||||
|
Selective Kernel attention mechanism factored out into its own module.
|
||||||
|
|
||||||
|
"""
|
||||||
|
super(SelectiveKernelAttn, self).__init__()
|
||||||
|
self.num_paths = num_paths
|
||||||
|
self.pool = nn.AdaptiveAvgPool2d(1)
|
||||||
|
self.fc_reduce = nn.Conv2d(channels, attn_channels, kernel_size=1, bias=False)
|
||||||
|
self.bn = norm_layer(attn_channels)
|
||||||
|
self.act = act_layer(inplace=True)
|
||||||
|
self.fc_select = nn.Conv2d(attn_channels, channels * num_paths, kernel_size=1, bias=False)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
assert x.shape[1] == self.num_paths
|
||||||
|
x = torch.sum(x, dim=1)
|
||||||
|
x = self.pool(x)
|
||||||
|
x = self.fc_reduce(x)
|
||||||
|
x = self.bn(x)
|
||||||
|
x = self.act(x)
|
||||||
|
x = self.fc_select(x)
|
||||||
|
B, C, H, W = x.shape
|
||||||
|
x = x.view(B, self.num_paths, C // self.num_paths, H, W)
|
||||||
|
x = torch.softmax(x, dim=1)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class SelectiveKernelConv(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, in_channels, out_channels, kernel_size=None, stride=1, dilation=1, groups=1,
|
||||||
|
attn_reduction=16, min_attn_channels=32, keep_3x3=True, split_input=False,
|
||||||
|
drop_block=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d):
|
||||||
|
""" Selective Kernel Convolution Module
|
||||||
|
|
||||||
|
As described in Selective Kernel Networks (https://arxiv.org/abs/1903.06586) with some modifications.
|
||||||
|
|
||||||
|
Largest change is the input split, which divides the input channels across each convolution path, this can
|
||||||
|
be viewed as a grouping of sorts, but the output channel counts expand to the module level value. This keeps
|
||||||
|
the parameter count from ballooning when the convolutions themselves don't have groups, but still provides
|
||||||
|
a noteworthy increase in performance over similar param count models without this attention layer. -Ross W
|
||||||
|
|
||||||
|
Args:
|
||||||
|
in_channels (int): module input (feature) channel count
|
||||||
|
out_channels (int): module output (feature) channel count
|
||||||
|
kernel_size (int, list): kernel size for each convolution branch
|
||||||
|
stride (int): stride for convolutions
|
||||||
|
dilation (int): dilation for module as a whole, impacts dilation of each branch
|
||||||
|
groups (int): number of groups for each branch
|
||||||
|
attn_reduction (int, float): reduction factor for attention features
|
||||||
|
min_attn_channels (int): minimum attention feature channels
|
||||||
|
keep_3x3 (bool): keep all branch convolution kernels as 3x3, changing larger kernels for dilations
|
||||||
|
split_input (bool): split input channels evenly across each convolution branch, keeps param count lower,
|
||||||
|
can be viewed as grouping by path, output expands to module out_channels count
|
||||||
|
drop_block (nn.Module): drop block module
|
||||||
|
act_layer (nn.Module): activation layer to use
|
||||||
|
norm_layer (nn.Module): batchnorm/norm layer to use
|
||||||
|
"""
|
||||||
|
super(SelectiveKernelConv, self).__init__()
|
||||||
|
kernel_size = kernel_size or [3, 5] # default to one 3x3 and one 5x5 branch. 5x5 -> 3x3 + dilation
|
||||||
|
_kernel_valid(kernel_size)
|
||||||
|
if not isinstance(kernel_size, list):
|
||||||
|
kernel_size = [kernel_size] * 2
|
||||||
|
if keep_3x3:
|
||||||
|
dilation = [dilation * (k - 1) // 2 for k in kernel_size]
|
||||||
|
kernel_size = [3] * len(kernel_size)
|
||||||
|
else:
|
||||||
|
dilation = [dilation] * len(kernel_size)
|
||||||
|
self.num_paths = len(kernel_size)
|
||||||
|
self.in_channels = in_channels
|
||||||
|
self.out_channels = out_channels
|
||||||
|
self.split_input = split_input
|
||||||
|
if self.split_input:
|
||||||
|
assert in_channels % self.num_paths == 0
|
||||||
|
in_channels = in_channels // self.num_paths
|
||||||
|
groups = min(out_channels, groups)
|
||||||
|
|
||||||
|
conv_kwargs = dict(
|
||||||
|
stride=stride, groups=groups, drop_block=drop_block, act_layer=act_layer, norm_layer=norm_layer)
|
||||||
|
self.paths = nn.ModuleList([
|
||||||
|
ConvBnAct(in_channels, out_channels, kernel_size=k, dilation=d, **conv_kwargs)
|
||||||
|
for k, d in zip(kernel_size, dilation)])
|
||||||
|
|
||||||
|
attn_channels = max(int(out_channels / attn_reduction), min_attn_channels)
|
||||||
|
self.attn = SelectiveKernelAttn(out_channels, self.num_paths, attn_channels)
|
||||||
|
self.drop_block = drop_block
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
if self.split_input:
|
||||||
|
x_split = torch.split(x, self.in_channels // self.num_paths, 1)
|
||||||
|
x_paths = [op(x_split[i]) for i, op in enumerate(self.paths)]
|
||||||
|
else:
|
||||||
|
x_paths = [op(x) for op in self.paths]
|
||||||
|
x = torch.stack(x_paths, dim=1)
|
||||||
|
x_attn = self.attn(x)
|
||||||
|
x = x * x_attn
|
||||||
|
x = torch.sum(x, dim=1)
|
||||||
|
return x
|
@ -0,0 +1,237 @@
|
|||||||
|
""" Selective Kernel Networks (ResNet base)
|
||||||
|
|
||||||
|
Paper: Selective Kernel Networks (https://arxiv.org/abs/1903.06586)
|
||||||
|
|
||||||
|
This was inspired by reading 'Compounding the Performance Improvements...' (https://arxiv.org/abs/2001.06268)
|
||||||
|
and a streamlined impl at https://github.com/clovaai/assembled-cnn but I ended up building something closer
|
||||||
|
to the original paper with some modifications of my own to better balance param count vs accuracy.
|
||||||
|
|
||||||
|
Hacked together by Ross Wightman
|
||||||
|
"""
|
||||||
|
import math
|
||||||
|
|
||||||
|
from torch import nn as nn
|
||||||
|
|
||||||
|
from .registry import register_model
|
||||||
|
from .helpers import load_pretrained
|
||||||
|
from .layers import SelectiveKernelConv, ConvBnAct, create_attn
|
||||||
|
from .resnet import ResNet
|
||||||
|
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||||
|
|
||||||
|
|
||||||
|
def _cfg(url='', **kwargs):
|
||||||
|
return {
|
||||||
|
'url': url,
|
||||||
|
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
|
||||||
|
'crop_pct': 0.875, 'interpolation': 'bicubic',
|
||||||
|
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
|
||||||
|
'first_conv': 'conv1', 'classifier': 'fc',
|
||||||
|
**kwargs
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
default_cfgs = {
|
||||||
|
'skresnet18': _cfg(
|
||||||
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/skresnet18_ra-4eec2804.pth'),
|
||||||
|
'skresnet34': _cfg(
|
||||||
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/skresnet34_ra-bdc0ccde.pth'),
|
||||||
|
'skresnet50': _cfg(),
|
||||||
|
'skresnet50d': _cfg(),
|
||||||
|
'skresnext50_32x4d': _cfg(
|
||||||
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/skresnext50_ra-f40e40bf.pth'),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class SelectiveKernelBasic(nn.Module):
|
||||||
|
expansion = 1
|
||||||
|
|
||||||
|
def __init__(self, inplanes, planes, stride=1, downsample=None, cardinality=1, base_width=64,
|
||||||
|
sk_kwargs=None, reduce_first=1, dilation=1, first_dilation=None,
|
||||||
|
drop_block=None, drop_path=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, attn_layer=None):
|
||||||
|
super(SelectiveKernelBasic, self).__init__()
|
||||||
|
|
||||||
|
sk_kwargs = sk_kwargs or {}
|
||||||
|
conv_kwargs = dict(drop_block=drop_block, act_layer=act_layer, norm_layer=norm_layer)
|
||||||
|
assert cardinality == 1, 'BasicBlock only supports cardinality of 1'
|
||||||
|
assert base_width == 64, 'BasicBlock doest not support changing base width'
|
||||||
|
first_planes = planes // reduce_first
|
||||||
|
outplanes = planes * self.expansion
|
||||||
|
first_dilation = first_dilation or dilation
|
||||||
|
|
||||||
|
self.conv1 = SelectiveKernelConv(
|
||||||
|
inplanes, first_planes, stride=stride, dilation=first_dilation, **conv_kwargs, **sk_kwargs)
|
||||||
|
conv_kwargs['act_layer'] = None
|
||||||
|
self.conv2 = ConvBnAct(
|
||||||
|
first_planes, outplanes, kernel_size=3, dilation=dilation, **conv_kwargs)
|
||||||
|
self.se = create_attn(attn_layer, outplanes)
|
||||||
|
self.act = act_layer(inplace=True)
|
||||||
|
self.downsample = downsample
|
||||||
|
self.stride = stride
|
||||||
|
self.dilation = dilation
|
||||||
|
self.drop_block = drop_block
|
||||||
|
self.drop_path = drop_path
|
||||||
|
|
||||||
|
def zero_init_last_bn(self):
|
||||||
|
nn.init.zeros_(self.conv2.bn.weight)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
residual = x
|
||||||
|
x = self.conv1(x)
|
||||||
|
x = self.conv2(x)
|
||||||
|
if self.se is not None:
|
||||||
|
x = self.se(x)
|
||||||
|
if self.drop_path is not None:
|
||||||
|
x = self.drop_path(x)
|
||||||
|
if self.downsample is not None:
|
||||||
|
residual = self.downsample(residual)
|
||||||
|
x += residual
|
||||||
|
x = self.act(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class SelectiveKernelBottleneck(nn.Module):
|
||||||
|
expansion = 4
|
||||||
|
|
||||||
|
def __init__(self, inplanes, planes, stride=1, downsample=None,
|
||||||
|
cardinality=1, base_width=64, sk_kwargs=None, reduce_first=1, dilation=1, first_dilation=None,
|
||||||
|
drop_block=None, drop_path=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, attn_layer=None):
|
||||||
|
super(SelectiveKernelBottleneck, self).__init__()
|
||||||
|
|
||||||
|
sk_kwargs = sk_kwargs or {}
|
||||||
|
conv_kwargs = dict(drop_block=drop_block, act_layer=act_layer, norm_layer=norm_layer)
|
||||||
|
width = int(math.floor(planes * (base_width / 64)) * cardinality)
|
||||||
|
first_planes = width // reduce_first
|
||||||
|
outplanes = planes * self.expansion
|
||||||
|
first_dilation = first_dilation or dilation
|
||||||
|
|
||||||
|
self.conv1 = ConvBnAct(inplanes, first_planes, kernel_size=1, **conv_kwargs)
|
||||||
|
self.conv2 = SelectiveKernelConv(
|
||||||
|
first_planes, width, stride=stride, dilation=first_dilation, groups=cardinality,
|
||||||
|
**conv_kwargs, **sk_kwargs)
|
||||||
|
conv_kwargs['act_layer'] = None
|
||||||
|
self.conv3 = ConvBnAct(width, outplanes, kernel_size=1, **conv_kwargs)
|
||||||
|
self.se = create_attn(attn_layer, outplanes)
|
||||||
|
self.act = act_layer(inplace=True)
|
||||||
|
self.downsample = downsample
|
||||||
|
self.stride = stride
|
||||||
|
self.dilation = dilation
|
||||||
|
self.drop_block = drop_block
|
||||||
|
self.drop_path = drop_path
|
||||||
|
|
||||||
|
def zero_init_last_bn(self):
|
||||||
|
nn.init.zeros_(self.conv3.bn.weight)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
residual = x
|
||||||
|
x = self.conv1(x)
|
||||||
|
x = self.conv2(x)
|
||||||
|
x = self.conv3(x)
|
||||||
|
if self.se is not None:
|
||||||
|
x = self.se(x)
|
||||||
|
if self.drop_path is not None:
|
||||||
|
x = self.drop_path(x)
|
||||||
|
if self.downsample is not None:
|
||||||
|
residual = self.downsample(residual)
|
||||||
|
x += residual
|
||||||
|
x = self.act(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
@register_model
|
||||||
|
def skresnet18(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||||
|
"""Constructs a Selective Kernel ResNet-18 model.
|
||||||
|
|
||||||
|
Different from configs in Select Kernel paper or "Compounding the Performance Improvements..." this
|
||||||
|
variation splits the input channels to the selective convolutions to keep param count down.
|
||||||
|
"""
|
||||||
|
default_cfg = default_cfgs['skresnet18']
|
||||||
|
sk_kwargs = dict(
|
||||||
|
min_attn_channels=16,
|
||||||
|
attn_reduction=8,
|
||||||
|
split_input=True
|
||||||
|
)
|
||||||
|
model = ResNet(
|
||||||
|
SelectiveKernelBasic, [2, 2, 2, 2], num_classes=num_classes, in_chans=in_chans,
|
||||||
|
block_args=dict(sk_kwargs=sk_kwargs), zero_init_last_bn=False, **kwargs)
|
||||||
|
model.default_cfg = default_cfg
|
||||||
|
if pretrained:
|
||||||
|
load_pretrained(model, default_cfg, num_classes, in_chans)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
@register_model
|
||||||
|
def skresnet34(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||||
|
"""Constructs a Selective Kernel ResNet-34 model.
|
||||||
|
|
||||||
|
Different from configs in Select Kernel paper or "Compounding the Performance Improvements..." this
|
||||||
|
variation splits the input channels to the selective convolutions to keep param count down.
|
||||||
|
"""
|
||||||
|
default_cfg = default_cfgs['skresnet34']
|
||||||
|
sk_kwargs = dict(
|
||||||
|
min_attn_channels=16,
|
||||||
|
attn_reduction=8,
|
||||||
|
split_input=True
|
||||||
|
)
|
||||||
|
model = ResNet(
|
||||||
|
SelectiveKernelBasic, [3, 4, 6, 3], num_classes=num_classes, in_chans=in_chans,
|
||||||
|
block_args=dict(sk_kwargs=sk_kwargs), zero_init_last_bn=False, **kwargs)
|
||||||
|
model.default_cfg = default_cfg
|
||||||
|
if pretrained:
|
||||||
|
load_pretrained(model, default_cfg, num_classes, in_chans)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
@register_model
|
||||||
|
def skresnet50(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||||
|
"""Constructs a Select Kernel ResNet-50 model.
|
||||||
|
|
||||||
|
Different from configs in Select Kernel paper or "Compounding the Performance Improvements..." this
|
||||||
|
variation splits the input channels to the selective convolutions to keep param count down.
|
||||||
|
"""
|
||||||
|
sk_kwargs = dict(
|
||||||
|
split_input=True,
|
||||||
|
)
|
||||||
|
default_cfg = default_cfgs['skresnet50']
|
||||||
|
model = ResNet(
|
||||||
|
SelectiveKernelBottleneck, [3, 4, 6, 3], num_classes=num_classes, in_chans=in_chans,
|
||||||
|
block_args=dict(sk_kwargs=sk_kwargs), zero_init_last_bn=False, **kwargs)
|
||||||
|
model.default_cfg = default_cfg
|
||||||
|
if pretrained:
|
||||||
|
load_pretrained(model, default_cfg, num_classes, in_chans)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
@register_model
|
||||||
|
def skresnet50d(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||||
|
"""Constructs a Select Kernel ResNet-50-D model.
|
||||||
|
|
||||||
|
Different from configs in Select Kernel paper or "Compounding the Performance Improvements..." this
|
||||||
|
variation splits the input channels to the selective convolutions to keep param count down.
|
||||||
|
"""
|
||||||
|
sk_kwargs = dict(
|
||||||
|
split_input=True,
|
||||||
|
)
|
||||||
|
default_cfg = default_cfgs['skresnet50d']
|
||||||
|
model = ResNet(
|
||||||
|
SelectiveKernelBottleneck, [3, 4, 6, 3], stem_width=32, stem_type='deep', avg_down=True,
|
||||||
|
num_classes=num_classes, in_chans=in_chans, block_args=dict(sk_kwargs=sk_kwargs),
|
||||||
|
zero_init_last_bn=False, **kwargs)
|
||||||
|
model.default_cfg = default_cfg
|
||||||
|
if pretrained:
|
||||||
|
load_pretrained(model, default_cfg, num_classes, in_chans)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
@register_model
|
||||||
|
def skresnext50_32x4d(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||||
|
"""Constructs a Select Kernel ResNeXt50-32x4d model. This should be equivalent to
|
||||||
|
the SKNet-50 model in the Select Kernel Paper
|
||||||
|
"""
|
||||||
|
default_cfg = default_cfgs['skresnext50_32x4d']
|
||||||
|
model = ResNet(
|
||||||
|
SelectiveKernelBottleneck, [3, 4, 6, 3], cardinality=32, base_width=4,
|
||||||
|
num_classes=num_classes, in_chans=in_chans, zero_init_last_bn=False, **kwargs)
|
||||||
|
model.default_cfg = default_cfg
|
||||||
|
if pretrained:
|
||||||
|
load_pretrained(model, default_cfg, num_classes, in_chans)
|
||||||
|
return model
|
Loading…
Reference in new issue