diff --git a/timm/models/conv2d_layers.py b/timm/models/conv2d_layers.py index acd14fde..7583263a 100644 --- a/timm/models/conv2d_layers.py +++ b/timm/models/conv2d_layers.py @@ -1,3 +1,5 @@ +from collections import OrderedDict + import torch import torch.nn as nn import torch.nn.functional as F @@ -100,14 +102,11 @@ def create_conv2d_pad(in_chs, out_chs, kernel_size, **kwargs): return nn.Conv2d(in_chs, out_chs, kernel_size, padding=padding, **kwargs) -class MixedConv2d(nn.Module): +class MixedConv2d(nn.ModuleDict): """ Mixed Grouped Convolution Based on MDConv and GroupedConv in MixNet impl: https://github.com/tensorflow/tpu/blob/master/models/official/mnasnet/mixnet/custom_layers.py - - NOTE: This does not currently work with torch.jit.script """ - def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding='', dilation=1, depthwise=False, **kwargs): super(MixedConv2d, self).__init__() @@ -131,7 +130,7 @@ class MixedConv2d(nn.Module): def forward(self, x): x_split = torch.split(x, self.splits, 1) - x_out = [c(x) for x, c in zip(x_split, self._modules.values())] + x_out = [c(x_split[i]) for i, c in enumerate(self.values())] x = torch.cat(x_out, 1) return x @@ -240,6 +239,97 @@ class CondConv2d(nn.Module): return out +class SelectiveKernelAttn(nn.Module): + def __init__(self, channels, num_paths=2, attn_channels=32, + act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d): + super(SelectiveKernelAttn, self).__init__() + self.num_paths = num_paths + self.pool = nn.AdaptiveAvgPool2d(1) + self.fc_reduce = nn.Conv2d(channels, attn_channels, kernel_size=1, bias=False) + self.bn = norm_layer(attn_channels) + self.act = act_layer(inplace=True) + self.fc_select = nn.Conv2d(attn_channels, channels * num_paths, kernel_size=1, bias=False) + + def forward(self, x): + assert x.shape[1] == self.num_paths + x = torch.sum(x, dim=1) + x = self.pool(x) + x = self.fc_reduce(x) + x = self.bn(x) + x = self.act(x) + x = self.fc_select(x) + B, C, H, W = x.shape + x = x.view(B, self.num_paths, C // self.num_paths, H, W) + x = torch.softmax(x, dim=1) + return x + + +def _kernel_valid(k): + if isinstance(k, (list, tuple)): + for ki in k: + return _kernel_valid(ki) + assert k >= 3 and k % 2 + + +class SelectiveKernelConv(nn.Module): + + def __init__(self, in_channels, out_channels, kernel_size=None, stride=1, dilation=1, groups=1, + attn_reduction=16, min_attn_channels=32, keep_3x3=True, split_input=False, + act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d): + super(SelectiveKernelConv, self).__init__() + kernel_size = kernel_size or [3, 5] + _kernel_valid(kernel_size) + if not isinstance(kernel_size, list): + kernel_size = [kernel_size] * 2 + if keep_3x3: + dilation = [dilation * (k - 1) // 2 for k in kernel_size] + kernel_size = [3] * len(kernel_size) + else: + dilation = [dilation] * len(kernel_size) + num_paths = len(kernel_size) + self.num_paths = num_paths + self.split_input = split_input + self.in_channels = in_channels + self.out_channels = out_channels + if split_input: + assert in_channels % num_paths == 0 and out_channels % num_paths == 0 + in_channels = in_channels // num_paths + out_channels = out_channels // num_paths + groups = min(out_channels, groups) + + self.paths = nn.ModuleList() + for k, d in zip(kernel_size, dilation): + p = _get_padding(k, stride, d) + self.paths.append(nn.Sequential(OrderedDict([ + ('conv', nn.Conv2d( + in_channels, out_channels, kernel_size=k, stride=stride, padding=p, + dilation=d, groups=groups, bias=False)), + ('bn', norm_layer(out_channels)), + ('act', act_layer(inplace=True)) + ]))) + + attn_channels = max(int(out_channels / attn_reduction), min_attn_channels) + self.attn = SelectiveKernelAttn(out_channels, num_paths, attn_channels) + + def forward(self, x): + if self.split_input: + x_split = torch.split(x, self.in_channels // self.num_paths, 1) + x_paths = [op(x_split[i]) for i, op in enumerate(self.paths)] + else: + x_paths = [op(x) for op in self.paths] + + x = torch.stack(x_paths, dim=1) + x_attn = self.attn(x) + x = x * x_attn + + if self.split_input: + B, N, C, H, W = x.shape + x = x.reshape(B, N * C, H, W) + else: + x = torch.sum(x, dim=1) + return x + + # helper method def select_conv2d(in_chs, out_chs, kernel_size, **kwargs): assert 'groups' not in kwargs # only use 'depthwise' bool arg @@ -256,5 +346,3 @@ def select_conv2d(in_chs, out_chs, kernel_size, **kwargs): else: m = create_conv2d_pad(in_chs, out_chs, kernel_size, groups=groups, **kwargs) return m - - diff --git a/timm/models/res2net.py b/timm/models/res2net.py index da20e7a0..c83aba62 100644 --- a/timm/models/res2net.py +++ b/timm/models/res2net.py @@ -54,14 +54,15 @@ class Bottle2neck(nn.Module): def __init__(self, inplanes, planes, stride=1, downsample=None, cardinality=1, base_width=26, scale=4, use_se=False, - act_layer=nn.ReLU, norm_layer=None, dilation=1, previous_dilation=1, **_): + act_layer=nn.ReLU, norm_layer=None, dilation=1, first_dilation=None, **_): super(Bottle2neck, self).__init__() self.scale = scale self.is_first = stride > 1 or downsample is not None self.num_scales = max(1, scale - 1) width = int(math.floor(planes * (base_width / 64.0))) * cardinality - outplanes = planes * self.expansion self.width = width + outplanes = planes * self.expansion + first_dilation = first_dilation or dilation self.conv1 = nn.Conv2d(inplanes, width * scale, kernel_size=1, bias=False) self.bn1 = norm_layer(width * scale) @@ -70,8 +71,8 @@ class Bottle2neck(nn.Module): bns = [] for i in range(self.num_scales): convs.append(nn.Conv2d( - width, width, kernel_size=3, stride=stride, padding=dilation, - dilation=dilation, groups=cardinality, bias=False)) + width, width, kernel_size=3, stride=stride, padding=first_dilation, + dilation=first_dilation, groups=cardinality, bias=False)) bns.append(norm_layer(width)) self.convs = nn.ModuleList(convs) self.bns = nn.ModuleList(bns) diff --git a/timm/models/resnet.py b/timm/models/resnet.py index c3104530..976d4234 100644 --- a/timm/models/resnet.py +++ b/timm/models/resnet.py @@ -131,24 +131,23 @@ class BasicBlock(nn.Module): __constants__ = ['se', 'downsample'] # for pre 1.4 torchscript compat expansion = 1 - def __init__(self, inplanes, planes, stride=1, downsample=None, - cardinality=1, base_width=64, use_se=False, - reduce_first=1, dilation=1, previous_dilation=1, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d): + def __init__(self, inplanes, planes, stride=1, downsample=None, cardinality=1, base_width=64, use_se=False, + reduce_first=1, dilation=1, first_dilation=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d): super(BasicBlock, self).__init__() assert cardinality == 1, 'BasicBlock only supports cardinality of 1' assert base_width == 64, 'BasicBlock doest not support changing base width' first_planes = planes // reduce_first outplanes = planes * self.expansion + first_dilation = first_dilation or dilation self.conv1 = nn.Conv2d( - inplanes, first_planes, kernel_size=3, stride=stride, padding=dilation, - dilation=dilation, bias=False) + inplanes, first_planes, kernel_size=3, stride=stride, padding=first_dilation, + dilation=first_dilation, bias=False) self.bn1 = norm_layer(first_planes) self.act1 = act_layer(inplace=True) self.conv2 = nn.Conv2d( - first_planes, outplanes, kernel_size=3, padding=previous_dilation, - dilation=previous_dilation, bias=False) + first_planes, outplanes, kernel_size=3, padding=dilation, dilation=dilation, bias=False) self.bn2 = norm_layer(outplanes) self.se = SEModule(outplanes, planes // 4) if use_se else None self.act2 = act_layer(inplace=True) @@ -181,21 +180,21 @@ class Bottleneck(nn.Module): __constants__ = ['se', 'downsample'] # for pre 1.4 torchscript compat expansion = 4 - def __init__(self, inplanes, planes, stride=1, downsample=None, - cardinality=1, base_width=64, use_se=False, - reduce_first=1, dilation=1, previous_dilation=1, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d): + def __init__(self, inplanes, planes, stride=1, downsample=None, cardinality=1, base_width=64, use_se=False, + reduce_first=1, dilation=1, first_dilation=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d): super(Bottleneck, self).__init__() width = int(math.floor(planes * (base_width / 64)) * cardinality) first_planes = width // reduce_first outplanes = planes * self.expansion + first_dilation = first_dilation or dilation self.conv1 = nn.Conv2d(inplanes, first_planes, kernel_size=1, bias=False) self.bn1 = norm_layer(first_planes) self.act1 = act_layer(inplace=True) self.conv2 = nn.Conv2d( first_planes, width, kernel_size=3, stride=stride, - padding=dilation, dilation=dilation, groups=cardinality, bias=False) + padding=first_dilation, dilation=first_dilation, groups=cardinality, bias=False) self.bn2 = norm_layer(width) self.act2 = act_layer(inplace=True) self.conv3 = nn.Conv2d(width, outplanes, kernel_size=1, bias=False) @@ -396,13 +395,11 @@ class ResNet(nn.Module): first_dilation = 1 if dilation in (1, 2) else 2 bkwargs = dict( cardinality=self.cardinality, base_width=self.base_width, reduce_first=reduce_first, - use_se=use_se, **kwargs) - layers = [block( - self.inplanes, planes, stride, downsample, dilation=first_dilation, previous_dilation=dilation, **bkwargs)] + dilation=dilation, use_se=use_se, **kwargs) + layers = [block(self.inplanes, planes, stride, downsample, first_dilation=first_dilation, **bkwargs)] self.inplanes = planes * block.expansion for i in range(1, blocks): - layers.append(block( - self.inplanes, planes, dilation=dilation, previous_dilation=dilation, **bkwargs)) + layers.append(block(self.inplanes, planes, **bkwargs)) return nn.Sequential(*layers) @@ -430,8 +427,8 @@ class ResNet(nn.Module): def forward(self, x): x = self.forward_features(x) x = self.global_pool(x).flatten(1) - if self.drop_rate > 0.: - x = F.dropout(x, p=self.drop_rate, training=self.training) + if self.drop_rate: + x = F.dropout(x, p=float(self.drop_rate), training=self.training) x = self.fc(x) return x diff --git a/timm/models/sknet.py b/timm/models/sknet.py index d3a4fb5d..e9dbf68d 100644 --- a/timm/models/sknet.py +++ b/timm/models/sknet.py @@ -1,12 +1,11 @@ import math -from collections import OrderedDict -import torch from torch import nn as nn from timm.models.registry import register_model from timm.models.helpers import load_pretrained -from timm.models.resnet import ResNet, get_padding, SEModule +from timm.models.conv2d_layers import SelectiveKernelConv +from timm.models.resnet import ResNet, SEModule from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD @@ -27,113 +26,12 @@ default_cfgs = { } -class SelectiveKernelAttn(nn.Module): - def __init__(self, channels, num_paths=2, attn_channels=32, - act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d): - super(SelectiveKernelAttn, self).__init__() - self.num_paths = num_paths - self.pool = nn.AdaptiveAvgPool2d(1) - self.fc_reduce = nn.Conv2d(channels, attn_channels, kernel_size=1, bias=False) - self.bn = norm_layer(attn_channels) - self.act = act_layer(inplace=True) - self.fc_select = nn.Conv2d(attn_channels, channels * num_paths, kernel_size=1, bias=False) - - def forward(self, x): - assert x.shape[1] == self.num_paths - x = torch.sum(x, dim=1) - #print('attn sum', x.shape) - x = self.pool(x) - #print('attn pool', x.shape) - x = self.fc_reduce(x) - x = self.bn(x) - x = self.act(x) - x = self.fc_select(x) - #print('attn sel', x.shape) - B, C, H, W = x.shape - x = x.view(B, self.num_paths, C // self.num_paths, H, W) - #print('attn spl', x.shape) - x = torch.softmax(x, dim=1) - return x - - -def _kernel_valid(k): - if isinstance(k, (list, tuple)): - for ki in k: - return _kernel_valid(ki) - assert k >= 3 and k % 2 - - -class SelectiveKernelConv(nn.Module): - - def __init__(self, in_channels, out_channels, kernel_size=[3, 5], stride=1, dilation=1, groups=1, - attn_reduction=16, min_attn_channels=32, keep_3x3=True, use_attn=True, - split_input=False, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d): - super(SelectiveKernelConv, self).__init__() - _kernel_valid(kernel_size) - if not isinstance(kernel_size, list): - kernel_size = [kernel_size] * 2 - if keep_3x3: - dilation = [dilation * (k - 1) // 2 for k in kernel_size] - kernel_size = [3] * len(kernel_size) - else: - dilation = [dilation] * len(kernel_size) - num_paths = len(kernel_size) - self.num_paths = num_paths - self.split_input = split_input - self.in_channels = in_channels - self.out_channels = out_channels - if split_input: - assert in_channels % num_paths == 0 and out_channels % num_paths == 0 - in_channels = in_channels // num_paths - out_channels = out_channels // num_paths - groups = min(out_channels, groups) - - self.paths = nn.ModuleList() - for k, d in zip(kernel_size, dilation): - p = get_padding(k, stride, d) - self.paths.append(nn.Sequential(OrderedDict([ - ('conv', nn.Conv2d( - in_channels, out_channels, kernel_size=k, stride=stride, padding=p, dilation=d, groups=groups)), - ('bn', norm_layer(out_channels)), - ('act', act_layer(inplace=True)) - ]))) - - if use_attn: - attn_channels = max(int(out_channels / attn_reduction), min_attn_channels) - self.attn = SelectiveKernelAttn(out_channels, num_paths, attn_channels) - else: - self.attn = None - - def forward(self, x): - if self.split_input: - x_split = torch.split(x, self.in_channels // self.num_paths, 1) - x_paths = [op(x_split[i]) for i, op in enumerate(self.paths)] - else: - x_paths = [op(x) for op in self.paths] - - if self.attn is not None: - x = torch.stack(x_paths, dim=1) - # print('paths', x_paths.shape) - x_attn = self.attn(x) - #print('attn', x_attn.shape) - x = x * x_attn - #print('amul', x.shape) - - if self.split_input: - B, N, C, H, W = x.shape - x = x.reshape(B, N * C, H, W) - else: - x = torch.sum(x, dim=1) - #print('aout', x.shape) - return x - - class SelectiveKernelBasic(nn.Module): expansion = 1 def __init__(self, inplanes, planes, stride=1, downsample=None, cardinality=1, base_width=64, use_se=False, sk_kwargs=None, - reduce_first=1, dilation=1, previous_dilation=1, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d): + reduce_first=1, dilation=1, first_dilation=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d): super(SelectiveKernelBasic, self).__init__() sk_kwargs = sk_kwargs or {} @@ -141,24 +39,25 @@ class SelectiveKernelBasic(nn.Module): assert base_width == 64, 'BasicBlock doest not support changing base width' first_planes = planes // reduce_first outplanes = planes * self.expansion + first_dilation = first_dilation or dilation _selective_first = True # FIXME temporary, for experiments if _selective_first: self.conv1 = SelectiveKernelConv( - inplanes, first_planes, stride=stride, dilation=dilation, **sk_kwargs) + inplanes, first_planes, stride=stride, dilation=first_dilation, **sk_kwargs) else: self.conv1 = nn.Conv2d( - inplanes, first_planes, kernel_size=3, stride=stride, padding=dilation, - dilation=dilation, bias=False) + inplanes, first_planes, kernel_size=3, stride=stride, padding=first_dilation, + dilation=first_dilation, bias=False) self.bn1 = norm_layer(first_planes) self.act1 = act_layer(inplace=True) if _selective_first: self.conv2 = nn.Conv2d( - first_planes, outplanes, kernel_size=3, padding=previous_dilation, - dilation=previous_dilation, bias=False) + first_planes, outplanes, kernel_size=3, padding=dilation, + dilation=dilation, bias=False) else: self.conv2 = SelectiveKernelConv( - first_planes, outplanes, dilation=previous_dilation, **sk_kwargs) + first_planes, outplanes, dilation=dilation, **sk_kwargs) self.bn2 = norm_layer(outplanes) self.se = SEModule(outplanes, planes // 4) if use_se else None self.act2 = act_layer(inplace=True) @@ -192,19 +91,20 @@ class SelectiveKernelBottleneck(nn.Module): def __init__(self, inplanes, planes, stride=1, downsample=None, cardinality=1, base_width=64, use_se=False, sk_kwargs=None, - reduce_first=1, dilation=1, previous_dilation=1, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d): + reduce_first=1, dilation=1, first_dilation=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d): super(SelectiveKernelBottleneck, self).__init__() sk_kwargs = sk_kwargs or {} width = int(math.floor(planes * (base_width / 64)) * cardinality) first_planes = width // reduce_first outplanes = planes * self.expansion + first_dilation = first_dilation or dilation self.conv1 = nn.Conv2d(inplanes, first_planes, kernel_size=1, bias=False) self.bn1 = norm_layer(first_planes) self.act1 = act_layer(inplace=True) self.conv2 = SelectiveKernelConv( - first_planes, width, stride=stride, dilation=dilation, groups=cardinality, **sk_kwargs) + first_planes, width, stride=stride, dilation=first_dilation, groups=cardinality, **sk_kwargs) self.bn2 = norm_layer(width) self.act2 = act_layer(inplace=True) self.conv3 = nn.Conv2d(width, outplanes, kernel_size=1, bias=False)