A bunch more layer reorg, splitting many layers into own files. Improve torchscript compatibility.

pull/87/head
Ross Wightman 5 years ago
parent 13746a33fc
commit a99ec4e7d1

@ -27,7 +27,8 @@ from .efficientnet_builder import *
from .feature_hooks import FeatureHooks
from .registry import register_model
from .helpers import load_pretrained
from .layers import SelectAdaptivePool2d, select_conv2d
from .layers import SelectAdaptivePool2d
from timm.models.layers import select_conv2d
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD

@ -1,11 +1,8 @@
from functools import partial
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import functional as F
from .layers.activations import sigmoid
from .layers.conv2d_layers import *
from .layers import select_conv2d
# Defaults used for Google/Tensorflow training of mobile networks /w RMSprop as per
@ -72,7 +69,7 @@ def round_channels(channels, multiplier=1.0, divisor=8, channel_min=None):
return make_divisible(channels, divisor, channel_min)
def drop_connect(inputs, training=False, drop_connect_rate=0.):
def drop_connect(inputs, training: bool = False, drop_connect_rate: float = 0.):
"""Apply drop connect."""
if not training:
return inputs
@ -160,7 +157,7 @@ class DepthwiseSeparableConv(nn.Module):
norm_layer=nn.BatchNorm2d, norm_kwargs=None, drop_connect_rate=0.):
super(DepthwiseSeparableConv, self).__init__()
norm_kwargs = norm_kwargs or {}
self.has_se = se_ratio is not None and se_ratio > 0.
has_se = se_ratio is not None and se_ratio > 0.
self.has_residual = (stride == 1 and in_chs == out_chs) and not noskip
self.has_pw_act = pw_act # activation after point-wise conv
self.drop_connect_rate = drop_connect_rate
@ -171,9 +168,11 @@ class DepthwiseSeparableConv(nn.Module):
self.act1 = act_layer(inplace=True)
# Squeeze-and-excitation
if self.has_se:
if has_se:
se_kwargs = resolve_se_args(se_kwargs, in_chs, act_layer)
self.se = SqueezeExcite(in_chs, se_ratio=se_ratio, **se_kwargs)
else:
self.se = None
self.conv_pw = select_conv2d(in_chs, out_chs, pw_kernel_size, padding=pad_type)
self.bn2 = norm_layer(out_chs, **norm_kwargs)
@ -193,7 +192,7 @@ class DepthwiseSeparableConv(nn.Module):
x = self.bn1(x)
x = self.act1(x)
if self.has_se:
if self.se is not None:
x = self.se(x)
x = self.conv_pw(x)
@ -219,7 +218,7 @@ class InvertedResidual(nn.Module):
norm_kwargs = norm_kwargs or {}
conv_kwargs = conv_kwargs or {}
mid_chs = make_divisible(in_chs * exp_ratio)
self.has_se = se_ratio is not None and se_ratio > 0.
has_se = se_ratio is not None and se_ratio > 0.
self.has_residual = (in_chs == out_chs and stride == 1) and not noskip
self.drop_connect_rate = drop_connect_rate
@ -236,9 +235,11 @@ class InvertedResidual(nn.Module):
self.act2 = act_layer(inplace=True)
# Squeeze-and-excitation
if self.has_se:
if has_se:
se_kwargs = resolve_se_args(se_kwargs, in_chs, act_layer)
self.se = SqueezeExcite(mid_chs, se_ratio=se_ratio, **se_kwargs)
else:
self.se = None
# Point-wise linear projection
self.conv_pwl = select_conv2d(mid_chs, out_chs, pw_kernel_size, padding=pad_type, **conv_kwargs)
@ -269,7 +270,7 @@ class InvertedResidual(nn.Module):
x = self.act2(x)
# Squeeze-and-excitation
if self.has_se:
if self.se is not None:
x = self.se(x)
# Point-wise linear projection
@ -323,7 +324,7 @@ class CondConvResidual(InvertedResidual):
x = self.act2(x)
# Squeeze-and-excitation
if self.has_se:
if self.se is not None:
x = self.se(x)
# Point-wise linear projection
@ -350,7 +351,7 @@ class EdgeResidual(nn.Module):
mid_chs = make_divisible(fake_in_chs * exp_ratio)
else:
mid_chs = make_divisible(in_chs * exp_ratio)
self.has_se = se_ratio is not None and se_ratio > 0.
has_se = se_ratio is not None and se_ratio > 0.
self.has_residual = (in_chs == out_chs and stride == 1) and not noskip
self.drop_connect_rate = drop_connect_rate
@ -360,9 +361,11 @@ class EdgeResidual(nn.Module):
self.act1 = act_layer(inplace=True)
# Squeeze-and-excitation
if self.has_se:
if has_se:
se_kwargs = resolve_se_args(se_kwargs, in_chs, act_layer)
self.se = SqueezeExcite(mid_chs, se_ratio=se_ratio, **se_kwargs)
else:
self.se = None
# Point-wise linear projection
self.conv_pwl = select_conv2d(
@ -389,7 +392,7 @@ class EdgeResidual(nn.Module):
x = self.act1(x)
# Squeeze-and-excitation
if self.has_se:
if self.se is not None:
x = self.se(x)
# Point-wise linear projection

@ -5,7 +5,8 @@ from collections.__init__ import OrderedDict
from copy import deepcopy
import torch.nn as nn
from .layers.activations import sigmoid, HardSwish, Swish
from .layers import CondConv2d, get_condconv_initializer
from .layers.activations import HardSwish, Swish
from .efficientnet_blocks import *

@ -1,8 +1,12 @@
from .conv2d_layers import select_conv2d, MixedConv2d, CondConv2d, ConvBnAct, SelectiveKernelConv
from .conv_bn_act import ConvBnAct
from .mixed_conv2d import MixedConv2d
from .cond_conv2d import CondConv2d, get_condconv_initializer
from .select_conv2d import select_conv2d
from .selective_kernel import SelectiveKernelConv
from .eca import EcaModule, CecaModule
from .activations import *
from .adaptive_avgmax_pool import \
adaptive_avgmax_pool2d, select_adaptive_pool2d, AdaptiveAvgMaxPool2d, SelectAdaptivePool2d
from .nn_ops import DropBlock2d, DropPath
from .drop import DropBlock2d, DropPath
from .test_time_pool import TestTimePoolHead, apply_test_time_pool
from .split_batchnorm import SplitBatchNorm2d, convert_splitbn_model

@ -1,9 +1,18 @@
""" Activations
A collection of activations fn and modules with a common interface so that they can
easily be swapped. All have an `inplace` arg even if not used.
Hacked together by Ross Wightman
"""
import torch
from torch import nn as nn
from torch.nn import functional as F
_USE_MEM_EFFICIENT_ISH = True
_USE_MEM_EFFICIENT_ISH = False
if _USE_MEM_EFFICIENT_ISH:
# This version reduces memory overhead of Swish during training by
# recomputing torch.sigmoid(x) in backward instead of saving it.
@ -66,20 +75,20 @@ if _USE_MEM_EFFICIENT_ISH:
return MishJitAutoFn.apply(x)
else:
def swish(x, inplace=False):
def swish(x, inplace: bool = False):
"""Swish - Described in: https://arxiv.org/abs/1710.05941
"""
return x.mul_(x.sigmoid()) if inplace else x.mul(x.sigmoid())
def mish(x, _inplace=False):
def mish(x, _inplace: bool = False):
"""Mish: A Self Regularized Non-Monotonic Neural Activation Function - https://arxiv.org/abs/1908.08681
"""
return x.mul(F.softplus(x).tanh())
class Swish(nn.Module):
def __init__(self, inplace=False):
def __init__(self, inplace: bool = False):
super(Swish, self).__init__()
self.inplace = inplace
@ -88,7 +97,7 @@ class Swish(nn.Module):
class Mish(nn.Module):
def __init__(self, inplace=False):
def __init__(self, inplace: bool = False):
super(Mish, self).__init__()
self.inplace = inplace
@ -96,13 +105,13 @@ class Mish(nn.Module):
return mish(x, self.inplace)
def sigmoid(x, inplace=False):
def sigmoid(x, inplace: bool = False):
return x.sigmoid_() if inplace else x.sigmoid()
# PyTorch has this, but not with a consistent inplace argmument interface
class Sigmoid(nn.Module):
def __init__(self, inplace=False):
def __init__(self, inplace: bool = False):
super(Sigmoid, self).__init__()
self.inplace = inplace
@ -110,13 +119,13 @@ class Sigmoid(nn.Module):
return x.sigmoid_() if self.inplace else x.sigmoid()
def tanh(x, inplace=False):
def tanh(x, inplace: bool = False):
return x.tanh_() if inplace else x.tanh()
# PyTorch has this, but not with a consistent inplace argmument interface
class Tanh(nn.Module):
def __init__(self, inplace=False):
def __init__(self, inplace: bool = False):
super(Tanh, self).__init__()
self.inplace = inplace
@ -124,13 +133,13 @@ class Tanh(nn.Module):
return x.tanh_() if self.inplace else x.tanh()
def hard_swish(x, inplace=False):
def hard_swish(x, inplace: bool = False):
inner = F.relu6(x + 3.).div_(6.)
return x.mul_(inner) if inplace else x.mul(inner)
class HardSwish(nn.Module):
def __init__(self, inplace=False):
def __init__(self, inplace: bool = False):
super(HardSwish, self).__init__()
self.inplace = inplace
@ -138,7 +147,7 @@ class HardSwish(nn.Module):
return hard_swish(x, self.inplace)
def hard_sigmoid(x, inplace=False):
def hard_sigmoid(x, inplace: bool = False):
if inplace:
return x.add_(3.).clamp_(0., 6.).div_(6.)
else:
@ -146,7 +155,7 @@ def hard_sigmoid(x, inplace=False):
class HardSigmoid(nn.Module):
def __init__(self, inplace=False):
def __init__(self, inplace: bool = False):
super(HardSigmoid, self).__init__()
self.inplace = inplace

@ -0,0 +1,118 @@
""" Conditional Convolution
Hacked together by Ross Wightman
"""
import math
from functools import partial
import numpy as np
import torch
from torch import nn as nn
from torch.nn import functional as F
from .conv2d_same import get_padding_value, conv2d_same
from .conv_helpers import tup_pair
def get_condconv_initializer(initializer, num_experts, expert_shape):
def condconv_initializer(weight):
"""CondConv initializer function."""
num_params = np.prod(expert_shape)
if (len(weight.shape) != 2 or weight.shape[0] != num_experts or
weight.shape[1] != num_params):
raise (ValueError(
'CondConv variables must have shape [num_experts, num_params]'))
for i in range(num_experts):
initializer(weight[i].view(expert_shape))
return condconv_initializer
class CondConv2d(nn.Module):
""" Conditional Convolution
Inspired by: https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/condconv/condconv_layers.py
Grouped convolution hackery for parallel execution of the per-sample kernel filters inspired by this discussion:
https://github.com/pytorch/pytorch/issues/17983
"""
__constants__ = ['bias', 'in_channels', 'out_channels', 'dynamic_padding']
def __init__(self, in_channels, out_channels, kernel_size=3,
stride=1, padding='', dilation=1, groups=1, bias=False, num_experts=4):
super(CondConv2d, self).__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = tup_pair(kernel_size)
self.stride = tup_pair(stride)
padding_val, is_padding_dynamic = get_padding_value(
padding, kernel_size, stride=stride, dilation=dilation)
self.dynamic_padding = is_padding_dynamic # if in forward to work with torchscript
self.padding = tup_pair(padding_val)
self.dilation = tup_pair(dilation)
self.groups = groups
self.num_experts = num_experts
self.weight_shape = (self.out_channels, self.in_channels // self.groups) + self.kernel_size
weight_num_param = 1
for wd in self.weight_shape:
weight_num_param *= wd
self.weight = torch.nn.Parameter(torch.Tensor(self.num_experts, weight_num_param))
if bias:
self.bias_shape = (self.out_channels,)
self.bias = torch.nn.Parameter(torch.Tensor(self.num_experts, self.out_channels))
else:
self.register_parameter('bias', None)
self.reset_parameters()
def reset_parameters(self):
init_weight = get_condconv_initializer(
partial(nn.init.kaiming_uniform_, a=math.sqrt(5)), self.num_experts, self.weight_shape)
init_weight(self.weight)
if self.bias is not None:
fan_in = np.prod(self.weight_shape[1:])
bound = 1 / math.sqrt(fan_in)
init_bias = get_condconv_initializer(
partial(nn.init.uniform_, a=-bound, b=bound), self.num_experts, self.bias_shape)
init_bias(self.bias)
def forward(self, x, routing_weights):
B, C, H, W = x.shape
weight = torch.matmul(routing_weights, self.weight)
new_weight_shape = (B * self.out_channels, self.in_channels // self.groups) + self.kernel_size
weight = weight.view(new_weight_shape)
bias = None
if self.bias is not None:
bias = torch.matmul(routing_weights, self.bias)
bias = bias.view(B * self.out_channels)
# move batch elements with channels so each batch element can be efficiently convolved with separate kernel
x = x.view(1, B * C, H, W)
if self.dynamic_padding:
out = conv2d_same(
x, weight, bias, stride=self.stride, padding=self.padding,
dilation=self.dilation, groups=self.groups * B)
else:
out = F.conv2d(
x, weight, bias, stride=self.stride, padding=self.padding,
dilation=self.dilation, groups=self.groups * B)
out = out.permute([1, 0, 2, 3]).view(B, self.out_channels, out.shape[-2], out.shape[-1])
# Literal port (from TF definition)
# x = torch.split(x, 1, 0)
# weight = torch.split(weight, 1, 0)
# if self.bias is not None:
# bias = torch.matmul(routing_weights, self.bias)
# bias = torch.split(bias, 1, 0)
# else:
# bias = [None] * B
# out = []
# for xi, wi, bi in zip(x, weight, bias):
# wi = wi.view(*self.weight_shape)
# if bi is not None:
# bi = bi.view(*self.bias_shape)
# out.append(self.conv_fn(
# xi, wi, bi, stride=self.stride, padding=self.padding,
# dilation=self.dilation, groups=self.groups))
# out = torch.cat(out, 0)
return out

@ -1,361 +0,0 @@
from collections import OrderedDict
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch._six import container_abcs
from itertools import repeat
from functools import partial
import numpy as np
import math
# Tuple helpers ripped from PyTorch
def _ntuple(n):
def parse(x):
if isinstance(x, container_abcs.Iterable):
return x
return tuple(repeat(x, n))
return parse
_single = _ntuple(1)
_pair = _ntuple(2)
_triple = _ntuple(3)
_quadruple = _ntuple(4)
def _is_static_pad(kernel_size, stride=1, dilation=1, **_):
return stride == 1 and (dilation * (kernel_size - 1)) % 2 == 0
def _get_padding(kernel_size, stride=1, dilation=1, **_):
padding = ((stride - 1) + dilation * (kernel_size - 1)) // 2
return padding
def _calc_same_pad(i, k, s, d):
return max((math.ceil(i / s) - 1) * s + (k - 1) * d + 1 - i, 0)
def _split_channels(num_chan, num_groups):
split = [num_chan // num_groups for _ in range(num_groups)]
split[0] += num_chan - sum(split)
return split
# pylint: disable=unused-argument
def conv2d_same(x, weight, bias=None, stride=(1, 1), padding=(0, 0), dilation=(1, 1), groups=1):
ih, iw = x.size()[-2:]
kh, kw = weight.size()[-2:]
pad_h = _calc_same_pad(ih, kh, stride[0], dilation[0])
pad_w = _calc_same_pad(iw, kw, stride[1], dilation[1])
if pad_h > 0 or pad_w > 0:
x = F.pad(x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2])
return F.conv2d(x, weight, bias, stride, (0, 0), dilation, groups)
class Conv2dSame(nn.Conv2d):
""" Tensorflow like 'SAME' convolution wrapper for 2D convolutions
"""
# pylint: disable=unused-argument
def __init__(self, in_channels, out_channels, kernel_size, stride=1,
padding=0, dilation=1, groups=1, bias=True):
super(Conv2dSame, self).__init__(
in_channels, out_channels, kernel_size, stride, 0, dilation, groups, bias)
def forward(self, x):
return conv2d_same(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
def get_padding_value(padding, kernel_size, **kwargs):
dynamic = False
if isinstance(padding, str):
# for any string padding, the padding will be calculated for you, one of three ways
padding = padding.lower()
if padding == 'same':
# TF compatible 'SAME' padding, has a performance and GPU memory allocation impact
if _is_static_pad(kernel_size, **kwargs):
# static case, no extra overhead
padding = _get_padding(kernel_size, **kwargs)
else:
# dynamic 'SAME' padding, has runtime/GPU memory overhead
padding = 0
dynamic = True
elif padding == 'valid':
# 'VALID' padding, same as padding=0
padding = 0
else:
# Default to PyTorch style 'same'-ish symmetric padding
padding = _get_padding(kernel_size, **kwargs)
return padding, dynamic
def create_conv2d_pad(in_chs, out_chs, kernel_size, **kwargs):
padding = kwargs.pop('padding', '')
kwargs.setdefault('bias', False)
padding, is_dynamic = get_padding_value(padding, kernel_size, **kwargs)
if is_dynamic:
return Conv2dSame(in_chs, out_chs, kernel_size, **kwargs)
else:
return nn.Conv2d(in_chs, out_chs, kernel_size, padding=padding, **kwargs)
class MixedConv2d(nn.ModuleDict):
""" Mixed Grouped Convolution
Based on MDConv and GroupedConv in MixNet impl:
https://github.com/tensorflow/tpu/blob/master/models/official/mnasnet/mixnet/custom_layers.py
"""
def __init__(self, in_channels, out_channels, kernel_size=3,
stride=1, padding='', dilation=1, depthwise=False, **kwargs):
super(MixedConv2d, self).__init__()
kernel_size = kernel_size if isinstance(kernel_size, list) else [kernel_size]
num_groups = len(kernel_size)
in_splits = _split_channels(in_channels, num_groups)
out_splits = _split_channels(out_channels, num_groups)
self.in_channels = sum(in_splits)
self.out_channels = sum(out_splits)
for idx, (k, in_ch, out_ch) in enumerate(zip(kernel_size, in_splits, out_splits)):
conv_groups = out_ch if depthwise else 1
# use add_module to keep key space clean
self.add_module(
str(idx),
create_conv2d_pad(
in_ch, out_ch, k, stride=stride,
padding=padding, dilation=dilation, groups=conv_groups, **kwargs)
)
self.splits = in_splits
def forward(self, x):
x_split = torch.split(x, self.splits, 1)
x_out = [c(x_split[i]) for i, c in enumerate(self.values())]
x = torch.cat(x_out, 1)
return x
def get_condconv_initializer(initializer, num_experts, expert_shape):
def condconv_initializer(weight):
"""CondConv initializer function."""
num_params = np.prod(expert_shape)
if (len(weight.shape) != 2 or weight.shape[0] != num_experts or
weight.shape[1] != num_params):
raise (ValueError(
'CondConv variables must have shape [num_experts, num_params]'))
for i in range(num_experts):
initializer(weight[i].view(expert_shape))
return condconv_initializer
class CondConv2d(nn.Module):
""" Conditional Convolution
Inspired by: https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/condconv/condconv_layers.py
Grouped convolution hackery for parallel execution of the per-sample kernel filters inspired by this discussion:
https://github.com/pytorch/pytorch/issues/17983
"""
__constants__ = ['bias', 'in_channels', 'out_channels', 'dynamic_padding']
def __init__(self, in_channels, out_channels, kernel_size=3,
stride=1, padding='', dilation=1, groups=1, bias=False, num_experts=4):
super(CondConv2d, self).__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = _pair(kernel_size)
self.stride = _pair(stride)
padding_val, is_padding_dynamic = get_padding_value(
padding, kernel_size, stride=stride, dilation=dilation)
self.dynamic_padding = is_padding_dynamic # if in forward to work with torchscript
self.padding = _pair(padding_val)
self.dilation = _pair(dilation)
self.groups = groups
self.num_experts = num_experts
self.weight_shape = (self.out_channels, self.in_channels // self.groups) + self.kernel_size
weight_num_param = 1
for wd in self.weight_shape:
weight_num_param *= wd
self.weight = torch.nn.Parameter(torch.Tensor(self.num_experts, weight_num_param))
if bias:
self.bias_shape = (self.out_channels,)
self.bias = torch.nn.Parameter(torch.Tensor(self.num_experts, self.out_channels))
else:
self.register_parameter('bias', None)
self.reset_parameters()
def reset_parameters(self):
init_weight = get_condconv_initializer(
partial(nn.init.kaiming_uniform_, a=math.sqrt(5)), self.num_experts, self.weight_shape)
init_weight(self.weight)
if self.bias is not None:
fan_in = np.prod(self.weight_shape[1:])
bound = 1 / math.sqrt(fan_in)
init_bias = get_condconv_initializer(
partial(nn.init.uniform_, a=-bound, b=bound), self.num_experts, self.bias_shape)
init_bias(self.bias)
def forward(self, x, routing_weights):
B, C, H, W = x.shape
weight = torch.matmul(routing_weights, self.weight)
new_weight_shape = (B * self.out_channels, self.in_channels // self.groups) + self.kernel_size
weight = weight.view(new_weight_shape)
bias = None
if self.bias is not None:
bias = torch.matmul(routing_weights, self.bias)
bias = bias.view(B * self.out_channels)
# move batch elements with channels so each batch element can be efficiently convolved with separate kernel
x = x.view(1, B * C, H, W)
if self.dynamic_padding:
out = conv2d_same(
x, weight, bias, stride=self.stride, padding=self.padding,
dilation=self.dilation, groups=self.groups * B)
else:
out = F.conv2d(
x, weight, bias, stride=self.stride, padding=self.padding,
dilation=self.dilation, groups=self.groups * B)
out = out.permute([1, 0, 2, 3]).view(B, self.out_channels, out.shape[-2], out.shape[-1])
# Literal port (from TF definition)
# x = torch.split(x, 1, 0)
# weight = torch.split(weight, 1, 0)
# if self.bias is not None:
# bias = torch.matmul(routing_weights, self.bias)
# bias = torch.split(bias, 1, 0)
# else:
# bias = [None] * B
# out = []
# for xi, wi, bi in zip(x, weight, bias):
# wi = wi.view(*self.weight_shape)
# if bi is not None:
# bi = bi.view(*self.bias_shape)
# out.append(self.conv_fn(
# xi, wi, bi, stride=self.stride, padding=self.padding,
# dilation=self.dilation, groups=self.groups))
# out = torch.cat(out, 0)
return out
class SelectiveKernelAttn(nn.Module):
def __init__(self, channels, num_paths=2, attn_channels=32,
act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d):
super(SelectiveKernelAttn, self).__init__()
self.num_paths = num_paths
self.pool = nn.AdaptiveAvgPool2d(1)
self.fc_reduce = nn.Conv2d(channels, attn_channels, kernel_size=1, bias=False)
self.bn = norm_layer(attn_channels)
self.act = act_layer(inplace=True)
self.fc_select = nn.Conv2d(attn_channels, channels * num_paths, kernel_size=1, bias=False)
def forward(self, x):
assert x.shape[1] == self.num_paths
x = torch.sum(x, dim=1)
x = self.pool(x)
x = self.fc_reduce(x)
x = self.bn(x)
x = self.act(x)
x = self.fc_select(x)
B, C, H, W = x.shape
x = x.view(B, self.num_paths, C // self.num_paths, H, W)
x = torch.softmax(x, dim=1)
return x
def _kernel_valid(k):
if isinstance(k, (list, tuple)):
for ki in k:
return _kernel_valid(ki)
assert k >= 3 and k % 2
class ConvBnAct(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, dilation=1, groups=1,
drop_block=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d):
super(ConvBnAct, self).__init__()
padding = _get_padding(kernel_size, stride, dilation) # assuming PyTorch style padding for this block
self.conv = nn.Conv2d(
in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride,
padding=padding, dilation=dilation, groups=groups, bias=False)
self.bn = norm_layer(out_channels)
self.drop_block = drop_block
if act_layer is not None:
self.act = act_layer(inplace=True)
else:
self.act = None
def forward(self, x):
x = self.conv(x)
x = self.bn(x)
if self.drop_block is not None:
x = self.drop_block(x)
if self.act is not None:
x = self.act(x)
return x
class SelectiveKernelConv(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size=None, stride=1, dilation=1, groups=1,
attn_reduction=16, min_attn_channels=32, keep_3x3=True, split_input=False,
drop_block=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d):
super(SelectiveKernelConv, self).__init__()
kernel_size = kernel_size or [3, 5]
_kernel_valid(kernel_size)
if not isinstance(kernel_size, list):
kernel_size = [kernel_size] * 2
if keep_3x3:
dilation = [dilation * (k - 1) // 2 for k in kernel_size]
kernel_size = [3] * len(kernel_size)
else:
dilation = [dilation] * len(kernel_size)
self.num_paths = len(kernel_size)
self.in_channels = in_channels
self.out_channels = out_channels
self.split_input = split_input
if self.split_input:
assert in_channels % self.num_paths == 0
in_channels = in_channels // self.num_paths
groups = min(out_channels, groups)
conv_kwargs = dict(
stride=stride, groups=groups, drop_block=drop_block, act_layer=act_layer, norm_layer=norm_layer)
self.paths = nn.ModuleList([
ConvBnAct(in_channels, out_channels, kernel_size=k, dilation=d, **conv_kwargs)
for k, d in zip(kernel_size, dilation)])
attn_channels = max(int(out_channels / attn_reduction), min_attn_channels)
self.attn = SelectiveKernelAttn(out_channels, self.num_paths, attn_channels)
self.drop_block = drop_block
def forward(self, x):
if self.split_input:
x_split = torch.split(x, self.in_channels // self.num_paths, 1)
x_paths = [op(x_split[i]) for i, op in enumerate(self.paths)]
else:
x_paths = [op(x) for op in self.paths]
x = torch.stack(x_paths, dim=1)
x_attn = self.attn(x)
x = x * x_attn
x = torch.sum(x, dim=1)
return x
# helper method
def select_conv2d(in_chs, out_chs, kernel_size, **kwargs):
assert 'groups' not in kwargs # only use 'depthwise' bool arg
if isinstance(kernel_size, list):
assert 'num_experts' not in kwargs # MixNet + CondConv combo not supported currently
# We're going to use only lists for defining the MixedConv2d kernel groups,
# ints, tuples, other iterables will continue to pass to normal conv and specify h, w.
m = MixedConv2d(in_chs, out_chs, kernel_size, **kwargs)
else:
depthwise = kwargs.pop('depthwise', False)
groups = out_chs if depthwise else 1
if 'num_experts' in kwargs and kwargs['num_experts'] > 0:
m = CondConv2d(in_chs, out_chs, kernel_size, groups=groups, **kwargs)
else:
m = create_conv2d_pad(in_chs, out_chs, kernel_size, groups=groups, **kwargs)
return m

@ -0,0 +1,79 @@
""" Conv2d w/ Same Padding
Hacked together by Ross Wightman
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Union, List, Tuple, Optional, Callable
import math
from .conv_helpers import get_padding
def _is_static_pad(kernel_size, stride=1, dilation=1, **_):
return stride == 1 and (dilation * (kernel_size - 1)) % 2 == 0
def _calc_same_pad(i: int, k: int, s: int, d: int):
return max((math.ceil(i / s) - 1) * s + (k - 1) * d + 1 - i, 0)
def conv2d_same(
x, weight: torch.Tensor, bias: Optional[torch.Tensor] = None, stride: Tuple[int, int] = (1, 1),
padding: Tuple[int, int] = (0, 0), dilation: Tuple[int, int] = (1, 1), groups: int = 1):
ih, iw = x.size()[-2:]
kh, kw = weight.size()[-2:]
pad_h = _calc_same_pad(ih, kh, stride[0], dilation[0])
pad_w = _calc_same_pad(iw, kw, stride[1], dilation[1])
if pad_h > 0 or pad_w > 0:
x = F.pad(x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2])
return F.conv2d(x, weight, bias, stride, (0, 0), dilation, groups)
class Conv2dSame(nn.Conv2d):
""" Tensorflow like 'SAME' convolution wrapper for 2D convolutions
"""
def __init__(self, in_channels, out_channels, kernel_size, stride=1,
padding=0, dilation=1, groups=1, bias=True):
super(Conv2dSame, self).__init__(
in_channels, out_channels, kernel_size, stride, 0, dilation, groups, bias)
def forward(self, x):
return conv2d_same(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
def get_padding_value(padding, kernel_size, **kwargs) -> Tuple[Tuple, bool]:
dynamic = False
if isinstance(padding, str):
# for any string padding, the padding will be calculated for you, one of three ways
padding = padding.lower()
if padding == 'same':
# TF compatible 'SAME' padding, has a performance and GPU memory allocation impact
if _is_static_pad(kernel_size, **kwargs):
# static case, no extra overhead
padding = get_padding(kernel_size, **kwargs)
else:
# dynamic 'SAME' padding, has runtime/GPU memory overhead
padding = 0
dynamic = True
elif padding == 'valid':
# 'VALID' padding, same as padding=0
padding = 0
else:
# Default to PyTorch style 'same'-ish symmetric padding
padding = get_padding(kernel_size, **kwargs)
return padding, dynamic
def create_conv2d_pad(in_chs, out_chs, kernel_size, **kwargs):
padding = kwargs.pop('padding', '')
kwargs.setdefault('bias', False)
padding, is_dynamic = get_padding_value(padding, kernel_size, **kwargs)
if is_dynamic:
return Conv2dSame(in_chs, out_chs, kernel_size, **kwargs)
else:
return nn.Conv2d(in_chs, out_chs, kernel_size, padding=padding, **kwargs)

@ -0,0 +1,32 @@
""" Conv2d + BN + Act
Hacked together by Ross Wightman
"""
from torch import nn as nn
from timm.models.layers.conv_helpers import get_padding
class ConvBnAct(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, dilation=1, groups=1,
drop_block=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d):
super(ConvBnAct, self).__init__()
padding = get_padding(kernel_size, stride, dilation) # assuming PyTorch style padding for this block
self.conv = nn.Conv2d(
in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride,
padding=padding, dilation=dilation, groups=groups, bias=False)
self.bn = norm_layer(out_channels)
self.drop_block = drop_block
if act_layer is not None:
self.act = act_layer(inplace=True)
else:
self.act = None
def forward(self, x):
x = self.conv(x)
x = self.bn(x)
if self.drop_block is not None:
x = self.drop_block(x)
if self.act is not None:
x = self.act(x)
return x

@ -0,0 +1,27 @@
""" Common Helpers
Hacked together by Ross Wightman
"""
from itertools import repeat
from torch._six import container_abcs
# From PyTorch internals
def _ntuple(n):
def parse(x):
if isinstance(x, container_abcs.Iterable):
return x
return tuple(repeat(x, n))
return parse
tup_single = _ntuple(1)
tup_pair = _ntuple(2)
tup_triple = _ntuple(3)
tup_quadruple = _ntuple(4)
# Calculate symmetric padding for a convolution
def get_padding(kernel_size: int, stride: int = 1, dilation: int = 1, **_) -> int:
padding = ((stride - 1) + dilation * (kernel_size - 1)) // 2
return padding

@ -0,0 +1,49 @@
""" Conditional Convolution
Hacked together by Ross Wightman
"""
import torch
from torch import nn as nn
from .conv2d_same import create_conv2d_pad
def _split_channels(num_chan, num_groups):
split = [num_chan // num_groups for _ in range(num_groups)]
split[0] += num_chan - sum(split)
return split
class MixedConv2d(nn.ModuleDict):
""" Mixed Grouped Convolution
Based on MDConv and GroupedConv in MixNet impl:
https://github.com/tensorflow/tpu/blob/master/models/official/mnasnet/mixnet/custom_layers.py
"""
def __init__(self, in_channels, out_channels, kernel_size=3,
stride=1, padding='', dilation=1, depthwise=False, **kwargs):
super(MixedConv2d, self).__init__()
kernel_size = kernel_size if isinstance(kernel_size, list) else [kernel_size]
num_groups = len(kernel_size)
in_splits = _split_channels(in_channels, num_groups)
out_splits = _split_channels(out_channels, num_groups)
self.in_channels = sum(in_splits)
self.out_channels = sum(out_splits)
for idx, (k, in_ch, out_ch) in enumerate(zip(kernel_size, in_splits, out_splits)):
conv_groups = out_ch if depthwise else 1
# use add_module to keep key space clean
self.add_module(
str(idx),
create_conv2d_pad(
in_ch, out_ch, k, stride=stride,
padding=padding, dilation=dilation, groups=conv_groups, **kwargs)
)
self.splits = in_splits
def forward(self, x):
x_split = torch.split(x, self.splits, 1)
x_out = [c(x_split[i]) for i, c in enumerate(self.values())]
x = torch.cat(x_out, 1)
return x

@ -0,0 +1,30 @@
""" Select Conv2d Factory Method
Hacked together by Ross Wightman
"""
from .mixed_conv2d import MixedConv2d
from .cond_conv2d import CondConv2d
from .conv2d_same import create_conv2d_pad
def select_conv2d(in_chs, out_chs, kernel_size, **kwargs):
""" Select a 2d convolution implementation based on arguments
Creates and returns one of torch.nn.Conv2d, Conv2dSame, MixedConv2d, or CondConv2d.
Used extensively by EfficientNet, MobileNetv3 and related networks.
"""
assert 'groups' not in kwargs # only use 'depthwise' bool arg
if isinstance(kernel_size, list):
assert 'num_experts' not in kwargs # MixNet + CondConv combo not supported currently
# We're going to use only lists for defining the MixedConv2d kernel groups,
# ints, tuples, other iterables will continue to pass to normal conv and specify h, w.
m = MixedConv2d(in_chs, out_chs, kernel_size, **kwargs)
else:
depthwise = kwargs.pop('depthwise', False)
groups = out_chs if depthwise else 1
if 'num_experts' in kwargs and kwargs['num_experts'] > 0:
m = CondConv2d(in_chs, out_chs, kernel_size, groups=groups, **kwargs)
else:
m = create_conv2d_pad(in_chs, out_chs, kernel_size, groups=groups, **kwargs)
return m

@ -0,0 +1,88 @@
""" Selective Kernel Convolution Attention
Hacked together by Ross Wightman
"""
import torch
from torch import nn as nn
from .conv_bn_act import ConvBnAct
def _kernel_valid(k):
if isinstance(k, (list, tuple)):
for ki in k:
return _kernel_valid(ki)
assert k >= 3 and k % 2
class SelectiveKernelAttn(nn.Module):
def __init__(self, channels, num_paths=2, attn_channels=32,
act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d):
super(SelectiveKernelAttn, self).__init__()
self.num_paths = num_paths
self.pool = nn.AdaptiveAvgPool2d(1)
self.fc_reduce = nn.Conv2d(channels, attn_channels, kernel_size=1, bias=False)
self.bn = norm_layer(attn_channels)
self.act = act_layer(inplace=True)
self.fc_select = nn.Conv2d(attn_channels, channels * num_paths, kernel_size=1, bias=False)
def forward(self, x):
assert x.shape[1] == self.num_paths
x = torch.sum(x, dim=1)
x = self.pool(x)
x = self.fc_reduce(x)
x = self.bn(x)
x = self.act(x)
x = self.fc_select(x)
B, C, H, W = x.shape
x = x.view(B, self.num_paths, C // self.num_paths, H, W)
x = torch.softmax(x, dim=1)
return x
class SelectiveKernelConv(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size=None, stride=1, dilation=1, groups=1,
attn_reduction=16, min_attn_channels=32, keep_3x3=True, split_input=False,
drop_block=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d):
super(SelectiveKernelConv, self).__init__()
kernel_size = kernel_size or [3, 5]
_kernel_valid(kernel_size)
if not isinstance(kernel_size, list):
kernel_size = [kernel_size] * 2
if keep_3x3:
dilation = [dilation * (k - 1) // 2 for k in kernel_size]
kernel_size = [3] * len(kernel_size)
else:
dilation = [dilation] * len(kernel_size)
self.num_paths = len(kernel_size)
self.in_channels = in_channels
self.out_channels = out_channels
self.split_input = split_input
if self.split_input:
assert in_channels % self.num_paths == 0
in_channels = in_channels // self.num_paths
groups = min(out_channels, groups)
conv_kwargs = dict(
stride=stride, groups=groups, drop_block=drop_block, act_layer=act_layer, norm_layer=norm_layer)
self.paths = nn.ModuleList([
ConvBnAct(in_channels, out_channels, kernel_size=k, dilation=d, **conv_kwargs)
for k, d in zip(kernel_size, dilation)])
attn_channels = max(int(out_channels / attn_reduction), min_attn_channels)
self.attn = SelectiveKernelAttn(out_channels, self.num_paths, attn_channels)
self.drop_block = drop_block
def forward(self, x):
if self.split_input:
x_split = torch.split(x, self.in_channels // self.num_paths, 1)
x_paths = [op(x_split[i]) for i, op in enumerate(self.paths)]
else:
x_paths = [op(x) for op in self.paths]
x = torch.stack(x_paths, dim=1)
x_attn = self.attn(x)
x = x * x_attn
x = torch.sum(x, dim=1)
return x

@ -1,3 +1,8 @@
""" Test Time Pooling (Average-Max Pool)
Hacked together by Ross Wightman
"""
import logging
from torch import nn
import torch.nn.functional as F

@ -7,8 +7,6 @@ Paper: Searching for MobileNetV3 - https://arxiv.org/abs/1905.02244
Hacked together by Ross Wightman
"""
import torch.nn as nn
import torch.nn.functional as F
from .efficientnet_builder import *
from .registry import register_model

Loading…
Cancel
Save