From 165fb354b2a797c68ec30399971dc1fdfc498509 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Fri, 14 May 2021 16:48:58 -0700 Subject: [PATCH 01/48] Add initial RedNet model / Involution layer impl for testing --- timm/models/byoanet.py | 49 +++++++++++++++++++++++++ timm/models/layers/__init__.py | 1 + timm/models/layers/create_self_attn.py | 3 ++ timm/models/layers/involution.py | 50 ++++++++++++++++++++++++++ 4 files changed, 103 insertions(+) create mode 100644 timm/models/layers/involution.py diff --git a/timm/models/byoanet.py b/timm/models/byoanet.py index ca49089b..a58eea63 100644 --- a/timm/models/byoanet.py +++ b/timm/models/byoanet.py @@ -58,6 +58,9 @@ default_cfgs = { 'swinnet26t_256': _cfg(url='', fixed_input_size=True, input_size=(3, 256, 256), pool_size=(8, 8)), 'swinnet50ts_256': _cfg(url='', fixed_input_size=True, input_size=(3, 256, 256), pool_size=(8, 8)), + + 'rednet26t': _cfg(url='', fixed_input_size=False, input_size=(3, 256, 256), pool_size=(8, 8)), + 'rednet50ts': _cfg(url='', fixed_input_size=False, input_size=(3, 256, 256), pool_size=(8, 8)), } @@ -245,6 +248,38 @@ model_cfgs = dict( self_attn_fixed_size=True, self_attn_kwargs=dict(win_size=8) ), + + rednet26t=ByoaCfg( + blocks=( + ByoaBlocksCfg(type='self_attn', d=2, c=256, s=1, gs=0, br=0.25), + ByoaBlocksCfg(type='self_attn', d=2, c=512, s=2, gs=0, br=0.25), + ByoaBlocksCfg(type='self_attn', d=2, c=1024, s=2, gs=0, br=0.25), + ByoaBlocksCfg(type='self_attn', d=2, c=2048, s=2, gs=0, br=0.25), + ), + stem_chs=64, + stem_type='tiered', # FIXME RedNet uses involution in middle of stem + stem_pool='maxpool', + num_features=0, + self_attn_layer='involution', + self_attn_fixed_size=False, + self_attn_kwargs=dict() + ), + rednet50ts=ByoaCfg( + blocks=( + ByoaBlocksCfg(type='self_attn', d=3, c=256, s=1, gs=0, br=0.25), + ByoaBlocksCfg(type='self_attn', d=4, c=512, s=2, gs=0, br=0.25), + ByoaBlocksCfg(type='self_attn', d=2, c=1024, s=2, gs=0, br=0.25), + ByoaBlocksCfg(type='self_attn', d=3, c=2048, s=2, gs=0, br=0.25), + ), + stem_chs=64, + stem_type='tiered', + stem_pool='maxpool', + num_features=0, + act_layer='silu', + self_attn_layer='involution', + self_attn_fixed_size=False, + self_attn_kwargs=dict() + ), ) @@ -477,3 +512,17 @@ def swinnet50ts_256(pretrained=False, **kwargs): """ kwargs.setdefault('img_size', 256) return _create_byoanet('swinnet50ts_256', 'swinnet50ts', pretrained=pretrained, **kwargs) + + +@register_model +def rednet26t(pretrained=False, **kwargs): + """ + """ + return _create_byoanet('rednet26t', pretrained=pretrained, **kwargs) + + +@register_model +def rednet50ts(pretrained=False, **kwargs): + """ + """ + return _create_byoanet('rednet50ts', pretrained=pretrained, **kwargs) diff --git a/timm/models/layers/__init__.py b/timm/models/layers/__init__.py index 90241f5c..522c27e1 100644 --- a/timm/models/layers/__init__.py +++ b/timm/models/layers/__init__.py @@ -18,6 +18,7 @@ from .eca import EcaModule, CecaModule from .evo_norm import EvoNormBatch2d, EvoNormSample2d from .helpers import to_ntuple, to_2tuple, to_3tuple, to_4tuple, make_divisible from .inplace_abn import InplaceAbn +from .involution import Involution from .linear import Linear from .mixed_conv2d import MixedConv2d from .mlp import Mlp, GluMlp diff --git a/timm/models/layers/create_self_attn.py b/timm/models/layers/create_self_attn.py index ba208f17..448ddb34 100644 --- a/timm/models/layers/create_self_attn.py +++ b/timm/models/layers/create_self_attn.py @@ -1,5 +1,6 @@ from .bottleneck_attn import BottleneckAttn from .halo_attn import HaloAttn +from .involution import Involution from .lambda_layer import LambdaLayer from .swin_attn import WindowAttention @@ -13,6 +14,8 @@ def get_self_attn(attn_type): return LambdaLayer elif attn_type == 'swin': return WindowAttention + elif attn_type == 'involution': + return Involution else: assert False, f"Unknown attn type ({attn_type})" diff --git a/timm/models/layers/involution.py b/timm/models/layers/involution.py new file mode 100644 index 00000000..0dba9fae --- /dev/null +++ b/timm/models/layers/involution.py @@ -0,0 +1,50 @@ +""" PyTorch Involution Layer + +Official impl: https://github.com/d-li14/involution/blob/main/cls/mmcls/models/utils/involution_naive.py +Paper: `Involution: Inverting the Inherence of Convolution for Visual Recognition` - https://arxiv.org/abs/2103.06255 +""" +import torch.nn as nn +from .conv_bn_act import ConvBnAct +from .create_conv2d import create_conv2d + + +class Involution(nn.Module): + + def __init__( + self, + channels, + kernel_size=3, + stride=1, + group_size=16, + reduction_ratio=4, + norm_layer=nn.BatchNorm2d, + act_layer=nn.ReLU, + ): + super(Involution, self).__init__() + self.kernel_size = kernel_size + self.stride = stride + self.channels = channels + self.group_size = group_size + self.groups = self.channels // self.group_size + self.conv1 = ConvBnAct( + in_channels=channels, + out_channels=channels // reduction_ratio, + kernel_size=1, + norm_layer=norm_layer, + act_layer=act_layer) + self.conv2 = self.conv = create_conv2d( + in_channels=channels // reduction_ratio, + out_channels=kernel_size**2 * self.groups, + kernel_size=1, + stride=1) + self.avgpool = nn.AvgPool2d(stride, stride) if stride == 2 else nn.Identity() + self.unfold = nn.Unfold(kernel_size, 1, (kernel_size-1)//2, stride) + + def forward(self, x): + weight = self.conv2(self.conv1(self.avgpool(x))) + B, C, H, W = weight.shape + KK = int(self.kernel_size ** 2) + weight = weight.view(B, self.groups, KK, H, W).unsqueeze(2) + out = self.unfold(x).view(B, self.groups, self.group_size, KK, H, W) + out = (weight * out).sum(dim=3).view(B, self.channels, H, W) + return out From ecc7552c5c5ab1d177705774e8e4efd16939852c Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Fri, 14 May 2021 17:10:43 -0700 Subject: [PATCH 02/48] Add levit, levit_c, and visformer model defs. Largely untested and not finished cleanup. --- tests/test_models.py | 2 +- timm/models/__init__.py | 3 + timm/models/layers/patch_embed.py | 7 +- timm/models/levit.py | 440 ++++++++++++++++++++++++++++++ timm/models/levitc.py | 400 +++++++++++++++++++++++++++ timm/models/visformer.py | 377 +++++++++++++++++++++++++ 6 files changed, 1226 insertions(+), 3 deletions(-) create mode 100644 timm/models/levit.py create mode 100644 timm/models/levitc.py create mode 100644 timm/models/visformer.py diff --git a/tests/test_models.py b/tests/test_models.py index ced2fd76..1e1de498 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -15,7 +15,7 @@ if hasattr(torch._C, '_jit_set_profiling_executor'): torch._C._jit_set_profiling_mode(False) # transformer models don't support many of the spatial / feature based model functionalities -NON_STD_FILTERS = ['vit_*', 'tnt_*', 'pit_*', 'swin_*', 'coat_*', 'cait_*', 'mixer_*'] +NON_STD_FILTERS = ['vit_*', 'tnt_*', 'pit_*', 'swin_*', 'coat_*', 'cait_*', 'mixer_*', 'levit*', 'visformer*'] NUM_NON_STD = len(NON_STD_FILTERS) # exclude models that cause specific test failures diff --git a/timm/models/__init__.py b/timm/models/__init__.py index 46ea155f..821012e2 100644 --- a/timm/models/__init__.py +++ b/timm/models/__init__.py @@ -15,6 +15,8 @@ from .hrnet import * from .inception_resnet_v2 import * from .inception_v3 import * from .inception_v4 import * +from .levitc import * +from .levit import * from .mlp_mixer import * from .mobilenetv3 import * from .nasnet import * @@ -34,6 +36,7 @@ from .swin_transformer import * from .tnt import * from .tresnet import * from .vgg import * +from .visformer import * from .vision_transformer import * from .vision_transformer_hybrid import * from .vovnet import * diff --git a/timm/models/layers/patch_embed.py b/timm/models/layers/patch_embed.py index b06f9982..42997fb8 100644 --- a/timm/models/layers/patch_embed.py +++ b/timm/models/layers/patch_embed.py @@ -15,7 +15,7 @@ from .helpers import to_2tuple class PatchEmbed(nn.Module): """ 2D Image to Patch Embedding """ - def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, norm_layer=None): + def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, norm_layer=None, flatten=True): super().__init__() img_size = to_2tuple(img_size) patch_size = to_2tuple(patch_size) @@ -23,6 +23,7 @@ class PatchEmbed(nn.Module): self.patch_size = patch_size self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1]) self.num_patches = self.grid_size[0] * self.grid_size[1] + self.flatten = flatten self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() @@ -31,6 +32,8 @@ class PatchEmbed(nn.Module): B, C, H, W = x.shape assert H == self.img_size[0] and W == self.img_size[1], \ f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." - x = self.proj(x).flatten(2).transpose(1, 2) + x = self.proj(x) + if self.flatten: + x = x.flatten(2).transpose(1, 2) # BCHW -> BNC x = self.norm(x) return x diff --git a/timm/models/levit.py b/timm/models/levit.py new file mode 100644 index 00000000..997b44d7 --- /dev/null +++ b/timm/models/levit.py @@ -0,0 +1,440 @@ +# Copyright (c) 2015-present, Facebook, Inc. +# All rights reserved. + +# Modified from +# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py +# Copyright 2020 Ross Wightman, Apache-2.0 License +import itertools + +import torch + +from timm.data import IMAGENET_DEFAULT_STD, IMAGENET_DEFAULT_MEAN +from .vision_transformer import trunc_normal_ +from .registry import register_model + + +def _cfg(url='', **kwargs): + return { + 'url': url, + 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, + 'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True, + 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, + 'first_conv': 'patch_embed.proj', 'classifier': 'head', + **kwargs + } + + +specification = { + 'levit_128s': { + 'C': '128_256_384', 'D': 16, 'N': '4_6_8', 'X': '2_3_4', 'drop_path': 0, + 'weights': 'https://dl.fbaipublicfiles.com/LeViT/LeViT-128S-96703c44.pth'}, + 'levit_128': { + 'C': '128_256_384', 'D': 16, 'N': '4_8_12', 'X': '4_4_4', 'drop_path': 0, + 'weights': 'https://dl.fbaipublicfiles.com/LeViT/LeViT-128-b88c2750.pth'}, + 'levit_192': { + 'C': '192_288_384', 'D': 32, 'N': '3_5_6', 'X': '4_4_4', 'drop_path': 0, + 'weights': 'https://dl.fbaipublicfiles.com/LeViT/LeViT-192-92712e41.pth'}, + 'levit_256': { + 'C': '256_384_512', 'D': 32, 'N': '4_6_8', 'X': '4_4_4', 'drop_path': 0, + 'weights': 'https://dl.fbaipublicfiles.com/LeViT/LeViT-256-13b5763e.pth'}, + 'levit_384': { + 'C': '384_512_768', 'D': 32, 'N': '6_9_12', 'X': '4_4_4', 'drop_path': 0.1, + 'weights': 'https://dl.fbaipublicfiles.com/LeViT/LeViT-384-9bdaf2e2.pth'}, +} + +__all__ = ['Levit'] + + +@register_model +def levit_128s(num_classes=1000, distillation=True, pretrained=False, fuse=False, **kwargs): + return model_factory(**specification['levit_128s'], num_classes=num_classes, + distillation=distillation, pretrained=pretrained, fuse=fuse) + + +@register_model +def levit_128(num_classes=1000, distillation=True, pretrained=False, fuse=False, **kwargs): + return model_factory(**specification['levit_128'], num_classes=num_classes, + distillation=distillation, pretrained=pretrained, fuse=fuse) + + +@register_model +def levit_192(num_classes=1000, distillation=True, pretrained=False, fuse=False, **kwargs): + return model_factory(**specification['levit_192'], num_classes=num_classes, + distillation=distillation, pretrained=pretrained, fuse=fuse) + + +@register_model +def levit_256(num_classes=1000, distillation=True, pretrained=False, fuse=False, **kwargs): + return model_factory(**specification['levit_256'], num_classes=num_classes, + distillation=distillation, pretrained=pretrained, fuse=fuse) + + +@register_model +def levit_384(num_classes=1000, distillation=True, pretrained=False, fuse=False, **kwargs): + return model_factory(**specification['levit_384'], num_classes=num_classes, + distillation=distillation, pretrained=pretrained, fuse=fuse) + + +class ConvNorm(torch.nn.Sequential): + def __init__( + self, a, b, ks=1, stride=1, pad=0, dilation=1, groups=1, bn_weight_init=1, resolution=-10000): + super().__init__() + self.add_module('c', torch.nn.Conv2d(a, b, ks, stride, pad, dilation, groups, bias=False)) + bn = torch.nn.BatchNorm2d(b) + torch.nn.init.constant_(bn.weight, bn_weight_init) + torch.nn.init.constant_(bn.bias, 0) + self.add_module('bn', bn) + + @torch.no_grad() + def fuse(self): + c, bn = self._modules.values() + w = bn.weight / (bn.running_var + bn.eps) ** 0.5 + w = c.weight * w[:, None, None, None] + b = bn.bias - bn.running_mean * bn.weight / (bn.running_var + bn.eps) ** 0.5 + m = torch.nn.Conv2d( + w.size(1), w.size(0), w.shape[2:], stride=self.c.stride, + padding=self.c.padding, dilation=self.c.dilation, groups=self.c.groups) + m.weight.data.copy_(w) + m.bias.data.copy_(b) + return m + + +class LinearNorm(torch.nn.Sequential): + def __init__(self, a, b, bn_weight_init=1, resolution=-100000): + super().__init__() + self.add_module('c', torch.nn.Linear(a, b, bias=False)) + bn = torch.nn.BatchNorm1d(b) + torch.nn.init.constant_(bn.weight, bn_weight_init) + torch.nn.init.constant_(bn.bias, 0) + self.add_module('bn', bn) + + @torch.no_grad() + def fuse(self): + l, bn = self._modules.values() + w = bn.weight / (bn.running_var + bn.eps) ** 0.5 + w = l.weight * w[:, None] + b = bn.bias - bn.running_mean * bn.weight / (bn.running_var + bn.eps) ** 0.5 + m = torch.nn.Linear(w.size(1), w.size(0)) + m.weight.data.copy_(w) + m.bias.data.copy_(b) + return m + + def forward(self, x): + l, bn = self._modules.values() + x = l(x) + return bn(x.flatten(0, 1)).reshape_as(x) + + +class NormLinear(torch.nn.Sequential): + def __init__(self, a, b, bias=True, std=0.02): + super().__init__() + self.add_module('bn', torch.nn.BatchNorm1d(a)) + l = torch.nn.Linear(a, b, bias=bias) + trunc_normal_(l.weight, std=std) + if bias: + torch.nn.init.constant_(l.bias, 0) + self.add_module('l', l) + + @torch.no_grad() + def fuse(self): + bn, l = self._modules.values() + w = bn.weight / (bn.running_var + bn.eps) ** 0.5 + b = bn.bias - self.bn.running_mean * self.bn.weight / (bn.running_var + bn.eps) ** 0.5 + w = l.weight * w[None, :] + if l.bias is None: + b = b @ self.l.weight.T + else: + b = (l.weight @ b[:, None]).view(-1) + self.l.bias + m = torch.nn.Linear(w.size(1), w.size(0)) + m.weight.data.copy_(w) + m.bias.data.copy_(b) + return m + + +def b16(n, activation, resolution=224): + return torch.nn.Sequential( + ConvNorm(3, n // 8, 3, 2, 1, resolution=resolution), + activation(), + ConvNorm(n // 8, n // 4, 3, 2, 1, resolution=resolution // 2), + activation(), + ConvNorm(n // 4, n // 2, 3, 2, 1, resolution=resolution // 4), + activation(), + ConvNorm(n // 2, n, 3, 2, 1, resolution=resolution // 8)) + + +class Residual(torch.nn.Module): + def __init__(self, m, drop): + super().__init__() + self.m = m + self.drop = drop + + def forward(self, x): + if self.training and self.drop > 0: + return x + self.m(x) * torch.rand( + x.size(0), 1, 1, device=x.device).ge_(self.drop).div(1 - self.drop).detach() + else: + return x + self.m(x) + + +class Attention(torch.nn.Module): + def __init__( + self, dim, key_dim, num_heads=8, attn_ratio=4, act_layer=None, resolution=14): + super().__init__() + self.num_heads = num_heads + self.scale = key_dim ** -0.5 + self.key_dim = key_dim + self.nh_kd = nh_kd = key_dim * num_heads + self.d = int(attn_ratio * key_dim) + self.dh = int(attn_ratio * key_dim) * num_heads + self.attn_ratio = attn_ratio + h = self.dh + nh_kd * 2 + self.qkv = LinearNorm(dim, h, resolution=resolution) + self.proj = torch.nn.Sequential( + act_layer(), + LinearNorm(self.dh, dim, bn_weight_init=0, resolution=resolution)) + + points = list(itertools.product(range(resolution), range(resolution))) + N = len(points) + attention_offsets = {} + idxs = [] + for p1 in points: + for p2 in points: + offset = (abs(p1[0] - p2[0]), abs(p1[1] - p2[1])) + if offset not in attention_offsets: + attention_offsets[offset] = len(attention_offsets) + idxs.append(attention_offsets[offset]) + self.attention_biases = torch.nn.Parameter(torch.zeros(num_heads, len(attention_offsets))) + self.register_buffer('attention_bias_idxs', torch.LongTensor(idxs).view(N, N)) + + @torch.no_grad() + def train(self, mode=True): + super().train(mode) + if mode and hasattr(self, 'ab'): + del self.ab + else: + self.ab = self.attention_biases[:, self.attention_bias_idxs] + + def forward(self, x): # x (B,N,C) + B, N, C = x.shape + qkv = self.qkv(x) + q, k, v = qkv.view(B, N, self.num_heads, -1).split([self.key_dim, self.key_dim, self.d], dim=3) + q = q.permute(0, 2, 1, 3) + k = k.permute(0, 2, 1, 3) + v = v.permute(0, 2, 1, 3) + + ab = self.attention_biases[:, self.attention_bias_idxs] if self.training else self.ab + attn = q @ k.transpose(-2, -1) * self.scale + ab + + attn = attn.softmax(dim=-1) + x = (attn @ v).transpose(1, 2).reshape(B, N, self.dh) + x = self.proj(x) + return x + + +class Subsample(torch.nn.Module): + def __init__(self, stride, resolution): + super().__init__() + self.stride = stride + self.resolution = resolution + + def forward(self, x): + B, N, C = x.shape + x = x.view(B, self.resolution, self.resolution, C)[:, ::self.stride, ::self.stride] + return x.reshape(B, -1, C) + + +class AttentionSubsample(torch.nn.Module): + def __init__(self, in_dim, out_dim, key_dim, num_heads=8, + attn_ratio=2, act_layer=None, stride=2, resolution=14, resolution_=7): + super().__init__() + self.num_heads = num_heads + self.scale = key_dim ** -0.5 + self.key_dim = key_dim + self.nh_kd = nh_kd = key_dim * num_heads + self.d = int(attn_ratio * key_dim) + self.dh = int(attn_ratio * key_dim) * self.num_heads + self.attn_ratio = attn_ratio + self.resolution_ = resolution_ + self.resolution_2 = resolution_ ** 2 + h = self.dh + nh_kd + self.kv = LinearNorm(in_dim, h, resolution=resolution) + + self.q = torch.nn.Sequential( + Subsample(stride, resolution), + LinearNorm(in_dim, nh_kd, resolution=resolution_)) + self.proj = torch.nn.Sequential( + act_layer(), + LinearNorm(self.dh, out_dim, resolution=resolution_)) + + self.stride = stride + self.resolution = resolution + points = list(itertools.product(range(resolution), range(resolution))) + points_ = list(itertools.product(range(resolution_), range(resolution_))) + N = len(points) + N_ = len(points_) + attention_offsets = {} + idxs = [] + for p1 in points_: + for p2 in points: + size = 1 + offset = ( + abs(p1[0] * stride - p2[0] + (size - 1) / 2), + abs(p1[1] * stride - p2[1] + (size - 1) / 2)) + if offset not in attention_offsets: + attention_offsets[offset] = len(attention_offsets) + idxs.append(attention_offsets[offset]) + self.attention_biases = torch.nn.Parameter(torch.zeros(num_heads, len(attention_offsets))) + self.register_buffer('attention_bias_idxs', torch.LongTensor(idxs).view(N_, N)) + + + @torch.no_grad() + def train(self, mode=True): + super().train(mode) + if mode and hasattr(self, 'ab'): + del self.ab + else: + self.ab = self.attention_biases[:, self.attention_bias_idxs] + + def forward(self, x): + B, N, C = x.shape + k, v = self.kv(x).view(B, N, self.num_heads, -1).split([self.key_dim, self.d], dim=3) + k = k.permute(0, 2, 1, 3) # BHNC + v = v.permute(0, 2, 1, 3) # BHNC + q = self.q(x).view(B, self.resolution_2, self.num_heads, self.key_dim).permute(0, 2, 1, 3) + + ab = self.attention_biases[:, self.attention_bias_idxs] if self.training else self.ab + attn = q @ k.transpose(-2, -1) * self.scale + ab + attn = attn.softmax(dim=-1) + + x = (attn @ v).transpose(1, 2).reshape(B, -1, self.dh) + x = self.proj(x) + return x + + +class Levit(torch.nn.Module): + """ Vision Transformer with support for patch or hybrid CNN input stage + """ + + def __init__( + self, + img_size=224, + patch_size=16, + in_chans=3, + num_classes=1000, + embed_dim=[192], + key_dim=[64], + depth=[12], + num_heads=[3], + attn_ratio=[2], + mlp_ratio=[2], + hybrid_backbone=None, + down_ops=[], + attn_act_layer=torch.nn.Hardswish, + mlp_act_layer=torch.nn.Hardswish, + distillation=True, + drop_path=0): + super().__init__() + global FLOPS_COUNTER + + self.num_classes = num_classes + self.num_features = embed_dim[-1] + self.embed_dim = embed_dim + self.distillation = distillation + + self.patch_embed = hybrid_backbone + + self.blocks = [] + down_ops.append(['']) + resolution = img_size // patch_size + for i, (ed, kd, dpth, nh, ar, mr, do) in enumerate( + zip(embed_dim, key_dim, depth, num_heads, attn_ratio, mlp_ratio, down_ops)): + for _ in range(dpth): + self.blocks.append( + Residual( + Attention(ed, kd, nh, attn_ratio=ar, act_layer=attn_act_layer, resolution=resolution), + drop_path)) + if mr > 0: + h = int(ed * mr) + self.blocks.append( + Residual(torch.nn.Sequential( + LinearNorm(ed, h, resolution=resolution), + mlp_act_layer(), + LinearNorm(h, ed, bn_weight_init=0, resolution=resolution), + ), drop_path)) + if do[0] == 'Subsample': + # ('Subsample',key_dim, num_heads, attn_ratio, mlp_ratio, stride) + resolution_ = (resolution - 1) // do[5] + 1 + self.blocks.append( + AttentionSubsample( + *embed_dim[i:i + 2], key_dim=do[1], num_heads=do[2], + attn_ratio=do[3], act_layer=attn_act_layer, stride=do[5], + resolution=resolution, resolution_=resolution_)) + resolution = resolution_ + if do[4] > 0: # mlp_ratio + h = int(embed_dim[i + 1] * do[4]) + self.blocks.append( + Residual(torch.nn.Sequential( + LinearNorm(embed_dim[i + 1], h, resolution=resolution), + mlp_act_layer(), + LinearNorm(h, embed_dim[i + 1], bn_weight_init=0, resolution=resolution), + ), drop_path)) + self.blocks = torch.nn.Sequential(*self.blocks) + + # Classifier head + self.head = NormLinear(embed_dim[-1], num_classes) if num_classes > 0 else torch.nn.Identity() + if distillation: + self.head_dist = NormLinear(embed_dim[-1], num_classes) if num_classes > 0 else torch.nn.Identity() + else: + self.head_dist = None + + @torch.jit.ignore + def no_weight_decay(self): + return {x for x in self.state_dict().keys() if 'attention_biases' in x} + + def forward(self, x): + x = self.patch_embed(x) + x = x.flatten(2).transpose(1, 2) + x = self.blocks(x) + x = x.mean(1) + if self.distillation: + x = self.head(x), self.head_dist(x) + if not self.training: + x = (x[0] + x[1]) / 2 + else: + x = self.head(x) + return x + + +def model_factory(C, D, X, N, drop_path, weights, num_classes, distillation, pretrained, fuse): + embed_dim = [int(x) for x in C.split('_')] + num_heads = [int(x) for x in N.split('_')] + depth = [int(x) for x in X.split('_')] + act = torch.nn.Hardswish + model = Levit( + patch_size=16, + embed_dim=embed_dim, + num_heads=num_heads, + key_dim=[D] * 3, + depth=depth, + attn_ratio=[2, 2, 2], + mlp_ratio=[2, 2, 2], + down_ops=[ + # ('Subsample',key_dim, num_heads, attn_ratio, mlp_ratio, stride) + ['Subsample', D, embed_dim[0] // D, 4, 2, 2], + ['Subsample', D, embed_dim[1] // D, 4, 2, 2], + ], + attn_act_layer=act, + mlp_act_layer=act, + hybrid_backbone=b16(embed_dim[0], activation=act), + num_classes=num_classes, + drop_path=drop_path, + distillation=distillation + ) + model.default_cfg = _cfg() + if pretrained: + checkpoint = torch.hub.load_state_dict_from_url(weights, map_location='cpu') + model.load_state_dict(checkpoint['model']) + #if fuse: + # utils.replace_batchnorm(model) + + return model diff --git a/timm/models/levitc.py b/timm/models/levitc.py new file mode 100644 index 00000000..1a422953 --- /dev/null +++ b/timm/models/levitc.py @@ -0,0 +1,400 @@ +# Copyright (c) 2015-present, Facebook, Inc. +# All rights reserved. + +# Modified from +# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py +# Copyright 2020 Ross Wightman, Apache-2.0 License +import itertools + +import torch + +from timm.data import IMAGENET_DEFAULT_STD, IMAGENET_DEFAULT_MEAN +from .vision_transformer import trunc_normal_ +from .registry import register_model + + +def _cfg(url='', **kwargs): + return { + 'url': url, + 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, + 'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True, + 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, + 'first_conv': 'patch_embed.proj', 'classifier': 'head', + **kwargs + } + + +specification = { + 'levit_c_128s': { + 'C': '128_256_384', 'D': 16, 'N': '4_6_8', 'X': '2_3_4', 'drop_path': 0, + 'weights': 'https://dl.fbaipublicfiles.com/LeViT/LeViT-128S-96703c44.pth'}, + 'levit_c_128': { + 'C': '128_256_384', 'D': 16, 'N': '4_8_12', 'X': '4_4_4', 'drop_path': 0, + 'weights': 'https://dl.fbaipublicfiles.com/LeViT/LeViT-128-b88c2750.pth'}, + 'levit_c_192': { + 'C': '192_288_384', 'D': 32, 'N': '3_5_6', 'X': '4_4_4', 'drop_path': 0, + 'weights': 'https://dl.fbaipublicfiles.com/LeViT/LeViT-192-92712e41.pth'}, + 'levit_c_256': { + 'C': '256_384_512', 'D': 32, 'N': '4_6_8', 'X': '4_4_4', 'drop_path': 0, + 'weights': 'https://dl.fbaipublicfiles.com/LeViT/LeViT-256-13b5763e.pth'}, + 'levit_c_384': { + 'C': '384_512_768', 'D': 32, 'N': '6_9_12', 'X': '4_4_4', 'drop_path': 0.1, + 'weights': 'https://dl.fbaipublicfiles.com/LeViT/LeViT-384-9bdaf2e2.pth'}, +} + +__all__ = ['Levit'] + + +@register_model +def levit_c_128s(num_classes=1000, distillation=True, pretrained=False, fuse=False, **kwargs): + return model_factory(**specification['levit_c_128s'], num_classes=num_classes, + distillation=distillation, pretrained=pretrained, fuse=fuse) + + +@register_model +def levit_c_128(num_classes=1000, distillation=True, pretrained=False, fuse=False, **kwargs): + return model_factory(**specification['levit_c_128'], num_classes=num_classes, + distillation=distillation, pretrained=pretrained, fuse=fuse) + + +@register_model +def levit_c_192(num_classes=1000, distillation=True, pretrained=False, fuse=False, **kwargs): + return model_factory(**specification['levit_c_192'], num_classes=num_classes, + distillation=distillation, pretrained=pretrained, fuse=fuse) + + +@register_model +def levit_c_256(num_classes=1000, distillation=True, pretrained=False, fuse=False, **kwargs): + return model_factory(**specification['levit_c_256'], num_classes=num_classes, + distillation=distillation, pretrained=pretrained, fuse=fuse) + + +@register_model +def levit_c_384(num_classes=1000, distillation=True, pretrained=False, fuse=False, **kwargs): + return model_factory(**specification['levit_c_384'], num_classes=num_classes, + distillation=distillation, pretrained=pretrained, fuse=fuse) + + +class ConvNorm(torch.nn.Sequential): + def __init__( + self, a, b, ks=1, stride=1, pad=0, dilation=1, groups=1, bn_weight_init=1, resolution=-10000): + super().__init__() + self.add_module('c', torch.nn.Conv2d(a, b, ks, stride, pad, dilation, groups, bias=False)) + bn = torch.nn.BatchNorm2d(b) + torch.nn.init.constant_(bn.weight, bn_weight_init) + torch.nn.init.constant_(bn.bias, 0) + self.add_module('bn', bn) + + @torch.no_grad() + def fuse(self): + c, bn = self._modules.values() + w = bn.weight / (bn.running_var + bn.eps) ** 0.5 + w = c.weight * w[:, None, None, None] + b = bn.bias - bn.running_mean * bn.weight / \ + (bn.running_var + bn.eps) ** 0.5 + m = torch.nn.Conv2d( + w.size(1), w.size(0), w.shape[2:], stride=self.c.stride, + padding=self.c.padding, dilation=self.c.dilation, groups=self.c.groups) + m.weight.data.copy_(w) + m.bias.data.copy_(b) + return m + + +class NormLinear(torch.nn.Sequential): + def __init__(self, a, b, bias=True, std=0.02): + super().__init__() + self.add_module('bn', torch.nn.BatchNorm1d(a)) + l = torch.nn.Linear(a, b, bias=bias) + trunc_normal_(l.weight, std=std) + if bias: + torch.nn.init.constant_(l.bias, 0) + self.add_module('l', l) + + @torch.no_grad() + def fuse(self): + bn, l = self._modules.values() + w = bn.weight / (bn.running_var + bn.eps) ** 0.5 + b = bn.bias - self.bn.running_mean * \ + self.bn.weight / (bn.running_var + bn.eps) ** 0.5 + w = l.weight * w[None, :] + if l.bias is None: + b = b @ self.l.weight.T + else: + b = (l.weight @ b[:, None]).view(-1) + self.l.bias + m = torch.nn.Linear(w.size(1), w.size(0)) + m.weight.data.copy_(w) + m.bias.data.copy_(b) + return m + + +def b16(n, activation, resolution=224): + return torch.nn.Sequential( + ConvNorm(3, n // 8, 3, 2, 1, resolution=resolution), + activation(), + ConvNorm(n // 8, n // 4, 3, 2, 1, resolution=resolution // 2), + activation(), + ConvNorm(n // 4, n // 2, 3, 2, 1, resolution=resolution // 4), + activation(), + ConvNorm(n // 2, n, 3, 2, 1, resolution=resolution // 8)) + + +class Residual(torch.nn.Module): + def __init__(self, m, drop): + super().__init__() + self.m = m + self.drop = drop + + def forward(self, x): + if self.training and self.drop > 0: + return x + self.m(x) * torch.rand( + x.size(0), 1, 1, device=x.device).ge_(self.drop).div(1 - self.drop).detach() + else: + return x + self.m(x) + + +class Attention(torch.nn.Module): + def __init__(self, dim, key_dim, num_heads=8, + attn_ratio=4, act_layer=None, resolution=14): + super().__init__() + self.num_heads = num_heads + self.scale = key_dim ** -0.5 + self.key_dim = key_dim + self.nh_kd = nh_kd = key_dim * num_heads + self.d = int(attn_ratio * key_dim) + self.dh = int(attn_ratio * key_dim) * num_heads + self.attn_ratio = attn_ratio + h = self.dh + nh_kd * 2 + self.qkv = ConvNorm(dim, h, resolution=resolution) + self.proj = torch.nn.Sequential( + act_layer(), + ConvNorm(self.dh, dim, bn_weight_init=0, resolution=resolution)) + + points = list(itertools.product(range(resolution), range(resolution))) + N = len(points) + attention_offsets = {} + idxs = [] + for p1 in points: + for p2 in points: + offset = (abs(p1[0] - p2[0]), abs(p1[1] - p2[1])) + if offset not in attention_offsets: + attention_offsets[offset] = len(attention_offsets) + idxs.append(attention_offsets[offset]) + self.attention_biases = torch.nn.Parameter(torch.zeros(num_heads, len(attention_offsets))) + self.register_buffer('attention_bias_idxs', torch.LongTensor(idxs).view(N, N)) + self.ab = None + + @torch.no_grad() + def train(self, mode=True): + super().train(mode) + if mode and self.ab is not None: + self.ab = None + else: + self.ab = self.attention_biases[:, self.attention_bias_idxs] + + def forward(self, x): # x (B,C,H,W) + B, C, H, W = x.shape + q, k, v = self.qkv(x).view(B, self.num_heads, -1, H * W).split([self.key_dim, self.key_dim, self.d], dim=2) + ab = self.attention_biases[:, self.attention_bias_idxs] if self.training else self.ab + attn = (q.transpose(-2, -1) @ k) * self.scale + ab + attn = attn.softmax(dim=-1) + x = (v @ attn.transpose(-2, -1)).view(B, -1, H, W) + x = self.proj(x) + return x + + +class AttentionSubsample(torch.nn.Module): + def __init__( + self, in_dim, out_dim, key_dim, num_heads=8, attn_ratio=2, + act_layer=None, stride=2, resolution=14, resolution_=7): + super().__init__() + self.num_heads = num_heads + self.scale = key_dim ** -0.5 + self.key_dim = key_dim + self.nh_kd = nh_kd = key_dim * num_heads + self.d = int(attn_ratio * key_dim) + self.dh = int(attn_ratio * key_dim) * self.num_heads + self.attn_ratio = attn_ratio + self.resolution_ = resolution_ + self.resolution_2 = resolution_ ** 2 + h = self.dh + nh_kd + self.kv = ConvNorm(in_dim, h, resolution=resolution) + self.q = torch.nn.Sequential( + torch.nn.AvgPool2d(1, stride, 0), + ConvNorm(in_dim, nh_kd, resolution=resolution_)) + self.proj = torch.nn.Sequential( + act_layer(), + ConvNorm(self.d * num_heads, out_dim, resolution=resolution_)) + + self.stride = stride + self.resolution = resolution + points = list(itertools.product(range(resolution), range(resolution))) + points_ = list(itertools.product(range(resolution_), range(resolution_))) + N = len(points) + N_ = len(points_) + attention_offsets = {} + idxs = [] + for p1 in points_: + for p2 in points: + size = 1 + offset = ( + abs(p1[0] * stride - p2[0] + (size - 1) / 2), + abs(p1[1] * stride - p2[1] + (size - 1) / 2)) + if offset not in attention_offsets: + attention_offsets[offset] = len(attention_offsets) + idxs.append(attention_offsets[offset]) + self.attention_biases = torch.nn.Parameter(torch.zeros(num_heads, len(attention_offsets))) + self.register_buffer('attention_bias_idxs', torch.LongTensor(idxs).view(N_, N)) + self.ab = None + + @torch.no_grad() + def train(self, mode=True): + super().train(mode) + if mode and self.ab is not None: + self.ab = None + else: + self.ab = self.attention_biases[:, self.attention_bias_idxs] + + def forward(self, x): + B, C, H, W = x.shape + k, v = self.kv(x).view(B, self.num_heads, -1, H * W).split([self.key_dim, self.d], dim=2) + q = self.q(x).view(B, self.num_heads, self.key_dim, self.resolution_2) + ab = self.attention_biases[:, self.attention_bias_idxs] if self.training else self.ab + attn = (q.transpose(-2, -1) @ k) * self.scale + ab + attn = attn.softmax(dim=-1) + + x = (v @ attn.transpose(-2, -1)).reshape(B, -1, self.resolution_, self.resolution_) + x = self.proj(x) + return x + + +class Levit(torch.nn.Module): + """ Vision Transformer with support for patch or hybrid CNN input stage + """ + + def __init__( + self, + img_size=224, + patch_size=16, + in_chans=3, + num_classes=1000, + embed_dim=[192], + key_dim=[64], + depth=[12], + num_heads=[3], + attn_ratio=[2], + mlp_ratio=[2], + hybrid_backbone=None, + down_ops=[], + attn_act_layer=torch.nn.Hardswish, + mlp_act_layer=torch.nn.Hardswish, + distillation=True, + drop_path=0): + super().__init__() + self.num_classes = num_classes + self.num_features = embed_dim[-1] + self.embed_dim = embed_dim + self.distillation = distillation + + self.patch_embed = hybrid_backbone + + self.blocks = [] + down_ops.append(['']) + resolution = img_size // patch_size + for i, (ed, kd, dpth, nh, ar, mr, do) in enumerate( + zip(embed_dim, key_dim, depth, num_heads, attn_ratio, mlp_ratio, down_ops)): + for _ in range(dpth): + self.blocks.append( + Residual( + Attention(ed, kd, nh, attn_ratio=ar, act_layer=attn_act_layer, resolution=resolution), + drop_path)) + if mr > 0: + h = int(ed * mr) + self.blocks.append( + Residual(torch.nn.Sequential( + ConvNorm(ed, h, resolution=resolution), + mlp_act_layer(), + ConvNorm(h, ed, bn_weight_init=0, resolution=resolution), + ), drop_path)) + if do[0] == 'Subsample': + # ('Subsample',key_dim, num_heads, attn_ratio, mlp_ratio, stride) + resolution_ = (resolution - 1) // do[5] + 1 + self.blocks.append( + AttentionSubsample( + *embed_dim[i:i + 2], key_dim=do[1], num_heads=do[2], attn_ratio=do[3], + act_layer=attn_act_layer, stride=do[5], + resolution=resolution, resolution_=resolution_)) + resolution = resolution_ + if do[4] > 0: # mlp_ratio + h = int(embed_dim[i + 1] * do[4]) + self.blocks.append( + Residual(torch.nn.Sequential( + ConvNorm(embed_dim[i + 1], h, resolution=resolution), + mlp_act_layer(), + ConvNorm(h, embed_dim[i + 1], bn_weight_init=0, resolution=resolution), + ), drop_path)) + self.blocks = torch.nn.Sequential(*self.blocks) + + # Classifier head + self.head = NormLinear( + embed_dim[-1], num_classes) if num_classes > 0 else torch.nn.Identity() + if distillation: + self.head_dist = NormLinear( + embed_dim[-1], num_classes) if num_classes > 0 else torch.nn.Identity() + + @torch.jit.ignore + def no_weight_decay(self): + return {x for x in self.state_dict().keys() if 'attention_biases' in x} + + def forward(self, x): + x = self.patch_embed(x) + x = self.blocks(x) + x = torch.nn.functional.adaptive_avg_pool2d(x, 1).flatten(1) + if self.distillation: + x = self.head(x), self.head_dist(x) + if not self.training: + x = (x[0] + x[1]) / 2 + else: + x = self.head(x) + return x + + +def model_factory(C, D, X, N, drop_path, weights, num_classes, distillation, pretrained, fuse): + embed_dim = [int(x) for x in C.split('_')] + num_heads = [int(x) for x in N.split('_')] + depth = [int(x) for x in X.split('_')] + act = torch.nn.Hardswish + model = Levit( + patch_size=16, + embed_dim=embed_dim, + num_heads=num_heads, + key_dim=[D] * 3, + depth=depth, + attn_ratio=[2, 2, 2], + mlp_ratio=[2, 2, 2], + down_ops=[ + # ('Subsample',key_dim, num_heads, attn_ratio, mlp_ratio, stride) + ['Subsample', D, embed_dim[0] // D, 4, 2, 2], + ['Subsample', D, embed_dim[1] // D, 4, 2, 2], + ], + attn_act_layer=act, + mlp_act_layer=act, + hybrid_backbone=b16(embed_dim[0], activation=act), + num_classes=num_classes, + drop_path=drop_path, + distillation=distillation + ) + model.default_cfg = _cfg() + if pretrained: + checkpoint = torch.hub.load_state_dict_from_url( + weights, map_location='cpu') + d = checkpoint['model'] + D = model.state_dict() + for k in d.keys(): + if D[k].shape != d[k].shape: + d[k] = d[k][:, :, None, None] + model.load_state_dict(d) + #if fuse: + # utils.replace_batchnorm(model) + + return model + diff --git a/timm/models/visformer.py b/timm/models/visformer.py new file mode 100644 index 00000000..0d213ad5 --- /dev/null +++ b/timm/models/visformer.py @@ -0,0 +1,377 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from .helpers import build_model_with_cfg, overlay_external_default_cfg +from .layers import to_2tuple, trunc_normal_, DropPath, PatchEmbed +from .registry import register_model + + +__all__ = ['Visformer'] + + +def _cfg(url='', **kwargs): + return { + 'url': url, + 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, + 'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True, + 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, + 'first_conv': 'patch_embed.proj', 'classifier': 'head', + **kwargs + } + + +class LayerNormBHWC(nn.LayerNorm): + def __init__(self, dim): + super().__init__(dim) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return F.layer_norm( + x.permute(0, 2, 3, 1), self.normalized_shape, self.weight, self.bias, self.eps).permute(0, 3, 1, 2) + + +class SpatialMlp(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, + act_layer=nn.GELU, drop=0., group=8, spatial_conv=False): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.in_features = in_features + self.out_features = out_features + self.spatial_conv = spatial_conv + if self.spatial_conv: + if group < 2: # net setting + hidden_features = in_features * 5 // 6 + else: + hidden_features = in_features * 2 + self.hidden_features = hidden_features + self.group = group + self.drop = nn.Dropout(drop) + self.conv1 = nn.Conv2d(in_features, hidden_features, 1, stride=1, padding=0, bias=False) + self.act1 = act_layer() + if self.spatial_conv: + self.conv2 = nn.Conv2d( + hidden_features, hidden_features, 3, stride=1, padding=1, groups=self.group, bias=False) + self.act2 = act_layer() + else: + self.conv2 = None + self.act2 = None + self.conv3 = nn.Conv2d(hidden_features, out_features, 1, stride=1, padding=0, bias=False) + + def forward(self, x): + x = self.conv1(x) + x = self.act1(x) + x = self.drop(x) + if self.conv2 is not None: + x = self.conv2(x) + x = self.act2(x) + x = self.conv3(x) + x = self.drop(x) + return x + + +class Attention(nn.Module): + def __init__(self, dim, num_heads=8, head_dim_ratio=1., attn_drop=0., proj_drop=0.): + super().__init__() + self.dim = dim + self.num_heads = num_heads + head_dim = round(dim // num_heads * head_dim_ratio) + self.head_dim = head_dim + self.scale = head_dim ** -0.5 + self.qkv = nn.Conv2d(dim, head_dim * num_heads * 3, 1, stride=1, padding=0, bias=False) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Conv2d(self.head_dim * self.num_heads, dim, 1, stride=1, padding=0, bias=False) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x): + B, C, H, W = x.shape + x = self.qkv(x).reshape(B, 3, self.num_heads, self.head_dim, -1).permute(1, 0, 2, 4, 3) + q, k, v = x[0], x[1], x[2] + + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + x = attn @ v + + x = x.permute(0, 1, 3, 2).reshape(B, -1, H, W) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class Block(nn.Module): + def __init__(self, dim, num_heads, head_dim_ratio=1., mlp_ratio=4., + drop=0., attn_drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=LayerNormBHWC, + group=8, attn_disabled=False, spatial_conv=False): + super().__init__() + self.spatial_conv = spatial_conv + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + if attn_disabled: + self.norm1 = None + self.attn = None + else: + self.norm1 = norm_layer(dim) + self.attn = Attention( + dim, num_heads=num_heads, head_dim_ratio=head_dim_ratio, attn_drop=attn_drop, proj_drop=drop) + + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = SpatialMlp( + in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop, + group=group, spatial_conv=spatial_conv) # new setting + + def forward(self, x): + if self.attn is not None: + x = x + self.drop_path(self.attn(self.norm1(x))) + x = x + self.drop_path(self.mlp(self.norm2(x))) + return x + + +class Visformer(nn.Module): + def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, init_channels=32, embed_dim=384, + depth=12, num_heads=6, mlp_ratio=4., drop_rate=0., attn_drop_rate=0., drop_path_rate=0., + norm_layer=LayerNormBHWC, attn_stage='111', pos_embed=True, spatial_conv='111', + vit_stem=False, group=8, pool=True, conv_init=False, embed_norm=None): + super().__init__() + self.num_classes = num_classes + self.num_features = self.embed_dim = embed_dim + self.init_channels = init_channels + self.img_size = img_size + self.vit_stem = vit_stem + self.pool = pool + self.conv_init = conv_init + if isinstance(depth, (list, tuple)): + self.stage_num1, self.stage_num2, self.stage_num3 = depth + depth = sum(depth) + else: + self.stage_num1 = self.stage_num3 = depth // 3 + self.stage_num2 = depth - self.stage_num1 - self.stage_num3 + self.pos_embed = pos_embed + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] + + # stage 1 + if self.vit_stem: + self.stem = None + self.patch_embed1 = PatchEmbed( + img_size=img_size, patch_size=patch_size, in_chans=in_chans, + embed_dim=embed_dim, norm_layer=embed_norm, flatten=False) + img_size //= 16 + else: + if self.init_channels is None: + self.stem = None + self.patch_embed1 = PatchEmbed( + img_size=img_size, patch_size=patch_size // 2, in_chans=in_chans, + embed_dim=embed_dim // 2, norm_layer=embed_norm, flatten=False) + img_size //= 8 + else: + self.stem = nn.Sequential( + nn.Conv2d(3, self.init_channels, 7, stride=2, padding=3, bias=False), + nn.BatchNorm2d(self.init_channels), + nn.ReLU(inplace=True) + ) + img_size //= 2 + self.patch_embed1 = PatchEmbed( + img_size=img_size, patch_size=patch_size // 4, in_chans=self.init_channels, + embed_dim=embed_dim // 2, norm_layer=embed_norm, flatten=False) + img_size //= 4 + + if self.pos_embed: + if self.vit_stem: + self.pos_embed1 = nn.Parameter(torch.zeros(1, embed_dim, img_size, img_size)) + else: + self.pos_embed1 = nn.Parameter(torch.zeros(1, embed_dim//2, img_size, img_size)) + self.pos_drop = nn.Dropout(p=drop_rate) + self.stage1 = nn.ModuleList([ + Block( + dim=embed_dim//2, num_heads=num_heads, head_dim_ratio=0.5, mlp_ratio=mlp_ratio, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, + group=group, attn_disabled=(attn_stage[0] == '0'), spatial_conv=(spatial_conv[0] == '1') + ) + for i in range(self.stage_num1) + ]) + + #stage2 + if not self.vit_stem: + self.patch_embed2 = PatchEmbed( + img_size=img_size, patch_size=patch_size // 8, in_chans=embed_dim // 2, + embed_dim=embed_dim, norm_layer=embed_norm, flatten=False) + img_size //= 2 + if self.pos_embed: + self.pos_embed2 = nn.Parameter(torch.zeros(1, embed_dim, img_size, img_size)) + self.stage2 = nn.ModuleList([ + Block( + dim=embed_dim, num_heads=num_heads, head_dim_ratio=1.0, mlp_ratio=mlp_ratio, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, + group=group, attn_disabled=(attn_stage[1] == '0'), spatial_conv=(spatial_conv[1] == '1') + ) + for i in range(self.stage_num1, self.stage_num1+self.stage_num2) + ]) + + # stage 3 + if not self.vit_stem: + self.patch_embed3 = PatchEmbed( + img_size=img_size, patch_size=patch_size // 8, in_chans=embed_dim, + embed_dim=embed_dim * 2, norm_layer=embed_norm, flatten=False) + img_size //= 2 + if self.pos_embed: + self.pos_embed3 = nn.Parameter(torch.zeros(1, embed_dim*2, img_size, img_size)) + self.stage3 = nn.ModuleList([ + Block( + dim=embed_dim*2, num_heads=num_heads, head_dim_ratio=1.0, mlp_ratio=mlp_ratio, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, + group=group, attn_disabled=(attn_stage[2] == '0'), spatial_conv=(spatial_conv[2] == '1') + ) + for i in range(self.stage_num1+self.stage_num2, depth) + ]) + + # head + if self.pool: + self.global_pooling = nn.AdaptiveAvgPool2d(1) + head_dim = embed_dim if self.vit_stem else embed_dim * 2 + self.norm = norm_layer(head_dim) + self.head = nn.Linear(head_dim, num_classes) + + # weights init + if self.pos_embed: + trunc_normal_(self.pos_embed1, std=0.02) + if not self.vit_stem: + trunc_normal_(self.pos_embed2, std=0.02) + trunc_normal_(self.pos_embed3, std=0.02) + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=0.02) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + elif isinstance(m, nn.BatchNorm2d): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + elif isinstance(m, nn.Conv2d): + if self.conv_init: + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + else: + trunc_normal_(m.weight, std=0.02) + if m.bias is not None: + nn.init.constant_(m.bias, 0.) + + def forward(self, x): + if self.stem is not None: + x = self.stem(x) + + # stage 1 + x = self.patch_embed1(x) + if self.pos_embed: + x = x + self.pos_embed1 + x = self.pos_drop(x) + for b in self.stage1: + x = b(x) + + # stage 2 + if not self.vit_stem: + x = self.patch_embed2(x) + if self.pos_embed: + x = x + self.pos_embed2 + x = self.pos_drop(x) + for b in self.stage2: + x = b(x) + + # stage3 + if not self.vit_stem: + x = self.patch_embed3(x) + if self.pos_embed: + x = x + self.pos_embed3 + x = self.pos_drop(x) + for b in self.stage3: + x = b(x) + + # head + x = self.norm(x) + if self.pool: + x = self.global_pooling(x) + else: + x = x[:, :, 0, 0] + + x = self.head(x.view(x.size(0), -1)) + return x + + +@register_model +def visformer_tiny(pretrained=False, **kwargs): + model = Visformer( + img_size=224, init_channels=16, embed_dim=192, depth=(7, 4, 4), num_heads=3, mlp_ratio=4., group=8, + attn_stage='011', spatial_conv='100', norm_layer=nn.BatchNorm2d, conv_init=True, + embed_norm=nn.BatchNorm2d, **kwargs) + return model + + +@register_model +def visformer_small(pretrained=False, **kwargs): + model = Visformer( + img_size=224, init_channels=32, embed_dim=384, depth=(7, 4, 4), num_heads=6, mlp_ratio=4., group=8, + attn_stage='011', spatial_conv='100', norm_layer=nn.BatchNorm2d, conv_init=True, + embed_norm=nn.BatchNorm2d, **kwargs) + return model + + +@register_model +def visformer_net1(pretrained=False, **kwargs): + model = Visformer( + init_channels=None, embed_dim=384, depth=(0, 12, 0), num_heads=6, mlp_ratio=4., attn_stage='111', + spatial_conv='000', vit_stem=True, conv_init=True, **kwargs) + return model + + +@register_model +def visformer_net2(pretrained=False, **kwargs): + model = Visformer( + init_channels=32, embed_dim=384, depth=(0, 12, 0), num_heads=6, mlp_ratio=4., attn_stage='111', + spatial_conv='000', vit_stem=False, conv_init=True, **kwargs) + return model + + +@register_model +def visformer_net3(pretrained=False, **kwargs): + model = Visformer( + init_channels=32, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4., attn_stage='111', + spatial_conv='000', vit_stem=False, conv_init=True, **kwargs) + return model + + +@register_model +def visformer_net4(pretrained=False, **kwargs): + model = Visformer(init_channels=32, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4., attn_stage='111', + spatial_conv='000', vit_stem=False, conv_init=True, **kwargs) + return model + + +@register_model +def visformer_net5(pretrained=False, **kwargs): + model = Visformer( + init_channels=32, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4., group=1, attn_stage='111', + spatial_conv='111', vit_stem=False, conv_init=True, **kwargs) + return model + + +@register_model +def visformer_net6(pretrained=False, **kwargs): + model = Visformer( + init_channels=32, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4., group=1, attn_stage='111', + pos_embed=False, spatial_conv='111', conv_init=True, **kwargs) + return model + + +@register_model +def visformer_net7(pretrained=False, **kwargs): + model = Visformer( + init_channels=32, embed_dim=384, depth=(6, 7, 7), num_heads=6, group=1, attn_stage='000', + pos_embed=False, spatial_conv='111', conv_init=True, **kwargs) + return model + + + + From 94d4b53352f6824b9cbe41c3dd70c18103714951 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Sat, 15 May 2021 08:41:31 -0700 Subject: [PATCH 03/48] Add temporary default_cfgs to visformer models so they pass tests --- timm/models/visformer.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/timm/models/visformer.py b/timm/models/visformer.py index 0d213ad5..aa3bca57 100644 --- a/timm/models/visformer.py +++ b/timm/models/visformer.py @@ -306,6 +306,7 @@ def visformer_tiny(pretrained=False, **kwargs): img_size=224, init_channels=16, embed_dim=192, depth=(7, 4, 4), num_heads=3, mlp_ratio=4., group=8, attn_stage='011', spatial_conv='100', norm_layer=nn.BatchNorm2d, conv_init=True, embed_norm=nn.BatchNorm2d, **kwargs) + model.default_cfg = _cfg() return model @@ -315,6 +316,7 @@ def visformer_small(pretrained=False, **kwargs): img_size=224, init_channels=32, embed_dim=384, depth=(7, 4, 4), num_heads=6, mlp_ratio=4., group=8, attn_stage='011', spatial_conv='100', norm_layer=nn.BatchNorm2d, conv_init=True, embed_norm=nn.BatchNorm2d, **kwargs) + model.default_cfg = _cfg() return model @@ -323,6 +325,7 @@ def visformer_net1(pretrained=False, **kwargs): model = Visformer( init_channels=None, embed_dim=384, depth=(0, 12, 0), num_heads=6, mlp_ratio=4., attn_stage='111', spatial_conv='000', vit_stem=True, conv_init=True, **kwargs) + model.default_cfg = _cfg() return model @@ -331,6 +334,7 @@ def visformer_net2(pretrained=False, **kwargs): model = Visformer( init_channels=32, embed_dim=384, depth=(0, 12, 0), num_heads=6, mlp_ratio=4., attn_stage='111', spatial_conv='000', vit_stem=False, conv_init=True, **kwargs) + model.default_cfg = _cfg() return model @@ -339,13 +343,16 @@ def visformer_net3(pretrained=False, **kwargs): model = Visformer( init_channels=32, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4., attn_stage='111', spatial_conv='000', vit_stem=False, conv_init=True, **kwargs) + model.default_cfg = _cfg() return model @register_model def visformer_net4(pretrained=False, **kwargs): - model = Visformer(init_channels=32, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4., attn_stage='111', - spatial_conv='000', vit_stem=False, conv_init=True, **kwargs) + model = Visformer( + init_channels=32, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4., attn_stage='111', + spatial_conv='000', vit_stem=False, conv_init=True, **kwargs) + model.default_cfg = _cfg() return model @@ -354,6 +361,7 @@ def visformer_net5(pretrained=False, **kwargs): model = Visformer( init_channels=32, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4., group=1, attn_stage='111', spatial_conv='111', vit_stem=False, conv_init=True, **kwargs) + model.default_cfg = _cfg() return model @@ -362,6 +370,7 @@ def visformer_net6(pretrained=False, **kwargs): model = Visformer( init_channels=32, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4., group=1, attn_stage='111', pos_embed=False, spatial_conv='111', conv_init=True, **kwargs) + model.default_cfg = _cfg() return model @@ -370,6 +379,7 @@ def visformer_net7(pretrained=False, **kwargs): model = Visformer( init_channels=32, embed_dim=384, depth=(6, 7, 7), num_heads=6, group=1, attn_stage='000', pos_embed=False, spatial_conv='111', conv_init=True, **kwargs) + model.default_cfg = _cfg() return model From d53e91218e9c1a6df468740a5d86a956016042f8 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Sat, 15 May 2021 22:56:12 -0700 Subject: [PATCH 04/48] Fix tf.data options setting for newer TF versions --- timm/data/parsers/parser_tfds.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/timm/data/parsers/parser_tfds.py b/timm/data/parsers/parser_tfds.py index 0b12a4db..2ff90b09 100644 --- a/timm/data/parsers/parser_tfds.py +++ b/timm/data/parsers/parser_tfds.py @@ -25,8 +25,8 @@ from .parser import Parser MAX_TP_SIZE = 8 # maximum TF threadpool size, only doing jpeg decodes and queuing activities -SHUFFLE_SIZE = 16834 # samples to shuffle in DS queue -PREFETCH_SIZE = 4096 # samples to prefetch +SHUFFLE_SIZE = 20480 # samples to shuffle in DS queue +PREFETCH_SIZE = 2048 # samples to prefetch def even_split_indices(split, n, num_samples): @@ -144,14 +144,16 @@ class ParserTfds(Parser): ds = self.builder.as_dataset( split=self.subsplit or self.split, shuffle_files=self.shuffle, read_config=read_config) # avoid overloading threading w/ combo fo TF ds threads + PyTorch workers - ds.options().experimental_threading.private_threadpool_size = max(1, MAX_TP_SIZE // num_workers) - ds.options().experimental_threading.max_intra_op_parallelism = 1 + options = tf.data.Options() + options.experimental_threading.private_threadpool_size = max(1, MAX_TP_SIZE // num_workers) + options.experimental_threading.max_intra_op_parallelism = 1 + ds = ds.with_options(options) if self.is_training or self.repeats > 1: # to prevent excessive drop_last batch behaviour w/ IterableDatasets # see warnings at https://pytorch.org/docs/stable/data.html#multi-process-data-loading ds = ds.repeat() # allow wrap around and break iteration manually if self.shuffle: - ds = ds.shuffle(min(self.num_samples // self._num_pipelines, SHUFFLE_SIZE), seed=0) + ds = ds.shuffle(min(self.num_samples, SHUFFLE_SIZE) // self._num_pipelines, seed=0) ds = ds.prefetch(min(self.num_samples // self._num_pipelines, PREFETCH_SIZE)) self.ds = tfds.as_numpy(ds) From 9a3ae97311d7971c532d19e6262c600fcdd7808d Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Sat, 15 May 2021 22:56:51 -0700 Subject: [PATCH 05/48] Another set of byoanet models w/ ECA channel + SA + groups --- timm/models/byoanet.py | 101 +++++++++++++++++++++++++++++++++++++++++ timm/models/byobnet.py | 2 +- 2 files changed, 102 insertions(+), 1 deletion(-) diff --git a/timm/models/byoanet.py b/timm/models/byoanet.py index a58eea63..c179a01c 100644 --- a/timm/models/byoanet.py +++ b/timm/models/byoanet.py @@ -47,17 +47,21 @@ default_cfgs = { # GPU-Efficient (ResNet) weights 'botnet26t_256': _cfg(url='', fixed_input_size=True, input_size=(3, 256, 256), pool_size=(8, 8)), 'botnet50ts_256': _cfg(url='', fixed_input_size=True, input_size=(3, 256, 256), pool_size=(8, 8)), + 'eca_botnext26ts_256': _cfg(url='', fixed_input_size=True, input_size=(3, 256, 256), pool_size=(8, 8)), 'halonet_h1': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8), min_input_size=(3, 256, 256)), 'halonet_h1_c4c5': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8), min_input_size=(3, 256, 256)), 'halonet26t': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8), min_input_size=(3, 256, 256)), 'halonet50ts': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8), min_input_size=(3, 256, 256)), + 'eca_halonext26ts': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8), min_input_size=(3, 256, 256)), 'lambda_resnet26t': _cfg(url='', min_input_size=(3, 128, 128), input_size=(3, 256, 256), pool_size=(8, 8)), 'lambda_resnet50t': _cfg(url='', min_input_size=(3, 128, 128)), + 'eca_lambda_resnext26ts': _cfg(url='', min_input_size=(3, 128, 128), input_size=(3, 256, 256), pool_size=(8, 8)), 'swinnet26t_256': _cfg(url='', fixed_input_size=True, input_size=(3, 256, 256), pool_size=(8, 8)), 'swinnet50ts_256': _cfg(url='', fixed_input_size=True, input_size=(3, 256, 256), pool_size=(8, 8)), + 'eca_swinnext26ts_256': _cfg(url='', fixed_input_size=True, input_size=(3, 256, 256), pool_size=(8, 8)), 'rednet26t': _cfg(url='', fixed_input_size=False, input_size=(3, 256, 256), pool_size=(8, 8)), 'rednet50ts': _cfg(url='', fixed_input_size=False, input_size=(3, 256, 256), pool_size=(8, 8)), @@ -129,6 +133,23 @@ model_cfgs = dict( self_attn_fixed_size=True, self_attn_kwargs=dict() ), + eca_botnext26ts=ByoaCfg( + blocks=( + ByoaBlocksCfg(type='bottle', d=3, c=256, s=1, gs=16, br=0.25), + ByoaBlocksCfg(type='bottle', d=4, c=512, s=2, gs=16, br=0.25), + interleave_attn(types=('bottle', 'self_attn'), every=1, d=2, c=1024, s=2, gs=16, br=0.25), + ByoaBlocksCfg(type='self_attn', d=3, c=2048, s=2, gs=16, br=0.25), + ), + stem_chs=64, + stem_type='tiered', + stem_pool='maxpool', + num_features=0, + act_layer='silu', + attn_layer='eca', + self_attn_layer='bottleneck', + self_attn_fixed_size=True, + self_attn_kwargs=dict() + ), halonet_h1=ByoaCfg( blocks=( @@ -187,6 +208,22 @@ model_cfgs = dict( self_attn_layer='halo', self_attn_kwargs=dict(block_size=8, halo_size=2) ), + eca_halonext26ts=ByoaCfg( + blocks=( + ByoaBlocksCfg(type='bottle', d=2, c=256, s=1, gs=16, br=0.25), + ByoaBlocksCfg(type='bottle', d=2, c=512, s=2, gs=16, br=0.25), + interleave_attn(types=('bottle', 'self_attn'), every=1, d=2, c=1024, s=2, gs=16, br=0.25), + ByoaBlocksCfg(type='self_attn', d=2, c=2048, s=2, gs=16, br=0.25), + ), + stem_chs=64, + stem_type='tiered', + stem_pool='maxpool', + num_features=0, + act_layer='silu', + attn_layer='eca', + self_attn_layer='halo', + self_attn_kwargs=dict(block_size=8, halo_size=2) # intended for 256x256 res + ), lambda_resnet26t=ByoaCfg( blocks=( @@ -216,6 +253,22 @@ model_cfgs = dict( self_attn_layer='lambda', self_attn_kwargs=dict() ), + eca_lambda_resnext26ts=ByoaCfg( + blocks=( + ByoaBlocksCfg(type='bottle', d=2, c=256, s=1, gs=16, br=0.25), + ByoaBlocksCfg(type='bottle', d=2, c=512, s=2, gs=16, br=0.25), + interleave_attn(types=('bottle', 'self_attn'), every=1, d=2, c=1024, s=2, gs=16, br=0.25), + ByoaBlocksCfg(type='self_attn', d=2, c=2048, s=2, gs=16, br=0.25), + ), + stem_chs=64, + stem_type='tiered', + stem_pool='maxpool', + num_features=0, + act_layer='silu', + attn_layer='eca', + self_attn_layer='lambda', + self_attn_kwargs=dict() + ), swinnet26t=ByoaCfg( blocks=( @@ -248,6 +301,24 @@ model_cfgs = dict( self_attn_fixed_size=True, self_attn_kwargs=dict(win_size=8) ), + eca_swinnext26ts=ByoaCfg( + blocks=( + ByoaBlocksCfg(type='bottle', d=2, c=256, s=1, gs=16, br=0.25), + interleave_attn(types=('bottle', 'self_attn'), every=1, d=2, c=512, s=2, gs=16, br=0.25), + interleave_attn(types=('bottle', 'self_attn'), every=1, d=2, c=1024, s=2, gs=16, br=0.25), + ByoaBlocksCfg(type='self_attn', d=2, c=2048, s=2, gs=16, br=0.25), + ), + stem_chs=64, + stem_type='tiered', + stem_pool='maxpool', + num_features=0, + act_layer='silu', + attn_layer='eca', + self_attn_layer='swin', + self_attn_fixed_size=True, + self_attn_kwargs=dict(win_size=8) + ), + rednet26t=ByoaCfg( blocks=( @@ -454,6 +525,14 @@ def botnet50ts_256(pretrained=False, **kwargs): return _create_byoanet('botnet50ts_256', 'botnet50ts', pretrained=pretrained, **kwargs) +@register_model +def eca_botnext26ts_256(pretrained=False, **kwargs): + """ Bottleneck Transformer w/ ResNet26-T backbone. Bottleneck attn in final stage. + """ + kwargs.setdefault('img_size', 256) + return _create_byoanet('eca_botnext26ts_256', 'eca_botnext26ts', pretrained=pretrained, **kwargs) + + @register_model def halonet_h1(pretrained=False, **kwargs): """ HaloNet-H1. Halo attention in all stages as per the paper. @@ -484,6 +563,13 @@ def halonet50ts(pretrained=False, **kwargs): return _create_byoanet('halonet50ts', pretrained=pretrained, **kwargs) +@register_model +def eca_halonext26ts(pretrained=False, **kwargs): + """ HaloNet w/ a ResNet26-t backbone, Hallo attention in final stage + """ + return _create_byoanet('eca_halonext26ts', pretrained=pretrained, **kwargs) + + @register_model def lambda_resnet26t(pretrained=False, **kwargs): """ Lambda-ResNet-26T. Lambda layers in one C4 stage and all C5. @@ -498,6 +584,13 @@ def lambda_resnet50t(pretrained=False, **kwargs): return _create_byoanet('lambda_resnet50t', pretrained=pretrained, **kwargs) +@register_model +def eca_lambda_resnext26ts(pretrained=False, **kwargs): + """ Lambda-ResNet-26T. Lambda layers in one C4 stage and all C5. + """ + return _create_byoanet('eca_lambda_resnext26ts', pretrained=pretrained, **kwargs) + + @register_model def swinnet26t_256(pretrained=False, **kwargs): """ @@ -514,6 +607,14 @@ def swinnet50ts_256(pretrained=False, **kwargs): return _create_byoanet('swinnet50ts_256', 'swinnet50ts', pretrained=pretrained, **kwargs) +@register_model +def eca_swinnext26ts_256(pretrained=False, **kwargs): + """ + """ + kwargs.setdefault('img_size', 256) + return _create_byoanet('eca_swinnext26ts_256', 'eca_swinnext26ts', pretrained=pretrained, **kwargs) + + @register_model def rednet26t(pretrained=False, **kwargs): """ diff --git a/timm/models/byobnet.py b/timm/models/byobnet.py index 75610f67..8f4a2020 100644 --- a/timm/models/byobnet.py +++ b/timm/models/byobnet.py @@ -98,7 +98,7 @@ class BlocksCfg: s: int = 2 # stride of stage (first block) gs: Optional[Union[int, Callable]] = None # group-size of blocks in stage, conv is depthwise if gs == 1 br: float = 1. # bottleneck-ratio of blocks in stage - no_attn: bool = True # disable channel attn (ie SE) when layer is set for model + no_attn: bool = False # disable channel attn (ie SE) when layer is set for model @dataclass From 00548b8427739bf9954dd8ce522e2833de616baf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9D=8E=E9=91=AB=E6=9D=B0?= Date: Tue, 18 May 2021 19:21:53 +0800 Subject: [PATCH 06/48] Add Twins --- timm/models/__init__.py | 1 + timm/models/twins.py | 625 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 626 insertions(+) create mode 100644 timm/models/twins.py diff --git a/timm/models/__init__.py b/timm/models/__init__.py index 46ea155f..293b459d 100644 --- a/timm/models/__init__.py +++ b/timm/models/__init__.py @@ -39,6 +39,7 @@ from .vision_transformer_hybrid import * from .vovnet import * from .xception import * from .xception_aligned import * +from .twins import * from .factory import create_model, split_model_name, safe_model_name from .helpers import load_checkpoint, resume_checkpoint, model_parameters diff --git a/timm/models/twins.py b/timm/models/twins.py new file mode 100644 index 00000000..27be4cba --- /dev/null +++ b/timm/models/twins.py @@ -0,0 +1,625 @@ +""" Twins +A PyTorch impl of : `Twins: Revisiting the Design of Spatial Attention in Vision Transformers` + - https://arxiv.org/pdf/2104.13840.pdf + +Code/weights from https://github.com/Meituan-AutoML/Twins, original copyright/license info below + +""" +# -------------------------------------------------------- +# Twins +# Copyright (c) 2021 Meituan +# Licensed under The Apache 2.0 License [see LICENSE for details] +# Written by Xinjie Li, Xiangxiang Chu +# -------------------------------------------------------- + +import logging +import math +from copy import deepcopy +from typing import Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F +from functools import partial + +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from .layers import Mlp, DropPath, to_2tuple, trunc_normal_ +from .registry import register_model +from .vision_transformer import _cfg +from .vision_transformer import Block as TimmBlock +from .vision_transformer import Attention as TimmAttention +from .helpers import build_model_with_cfg, overlay_external_default_cfg +from .vision_transformer import checkpoint_filter_fn, _init_vit_weights + +_logger = logging.getLogger(__name__) + +def _cfg(url='', **kwargs): + return { + 'url': url, + 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, + 'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True, + 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, + 'first_conv': 'patch_embed.proj', 'classifier': 'head', + **kwargs + } + +default_cfgs = { + 'twins_pcpvt_small': _cfg( + url='https://s3plus.meituan.net/v1/mss_9240d97c6bf34ab1b78859c3c2a2a3e4/automl-model-zoo/models/twins/pcpvt_small.pth', + ), + 'twins_pcpvt_base': _cfg( + url='https://s3plus.meituan.net/v1/mss_9240d97c6bf34ab1b78859c3c2a2a3e4/automl-model-zoo/models/twins/pcpvt_base.pth', + ), + 'twins_pcpvt_large': _cfg( + url='https://s3plus.meituan.net/v1/mss_9240d97c6bf34ab1b78859c3c2a2a3e4/automl-model-zoo/models/twins/pcpvt_large.pth', + ), + 'twins_svt_small': _cfg( + url='https://s3plus.meituan.net/v1/mss_9240d97c6bf34ab1b78859c3c2a2a3e4/automl-model-zoo/models/twins/alt_gvt_small.pth', + ), + 'twins_svt_base': _cfg( + url='https://s3plus.meituan.net/v1/mss_9240d97c6bf34ab1b78859c3c2a2a3e4/automl-model-zoo/models/twins/alt_gvt_base.pth', + ), + 'twins_svt_large': _cfg( + url='https://s3plus.meituan.net/v1/mss_9240d97c6bf34ab1b78859c3c2a2a3e4/automl-model-zoo/models/twins/alt_gvt_large.pth', + ), +} + + + +class GroupAttention(nn.Module): + """ + LSA: self attention within a group + """ + def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., ws=1): + assert ws != 1 + super(GroupAttention, self).__init__() + assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}." + + self.dim = dim + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + self.ws = ws + + def forward(self, x, H, W): + """ + There are two implementations for this function, zero padding or mask. We don't observe obvious difference for + both. You can choose any one, we recommend forward_padding because it's neat. However, + the masking implementation is more reasonable and accurate. + Args: + x: + H: + W: + + Returns: + + """ + return self.forward_padding(x, H, W) + + def forward_mask(self, x, H, W): + B, N, C = x.shape + x = x.view(B, H, W, C) + pad_l = pad_t = 0 + pad_r = (self.ws - W % self.ws) % self.ws + pad_b = (self.ws - H % self.ws) % self.ws + x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b)) + _, Hp, Wp, _ = x.shape + _h, _w = Hp // self.ws, Wp // self.ws + mask = torch.zeros((1, Hp, Wp), device=x.device) + mask[:, -pad_b:, :].fill_(1) + mask[:, :, -pad_r:].fill_(1) + + x = x.reshape(B, _h, self.ws, _w, self.ws, C).transpose(2, 3) # B, _h, _w, ws, ws, C + mask = mask.reshape(1, _h, self.ws, _w, self.ws).transpose(2, 3).reshape(1, _h*_w, self.ws*self.ws) + attn_mask = mask.unsqueeze(2) - mask.unsqueeze(3) # 1, _h*_w, ws*ws, ws*ws + attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-1000.0)).masked_fill(attn_mask == 0, float(0.0)) + qkv = self.qkv(x).reshape(B, _h * _w, self.ws * self.ws, 3, self.num_heads, + C // self.num_heads).permute(3, 0, 1, 4, 2, 5) # n_h, B, _w*_h, nhead, ws*ws, dim + q, k, v = qkv[0], qkv[1], qkv[2] # B, _h*_w, n_head, ws*ws, dim_head + attn = (q @ k.transpose(-2, -1)) * self.scale # B, _h*_w, n_head, ws*ws, ws*ws + attn = attn + attn_mask.unsqueeze(2) + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) # attn @v -> B, _h*_w, n_head, ws*ws, dim_head + attn = (attn @ v).transpose(2, 3).reshape(B, _h, _w, self.ws, self.ws, C) + x = attn.transpose(2, 3).reshape(B, _h * self.ws, _w * self.ws, C) + if pad_r > 0 or pad_b > 0: + x = x[:, :H, :W, :].contiguous() + x = x.reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + def forward_padding(self, x, H, W): + B, N, C = x.shape + x = x.view(B, H, W, C) + pad_l = pad_t = 0 + pad_r = (self.ws - W % self.ws) % self.ws + pad_b = (self.ws - H % self.ws) % self.ws + x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b)) + _, Hp, Wp, _ = x.shape + _h, _w = Hp // self.ws, Wp // self.ws + x = x.reshape(B, _h, self.ws, _w, self.ws, C).transpose(2, 3) + qkv = self.qkv(x).reshape(B, _h * _w, self.ws * self.ws, 3, self.num_heads, + C // self.num_heads).permute(3, 0, 1, 4, 2, 5) + q, k, v = qkv[0], qkv[1], qkv[2] + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + attn = (attn @ v).transpose(2, 3).reshape(B, _h, _w, self.ws, self.ws, C) + x = attn.transpose(2, 3).reshape(B, _h * self.ws, _w * self.ws, C) + if pad_r > 0 or pad_b > 0: + x = x[:, :H, :W, :].contiguous() + x = x.reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class Attention(nn.Module): + """ + GSA: using a key to summarize the information for a group to be efficient. + """ + def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., sr_ratio=1): + super().__init__() + assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}." + + self.dim = dim + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + + self.q = nn.Linear(dim, dim, bias=qkv_bias) + self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + self.sr_ratio = sr_ratio + if sr_ratio > 1: + self.sr = nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio) + self.norm = nn.LayerNorm(dim) + + def forward(self, x, H, W): + B, N, C = x.shape + q = self.q(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) + + if self.sr_ratio > 1: + x_ = x.permute(0, 2, 1).reshape(B, C, H, W) + x_ = self.sr(x_).reshape(B, C, -1).permute(0, 2, 1) + x_ = self.norm(x_) + kv = self.kv(x_).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + else: + kv = self.kv(x).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + k, v = kv[0], kv[1] + + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + + return x + + +class Block(nn.Module): + + def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, sr_ratio=1): + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = Attention( + dim, + num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, + attn_drop=attn_drop, proj_drop=drop, sr_ratio=sr_ratio) + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + def forward(self, x, H, W): + x = x + self.drop_path(self.attn(self.norm1(x), H, W)) + x = x + self.drop_path(self.mlp(self.norm2(x))) + + return x + + +class SBlock(TimmBlock): + def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, sr_ratio=1): + super(SBlock, self).__init__(dim, num_heads, mlp_ratio, qkv_bias, qk_scale, drop, attn_drop, + drop_path, act_layer, norm_layer) + + def forward(self, x, H, W): + return super(SBlock, self).forward(x) + + +class GroupBlock(TimmBlock): + def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, sr_ratio=1, ws=1): + super(GroupBlock, self).__init__(dim, num_heads, mlp_ratio, qkv_bias, qk_scale, drop, attn_drop, + drop_path, act_layer, norm_layer) + del self.attn + if ws == 1: + self.attn = Attention(dim, num_heads, qkv_bias, qk_scale, attn_drop, drop, sr_ratio) + else: + self.attn = GroupAttention(dim, num_heads, qkv_bias, qk_scale, attn_drop, drop, ws) + + def forward(self, x, H, W): + x = x + self.drop_path(self.attn(self.norm1(x), H, W)) + x = x + self.drop_path(self.mlp(self.norm2(x))) + return x + + +class PatchEmbed(nn.Module): + """ Image to Patch Embedding + """ + + def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768): + super().__init__() + # img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + + self.img_size = img_size + self.patch_size = patch_size + assert img_size[0] % patch_size[0] == 0 and img_size[1] % patch_size[1] == 0, \ + f"img_size {img_size} should be divided by patch_size {patch_size}." + self.H, self.W = img_size[0] // patch_size[0], img_size[1] // patch_size[1] + self.num_patches = self.H * self.W + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) + self.norm = nn.LayerNorm(embed_dim) + + def forward(self, x): + B, C, H, W = x.shape + + x = self.proj(x).flatten(2).transpose(1, 2) + x = self.norm(x) + H, W = H // self.patch_size[0], W // self.patch_size[1] + + return x, (H, W) + + +# borrow from PVT https://github.com/whai362/PVT.git +class PyramidVisionTransformer(nn.Module): + def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dims=[64, 128, 256, 512], + num_heads=[1, 2, 4, 8], mlp_ratios=[4, 4, 4, 4], qkv_bias=False, qk_scale=None, drop_rate=0., + attn_drop_rate=0., drop_path_rate=0., norm_layer=nn.LayerNorm, + depths=[3, 4, 6, 3], sr_ratios=[8, 4, 2, 1], block_cls=Block): + super().__init__() + self.num_classes = num_classes + self.depths = depths + + # patch_embed + self.patch_embeds = nn.ModuleList() + self.pos_embeds = nn.ParameterList() + self.pos_drops = nn.ModuleList() + self.blocks = nn.ModuleList() + + for i in range(len(depths)): + if i == 0: + self.patch_embeds.append(PatchEmbed(img_size, patch_size, in_chans, embed_dims[i])) + else: + self.patch_embeds.append( + # PatchEmbed(img_size // patch_size // 2 ** (i - 1), 2, embed_dims[i - 1], embed_dims[i]) + PatchEmbed((img_size[0] // patch_size // 2**(i-1),img_size[1] // patch_size // 2**(i-1)), 2, embed_dims[i - 1], embed_dims[i]) + ) + patch_num = self.patch_embeds[-1].num_patches + 1 if i == len(embed_dims) - 1 else self.patch_embeds[ + -1].num_patches + self.pos_embeds.append(nn.Parameter(torch.zeros(1, patch_num, embed_dims[i]))) + self.pos_drops.append(nn.Dropout(p=drop_rate)) + + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule + cur = 0 + for k in range(len(depths)): + _block = nn.ModuleList([block_cls( + dim=embed_dims[k], num_heads=num_heads[k], mlp_ratio=mlp_ratios[k], qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer, + sr_ratio=sr_ratios[k]) + for i in range(depths[k])]) + self.blocks.append(_block) + cur += depths[k] + + self.norm = norm_layer(embed_dims[-1]) + + # cls_token + self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dims[-1])) + + # classification head + self.head = nn.Linear(embed_dims[-1], num_classes) if num_classes > 0 else nn.Identity() + + # init weights + for pos_emb in self.pos_embeds: + trunc_normal_(pos_emb, std=.02) + self.apply(self._init_weights) + + def reset_drop_path(self, drop_path_rate): + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(self.depths))] + cur = 0 + for k in range(len(self.depths)): + for i in range(self.depths[k]): + self.blocks[k][i].drop_path.drop_prob = dpr[cur + i] + cur += self.depths[k] + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + @torch.jit.ignore + def no_weight_decay(self): + return {'cls_token'} + + def get_classifier(self): + return self.head + + def reset_classifier(self, num_classes, global_pool=''): + self.num_classes = num_classes + self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() + + def forward_features(self, x): + B = x.shape[0] + for i in range(len(self.depths)): + x, (H, W) = self.patch_embeds[i](x) + if i == len(self.depths) - 1: + cls_tokens = self.cls_token.expand(B, -1, -1) + x = torch.cat((cls_tokens, x), dim=1) + x = x + self.pos_embeds[i] + x = self.pos_drops[i](x) + for blk in self.blocks[i]: + x = blk(x, H, W) + if i < len(self.depths) - 1: + x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() + + x = self.norm(x) + + return x[:, 0] + + def forward(self, x): + x = self.forward_features(x) + x = self.head(x) + + return x + + +# PEG from https://arxiv.org/abs/2102.10882 +class PosCNN(nn.Module): + def __init__(self, in_chans, embed_dim=768, s=1): + super(PosCNN, self).__init__() + self.proj = nn.Sequential(nn.Conv2d(in_chans, embed_dim, 3, s, 1, bias=True, groups=embed_dim), ) + self.s = s + + def forward(self, x, H, W): + B, N, C = x.shape + feat_token = x + cnn_feat = feat_token.transpose(1, 2).view(B, C, H, W) + if self.s == 1: + x = self.proj(cnn_feat) + cnn_feat + else: + x = self.proj(cnn_feat) + x = x.flatten(2).transpose(1, 2) + return x + + def no_weight_decay(self): + return ['proj.%d.weight' % i for i in range(4)] + + +class CPVTV2(PyramidVisionTransformer): + """ + Use useful results from CPVT. PEG and GAP. + Therefore, cls token is no longer required. + PEG is used to encode the absolute position on the fly, which greatly affects the performance when input resolution + changes during the training (such as segmentation, detection) + """ + def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes=1000, embed_dims=[64, 128, 256, 512], + num_heads=[1, 2, 4, 8], mlp_ratios=[4, 4, 4, 4], qkv_bias=False, qk_scale=None, drop_rate=0., + attn_drop_rate=0., drop_path_rate=0., norm_layer=nn.LayerNorm, + depths=[3, 4, 6, 3], sr_ratios=[8, 4, 2, 1], block_cls=Block): + super(CPVTV2, self).__init__(img_size, patch_size, in_chans, num_classes, embed_dims, num_heads, mlp_ratios, + qkv_bias, qk_scale, drop_rate, attn_drop_rate, drop_path_rate, norm_layer, depths, + sr_ratios, block_cls) + del self.pos_embeds + del self.cls_token + self.pos_block = nn.ModuleList( + [PosCNN(embed_dim, embed_dim) for embed_dim in embed_dims] + ) + self.apply(self._init_weights) + + def _init_weights(self, m): + import math + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + elif isinstance(m, nn.Conv2d): + fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + fan_out //= m.groups + m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) + if m.bias is not None: + m.bias.data.zero_() + elif isinstance(m, nn.BatchNorm2d): + m.weight.data.fill_(1.0) + m.bias.data.zero_() + + def no_weight_decay(self): + return set(['cls_token'] + ['pos_block.' + n for n, p in self.pos_block.named_parameters()]) + + def forward_features(self, x): + B = x.shape[0] + + for i in range(len(self.depths)): + x, (H, W) = self.patch_embeds[i](x) + x = self.pos_drops[i](x) + for j, blk in enumerate(self.blocks[i]): + x = blk(x, H, W) + if j == 0: + x = self.pos_block[i](x, H, W) # PEG here + if i < len(self.depths) - 1: + x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() + + x = self.norm(x) + + return x.mean(dim=1) # GAP here + + +class Twins_PCPVT(CPVTV2): + def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes=1000, embed_dims=[64, 128, 256], + num_heads=[1, 2, 4], mlp_ratios=[4, 4, 4], qkv_bias=False, qk_scale=None, drop_rate=0., + attn_drop_rate=0., drop_path_rate=0., norm_layer=nn.LayerNorm, + depths=[4, 4, 4], sr_ratios=[4, 2, 1], block_cls=SBlock): + super(Twins_PCPVT, self).__init__(img_size, patch_size, in_chans, num_classes, embed_dims, num_heads, + mlp_ratios, qkv_bias, qk_scale, drop_rate, attn_drop_rate, drop_path_rate, + norm_layer, depths, sr_ratios, block_cls) + + +class Twins_SVT(Twins_PCPVT): + """ + alias Twins-SVT + """ + def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes=1000, embed_dims=[64, 128, 256], + num_heads=[1, 2, 4], mlp_ratios=[4, 4, 4], qkv_bias=False, qk_scale=None, drop_rate=0., + attn_drop_rate=0., drop_path_rate=0., norm_layer=nn.LayerNorm, + depths=[4, 4, 4], sr_ratios=[4, 2, 1], block_cls=GroupBlock, wss=[7, 7, 7]): + super(Twins_SVT, self).__init__(img_size, patch_size, in_chans, num_classes, embed_dims, num_heads, + mlp_ratios, qkv_bias, qk_scale, drop_rate, attn_drop_rate, drop_path_rate, + norm_layer, depths, sr_ratios, block_cls) + del self.blocks + self.wss = wss + # transformer encoder + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule + cur = 0 + self.blocks = nn.ModuleList() + for k in range(len(depths)): + _block = nn.ModuleList([block_cls( + dim=embed_dims[k], num_heads=num_heads[k], mlp_ratio=mlp_ratios[k], qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer, + sr_ratio=sr_ratios[k], ws=1 if i % 2 == 1 else wss[k]) for i in range(depths[k])]) + self.blocks.append(_block) + cur += depths[k] + self.apply(self._init_weights) + + +def _conv_filter(state_dict, patch_size=16): + """ convert patch embedding weight from manual patchify + linear proj to conv""" + out_dict = {} + for k, v in state_dict.items(): + if 'patch_embed.proj.weight' in k: + v = v.reshape((v.shape[0], 3, patch_size, patch_size)) + out_dict[k] = v + + return out_dict + +def _create_twins_svt(variant, pretrained=False, default_cfg=None, **kwargs): + if default_cfg is None: + default_cfg = deepcopy(default_cfgs[variant]) + overlay_external_default_cfg(default_cfg, kwargs) + default_num_classes = default_cfg['num_classes'] + default_img_size = default_cfg['input_size'][-2:] + + num_classes = kwargs.pop('num_classes', default_num_classes) + img_size = kwargs.pop('img_size', default_img_size) + if kwargs.get('features_only', None): + raise RuntimeError('features_only not implemented for Vision Transformer models.') + + model = build_model_with_cfg( + Twins_SVT, variant, pretrained, + default_cfg=default_cfg, + img_size=img_size, + num_classes=num_classes, + pretrained_filter_fn=checkpoint_filter_fn, + **kwargs) + + return model + +def _create_twins_pcpvt(variant, pretrained=False, default_cfg=None, **kwargs): + if default_cfg is None: + default_cfg = deepcopy(default_cfgs[variant]) + overlay_external_default_cfg(default_cfg, kwargs) + default_num_classes = default_cfg['num_classes'] + default_img_size = default_cfg['input_size'][-2:] + + num_classes = kwargs.pop('num_classes', default_num_classes) + img_size = kwargs.pop('img_size', default_img_size) + if kwargs.get('features_only', None): + raise RuntimeError('features_only not implemented for Vision Transformer models.') + + model = build_model_with_cfg( + CPVTV2, variant, pretrained, + default_cfg=default_cfg, + img_size=img_size, + num_classes=num_classes, + pretrained_filter_fn=checkpoint_filter_fn, + **kwargs) + + return model + + +@register_model +def twins_pcpvt_small(pretrained=False, **kwargs): + model_kwargs = dict( + patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[8, 8, 4, 4], qkv_bias=True, + norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 4, 6, 3], sr_ratios=[8, 4, 2, 1], + **kwargs) + return _create_twins_pcpvt('twins_pcpvt_small', pretrained=pretrained, **model_kwargs) + + +@register_model +def twins_pcpvt_base(pretrained=False, **kwargs): + model_kwargs = dict( + patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[8, 8, 4, 4], qkv_bias=True, + norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 4, 18, 3], sr_ratios=[8, 4, 2, 1], + **kwargs) + return _create_twins_pcpvt('twins_pcpvt_base', pretrained=pretrained, **model_kwargs) + + +@register_model +def twins_pcpvt_large(pretrained=False, **kwargs): + model_kwargs = dict( + patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[8, 8, 4, 4], qkv_bias=True, + norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 8, 27, 3], sr_ratios=[8, 4, 2, 1], + **kwargs) + return _create_twins_pcpvt('twins_pcpvt_large', pretrained=pretrained, **model_kwargs) + + +@register_model +def twins_svt_small(pretrained=False, **kwargs): + model_kwargs = dict( + patch_size=4, embed_dims=[64, 128, 256, 512], num_heads=[2, 4, 8, 16], mlp_ratios=[4, 4, 4, 4], qkv_bias=True, + norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[2, 2, 10, 4], wss=[7, 7, 7, 7], sr_ratios=[8, 4, 2, 1], + **kwargs) + return _create_twins_svt('twins_svt_small', pretrained=pretrained, **model_kwargs) + + +@register_model +def twins_svt_base(pretrained=False, **kwargs): + model_kwargs = dict( + patch_size=4, embed_dims=[96, 192, 384, 768], num_heads=[3, 6, 12, 24], mlp_ratios=[4, 4, 4, 4], qkv_bias=True, + norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[2, 2, 18, 2], wss=[7, 7, 7, 7], sr_ratios=[8, 4, 2, 1], + **kwargs) + + return _create_twins_svt('twins_svt_base', pretrained=pretrained, **model_kwargs) + + +@register_model +def twins_svt_large(pretrained=False, **kwargs): + model_kwargs = dict( + patch_size=4, embed_dims=[128, 256, 512, 1024], num_heads=[4, 8, 16, 32], mlp_ratios=[4, 4, 4, 4], + qkv_bias=True, + norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[2, 2, 18, 2], wss=[7, 7, 7, 7], sr_ratios=[8, 4, 2, 1], + **kwargs) + + return _create_twins_svt('twins_svt_large', pretrained=pretrained, **model_kwargs) From 5bcf686cb0aad39b7c9114931db2e7fc2bc4f24c Mon Sep 17 00:00:00 2001 From: talrid Date: Wed, 19 May 2021 20:51:10 +0300 Subject: [PATCH 07/48] mixer_b16_224_miil --- timm/models/mlp_mixer.py | 27 +++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/timm/models/mlp_mixer.py b/timm/models/mlp_mixer.py index 248568fc..87edbfd6 100644 --- a/timm/models/mlp_mixer.py +++ b/timm/models/mlp_mixer.py @@ -60,6 +60,15 @@ default_cfgs = dict( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_mixer_l16_224_in21k-846aa33c.pth', num_classes=21843 ), + # Mixer ImageNet-21K-P pretraining + mixer_b16_224_miil_in21k=_cfg( + url='https://miil-public-eu.oss-eu-central-1.aliyuncs.com/model-zoo/ImageNet_21K_P/models/timm/mixer_b16_224_miil_in21k.pth', + mean=(0, 0, 0), std=(1, 1, 1), crop_pct=0.875, interpolation='bilinear', num_classes=11221, + ), + mixer_b16_224_miil=_cfg( + url='https://miil-public-eu.oss-eu-central-1.aliyuncs.com/model-zoo/ImageNet_21K_P/models/timm/mixer_b16_224_miil.pth', + mean=(0, 0, 0), std=(1, 1, 1), crop_pct=0.875, interpolation='bilinear', + ), ) @@ -255,3 +264,21 @@ def mixer_l16_224_in21k(pretrained=False, **kwargs): model_args = dict(patch_size=16, num_blocks=24, hidden_dim=1024, tokens_dim=512, channels_dim=4096, **kwargs) model = _create_mixer('mixer_l16_224_in21k', pretrained=pretrained, **model_args) return model + +@register_model +def mixer_b16_224_miil(pretrained=False, **kwargs): + """ Mixer-B/16 224x224. ImageNet-21k pretrained weights. + Weights taken from: https://github.com/Alibaba-MIIL/ImageNet21K + """ + model_args = dict(patch_size=16, num_blocks=12, hidden_dim=768, tokens_dim=384, channels_dim=3072, **kwargs) + model = _create_mixer('mixer_b16_224_miil', pretrained=pretrained, **model_args) + return model + +@register_model +def mixer_b16_224_miil_in21k(pretrained=False, **kwargs): + """ Mixer-B/16 224x224. ImageNet-1k pretrained weights. + Weights taken from: https://github.com/Alibaba-MIIL/ImageNet21K + """ + model_args = dict(patch_size=16, num_blocks=12, hidden_dim=768, tokens_dim=384, channels_dim=3072, **kwargs) + model = _create_mixer('mixer_b16_224_miil_in21k', pretrained=pretrained, **model_args) + return model \ No newline at end of file From d046498e0bf8ee5a8fcc80d91452363c62c262f0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9D=8E=E9=91=AB=E6=9D=B0?= Date: Thu, 20 May 2021 11:20:39 +0800 Subject: [PATCH 08/48] update test_models.py --- tests/test_models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_models.py b/tests/test_models.py index b77b29ff..3013d0b9 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -15,7 +15,7 @@ if hasattr(torch._C, '_jit_set_profiling_executor'): torch._C._jit_set_profiling_mode(False) # transformer models don't support many of the spatial / feature based model functionalities -NON_STD_FILTERS = ['vit_*', 'tnt_*', 'pit_*', 'swin_*', 'coat_*', 'cait_*', '*mixer_*', 'gmlp_*', 'resmlp_*'] +NON_STD_FILTERS = ['vit_*', 'tnt_*', 'pit_*', 'swin_*', 'coat_*', 'cait_*', '*mixer_*', 'gmlp_*', 'resmlp_*', 'twins_*'] NUM_NON_STD = len(NON_STD_FILTERS) # exclude models that cause specific test failures From 240e6677468392283835c372fe2addc72514cff9 Mon Sep 17 00:00:00 2001 From: talrid Date: Thu, 20 May 2021 10:23:07 +0300 Subject: [PATCH 09/48] Revert "mixer_b16_224_miil" --- timm/models/mlp_mixer.py | 27 --------------------------- 1 file changed, 27 deletions(-) diff --git a/timm/models/mlp_mixer.py b/timm/models/mlp_mixer.py index 87edbfd6..248568fc 100644 --- a/timm/models/mlp_mixer.py +++ b/timm/models/mlp_mixer.py @@ -60,15 +60,6 @@ default_cfgs = dict( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_mixer_l16_224_in21k-846aa33c.pth', num_classes=21843 ), - # Mixer ImageNet-21K-P pretraining - mixer_b16_224_miil_in21k=_cfg( - url='https://miil-public-eu.oss-eu-central-1.aliyuncs.com/model-zoo/ImageNet_21K_P/models/timm/mixer_b16_224_miil_in21k.pth', - mean=(0, 0, 0), std=(1, 1, 1), crop_pct=0.875, interpolation='bilinear', num_classes=11221, - ), - mixer_b16_224_miil=_cfg( - url='https://miil-public-eu.oss-eu-central-1.aliyuncs.com/model-zoo/ImageNet_21K_P/models/timm/mixer_b16_224_miil.pth', - mean=(0, 0, 0), std=(1, 1, 1), crop_pct=0.875, interpolation='bilinear', - ), ) @@ -264,21 +255,3 @@ def mixer_l16_224_in21k(pretrained=False, **kwargs): model_args = dict(patch_size=16, num_blocks=24, hidden_dim=1024, tokens_dim=512, channels_dim=4096, **kwargs) model = _create_mixer('mixer_l16_224_in21k', pretrained=pretrained, **model_args) return model - -@register_model -def mixer_b16_224_miil(pretrained=False, **kwargs): - """ Mixer-B/16 224x224. ImageNet-21k pretrained weights. - Weights taken from: https://github.com/Alibaba-MIIL/ImageNet21K - """ - model_args = dict(patch_size=16, num_blocks=12, hidden_dim=768, tokens_dim=384, channels_dim=3072, **kwargs) - model = _create_mixer('mixer_b16_224_miil', pretrained=pretrained, **model_args) - return model - -@register_model -def mixer_b16_224_miil_in21k(pretrained=False, **kwargs): - """ Mixer-B/16 224x224. ImageNet-1k pretrained weights. - Weights taken from: https://github.com/Alibaba-MIIL/ImageNet21K - """ - model_args = dict(patch_size=16, num_blocks=12, hidden_dim=768, tokens_dim=384, channels_dim=3072, **kwargs) - model = _create_mixer('mixer_b16_224_miil_in21k', pretrained=pretrained, **model_args) - return model \ No newline at end of file From dc1a4efd28b335ebd85e13d64edd78404f75aeb7 Mon Sep 17 00:00:00 2001 From: talrid Date: Thu, 20 May 2021 10:35:50 +0300 Subject: [PATCH 10/48] mixer_b16_224_miil, mixer_b16_224_miil_in21k models --- timm/models/mlp_mixer.py | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/timm/models/mlp_mixer.py b/timm/models/mlp_mixer.py index 2241fe43..92ca115b 100644 --- a/timm/models/mlp_mixer.py +++ b/timm/models/mlp_mixer.py @@ -80,6 +80,15 @@ default_cfgs = dict( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_mixer_l16_224_in21k-846aa33c.pth', num_classes=21843 ), + # Mixer ImageNet-21K-P pretraining + mixer_b16_224_miil_in21k=_cfg( + url='https://miil-public-eu.oss-eu-central-1.aliyuncs.com/model-zoo/ImageNet_21K_P/models/timm/mixer_b16_224_miil_in21k.pth', + mean=(0, 0, 0), std=(1, 1, 1), crop_pct=0.875, interpolation='bilinear', num_classes=11221, + ), + mixer_b16_224_miil=_cfg( + url='https://miil-public-eu.oss-eu-central-1.aliyuncs.com/model-zoo/ImageNet_21K_P/models/timm/mixer_b16_224_miil.pth', + mean=(0, 0, 0), std=(1, 1, 1), crop_pct=0.875, interpolation='bilinear', + ), gmixer_12_224=_cfg(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD), gmixer_24_224=_cfg(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD), @@ -365,6 +374,23 @@ def mixer_l16_224_in21k(pretrained=False, **kwargs): model = _create_mixer('mixer_l16_224_in21k', pretrained=pretrained, **model_args) return model +@register_model +def mixer_b16_224_miil(pretrained=False, **kwargs): + """ Mixer-B/16 224x224. ImageNet-21k pretrained weights. + Weights taken from: https://github.com/Alibaba-MIIL/ImageNet21K + """ + model_args = dict(patch_size=16, num_blocks=12, hidden_dim=768, **kwargs) + model = _create_mixer('mixer_b16_224_miil', pretrained=pretrained, **model_args) + return model + +@register_model +def mixer_b16_224_miil_in21k(pretrained=False, **kwargs): + """ Mixer-B/16 224x224. ImageNet-1k pretrained weights. + Weights taken from: https://github.com/Alibaba-MIIL/ImageNet21K + """ + model_args = dict(patch_size=16, num_blocks=12, hidden_dim=768, **kwargs) + model = _create_mixer('mixer_b16_224_miil_in21k', pretrained=pretrained, **model_args) + return model @register_model def gmixer_12_224(pretrained=False, **kwargs): From 8086943b6f4cef1ad7b1f044eafcd8e138dd5cfd Mon Sep 17 00:00:00 2001 From: Alexander Soare Date: Thu, 20 May 2021 11:27:58 +0100 Subject: [PATCH 11/48] allow resize positional embeddings to non-square grid --- timm/models/vision_transformer.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/timm/models/vision_transformer.py b/timm/models/vision_transformer.py index cc7e0903..1acdd808 100644 --- a/timm/models/vision_transformer.py +++ b/timm/models/vision_transformer.py @@ -352,7 +352,7 @@ def _init_vit_weights(m, n: str = '', head_bias: float = 0., jax_impl: bool = Fa nn.init.ones_(m.weight) -def resize_pos_embed(posemb, posemb_new, num_tokens=1): +def resize_pos_embed(posemb, posemb_new, num_tokens=1, gs_new=[]): # Rescale the grid of position embeddings when loading from state_dict. Adapted from # https://github.com/google-research/vision_transformer/blob/00883dd691c63a6830751563748663526e811cee/vit_jax/checkpoint.py#L224 _logger.info('Resized position embedding: %s to %s', posemb.shape, posemb_new.shape) @@ -363,11 +363,12 @@ def resize_pos_embed(posemb, posemb_new, num_tokens=1): else: posemb_tok, posemb_grid = posemb[:, :0], posemb[0] gs_old = int(math.sqrt(len(posemb_grid))) - gs_new = int(math.sqrt(ntok_new)) - _logger.info('Position embedding grid-size from %s to %s', gs_old, gs_new) + if not len(gs_new): # backwards compatibility + gs_new = [int(math.sqrt(ntok_new))]*2 + _logger.info('Position embedding grid-size from %s to %s', [gs_old, gs_old], gs_new) posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2) - posemb_grid = F.interpolate(posemb_grid, size=(gs_new, gs_new), mode='bilinear') - posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_new * gs_new, -1) + posemb_grid = F.interpolate(posemb_grid, size=gs_new, mode='bilinear') + posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_new[0] * gs_new[1], -1) posemb = torch.cat([posemb_tok, posemb_grid], dim=1) return posemb @@ -385,7 +386,8 @@ def checkpoint_filter_fn(state_dict, model): v = v.reshape(O, -1, H, W) elif k == 'pos_embed' and v.shape != model.pos_embed.shape: # To resize pos embedding when using model at different size from pretrained weights - v = resize_pos_embed(v, model.pos_embed, getattr(model, 'num_tokens', 1)) + v = resize_pos_embed(v, model.pos_embed, getattr(model, 'num_tokens', 1), + model.patch_embed.grid_size) out_dict[k] = v return out_dict From 79760198640dbb7e63889a322c1c70c1b5113b97 Mon Sep 17 00:00:00 2001 From: Alexander Soare Date: Thu, 20 May 2021 11:55:48 +0100 Subject: [PATCH 12/48] extend positional embedding resizing functionality to tnt --- timm/models/tnt.py | 31 +++++++++++++++++++++++-------- 1 file changed, 23 insertions(+), 8 deletions(-) diff --git a/timm/models/tnt.py b/timm/models/tnt.py index cc732677..8e038718 100644 --- a/timm/models/tnt.py +++ b/timm/models/tnt.py @@ -14,7 +14,9 @@ from functools import partial from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.models.helpers import load_pretrained from timm.models.layers import Mlp, DropPath, trunc_normal_ +from timm.models.layers.helpers import to_2tuple from timm.models.registry import register_model +from timm.models.vision_transformer import resize_pos_embed def _cfg(url='', **kwargs): @@ -118,11 +120,15 @@ class PixelEmbed(nn.Module): """ def __init__(self, img_size=224, patch_size=16, in_chans=3, in_dim=48, stride=4): super().__init__() - num_patches = (img_size // patch_size) ** 2 + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + # grid_size property necessary for resizing positional embedding + self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1]) + num_patches = (self.grid_size[0]) * (self.grid_size[1]) self.img_size = img_size self.num_patches = num_patches self.in_dim = in_dim - new_patch_size = math.ceil(patch_size / stride) + new_patch_size = [math.ceil(ps / stride) for ps in patch_size] self.new_patch_size = new_patch_size self.proj = nn.Conv2d(in_chans, self.in_dim, kernel_size=7, padding=3, stride=stride) @@ -130,11 +136,11 @@ class PixelEmbed(nn.Module): def forward(self, x, pixel_pos): B, C, H, W = x.shape - assert H == self.img_size and W == self.img_size, \ - f"Input image size ({H}*{W}) doesn't match model ({self.img_size}*{self.img_size})." + assert H == self.img_size[0] and W == self.img_size[1], \ + f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." x = self.proj(x) x = self.unfold(x) - x = x.transpose(1, 2).reshape(B * self.num_patches, self.in_dim, self.new_patch_size, self.new_patch_size) + x = x.transpose(1, 2).reshape(B * self.num_patches, self.in_dim, self.new_patch_size[0], self.new_patch_size[1]) x = x + pixel_pos x = x.reshape(B * self.num_patches, self.in_dim, -1).transpose(1, 2) return x @@ -155,7 +161,7 @@ class TNT(nn.Module): num_patches = self.pixel_embed.num_patches self.num_patches = num_patches new_patch_size = self.pixel_embed.new_patch_size - num_pixel = new_patch_size ** 2 + num_pixel = new_patch_size[0] * new_patch_size[1] self.norm1_proj = norm_layer(num_pixel * in_dim) self.proj = nn.Linear(num_pixel * in_dim, embed_dim) @@ -163,7 +169,7 @@ class TNT(nn.Module): self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) self.patch_pos = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim)) - self.pixel_pos = nn.Parameter(torch.zeros(1, in_dim, new_patch_size, new_patch_size)) + self.pixel_pos = nn.Parameter(torch.zeros(1, in_dim, new_patch_size[0], new_patch_size[1])) self.pos_drop = nn.Dropout(p=drop_rate) dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule @@ -224,6 +230,14 @@ class TNT(nn.Module): return x +def checkpoint_filter_fn(state_dict, model): + """ convert patch embedding weight from manual patchify + linear proj to conv""" + if state_dict['patch_pos'].shape != model.patch_pos.shape: + state_dict['patch_pos'] = resize_pos_embed(state_dict['patch_pos'], + model.patch_pos, getattr(model, 'num_tokens', 1), model.pixel_embed.grid_size) + return state_dict + + @register_model def tnt_s_patch16_224(pretrained=False, **kwargs): model = TNT(patch_size=16, embed_dim=384, in_dim=24, depth=12, num_heads=6, in_num_head=4, @@ -231,7 +245,8 @@ def tnt_s_patch16_224(pretrained=False, **kwargs): model.default_cfg = default_cfgs['tnt_s_patch16_224'] if pretrained: load_pretrained( - model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3)) + model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3), + filter_fn=checkpoint_filter_fn) return model From 40c506ba1ecac691955a8e486b99036d294cb763 Mon Sep 17 00:00:00 2001 From: Aman Arora Date: Thu, 20 May 2021 23:17:28 +0000 Subject: [PATCH 13/48] Add ConViT --- timm/models/__init__.py | 1 + timm/models/convit.py | 445 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 446 insertions(+) create mode 100644 timm/models/convit.py diff --git a/timm/models/__init__.py b/timm/models/__init__.py index 46ea155f..4d1230bd 100644 --- a/timm/models/__init__.py +++ b/timm/models/__init__.py @@ -2,6 +2,7 @@ from .byoanet import * from .byobnet import * from .cait import * from .coat import * +from .convit import * from .cspnet import * from .densenet import * from .dla import * diff --git a/timm/models/convit.py b/timm/models/convit.py new file mode 100644 index 00000000..82a0d988 --- /dev/null +++ b/timm/models/convit.py @@ -0,0 +1,445 @@ +"""These modules are adapted from those of timm, see +https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py +""" + +import torch +import torch.nn as nn +from functools import partial +import torch.nn.functional as F + +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from .helpers import build_model_with_cfg +from timm.models.helpers import load_pretrained +from timm.models.layers import DropPath, to_2tuple, trunc_normal_ +from timm.models.registry import register_model + +import torch +import torch.nn as nn +import matplotlib.pyplot as plt + + +def _cfg(url='', **kwargs): + return { + 'url': url, + 'num_classes': 1000, 'input_size': (3, 224, 224), + 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, + **kwargs + } + + +default_cfgs = { + # ConViT + 'convit_tiny': _cfg( + url="https://dl.fbaipublicfiles.com/convit/convit_tiny.pth"), + 'convit_small': _cfg( + url="https://dl.fbaipublicfiles.com/convit/convit_small.pth"), + 'convit_base': _cfg( + url="https://dl.fbaipublicfiles.com/convit/convit_base.pth") +} + + +class Mlp(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class GPSA(nn.Module): + def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., + locality_strength=1., use_local_init=True): + super().__init__() + self.num_heads = num_heads + self.dim = dim + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + + self.qk = nn.Linear(dim, dim * 2, bias=qkv_bias) + self.v = nn.Linear(dim, dim, bias=qkv_bias) + + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.pos_proj = nn.Linear(3, num_heads) + self.proj_drop = nn.Dropout(proj_drop) + self.locality_strength = locality_strength + self.gating_param = nn.Parameter(torch.ones(self.num_heads)) + self.apply(self._init_weights) + if use_local_init: + self.local_init(locality_strength=locality_strength) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + def forward(self, x): + B, N, C = x.shape + if not hasattr(self, 'rel_indices') or self.rel_indices.size(1)!=N: + self.get_rel_indices(N) + + attn = self.get_attention(x) + v = self.v(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + def get_attention(self, x): + B, N, C = x.shape + qk = self.qk(x).reshape(B, N, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k = qk[0], qk[1] + pos_score = self.rel_indices.expand(B, -1, -1,-1) + pos_score = self.pos_proj(pos_score).permute(0,3,1,2) + patch_score = (q @ k.transpose(-2, -1)) * self.scale + patch_score = patch_score.softmax(dim=-1) + pos_score = pos_score.softmax(dim=-1) + + gating = self.gating_param.view(1,-1,1,1) + attn = (1.-torch.sigmoid(gating)) * patch_score + torch.sigmoid(gating) * pos_score + attn /= attn.sum(dim=-1).unsqueeze(-1) + attn = self.attn_drop(attn) + return attn + + def get_attention_map(self, x, return_map = False): + + attn_map = self.get_attention(x).mean(0) # average over batch + distances = self.rel_indices.squeeze()[:,:,-1]**.5 + dist = torch.einsum('nm,hnm->h', (distances, attn_map)) + dist /= distances.size(0) + if return_map: + return dist, attn_map + else: + return dist + + def local_init(self, locality_strength=1.): + + self.v.weight.data.copy_(torch.eye(self.dim)) + locality_distance = 1 #max(1,1/locality_strength**.5) + + kernel_size = int(self.num_heads**.5) + center = (kernel_size-1)/2 if kernel_size%2==0 else kernel_size//2 + for h1 in range(kernel_size): + for h2 in range(kernel_size): + position = h1+kernel_size*h2 + self.pos_proj.weight.data[position,2] = -1 + self.pos_proj.weight.data[position,1] = 2*(h1-center)*locality_distance + self.pos_proj.weight.data[position,0] = 2*(h2-center)*locality_distance + self.pos_proj.weight.data *= locality_strength + + def get_rel_indices(self, num_patches): + img_size = int(num_patches**.5) + rel_indices = torch.zeros(1, num_patches, num_patches, 3) + ind = torch.arange(img_size).view(1,-1) - torch.arange(img_size).view(-1, 1) + indx = ind.repeat(img_size,img_size) + indy = ind.repeat_interleave(img_size,dim=0).repeat_interleave(img_size,dim=1) + indd = indx**2 + indy**2 + rel_indices[:,:,:,2] = indd.unsqueeze(0) + rel_indices[:,:,:,1] = indy.unsqueeze(0) + rel_indices[:,:,:,0] = indx.unsqueeze(0) + device = self.qk.weight.device + self.rel_indices = rel_indices.to(device) + + +class MHSA(nn.Module): + def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + def get_attention_map(self, x, return_map = False): + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] + attn_map = (q @ k.transpose(-2, -1)) * self.scale + attn_map = attn_map.softmax(dim=-1).mean(0) + + img_size = int(N**.5) + ind = torch.arange(img_size).view(1,-1) - torch.arange(img_size).view(-1, 1) + indx = ind.repeat(img_size,img_size) + indy = ind.repeat_interleave(img_size,dim=0).repeat_interleave(img_size,dim=1) + indd = indx**2 + indy**2 + distances = indd**.5 + distances = distances.to('cuda') + + dist = torch.einsum('nm,hnm->h', (distances, attn_map)) + dist /= N + + if return_map: + return dist, attn_map + else: + return dist + + + def forward(self, x): + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] + + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class Block(nn.Module): + + def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, use_gpsa=True, **kwargs): + super().__init__() + self.norm1 = norm_layer(dim) + self.use_gpsa = use_gpsa + if self.use_gpsa: + self.attn = GPSA(dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, **kwargs) + else: + self.attn = MHSA(dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, **kwargs) + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + def forward(self, x): + x = x + self.drop_path(self.attn(self.norm1(x))) + x = x + self.drop_path(self.mlp(self.norm2(x))) + return x + + +class PatchEmbed(nn.Module): + """ Image to Patch Embedding, from timm + """ + def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0]) + self.img_size = img_size + self.patch_size = patch_size + self.num_patches = num_patches + + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) + self.apply(self._init_weights) + + def forward(self, x): + B, C, H, W = x.shape + assert H == self.img_size[0] and W == self.img_size[1], \ + f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." + x = self.proj(x).flatten(2).transpose(1, 2) + return x + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + +class HybridEmbed(nn.Module): + """ CNN Feature Map Embedding, from timm + """ + def __init__(self, backbone, img_size=224, feature_size=None, in_chans=3, embed_dim=768): + super().__init__() + assert isinstance(backbone, nn.Module) + img_size = to_2tuple(img_size) + self.img_size = img_size + self.backbone = backbone + if feature_size is None: + with torch.no_grad(): + training = backbone.training + if training: + backbone.eval() + o = self.backbone(torch.zeros(1, in_chans, img_size[0], img_size[1]))[-1] + feature_size = o.shape[-2:] + feature_dim = o.shape[1] + backbone.train(training) + else: + feature_size = to_2tuple(feature_size) + feature_dim = self.backbone.feature_info.channels()[-1] + self.num_patches = feature_size[0] * feature_size[1] + self.proj = nn.Linear(feature_dim, embed_dim) + self.apply(self._init_weights) + + def forward(self, x): + x = self.backbone(x)[-1] + x = x.flatten(2).transpose(1, 2) + x = self.proj(x) + return x + + +class ConViT(nn.Module): + """ Vision Transformer with support for patch or hybrid CNN input stage + """ + def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12, + num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0., + drop_path_rate=0., hybrid_backbone=None, norm_layer=nn.LayerNorm, global_pool=None, + local_up_to_layer=3, locality_strength=1., use_pos_embed=True): + super().__init__() + embed_dim *= num_heads + self.num_classes = num_classes + self.local_up_to_layer = local_up_to_layer + self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models + self.locality_strength = locality_strength + self.use_pos_embed = use_pos_embed + + if hybrid_backbone is not None: + self.patch_embed = HybridEmbed( + hybrid_backbone, img_size=img_size, in_chans=in_chans, embed_dim=embed_dim) + else: + self.patch_embed = PatchEmbed( + img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) + num_patches = self.patch_embed.num_patches + self.num_patches = num_patches + + self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) + self.pos_drop = nn.Dropout(p=drop_rate) + + if self.use_pos_embed: + self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim)) + trunc_normal_(self.pos_embed, std=.02) + + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule + self.blocks = nn.ModuleList([ + Block( + dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, + use_gpsa=True, + locality_strength=locality_strength) + if i 0 else nn.Identity() + + trunc_normal_(self.cls_token, std=.02) + self.head.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + @torch.jit.ignore + def no_weight_decay(self): + return {'pos_embed', 'cls_token'} + + def get_classifier(self): + return self.head + + def reset_classifier(self, num_classes, global_pool=''): + self.num_classes = num_classes + self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() + + def forward_features(self, x): + B = x.shape[0] + x = self.patch_embed(x) + + cls_tokens = self.cls_token.expand(B, -1, -1) + + if self.use_pos_embed: + x = x + self.pos_embed + x = self.pos_drop(x) + + for u,blk in enumerate(self.blocks): + if u == self.local_up_to_layer : + x = torch.cat((cls_tokens, x), dim=1) + x = blk(x) + + x = self.norm(x) + return x[:, 0] + + def forward(self, x): + x = self.forward_features(x) + x = self.head(x) + return x + + +def _create_convit(variant, pretrained=False, **kwargs): + return build_model_with_cfg( + ConViT, variant, pretrained, + default_cfg=default_cfgs[variant], + **kwargs) + + +@register_model +def convit_tiny(pretrained=False, **kwargs): + model_args = dict( + local_up_to_layer=10, locality_strength=1.0, embed_dim=48, + num_heads=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) + model = _create_convit( + variant='convit_tiny', pretrained=pretrained, **model_args) + return model + +@register_model +def convit_small(pretrained=False, **kwargs): + model_args = dict( + local_up_to_layer=10, locality_strength=1.0, embed_dim=48, + num_heads=9, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) + model = _create_convit( + variant='convit_small', pretrained=pretrained, **model_args) + return model + +@register_model +def convit_base(pretrained=False, **kwargs): + model_args = dict( + local_up_to_layer=10, locality_strength=1.0, embed_dim=48, + num_heads=16, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) + model = _create_convit( + variant='convit_base', pretrained=pretrained, **model_args) + return model From 8b1f2e8e1f4c73f43d9e956c2162884ab319b1a3 Mon Sep 17 00:00:00 2001 From: Aman Arora Date: Thu, 20 May 2021 23:42:42 +0000 Subject: [PATCH 14/48] remote unused matplotlib import --- timm/models/convit.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/timm/models/convit.py b/timm/models/convit.py index 82a0d988..29970c76 100644 --- a/timm/models/convit.py +++ b/timm/models/convit.py @@ -9,13 +9,11 @@ import torch.nn.functional as F from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from .helpers import build_model_with_cfg -from timm.models.helpers import load_pretrained from timm.models.layers import DropPath, to_2tuple, trunc_normal_ from timm.models.registry import register_model import torch import torch.nn as nn -import matplotlib.pyplot as plt def _cfg(url='', **kwargs): From 163331748935559923ffb6aa5ed1882b47a6a92a Mon Sep 17 00:00:00 2001 From: Aman Arora Date: Fri, 21 May 2021 01:11:56 +0000 Subject: [PATCH 15/48] update tests and exclude convit_base --- tests/test_models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_models.py b/tests/test_models.py index b77b29ff..1bf0d738 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -24,7 +24,7 @@ if 'GITHUB_ACTIONS' in os.environ: # and 'Linux' in platform.system(): EXCLUDE_FILTERS = [ '*efficientnet_l2*', '*resnext101_32x48d', '*in21k', '*152x4_bitm', '*101x3_bitm', '*nfnet_f3*', '*nfnet_f4*', '*nfnet_f5*', '*nfnet_f6*', '*nfnet_f7*', - '*resnetrs350*', '*resnetrs420*'] + NON_STD_FILTERS + '*resnetrs350*', '*resnetrs420*', 'convit_base'] + NON_STD_FILTERS else: EXCLUDE_FILTERS = NON_STD_FILTERS From 5db1eb6ba56f35fce8bc06e85c7339e7c714a4f4 Mon Sep 17 00:00:00 2001 From: Aman Arora Date: Fri, 21 May 2021 02:11:20 +0000 Subject: [PATCH 16/48] Add defaults --- timm/models/convit.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/timm/models/convit.py b/timm/models/convit.py index 29970c76..31c05df3 100644 --- a/timm/models/convit.py +++ b/timm/models/convit.py @@ -19,8 +19,9 @@ import torch.nn as nn def _cfg(url='', **kwargs): return { 'url': url, - 'num_classes': 1000, 'input_size': (3, 224, 224), + 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, + 'first_conv': 'patch_embed.proj', 'classifier': 'head', **kwargs } From 50d6aab0efb53b4072008780fb7ea3cc82e0236f Mon Sep 17 00:00:00 2001 From: Aman Arora Date: Fri, 21 May 2021 03:46:47 +0000 Subject: [PATCH 17/48] Add convit to non-std filters as vit_ --- tests/test_models.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_models.py b/tests/test_models.py index 1bf0d738..f098fefd 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -15,7 +15,7 @@ if hasattr(torch._C, '_jit_set_profiling_executor'): torch._C._jit_set_profiling_mode(False) # transformer models don't support many of the spatial / feature based model functionalities -NON_STD_FILTERS = ['vit_*', 'tnt_*', 'pit_*', 'swin_*', 'coat_*', 'cait_*', '*mixer_*', 'gmlp_*', 'resmlp_*'] +NON_STD_FILTERS = ['vit_*', 'tnt_*', 'convit_*', 'pit_*', 'swin_*', 'coat_*', 'cait_*', '*mixer_*', 'gmlp_*', 'resmlp_*'] NUM_NON_STD = len(NON_STD_FILTERS) # exclude models that cause specific test failures @@ -24,7 +24,7 @@ if 'GITHUB_ACTIONS' in os.environ: # and 'Linux' in platform.system(): EXCLUDE_FILTERS = [ '*efficientnet_l2*', '*resnext101_32x48d', '*in21k', '*152x4_bitm', '*101x3_bitm', '*nfnet_f3*', '*nfnet_f4*', '*nfnet_f5*', '*nfnet_f6*', '*nfnet_f7*', - '*resnetrs350*', '*resnetrs420*', 'convit_base'] + NON_STD_FILTERS + '*resnetrs350*', '*resnetrs420*'] + NON_STD_FILTERS else: EXCLUDE_FILTERS = NON_STD_FILTERS From be99eef9c14fe63a2ebf3cdd2784d16140851004 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Thu, 20 May 2021 23:38:35 -0700 Subject: [PATCH 18/48] Remove redundant code, cleanup, fix torchscript. --- timm/models/twins.py | 495 +++++++++++++------------------------------ 1 file changed, 149 insertions(+), 346 deletions(-) diff --git a/timm/models/twins.py b/timm/models/twins.py index 27be4cba..ce51c497 100644 --- a/timm/models/twins.py +++ b/timm/models/twins.py @@ -11,11 +11,9 @@ Code/weights from https://github.com/Meituan-AutoML/Twins, original copyright/li # Licensed under The Apache 2.0 License [see LICENSE for details] # Written by Xinjie Li, Xiangxiang Chu # -------------------------------------------------------- - -import logging import math from copy import deepcopy -from typing import Optional +from typing import Optional, Tuple import torch import torch.nn as nn @@ -25,13 +23,9 @@ from functools import partial from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from .layers import Mlp, DropPath, to_2tuple, trunc_normal_ from .registry import register_model -from .vision_transformer import _cfg -from .vision_transformer import Block as TimmBlock -from .vision_transformer import Attention as TimmAttention +from .vision_transformer import Attention from .helpers import build_model_with_cfg, overlay_external_default_cfg -from .vision_transformer import checkpoint_filter_fn, _init_vit_weights -_logger = logging.getLogger(__name__) def _cfg(url='', **kwargs): return { @@ -43,6 +37,7 @@ def _cfg(url='', **kwargs): **kwargs } + default_cfgs = { 'twins_pcpvt_small': _cfg( url='https://s3plus.meituan.net/v1/mss_9240d97c6bf34ab1b78859c3c2a2a3e4/automl-model-zoo/models/twins/pcpvt_small.pth', @@ -64,78 +59,34 @@ default_cfgs = { ), } +Size_ = Tuple[int, int] -class GroupAttention(nn.Module): - """ - LSA: self attention within a group +class LocallyGroupedAttn(nn.Module): + """ LSA: self attention within a group """ - def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., ws=1): + def __init__(self, dim, num_heads=8, attn_drop=0., proj_drop=0., ws=1): assert ws != 1 - super(GroupAttention, self).__init__() + super(LocallyGroupedAttn, self).__init__() assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}." self.dim = dim self.num_heads = num_heads head_dim = dim // num_heads - self.scale = qk_scale or head_dim ** -0.5 + self.scale = head_dim ** -0.5 - self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.qkv = nn.Linear(dim, dim * 3, bias=True) self.attn_drop = nn.Dropout(attn_drop) self.proj = nn.Linear(dim, dim) self.proj_drop = nn.Dropout(proj_drop) self.ws = ws - def forward(self, x, H, W): - """ - There are two implementations for this function, zero padding or mask. We don't observe obvious difference for - both. You can choose any one, we recommend forward_padding because it's neat. However, - the masking implementation is more reasonable and accurate. - Args: - x: - H: - W: - - Returns: - - """ - return self.forward_padding(x, H, W) - - def forward_mask(self, x, H, W): - B, N, C = x.shape - x = x.view(B, H, W, C) - pad_l = pad_t = 0 - pad_r = (self.ws - W % self.ws) % self.ws - pad_b = (self.ws - H % self.ws) % self.ws - x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b)) - _, Hp, Wp, _ = x.shape - _h, _w = Hp // self.ws, Wp // self.ws - mask = torch.zeros((1, Hp, Wp), device=x.device) - mask[:, -pad_b:, :].fill_(1) - mask[:, :, -pad_r:].fill_(1) - - x = x.reshape(B, _h, self.ws, _w, self.ws, C).transpose(2, 3) # B, _h, _w, ws, ws, C - mask = mask.reshape(1, _h, self.ws, _w, self.ws).transpose(2, 3).reshape(1, _h*_w, self.ws*self.ws) - attn_mask = mask.unsqueeze(2) - mask.unsqueeze(3) # 1, _h*_w, ws*ws, ws*ws - attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-1000.0)).masked_fill(attn_mask == 0, float(0.0)) - qkv = self.qkv(x).reshape(B, _h * _w, self.ws * self.ws, 3, self.num_heads, - C // self.num_heads).permute(3, 0, 1, 4, 2, 5) # n_h, B, _w*_h, nhead, ws*ws, dim - q, k, v = qkv[0], qkv[1], qkv[2] # B, _h*_w, n_head, ws*ws, dim_head - attn = (q @ k.transpose(-2, -1)) * self.scale # B, _h*_w, n_head, ws*ws, ws*ws - attn = attn + attn_mask.unsqueeze(2) - attn = attn.softmax(dim=-1) - attn = self.attn_drop(attn) # attn @v -> B, _h*_w, n_head, ws*ws, dim_head - attn = (attn @ v).transpose(2, 3).reshape(B, _h, _w, self.ws, self.ws, C) - x = attn.transpose(2, 3).reshape(B, _h * self.ws, _w * self.ws, C) - if pad_r > 0 or pad_b > 0: - x = x[:, :H, :W, :].contiguous() - x = x.reshape(B, N, C) - x = self.proj(x) - x = self.proj_drop(x) - return x - - def forward_padding(self, x, H, W): + def forward(self, x, size: Size_): + # There are two implementations for this function, zero padding or mask. We don't observe obvious difference for + # both. You can choose any one, we recommend forward_padding because it's neat. However, + # the masking implementation is more reasonable and accurate. B, N, C = x.shape + H, W = size x = x.view(B, H, W, C) pad_l = pad_t = 0 pad_r = (self.ws - W % self.ws) % self.ws @@ -144,8 +95,8 @@ class GroupAttention(nn.Module): _, Hp, Wp, _ = x.shape _h, _w = Hp // self.ws, Wp // self.ws x = x.reshape(B, _h, self.ws, _w, self.ws, C).transpose(2, 3) - qkv = self.qkv(x).reshape(B, _h * _w, self.ws * self.ws, 3, self.num_heads, - C // self.num_heads).permute(3, 0, 1, 4, 2, 5) + qkv = self.qkv(x).reshape( + B, _h * _w, self.ws * self.ws, 3, self.num_heads, C // self.num_heads).permute(3, 0, 1, 4, 2, 5) q, k, v = qkv[0], qkv[1], qkv[2] attn = (q @ k.transpose(-2, -1)) * self.scale attn = attn.softmax(dim=-1) @@ -159,22 +110,56 @@ class GroupAttention(nn.Module): x = self.proj_drop(x) return x - -class Attention(nn.Module): - """ - GSA: using a key to summarize the information for a group to be efficient. + # def forward_mask(self, x, size: Size_): + # B, N, C = x.shape + # H, W = size + # x = x.view(B, H, W, C) + # pad_l = pad_t = 0 + # pad_r = (self.ws - W % self.ws) % self.ws + # pad_b = (self.ws - H % self.ws) % self.ws + # x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b)) + # _, Hp, Wp, _ = x.shape + # _h, _w = Hp // self.ws, Wp // self.ws + # mask = torch.zeros((1, Hp, Wp), device=x.device) + # mask[:, -pad_b:, :].fill_(1) + # mask[:, :, -pad_r:].fill_(1) + # + # x = x.reshape(B, _h, self.ws, _w, self.ws, C).transpose(2, 3) # B, _h, _w, ws, ws, C + # mask = mask.reshape(1, _h, self.ws, _w, self.ws).transpose(2, 3).reshape(1, _h * _w, self.ws * self.ws) + # attn_mask = mask.unsqueeze(2) - mask.unsqueeze(3) # 1, _h*_w, ws*ws, ws*ws + # attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-1000.0)).masked_fill(attn_mask == 0, float(0.0)) + # qkv = self.qkv(x).reshape( + # B, _h * _w, self.ws * self.ws, 3, self.num_heads, C // self.num_heads).permute(3, 0, 1, 4, 2, 5) + # # n_h, B, _w*_h, nhead, ws*ws, dim + # q, k, v = qkv[0], qkv[1], qkv[2] # B, _h*_w, n_head, ws*ws, dim_head + # attn = (q @ k.transpose(-2, -1)) * self.scale # B, _h*_w, n_head, ws*ws, ws*ws + # attn = attn + attn_mask.unsqueeze(2) + # attn = attn.softmax(dim=-1) + # attn = self.attn_drop(attn) # attn @v -> B, _h*_w, n_head, ws*ws, dim_head + # attn = (attn @ v).transpose(2, 3).reshape(B, _h, _w, self.ws, self.ws, C) + # x = attn.transpose(2, 3).reshape(B, _h * self.ws, _w * self.ws, C) + # if pad_r > 0 or pad_b > 0: + # x = x[:, :H, :W, :].contiguous() + # x = x.reshape(B, N, C) + # x = self.proj(x) + # x = self.proj_drop(x) + # return x + + +class GlobalSubSampleAttn(nn.Module): + """ GSA: using a key to summarize the information for a group to be efficient. """ - def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., sr_ratio=1): + def __init__(self, dim, num_heads=8, attn_drop=0., proj_drop=0., sr_ratio=1): super().__init__() assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}." self.dim = dim self.num_heads = num_heads head_dim = dim // num_heads - self.scale = qk_scale or head_dim ** -0.5 + self.scale = head_dim ** -0.5 - self.q = nn.Linear(dim, dim, bias=qkv_bias) - self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias) + self.q = nn.Linear(dim, dim, bias=True) + self.kv = nn.Linear(dim, dim * 2, bias=True) self.attn_drop = nn.Dropout(attn_drop) self.proj = nn.Linear(dim, dim) self.proj_drop = nn.Dropout(proj_drop) @@ -183,18 +168,19 @@ class Attention(nn.Module): if sr_ratio > 1: self.sr = nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio) self.norm = nn.LayerNorm(dim) + else: + self.sr = None + self.norm = None - def forward(self, x, H, W): + def forward(self, x, size: Size_): B, N, C = x.shape q = self.q(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) - if self.sr_ratio > 1: - x_ = x.permute(0, 2, 1).reshape(B, C, H, W) - x_ = self.sr(x_).reshape(B, C, -1).permute(0, 2, 1) - x_ = self.norm(x_) - kv = self.kv(x_).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) - else: - kv = self.kv(x).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + if self.sr is not None: + x = x.permute(0, 2, 1).reshape(B, C, *size) + x = self.sr(x).reshape(B, C, -1).permute(0, 2, 1) + x = self.norm(x) + kv = self.kv(x).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) k, v = kv[0], kv[1] attn = (q @ k.transpose(-2, -1)) * self.scale @@ -210,52 +196,46 @@ class Attention(nn.Module): class Block(nn.Module): - def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., - drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, sr_ratio=1): + def __init__(self, dim, num_heads, mlp_ratio=4., drop=0., attn_drop=0., drop_path=0., + act_layer=nn.GELU, norm_layer=nn.LayerNorm, sr_ratio=1, ws=None): super().__init__() self.norm1 = norm_layer(dim) - self.attn = Attention( - dim, - num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, - attn_drop=attn_drop, proj_drop=drop, sr_ratio=sr_ratio) + if ws is None: + self.attn = Attention(dim, num_heads, False, None, attn_drop, drop) + elif ws == 1: + self.attn = GlobalSubSampleAttn(dim, num_heads, attn_drop, drop, sr_ratio) + else: + self.attn = LocallyGroupedAttn(dim, num_heads, attn_drop, drop, ws) self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() self.norm2 = norm_layer(dim) mlp_hidden_dim = int(dim * mlp_ratio) self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) - def forward(self, x, H, W): - x = x + self.drop_path(self.attn(self.norm1(x), H, W)) + def forward(self, x, size: Size_): + x = x + self.drop_path(self.attn(self.norm1(x), size)) x = x + self.drop_path(self.mlp(self.norm2(x))) - return x -class SBlock(TimmBlock): - def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., - drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, sr_ratio=1): - super(SBlock, self).__init__(dim, num_heads, mlp_ratio, qkv_bias, qk_scale, drop, attn_drop, - drop_path, act_layer, norm_layer) - - def forward(self, x, H, W): - return super(SBlock, self).forward(x) - - -class GroupBlock(TimmBlock): - def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., - drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, sr_ratio=1, ws=1): - super(GroupBlock, self).__init__(dim, num_heads, mlp_ratio, qkv_bias, qk_scale, drop, attn_drop, - drop_path, act_layer, norm_layer) - del self.attn - if ws == 1: - self.attn = Attention(dim, num_heads, qkv_bias, qk_scale, attn_drop, drop, sr_ratio) - else: - self.attn = GroupAttention(dim, num_heads, qkv_bias, qk_scale, attn_drop, drop, ws) +class PosConv(nn.Module): + # PEG from https://arxiv.org/abs/2102.10882 + def __init__(self, in_chans, embed_dim=768, stride=1): + super(PosConv, self).__init__() + self.proj = nn.Sequential(nn.Conv2d(in_chans, embed_dim, 3, stride, 1, bias=True, groups=embed_dim), ) + self.stride = stride - def forward(self, x, H, W): - x = x + self.drop_path(self.attn(self.norm1(x), H, W)) - x = x + self.drop_path(self.mlp(self.norm2(x))) + def forward(self, x, size: Size_): + B, N, C = x.shape + cnn_feat_token = x.transpose(1, 2).view(B, C, *size) + x = self.proj(cnn_feat_token) + if self.stride == 1: + x += cnn_feat_token + x = x.flatten(2).transpose(1, 2) return x + def no_weight_decay(self): + return ['proj.%d.weight' % i for i in range(4)] + class PatchEmbed(nn.Module): """ Image to Patch Embedding @@ -263,7 +243,7 @@ class PatchEmbed(nn.Module): def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768): super().__init__() - # img_size = to_2tuple(img_size) + img_size = to_2tuple(img_size) patch_size = to_2tuple(patch_size) self.img_size = img_size @@ -275,90 +255,62 @@ class PatchEmbed(nn.Module): self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) self.norm = nn.LayerNorm(embed_dim) - def forward(self, x): + def forward(self, x) -> Tuple[torch.Tensor, Size_]: B, C, H, W = x.shape x = self.proj(x).flatten(2).transpose(1, 2) x = self.norm(x) - H, W = H // self.patch_size[0], W // self.patch_size[1] + out_size = (H // self.patch_size[0], W // self.patch_size[1]) - return x, (H, W) + return x, out_size -# borrow from PVT https://github.com/whai362/PVT.git -class PyramidVisionTransformer(nn.Module): - def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dims=[64, 128, 256, 512], - num_heads=[1, 2, 4, 8], mlp_ratios=[4, 4, 4, 4], qkv_bias=False, qk_scale=None, drop_rate=0., - attn_drop_rate=0., drop_path_rate=0., norm_layer=nn.LayerNorm, - depths=[3, 4, 6, 3], sr_ratios=[8, 4, 2, 1], block_cls=Block): +class Twins(nn.Module): + # Adapted from PVT https://github.com/whai362/PVT.git + def __init__( + self, img_size=224, patch_size=4, in_chans=3, num_classes=1000, embed_dims=(64, 128, 256, 512), + num_heads=(1, 2, 4, 8), mlp_ratios=(4, 4, 4, 4), drop_rate=0., attn_drop_rate=0., drop_path_rate=0., + norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=(3, 4, 6, 3), sr_ratios=(8, 4, 2, 1), wss=None, + block_cls=Block): super().__init__() self.num_classes = num_classes self.depths = depths - # patch_embed + img_size = to_2tuple(img_size) + prev_chs = in_chans self.patch_embeds = nn.ModuleList() - self.pos_embeds = nn.ParameterList() self.pos_drops = nn.ModuleList() - self.blocks = nn.ModuleList() - for i in range(len(depths)): - if i == 0: - self.patch_embeds.append(PatchEmbed(img_size, patch_size, in_chans, embed_dims[i])) - else: - self.patch_embeds.append( - # PatchEmbed(img_size // patch_size // 2 ** (i - 1), 2, embed_dims[i - 1], embed_dims[i]) - PatchEmbed((img_size[0] // patch_size // 2**(i-1),img_size[1] // patch_size // 2**(i-1)), 2, embed_dims[i - 1], embed_dims[i]) - ) - patch_num = self.patch_embeds[-1].num_patches + 1 if i == len(embed_dims) - 1 else self.patch_embeds[ - -1].num_patches - self.pos_embeds.append(nn.Parameter(torch.zeros(1, patch_num, embed_dims[i]))) + self.patch_embeds.append(PatchEmbed(img_size, patch_size, prev_chs, embed_dims[i])) self.pos_drops.append(nn.Dropout(p=drop_rate)) + prev_chs = embed_dims[i] + img_size = tuple(t // patch_size for t in img_size) + patch_size = 2 + self.blocks = nn.ModuleList() dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule cur = 0 for k in range(len(depths)): _block = nn.ModuleList([block_cls( - dim=embed_dims[k], num_heads=num_heads[k], mlp_ratio=mlp_ratios[k], qkv_bias=qkv_bias, - qk_scale=qk_scale, - drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer, - sr_ratio=sr_ratios[k]) - for i in range(depths[k])]) + dim=embed_dims[k], num_heads=num_heads[k], mlp_ratio=mlp_ratios[k], drop=drop_rate, + attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer, sr_ratio=sr_ratios[k], + ws=1 if wss is None or i % 2 == 1 else wss[k]) for i in range(depths[k])]) self.blocks.append(_block) cur += depths[k] - self.norm = norm_layer(embed_dims[-1]) + self.pos_block = nn.ModuleList([PosConv(embed_dim, embed_dim) for embed_dim in embed_dims]) - # cls_token - self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dims[-1])) + self.norm = norm_layer(embed_dims[-1]) # classification head self.head = nn.Linear(embed_dims[-1], num_classes) if num_classes > 0 else nn.Identity() # init weights - for pos_emb in self.pos_embeds: - trunc_normal_(pos_emb, std=.02) self.apply(self._init_weights) - def reset_drop_path(self, drop_path_rate): - dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(self.depths))] - cur = 0 - for k in range(len(self.depths)): - for i in range(self.depths[k]): - self.blocks[k][i].drop_path.drop_prob = dpr[cur + i] - cur += self.depths[k] - - def _init_weights(self, m): - if isinstance(m, nn.Linear): - trunc_normal_(m.weight, std=.02) - if isinstance(m, nn.Linear) and m.bias is not None: - nn.init.constant_(m.bias, 0) - elif isinstance(m, nn.LayerNorm): - nn.init.constant_(m.bias, 0) - nn.init.constant_(m.weight, 1.0) - @torch.jit.ignore def no_weight_decay(self): - return {'cls_token'} + return set(['pos_block.' + n for n, p in self.pos_block.named_parameters()]) def get_classifier(self): return self.head @@ -367,76 +319,7 @@ class PyramidVisionTransformer(nn.Module): self.num_classes = num_classes self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() - def forward_features(self, x): - B = x.shape[0] - for i in range(len(self.depths)): - x, (H, W) = self.patch_embeds[i](x) - if i == len(self.depths) - 1: - cls_tokens = self.cls_token.expand(B, -1, -1) - x = torch.cat((cls_tokens, x), dim=1) - x = x + self.pos_embeds[i] - x = self.pos_drops[i](x) - for blk in self.blocks[i]: - x = blk(x, H, W) - if i < len(self.depths) - 1: - x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() - - x = self.norm(x) - - return x[:, 0] - - def forward(self, x): - x = self.forward_features(x) - x = self.head(x) - - return x - - -# PEG from https://arxiv.org/abs/2102.10882 -class PosCNN(nn.Module): - def __init__(self, in_chans, embed_dim=768, s=1): - super(PosCNN, self).__init__() - self.proj = nn.Sequential(nn.Conv2d(in_chans, embed_dim, 3, s, 1, bias=True, groups=embed_dim), ) - self.s = s - - def forward(self, x, H, W): - B, N, C = x.shape - feat_token = x - cnn_feat = feat_token.transpose(1, 2).view(B, C, H, W) - if self.s == 1: - x = self.proj(cnn_feat) + cnn_feat - else: - x = self.proj(cnn_feat) - x = x.flatten(2).transpose(1, 2) - return x - - def no_weight_decay(self): - return ['proj.%d.weight' % i for i in range(4)] - - -class CPVTV2(PyramidVisionTransformer): - """ - Use useful results from CPVT. PEG and GAP. - Therefore, cls token is no longer required. - PEG is used to encode the absolute position on the fly, which greatly affects the performance when input resolution - changes during the training (such as segmentation, detection) - """ - def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes=1000, embed_dims=[64, 128, 256, 512], - num_heads=[1, 2, 4, 8], mlp_ratios=[4, 4, 4, 4], qkv_bias=False, qk_scale=None, drop_rate=0., - attn_drop_rate=0., drop_path_rate=0., norm_layer=nn.LayerNorm, - depths=[3, 4, 6, 3], sr_ratios=[8, 4, 2, 1], block_cls=Block): - super(CPVTV2, self).__init__(img_size, patch_size, in_chans, num_classes, embed_dims, num_heads, mlp_ratios, - qkv_bias, qk_scale, drop_rate, attn_drop_rate, drop_path_rate, norm_layer, depths, - sr_ratios, block_cls) - del self.pos_embeds - del self.cls_token - self.pos_block = nn.ModuleList( - [PosCNN(embed_dim, embed_dim) for embed_dim in embed_dims] - ) - self.apply(self._init_weights) - def _init_weights(self, m): - import math if isinstance(m, nn.Linear): trunc_normal_(m.weight, std=.02) if isinstance(m, nn.Linear) and m.bias is not None: @@ -454,98 +337,28 @@ class CPVTV2(PyramidVisionTransformer): m.weight.data.fill_(1.0) m.bias.data.zero_() - def no_weight_decay(self): - return set(['cls_token'] + ['pos_block.' + n for n, p in self.pos_block.named_parameters()]) - def forward_features(self, x): B = x.shape[0] - - for i in range(len(self.depths)): - x, (H, W) = self.patch_embeds[i](x) - x = self.pos_drops[i](x) - for j, blk in enumerate(self.blocks[i]): - x = blk(x, H, W) + for i, (embed, drop, blocks, pos_blk) in enumerate( + zip(self.patch_embeds, self.pos_drops, self.blocks, self.pos_block)): + x, size = embed(x) + x = drop(x) + for j, blk in enumerate(blocks): + x = blk(x, size) if j == 0: - x = self.pos_block[i](x, H, W) # PEG here + x = pos_blk(x, size) # PEG here if i < len(self.depths) - 1: - x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() - + x = x.reshape(B, *size, -1).permute(0, 3, 1, 2).contiguous() x = self.norm(x) - return x.mean(dim=1) # GAP here + def forward(self, x): + x = self.forward_features(x) + x = self.head(x) + return x -class Twins_PCPVT(CPVTV2): - def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes=1000, embed_dims=[64, 128, 256], - num_heads=[1, 2, 4], mlp_ratios=[4, 4, 4], qkv_bias=False, qk_scale=None, drop_rate=0., - attn_drop_rate=0., drop_path_rate=0., norm_layer=nn.LayerNorm, - depths=[4, 4, 4], sr_ratios=[4, 2, 1], block_cls=SBlock): - super(Twins_PCPVT, self).__init__(img_size, patch_size, in_chans, num_classes, embed_dims, num_heads, - mlp_ratios, qkv_bias, qk_scale, drop_rate, attn_drop_rate, drop_path_rate, - norm_layer, depths, sr_ratios, block_cls) - - -class Twins_SVT(Twins_PCPVT): - """ - alias Twins-SVT - """ - def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes=1000, embed_dims=[64, 128, 256], - num_heads=[1, 2, 4], mlp_ratios=[4, 4, 4], qkv_bias=False, qk_scale=None, drop_rate=0., - attn_drop_rate=0., drop_path_rate=0., norm_layer=nn.LayerNorm, - depths=[4, 4, 4], sr_ratios=[4, 2, 1], block_cls=GroupBlock, wss=[7, 7, 7]): - super(Twins_SVT, self).__init__(img_size, patch_size, in_chans, num_classes, embed_dims, num_heads, - mlp_ratios, qkv_bias, qk_scale, drop_rate, attn_drop_rate, drop_path_rate, - norm_layer, depths, sr_ratios, block_cls) - del self.blocks - self.wss = wss - # transformer encoder - dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule - cur = 0 - self.blocks = nn.ModuleList() - for k in range(len(depths)): - _block = nn.ModuleList([block_cls( - dim=embed_dims[k], num_heads=num_heads[k], mlp_ratio=mlp_ratios[k], qkv_bias=qkv_bias, - qk_scale=qk_scale, - drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer, - sr_ratio=sr_ratios[k], ws=1 if i % 2 == 1 else wss[k]) for i in range(depths[k])]) - self.blocks.append(_block) - cur += depths[k] - self.apply(self._init_weights) - - -def _conv_filter(state_dict, patch_size=16): - """ convert patch embedding weight from manual patchify + linear proj to conv""" - out_dict = {} - for k, v in state_dict.items(): - if 'patch_embed.proj.weight' in k: - v = v.reshape((v.shape[0], 3, patch_size, patch_size)) - out_dict[k] = v - - return out_dict - -def _create_twins_svt(variant, pretrained=False, default_cfg=None, **kwargs): - if default_cfg is None: - default_cfg = deepcopy(default_cfgs[variant]) - overlay_external_default_cfg(default_cfg, kwargs) - default_num_classes = default_cfg['num_classes'] - default_img_size = default_cfg['input_size'][-2:] - - num_classes = kwargs.pop('num_classes', default_num_classes) - img_size = kwargs.pop('img_size', default_img_size) - if kwargs.get('features_only', None): - raise RuntimeError('features_only not implemented for Vision Transformer models.') - - model = build_model_with_cfg( - Twins_SVT, variant, pretrained, - default_cfg=default_cfg, - img_size=img_size, - num_classes=num_classes, - pretrained_filter_fn=checkpoint_filter_fn, - **kwargs) - - return model -def _create_twins_pcpvt(variant, pretrained=False, default_cfg=None, **kwargs): +def _create_twins(variant, pretrained=False, default_cfg=None, **kwargs): if default_cfg is None: default_cfg = deepcopy(default_cfgs[variant]) overlay_external_default_cfg(default_cfg, kwargs) @@ -558,11 +371,10 @@ def _create_twins_pcpvt(variant, pretrained=False, default_cfg=None, **kwargs): raise RuntimeError('features_only not implemented for Vision Transformer models.') model = build_model_with_cfg( - CPVTV2, variant, pretrained, + Twins, variant, pretrained, default_cfg=default_cfg, img_size=img_size, num_classes=num_classes, - pretrained_filter_fn=checkpoint_filter_fn, **kwargs) return model @@ -571,55 +383,46 @@ def _create_twins_pcpvt(variant, pretrained=False, default_cfg=None, **kwargs): @register_model def twins_pcpvt_small(pretrained=False, **kwargs): model_kwargs = dict( - patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[8, 8, 4, 4], qkv_bias=True, - norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 4, 6, 3], sr_ratios=[8, 4, 2, 1], - **kwargs) - return _create_twins_pcpvt('twins_pcpvt_small', pretrained=pretrained, **model_kwargs) + patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[8, 8, 4, 4], + depths=[3, 4, 6, 3], sr_ratios=[8, 4, 2, 1], **kwargs) + return _create_twins('twins_pcpvt_small', pretrained=pretrained, **model_kwargs) @register_model def twins_pcpvt_base(pretrained=False, **kwargs): model_kwargs = dict( - patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[8, 8, 4, 4], qkv_bias=True, - norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 4, 18, 3], sr_ratios=[8, 4, 2, 1], - **kwargs) - return _create_twins_pcpvt('twins_pcpvt_base', pretrained=pretrained, **model_kwargs) + patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[8, 8, 4, 4], + depths=[3, 4, 18, 3], sr_ratios=[8, 4, 2, 1], **kwargs) + return _create_twins('twins_pcpvt_base', pretrained=pretrained, **model_kwargs) @register_model def twins_pcpvt_large(pretrained=False, **kwargs): model_kwargs = dict( - patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[8, 8, 4, 4], qkv_bias=True, - norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 8, 27, 3], sr_ratios=[8, 4, 2, 1], - **kwargs) - return _create_twins_pcpvt('twins_pcpvt_large', pretrained=pretrained, **model_kwargs) + patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[8, 8, 4, 4], + depths=[3, 8, 27, 3], sr_ratios=[8, 4, 2, 1], **kwargs) + return _create_twins('twins_pcpvt_large', pretrained=pretrained, **model_kwargs) @register_model def twins_svt_small(pretrained=False, **kwargs): model_kwargs = dict( - patch_size=4, embed_dims=[64, 128, 256, 512], num_heads=[2, 4, 8, 16], mlp_ratios=[4, 4, 4, 4], qkv_bias=True, - norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[2, 2, 10, 4], wss=[7, 7, 7, 7], sr_ratios=[8, 4, 2, 1], - **kwargs) - return _create_twins_svt('twins_svt_small', pretrained=pretrained, **model_kwargs) + patch_size=4, embed_dims=[64, 128, 256, 512], num_heads=[2, 4, 8, 16], mlp_ratios=[4, 4, 4, 4], + depths=[2, 2, 10, 4], wss=[7, 7, 7, 7], sr_ratios=[8, 4, 2, 1], **kwargs) + return _create_twins('twins_svt_small', pretrained=pretrained, **model_kwargs) @register_model def twins_svt_base(pretrained=False, **kwargs): model_kwargs = dict( - patch_size=4, embed_dims=[96, 192, 384, 768], num_heads=[3, 6, 12, 24], mlp_ratios=[4, 4, 4, 4], qkv_bias=True, - norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[2, 2, 18, 2], wss=[7, 7, 7, 7], sr_ratios=[8, 4, 2, 1], - **kwargs) - - return _create_twins_svt('twins_svt_base', pretrained=pretrained, **model_kwargs) + patch_size=4, embed_dims=[96, 192, 384, 768], num_heads=[3, 6, 12, 24], mlp_ratios=[4, 4, 4, 4], + depths=[2, 2, 18, 2], wss=[7, 7, 7, 7], sr_ratios=[8, 4, 2, 1], **kwargs) + return _create_twins('twins_svt_base', pretrained=pretrained, **model_kwargs) @register_model def twins_svt_large(pretrained=False, **kwargs): model_kwargs = dict( patch_size=4, embed_dims=[128, 256, 512, 1024], num_heads=[4, 8, 16, 32], mlp_ratios=[4, 4, 4, 4], - qkv_bias=True, - norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[2, 2, 18, 2], wss=[7, 7, 7, 7], sr_ratios=[8, 4, 2, 1], - **kwargs) - - return _create_twins_svt('twins_svt_large', pretrained=pretrained, **model_kwargs) + depths=[2, 2, 18, 2], wss=[7, 7, 7, 7], sr_ratios=[8, 4, 2, 1], **kwargs) + return _create_twins('twins_svt_large', pretrained=pretrained, **model_kwargs) From a569635045b83bfd7f86881694b2515fed575592 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Fri, 21 May 2021 16:23:14 -0700 Subject: [PATCH 19/48] Update twin weights to a copy in GitHub releases for faster dl. Tweak model class comment. --- timm/models/twins.py | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/timm/models/twins.py b/timm/models/twins.py index ce51c497..a534d174 100644 --- a/timm/models/twins.py +++ b/timm/models/twins.py @@ -40,22 +40,22 @@ def _cfg(url='', **kwargs): default_cfgs = { 'twins_pcpvt_small': _cfg( - url='https://s3plus.meituan.net/v1/mss_9240d97c6bf34ab1b78859c3c2a2a3e4/automl-model-zoo/models/twins/pcpvt_small.pth', + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vt3p-weights/twins_pcpvt_small-e70e7e7a.pth', ), 'twins_pcpvt_base': _cfg( - url='https://s3plus.meituan.net/v1/mss_9240d97c6bf34ab1b78859c3c2a2a3e4/automl-model-zoo/models/twins/pcpvt_base.pth', + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vt3p-weights/twins_pcpvt_base-e5ecb09b.pth', ), 'twins_pcpvt_large': _cfg( - url='https://s3plus.meituan.net/v1/mss_9240d97c6bf34ab1b78859c3c2a2a3e4/automl-model-zoo/models/twins/pcpvt_large.pth', + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vt3p-weights/twins_pcpvt_large-d273f802.pth', ), 'twins_svt_small': _cfg( - url='https://s3plus.meituan.net/v1/mss_9240d97c6bf34ab1b78859c3c2a2a3e4/automl-model-zoo/models/twins/alt_gvt_small.pth', + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vt3p-weights/twins_svt_small-42e5f78c.pth', ), 'twins_svt_base': _cfg( - url='https://s3plus.meituan.net/v1/mss_9240d97c6bf34ab1b78859c3c2a2a3e4/automl-model-zoo/models/twins/alt_gvt_base.pth', + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vt3p-weights/twins_svt_base-c2265010.pth', ), 'twins_svt_large': _cfg( - url='https://s3plus.meituan.net/v1/mss_9240d97c6bf34ab1b78859c3c2a2a3e4/automl-model-zoo/models/twins/alt_gvt_large.pth', + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vt3p-weights/twins_svt_large-90f6aaa9.pth', ), } @@ -266,7 +266,10 @@ class PatchEmbed(nn.Module): class Twins(nn.Module): - # Adapted from PVT https://github.com/whai362/PVT.git + """ Twins Vision Transfomer (Revisiting Spatial Attention) + + Adapted from PVT (PyramidVisionTransformer) class at https://github.com/whai362/PVT.git + """ def __init__( self, img_size=224, patch_size=4, in_chans=3, num_classes=1000, embed_dims=(64, 128, 256, 512), num_heads=(1, 2, 4, 8), mlp_ratios=(4, 4, 4, 4), drop_rate=0., attn_drop_rate=0., drop_path_rate=0., From b7de82e835682c2f90b6a5fc9fd325d1457193b6 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Fri, 21 May 2021 17:04:23 -0700 Subject: [PATCH 20/48] ConViT cleanup, fix torchscript, bit of reformatting, reuse existing layers. --- timm/models/convit.py | 290 ++++++++++++++---------------------------- 1 file changed, 98 insertions(+), 192 deletions(-) diff --git a/timm/models/convit.py b/timm/models/convit.py index 31c05df3..f6ae3ec1 100644 --- a/timm/models/convit.py +++ b/timm/models/convit.py @@ -1,6 +1,24 @@ -"""These modules are adapted from those of timm, see -https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py +""" ConViT Model + +@article{d2021convit, + title={ConViT: Improving Vision Transformers with Soft Convolutional Inductive Biases}, + author={d'Ascoli, St{\'e}phane and Touvron, Hugo and Leavitt, Matthew and Morcos, Ari and Biroli, Giulio and Sagun, Levent}, + journal={arXiv preprint arXiv:2103.10697}, + year={2021} +} + +Paper link: https://arxiv.org/abs/2103.10697 +Original code: https://github.com/facebookresearch/convit, original copyright below """ +# Copyright (c) 2015-present, Facebook, Inc. +# All rights reserved. +# +# This source code is licensed under the CC-by-NC license found in the +# LICENSE file in the root directory of this source tree. +# +'''These modules are adapted from those of timm, see +https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py +''' import torch import torch.nn as nn @@ -9,8 +27,9 @@ import torch.nn.functional as F from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from .helpers import build_model_with_cfg -from timm.models.layers import DropPath, to_2tuple, trunc_normal_ -from timm.models.registry import register_model +from .layers import DropPath, to_2tuple, trunc_normal_, PatchEmbed, Mlp +from .registry import register_model +from .vision_transformer_hybrid import HybridEmbed import torch import torch.nn as nn @@ -29,7 +48,7 @@ def _cfg(url='', **kwargs): default_cfgs = { # ConViT 'convit_tiny': _cfg( - url="https://dl.fbaipublicfiles.com/convit/convit_tiny.pth"), + url="https://dl.fbaipublicfiles.com/convit/convit_tiny.pth"), 'convit_small': _cfg( url="https://dl.fbaipublicfiles.com/convit/convit_small.pth"), 'convit_base': _cfg( @@ -37,71 +56,31 @@ default_cfgs = { } -class Mlp(nn.Module): - def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): - super().__init__() - out_features = out_features or in_features - hidden_features = hidden_features or in_features - self.fc1 = nn.Linear(in_features, hidden_features) - self.act = act_layer() - self.fc2 = nn.Linear(hidden_features, out_features) - self.drop = nn.Dropout(drop) - self.apply(self._init_weights) - - def _init_weights(self, m): - if isinstance(m, nn.Linear): - trunc_normal_(m.weight, std=.02) - if isinstance(m, nn.Linear) and m.bias is not None: - nn.init.constant_(m.bias, 0) - elif isinstance(m, nn.LayerNorm): - nn.init.constant_(m.bias, 0) - nn.init.constant_(m.weight, 1.0) - - def forward(self, x): - x = self.fc1(x) - x = self.act(x) - x = self.drop(x) - x = self.fc2(x) - x = self.drop(x) - return x - - class GPSA(nn.Module): def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., - locality_strength=1., use_local_init=True): + locality_strength=1.): super().__init__() self.num_heads = num_heads self.dim = dim head_dim = dim // num_heads self.scale = qk_scale or head_dim ** -0.5 + self.locality_strength = locality_strength + + self.qk = nn.Linear(dim, dim * 2, bias=qkv_bias) + self.v = nn.Linear(dim, dim, bias=qkv_bias) - self.qk = nn.Linear(dim, dim * 2, bias=qkv_bias) - self.v = nn.Linear(dim, dim, bias=qkv_bias) - self.attn_drop = nn.Dropout(attn_drop) self.proj = nn.Linear(dim, dim) self.pos_proj = nn.Linear(3, num_heads) self.proj_drop = nn.Dropout(proj_drop) self.locality_strength = locality_strength self.gating_param = nn.Parameter(torch.ones(self.num_heads)) - self.apply(self._init_weights) - if use_local_init: - self.local_init(locality_strength=locality_strength) + self.rel_indices: torch.Tensor = torch.zeros(1, 1, 1, 3) # silly torchscript hack, won't work with None - def _init_weights(self, m): - if isinstance(m, nn.Linear): - trunc_normal_(m.weight, std=.02) - if isinstance(m, nn.Linear) and m.bias is not None: - nn.init.constant_(m.bias, 0) - elif isinstance(m, nn.LayerNorm): - nn.init.constant_(m.bias, 0) - nn.init.constant_(m.weight, 1.0) - def forward(self, x): B, N, C = x.shape - if not hasattr(self, 'rel_indices') or self.rel_indices.size(1)!=N: - self.get_rel_indices(N) - + if self.rel_indices is None or self.rel_indices.shape[1] != N: + self.rel_indices = self.get_rel_indices(N) attn = self.get_attention(x) v = self.v(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) x = (attn @ v).transpose(1, 2).reshape(B, N, C) @@ -110,61 +89,58 @@ class GPSA(nn.Module): return x def get_attention(self, x): - B, N, C = x.shape + B, N, C = x.shape qk = self.qk(x).reshape(B, N, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) q, k = qk[0], qk[1] - pos_score = self.rel_indices.expand(B, -1, -1,-1) - pos_score = self.pos_proj(pos_score).permute(0,3,1,2) + pos_score = self.rel_indices.expand(B, -1, -1, -1) + pos_score = self.pos_proj(pos_score).permute(0, 3, 1, 2) patch_score = (q @ k.transpose(-2, -1)) * self.scale patch_score = patch_score.softmax(dim=-1) pos_score = pos_score.softmax(dim=-1) - gating = self.gating_param.view(1,-1,1,1) - attn = (1.-torch.sigmoid(gating)) * patch_score + torch.sigmoid(gating) * pos_score + gating = self.gating_param.view(1, -1, 1, 1) + attn = (1. - torch.sigmoid(gating)) * patch_score + torch.sigmoid(gating) * pos_score attn /= attn.sum(dim=-1).unsqueeze(-1) attn = self.attn_drop(attn) return attn - def get_attention_map(self, x, return_map = False): - - attn_map = self.get_attention(x).mean(0) # average over batch - distances = self.rel_indices.squeeze()[:,:,-1]**.5 - dist = torch.einsum('nm,hnm->h', (distances, attn_map)) - dist /= distances.size(0) + def get_attention_map(self, x, return_map=False): + attn_map = self.get_attention(x).mean(0) # average over batch + distances = self.rel_indices.squeeze()[:, :, -1] ** .5 + dist = torch.einsum('nm,hnm->h', (distances, attn_map)) / distances.size(0) if return_map: return dist, attn_map else: return dist - - def local_init(self, locality_strength=1.): - + + def local_init(self): self.v.weight.data.copy_(torch.eye(self.dim)) - locality_distance = 1 #max(1,1/locality_strength**.5) - - kernel_size = int(self.num_heads**.5) - center = (kernel_size-1)/2 if kernel_size%2==0 else kernel_size//2 + locality_distance = 1 # max(1,1/locality_strength**.5) + + kernel_size = int(self.num_heads ** .5) + center = (kernel_size - 1) / 2 if kernel_size % 2 == 0 else kernel_size // 2 for h1 in range(kernel_size): for h2 in range(kernel_size): - position = h1+kernel_size*h2 - self.pos_proj.weight.data[position,2] = -1 - self.pos_proj.weight.data[position,1] = 2*(h1-center)*locality_distance - self.pos_proj.weight.data[position,0] = 2*(h2-center)*locality_distance - self.pos_proj.weight.data *= locality_strength - - def get_rel_indices(self, num_patches): - img_size = int(num_patches**.5) - rel_indices = torch.zeros(1, num_patches, num_patches, 3) - ind = torch.arange(img_size).view(1,-1) - torch.arange(img_size).view(-1, 1) - indx = ind.repeat(img_size,img_size) - indy = ind.repeat_interleave(img_size,dim=0).repeat_interleave(img_size,dim=1) - indd = indx**2 + indy**2 - rel_indices[:,:,:,2] = indd.unsqueeze(0) - rel_indices[:,:,:,1] = indy.unsqueeze(0) - rel_indices[:,:,:,0] = indx.unsqueeze(0) + position = h1 + kernel_size * h2 + self.pos_proj.weight.data[position, 2] = -1 + self.pos_proj.weight.data[position, 1] = 2 * (h1 - center) * locality_distance + self.pos_proj.weight.data[position, 0] = 2 * (h2 - center) * locality_distance + self.pos_proj.weight.data *= self.locality_strength + + def get_rel_indices(self, num_patches: int) -> torch.Tensor: + img_size = int(num_patches ** .5) + rel_indices = torch.zeros(1, num_patches, num_patches, 3) + ind = torch.arange(img_size).view(1, -1) - torch.arange(img_size).view(-1, 1) + indx = ind.repeat(img_size, img_size) + indy = ind.repeat_interleave(img_size, dim=0).repeat_interleave(img_size, dim=1) + indd = indx ** 2 + indy ** 2 + rel_indices[:, :, :, 2] = indd.unsqueeze(0) + rel_indices[:, :, :, 1] = indy.unsqueeze(0) + rel_indices[:, :, :, 0] = indx.unsqueeze(0) device = self.qk.weight.device - self.rel_indices = rel_indices.to(device) + return rel_indices.to(device) + - class MHSA(nn.Module): def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): super().__init__() @@ -176,41 +152,28 @@ class MHSA(nn.Module): self.attn_drop = nn.Dropout(attn_drop) self.proj = nn.Linear(dim, dim) self.proj_drop = nn.Dropout(proj_drop) - self.apply(self._init_weights) - - def _init_weights(self, m): - if isinstance(m, nn.Linear): - trunc_normal_(m.weight, std=.02) - if isinstance(m, nn.Linear) and m.bias is not None: - nn.init.constant_(m.bias, 0) - elif isinstance(m, nn.LayerNorm): - nn.init.constant_(m.bias, 0) - nn.init.constant_(m.weight, 1.0) - def get_attention_map(self, x, return_map = False): + def get_attention_map(self, x, return_map=False): B, N, C = x.shape qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) q, k, v = qkv[0], qkv[1], qkv[2] attn_map = (q @ k.transpose(-2, -1)) * self.scale attn_map = attn_map.softmax(dim=-1).mean(0) - img_size = int(N**.5) - ind = torch.arange(img_size).view(1,-1) - torch.arange(img_size).view(-1, 1) - indx = ind.repeat(img_size,img_size) - indy = ind.repeat_interleave(img_size,dim=0).repeat_interleave(img_size,dim=1) - indd = indx**2 + indy**2 - distances = indd**.5 + img_size = int(N ** .5) + ind = torch.arange(img_size).view(1, -1) - torch.arange(img_size).view(-1, 1) + indx = ind.repeat(img_size, img_size) + indy = ind.repeat_interleave(img_size, dim=0).repeat_interleave(img_size, dim=1) + indd = indx ** 2 + indy ** 2 + distances = indd ** .5 distances = distances.to('cuda') - dist = torch.einsum('nm,hnm->h', (distances, attn_map)) - dist /= N - + dist = torch.einsum('nm,hnm->h', (distances, attn_map)) / N if return_map: return dist, attn_map else: return dist - def forward(self, x): B, N, C = x.shape qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) @@ -228,15 +191,19 @@ class MHSA(nn.Module): class Block(nn.Module): - def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., + def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, use_gpsa=True, **kwargs): super().__init__() self.norm1 = norm_layer(dim) self.use_gpsa = use_gpsa if self.use_gpsa: - self.attn = GPSA(dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, **kwargs) + self.attn = GPSA( + dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, + proj_drop=drop, **kwargs) else: - self.attn = MHSA(dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, **kwargs) + self.attn = MHSA( + dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, + proj_drop=drop, **kwargs) self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() self.norm2 = norm_layer(dim) mlp_hidden_dim = int(dim * mlp_ratio) @@ -246,75 +213,12 @@ class Block(nn.Module): x = x + self.drop_path(self.attn(self.norm1(x))) x = x + self.drop_path(self.mlp(self.norm2(x))) return x - - -class PatchEmbed(nn.Module): - """ Image to Patch Embedding, from timm - """ - def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768): - super().__init__() - img_size = to_2tuple(img_size) - patch_size = to_2tuple(patch_size) - num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0]) - self.img_size = img_size - self.patch_size = patch_size - self.num_patches = num_patches - - self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) - self.apply(self._init_weights) - - def forward(self, x): - B, C, H, W = x.shape - assert H == self.img_size[0] and W == self.img_size[1], \ - f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." - x = self.proj(x).flatten(2).transpose(1, 2) - return x - - def _init_weights(self, m): - if isinstance(m, nn.Linear): - trunc_normal_(m.weight, std=.02) - if isinstance(m, nn.Linear) and m.bias is not None: - nn.init.constant_(m.bias, 0) - elif isinstance(m, nn.LayerNorm): - nn.init.constant_(m.bias, 0) - nn.init.constant_(m.weight, 1.0) - - -class HybridEmbed(nn.Module): - """ CNN Feature Map Embedding, from timm - """ - def __init__(self, backbone, img_size=224, feature_size=None, in_chans=3, embed_dim=768): - super().__init__() - assert isinstance(backbone, nn.Module) - img_size = to_2tuple(img_size) - self.img_size = img_size - self.backbone = backbone - if feature_size is None: - with torch.no_grad(): - training = backbone.training - if training: - backbone.eval() - o = self.backbone(torch.zeros(1, in_chans, img_size[0], img_size[1]))[-1] - feature_size = o.shape[-2:] - feature_dim = o.shape[1] - backbone.train(training) - else: - feature_size = to_2tuple(feature_size) - feature_dim = self.backbone.feature_info.channels()[-1] - self.num_patches = feature_size[0] * feature_size[1] - self.proj = nn.Linear(feature_dim, embed_dim) - self.apply(self._init_weights) - - def forward(self, x): - x = self.backbone(x)[-1] - x = x.flatten(2).transpose(1, 2) - x = self.proj(x) - return x class ConViT(nn.Module): """ Vision Transformer with support for patch or hybrid CNN input stage """ + def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0., drop_path_rate=0., hybrid_backbone=None, norm_layer=nn.LayerNorm, global_pool=None, @@ -335,7 +239,7 @@ class ConViT(nn.Module): img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) num_patches = self.patch_embed.num_patches self.num_patches = num_patches - + self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) self.pos_drop = nn.Dropout(p=drop_rate) @@ -350,7 +254,7 @@ class ConViT(nn.Module): drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, use_gpsa=True, locality_strength=locality_strength) - if i 0 else nn.Identity() trunc_normal_(self.cls_token, std=.02) - self.head.apply(self._init_weights) + self.apply(self._init_weights) + for n, m in self.named_modules(): + if hasattr(m, 'local_init'): + m.local_init() def _init_weights(self, m): if isinstance(m, nn.Linear): @@ -395,8 +302,8 @@ class ConViT(nn.Module): x = x + self.pos_embed x = self.pos_drop(x) - for u,blk in enumerate(self.blocks): - if u == self.local_up_to_layer : + for u, blk in enumerate(self.blocks): + if u == self.local_up_to_layer: x = torch.cat((cls_tokens, x), dim=1) x = blk(x) @@ -415,30 +322,29 @@ def _create_convit(variant, pretrained=False, **kwargs): default_cfg=default_cfgs[variant], **kwargs) - + @register_model def convit_tiny(pretrained=False, **kwargs): model_args = dict( - local_up_to_layer=10, locality_strength=1.0, embed_dim=48, + local_up_to_layer=10, locality_strength=1.0, embed_dim=48, num_heads=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) - model = _create_convit( - variant='convit_tiny', pretrained=pretrained, **model_args) + model = _create_convit(variant='convit_tiny', pretrained=pretrained, **model_args) return model + @register_model def convit_small(pretrained=False, **kwargs): model_args = dict( - local_up_to_layer=10, locality_strength=1.0, embed_dim=48, + local_up_to_layer=10, locality_strength=1.0, embed_dim=48, num_heads=9, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) - model = _create_convit( - variant='convit_small', pretrained=pretrained, **model_args) + model = _create_convit(variant='convit_small', pretrained=pretrained, **model_args) return model + @register_model def convit_base(pretrained=False, **kwargs): model_args = dict( - local_up_to_layer=10, locality_strength=1.0, embed_dim=48, + local_up_to_layer=10, locality_strength=1.0, embed_dim=48, num_heads=16, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) - model = _create_convit( - variant='convit_base', pretrained=pretrained, **model_args) + model = _create_convit(variant='convit_base', pretrained=pretrained, **model_args) return model From 30b9880d06a7f65edbd6a65aba4b6fca4c735060 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Fri, 21 May 2021 17:20:33 -0700 Subject: [PATCH 21/48] Minor adjustment, mutable default arg, extra check of valid len... --- timm/models/vision_transformer.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/timm/models/vision_transformer.py b/timm/models/vision_transformer.py index 1acdd808..bef6dfb0 100644 --- a/timm/models/vision_transformer.py +++ b/timm/models/vision_transformer.py @@ -352,7 +352,7 @@ def _init_vit_weights(m, n: str = '', head_bias: float = 0., jax_impl: bool = Fa nn.init.ones_(m.weight) -def resize_pos_embed(posemb, posemb_new, num_tokens=1, gs_new=[]): +def resize_pos_embed(posemb, posemb_new, num_tokens=1, gs_new=()): # Rescale the grid of position embeddings when loading from state_dict. Adapted from # https://github.com/google-research/vision_transformer/blob/00883dd691c63a6830751563748663526e811cee/vit_jax/checkpoint.py#L224 _logger.info('Resized position embedding: %s to %s', posemb.shape, posemb_new.shape) @@ -363,8 +363,9 @@ def resize_pos_embed(posemb, posemb_new, num_tokens=1, gs_new=[]): else: posemb_tok, posemb_grid = posemb[:, :0], posemb[0] gs_old = int(math.sqrt(len(posemb_grid))) - if not len(gs_new): # backwards compatibility - gs_new = [int(math.sqrt(ntok_new))]*2 + if not len(gs_new): # backwards compatibility + gs_new = [int(math.sqrt(ntok_new))] * 2 + assert len(gs_new) >= 2 _logger.info('Position embedding grid-size from %s to %s', [gs_old, gs_old], gs_new) posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2) posemb_grid = F.interpolate(posemb_grid, size=gs_new, mode='bilinear') From c2ba229d995c33aaaf20e00a5686b4dc857044be Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Fri, 21 May 2021 17:47:49 -0700 Subject: [PATCH 22/48] Prep for effcientnetv2_rw_m model weights that started training before official release.. --- timm/models/efficientnet.py | 16 ++++++++++++++-- timm/models/efficientnet_builder.py | 8 ++++++-- 2 files changed, 20 insertions(+), 4 deletions(-) diff --git a/timm/models/efficientnet.py b/timm/models/efficientnet.py index 0c0464b5..37c1c745 100644 --- a/timm/models/efficientnet.py +++ b/timm/models/efficientnet.py @@ -162,6 +162,9 @@ default_cfgs = { 'efficientnetv2_rw_s': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_v2s_ra2_288-a6477665.pth', input_size=(3, 288, 288), test_input_size=(3, 384, 384), pool_size=(9, 9), crop_pct=1.0), + 'efficientnetv2_rw_m': _cfg( + url='', + input_size=(3, 320, 320), test_input_size=(3, 416, 416), pool_size=(10, 10), crop_pct=1.0), 'efficientnetv2_s': _cfg( url='', @@ -173,7 +176,6 @@ default_cfgs = { url='', input_size=(3, 384, 384), test_input_size=(3, 480, 480), pool_size=(12, 12), crop_pct=1.0), - 'tf_efficientnet_b0': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b0_aa-827b6e33.pth', input_size=(3, 224, 224)), @@ -1461,7 +1463,7 @@ def efficientnet_b3_pruned(pretrained=False, **kwargs): @register_model def efficientnetv2_rw_s(pretrained=False, **kwargs): - """ EfficientNet-V2 Small. + """ EfficientNet-V2 Small RW variant. NOTE: This is my initial (pre official code release) w/ some differences. See efficientnetv2_s and tf_efficientnetv2_s for versions that match the official w/ PyTorch vs TF padding """ @@ -1469,6 +1471,16 @@ def efficientnetv2_rw_s(pretrained=False, **kwargs): return model +@register_model +def efficientnetv2_rw_m(pretrained=False, **kwargs): + """ EfficientNet-V2 Medium RW variant. + """ + model = _gen_efficientnetv2_s( + 'efficientnetv2_rw_m', channel_multiplier=1.2, depth_multiplier=(1.2,) * 4 + (1.6,) * 2, rw=True, + pretrained=pretrained, **kwargs) + return model + + @register_model def efficientnetv2_s(pretrained=False, **kwargs): """ EfficientNet-V2 Small. """ diff --git a/timm/models/efficientnet_builder.py b/timm/models/efficientnet_builder.py index 9d5853c7..30739454 100644 --- a/timm/models/efficientnet_builder.py +++ b/timm/models/efficientnet_builder.py @@ -237,7 +237,11 @@ def _scale_stage_depth(stack_args, repeats, depth_multiplier=1.0, depth_trunc='c def decode_arch_def(arch_def, depth_multiplier=1.0, depth_trunc='ceil', experts_multiplier=1, fix_first_last=False): arch_args = [] - for stack_idx, block_strings in enumerate(arch_def): + if isinstance(depth_multiplier, tuple): + assert len(depth_multiplier) == len(arch_def) + else: + depth_multiplier = (depth_multiplier,) * len(arch_def) + for stack_idx, (block_strings, multiplier) in enumerate(zip(arch_def, depth_multiplier)): assert isinstance(block_strings, list) stack_args = [] repeats = [] @@ -251,7 +255,7 @@ def decode_arch_def(arch_def, depth_multiplier=1.0, depth_trunc='ceil', experts_ if fix_first_last and (stack_idx == 0 or stack_idx == len(arch_def) - 1): arch_args.append(_scale_stage_depth(stack_args, repeats, 1.0, depth_trunc)) else: - arch_args.append(_scale_stage_depth(stack_args, repeats, depth_multiplier, depth_trunc)) + arch_args.append(_scale_stage_depth(stack_args, repeats, multiplier, depth_trunc)) return arch_args From 23c18a33e4168dc7cb11439c1f9acd38dc8e9824 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Fri, 21 May 2021 21:16:25 -0700 Subject: [PATCH 23/48] Add efficientnetv2_rw_m weights trained in PyTorch. 84.8 top-1 @ 416 test. 53M params. --- timm/models/efficientnet.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/timm/models/efficientnet.py b/timm/models/efficientnet.py index 37c1c745..8aa61ec5 100644 --- a/timm/models/efficientnet.py +++ b/timm/models/efficientnet.py @@ -163,7 +163,7 @@ default_cfgs = { url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_v2s_ra2_288-a6477665.pth', input_size=(3, 288, 288), test_input_size=(3, 384, 384), pool_size=(9, 9), crop_pct=1.0), 'efficientnetv2_rw_m': _cfg( - url='', + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnetv2_rw_m_agc-3d90cb1e.pth', input_size=(3, 320, 320), test_input_size=(3, 416, 416), pool_size=(10, 10), crop_pct=1.0), 'efficientnetv2_s': _cfg( From 18bf520ad12297dac4f9992ce497030259ca1aa2 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Sat, 22 May 2021 21:55:37 -0700 Subject: [PATCH 24/48] Add eca_nfnet_l2/l3 defs for future training --- timm/models/nfnet.py | 28 ++++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/timm/models/nfnet.py b/timm/models/nfnet.py index 3c21eea1..1b67581e 100644 --- a/timm/models/nfnet.py +++ b/timm/models/nfnet.py @@ -110,6 +110,12 @@ default_cfgs = dict( eca_nfnet_l1=_dcfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/ecanfnet_l1_ra2-7dce93cd.pth', pool_size=(8, 8), input_size=(3, 256, 256), test_input_size=(3, 320, 320), crop_pct=1.0), + eca_nfnet_l2=_dcfg( + url='', + pool_size=(9, 9), input_size=(3, 288, 288), test_input_size=(3, 352, 352), crop_pct=1.0), + eca_nfnet_l3=_dcfg( + url='', + pool_size=(10, 10), input_size=(3, 320, 320), test_input_size=(3, 384, 384), crop_pct=1.0), nf_regnet_b0=_dcfg( url='', pool_size=(6, 6), input_size=(3, 192, 192), test_input_size=(3, 256, 256), first_conv='stem.conv'), @@ -244,6 +250,12 @@ model_cfgs = dict( eca_nfnet_l1=_nfnet_cfg( depths=(2, 4, 12, 6), feat_mult=2, group_size=64, bottle_ratio=0.25, attn_layer='eca', attn_kwargs=dict(), act_layer='silu'), + eca_nfnet_l2=_nfnet_cfg( + depths=(3, 6, 18, 9), feat_mult=2, group_size=64, bottle_ratio=0.25, + attn_layer='eca', attn_kwargs=dict(), act_layer='silu'), + eca_nfnet_l3=_nfnet_cfg( + depths=(4, 8, 24, 12), feat_mult=2, group_size=64, bottle_ratio=0.25, + attn_layer='eca', attn_kwargs=dict(), act_layer='silu'), # EffNet influenced RegNet defs. # NOTE: These aren't quite the official ver, ch_div=1 must be set for exact ch counts. I round to ch_div=8. @@ -814,6 +826,22 @@ def eca_nfnet_l1(pretrained=False, **kwargs): return _create_normfreenet('eca_nfnet_l1', pretrained=pretrained, **kwargs) +@register_model +def eca_nfnet_l2(pretrained=False, **kwargs): + """ ECA-NFNet-L2 w/ SiLU + My experimental 'light' model w/ F2 repeats, 2.0x final_conv mult, 64 group_size, .25 bottleneck & ECA attn + """ + return _create_normfreenet('eca_nfnet_l2', pretrained=pretrained, **kwargs) + + +@register_model +def eca_nfnet_l3(pretrained=False, **kwargs): + """ ECA-NFNet-L3 w/ SiLU + My experimental 'light' model w/ F3 repeats, 2.0x final_conv mult, 64 group_size, .25 bottleneck & ECA attn + """ + return _create_normfreenet('eca_nfnet_l3', pretrained=pretrained, **kwargs) + + @register_model def nf_regnet_b0(pretrained=False, **kwargs): """ Normalization-Free RegNet-B0 From bfc72f75d3f836ca5545cbe6ec1f8ba67b804b8b Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Mon, 24 May 2021 21:13:26 -0700 Subject: [PATCH 25/48] Expand scope of testing for non-std vision transformer / mlp models. Some related cleanup and create fn cleanup for all vision transformer and mlp models. More CoaT weights. --- tests/test_models.py | 70 ++-- timm/models/__init__.py | 2 +- timm/models/cait.py | 15 +- timm/models/coat.py | 179 +++++---- timm/models/convit.py | 5 +- timm/models/helpers.py | 8 +- timm/models/levit.py | 471 ++++++++++++++--------- timm/models/levitc.py | 400 ------------------- timm/models/mlp_mixer.py | 15 +- timm/models/pit.py | 12 +- timm/models/tnt.py | 31 +- timm/models/twins.py | 17 +- timm/models/visformer.py | 155 ++++---- timm/models/vision_transformer.py | 23 +- timm/models/vision_transformer_hybrid.py | 5 +- 15 files changed, 559 insertions(+), 849 deletions(-) delete mode 100644 timm/models/levitc.py diff --git a/tests/test_models.py b/tests/test_models.py index 5ff9fb33..570b49db 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -26,29 +26,41 @@ if 'GITHUB_ACTIONS' in os.environ: # and 'Linux' in platform.system(): EXCLUDE_FILTERS = [ '*efficientnet_l2*', '*resnext101_32x48d', '*in21k', '*152x4_bitm', '*101x3_bitm', '*nfnet_f3*', '*nfnet_f4*', '*nfnet_f5*', '*nfnet_f6*', '*nfnet_f7*', - '*resnetrs350*', '*resnetrs420*'] + NON_STD_FILTERS + '*resnetrs350*', '*resnetrs420*'] else: - EXCLUDE_FILTERS = NON_STD_FILTERS + EXCLUDE_FILTERS = [] -MAX_FWD_SIZE = 384 -MAX_BWD_SIZE = 128 +TARGET_FWD_SIZE = MAX_FWD_SIZE = 384 +TARGET_BWD_SIZE = 128 +MAX_BWD_SIZE = 384 MAX_FWD_FEAT_SIZE = 448 +def _get_input_size(model, target=None): + default_cfg = model.default_cfg + input_size = default_cfg['input_size'] + if 'fixed_input_size' in default_cfg and default_cfg['fixed_input_size']: + return input_size + if 'min_input_size' in default_cfg: + if target and max(input_size) > target: + input_size = default_cfg['min_input_size'] + else: + if target and max(input_size) > target: + input_size = tuple([min(x, target) for x in input_size]) + return input_size + + @pytest.mark.timeout(120) -@pytest.mark.parametrize('model_name', list_models(exclude_filters=EXCLUDE_FILTERS[:-NUM_NON_STD])) +@pytest.mark.parametrize('model_name', list_models(exclude_filters=EXCLUDE_FILTERS)) @pytest.mark.parametrize('batch_size', [1]) def test_model_forward(model_name, batch_size): """Run a single forward pass with each model""" model = create_model(model_name, pretrained=False) model.eval() - input_size = model.default_cfg['input_size'] - if any([x > MAX_FWD_SIZE for x in input_size]): - if is_model_default_key(model_name, 'fixed_input_size'): - pytest.skip("Fixed input size model > limit.") - # cap forward test at max res 384 * 384 to keep resource down - input_size = tuple([min(x, MAX_FWD_SIZE) for x in input_size]) + input_size = _get_input_size(model, TARGET_FWD_SIZE) + if max(input_size) > MAX_FWD_SIZE: + pytest.skip("Fixed input size model > limit.") inputs = torch.randn((batch_size, *input_size)) outputs = model(inputs) @@ -63,20 +75,16 @@ def test_model_backward(model_name, batch_size): """Run a single forward pass with each model""" model = create_model(model_name, pretrained=False, num_classes=42) num_params = sum([x.numel() for x in model.parameters()]) - model.eval() + model.train() - input_size = model.default_cfg['input_size'] - if not is_model_default_key(model_name, 'fixed_input_size'): - min_input_size = get_model_default_value(model_name, 'min_input_size') - if min_input_size is not None: - input_size = min_input_size - else: - if any([x > MAX_BWD_SIZE for x in input_size]): - # cap backward test at 128 * 128 to keep resource usage down - input_size = tuple([min(x, MAX_BWD_SIZE) for x in input_size]) + input_size = _get_input_size(model, TARGET_BWD_SIZE) + if max(input_size) > MAX_BWD_SIZE: + pytest.skip("Fixed input size model > limit.") inputs = torch.randn((batch_size, *input_size)) outputs = model(inputs) + if isinstance(outputs, tuple): + outputs = torch.cat(outputs) outputs.mean().backward() for n, x in model.named_parameters(): assert x.grad is not None, f'No gradient for {n}' @@ -168,12 +176,9 @@ def test_model_forward_torchscript(model_name, batch_size): model = create_model(model_name, pretrained=False) model.eval() - if has_model_default_key(model_name, 'fixed_input_size'): - input_size = get_model_default_value(model_name, 'input_size') - elif has_model_default_key(model_name, 'min_input_size'): - input_size = get_model_default_value(model_name, 'min_input_size') - else: - input_size = (3, 128, 128) # jit compile is already a bit slow and we've tested normal res already... + input_size = _get_input_size(model, 128) + if max(input_size) > MAX_FWD_SIZE: # NOTE using MAX_FWD_SIZE as the final limit is intentional + pytest.skip("Fixed input size model > limit.") model = torch.jit.script(model) outputs = model(torch.randn((batch_size, *input_size))) @@ -184,7 +189,7 @@ def test_model_forward_torchscript(model_name, batch_size): EXCLUDE_FEAT_FILTERS = [ '*pruned*', # hopefully fix at some point -] +] + NON_STD_FILTERS if 'GITHUB_ACTIONS' in os.environ: # and 'Linux' in platform.system(): # GitHub Linux runner is slower and hits memory limits sooner than MacOS, exclude bigger models EXCLUDE_FEAT_FILTERS += ['*resnext101_32x32d', '*resnext101_32x16d'] @@ -200,12 +205,9 @@ def test_model_forward_features(model_name, batch_size): expected_channels = model.feature_info.channels() assert len(expected_channels) >= 4 # all models here should have at least 4 feature levels by default, some 5 or 6 - if has_model_default_key(model_name, 'fixed_input_size'): - input_size = get_model_default_value(model_name, 'input_size') - elif has_model_default_key(model_name, 'min_input_size'): - input_size = get_model_default_value(model_name, 'min_input_size') - else: - input_size = (3, 96, 96) # jit compile is already a bit slow and we've tested normal res already... + input_size = _get_input_size(model, 96) # jit compile is already a bit slow and we've tested normal res already... + if max(input_size) > MAX_FWD_SIZE: # NOTE using MAX_FWD_SIZE as the final limit is intentional + pytest.skip("Fixed input size model > limit.") outputs = model(torch.randn((batch_size, *input_size))) assert len(expected_channels) == len(outputs) diff --git a/timm/models/__init__.py b/timm/models/__init__.py index 1a21de09..788b7518 100644 --- a/timm/models/__init__.py +++ b/timm/models/__init__.py @@ -16,8 +16,8 @@ from .hrnet import * from .inception_resnet_v2 import * from .inception_v3 import * from .inception_v4 import * -from .levitc import * from .levit import * +#from .levit import * from .mlp_mixer import * from .mobilenetv3 import * from .nasnet import * diff --git a/timm/models/cait.py b/timm/models/cait.py index c5f7742f..aa2e5f07 100644 --- a/timm/models/cait.py +++ b/timm/models/cait.py @@ -306,26 +306,15 @@ def checkpoint_filter_fn(state_dict, model=None): return checkpoint_no_module -def _create_cait(variant, pretrained=False, default_cfg=None, **kwargs): - if default_cfg is None: - default_cfg = deepcopy(default_cfgs[variant]) - overlay_external_default_cfg(default_cfg, kwargs) - default_num_classes = default_cfg['num_classes'] - default_img_size = default_cfg['input_size'][-2:] - num_classes = kwargs.pop('num_classes', default_num_classes) - img_size = kwargs.pop('img_size', default_img_size) - +def _create_cait(variant, pretrained=False, **kwargs): if kwargs.get('features_only', None): raise RuntimeError('features_only not implemented for Vision Transformer models.') model = build_model_with_cfg( Cait, variant, pretrained, - default_cfg=default_cfg, - img_size=img_size, - num_classes=num_classes, + default_cfg=default_cfgs[variant], pretrained_filter_fn=checkpoint_filter_fn, **kwargs) - return model diff --git a/timm/models/coat.py b/timm/models/coat.py index cb265522..9eb384d8 100644 --- a/timm/models/coat.py +++ b/timm/models/coat.py @@ -7,19 +7,19 @@ Official CoaT code at: https://github.com/mlpc-ucsd/CoaT Modified from timm/models/vision_transformer.py """ -from typing import Tuple, Dict, Any, Optional +from copy import deepcopy +from functools import partial +from typing import Tuple, List import torch import torch.nn as nn import torch.nn.functional as F from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from timm.models.helpers import load_pretrained -from timm.models.layers import PatchEmbed, Mlp, DropPath, to_2tuple, trunc_normal_ -from timm.models.registry import register_model +from .helpers import build_model_with_cfg, overlay_external_default_cfg +from .layers import PatchEmbed, Mlp, DropPath, to_2tuple, trunc_normal_ +from .registry import register_model -from functools import partial -from torch import nn __all__ = [ "coat_tiny", @@ -34,7 +34,7 @@ def _cfg_coat(url='', **kwargs): return { 'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, - 'crop_pct': .9, 'interpolation': 'bicubic', + 'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True, 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, 'first_conv': 'patch_embed1.proj', 'classifier': 'head', **kwargs @@ -42,15 +42,21 @@ def _cfg_coat(url='', **kwargs): default_cfgs = { - 'coat_tiny': _cfg_coat(), - 'coat_mini': _cfg_coat(), + 'coat_tiny': _cfg_coat( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-coat-weights/coat_tiny-473c2a20.pth' + ), + 'coat_mini': _cfg_coat( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-coat-weights/coat_mini-2c6baf49.pth' + ), 'coat_lite_tiny': _cfg_coat( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-coat-weights/coat_lite_tiny-461b07a7.pth' ), 'coat_lite_mini': _cfg_coat( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-coat-weights/coat_lite_mini-d7842000.pth' ), - 'coat_lite_small': _cfg_coat(), + 'coat_lite_small': _cfg_coat( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-coat-weights/coat_lite_small-fea1d5a1.pth' + ), } @@ -120,11 +126,11 @@ class ConvRelPosEnc(nn.Module): class FactorAtt_ConvRelPosEnc(nn.Module): """ Factorized attention with convolutional relative position encoding class. """ - def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., shared_crpe=None): + def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0., shared_crpe=None): super().__init__() self.num_heads = num_heads head_dim = dim // num_heads - self.scale = qk_scale or head_dim ** -0.5 + self.scale = head_dim ** -0.5 self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) self.attn_drop = nn.Dropout(attn_drop) # Note: attn_drop is actually not used. @@ -190,9 +196,8 @@ class ConvPosEnc(nn.Module): class SerialBlock(nn.Module): """ Serial block class. Note: In this implementation, each serial block only contains a conv-attention and a FFN (MLP) module. """ - def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., - drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, - shared_cpe=None, shared_crpe=None): + def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0., + drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, shared_cpe=None, shared_crpe=None): super().__init__() # Conv-Attention. @@ -200,8 +205,7 @@ class SerialBlock(nn.Module): self.norm1 = norm_layer(dim) self.factoratt_crpe = FactorAtt_ConvRelPosEnc( - dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, - shared_crpe=shared_crpe) + dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop, shared_crpe=shared_crpe) self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() # MLP. @@ -226,27 +230,24 @@ class SerialBlock(nn.Module): class ParallelBlock(nn.Module): """ Parallel block class. """ - def __init__(self, dims, num_heads, mlp_ratios=[], qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., - drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, - shared_cpes=None, shared_crpes=None): + def __init__(self, dims, num_heads, mlp_ratios=[], qkv_bias=False, drop=0., attn_drop=0., + drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, shared_crpes=None): super().__init__() # Conv-Attention. - self.cpes = shared_cpes - self.norm12 = norm_layer(dims[1]) self.norm13 = norm_layer(dims[2]) self.norm14 = norm_layer(dims[3]) self.factoratt_crpe2 = FactorAtt_ConvRelPosEnc( - dims[1], num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, + dims[1], num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop, shared_crpe=shared_crpes[1] ) self.factoratt_crpe3 = FactorAtt_ConvRelPosEnc( - dims[2], num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, + dims[2], num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop, shared_crpe=shared_crpes[2] ) self.factoratt_crpe4 = FactorAtt_ConvRelPosEnc( - dims[3], num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, + dims[3], num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop, shared_crpe=shared_crpes[3] ) self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() @@ -262,15 +263,15 @@ class ParallelBlock(nn.Module): self.mlp2 = self.mlp3 = self.mlp4 = Mlp( in_features=dims[1], hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) - def upsample(self, x, factor, size): + def upsample(self, x, factor: float, size: Tuple[int, int]): """ Feature map up-sampling. """ return self.interpolate(x, scale_factor=factor, size=size) - def downsample(self, x, factor, size): + def downsample(self, x, factor: float, size: Tuple[int, int]): """ Feature map down-sampling. """ return self.interpolate(x, scale_factor=1.0/factor, size=size) - def interpolate(self, x, scale_factor, size): + def interpolate(self, x, scale_factor: float, size: Tuple[int, int]): """ Feature map interpolation. """ B, N, C = x.shape H, W = size @@ -280,33 +281,28 @@ class ParallelBlock(nn.Module): img_tokens = x[:, 1:, :] img_tokens = img_tokens.transpose(1, 2).reshape(B, C, H, W) - img_tokens = F.interpolate(img_tokens, scale_factor=scale_factor, mode='bilinear') + img_tokens = F.interpolate( + img_tokens, scale_factor=scale_factor, recompute_scale_factor=False, mode='bilinear', align_corners=False) img_tokens = img_tokens.reshape(B, C, -1).transpose(1, 2) out = torch.cat((cls_token, img_tokens), dim=1) return out - def forward(self, x1, x2, x3, x4, sizes): - _, (H2, W2), (H3, W3), (H4, W4) = sizes - - # Conv-Attention. - x2 = self.cpes[1](x2, size=(H2, W2)) # Note: x1 is ignored. - x3 = self.cpes[2](x3, size=(H3, W3)) - x4 = self.cpes[3](x4, size=(H4, W4)) - + def forward(self, x1, x2, x3, x4, sizes: List[Tuple[int, int]]): + _, S2, S3, S4 = sizes cur2 = self.norm12(x2) cur3 = self.norm13(x3) cur4 = self.norm14(x4) - cur2 = self.factoratt_crpe2(cur2, size=(H2, W2)) - cur3 = self.factoratt_crpe3(cur3, size=(H3, W3)) - cur4 = self.factoratt_crpe4(cur4, size=(H4, W4)) - upsample3_2 = self.upsample(cur3, factor=2, size=(H3, W3)) - upsample4_3 = self.upsample(cur4, factor=2, size=(H4, W4)) - upsample4_2 = self.upsample(cur4, factor=4, size=(H4, W4)) - downsample2_3 = self.downsample(cur2, factor=2, size=(H2, W2)) - downsample3_4 = self.downsample(cur3, factor=2, size=(H3, W3)) - downsample2_4 = self.downsample(cur2, factor=4, size=(H2, W2)) + cur2 = self.factoratt_crpe2(cur2, size=S2) + cur3 = self.factoratt_crpe3(cur3, size=S3) + cur4 = self.factoratt_crpe4(cur4, size=S4) + upsample3_2 = self.upsample(cur3, factor=2., size=S3) + upsample4_3 = self.upsample(cur4, factor=2., size=S4) + upsample4_2 = self.upsample(cur4, factor=4., size=S4) + downsample2_3 = self.downsample(cur2, factor=2., size=S2) + downsample3_4 = self.downsample(cur3, factor=2., size=S3) + downsample2_4 = self.downsample(cur2, factor=4., size=S2) cur2 = cur2 + upsample3_2 + upsample4_2 cur3 = cur3 + upsample4_3 + downsample2_3 cur4 = cur4 + downsample3_4 + downsample2_4 @@ -330,11 +326,11 @@ class ParallelBlock(nn.Module): class CoaT(nn.Module): """ CoaT class. """ - def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dims=[0, 0, 0, 0], - serial_depths=[0, 0, 0, 0], parallel_depth=0, - num_heads=0, mlp_ratios=[0, 0, 0, 0], qkv_bias=True, qk_scale=None, drop_rate=0., attn_drop_rate=0., - drop_path_rate=0., norm_layer=partial(nn.LayerNorm, eps=1e-6), - return_interm_layers=False, out_features = None, crpe_window=None, **kwargs): + def __init__( + self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dims=(0, 0, 0, 0), + serial_depths=(0, 0, 0, 0), parallel_depth=0, num_heads=0, mlp_ratios=(0, 0, 0, 0), qkv_bias=True, + drop_rate=0., attn_drop_rate=0., drop_path_rate=0., norm_layer=partial(nn.LayerNorm, eps=1e-6), + return_interm_layers=False, out_features=None, crpe_window=None, **kwargs): super().__init__() crpe_window = crpe_window or {3: 2, 5: 3, 7: 3} self.return_interm_layers = return_interm_layers @@ -342,17 +338,18 @@ class CoaT(nn.Module): self.num_classes = num_classes # Patch embeddings. + img_size = to_2tuple(img_size) self.patch_embed1 = PatchEmbed( img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dims[0], norm_layer=nn.LayerNorm) self.patch_embed2 = PatchEmbed( - img_size=img_size // 4, patch_size=2, in_chans=embed_dims[0], + img_size=[x // 4 for x in img_size], patch_size=2, in_chans=embed_dims[0], embed_dim=embed_dims[1], norm_layer=nn.LayerNorm) self.patch_embed3 = PatchEmbed( - img_size=img_size // 8, patch_size=2, in_chans=embed_dims[1], + img_size=[x // 8 for x in img_size], patch_size=2, in_chans=embed_dims[1], embed_dim=embed_dims[2], norm_layer=nn.LayerNorm) self.patch_embed4 = PatchEmbed( - img_size=img_size // 16, patch_size=2, in_chans=embed_dims[2], + img_size=[x // 16 for x in img_size], patch_size=2, in_chans=embed_dims[2], embed_dim=embed_dims[3], norm_layer=nn.LayerNorm) # Class tokens. @@ -380,7 +377,7 @@ class CoaT(nn.Module): # Serial blocks 1. self.serial_blocks1 = nn.ModuleList([ SerialBlock( - dim=embed_dims[0], num_heads=num_heads, mlp_ratio=mlp_ratios[0], qkv_bias=qkv_bias, qk_scale=qk_scale, + dim=embed_dims[0], num_heads=num_heads, mlp_ratio=mlp_ratios[0], qkv_bias=qkv_bias, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr, norm_layer=norm_layer, shared_cpe=self.cpe1, shared_crpe=self.crpe1 ) @@ -390,7 +387,7 @@ class CoaT(nn.Module): # Serial blocks 2. self.serial_blocks2 = nn.ModuleList([ SerialBlock( - dim=embed_dims[1], num_heads=num_heads, mlp_ratio=mlp_ratios[1], qkv_bias=qkv_bias, qk_scale=qk_scale, + dim=embed_dims[1], num_heads=num_heads, mlp_ratio=mlp_ratios[1], qkv_bias=qkv_bias, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr, norm_layer=norm_layer, shared_cpe=self.cpe2, shared_crpe=self.crpe2 ) @@ -400,7 +397,7 @@ class CoaT(nn.Module): # Serial blocks 3. self.serial_blocks3 = nn.ModuleList([ SerialBlock( - dim=embed_dims[2], num_heads=num_heads, mlp_ratio=mlp_ratios[2], qkv_bias=qkv_bias, qk_scale=qk_scale, + dim=embed_dims[2], num_heads=num_heads, mlp_ratio=mlp_ratios[2], qkv_bias=qkv_bias, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr, norm_layer=norm_layer, shared_cpe=self.cpe3, shared_crpe=self.crpe3 ) @@ -410,7 +407,7 @@ class CoaT(nn.Module): # Serial blocks 4. self.serial_blocks4 = nn.ModuleList([ SerialBlock( - dim=embed_dims[3], num_heads=num_heads, mlp_ratio=mlp_ratios[3], qkv_bias=qkv_bias, qk_scale=qk_scale, + dim=embed_dims[3], num_heads=num_heads, mlp_ratio=mlp_ratios[3], qkv_bias=qkv_bias, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr, norm_layer=norm_layer, shared_cpe=self.cpe4, shared_crpe=self.crpe4 ) @@ -422,10 +419,9 @@ class CoaT(nn.Module): if self.parallel_depth > 0: self.parallel_blocks = nn.ModuleList([ ParallelBlock( - dims=embed_dims, num_heads=num_heads, mlp_ratios=mlp_ratios, qkv_bias=qkv_bias, qk_scale=qk_scale, - drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr, norm_layer=norm_layer, - shared_cpes=[self.cpe1, self.cpe2, self.cpe3, self.cpe4], - shared_crpes=[self.crpe1, self.crpe2, self.crpe3, self.crpe4] + dims=embed_dims, num_heads=num_heads, mlp_ratios=mlp_ratios, qkv_bias=qkv_bias, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr, norm_layer=norm_layer, + shared_crpes=(self.crpe1, self.crpe2, self.crpe3, self.crpe4) ) for _ in range(parallel_depth)] ) @@ -434,9 +430,11 @@ class CoaT(nn.Module): # Classification head(s). if not self.return_interm_layers: - self.norm1 = norm_layer(embed_dims[0]) - self.norm2 = norm_layer(embed_dims[1]) - self.norm3 = norm_layer(embed_dims[2]) + if self.parallel_blocks is not None: + self.norm2 = norm_layer(embed_dims[1]) + self.norm3 = norm_layer(embed_dims[2]) + else: + self.norm2 = self.norm3 = None self.norm4 = norm_layer(embed_dims[3]) if self.parallel_depth > 0: @@ -546,6 +544,7 @@ class CoaT(nn.Module): # Parallel blocks. for blk in self.parallel_blocks: + x2, x3, x4 = self.cpe2(x2, (H2, W2)), self.cpe3(x3, (H3, W3)), self.cpe4(x4, (H4, W4)) x1, x2, x3, x4 = blk(x1, x2, x3, x4, sizes=[(H1, W1), (H2, W2), (H3, W3), (H4, W4)]) if not torch.jit.is_scripting() and self.return_interm_layers: @@ -590,52 +589,70 @@ class CoaT(nn.Module): return x +def checkpoint_filter_fn(state_dict, model): + out_dict = {} + for k, v in state_dict.items(): + # original model had unused norm layers, removing them requires filtering pretrained checkpoints + if k.startswith('norm1') or \ + (model.norm2 is None and k.startswith('norm2')) or \ + (model.norm3 is None and k.startswith('norm3')): + continue + out_dict[k] = v + return out_dict + + +def _create_coat(variant, pretrained=False, default_cfg=None, **kwargs): + if kwargs.get('features_only', None): + raise RuntimeError('features_only not implemented for Vision Transformer models.') + + model = build_model_with_cfg( + CoaT, variant, pretrained, + default_cfg=default_cfgs[variant], + pretrained_filter_fn=checkpoint_filter_fn, + **kwargs) + return model + + @register_model def coat_tiny(pretrained=False, **kwargs): - model = CoaT( + model_cfg = dict( patch_size=4, embed_dims=[152, 152, 152, 152], serial_depths=[2, 2, 2, 2], parallel_depth=6, num_heads=8, mlp_ratios=[4, 4, 4, 4], **kwargs) - model.default_cfg = default_cfgs['coat_tiny'] + model = _create_coat('coat_tiny', pretrained=pretrained, **model_cfg) return model @register_model def coat_mini(pretrained=False, **kwargs): - model = CoaT( + model_cfg = dict( patch_size=4, embed_dims=[152, 216, 216, 216], serial_depths=[2, 2, 2, 2], parallel_depth=6, num_heads=8, mlp_ratios=[4, 4, 4, 4], **kwargs) - model.default_cfg = default_cfgs['coat_mini'] + model = _create_coat('coat_mini', pretrained=pretrained, **model_cfg) return model @register_model def coat_lite_tiny(pretrained=False, **kwargs): - model = CoaT( + model_cfg = dict( patch_size=4, embed_dims=[64, 128, 256, 320], serial_depths=[2, 2, 2, 2], parallel_depth=0, num_heads=8, mlp_ratios=[8, 8, 4, 4], **kwargs) - # FIXME use builder - model.default_cfg = default_cfgs['coat_lite_tiny'] - if pretrained: - load_pretrained(model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3)) + model = _create_coat('coat_lite_tiny', pretrained=pretrained, **model_cfg) return model @register_model def coat_lite_mini(pretrained=False, **kwargs): - model = CoaT( + model_cfg = dict( patch_size=4, embed_dims=[64, 128, 320, 512], serial_depths=[2, 2, 2, 2], parallel_depth=0, num_heads=8, mlp_ratios=[8, 8, 4, 4], **kwargs) - # FIXME use builder - model.default_cfg = default_cfgs['coat_lite_mini'] - if pretrained: - load_pretrained(model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3)) + model = _create_coat('coat_lite_mini', pretrained=pretrained, **model_cfg) return model @register_model def coat_lite_small(pretrained=False, **kwargs): - model = CoaT( + model_cfg = dict( patch_size=4, embed_dims=[64, 128, 320, 512], serial_depths=[3, 4, 6, 3], parallel_depth=0, num_heads=8, mlp_ratios=[8, 8, 4, 4], **kwargs) - model.default_cfg = default_cfgs['coat_lite_small'] + model = _create_coat('coat_lite_small', pretrained=pretrained, **model_cfg) return model \ No newline at end of file diff --git a/timm/models/convit.py b/timm/models/convit.py index f6ae3ec1..b15b46d8 100644 --- a/timm/models/convit.py +++ b/timm/models/convit.py @@ -39,7 +39,7 @@ def _cfg(url='', **kwargs): return { 'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, - 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, + 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, 'fixed_input_size': True, 'first_conv': 'patch_embed.proj', 'classifier': 'head', **kwargs } @@ -317,6 +317,9 @@ class ConViT(nn.Module): def _create_convit(variant, pretrained=False, **kwargs): + if kwargs.get('features_only', None): + raise RuntimeError('features_only not implemented for Vision Transformer models.') + return build_model_with_cfg( ConViT, variant, pretrained, default_cfg=default_cfgs[variant], diff --git a/timm/models/helpers.py b/timm/models/helpers.py index e9ac7f00..dfb6b860 100644 --- a/timm/models/helpers.py +++ b/timm/models/helpers.py @@ -44,7 +44,7 @@ def load_state_dict(checkpoint_path, use_ema=False): raise FileNotFoundError() -def load_checkpoint(model, checkpoint_path, use_ema=False, strict=True): +def load_checkpoint(model, checkpoint_path, use_ema=False, strict=False): state_dict = load_state_dict(checkpoint_path, use_ema) model.load_state_dict(state_dict, strict=strict) @@ -378,7 +378,11 @@ def update_default_cfg_and_kwargs(default_cfg, kwargs, kwargs_filter): # Overlay default cfg values from `external_default_cfg` if it exists in kwargs overlay_external_default_cfg(default_cfg, kwargs) # Set model __init__ args that can be determined by default_cfg (if not already passed as kwargs) - set_default_kwargs(kwargs, names=('num_classes', 'global_pool', 'in_chans'), default_cfg=default_cfg) + default_kwarg_names = ('num_classes', 'global_pool', 'in_chans') + if default_cfg.get('fixed_input_size', False): + # if fixed_input_size exists and is True, model takes an img_size arg that fixes its input size + default_kwarg_names += ('img_size',) + set_default_kwargs(kwargs, names=default_kwarg_names, default_cfg=default_cfg) # Filter keyword args for task specific model variants (some 'features only' models, etc.) filter_kwargs(kwargs, names=kwargs_filter) diff --git a/timm/models/levit.py b/timm/models/levit.py index 997b44d7..96a0c85b 100644 --- a/timm/models/levit.py +++ b/timm/models/levit.py @@ -1,3 +1,22 @@ +""" LeViT + +Paper: `LeViT: a Vision Transformer in ConvNet's Clothing for Faster Inference` + - https://arxiv.org/abs/2104.01136 + +@article{graham2021levit, + title={LeViT: a Vision Transformer in ConvNet's Clothing for Faster Inference}, + author={Benjamin Graham and Alaaeldin El-Nouby and Hugo Touvron and Pierre Stock and Armand Joulin and Herv\'e J\'egou and Matthijs Douze}, + journal={arXiv preprint arXiv:22104.01136}, + year={2021} +} + +Adapted from official impl at https://github.com/facebookresearch/LeViT, original copyright bellow. + +This version combines both conv/linear models and fixes torchscript compatibility. + +Modifications by/coyright Copyright 2021 Ross Wightman +""" + # Copyright (c) 2015-present, Facebook, Inc. # All rights reserved. @@ -5,10 +24,15 @@ # https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py # Copyright 2020 Ross Wightman, Apache-2.0 License import itertools +from copy import deepcopy +from functools import partial import torch +import torch.nn as nn from timm.data import IMAGENET_DEFAULT_STD, IMAGENET_DEFAULT_MEAN +from .helpers import build_model_with_cfg, overlay_external_default_cfg +from .layers import to_ntuple from .vision_transformer import trunc_normal_ from .registry import register_model @@ -19,70 +43,113 @@ def _cfg(url='', **kwargs): 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, 'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True, 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, - 'first_conv': 'patch_embed.proj', 'classifier': 'head', + 'first_conv': 'patch_embed.0.c', 'classifier': ('head.l', 'head_dist.l'), **kwargs } -specification = { - 'levit_128s': { - 'C': '128_256_384', 'D': 16, 'N': '4_6_8', 'X': '2_3_4', 'drop_path': 0, - 'weights': 'https://dl.fbaipublicfiles.com/LeViT/LeViT-128S-96703c44.pth'}, - 'levit_128': { - 'C': '128_256_384', 'D': 16, 'N': '4_8_12', 'X': '4_4_4', 'drop_path': 0, - 'weights': 'https://dl.fbaipublicfiles.com/LeViT/LeViT-128-b88c2750.pth'}, - 'levit_192': { - 'C': '192_288_384', 'D': 32, 'N': '3_5_6', 'X': '4_4_4', 'drop_path': 0, - 'weights': 'https://dl.fbaipublicfiles.com/LeViT/LeViT-192-92712e41.pth'}, - 'levit_256': { - 'C': '256_384_512', 'D': 32, 'N': '4_6_8', 'X': '4_4_4', 'drop_path': 0, - 'weights': 'https://dl.fbaipublicfiles.com/LeViT/LeViT-256-13b5763e.pth'}, - 'levit_384': { - 'C': '384_512_768', 'D': 32, 'N': '6_9_12', 'X': '4_4_4', 'drop_path': 0.1, - 'weights': 'https://dl.fbaipublicfiles.com/LeViT/LeViT-384-9bdaf2e2.pth'}, -} +default_cfgs = dict( + levit_128s=_cfg( + url='https://dl.fbaipublicfiles.com/LeViT/LeViT-128S-96703c44.pth' + ), + levit_128=_cfg( + url='https://dl.fbaipublicfiles.com/LeViT/LeViT-128-b88c2750.pth' + ), + levit_192=_cfg( + url='https://dl.fbaipublicfiles.com/LeViT/LeViT-192-92712e41.pth' + ), + levit_256=_cfg( + url='https://dl.fbaipublicfiles.com/LeViT/LeViT-256-13b5763e.pth' + ), + levit_384=_cfg( + url='https://dl.fbaipublicfiles.com/LeViT/LeViT-384-9bdaf2e2.pth' + ), +) + +model_cfgs = dict( + levit_128s=dict( + embed_dim=(128, 256, 384), key_dim=16, num_heads=(4, 6, 8), depth=(2, 3, 4)), + levit_128=dict( + embed_dim=(128, 256, 384), key_dim=16, num_heads=(4, 8, 12), depth=(4, 4, 4)), + levit_192=dict( + embed_dim=(192, 288, 384), key_dim=32, num_heads=(3, 5, 6), depth=(4, 4, 4)), + levit_256=dict( + embed_dim=(256, 384, 512), key_dim=32, num_heads=(4, 6, 8), depth=(4, 4, 4)), + levit_384=dict( + embed_dim=(384, 512, 768), key_dim=32, num_heads=(6, 9, 12), depth=(4, 4, 4)), +) __all__ = ['Levit'] @register_model -def levit_128s(num_classes=1000, distillation=True, pretrained=False, fuse=False, **kwargs): - return model_factory(**specification['levit_128s'], num_classes=num_classes, - distillation=distillation, pretrained=pretrained, fuse=fuse) +def levit_128s(pretrained=False, fuse=False,distillation=True, use_conv=False, **kwargs): + return create_levit( + 'levit_128s', pretrained=pretrained, fuse=fuse, distillation=distillation, use_conv=use_conv, **kwargs) + + +@register_model +def levit_128(pretrained=False, fuse=False, distillation=True, use_conv=False, **kwargs): + return create_levit( + 'levit_128', pretrained=pretrained, fuse=fuse, distillation=distillation, use_conv=use_conv, **kwargs) + + +@register_model +def levit_192(pretrained=False, fuse=False, distillation=True, use_conv=False, **kwargs): + return create_levit( + 'levit_192', pretrained=pretrained, fuse=fuse, distillation=distillation, use_conv=use_conv, **kwargs) @register_model -def levit_128(num_classes=1000, distillation=True, pretrained=False, fuse=False, **kwargs): - return model_factory(**specification['levit_128'], num_classes=num_classes, - distillation=distillation, pretrained=pretrained, fuse=fuse) +def levit_256(pretrained=False, fuse=False, distillation=True, use_conv=False, **kwargs): + return create_levit( + 'levit_256', pretrained=pretrained, fuse=fuse, distillation=distillation, use_conv=use_conv, **kwargs) @register_model -def levit_192(num_classes=1000, distillation=True, pretrained=False, fuse=False, **kwargs): - return model_factory(**specification['levit_192'], num_classes=num_classes, - distillation=distillation, pretrained=pretrained, fuse=fuse) +def levit_384(pretrained=False, fuse=False, distillation=True, use_conv=False, **kwargs): + return create_levit( + 'levit_384', pretrained=pretrained, fuse=fuse, distillation=distillation, use_conv=use_conv, **kwargs) @register_model -def levit_256(num_classes=1000, distillation=True, pretrained=False, fuse=False, **kwargs): - return model_factory(**specification['levit_256'], num_classes=num_classes, - distillation=distillation, pretrained=pretrained, fuse=fuse) +def levit_c_128s(pretrained=False, fuse=False, distillation=True, use_conv=True,**kwargs): + return create_levit( + 'levit_128s', pretrained=pretrained, fuse=fuse, distillation=distillation, use_conv=use_conv, **kwargs) @register_model -def levit_384(num_classes=1000, distillation=True, pretrained=False, fuse=False, **kwargs): - return model_factory(**specification['levit_384'], num_classes=num_classes, - distillation=distillation, pretrained=pretrained, fuse=fuse) +def levit_c_128(pretrained=False, fuse=False,distillation=True, use_conv=True, **kwargs): + return create_levit( + 'levit_128', pretrained=pretrained, fuse=fuse, distillation=distillation, use_conv=use_conv, **kwargs) -class ConvNorm(torch.nn.Sequential): +@register_model +def levit_c_192(pretrained=False, fuse=False, distillation=True, use_conv=True, **kwargs): + return create_levit( + 'levit_192', pretrained=pretrained, fuse=fuse, distillation=distillation, use_conv=use_conv, **kwargs) + + +@register_model +def levit_c_256(pretrained=False, fuse=False, distillation=True, use_conv=True, **kwargs): + return create_levit( + 'levit_256', pretrained=pretrained, fuse=fuse, distillation=distillation, use_conv=use_conv, **kwargs) + + +@register_model +def levit_c_384(pretrained=False, fuse=False, distillation=True, use_conv=True, **kwargs): + return create_levit( + 'levit_384', pretrained=pretrained, fuse=fuse, distillation=distillation, use_conv=use_conv, **kwargs) + + +class ConvNorm(nn.Sequential): def __init__( self, a, b, ks=1, stride=1, pad=0, dilation=1, groups=1, bn_weight_init=1, resolution=-10000): super().__init__() - self.add_module('c', torch.nn.Conv2d(a, b, ks, stride, pad, dilation, groups, bias=False)) - bn = torch.nn.BatchNorm2d(b) - torch.nn.init.constant_(bn.weight, bn_weight_init) - torch.nn.init.constant_(bn.bias, 0) + self.add_module('c', nn.Conv2d(a, b, ks, stride, pad, dilation, groups, bias=False)) + bn = nn.BatchNorm2d(b) + nn.init.constant_(bn.weight, bn_weight_init) + nn.init.constant_(bn.bias, 0) self.add_module('bn', bn) @torch.no_grad() @@ -91,7 +158,7 @@ class ConvNorm(torch.nn.Sequential): w = bn.weight / (bn.running_var + bn.eps) ** 0.5 w = c.weight * w[:, None, None, None] b = bn.bias - bn.running_mean * bn.weight / (bn.running_var + bn.eps) ** 0.5 - m = torch.nn.Conv2d( + m = nn.Conv2d( w.size(1), w.size(0), w.shape[2:], stride=self.c.stride, padding=self.c.padding, dilation=self.c.dilation, groups=self.c.groups) m.weight.data.copy_(w) @@ -99,13 +166,13 @@ class ConvNorm(torch.nn.Sequential): return m -class LinearNorm(torch.nn.Sequential): +class LinearNorm(nn.Sequential): def __init__(self, a, b, bn_weight_init=1, resolution=-100000): super().__init__() - self.add_module('c', torch.nn.Linear(a, b, bias=False)) - bn = torch.nn.BatchNorm1d(b) - torch.nn.init.constant_(bn.weight, bn_weight_init) - torch.nn.init.constant_(bn.bias, 0) + self.add_module('c', nn.Linear(a, b, bias=False)) + bn = nn.BatchNorm1d(b) + nn.init.constant_(bn.weight, bn_weight_init) + nn.init.constant_(bn.bias, 0) self.add_module('bn', bn) @torch.no_grad() @@ -114,25 +181,24 @@ class LinearNorm(torch.nn.Sequential): w = bn.weight / (bn.running_var + bn.eps) ** 0.5 w = l.weight * w[:, None] b = bn.bias - bn.running_mean * bn.weight / (bn.running_var + bn.eps) ** 0.5 - m = torch.nn.Linear(w.size(1), w.size(0)) + m = nn.Linear(w.size(1), w.size(0)) m.weight.data.copy_(w) m.bias.data.copy_(b) return m def forward(self, x): - l, bn = self._modules.values() - x = l(x) - return bn(x.flatten(0, 1)).reshape_as(x) + x = self.c(x) + return self.bn(x.flatten(0, 1)).reshape_as(x) -class NormLinear(torch.nn.Sequential): +class NormLinear(nn.Sequential): def __init__(self, a, b, bias=True, std=0.02): super().__init__() - self.add_module('bn', torch.nn.BatchNorm1d(a)) - l = torch.nn.Linear(a, b, bias=bias) + self.add_module('bn', nn.BatchNorm1d(a)) + l = nn.Linear(a, b, bias=bias) trunc_normal_(l.weight, std=std) if bias: - torch.nn.init.constant_(l.bias, 0) + nn.init.constant_(l.bias, 0) self.add_module('l', l) @torch.no_grad() @@ -145,24 +211,24 @@ class NormLinear(torch.nn.Sequential): b = b @ self.l.weight.T else: b = (l.weight @ b[:, None]).view(-1) + self.l.bias - m = torch.nn.Linear(w.size(1), w.size(0)) + m = nn.Linear(w.size(1), w.size(0)) m.weight.data.copy_(w) m.bias.data.copy_(b) return m -def b16(n, activation, resolution=224): - return torch.nn.Sequential( - ConvNorm(3, n // 8, 3, 2, 1, resolution=resolution), +def stem_b16(in_chs, out_chs, activation, resolution=224): + return nn.Sequential( + ConvNorm(in_chs, out_chs // 8, 3, 2, 1, resolution=resolution), activation(), - ConvNorm(n // 8, n // 4, 3, 2, 1, resolution=resolution // 2), + ConvNorm(out_chs // 8, out_chs // 4, 3, 2, 1, resolution=resolution // 2), activation(), - ConvNorm(n // 4, n // 2, 3, 2, 1, resolution=resolution // 4), + ConvNorm(out_chs // 4, out_chs // 2, 3, 2, 1, resolution=resolution // 4), activation(), - ConvNorm(n // 2, n, 3, 2, 1, resolution=resolution // 8)) + ConvNorm(out_chs // 2, out_chs, 3, 2, 1, resolution=resolution // 8)) -class Residual(torch.nn.Module): +class Residual(nn.Module): def __init__(self, m, drop): super().__init__() self.m = m @@ -176,10 +242,23 @@ class Residual(torch.nn.Module): return x + self.m(x) -class Attention(torch.nn.Module): +class Subsample(nn.Module): + def __init__(self, stride, resolution): + super().__init__() + self.stride = stride + self.resolution = resolution + + def forward(self, x): + B, N, C = x.shape + x = x.view(B, self.resolution, self.resolution, C)[:, ::self.stride, ::self.stride] + return x.reshape(B, -1, C) + + +class Attention(nn.Module): def __init__( - self, dim, key_dim, num_heads=8, attn_ratio=4, act_layer=None, resolution=14): + self, dim, key_dim, num_heads=8, attn_ratio=4, act_layer=None, resolution=14, use_conv=False): super().__init__() + self.num_heads = num_heads self.scale = key_dim ** -0.5 self.key_dim = key_dim @@ -187,11 +266,13 @@ class Attention(torch.nn.Module): self.d = int(attn_ratio * key_dim) self.dh = int(attn_ratio * key_dim) * num_heads self.attn_ratio = attn_ratio + self.use_conv = use_conv + ln_layer = ConvNorm if self.use_conv else LinearNorm h = self.dh + nh_kd * 2 - self.qkv = LinearNorm(dim, h, resolution=resolution) - self.proj = torch.nn.Sequential( + self.qkv = ln_layer(dim, h, resolution=resolution) + self.proj = nn.Sequential( act_layer(), - LinearNorm(self.dh, dim, bn_weight_init=0, resolution=resolution)) + ln_layer(self.dh, dim, bn_weight_init=0, resolution=resolution)) points = list(itertools.product(range(resolution), range(resolution))) N = len(points) @@ -203,68 +284,68 @@ class Attention(torch.nn.Module): if offset not in attention_offsets: attention_offsets[offset] = len(attention_offsets) idxs.append(attention_offsets[offset]) - self.attention_biases = torch.nn.Parameter(torch.zeros(num_heads, len(attention_offsets))) + self.attention_biases = nn.Parameter(torch.zeros(num_heads, len(attention_offsets))) self.register_buffer('attention_bias_idxs', torch.LongTensor(idxs).view(N, N)) + self.ab = None @torch.no_grad() def train(self, mode=True): super().train(mode) - if mode and hasattr(self, 'ab'): - del self.ab + self.ab = None if mode else self.attention_biases[:, self.attention_bias_idxs] + + def forward(self, x): # x (B,C,H,W) + if self.use_conv: + B, C, H, W = x.shape + q, k, v = self.qkv(x).view(B, self.num_heads, -1, H * W).split([self.key_dim, self.key_dim, self.d], dim=2) + ab = self.attention_biases[:, self.attention_bias_idxs] if self.ab is None else self.ab + attn = (q.transpose(-2, -1) @ k) * self.scale + ab + attn = attn.softmax(dim=-1) + x = (v @ attn.transpose(-2, -1)).view(B, -1, H, W) else: - self.ab = self.attention_biases[:, self.attention_bias_idxs] - - def forward(self, x): # x (B,N,C) - B, N, C = x.shape - qkv = self.qkv(x) - q, k, v = qkv.view(B, N, self.num_heads, -1).split([self.key_dim, self.key_dim, self.d], dim=3) - q = q.permute(0, 2, 1, 3) - k = k.permute(0, 2, 1, 3) - v = v.permute(0, 2, 1, 3) - - ab = self.attention_biases[:, self.attention_bias_idxs] if self.training else self.ab - attn = q @ k.transpose(-2, -1) * self.scale + ab - - attn = attn.softmax(dim=-1) - x = (attn @ v).transpose(1, 2).reshape(B, N, self.dh) + B, N, C = x.shape + qkv = self.qkv(x) + q, k, v = qkv.view(B, N, self.num_heads, -1).split([self.key_dim, self.key_dim, self.d], dim=3) + q = q.permute(0, 2, 1, 3) + k = k.permute(0, 2, 1, 3) + v = v.permute(0, 2, 1, 3) + ab = self.attention_biases[:, self.attention_bias_idxs] if self.ab is None else self.ab + attn = q @ k.transpose(-2, -1) * self.scale + ab + attn = attn.softmax(dim=-1) + x = (attn @ v).transpose(1, 2).reshape(B, N, self.dh) x = self.proj(x) return x -class Subsample(torch.nn.Module): - def __init__(self, stride, resolution): - super().__init__() - self.stride = stride - self.resolution = resolution - - def forward(self, x): - B, N, C = x.shape - x = x.view(B, self.resolution, self.resolution, C)[:, ::self.stride, ::self.stride] - return x.reshape(B, -1, C) - - -class AttentionSubsample(torch.nn.Module): - def __init__(self, in_dim, out_dim, key_dim, num_heads=8, - attn_ratio=2, act_layer=None, stride=2, resolution=14, resolution_=7): +class AttentionSubsample(nn.Module): + def __init__( + self, in_dim, out_dim, key_dim, num_heads=8, attn_ratio=2, + act_layer=None, stride=2, resolution=14, resolution_=7, use_conv=False): super().__init__() self.num_heads = num_heads self.scale = key_dim ** -0.5 self.key_dim = key_dim self.nh_kd = nh_kd = key_dim * num_heads self.d = int(attn_ratio * key_dim) - self.dh = int(attn_ratio * key_dim) * self.num_heads + self.dh = self.d * self.num_heads self.attn_ratio = attn_ratio self.resolution_ = resolution_ self.resolution_2 = resolution_ ** 2 - h = self.dh + nh_kd - self.kv = LinearNorm(in_dim, h, resolution=resolution) + self.use_conv = use_conv + if self.use_conv: + ln_layer = ConvNorm + sub_layer = partial(nn.AvgPool2d, kernel_size=1, padding=0) + else: + ln_layer = LinearNorm + sub_layer = partial(Subsample, resolution=resolution) - self.q = torch.nn.Sequential( - Subsample(stride, resolution), - LinearNorm(in_dim, nh_kd, resolution=resolution_)) - self.proj = torch.nn.Sequential( + h = self.dh + nh_kd + self.kv = ln_layer(in_dim, h, resolution=resolution) + self.q = nn.Sequential( + sub_layer(stride=stride), + ln_layer(in_dim, nh_kd, resolution=resolution_)) + self.proj = nn.Sequential( act_layer(), - LinearNorm(self.dh, out_dim, resolution=resolution_)) + ln_layer(self.dh, out_dim, resolution=resolution_)) self.stride = stride self.resolution = resolution @@ -283,35 +364,43 @@ class AttentionSubsample(torch.nn.Module): if offset not in attention_offsets: attention_offsets[offset] = len(attention_offsets) idxs.append(attention_offsets[offset]) - self.attention_biases = torch.nn.Parameter(torch.zeros(num_heads, len(attention_offsets))) + self.attention_biases = nn.Parameter(torch.zeros(num_heads, len(attention_offsets))) self.register_buffer('attention_bias_idxs', torch.LongTensor(idxs).view(N_, N)) - + self.ab = None @torch.no_grad() def train(self, mode=True): super().train(mode) - if mode and hasattr(self, 'ab'): - del self.ab - else: - self.ab = self.attention_biases[:, self.attention_bias_idxs] + self.ab = None if mode else self.attention_biases[:, self.attention_bias_idxs] def forward(self, x): - B, N, C = x.shape - k, v = self.kv(x).view(B, N, self.num_heads, -1).split([self.key_dim, self.d], dim=3) - k = k.permute(0, 2, 1, 3) # BHNC - v = v.permute(0, 2, 1, 3) # BHNC - q = self.q(x).view(B, self.resolution_2, self.num_heads, self.key_dim).permute(0, 2, 1, 3) + if self.use_conv: + B, C, H, W = x.shape + k, v = self.kv(x).view(B, self.num_heads, -1, H * W).split([self.key_dim, self.d], dim=2) + q = self.q(x).view(B, self.num_heads, self.key_dim, self.resolution_2) + + ab = self.attention_biases[:, self.attention_bias_idxs] if self.ab is None else self.ab + attn = (q.transpose(-2, -1) @ k) * self.scale + ab + attn = attn.softmax(dim=-1) + + x = (v @ attn.transpose(-2, -1)).reshape(B, -1, self.resolution_, self.resolution_) + else: + B, N, C = x.shape + k, v = self.kv(x).view(B, N, self.num_heads, -1).split([self.key_dim, self.d], dim=3) + k = k.permute(0, 2, 1, 3) # BHNC + v = v.permute(0, 2, 1, 3) # BHNC + q = self.q(x).view(B, self.resolution_2, self.num_heads, self.key_dim).permute(0, 2, 1, 3) - ab = self.attention_biases[:, self.attention_bias_idxs] if self.training else self.ab - attn = q @ k.transpose(-2, -1) * self.scale + ab - attn = attn.softmax(dim=-1) + ab = self.attention_biases[:, self.attention_bias_idxs] if self.ab is None else self.ab + attn = q @ k.transpose(-2, -1) * self.scale + ab + attn = attn.softmax(dim=-1) - x = (attn @ v).transpose(1, 2).reshape(B, -1, self.dh) + x = (attn @ v).transpose(1, 2).reshape(B, -1, self.dh) x = self.proj(x) return x -class Levit(torch.nn.Module): +class Levit(nn.Module): """ Vision Transformer with support for patch or hybrid CNN input stage """ @@ -321,45 +410,63 @@ class Levit(torch.nn.Module): patch_size=16, in_chans=3, num_classes=1000, - embed_dim=[192], - key_dim=[64], - depth=[12], - num_heads=[3], - attn_ratio=[2], - mlp_ratio=[2], + embed_dim=(192,), + key_dim=64, + depth=(12,), + num_heads=(3,), + attn_ratio=2, + mlp_ratio=2, hybrid_backbone=None, - down_ops=[], - attn_act_layer=torch.nn.Hardswish, - mlp_act_layer=torch.nn.Hardswish, + down_ops=None, + act_layer=nn.Hardswish, + attn_act_layer=nn.Hardswish, distillation=True, + use_conv=False, drop_path=0): super().__init__() - global FLOPS_COUNTER - + if isinstance(img_size, tuple): + # FIXME origin impl passes single img/res dim through whole hierarchy, + # not sure this model will be used enough to spend time fixing it. + assert img_size[0] == img_size[1] + img_size = img_size[0] self.num_classes = num_classes self.num_features = embed_dim[-1] self.embed_dim = embed_dim + N = len(embed_dim) + assert len(depth) == len(num_heads) == N + key_dim = to_ntuple(N)(key_dim) + attn_ratio = to_ntuple(N)(attn_ratio) + mlp_ratio = to_ntuple(N)(mlp_ratio) + down_ops = down_ops or ( + # ('Subsample',key_dim, num_heads, attn_ratio, mlp_ratio, stride) + ('Subsample', key_dim[0], embed_dim[0] // key_dim[0], 4, 2, 2), + ('Subsample', key_dim[0], embed_dim[1] // key_dim[1], 4, 2, 2), + ('',) + ) self.distillation = distillation + self.use_conv = use_conv + ln_layer = ConvNorm if self.use_conv else LinearNorm - self.patch_embed = hybrid_backbone + self.patch_embed = hybrid_backbone or stem_b16(in_chans, embed_dim[0], activation=act_layer) self.blocks = [] - down_ops.append(['']) resolution = img_size // patch_size for i, (ed, kd, dpth, nh, ar, mr, do) in enumerate( zip(embed_dim, key_dim, depth, num_heads, attn_ratio, mlp_ratio, down_ops)): for _ in range(dpth): self.blocks.append( Residual( - Attention(ed, kd, nh, attn_ratio=ar, act_layer=attn_act_layer, resolution=resolution), + Attention( + ed, kd, nh, attn_ratio=ar, act_layer=attn_act_layer, + resolution=resolution, use_conv=use_conv), drop_path)) if mr > 0: h = int(ed * mr) self.blocks.append( - Residual(torch.nn.Sequential( - LinearNorm(ed, h, resolution=resolution), - mlp_act_layer(), - LinearNorm(h, ed, bn_weight_init=0, resolution=resolution), + Residual(nn.Sequential( + ln_layer(ed, h, resolution=resolution), + act_layer(), + ln_layer(h, ed, bn_weight_init=0, resolution=resolution), ), drop_path)) if do[0] == 'Subsample': # ('Subsample',key_dim, num_heads, attn_ratio, mlp_ratio, stride) @@ -368,22 +475,22 @@ class Levit(torch.nn.Module): AttentionSubsample( *embed_dim[i:i + 2], key_dim=do[1], num_heads=do[2], attn_ratio=do[3], act_layer=attn_act_layer, stride=do[5], - resolution=resolution, resolution_=resolution_)) + resolution=resolution, resolution_=resolution_, use_conv=use_conv)) resolution = resolution_ if do[4] > 0: # mlp_ratio h = int(embed_dim[i + 1] * do[4]) self.blocks.append( - Residual(torch.nn.Sequential( - LinearNorm(embed_dim[i + 1], h, resolution=resolution), - mlp_act_layer(), - LinearNorm(h, embed_dim[i + 1], bn_weight_init=0, resolution=resolution), + Residual(nn.Sequential( + ln_layer(embed_dim[i + 1], h, resolution=resolution), + act_layer(), + ln_layer(h, embed_dim[i + 1], bn_weight_init=0, resolution=resolution), ), drop_path)) - self.blocks = torch.nn.Sequential(*self.blocks) + self.blocks = nn.Sequential(*self.blocks) # Classifier head - self.head = NormLinear(embed_dim[-1], num_classes) if num_classes > 0 else torch.nn.Identity() + self.head = NormLinear(embed_dim[-1], num_classes) if num_classes > 0 else nn.Identity() if distillation: - self.head_dist = NormLinear(embed_dim[-1], num_classes) if num_classes > 0 else torch.nn.Identity() + self.head_dist = NormLinear(embed_dim[-1], num_classes) if num_classes > 0 else nn.Identity() else: self.head_dist = None @@ -393,48 +500,44 @@ class Levit(torch.nn.Module): def forward(self, x): x = self.patch_embed(x) - x = x.flatten(2).transpose(1, 2) + if not self.use_conv: + x = x.flatten(2).transpose(1, 2) x = self.blocks(x) - x = x.mean(1) - if self.distillation: - x = self.head(x), self.head_dist(x) - if not self.training: - x = (x[0] + x[1]) / 2 + x = x.mean((-2, -1)) if self.use_conv else x.mean(1) + if self.head_dist is not None: + x, x_dist = self.head(x), self.head_dist(x) + if self.training and not torch.jit.is_scripting(): + return x, x_dist + else: + # during inference, return the average of both classifier predictions + return (x + x_dist) / 2 else: x = self.head(x) return x -def model_factory(C, D, X, N, drop_path, weights, num_classes, distillation, pretrained, fuse): - embed_dim = [int(x) for x in C.split('_')] - num_heads = [int(x) for x in N.split('_')] - depth = [int(x) for x in X.split('_')] - act = torch.nn.Hardswish - model = Levit( - patch_size=16, - embed_dim=embed_dim, - num_heads=num_heads, - key_dim=[D] * 3, - depth=depth, - attn_ratio=[2, 2, 2], - mlp_ratio=[2, 2, 2], - down_ops=[ - # ('Subsample',key_dim, num_heads, attn_ratio, mlp_ratio, stride) - ['Subsample', D, embed_dim[0] // D, 4, 2, 2], - ['Subsample', D, embed_dim[1] // D, 4, 2, 2], - ], - attn_act_layer=act, - mlp_act_layer=act, - hybrid_backbone=b16(embed_dim[0], activation=act), - num_classes=num_classes, - drop_path=drop_path, - distillation=distillation - ) - model.default_cfg = _cfg() - if pretrained: - checkpoint = torch.hub.load_state_dict_from_url(weights, map_location='cpu') - model.load_state_dict(checkpoint['model']) +def checkpoint_filter_fn(state_dict, model): + if 'model' in state_dict: + # For deit models + state_dict = state_dict['model'] + D = model.state_dict() + for k in state_dict.keys(): + if D[k].ndim == 4 and state_dict[k].ndim == 2: + state_dict[k] = state_dict[k][:, :, None, None] + return state_dict + + +def create_levit(variant, pretrained=False, default_cfg=None, fuse=False, **kwargs): + if kwargs.get('features_only', None): + raise RuntimeError('features_only not implemented for Vision Transformer models.') + + model_cfg = dict(**model_cfgs[variant], **kwargs) + model = build_model_with_cfg( + Levit, variant, pretrained, + default_cfg=default_cfgs[variant], + pretrained_filter_fn=checkpoint_filter_fn, + **model_cfg) #if fuse: # utils.replace_batchnorm(model) - return model + diff --git a/timm/models/levitc.py b/timm/models/levitc.py deleted file mode 100644 index 1a422953..00000000 --- a/timm/models/levitc.py +++ /dev/null @@ -1,400 +0,0 @@ -# Copyright (c) 2015-present, Facebook, Inc. -# All rights reserved. - -# Modified from -# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py -# Copyright 2020 Ross Wightman, Apache-2.0 License -import itertools - -import torch - -from timm.data import IMAGENET_DEFAULT_STD, IMAGENET_DEFAULT_MEAN -from .vision_transformer import trunc_normal_ -from .registry import register_model - - -def _cfg(url='', **kwargs): - return { - 'url': url, - 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, - 'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True, - 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, - 'first_conv': 'patch_embed.proj', 'classifier': 'head', - **kwargs - } - - -specification = { - 'levit_c_128s': { - 'C': '128_256_384', 'D': 16, 'N': '4_6_8', 'X': '2_3_4', 'drop_path': 0, - 'weights': 'https://dl.fbaipublicfiles.com/LeViT/LeViT-128S-96703c44.pth'}, - 'levit_c_128': { - 'C': '128_256_384', 'D': 16, 'N': '4_8_12', 'X': '4_4_4', 'drop_path': 0, - 'weights': 'https://dl.fbaipublicfiles.com/LeViT/LeViT-128-b88c2750.pth'}, - 'levit_c_192': { - 'C': '192_288_384', 'D': 32, 'N': '3_5_6', 'X': '4_4_4', 'drop_path': 0, - 'weights': 'https://dl.fbaipublicfiles.com/LeViT/LeViT-192-92712e41.pth'}, - 'levit_c_256': { - 'C': '256_384_512', 'D': 32, 'N': '4_6_8', 'X': '4_4_4', 'drop_path': 0, - 'weights': 'https://dl.fbaipublicfiles.com/LeViT/LeViT-256-13b5763e.pth'}, - 'levit_c_384': { - 'C': '384_512_768', 'D': 32, 'N': '6_9_12', 'X': '4_4_4', 'drop_path': 0.1, - 'weights': 'https://dl.fbaipublicfiles.com/LeViT/LeViT-384-9bdaf2e2.pth'}, -} - -__all__ = ['Levit'] - - -@register_model -def levit_c_128s(num_classes=1000, distillation=True, pretrained=False, fuse=False, **kwargs): - return model_factory(**specification['levit_c_128s'], num_classes=num_classes, - distillation=distillation, pretrained=pretrained, fuse=fuse) - - -@register_model -def levit_c_128(num_classes=1000, distillation=True, pretrained=False, fuse=False, **kwargs): - return model_factory(**specification['levit_c_128'], num_classes=num_classes, - distillation=distillation, pretrained=pretrained, fuse=fuse) - - -@register_model -def levit_c_192(num_classes=1000, distillation=True, pretrained=False, fuse=False, **kwargs): - return model_factory(**specification['levit_c_192'], num_classes=num_classes, - distillation=distillation, pretrained=pretrained, fuse=fuse) - - -@register_model -def levit_c_256(num_classes=1000, distillation=True, pretrained=False, fuse=False, **kwargs): - return model_factory(**specification['levit_c_256'], num_classes=num_classes, - distillation=distillation, pretrained=pretrained, fuse=fuse) - - -@register_model -def levit_c_384(num_classes=1000, distillation=True, pretrained=False, fuse=False, **kwargs): - return model_factory(**specification['levit_c_384'], num_classes=num_classes, - distillation=distillation, pretrained=pretrained, fuse=fuse) - - -class ConvNorm(torch.nn.Sequential): - def __init__( - self, a, b, ks=1, stride=1, pad=0, dilation=1, groups=1, bn_weight_init=1, resolution=-10000): - super().__init__() - self.add_module('c', torch.nn.Conv2d(a, b, ks, stride, pad, dilation, groups, bias=False)) - bn = torch.nn.BatchNorm2d(b) - torch.nn.init.constant_(bn.weight, bn_weight_init) - torch.nn.init.constant_(bn.bias, 0) - self.add_module('bn', bn) - - @torch.no_grad() - def fuse(self): - c, bn = self._modules.values() - w = bn.weight / (bn.running_var + bn.eps) ** 0.5 - w = c.weight * w[:, None, None, None] - b = bn.bias - bn.running_mean * bn.weight / \ - (bn.running_var + bn.eps) ** 0.5 - m = torch.nn.Conv2d( - w.size(1), w.size(0), w.shape[2:], stride=self.c.stride, - padding=self.c.padding, dilation=self.c.dilation, groups=self.c.groups) - m.weight.data.copy_(w) - m.bias.data.copy_(b) - return m - - -class NormLinear(torch.nn.Sequential): - def __init__(self, a, b, bias=True, std=0.02): - super().__init__() - self.add_module('bn', torch.nn.BatchNorm1d(a)) - l = torch.nn.Linear(a, b, bias=bias) - trunc_normal_(l.weight, std=std) - if bias: - torch.nn.init.constant_(l.bias, 0) - self.add_module('l', l) - - @torch.no_grad() - def fuse(self): - bn, l = self._modules.values() - w = bn.weight / (bn.running_var + bn.eps) ** 0.5 - b = bn.bias - self.bn.running_mean * \ - self.bn.weight / (bn.running_var + bn.eps) ** 0.5 - w = l.weight * w[None, :] - if l.bias is None: - b = b @ self.l.weight.T - else: - b = (l.weight @ b[:, None]).view(-1) + self.l.bias - m = torch.nn.Linear(w.size(1), w.size(0)) - m.weight.data.copy_(w) - m.bias.data.copy_(b) - return m - - -def b16(n, activation, resolution=224): - return torch.nn.Sequential( - ConvNorm(3, n // 8, 3, 2, 1, resolution=resolution), - activation(), - ConvNorm(n // 8, n // 4, 3, 2, 1, resolution=resolution // 2), - activation(), - ConvNorm(n // 4, n // 2, 3, 2, 1, resolution=resolution // 4), - activation(), - ConvNorm(n // 2, n, 3, 2, 1, resolution=resolution // 8)) - - -class Residual(torch.nn.Module): - def __init__(self, m, drop): - super().__init__() - self.m = m - self.drop = drop - - def forward(self, x): - if self.training and self.drop > 0: - return x + self.m(x) * torch.rand( - x.size(0), 1, 1, device=x.device).ge_(self.drop).div(1 - self.drop).detach() - else: - return x + self.m(x) - - -class Attention(torch.nn.Module): - def __init__(self, dim, key_dim, num_heads=8, - attn_ratio=4, act_layer=None, resolution=14): - super().__init__() - self.num_heads = num_heads - self.scale = key_dim ** -0.5 - self.key_dim = key_dim - self.nh_kd = nh_kd = key_dim * num_heads - self.d = int(attn_ratio * key_dim) - self.dh = int(attn_ratio * key_dim) * num_heads - self.attn_ratio = attn_ratio - h = self.dh + nh_kd * 2 - self.qkv = ConvNorm(dim, h, resolution=resolution) - self.proj = torch.nn.Sequential( - act_layer(), - ConvNorm(self.dh, dim, bn_weight_init=0, resolution=resolution)) - - points = list(itertools.product(range(resolution), range(resolution))) - N = len(points) - attention_offsets = {} - idxs = [] - for p1 in points: - for p2 in points: - offset = (abs(p1[0] - p2[0]), abs(p1[1] - p2[1])) - if offset not in attention_offsets: - attention_offsets[offset] = len(attention_offsets) - idxs.append(attention_offsets[offset]) - self.attention_biases = torch.nn.Parameter(torch.zeros(num_heads, len(attention_offsets))) - self.register_buffer('attention_bias_idxs', torch.LongTensor(idxs).view(N, N)) - self.ab = None - - @torch.no_grad() - def train(self, mode=True): - super().train(mode) - if mode and self.ab is not None: - self.ab = None - else: - self.ab = self.attention_biases[:, self.attention_bias_idxs] - - def forward(self, x): # x (B,C,H,W) - B, C, H, W = x.shape - q, k, v = self.qkv(x).view(B, self.num_heads, -1, H * W).split([self.key_dim, self.key_dim, self.d], dim=2) - ab = self.attention_biases[:, self.attention_bias_idxs] if self.training else self.ab - attn = (q.transpose(-2, -1) @ k) * self.scale + ab - attn = attn.softmax(dim=-1) - x = (v @ attn.transpose(-2, -1)).view(B, -1, H, W) - x = self.proj(x) - return x - - -class AttentionSubsample(torch.nn.Module): - def __init__( - self, in_dim, out_dim, key_dim, num_heads=8, attn_ratio=2, - act_layer=None, stride=2, resolution=14, resolution_=7): - super().__init__() - self.num_heads = num_heads - self.scale = key_dim ** -0.5 - self.key_dim = key_dim - self.nh_kd = nh_kd = key_dim * num_heads - self.d = int(attn_ratio * key_dim) - self.dh = int(attn_ratio * key_dim) * self.num_heads - self.attn_ratio = attn_ratio - self.resolution_ = resolution_ - self.resolution_2 = resolution_ ** 2 - h = self.dh + nh_kd - self.kv = ConvNorm(in_dim, h, resolution=resolution) - self.q = torch.nn.Sequential( - torch.nn.AvgPool2d(1, stride, 0), - ConvNorm(in_dim, nh_kd, resolution=resolution_)) - self.proj = torch.nn.Sequential( - act_layer(), - ConvNorm(self.d * num_heads, out_dim, resolution=resolution_)) - - self.stride = stride - self.resolution = resolution - points = list(itertools.product(range(resolution), range(resolution))) - points_ = list(itertools.product(range(resolution_), range(resolution_))) - N = len(points) - N_ = len(points_) - attention_offsets = {} - idxs = [] - for p1 in points_: - for p2 in points: - size = 1 - offset = ( - abs(p1[0] * stride - p2[0] + (size - 1) / 2), - abs(p1[1] * stride - p2[1] + (size - 1) / 2)) - if offset not in attention_offsets: - attention_offsets[offset] = len(attention_offsets) - idxs.append(attention_offsets[offset]) - self.attention_biases = torch.nn.Parameter(torch.zeros(num_heads, len(attention_offsets))) - self.register_buffer('attention_bias_idxs', torch.LongTensor(idxs).view(N_, N)) - self.ab = None - - @torch.no_grad() - def train(self, mode=True): - super().train(mode) - if mode and self.ab is not None: - self.ab = None - else: - self.ab = self.attention_biases[:, self.attention_bias_idxs] - - def forward(self, x): - B, C, H, W = x.shape - k, v = self.kv(x).view(B, self.num_heads, -1, H * W).split([self.key_dim, self.d], dim=2) - q = self.q(x).view(B, self.num_heads, self.key_dim, self.resolution_2) - ab = self.attention_biases[:, self.attention_bias_idxs] if self.training else self.ab - attn = (q.transpose(-2, -1) @ k) * self.scale + ab - attn = attn.softmax(dim=-1) - - x = (v @ attn.transpose(-2, -1)).reshape(B, -1, self.resolution_, self.resolution_) - x = self.proj(x) - return x - - -class Levit(torch.nn.Module): - """ Vision Transformer with support for patch or hybrid CNN input stage - """ - - def __init__( - self, - img_size=224, - patch_size=16, - in_chans=3, - num_classes=1000, - embed_dim=[192], - key_dim=[64], - depth=[12], - num_heads=[3], - attn_ratio=[2], - mlp_ratio=[2], - hybrid_backbone=None, - down_ops=[], - attn_act_layer=torch.nn.Hardswish, - mlp_act_layer=torch.nn.Hardswish, - distillation=True, - drop_path=0): - super().__init__() - self.num_classes = num_classes - self.num_features = embed_dim[-1] - self.embed_dim = embed_dim - self.distillation = distillation - - self.patch_embed = hybrid_backbone - - self.blocks = [] - down_ops.append(['']) - resolution = img_size // patch_size - for i, (ed, kd, dpth, nh, ar, mr, do) in enumerate( - zip(embed_dim, key_dim, depth, num_heads, attn_ratio, mlp_ratio, down_ops)): - for _ in range(dpth): - self.blocks.append( - Residual( - Attention(ed, kd, nh, attn_ratio=ar, act_layer=attn_act_layer, resolution=resolution), - drop_path)) - if mr > 0: - h = int(ed * mr) - self.blocks.append( - Residual(torch.nn.Sequential( - ConvNorm(ed, h, resolution=resolution), - mlp_act_layer(), - ConvNorm(h, ed, bn_weight_init=0, resolution=resolution), - ), drop_path)) - if do[0] == 'Subsample': - # ('Subsample',key_dim, num_heads, attn_ratio, mlp_ratio, stride) - resolution_ = (resolution - 1) // do[5] + 1 - self.blocks.append( - AttentionSubsample( - *embed_dim[i:i + 2], key_dim=do[1], num_heads=do[2], attn_ratio=do[3], - act_layer=attn_act_layer, stride=do[5], - resolution=resolution, resolution_=resolution_)) - resolution = resolution_ - if do[4] > 0: # mlp_ratio - h = int(embed_dim[i + 1] * do[4]) - self.blocks.append( - Residual(torch.nn.Sequential( - ConvNorm(embed_dim[i + 1], h, resolution=resolution), - mlp_act_layer(), - ConvNorm(h, embed_dim[i + 1], bn_weight_init=0, resolution=resolution), - ), drop_path)) - self.blocks = torch.nn.Sequential(*self.blocks) - - # Classifier head - self.head = NormLinear( - embed_dim[-1], num_classes) if num_classes > 0 else torch.nn.Identity() - if distillation: - self.head_dist = NormLinear( - embed_dim[-1], num_classes) if num_classes > 0 else torch.nn.Identity() - - @torch.jit.ignore - def no_weight_decay(self): - return {x for x in self.state_dict().keys() if 'attention_biases' in x} - - def forward(self, x): - x = self.patch_embed(x) - x = self.blocks(x) - x = torch.nn.functional.adaptive_avg_pool2d(x, 1).flatten(1) - if self.distillation: - x = self.head(x), self.head_dist(x) - if not self.training: - x = (x[0] + x[1]) / 2 - else: - x = self.head(x) - return x - - -def model_factory(C, D, X, N, drop_path, weights, num_classes, distillation, pretrained, fuse): - embed_dim = [int(x) for x in C.split('_')] - num_heads = [int(x) for x in N.split('_')] - depth = [int(x) for x in X.split('_')] - act = torch.nn.Hardswish - model = Levit( - patch_size=16, - embed_dim=embed_dim, - num_heads=num_heads, - key_dim=[D] * 3, - depth=depth, - attn_ratio=[2, 2, 2], - mlp_ratio=[2, 2, 2], - down_ops=[ - # ('Subsample',key_dim, num_heads, attn_ratio, mlp_ratio, stride) - ['Subsample', D, embed_dim[0] // D, 4, 2, 2], - ['Subsample', D, embed_dim[1] // D, 4, 2, 2], - ], - attn_act_layer=act, - mlp_act_layer=act, - hybrid_backbone=b16(embed_dim[0], activation=act), - num_classes=num_classes, - drop_path=drop_path, - distillation=distillation - ) - model.default_cfg = _cfg() - if pretrained: - checkpoint = torch.hub.load_state_dict_from_url( - weights, map_location='cpu') - d = checkpoint['model'] - D = model.state_dict() - for k in d.keys(): - if D[k].shape != d[k].shape: - d[k] = d[k][:, :, None, None] - model.load_state_dict(d) - #if fuse: - # utils.replace_batchnorm(model) - - return model - diff --git a/timm/models/mlp_mixer.py b/timm/models/mlp_mixer.py index 92ca115b..5a6dce6f 100644 --- a/timm/models/mlp_mixer.py +++ b/timm/models/mlp_mixer.py @@ -273,25 +273,14 @@ def _init_weights(m, n: str, head_bias: float = 0.): nn.init.ones_(m.weight) -def _create_mixer(variant, pretrained=False, default_cfg=None, **kwargs): - if default_cfg is None: - default_cfg = deepcopy(default_cfgs[variant]) - overlay_external_default_cfg(default_cfg, kwargs) - default_num_classes = default_cfg['num_classes'] - default_img_size = default_cfg['input_size'][-2:] - num_classes = kwargs.pop('num_classes', default_num_classes) - img_size = kwargs.pop('img_size', default_img_size) - +def _create_mixer(variant, pretrained=False, **kwargs): if kwargs.get('features_only', None): raise RuntimeError('features_only not implemented for MLP-Mixer models.') model = build_model_with_cfg( MlpMixer, variant, pretrained, - default_cfg=default_cfg, - img_size=img_size, - num_classes=num_classes, + default_cfg=default_cfgs[variant], **kwargs) - return model diff --git a/timm/models/pit.py b/timm/models/pit.py index 040d96db..9c350861 100644 --- a/timm/models/pit.py +++ b/timm/models/pit.py @@ -251,24 +251,14 @@ def checkpoint_filter_fn(state_dict, model): def _create_pit(variant, pretrained=False, **kwargs): - default_cfg = deepcopy(default_cfgs[variant]) - overlay_external_default_cfg(default_cfg, kwargs) - default_num_classes = default_cfg['num_classes'] - default_img_size = default_cfg['input_size'][-2:] - img_size = kwargs.pop('img_size', default_img_size) - num_classes = kwargs.pop('num_classes', default_num_classes) - if kwargs.get('features_only', None): raise RuntimeError('features_only not implemented for Vision Transformer models.') model = build_model_with_cfg( PoolingVisionTransformer, variant, pretrained, - default_cfg=default_cfg, - img_size=img_size, - num_classes=num_classes, + default_cfg=default_cfgs[variant], pretrained_filter_fn=checkpoint_filter_fn, **kwargs) - return model diff --git a/timm/models/tnt.py b/timm/models/tnt.py index 8e038718..8186cc4a 100644 --- a/timm/models/tnt.py +++ b/timm/models/tnt.py @@ -12,7 +12,7 @@ import torch.nn as nn from functools import partial from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from timm.models.helpers import load_pretrained +from timm.models.helpers import build_model_with_cfg from timm.models.layers import Mlp, DropPath, trunc_normal_ from timm.models.layers.helpers import to_2tuple from timm.models.registry import register_model @@ -238,24 +238,31 @@ def checkpoint_filter_fn(state_dict, model): return state_dict +def _create_tnt(variant, pretrained=False, **kwargs): + if kwargs.get('features_only', None): + raise RuntimeError('features_only not implemented for Vision Transformer models.') + + model = build_model_with_cfg( + TNT, variant, pretrained, + default_cfg=default_cfgs[variant], + pretrained_filter_fn=checkpoint_filter_fn, + **kwargs) + return model + + @register_model def tnt_s_patch16_224(pretrained=False, **kwargs): - model = TNT(patch_size=16, embed_dim=384, in_dim=24, depth=12, num_heads=6, in_num_head=4, + model_cfg = dict( + patch_size=16, embed_dim=384, in_dim=24, depth=12, num_heads=6, in_num_head=4, qkv_bias=False, **kwargs) - model.default_cfg = default_cfgs['tnt_s_patch16_224'] - if pretrained: - load_pretrained( - model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3), - filter_fn=checkpoint_filter_fn) + model = _create_tnt('tnt_s_patch16_224', pretrained=pretrained, **model_cfg) return model @register_model def tnt_b_patch16_224(pretrained=False, **kwargs): - model = TNT(patch_size=16, embed_dim=640, in_dim=40, depth=12, num_heads=10, in_num_head=4, + model_cfg = dict( + patch_size=16, embed_dim=640, in_dim=40, depth=12, num_heads=10, in_num_head=4, qkv_bias=False, **kwargs) - model.default_cfg = default_cfgs['tnt_b_patch16_224'] - if pretrained: - load_pretrained( - model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3)) + model = _create_tnt('tnt_b_patch16_224', pretrained=pretrained, **model_cfg) return model diff --git a/timm/models/twins.py b/timm/models/twins.py index a534d174..793d2ede 100644 --- a/timm/models/twins.py +++ b/timm/models/twins.py @@ -33,7 +33,7 @@ def _cfg(url='', **kwargs): 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, 'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True, 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, - 'first_conv': 'patch_embed.proj', 'classifier': 'head', + 'first_conv': 'patch_embeds.0.proj', 'classifier': 'head', **kwargs } @@ -361,25 +361,14 @@ class Twins(nn.Module): return x -def _create_twins(variant, pretrained=False, default_cfg=None, **kwargs): - if default_cfg is None: - default_cfg = deepcopy(default_cfgs[variant]) - overlay_external_default_cfg(default_cfg, kwargs) - default_num_classes = default_cfg['num_classes'] - default_img_size = default_cfg['input_size'][-2:] - - num_classes = kwargs.pop('num_classes', default_num_classes) - img_size = kwargs.pop('img_size', default_img_size) +def _create_twins(variant, pretrained=False, **kwargs): if kwargs.get('features_only', None): raise RuntimeError('features_only not implemented for Vision Transformer models.') model = build_model_with_cfg( Twins, variant, pretrained, - default_cfg=default_cfg, - img_size=img_size, - num_classes=num_classes, + default_cfg=default_cfgs[variant], **kwargs) - return model diff --git a/timm/models/visformer.py b/timm/models/visformer.py index aa3bca57..df1d502a 100644 --- a/timm/models/visformer.py +++ b/timm/models/visformer.py @@ -1,3 +1,12 @@ +""" Visformer + +Paper: Visformer: The Vision-friendly Transformer - https://arxiv.org/abs/2104.12533 + +From original at https://github.com/danczs/Visformer + +""" +from copy import deepcopy + import torch import torch.nn as nn import torch.nn.functional as F @@ -22,6 +31,12 @@ def _cfg(url='', **kwargs): } +default_cfgs = dict( + visformer_tiny=_cfg(), + visformer_small=_cfg(), +) + + class LayerNormBHWC(nn.LayerNorm): def __init__(self, dim): super().__init__(dim) @@ -300,87 +315,97 @@ class Visformer(nn.Module): return x +def _create_visformer(variant, pretrained=False, default_cfg=None, **kwargs): + if kwargs.get('features_only', None): + raise RuntimeError('features_only not implemented for Vision Transformer models.') + model = build_model_with_cfg( + Visformer, variant, pretrained, + default_cfg=default_cfgs[variant], + **kwargs) + return model + + @register_model def visformer_tiny(pretrained=False, **kwargs): - model = Visformer( + model_cfg = dict( img_size=224, init_channels=16, embed_dim=192, depth=(7, 4, 4), num_heads=3, mlp_ratio=4., group=8, attn_stage='011', spatial_conv='100', norm_layer=nn.BatchNorm2d, conv_init=True, embed_norm=nn.BatchNorm2d, **kwargs) - model.default_cfg = _cfg() + model = _create_visformer('visformer_tiny', pretrained=pretrained, **model_cfg) return model @register_model def visformer_small(pretrained=False, **kwargs): - model = Visformer( + model_cfg = dict( img_size=224, init_channels=32, embed_dim=384, depth=(7, 4, 4), num_heads=6, mlp_ratio=4., group=8, attn_stage='011', spatial_conv='100', norm_layer=nn.BatchNorm2d, conv_init=True, embed_norm=nn.BatchNorm2d, **kwargs) - model.default_cfg = _cfg() + model = _create_visformer('visformer_small', pretrained=pretrained, **model_cfg) return model -@register_model -def visformer_net1(pretrained=False, **kwargs): - model = Visformer( - init_channels=None, embed_dim=384, depth=(0, 12, 0), num_heads=6, mlp_ratio=4., attn_stage='111', - spatial_conv='000', vit_stem=True, conv_init=True, **kwargs) - model.default_cfg = _cfg() - return model - - -@register_model -def visformer_net2(pretrained=False, **kwargs): - model = Visformer( - init_channels=32, embed_dim=384, depth=(0, 12, 0), num_heads=6, mlp_ratio=4., attn_stage='111', - spatial_conv='000', vit_stem=False, conv_init=True, **kwargs) - model.default_cfg = _cfg() - return model - - -@register_model -def visformer_net3(pretrained=False, **kwargs): - model = Visformer( - init_channels=32, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4., attn_stage='111', - spatial_conv='000', vit_stem=False, conv_init=True, **kwargs) - model.default_cfg = _cfg() - return model - - -@register_model -def visformer_net4(pretrained=False, **kwargs): - model = Visformer( - init_channels=32, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4., attn_stage='111', - spatial_conv='000', vit_stem=False, conv_init=True, **kwargs) - model.default_cfg = _cfg() - return model - - -@register_model -def visformer_net5(pretrained=False, **kwargs): - model = Visformer( - init_channels=32, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4., group=1, attn_stage='111', - spatial_conv='111', vit_stem=False, conv_init=True, **kwargs) - model.default_cfg = _cfg() - return model - - -@register_model -def visformer_net6(pretrained=False, **kwargs): - model = Visformer( - init_channels=32, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4., group=1, attn_stage='111', - pos_embed=False, spatial_conv='111', conv_init=True, **kwargs) - model.default_cfg = _cfg() - return model - - -@register_model -def visformer_net7(pretrained=False, **kwargs): - model = Visformer( - init_channels=32, embed_dim=384, depth=(6, 7, 7), num_heads=6, group=1, attn_stage='000', - pos_embed=False, spatial_conv='111', conv_init=True, **kwargs) - model.default_cfg = _cfg() - return model +# @register_model +# def visformer_net1(pretrained=False, **kwargs): +# model = Visformer( +# init_channels=None, embed_dim=384, depth=(0, 12, 0), num_heads=6, mlp_ratio=4., attn_stage='111', +# spatial_conv='000', vit_stem=True, conv_init=True, **kwargs) +# model.default_cfg = _cfg() +# return model +# +# +# @register_model +# def visformer_net2(pretrained=False, **kwargs): +# model = Visformer( +# init_channels=32, embed_dim=384, depth=(0, 12, 0), num_heads=6, mlp_ratio=4., attn_stage='111', +# spatial_conv='000', vit_stem=False, conv_init=True, **kwargs) +# model.default_cfg = _cfg() +# return model +# +# +# @register_model +# def visformer_net3(pretrained=False, **kwargs): +# model = Visformer( +# init_channels=32, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4., attn_stage='111', +# spatial_conv='000', vit_stem=False, conv_init=True, **kwargs) +# model.default_cfg = _cfg() +# return model +# +# +# @register_model +# def visformer_net4(pretrained=False, **kwargs): +# model = Visformer( +# init_channels=32, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4., attn_stage='111', +# spatial_conv='000', vit_stem=False, conv_init=True, **kwargs) +# model.default_cfg = _cfg() +# return model +# +# +# @register_model +# def visformer_net5(pretrained=False, **kwargs): +# model = Visformer( +# init_channels=32, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4., group=1, attn_stage='111', +# spatial_conv='111', vit_stem=False, conv_init=True, **kwargs) +# model.default_cfg = _cfg() +# return model +# +# +# @register_model +# def visformer_net6(pretrained=False, **kwargs): +# model = Visformer( +# init_channels=32, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4., group=1, attn_stage='111', +# pos_embed=False, spatial_conv='111', conv_init=True, **kwargs) +# model.default_cfg = _cfg() +# return model +# +# +# @register_model +# def visformer_net7(pretrained=False, **kwargs): +# model = Visformer( +# init_channels=32, embed_dim=384, depth=(6, 7, 7), num_heads=6, group=1, attn_stage='000', +# pos_embed=False, spatial_conv='111', conv_init=True, **kwargs) +# model.default_cfg = _cfg() +# return model diff --git a/timm/models/vision_transformer.py b/timm/models/vision_transformer.py index bef6dfb0..ff74d836 100644 --- a/timm/models/vision_transformer.py +++ b/timm/models/vision_transformer.py @@ -387,21 +387,20 @@ def checkpoint_filter_fn(state_dict, model): v = v.reshape(O, -1, H, W) elif k == 'pos_embed' and v.shape != model.pos_embed.shape: # To resize pos embedding when using model at different size from pretrained weights - v = resize_pos_embed(v, model.pos_embed, getattr(model, 'num_tokens', 1), - model.patch_embed.grid_size) + v = resize_pos_embed( + v, model.pos_embed, getattr(model, 'num_tokens', 1), model.patch_embed.grid_size) out_dict[k] = v return out_dict def _create_vision_transformer(variant, pretrained=False, default_cfg=None, **kwargs): - if default_cfg is None: - default_cfg = deepcopy(default_cfgs[variant]) - overlay_external_default_cfg(default_cfg, kwargs) - default_num_classes = default_cfg['num_classes'] - default_img_size = default_cfg['input_size'][-2:] + default_cfg = default_cfg or default_cfgs[variant] + if kwargs.get('features_only', None): + raise RuntimeError('features_only not implemented for Vision Transformer models.') - num_classes = kwargs.pop('num_classes', default_num_classes) - img_size = kwargs.pop('img_size', default_img_size) + # NOTE this extra code to support handling of repr size for in21k pretrained models + default_num_classes = default_cfg['num_classes'] + num_classes = kwargs.get('num_classes', default_num_classes) repr_size = kwargs.pop('representation_size', None) if repr_size is not None and num_classes != default_num_classes: # Remove representation layer if fine-tuning. This may not always be the desired action, @@ -409,18 +408,12 @@ def _create_vision_transformer(variant, pretrained=False, default_cfg=None, **kw _logger.warning("Removing representation layer for fine-tuning.") repr_size = None - if kwargs.get('features_only', None): - raise RuntimeError('features_only not implemented for Vision Transformer models.') - model = build_model_with_cfg( VisionTransformer, variant, pretrained, default_cfg=default_cfg, - img_size=img_size, - num_classes=num_classes, representation_size=repr_size, pretrained_filter_fn=checkpoint_filter_fn, **kwargs) - return model diff --git a/timm/models/vision_transformer_hybrid.py b/timm/models/vision_transformer_hybrid.py index 1656559f..9e5a62b2 100644 --- a/timm/models/vision_transformer_hybrid.py +++ b/timm/models/vision_transformer_hybrid.py @@ -27,7 +27,7 @@ def _cfg(url='', **kwargs): return { 'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, - 'crop_pct': .9, 'interpolation': 'bicubic', + 'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True, 'mean': (0.5, 0.5, 0.5), 'std': (0.5, 0.5, 0.5), 'first_conv': 'patch_embed.backbone.stem.conv', 'classifier': 'head', **kwargs @@ -107,11 +107,10 @@ class HybridEmbed(nn.Module): def _create_vision_transformer_hybrid(variant, backbone, pretrained=False, **kwargs): - default_cfg = deepcopy(default_cfgs[variant]) embed_layer = partial(HybridEmbed, backbone=backbone) kwargs.setdefault('patch_size', 1) # default patch size for hybrid models if not set return _create_vision_transformer( - variant, pretrained=pretrained, default_cfg=default_cfg, embed_layer=embed_layer, **kwargs) + variant, pretrained=pretrained, embed_layer=embed_layer, default_cfg=default_cfgs[variant], **kwargs) def _resnetv2(layers=(3, 4, 9), **kwargs): From 83487e2a0dd93e9112ee19ee0c8dcdf393463440 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Mon, 24 May 2021 21:36:56 -0700 Subject: [PATCH 26/48] Lower max backward size for tests. --- tests/test_models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_models.py b/tests/test_models.py index 570b49db..c1632271 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -32,7 +32,7 @@ else: TARGET_FWD_SIZE = MAX_FWD_SIZE = 384 TARGET_BWD_SIZE = 128 -MAX_BWD_SIZE = 384 +MAX_BWD_SIZE = 320 MAX_FWD_FEAT_SIZE = 448 From c4572cc5aa21dfa6c81b4ef6a3479409e49561f0 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Mon, 24 May 2021 22:50:12 -0700 Subject: [PATCH 27/48] Add Visformer-small weighs, tweak torchscript jit test img size. --- tests/test_models.py | 18 +++++++++++------- timm/models/visformer.py | 4 +++- 2 files changed, 14 insertions(+), 8 deletions(-) diff --git a/tests/test_models.py b/tests/test_models.py index c1632271..e6a73619 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -33,7 +33,11 @@ else: TARGET_FWD_SIZE = MAX_FWD_SIZE = 384 TARGET_BWD_SIZE = 128 MAX_BWD_SIZE = 320 -MAX_FWD_FEAT_SIZE = 448 +MAX_FWD_OUT_SIZE = 448 +TARGET_JIT_SIZE = 128 +MAX_JIT_SIZE = 320 +TARGET_FFEAT_SIZE = 96 +MAX_FFEAT_SIZE = 256 def _get_input_size(model, target=None): @@ -109,10 +113,10 @@ def test_model_default_cfgs(model_name, batch_size): pool_size = cfg['pool_size'] input_size = model.default_cfg['input_size'] - if all([x <= MAX_FWD_FEAT_SIZE for x in input_size]) and \ + if all([x <= MAX_FWD_OUT_SIZE for x in input_size]) and \ not any([fnmatch.fnmatch(model_name, x) for x in EXCLUDE_FILTERS]): # output sizes only checked if default res <= 448 * 448 to keep resource down - input_size = tuple([min(x, MAX_FWD_FEAT_SIZE) for x in input_size]) + input_size = tuple([min(x, MAX_FWD_OUT_SIZE) for x in input_size]) input_tensor = torch.randn((batch_size, *input_size)) # test forward_features (always unpooled) @@ -176,8 +180,8 @@ def test_model_forward_torchscript(model_name, batch_size): model = create_model(model_name, pretrained=False) model.eval() - input_size = _get_input_size(model, 128) - if max(input_size) > MAX_FWD_SIZE: # NOTE using MAX_FWD_SIZE as the final limit is intentional + input_size = _get_input_size(model, TARGET_JIT_SIZE) + if max(input_size) > MAX_JIT_SIZE: # NOTE using MAX_FWD_SIZE as the final limit is intentional pytest.skip("Fixed input size model > limit.") model = torch.jit.script(model) @@ -205,8 +209,8 @@ def test_model_forward_features(model_name, batch_size): expected_channels = model.feature_info.channels() assert len(expected_channels) >= 4 # all models here should have at least 4 feature levels by default, some 5 or 6 - input_size = _get_input_size(model, 96) # jit compile is already a bit slow and we've tested normal res already... - if max(input_size) > MAX_FWD_SIZE: # NOTE using MAX_FWD_SIZE as the final limit is intentional + input_size = _get_input_size(model, TARGET_FFEAT_SIZE) + if max(input_size) > MAX_FFEAT_SIZE: pytest.skip("Fixed input size model > limit.") outputs = model(torch.randn((batch_size, *input_size))) diff --git a/timm/models/visformer.py b/timm/models/visformer.py index df1d502a..936f1ddf 100644 --- a/timm/models/visformer.py +++ b/timm/models/visformer.py @@ -33,7 +33,9 @@ def _cfg(url='', **kwargs): default_cfgs = dict( visformer_tiny=_cfg(), - visformer_small=_cfg(), + visformer_small=_cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vt3p-weights/visformer_small-839e1f5b.pth' + ), ) From d400f1dbddf71faed78e485e0797465298510815 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Tue, 25 May 2021 10:14:45 -0700 Subject: [PATCH 28/48] Filter test models before creation for backward/torchscript tests --- tests/test_models.py | 44 ++++++++++++++++++++++++++--------------- timm/models/registry.py | 7 +++++-- 2 files changed, 33 insertions(+), 18 deletions(-) diff --git a/tests/test_models.py b/tests/test_models.py index e6a73619..63a95fa5 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -40,14 +40,25 @@ TARGET_FFEAT_SIZE = 96 MAX_FFEAT_SIZE = 256 -def _get_input_size(model, target=None): - default_cfg = model.default_cfg - input_size = default_cfg['input_size'] - if 'fixed_input_size' in default_cfg and default_cfg['fixed_input_size']: +def _get_input_size(model=None, model_name='', target=None): + if model is None: + assert model_name, "One of model or model_name must be provided" + input_size = get_model_default_value(model_name, 'input_size') + fixed_input_size = get_model_default_value(model_name, 'fixed_input_size') + min_input_size = get_model_default_value(model_name, 'min_input_size') + else: + default_cfg = model.default_cfg + input_size = default_cfg['input_size'] + fixed_input_size = default_cfg.get('fixed_input_size', None) + min_input_size = default_cfg.get('min_input_size', None) + assert input_size is not None + + if fixed_input_size: return input_size - if 'min_input_size' in default_cfg: + + if min_input_size: if target and max(input_size) > target: - input_size = default_cfg['min_input_size'] + input_size = min_input_size else: if target and max(input_size) > target: input_size = tuple([min(x, target) for x in input_size]) @@ -73,18 +84,18 @@ def test_model_forward(model_name, batch_size): @pytest.mark.timeout(120) -@pytest.mark.parametrize('model_name', list_models(exclude_filters=EXCLUDE_FILTERS)) +@pytest.mark.parametrize('model_name', list_models(exclude_filters=EXCLUDE_FILTERS, name_matches_cfg=True)) @pytest.mark.parametrize('batch_size', [2]) def test_model_backward(model_name, batch_size): """Run a single forward pass with each model""" + input_size = _get_input_size(model_name=model_name, target=TARGET_BWD_SIZE) + if max(input_size) > MAX_BWD_SIZE: + pytest.skip("Fixed input size model > limit.") + model = create_model(model_name, pretrained=False, num_classes=42) num_params = sum([x.numel() for x in model.parameters()]) model.train() - input_size = _get_input_size(model, TARGET_BWD_SIZE) - if max(input_size) > MAX_BWD_SIZE: - pytest.skip("Fixed input size model > limit.") - inputs = torch.randn((batch_size, *input_size)) outputs = model(inputs) if isinstance(outputs, tuple): @@ -172,18 +183,19 @@ EXCLUDE_JIT_FILTERS = [ @pytest.mark.timeout(120) -@pytest.mark.parametrize('model_name', list_models(exclude_filters=EXCLUDE_FILTERS + EXCLUDE_JIT_FILTERS)) +@pytest.mark.parametrize( + 'model_name', list_models(exclude_filters=EXCLUDE_FILTERS + EXCLUDE_JIT_FILTERS, name_matches_cfg=True)) @pytest.mark.parametrize('batch_size', [1]) def test_model_forward_torchscript(model_name, batch_size): """Run a single forward pass with each model""" + input_size = _get_input_size(model_name=model_name, target=TARGET_JIT_SIZE) + if max(input_size) > MAX_JIT_SIZE: # NOTE using MAX_FWD_SIZE as the final limit is intentional + pytest.skip("Fixed input size model > limit.") + with set_scriptable(True): model = create_model(model_name, pretrained=False) model.eval() - input_size = _get_input_size(model, TARGET_JIT_SIZE) - if max(input_size) > MAX_JIT_SIZE: # NOTE using MAX_FWD_SIZE as the final limit is intentional - pytest.skip("Fixed input size model > limit.") - model = torch.jit.script(model) outputs = model(torch.randn((batch_size, *input_size))) diff --git a/timm/models/registry.py b/timm/models/registry.py index 9172ac7e..6927b6d6 100644 --- a/timm/models/registry.py +++ b/timm/models/registry.py @@ -50,7 +50,7 @@ def _natural_key(string_): return [int(s) if s.isdigit() else s for s in re.split(r'(\d+)', string_.lower())] -def list_models(filter='', module='', pretrained=False, exclude_filters=''): +def list_models(filter='', module='', pretrained=False, exclude_filters='', name_matches_cfg=False): """ Return list of available model names, sorted alphabetically Args: @@ -58,6 +58,7 @@ def list_models(filter='', module='', pretrained=False, exclude_filters=''): module (str) - Limit model selection to a specific sub-module (ie 'gen_efficientnet') pretrained (bool) - Include only models with pretrained weights if True exclude_filters (str or list[str]) - Wildcard filters to exclude models after including them with filter + name_matches_cfg (bool) - Include only models w/ model_name matching default_cfg name (excludes some aliases) Example: model_list('gluon_resnet*') -- returns all models starting with 'gluon_resnet' @@ -70,7 +71,7 @@ def list_models(filter='', module='', pretrained=False, exclude_filters=''): if filter: models = fnmatch.filter(models, filter) # include these models if exclude_filters: - if not isinstance(exclude_filters, list): + if not isinstance(exclude_filters, (tuple, list)): exclude_filters = [exclude_filters] for xf in exclude_filters: exclude_models = fnmatch.filter(models, xf) # exclude these models @@ -78,6 +79,8 @@ def list_models(filter='', module='', pretrained=False, exclude_filters=''): models = set(models).difference(exclude_models) if pretrained: models = _model_has_pretrained.intersection(models) + if name_matches_cfg: + models = set(_model_default_cfgs).intersection(models) return list(sorted(models, key=_natural_key)) From 11ae795e99f6146dfada86eeef8dcd8d1dcb8679 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Tue, 25 May 2021 10:15:32 -0700 Subject: [PATCH 29/48] Redo LeViT attention bias caching in a way that works with both torchscript and DataParallel --- timm/models/levit.py | 49 +++++++++++++++++++++++++++++++++----------- 1 file changed, 37 insertions(+), 12 deletions(-) diff --git a/timm/models/levit.py b/timm/models/levit.py index 96a0c85b..5019ee9a 100644 --- a/timm/models/levit.py +++ b/timm/models/levit.py @@ -26,6 +26,7 @@ Modifications by/coyright Copyright 2021 Ross Wightman import itertools from copy import deepcopy from functools import partial +from typing import Dict import torch import torch.nn as nn @@ -255,6 +256,8 @@ class Subsample(nn.Module): class Attention(nn.Module): + ab: Dict[str, torch.Tensor] + def __init__( self, dim, key_dim, num_heads=8, attn_ratio=4, act_layer=None, resolution=14, use_conv=False): super().__init__() @@ -286,20 +289,31 @@ class Attention(nn.Module): idxs.append(attention_offsets[offset]) self.attention_biases = nn.Parameter(torch.zeros(num_heads, len(attention_offsets))) self.register_buffer('attention_bias_idxs', torch.LongTensor(idxs).view(N, N)) - self.ab = None + self.ab = {} @torch.no_grad() def train(self, mode=True): super().train(mode) - self.ab = None if mode else self.attention_biases[:, self.attention_bias_idxs] + if mode and self.ab: + self.ab = {} # clear ab cache + + def get_attention_biases(self, device: torch.device) -> torch.Tensor: + if self.training: + return self.attention_biases[:, self.attention_bias_idxs] + else: + device_key = str(device) + if device_key not in self.ab: + self.ab[device_key] = self.attention_biases[:, self.attention_bias_idxs] + return self.ab[device_key] def forward(self, x): # x (B,C,H,W) if self.use_conv: B, C, H, W = x.shape q, k, v = self.qkv(x).view(B, self.num_heads, -1, H * W).split([self.key_dim, self.key_dim, self.d], dim=2) - ab = self.attention_biases[:, self.attention_bias_idxs] if self.ab is None else self.ab - attn = (q.transpose(-2, -1) @ k) * self.scale + ab + + attn = (q.transpose(-2, -1) @ k) * self.scale + self.get_attention_biases(x.device) attn = attn.softmax(dim=-1) + x = (v @ attn.transpose(-2, -1)).view(B, -1, H, W) else: B, N, C = x.shape @@ -308,15 +322,18 @@ class Attention(nn.Module): q = q.permute(0, 2, 1, 3) k = k.permute(0, 2, 1, 3) v = v.permute(0, 2, 1, 3) - ab = self.attention_biases[:, self.attention_bias_idxs] if self.ab is None else self.ab - attn = q @ k.transpose(-2, -1) * self.scale + ab + + attn = q @ k.transpose(-2, -1) * self.scale + self.get_attention_biases(x.device) attn = attn.softmax(dim=-1) + x = (attn @ v).transpose(1, 2).reshape(B, N, self.dh) x = self.proj(x) return x class AttentionSubsample(nn.Module): + ab: Dict[str, torch.Tensor] + def __init__( self, in_dim, out_dim, key_dim, num_heads=8, attn_ratio=2, act_layer=None, stride=2, resolution=14, resolution_=7, use_conv=False): @@ -366,12 +383,22 @@ class AttentionSubsample(nn.Module): idxs.append(attention_offsets[offset]) self.attention_biases = nn.Parameter(torch.zeros(num_heads, len(attention_offsets))) self.register_buffer('attention_bias_idxs', torch.LongTensor(idxs).view(N_, N)) - self.ab = None + self.ab = {} # per-device attention_biases cache @torch.no_grad() def train(self, mode=True): super().train(mode) - self.ab = None if mode else self.attention_biases[:, self.attention_bias_idxs] + if mode and self.ab: + self.ab = {} # clear ab cache + + def get_attention_biases(self, device: torch.device) -> torch.Tensor: + if self.training: + return self.attention_biases[:, self.attention_bias_idxs] + else: + device_key = str(device) + if device_key not in self.ab: + self.ab[device_key] = self.attention_biases[:, self.attention_bias_idxs] + return self.ab[device_key] def forward(self, x): if self.use_conv: @@ -379,8 +406,7 @@ class AttentionSubsample(nn.Module): k, v = self.kv(x).view(B, self.num_heads, -1, H * W).split([self.key_dim, self.d], dim=2) q = self.q(x).view(B, self.num_heads, self.key_dim, self.resolution_2) - ab = self.attention_biases[:, self.attention_bias_idxs] if self.ab is None else self.ab - attn = (q.transpose(-2, -1) @ k) * self.scale + ab + attn = (q.transpose(-2, -1) @ k) * self.scale + self.get_attention_biases(x.device) attn = attn.softmax(dim=-1) x = (v @ attn.transpose(-2, -1)).reshape(B, -1, self.resolution_, self.resolution_) @@ -391,8 +417,7 @@ class AttentionSubsample(nn.Module): v = v.permute(0, 2, 1, 3) # BHNC q = self.q(x).view(B, self.resolution_2, self.num_heads, self.key_dim).permute(0, 2, 1, 3) - ab = self.attention_biases[:, self.attention_bias_idxs] if self.ab is None else self.ab - attn = q @ k.transpose(-2, -1) * self.scale + ab + attn = q @ k.transpose(-2, -1) * self.scale + self.get_attention_biases(x.device) attn = attn.softmax(dim=-1) x = (attn @ v).transpose(1, 2).reshape(B, -1, self.dh) From 99d97e0d67d302b4518c16ccd0c74a05a533855f Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Tue, 25 May 2021 11:10:17 -0700 Subject: [PATCH 30/48] Hopefully the last test update for this PR... --- .github/workflows/tests.yml | 4 ++-- tests/test_models.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 1cc44acf..9f7aebdb 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -17,8 +17,8 @@ jobs: matrix: os: [ubuntu-latest, macOS-latest] python: ['3.8'] - torch: ['1.8.0'] - torchvision: ['0.9.0'] + torch: ['1.8.1'] + torchvision: ['0.9.1'] runs-on: ${{ matrix.os }} steps: diff --git a/tests/test_models.py b/tests/test_models.py index 63a95fa5..029ae0dd 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -73,7 +73,7 @@ def test_model_forward(model_name, batch_size): model = create_model(model_name, pretrained=False) model.eval() - input_size = _get_input_size(model, TARGET_FWD_SIZE) + input_size = _get_input_size(model=model, target=TARGET_FWD_SIZE) if max(input_size) > MAX_FWD_SIZE: pytest.skip("Fixed input size model > limit.") inputs = torch.randn((batch_size, *input_size)) @@ -221,7 +221,7 @@ def test_model_forward_features(model_name, batch_size): expected_channels = model.feature_info.channels() assert len(expected_channels) >= 4 # all models here should have at least 4 feature levels by default, some 5 or 6 - input_size = _get_input_size(model, TARGET_FFEAT_SIZE) + input_size = _get_input_size(model=model, target=TARGET_FFEAT_SIZE) if max(input_size) > MAX_FFEAT_SIZE: pytest.skip("Fixed input size model > limit.") From 318360c3f97dd9ab457a44db0b5f3af7cf0e9c88 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Tue, 25 May 2021 12:25:53 -0700 Subject: [PATCH 31/48] Update README.md before merge. Bump version to 0.4.10 --- README.md | 37 +++++++++++++------------------------ docs/archived_changes.md | 24 ++++++++++++++++++++++++ docs/changes.md | 28 ++++++++++++++++++++++++++++ timm/version.py | 2 +- 4 files changed, 66 insertions(+), 25 deletions(-) diff --git a/README.md b/README.md index ca283605..8ba95e98 100644 --- a/README.md +++ b/README.md @@ -23,6 +23,14 @@ I'm fortunate to be able to dedicate significant time and money of my own suppor ## What's New +### May 25, 2021 +* Add LeViT, Visformer, ConViT (PR by Aman Arora), Twins (PR by paper authors) transformer models +* Add ResMLP and gMLP MLP vision models to the existing MLP Mixer impl +* Fix a number of torchscript issues with various vision transformer models +* Cleanup input_size/img_size override handling and improve testing / test coverage for all vision transformer and MLP models +* More flexible pos embedding resize (non-square) for ViT and TnT. Thanks [Alexander Soare](https://github.com/alexander-soare) +* Add `efficientnetv2_rw_m` model and weights (started training before official code). 84.8 top-1, 53M params. + ### May 14, 2021 * Add EfficientNet-V2 official model defs w/ ported weights from official [Tensorflow/Keras](https://github.com/google/automl/tree/master/efficientnetv2) impl. * 1k trained variants: `tf_efficientnetv2_s/m/l` @@ -166,30 +174,6 @@ I'm fortunate to be able to dedicate significant time and money of my own suppor * Misc fixes for SiLU ONNX export, default_cfg missing from Feature extraction models, Linear layer w/ AMP + torchscript * PyPi release @ 0.3.2 (needed by EfficientDet) -### Oct 30, 2020 -* Test with PyTorch 1.7 and fix a small top-n metric view vs reshape issue. -* Convert newly added 224x224 Vision Transformer weights from official JAX repo. 81.8 top-1 for B/16, 83.1 L/16. -* Support PyTorch 1.7 optimized, native SiLU (aka Swish) activation. Add mapping to 'silu' name, custom swish will eventually be deprecated. -* Fix regression for loading pretrained classifier via direct model entrypoint functions. Didn't impact create_model() factory usage. -* PyPi release @ 0.3.0 version! - -### Oct 26, 2020 -* Update Vision Transformer models to be compatible with official code release at https://github.com/google-research/vision_transformer -* Add Vision Transformer weights (ImageNet-21k pretrain) for 384x384 base and large models converted from official jax impl - * ViT-B/16 - 84.2 - * ViT-B/32 - 81.7 - * ViT-L/16 - 85.2 - * ViT-L/32 - 81.5 - -### Oct 21, 2020 -* Weights added for Vision Transformer (ViT) models. 77.86 top-1 for 'small' and 79.35 for 'base'. Thanks to [Christof](https://www.kaggle.com/christofhenkel) for training the base model w/ lots of GPUs. - -### Oct 13, 2020 -* Initial impl of Vision Transformer models. Both patch and hybrid (CNN backbone) variants. Currently trying to train... -* Adafactor and AdaHessian (FP32 only, no AMP) optimizers -* EdgeTPU-M (`efficientnet_em`) model trained in PyTorch, 79.3 top-1 -* Pip release, doc updates pending a few more changes... - ## Introduction @@ -207,6 +191,7 @@ A full version of the list below with source links can be found in the [document * Bottleneck Transformers - https://arxiv.org/abs/2101.11605 * CaiT (Class-Attention in Image Transformers) - https://arxiv.org/abs/2103.17239 * CoaT (Co-Scale Conv-Attentional Image Transformers) - https://arxiv.org/abs/2104.06399 +* ConViT (Soft Convolutional Inductive Biases Vision Transformers)- https://arxiv.org/abs/2103.10697 * CspNet (Cross-Stage Partial Networks) - https://arxiv.org/abs/1911.11929 * DeiT (Vision Transformer) - https://arxiv.org/abs/2012.12877 * DenseNet - https://arxiv.org/abs/1608.06993 @@ -224,6 +209,7 @@ A full version of the list below with source links can be found in the [document * MobileNet-V2 - https://arxiv.org/abs/1801.04381 * Single-Path NAS - https://arxiv.org/abs/1904.02877 * GhostNet - https://arxiv.org/abs/1911.11907 +* gMLP - https://arxiv.org/abs/2105.08050 * GPU-Efficient Networks - https://arxiv.org/abs/2006.14090 * Halo Nets - https://arxiv.org/abs/2103.12731 * HardCoRe-NAS - https://arxiv.org/abs/2102.11646 @@ -231,6 +217,7 @@ A full version of the list below with source links can be found in the [document * Inception-V3 - https://arxiv.org/abs/1512.00567 * Inception-ResNet-V2 and Inception-V4 - https://arxiv.org/abs/1602.07261 * Lambda Networks - https://arxiv.org/abs/2102.08602 +* LeViT (Vision Transformer in ConvNet's Clothing) - https://arxiv.org/abs/2104.01136 * MLP-Mixer - https://arxiv.org/abs/2105.01601 * MobileNet-V3 (MBConvNet w/ Efficient Head) - https://arxiv.org/abs/1905.02244 * NASNet-A - https://arxiv.org/abs/1707.07012 @@ -240,6 +227,7 @@ A full version of the list below with source links can be found in the [document * Pooling-based Vision Transformer (PiT) - https://arxiv.org/abs/2103.16302 * RegNet - https://arxiv.org/abs/2003.13678 * RepVGG - https://arxiv.org/abs/2101.03697 +* ResMLP - https://arxiv.org/abs/2105.03404 * ResNet/ResNeXt * ResNet (v1b/v1.5) - https://arxiv.org/abs/1512.03385 * ResNeXt - https://arxiv.org/abs/1611.05431 @@ -257,6 +245,7 @@ A full version of the list below with source links can be found in the [document * Swin Transformer - https://arxiv.org/abs/2103.14030 * Transformer-iN-Transformer (TNT) - https://arxiv.org/abs/2103.00112 * TResNet - https://arxiv.org/abs/2003.13630 +* Twins (Spatial Attention in Vision Transformers) - https://arxiv.org/pdf/2104.13840.pdf * Vision Transformer - https://arxiv.org/abs/2010.11929 * VovNet V2 and V1 - https://arxiv.org/abs/1911.06667 * Xception - https://arxiv.org/abs/1610.02357 diff --git a/docs/archived_changes.md b/docs/archived_changes.md index 857a914d..56ee706f 100644 --- a/docs/archived_changes.md +++ b/docs/archived_changes.md @@ -1,5 +1,29 @@ # Archived Changes +### Oct 30, 2020 +* Test with PyTorch 1.7 and fix a small top-n metric view vs reshape issue. +* Convert newly added 224x224 Vision Transformer weights from official JAX repo. 81.8 top-1 for B/16, 83.1 L/16. +* Support PyTorch 1.7 optimized, native SiLU (aka Swish) activation. Add mapping to 'silu' name, custom swish will eventually be deprecated. +* Fix regression for loading pretrained classifier via direct model entrypoint functions. Didn't impact create_model() factory usage. +* PyPi release @ 0.3.0 version! + +### Oct 26, 2020 +* Update Vision Transformer models to be compatible with official code release at https://github.com/google-research/vision_transformer +* Add Vision Transformer weights (ImageNet-21k pretrain) for 384x384 base and large models converted from official jax impl + * ViT-B/16 - 84.2 + * ViT-B/32 - 81.7 + * ViT-L/16 - 85.2 + * ViT-L/32 - 81.5 + +### Oct 21, 2020 +* Weights added for Vision Transformer (ViT) models. 77.86 top-1 for 'small' and 79.35 for 'base'. Thanks to [Christof](https://www.kaggle.com/christofhenkel) for training the base model w/ lots of GPUs. + +### Oct 13, 2020 +* Initial impl of Vision Transformer models. Both patch and hybrid (CNN backbone) variants. Currently trying to train... +* Adafactor and AdaHessian (FP32 only, no AMP) optimizers +* EdgeTPU-M (`efficientnet_em`) model trained in PyTorch, 79.3 top-1 +* Pip release, doc updates pending a few more changes... + ### Sept 18, 2020 * New ResNet 'D' weights. 72.7 (top-1) ResNet-18-D, 77.1 ResNet-34-D, 80.5 ResNet-50-D * Added a few untrained defs for other ResNet models (66D, 101D, 152D, 200/200D) diff --git a/docs/changes.md b/docs/changes.md index b0ac125c..9719dd65 100644 --- a/docs/changes.md +++ b/docs/changes.md @@ -1,5 +1,33 @@ # Recent Changes +### May 25, 2021 +* Add LeViT, Visformer, Convit (PR by Aman Arora), Twins (PR by paper authors) transformer models +* Cleanup input_size/img_size override handling and testing for all vision transformer models +* Add `efficientnetv2_rw_m` model and weights (started training before official code). 84.8 top-1, 53M params. + +### May 14, 2021 +* Add EfficientNet-V2 official model defs w/ ported weights from official [Tensorflow/Keras](https://github.com/google/automl/tree/master/efficientnetv2) impl. + * 1k trained variants: `tf_efficientnetv2_s/m/l` + * 21k trained variants: `tf_efficientnetv2_s/m/l_in21k` + * 21k pretrained -> 1k fine-tuned: `tf_efficientnetv2_s/m/l_in21ft1k` + * v2 models w/ v1 scaling: `tf_efficientnetv2_b0` through `b3` + * Rename my prev V2 guess `efficientnet_v2s` -> `efficientnetv2_rw_s` + * Some blank `efficientnetv2_*` models in-place for future native PyTorch training + +### May 5, 2021 +* Add MLP-Mixer models and port pretrained weights from [Google JAX impl](https://github.com/google-research/vision_transformer/tree/linen) +* Add CaiT models and pretrained weights from [FB](https://github.com/facebookresearch/deit) +* Add ResNet-RS models and weights from [TF](https://github.com/tensorflow/tpu/tree/master/models/official/resnet/resnet_rs). Thanks [Aman Arora](https://github.com/amaarora) +* Add CoaT models and weights. Thanks [Mohammed Rizin](https://github.com/morizin) +* Add new ImageNet-21k weights & finetuned weights for TResNet, MobileNet-V3, ViT models. Thanks [mrT](https://github.com/mrT23) +* Add GhostNet models and weights. Thanks [Kai Han](https://github.com/iamhankai) +* Update ByoaNet attention modles + * Improve SA module inits + * Hack together experimental stand-alone Swin based attn module and `swinnet` + * Consistent '26t' model defs for experiments. +* Add improved Efficientnet-V2S (prelim model def) weights. 83.8 top-1. +* WandB logging support + ### April 13, 2021 * Add Swin Transformer models and weights from https://github.com/microsoft/Swin-Transformer diff --git a/timm/version.py b/timm/version.py index 2d802716..b94cbb01 100644 --- a/timm/version.py +++ b/timm/version.py @@ -1 +1 @@ -__version__ = '0.4.9' +__version__ = '0.4.10' From fd92ba0de85cb9da474a50d765601b36fee778ba Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Tue, 25 May 2021 12:52:07 -0700 Subject: [PATCH 32/48] Filter large vit models from torchscript tests --- tests/test_models.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/test_models.py b/tests/test_models.py index 029ae0dd..de664068 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -178,7 +178,8 @@ if 'GITHUB_ACTIONS' not in os.environ: EXCLUDE_JIT_FILTERS = [ '*iabn*', 'tresnet*', # models using inplace abn unlikely to ever be scriptable - 'dla*', 'hrnet*', 'ghostnet*', # hopefully fix at some point + 'dla*', 'hrnet*', 'ghostnet*', # hopefully fix at some point + 'vit_large_*', 'vit_huge_*', ] From 51c432150a0019f6193ca53b1e0fc21b86dc2c2a Mon Sep 17 00:00:00 2001 From: Peter Vandenabeele Date: Tue, 25 May 2021 22:42:44 +0200 Subject: [PATCH 33/48] README: fix simple typos --- README.md | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index ca283605..6a8d520e 100644 --- a/README.md +++ b/README.md @@ -39,7 +39,7 @@ I'm fortunate to be able to dedicate significant time and money of my own suppor * Add CoaT models and weights. Thanks [Mohammed Rizin](https://github.com/morizin) * Add new ImageNet-21k weights & finetuned weights for TResNet, MobileNet-V3, ViT models. Thanks [mrT](https://github.com/mrT23) * Add GhostNet models and weights. Thanks [Kai Han](https://github.com/iamhankai) -* Update ByoaNet attention modles +* Update ByoaNet attention modules * Improve SA module inits * Hack together experimental stand-alone Swin based attn module and `swinnet` * Consistent '26t' model defs for experiments. @@ -282,7 +282,7 @@ Several (less common) features that I often utilize in my projects are included. * PyTorch DistributedDataParallel w/ multi-gpu, single process (AMP disabled as it crashes when enabled) * PyTorch w/ single GPU single process (AMP optional) * A dynamic global pool implementation that allows selecting from average pooling, max pooling, average + max, or concat([average, max]) at model creation. All global pooling is adaptive average by default and compatible with pretrained weights. -* A 'Test Time Pool' wrapper that can wrap any of the included models and usually provide improved performance doing inference with input images larger than the training size. Idea adapted from original DPN implementation when I ported (https://github.com/cypw/DPNs) +* A 'Test Time Pool' wrapper that can wrap any of the included models and usually provides improved performance doing inference with input images larger than the training size. Idea adapted from original DPN implementation when I ported (https://github.com/cypw/DPNs) * Learning rate schedulers * Ideas adopted from * [AllenNLP schedulers](https://github.com/allenai/allennlp/tree/master/allennlp/training/learning_rate_schedulers) @@ -329,7 +329,7 @@ The root folder of the repository contains reference train, validation, and infe ## Awesome PyTorch Resources -One of the greatest assets of PyTorch is the community and their contributions. A few of my favourite resources that pair well with the models and componenets here are listed below. +One of the greatest assets of PyTorch is the community and their contributions. A few of my favourite resources that pair well with the models and components here are listed below. ### Object Detection, Instance and Semantic Segmentation * Detectron2 - https://github.com/facebookresearch/detectron2 @@ -353,7 +353,7 @@ One of the greatest assets of PyTorch is the community and their contributions. ## Licenses ### Code -The code here is licensed Apache 2.0. I've taken care to make sure any third party code included or adapted has compatible (permissive) licenses such as MIT, BSD, etc. I've made an effort to avoid any GPL / LGPL conflicts. That said, it is your responsibility to ensure you comply with license here and conditions of any dependent licenses. Where applicable, I've linked the sources/references for various components in docstrings. If you think I've missed anything please create an issue. +The code here is licensed Apache 2.0. I've taken care to make sure any third party code included or adapted has compatible (permissive) licenses such as MIT, BSD, etc. I've made an effort to avoid any GPL / LGPL conflicts. That said, it is your responsibility to ensure you comply with licenses here and conditions of any dependent licenses. Where applicable, I've linked the sources/references for various components in docstrings. If you think I've missed anything please create an issue. ### Pretrained Weights So far all of the pretrained weights available here are pretrained on ImageNet with a select few that have some additional pretraining (see extra note below). ImageNet was released for non-commercial research purposes only (http://www.image-net.org/download-faq). It's not clear what the implications of that are for the use of pretrained weights from that dataset. Any models I have trained with ImageNet are done for research purposes and one should assume that the original dataset license applies to the weights. It's best to seek legal advice if you intend to use the pretrained weights in a commercial product. From 5db74521736f6d6caef4e0cd8aba1ad624ae9390 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Tue, 25 May 2021 14:11:36 -0700 Subject: [PATCH 34/48] Fix visformer in_chans stem handling --- tests/test_models.py | 2 +- timm/models/visformer.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/test_models.py b/tests/test_models.py index de664068..44cb3ba2 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -190,7 +190,7 @@ EXCLUDE_JIT_FILTERS = [ def test_model_forward_torchscript(model_name, batch_size): """Run a single forward pass with each model""" input_size = _get_input_size(model_name=model_name, target=TARGET_JIT_SIZE) - if max(input_size) > MAX_JIT_SIZE: # NOTE using MAX_FWD_SIZE as the final limit is intentional + if max(input_size) > MAX_JIT_SIZE: pytest.skip("Fixed input size model > limit.") with set_scriptable(True): diff --git a/timm/models/visformer.py b/timm/models/visformer.py index 936f1ddf..33a2fe87 100644 --- a/timm/models/visformer.py +++ b/timm/models/visformer.py @@ -26,7 +26,7 @@ def _cfg(url='', **kwargs): 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, 'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True, 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, - 'first_conv': 'patch_embed.proj', 'classifier': 'head', + 'first_conv': 'stem.0', 'classifier': 'head', **kwargs } @@ -183,7 +183,7 @@ class Visformer(nn.Module): img_size //= 8 else: self.stem = nn.Sequential( - nn.Conv2d(3, self.init_channels, 7, stride=2, padding=3, bias=False), + nn.Conv2d(in_chans, self.init_channels, 7, stride=2, padding=3, bias=False), nn.BatchNorm2d(self.init_channels), nn.ReLU(inplace=True) ) From 9c78de8c024bff0acc68b044dfb935366c6185dc Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Wed, 26 May 2021 15:28:42 -0700 Subject: [PATCH 35/48] Fix #661, move hardswish out of default args for LeViT. Enable native torch support for hardswish, hardsigmoid, mish if present. --- tests/test_layers.py | 12 ++--- tests/test_models.py | 2 +- timm/models/efficientnet_blocks.py | 6 +-- timm/models/efficientnet_builder.py | 5 +- timm/models/ghostnet.py | 4 +- timm/models/layers/create_act.py | 74 ++++++++++++++++++----------- timm/models/layers/se.py | 2 +- timm/models/levit.py | 8 ++-- 8 files changed, 66 insertions(+), 47 deletions(-) diff --git a/tests/test_layers.py b/tests/test_layers.py index 714cb444..508a6aae 100644 --- a/tests/test_layers.py +++ b/tests/test_layers.py @@ -8,10 +8,10 @@ from timm.models.layers import create_act_layer, get_act_layer, set_layer_config class MLP(nn.Module): - def __init__(self, act_layer="relu"): + def __init__(self, act_layer="relu", inplace=True): super(MLP, self).__init__() self.fc1 = nn.Linear(1000, 100) - self.act = create_act_layer(act_layer, inplace=True) + self.act = create_act_layer(act_layer, inplace=inplace) self.fc2 = nn.Linear(100, 10) def forward(self, x): @@ -21,14 +21,14 @@ class MLP(nn.Module): return x -def _run_act_layer_grad(act_type): +def _run_act_layer_grad(act_type, inplace=True): x = torch.rand(10, 1000) * 10 - m = MLP(act_layer=act_type) + m = MLP(act_layer=act_type, inplace=inplace) def _run(x, act_layer=''): if act_layer: # replace act layer if set - m.act = create_act_layer(act_layer, inplace=True) + m.act = create_act_layer(act_layer, inplace=inplace) out = m(x) l = (out - 0).pow(2).sum() return l @@ -58,7 +58,7 @@ def test_mish_grad(): def test_hard_sigmoid_grad(): for _ in range(100): - _run_act_layer_grad('hard_sigmoid') + _run_act_layer_grad('hard_sigmoid', inplace=None) def test_hard_swish_grad(): diff --git a/tests/test_models.py b/tests/test_models.py index 44cb3ba2..18298dff 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -110,7 +110,7 @@ def test_model_backward(model_name, batch_size): assert not torch.isnan(outputs).any(), 'Output included NaNs' -@pytest.mark.timeout(120) +@pytest.mark.timeout(300) @pytest.mark.parametrize('model_name', list_models(exclude_filters=NON_STD_FILTERS)) @pytest.mark.parametrize('batch_size', [1]) def test_model_default_cfgs(model_name, batch_size): diff --git a/timm/models/efficientnet_blocks.py b/timm/models/efficientnet_blocks.py index 83b57beb..7853db0e 100644 --- a/timm/models/efficientnet_blocks.py +++ b/timm/models/efficientnet_blocks.py @@ -7,7 +7,7 @@ import torch import torch.nn as nn from torch.nn import functional as F -from .layers import create_conv2d, drop_path, make_divisible +from .layers import create_conv2d, drop_path, make_divisible, get_act_fn, create_act_layer from .layers.activations import sigmoid __all__ = [ @@ -36,9 +36,9 @@ class SqueezeExcite(nn.Module): reduced_chs = make_divisible(reduced_chs * se_ratio, divisor) act_layer = force_act_layer or act_layer self.conv_reduce = nn.Conv2d(in_chs, reduced_chs, 1, bias=True) - self.act1 = act_layer(inplace=True) + self.act1 = create_act_layer(act_layer, inplace=True) self.conv_expand = nn.Conv2d(reduced_chs, in_chs, 1, bias=True) - self.gate_fn = gate_fn + self.gate_fn = get_act_fn(gate_fn) def forward(self, x): x_se = x.mean((2, 3), keepdim=True) diff --git a/timm/models/efficientnet_builder.py b/timm/models/efficientnet_builder.py index 30739454..57e2039b 100644 --- a/timm/models/efficientnet_builder.py +++ b/timm/models/efficientnet_builder.py @@ -50,10 +50,7 @@ def resolve_bn_args(kwargs): def resolve_act_layer(kwargs, default='relu'): - act_layer = kwargs.pop('act_layer', default) - if isinstance(act_layer, str): - act_layer = get_act_layer(act_layer) - return act_layer + return get_act_layer(kwargs.pop('act_layer', default)) def round_channels(channels, multiplier=1.0, divisor=8, channel_min=None, round_limit=0.9): diff --git a/timm/models/ghostnet.py b/timm/models/ghostnet.py index c132142a..1783ff7a 100644 --- a/timm/models/ghostnet.py +++ b/timm/models/ghostnet.py @@ -13,7 +13,7 @@ import torch.nn.functional as F from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from .layers import SelectAdaptivePool2d, Linear, hard_sigmoid, make_divisible +from .layers import SelectAdaptivePool2d, Linear, make_divisible from .efficientnet_blocks import SqueezeExcite, ConvBnAct from .helpers import build_model_with_cfg from .registry import register_model @@ -40,7 +40,7 @@ default_cfgs = { } -_SE_LAYER = partial(SqueezeExcite, gate_fn=hard_sigmoid, divisor=4) +_SE_LAYER = partial(SqueezeExcite, gate_fn='hard_sigmoid', divisor=4) class GhostModule(nn.Module): diff --git a/timm/models/layers/create_act.py b/timm/models/layers/create_act.py index 426c3681..aa557692 100644 --- a/timm/models/layers/create_act.py +++ b/timm/models/layers/create_act.py @@ -1,20 +1,26 @@ """ Activation Factory Hacked together by / Copyright 2020 Ross Wightman """ +from typing import Union, Callable, Type + from .activations import * from .activations_jit import * from .activations_me import * from .config import is_exportable, is_scriptable, is_no_jit -# PyTorch has an optimized, native 'silu' (aka 'swish') operator as of PyTorch 1.7. This code -# will use native version if present. Eventually, the custom Swish layers will be removed -# and only native 'silu' will be used. +# PyTorch has an optimized, native 'silu' (aka 'swish') operator as of PyTorch 1.7. +# Also hardsigmoid, hardswish, and soon mish. This code will use native version if present. +# Eventually, the custom SiLU, Mish, Hard*, layers will be removed and only native variants will be used. _has_silu = 'silu' in dir(torch.nn.functional) +_has_hardswish = 'hardswish' in dir(torch.nn.functional) +_has_hardsigmoid = 'hardsigmoid' in dir(torch.nn.functional) +_has_mish = 'mish' in dir(torch.nn.functional) + _ACT_FN_DEFAULT = dict( silu=F.silu if _has_silu else swish, swish=F.silu if _has_silu else swish, - mish=mish, + mish=F.mish if _has_mish else mish, relu=F.relu, relu6=F.relu6, leaky_relu=F.leaky_relu, @@ -24,33 +30,39 @@ _ACT_FN_DEFAULT = dict( gelu=gelu, sigmoid=sigmoid, tanh=tanh, - hard_sigmoid=hard_sigmoid, - hard_swish=hard_swish, + hard_sigmoid=F.hardsigmoid if _has_hardsigmoid else hard_sigmoid, + hard_swish=F.hardswish if _has_hardswish else hard_swish, hard_mish=hard_mish, ) _ACT_FN_JIT = dict( silu=F.silu if _has_silu else swish_jit, swish=F.silu if _has_silu else swish_jit, - mish=mish_jit, - hard_sigmoid=hard_sigmoid_jit, - hard_swish=hard_swish_jit, + mish=F.mish if _has_mish else mish_jit, + hard_sigmoid=F.hardsigmoid if _has_hardsigmoid else hard_sigmoid_jit, + hard_swish=F.hardswish if _has_hardswish else hard_swish_jit, hard_mish=hard_mish_jit ) _ACT_FN_ME = dict( silu=F.silu if _has_silu else swish_me, swish=F.silu if _has_silu else swish_me, - mish=mish_me, - hard_sigmoid=hard_sigmoid_me, - hard_swish=hard_swish_me, + mish=F.mish if _has_mish else mish_me, + hard_sigmoid=F.hardsigmoid if _has_hardsigmoid else hard_sigmoid_me, + hard_swish=F.hardswish if _has_hardswish else hard_swish_me, hard_mish=hard_mish_me, ) +_ACT_FNS = (_ACT_FN_ME, _ACT_FN_JIT, _ACT_FN_DEFAULT) +for a in _ACT_FNS: + a.setdefault('hardsigmoid', a.get('hard_sigmoid')) + a.setdefault('hardswish', a.get('hard_swish')) + + _ACT_LAYER_DEFAULT = dict( silu=nn.SiLU if _has_silu else Swish, swish=nn.SiLU if _has_silu else Swish, - mish=Mish, + mish=nn.Mish if _has_mish else Mish, relu=nn.ReLU, relu6=nn.ReLU6, leaky_relu=nn.LeakyReLU, @@ -61,37 +73,44 @@ _ACT_LAYER_DEFAULT = dict( gelu=GELU, sigmoid=Sigmoid, tanh=Tanh, - hard_sigmoid=HardSigmoid, - hard_swish=HardSwish, + hard_sigmoid=nn.Hardsigmoid if _has_hardsigmoid else HardSigmoid, + hard_swish=nn.Hardswish if _has_hardswish else HardSwish, hard_mish=HardMish, ) _ACT_LAYER_JIT = dict( silu=nn.SiLU if _has_silu else SwishJit, swish=nn.SiLU if _has_silu else SwishJit, - mish=MishJit, - hard_sigmoid=HardSigmoidJit, - hard_swish=HardSwishJit, + mish=nn.Mish if _has_mish else MishJit, + hard_sigmoid=nn.Hardsigmoid if _has_hardsigmoid else HardSigmoidJit, + hard_swish=nn.Hardswish if _has_hardswish else HardSwishJit, hard_mish=HardMishJit ) _ACT_LAYER_ME = dict( silu=nn.SiLU if _has_silu else SwishMe, swish=nn.SiLU if _has_silu else SwishMe, - mish=MishMe, - hard_sigmoid=HardSigmoidMe, - hard_swish=HardSwishMe, + mish=nn.Mish if _has_mish else MishMe, + hard_sigmoid=nn.Hardsigmoid if _has_hardsigmoid else HardSigmoidMe, + hard_swish=nn.Hardswish if _has_hardswish else HardSwishMe, hard_mish=HardMishMe, ) +_ACT_LAYERS = (_ACT_LAYER_ME, _ACT_LAYER_JIT, _ACT_LAYER_DEFAULT) +for a in _ACT_LAYERS: + a.setdefault('hardsigmoid', a.get('hard_sigmoid')) + a.setdefault('hardswish', a.get('hard_swish')) + -def get_act_fn(name='relu'): +def get_act_fn(name: Union[Callable, str] = 'relu'): """ Activation Function Factory Fetching activation fns by name with this function allows export or torch script friendly functions to be returned dynamically based on current config. """ if not name: return None + if isinstance(name, Callable): + return name if not (is_no_jit() or is_exportable() or is_scriptable()): # If not exporting or scripting the model, first look for a memory-efficient version with # custom autograd, then fallback @@ -106,13 +125,15 @@ def get_act_fn(name='relu'): return _ACT_FN_DEFAULT[name] -def get_act_layer(name='relu'): +def get_act_layer(name: Union[Type[nn.Module], str] = 'relu'): """ Activation Layer Factory Fetching activation layers by name with this function allows export or torch script friendly functions to be returned dynamically based on current config. """ if not name: return None + if isinstance(name, type): + return name if not (is_no_jit() or is_exportable() or is_scriptable()): if name in _ACT_LAYER_ME: return _ACT_LAYER_ME[name] @@ -125,9 +146,8 @@ def get_act_layer(name='relu'): return _ACT_LAYER_DEFAULT[name] -def create_act_layer(name, inplace=False, **kwargs): +def create_act_layer(name: Union[nn.Module, str], inplace=None, **kwargs): act_layer = get_act_layer(name) - if act_layer is not None: - return act_layer(inplace=inplace, **kwargs) - else: + if act_layer is None: return None + return act_layer(**kwargs) if inplace is None else act_layer(inplace=inplace, **kwargs) diff --git a/timm/models/layers/se.py b/timm/models/layers/se.py index 54c0ef33..4354144d 100644 --- a/timm/models/layers/se.py +++ b/timm/models/layers/se.py @@ -42,7 +42,7 @@ class EffectiveSEModule(nn.Module): def __init__(self, channels, gate_layer='hard_sigmoid'): super(EffectiveSEModule, self).__init__() self.fc = nn.Conv2d(channels, channels, kernel_size=1, padding=0) - self.gate = create_act_layer(gate_layer, inplace=True) + self.gate = create_act_layer(gate_layer) def forward(self, x): x_se = x.mean((2, 3), keepdim=True) diff --git a/timm/models/levit.py b/timm/models/levit.py index 5019ee9a..2180254a 100644 --- a/timm/models/levit.py +++ b/timm/models/levit.py @@ -33,7 +33,7 @@ import torch.nn as nn from timm.data import IMAGENET_DEFAULT_STD, IMAGENET_DEFAULT_MEAN from .helpers import build_model_with_cfg, overlay_external_default_cfg -from .layers import to_ntuple +from .layers import to_ntuple, get_act_layer from .vision_transformer import trunc_normal_ from .registry import register_model @@ -443,12 +443,14 @@ class Levit(nn.Module): mlp_ratio=2, hybrid_backbone=None, down_ops=None, - act_layer=nn.Hardswish, - attn_act_layer=nn.Hardswish, + act_layer='hard_swish', + attn_act_layer='hard_swish', distillation=True, use_conv=False, drop_path=0): super().__init__() + act_layer = get_act_layer(act_layer) + attn_act_layer = get_act_layer(attn_act_layer) if isinstance(img_size, tuple): # FIXME origin impl passes single img/res dim through whole hierarchy, # not sure this model will be used enough to spend time fixing it. From 742c2d524726d426ea2745055a5b217c020ccc72 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Thu, 27 May 2021 18:03:29 -0700 Subject: [PATCH 36/48] Add Gather-Excite and Global Context attn modules. Refactor existing SE-like attn for consistency and refactor byob/byoanet for less redundancy. --- timm/models/__init__.py | 1 - timm/models/byoanet.py | 374 +++++++-------------------- timm/models/byobnet.py | 365 +++++++++++++++++++++----- timm/models/layers/__init__.py | 8 +- timm/models/layers/cbam.py | 71 ++--- timm/models/layers/create_attn.py | 11 +- timm/models/layers/eca.py | 6 + timm/models/layers/gather_excite.py | 90 +++++++ timm/models/layers/global_context.py | 67 +++++ timm/models/layers/involution.py | 6 +- timm/models/layers/mlp.py | 23 ++ timm/models/layers/norm.py | 9 + timm/models/layers/se.py | 50 ---- timm/models/layers/squeeze_excite.py | 74 ++++++ timm/models/nfnet.py | 15 +- timm/models/regnet.py | 2 +- timm/models/resnet.py | 14 +- timm/models/rexnet.py | 29 +-- timm/models/tresnet.py | 6 +- timm/models/visformer.py | 15 +- 20 files changed, 744 insertions(+), 492 deletions(-) create mode 100644 timm/models/layers/gather_excite.py create mode 100644 timm/models/layers/global_context.py delete mode 100644 timm/models/layers/se.py create mode 100644 timm/models/layers/squeeze_excite.py diff --git a/timm/models/__init__.py b/timm/models/__init__.py index 788b7518..06217e18 100644 --- a/timm/models/__init__.py +++ b/timm/models/__init__.py @@ -17,7 +17,6 @@ from .inception_resnet_v2 import * from .inception_v3 import * from .inception_v4 import * from .levit import * -#from .levit import * from .mlp_mixer import * from .mobilenetv3 import * from .nasnet import * diff --git a/timm/models/byoanet.py b/timm/models/byoanet.py index c179a01c..73c6811b 100644 --- a/timm/models/byoanet.py +++ b/timm/models/byoanet.py @@ -12,24 +12,12 @@ Consider all of the models definitions here as experimental WIP and likely to ch Hacked together by / copyright Ross Wightman, 2021. """ -import math -from dataclasses import dataclass, field -from collections import OrderedDict -from typing import Tuple, List, Optional, Union, Any, Callable -from functools import partial - -import torch -import torch.nn as nn - from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from .byobnet import BlocksCfg, ByobCfg, create_byob_stem, create_byob_stages, create_downsample,\ - reduce_feat_size, register_block, num_groups, LayerFn, _init_weights +from .byobnet import ByoBlockCfg, ByoModelCfg, ByobNet, interleave_blocks from .helpers import build_model_with_cfg -from .layers import ClassifierHead, ConvBnAct, DropPath, get_act_layer, convert_norm_act, get_attn, get_self_attn,\ - make_divisible, to_2tuple from .registry import register_model -__all__ = ['ByoaNet'] +__all__ = [] def _cfg(url='', **kwargs): @@ -63,100 +51,68 @@ default_cfgs = { 'swinnet50ts_256': _cfg(url='', fixed_input_size=True, input_size=(3, 256, 256), pool_size=(8, 8)), 'eca_swinnext26ts_256': _cfg(url='', fixed_input_size=True, input_size=(3, 256, 256), pool_size=(8, 8)), - 'rednet26t': _cfg(url='', fixed_input_size=False, input_size=(3, 256, 256), pool_size=(8, 8)), - 'rednet50ts': _cfg(url='', fixed_input_size=False, input_size=(3, 256, 256), pool_size=(8, 8)), + 'rednet26t': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8)), + 'rednet50ts': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8)), } -@dataclass -class ByoaBlocksCfg(BlocksCfg): - # FIXME allow overriding self_attn layer or args per block/stage, - pass - - -@dataclass -class ByoaCfg(ByobCfg): - blocks: Tuple[Union[ByoaBlocksCfg, Tuple[ByoaBlocksCfg, ...]], ...] = None - self_attn_layer: Optional[str] = None - self_attn_fixed_size: bool = False - self_attn_kwargs: dict = field(default_factory=lambda: dict()) - - -def interleave_attn( - types : Tuple[str, str], every: Union[int, List[int]], d, first: bool = False, **kwargs -) -> Tuple[ByoaBlocksCfg]: - """ interleave attn blocks - """ - assert len(types) == 2 - if isinstance(every, int): - every = list(range(0 if first else every, d, every)) - if not every: - every = [d - 1] - set(every) - blocks = [] - for i in range(d): - block_type = types[1] if i in every else types[0] - blocks += [ByoaBlocksCfg(type=block_type, d=1, **kwargs)] - return tuple(blocks) - - model_cfgs = dict( - botnet26t=ByoaCfg( + botnet26t=ByoModelCfg( blocks=( - ByoaBlocksCfg(type='bottle', d=3, c=256, s=1, gs=0, br=0.25), - ByoaBlocksCfg(type='bottle', d=4, c=512, s=2, gs=0, br=0.25), - interleave_attn(types=('bottle', 'self_attn'), every=1, d=2, c=1024, s=2, gs=0, br=0.25), - ByoaBlocksCfg(type='self_attn', d=3, c=2048, s=2, gs=0, br=0.25), + ByoBlockCfg(type='bottle', d=3, c=256, s=1, gs=0, br=0.25), + ByoBlockCfg(type='bottle', d=4, c=512, s=2, gs=0, br=0.25), + interleave_blocks(types=('bottle', 'self_attn'), every=1, d=2, c=1024, s=2, gs=0, br=0.25), + ByoBlockCfg(type='self_attn', d=3, c=2048, s=2, gs=0, br=0.25), ), stem_chs=64, stem_type='tiered', stem_pool='maxpool', num_features=0, + fixed_input_size=True, self_attn_layer='bottleneck', - self_attn_fixed_size=True, self_attn_kwargs=dict() ), - botnet50ts=ByoaCfg( + botnet50ts=ByoModelCfg( blocks=( - ByoaBlocksCfg(type='bottle', d=3, c=256, s=2, gs=0, br=0.25), - ByoaBlocksCfg(type='bottle', d=4, c=512, s=2, gs=0, br=0.25), - interleave_attn(types=('bottle', 'self_attn'), every=1, d=6, c=1024, s=2, gs=0, br=0.25), - ByoaBlocksCfg(type='self_attn', d=3, c=2048, s=1, gs=0, br=0.25), + ByoBlockCfg(type='bottle', d=3, c=256, s=2, gs=0, br=0.25), + ByoBlockCfg(type='bottle', d=4, c=512, s=2, gs=0, br=0.25), + interleave_blocks(types=('bottle', 'self_attn'), every=1, d=6, c=1024, s=2, gs=0, br=0.25), + ByoBlockCfg(type='self_attn', d=3, c=2048, s=1, gs=0, br=0.25), ), stem_chs=64, stem_type='tiered', stem_pool='', num_features=0, + fixed_input_size=True, act_layer='silu', self_attn_layer='bottleneck', - self_attn_fixed_size=True, self_attn_kwargs=dict() ), - eca_botnext26ts=ByoaCfg( + eca_botnext26ts=ByoModelCfg( blocks=( - ByoaBlocksCfg(type='bottle', d=3, c=256, s=1, gs=16, br=0.25), - ByoaBlocksCfg(type='bottle', d=4, c=512, s=2, gs=16, br=0.25), - interleave_attn(types=('bottle', 'self_attn'), every=1, d=2, c=1024, s=2, gs=16, br=0.25), - ByoaBlocksCfg(type='self_attn', d=3, c=2048, s=2, gs=16, br=0.25), + ByoBlockCfg(type='bottle', d=3, c=256, s=1, gs=16, br=0.25), + ByoBlockCfg(type='bottle', d=4, c=512, s=2, gs=16, br=0.25), + interleave_blocks(types=('bottle', 'self_attn'), every=1, d=2, c=1024, s=2, gs=16, br=0.25), + ByoBlockCfg(type='self_attn', d=3, c=2048, s=2, gs=16, br=0.25), ), stem_chs=64, stem_type='tiered', stem_pool='maxpool', num_features=0, + fixed_input_size=True, act_layer='silu', attn_layer='eca', self_attn_layer='bottleneck', - self_attn_fixed_size=True, self_attn_kwargs=dict() ), - halonet_h1=ByoaCfg( + halonet_h1=ByoModelCfg( blocks=( - ByoaBlocksCfg(type='self_attn', d=3, c=64, s=1, gs=0, br=1.0), - ByoaBlocksCfg(type='self_attn', d=3, c=128, s=2, gs=0, br=1.0), - ByoaBlocksCfg(type='self_attn', d=10, c=256, s=2, gs=0, br=1.0), - ByoaBlocksCfg(type='self_attn', d=3, c=512, s=2, gs=0, br=1.0), + ByoBlockCfg(type='self_attn', d=3, c=64, s=1, gs=0, br=1.0), + ByoBlockCfg(type='self_attn', d=3, c=128, s=2, gs=0, br=1.0), + ByoBlockCfg(type='self_attn', d=10, c=256, s=2, gs=0, br=1.0), + ByoBlockCfg(type='self_attn', d=3, c=512, s=2, gs=0, br=1.0), ), stem_chs=64, stem_type='7x7', @@ -165,12 +121,12 @@ model_cfgs = dict( self_attn_layer='halo', self_attn_kwargs=dict(block_size=8, halo_size=3), ), - halonet_h1_c4c5=ByoaCfg( + halonet_h1_c4c5=ByoModelCfg( blocks=( - ByoaBlocksCfg(type='bottle', d=3, c=64, s=1, gs=0, br=1.0), - ByoaBlocksCfg(type='bottle', d=3, c=128, s=2, gs=0, br=1.0), - ByoaBlocksCfg(type='self_attn', d=10, c=256, s=2, gs=0, br=1.0), - ByoaBlocksCfg(type='self_attn', d=3, c=512, s=2, gs=0, br=1.0), + ByoBlockCfg(type='bottle', d=3, c=64, s=1, gs=0, br=1.0), + ByoBlockCfg(type='bottle', d=3, c=128, s=2, gs=0, br=1.0), + ByoBlockCfg(type='self_attn', d=10, c=256, s=2, gs=0, br=1.0), + ByoBlockCfg(type='self_attn', d=3, c=512, s=2, gs=0, br=1.0), ), stem_chs=64, stem_type='tiered', @@ -179,12 +135,12 @@ model_cfgs = dict( self_attn_layer='halo', self_attn_kwargs=dict(block_size=8, halo_size=3), ), - halonet26t=ByoaCfg( + halonet26t=ByoModelCfg( blocks=( - ByoaBlocksCfg(type='bottle', d=2, c=256, s=1, gs=0, br=0.25), - ByoaBlocksCfg(type='bottle', d=2, c=512, s=2, gs=0, br=0.25), - interleave_attn(types=('bottle', 'self_attn'), every=1, d=2, c=1024, s=2, gs=0, br=0.25), - ByoaBlocksCfg(type='self_attn', d=2, c=2048, s=2, gs=0, br=0.25), + ByoBlockCfg(type='bottle', d=2, c=256, s=1, gs=0, br=0.25), + ByoBlockCfg(type='bottle', d=2, c=512, s=2, gs=0, br=0.25), + interleave_blocks(types=('bottle', 'self_attn'), every=1, d=2, c=1024, s=2, gs=0, br=0.25), + ByoBlockCfg(type='self_attn', d=2, c=2048, s=2, gs=0, br=0.25), ), stem_chs=64, stem_type='tiered', @@ -193,12 +149,12 @@ model_cfgs = dict( self_attn_layer='halo', self_attn_kwargs=dict(block_size=8, halo_size=2) # intended for 256x256 res ), - halonet50ts=ByoaCfg( + halonet50ts=ByoModelCfg( blocks=( - ByoaBlocksCfg(type='bottle', d=3, c=256, s=1, gs=0, br=0.25), - ByoaBlocksCfg(type='bottle', d=4, c=512, s=2, gs=0, br=0.25), - interleave_attn(types=('bottle', 'self_attn'), every=1, d=6, c=1024, s=2, gs=0, br=0.25), - ByoaBlocksCfg(type='self_attn', d=3, c=2048, s=2, gs=0, br=0.25), + ByoBlockCfg(type='bottle', d=3, c=256, s=1, gs=0, br=0.25), + ByoBlockCfg(type='bottle', d=4, c=512, s=2, gs=0, br=0.25), + interleave_blocks(types=('bottle', 'self_attn'), every=1, d=6, c=1024, s=2, gs=0, br=0.25), + ByoBlockCfg(type='self_attn', d=3, c=2048, s=2, gs=0, br=0.25), ), stem_chs=64, stem_type='tiered', @@ -208,12 +164,12 @@ model_cfgs = dict( self_attn_layer='halo', self_attn_kwargs=dict(block_size=8, halo_size=2) ), - eca_halonext26ts=ByoaCfg( + eca_halonext26ts=ByoModelCfg( blocks=( - ByoaBlocksCfg(type='bottle', d=2, c=256, s=1, gs=16, br=0.25), - ByoaBlocksCfg(type='bottle', d=2, c=512, s=2, gs=16, br=0.25), - interleave_attn(types=('bottle', 'self_attn'), every=1, d=2, c=1024, s=2, gs=16, br=0.25), - ByoaBlocksCfg(type='self_attn', d=2, c=2048, s=2, gs=16, br=0.25), + ByoBlockCfg(type='bottle', d=2, c=256, s=1, gs=16, br=0.25), + ByoBlockCfg(type='bottle', d=2, c=512, s=2, gs=16, br=0.25), + interleave_blocks(types=('bottle', 'self_attn'), every=1, d=2, c=1024, s=2, gs=16, br=0.25), + ByoBlockCfg(type='self_attn', d=2, c=2048, s=2, gs=16, br=0.25), ), stem_chs=64, stem_type='tiered', @@ -225,12 +181,12 @@ model_cfgs = dict( self_attn_kwargs=dict(block_size=8, halo_size=2) # intended for 256x256 res ), - lambda_resnet26t=ByoaCfg( + lambda_resnet26t=ByoModelCfg( blocks=( - ByoaBlocksCfg(type='bottle', d=2, c=256, s=1, gs=0, br=0.25), - ByoaBlocksCfg(type='bottle', d=2, c=512, s=2, gs=0, br=0.25), - interleave_attn(types=('bottle', 'self_attn'), every=1, d=2, c=1024, s=2, gs=0, br=0.25), - ByoaBlocksCfg(type='self_attn', d=2, c=2048, s=2, gs=0, br=0.25), + ByoBlockCfg(type='bottle', d=2, c=256, s=1, gs=0, br=0.25), + ByoBlockCfg(type='bottle', d=2, c=512, s=2, gs=0, br=0.25), + interleave_blocks(types=('bottle', 'self_attn'), every=1, d=2, c=1024, s=2, gs=0, br=0.25), + ByoBlockCfg(type='self_attn', d=2, c=2048, s=2, gs=0, br=0.25), ), stem_chs=64, stem_type='tiered', @@ -239,12 +195,12 @@ model_cfgs = dict( self_attn_layer='lambda', self_attn_kwargs=dict() ), - lambda_resnet50t=ByoaCfg( + lambda_resnet50t=ByoModelCfg( blocks=( - ByoaBlocksCfg(type='bottle', d=3, c=256, s=1, gs=0, br=0.25), - ByoaBlocksCfg(type='bottle', d=4, c=512, s=2, gs=0, br=0.25), - interleave_attn(types=('bottle', 'self_attn'), every=3, d=6, c=1024, s=2, gs=0, br=0.25), - ByoaBlocksCfg(type='self_attn', d=3, c=2048, s=2, gs=0, br=0.25), + ByoBlockCfg(type='bottle', d=3, c=256, s=1, gs=0, br=0.25), + ByoBlockCfg(type='bottle', d=4, c=512, s=2, gs=0, br=0.25), + interleave_blocks(types=('bottle', 'self_attn'), every=3, d=6, c=1024, s=2, gs=0, br=0.25), + ByoBlockCfg(type='self_attn', d=3, c=2048, s=2, gs=0, br=0.25), ), stem_chs=64, stem_type='tiered', @@ -253,12 +209,12 @@ model_cfgs = dict( self_attn_layer='lambda', self_attn_kwargs=dict() ), - eca_lambda_resnext26ts=ByoaCfg( + eca_lambda_resnext26ts=ByoModelCfg( blocks=( - ByoaBlocksCfg(type='bottle', d=2, c=256, s=1, gs=16, br=0.25), - ByoaBlocksCfg(type='bottle', d=2, c=512, s=2, gs=16, br=0.25), - interleave_attn(types=('bottle', 'self_attn'), every=1, d=2, c=1024, s=2, gs=16, br=0.25), - ByoaBlocksCfg(type='self_attn', d=2, c=2048, s=2, gs=16, br=0.25), + ByoBlockCfg(type='bottle', d=2, c=256, s=1, gs=16, br=0.25), + ByoBlockCfg(type='bottle', d=2, c=512, s=2, gs=16, br=0.25), + interleave_blocks(types=('bottle', 'self_attn'), every=1, d=2, c=1024, s=2, gs=16, br=0.25), + ByoBlockCfg(type='self_attn', d=2, c=2048, s=2, gs=16, br=0.25), ), stem_chs=64, stem_type='tiered', @@ -270,77 +226,76 @@ model_cfgs = dict( self_attn_kwargs=dict() ), - swinnet26t=ByoaCfg( + swinnet26t=ByoModelCfg( blocks=( - ByoaBlocksCfg(type='bottle', d=2, c=256, s=1, gs=0, br=0.25), - interleave_attn(types=('bottle', 'self_attn'), every=1, d=2, c=512, s=2, gs=0, br=0.25), - interleave_attn(types=('bottle', 'self_attn'), every=1, d=2, c=1024, s=2, gs=0, br=0.25), - ByoaBlocksCfg(type='self_attn', d=2, c=2048, s=2, gs=0, br=0.25), + ByoBlockCfg(type='bottle', d=2, c=256, s=1, gs=0, br=0.25), + interleave_blocks(types=('bottle', 'self_attn'), every=1, d=2, c=512, s=2, gs=0, br=0.25), + interleave_blocks(types=('bottle', 'self_attn'), every=1, d=2, c=1024, s=2, gs=0, br=0.25), + ByoBlockCfg(type='self_attn', d=2, c=2048, s=2, gs=0, br=0.25), ), stem_chs=64, stem_type='tiered', stem_pool='maxpool', num_features=0, + fixed_input_size=True, self_attn_layer='swin', - self_attn_fixed_size=True, self_attn_kwargs=dict(win_size=8) ), - swinnet50ts=ByoaCfg( + swinnet50ts=ByoModelCfg( blocks=( - ByoaBlocksCfg(type='bottle', d=3, c=256, s=1, gs=0, br=0.25), - interleave_attn(types=('bottle', 'self_attn'), every=1, d=4, c=512, s=2, gs=0, br=0.25), - interleave_attn(types=('bottle', 'self_attn'), every=1, d=2, c=1024, s=2, gs=0, br=0.25), - ByoaBlocksCfg(type='self_attn', d=3, c=2048, s=2, gs=0, br=0.25), + ByoBlockCfg(type='bottle', d=3, c=256, s=1, gs=0, br=0.25), + interleave_blocks(types=('bottle', 'self_attn'), every=1, d=4, c=512, s=2, gs=0, br=0.25), + interleave_blocks(types=('bottle', 'self_attn'), every=1, d=2, c=1024, s=2, gs=0, br=0.25), + ByoBlockCfg(type='self_attn', d=3, c=2048, s=2, gs=0, br=0.25), ), stem_chs=64, stem_type='tiered', stem_pool='maxpool', num_features=0, + fixed_input_size=True, act_layer='silu', self_attn_layer='swin', - self_attn_fixed_size=True, self_attn_kwargs=dict(win_size=8) ), - eca_swinnext26ts=ByoaCfg( + eca_swinnext26ts=ByoModelCfg( blocks=( - ByoaBlocksCfg(type='bottle', d=2, c=256, s=1, gs=16, br=0.25), - interleave_attn(types=('bottle', 'self_attn'), every=1, d=2, c=512, s=2, gs=16, br=0.25), - interleave_attn(types=('bottle', 'self_attn'), every=1, d=2, c=1024, s=2, gs=16, br=0.25), - ByoaBlocksCfg(type='self_attn', d=2, c=2048, s=2, gs=16, br=0.25), + ByoBlockCfg(type='bottle', d=2, c=256, s=1, gs=16, br=0.25), + interleave_blocks(types=('bottle', 'self_attn'), every=1, d=2, c=512, s=2, gs=16, br=0.25), + interleave_blocks(types=('bottle', 'self_attn'), every=1, d=2, c=1024, s=2, gs=16, br=0.25), + ByoBlockCfg(type='self_attn', d=2, c=2048, s=2, gs=16, br=0.25), ), stem_chs=64, stem_type='tiered', stem_pool='maxpool', num_features=0, + fixed_input_size=True, act_layer='silu', attn_layer='eca', self_attn_layer='swin', - self_attn_fixed_size=True, self_attn_kwargs=dict(win_size=8) ), - rednet26t=ByoaCfg( + rednet26t=ByoModelCfg( blocks=( - ByoaBlocksCfg(type='self_attn', d=2, c=256, s=1, gs=0, br=0.25), - ByoaBlocksCfg(type='self_attn', d=2, c=512, s=2, gs=0, br=0.25), - ByoaBlocksCfg(type='self_attn', d=2, c=1024, s=2, gs=0, br=0.25), - ByoaBlocksCfg(type='self_attn', d=2, c=2048, s=2, gs=0, br=0.25), + ByoBlockCfg(type='self_attn', d=2, c=256, s=1, gs=0, br=0.25), + ByoBlockCfg(type='self_attn', d=2, c=512, s=2, gs=0, br=0.25), + ByoBlockCfg(type='self_attn', d=2, c=1024, s=2, gs=0, br=0.25), + ByoBlockCfg(type='self_attn', d=2, c=2048, s=2, gs=0, br=0.25), ), stem_chs=64, stem_type='tiered', # FIXME RedNet uses involution in middle of stem stem_pool='maxpool', num_features=0, self_attn_layer='involution', - self_attn_fixed_size=False, self_attn_kwargs=dict() ), - rednet50ts=ByoaCfg( + rednet50ts=ByoModelCfg( blocks=( - ByoaBlocksCfg(type='self_attn', d=3, c=256, s=1, gs=0, br=0.25), - ByoaBlocksCfg(type='self_attn', d=4, c=512, s=2, gs=0, br=0.25), - ByoaBlocksCfg(type='self_attn', d=2, c=1024, s=2, gs=0, br=0.25), - ByoaBlocksCfg(type='self_attn', d=3, c=2048, s=2, gs=0, br=0.25), + ByoBlockCfg(type='self_attn', d=3, c=256, s=1, gs=0, br=0.25), + ByoBlockCfg(type='self_attn', d=4, c=512, s=2, gs=0, br=0.25), + ByoBlockCfg(type='self_attn', d=2, c=1024, s=2, gs=0, br=0.25), + ByoBlockCfg(type='self_attn', d=3, c=2048, s=2, gs=0, br=0.25), ), stem_chs=64, stem_type='tiered', @@ -348,161 +303,14 @@ model_cfgs = dict( num_features=0, act_layer='silu', self_attn_layer='involution', - self_attn_fixed_size=False, self_attn_kwargs=dict() ), ) -@dataclass -class ByoaLayerFn(LayerFn): - self_attn: Optional[Callable] = None - - -class SelfAttnBlock(nn.Module): - """ ResNet-like Bottleneck Block - 1x1 - optional kxk - self attn - 1x1 - """ - - def __init__(self, in_chs, out_chs, kernel_size=3, stride=1, dilation=(1, 1), bottle_ratio=1., group_size=None, - downsample='avg', extra_conv=False, linear_out=False, post_attn_na=True, feat_size=None, - layers: ByoaLayerFn = None, drop_block=None, drop_path_rate=0.): - super(SelfAttnBlock, self).__init__() - assert layers is not None - mid_chs = make_divisible(out_chs * bottle_ratio) - groups = num_groups(group_size, mid_chs) - - if in_chs != out_chs or stride != 1 or dilation[0] != dilation[1]: - self.shortcut = create_downsample( - downsample, in_chs=in_chs, out_chs=out_chs, stride=stride, dilation=dilation[0], - apply_act=False, layers=layers) - else: - self.shortcut = nn.Identity() - - self.conv1_1x1 = layers.conv_norm_act(in_chs, mid_chs, 1) - if extra_conv: - self.conv2_kxk = layers.conv_norm_act( - mid_chs, mid_chs, kernel_size, stride=stride, dilation=dilation[0], - groups=groups, drop_block=drop_block) - stride = 1 # striding done via conv if enabled - else: - self.conv2_kxk = nn.Identity() - opt_kwargs = {} if feat_size is None else dict(feat_size=feat_size) - # FIXME need to dilate self attn to have dilated network support, moop moop - self.self_attn = layers.self_attn(mid_chs, stride=stride, **opt_kwargs) - self.post_attn = layers.norm_act(mid_chs) if post_attn_na else nn.Identity() - self.conv3_1x1 = layers.conv_norm_act(mid_chs, out_chs, 1, apply_act=False) - self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity() - self.act = nn.Identity() if linear_out else layers.act(inplace=True) - - def init_weights(self, zero_init_last_bn=False): - if zero_init_last_bn: - nn.init.zeros_(self.conv3_1x1.bn.weight) - if hasattr(self.self_attn, 'reset_parameters'): - self.self_attn.reset_parameters() - - def forward(self, x): - shortcut = self.shortcut(x) - - x = self.conv1_1x1(x) - x = self.conv2_kxk(x) - x = self.self_attn(x) - x = self.post_attn(x) - x = self.conv3_1x1(x) - x = self.drop_path(x) - - x = self.act(x + shortcut) - return x - -register_block('self_attn', SelfAttnBlock) - - -def _byoa_block_args(block_kwargs, block_cfg: ByoaBlocksCfg, model_cfg: ByoaCfg, feat_size=None): - if block_cfg.type == 'self_attn' and model_cfg.self_attn_fixed_size: - assert feat_size is not None - block_kwargs['feat_size'] = feat_size - return block_kwargs - - -def get_layer_fns(cfg: ByoaCfg): - act = get_act_layer(cfg.act_layer) - norm_act = convert_norm_act(norm_layer=cfg.norm_layer, act_layer=act) - conv_norm_act = partial(ConvBnAct, norm_layer=cfg.norm_layer, act_layer=act) - attn = partial(get_attn(cfg.attn_layer), **cfg.attn_kwargs) if cfg.attn_layer else None - self_attn = partial(get_self_attn(cfg.self_attn_layer), **cfg.self_attn_kwargs) if cfg.self_attn_layer else None - layer_fn = ByoaLayerFn( - conv_norm_act=conv_norm_act, norm_act=norm_act, act=act, attn=attn, self_attn=self_attn) - return layer_fn - - -class ByoaNet(nn.Module): - """ 'Bring-your-own-attention' Net - - A ResNet inspired backbone that supports interleaving traditional residual blocks with - 'Self Attention' bottleneck blocks that replace the bottleneck kxk conv w/ a self-attention - or similar module. - - FIXME This class network definition is almost the same as ByobNet, I'd like to merge them but - torchscript limitations prevent sensible inheritance overrides. - """ - def __init__(self, cfg: ByoaCfg, num_classes=1000, in_chans=3, output_stride=32, global_pool='avg', - zero_init_last_bn=True, img_size=None, drop_rate=0., drop_path_rate=0.): - super().__init__() - self.num_classes = num_classes - self.drop_rate = drop_rate - layers = get_layer_fns(cfg) - feat_size = to_2tuple(img_size) if img_size is not None else None - - self.feature_info = [] - stem_chs = int(round((cfg.stem_chs or cfg.blocks[0].c) * cfg.width_factor)) - self.stem, stem_feat = create_byob_stem(in_chans, stem_chs, cfg.stem_type, cfg.stem_pool, layers=layers) - self.feature_info.extend(stem_feat[:-1]) - feat_size = reduce_feat_size(feat_size, stride=stem_feat[-1]['reduction']) - - self.stages, stage_feat = create_byob_stages( - cfg, drop_path_rate, output_stride, stem_feat[-1], - feat_size=feat_size, layers=layers, extra_args_fn=_byoa_block_args) - self.feature_info.extend(stage_feat[:-1]) - - prev_chs = stage_feat[-1]['num_chs'] - if cfg.num_features: - self.num_features = int(round(cfg.width_factor * cfg.num_features)) - self.final_conv = layers.conv_norm_act(prev_chs, self.num_features, 1) - else: - self.num_features = prev_chs - self.final_conv = nn.Identity() - self.feature_info += [ - dict(num_chs=self.num_features, reduction=stage_feat[-1]['reduction'], module='final_conv')] - - self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate) - - for n, m in self.named_modules(): - _init_weights(m, n) - for m in self.modules(): - # call each block's weight init for block-specific overrides to init above - if hasattr(m, 'init_weights'): - m.init_weights(zero_init_last_bn=zero_init_last_bn) - - def get_classifier(self): - return self.head.fc - - def reset_classifier(self, num_classes, global_pool='avg'): - self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate) - - def forward_features(self, x): - x = self.stem(x) - x = self.stages(x) - x = self.final_conv(x) - return x - - def forward(self, x): - x = self.forward_features(x) - x = self.head(x) - return x - - def _create_byoanet(variant, cfg_variant=None, pretrained=False, **kwargs): return build_model_with_cfg( - ByoaNet, variant, pretrained, + ByobNet, variant, pretrained, default_cfg=default_cfgs[variant], model_cfg=model_cfgs[variant] if not cfg_variant else model_cfgs[cfg_variant], feature_cfg=dict(flatten_sequential=True), diff --git a/timm/models/byobnet.py b/timm/models/byobnet.py index 8f4a2020..3f162c79 100644 --- a/timm/models/byobnet.py +++ b/timm/models/byobnet.py @@ -26,8 +26,7 @@ Hacked together by / copyright Ross Wightman, 2021. """ import math from dataclasses import dataclass, field, replace -from collections import OrderedDict -from typing import Tuple, List, Optional, Union, Any, Callable, Sequence +from typing import Tuple, List, Dict, Optional, Union, Any, Callable, Sequence from functools import partial import torch @@ -36,10 +35,10 @@ import torch.nn as nn from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from .helpers import build_model_with_cfg from .layers import ClassifierHead, ConvBnAct, BatchNormAct2d, DropPath, AvgPool2dSame, \ - create_conv2d, get_act_layer, convert_norm_act, get_attn, make_divisible + create_conv2d, get_act_layer, convert_norm_act, get_attn, get_self_attn, make_divisible, to_2tuple from .registry import register_model -__all__ = ['ByobNet', 'ByobCfg', 'BlocksCfg', 'create_byob_stem', 'create_block'] +__all__ = ['ByobNet', 'ByoModelCfg', 'ByoBlockCfg', 'create_byob_stem', 'create_block'] def _cfg(url='', **kwargs): @@ -87,35 +86,52 @@ default_cfgs = { 'repvgg_b3g4': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-repvgg-weights/repvgg_b3g4-73c370bf.pth', first_conv=('stem.conv_kxk.conv', 'stem.conv_1x1.conv')), + + # experimental configs + 'resnet52qs': _cfg(first_conv='stem.conv1.conv'), + 'geresnet50t': _cfg(first_conv='stem.conv1.conv'), + 'gcresnet50t': _cfg(first_conv='stem.conv1.conv'), } @dataclass -class BlocksCfg: +class ByoBlockCfg: type: Union[str, nn.Module] d: int # block depth (number of block repeats in stage) c: int # number of output channels for each block in stage s: int = 2 # stride of stage (first block) gs: Optional[Union[int, Callable]] = None # group-size of blocks in stage, conv is depthwise if gs == 1 br: float = 1. # bottleneck-ratio of blocks in stage - no_attn: bool = False # disable channel attn (ie SE) when layer is set for model + + # NOTE: these config items override the model cfgs that are applied to all blocks by default + attn_layer: Optional[str] = None + attn_kwargs: Optional[Dict[str, Any]] = None + self_attn_layer: Optional[str] = None + self_attn_kwargs: Optional[Dict[str, Any]] = None + block_kwargs: Optional[Dict[str, Any]] = None @dataclass -class ByobCfg: - blocks: Tuple[Union[BlocksCfg, Tuple[BlocksCfg, ...]], ...] +class ByoModelCfg: + blocks: Tuple[Union[ByoBlockCfg, Tuple[ByoBlockCfg, ...]], ...] downsample: str = 'conv1x1' stem_type: str = '3x3' - stem_pool: str = '' + stem_pool: Optional[str] = 'maxpool' stem_chs: int = 32 width_factor: float = 1.0 num_features: int = 0 # num out_channels for final conv, no final 1x1 conv if 0 zero_init_last_bn: bool = True + fixed_input_size: bool = False # model constrained to a fixed-input size / img_size must be provided on creation act_layer: str = 'relu' norm_layer: str = 'batchnorm' + + # NOTE: these config items will be overridden by the block cfg (per-block) if they are set there attn_layer: Optional[str] = None attn_kwargs: dict = field(default_factory=lambda: dict()) + self_attn_layer: Optional[str] = None + self_attn_kwargs: dict = field(default_factory=lambda: dict()) + block_kwargs: Dict[str, Any] = field(default_factory=lambda: dict()) def _rep_vgg_bcfg(d=(4, 6, 16, 1), wf=(1., 1., 1., 1.), groups=0): @@ -123,103 +139,155 @@ def _rep_vgg_bcfg(d=(4, 6, 16, 1), wf=(1., 1., 1., 1.), groups=0): group_size = 0 if groups > 0: group_size = lambda chs, idx: chs // groups if (idx + 1) % 2 == 0 else 0 - bcfg = tuple([BlocksCfg(type='rep', d=d, c=c * wf, gs=group_size) for d, c, wf in zip(d, c, wf)]) + bcfg = tuple([ByoBlockCfg(type='rep', d=d, c=c * wf, gs=group_size) for d, c, wf in zip(d, c, wf)]) return bcfg -model_cfgs = dict( +def interleave_blocks( + types: Tuple[str, str], every: Union[int, List[int]], d, first: bool = False, **kwargs +) -> Tuple[ByoBlockCfg]: + """ interleave 2 block types in stack + """ + assert len(types) == 2 + if isinstance(every, int): + every = list(range(0 if first else every, d, every)) + if not every: + every = [d - 1] + set(every) + blocks = [] + for i in range(d): + block_type = types[1] if i in every else types[0] + blocks += [ByoBlockCfg(type=block_type, d=1, **kwargs)] + return tuple(blocks) - gernet_l=ByobCfg( + +model_cfgs = dict( + gernet_l=ByoModelCfg( blocks=( - BlocksCfg(type='basic', d=1, c=128, s=2, gs=0, br=1.), - BlocksCfg(type='basic', d=2, c=192, s=2, gs=0, br=1.), - BlocksCfg(type='bottle', d=6, c=640, s=2, gs=0, br=1 / 4), - BlocksCfg(type='bottle', d=5, c=640, s=2, gs=1, br=3.), - BlocksCfg(type='bottle', d=4, c=640, s=1, gs=1, br=3.), + ByoBlockCfg(type='basic', d=1, c=128, s=2, gs=0, br=1.), + ByoBlockCfg(type='basic', d=2, c=192, s=2, gs=0, br=1.), + ByoBlockCfg(type='bottle', d=6, c=640, s=2, gs=0, br=1 / 4), + ByoBlockCfg(type='bottle', d=5, c=640, s=2, gs=1, br=3.), + ByoBlockCfg(type='bottle', d=4, c=640, s=1, gs=1, br=3.), ), stem_chs=32, + stem_pool=None, num_features=2560, ), - gernet_m=ByobCfg( + gernet_m=ByoModelCfg( blocks=( - BlocksCfg(type='basic', d=1, c=128, s=2, gs=0, br=1.), - BlocksCfg(type='basic', d=2, c=192, s=2, gs=0, br=1.), - BlocksCfg(type='bottle', d=6, c=640, s=2, gs=0, br=1 / 4), - BlocksCfg(type='bottle', d=4, c=640, s=2, gs=1, br=3.), - BlocksCfg(type='bottle', d=1, c=640, s=1, gs=1, br=3.), + ByoBlockCfg(type='basic', d=1, c=128, s=2, gs=0, br=1.), + ByoBlockCfg(type='basic', d=2, c=192, s=2, gs=0, br=1.), + ByoBlockCfg(type='bottle', d=6, c=640, s=2, gs=0, br=1 / 4), + ByoBlockCfg(type='bottle', d=4, c=640, s=2, gs=1, br=3.), + ByoBlockCfg(type='bottle', d=1, c=640, s=1, gs=1, br=3.), ), stem_chs=32, + stem_pool=None, num_features=2560, ), - gernet_s=ByobCfg( + gernet_s=ByoModelCfg( blocks=( - BlocksCfg(type='basic', d=1, c=48, s=2, gs=0, br=1.), - BlocksCfg(type='basic', d=3, c=48, s=2, gs=0, br=1.), - BlocksCfg(type='bottle', d=7, c=384, s=2, gs=0, br=1 / 4), - BlocksCfg(type='bottle', d=2, c=560, s=2, gs=1, br=3.), - BlocksCfg(type='bottle', d=1, c=256, s=1, gs=1, br=3.), + ByoBlockCfg(type='basic', d=1, c=48, s=2, gs=0, br=1.), + ByoBlockCfg(type='basic', d=3, c=48, s=2, gs=0, br=1.), + ByoBlockCfg(type='bottle', d=7, c=384, s=2, gs=0, br=1 / 4), + ByoBlockCfg(type='bottle', d=2, c=560, s=2, gs=1, br=3.), + ByoBlockCfg(type='bottle', d=1, c=256, s=1, gs=1, br=3.), ), stem_chs=13, + stem_pool=None, num_features=1920, ), - repvgg_a2=ByobCfg( + repvgg_a2=ByoModelCfg( blocks=_rep_vgg_bcfg(d=(2, 4, 14, 1), wf=(1.5, 1.5, 1.5, 2.75)), stem_type='rep', stem_chs=64, ), - repvgg_b0=ByobCfg( + repvgg_b0=ByoModelCfg( blocks=_rep_vgg_bcfg(wf=(1., 1., 1., 2.5)), stem_type='rep', stem_chs=64, ), - repvgg_b1=ByobCfg( + repvgg_b1=ByoModelCfg( blocks=_rep_vgg_bcfg(wf=(2., 2., 2., 4.)), stem_type='rep', stem_chs=64, ), - repvgg_b1g4=ByobCfg( + repvgg_b1g4=ByoModelCfg( blocks=_rep_vgg_bcfg(wf=(2., 2., 2., 4.), groups=4), stem_type='rep', stem_chs=64, ), - repvgg_b2=ByobCfg( + repvgg_b2=ByoModelCfg( blocks=_rep_vgg_bcfg(wf=(2.5, 2.5, 2.5, 5.)), stem_type='rep', stem_chs=64, ), - repvgg_b2g4=ByobCfg( + repvgg_b2g4=ByoModelCfg( blocks=_rep_vgg_bcfg(wf=(2.5, 2.5, 2.5, 5.), groups=4), stem_type='rep', stem_chs=64, ), - repvgg_b3=ByobCfg( + repvgg_b3=ByoModelCfg( blocks=_rep_vgg_bcfg(wf=(3., 3., 3., 5.)), stem_type='rep', stem_chs=64, ), - repvgg_b3g4=ByobCfg( + repvgg_b3g4=ByoModelCfg( blocks=_rep_vgg_bcfg(wf=(3., 3., 3., 5.), groups=4), stem_type='rep', stem_chs=64, ), - resnet52q=ByobCfg( + # WARN: experimental, may vanish/change + resnet52q=ByoModelCfg( blocks=( - BlocksCfg(type='bottle', d=2, c=256, s=1, gs=32, br=0.25), - BlocksCfg(type='bottle', d=4, c=512, s=2, gs=32, br=0.25), - BlocksCfg(type='bottle', d=6, c=1536, s=2, gs=32, br=0.25), - BlocksCfg(type='bottle', d=4, c=1536, s=2, gs=1, br=1.0), + ByoBlockCfg(type='bottle', d=2, c=256, s=1, gs=32, br=0.25), + ByoBlockCfg(type='bottle', d=4, c=512, s=2, gs=32, br=0.25), + ByoBlockCfg(type='bottle', d=6, c=1536, s=2, gs=32, br=0.25), + ByoBlockCfg(type='bottle', d=4, c=1536, s=2, gs=1, br=1.0), ), stem_chs=128, stem_type='quad', num_features=2048, act_layer='silu', ), + + # WARN: experimental, may vanish/change + geresnet50t=ByoModelCfg( + blocks=( + ByoBlockCfg(type='edge', d=3, c=256, s=1, br=0.25), + ByoBlockCfg(type='edge', d=4, c=512, s=2, br=0.25), + ByoBlockCfg(type='bottle', d=6, c=1024, s=2, br=0.25), + ByoBlockCfg(type='bottle', d=3, c=2048, s=2, br=0.25), + ), + stem_chs=64, + stem_type='tiered', + stem_pool=None, + attn_layer='ge', + attn_kwargs=dict(extent=8, extra_params=True), + #attn_kwargs=dict(extent=8), + #block_kwargs=dict(attn_last=True) + ), + + # WARN: experimental, may vanish/change + gcresnet50t=ByoModelCfg( + blocks=( + ByoBlockCfg(type='bottle', d=3, c=256, s=1, br=0.25), + ByoBlockCfg(type='bottle', d=4, c=512, s=2, br=0.25), + ByoBlockCfg(type='bottle', d=6, c=1024, s=2, br=0.25), + ByoBlockCfg(type='bottle', d=3, c=2048, s=2, br=0.25), + ), + stem_chs=64, + stem_type='tiered', + stem_pool=None, + attn_layer='gc' + ), ) -def expand_blocks_cfg(stage_blocks_cfg: Union[BlocksCfg, Sequence[BlocksCfg]]) -> List[BlocksCfg]: +def expand_blocks_cfg(stage_blocks_cfg: Union[ByoBlockCfg, Sequence[ByoBlockCfg]]) -> List[ByoBlockCfg]: if not isinstance(stage_blocks_cfg, Sequence): stage_blocks_cfg = (stage_blocks_cfg,) block_cfgs = [] @@ -243,6 +311,7 @@ class LayerFn: norm_act: Callable = BatchNormAct2d act: Callable = nn.ReLU attn: Optional[Callable] = None + self_attn: Optional[Callable] = None class DownsampleAvg(nn.Module): @@ -275,7 +344,8 @@ class BasicBlock(nn.Module): def __init__( self, in_chs, out_chs, kernel_size=3, stride=1, dilation=(1, 1), group_size=None, bottle_ratio=1.0, - downsample='avg', linear_out=False, layers: LayerFn = None, drop_block=None, drop_path_rate=0.): + downsample='avg', attn_last=True, linear_out=False, layers: LayerFn = None, drop_block=None, + drop_path_rate=0.): super(BasicBlock, self).__init__() layers = layers or LayerFn() mid_chs = make_divisible(out_chs * bottle_ratio) @@ -289,15 +359,19 @@ class BasicBlock(nn.Module): self.shortcut = nn.Identity() self.conv1_kxk = layers.conv_norm_act(in_chs, mid_chs, kernel_size, stride=stride, dilation=dilation[0]) + self.attn = nn.Identity() if attn_last or layers.attn is None else layers.attn(mid_chs) self.conv2_kxk = layers.conv_norm_act( mid_chs, out_chs, kernel_size, dilation=dilation[1], groups=groups, drop_block=drop_block, apply_act=False) - self.attn = nn.Identity() if layers.attn is None else layers.attn(out_chs) + self.attn_last = nn.Identity() if not attn_last or layers.attn is None else layers.attn(out_chs) self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity() self.act = nn.Identity() if linear_out else layers.act(inplace=True) - def init_weights(self, zero_init_last_bn=False): + def init_weights(self, zero_init_last_bn: bool = False): if zero_init_last_bn: nn.init.zeros_(self.conv2_kxk.bn.weight) + for attn in (self.attn, self.attn_last): + if hasattr(attn, 'reset_parameters'): + attn.reset_parameters() def forward(self, x): shortcut = self.shortcut(x) @@ -317,7 +391,8 @@ class BottleneckBlock(nn.Module): """ def __init__(self, in_chs, out_chs, kernel_size=3, stride=1, dilation=(1, 1), bottle_ratio=1., group_size=None, - downsample='avg', linear_out=False, layers : LayerFn = None, drop_block=None, drop_path_rate=0.): + downsample='avg', attn_last=False, linear_out=False, layers: LayerFn = None, drop_block=None, + drop_path_rate=0.): super(BottleneckBlock, self).__init__() layers = layers or LayerFn() mid_chs = make_divisible(out_chs * bottle_ratio) @@ -334,14 +409,18 @@ class BottleneckBlock(nn.Module): self.conv2_kxk = layers.conv_norm_act( mid_chs, mid_chs, kernel_size, stride=stride, dilation=dilation[0], groups=groups, drop_block=drop_block) - self.attn = nn.Identity() if layers.attn is None else layers.attn(mid_chs) + self.attn = nn.Identity() if attn_last or layers.attn is None else layers.attn(mid_chs) self.conv3_1x1 = layers.conv_norm_act(mid_chs, out_chs, 1, apply_act=False) + self.attn_last = nn.Identity() if not attn_last or layers.attn is None else layers.attn(out_chs) self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity() self.act = nn.Identity() if linear_out else layers.act(inplace=True) - def init_weights(self, zero_init_last_bn=False): + def init_weights(self, zero_init_last_bn: bool = False): if zero_init_last_bn: nn.init.zeros_(self.conv3_1x1.bn.weight) + for attn in (self.attn, self.attn_last): + if hasattr(attn, 'reset_parameters'): + attn.reset_parameters() def forward(self, x): shortcut = self.shortcut(x) @@ -350,6 +429,7 @@ class BottleneckBlock(nn.Module): x = self.conv2_kxk(x) x = self.attn(x) x = self.conv3_1x1(x) + x = self.attn_last(x) x = self.drop_path(x) x = self.act(x + shortcut) @@ -368,7 +448,8 @@ class DarkBlock(nn.Module): """ def __init__(self, in_chs, out_chs, kernel_size=3, stride=1, dilation=(1, 1), bottle_ratio=1.0, group_size=None, - downsample='avg', linear_out=False, layers: LayerFn = None, drop_block=None, drop_path_rate=0.): + downsample='avg', attn_last=True, linear_out=False, layers: LayerFn = None, drop_block=None, + drop_path_rate=0.): super(DarkBlock, self).__init__() layers = layers or LayerFn() mid_chs = make_divisible(out_chs * bottle_ratio) @@ -382,23 +463,28 @@ class DarkBlock(nn.Module): self.shortcut = nn.Identity() self.conv1_1x1 = layers.conv_norm_act(in_chs, mid_chs, 1) + self.attn = nn.Identity() if attn_last or layers.attn is None else layers.attn(mid_chs) self.conv2_kxk = layers.conv_norm_act( mid_chs, out_chs, kernel_size, stride=stride, dilation=dilation[0], groups=groups, drop_block=drop_block, apply_act=False) - self.attn = nn.Identity() if layers.attn is None else layers.attn(out_chs) + self.attn_last = nn.Identity() if not attn_last or layers.attn is None else layers.attn(out_chs) self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity() self.act = nn.Identity() if linear_out else layers.act(inplace=True) - def init_weights(self, zero_init_last_bn=False): + def init_weights(self, zero_init_last_bn: bool = False): if zero_init_last_bn: nn.init.zeros_(self.conv2_kxk.bn.weight) + for attn in (self.attn, self.attn_last): + if hasattr(attn, 'reset_parameters'): + attn.reset_parameters() def forward(self, x): shortcut = self.shortcut(x) x = self.conv1_1x1(x) - x = self.conv2_kxk(x) x = self.attn(x) + x = self.conv2_kxk(x) + x = self.attn_last(x) x = self.drop_path(x) x = self.act(x + shortcut) return x @@ -415,7 +501,8 @@ class EdgeBlock(nn.Module): """ def __init__(self, in_chs, out_chs, kernel_size=3, stride=1, dilation=(1, 1), bottle_ratio=1.0, group_size=None, - downsample='avg', linear_out=False, layers: LayerFn = None, drop_block=None, drop_path_rate=0.): + downsample='avg', attn_last=False, linear_out=False, layers: LayerFn = None, + drop_block=None, drop_path_rate=0.): super(EdgeBlock, self).__init__() layers = layers or LayerFn() mid_chs = make_divisible(out_chs * bottle_ratio) @@ -431,14 +518,18 @@ class EdgeBlock(nn.Module): self.conv1_kxk = layers.conv_norm_act( in_chs, mid_chs, kernel_size, stride=stride, dilation=dilation[0], groups=groups, drop_block=drop_block) - self.attn = nn.Identity() if layers.attn is None else layers.attn(out_chs) + self.attn = nn.Identity() if attn_last or layers.attn is None else layers.attn(mid_chs) self.conv2_1x1 = layers.conv_norm_act(mid_chs, out_chs, 1, apply_act=False) + self.attn_last = nn.Identity() if not attn_last or layers.attn is None else layers.attn(out_chs) self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity() self.act = nn.Identity() if linear_out else layers.act(inplace=True) - def init_weights(self, zero_init_last_bn=False): + def init_weights(self, zero_init_last_bn: bool = False): if zero_init_last_bn: nn.init.zeros_(self.conv2_1x1.bn.weight) + for attn in (self.attn, self.attn_last): + if hasattr(attn, 'reset_parameters'): + attn.reset_parameters() def forward(self, x): shortcut = self.shortcut(x) @@ -446,6 +537,7 @@ class EdgeBlock(nn.Module): x = self.conv1_kxk(x) x = self.attn(x) x = self.conv2_1x1(x) + x = self.attn_last(x) x = self.drop_path(x) x = self.act(x + shortcut) return x @@ -460,7 +552,7 @@ class RepVggBlock(nn.Module): """ def __init__(self, in_chs, out_chs, kernel_size=3, stride=1, dilation=(1, 1), bottle_ratio=1.0, group_size=None, - downsample='', layers : LayerFn = None, drop_block=None, drop_path_rate=0.): + downsample='', layers: LayerFn = None, drop_block=None, drop_path_rate=0.): super(RepVggBlock, self).__init__() layers = layers or LayerFn() groups = num_groups(group_size, in_chs) @@ -475,12 +567,15 @@ class RepVggBlock(nn.Module): self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. and use_ident else nn.Identity() self.act = layers.act(inplace=True) - def init_weights(self, zero_init_last_bn=False): + def init_weights(self, zero_init_last_bn: bool = False): # NOTE this init overrides that base model init with specific changes for the block type for m in self.modules(): if isinstance(m, nn.BatchNorm2d): nn.init.normal_(m.weight, .1, .1) nn.init.normal_(m.bias, 0, .1) + for attn in (self.attn, self.attn_last): + if hasattr(attn, 'reset_parameters'): + attn.reset_parameters() def forward(self, x): if self.identity is None: @@ -495,12 +590,68 @@ class RepVggBlock(nn.Module): return x +class SelfAttnBlock(nn.Module): + """ ResNet-like Bottleneck Block - 1x1 - optional kxk - self attn - 1x1 + """ + + def __init__(self, in_chs, out_chs, kernel_size=3, stride=1, dilation=(1, 1), bottle_ratio=1., group_size=None, + downsample='avg', extra_conv=False, linear_out=False, post_attn_na=True, feat_size=None, + layers: LayerFn = None, drop_block=None, drop_path_rate=0.): + super(SelfAttnBlock, self).__init__() + assert layers is not None + mid_chs = make_divisible(out_chs * bottle_ratio) + groups = num_groups(group_size, mid_chs) + + if in_chs != out_chs or stride != 1 or dilation[0] != dilation[1]: + self.shortcut = create_downsample( + downsample, in_chs=in_chs, out_chs=out_chs, stride=stride, dilation=dilation[0], + apply_act=False, layers=layers) + else: + self.shortcut = nn.Identity() + + self.conv1_1x1 = layers.conv_norm_act(in_chs, mid_chs, 1) + if extra_conv: + self.conv2_kxk = layers.conv_norm_act( + mid_chs, mid_chs, kernel_size, stride=stride, dilation=dilation[0], + groups=groups, drop_block=drop_block) + stride = 1 # striding done via conv if enabled + else: + self.conv2_kxk = nn.Identity() + opt_kwargs = {} if feat_size is None else dict(feat_size=feat_size) + # FIXME need to dilate self attn to have dilated network support, moop moop + self.self_attn = layers.self_attn(mid_chs, stride=stride, **opt_kwargs) + self.post_attn = layers.norm_act(mid_chs) if post_attn_na else nn.Identity() + self.conv3_1x1 = layers.conv_norm_act(mid_chs, out_chs, 1, apply_act=False) + self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity() + self.act = nn.Identity() if linear_out else layers.act(inplace=True) + + def init_weights(self, zero_init_last_bn: bool = False): + if zero_init_last_bn: + nn.init.zeros_(self.conv3_1x1.bn.weight) + if hasattr(self.self_attn, 'reset_parameters'): + self.self_attn.reset_parameters() + + def forward(self, x): + shortcut = self.shortcut(x) + + x = self.conv1_1x1(x) + x = self.conv2_kxk(x) + x = self.self_attn(x) + x = self.post_attn(x) + x = self.conv3_1x1(x) + x = self.drop_path(x) + + x = self.act(x + shortcut) + return x + + _block_registry = dict( basic=BasicBlock, bottle=BottleneckBlock, dark=DarkBlock, edge=EdgeBlock, rep=RepVggBlock, + self_attn=SelfAttnBlock, ) @@ -552,7 +703,7 @@ class Stem(nn.Sequential): curr_stride *= s prev_feat = conv_name - if 'max' in pool.lower(): + if pool and 'max' in pool.lower(): self.feature_info.append(dict(num_chs=prev_chs, reduction=curr_stride, module=prev_feat)) self.add_module('pool', nn.MaxPool2d(3, 2, 1)) curr_stride *= 2 @@ -601,9 +752,58 @@ def reduce_feat_size(feat_size, stride=2): return None if feat_size is None else tuple([s // stride for s in feat_size]) +def override_kwargs(block_kwargs, model_kwargs): + """ Override model level attn/self-attn/block kwargs w/ block level + + NOTE: kwargs are NOT merged across levels, block_kwargs will fully replace model_kwargs + for the block if set to anything that isn't None. + + i.e. an empty block_kwargs dict will remove kwargs set at model level for that block + """ + out_kwargs = block_kwargs if block_kwargs is not None else model_kwargs + return out_kwargs or {} # make sure None isn't returned + + +def update_block_kwargs(block_kwargs: Dict[str, Any], block_cfg: ByoBlockCfg, model_cfg: ByoModelCfg, ): + layer_fns = block_kwargs['layers'] + + # override attn layer / args with block local config + if block_cfg.attn_kwargs is not None or block_cfg.attn_layer is not None: + # override attn layer config + if not block_cfg.attn_layer: + # empty string for attn_layer type will disable attn for this block + attn_layer = None + else: + attn_kwargs = override_kwargs(block_cfg.attn_kwargs, model_cfg.attn_kwargs) + attn_layer = block_cfg.attn_layer or model_cfg.attn_layer + attn_layer = partial(get_attn(attn_layer), *attn_kwargs) if attn_layer is not None else None + layer_fns = replace(layer_fns, attn=attn_layer) + + # override self-attn layer / args with block local cfg + if block_cfg.self_attn_kwargs is not None or block_cfg.self_attn_layer is not None: + # override attn layer config + if not block_cfg.self_attn_layer: + # empty string for self_attn_layer type will disable attn for this block + self_attn_layer = None + else: + self_attn_kwargs = override_kwargs(block_cfg.self_attn_kwargs, model_cfg.self_attn_kwargs) + self_attn_layer = block_cfg.self_attn_layer or model_cfg.self_attn_layer + self_attn_layer = partial(get_self_attn(self_attn_layer), *self_attn_kwargs) \ + if self_attn_layer is not None else None + layer_fns = replace(layer_fns, self_attn=self_attn_layer) + + block_kwargs['layers'] = layer_fns + + # add additional block_kwargs specified in block_cfg or model_cfg, precedence to block if set + block_kwargs.update(override_kwargs(block_cfg.block_kwargs, model_cfg.block_kwargs)) + + def create_byob_stages( - cfg, drop_path_rate, output_stride, stem_feat, - feat_size=None, layers=None, extra_args_fn=None): + cfg: ByoModelCfg, drop_path_rate: float, output_stride: int, stem_feat: Dict[str, Any], + feat_size: Optional[int] = None, + layers: Optional[LayerFn] = None, + block_kwargs_fn: Optional[Callable] = update_block_kwargs): + layers = layers or LayerFn() feature_info = [] block_cfgs = [expand_blocks_cfg(s) for s in cfg.blocks] @@ -641,8 +841,10 @@ def create_byob_stages( drop_path_rate=dpr[stage_idx][block_idx], layers=layers, ) - if extra_args_fn is not None: - extra_args_fn(block_kwargs, block_cfg=block_cfg, model_cfg=cfg, feat_size=feat_size) + if block_cfg.type in ('self_attn',): + # add feat_size arg for blocks that support/need it + block_kwargs['feat_size'] = feat_size + block_kwargs_fn(block_kwargs, block_cfg=block_cfg, model_cfg=cfg) blocks += [create_block(block_cfg.type, **block_kwargs)] first_dilation = dilation prev_chs = out_chs @@ -656,12 +858,13 @@ def create_byob_stages( return nn.Sequential(*stages), feature_info -def get_layer_fns(cfg: ByobCfg): +def get_layer_fns(cfg: ByoModelCfg): act = get_act_layer(cfg.act_layer) norm_act = convert_norm_act(norm_layer=cfg.norm_layer, act_layer=act) conv_norm_act = partial(ConvBnAct, norm_layer=cfg.norm_layer, act_layer=act) attn = partial(get_attn(cfg.attn_layer), **cfg.attn_kwargs) if cfg.attn_layer else None - layer_fn = LayerFn(conv_norm_act=conv_norm_act, norm_act=norm_act, act=act, attn=attn) + self_attn = partial(get_self_attn(cfg.self_attn_layer), **cfg.self_attn_kwargs) if cfg.self_attn_layer else None + layer_fn = LayerFn(conv_norm_act=conv_norm_act, norm_act=norm_act, act=act, attn=attn, self_attn=self_attn) return layer_fn @@ -673,19 +876,24 @@ class ByobNet(nn.Module): Current assumption is that both stem and blocks are in conv-bn-act order (w/ block ending in act). """ - def __init__(self, cfg: ByobCfg, num_classes=1000, in_chans=3, global_pool='avg', output_stride=32, - zero_init_last_bn=True, drop_rate=0., drop_path_rate=0.): + def __init__(self, cfg: ByoModelCfg, num_classes=1000, in_chans=3, global_pool='avg', output_stride=32, + zero_init_last_bn=True, img_size=None, drop_rate=0., drop_path_rate=0.): super().__init__() self.num_classes = num_classes self.drop_rate = drop_rate layers = get_layer_fns(cfg) + if cfg.fixed_input_size: + assert img_size is not None, 'img_size argument is required for fixed input size model' + feat_size = to_2tuple(img_size) if img_size is not None else None self.feature_info = [] stem_chs = int(round((cfg.stem_chs or cfg.blocks[0].c) * cfg.width_factor)) self.stem, stem_feat = create_byob_stem(in_chans, stem_chs, cfg.stem_type, cfg.stem_pool, layers=layers) self.feature_info.extend(stem_feat[:-1]) + feat_size = reduce_feat_size(feat_size, stride=stem_feat[-1]['reduction']) - self.stages, stage_feat = create_byob_stages(cfg, drop_path_rate, output_stride, stem_feat[-1], layers=layers) + self.stages, stage_feat = create_byob_stages( + cfg, drop_path_rate, output_stride, stem_feat[-1], layers=layers, feat_size=feat_size) self.feature_info.extend(stage_feat[:-1]) prev_chs = stage_feat[-1]['num_chs'] @@ -836,3 +1044,24 @@ def repvgg_b3g4(pretrained=False, **kwargs): `Making VGG-style ConvNets Great Again` - https://arxiv.org/abs/2101.03697 """ return _create_byobnet('repvgg_b3g4', pretrained=pretrained, **kwargs) + + +@register_model +def resnet52q(pretrained=False, **kwargs): + """ + """ + return _create_byobnet('geresnet50t', pretrained=pretrained, **kwargs) + + +@register_model +def geresnet50t(pretrained=False, **kwargs): + """ + """ + return _create_byobnet('geresnet50t', pretrained=pretrained, **kwargs) + + +@register_model +def gcresnet50t(pretrained=False, **kwargs): + """ + """ + return _create_byobnet('gcresnet50t', pretrained=pretrained, **kwargs) diff --git a/timm/models/layers/__init__.py b/timm/models/layers/__init__.py index cd192281..30a1b40d 100644 --- a/timm/models/layers/__init__.py +++ b/timm/models/layers/__init__.py @@ -14,20 +14,22 @@ from .create_conv2d import create_conv2d from .create_norm_act import get_norm_act_layer, create_norm_act, convert_norm_act from .create_self_attn import get_self_attn, create_self_attn from .drop import DropBlock2d, DropPath, drop_block_2d, drop_path -from .eca import EcaModule, CecaModule +from .eca import EcaModule, CecaModule, EfficientChannelAttn, CircularEfficientChannelAttn from .evo_norm import EvoNormBatch2d, EvoNormSample2d +from .gather_excite import GatherExcite +from .global_context import GlobalContext from .helpers import to_ntuple, to_2tuple, to_3tuple, to_4tuple, make_divisible from .inplace_abn import InplaceAbn from .involution import Involution from .linear import Linear from .mixed_conv2d import MixedConv2d from .mlp import Mlp, GluMlp, GatedMlp -from .norm import GroupNorm +from .norm import GroupNorm, LayerNorm2d from .norm_act import BatchNormAct2d, GroupNormAct from .padding import get_padding, get_same_padding, pad_same from .patch_embed import PatchEmbed from .pool2d_same import AvgPool2dSame, create_pool2d -from .se import SEModule +from .squeeze_excite import SEModule, SqueezeExcite, EffectiveSEModule, EffectiveSqueezeExcite from .selective_kernel import SelectiveKernelConv from .separable_conv import SeparableConv2d, SeparableConvBnAct from .space_to_depth import SpaceToDepthModule diff --git a/timm/models/layers/cbam.py b/timm/models/layers/cbam.py index 44e2fe6d..bacf5cf0 100644 --- a/timm/models/layers/cbam.py +++ b/timm/models/layers/cbam.py @@ -7,78 +7,87 @@ some tasks, especially fine-grained it seems. I may end up removing this impl. Hacked together by / Copyright 2020 Ross Wightman """ - import torch from torch import nn as nn import torch.nn.functional as F + from .conv_bn_act import ConvBnAct +from .create_act import create_act_layer, get_act_layer +from .helpers import make_divisible class ChannelAttn(nn.Module): """ Original CBAM channel attention module, currently avg + max pool variant only. """ - def __init__(self, channels, reduction=16, act_layer=nn.ReLU): + def __init__( + self, channels, rd_ratio=1./16, rd_channels=None, rd_divisor=1, + act_layer=nn.ReLU, gate_layer='sigmoid', mlp_bias=False): super(ChannelAttn, self).__init__() - self.fc1 = nn.Conv2d(channels, channels // reduction, 1, bias=False) + if not rd_channels: + rd_channels = make_divisible(channels * rd_ratio, rd_divisor, round_limit=0.) + self.fc1 = nn.Conv2d(channels, rd_channels, 1, bias=mlp_bias) self.act = act_layer(inplace=True) - self.fc2 = nn.Conv2d(channels // reduction, channels, 1, bias=False) + self.fc2 = nn.Conv2d(rd_channels, channels, 1, bias=mlp_bias) + self.gate = create_act_layer(gate_layer) def forward(self, x): - x_avg = x.mean((2, 3), keepdim=True) - x_max = F.adaptive_max_pool2d(x, 1) - x_avg = self.fc2(self.act(self.fc1(x_avg))) - x_max = self.fc2(self.act(self.fc1(x_max))) - x_attn = x_avg + x_max - return x * x_attn.sigmoid() + x_avg = self.fc2(self.act(self.fc1(x.mean((2, 3), keepdim=True)))) + x_max = self.fc2(self.act(self.fc1(x.amax((2, 3), keepdim=True)))) + return x * self.gate(x_avg + x_max) class LightChannelAttn(ChannelAttn): """An experimental 'lightweight' that sums avg + max pool first """ - def __init__(self, channels, reduction=16): - super(LightChannelAttn, self).__init__(channels, reduction) + def __init__( + self, channels, rd_ratio=1./16, rd_channels=None, rd_divisor=1, + act_layer=nn.ReLU, gate_layer='sigmoid', mlp_bias=False): + super(LightChannelAttn, self).__init__( + channels, rd_ratio, rd_channels, rd_divisor, act_layer, gate_layer, mlp_bias) def forward(self, x): - x_pool = 0.5 * x.mean((2, 3), keepdim=True) + 0.5 * F.adaptive_max_pool2d(x, 1) + x_pool = 0.5 * x.mean((2, 3), keepdim=True) + 0.5 * x.amax((2, 3), keepdim=True) x_attn = self.fc2(self.act(self.fc1(x_pool))) - return x * x_attn.sigmoid() + return x * F.sigmoid(x_attn) class SpatialAttn(nn.Module): """ Original CBAM spatial attention module """ - def __init__(self, kernel_size=7): + def __init__(self, kernel_size=7, gate_layer='sigmoid'): super(SpatialAttn, self).__init__() self.conv = ConvBnAct(2, 1, kernel_size, act_layer=None) + self.gate = create_act_layer(gate_layer) def forward(self, x): - x_avg = torch.mean(x, dim=1, keepdim=True) - x_max = torch.max(x, dim=1, keepdim=True)[0] - x_attn = torch.cat([x_avg, x_max], dim=1) + x_attn = torch.cat([x.mean(dim=1, keepdim=True), x.amax(dim=1, keepdim=True)], dim=1) x_attn = self.conv(x_attn) - return x * x_attn.sigmoid() + return x * self.gate(x_attn) class LightSpatialAttn(nn.Module): """An experimental 'lightweight' variant that sums avg_pool and max_pool results. """ - def __init__(self, kernel_size=7): + def __init__(self, kernel_size=7, gate_layer='sigmoid'): super(LightSpatialAttn, self).__init__() self.conv = ConvBnAct(1, 1, kernel_size, act_layer=None) + self.gate = create_act_layer(gate_layer) def forward(self, x): - x_avg = torch.mean(x, dim=1, keepdim=True) - x_max = torch.max(x, dim=1, keepdim=True)[0] - x_attn = 0.5 * x_avg + 0.5 * x_max + x_attn = 0.5 * x.mean(dim=1, keepdim=True) + 0.5 * x.amax(dim=1, keepdim=True) x_attn = self.conv(x_attn) - return x * x_attn.sigmoid() + return x * self.gate(x_attn) class CbamModule(nn.Module): - def __init__(self, channels, spatial_kernel_size=7): + def __init__( + self, channels, rd_ratio=1./16, rd_channels=None, rd_divisor=1, + spatial_kernel_size=7, act_layer=nn.ReLU, gate_layer='sigmoid', mlp_bias=False): super(CbamModule, self).__init__() - self.channel = ChannelAttn(channels) - self.spatial = SpatialAttn(spatial_kernel_size) + self.channel = ChannelAttn( + channels, rd_ratio=rd_ratio, rd_channels=rd_channels, + rd_divisor=rd_divisor, act_layer=act_layer, gate_layer=gate_layer, mlp_bias=mlp_bias) + self.spatial = SpatialAttn(spatial_kernel_size, gate_layer=gate_layer) def forward(self, x): x = self.channel(x) @@ -87,9 +96,13 @@ class CbamModule(nn.Module): class LightCbamModule(nn.Module): - def __init__(self, channels, spatial_kernel_size=7): + def __init__( + self, channels, rd_ratio=1./16, rd_channels=None, rd_divisor=1, + spatial_kernel_size=7, act_layer=nn.ReLU, gate_layer='sigmoid', mlp_bias=False): super(LightCbamModule, self).__init__() - self.channel = LightChannelAttn(channels) + self.channel = LightChannelAttn( + channels, rd_ratio=rd_ratio, rd_channels=rd_channels, + rd_divisor=rd_divisor, act_layer=act_layer, gate_layer=gate_layer, mlp_bias=mlp_bias) self.spatial = LightSpatialAttn(spatial_kernel_size) def forward(self, x): diff --git a/timm/models/layers/create_attn.py b/timm/models/layers/create_attn.py index ff20e5df..de866eea 100644 --- a/timm/models/layers/create_attn.py +++ b/timm/models/layers/create_attn.py @@ -3,9 +3,12 @@ Hacked together by / Copyright 2020 Ross Wightman """ import torch -from .se import SEModule, EffectiveSEModule -from .eca import EcaModule, CecaModule + from .cbam import CbamModule, LightCbamModule +from .eca import EcaModule, CecaModule +from .gather_excite import GatherExcite +from .global_context import GlobalContext +from .squeeze_excite import SEModule, EffectiveSEModule def get_attn(attn_type): @@ -23,6 +26,10 @@ def get_attn(attn_type): module_cls = EcaModule elif attn_type == 'ceca': module_cls = CecaModule + elif attn_type == 'ge': + module_cls = GatherExcite + elif attn_type == 'gc': + module_cls = GlobalContext elif attn_type == 'cbam': module_cls = CbamModule elif attn_type == 'lcbam': diff --git a/timm/models/layers/eca.py b/timm/models/layers/eca.py index 3a7f8b82..d0d8f74a 100644 --- a/timm/models/layers/eca.py +++ b/timm/models/layers/eca.py @@ -65,6 +65,9 @@ class EcaModule(nn.Module): return x * y.expand_as(x) +EfficientChannelAttn = EcaModule # alias + + class CecaModule(nn.Module): """Constructs a circular ECA module. @@ -105,3 +108,6 @@ class CecaModule(nn.Module): y = self.conv(y) y = y.view(x.shape[0], -1, 1, 1).sigmoid() return x * y.expand_as(x) + + +CircularEfficientChannelAttn = CecaModule \ No newline at end of file diff --git a/timm/models/layers/gather_excite.py b/timm/models/layers/gather_excite.py new file mode 100644 index 00000000..2d60dc96 --- /dev/null +++ b/timm/models/layers/gather_excite.py @@ -0,0 +1,90 @@ +""" Gather-Excite Attention Block + +Paper: `Gather-Excite: Exploiting Feature Context in CNNs` - https://arxiv.org/abs/1810.12348 + +Official code here, but it's only partial impl in Caffe: https://github.com/hujie-frank/GENet + +I've tried to support all of the extent both w/ and w/o params. I don't believe I've seen another +impl that covers all of the cases. + +NOTE: extent=0 + extra_params=False is equivalent to Squeeze-and-Excitation + +Hacked together by / Copyright 2021 Ross Wightman +""" +import math + +from torch import nn as nn +import torch.nn.functional as F + +from .create_act import create_act_layer, get_act_layer +from .create_conv2d import create_conv2d +from .helpers import make_divisible +from .mlp import ConvMlp + + +class GatherExcite(nn.Module): + """ Gather-Excite Attention Module + """ + def __init__( + self, channels, feat_size=None, extra_params=False, extent=0, use_mlp=True, + rd_ratio=1./16, rd_channels=None, rd_divisor=1, add_maxpool=False, + act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, gate_layer='sigmoid'): + super(GatherExcite, self).__init__() + self.add_maxpool = add_maxpool + act_layer = get_act_layer(act_layer) + self.extent = extent + if extra_params: + self.gather = nn.Sequential() + if extent == 0: + assert feat_size is not None, 'spatial feature size must be specified for global extent w/ params' + self.gather.add_module( + 'conv1', create_conv2d(channels, channels, kernel_size=feat_size, stride=1, depthwise=True)) + if norm_layer: + self.gather.add_module(f'norm1', nn.BatchNorm2d(channels)) + else: + assert extent % 2 == 0 + num_conv = int(math.log2(extent)) + for i in range(num_conv): + self.gather.add_module( + f'conv{i + 1}', + create_conv2d(channels, channels, kernel_size=3, stride=2, depthwise=True)) + if norm_layer: + self.gather.add_module(f'norm{i + 1}', nn.BatchNorm2d(channels)) + if i != num_conv - 1: + self.gather.add_module(f'act{i + 1}', act_layer(inplace=True)) + else: + self.gather = None + if self.extent == 0: + self.gk = 0 + self.gs = 0 + else: + assert extent % 2 == 0 + self.gk = self.extent * 2 - 1 + self.gs = self.extent + + if not rd_channels: + rd_channels = make_divisible(channels * rd_ratio, rd_divisor, round_limit=0.) + self.mlp = ConvMlp(channels, rd_channels, act_layer=act_layer) if use_mlp else nn.Identity() + self.gate = create_act_layer(gate_layer) + + def forward(self, x): + size = x.shape[-2:] + if self.gather is not None: + x_ge = self.gather(x) + else: + if self.extent == 0: + # global extent + x_ge = x.mean(dim=(2, 3), keepdims=True) + if self.add_maxpool: + # experimental codepath, may remove or change + x_ge = 0.5 * x_ge + 0.5 * x.amax((2, 3), keepdim=True) + else: + x_ge = F.avg_pool2d( + x, kernel_size=self.gk, stride=self.gs, padding=self.gk // 2, count_include_pad=False) + if self.add_maxpool: + # experimental codepath, may remove or change + x_ge = 0.5 * x_ge + 0.5 * F.max_pool2d(x, kernel_size=self.gk, stride=self.gs, padding=self.gk // 2) + x_ge = self.mlp(x_ge) + if x_ge.shape[-1] != 1 or x_ge.shape[-2] != 1: + x_ge = F.interpolate(x_ge, size=size) + return x * self.gate(x_ge) diff --git a/timm/models/layers/global_context.py b/timm/models/layers/global_context.py new file mode 100644 index 00000000..4c2c82f3 --- /dev/null +++ b/timm/models/layers/global_context.py @@ -0,0 +1,67 @@ +""" Global Context Attention Block + +Paper: `GCNet: Non-local Networks Meet Squeeze-Excitation Networks and Beyond` + - https://arxiv.org/abs/1904.11492 + +Official code consulted as reference: https://github.com/xvjiarui/GCNet + +Hacked together by / Copyright 2021 Ross Wightman +""" +from torch import nn as nn +import torch.nn.functional as F + +from .create_act import create_act_layer, get_act_layer +from .helpers import make_divisible +from .mlp import ConvMlp +from .norm import LayerNorm2d + + +class GlobalContext(nn.Module): + + def __init__(self, channels, use_attn=True, fuse_add=True, fuse_scale=False, init_last_zero=False, + rd_ratio=1./8, rd_channels=None, rd_divisor=1, act_layer=nn.ReLU, gate_layer='sigmoid'): + super(GlobalContext, self).__init__() + act_layer = get_act_layer(act_layer) + + self.conv_attn = nn.Conv2d(channels, 1, kernel_size=1, bias=True) if use_attn else None + + if rd_channels is None: + rd_channels = make_divisible(channels * rd_ratio, rd_divisor, round_limit=0.) + if fuse_add: + self.mlp_add = ConvMlp(channels, rd_channels, act_layer=act_layer, norm_layer=LayerNorm2d) + else: + self.mlp_add = None + if fuse_scale: + self.mlp_scale = ConvMlp(channels, rd_channels, act_layer=act_layer, norm_layer=LayerNorm2d) + else: + self.mlp_scale = None + + self.gate = create_act_layer(gate_layer) + self.init_last_zero = init_last_zero + self.reset_parameters() + + def reset_parameters(self): + if self.conv_attn is not None: + nn.init.kaiming_normal_(self.conv_attn.weight, mode='fan_in', nonlinearity='relu') + if self.mlp_add is not None: + nn.init.zeros_(self.mlp_add.fc2.weight) + + def forward(self, x): + B, C, H, W = x.shape + + if self.conv_attn is not None: + attn = self.conv_attn(x).reshape(B, 1, H * W) # (B, 1, H * W) + attn = F.softmax(attn, dim=-1).unsqueeze(3) # (B, 1, H * W, 1) + context = x.reshape(B, C, H * W).unsqueeze(1) @ attn + context = context.view(B, C, 1, 1) + else: + context = x.mean(dim=(2, 3), keepdim=True) + + if self.mlp_scale is not None: + mlp_x = self.mlp_scale(context) + x = x * self.gate(mlp_x) + if self.mlp_add is not None: + mlp_x = self.mlp_add(context) + x = x + mlp_x + + return x diff --git a/timm/models/layers/involution.py b/timm/models/layers/involution.py index 0dba9fae..ccdeefcb 100644 --- a/timm/models/layers/involution.py +++ b/timm/models/layers/involution.py @@ -16,7 +16,7 @@ class Involution(nn.Module): kernel_size=3, stride=1, group_size=16, - reduction_ratio=4, + rd_ratio=4, norm_layer=nn.BatchNorm2d, act_layer=nn.ReLU, ): @@ -28,12 +28,12 @@ class Involution(nn.Module): self.groups = self.channels // self.group_size self.conv1 = ConvBnAct( in_channels=channels, - out_channels=channels // reduction_ratio, + out_channels=channels // rd_ratio, kernel_size=1, norm_layer=norm_layer, act_layer=act_layer) self.conv2 = self.conv = create_conv2d( - in_channels=channels // reduction_ratio, + in_channels=channels // rd_ratio, out_channels=kernel_size**2 * self.groups, kernel_size=1, stride=1) diff --git a/timm/models/layers/mlp.py b/timm/models/layers/mlp.py index b3f8de11..4739ba74 100644 --- a/timm/models/layers/mlp.py +++ b/timm/models/layers/mlp.py @@ -77,3 +77,26 @@ class GatedMlp(nn.Module): x = self.fc2(x) x = self.drop(x) return x + + +class ConvMlp(nn.Module): + """ MLP using 1x1 convs that keeps spatial dims + """ + def __init__( + self, in_features, hidden_features=None, out_features=None, act_layer=nn.ReLU, norm_layer=None, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Conv2d(in_features, hidden_features, kernel_size=1, bias=True) + self.norm = norm_layer(hidden_features) if norm_layer else nn.Identity() + self.act = act_layer() + self.fc2 = nn.Conv2d(hidden_features, out_features, kernel_size=1, bias=True) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.norm(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + return x diff --git a/timm/models/layers/norm.py b/timm/models/layers/norm.py index 2925e5c7..433552b4 100644 --- a/timm/models/layers/norm.py +++ b/timm/models/layers/norm.py @@ -12,3 +12,12 @@ class GroupNorm(nn.GroupNorm): def forward(self, x): return F.group_norm(x, self.num_groups, self.weight, self.bias, self.eps) + + +class LayerNorm2d(nn.LayerNorm): + """ Layernorm for channels of '2d' spatial BCHW tensors """ + def __init__(self, num_channels): + super().__init__([num_channels, 1, 1]) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) diff --git a/timm/models/layers/se.py b/timm/models/layers/se.py deleted file mode 100644 index 4354144d..00000000 --- a/timm/models/layers/se.py +++ /dev/null @@ -1,50 +0,0 @@ -from torch import nn as nn -import torch.nn.functional as F - -from .create_act import create_act_layer -from .helpers import make_divisible - - -class SEModule(nn.Module): - """ SE Module as defined in original SE-Nets with a few additions - Additions include: - * min_channels can be specified to keep reduced channel count at a minimum (default: 8) - * divisor can be specified to keep channels rounded to specified values (default: 1) - * reduction channels can be specified directly by arg (if reduction_channels is set) - * reduction channels can be specified by float ratio (if reduction_ratio is set) - """ - def __init__(self, channels, reduction=16, act_layer=nn.ReLU, gate_layer='sigmoid', - reduction_ratio=None, reduction_channels=None, min_channels=8, divisor=1): - super(SEModule, self).__init__() - if reduction_channels is not None: - reduction_channels = reduction_channels # direct specification highest priority, no rounding/min done - elif reduction_ratio is not None: - reduction_channels = make_divisible(channels * reduction_ratio, divisor, min_channels) - else: - reduction_channels = make_divisible(channels // reduction, divisor, min_channels) - self.fc1 = nn.Conv2d(channels, reduction_channels, kernel_size=1, bias=True) - self.act = act_layer(inplace=True) - self.fc2 = nn.Conv2d(reduction_channels, channels, kernel_size=1, bias=True) - self.gate = create_act_layer(gate_layer) - - def forward(self, x): - x_se = x.mean((2, 3), keepdim=True) - x_se = self.fc1(x_se) - x_se = self.act(x_se) - x_se = self.fc2(x_se) - return x * self.gate(x_se) - - -class EffectiveSEModule(nn.Module): - """ 'Effective Squeeze-Excitation - From `CenterMask : Real-Time Anchor-Free Instance Segmentation` - https://arxiv.org/abs/1911.06667 - """ - def __init__(self, channels, gate_layer='hard_sigmoid'): - super(EffectiveSEModule, self).__init__() - self.fc = nn.Conv2d(channels, channels, kernel_size=1, padding=0) - self.gate = create_act_layer(gate_layer) - - def forward(self, x): - x_se = x.mean((2, 3), keepdim=True) - x_se = self.fc(x_se) - return x * self.gate(x_se) diff --git a/timm/models/layers/squeeze_excite.py b/timm/models/layers/squeeze_excite.py new file mode 100644 index 00000000..3e8a05bb --- /dev/null +++ b/timm/models/layers/squeeze_excite.py @@ -0,0 +1,74 @@ +""" Squeeze-and-Excitation Channel Attention + +An SE implementation originally based on PyTorch SE-Net impl. +Has since evolved with additional functionality / configuration. + +Paper: `Squeeze-and-Excitation Networks` - https://arxiv.org/abs/1709.01507 + +Also included is Effective Squeeze-Excitation (ESE). +Paper: `CenterMask : Real-Time Anchor-Free Instance Segmentation` - https://arxiv.org/abs/1911.06667 + +Hacked together by / Copyright 2021 Ross Wightman +""" +from torch import nn as nn + +from .create_act import create_act_layer +from .helpers import make_divisible + + +class SEModule(nn.Module): + """ SE Module as defined in original SE-Nets with a few additions + Additions include: + * divisor can be specified to keep channels % div == 0 (default: 8) + * reduction channels can be specified directly by arg (if rd_channels is set) + * reduction channels can be specified by float rd_ratio (default: 1/16) + * global max pooling can be added to the squeeze aggregation + * customizable activation, normalization, and gate layer + """ + def __init__( + self, channels, rd_ratio=1. / 16, rd_channels=None, rd_divisor=8, add_maxpool=False, + act_layer=nn.ReLU, norm_layer=None, gate_layer='sigmoid'): + super(SEModule, self).__init__() + self.add_maxpool = add_maxpool + if not rd_channels: + rd_channels = make_divisible(channels * rd_ratio, rd_divisor, round_limit=0.) + self.fc1 = nn.Conv2d(channels, rd_channels, kernel_size=1, bias=True) + self.bn = norm_layer(rd_channels) if norm_layer else nn.Identity() + self.act = create_act_layer(act_layer, inplace=True) + self.fc2 = nn.Conv2d(rd_channels, channels, kernel_size=1, bias=True) + self.gate = create_act_layer(gate_layer) + + def forward(self, x): + x_se = x.mean((2, 3), keepdim=True) + if self.add_maxpool: + # experimental codepath, may remove or change + x_se = 0.5 * x_se + 0.5 * x.amax((2, 3), keepdim=True) + x_se = self.fc1(x_se) + x_se = self.act(self.bn(x_se)) + x_se = self.fc2(x_se) + return x * self.gate(x_se) + + +SqueezeExcite = SEModule # alias + + +class EffectiveSEModule(nn.Module): + """ 'Effective Squeeze-Excitation + From `CenterMask : Real-Time Anchor-Free Instance Segmentation` - https://arxiv.org/abs/1911.06667 + """ + def __init__(self, channels, add_maxpool=False, gate_layer='hard_sigmoid'): + super(EffectiveSEModule, self).__init__() + self.add_maxpool = add_maxpool + self.fc = nn.Conv2d(channels, channels, kernel_size=1, padding=0) + self.gate = create_act_layer(gate_layer) + + def forward(self, x): + x_se = x.mean((2, 3), keepdim=True) + if self.add_maxpool: + # experimental codepath, may remove or change + x_se = 0.5 * x_se + 0.5 * x.amax((2, 3), keepdim=True) + x_se = self.fc(x_se) + return x * self.gate(x_se) + + +EffectiveSqueezeExcite = EffectiveSEModule # alias diff --git a/timm/models/nfnet.py b/timm/models/nfnet.py index 1b67581e..593796a5 100644 --- a/timm/models/nfnet.py +++ b/timm/models/nfnet.py @@ -182,7 +182,7 @@ def _nfres_cfg( def _nfreg_cfg(depths, channels=(48, 104, 208, 440)): num_features = 1280 * channels[-1] // 440 - attn_kwargs = dict(reduction_ratio=0.5, divisor=8) + attn_kwargs = dict(rd_ratio=0.5) cfg = NfCfg( depths=depths, channels=channels, stem_type='3x3', group_size=8, width_factor=0.75, bottle_ratio=2.25, num_features=num_features, reg=True, attn_layer='se', attn_kwargs=attn_kwargs) @@ -193,7 +193,7 @@ def _nfnet_cfg( depths, channels=(256, 512, 1536, 1536), group_size=128, bottle_ratio=0.5, feat_mult=2., act_layer='gelu', attn_layer='se', attn_kwargs=None): num_features = int(channels[-1] * feat_mult) - attn_kwargs = attn_kwargs if attn_kwargs is not None else dict(reduction_ratio=0.5, divisor=8) + attn_kwargs = attn_kwargs if attn_kwargs is not None else dict(rd_ratio=0.5) cfg = NfCfg( depths=depths, channels=channels, stem_type='deep_quad', stem_chs=128, group_size=group_size, bottle_ratio=bottle_ratio, extra_conv=True, num_features=num_features, act_layer=act_layer, @@ -202,11 +202,10 @@ def _nfnet_cfg( def _dm_nfnet_cfg(depths, channels=(256, 512, 1536, 1536), act_layer='gelu', skipinit=True): - attn_kwargs = dict(reduction_ratio=0.5, divisor=8) cfg = NfCfg( depths=depths, channels=channels, stem_type='deep_quad', stem_chs=128, group_size=128, bottle_ratio=0.5, extra_conv=True, gamma_in_act=True, same_padding=True, skipinit=skipinit, - num_features=int(channels[-1] * 2.0), act_layer=act_layer, attn_layer='se', attn_kwargs=attn_kwargs) + num_features=int(channels[-1] * 2.0), act_layer=act_layer, attn_layer='se', attn_kwargs=dict(rd_ratio=0.5)) return cfg @@ -243,7 +242,7 @@ model_cfgs = dict( # Experimental 'light' versions of NFNet-F that are little leaner nfnet_l0=_nfnet_cfg( depths=(1, 2, 6, 3), feat_mult=1.5, group_size=64, bottle_ratio=0.25, - attn_kwargs=dict(reduction_ratio=0.25, divisor=8), act_layer='silu'), + attn_kwargs=dict(rd_ratio=0.25, rd_divisor=8), act_layer='silu'), eca_nfnet_l0=_nfnet_cfg( depths=(1, 2, 6, 3), feat_mult=1.5, group_size=64, bottle_ratio=0.25, attn_layer='eca', attn_kwargs=dict(), act_layer='silu'), @@ -272,9 +271,9 @@ model_cfgs = dict( nf_resnet50=_nfres_cfg(depths=(3, 4, 6, 3)), nf_resnet101=_nfres_cfg(depths=(3, 4, 23, 3)), - nf_seresnet26=_nfres_cfg(depths=(2, 2, 2, 2), attn_layer='se', attn_kwargs=dict(reduction_ratio=1/16)), - nf_seresnet50=_nfres_cfg(depths=(3, 4, 6, 3), attn_layer='se', attn_kwargs=dict(reduction_ratio=1/16)), - nf_seresnet101=_nfres_cfg(depths=(3, 4, 23, 3), attn_layer='se', attn_kwargs=dict(reduction_ratio=1/16)), + nf_seresnet26=_nfres_cfg(depths=(2, 2, 2, 2), attn_layer='se', attn_kwargs=dict(rd_ratio=1/16)), + nf_seresnet50=_nfres_cfg(depths=(3, 4, 6, 3), attn_layer='se', attn_kwargs=dict(rd_ratio=1/16)), + nf_seresnet101=_nfres_cfg(depths=(3, 4, 23, 3), attn_layer='se', attn_kwargs=dict(rd_ratio=1/16)), nf_ecaresnet26=_nfres_cfg(depths=(2, 2, 2, 2), attn_layer='eca', attn_kwargs=dict()), nf_ecaresnet50=_nfres_cfg(depths=(3, 4, 6, 3), attn_layer='eca', attn_kwargs=dict()), diff --git a/timm/models/regnet.py b/timm/models/regnet.py index 3b7dba52..6a381074 100644 --- a/timm/models/regnet.py +++ b/timm/models/regnet.py @@ -146,7 +146,7 @@ class Bottleneck(nn.Module): groups=groups, **cargs) if se_ratio: se_channels = int(round(in_chs * se_ratio)) - self.se = SEModule(bottleneck_chs, reduction_channels=se_channels) + self.se = SEModule(bottleneck_chs, rd_channels=se_channels) else: self.se = None cargs['act_layer'] = None diff --git a/timm/models/resnet.py b/timm/models/resnet.py index 2b0b0339..2f02f12a 100644 --- a/timm/models/resnet.py +++ b/timm/models/resnet.py @@ -1122,7 +1122,7 @@ def resnetrs50(pretrained=False, **kwargs): Paper: Revisiting ResNets - https://arxiv.org/abs/2103.07579 Pretrained weights from https://github.com/tensorflow/tpu/tree/bee9c4f6/models/official/resnet/resnet_rs """ - attn_layer = partial(get_attn('se'), reduction_ratio=0.25) + attn_layer = partial(get_attn('se'), rd_ratio=0.25) model_args = dict( block=Bottleneck, layers=[3, 4, 6, 3], stem_width=32, stem_type='deep', replace_stem_pool=True, avg_down=True, block_args=dict(attn_layer=attn_layer), **kwargs) @@ -1135,7 +1135,7 @@ def resnetrs101(pretrained=False, **kwargs): Paper: Revisiting ResNets - https://arxiv.org/abs/2103.07579 Pretrained weights from https://github.com/tensorflow/tpu/tree/bee9c4f6/models/official/resnet/resnet_rs """ - attn_layer = partial(get_attn('se'), reduction_ratio=0.25) + attn_layer = partial(get_attn('se'), rd_ratio=0.25) model_args = dict( block=Bottleneck, layers=[3, 4, 23, 3], stem_width=32, stem_type='deep', replace_stem_pool=True, avg_down=True, block_args=dict(attn_layer=attn_layer), **kwargs) @@ -1148,7 +1148,7 @@ def resnetrs152(pretrained=False, **kwargs): Paper: Revisiting ResNets - https://arxiv.org/abs/2103.07579 Pretrained weights from https://github.com/tensorflow/tpu/tree/bee9c4f6/models/official/resnet/resnet_rs """ - attn_layer = partial(get_attn('se'), reduction_ratio=0.25) + attn_layer = partial(get_attn('se'), rd_ratio=0.25) model_args = dict( block=Bottleneck, layers=[3, 8, 36, 3], stem_width=32, stem_type='deep', replace_stem_pool=True, avg_down=True, block_args=dict(attn_layer=attn_layer), **kwargs) @@ -1161,7 +1161,7 @@ def resnetrs200(pretrained=False, **kwargs): Paper: Revisiting ResNets - https://arxiv.org/abs/2103.07579 Pretrained weights from https://github.com/tensorflow/tpu/tree/bee9c4f6/models/official/resnet/resnet_rs """ - attn_layer = partial(get_attn('se'), reduction_ratio=0.25) + attn_layer = partial(get_attn('se'), rd_ratio=0.25) model_args = dict( block=Bottleneck, layers=[3, 24, 36, 3], stem_width=32, stem_type='deep', replace_stem_pool=True, avg_down=True, block_args=dict(attn_layer=attn_layer), **kwargs) @@ -1174,7 +1174,7 @@ def resnetrs270(pretrained=False, **kwargs): Paper: Revisiting ResNets - https://arxiv.org/abs/2103.07579 Pretrained weights from https://github.com/tensorflow/tpu/tree/bee9c4f6/models/official/resnet/resnet_rs """ - attn_layer = partial(get_attn('se'), reduction_ratio=0.25) + attn_layer = partial(get_attn('se'), rd_ratio=0.25) model_args = dict( block=Bottleneck, layers=[4, 29, 53, 4], stem_width=32, stem_type='deep', replace_stem_pool=True, avg_down=True, block_args=dict(attn_layer=attn_layer), **kwargs) @@ -1188,7 +1188,7 @@ def resnetrs350(pretrained=False, **kwargs): Paper: Revisiting ResNets - https://arxiv.org/abs/2103.07579 Pretrained weights from https://github.com/tensorflow/tpu/tree/bee9c4f6/models/official/resnet/resnet_rs """ - attn_layer = partial(get_attn('se'), reduction_ratio=0.25) + attn_layer = partial(get_attn('se'), rd_ratio=0.25) model_args = dict( block=Bottleneck, layers=[4, 36, 72, 4], stem_width=32, stem_type='deep', replace_stem_pool=True, avg_down=True, block_args=dict(attn_layer=attn_layer), **kwargs) @@ -1201,7 +1201,7 @@ def resnetrs420(pretrained=False, **kwargs): Paper: Revisiting ResNets - https://arxiv.org/abs/2103.07579 Pretrained weights from https://github.com/tensorflow/tpu/tree/bee9c4f6/models/official/resnet/resnet_rs """ - attn_layer = partial(get_attn('se'), reduction_ratio=0.25) + attn_layer = partial(get_attn('se'), rd_ratio=0.25) model_args = dict( block=Bottleneck, layers=[4, 44, 87, 4], stem_width=32, stem_type='deep', replace_stem_pool=True, avg_down=True, block_args=dict(attn_layer=attn_layer), **kwargs) diff --git a/timm/models/rexnet.py b/timm/models/rexnet.py index 859b584e..7ab8d659 100644 --- a/timm/models/rexnet.py +++ b/timm/models/rexnet.py @@ -11,11 +11,12 @@ Copyright 2020 Ross Wightman """ import torch.nn as nn +from functools import partial from math import ceil from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from .helpers import build_model_with_cfg -from .layers import ClassifierHead, create_act_layer, ConvBnAct, DropPath, make_divisible +from .layers import ClassifierHead, create_act_layer, ConvBnAct, DropPath, make_divisible, SEModule from .registry import register_model from .efficientnet_builder import efficientnet_init_weights @@ -48,26 +49,7 @@ default_cfgs = dict( url=''), ) - -class SEWithNorm(nn.Module): - - def __init__(self, channels, se_ratio=1 / 12., act_layer=nn.ReLU, divisor=1, reduction_channels=None, - gate_layer='sigmoid'): - super(SEWithNorm, self).__init__() - reduction_channels = reduction_channels or make_divisible(int(channels * se_ratio), divisor=divisor) - self.fc1 = nn.Conv2d(channels, reduction_channels, kernel_size=1, bias=True) - self.bn = nn.BatchNorm2d(reduction_channels) - self.act = act_layer(inplace=True) - self.fc2 = nn.Conv2d(reduction_channels, channels, kernel_size=1, bias=True) - self.gate = create_act_layer(gate_layer) - - def forward(self, x): - x_se = x.mean((2, 3), keepdim=True) - x_se = self.fc1(x_se) - x_se = self.bn(x_se) - x_se = self.act(x_se) - x_se = self.fc2(x_se) - return x * self.gate(x_se) +SEWithNorm = partial(SEModule, norm_layer=nn.BatchNorm2d) class LinearBottleneck(nn.Module): @@ -86,7 +68,10 @@ class LinearBottleneck(nn.Module): self.conv_exp = None self.conv_dw = ConvBnAct(dw_chs, dw_chs, 3, stride=stride, groups=dw_chs, apply_act=False) - self.se = SEWithNorm(dw_chs, se_ratio=se_ratio, divisor=ch_div) if se_ratio > 0. else None + if se_ratio > 0: + self.se = SEWithNorm(dw_chs, rd_channels=make_divisible(int(dw_chs * se_ratio), ch_div)) + else: + self.se = None self.act_dw = create_act_layer(dw_act_layer) self.conv_pwl = ConvBnAct(dw_chs, out_chs, 1, apply_act=False) diff --git a/timm/models/tresnet.py b/timm/models/tresnet.py index 9fb34c20..372bfb7b 100644 --- a/timm/models/tresnet.py +++ b/timm/models/tresnet.py @@ -84,8 +84,8 @@ class BasicBlock(nn.Module): self.relu = nn.ReLU(inplace=True) self.downsample = downsample self.stride = stride - reduction_chs = max(planes * self.expansion // 4, 64) - self.se = SEModule(planes * self.expansion, reduction_channels=reduction_chs) if use_se else None + rd_chs = max(planes * self.expansion // 4, 64) + self.se = SEModule(planes * self.expansion, rd_channels=rd_chs) if use_se else None def forward(self, x): if self.downsample is not None: @@ -125,7 +125,7 @@ class Bottleneck(nn.Module): aa_layer(channels=planes, filt_size=3, stride=2)) reduction_chs = max(planes * self.expansion // 8, 64) - self.se = SEModule(planes, reduction_channels=reduction_chs) if use_se else None + self.se = SEModule(planes, rd_channels=reduction_chs) if use_se else None self.conv3 = conv2d_iabn( planes, planes * self.expansion, kernel_size=1, stride=1, act_layer="identity") diff --git a/timm/models/visformer.py b/timm/models/visformer.py index 33a2fe87..5583ea3c 100644 --- a/timm/models/visformer.py +++ b/timm/models/visformer.py @@ -13,7 +13,7 @@ import torch.nn.functional as F from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from .helpers import build_model_with_cfg, overlay_external_default_cfg -from .layers import to_2tuple, trunc_normal_, DropPath, PatchEmbed +from .layers import to_2tuple, trunc_normal_, DropPath, PatchEmbed, LayerNorm2d from .registry import register_model @@ -39,15 +39,6 @@ default_cfgs = dict( ) -class LayerNormBHWC(nn.LayerNorm): - def __init__(self, dim): - super().__init__(dim) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - return F.layer_norm( - x.permute(0, 2, 3, 1), self.normalized_shape, self.weight, self.bias, self.eps).permute(0, 3, 1, 2) - - class SpatialMlp(nn.Module): def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0., group=8, spatial_conv=False): @@ -119,7 +110,7 @@ class Attention(nn.Module): class Block(nn.Module): def __init__(self, dim, num_heads, head_dim_ratio=1., mlp_ratio=4., - drop=0., attn_drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=LayerNormBHWC, + drop=0., attn_drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=LayerNorm2d, group=8, attn_disabled=False, spatial_conv=False): super().__init__() self.spatial_conv = spatial_conv @@ -148,7 +139,7 @@ class Block(nn.Module): class Visformer(nn.Module): def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, init_channels=32, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4., drop_rate=0., attn_drop_rate=0., drop_path_rate=0., - norm_layer=LayerNormBHWC, attn_stage='111', pos_embed=True, spatial_conv='111', + norm_layer=LayerNorm2d, attn_stage='111', pos_embed=True, spatial_conv='111', vit_stem=False, group=8, pool=True, conv_init=False, embed_norm=None): super().__init__() self.num_classes = num_classes From f615474be317b1e015c082b7dabd391f461c10b7 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Thu, 27 May 2021 18:12:22 -0700 Subject: [PATCH 37/48] Fix broken test, repvgg block doesn't have attn_last attr. --- timm/models/byobnet.py | 5 ++--- timm/models/layers/eca.py | 2 +- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/timm/models/byobnet.py b/timm/models/byobnet.py index 3f162c79..aab44365 100644 --- a/timm/models/byobnet.py +++ b/timm/models/byobnet.py @@ -573,9 +573,8 @@ class RepVggBlock(nn.Module): if isinstance(m, nn.BatchNorm2d): nn.init.normal_(m.weight, .1, .1) nn.init.normal_(m.bias, 0, .1) - for attn in (self.attn, self.attn_last): - if hasattr(attn, 'reset_parameters'): - attn.reset_parameters() + if hasattr(self.attn, 'reset_parameters'): + self.attn.reset_parameters() def forward(self, x): if self.identity is None: diff --git a/timm/models/layers/eca.py b/timm/models/layers/eca.py index d0d8f74a..f2980730 100644 --- a/timm/models/layers/eca.py +++ b/timm/models/layers/eca.py @@ -110,4 +110,4 @@ class CecaModule(nn.Module): return x * y.expand_as(x) -CircularEfficientChannelAttn = CecaModule \ No newline at end of file +CircularEfficientChannelAttn = CecaModule From 02f9d4bc34d8fda03903fe2d8e6f3599e3f1fd38 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Fri, 28 May 2021 09:53:16 -0700 Subject: [PATCH 38/48] Add weights for resnet51q model, add 61q def. --- timm/models/byobnet.py | 271 +++++++++++++++++++++++------------------ 1 file changed, 155 insertions(+), 116 deletions(-) diff --git a/timm/models/byobnet.py b/timm/models/byobnet.py index aab44365..8214b490 100644 --- a/timm/models/byobnet.py +++ b/timm/models/byobnet.py @@ -88,9 +88,16 @@ default_cfgs = { first_conv=('stem.conv_kxk.conv', 'stem.conv_1x1.conv')), # experimental configs - 'resnet52qs': _cfg(first_conv='stem.conv1.conv'), - 'geresnet50t': _cfg(first_conv='stem.conv1.conv'), - 'gcresnet50t': _cfg(first_conv='stem.conv1.conv'), + 'resnet51q': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnet51q_ra2-d47dcc76.pth', + first_conv='stem.conv1.conv', input_size=(3, 256, 256), pool_size=(8, 8), + test_input_size=(3, 288, 288), crop_pct=1.0), + 'resnet61q': _cfg( + first_conv='stem.conv1.conv', input_size=(3, 256, 256), pool_size=(8, 8), interpolation='bicubic'), + 'geresnet50t': _cfg( + first_conv='stem.conv1.conv', input_size=(3, 256, 256), pool_size=(8, 8), interpolation='bicubic'), + 'gcresnet50t': _cfg( + first_conv='stem.conv1.conv', input_size=(3, 256, 256), pool_size=(8, 8), interpolation='bicubic'), } @@ -241,7 +248,7 @@ model_cfgs = dict( ), # WARN: experimental, may vanish/change - resnet52q=ByoModelCfg( + resnet51q=ByoModelCfg( blocks=( ByoBlockCfg(type='bottle', d=2, c=256, s=1, gs=32, br=0.25), ByoBlockCfg(type='bottle', d=4, c=512, s=2, gs=32, br=0.25), @@ -249,9 +256,25 @@ model_cfgs = dict( ByoBlockCfg(type='bottle', d=4, c=1536, s=2, gs=1, br=1.0), ), stem_chs=128, + stem_type='quad2', + stem_pool=None, + num_features=2048, + act_layer='silu', + ), + + resnet61q=ByoModelCfg( + blocks=( + ByoBlockCfg(type='edge', d=1, c=256, s=1, gs=0, br=1.0, block_kwargs=dict()), + ByoBlockCfg(type='bottle', d=4, c=512, s=2, gs=32, br=0.25), + ByoBlockCfg(type='bottle', d=6, c=1536, s=2, gs=32, br=0.25), + ByoBlockCfg(type='bottle', d=4, c=1536, s=2, gs=1, br=1.0), + ), + stem_chs=128, stem_type='quad', + stem_pool=None, num_features=2048, act_layer='silu', + block_kwargs=dict(extra_conv=True), ), # WARN: experimental, may vanish/change @@ -287,6 +310,122 @@ model_cfgs = dict( ) +@register_model +def gernet_l(pretrained=False, **kwargs): + """ GEResNet-Large (GENet-Large from official impl) + `Neural Architecture Design for GPU-Efficient Networks` - https://arxiv.org/abs/2006.14090 + """ + return _create_byobnet('gernet_l', pretrained=pretrained, **kwargs) + + +@register_model +def gernet_m(pretrained=False, **kwargs): + """ GEResNet-Medium (GENet-Normal from official impl) + `Neural Architecture Design for GPU-Efficient Networks` - https://arxiv.org/abs/2006.14090 + """ + return _create_byobnet('gernet_m', pretrained=pretrained, **kwargs) + + +@register_model +def gernet_s(pretrained=False, **kwargs): + """ EResNet-Small (GENet-Small from official impl) + `Neural Architecture Design for GPU-Efficient Networks` - https://arxiv.org/abs/2006.14090 + """ + return _create_byobnet('gernet_s', pretrained=pretrained, **kwargs) + + +@register_model +def repvgg_a2(pretrained=False, **kwargs): + """ RepVGG-A2 + `Making VGG-style ConvNets Great Again` - https://arxiv.org/abs/2101.03697 + """ + return _create_byobnet('repvgg_a2', pretrained=pretrained, **kwargs) + + +@register_model +def repvgg_b0(pretrained=False, **kwargs): + """ RepVGG-B0 + `Making VGG-style ConvNets Great Again` - https://arxiv.org/abs/2101.03697 + """ + return _create_byobnet('repvgg_b0', pretrained=pretrained, **kwargs) + + +@register_model +def repvgg_b1(pretrained=False, **kwargs): + """ RepVGG-B1 + `Making VGG-style ConvNets Great Again` - https://arxiv.org/abs/2101.03697 + """ + return _create_byobnet('repvgg_b1', pretrained=pretrained, **kwargs) + + +@register_model +def repvgg_b1g4(pretrained=False, **kwargs): + """ RepVGG-B1g4 + `Making VGG-style ConvNets Great Again` - https://arxiv.org/abs/2101.03697 + """ + return _create_byobnet('repvgg_b1g4', pretrained=pretrained, **kwargs) + + +@register_model +def repvgg_b2(pretrained=False, **kwargs): + """ RepVGG-B2 + `Making VGG-style ConvNets Great Again` - https://arxiv.org/abs/2101.03697 + """ + return _create_byobnet('repvgg_b2', pretrained=pretrained, **kwargs) + + +@register_model +def repvgg_b2g4(pretrained=False, **kwargs): + """ RepVGG-B2g4 + `Making VGG-style ConvNets Great Again` - https://arxiv.org/abs/2101.03697 + """ + return _create_byobnet('repvgg_b2g4', pretrained=pretrained, **kwargs) + + +@register_model +def repvgg_b3(pretrained=False, **kwargs): + """ RepVGG-B3 + `Making VGG-style ConvNets Great Again` - https://arxiv.org/abs/2101.03697 + """ + return _create_byobnet('repvgg_b3', pretrained=pretrained, **kwargs) + + +@register_model +def repvgg_b3g4(pretrained=False, **kwargs): + """ RepVGG-B3g4 + `Making VGG-style ConvNets Great Again` - https://arxiv.org/abs/2101.03697 + """ + return _create_byobnet('repvgg_b3g4', pretrained=pretrained, **kwargs) + + +@register_model +def resnet51q(pretrained=False, **kwargs): + """ + """ + return _create_byobnet('resnet51q', pretrained=pretrained, **kwargs) + + +@register_model +def resnet61q(pretrained=False, **kwargs): + """ + """ + return _create_byobnet('resnet61q', pretrained=pretrained, **kwargs) + + +@register_model +def geresnet50t(pretrained=False, **kwargs): + """ + """ + return _create_byobnet('geresnet50t', pretrained=pretrained, **kwargs) + + +@register_model +def gcresnet50t(pretrained=False, **kwargs): + """ + """ + return _create_byobnet('gcresnet50t', pretrained=pretrained, **kwargs) + + def expand_blocks_cfg(stage_blocks_cfg: Union[ByoBlockCfg, Sequence[ByoBlockCfg]]) -> List[ByoBlockCfg]: if not isinstance(stage_blocks_cfg, Sequence): stage_blocks_cfg = (stage_blocks_cfg,) @@ -391,8 +530,8 @@ class BottleneckBlock(nn.Module): """ def __init__(self, in_chs, out_chs, kernel_size=3, stride=1, dilation=(1, 1), bottle_ratio=1., group_size=None, - downsample='avg', attn_last=False, linear_out=False, layers: LayerFn = None, drop_block=None, - drop_path_rate=0.): + downsample='avg', attn_last=False, linear_out=False, extra_conv=False, layers: LayerFn = None, + drop_block=None, drop_path_rate=0.): super(BottleneckBlock, self).__init__() layers = layers or LayerFn() mid_chs = make_divisible(out_chs * bottle_ratio) @@ -409,6 +548,14 @@ class BottleneckBlock(nn.Module): self.conv2_kxk = layers.conv_norm_act( mid_chs, mid_chs, kernel_size, stride=stride, dilation=dilation[0], groups=groups, drop_block=drop_block) + self.conv2_kxk = layers.conv_norm_act( + mid_chs, mid_chs, kernel_size, stride=stride, dilation=dilation[0], + groups=groups, drop_block=drop_block) + if extra_conv: + self.conv2b_kxk = layers.conv_norm_act( + mid_chs, mid_chs, kernel_size, dilation=dilation[1], groups=groups, drop_block=drop_block) + else: + self.conv2b_kxk = nn.Identity() self.attn = nn.Identity() if attn_last or layers.attn is None else layers.attn(mid_chs) self.conv3_1x1 = layers.conv_norm_act(mid_chs, out_chs, 1, apply_act=False) self.attn_last = nn.Identity() if not attn_last or layers.attn is None else layers.attn(out_chs) @@ -427,6 +574,7 @@ class BottleneckBlock(nn.Module): x = self.conv1_1x1(x) x = self.conv2_kxk(x) + x = self.conv2b_kxk(x) x = self.attn(x) x = self.conv3_1x1(x) x = self.attn_last(x) @@ -714,7 +862,7 @@ class Stem(nn.Sequential): def create_byob_stem(in_chs, out_chs, stem_type='', pool_type='', feat_prefix='stem', layers: LayerFn = None): layers = layers or LayerFn() - assert stem_type in ('', 'quad', 'tiered', 'deep', 'rep', '7x7', '3x3') + assert stem_type in ('', 'quad', 'quad2', 'tiered', 'deep', 'rep', '7x7', '3x3') if 'quad' in stem_type: # based on NFNet stem, stack of 4 3x3 convs num_act = 2 if 'quad2' in stem_type else None @@ -955,112 +1103,3 @@ def _create_byobnet(variant, pretrained=False, **kwargs): model_cfg=model_cfgs[variant], feature_cfg=dict(flatten_sequential=True), **kwargs) - - -@register_model -def gernet_l(pretrained=False, **kwargs): - """ GEResNet-Large (GENet-Large from official impl) - `Neural Architecture Design for GPU-Efficient Networks` - https://arxiv.org/abs/2006.14090 - """ - return _create_byobnet('gernet_l', pretrained=pretrained, **kwargs) - - -@register_model -def gernet_m(pretrained=False, **kwargs): - """ GEResNet-Medium (GENet-Normal from official impl) - `Neural Architecture Design for GPU-Efficient Networks` - https://arxiv.org/abs/2006.14090 - """ - return _create_byobnet('gernet_m', pretrained=pretrained, **kwargs) - - -@register_model -def gernet_s(pretrained=False, **kwargs): - """ EResNet-Small (GENet-Small from official impl) - `Neural Architecture Design for GPU-Efficient Networks` - https://arxiv.org/abs/2006.14090 - """ - return _create_byobnet('gernet_s', pretrained=pretrained, **kwargs) - - -@register_model -def repvgg_a2(pretrained=False, **kwargs): - """ RepVGG-A2 - `Making VGG-style ConvNets Great Again` - https://arxiv.org/abs/2101.03697 - """ - return _create_byobnet('repvgg_a2', pretrained=pretrained, **kwargs) - - -@register_model -def repvgg_b0(pretrained=False, **kwargs): - """ RepVGG-B0 - `Making VGG-style ConvNets Great Again` - https://arxiv.org/abs/2101.03697 - """ - return _create_byobnet('repvgg_b0', pretrained=pretrained, **kwargs) - - -@register_model -def repvgg_b1(pretrained=False, **kwargs): - """ RepVGG-B1 - `Making VGG-style ConvNets Great Again` - https://arxiv.org/abs/2101.03697 - """ - return _create_byobnet('repvgg_b1', pretrained=pretrained, **kwargs) - - -@register_model -def repvgg_b1g4(pretrained=False, **kwargs): - """ RepVGG-B1g4 - `Making VGG-style ConvNets Great Again` - https://arxiv.org/abs/2101.03697 - """ - return _create_byobnet('repvgg_b1g4', pretrained=pretrained, **kwargs) - - -@register_model -def repvgg_b2(pretrained=False, **kwargs): - """ RepVGG-B2 - `Making VGG-style ConvNets Great Again` - https://arxiv.org/abs/2101.03697 - """ - return _create_byobnet('repvgg_b2', pretrained=pretrained, **kwargs) - - -@register_model -def repvgg_b2g4(pretrained=False, **kwargs): - """ RepVGG-B2g4 - `Making VGG-style ConvNets Great Again` - https://arxiv.org/abs/2101.03697 - """ - return _create_byobnet('repvgg_b2g4', pretrained=pretrained, **kwargs) - - -@register_model -def repvgg_b3(pretrained=False, **kwargs): - """ RepVGG-B3 - `Making VGG-style ConvNets Great Again` - https://arxiv.org/abs/2101.03697 - """ - return _create_byobnet('repvgg_b3', pretrained=pretrained, **kwargs) - - -@register_model -def repvgg_b3g4(pretrained=False, **kwargs): - """ RepVGG-B3g4 - `Making VGG-style ConvNets Great Again` - https://arxiv.org/abs/2101.03697 - """ - return _create_byobnet('repvgg_b3g4', pretrained=pretrained, **kwargs) - - -@register_model -def resnet52q(pretrained=False, **kwargs): - """ - """ - return _create_byobnet('geresnet50t', pretrained=pretrained, **kwargs) - - -@register_model -def geresnet50t(pretrained=False, **kwargs): - """ - """ - return _create_byobnet('geresnet50t', pretrained=pretrained, **kwargs) - - -@register_model -def gcresnet50t(pretrained=False, **kwargs): - """ - """ - return _create_byobnet('gcresnet50t', pretrained=pretrained, **kwargs) From d7bab8a6c52a72487d1bed0a28aad41e326d7622 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Fri, 28 May 2021 09:54:50 -0700 Subject: [PATCH 39/48] Fix strict flag change for checkpoint load. --- timm/models/helpers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/timm/models/helpers.py b/timm/models/helpers.py index dfb6b860..adfef550 100644 --- a/timm/models/helpers.py +++ b/timm/models/helpers.py @@ -44,7 +44,7 @@ def load_state_dict(checkpoint_path, use_ema=False): raise FileNotFoundError() -def load_checkpoint(model, checkpoint_path, use_ema=False, strict=False): +def load_checkpoint(model, checkpoint_path, use_ema=False, strict=True): state_dict = load_state_dict(checkpoint_path, use_ema) model.load_state_dict(state_dict, strict=strict) From 9611458e199793a5a46c3fb5ce7031195e16bfbe Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Fri, 28 May 2021 20:47:24 -0700 Subject: [PATCH 40/48] Throw in some FBNetV3 code I had lying around, some refactoring of SE reduction channel calcs for all EffNet archs. --- timm/models/byobnet.py | 2 +- timm/models/efficientnet_blocks.py | 16 ++-- timm/models/efficientnet_builder.py | 21 ++--- timm/models/ghostnet.py | 2 +- timm/models/hardcorenas.py | 4 +- timm/models/layers/helpers.py | 2 +- timm/models/mobilenetv3.py | 119 ++++++++++++++++++++++++++-- 7 files changed, 135 insertions(+), 31 deletions(-) diff --git a/timm/models/byobnet.py b/timm/models/byobnet.py index 8214b490..8ec8690a 100644 --- a/timm/models/byobnet.py +++ b/timm/models/byobnet.py @@ -90,7 +90,7 @@ default_cfgs = { # experimental configs 'resnet51q': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnet51q_ra2-d47dcc76.pth', - first_conv='stem.conv1.conv', input_size=(3, 256, 256), pool_size=(8, 8), + first_conv='stem.conv1', input_size=(3, 256, 256), pool_size=(8, 8), test_input_size=(3, 288, 288), crop_pct=1.0), 'resnet61q': _cfg( first_conv='stem.conv1.conv', input_size=(3, 256, 256), pool_size=(8, 8), interpolation='bicubic'), diff --git a/timm/models/efficientnet_blocks.py b/timm/models/efficientnet_blocks.py index 7853db0e..ea0c791e 100644 --- a/timm/models/efficientnet_blocks.py +++ b/timm/models/efficientnet_blocks.py @@ -22,18 +22,16 @@ class SqueezeExcite(nn.Module): se_ratio (float): ratio of squeeze reduction act_layer (nn.Module): activation layer of containing block gate_fn (Callable): attention gate function - block_in_chs (int): input channels of containing block (for calculating reduction from) - reduce_from_block (bool): calculate reduction from block input channels if True force_act_layer (nn.Module): override block's activation fn if this is set/bound - divisor (int): make reduction channels divisible by this + round_chs_fn (Callable): specify a fn to calculate rounding of reduced chs """ def __init__( self, in_chs, se_ratio=0.25, act_layer=nn.ReLU, gate_fn=sigmoid, - block_in_chs=None, reduce_from_block=True, force_act_layer=None, divisor=1): + force_act_layer=None, round_chs_fn=None): super(SqueezeExcite, self).__init__() - reduced_chs = (block_in_chs or in_chs) if reduce_from_block else in_chs - reduced_chs = make_divisible(reduced_chs * se_ratio, divisor) + round_chs_fn = round_chs_fn or round + reduced_chs = round_chs_fn(in_chs * se_ratio) act_layer = force_act_layer or act_layer self.conv_reduce = nn.Conv2d(in_chs, reduced_chs, 1, bias=True) self.act1 = create_act_layer(act_layer, inplace=True) @@ -168,8 +166,7 @@ class InvertedResidual(nn.Module): self.act2 = act_layer(inplace=True) # Squeeze-and-excitation - self.se = se_layer( - mid_chs, se_ratio=se_ratio, act_layer=act_layer, block_in_chs=in_chs) if has_se else nn.Identity() + self.se = se_layer(mid_chs, se_ratio=se_ratio, act_layer=act_layer) if has_se else nn.Identity() # Point-wise linear projection self.conv_pwl = create_conv2d(mid_chs, out_chs, pw_kernel_size, padding=pad_type, **conv_kwargs) @@ -292,8 +289,7 @@ class EdgeResidual(nn.Module): self.act1 = act_layer(inplace=True) # Squeeze-and-excitation - self.se = SqueezeExcite( - mid_chs, se_ratio=se_ratio, act_layer=act_layer, block_in_chs=in_chs) if has_se else nn.Identity() + self.se = SqueezeExcite(mid_chs, se_ratio=se_ratio, act_layer=act_layer) if has_se else nn.Identity() # Point-wise linear projection self.conv_pwl = create_conv2d(mid_chs, out_chs, pw_kernel_size, padding=pad_type) diff --git a/timm/models/efficientnet_builder.py b/timm/models/efficientnet_builder.py index 57e2039b..35019747 100644 --- a/timm/models/efficientnet_builder.py +++ b/timm/models/efficientnet_builder.py @@ -265,11 +265,12 @@ class EfficientNetBuilder: https://github.com/facebookresearch/maskrcnn-benchmark/blob/master/maskrcnn_benchmark/modeling/backbone/fbnet_builder.py """ - def __init__(self, output_stride=32, pad_type='', round_chs_fn=round_channels, + def __init__(self, output_stride=32, pad_type='', round_chs_fn=round_channels, se_from_exp=False, act_layer=None, norm_layer=None, se_layer=None, drop_path_rate=0., feature_location=''): self.output_stride = output_stride self.pad_type = pad_type self.round_chs_fn = round_chs_fn + self.se_from_exp = se_from_exp # calculate se channel reduction from expanded (mid) chs self.act_layer = act_layer self.norm_layer = norm_layer self.se_layer = se_layer @@ -301,6 +302,8 @@ class EfficientNetBuilder: ba['norm_layer'] = self.norm_layer if bt != 'cn': ba['se_layer'] = self.se_layer + if not self.se_from_exp and ba['se_ratio']: + ba['se_ratio'] /= ba.get('exp_ratio', 1.0) ba['drop_path_rate'] = drop_path_rate if bt == 'ir': @@ -418,28 +421,28 @@ def _init_weight_goog(m, n='', fix_group_fanout=True): if fix_group_fanout: fan_out //= m.groups init_weight_fn = get_condconv_initializer( - lambda w: w.data.normal_(0, math.sqrt(2.0 / fan_out)), m.num_experts, m.weight_shape) + lambda w: nn.init.normal_(w, 0, math.sqrt(2.0 / fan_out)), m.num_experts, m.weight_shape) init_weight_fn(m.weight) if m.bias is not None: - m.bias.data.zero_() + nn.init.zeros_(m.bias) elif isinstance(m, nn.Conv2d): fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels if fix_group_fanout: fan_out //= m.groups - m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) + nn.init.normal_(m.weight, 0, math.sqrt(2.0 / fan_out)) if m.bias is not None: - m.bias.data.zero_() + nn.init.zeros_(m.bias) elif isinstance(m, nn.BatchNorm2d): - m.weight.data.fill_(1.0) - m.bias.data.zero_() + nn.init.ones_(m.weight) + nn.init.zeros_(m.bias) elif isinstance(m, nn.Linear): fan_out = m.weight.size(0) # fan-out fan_in = 0 if 'routing_fn' in n: fan_in = m.weight.size(1) init_range = 1.0 / math.sqrt(fan_in + fan_out) - m.weight.data.uniform_(-init_range, init_range) - m.bias.data.zero_() + nn.init.uniform_(m.weight, -init_range, init_range) + nn.init.zeros_(m.bias) def efficientnet_init_weights(model: nn.Module, init_fn=None): diff --git a/timm/models/ghostnet.py b/timm/models/ghostnet.py index 1783ff7a..d82a91b4 100644 --- a/timm/models/ghostnet.py +++ b/timm/models/ghostnet.py @@ -40,7 +40,7 @@ default_cfgs = { } -_SE_LAYER = partial(SqueezeExcite, gate_fn='hard_sigmoid', divisor=4) +_SE_LAYER = partial(SqueezeExcite, gate_fn='hard_sigmoid', round_chs_fn=partial(make_divisible, divisor=4)) class GhostModule(nn.Module): diff --git a/timm/models/hardcorenas.py b/timm/models/hardcorenas.py index 231bb4b6..16b9c4bc 100644 --- a/timm/models/hardcorenas.py +++ b/timm/models/hardcorenas.py @@ -4,7 +4,7 @@ import torch.nn as nn from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from .efficientnet_blocks import SqueezeExcite -from .efficientnet_builder import decode_arch_def, resolve_act_layer, resolve_bn_args +from .efficientnet_builder import decode_arch_def, resolve_act_layer, resolve_bn_args, round_channels from .helpers import build_model_with_cfg, default_cfg_for_features from .layers import get_act_fn from .mobilenetv3 import MobileNetV3, MobileNetV3Features @@ -40,7 +40,7 @@ def _gen_hardcorenas(pretrained, variant, arch_def, **kwargs): """ num_features = 1280 se_layer = partial( - SqueezeExcite, gate_fn=get_act_fn('hard_sigmoid'), force_act_layer=nn.ReLU, reduce_from_block=False, divisor=8) + SqueezeExcite, gate_fn=get_act_fn('hard_sigmoid'), force_act_layer=nn.ReLU, round_chs_fn=round_channels) model_kwargs = dict( block_args=decode_arch_def(arch_def), num_features=num_features, diff --git a/timm/models/layers/helpers.py b/timm/models/layers/helpers.py index 64573ef6..cc54ca7f 100644 --- a/timm/models/layers/helpers.py +++ b/timm/models/layers/helpers.py @@ -28,4 +28,4 @@ def make_divisible(v, divisor=8, min_value=None, round_limit=.9): # Make sure that round down does not go down by more than 10%. if new_v < round_limit * v: new_v += divisor - return new_v \ No newline at end of file + return new_v diff --git a/timm/models/mobilenetv3.py b/timm/models/mobilenetv3.py index 9afa3d75..fad88aa7 100644 --- a/timm/models/mobilenetv3.py +++ b/timm/models/mobilenetv3.py @@ -72,6 +72,10 @@ default_cfgs = { 'tf_mobilenetv3_small_minimal_100': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_small_minimal_100-922a7843.pth', mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD), + + 'fbnetv3_b': _cfg(), + 'fbnetv3_d': _cfg(), + 'fbnetv3_g': _cfg(), } @@ -86,7 +90,7 @@ class MobileNetV3(nn.Module): """ def __init__(self, block_args, num_classes=1000, in_chans=3, stem_size=16, num_features=1280, head_bias=True, - pad_type='', act_layer=None, norm_layer=None, se_layer=None, + pad_type='', act_layer=None, norm_layer=None, se_layer=None, se_from_exp=True, round_chs_fn=round_channels, drop_rate=0., drop_path_rate=0., global_pool='avg'): super(MobileNetV3, self).__init__() act_layer = act_layer or nn.ReLU @@ -104,7 +108,7 @@ class MobileNetV3(nn.Module): # Middle stages (IR/ER/DS Blocks) builder = EfficientNetBuilder( - output_stride=32, pad_type=pad_type, round_chs_fn=round_chs_fn, + output_stride=32, pad_type=pad_type, round_chs_fn=round_chs_fn, se_from_exp=se_from_exp, act_layer=act_layer, norm_layer=norm_layer, se_layer=se_layer, drop_path_rate=drop_path_rate) self.blocks = nn.Sequential(*builder(stem_size, block_args)) self.feature_info = builder.features @@ -161,8 +165,8 @@ class MobileNetV3Features(nn.Module): and object detection models. """ - def __init__(self, block_args, out_indices=(0, 1, 2, 3, 4), feature_location='bottleneck', - in_chans=3, stem_size=16, output_stride=32, pad_type='', round_chs_fn=round_channels, + def __init__(self, block_args, out_indices=(0, 1, 2, 3, 4), feature_location='bottleneck', in_chans=3, + stem_size=16, output_stride=32, pad_type='', round_chs_fn=round_channels, se_from_exp=True, act_layer=None, norm_layer=None, se_layer=None, drop_rate=0., drop_path_rate=0.): super(MobileNetV3Features, self).__init__() act_layer = act_layer or nn.ReLU @@ -178,7 +182,7 @@ class MobileNetV3Features(nn.Module): # Middle stages (IR/ER/DS Blocks) builder = EfficientNetBuilder( - output_stride=output_stride, pad_type=pad_type, round_chs_fn=round_chs_fn, + output_stride=output_stride, pad_type=pad_type, round_chs_fn=round_chs_fn, se_from_exp=se_from_exp, act_layer=act_layer, norm_layer=norm_layer, se_layer=se_layer, drop_path_rate=drop_path_rate, feature_location=feature_location) self.blocks = nn.Sequential(*builder(stem_size, block_args)) @@ -262,7 +266,7 @@ def _gen_mobilenet_v3_rw(variant, channel_multiplier=1.0, pretrained=False, **kw round_chs_fn=partial(round_channels, multiplier=channel_multiplier), norm_layer=partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)), act_layer=resolve_act_layer(kwargs, 'hard_swish'), - se_layer=partial(SqueezeExcite, gate_fn=get_act_fn('hard_sigmoid'), reduce_from_block=False), + se_layer=partial(SqueezeExcite, gate_fn=get_act_fn('hard_sigmoid')), **kwargs, ) model = _create_mnv3(variant, pretrained, **model_kwargs) @@ -351,7 +355,7 @@ def _gen_mobilenet_v3(variant, channel_multiplier=1.0, pretrained=False, **kwarg ['cn_r1_k1_s1_c960'], # hard-swish ] se_layer = partial( - SqueezeExcite, gate_fn=get_act_fn('hard_sigmoid'), force_act_layer=nn.ReLU, reduce_from_block=False, divisor=8) + SqueezeExcite, gate_fn=get_act_fn('hard_sigmoid'), force_act_layer=nn.ReLU, round_chs_fn=round_channels) model_kwargs = dict( block_args=decode_arch_def(arch_def), num_features=num_features, @@ -366,6 +370,86 @@ def _gen_mobilenet_v3(variant, channel_multiplier=1.0, pretrained=False, **kwarg return model +def _gen_fbnetv3(variant, channel_multiplier=1.0, pretrained=False, **kwargs): + """ FBNetV3 + FIXME untested, this is a preliminary impl of some FBNet-V3 variants. + """ + vl = variant.split('_')[-1] + if vl in ('a', 'b'): + stem_size = 16 + arch_def = [ + # stage 0, 112x112 in + ['ds_r2_k3_s1_e1_c16'], + # stage 1, 112x112 in + ['ir_r1_k5_s2_e4_c24', 'ir_r3_k5_s1_e2_c24'], + # stage 2, 56x56 in + ['ir_r1_k5_s2_e5_c40_se0.25', 'ir_r4_k5_s1_e3_c40_se0.25'], + # stage 3, 28x28 in + ['ir_r1_k5_s2_e5_c72', 'ir_r4_k3_s1_e3_c72'], + # stage 4, 14x14in + ['ir_r1_k3_s1_e5_c120_se0.25', 'ir_r5_k5_s1_e3_c120_se0.25'], + # stage 5, 14x14in + ['ir_r1_k3_s2_e6_c184_se0.25', 'ir_r5_k5_s1_e4_c184_se0.25', 'ir_r1_k5_s1_e6_c224_se0.25'], + # stage 6, 7x7 in + ['cn_r1_k1_s1_c1344'], + ] + elif vl == 'd': + stem_size = 24 + arch_def = [ + # stage 0, 112x112 in + ['ds_r2_k3_s1_e1_c16'], + # stage 1, 112x112 in + ['ir_r1_k3_s2_e5_c24', 'ir_r5_k3_s1_e2_c24'], + # stage 2, 56x56 in + ['ir_r1_k5_s2_e4_c40_se0.25', 'ir_r4_k3_s1_e3_c40_se0.25'], + # stage 3, 28x28 in + ['ir_r1_k3_s2_e5_c72', 'ir_r4_k3_s1_e3_c72'], + # stage 4, 14x14in + ['ir_r1_k3_s1_e5_c128_se0.25', 'ir_r6_k5_s1_e3_c128_se0.25'], + # stage 5, 14x14in + ['ir_r1_k3_s2_e6_c208_se0.25', 'ir_r5_k5_s1_e5_c208_se0.25', 'ir_r1_k5_s1_e6_c240_se0.25'], + # stage 6, 7x7 in + ['cn_r1_k1_s1_c1440'], + ] + elif vl == 'g': + stem_size = 32 + arch_def = [ + # stage 0, 112x112 in + ['ds_r3_k3_s1_e1_c24'], + # stage 1, 112x112 in + ['ir_r1_k5_s2_e4_c40', 'ir_r4_k5_s1_e2_c40'], + # stage 2, 56x56 in + ['ir_r1_k5_s2_e4_c56_se0.25', 'ir_r4_k5_s1_e3_c56_se0.25'], + # stage 3, 28x28 in + ['ir_r1_k5_s2_e5_c104', 'ir_r4_k3_s1_e3_c104'], + # stage 4, 14x14in + ['ir_r1_k3_s1_e5_c160_se0.25', 'ir_r8_k5_s1_e3_c160_se0.25'], + # stage 5, 14x14in + ['ir_r1_k3_s2_e6_c264_se0.25', 'ir_r6_k5_s1_e5_c264_se0.25', 'ir_r2_k5_s1_e6_c288_se0.25'], + # stage 6, 7x7 in + ['cn_r1_k1_s1_c1728'], # hard-swish + ] + else: + raise NotImplemented + round_chs_fn = partial(round_channels, multiplier=channel_multiplier, round_limit=0.95) + se_layer = partial(SqueezeExcite, gate_fn=get_act_fn('hard_sigmoid'), round_chs_fn=round_chs_fn) + act_layer = resolve_act_layer(kwargs, 'hard_swish') + model_kwargs = dict( + block_args=decode_arch_def(arch_def), + num_features=1984, + head_bias=False, + stem_size=stem_size, + round_chs_fn=round_chs_fn, + se_from_exp=False, + norm_layer=partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)), + act_layer=act_layer, + se_layer=se_layer, + **kwargs, + ) + model = _create_mnv3(variant, pretrained, **model_kwargs) + return model + + @register_model def mobilenetv3_large_075(pretrained=False, **kwargs): """ MobileNet V3 """ @@ -474,3 +558,24 @@ def tf_mobilenetv3_small_minimal_100(pretrained=False, **kwargs): kwargs['pad_type'] = 'same' model = _gen_mobilenet_v3('tf_mobilenetv3_small_minimal_100', 1.0, pretrained=pretrained, **kwargs) return model + + +@register_model +def fbnetv3_b(pretrained=False, **kwargs): + """ FBNetV3-B """ + model = _gen_fbnetv3('fbnetv3_b', pretrained=pretrained, **kwargs) + return model + + +@register_model +def fbnetv3_d(pretrained=False, **kwargs): + """ FBNetV3-D """ + model = _gen_fbnetv3('fbnetv3_d', pretrained=pretrained, **kwargs) + return model + + +@register_model +def fbnetv3_g(pretrained=False, **kwargs): + """ FBNetV3-G """ + model = _gen_fbnetv3('fbnetv3_g', pretrained=pretrained, **kwargs) + return model From bcec14d3b585d7b5c469705f99fb2d830bdcdb7d Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Sat, 29 May 2021 23:41:38 -0700 Subject: [PATCH 41/48] Bring EfficientNet SE layer in line with others, pull se_ratio outside of blocks. Allows swapping w/ other attn layers. --- timm/models/efficientnet_blocks.py | 53 ++++++++++++++--------------- timm/models/efficientnet_builder.py | 33 ++++++++++++------ timm/models/ghostnet.py | 4 +-- timm/models/hardcorenas.py | 3 +- timm/models/mobilenetv3.py | 32 ++++------------- 5 files changed, 57 insertions(+), 68 deletions(-) diff --git a/timm/models/efficientnet_blocks.py b/timm/models/efficientnet_blocks.py index ea0c791e..b43f38f5 100644 --- a/timm/models/efficientnet_blocks.py +++ b/timm/models/efficientnet_blocks.py @@ -7,7 +7,7 @@ import torch import torch.nn as nn from torch.nn import functional as F -from .layers import create_conv2d, drop_path, make_divisible, get_act_fn, create_act_layer +from .layers import create_conv2d, drop_path, make_divisible, create_act_layer from .layers.activations import sigmoid __all__ = [ @@ -19,31 +19,32 @@ class SqueezeExcite(nn.Module): Args: in_chs (int): input channels to layer - se_ratio (float): ratio of squeeze reduction + rd_ratio (float): ratio of squeeze reduction act_layer (nn.Module): activation layer of containing block - gate_fn (Callable): attention gate function + gate_layer (Callable): attention gate function force_act_layer (nn.Module): override block's activation fn if this is set/bound - round_chs_fn (Callable): specify a fn to calculate rounding of reduced chs + rd_round_fn (Callable): specify a fn to calculate rounding of reduced chs """ def __init__( - self, in_chs, se_ratio=0.25, act_layer=nn.ReLU, gate_fn=sigmoid, - force_act_layer=None, round_chs_fn=None): + self, in_chs, rd_ratio=0.25, rd_channels=None, act_layer=nn.ReLU, + gate_layer=nn.Sigmoid, force_act_layer=None, rd_round_fn=None): super(SqueezeExcite, self).__init__() - round_chs_fn = round_chs_fn or round - reduced_chs = round_chs_fn(in_chs * se_ratio) + if rd_channels is None: + rd_round_fn = rd_round_fn or round + rd_channels = rd_round_fn(in_chs * rd_ratio) act_layer = force_act_layer or act_layer - self.conv_reduce = nn.Conv2d(in_chs, reduced_chs, 1, bias=True) + self.conv_reduce = nn.Conv2d(in_chs, rd_channels, 1, bias=True) self.act1 = create_act_layer(act_layer, inplace=True) - self.conv_expand = nn.Conv2d(reduced_chs, in_chs, 1, bias=True) - self.gate_fn = get_act_fn(gate_fn) + self.conv_expand = nn.Conv2d(rd_channels, in_chs, 1, bias=True) + self.gate = create_act_layer(gate_layer) def forward(self, x): x_se = x.mean((2, 3), keepdim=True) x_se = self.conv_reduce(x_se) x_se = self.act1(x_se) x_se = self.conv_expand(x_se) - return x * self.gate_fn(x_se) + return x * self.gate(x_se) class ConvBnAct(nn.Module): @@ -85,10 +86,9 @@ class DepthwiseSeparableConv(nn.Module): """ def __init__( self, in_chs, out_chs, dw_kernel_size=3, stride=1, dilation=1, pad_type='', - noskip=False, pw_kernel_size=1, pw_act=False, se_ratio=0., - act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, se_layer=None, drop_path_rate=0.): + noskip=False, pw_kernel_size=1, pw_act=False, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, + se_layer=None, drop_path_rate=0.): super(DepthwiseSeparableConv, self).__init__() - has_se = se_layer is not None and se_ratio > 0. self.has_residual = (stride == 1 and in_chs == out_chs) and not noskip self.has_pw_act = pw_act # activation after point-wise conv self.drop_path_rate = drop_path_rate @@ -99,7 +99,7 @@ class DepthwiseSeparableConv(nn.Module): self.act1 = act_layer(inplace=True) # Squeeze-and-excitation - self.se = se_layer(in_chs, se_ratio=se_ratio, act_layer=act_layer) if has_se else nn.Identity() + self.se = se_layer(in_chs, act_layer=act_layer) if se_layer else nn.Identity() self.conv_pw = create_conv2d(in_chs, out_chs, pw_kernel_size, padding=pad_type) self.bn2 = norm_layer(out_chs) @@ -144,12 +144,11 @@ class InvertedResidual(nn.Module): def __init__( self, in_chs, out_chs, dw_kernel_size=3, stride=1, dilation=1, pad_type='', - noskip=False, exp_ratio=1.0, exp_kernel_size=1, pw_kernel_size=1, se_ratio=0., - act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, se_layer=None, conv_kwargs=None, drop_path_rate=0.): + noskip=False, exp_ratio=1.0, exp_kernel_size=1, pw_kernel_size=1, act_layer=nn.ReLU, + norm_layer=nn.BatchNorm2d, se_layer=None, conv_kwargs=None, drop_path_rate=0.): super(InvertedResidual, self).__init__() conv_kwargs = conv_kwargs or {} mid_chs = make_divisible(in_chs * exp_ratio) - has_se = se_layer is not None and se_ratio > 0. self.has_residual = (in_chs == out_chs and stride == 1) and not noskip self.drop_path_rate = drop_path_rate @@ -166,7 +165,7 @@ class InvertedResidual(nn.Module): self.act2 = act_layer(inplace=True) # Squeeze-and-excitation - self.se = se_layer(mid_chs, se_ratio=se_ratio, act_layer=act_layer) if has_se else nn.Identity() + self.se = se_layer(mid_chs, act_layer=act_layer) if se_layer else nn.Identity() # Point-wise linear projection self.conv_pwl = create_conv2d(mid_chs, out_chs, pw_kernel_size, padding=pad_type, **conv_kwargs) @@ -212,8 +211,8 @@ class CondConvResidual(InvertedResidual): def __init__( self, in_chs, out_chs, dw_kernel_size=3, stride=1, dilation=1, pad_type='', - noskip=False, exp_ratio=1.0, exp_kernel_size=1, pw_kernel_size=1, se_ratio=0., - act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, se_layer=None, num_experts=0, drop_path_rate=0.): + noskip=False, exp_ratio=1.0, exp_kernel_size=1, pw_kernel_size=1, act_layer=nn.ReLU, + norm_layer=nn.BatchNorm2d, se_layer=None, num_experts=0, drop_path_rate=0.): self.num_experts = num_experts conv_kwargs = dict(num_experts=self.num_experts) @@ -221,8 +220,8 @@ class CondConvResidual(InvertedResidual): super(CondConvResidual, self).__init__( in_chs, out_chs, dw_kernel_size=dw_kernel_size, stride=stride, dilation=dilation, pad_type=pad_type, act_layer=act_layer, noskip=noskip, exp_ratio=exp_ratio, exp_kernel_size=exp_kernel_size, - pw_kernel_size=pw_kernel_size, se_ratio=se_ratio, se_layer=se_layer, - norm_layer=norm_layer, conv_kwargs=conv_kwargs, drop_path_rate=drop_path_rate) + pw_kernel_size=pw_kernel_size, se_layer=se_layer, norm_layer=norm_layer, conv_kwargs=conv_kwargs, + drop_path_rate=drop_path_rate) self.routing_fn = nn.Linear(in_chs, self.num_experts) @@ -271,8 +270,8 @@ class EdgeResidual(nn.Module): def __init__( self, in_chs, out_chs, exp_kernel_size=3, stride=1, dilation=1, pad_type='', - force_in_chs=0, noskip=False, exp_ratio=1.0, pw_kernel_size=1, se_ratio=0., - act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, se_layer=None, drop_path_rate=0.): + force_in_chs=0, noskip=False, exp_ratio=1.0, pw_kernel_size=1, act_layer=nn.ReLU, + norm_layer=nn.BatchNorm2d, se_layer=None, drop_path_rate=0.): super(EdgeResidual, self).__init__() if force_in_chs > 0: mid_chs = make_divisible(force_in_chs * exp_ratio) @@ -289,7 +288,7 @@ class EdgeResidual(nn.Module): self.act1 = act_layer(inplace=True) # Squeeze-and-excitation - self.se = SqueezeExcite(mid_chs, se_ratio=se_ratio, act_layer=act_layer) if has_se else nn.Identity() + self.se = se_layer(mid_chs, act_layer=act_layer) if se_layer else nn.Identity() # Point-wise linear projection self.conv_pwl = create_conv2d(mid_chs, out_chs, pw_kernel_size, padding=pad_type) diff --git a/timm/models/efficientnet_builder.py b/timm/models/efficientnet_builder.py index 35019747..f44cf158 100644 --- a/timm/models/efficientnet_builder.py +++ b/timm/models/efficientnet_builder.py @@ -10,11 +10,12 @@ import logging import math import re from copy import deepcopy +from functools import partial import torch.nn as nn from .efficientnet_blocks import * -from .layers import CondConv2d, get_condconv_initializer, get_act_layer, make_divisible +from .layers import CondConv2d, get_condconv_initializer, get_act_layer, get_attn, make_divisible __all__ = ["EfficientNetBuilder", "decode_arch_def", "efficientnet_init_weights", 'resolve_bn_args', 'resolve_act_layer', 'round_channels', 'BN_MOMENTUM_TF_DEFAULT', 'BN_EPS_TF_DEFAULT'] @@ -120,7 +121,9 @@ def _decode_block_str(block_str): elif v == 'hs': value = get_act_layer('hard_swish') elif v == 'sw': - value = get_act_layer('swish') + value = get_act_layer('swish') # aka SiLU + elif v == 'mi': + value = get_act_layer('mish') else: continue options[key] = value @@ -273,7 +276,12 @@ class EfficientNetBuilder: self.se_from_exp = se_from_exp # calculate se channel reduction from expanded (mid) chs self.act_layer = act_layer self.norm_layer = norm_layer - self.se_layer = se_layer + self.se_layer = get_attn(se_layer) + try: + self.se_layer(8, rd_ratio=1.0) + self.se_has_ratio = True + except RuntimeError as e: + self.se_has_ratio = False self.drop_path_rate = drop_path_rate if feature_location == 'depthwise': # old 'depthwise' mode renamed 'expansion' to match TF impl, old expansion mode didn't make sense @@ -300,18 +308,21 @@ class EfficientNetBuilder: ba['act_layer'] = ba['act_layer'] if ba['act_layer'] is not None else self.act_layer assert ba['act_layer'] is not None ba['norm_layer'] = self.norm_layer + ba['drop_path_rate'] = drop_path_rate if bt != 'cn': - ba['se_layer'] = self.se_layer - if not self.se_from_exp and ba['se_ratio']: - ba['se_ratio'] /= ba.get('exp_ratio', 1.0) - ba['drop_path_rate'] = drop_path_rate + se_ratio = ba.pop('se_ratio') + if se_ratio and self.se_layer is not None: + if not self.se_from_exp: + # adjust se_ratio by expansion ratio if calculating se channels from block input + se_ratio /= ba.get('exp_ratio', 1.0) + if self.se_has_ratio: + ba['se_layer'] = partial(self.se_layer, rd_ratio=se_ratio) + else: + ba['se_layer'] = self.se_layer if bt == 'ir': _log_info_if(' InvertedResidual {}, Args: {}'.format(block_idx, str(ba)), self.verbose) - if ba.get('num_experts', 0) > 0: - block = CondConvResidual(**ba) - else: - block = InvertedResidual(**ba) + block = CondConvResidual(**ba) if ba.get('num_experts', 0) else InvertedResidual(**ba) elif bt == 'ds' or bt == 'dsa': _log_info_if(' DepthwiseSeparable {}, Args: {}'.format(block_idx, str(ba)), self.verbose) block = DepthwiseSeparableConv(**ba) diff --git a/timm/models/ghostnet.py b/timm/models/ghostnet.py index d82a91b4..48dee6ec 100644 --- a/timm/models/ghostnet.py +++ b/timm/models/ghostnet.py @@ -40,7 +40,7 @@ default_cfgs = { } -_SE_LAYER = partial(SqueezeExcite, gate_fn='hard_sigmoid', round_chs_fn=partial(make_divisible, divisor=4)) +_SE_LAYER = partial(SqueezeExcite, gate_layer='hard_sigmoid', rd_round_fn=partial(make_divisible, divisor=4)) class GhostModule(nn.Module): @@ -92,7 +92,7 @@ class GhostBottleneck(nn.Module): self.bn_dw = None # Squeeze-and-excitation - self.se = _SE_LAYER(mid_chs, se_ratio=se_ratio) if has_se else None + self.se = _SE_LAYER(mid_chs, rd_ratio=se_ratio) if has_se else None # Point-wise linear projection self.ghost2 = GhostModule(mid_chs, out_chs, relu=False) diff --git a/timm/models/hardcorenas.py b/timm/models/hardcorenas.py index 16b9c4bc..9988a044 100644 --- a/timm/models/hardcorenas.py +++ b/timm/models/hardcorenas.py @@ -39,8 +39,7 @@ def _gen_hardcorenas(pretrained, variant, arch_def, **kwargs): """ num_features = 1280 - se_layer = partial( - SqueezeExcite, gate_fn=get_act_fn('hard_sigmoid'), force_act_layer=nn.ReLU, round_chs_fn=round_channels) + se_layer = partial(SqueezeExcite, gate_layer='hard_sigmoid', force_act_layer=nn.ReLU, rd_round_fn=round_channels) model_kwargs = dict( block_args=decode_arch_def(arch_def), num_features=num_features, diff --git a/timm/models/mobilenetv3.py b/timm/models/mobilenetv3.py index fad88aa7..e85112e6 100644 --- a/timm/models/mobilenetv3.py +++ b/timm/models/mobilenetv3.py @@ -266,7 +266,7 @@ def _gen_mobilenet_v3_rw(variant, channel_multiplier=1.0, pretrained=False, **kw round_chs_fn=partial(round_channels, multiplier=channel_multiplier), norm_layer=partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)), act_layer=resolve_act_layer(kwargs, 'hard_swish'), - se_layer=partial(SqueezeExcite, gate_fn=get_act_fn('hard_sigmoid')), + se_layer=partial(SqueezeExcite, gate_layer='hard_sigmoid'), **kwargs, ) model = _create_mnv3(variant, pretrained, **model_kwargs) @@ -354,8 +354,7 @@ def _gen_mobilenet_v3(variant, channel_multiplier=1.0, pretrained=False, **kwarg # stage 6, 7x7 in ['cn_r1_k1_s1_c960'], # hard-swish ] - se_layer = partial( - SqueezeExcite, gate_fn=get_act_fn('hard_sigmoid'), force_act_layer=nn.ReLU, round_chs_fn=round_channels) + se_layer = partial(SqueezeExcite, gate_layer='hard_sigmoid', force_act_layer=nn.ReLU, rd_round_fn=round_channels) model_kwargs = dict( block_args=decode_arch_def(arch_def), num_features=num_features, @@ -372,67 +371,48 @@ def _gen_mobilenet_v3(variant, channel_multiplier=1.0, pretrained=False, **kwarg def _gen_fbnetv3(variant, channel_multiplier=1.0, pretrained=False, **kwargs): """ FBNetV3 + Paper: `FBNetV3: Joint Architecture-Recipe Search using Predictor Pretraining` + - https://arxiv.org/abs/2006.02049 FIXME untested, this is a preliminary impl of some FBNet-V3 variants. """ vl = variant.split('_')[-1] if vl in ('a', 'b'): stem_size = 16 arch_def = [ - # stage 0, 112x112 in ['ds_r2_k3_s1_e1_c16'], - # stage 1, 112x112 in ['ir_r1_k5_s2_e4_c24', 'ir_r3_k5_s1_e2_c24'], - # stage 2, 56x56 in ['ir_r1_k5_s2_e5_c40_se0.25', 'ir_r4_k5_s1_e3_c40_se0.25'], - # stage 3, 28x28 in ['ir_r1_k5_s2_e5_c72', 'ir_r4_k3_s1_e3_c72'], - # stage 4, 14x14in ['ir_r1_k3_s1_e5_c120_se0.25', 'ir_r5_k5_s1_e3_c120_se0.25'], - # stage 5, 14x14in ['ir_r1_k3_s2_e6_c184_se0.25', 'ir_r5_k5_s1_e4_c184_se0.25', 'ir_r1_k5_s1_e6_c224_se0.25'], - # stage 6, 7x7 in ['cn_r1_k1_s1_c1344'], ] elif vl == 'd': stem_size = 24 arch_def = [ - # stage 0, 112x112 in ['ds_r2_k3_s1_e1_c16'], - # stage 1, 112x112 in ['ir_r1_k3_s2_e5_c24', 'ir_r5_k3_s1_e2_c24'], - # stage 2, 56x56 in ['ir_r1_k5_s2_e4_c40_se0.25', 'ir_r4_k3_s1_e3_c40_se0.25'], - # stage 3, 28x28 in ['ir_r1_k3_s2_e5_c72', 'ir_r4_k3_s1_e3_c72'], - # stage 4, 14x14in ['ir_r1_k3_s1_e5_c128_se0.25', 'ir_r6_k5_s1_e3_c128_se0.25'], - # stage 5, 14x14in ['ir_r1_k3_s2_e6_c208_se0.25', 'ir_r5_k5_s1_e5_c208_se0.25', 'ir_r1_k5_s1_e6_c240_se0.25'], - # stage 6, 7x7 in ['cn_r1_k1_s1_c1440'], ] elif vl == 'g': stem_size = 32 arch_def = [ - # stage 0, 112x112 in ['ds_r3_k3_s1_e1_c24'], - # stage 1, 112x112 in ['ir_r1_k5_s2_e4_c40', 'ir_r4_k5_s1_e2_c40'], - # stage 2, 56x56 in ['ir_r1_k5_s2_e4_c56_se0.25', 'ir_r4_k5_s1_e3_c56_se0.25'], - # stage 3, 28x28 in ['ir_r1_k5_s2_e5_c104', 'ir_r4_k3_s1_e3_c104'], - # stage 4, 14x14in ['ir_r1_k3_s1_e5_c160_se0.25', 'ir_r8_k5_s1_e3_c160_se0.25'], - # stage 5, 14x14in ['ir_r1_k3_s2_e6_c264_se0.25', 'ir_r6_k5_s1_e5_c264_se0.25', 'ir_r2_k5_s1_e6_c288_se0.25'], - # stage 6, 7x7 in - ['cn_r1_k1_s1_c1728'], # hard-swish + ['cn_r1_k1_s1_c1728'], ] else: raise NotImplemented round_chs_fn = partial(round_channels, multiplier=channel_multiplier, round_limit=0.95) - se_layer = partial(SqueezeExcite, gate_fn=get_act_fn('hard_sigmoid'), round_chs_fn=round_chs_fn) + se_layer = partial(SqueezeExcite, gate_layer='hard_sigmoid', rd_round_fn=round_chs_fn) act_layer = resolve_act_layer(kwargs, 'hard_swish') model_kwargs = dict( block_args=decode_arch_def(arch_def), From 8bf63b6c6cc2b4ba69030cb043bf33cd562b399c Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Sun, 30 May 2021 12:47:02 -0700 Subject: [PATCH 42/48] Able to use other attn layer in EfficientNet now. Create test ECA + GC B0 configs. Make ECA more configurable. --- tests/test_models.py | 2 +- timm/models/efficientnet.py | 24 +++++++++++++++++++ timm/models/efficientnet_builder.py | 4 ++-- timm/models/layers/eca.py | 36 +++++++++++++++++++++-------- 4 files changed, 54 insertions(+), 12 deletions(-) diff --git a/tests/test_models.py b/tests/test_models.py index 18298dff..1093e609 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -24,7 +24,7 @@ NUM_NON_STD = len(NON_STD_FILTERS) if 'GITHUB_ACTIONS' in os.environ: # and 'Linux' in platform.system(): # GitHub Linux runner is slower and hits memory limits sooner than MacOS, exclude bigger models EXCLUDE_FILTERS = [ - '*efficientnet_l2*', '*resnext101_32x48d', '*in21k', '*152x4_bitm', '*101x3_bitm', + '*efficientnet_l2*', '*resnext101_32x48d', '*in21k', '*152x4_bitm', '*101x3_bitm', '*50x3_bitm' '*nfnet_f3*', '*nfnet_f4*', '*nfnet_f5*', '*nfnet_f6*', '*nfnet_f7*', '*resnetrs350*', '*resnetrs420*'] else: diff --git a/timm/models/efficientnet.py b/timm/models/efficientnet.py index 8aa61ec5..09e47684 100644 --- a/timm/models/efficientnet.py +++ b/timm/models/efficientnet.py @@ -91,6 +91,12 @@ default_cfgs = { url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/spnasnet_100-048bc3f4.pth', interpolation='bilinear'), + # NOTE experimenting with alternate attention + 'eca_efficientnet_b0': _cfg( + url=''), + 'gc_efficientnet_b0': _cfg( + url=''), + 'efficientnet_b0': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_b0_ra-3dd342df.pth'), 'efficientnet_b1': _cfg( @@ -1223,6 +1229,24 @@ def efficientnet_b0(pretrained=False, **kwargs): return model +@register_model +def eca_efficientnet_b0(pretrained=False, **kwargs): + """ EfficientNet-B0 w/ ECA attn """ + # NOTE experimental config + model = _gen_efficientnet( + 'eca_efficientnet_b0', se_layer='eca', channel_multiplier=1.0, depth_multiplier=1.0, pretrained=pretrained, **kwargs) + return model + + +@register_model +def gc_efficientnet_b0(pretrained=False, **kwargs): + """ EfficientNet-B0 w/ GlobalContext """ + # NOTE experminetal config + model = _gen_efficientnet( + 'gc_efficientnet_b0', se_layer='gc', channel_multiplier=1.0, depth_multiplier=1.0, pretrained=pretrained, **kwargs) + return model + + @register_model def efficientnet_b1(pretrained=False, **kwargs): """ EfficientNet-B1 """ diff --git a/timm/models/efficientnet_builder.py b/timm/models/efficientnet_builder.py index f44cf158..a23e8273 100644 --- a/timm/models/efficientnet_builder.py +++ b/timm/models/efficientnet_builder.py @@ -278,9 +278,9 @@ class EfficientNetBuilder: self.norm_layer = norm_layer self.se_layer = get_attn(se_layer) try: - self.se_layer(8, rd_ratio=1.0) + self.se_layer(8, rd_ratio=1.0) # test if attn layer accepts rd_ratio arg self.se_has_ratio = True - except RuntimeError as e: + except TypeError: self.se_has_ratio = False self.drop_path_rate = drop_path_rate if feature_location == 'depthwise': diff --git a/timm/models/layers/eca.py b/timm/models/layers/eca.py index f2980730..5c024108 100644 --- a/timm/models/layers/eca.py +++ b/timm/models/layers/eca.py @@ -38,6 +38,9 @@ from torch import nn import torch.nn.functional as F +from .create_act import create_act_layer + + class EcaModule(nn.Module): """Constructs an ECA module. @@ -48,20 +51,27 @@ class EcaModule(nn.Module): refer to original paper https://arxiv.org/pdf/1910.03151.pdf (default=None. if channel size not given, use k_size given for kernel size.) kernel_size: Adaptive selection of kernel size (default=3) + gamm: used in kernel_size calc, see above + beta: used in kernel_size calc, see above + act_layer: optional non-linearity after conv, enables conv bias, this is an experiment + gate_layer: gating non-linearity to use """ - def __init__(self, channels=None, kernel_size=3, gamma=2, beta=1): + def __init__(self, channels=None, kernel_size=3, gamma=2, beta=1, act_layer=None, gate_layer='sigmoid'): super(EcaModule, self).__init__() - assert kernel_size % 2 == 1 if channels is not None: t = int(abs(math.log(channels, 2) + beta) / gamma) kernel_size = max(t if t % 2 else t + 1, 3) - - self.conv = nn.Conv1d(1, 1, kernel_size=kernel_size, padding=(kernel_size - 1) // 2, bias=False) + assert kernel_size % 2 == 1 + has_act = act_layer is not None + self.conv = nn.Conv1d(1, 1, kernel_size=kernel_size, padding=(kernel_size - 1) // 2, bias=has_act) + self.act = create_act_layer(act_layer) if has_act else nn.Identity() + self.gate = create_act_layer(gate_layer) def forward(self, x): y = x.mean((2, 3)).view(x.shape[0], 1, -1) # view for 1d conv y = self.conv(y) - y = y.view(x.shape[0], -1, 1, 1).sigmoid() + y = self.act(y) # NOTE: usually a no-op, added for experimentation + y = self.gate(y).view(x.shape[0], -1, 1, 1) return x * y.expand_as(x) @@ -86,27 +96,35 @@ class CecaModule(nn.Module): refer to original paper https://arxiv.org/pdf/1910.03151.pdf (default=None. if channel size not given, use k_size given for kernel size.) kernel_size: Adaptive selection of kernel size (default=3) + gamm: used in kernel_size calc, see above + beta: used in kernel_size calc, see above + act_layer: optional non-linearity after conv, enables conv bias, this is an experiment + gate_layer: gating non-linearity to use """ - def __init__(self, channels=None, kernel_size=3, gamma=2, beta=1): + def __init__(self, channels=None, kernel_size=3, gamma=2, beta=1, act_layer=None, gate_layer='sigmoid'): super(CecaModule, self).__init__() - assert kernel_size % 2 == 1 if channels is not None: t = int(abs(math.log(channels, 2) + beta) / gamma) kernel_size = max(t if t % 2 else t + 1, 3) + has_act = act_layer is not None + assert kernel_size % 2 == 1 # PyTorch circular padding mode is buggy as of pytorch 1.4 # see https://github.com/pytorch/pytorch/pull/17240 # implement manual circular padding - self.conv = nn.Conv1d(1, 1, kernel_size=kernel_size, padding=0, bias=False) self.padding = (kernel_size - 1) // 2 + self.conv = nn.Conv1d(1, 1, kernel_size=kernel_size, padding=0, bias=has_act) + self.act = create_act_layer(act_layer) if has_act else nn.Identity() + self.gate = create_act_layer(gate_layer) def forward(self, x): y = x.mean((2, 3)).view(x.shape[0], 1, -1) # Manually implement circular padding, F.pad does not seemed to be bugged y = F.pad(y, (self.padding, self.padding), mode='circular') y = self.conv(y) - y = y.view(x.shape[0], -1, 1, 1).sigmoid() + y = self.act(y) # NOTE: usually a no-op, added for experimentation + y = self.gate(y).view(x.shape[0], -1, 1, 1) return x * y.expand_as(x) From 34522097b1d847f11263d9005d8dd1ff584c3edb Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Sun, 30 May 2021 21:12:10 -0700 Subject: [PATCH 43/48] See if we can use tcmalloc in test runner --- .github/workflows/tests.yml | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 9f7aebdb..f404085a 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -36,11 +36,16 @@ jobs: run: pip install --no-cache-dir torch==${{ matrix.torch }} torchvision==${{ matrix.torchvision }} - name: Install torch on ubuntu if: startsWith(matrix.os, 'ubuntu') - run: pip install --no-cache-dir torch==${{ matrix.torch }}+cpu torchvision==${{ matrix.torchvision }}+cpu -f https://download.pytorch.org/whl/torch_stable.html + run: | + pip install --no-cache-dir torch==${{ matrix.torch }}+cpu torchvision==${{ matrix.torchvision }}+cpu -f https://download.pytorch.org/whl/torch_stable.html + sudo apt update + sudo apt install -y google-perftools - name: Install requirements run: | if [ -f requirements.txt ]; then pip install -r requirements.txt; fi pip install --no-cache-dir git+https://github.com/mapillary/inplace_abn.git@v1.0.12 - name: Run tests + env: + LD_PRELOAD: /usr/lib/x86_64-linux-gnu/libtcmalloc.so.4 run: | pytest -vv --durations=0 ./tests From 17dc47c8e64e1452a0f2be7883a55b2f618229eb Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Sun, 30 May 2021 22:00:43 -0700 Subject: [PATCH 44/48] Missed comma in test filters. --- tests/test_models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_models.py b/tests/test_models.py index 1093e609..5a31935e 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -24,7 +24,7 @@ NUM_NON_STD = len(NON_STD_FILTERS) if 'GITHUB_ACTIONS' in os.environ: # and 'Linux' in platform.system(): # GitHub Linux runner is slower and hits memory limits sooner than MacOS, exclude bigger models EXCLUDE_FILTERS = [ - '*efficientnet_l2*', '*resnext101_32x48d', '*in21k', '*152x4_bitm', '*101x3_bitm', '*50x3_bitm' + '*efficientnet_l2*', '*resnext101_32x48d', '*in21k', '*152x4_bitm', '*101x3_bitm', '*50x3_bitm', '*nfnet_f3*', '*nfnet_f4*', '*nfnet_f5*', '*nfnet_f6*', '*nfnet_f7*', '*resnetrs350*', '*resnetrs420*'] else: From 307a935b790b5af8d551ebecda053cb1a9b16fcb Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Mon, 31 May 2021 13:18:11 -0700 Subject: [PATCH 45/48] Add non-local and BAT attention. Merge attn and self-attn factories into one. Add attention references to README. Add mlp 'mode' to ECA. --- README.md | 16 ++- timm/models/byobnet.py | 6 +- timm/models/efficientnet.py | 6 +- timm/models/layers/__init__.py | 6 +- timm/models/layers/create_attn.py | 45 +++++++- timm/models/layers/create_self_attn.py | 25 ----- timm/models/layers/eca.py | 28 +++-- timm/models/layers/non_local_attn.py | 145 +++++++++++++++++++++++++ timm/models/layers/selective_kernel.py | 17 +-- timm/models/layers/split_attn.py | 39 +++---- timm/models/layers/squeeze_excite.py | 2 +- timm/models/resnest.py | 17 ++- timm/models/sknet.py | 16 +-- 13 files changed, 276 insertions(+), 92 deletions(-) delete mode 100644 timm/models/layers/create_self_attn.py create mode 100644 timm/models/layers/non_local_attn.py diff --git a/README.md b/README.md index 06aee7ec..0b878a0a 100644 --- a/README.md +++ b/README.md @@ -295,10 +295,24 @@ Several (less common) features that I often utilize in my projects are included. * SplitBachNorm - allows splitting batch norm layers between clean and augmented (auxiliary batch norm) data * DropPath aka "Stochastic Depth" (https://arxiv.org/abs/1603.09382) * DropBlock (https://arxiv.org/abs/1810.12890) -* Efficient Channel Attention - ECA (https://arxiv.org/abs/1910.03151) * Blur Pooling (https://arxiv.org/abs/1904.11486) * Space-to-Depth by [mrT23](https://github.com/mrT23/TResNet/blob/master/src/models/tresnet/layers/space_to_depth.py) (https://arxiv.org/abs/1801.04590) -- original paper? * Adaptive Gradient Clipping (https://arxiv.org/abs/2102.06171, https://github.com/deepmind/deepmind-research/tree/master/nfnets) +* An extensive selection of channel and/or spatial attention modules: + * Bottleneck Transformer - https://arxiv.org/abs/2101.11605 + * CBAM - https://arxiv.org/abs/1807.06521 + * Effective Squeeze-Excitation (ESE) - https://arxiv.org/abs/1911.06667 + * Efficient Channel Attention (ECA) - https://arxiv.org/abs/1910.03151 + * Gather-Excite (GE) - https://arxiv.org/abs/1810.12348 + * Global Context (GC) - https://arxiv.org/abs/1904.11492 + * Halo - https://arxiv.org/abs/2103.12731 + * Involution - https://arxiv.org/abs/2103.06255 + * Lambda Layer - https://arxiv.org/abs/2102.08602 + * Non-Local (NL) - https://arxiv.org/abs/1711.07971 + * Squeeze-and-Excitation (SE) - https://arxiv.org/abs/1709.01507 + * Selective Kernel (SK) - (https://arxiv.org/abs/1903.06586 + * Split (SPLAT) - https://arxiv.org/abs/2004.08955 + * Shifted Window (SWIN) - https://arxiv.org/abs/2103.14030 ## Results diff --git a/timm/models/byobnet.py b/timm/models/byobnet.py index 8ec8690a..d41245f5 100644 --- a/timm/models/byobnet.py +++ b/timm/models/byobnet.py @@ -35,7 +35,7 @@ import torch.nn as nn from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from .helpers import build_model_with_cfg from .layers import ClassifierHead, ConvBnAct, BatchNormAct2d, DropPath, AvgPool2dSame, \ - create_conv2d, get_act_layer, convert_norm_act, get_attn, get_self_attn, make_divisible, to_2tuple + create_conv2d, get_act_layer, convert_norm_act, get_attn, make_divisible, to_2tuple from .registry import register_model __all__ = ['ByobNet', 'ByoModelCfg', 'ByoBlockCfg', 'create_byob_stem', 'create_block'] @@ -935,7 +935,7 @@ def update_block_kwargs(block_kwargs: Dict[str, Any], block_cfg: ByoBlockCfg, mo else: self_attn_kwargs = override_kwargs(block_cfg.self_attn_kwargs, model_cfg.self_attn_kwargs) self_attn_layer = block_cfg.self_attn_layer or model_cfg.self_attn_layer - self_attn_layer = partial(get_self_attn(self_attn_layer), *self_attn_kwargs) \ + self_attn_layer = partial(get_attn(self_attn_layer), *self_attn_kwargs) \ if self_attn_layer is not None else None layer_fns = replace(layer_fns, self_attn=self_attn_layer) @@ -1010,7 +1010,7 @@ def get_layer_fns(cfg: ByoModelCfg): norm_act = convert_norm_act(norm_layer=cfg.norm_layer, act_layer=act) conv_norm_act = partial(ConvBnAct, norm_layer=cfg.norm_layer, act_layer=act) attn = partial(get_attn(cfg.attn_layer), **cfg.attn_kwargs) if cfg.attn_layer else None - self_attn = partial(get_self_attn(cfg.self_attn_layer), **cfg.self_attn_kwargs) if cfg.self_attn_layer else None + self_attn = partial(get_attn(cfg.self_attn_layer), **cfg.self_attn_kwargs) if cfg.self_attn_layer else None layer_fn = LayerFn(conv_norm_act=conv_norm_act, norm_act=norm_act, act=act, attn=attn, self_attn=self_attn) return layer_fn diff --git a/timm/models/efficientnet.py b/timm/models/efficientnet.py index 09e47684..6426b540 100644 --- a/timm/models/efficientnet.py +++ b/timm/models/efficientnet.py @@ -1234,7 +1234,8 @@ def eca_efficientnet_b0(pretrained=False, **kwargs): """ EfficientNet-B0 w/ ECA attn """ # NOTE experimental config model = _gen_efficientnet( - 'eca_efficientnet_b0', se_layer='eca', channel_multiplier=1.0, depth_multiplier=1.0, pretrained=pretrained, **kwargs) + 'eca_efficientnet_b0', se_layer='ecam', channel_multiplier=1.0, depth_multiplier=1.0, + pretrained=pretrained, **kwargs) return model @@ -1243,7 +1244,8 @@ def gc_efficientnet_b0(pretrained=False, **kwargs): """ EfficientNet-B0 w/ GlobalContext """ # NOTE experminetal config model = _gen_efficientnet( - 'gc_efficientnet_b0', se_layer='gc', channel_multiplier=1.0, depth_multiplier=1.0, pretrained=pretrained, **kwargs) + 'gc_efficientnet_b0', se_layer='gc', channel_multiplier=1.0, depth_multiplier=1.0, + pretrained=pretrained, **kwargs) return model diff --git a/timm/models/layers/__init__.py b/timm/models/layers/__init__.py index 30a1b40d..77d1026e 100644 --- a/timm/models/layers/__init__.py +++ b/timm/models/layers/__init__.py @@ -12,7 +12,6 @@ from .create_act import create_act_layer, get_act_layer, get_act_fn from .create_attn import get_attn, create_attn from .create_conv2d import create_conv2d from .create_norm_act import get_norm_act_layer, create_norm_act, convert_norm_act -from .create_self_attn import get_self_attn, create_self_attn from .drop import DropBlock2d, DropPath, drop_block_2d, drop_path from .eca import EcaModule, CecaModule, EfficientChannelAttn, CircularEfficientChannelAttn from .evo_norm import EvoNormBatch2d, EvoNormSample2d @@ -24,16 +23,17 @@ from .involution import Involution from .linear import Linear from .mixed_conv2d import MixedConv2d from .mlp import Mlp, GluMlp, GatedMlp +from .non_local_attn import NonLocalAttn, BatNonLocalAttn from .norm import GroupNorm, LayerNorm2d from .norm_act import BatchNormAct2d, GroupNormAct from .padding import get_padding, get_same_padding, pad_same from .patch_embed import PatchEmbed from .pool2d_same import AvgPool2dSame, create_pool2d from .squeeze_excite import SEModule, SqueezeExcite, EffectiveSEModule, EffectiveSqueezeExcite -from .selective_kernel import SelectiveKernelConv +from .selective_kernel import SelectiveKernel from .separable_conv import SeparableConv2d, SeparableConvBnAct from .space_to_depth import SpaceToDepthModule -from .split_attn import SplitAttnConv2d +from .split_attn import SplitAttn from .split_batchnorm import SplitBatchNorm2d, convert_splitbn_model from .std_conv import StdConv2d, StdConv2dSame, ScaledStdConv2d, ScaledStdConv2dSame from .test_time_pool import TestTimePoolHead, apply_test_time_pool diff --git a/timm/models/layers/create_attn.py b/timm/models/layers/create_attn.py index de866eea..3fed646b 100644 --- a/timm/models/layers/create_attn.py +++ b/timm/models/layers/create_attn.py @@ -1,14 +1,23 @@ -""" Select AttentionFactory Method +""" Attention Factory -Hacked together by / Copyright 2020 Ross Wightman +Hacked together by / Copyright 2021 Ross Wightman """ import torch +from functools import partial +from .bottleneck_attn import BottleneckAttn from .cbam import CbamModule, LightCbamModule from .eca import EcaModule, CecaModule from .gather_excite import GatherExcite from .global_context import GlobalContext +from .halo_attn import HaloAttn +from .involution import Involution +from .lambda_layer import LambdaLayer +from .non_local_attn import NonLocalAttn, BatNonLocalAttn +from .selective_kernel import SelectiveKernel +from .split_attn import SplitAttn from .squeeze_excite import SEModule, EffectiveSEModule +from .swin_attn import WindowAttention def get_attn(attn_type): @@ -18,12 +27,16 @@ def get_attn(attn_type): if attn_type is not None: if isinstance(attn_type, str): attn_type = attn_type.lower() + # Lightweight attention modules (channel and/or coarse spatial). + # Typically added to existing network architecture blocks in addition to existing convolutions. if attn_type == 'se': module_cls = SEModule elif attn_type == 'ese': module_cls = EffectiveSEModule elif attn_type == 'eca': module_cls = EcaModule + elif attn_type == 'ecam': + module_cls = partial(EcaModule, use_mlp=True) elif attn_type == 'ceca': module_cls = CecaModule elif attn_type == 'ge': @@ -34,6 +47,34 @@ def get_attn(attn_type): module_cls = CbamModule elif attn_type == 'lcbam': module_cls = LightCbamModule + + # Attention / attention-like modules w/ significant params + # Typically replace some of the existing workhorse convs in a network architecture. + # All of these accept a stride argument and can spatially downsample the input. + elif attn_type == 'sk': + module_cls = SelectiveKernel + elif attn_type == 'splat': + module_cls = SplitAttn + + # Self-attention / attention-like modules w/ significant compute and/or params + # Typically replace some of the existing workhorse convs in a network architecture. + # All of these accept a stride argument and can spatially downsample the input. + elif attn_type == 'lambda': + return LambdaLayer + elif attn_type == 'bottleneck': + return BottleneckAttn + elif attn_type == 'halo': + return HaloAttn + elif attn_type == 'swin': + return WindowAttention + elif attn_type == 'involution': + return Involution + elif attn_type == 'nl': + module_cls = NonLocalAttn + elif attn_type == 'bat': + module_cls = BatNonLocalAttn + + # Woops! else: assert False, "Invalid attn module (%s)" % attn_type elif isinstance(attn_type, bool): diff --git a/timm/models/layers/create_self_attn.py b/timm/models/layers/create_self_attn.py deleted file mode 100644 index 448ddb34..00000000 --- a/timm/models/layers/create_self_attn.py +++ /dev/null @@ -1,25 +0,0 @@ -from .bottleneck_attn import BottleneckAttn -from .halo_attn import HaloAttn -from .involution import Involution -from .lambda_layer import LambdaLayer -from .swin_attn import WindowAttention - - -def get_self_attn(attn_type): - if attn_type == 'bottleneck': - return BottleneckAttn - elif attn_type == 'halo': - return HaloAttn - elif attn_type == 'lambda': - return LambdaLayer - elif attn_type == 'swin': - return WindowAttention - elif attn_type == 'involution': - return Involution - else: - assert False, f"Unknown attn type ({attn_type})" - - -def create_self_attn(attn_type, dim, stride=1, **kwargs): - attn_fn = get_self_attn(attn_type) - return attn_fn(dim, stride=stride, **kwargs) diff --git a/timm/models/layers/eca.py b/timm/models/layers/eca.py index 5c024108..e29be6ac 100644 --- a/timm/models/layers/eca.py +++ b/timm/models/layers/eca.py @@ -39,6 +39,7 @@ import torch.nn.functional as F from .create_act import create_act_layer +from .helpers import make_divisible class EcaModule(nn.Module): @@ -56,21 +57,36 @@ class EcaModule(nn.Module): act_layer: optional non-linearity after conv, enables conv bias, this is an experiment gate_layer: gating non-linearity to use """ - def __init__(self, channels=None, kernel_size=3, gamma=2, beta=1, act_layer=None, gate_layer='sigmoid'): + def __init__( + self, channels=None, kernel_size=3, gamma=2, beta=1, act_layer=None, gate_layer='sigmoid', + rd_ratio=1/8, rd_channels=None, rd_divisor=8, use_mlp=False): super(EcaModule, self).__init__() if channels is not None: t = int(abs(math.log(channels, 2) + beta) / gamma) kernel_size = max(t if t % 2 else t + 1, 3) assert kernel_size % 2 == 1 - has_act = act_layer is not None - self.conv = nn.Conv1d(1, 1, kernel_size=kernel_size, padding=(kernel_size - 1) // 2, bias=has_act) - self.act = create_act_layer(act_layer) if has_act else nn.Identity() + padding = (kernel_size - 1) // 2 + if use_mlp: + # NOTE 'mlp' mode is a timm experiment, not in paper + assert channels is not None + if rd_channels is None: + rd_channels = make_divisible(channels * rd_ratio, divisor=rd_divisor) + act_layer = act_layer or nn.ReLU + self.conv = nn.Conv1d(1, rd_channels, kernel_size=1, padding=0, bias=True) + self.act = create_act_layer(act_layer) + self.conv2 = nn.Conv1d(rd_channels, 1, kernel_size=kernel_size, padding=padding, bias=True) + else: + self.conv = nn.Conv1d(1, 1, kernel_size=kernel_size, padding=padding, bias=False) + self.act = None + self.conv2 = None self.gate = create_act_layer(gate_layer) def forward(self, x): y = x.mean((2, 3)).view(x.shape[0], 1, -1) # view for 1d conv y = self.conv(y) - y = self.act(y) # NOTE: usually a no-op, added for experimentation + if self.conv2 is not None: + y = self.act(y) + y = self.conv2(y) y = self.gate(y).view(x.shape[0], -1, 1, 1) return x * y.expand_as(x) @@ -115,7 +131,6 @@ class CecaModule(nn.Module): # implement manual circular padding self.padding = (kernel_size - 1) // 2 self.conv = nn.Conv1d(1, 1, kernel_size=kernel_size, padding=0, bias=has_act) - self.act = create_act_layer(act_layer) if has_act else nn.Identity() self.gate = create_act_layer(gate_layer) def forward(self, x): @@ -123,7 +138,6 @@ class CecaModule(nn.Module): # Manually implement circular padding, F.pad does not seemed to be bugged y = F.pad(y, (self.padding, self.padding), mode='circular') y = self.conv(y) - y = self.act(y) # NOTE: usually a no-op, added for experimentation y = self.gate(y).view(x.shape[0], -1, 1, 1) return x * y.expand_as(x) diff --git a/timm/models/layers/non_local_attn.py b/timm/models/layers/non_local_attn.py new file mode 100644 index 00000000..d20a5f3e --- /dev/null +++ b/timm/models/layers/non_local_attn.py @@ -0,0 +1,145 @@ +""" Bilinear-Attention-Transform and Non-Local Attention + +Paper: `Non-Local Neural Networks With Grouped Bilinear Attentional Transforms` + - https://openaccess.thecvf.com/content_CVPR_2020/html/Chi_Non-Local_Neural_Networks_With_Grouped_Bilinear_Attentional_Transforms_CVPR_2020_paper.html +Adapted from original code: https://github.com/BA-Transform/BAT-Image-Classification +""" +import torch +from torch import nn +from torch.nn import functional as F + +from .conv_bn_act import ConvBnAct +from .helpers import make_divisible + + +class NonLocalAttn(nn.Module): + """Spatial NL block for image classification. + + This was adapted from https://github.com/BA-Transform/BAT-Image-Classification + Their NonLocal impl inspired by https://github.com/facebookresearch/video-nonlocal-net. + """ + + def __init__(self, in_channels, use_scale=True, rd_ratio=1/8, rd_channels=None, rd_divisor=8, **kwargs): + super(NonLocalAttn, self).__init__() + if rd_channels is None: + rd_channels = make_divisible(in_channels * rd_ratio, divisor=rd_divisor) + self.scale = in_channels ** -0.5 if use_scale else 1.0 + self.t = nn.Conv2d(in_channels, rd_channels, kernel_size=1, stride=1, bias=True) + self.p = nn.Conv2d(in_channels, rd_channels, kernel_size=1, stride=1, bias=True) + self.g = nn.Conv2d(in_channels, rd_channels, kernel_size=1, stride=1, bias=True) + self.z = nn.Conv2d(rd_channels, in_channels, kernel_size=1, stride=1, bias=True) + self.norm = nn.BatchNorm2d(in_channels) + self.reset_parameters() + + def forward(self, x): + shortcut = x + + t = self.t(x) + p = self.p(x) + g = self.g(x) + + B, C, H, W = t.size() + t = t.view(B, C, -1).permute(0, 2, 1) + p = p.view(B, C, -1) + g = g.view(B, C, -1).permute(0, 2, 1) + + att = torch.bmm(t, p) * self.scale + att = F.softmax(att, dim=2) + x = torch.bmm(att, g) + + x = x.permute(0, 2, 1).reshape(B, C, H, W) + x = self.z(x) + x = self.norm(x) + shortcut + + return x + + def reset_parameters(self): + for name, m in self.named_modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_( + m.weight, mode='fan_out', nonlinearity='relu') + if len(list(m.parameters())) > 1: + nn.init.constant_(m.bias, 0.0) + elif isinstance(m, nn.BatchNorm2d): + nn.init.constant_(m.weight, 0) + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.GroupNorm): + nn.init.constant_(m.weight, 0) + nn.init.constant_(m.bias, 0) + + +class BilinearAttnTransform(nn.Module): + + def __init__(self, in_channels, block_size, groups, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d): + super(BilinearAttnTransform, self).__init__() + + self.conv1 = ConvBnAct(in_channels, groups, 1, act_layer=act_layer, norm_layer=norm_layer) + self.conv_p = nn.Conv2d(groups, block_size * block_size * groups, kernel_size=(block_size, 1)) + self.conv_q = nn.Conv2d(groups, block_size * block_size * groups, kernel_size=(1, block_size)) + self.conv2 = ConvBnAct(in_channels, in_channels, 1, act_layer=act_layer, norm_layer=norm_layer) + self.block_size = block_size + self.groups = groups + self.in_channels = in_channels + + def resize_mat(self, x, t): + B, C, block_size, block_size1 = x.shape + assert block_size == block_size1 + if t <= 1: + return x + x = x.view(B * C, -1, 1, 1) + x = x * torch.eye(t, t, dtype=x.dtype, device=x.device) + x = x.view(B * C, block_size, block_size, t, t) + x = torch.cat(torch.split(x, 1, dim=1), dim=3) + x = torch.cat(torch.split(x, 1, dim=2), dim=4) + x = x.view(B, C, block_size * t, block_size * t) + return x + + def forward(self, x): + assert x.shape[-1] % self.block_size == 0 and x.shape[-2] % self.block_size == 0 + B, C, H, W = x.shape + out = self.conv1(x) + rp = F.adaptive_max_pool2d(out, (self.block_size, 1)) + cp = F.adaptive_max_pool2d(out, (1, self.block_size)) + p = self.conv_p(rp).view(B, self.groups, self.block_size, self.block_size) + q = self.conv_q(cp).view(B, self.groups, self.block_size, self.block_size) + p = F.sigmoid(p) + q = F.sigmoid(q) + p = p / p.sum(dim=3, keepdim=True) + q = q / q.sum(dim=2, keepdim=True) + p = p.view(B, self.groups, 1, self.block_size, self.block_size).expand(x.size( + 0), self.groups, C // self.groups, self.block_size, self.block_size).contiguous() + p = p.view(B, C, self.block_size, self.block_size) + q = q.view(B, self.groups, 1, self.block_size, self.block_size).expand(x.size( + 0), self.groups, C // self.groups, self.block_size, self.block_size).contiguous() + q = q.view(B, C, self.block_size, self.block_size) + p = self.resize_mat(p, H // self.block_size) + q = self.resize_mat(q, W // self.block_size) + y = p.matmul(x) + y = y.matmul(q) + + y = self.conv2(y) + return y + + +class BatNonLocalAttn(nn.Module): + """ BAT + Adapted from: https://github.com/BA-Transform/BAT-Image-Classification + """ + + def __init__( + self, in_channels, block_size=7, groups=2, rd_ratio=0.25, rd_channels=None, rd_divisor=8, + drop_rate=0.2, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, **_): + super().__init__() + if rd_channels is None: + rd_channels = make_divisible(in_channels * rd_ratio, divisor=rd_divisor) + self.conv1 = ConvBnAct(in_channels, rd_channels, 1, act_layer=act_layer, norm_layer=norm_layer) + self.ba = BilinearAttnTransform(rd_channels, block_size, groups, act_layer=act_layer, norm_layer=norm_layer) + self.conv2 = ConvBnAct(rd_channels, in_channels, 1, act_layer=act_layer, norm_layer=norm_layer) + self.dropout = nn.Dropout2d(p=drop_rate) + + def forward(self, x): + xl = self.conv1(x) + y = self.ba(xl) + y = self.conv2(y) + y = self.dropout(y) + return y + x diff --git a/timm/models/layers/selective_kernel.py b/timm/models/layers/selective_kernel.py index 10bfd0e0..246f72a6 100644 --- a/timm/models/layers/selective_kernel.py +++ b/timm/models/layers/selective_kernel.py @@ -8,6 +8,7 @@ import torch from torch import nn as nn from .conv_bn_act import ConvBnAct +from .helpers import make_divisible def _kernel_valid(k): @@ -45,10 +46,10 @@ class SelectiveKernelAttn(nn.Module): return x -class SelectiveKernelConv(nn.Module): +class SelectiveKernel(nn.Module): - def __init__(self, in_channels, out_channels, kernel_size=None, stride=1, dilation=1, groups=1, - attn_reduction=16, min_attn_channels=32, keep_3x3=True, split_input=False, + def __init__(self, in_channels, out_channels=None, kernel_size=None, stride=1, dilation=1, groups=1, + rd_ratio=1./16, rd_channels=None, min_rd_channels=16, rd_divisor=8, keep_3x3=True, split_input=True, drop_block=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, aa_layer=None): """ Selective Kernel Convolution Module @@ -66,8 +67,8 @@ class SelectiveKernelConv(nn.Module): stride (int): stride for convolutions dilation (int): dilation for module as a whole, impacts dilation of each branch groups (int): number of groups for each branch - attn_reduction (int, float): reduction factor for attention features - min_attn_channels (int): minimum attention feature channels + rd_ratio (int, float): reduction factor for attention features + min_rd_channels (int): minimum attention feature channels keep_3x3 (bool): keep all branch convolution kernels as 3x3, changing larger kernels for dilations split_input (bool): split input channels evenly across each convolution branch, keeps param count lower, can be viewed as grouping by path, output expands to module out_channels count @@ -75,7 +76,8 @@ class SelectiveKernelConv(nn.Module): act_layer (nn.Module): activation layer to use norm_layer (nn.Module): batchnorm/norm layer to use """ - super(SelectiveKernelConv, self).__init__() + super(SelectiveKernel, self).__init__() + out_channels = out_channels or in_channels kernel_size = kernel_size or [3, 5] # default to one 3x3 and one 5x5 branch. 5x5 -> 3x3 + dilation _kernel_valid(kernel_size) if not isinstance(kernel_size, list): @@ -101,7 +103,8 @@ class SelectiveKernelConv(nn.Module): ConvBnAct(in_channels, out_channels, kernel_size=k, dilation=d, **conv_kwargs) for k, d in zip(kernel_size, dilation)]) - attn_channels = max(int(out_channels / attn_reduction), min_attn_channels) + attn_channels = rd_channels or make_divisible( + out_channels * rd_ratio, min_value=min_rd_channels, divisor=rd_divisor) self.attn = SelectiveKernelAttn(out_channels, self.num_paths, attn_channels) self.drop_block = drop_block diff --git a/timm/models/layers/split_attn.py b/timm/models/layers/split_attn.py index 5615aa0b..dde601be 100644 --- a/timm/models/layers/split_attn.py +++ b/timm/models/layers/split_attn.py @@ -10,6 +10,8 @@ import torch import torch.nn.functional as F from torch import nn +from .helpers import make_divisible + class RadixSoftmax(nn.Module): def __init__(self, radix, cardinality): @@ -28,41 +30,37 @@ class RadixSoftmax(nn.Module): return x -class SplitAttnConv2d(nn.Module): - """Split-Attention Conv2d +class SplitAttn(nn.Module): + """Split-Attention (aka Splat) """ - def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, - dilation=1, groups=1, bias=False, radix=2, reduction_factor=4, + def __init__(self, in_channels, out_channels=None, kernel_size=3, stride=1, padding=None, + dilation=1, groups=1, bias=False, radix=2, rd_ratio=0.25, rd_channels=None, rd_divisor=8, act_layer=nn.ReLU, norm_layer=None, drop_block=None, **kwargs): - super(SplitAttnConv2d, self).__init__() + super(SplitAttn, self).__init__() + out_channels = out_channels or in_channels self.radix = radix self.drop_block = drop_block mid_chs = out_channels * radix - attn_chs = max(in_channels * radix // reduction_factor, 32) + if rd_channels is None: + attn_chs = make_divisible(in_channels * radix * rd_ratio, min_value=32, divisor=rd_divisor) + else: + attn_chs = rd_channels * radix + padding = kernel_size // 2 if padding is None else padding self.conv = nn.Conv2d( in_channels, mid_chs, kernel_size, stride, padding, dilation, groups=groups * radix, bias=bias, **kwargs) - self.bn0 = norm_layer(mid_chs) if norm_layer is not None else None + self.bn0 = norm_layer(mid_chs) if norm_layer else nn.Identity() self.act0 = act_layer(inplace=True) self.fc1 = nn.Conv2d(out_channels, attn_chs, 1, groups=groups) - self.bn1 = norm_layer(attn_chs) if norm_layer is not None else None + self.bn1 = norm_layer(attn_chs) if norm_layer else nn.Identity() self.act1 = act_layer(inplace=True) self.fc2 = nn.Conv2d(attn_chs, mid_chs, 1, groups=groups) self.rsoftmax = RadixSoftmax(radix, groups) - @property - def in_channels(self): - return self.conv.in_channels - - @property - def out_channels(self): - return self.fc1.out_channels - def forward(self, x): x = self.conv(x) - if self.bn0 is not None: - x = self.bn0(x) + x = self.bn0(x) if self.drop_block is not None: x = self.drop_block(x) x = self.act0(x) @@ -73,10 +71,9 @@ class SplitAttnConv2d(nn.Module): x_gap = x.sum(dim=1) else: x_gap = x - x_gap = F.adaptive_avg_pool2d(x_gap, 1) + x_gap = x_gap.mean((2, 3), keepdim=True) x_gap = self.fc1(x_gap) - if self.bn1 is not None: - x_gap = self.bn1(x_gap) + x_gap = self.bn1(x_gap) x_gap = self.act1(x_gap) x_attn = self.fc2(x_gap) diff --git a/timm/models/layers/squeeze_excite.py b/timm/models/layers/squeeze_excite.py index 3e8a05bb..e5da29ef 100644 --- a/timm/models/layers/squeeze_excite.py +++ b/timm/models/layers/squeeze_excite.py @@ -56,7 +56,7 @@ class EffectiveSEModule(nn.Module): """ 'Effective Squeeze-Excitation From `CenterMask : Real-Time Anchor-Free Instance Segmentation` - https://arxiv.org/abs/1911.06667 """ - def __init__(self, channels, add_maxpool=False, gate_layer='hard_sigmoid'): + def __init__(self, channels, add_maxpool=False, gate_layer='hard_sigmoid', **_): super(EffectiveSEModule, self).__init__() self.add_maxpool = add_maxpool self.fc = nn.Conv2d(channels, channels, kernel_size=1, padding=0) diff --git a/timm/models/resnest.py b/timm/models/resnest.py index ac3b2559..31eebd80 100644 --- a/timm/models/resnest.py +++ b/timm/models/resnest.py @@ -11,7 +11,7 @@ from torch import nn from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from .helpers import build_model_with_cfg -from .layers import SplitAttnConv2d +from .layers import SplitAttn from .registry import register_model from .resnet import ResNet @@ -83,11 +83,11 @@ class ResNestBottleneck(nn.Module): self.avd_first = nn.AvgPool2d(3, avd_stride, padding=1) if avd_stride > 0 and avd_first else None if self.radix >= 1: - self.conv2 = SplitAttnConv2d( + self.conv2 = SplitAttn( group_width, group_width, kernel_size=3, stride=stride, padding=first_dilation, dilation=first_dilation, groups=cardinality, radix=radix, norm_layer=norm_layer, drop_block=drop_block) - self.bn2 = None # FIXME revisit, here to satisfy current torchscript fussyness - self.act2 = None + self.bn2 = nn.Identity() + self.act2 = nn.Identity() else: self.conv2 = nn.Conv2d( group_width, group_width, kernel_size=3, stride=stride, padding=first_dilation, @@ -117,11 +117,10 @@ class ResNestBottleneck(nn.Module): out = self.avd_first(out) out = self.conv2(out) - if self.bn2 is not None: - out = self.bn2(out) - if self.drop_block is not None: - out = self.drop_block(out) - out = self.act2(out) + out = self.bn2(out) + if self.drop_block is not None: + out = self.drop_block(out) + out = self.act2(out) if self.avd_last is not None: out = self.avd_last(out) diff --git a/timm/models/sknet.py b/timm/models/sknet.py index eb7ad8c3..82ca5bfe 100644 --- a/timm/models/sknet.py +++ b/timm/models/sknet.py @@ -14,7 +14,7 @@ from torch import nn as nn from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from .helpers import build_model_with_cfg -from .layers import SelectiveKernelConv, ConvBnAct, create_attn +from .layers import SelectiveKernel, ConvBnAct, create_attn from .registry import register_model from .resnet import ResNet @@ -59,7 +59,7 @@ class SelectiveKernelBasic(nn.Module): outplanes = planes * self.expansion first_dilation = first_dilation or dilation - self.conv1 = SelectiveKernelConv( + self.conv1 = SelectiveKernel( inplanes, first_planes, stride=stride, dilation=first_dilation, **conv_kwargs, **sk_kwargs) conv_kwargs['act_layer'] = None self.conv2 = ConvBnAct( @@ -107,7 +107,7 @@ class SelectiveKernelBottleneck(nn.Module): first_dilation = first_dilation or dilation self.conv1 = ConvBnAct(inplanes, first_planes, kernel_size=1, **conv_kwargs) - self.conv2 = SelectiveKernelConv( + self.conv2 = SelectiveKernel( first_planes, width, stride=stride, dilation=first_dilation, groups=cardinality, **conv_kwargs, **sk_kwargs) conv_kwargs['act_layer'] = None @@ -153,10 +153,7 @@ def skresnet18(pretrained=False, **kwargs): Different from configs in Select Kernel paper or "Compounding the Performance Improvements..." this variation splits the input channels to the selective convolutions to keep param count down. """ - sk_kwargs = dict( - min_attn_channels=16, - attn_reduction=8, - split_input=True) + sk_kwargs = dict(min_rd_channels=16, rd_ratio=1/8, split_input=True) model_args = dict( block=SelectiveKernelBasic, layers=[2, 2, 2, 2], block_args=dict(sk_kwargs=sk_kwargs), zero_init_last_bn=False, **kwargs) @@ -170,10 +167,7 @@ def skresnet34(pretrained=False, **kwargs): Different from configs in Select Kernel paper or "Compounding the Performance Improvements..." this variation splits the input channels to the selective convolutions to keep param count down. """ - sk_kwargs = dict( - min_attn_channels=16, - attn_reduction=8, - split_input=True) + sk_kwargs = dict(min_rd_channels=16, rd_ratio=1/8, split_input=True) model_args = dict( block=SelectiveKernelBasic, layers=[3, 4, 6, 3], block_args=dict(sk_kwargs=sk_kwargs), zero_init_last_bn=False, **kwargs) From a27f4aec4aaa22c6a6e82c7d8a9a69d73176525e Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Mon, 31 May 2021 14:06:34 -0700 Subject: [PATCH 46/48] Missed args for skresnext w/ refactoring. --- timm/models/layers/selective_kernel.py | 2 +- timm/models/sknet.py | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/timm/models/layers/selective_kernel.py b/timm/models/layers/selective_kernel.py index 246f72a6..bf7df4d2 100644 --- a/timm/models/layers/selective_kernel.py +++ b/timm/models/layers/selective_kernel.py @@ -49,7 +49,7 @@ class SelectiveKernelAttn(nn.Module): class SelectiveKernel(nn.Module): def __init__(self, in_channels, out_channels=None, kernel_size=None, stride=1, dilation=1, groups=1, - rd_ratio=1./16, rd_channels=None, min_rd_channels=16, rd_divisor=8, keep_3x3=True, split_input=True, + rd_ratio=1./16, rd_channels=None, min_rd_channels=32, rd_divisor=8, keep_3x3=True, split_input=True, drop_block=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, aa_layer=None): """ Selective Kernel Convolution Module diff --git a/timm/models/sknet.py b/timm/models/sknet.py index 82ca5bfe..bba8bcf9 100644 --- a/timm/models/sknet.py +++ b/timm/models/sknet.py @@ -207,8 +207,9 @@ def skresnext50_32x4d(pretrained=False, **kwargs): """Constructs a Select Kernel ResNeXt50-32x4d model. This should be equivalent to the SKNet-50 model in the Select Kernel Paper """ + sk_kwargs = dict(min_rd_channels=32, rd_ratio=1/16, split_input=False) model_args = dict( block=SelectiveKernelBottleneck, layers=[3, 4, 6, 3], cardinality=32, base_width=4, - zero_init_last_bn=False, **kwargs) + block_args=dict(sk_kwargs=sk_kwargs), zero_init_last_bn=False, **kwargs) return _create_skresnet('skresnext50_32x4d', pretrained, **model_args) From bda8ab015ac5ee0ec75b8c59d5c0c3b399abda94 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Mon, 31 May 2021 15:38:56 -0700 Subject: [PATCH 47/48] Remove min channels for SelectiveKernel, divisor should cover cases well enough. --- timm/models/layers/selective_kernel.py | 6 ++---- timm/models/sknet.py | 6 +++--- 2 files changed, 5 insertions(+), 7 deletions(-) diff --git a/timm/models/layers/selective_kernel.py b/timm/models/layers/selective_kernel.py index bf7df4d2..f28b8d2e 100644 --- a/timm/models/layers/selective_kernel.py +++ b/timm/models/layers/selective_kernel.py @@ -49,7 +49,7 @@ class SelectiveKernelAttn(nn.Module): class SelectiveKernel(nn.Module): def __init__(self, in_channels, out_channels=None, kernel_size=None, stride=1, dilation=1, groups=1, - rd_ratio=1./16, rd_channels=None, min_rd_channels=32, rd_divisor=8, keep_3x3=True, split_input=True, + rd_ratio=1./16, rd_channels=None, rd_divisor=8, keep_3x3=True, split_input=True, drop_block=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, aa_layer=None): """ Selective Kernel Convolution Module @@ -68,7 +68,6 @@ class SelectiveKernel(nn.Module): dilation (int): dilation for module as a whole, impacts dilation of each branch groups (int): number of groups for each branch rd_ratio (int, float): reduction factor for attention features - min_rd_channels (int): minimum attention feature channels keep_3x3 (bool): keep all branch convolution kernels as 3x3, changing larger kernels for dilations split_input (bool): split input channels evenly across each convolution branch, keeps param count lower, can be viewed as grouping by path, output expands to module out_channels count @@ -103,8 +102,7 @@ class SelectiveKernel(nn.Module): ConvBnAct(in_channels, out_channels, kernel_size=k, dilation=d, **conv_kwargs) for k, d in zip(kernel_size, dilation)]) - attn_channels = rd_channels or make_divisible( - out_channels * rd_ratio, min_value=min_rd_channels, divisor=rd_divisor) + attn_channels = rd_channels or make_divisible(out_channels * rd_ratio, divisor=rd_divisor) self.attn = SelectiveKernelAttn(out_channels, self.num_paths, attn_channels) self.drop_block = drop_block diff --git a/timm/models/sknet.py b/timm/models/sknet.py index bba8bcf9..4dc2aa53 100644 --- a/timm/models/sknet.py +++ b/timm/models/sknet.py @@ -153,7 +153,7 @@ def skresnet18(pretrained=False, **kwargs): Different from configs in Select Kernel paper or "Compounding the Performance Improvements..." this variation splits the input channels to the selective convolutions to keep param count down. """ - sk_kwargs = dict(min_rd_channels=16, rd_ratio=1/8, split_input=True) + sk_kwargs = dict(rd_ratio=1 / 8, rd_divisor=16, split_input=True) model_args = dict( block=SelectiveKernelBasic, layers=[2, 2, 2, 2], block_args=dict(sk_kwargs=sk_kwargs), zero_init_last_bn=False, **kwargs) @@ -167,7 +167,7 @@ def skresnet34(pretrained=False, **kwargs): Different from configs in Select Kernel paper or "Compounding the Performance Improvements..." this variation splits the input channels to the selective convolutions to keep param count down. """ - sk_kwargs = dict(min_rd_channels=16, rd_ratio=1/8, split_input=True) + sk_kwargs = dict(rd_ratio=1 / 8, rd_divisor=16, split_input=True) model_args = dict( block=SelectiveKernelBasic, layers=[3, 4, 6, 3], block_args=dict(sk_kwargs=sk_kwargs), zero_init_last_bn=False, **kwargs) @@ -207,7 +207,7 @@ def skresnext50_32x4d(pretrained=False, **kwargs): """Constructs a Select Kernel ResNeXt50-32x4d model. This should be equivalent to the SKNet-50 model in the Select Kernel Paper """ - sk_kwargs = dict(min_rd_channels=32, rd_ratio=1/16, split_input=False) + sk_kwargs = dict(rd_ratio=1/16, rd_divisor=32, split_input=False) model_args = dict( block=SelectiveKernelBottleneck, layers=[3, 4, 6, 3], cardinality=32, base_width=4, block_args=dict(sk_kwargs=sk_kwargs), zero_init_last_bn=False, **kwargs) From 02320c3e3d217c90860e2576f1d644fdac09c09b Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Mon, 31 May 2021 15:41:51 -0700 Subject: [PATCH 48/48] Bump version to 0.4.11 --- timm/version.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/timm/version.py b/timm/version.py index b94cbb01..d4f33464 100644 --- a/timm/version.py +++ b/timm/version.py @@ -1 +1 @@ -__version__ = '0.4.10' +__version__ = '0.4.11'