From 165fb354b2a797c68ec30399971dc1fdfc498509 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Fri, 14 May 2021 16:48:58 -0700 Subject: [PATCH 01/18] Add initial RedNet model / Involution layer impl for testing --- timm/models/byoanet.py | 49 +++++++++++++++++++++++++ timm/models/layers/__init__.py | 1 + timm/models/layers/create_self_attn.py | 3 ++ timm/models/layers/involution.py | 50 ++++++++++++++++++++++++++ 4 files changed, 103 insertions(+) create mode 100644 timm/models/layers/involution.py diff --git a/timm/models/byoanet.py b/timm/models/byoanet.py index ca49089b..a58eea63 100644 --- a/timm/models/byoanet.py +++ b/timm/models/byoanet.py @@ -58,6 +58,9 @@ default_cfgs = { 'swinnet26t_256': _cfg(url='', fixed_input_size=True, input_size=(3, 256, 256), pool_size=(8, 8)), 'swinnet50ts_256': _cfg(url='', fixed_input_size=True, input_size=(3, 256, 256), pool_size=(8, 8)), + + 'rednet26t': _cfg(url='', fixed_input_size=False, input_size=(3, 256, 256), pool_size=(8, 8)), + 'rednet50ts': _cfg(url='', fixed_input_size=False, input_size=(3, 256, 256), pool_size=(8, 8)), } @@ -245,6 +248,38 @@ model_cfgs = dict( self_attn_fixed_size=True, self_attn_kwargs=dict(win_size=8) ), + + rednet26t=ByoaCfg( + blocks=( + ByoaBlocksCfg(type='self_attn', d=2, c=256, s=1, gs=0, br=0.25), + ByoaBlocksCfg(type='self_attn', d=2, c=512, s=2, gs=0, br=0.25), + ByoaBlocksCfg(type='self_attn', d=2, c=1024, s=2, gs=0, br=0.25), + ByoaBlocksCfg(type='self_attn', d=2, c=2048, s=2, gs=0, br=0.25), + ), + stem_chs=64, + stem_type='tiered', # FIXME RedNet uses involution in middle of stem + stem_pool='maxpool', + num_features=0, + self_attn_layer='involution', + self_attn_fixed_size=False, + self_attn_kwargs=dict() + ), + rednet50ts=ByoaCfg( + blocks=( + ByoaBlocksCfg(type='self_attn', d=3, c=256, s=1, gs=0, br=0.25), + ByoaBlocksCfg(type='self_attn', d=4, c=512, s=2, gs=0, br=0.25), + ByoaBlocksCfg(type='self_attn', d=2, c=1024, s=2, gs=0, br=0.25), + ByoaBlocksCfg(type='self_attn', d=3, c=2048, s=2, gs=0, br=0.25), + ), + stem_chs=64, + stem_type='tiered', + stem_pool='maxpool', + num_features=0, + act_layer='silu', + self_attn_layer='involution', + self_attn_fixed_size=False, + self_attn_kwargs=dict() + ), ) @@ -477,3 +512,17 @@ def swinnet50ts_256(pretrained=False, **kwargs): """ kwargs.setdefault('img_size', 256) return _create_byoanet('swinnet50ts_256', 'swinnet50ts', pretrained=pretrained, **kwargs) + + +@register_model +def rednet26t(pretrained=False, **kwargs): + """ + """ + return _create_byoanet('rednet26t', pretrained=pretrained, **kwargs) + + +@register_model +def rednet50ts(pretrained=False, **kwargs): + """ + """ + return _create_byoanet('rednet50ts', pretrained=pretrained, **kwargs) diff --git a/timm/models/layers/__init__.py b/timm/models/layers/__init__.py index 90241f5c..522c27e1 100644 --- a/timm/models/layers/__init__.py +++ b/timm/models/layers/__init__.py @@ -18,6 +18,7 @@ from .eca import EcaModule, CecaModule from .evo_norm import EvoNormBatch2d, EvoNormSample2d from .helpers import to_ntuple, to_2tuple, to_3tuple, to_4tuple, make_divisible from .inplace_abn import InplaceAbn +from .involution import Involution from .linear import Linear from .mixed_conv2d import MixedConv2d from .mlp import Mlp, GluMlp diff --git a/timm/models/layers/create_self_attn.py b/timm/models/layers/create_self_attn.py index ba208f17..448ddb34 100644 --- a/timm/models/layers/create_self_attn.py +++ b/timm/models/layers/create_self_attn.py @@ -1,5 +1,6 @@ from .bottleneck_attn import BottleneckAttn from .halo_attn import HaloAttn +from .involution import Involution from .lambda_layer import LambdaLayer from .swin_attn import WindowAttention @@ -13,6 +14,8 @@ def get_self_attn(attn_type): return LambdaLayer elif attn_type == 'swin': return WindowAttention + elif attn_type == 'involution': + return Involution else: assert False, f"Unknown attn type ({attn_type})" diff --git a/timm/models/layers/involution.py b/timm/models/layers/involution.py new file mode 100644 index 00000000..0dba9fae --- /dev/null +++ b/timm/models/layers/involution.py @@ -0,0 +1,50 @@ +""" PyTorch Involution Layer + +Official impl: https://github.com/d-li14/involution/blob/main/cls/mmcls/models/utils/involution_naive.py +Paper: `Involution: Inverting the Inherence of Convolution for Visual Recognition` - https://arxiv.org/abs/2103.06255 +""" +import torch.nn as nn +from .conv_bn_act import ConvBnAct +from .create_conv2d import create_conv2d + + +class Involution(nn.Module): + + def __init__( + self, + channels, + kernel_size=3, + stride=1, + group_size=16, + reduction_ratio=4, + norm_layer=nn.BatchNorm2d, + act_layer=nn.ReLU, + ): + super(Involution, self).__init__() + self.kernel_size = kernel_size + self.stride = stride + self.channels = channels + self.group_size = group_size + self.groups = self.channels // self.group_size + self.conv1 = ConvBnAct( + in_channels=channels, + out_channels=channels // reduction_ratio, + kernel_size=1, + norm_layer=norm_layer, + act_layer=act_layer) + self.conv2 = self.conv = create_conv2d( + in_channels=channels // reduction_ratio, + out_channels=kernel_size**2 * self.groups, + kernel_size=1, + stride=1) + self.avgpool = nn.AvgPool2d(stride, stride) if stride == 2 else nn.Identity() + self.unfold = nn.Unfold(kernel_size, 1, (kernel_size-1)//2, stride) + + def forward(self, x): + weight = self.conv2(self.conv1(self.avgpool(x))) + B, C, H, W = weight.shape + KK = int(self.kernel_size ** 2) + weight = weight.view(B, self.groups, KK, H, W).unsqueeze(2) + out = self.unfold(x).view(B, self.groups, self.group_size, KK, H, W) + out = (weight * out).sum(dim=3).view(B, self.channels, H, W) + return out From ecc7552c5c5ab1d177705774e8e4efd16939852c Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Fri, 14 May 2021 17:10:43 -0700 Subject: [PATCH 02/18] Add levit, levit_c, and visformer model defs. Largely untested and not finished cleanup. --- tests/test_models.py | 2 +- timm/models/__init__.py | 3 + timm/models/layers/patch_embed.py | 7 +- timm/models/levit.py | 440 ++++++++++++++++++++++++++++++ timm/models/levitc.py | 400 +++++++++++++++++++++++++++ timm/models/visformer.py | 377 +++++++++++++++++++++++++ 6 files changed, 1226 insertions(+), 3 deletions(-) create mode 100644 timm/models/levit.py create mode 100644 timm/models/levitc.py create mode 100644 timm/models/visformer.py diff --git a/tests/test_models.py b/tests/test_models.py index ced2fd76..1e1de498 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -15,7 +15,7 @@ if hasattr(torch._C, '_jit_set_profiling_executor'): torch._C._jit_set_profiling_mode(False) # transformer models don't support many of the spatial / feature based model functionalities -NON_STD_FILTERS = ['vit_*', 'tnt_*', 'pit_*', 'swin_*', 'coat_*', 'cait_*', 'mixer_*'] +NON_STD_FILTERS = ['vit_*', 'tnt_*', 'pit_*', 'swin_*', 'coat_*', 'cait_*', 'mixer_*', 'levit*', 'visformer*'] NUM_NON_STD = len(NON_STD_FILTERS) # exclude models that cause specific test failures diff --git a/timm/models/__init__.py b/timm/models/__init__.py index 46ea155f..821012e2 100644 --- a/timm/models/__init__.py +++ b/timm/models/__init__.py @@ -15,6 +15,8 @@ from .hrnet import * from .inception_resnet_v2 import * from .inception_v3 import * from .inception_v4 import * +from .levitc import * +from .levit import * from .mlp_mixer import * from .mobilenetv3 import * from .nasnet import * @@ -34,6 +36,7 @@ from .swin_transformer import * from .tnt import * from .tresnet import * from .vgg import * +from .visformer import * from .vision_transformer import * from .vision_transformer_hybrid import * from .vovnet import * diff --git a/timm/models/layers/patch_embed.py b/timm/models/layers/patch_embed.py index b06f9982..42997fb8 100644 --- a/timm/models/layers/patch_embed.py +++ b/timm/models/layers/patch_embed.py @@ -15,7 +15,7 @@ from .helpers import to_2tuple class PatchEmbed(nn.Module): """ 2D Image to Patch Embedding """ - def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, norm_layer=None): + def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, norm_layer=None, flatten=True): super().__init__() img_size = to_2tuple(img_size) patch_size = to_2tuple(patch_size) @@ -23,6 +23,7 @@ class PatchEmbed(nn.Module): self.patch_size = patch_size self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1]) self.num_patches = self.grid_size[0] * self.grid_size[1] + self.flatten = flatten self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() @@ -31,6 +32,8 @@ class PatchEmbed(nn.Module): B, C, H, W = x.shape assert H == self.img_size[0] and W == self.img_size[1], \ f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." - x = self.proj(x).flatten(2).transpose(1, 2) + x = self.proj(x) + if self.flatten: + x = x.flatten(2).transpose(1, 2) # BCHW -> BNC x = self.norm(x) return x diff --git a/timm/models/levit.py b/timm/models/levit.py new file mode 100644 index 00000000..997b44d7 --- /dev/null +++ b/timm/models/levit.py @@ -0,0 +1,440 @@ +# Copyright (c) 2015-present, Facebook, Inc. +# All rights reserved. + +# Modified from +# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py +# Copyright 2020 Ross Wightman, Apache-2.0 License +import itertools + +import torch + +from timm.data import IMAGENET_DEFAULT_STD, IMAGENET_DEFAULT_MEAN +from .vision_transformer import trunc_normal_ +from .registry import register_model + + +def _cfg(url='', **kwargs): + return { + 'url': url, + 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, + 'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True, + 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, + 'first_conv': 'patch_embed.proj', 'classifier': 'head', + **kwargs + } + + +specification = { + 'levit_128s': { + 'C': '128_256_384', 'D': 16, 'N': '4_6_8', 'X': '2_3_4', 'drop_path': 0, + 'weights': 'https://dl.fbaipublicfiles.com/LeViT/LeViT-128S-96703c44.pth'}, + 'levit_128': { + 'C': '128_256_384', 'D': 16, 'N': '4_8_12', 'X': '4_4_4', 'drop_path': 0, + 'weights': 'https://dl.fbaipublicfiles.com/LeViT/LeViT-128-b88c2750.pth'}, + 'levit_192': { + 'C': '192_288_384', 'D': 32, 'N': '3_5_6', 'X': '4_4_4', 'drop_path': 0, + 'weights': 'https://dl.fbaipublicfiles.com/LeViT/LeViT-192-92712e41.pth'}, + 'levit_256': { + 'C': '256_384_512', 'D': 32, 'N': '4_6_8', 'X': '4_4_4', 'drop_path': 0, + 'weights': 'https://dl.fbaipublicfiles.com/LeViT/LeViT-256-13b5763e.pth'}, + 'levit_384': { + 'C': '384_512_768', 'D': 32, 'N': '6_9_12', 'X': '4_4_4', 'drop_path': 0.1, + 'weights': 'https://dl.fbaipublicfiles.com/LeViT/LeViT-384-9bdaf2e2.pth'}, +} + +__all__ = ['Levit'] + + +@register_model +def levit_128s(num_classes=1000, distillation=True, pretrained=False, fuse=False, **kwargs): + return model_factory(**specification['levit_128s'], num_classes=num_classes, + distillation=distillation, pretrained=pretrained, fuse=fuse) + + +@register_model +def levit_128(num_classes=1000, distillation=True, pretrained=False, fuse=False, **kwargs): + return model_factory(**specification['levit_128'], num_classes=num_classes, + distillation=distillation, pretrained=pretrained, fuse=fuse) + + +@register_model +def levit_192(num_classes=1000, distillation=True, pretrained=False, fuse=False, **kwargs): + return model_factory(**specification['levit_192'], num_classes=num_classes, + distillation=distillation, pretrained=pretrained, fuse=fuse) + + +@register_model +def levit_256(num_classes=1000, distillation=True, pretrained=False, fuse=False, **kwargs): + return model_factory(**specification['levit_256'], num_classes=num_classes, + distillation=distillation, pretrained=pretrained, fuse=fuse) + + +@register_model +def levit_384(num_classes=1000, distillation=True, pretrained=False, fuse=False, **kwargs): + return model_factory(**specification['levit_384'], num_classes=num_classes, + distillation=distillation, pretrained=pretrained, fuse=fuse) + + +class ConvNorm(torch.nn.Sequential): + def __init__( + self, a, b, ks=1, stride=1, pad=0, dilation=1, groups=1, bn_weight_init=1, resolution=-10000): + super().__init__() + self.add_module('c', torch.nn.Conv2d(a, b, ks, stride, pad, dilation, groups, bias=False)) + bn = torch.nn.BatchNorm2d(b) + torch.nn.init.constant_(bn.weight, bn_weight_init) + torch.nn.init.constant_(bn.bias, 0) + self.add_module('bn', bn) + + @torch.no_grad() + def fuse(self): + c, bn = self._modules.values() + w = bn.weight / (bn.running_var + bn.eps) ** 0.5 + w = c.weight * w[:, None, None, None] + b = bn.bias - bn.running_mean * bn.weight / (bn.running_var + bn.eps) ** 0.5 + m = torch.nn.Conv2d( + w.size(1), w.size(0), w.shape[2:], stride=self.c.stride, + padding=self.c.padding, dilation=self.c.dilation, groups=self.c.groups) + m.weight.data.copy_(w) + m.bias.data.copy_(b) + return m + + +class LinearNorm(torch.nn.Sequential): + def __init__(self, a, b, bn_weight_init=1, resolution=-100000): + super().__init__() + self.add_module('c', torch.nn.Linear(a, b, bias=False)) + bn = torch.nn.BatchNorm1d(b) + torch.nn.init.constant_(bn.weight, bn_weight_init) + torch.nn.init.constant_(bn.bias, 0) + self.add_module('bn', bn) + + @torch.no_grad() + def fuse(self): + l, bn = self._modules.values() + w = bn.weight / (bn.running_var + bn.eps) ** 0.5 + w = l.weight * w[:, None] + b = bn.bias - bn.running_mean * bn.weight / (bn.running_var + bn.eps) ** 0.5 + m = torch.nn.Linear(w.size(1), w.size(0)) + m.weight.data.copy_(w) + m.bias.data.copy_(b) + return m + + def forward(self, x): + l, bn = self._modules.values() + x = l(x) + return bn(x.flatten(0, 1)).reshape_as(x) + + +class NormLinear(torch.nn.Sequential): + def __init__(self, a, b, bias=True, std=0.02): + super().__init__() + self.add_module('bn', torch.nn.BatchNorm1d(a)) + l = torch.nn.Linear(a, b, bias=bias) + trunc_normal_(l.weight, std=std) + if bias: + torch.nn.init.constant_(l.bias, 0) + self.add_module('l', l) + + @torch.no_grad() + def fuse(self): + bn, l = self._modules.values() + w = bn.weight / (bn.running_var + bn.eps) ** 0.5 + b = bn.bias - self.bn.running_mean * self.bn.weight / (bn.running_var + bn.eps) ** 0.5 + w = l.weight * w[None, :] + if l.bias is None: + b = b @ self.l.weight.T + else: + b = (l.weight @ b[:, None]).view(-1) + self.l.bias + m = torch.nn.Linear(w.size(1), w.size(0)) + m.weight.data.copy_(w) + m.bias.data.copy_(b) + return m + + +def b16(n, activation, resolution=224): + return torch.nn.Sequential( + ConvNorm(3, n // 8, 3, 2, 1, resolution=resolution), + activation(), + ConvNorm(n // 8, n // 4, 3, 2, 1, resolution=resolution // 2), + activation(), + ConvNorm(n // 4, n // 2, 3, 2, 1, resolution=resolution // 4), + activation(), + ConvNorm(n // 2, n, 3, 2, 1, resolution=resolution // 8)) + + +class Residual(torch.nn.Module): + def __init__(self, m, drop): + super().__init__() + self.m = m + self.drop = drop + + def forward(self, x): + if self.training and self.drop > 0: + return x + self.m(x) * torch.rand( + x.size(0), 1, 1, device=x.device).ge_(self.drop).div(1 - self.drop).detach() + else: + return x + self.m(x) + + +class Attention(torch.nn.Module): + def __init__( + self, dim, key_dim, num_heads=8, attn_ratio=4, act_layer=None, resolution=14): + super().__init__() + self.num_heads = num_heads + self.scale = key_dim ** -0.5 + self.key_dim = key_dim + self.nh_kd = nh_kd = key_dim * num_heads + self.d = int(attn_ratio * key_dim) + self.dh = int(attn_ratio * key_dim) * num_heads + self.attn_ratio = attn_ratio + h = self.dh + nh_kd * 2 + self.qkv = LinearNorm(dim, h, resolution=resolution) + self.proj = torch.nn.Sequential( + act_layer(), + LinearNorm(self.dh, dim, bn_weight_init=0, resolution=resolution)) + + points = list(itertools.product(range(resolution), range(resolution))) + N = len(points) + attention_offsets = {} + idxs = [] + for p1 in points: + for p2 in points: + offset = (abs(p1[0] - p2[0]), abs(p1[1] - p2[1])) + if offset not in attention_offsets: + attention_offsets[offset] = len(attention_offsets) + idxs.append(attention_offsets[offset]) + self.attention_biases = torch.nn.Parameter(torch.zeros(num_heads, len(attention_offsets))) + self.register_buffer('attention_bias_idxs', torch.LongTensor(idxs).view(N, N)) + + @torch.no_grad() + def train(self, mode=True): + super().train(mode) + if mode and hasattr(self, 'ab'): + del self.ab + else: + self.ab = self.attention_biases[:, self.attention_bias_idxs] + + def forward(self, x): # x (B,N,C) + B, N, C = x.shape + qkv = self.qkv(x) + q, k, v = qkv.view(B, N, self.num_heads, -1).split([self.key_dim, self.key_dim, self.d], dim=3) + q = q.permute(0, 2, 1, 3) + k = k.permute(0, 2, 1, 3) + v = v.permute(0, 2, 1, 3) + + ab = self.attention_biases[:, self.attention_bias_idxs] if self.training else self.ab + attn = q @ k.transpose(-2, -1) * self.scale + ab + + attn = attn.softmax(dim=-1) + x = (attn @ v).transpose(1, 2).reshape(B, N, self.dh) + x = self.proj(x) + return x + + +class Subsample(torch.nn.Module): + def __init__(self, stride, resolution): + super().__init__() + self.stride = stride + self.resolution = resolution + + def forward(self, x): + B, N, C = x.shape + x = x.view(B, self.resolution, self.resolution, C)[:, ::self.stride, ::self.stride] + return x.reshape(B, -1, C) + + +class AttentionSubsample(torch.nn.Module): + def __init__(self, in_dim, out_dim, key_dim, num_heads=8, + attn_ratio=2, act_layer=None, stride=2, resolution=14, resolution_=7): + super().__init__() + self.num_heads = num_heads + self.scale = key_dim ** -0.5 + self.key_dim = key_dim + self.nh_kd = nh_kd = key_dim * num_heads + self.d = int(attn_ratio * key_dim) + self.dh = int(attn_ratio * key_dim) * self.num_heads + self.attn_ratio = attn_ratio + self.resolution_ = resolution_ + self.resolution_2 = resolution_ ** 2 + h = self.dh + nh_kd + self.kv = LinearNorm(in_dim, h, resolution=resolution) + + self.q = torch.nn.Sequential( + Subsample(stride, resolution), + LinearNorm(in_dim, nh_kd, resolution=resolution_)) + self.proj = torch.nn.Sequential( + act_layer(), + LinearNorm(self.dh, out_dim, resolution=resolution_)) + + self.stride = stride + self.resolution = resolution + points = list(itertools.product(range(resolution), range(resolution))) + points_ = list(itertools.product(range(resolution_), range(resolution_))) + N = len(points) + N_ = len(points_) + attention_offsets = {} + idxs = [] + for p1 in points_: + for p2 in points: + size = 1 + offset = ( + abs(p1[0] * stride - p2[0] + (size - 1) / 2), + abs(p1[1] * stride - p2[1] + (size - 1) / 2)) + if offset not in attention_offsets: + attention_offsets[offset] = len(attention_offsets) + idxs.append(attention_offsets[offset]) + self.attention_biases = torch.nn.Parameter(torch.zeros(num_heads, len(attention_offsets))) + self.register_buffer('attention_bias_idxs', torch.LongTensor(idxs).view(N_, N)) + + + @torch.no_grad() + def train(self, mode=True): + super().train(mode) + if mode and hasattr(self, 'ab'): + del self.ab + else: + self.ab = self.attention_biases[:, self.attention_bias_idxs] + + def forward(self, x): + B, N, C = x.shape + k, v = self.kv(x).view(B, N, self.num_heads, -1).split([self.key_dim, self.d], dim=3) + k = k.permute(0, 2, 1, 3) # BHNC + v = v.permute(0, 2, 1, 3) # BHNC + q = self.q(x).view(B, self.resolution_2, self.num_heads, self.key_dim).permute(0, 2, 1, 3) + + ab = self.attention_biases[:, self.attention_bias_idxs] if self.training else self.ab + attn = q @ k.transpose(-2, -1) * self.scale + ab + attn = attn.softmax(dim=-1) + + x = (attn @ v).transpose(1, 2).reshape(B, -1, self.dh) + x = self.proj(x) + return x + + +class Levit(torch.nn.Module): + """ Vision Transformer with support for patch or hybrid CNN input stage + """ + + def __init__( + self, + img_size=224, + patch_size=16, + in_chans=3, + num_classes=1000, + embed_dim=[192], + key_dim=[64], + depth=[12], + num_heads=[3], + attn_ratio=[2], + mlp_ratio=[2], + hybrid_backbone=None, + down_ops=[], + attn_act_layer=torch.nn.Hardswish, + mlp_act_layer=torch.nn.Hardswish, + distillation=True, + drop_path=0): + super().__init__() + global FLOPS_COUNTER + + self.num_classes = num_classes + self.num_features = embed_dim[-1] + self.embed_dim = embed_dim + self.distillation = distillation + + self.patch_embed = hybrid_backbone + + self.blocks = [] + down_ops.append(['']) + resolution = img_size // patch_size + for i, (ed, kd, dpth, nh, ar, mr, do) in enumerate( + zip(embed_dim, key_dim, depth, num_heads, attn_ratio, mlp_ratio, down_ops)): + for _ in range(dpth): + self.blocks.append( + Residual( + Attention(ed, kd, nh, attn_ratio=ar, act_layer=attn_act_layer, resolution=resolution), + drop_path)) + if mr > 0: + h = int(ed * mr) + self.blocks.append( + Residual(torch.nn.Sequential( + LinearNorm(ed, h, resolution=resolution), + mlp_act_layer(), + LinearNorm(h, ed, bn_weight_init=0, resolution=resolution), + ), drop_path)) + if do[0] == 'Subsample': + # ('Subsample',key_dim, num_heads, attn_ratio, mlp_ratio, stride) + resolution_ = (resolution - 1) // do[5] + 1 + self.blocks.append( + AttentionSubsample( + *embed_dim[i:i + 2], key_dim=do[1], num_heads=do[2], + attn_ratio=do[3], act_layer=attn_act_layer, stride=do[5], + resolution=resolution, resolution_=resolution_)) + resolution = resolution_ + if do[4] > 0: # mlp_ratio + h = int(embed_dim[i + 1] * do[4]) + self.blocks.append( + Residual(torch.nn.Sequential( + LinearNorm(embed_dim[i + 1], h, resolution=resolution), + mlp_act_layer(), + LinearNorm(h, embed_dim[i + 1], bn_weight_init=0, resolution=resolution), + ), drop_path)) + self.blocks = torch.nn.Sequential(*self.blocks) + + # Classifier head + self.head = NormLinear(embed_dim[-1], num_classes) if num_classes > 0 else torch.nn.Identity() + if distillation: + self.head_dist = NormLinear(embed_dim[-1], num_classes) if num_classes > 0 else torch.nn.Identity() + else: + self.head_dist = None + + @torch.jit.ignore + def no_weight_decay(self): + return {x for x in self.state_dict().keys() if 'attention_biases' in x} + + def forward(self, x): + x = self.patch_embed(x) + x = x.flatten(2).transpose(1, 2) + x = self.blocks(x) + x = x.mean(1) + if self.distillation: + x = self.head(x), self.head_dist(x) + if not self.training: + x = (x[0] + x[1]) / 2 + else: + x = self.head(x) + return x + + +def model_factory(C, D, X, N, drop_path, weights, num_classes, distillation, pretrained, fuse): + embed_dim = [int(x) for x in C.split('_')] + num_heads = [int(x) for x in N.split('_')] + depth = [int(x) for x in X.split('_')] + act = torch.nn.Hardswish + model = Levit( + patch_size=16, + embed_dim=embed_dim, + num_heads=num_heads, + key_dim=[D] * 3, + depth=depth, + attn_ratio=[2, 2, 2], + mlp_ratio=[2, 2, 2], + down_ops=[ + # ('Subsample',key_dim, num_heads, attn_ratio, mlp_ratio, stride) + ['Subsample', D, embed_dim[0] // D, 4, 2, 2], + ['Subsample', D, embed_dim[1] // D, 4, 2, 2], + ], + attn_act_layer=act, + mlp_act_layer=act, + hybrid_backbone=b16(embed_dim[0], activation=act), + num_classes=num_classes, + drop_path=drop_path, + distillation=distillation + ) + model.default_cfg = _cfg() + if pretrained: + checkpoint = torch.hub.load_state_dict_from_url(weights, map_location='cpu') + model.load_state_dict(checkpoint['model']) + #if fuse: + # utils.replace_batchnorm(model) + + return model diff --git a/timm/models/levitc.py b/timm/models/levitc.py new file mode 100644 index 00000000..1a422953 --- /dev/null +++ b/timm/models/levitc.py @@ -0,0 +1,400 @@ +# Copyright (c) 2015-present, Facebook, Inc. +# All rights reserved. + +# Modified from +# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py +# Copyright 2020 Ross Wightman, Apache-2.0 License +import itertools + +import torch + +from timm.data import IMAGENET_DEFAULT_STD, IMAGENET_DEFAULT_MEAN +from .vision_transformer import trunc_normal_ +from .registry import register_model + + +def _cfg(url='', **kwargs): + return { + 'url': url, + 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, + 'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True, + 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, + 'first_conv': 'patch_embed.proj', 'classifier': 'head', + **kwargs + } + + +specification = { + 'levit_c_128s': { + 'C': '128_256_384', 'D': 16, 'N': '4_6_8', 'X': '2_3_4', 'drop_path': 0, + 'weights': 'https://dl.fbaipublicfiles.com/LeViT/LeViT-128S-96703c44.pth'}, + 'levit_c_128': { + 'C': '128_256_384', 'D': 16, 'N': '4_8_12', 'X': '4_4_4', 'drop_path': 0, + 'weights': 'https://dl.fbaipublicfiles.com/LeViT/LeViT-128-b88c2750.pth'}, + 'levit_c_192': { + 'C': '192_288_384', 'D': 32, 'N': '3_5_6', 'X': '4_4_4', 'drop_path': 0, + 'weights': 'https://dl.fbaipublicfiles.com/LeViT/LeViT-192-92712e41.pth'}, + 'levit_c_256': { + 'C': '256_384_512', 'D': 32, 'N': '4_6_8', 'X': '4_4_4', 'drop_path': 0, + 'weights': 'https://dl.fbaipublicfiles.com/LeViT/LeViT-256-13b5763e.pth'}, + 'levit_c_384': { + 'C': '384_512_768', 'D': 32, 'N': '6_9_12', 'X': '4_4_4', 'drop_path': 0.1, + 'weights': 'https://dl.fbaipublicfiles.com/LeViT/LeViT-384-9bdaf2e2.pth'}, +} + +__all__ = ['Levit'] + + +@register_model +def levit_c_128s(num_classes=1000, distillation=True, pretrained=False, fuse=False, **kwargs): + return model_factory(**specification['levit_c_128s'], num_classes=num_classes, + distillation=distillation, pretrained=pretrained, fuse=fuse) + + +@register_model +def levit_c_128(num_classes=1000, distillation=True, pretrained=False, fuse=False, **kwargs): + return model_factory(**specification['levit_c_128'], num_classes=num_classes, + distillation=distillation, pretrained=pretrained, fuse=fuse) + + +@register_model +def levit_c_192(num_classes=1000, distillation=True, pretrained=False, fuse=False, **kwargs): + return model_factory(**specification['levit_c_192'], num_classes=num_classes, + distillation=distillation, pretrained=pretrained, fuse=fuse) + + +@register_model +def levit_c_256(num_classes=1000, distillation=True, pretrained=False, fuse=False, **kwargs): + return model_factory(**specification['levit_c_256'], num_classes=num_classes, + distillation=distillation, pretrained=pretrained, fuse=fuse) + + +@register_model +def levit_c_384(num_classes=1000, distillation=True, pretrained=False, fuse=False, **kwargs): + return model_factory(**specification['levit_c_384'], num_classes=num_classes, + distillation=distillation, pretrained=pretrained, fuse=fuse) + + +class ConvNorm(torch.nn.Sequential): + def __init__( + self, a, b, ks=1, stride=1, pad=0, dilation=1, groups=1, bn_weight_init=1, resolution=-10000): + super().__init__() + self.add_module('c', torch.nn.Conv2d(a, b, ks, stride, pad, dilation, groups, bias=False)) + bn = torch.nn.BatchNorm2d(b) + torch.nn.init.constant_(bn.weight, bn_weight_init) + torch.nn.init.constant_(bn.bias, 0) + self.add_module('bn', bn) + + @torch.no_grad() + def fuse(self): + c, bn = self._modules.values() + w = bn.weight / (bn.running_var + bn.eps) ** 0.5 + w = c.weight * w[:, None, None, None] + b = bn.bias - bn.running_mean * bn.weight / \ + (bn.running_var + bn.eps) ** 0.5 + m = torch.nn.Conv2d( + w.size(1), w.size(0), w.shape[2:], stride=self.c.stride, + padding=self.c.padding, dilation=self.c.dilation, groups=self.c.groups) + m.weight.data.copy_(w) + m.bias.data.copy_(b) + return m + + +class NormLinear(torch.nn.Sequential): + def __init__(self, a, b, bias=True, std=0.02): + super().__init__() + self.add_module('bn', torch.nn.BatchNorm1d(a)) + l = torch.nn.Linear(a, b, bias=bias) + trunc_normal_(l.weight, std=std) + if bias: + torch.nn.init.constant_(l.bias, 0) + self.add_module('l', l) + + @torch.no_grad() + def fuse(self): + bn, l = self._modules.values() + w = bn.weight / (bn.running_var + bn.eps) ** 0.5 + b = bn.bias - self.bn.running_mean * \ + self.bn.weight / (bn.running_var + bn.eps) ** 0.5 + w = l.weight * w[None, :] + if l.bias is None: + b = b @ self.l.weight.T + else: + b = (l.weight @ b[:, None]).view(-1) + self.l.bias + m = torch.nn.Linear(w.size(1), w.size(0)) + m.weight.data.copy_(w) + m.bias.data.copy_(b) + return m + + +def b16(n, activation, resolution=224): + return torch.nn.Sequential( + ConvNorm(3, n // 8, 3, 2, 1, resolution=resolution), + activation(), + ConvNorm(n // 8, n // 4, 3, 2, 1, resolution=resolution // 2), + activation(), + ConvNorm(n // 4, n // 2, 3, 2, 1, resolution=resolution // 4), + activation(), + ConvNorm(n // 2, n, 3, 2, 1, resolution=resolution // 8)) + + +class Residual(torch.nn.Module): + def __init__(self, m, drop): + super().__init__() + self.m = m + self.drop = drop + + def forward(self, x): + if self.training and self.drop > 0: + return x + self.m(x) * torch.rand( + x.size(0), 1, 1, device=x.device).ge_(self.drop).div(1 - self.drop).detach() + else: + return x + self.m(x) + + +class Attention(torch.nn.Module): + def __init__(self, dim, key_dim, num_heads=8, + attn_ratio=4, act_layer=None, resolution=14): + super().__init__() + self.num_heads = num_heads + self.scale = key_dim ** -0.5 + self.key_dim = key_dim + self.nh_kd = nh_kd = key_dim * num_heads + self.d = int(attn_ratio * key_dim) + self.dh = int(attn_ratio * key_dim) * num_heads + self.attn_ratio = attn_ratio + h = self.dh + nh_kd * 2 + self.qkv = ConvNorm(dim, h, resolution=resolution) + self.proj = torch.nn.Sequential( + act_layer(), + ConvNorm(self.dh, dim, bn_weight_init=0, resolution=resolution)) + + points = list(itertools.product(range(resolution), range(resolution))) + N = len(points) + attention_offsets = {} + idxs = [] + for p1 in points: + for p2 in points: + offset = (abs(p1[0] - p2[0]), abs(p1[1] - p2[1])) + if offset not in attention_offsets: + attention_offsets[offset] = len(attention_offsets) + idxs.append(attention_offsets[offset]) + self.attention_biases = torch.nn.Parameter(torch.zeros(num_heads, len(attention_offsets))) + self.register_buffer('attention_bias_idxs', torch.LongTensor(idxs).view(N, N)) + self.ab = None + + @torch.no_grad() + def train(self, mode=True): + super().train(mode) + if mode and self.ab is not None: + self.ab = None + else: + self.ab = self.attention_biases[:, self.attention_bias_idxs] + + def forward(self, x): # x (B,C,H,W) + B, C, H, W = x.shape + q, k, v = self.qkv(x).view(B, self.num_heads, -1, H * W).split([self.key_dim, self.key_dim, self.d], dim=2) + ab = self.attention_biases[:, self.attention_bias_idxs] if self.training else self.ab + attn = (q.transpose(-2, -1) @ k) * self.scale + ab + attn = attn.softmax(dim=-1) + x = (v @ attn.transpose(-2, -1)).view(B, -1, H, W) + x = self.proj(x) + return x + + +class AttentionSubsample(torch.nn.Module): + def __init__( + self, in_dim, out_dim, key_dim, num_heads=8, attn_ratio=2, + act_layer=None, stride=2, resolution=14, resolution_=7): + super().__init__() + self.num_heads = num_heads + self.scale = key_dim ** -0.5 + self.key_dim = key_dim + self.nh_kd = nh_kd = key_dim * num_heads + self.d = int(attn_ratio * key_dim) + self.dh = int(attn_ratio * key_dim) * self.num_heads + self.attn_ratio = attn_ratio + self.resolution_ = resolution_ + self.resolution_2 = resolution_ ** 2 + h = self.dh + nh_kd + self.kv = ConvNorm(in_dim, h, resolution=resolution) + self.q = torch.nn.Sequential( + torch.nn.AvgPool2d(1, stride, 0), + ConvNorm(in_dim, nh_kd, resolution=resolution_)) + self.proj = torch.nn.Sequential( + act_layer(), + ConvNorm(self.d * num_heads, out_dim, resolution=resolution_)) + + self.stride = stride + self.resolution = resolution + points = list(itertools.product(range(resolution), range(resolution))) + points_ = list(itertools.product(range(resolution_), range(resolution_))) + N = len(points) + N_ = len(points_) + attention_offsets = {} + idxs = [] + for p1 in points_: + for p2 in points: + size = 1 + offset = ( + abs(p1[0] * stride - p2[0] + (size - 1) / 2), + abs(p1[1] * stride - p2[1] + (size - 1) / 2)) + if offset not in attention_offsets: + attention_offsets[offset] = len(attention_offsets) + idxs.append(attention_offsets[offset]) + self.attention_biases = torch.nn.Parameter(torch.zeros(num_heads, len(attention_offsets))) + self.register_buffer('attention_bias_idxs', torch.LongTensor(idxs).view(N_, N)) + self.ab = None + + @torch.no_grad() + def train(self, mode=True): + super().train(mode) + if mode and self.ab is not None: + self.ab = None + else: + self.ab = self.attention_biases[:, self.attention_bias_idxs] + + def forward(self, x): + B, C, H, W = x.shape + k, v = self.kv(x).view(B, self.num_heads, -1, H * W).split([self.key_dim, self.d], dim=2) + q = self.q(x).view(B, self.num_heads, self.key_dim, self.resolution_2) + ab = self.attention_biases[:, self.attention_bias_idxs] if self.training else self.ab + attn = (q.transpose(-2, -1) @ k) * self.scale + ab + attn = attn.softmax(dim=-1) + + x = (v @ attn.transpose(-2, -1)).reshape(B, -1, self.resolution_, self.resolution_) + x = self.proj(x) + return x + + +class Levit(torch.nn.Module): + """ Vision Transformer with support for patch or hybrid CNN input stage + """ + + def __init__( + self, + img_size=224, + patch_size=16, + in_chans=3, + num_classes=1000, + embed_dim=[192], + key_dim=[64], + depth=[12], + num_heads=[3], + attn_ratio=[2], + mlp_ratio=[2], + hybrid_backbone=None, + down_ops=[], + attn_act_layer=torch.nn.Hardswish, + mlp_act_layer=torch.nn.Hardswish, + distillation=True, + drop_path=0): + super().__init__() + self.num_classes = num_classes + self.num_features = embed_dim[-1] + self.embed_dim = embed_dim + self.distillation = distillation + + self.patch_embed = hybrid_backbone + + self.blocks = [] + down_ops.append(['']) + resolution = img_size // patch_size + for i, (ed, kd, dpth, nh, ar, mr, do) in enumerate( + zip(embed_dim, key_dim, depth, num_heads, attn_ratio, mlp_ratio, down_ops)): + for _ in range(dpth): + self.blocks.append( + Residual( + Attention(ed, kd, nh, attn_ratio=ar, act_layer=attn_act_layer, resolution=resolution), + drop_path)) + if mr > 0: + h = int(ed * mr) + self.blocks.append( + Residual(torch.nn.Sequential( + ConvNorm(ed, h, resolution=resolution), + mlp_act_layer(), + ConvNorm(h, ed, bn_weight_init=0, resolution=resolution), + ), drop_path)) + if do[0] == 'Subsample': + # ('Subsample',key_dim, num_heads, attn_ratio, mlp_ratio, stride) + resolution_ = (resolution - 1) // do[5] + 1 + self.blocks.append( + AttentionSubsample( + *embed_dim[i:i + 2], key_dim=do[1], num_heads=do[2], attn_ratio=do[3], + act_layer=attn_act_layer, stride=do[5], + resolution=resolution, resolution_=resolution_)) + resolution = resolution_ + if do[4] > 0: # mlp_ratio + h = int(embed_dim[i + 1] * do[4]) + self.blocks.append( + Residual(torch.nn.Sequential( + ConvNorm(embed_dim[i + 1], h, resolution=resolution), + mlp_act_layer(), + ConvNorm(h, embed_dim[i + 1], bn_weight_init=0, resolution=resolution), + ), drop_path)) + self.blocks = torch.nn.Sequential(*self.blocks) + + # Classifier head + self.head = NormLinear( + embed_dim[-1], num_classes) if num_classes > 0 else torch.nn.Identity() + if distillation: + self.head_dist = NormLinear( + embed_dim[-1], num_classes) if num_classes > 0 else torch.nn.Identity() + + @torch.jit.ignore + def no_weight_decay(self): + return {x for x in self.state_dict().keys() if 'attention_biases' in x} + + def forward(self, x): + x = self.patch_embed(x) + x = self.blocks(x) + x = torch.nn.functional.adaptive_avg_pool2d(x, 1).flatten(1) + if self.distillation: + x = self.head(x), self.head_dist(x) + if not self.training: + x = (x[0] + x[1]) / 2 + else: + x = self.head(x) + return x + + +def model_factory(C, D, X, N, drop_path, weights, num_classes, distillation, pretrained, fuse): + embed_dim = [int(x) for x in C.split('_')] + num_heads = [int(x) for x in N.split('_')] + depth = [int(x) for x in X.split('_')] + act = torch.nn.Hardswish + model = Levit( + patch_size=16, + embed_dim=embed_dim, + num_heads=num_heads, + key_dim=[D] * 3, + depth=depth, + attn_ratio=[2, 2, 2], + mlp_ratio=[2, 2, 2], + down_ops=[ + # ('Subsample',key_dim, num_heads, attn_ratio, mlp_ratio, stride) + ['Subsample', D, embed_dim[0] // D, 4, 2, 2], + ['Subsample', D, embed_dim[1] // D, 4, 2, 2], + ], + attn_act_layer=act, + mlp_act_layer=act, + hybrid_backbone=b16(embed_dim[0], activation=act), + num_classes=num_classes, + drop_path=drop_path, + distillation=distillation + ) + model.default_cfg = _cfg() + if pretrained: + checkpoint = torch.hub.load_state_dict_from_url( + weights, map_location='cpu') + d = checkpoint['model'] + D = model.state_dict() + for k in d.keys(): + if D[k].shape != d[k].shape: + d[k] = d[k][:, :, None, None] + model.load_state_dict(d) + #if fuse: + # utils.replace_batchnorm(model) + + return model + diff --git a/timm/models/visformer.py b/timm/models/visformer.py new file mode 100644 index 00000000..0d213ad5 --- /dev/null +++ b/timm/models/visformer.py @@ -0,0 +1,377 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from .helpers import build_model_with_cfg, overlay_external_default_cfg +from .layers import to_2tuple, trunc_normal_, DropPath, PatchEmbed +from .registry import register_model + + +__all__ = ['Visformer'] + + +def _cfg(url='', **kwargs): + return { + 'url': url, + 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, + 'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True, + 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, + 'first_conv': 'patch_embed.proj', 'classifier': 'head', + **kwargs + } + + +class LayerNormBHWC(nn.LayerNorm): + def __init__(self, dim): + super().__init__(dim) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return F.layer_norm( + x.permute(0, 2, 3, 1), self.normalized_shape, self.weight, self.bias, self.eps).permute(0, 3, 1, 2) + + +class SpatialMlp(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, + act_layer=nn.GELU, drop=0., group=8, spatial_conv=False): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.in_features = in_features + self.out_features = out_features + self.spatial_conv = spatial_conv + if self.spatial_conv: + if group < 2: # net setting + hidden_features = in_features * 5 // 6 + else: + hidden_features = in_features * 2 + self.hidden_features = hidden_features + self.group = group + self.drop = nn.Dropout(drop) + self.conv1 = nn.Conv2d(in_features, hidden_features, 1, stride=1, padding=0, bias=False) + self.act1 = act_layer() + if self.spatial_conv: + self.conv2 = nn.Conv2d( + hidden_features, hidden_features, 3, stride=1, padding=1, groups=self.group, bias=False) + self.act2 = act_layer() + else: + self.conv2 = None + self.act2 = None + self.conv3 = nn.Conv2d(hidden_features, out_features, 1, stride=1, padding=0, bias=False) + + def forward(self, x): + x = self.conv1(x) + x = self.act1(x) + x = self.drop(x) + if self.conv2 is not None: + x = self.conv2(x) + x = self.act2(x) + x = self.conv3(x) + x = self.drop(x) + return x + + +class Attention(nn.Module): + def __init__(self, dim, num_heads=8, head_dim_ratio=1., attn_drop=0., proj_drop=0.): + super().__init__() + self.dim = dim + self.num_heads = num_heads + head_dim = round(dim // num_heads * head_dim_ratio) + self.head_dim = head_dim + self.scale = head_dim ** -0.5 + self.qkv = nn.Conv2d(dim, head_dim * num_heads * 3, 1, stride=1, padding=0, bias=False) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Conv2d(self.head_dim * self.num_heads, dim, 1, stride=1, padding=0, bias=False) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x): + B, C, H, W = x.shape + x = self.qkv(x).reshape(B, 3, self.num_heads, self.head_dim, -1).permute(1, 0, 2, 4, 3) + q, k, v = x[0], x[1], x[2] + + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + x = attn @ v + + x = x.permute(0, 1, 3, 2).reshape(B, -1, H, W) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class Block(nn.Module): + def __init__(self, dim, num_heads, head_dim_ratio=1., mlp_ratio=4., + drop=0., attn_drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=LayerNormBHWC, + group=8, attn_disabled=False, spatial_conv=False): + super().__init__() + self.spatial_conv = spatial_conv + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + if attn_disabled: + self.norm1 = None + self.attn = None + else: + self.norm1 = norm_layer(dim) + self.attn = Attention( + dim, num_heads=num_heads, head_dim_ratio=head_dim_ratio, attn_drop=attn_drop, proj_drop=drop) + + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = SpatialMlp( + in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop, + group=group, spatial_conv=spatial_conv) # new setting + + def forward(self, x): + if self.attn is not None: + x = x + self.drop_path(self.attn(self.norm1(x))) + x = x + self.drop_path(self.mlp(self.norm2(x))) + return x + + +class Visformer(nn.Module): + def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, init_channels=32, embed_dim=384, + depth=12, num_heads=6, mlp_ratio=4., drop_rate=0., attn_drop_rate=0., drop_path_rate=0., + norm_layer=LayerNormBHWC, attn_stage='111', pos_embed=True, spatial_conv='111', + vit_stem=False, group=8, pool=True, conv_init=False, embed_norm=None): + super().__init__() + self.num_classes = num_classes + self.num_features = self.embed_dim = embed_dim + self.init_channels = init_channels + self.img_size = img_size + self.vit_stem = vit_stem + self.pool = pool + self.conv_init = conv_init + if isinstance(depth, (list, tuple)): + self.stage_num1, self.stage_num2, self.stage_num3 = depth + depth = sum(depth) + else: + self.stage_num1 = self.stage_num3 = depth // 3 + self.stage_num2 = depth - self.stage_num1 - self.stage_num3 + self.pos_embed = pos_embed + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] + + # stage 1 + if self.vit_stem: + self.stem = None + self.patch_embed1 = PatchEmbed( + img_size=img_size, patch_size=patch_size, in_chans=in_chans, + embed_dim=embed_dim, norm_layer=embed_norm, flatten=False) + img_size //= 16 + else: + if self.init_channels is None: + self.stem = None + self.patch_embed1 = PatchEmbed( + img_size=img_size, patch_size=patch_size // 2, in_chans=in_chans, + embed_dim=embed_dim // 2, norm_layer=embed_norm, flatten=False) + img_size //= 8 + else: + self.stem = nn.Sequential( + nn.Conv2d(3, self.init_channels, 7, stride=2, padding=3, bias=False), + nn.BatchNorm2d(self.init_channels), + nn.ReLU(inplace=True) + ) + img_size //= 2 + self.patch_embed1 = PatchEmbed( + img_size=img_size, patch_size=patch_size // 4, in_chans=self.init_channels, + embed_dim=embed_dim // 2, norm_layer=embed_norm, flatten=False) + img_size //= 4 + + if self.pos_embed: + if self.vit_stem: + self.pos_embed1 = nn.Parameter(torch.zeros(1, embed_dim, img_size, img_size)) + else: + self.pos_embed1 = nn.Parameter(torch.zeros(1, embed_dim//2, img_size, img_size)) + self.pos_drop = nn.Dropout(p=drop_rate) + self.stage1 = nn.ModuleList([ + Block( + dim=embed_dim//2, num_heads=num_heads, head_dim_ratio=0.5, mlp_ratio=mlp_ratio, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, + group=group, attn_disabled=(attn_stage[0] == '0'), spatial_conv=(spatial_conv[0] == '1') + ) + for i in range(self.stage_num1) + ]) + + #stage2 + if not self.vit_stem: + self.patch_embed2 = PatchEmbed( + img_size=img_size, patch_size=patch_size // 8, in_chans=embed_dim // 2, + embed_dim=embed_dim, norm_layer=embed_norm, flatten=False) + img_size //= 2 + if self.pos_embed: + self.pos_embed2 = nn.Parameter(torch.zeros(1, embed_dim, img_size, img_size)) + self.stage2 = nn.ModuleList([ + Block( + dim=embed_dim, num_heads=num_heads, head_dim_ratio=1.0, mlp_ratio=mlp_ratio, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, + group=group, attn_disabled=(attn_stage[1] == '0'), spatial_conv=(spatial_conv[1] == '1') + ) + for i in range(self.stage_num1, self.stage_num1+self.stage_num2) + ]) + + # stage 3 + if not self.vit_stem: + self.patch_embed3 = PatchEmbed( + img_size=img_size, patch_size=patch_size // 8, in_chans=embed_dim, + embed_dim=embed_dim * 2, norm_layer=embed_norm, flatten=False) + img_size //= 2 + if self.pos_embed: + self.pos_embed3 = nn.Parameter(torch.zeros(1, embed_dim*2, img_size, img_size)) + self.stage3 = nn.ModuleList([ + Block( + dim=embed_dim*2, num_heads=num_heads, head_dim_ratio=1.0, mlp_ratio=mlp_ratio, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, + group=group, attn_disabled=(attn_stage[2] == '0'), spatial_conv=(spatial_conv[2] == '1') + ) + for i in range(self.stage_num1+self.stage_num2, depth) + ]) + + # head + if self.pool: + self.global_pooling = nn.AdaptiveAvgPool2d(1) + head_dim = embed_dim if self.vit_stem else embed_dim * 2 + self.norm = norm_layer(head_dim) + self.head = nn.Linear(head_dim, num_classes) + + # weights init + if self.pos_embed: + trunc_normal_(self.pos_embed1, std=0.02) + if not self.vit_stem: + trunc_normal_(self.pos_embed2, std=0.02) + trunc_normal_(self.pos_embed3, std=0.02) + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=0.02) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + elif isinstance(m, nn.BatchNorm2d): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + elif isinstance(m, nn.Conv2d): + if self.conv_init: + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + else: + trunc_normal_(m.weight, std=0.02) + if m.bias is not None: + nn.init.constant_(m.bias, 0.) + + def forward(self, x): + if self.stem is not None: + x = self.stem(x) + + # stage 1 + x = self.patch_embed1(x) + if self.pos_embed: + x = x + self.pos_embed1 + x = self.pos_drop(x) + for b in self.stage1: + x = b(x) + + # stage 2 + if not self.vit_stem: + x = self.patch_embed2(x) + if self.pos_embed: + x = x + self.pos_embed2 + x = self.pos_drop(x) + for b in self.stage2: + x = b(x) + + # stage3 + if not self.vit_stem: + x = self.patch_embed3(x) + if self.pos_embed: + x = x + self.pos_embed3 + x = self.pos_drop(x) + for b in self.stage3: + x = b(x) + + # head + x = self.norm(x) + if self.pool: + x = self.global_pooling(x) + else: + x = x[:, :, 0, 0] + + x = self.head(x.view(x.size(0), -1)) + return x + + +@register_model +def visformer_tiny(pretrained=False, **kwargs): + model = Visformer( + img_size=224, init_channels=16, embed_dim=192, depth=(7, 4, 4), num_heads=3, mlp_ratio=4., group=8, + attn_stage='011', spatial_conv='100', norm_layer=nn.BatchNorm2d, conv_init=True, + embed_norm=nn.BatchNorm2d, **kwargs) + return model + + +@register_model +def visformer_small(pretrained=False, **kwargs): + model = Visformer( + img_size=224, init_channels=32, embed_dim=384, depth=(7, 4, 4), num_heads=6, mlp_ratio=4., group=8, + attn_stage='011', spatial_conv='100', norm_layer=nn.BatchNorm2d, conv_init=True, + embed_norm=nn.BatchNorm2d, **kwargs) + return model + + +@register_model +def visformer_net1(pretrained=False, **kwargs): + model = Visformer( + init_channels=None, embed_dim=384, depth=(0, 12, 0), num_heads=6, mlp_ratio=4., attn_stage='111', + spatial_conv='000', vit_stem=True, conv_init=True, **kwargs) + return model + + +@register_model +def visformer_net2(pretrained=False, **kwargs): + model = Visformer( + init_channels=32, embed_dim=384, depth=(0, 12, 0), num_heads=6, mlp_ratio=4., attn_stage='111', + spatial_conv='000', vit_stem=False, conv_init=True, **kwargs) + return model + + +@register_model +def visformer_net3(pretrained=False, **kwargs): + model = Visformer( + init_channels=32, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4., attn_stage='111', + spatial_conv='000', vit_stem=False, conv_init=True, **kwargs) + return model + + +@register_model +def visformer_net4(pretrained=False, **kwargs): + model = Visformer(init_channels=32, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4., attn_stage='111', + spatial_conv='000', vit_stem=False, conv_init=True, **kwargs) + return model + + +@register_model +def visformer_net5(pretrained=False, **kwargs): + model = Visformer( + init_channels=32, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4., group=1, attn_stage='111', + spatial_conv='111', vit_stem=False, conv_init=True, **kwargs) + return model + + +@register_model +def visformer_net6(pretrained=False, **kwargs): + model = Visformer( + init_channels=32, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4., group=1, attn_stage='111', + pos_embed=False, spatial_conv='111', conv_init=True, **kwargs) + return model + + +@register_model +def visformer_net7(pretrained=False, **kwargs): + model = Visformer( + init_channels=32, embed_dim=384, depth=(6, 7, 7), num_heads=6, group=1, attn_stage='000', + pos_embed=False, spatial_conv='111', conv_init=True, **kwargs) + return model + + + + From 94d4b53352f6824b9cbe41c3dd70c18103714951 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Sat, 15 May 2021 08:41:31 -0700 Subject: [PATCH 03/18] Add temporary default_cfgs to visformer models so they pass tests --- timm/models/visformer.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/timm/models/visformer.py b/timm/models/visformer.py index 0d213ad5..aa3bca57 100644 --- a/timm/models/visformer.py +++ b/timm/models/visformer.py @@ -306,6 +306,7 @@ def visformer_tiny(pretrained=False, **kwargs): img_size=224, init_channels=16, embed_dim=192, depth=(7, 4, 4), num_heads=3, mlp_ratio=4., group=8, attn_stage='011', spatial_conv='100', norm_layer=nn.BatchNorm2d, conv_init=True, embed_norm=nn.BatchNorm2d, **kwargs) + model.default_cfg = _cfg() return model @@ -315,6 +316,7 @@ def visformer_small(pretrained=False, **kwargs): img_size=224, init_channels=32, embed_dim=384, depth=(7, 4, 4), num_heads=6, mlp_ratio=4., group=8, attn_stage='011', spatial_conv='100', norm_layer=nn.BatchNorm2d, conv_init=True, embed_norm=nn.BatchNorm2d, **kwargs) + model.default_cfg = _cfg() return model @@ -323,6 +325,7 @@ def visformer_net1(pretrained=False, **kwargs): model = Visformer( init_channels=None, embed_dim=384, depth=(0, 12, 0), num_heads=6, mlp_ratio=4., attn_stage='111', spatial_conv='000', vit_stem=True, conv_init=True, **kwargs) + model.default_cfg = _cfg() return model @@ -331,6 +334,7 @@ def visformer_net2(pretrained=False, **kwargs): model = Visformer( init_channels=32, embed_dim=384, depth=(0, 12, 0), num_heads=6, mlp_ratio=4., attn_stage='111', spatial_conv='000', vit_stem=False, conv_init=True, **kwargs) + model.default_cfg = _cfg() return model @@ -339,13 +343,16 @@ def visformer_net3(pretrained=False, **kwargs): model = Visformer( init_channels=32, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4., attn_stage='111', spatial_conv='000', vit_stem=False, conv_init=True, **kwargs) + model.default_cfg = _cfg() return model @register_model def visformer_net4(pretrained=False, **kwargs): - model = Visformer(init_channels=32, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4., attn_stage='111', - spatial_conv='000', vit_stem=False, conv_init=True, **kwargs) + model = Visformer( + init_channels=32, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4., attn_stage='111', + spatial_conv='000', vit_stem=False, conv_init=True, **kwargs) + model.default_cfg = _cfg() return model @@ -354,6 +361,7 @@ def visformer_net5(pretrained=False, **kwargs): model = Visformer( init_channels=32, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4., group=1, attn_stage='111', spatial_conv='111', vit_stem=False, conv_init=True, **kwargs) + model.default_cfg = _cfg() return model @@ -362,6 +370,7 @@ def visformer_net6(pretrained=False, **kwargs): model = Visformer( init_channels=32, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4., group=1, attn_stage='111', pos_embed=False, spatial_conv='111', conv_init=True, **kwargs) + model.default_cfg = _cfg() return model @@ -370,6 +379,7 @@ def visformer_net7(pretrained=False, **kwargs): model = Visformer( init_channels=32, embed_dim=384, depth=(6, 7, 7), num_heads=6, group=1, attn_stage='000', pos_embed=False, spatial_conv='111', conv_init=True, **kwargs) + model.default_cfg = _cfg() return model From d53e91218e9c1a6df468740a5d86a956016042f8 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Sat, 15 May 2021 22:56:12 -0700 Subject: [PATCH 04/18] Fix tf.data options setting for newer TF versions --- timm/data/parsers/parser_tfds.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/timm/data/parsers/parser_tfds.py b/timm/data/parsers/parser_tfds.py index 0b12a4db..2ff90b09 100644 --- a/timm/data/parsers/parser_tfds.py +++ b/timm/data/parsers/parser_tfds.py @@ -25,8 +25,8 @@ from .parser import Parser MAX_TP_SIZE = 8 # maximum TF threadpool size, only doing jpeg decodes and queuing activities -SHUFFLE_SIZE = 16834 # samples to shuffle in DS queue -PREFETCH_SIZE = 4096 # samples to prefetch +SHUFFLE_SIZE = 20480 # samples to shuffle in DS queue +PREFETCH_SIZE = 2048 # samples to prefetch def even_split_indices(split, n, num_samples): @@ -144,14 +144,16 @@ class ParserTfds(Parser): ds = self.builder.as_dataset( split=self.subsplit or self.split, shuffle_files=self.shuffle, read_config=read_config) # avoid overloading threading w/ combo fo TF ds threads + PyTorch workers - ds.options().experimental_threading.private_threadpool_size = max(1, MAX_TP_SIZE // num_workers) - ds.options().experimental_threading.max_intra_op_parallelism = 1 + options = tf.data.Options() + options.experimental_threading.private_threadpool_size = max(1, MAX_TP_SIZE // num_workers) + options.experimental_threading.max_intra_op_parallelism = 1 + ds = ds.with_options(options) if self.is_training or self.repeats > 1: # to prevent excessive drop_last batch behaviour w/ IterableDatasets # see warnings at https://pytorch.org/docs/stable/data.html#multi-process-data-loading ds = ds.repeat() # allow wrap around and break iteration manually if self.shuffle: - ds = ds.shuffle(min(self.num_samples // self._num_pipelines, SHUFFLE_SIZE), seed=0) + ds = ds.shuffle(min(self.num_samples, SHUFFLE_SIZE) // self._num_pipelines, seed=0) ds = ds.prefetch(min(self.num_samples // self._num_pipelines, PREFETCH_SIZE)) self.ds = tfds.as_numpy(ds) From 9a3ae97311d7971c532d19e6262c600fcdd7808d Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Sat, 15 May 2021 22:56:51 -0700 Subject: [PATCH 05/18] Another set of byoanet models w/ ECA channel + SA + groups --- timm/models/byoanet.py | 101 +++++++++++++++++++++++++++++++++++++++++ timm/models/byobnet.py | 2 +- 2 files changed, 102 insertions(+), 1 deletion(-) diff --git a/timm/models/byoanet.py b/timm/models/byoanet.py index a58eea63..c179a01c 100644 --- a/timm/models/byoanet.py +++ b/timm/models/byoanet.py @@ -47,17 +47,21 @@ default_cfgs = { # GPU-Efficient (ResNet) weights 'botnet26t_256': _cfg(url='', fixed_input_size=True, input_size=(3, 256, 256), pool_size=(8, 8)), 'botnet50ts_256': _cfg(url='', fixed_input_size=True, input_size=(3, 256, 256), pool_size=(8, 8)), + 'eca_botnext26ts_256': _cfg(url='', fixed_input_size=True, input_size=(3, 256, 256), pool_size=(8, 8)), 'halonet_h1': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8), min_input_size=(3, 256, 256)), 'halonet_h1_c4c5': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8), min_input_size=(3, 256, 256)), 'halonet26t': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8), min_input_size=(3, 256, 256)), 'halonet50ts': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8), min_input_size=(3, 256, 256)), + 'eca_halonext26ts': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8), min_input_size=(3, 256, 256)), 'lambda_resnet26t': _cfg(url='', min_input_size=(3, 128, 128), input_size=(3, 256, 256), pool_size=(8, 8)), 'lambda_resnet50t': _cfg(url='', min_input_size=(3, 128, 128)), + 'eca_lambda_resnext26ts': _cfg(url='', min_input_size=(3, 128, 128), input_size=(3, 256, 256), pool_size=(8, 8)), 'swinnet26t_256': _cfg(url='', fixed_input_size=True, input_size=(3, 256, 256), pool_size=(8, 8)), 'swinnet50ts_256': _cfg(url='', fixed_input_size=True, input_size=(3, 256, 256), pool_size=(8, 8)), + 'eca_swinnext26ts_256': _cfg(url='', fixed_input_size=True, input_size=(3, 256, 256), pool_size=(8, 8)), 'rednet26t': _cfg(url='', fixed_input_size=False, input_size=(3, 256, 256), pool_size=(8, 8)), 'rednet50ts': _cfg(url='', fixed_input_size=False, input_size=(3, 256, 256), pool_size=(8, 8)), @@ -129,6 +133,23 @@ model_cfgs = dict( self_attn_fixed_size=True, self_attn_kwargs=dict() ), + eca_botnext26ts=ByoaCfg( + blocks=( + ByoaBlocksCfg(type='bottle', d=3, c=256, s=1, gs=16, br=0.25), + ByoaBlocksCfg(type='bottle', d=4, c=512, s=2, gs=16, br=0.25), + interleave_attn(types=('bottle', 'self_attn'), every=1, d=2, c=1024, s=2, gs=16, br=0.25), + ByoaBlocksCfg(type='self_attn', d=3, c=2048, s=2, gs=16, br=0.25), + ), + stem_chs=64, + stem_type='tiered', + stem_pool='maxpool', + num_features=0, + act_layer='silu', + attn_layer='eca', + self_attn_layer='bottleneck', + self_attn_fixed_size=True, + self_attn_kwargs=dict() + ), halonet_h1=ByoaCfg( blocks=( @@ -187,6 +208,22 @@ model_cfgs = dict( self_attn_layer='halo', self_attn_kwargs=dict(block_size=8, halo_size=2) ), + eca_halonext26ts=ByoaCfg( + blocks=( + ByoaBlocksCfg(type='bottle', d=2, c=256, s=1, gs=16, br=0.25), + ByoaBlocksCfg(type='bottle', d=2, c=512, s=2, gs=16, br=0.25), + interleave_attn(types=('bottle', 'self_attn'), every=1, d=2, c=1024, s=2, gs=16, br=0.25), + ByoaBlocksCfg(type='self_attn', d=2, c=2048, s=2, gs=16, br=0.25), + ), + stem_chs=64, + stem_type='tiered', + stem_pool='maxpool', + num_features=0, + act_layer='silu', + attn_layer='eca', + self_attn_layer='halo', + self_attn_kwargs=dict(block_size=8, halo_size=2) # intended for 256x256 res + ), lambda_resnet26t=ByoaCfg( blocks=( @@ -216,6 +253,22 @@ model_cfgs = dict( self_attn_layer='lambda', self_attn_kwargs=dict() ), + eca_lambda_resnext26ts=ByoaCfg( + blocks=( + ByoaBlocksCfg(type='bottle', d=2, c=256, s=1, gs=16, br=0.25), + ByoaBlocksCfg(type='bottle', d=2, c=512, s=2, gs=16, br=0.25), + interleave_attn(types=('bottle', 'self_attn'), every=1, d=2, c=1024, s=2, gs=16, br=0.25), + ByoaBlocksCfg(type='self_attn', d=2, c=2048, s=2, gs=16, br=0.25), + ), + stem_chs=64, + stem_type='tiered', + stem_pool='maxpool', + num_features=0, + act_layer='silu', + attn_layer='eca', + self_attn_layer='lambda', + self_attn_kwargs=dict() + ), swinnet26t=ByoaCfg( blocks=( @@ -248,6 +301,24 @@ model_cfgs = dict( self_attn_fixed_size=True, self_attn_kwargs=dict(win_size=8) ), + eca_swinnext26ts=ByoaCfg( + blocks=( + ByoaBlocksCfg(type='bottle', d=2, c=256, s=1, gs=16, br=0.25), + interleave_attn(types=('bottle', 'self_attn'), every=1, d=2, c=512, s=2, gs=16, br=0.25), + interleave_attn(types=('bottle', 'self_attn'), every=1, d=2, c=1024, s=2, gs=16, br=0.25), + ByoaBlocksCfg(type='self_attn', d=2, c=2048, s=2, gs=16, br=0.25), + ), + stem_chs=64, + stem_type='tiered', + stem_pool='maxpool', + num_features=0, + act_layer='silu', + attn_layer='eca', + self_attn_layer='swin', + self_attn_fixed_size=True, + self_attn_kwargs=dict(win_size=8) + ), + rednet26t=ByoaCfg( blocks=( @@ -454,6 +525,14 @@ def botnet50ts_256(pretrained=False, **kwargs): return _create_byoanet('botnet50ts_256', 'botnet50ts', pretrained=pretrained, **kwargs) +@register_model +def eca_botnext26ts_256(pretrained=False, **kwargs): + """ Bottleneck Transformer w/ ResNet26-T backbone. Bottleneck attn in final stage. + """ + kwargs.setdefault('img_size', 256) + return _create_byoanet('eca_botnext26ts_256', 'eca_botnext26ts', pretrained=pretrained, **kwargs) + + @register_model def halonet_h1(pretrained=False, **kwargs): """ HaloNet-H1. Halo attention in all stages as per the paper. @@ -484,6 +563,13 @@ def halonet50ts(pretrained=False, **kwargs): return _create_byoanet('halonet50ts', pretrained=pretrained, **kwargs) +@register_model +def eca_halonext26ts(pretrained=False, **kwargs): + """ HaloNet w/ a ResNet26-t backbone, Hallo attention in final stage + """ + return _create_byoanet('eca_halonext26ts', pretrained=pretrained, **kwargs) + + @register_model def lambda_resnet26t(pretrained=False, **kwargs): """ Lambda-ResNet-26T. Lambda layers in one C4 stage and all C5. @@ -498,6 +584,13 @@ def lambda_resnet50t(pretrained=False, **kwargs): return _create_byoanet('lambda_resnet50t', pretrained=pretrained, **kwargs) +@register_model +def eca_lambda_resnext26ts(pretrained=False, **kwargs): + """ Lambda-ResNet-26T. Lambda layers in one C4 stage and all C5. + """ + return _create_byoanet('eca_lambda_resnext26ts', pretrained=pretrained, **kwargs) + + @register_model def swinnet26t_256(pretrained=False, **kwargs): """ @@ -514,6 +607,14 @@ def swinnet50ts_256(pretrained=False, **kwargs): return _create_byoanet('swinnet50ts_256', 'swinnet50ts', pretrained=pretrained, **kwargs) +@register_model +def eca_swinnext26ts_256(pretrained=False, **kwargs): + """ + """ + kwargs.setdefault('img_size', 256) + return _create_byoanet('eca_swinnext26ts_256', 'eca_swinnext26ts', pretrained=pretrained, **kwargs) + + @register_model def rednet26t(pretrained=False, **kwargs): """ diff --git a/timm/models/byobnet.py b/timm/models/byobnet.py index 75610f67..8f4a2020 100644 --- a/timm/models/byobnet.py +++ b/timm/models/byobnet.py @@ -98,7 +98,7 @@ class BlocksCfg: s: int = 2 # stride of stage (first block) gs: Optional[Union[int, Callable]] = None # group-size of blocks in stage, conv is depthwise if gs == 1 br: float = 1. # bottleneck-ratio of blocks in stage - no_attn: bool = True # disable channel attn (ie SE) when layer is set for model + no_attn: bool = False # disable channel attn (ie SE) when layer is set for model @dataclass From 18bf520ad12297dac4f9992ce497030259ca1aa2 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Sat, 22 May 2021 21:55:37 -0700 Subject: [PATCH 06/18] Add eca_nfnet_l2/l3 defs for future training --- timm/models/nfnet.py | 28 ++++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/timm/models/nfnet.py b/timm/models/nfnet.py index 3c21eea1..1b67581e 100644 --- a/timm/models/nfnet.py +++ b/timm/models/nfnet.py @@ -110,6 +110,12 @@ default_cfgs = dict( eca_nfnet_l1=_dcfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/ecanfnet_l1_ra2-7dce93cd.pth', pool_size=(8, 8), input_size=(3, 256, 256), test_input_size=(3, 320, 320), crop_pct=1.0), + eca_nfnet_l2=_dcfg( + url='', + pool_size=(9, 9), input_size=(3, 288, 288), test_input_size=(3, 352, 352), crop_pct=1.0), + eca_nfnet_l3=_dcfg( + url='', + pool_size=(10, 10), input_size=(3, 320, 320), test_input_size=(3, 384, 384), crop_pct=1.0), nf_regnet_b0=_dcfg( url='', pool_size=(6, 6), input_size=(3, 192, 192), test_input_size=(3, 256, 256), first_conv='stem.conv'), @@ -244,6 +250,12 @@ model_cfgs = dict( eca_nfnet_l1=_nfnet_cfg( depths=(2, 4, 12, 6), feat_mult=2, group_size=64, bottle_ratio=0.25, attn_layer='eca', attn_kwargs=dict(), act_layer='silu'), + eca_nfnet_l2=_nfnet_cfg( + depths=(3, 6, 18, 9), feat_mult=2, group_size=64, bottle_ratio=0.25, + attn_layer='eca', attn_kwargs=dict(), act_layer='silu'), + eca_nfnet_l3=_nfnet_cfg( + depths=(4, 8, 24, 12), feat_mult=2, group_size=64, bottle_ratio=0.25, + attn_layer='eca', attn_kwargs=dict(), act_layer='silu'), # EffNet influenced RegNet defs. # NOTE: These aren't quite the official ver, ch_div=1 must be set for exact ch counts. I round to ch_div=8. @@ -814,6 +826,22 @@ def eca_nfnet_l1(pretrained=False, **kwargs): return _create_normfreenet('eca_nfnet_l1', pretrained=pretrained, **kwargs) +@register_model +def eca_nfnet_l2(pretrained=False, **kwargs): + """ ECA-NFNet-L2 w/ SiLU + My experimental 'light' model w/ F2 repeats, 2.0x final_conv mult, 64 group_size, .25 bottleneck & ECA attn + """ + return _create_normfreenet('eca_nfnet_l2', pretrained=pretrained, **kwargs) + + +@register_model +def eca_nfnet_l3(pretrained=False, **kwargs): + """ ECA-NFNet-L3 w/ SiLU + My experimental 'light' model w/ F3 repeats, 2.0x final_conv mult, 64 group_size, .25 bottleneck & ECA attn + """ + return _create_normfreenet('eca_nfnet_l3', pretrained=pretrained, **kwargs) + + @register_model def nf_regnet_b0(pretrained=False, **kwargs): """ Normalization-Free RegNet-B0 From bfc72f75d3f836ca5545cbe6ec1f8ba67b804b8b Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Mon, 24 May 2021 21:13:26 -0700 Subject: [PATCH 07/18] Expand scope of testing for non-std vision transformer / mlp models. Some related cleanup and create fn cleanup for all vision transformer and mlp models. More CoaT weights. --- tests/test_models.py | 70 ++-- timm/models/__init__.py | 2 +- timm/models/cait.py | 15 +- timm/models/coat.py | 179 +++++---- timm/models/convit.py | 5 +- timm/models/helpers.py | 8 +- timm/models/levit.py | 471 ++++++++++++++--------- timm/models/levitc.py | 400 ------------------- timm/models/mlp_mixer.py | 15 +- timm/models/pit.py | 12 +- timm/models/tnt.py | 31 +- timm/models/twins.py | 17 +- timm/models/visformer.py | 155 ++++---- timm/models/vision_transformer.py | 23 +- timm/models/vision_transformer_hybrid.py | 5 +- 15 files changed, 559 insertions(+), 849 deletions(-) delete mode 100644 timm/models/levitc.py diff --git a/tests/test_models.py b/tests/test_models.py index 5ff9fb33..570b49db 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -26,29 +26,41 @@ if 'GITHUB_ACTIONS' in os.environ: # and 'Linux' in platform.system(): EXCLUDE_FILTERS = [ '*efficientnet_l2*', '*resnext101_32x48d', '*in21k', '*152x4_bitm', '*101x3_bitm', '*nfnet_f3*', '*nfnet_f4*', '*nfnet_f5*', '*nfnet_f6*', '*nfnet_f7*', - '*resnetrs350*', '*resnetrs420*'] + NON_STD_FILTERS + '*resnetrs350*', '*resnetrs420*'] else: - EXCLUDE_FILTERS = NON_STD_FILTERS + EXCLUDE_FILTERS = [] -MAX_FWD_SIZE = 384 -MAX_BWD_SIZE = 128 +TARGET_FWD_SIZE = MAX_FWD_SIZE = 384 +TARGET_BWD_SIZE = 128 +MAX_BWD_SIZE = 384 MAX_FWD_FEAT_SIZE = 448 +def _get_input_size(model, target=None): + default_cfg = model.default_cfg + input_size = default_cfg['input_size'] + if 'fixed_input_size' in default_cfg and default_cfg['fixed_input_size']: + return input_size + if 'min_input_size' in default_cfg: + if target and max(input_size) > target: + input_size = default_cfg['min_input_size'] + else: + if target and max(input_size) > target: + input_size = tuple([min(x, target) for x in input_size]) + return input_size + + @pytest.mark.timeout(120) -@pytest.mark.parametrize('model_name', list_models(exclude_filters=EXCLUDE_FILTERS[:-NUM_NON_STD])) +@pytest.mark.parametrize('model_name', list_models(exclude_filters=EXCLUDE_FILTERS)) @pytest.mark.parametrize('batch_size', [1]) def test_model_forward(model_name, batch_size): """Run a single forward pass with each model""" model = create_model(model_name, pretrained=False) model.eval() - input_size = model.default_cfg['input_size'] - if any([x > MAX_FWD_SIZE for x in input_size]): - if is_model_default_key(model_name, 'fixed_input_size'): - pytest.skip("Fixed input size model > limit.") - # cap forward test at max res 384 * 384 to keep resource down - input_size = tuple([min(x, MAX_FWD_SIZE) for x in input_size]) + input_size = _get_input_size(model, TARGET_FWD_SIZE) + if max(input_size) > MAX_FWD_SIZE: + pytest.skip("Fixed input size model > limit.") inputs = torch.randn((batch_size, *input_size)) outputs = model(inputs) @@ -63,20 +75,16 @@ def test_model_backward(model_name, batch_size): """Run a single forward pass with each model""" model = create_model(model_name, pretrained=False, num_classes=42) num_params = sum([x.numel() for x in model.parameters()]) - model.eval() + model.train() - input_size = model.default_cfg['input_size'] - if not is_model_default_key(model_name, 'fixed_input_size'): - min_input_size = get_model_default_value(model_name, 'min_input_size') - if min_input_size is not None: - input_size = min_input_size - else: - if any([x > MAX_BWD_SIZE for x in input_size]): - # cap backward test at 128 * 128 to keep resource usage down - input_size = tuple([min(x, MAX_BWD_SIZE) for x in input_size]) + input_size = _get_input_size(model, TARGET_BWD_SIZE) + if max(input_size) > MAX_BWD_SIZE: + pytest.skip("Fixed input size model > limit.") inputs = torch.randn((batch_size, *input_size)) outputs = model(inputs) + if isinstance(outputs, tuple): + outputs = torch.cat(outputs) outputs.mean().backward() for n, x in model.named_parameters(): assert x.grad is not None, f'No gradient for {n}' @@ -168,12 +176,9 @@ def test_model_forward_torchscript(model_name, batch_size): model = create_model(model_name, pretrained=False) model.eval() - if has_model_default_key(model_name, 'fixed_input_size'): - input_size = get_model_default_value(model_name, 'input_size') - elif has_model_default_key(model_name, 'min_input_size'): - input_size = get_model_default_value(model_name, 'min_input_size') - else: - input_size = (3, 128, 128) # jit compile is already a bit slow and we've tested normal res already... + input_size = _get_input_size(model, 128) + if max(input_size) > MAX_FWD_SIZE: # NOTE using MAX_FWD_SIZE as the final limit is intentional + pytest.skip("Fixed input size model > limit.") model = torch.jit.script(model) outputs = model(torch.randn((batch_size, *input_size))) @@ -184,7 +189,7 @@ def test_model_forward_torchscript(model_name, batch_size): EXCLUDE_FEAT_FILTERS = [ '*pruned*', # hopefully fix at some point -] +] + NON_STD_FILTERS if 'GITHUB_ACTIONS' in os.environ: # and 'Linux' in platform.system(): # GitHub Linux runner is slower and hits memory limits sooner than MacOS, exclude bigger models EXCLUDE_FEAT_FILTERS += ['*resnext101_32x32d', '*resnext101_32x16d'] @@ -200,12 +205,9 @@ def test_model_forward_features(model_name, batch_size): expected_channels = model.feature_info.channels() assert len(expected_channels) >= 4 # all models here should have at least 4 feature levels by default, some 5 or 6 - if has_model_default_key(model_name, 'fixed_input_size'): - input_size = get_model_default_value(model_name, 'input_size') - elif has_model_default_key(model_name, 'min_input_size'): - input_size = get_model_default_value(model_name, 'min_input_size') - else: - input_size = (3, 96, 96) # jit compile is already a bit slow and we've tested normal res already... + input_size = _get_input_size(model, 96) # jit compile is already a bit slow and we've tested normal res already... + if max(input_size) > MAX_FWD_SIZE: # NOTE using MAX_FWD_SIZE as the final limit is intentional + pytest.skip("Fixed input size model > limit.") outputs = model(torch.randn((batch_size, *input_size))) assert len(expected_channels) == len(outputs) diff --git a/timm/models/__init__.py b/timm/models/__init__.py index 1a21de09..788b7518 100644 --- a/timm/models/__init__.py +++ b/timm/models/__init__.py @@ -16,8 +16,8 @@ from .hrnet import * from .inception_resnet_v2 import * from .inception_v3 import * from .inception_v4 import * -from .levitc import * from .levit import * +#from .levit import * from .mlp_mixer import * from .mobilenetv3 import * from .nasnet import * diff --git a/timm/models/cait.py b/timm/models/cait.py index c5f7742f..aa2e5f07 100644 --- a/timm/models/cait.py +++ b/timm/models/cait.py @@ -306,26 +306,15 @@ def checkpoint_filter_fn(state_dict, model=None): return checkpoint_no_module -def _create_cait(variant, pretrained=False, default_cfg=None, **kwargs): - if default_cfg is None: - default_cfg = deepcopy(default_cfgs[variant]) - overlay_external_default_cfg(default_cfg, kwargs) - default_num_classes = default_cfg['num_classes'] - default_img_size = default_cfg['input_size'][-2:] - num_classes = kwargs.pop('num_classes', default_num_classes) - img_size = kwargs.pop('img_size', default_img_size) - +def _create_cait(variant, pretrained=False, **kwargs): if kwargs.get('features_only', None): raise RuntimeError('features_only not implemented for Vision Transformer models.') model = build_model_with_cfg( Cait, variant, pretrained, - default_cfg=default_cfg, - img_size=img_size, - num_classes=num_classes, + default_cfg=default_cfgs[variant], pretrained_filter_fn=checkpoint_filter_fn, **kwargs) - return model diff --git a/timm/models/coat.py b/timm/models/coat.py index cb265522..9eb384d8 100644 --- a/timm/models/coat.py +++ b/timm/models/coat.py @@ -7,19 +7,19 @@ Official CoaT code at: https://github.com/mlpc-ucsd/CoaT Modified from timm/models/vision_transformer.py """ -from typing import Tuple, Dict, Any, Optional +from copy import deepcopy +from functools import partial +from typing import Tuple, List import torch import torch.nn as nn import torch.nn.functional as F from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from timm.models.helpers import load_pretrained -from timm.models.layers import PatchEmbed, Mlp, DropPath, to_2tuple, trunc_normal_ -from timm.models.registry import register_model +from .helpers import build_model_with_cfg, overlay_external_default_cfg +from .layers import PatchEmbed, Mlp, DropPath, to_2tuple, trunc_normal_ +from .registry import register_model -from functools import partial -from torch import nn __all__ = [ "coat_tiny", @@ -34,7 +34,7 @@ def _cfg_coat(url='', **kwargs): return { 'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, - 'crop_pct': .9, 'interpolation': 'bicubic', + 'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True, 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, 'first_conv': 'patch_embed1.proj', 'classifier': 'head', **kwargs @@ -42,15 +42,21 @@ def _cfg_coat(url='', **kwargs): default_cfgs = { - 'coat_tiny': _cfg_coat(), - 'coat_mini': _cfg_coat(), + 'coat_tiny': _cfg_coat( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-coat-weights/coat_tiny-473c2a20.pth' + ), + 'coat_mini': _cfg_coat( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-coat-weights/coat_mini-2c6baf49.pth' + ), 'coat_lite_tiny': _cfg_coat( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-coat-weights/coat_lite_tiny-461b07a7.pth' ), 'coat_lite_mini': _cfg_coat( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-coat-weights/coat_lite_mini-d7842000.pth' ), - 'coat_lite_small': _cfg_coat(), + 'coat_lite_small': _cfg_coat( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-coat-weights/coat_lite_small-fea1d5a1.pth' + ), } @@ -120,11 +126,11 @@ class ConvRelPosEnc(nn.Module): class FactorAtt_ConvRelPosEnc(nn.Module): """ Factorized attention with convolutional relative position encoding class. """ - def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., shared_crpe=None): + def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0., shared_crpe=None): super().__init__() self.num_heads = num_heads head_dim = dim // num_heads - self.scale = qk_scale or head_dim ** -0.5 + self.scale = head_dim ** -0.5 self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) self.attn_drop = nn.Dropout(attn_drop) # Note: attn_drop is actually not used. @@ -190,9 +196,8 @@ class ConvPosEnc(nn.Module): class SerialBlock(nn.Module): """ Serial block class. Note: In this implementation, each serial block only contains a conv-attention and a FFN (MLP) module. """ - def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., - drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, - shared_cpe=None, shared_crpe=None): + def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0., + drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, shared_cpe=None, shared_crpe=None): super().__init__() # Conv-Attention. @@ -200,8 +205,7 @@ class SerialBlock(nn.Module): self.norm1 = norm_layer(dim) self.factoratt_crpe = FactorAtt_ConvRelPosEnc( - dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, - shared_crpe=shared_crpe) + dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop, shared_crpe=shared_crpe) self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() # MLP. @@ -226,27 +230,24 @@ class SerialBlock(nn.Module): class ParallelBlock(nn.Module): """ Parallel block class. """ - def __init__(self, dims, num_heads, mlp_ratios=[], qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., - drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, - shared_cpes=None, shared_crpes=None): + def __init__(self, dims, num_heads, mlp_ratios=[], qkv_bias=False, drop=0., attn_drop=0., + drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, shared_crpes=None): super().__init__() # Conv-Attention. - self.cpes = shared_cpes - self.norm12 = norm_layer(dims[1]) self.norm13 = norm_layer(dims[2]) self.norm14 = norm_layer(dims[3]) self.factoratt_crpe2 = FactorAtt_ConvRelPosEnc( - dims[1], num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, + dims[1], num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop, shared_crpe=shared_crpes[1] ) self.factoratt_crpe3 = FactorAtt_ConvRelPosEnc( - dims[2], num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, + dims[2], num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop, shared_crpe=shared_crpes[2] ) self.factoratt_crpe4 = FactorAtt_ConvRelPosEnc( - dims[3], num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, + dims[3], num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop, shared_crpe=shared_crpes[3] ) self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() @@ -262,15 +263,15 @@ class ParallelBlock(nn.Module): self.mlp2 = self.mlp3 = self.mlp4 = Mlp( in_features=dims[1], hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) - def upsample(self, x, factor, size): + def upsample(self, x, factor: float, size: Tuple[int, int]): """ Feature map up-sampling. """ return self.interpolate(x, scale_factor=factor, size=size) - def downsample(self, x, factor, size): + def downsample(self, x, factor: float, size: Tuple[int, int]): """ Feature map down-sampling. """ return self.interpolate(x, scale_factor=1.0/factor, size=size) - def interpolate(self, x, scale_factor, size): + def interpolate(self, x, scale_factor: float, size: Tuple[int, int]): """ Feature map interpolation. """ B, N, C = x.shape H, W = size @@ -280,33 +281,28 @@ class ParallelBlock(nn.Module): img_tokens = x[:, 1:, :] img_tokens = img_tokens.transpose(1, 2).reshape(B, C, H, W) - img_tokens = F.interpolate(img_tokens, scale_factor=scale_factor, mode='bilinear') + img_tokens = F.interpolate( + img_tokens, scale_factor=scale_factor, recompute_scale_factor=False, mode='bilinear', align_corners=False) img_tokens = img_tokens.reshape(B, C, -1).transpose(1, 2) out = torch.cat((cls_token, img_tokens), dim=1) return out - def forward(self, x1, x2, x3, x4, sizes): - _, (H2, W2), (H3, W3), (H4, W4) = sizes - - # Conv-Attention. - x2 = self.cpes[1](x2, size=(H2, W2)) # Note: x1 is ignored. - x3 = self.cpes[2](x3, size=(H3, W3)) - x4 = self.cpes[3](x4, size=(H4, W4)) - + def forward(self, x1, x2, x3, x4, sizes: List[Tuple[int, int]]): + _, S2, S3, S4 = sizes cur2 = self.norm12(x2) cur3 = self.norm13(x3) cur4 = self.norm14(x4) - cur2 = self.factoratt_crpe2(cur2, size=(H2, W2)) - cur3 = self.factoratt_crpe3(cur3, size=(H3, W3)) - cur4 = self.factoratt_crpe4(cur4, size=(H4, W4)) - upsample3_2 = self.upsample(cur3, factor=2, size=(H3, W3)) - upsample4_3 = self.upsample(cur4, factor=2, size=(H4, W4)) - upsample4_2 = self.upsample(cur4, factor=4, size=(H4, W4)) - downsample2_3 = self.downsample(cur2, factor=2, size=(H2, W2)) - downsample3_4 = self.downsample(cur3, factor=2, size=(H3, W3)) - downsample2_4 = self.downsample(cur2, factor=4, size=(H2, W2)) + cur2 = self.factoratt_crpe2(cur2, size=S2) + cur3 = self.factoratt_crpe3(cur3, size=S3) + cur4 = self.factoratt_crpe4(cur4, size=S4) + upsample3_2 = self.upsample(cur3, factor=2., size=S3) + upsample4_3 = self.upsample(cur4, factor=2., size=S4) + upsample4_2 = self.upsample(cur4, factor=4., size=S4) + downsample2_3 = self.downsample(cur2, factor=2., size=S2) + downsample3_4 = self.downsample(cur3, factor=2., size=S3) + downsample2_4 = self.downsample(cur2, factor=4., size=S2) cur2 = cur2 + upsample3_2 + upsample4_2 cur3 = cur3 + upsample4_3 + downsample2_3 cur4 = cur4 + downsample3_4 + downsample2_4 @@ -330,11 +326,11 @@ class ParallelBlock(nn.Module): class CoaT(nn.Module): """ CoaT class. """ - def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dims=[0, 0, 0, 0], - serial_depths=[0, 0, 0, 0], parallel_depth=0, - num_heads=0, mlp_ratios=[0, 0, 0, 0], qkv_bias=True, qk_scale=None, drop_rate=0., attn_drop_rate=0., - drop_path_rate=0., norm_layer=partial(nn.LayerNorm, eps=1e-6), - return_interm_layers=False, out_features = None, crpe_window=None, **kwargs): + def __init__( + self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dims=(0, 0, 0, 0), + serial_depths=(0, 0, 0, 0), parallel_depth=0, num_heads=0, mlp_ratios=(0, 0, 0, 0), qkv_bias=True, + drop_rate=0., attn_drop_rate=0., drop_path_rate=0., norm_layer=partial(nn.LayerNorm, eps=1e-6), + return_interm_layers=False, out_features=None, crpe_window=None, **kwargs): super().__init__() crpe_window = crpe_window or {3: 2, 5: 3, 7: 3} self.return_interm_layers = return_interm_layers @@ -342,17 +338,18 @@ class CoaT(nn.Module): self.num_classes = num_classes # Patch embeddings. + img_size = to_2tuple(img_size) self.patch_embed1 = PatchEmbed( img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dims[0], norm_layer=nn.LayerNorm) self.patch_embed2 = PatchEmbed( - img_size=img_size // 4, patch_size=2, in_chans=embed_dims[0], + img_size=[x // 4 for x in img_size], patch_size=2, in_chans=embed_dims[0], embed_dim=embed_dims[1], norm_layer=nn.LayerNorm) self.patch_embed3 = PatchEmbed( - img_size=img_size // 8, patch_size=2, in_chans=embed_dims[1], + img_size=[x // 8 for x in img_size], patch_size=2, in_chans=embed_dims[1], embed_dim=embed_dims[2], norm_layer=nn.LayerNorm) self.patch_embed4 = PatchEmbed( - img_size=img_size // 16, patch_size=2, in_chans=embed_dims[2], + img_size=[x // 16 for x in img_size], patch_size=2, in_chans=embed_dims[2], embed_dim=embed_dims[3], norm_layer=nn.LayerNorm) # Class tokens. @@ -380,7 +377,7 @@ class CoaT(nn.Module): # Serial blocks 1. self.serial_blocks1 = nn.ModuleList([ SerialBlock( - dim=embed_dims[0], num_heads=num_heads, mlp_ratio=mlp_ratios[0], qkv_bias=qkv_bias, qk_scale=qk_scale, + dim=embed_dims[0], num_heads=num_heads, mlp_ratio=mlp_ratios[0], qkv_bias=qkv_bias, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr, norm_layer=norm_layer, shared_cpe=self.cpe1, shared_crpe=self.crpe1 ) @@ -390,7 +387,7 @@ class CoaT(nn.Module): # Serial blocks 2. self.serial_blocks2 = nn.ModuleList([ SerialBlock( - dim=embed_dims[1], num_heads=num_heads, mlp_ratio=mlp_ratios[1], qkv_bias=qkv_bias, qk_scale=qk_scale, + dim=embed_dims[1], num_heads=num_heads, mlp_ratio=mlp_ratios[1], qkv_bias=qkv_bias, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr, norm_layer=norm_layer, shared_cpe=self.cpe2, shared_crpe=self.crpe2 ) @@ -400,7 +397,7 @@ class CoaT(nn.Module): # Serial blocks 3. self.serial_blocks3 = nn.ModuleList([ SerialBlock( - dim=embed_dims[2], num_heads=num_heads, mlp_ratio=mlp_ratios[2], qkv_bias=qkv_bias, qk_scale=qk_scale, + dim=embed_dims[2], num_heads=num_heads, mlp_ratio=mlp_ratios[2], qkv_bias=qkv_bias, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr, norm_layer=norm_layer, shared_cpe=self.cpe3, shared_crpe=self.crpe3 ) @@ -410,7 +407,7 @@ class CoaT(nn.Module): # Serial blocks 4. self.serial_blocks4 = nn.ModuleList([ SerialBlock( - dim=embed_dims[3], num_heads=num_heads, mlp_ratio=mlp_ratios[3], qkv_bias=qkv_bias, qk_scale=qk_scale, + dim=embed_dims[3], num_heads=num_heads, mlp_ratio=mlp_ratios[3], qkv_bias=qkv_bias, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr, norm_layer=norm_layer, shared_cpe=self.cpe4, shared_crpe=self.crpe4 ) @@ -422,10 +419,9 @@ class CoaT(nn.Module): if self.parallel_depth > 0: self.parallel_blocks = nn.ModuleList([ ParallelBlock( - dims=embed_dims, num_heads=num_heads, mlp_ratios=mlp_ratios, qkv_bias=qkv_bias, qk_scale=qk_scale, - drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr, norm_layer=norm_layer, - shared_cpes=[self.cpe1, self.cpe2, self.cpe3, self.cpe4], - shared_crpes=[self.crpe1, self.crpe2, self.crpe3, self.crpe4] + dims=embed_dims, num_heads=num_heads, mlp_ratios=mlp_ratios, qkv_bias=qkv_bias, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr, norm_layer=norm_layer, + shared_crpes=(self.crpe1, self.crpe2, self.crpe3, self.crpe4) ) for _ in range(parallel_depth)] ) @@ -434,9 +430,11 @@ class CoaT(nn.Module): # Classification head(s). if not self.return_interm_layers: - self.norm1 = norm_layer(embed_dims[0]) - self.norm2 = norm_layer(embed_dims[1]) - self.norm3 = norm_layer(embed_dims[2]) + if self.parallel_blocks is not None: + self.norm2 = norm_layer(embed_dims[1]) + self.norm3 = norm_layer(embed_dims[2]) + else: + self.norm2 = self.norm3 = None self.norm4 = norm_layer(embed_dims[3]) if self.parallel_depth > 0: @@ -546,6 +544,7 @@ class CoaT(nn.Module): # Parallel blocks. for blk in self.parallel_blocks: + x2, x3, x4 = self.cpe2(x2, (H2, W2)), self.cpe3(x3, (H3, W3)), self.cpe4(x4, (H4, W4)) x1, x2, x3, x4 = blk(x1, x2, x3, x4, sizes=[(H1, W1), (H2, W2), (H3, W3), (H4, W4)]) if not torch.jit.is_scripting() and self.return_interm_layers: @@ -590,52 +589,70 @@ class CoaT(nn.Module): return x +def checkpoint_filter_fn(state_dict, model): + out_dict = {} + for k, v in state_dict.items(): + # original model had unused norm layers, removing them requires filtering pretrained checkpoints + if k.startswith('norm1') or \ + (model.norm2 is None and k.startswith('norm2')) or \ + (model.norm3 is None and k.startswith('norm3')): + continue + out_dict[k] = v + return out_dict + + +def _create_coat(variant, pretrained=False, default_cfg=None, **kwargs): + if kwargs.get('features_only', None): + raise RuntimeError('features_only not implemented for Vision Transformer models.') + + model = build_model_with_cfg( + CoaT, variant, pretrained, + default_cfg=default_cfgs[variant], + pretrained_filter_fn=checkpoint_filter_fn, + **kwargs) + return model + + @register_model def coat_tiny(pretrained=False, **kwargs): - model = CoaT( + model_cfg = dict( patch_size=4, embed_dims=[152, 152, 152, 152], serial_depths=[2, 2, 2, 2], parallel_depth=6, num_heads=8, mlp_ratios=[4, 4, 4, 4], **kwargs) - model.default_cfg = default_cfgs['coat_tiny'] + model = _create_coat('coat_tiny', pretrained=pretrained, **model_cfg) return model @register_model def coat_mini(pretrained=False, **kwargs): - model = CoaT( + model_cfg = dict( patch_size=4, embed_dims=[152, 216, 216, 216], serial_depths=[2, 2, 2, 2], parallel_depth=6, num_heads=8, mlp_ratios=[4, 4, 4, 4], **kwargs) - model.default_cfg = default_cfgs['coat_mini'] + model = _create_coat('coat_mini', pretrained=pretrained, **model_cfg) return model @register_model def coat_lite_tiny(pretrained=False, **kwargs): - model = CoaT( + model_cfg = dict( patch_size=4, embed_dims=[64, 128, 256, 320], serial_depths=[2, 2, 2, 2], parallel_depth=0, num_heads=8, mlp_ratios=[8, 8, 4, 4], **kwargs) - # FIXME use builder - model.default_cfg = default_cfgs['coat_lite_tiny'] - if pretrained: - load_pretrained(model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3)) + model = _create_coat('coat_lite_tiny', pretrained=pretrained, **model_cfg) return model @register_model def coat_lite_mini(pretrained=False, **kwargs): - model = CoaT( + model_cfg = dict( patch_size=4, embed_dims=[64, 128, 320, 512], serial_depths=[2, 2, 2, 2], parallel_depth=0, num_heads=8, mlp_ratios=[8, 8, 4, 4], **kwargs) - # FIXME use builder - model.default_cfg = default_cfgs['coat_lite_mini'] - if pretrained: - load_pretrained(model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3)) + model = _create_coat('coat_lite_mini', pretrained=pretrained, **model_cfg) return model @register_model def coat_lite_small(pretrained=False, **kwargs): - model = CoaT( + model_cfg = dict( patch_size=4, embed_dims=[64, 128, 320, 512], serial_depths=[3, 4, 6, 3], parallel_depth=0, num_heads=8, mlp_ratios=[8, 8, 4, 4], **kwargs) - model.default_cfg = default_cfgs['coat_lite_small'] + model = _create_coat('coat_lite_small', pretrained=pretrained, **model_cfg) return model \ No newline at end of file diff --git a/timm/models/convit.py b/timm/models/convit.py index f6ae3ec1..b15b46d8 100644 --- a/timm/models/convit.py +++ b/timm/models/convit.py @@ -39,7 +39,7 @@ def _cfg(url='', **kwargs): return { 'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, - 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, + 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, 'fixed_input_size': True, 'first_conv': 'patch_embed.proj', 'classifier': 'head', **kwargs } @@ -317,6 +317,9 @@ class ConViT(nn.Module): def _create_convit(variant, pretrained=False, **kwargs): + if kwargs.get('features_only', None): + raise RuntimeError('features_only not implemented for Vision Transformer models.') + return build_model_with_cfg( ConViT, variant, pretrained, default_cfg=default_cfgs[variant], diff --git a/timm/models/helpers.py b/timm/models/helpers.py index e9ac7f00..dfb6b860 100644 --- a/timm/models/helpers.py +++ b/timm/models/helpers.py @@ -44,7 +44,7 @@ def load_state_dict(checkpoint_path, use_ema=False): raise FileNotFoundError() -def load_checkpoint(model, checkpoint_path, use_ema=False, strict=True): +def load_checkpoint(model, checkpoint_path, use_ema=False, strict=False): state_dict = load_state_dict(checkpoint_path, use_ema) model.load_state_dict(state_dict, strict=strict) @@ -378,7 +378,11 @@ def update_default_cfg_and_kwargs(default_cfg, kwargs, kwargs_filter): # Overlay default cfg values from `external_default_cfg` if it exists in kwargs overlay_external_default_cfg(default_cfg, kwargs) # Set model __init__ args that can be determined by default_cfg (if not already passed as kwargs) - set_default_kwargs(kwargs, names=('num_classes', 'global_pool', 'in_chans'), default_cfg=default_cfg) + default_kwarg_names = ('num_classes', 'global_pool', 'in_chans') + if default_cfg.get('fixed_input_size', False): + # if fixed_input_size exists and is True, model takes an img_size arg that fixes its input size + default_kwarg_names += ('img_size',) + set_default_kwargs(kwargs, names=default_kwarg_names, default_cfg=default_cfg) # Filter keyword args for task specific model variants (some 'features only' models, etc.) filter_kwargs(kwargs, names=kwargs_filter) diff --git a/timm/models/levit.py b/timm/models/levit.py index 997b44d7..96a0c85b 100644 --- a/timm/models/levit.py +++ b/timm/models/levit.py @@ -1,3 +1,22 @@ +""" LeViT + +Paper: `LeViT: a Vision Transformer in ConvNet's Clothing for Faster Inference` + - https://arxiv.org/abs/2104.01136 + +@article{graham2021levit, + title={LeViT: a Vision Transformer in ConvNet's Clothing for Faster Inference}, + author={Benjamin Graham and Alaaeldin El-Nouby and Hugo Touvron and Pierre Stock and Armand Joulin and Herv\'e J\'egou and Matthijs Douze}, + journal={arXiv preprint arXiv:22104.01136}, + year={2021} +} + +Adapted from official impl at https://github.com/facebookresearch/LeViT, original copyright bellow. + +This version combines both conv/linear models and fixes torchscript compatibility. + +Modifications by/coyright Copyright 2021 Ross Wightman +""" + # Copyright (c) 2015-present, Facebook, Inc. # All rights reserved. @@ -5,10 +24,15 @@ # https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py # Copyright 2020 Ross Wightman, Apache-2.0 License import itertools +from copy import deepcopy +from functools import partial import torch +import torch.nn as nn from timm.data import IMAGENET_DEFAULT_STD, IMAGENET_DEFAULT_MEAN +from .helpers import build_model_with_cfg, overlay_external_default_cfg +from .layers import to_ntuple from .vision_transformer import trunc_normal_ from .registry import register_model @@ -19,70 +43,113 @@ def _cfg(url='', **kwargs): 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, 'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True, 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, - 'first_conv': 'patch_embed.proj', 'classifier': 'head', + 'first_conv': 'patch_embed.0.c', 'classifier': ('head.l', 'head_dist.l'), **kwargs } -specification = { - 'levit_128s': { - 'C': '128_256_384', 'D': 16, 'N': '4_6_8', 'X': '2_3_4', 'drop_path': 0, - 'weights': 'https://dl.fbaipublicfiles.com/LeViT/LeViT-128S-96703c44.pth'}, - 'levit_128': { - 'C': '128_256_384', 'D': 16, 'N': '4_8_12', 'X': '4_4_4', 'drop_path': 0, - 'weights': 'https://dl.fbaipublicfiles.com/LeViT/LeViT-128-b88c2750.pth'}, - 'levit_192': { - 'C': '192_288_384', 'D': 32, 'N': '3_5_6', 'X': '4_4_4', 'drop_path': 0, - 'weights': 'https://dl.fbaipublicfiles.com/LeViT/LeViT-192-92712e41.pth'}, - 'levit_256': { - 'C': '256_384_512', 'D': 32, 'N': '4_6_8', 'X': '4_4_4', 'drop_path': 0, - 'weights': 'https://dl.fbaipublicfiles.com/LeViT/LeViT-256-13b5763e.pth'}, - 'levit_384': { - 'C': '384_512_768', 'D': 32, 'N': '6_9_12', 'X': '4_4_4', 'drop_path': 0.1, - 'weights': 'https://dl.fbaipublicfiles.com/LeViT/LeViT-384-9bdaf2e2.pth'}, -} +default_cfgs = dict( + levit_128s=_cfg( + url='https://dl.fbaipublicfiles.com/LeViT/LeViT-128S-96703c44.pth' + ), + levit_128=_cfg( + url='https://dl.fbaipublicfiles.com/LeViT/LeViT-128-b88c2750.pth' + ), + levit_192=_cfg( + url='https://dl.fbaipublicfiles.com/LeViT/LeViT-192-92712e41.pth' + ), + levit_256=_cfg( + url='https://dl.fbaipublicfiles.com/LeViT/LeViT-256-13b5763e.pth' + ), + levit_384=_cfg( + url='https://dl.fbaipublicfiles.com/LeViT/LeViT-384-9bdaf2e2.pth' + ), +) + +model_cfgs = dict( + levit_128s=dict( + embed_dim=(128, 256, 384), key_dim=16, num_heads=(4, 6, 8), depth=(2, 3, 4)), + levit_128=dict( + embed_dim=(128, 256, 384), key_dim=16, num_heads=(4, 8, 12), depth=(4, 4, 4)), + levit_192=dict( + embed_dim=(192, 288, 384), key_dim=32, num_heads=(3, 5, 6), depth=(4, 4, 4)), + levit_256=dict( + embed_dim=(256, 384, 512), key_dim=32, num_heads=(4, 6, 8), depth=(4, 4, 4)), + levit_384=dict( + embed_dim=(384, 512, 768), key_dim=32, num_heads=(6, 9, 12), depth=(4, 4, 4)), +) __all__ = ['Levit'] @register_model -def levit_128s(num_classes=1000, distillation=True, pretrained=False, fuse=False, **kwargs): - return model_factory(**specification['levit_128s'], num_classes=num_classes, - distillation=distillation, pretrained=pretrained, fuse=fuse) +def levit_128s(pretrained=False, fuse=False,distillation=True, use_conv=False, **kwargs): + return create_levit( + 'levit_128s', pretrained=pretrained, fuse=fuse, distillation=distillation, use_conv=use_conv, **kwargs) + + +@register_model +def levit_128(pretrained=False, fuse=False, distillation=True, use_conv=False, **kwargs): + return create_levit( + 'levit_128', pretrained=pretrained, fuse=fuse, distillation=distillation, use_conv=use_conv, **kwargs) + + +@register_model +def levit_192(pretrained=False, fuse=False, distillation=True, use_conv=False, **kwargs): + return create_levit( + 'levit_192', pretrained=pretrained, fuse=fuse, distillation=distillation, use_conv=use_conv, **kwargs) @register_model -def levit_128(num_classes=1000, distillation=True, pretrained=False, fuse=False, **kwargs): - return model_factory(**specification['levit_128'], num_classes=num_classes, - distillation=distillation, pretrained=pretrained, fuse=fuse) +def levit_256(pretrained=False, fuse=False, distillation=True, use_conv=False, **kwargs): + return create_levit( + 'levit_256', pretrained=pretrained, fuse=fuse, distillation=distillation, use_conv=use_conv, **kwargs) @register_model -def levit_192(num_classes=1000, distillation=True, pretrained=False, fuse=False, **kwargs): - return model_factory(**specification['levit_192'], num_classes=num_classes, - distillation=distillation, pretrained=pretrained, fuse=fuse) +def levit_384(pretrained=False, fuse=False, distillation=True, use_conv=False, **kwargs): + return create_levit( + 'levit_384', pretrained=pretrained, fuse=fuse, distillation=distillation, use_conv=use_conv, **kwargs) @register_model -def levit_256(num_classes=1000, distillation=True, pretrained=False, fuse=False, **kwargs): - return model_factory(**specification['levit_256'], num_classes=num_classes, - distillation=distillation, pretrained=pretrained, fuse=fuse) +def levit_c_128s(pretrained=False, fuse=False, distillation=True, use_conv=True,**kwargs): + return create_levit( + 'levit_128s', pretrained=pretrained, fuse=fuse, distillation=distillation, use_conv=use_conv, **kwargs) @register_model -def levit_384(num_classes=1000, distillation=True, pretrained=False, fuse=False, **kwargs): - return model_factory(**specification['levit_384'], num_classes=num_classes, - distillation=distillation, pretrained=pretrained, fuse=fuse) +def levit_c_128(pretrained=False, fuse=False,distillation=True, use_conv=True, **kwargs): + return create_levit( + 'levit_128', pretrained=pretrained, fuse=fuse, distillation=distillation, use_conv=use_conv, **kwargs) -class ConvNorm(torch.nn.Sequential): +@register_model +def levit_c_192(pretrained=False, fuse=False, distillation=True, use_conv=True, **kwargs): + return create_levit( + 'levit_192', pretrained=pretrained, fuse=fuse, distillation=distillation, use_conv=use_conv, **kwargs) + + +@register_model +def levit_c_256(pretrained=False, fuse=False, distillation=True, use_conv=True, **kwargs): + return create_levit( + 'levit_256', pretrained=pretrained, fuse=fuse, distillation=distillation, use_conv=use_conv, **kwargs) + + +@register_model +def levit_c_384(pretrained=False, fuse=False, distillation=True, use_conv=True, **kwargs): + return create_levit( + 'levit_384', pretrained=pretrained, fuse=fuse, distillation=distillation, use_conv=use_conv, **kwargs) + + +class ConvNorm(nn.Sequential): def __init__( self, a, b, ks=1, stride=1, pad=0, dilation=1, groups=1, bn_weight_init=1, resolution=-10000): super().__init__() - self.add_module('c', torch.nn.Conv2d(a, b, ks, stride, pad, dilation, groups, bias=False)) - bn = torch.nn.BatchNorm2d(b) - torch.nn.init.constant_(bn.weight, bn_weight_init) - torch.nn.init.constant_(bn.bias, 0) + self.add_module('c', nn.Conv2d(a, b, ks, stride, pad, dilation, groups, bias=False)) + bn = nn.BatchNorm2d(b) + nn.init.constant_(bn.weight, bn_weight_init) + nn.init.constant_(bn.bias, 0) self.add_module('bn', bn) @torch.no_grad() @@ -91,7 +158,7 @@ class ConvNorm(torch.nn.Sequential): w = bn.weight / (bn.running_var + bn.eps) ** 0.5 w = c.weight * w[:, None, None, None] b = bn.bias - bn.running_mean * bn.weight / (bn.running_var + bn.eps) ** 0.5 - m = torch.nn.Conv2d( + m = nn.Conv2d( w.size(1), w.size(0), w.shape[2:], stride=self.c.stride, padding=self.c.padding, dilation=self.c.dilation, groups=self.c.groups) m.weight.data.copy_(w) @@ -99,13 +166,13 @@ class ConvNorm(torch.nn.Sequential): return m -class LinearNorm(torch.nn.Sequential): +class LinearNorm(nn.Sequential): def __init__(self, a, b, bn_weight_init=1, resolution=-100000): super().__init__() - self.add_module('c', torch.nn.Linear(a, b, bias=False)) - bn = torch.nn.BatchNorm1d(b) - torch.nn.init.constant_(bn.weight, bn_weight_init) - torch.nn.init.constant_(bn.bias, 0) + self.add_module('c', nn.Linear(a, b, bias=False)) + bn = nn.BatchNorm1d(b) + nn.init.constant_(bn.weight, bn_weight_init) + nn.init.constant_(bn.bias, 0) self.add_module('bn', bn) @torch.no_grad() @@ -114,25 +181,24 @@ class LinearNorm(torch.nn.Sequential): w = bn.weight / (bn.running_var + bn.eps) ** 0.5 w = l.weight * w[:, None] b = bn.bias - bn.running_mean * bn.weight / (bn.running_var + bn.eps) ** 0.5 - m = torch.nn.Linear(w.size(1), w.size(0)) + m = nn.Linear(w.size(1), w.size(0)) m.weight.data.copy_(w) m.bias.data.copy_(b) return m def forward(self, x): - l, bn = self._modules.values() - x = l(x) - return bn(x.flatten(0, 1)).reshape_as(x) + x = self.c(x) + return self.bn(x.flatten(0, 1)).reshape_as(x) -class NormLinear(torch.nn.Sequential): +class NormLinear(nn.Sequential): def __init__(self, a, b, bias=True, std=0.02): super().__init__() - self.add_module('bn', torch.nn.BatchNorm1d(a)) - l = torch.nn.Linear(a, b, bias=bias) + self.add_module('bn', nn.BatchNorm1d(a)) + l = nn.Linear(a, b, bias=bias) trunc_normal_(l.weight, std=std) if bias: - torch.nn.init.constant_(l.bias, 0) + nn.init.constant_(l.bias, 0) self.add_module('l', l) @torch.no_grad() @@ -145,24 +211,24 @@ class NormLinear(torch.nn.Sequential): b = b @ self.l.weight.T else: b = (l.weight @ b[:, None]).view(-1) + self.l.bias - m = torch.nn.Linear(w.size(1), w.size(0)) + m = nn.Linear(w.size(1), w.size(0)) m.weight.data.copy_(w) m.bias.data.copy_(b) return m -def b16(n, activation, resolution=224): - return torch.nn.Sequential( - ConvNorm(3, n // 8, 3, 2, 1, resolution=resolution), +def stem_b16(in_chs, out_chs, activation, resolution=224): + return nn.Sequential( + ConvNorm(in_chs, out_chs // 8, 3, 2, 1, resolution=resolution), activation(), - ConvNorm(n // 8, n // 4, 3, 2, 1, resolution=resolution // 2), + ConvNorm(out_chs // 8, out_chs // 4, 3, 2, 1, resolution=resolution // 2), activation(), - ConvNorm(n // 4, n // 2, 3, 2, 1, resolution=resolution // 4), + ConvNorm(out_chs // 4, out_chs // 2, 3, 2, 1, resolution=resolution // 4), activation(), - ConvNorm(n // 2, n, 3, 2, 1, resolution=resolution // 8)) + ConvNorm(out_chs // 2, out_chs, 3, 2, 1, resolution=resolution // 8)) -class Residual(torch.nn.Module): +class Residual(nn.Module): def __init__(self, m, drop): super().__init__() self.m = m @@ -176,10 +242,23 @@ class Residual(torch.nn.Module): return x + self.m(x) -class Attention(torch.nn.Module): +class Subsample(nn.Module): + def __init__(self, stride, resolution): + super().__init__() + self.stride = stride + self.resolution = resolution + + def forward(self, x): + B, N, C = x.shape + x = x.view(B, self.resolution, self.resolution, C)[:, ::self.stride, ::self.stride] + return x.reshape(B, -1, C) + + +class Attention(nn.Module): def __init__( - self, dim, key_dim, num_heads=8, attn_ratio=4, act_layer=None, resolution=14): + self, dim, key_dim, num_heads=8, attn_ratio=4, act_layer=None, resolution=14, use_conv=False): super().__init__() + self.num_heads = num_heads self.scale = key_dim ** -0.5 self.key_dim = key_dim @@ -187,11 +266,13 @@ class Attention(torch.nn.Module): self.d = int(attn_ratio * key_dim) self.dh = int(attn_ratio * key_dim) * num_heads self.attn_ratio = attn_ratio + self.use_conv = use_conv + ln_layer = ConvNorm if self.use_conv else LinearNorm h = self.dh + nh_kd * 2 - self.qkv = LinearNorm(dim, h, resolution=resolution) - self.proj = torch.nn.Sequential( + self.qkv = ln_layer(dim, h, resolution=resolution) + self.proj = nn.Sequential( act_layer(), - LinearNorm(self.dh, dim, bn_weight_init=0, resolution=resolution)) + ln_layer(self.dh, dim, bn_weight_init=0, resolution=resolution)) points = list(itertools.product(range(resolution), range(resolution))) N = len(points) @@ -203,68 +284,68 @@ class Attention(torch.nn.Module): if offset not in attention_offsets: attention_offsets[offset] = len(attention_offsets) idxs.append(attention_offsets[offset]) - self.attention_biases = torch.nn.Parameter(torch.zeros(num_heads, len(attention_offsets))) + self.attention_biases = nn.Parameter(torch.zeros(num_heads, len(attention_offsets))) self.register_buffer('attention_bias_idxs', torch.LongTensor(idxs).view(N, N)) + self.ab = None @torch.no_grad() def train(self, mode=True): super().train(mode) - if mode and hasattr(self, 'ab'): - del self.ab + self.ab = None if mode else self.attention_biases[:, self.attention_bias_idxs] + + def forward(self, x): # x (B,C,H,W) + if self.use_conv: + B, C, H, W = x.shape + q, k, v = self.qkv(x).view(B, self.num_heads, -1, H * W).split([self.key_dim, self.key_dim, self.d], dim=2) + ab = self.attention_biases[:, self.attention_bias_idxs] if self.ab is None else self.ab + attn = (q.transpose(-2, -1) @ k) * self.scale + ab + attn = attn.softmax(dim=-1) + x = (v @ attn.transpose(-2, -1)).view(B, -1, H, W) else: - self.ab = self.attention_biases[:, self.attention_bias_idxs] - - def forward(self, x): # x (B,N,C) - B, N, C = x.shape - qkv = self.qkv(x) - q, k, v = qkv.view(B, N, self.num_heads, -1).split([self.key_dim, self.key_dim, self.d], dim=3) - q = q.permute(0, 2, 1, 3) - k = k.permute(0, 2, 1, 3) - v = v.permute(0, 2, 1, 3) - - ab = self.attention_biases[:, self.attention_bias_idxs] if self.training else self.ab - attn = q @ k.transpose(-2, -1) * self.scale + ab - - attn = attn.softmax(dim=-1) - x = (attn @ v).transpose(1, 2).reshape(B, N, self.dh) + B, N, C = x.shape + qkv = self.qkv(x) + q, k, v = qkv.view(B, N, self.num_heads, -1).split([self.key_dim, self.key_dim, self.d], dim=3) + q = q.permute(0, 2, 1, 3) + k = k.permute(0, 2, 1, 3) + v = v.permute(0, 2, 1, 3) + ab = self.attention_biases[:, self.attention_bias_idxs] if self.ab is None else self.ab + attn = q @ k.transpose(-2, -1) * self.scale + ab + attn = attn.softmax(dim=-1) + x = (attn @ v).transpose(1, 2).reshape(B, N, self.dh) x = self.proj(x) return x -class Subsample(torch.nn.Module): - def __init__(self, stride, resolution): - super().__init__() - self.stride = stride - self.resolution = resolution - - def forward(self, x): - B, N, C = x.shape - x = x.view(B, self.resolution, self.resolution, C)[:, ::self.stride, ::self.stride] - return x.reshape(B, -1, C) - - -class AttentionSubsample(torch.nn.Module): - def __init__(self, in_dim, out_dim, key_dim, num_heads=8, - attn_ratio=2, act_layer=None, stride=2, resolution=14, resolution_=7): +class AttentionSubsample(nn.Module): + def __init__( + self, in_dim, out_dim, key_dim, num_heads=8, attn_ratio=2, + act_layer=None, stride=2, resolution=14, resolution_=7, use_conv=False): super().__init__() self.num_heads = num_heads self.scale = key_dim ** -0.5 self.key_dim = key_dim self.nh_kd = nh_kd = key_dim * num_heads self.d = int(attn_ratio * key_dim) - self.dh = int(attn_ratio * key_dim) * self.num_heads + self.dh = self.d * self.num_heads self.attn_ratio = attn_ratio self.resolution_ = resolution_ self.resolution_2 = resolution_ ** 2 - h = self.dh + nh_kd - self.kv = LinearNorm(in_dim, h, resolution=resolution) + self.use_conv = use_conv + if self.use_conv: + ln_layer = ConvNorm + sub_layer = partial(nn.AvgPool2d, kernel_size=1, padding=0) + else: + ln_layer = LinearNorm + sub_layer = partial(Subsample, resolution=resolution) - self.q = torch.nn.Sequential( - Subsample(stride, resolution), - LinearNorm(in_dim, nh_kd, resolution=resolution_)) - self.proj = torch.nn.Sequential( + h = self.dh + nh_kd + self.kv = ln_layer(in_dim, h, resolution=resolution) + self.q = nn.Sequential( + sub_layer(stride=stride), + ln_layer(in_dim, nh_kd, resolution=resolution_)) + self.proj = nn.Sequential( act_layer(), - LinearNorm(self.dh, out_dim, resolution=resolution_)) + ln_layer(self.dh, out_dim, resolution=resolution_)) self.stride = stride self.resolution = resolution @@ -283,35 +364,43 @@ class AttentionSubsample(torch.nn.Module): if offset not in attention_offsets: attention_offsets[offset] = len(attention_offsets) idxs.append(attention_offsets[offset]) - self.attention_biases = torch.nn.Parameter(torch.zeros(num_heads, len(attention_offsets))) + self.attention_biases = nn.Parameter(torch.zeros(num_heads, len(attention_offsets))) self.register_buffer('attention_bias_idxs', torch.LongTensor(idxs).view(N_, N)) - + self.ab = None @torch.no_grad() def train(self, mode=True): super().train(mode) - if mode and hasattr(self, 'ab'): - del self.ab - else: - self.ab = self.attention_biases[:, self.attention_bias_idxs] + self.ab = None if mode else self.attention_biases[:, self.attention_bias_idxs] def forward(self, x): - B, N, C = x.shape - k, v = self.kv(x).view(B, N, self.num_heads, -1).split([self.key_dim, self.d], dim=3) - k = k.permute(0, 2, 1, 3) # BHNC - v = v.permute(0, 2, 1, 3) # BHNC - q = self.q(x).view(B, self.resolution_2, self.num_heads, self.key_dim).permute(0, 2, 1, 3) + if self.use_conv: + B, C, H, W = x.shape + k, v = self.kv(x).view(B, self.num_heads, -1, H * W).split([self.key_dim, self.d], dim=2) + q = self.q(x).view(B, self.num_heads, self.key_dim, self.resolution_2) + + ab = self.attention_biases[:, self.attention_bias_idxs] if self.ab is None else self.ab + attn = (q.transpose(-2, -1) @ k) * self.scale + ab + attn = attn.softmax(dim=-1) + + x = (v @ attn.transpose(-2, -1)).reshape(B, -1, self.resolution_, self.resolution_) + else: + B, N, C = x.shape + k, v = self.kv(x).view(B, N, self.num_heads, -1).split([self.key_dim, self.d], dim=3) + k = k.permute(0, 2, 1, 3) # BHNC + v = v.permute(0, 2, 1, 3) # BHNC + q = self.q(x).view(B, self.resolution_2, self.num_heads, self.key_dim).permute(0, 2, 1, 3) - ab = self.attention_biases[:, self.attention_bias_idxs] if self.training else self.ab - attn = q @ k.transpose(-2, -1) * self.scale + ab - attn = attn.softmax(dim=-1) + ab = self.attention_biases[:, self.attention_bias_idxs] if self.ab is None else self.ab + attn = q @ k.transpose(-2, -1) * self.scale + ab + attn = attn.softmax(dim=-1) - x = (attn @ v).transpose(1, 2).reshape(B, -1, self.dh) + x = (attn @ v).transpose(1, 2).reshape(B, -1, self.dh) x = self.proj(x) return x -class Levit(torch.nn.Module): +class Levit(nn.Module): """ Vision Transformer with support for patch or hybrid CNN input stage """ @@ -321,45 +410,63 @@ class Levit(torch.nn.Module): patch_size=16, in_chans=3, num_classes=1000, - embed_dim=[192], - key_dim=[64], - depth=[12], - num_heads=[3], - attn_ratio=[2], - mlp_ratio=[2], + embed_dim=(192,), + key_dim=64, + depth=(12,), + num_heads=(3,), + attn_ratio=2, + mlp_ratio=2, hybrid_backbone=None, - down_ops=[], - attn_act_layer=torch.nn.Hardswish, - mlp_act_layer=torch.nn.Hardswish, + down_ops=None, + act_layer=nn.Hardswish, + attn_act_layer=nn.Hardswish, distillation=True, + use_conv=False, drop_path=0): super().__init__() - global FLOPS_COUNTER - + if isinstance(img_size, tuple): + # FIXME origin impl passes single img/res dim through whole hierarchy, + # not sure this model will be used enough to spend time fixing it. + assert img_size[0] == img_size[1] + img_size = img_size[0] self.num_classes = num_classes self.num_features = embed_dim[-1] self.embed_dim = embed_dim + N = len(embed_dim) + assert len(depth) == len(num_heads) == N + key_dim = to_ntuple(N)(key_dim) + attn_ratio = to_ntuple(N)(attn_ratio) + mlp_ratio = to_ntuple(N)(mlp_ratio) + down_ops = down_ops or ( + # ('Subsample',key_dim, num_heads, attn_ratio, mlp_ratio, stride) + ('Subsample', key_dim[0], embed_dim[0] // key_dim[0], 4, 2, 2), + ('Subsample', key_dim[0], embed_dim[1] // key_dim[1], 4, 2, 2), + ('',) + ) self.distillation = distillation + self.use_conv = use_conv + ln_layer = ConvNorm if self.use_conv else LinearNorm - self.patch_embed = hybrid_backbone + self.patch_embed = hybrid_backbone or stem_b16(in_chans, embed_dim[0], activation=act_layer) self.blocks = [] - down_ops.append(['']) resolution = img_size // patch_size for i, (ed, kd, dpth, nh, ar, mr, do) in enumerate( zip(embed_dim, key_dim, depth, num_heads, attn_ratio, mlp_ratio, down_ops)): for _ in range(dpth): self.blocks.append( Residual( - Attention(ed, kd, nh, attn_ratio=ar, act_layer=attn_act_layer, resolution=resolution), + Attention( + ed, kd, nh, attn_ratio=ar, act_layer=attn_act_layer, + resolution=resolution, use_conv=use_conv), drop_path)) if mr > 0: h = int(ed * mr) self.blocks.append( - Residual(torch.nn.Sequential( - LinearNorm(ed, h, resolution=resolution), - mlp_act_layer(), - LinearNorm(h, ed, bn_weight_init=0, resolution=resolution), + Residual(nn.Sequential( + ln_layer(ed, h, resolution=resolution), + act_layer(), + ln_layer(h, ed, bn_weight_init=0, resolution=resolution), ), drop_path)) if do[0] == 'Subsample': # ('Subsample',key_dim, num_heads, attn_ratio, mlp_ratio, stride) @@ -368,22 +475,22 @@ class Levit(torch.nn.Module): AttentionSubsample( *embed_dim[i:i + 2], key_dim=do[1], num_heads=do[2], attn_ratio=do[3], act_layer=attn_act_layer, stride=do[5], - resolution=resolution, resolution_=resolution_)) + resolution=resolution, resolution_=resolution_, use_conv=use_conv)) resolution = resolution_ if do[4] > 0: # mlp_ratio h = int(embed_dim[i + 1] * do[4]) self.blocks.append( - Residual(torch.nn.Sequential( - LinearNorm(embed_dim[i + 1], h, resolution=resolution), - mlp_act_layer(), - LinearNorm(h, embed_dim[i + 1], bn_weight_init=0, resolution=resolution), + Residual(nn.Sequential( + ln_layer(embed_dim[i + 1], h, resolution=resolution), + act_layer(), + ln_layer(h, embed_dim[i + 1], bn_weight_init=0, resolution=resolution), ), drop_path)) - self.blocks = torch.nn.Sequential(*self.blocks) + self.blocks = nn.Sequential(*self.blocks) # Classifier head - self.head = NormLinear(embed_dim[-1], num_classes) if num_classes > 0 else torch.nn.Identity() + self.head = NormLinear(embed_dim[-1], num_classes) if num_classes > 0 else nn.Identity() if distillation: - self.head_dist = NormLinear(embed_dim[-1], num_classes) if num_classes > 0 else torch.nn.Identity() + self.head_dist = NormLinear(embed_dim[-1], num_classes) if num_classes > 0 else nn.Identity() else: self.head_dist = None @@ -393,48 +500,44 @@ class Levit(torch.nn.Module): def forward(self, x): x = self.patch_embed(x) - x = x.flatten(2).transpose(1, 2) + if not self.use_conv: + x = x.flatten(2).transpose(1, 2) x = self.blocks(x) - x = x.mean(1) - if self.distillation: - x = self.head(x), self.head_dist(x) - if not self.training: - x = (x[0] + x[1]) / 2 + x = x.mean((-2, -1)) if self.use_conv else x.mean(1) + if self.head_dist is not None: + x, x_dist = self.head(x), self.head_dist(x) + if self.training and not torch.jit.is_scripting(): + return x, x_dist + else: + # during inference, return the average of both classifier predictions + return (x + x_dist) / 2 else: x = self.head(x) return x -def model_factory(C, D, X, N, drop_path, weights, num_classes, distillation, pretrained, fuse): - embed_dim = [int(x) for x in C.split('_')] - num_heads = [int(x) for x in N.split('_')] - depth = [int(x) for x in X.split('_')] - act = torch.nn.Hardswish - model = Levit( - patch_size=16, - embed_dim=embed_dim, - num_heads=num_heads, - key_dim=[D] * 3, - depth=depth, - attn_ratio=[2, 2, 2], - mlp_ratio=[2, 2, 2], - down_ops=[ - # ('Subsample',key_dim, num_heads, attn_ratio, mlp_ratio, stride) - ['Subsample', D, embed_dim[0] // D, 4, 2, 2], - ['Subsample', D, embed_dim[1] // D, 4, 2, 2], - ], - attn_act_layer=act, - mlp_act_layer=act, - hybrid_backbone=b16(embed_dim[0], activation=act), - num_classes=num_classes, - drop_path=drop_path, - distillation=distillation - ) - model.default_cfg = _cfg() - if pretrained: - checkpoint = torch.hub.load_state_dict_from_url(weights, map_location='cpu') - model.load_state_dict(checkpoint['model']) +def checkpoint_filter_fn(state_dict, model): + if 'model' in state_dict: + # For deit models + state_dict = state_dict['model'] + D = model.state_dict() + for k in state_dict.keys(): + if D[k].ndim == 4 and state_dict[k].ndim == 2: + state_dict[k] = state_dict[k][:, :, None, None] + return state_dict + + +def create_levit(variant, pretrained=False, default_cfg=None, fuse=False, **kwargs): + if kwargs.get('features_only', None): + raise RuntimeError('features_only not implemented for Vision Transformer models.') + + model_cfg = dict(**model_cfgs[variant], **kwargs) + model = build_model_with_cfg( + Levit, variant, pretrained, + default_cfg=default_cfgs[variant], + pretrained_filter_fn=checkpoint_filter_fn, + **model_cfg) #if fuse: # utils.replace_batchnorm(model) - return model + diff --git a/timm/models/levitc.py b/timm/models/levitc.py deleted file mode 100644 index 1a422953..00000000 --- a/timm/models/levitc.py +++ /dev/null @@ -1,400 +0,0 @@ -# Copyright (c) 2015-present, Facebook, Inc. -# All rights reserved. - -# Modified from -# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py -# Copyright 2020 Ross Wightman, Apache-2.0 License -import itertools - -import torch - -from timm.data import IMAGENET_DEFAULT_STD, IMAGENET_DEFAULT_MEAN -from .vision_transformer import trunc_normal_ -from .registry import register_model - - -def _cfg(url='', **kwargs): - return { - 'url': url, - 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, - 'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True, - 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, - 'first_conv': 'patch_embed.proj', 'classifier': 'head', - **kwargs - } - - -specification = { - 'levit_c_128s': { - 'C': '128_256_384', 'D': 16, 'N': '4_6_8', 'X': '2_3_4', 'drop_path': 0, - 'weights': 'https://dl.fbaipublicfiles.com/LeViT/LeViT-128S-96703c44.pth'}, - 'levit_c_128': { - 'C': '128_256_384', 'D': 16, 'N': '4_8_12', 'X': '4_4_4', 'drop_path': 0, - 'weights': 'https://dl.fbaipublicfiles.com/LeViT/LeViT-128-b88c2750.pth'}, - 'levit_c_192': { - 'C': '192_288_384', 'D': 32, 'N': '3_5_6', 'X': '4_4_4', 'drop_path': 0, - 'weights': 'https://dl.fbaipublicfiles.com/LeViT/LeViT-192-92712e41.pth'}, - 'levit_c_256': { - 'C': '256_384_512', 'D': 32, 'N': '4_6_8', 'X': '4_4_4', 'drop_path': 0, - 'weights': 'https://dl.fbaipublicfiles.com/LeViT/LeViT-256-13b5763e.pth'}, - 'levit_c_384': { - 'C': '384_512_768', 'D': 32, 'N': '6_9_12', 'X': '4_4_4', 'drop_path': 0.1, - 'weights': 'https://dl.fbaipublicfiles.com/LeViT/LeViT-384-9bdaf2e2.pth'}, -} - -__all__ = ['Levit'] - - -@register_model -def levit_c_128s(num_classes=1000, distillation=True, pretrained=False, fuse=False, **kwargs): - return model_factory(**specification['levit_c_128s'], num_classes=num_classes, - distillation=distillation, pretrained=pretrained, fuse=fuse) - - -@register_model -def levit_c_128(num_classes=1000, distillation=True, pretrained=False, fuse=False, **kwargs): - return model_factory(**specification['levit_c_128'], num_classes=num_classes, - distillation=distillation, pretrained=pretrained, fuse=fuse) - - -@register_model -def levit_c_192(num_classes=1000, distillation=True, pretrained=False, fuse=False, **kwargs): - return model_factory(**specification['levit_c_192'], num_classes=num_classes, - distillation=distillation, pretrained=pretrained, fuse=fuse) - - -@register_model -def levit_c_256(num_classes=1000, distillation=True, pretrained=False, fuse=False, **kwargs): - return model_factory(**specification['levit_c_256'], num_classes=num_classes, - distillation=distillation, pretrained=pretrained, fuse=fuse) - - -@register_model -def levit_c_384(num_classes=1000, distillation=True, pretrained=False, fuse=False, **kwargs): - return model_factory(**specification['levit_c_384'], num_classes=num_classes, - distillation=distillation, pretrained=pretrained, fuse=fuse) - - -class ConvNorm(torch.nn.Sequential): - def __init__( - self, a, b, ks=1, stride=1, pad=0, dilation=1, groups=1, bn_weight_init=1, resolution=-10000): - super().__init__() - self.add_module('c', torch.nn.Conv2d(a, b, ks, stride, pad, dilation, groups, bias=False)) - bn = torch.nn.BatchNorm2d(b) - torch.nn.init.constant_(bn.weight, bn_weight_init) - torch.nn.init.constant_(bn.bias, 0) - self.add_module('bn', bn) - - @torch.no_grad() - def fuse(self): - c, bn = self._modules.values() - w = bn.weight / (bn.running_var + bn.eps) ** 0.5 - w = c.weight * w[:, None, None, None] - b = bn.bias - bn.running_mean * bn.weight / \ - (bn.running_var + bn.eps) ** 0.5 - m = torch.nn.Conv2d( - w.size(1), w.size(0), w.shape[2:], stride=self.c.stride, - padding=self.c.padding, dilation=self.c.dilation, groups=self.c.groups) - m.weight.data.copy_(w) - m.bias.data.copy_(b) - return m - - -class NormLinear(torch.nn.Sequential): - def __init__(self, a, b, bias=True, std=0.02): - super().__init__() - self.add_module('bn', torch.nn.BatchNorm1d(a)) - l = torch.nn.Linear(a, b, bias=bias) - trunc_normal_(l.weight, std=std) - if bias: - torch.nn.init.constant_(l.bias, 0) - self.add_module('l', l) - - @torch.no_grad() - def fuse(self): - bn, l = self._modules.values() - w = bn.weight / (bn.running_var + bn.eps) ** 0.5 - b = bn.bias - self.bn.running_mean * \ - self.bn.weight / (bn.running_var + bn.eps) ** 0.5 - w = l.weight * w[None, :] - if l.bias is None: - b = b @ self.l.weight.T - else: - b = (l.weight @ b[:, None]).view(-1) + self.l.bias - m = torch.nn.Linear(w.size(1), w.size(0)) - m.weight.data.copy_(w) - m.bias.data.copy_(b) - return m - - -def b16(n, activation, resolution=224): - return torch.nn.Sequential( - ConvNorm(3, n // 8, 3, 2, 1, resolution=resolution), - activation(), - ConvNorm(n // 8, n // 4, 3, 2, 1, resolution=resolution // 2), - activation(), - ConvNorm(n // 4, n // 2, 3, 2, 1, resolution=resolution // 4), - activation(), - ConvNorm(n // 2, n, 3, 2, 1, resolution=resolution // 8)) - - -class Residual(torch.nn.Module): - def __init__(self, m, drop): - super().__init__() - self.m = m - self.drop = drop - - def forward(self, x): - if self.training and self.drop > 0: - return x + self.m(x) * torch.rand( - x.size(0), 1, 1, device=x.device).ge_(self.drop).div(1 - self.drop).detach() - else: - return x + self.m(x) - - -class Attention(torch.nn.Module): - def __init__(self, dim, key_dim, num_heads=8, - attn_ratio=4, act_layer=None, resolution=14): - super().__init__() - self.num_heads = num_heads - self.scale = key_dim ** -0.5 - self.key_dim = key_dim - self.nh_kd = nh_kd = key_dim * num_heads - self.d = int(attn_ratio * key_dim) - self.dh = int(attn_ratio * key_dim) * num_heads - self.attn_ratio = attn_ratio - h = self.dh + nh_kd * 2 - self.qkv = ConvNorm(dim, h, resolution=resolution) - self.proj = torch.nn.Sequential( - act_layer(), - ConvNorm(self.dh, dim, bn_weight_init=0, resolution=resolution)) - - points = list(itertools.product(range(resolution), range(resolution))) - N = len(points) - attention_offsets = {} - idxs = [] - for p1 in points: - for p2 in points: - offset = (abs(p1[0] - p2[0]), abs(p1[1] - p2[1])) - if offset not in attention_offsets: - attention_offsets[offset] = len(attention_offsets) - idxs.append(attention_offsets[offset]) - self.attention_biases = torch.nn.Parameter(torch.zeros(num_heads, len(attention_offsets))) - self.register_buffer('attention_bias_idxs', torch.LongTensor(idxs).view(N, N)) - self.ab = None - - @torch.no_grad() - def train(self, mode=True): - super().train(mode) - if mode and self.ab is not None: - self.ab = None - else: - self.ab = self.attention_biases[:, self.attention_bias_idxs] - - def forward(self, x): # x (B,C,H,W) - B, C, H, W = x.shape - q, k, v = self.qkv(x).view(B, self.num_heads, -1, H * W).split([self.key_dim, self.key_dim, self.d], dim=2) - ab = self.attention_biases[:, self.attention_bias_idxs] if self.training else self.ab - attn = (q.transpose(-2, -1) @ k) * self.scale + ab - attn = attn.softmax(dim=-1) - x = (v @ attn.transpose(-2, -1)).view(B, -1, H, W) - x = self.proj(x) - return x - - -class AttentionSubsample(torch.nn.Module): - def __init__( - self, in_dim, out_dim, key_dim, num_heads=8, attn_ratio=2, - act_layer=None, stride=2, resolution=14, resolution_=7): - super().__init__() - self.num_heads = num_heads - self.scale = key_dim ** -0.5 - self.key_dim = key_dim - self.nh_kd = nh_kd = key_dim * num_heads - self.d = int(attn_ratio * key_dim) - self.dh = int(attn_ratio * key_dim) * self.num_heads - self.attn_ratio = attn_ratio - self.resolution_ = resolution_ - self.resolution_2 = resolution_ ** 2 - h = self.dh + nh_kd - self.kv = ConvNorm(in_dim, h, resolution=resolution) - self.q = torch.nn.Sequential( - torch.nn.AvgPool2d(1, stride, 0), - ConvNorm(in_dim, nh_kd, resolution=resolution_)) - self.proj = torch.nn.Sequential( - act_layer(), - ConvNorm(self.d * num_heads, out_dim, resolution=resolution_)) - - self.stride = stride - self.resolution = resolution - points = list(itertools.product(range(resolution), range(resolution))) - points_ = list(itertools.product(range(resolution_), range(resolution_))) - N = len(points) - N_ = len(points_) - attention_offsets = {} - idxs = [] - for p1 in points_: - for p2 in points: - size = 1 - offset = ( - abs(p1[0] * stride - p2[0] + (size - 1) / 2), - abs(p1[1] * stride - p2[1] + (size - 1) / 2)) - if offset not in attention_offsets: - attention_offsets[offset] = len(attention_offsets) - idxs.append(attention_offsets[offset]) - self.attention_biases = torch.nn.Parameter(torch.zeros(num_heads, len(attention_offsets))) - self.register_buffer('attention_bias_idxs', torch.LongTensor(idxs).view(N_, N)) - self.ab = None - - @torch.no_grad() - def train(self, mode=True): - super().train(mode) - if mode and self.ab is not None: - self.ab = None - else: - self.ab = self.attention_biases[:, self.attention_bias_idxs] - - def forward(self, x): - B, C, H, W = x.shape - k, v = self.kv(x).view(B, self.num_heads, -1, H * W).split([self.key_dim, self.d], dim=2) - q = self.q(x).view(B, self.num_heads, self.key_dim, self.resolution_2) - ab = self.attention_biases[:, self.attention_bias_idxs] if self.training else self.ab - attn = (q.transpose(-2, -1) @ k) * self.scale + ab - attn = attn.softmax(dim=-1) - - x = (v @ attn.transpose(-2, -1)).reshape(B, -1, self.resolution_, self.resolution_) - x = self.proj(x) - return x - - -class Levit(torch.nn.Module): - """ Vision Transformer with support for patch or hybrid CNN input stage - """ - - def __init__( - self, - img_size=224, - patch_size=16, - in_chans=3, - num_classes=1000, - embed_dim=[192], - key_dim=[64], - depth=[12], - num_heads=[3], - attn_ratio=[2], - mlp_ratio=[2], - hybrid_backbone=None, - down_ops=[], - attn_act_layer=torch.nn.Hardswish, - mlp_act_layer=torch.nn.Hardswish, - distillation=True, - drop_path=0): - super().__init__() - self.num_classes = num_classes - self.num_features = embed_dim[-1] - self.embed_dim = embed_dim - self.distillation = distillation - - self.patch_embed = hybrid_backbone - - self.blocks = [] - down_ops.append(['']) - resolution = img_size // patch_size - for i, (ed, kd, dpth, nh, ar, mr, do) in enumerate( - zip(embed_dim, key_dim, depth, num_heads, attn_ratio, mlp_ratio, down_ops)): - for _ in range(dpth): - self.blocks.append( - Residual( - Attention(ed, kd, nh, attn_ratio=ar, act_layer=attn_act_layer, resolution=resolution), - drop_path)) - if mr > 0: - h = int(ed * mr) - self.blocks.append( - Residual(torch.nn.Sequential( - ConvNorm(ed, h, resolution=resolution), - mlp_act_layer(), - ConvNorm(h, ed, bn_weight_init=0, resolution=resolution), - ), drop_path)) - if do[0] == 'Subsample': - # ('Subsample',key_dim, num_heads, attn_ratio, mlp_ratio, stride) - resolution_ = (resolution - 1) // do[5] + 1 - self.blocks.append( - AttentionSubsample( - *embed_dim[i:i + 2], key_dim=do[1], num_heads=do[2], attn_ratio=do[3], - act_layer=attn_act_layer, stride=do[5], - resolution=resolution, resolution_=resolution_)) - resolution = resolution_ - if do[4] > 0: # mlp_ratio - h = int(embed_dim[i + 1] * do[4]) - self.blocks.append( - Residual(torch.nn.Sequential( - ConvNorm(embed_dim[i + 1], h, resolution=resolution), - mlp_act_layer(), - ConvNorm(h, embed_dim[i + 1], bn_weight_init=0, resolution=resolution), - ), drop_path)) - self.blocks = torch.nn.Sequential(*self.blocks) - - # Classifier head - self.head = NormLinear( - embed_dim[-1], num_classes) if num_classes > 0 else torch.nn.Identity() - if distillation: - self.head_dist = NormLinear( - embed_dim[-1], num_classes) if num_classes > 0 else torch.nn.Identity() - - @torch.jit.ignore - def no_weight_decay(self): - return {x for x in self.state_dict().keys() if 'attention_biases' in x} - - def forward(self, x): - x = self.patch_embed(x) - x = self.blocks(x) - x = torch.nn.functional.adaptive_avg_pool2d(x, 1).flatten(1) - if self.distillation: - x = self.head(x), self.head_dist(x) - if not self.training: - x = (x[0] + x[1]) / 2 - else: - x = self.head(x) - return x - - -def model_factory(C, D, X, N, drop_path, weights, num_classes, distillation, pretrained, fuse): - embed_dim = [int(x) for x in C.split('_')] - num_heads = [int(x) for x in N.split('_')] - depth = [int(x) for x in X.split('_')] - act = torch.nn.Hardswish - model = Levit( - patch_size=16, - embed_dim=embed_dim, - num_heads=num_heads, - key_dim=[D] * 3, - depth=depth, - attn_ratio=[2, 2, 2], - mlp_ratio=[2, 2, 2], - down_ops=[ - # ('Subsample',key_dim, num_heads, attn_ratio, mlp_ratio, stride) - ['Subsample', D, embed_dim[0] // D, 4, 2, 2], - ['Subsample', D, embed_dim[1] // D, 4, 2, 2], - ], - attn_act_layer=act, - mlp_act_layer=act, - hybrid_backbone=b16(embed_dim[0], activation=act), - num_classes=num_classes, - drop_path=drop_path, - distillation=distillation - ) - model.default_cfg = _cfg() - if pretrained: - checkpoint = torch.hub.load_state_dict_from_url( - weights, map_location='cpu') - d = checkpoint['model'] - D = model.state_dict() - for k in d.keys(): - if D[k].shape != d[k].shape: - d[k] = d[k][:, :, None, None] - model.load_state_dict(d) - #if fuse: - # utils.replace_batchnorm(model) - - return model - diff --git a/timm/models/mlp_mixer.py b/timm/models/mlp_mixer.py index 92ca115b..5a6dce6f 100644 --- a/timm/models/mlp_mixer.py +++ b/timm/models/mlp_mixer.py @@ -273,25 +273,14 @@ def _init_weights(m, n: str, head_bias: float = 0.): nn.init.ones_(m.weight) -def _create_mixer(variant, pretrained=False, default_cfg=None, **kwargs): - if default_cfg is None: - default_cfg = deepcopy(default_cfgs[variant]) - overlay_external_default_cfg(default_cfg, kwargs) - default_num_classes = default_cfg['num_classes'] - default_img_size = default_cfg['input_size'][-2:] - num_classes = kwargs.pop('num_classes', default_num_classes) - img_size = kwargs.pop('img_size', default_img_size) - +def _create_mixer(variant, pretrained=False, **kwargs): if kwargs.get('features_only', None): raise RuntimeError('features_only not implemented for MLP-Mixer models.') model = build_model_with_cfg( MlpMixer, variant, pretrained, - default_cfg=default_cfg, - img_size=img_size, - num_classes=num_classes, + default_cfg=default_cfgs[variant], **kwargs) - return model diff --git a/timm/models/pit.py b/timm/models/pit.py index 040d96db..9c350861 100644 --- a/timm/models/pit.py +++ b/timm/models/pit.py @@ -251,24 +251,14 @@ def checkpoint_filter_fn(state_dict, model): def _create_pit(variant, pretrained=False, **kwargs): - default_cfg = deepcopy(default_cfgs[variant]) - overlay_external_default_cfg(default_cfg, kwargs) - default_num_classes = default_cfg['num_classes'] - default_img_size = default_cfg['input_size'][-2:] - img_size = kwargs.pop('img_size', default_img_size) - num_classes = kwargs.pop('num_classes', default_num_classes) - if kwargs.get('features_only', None): raise RuntimeError('features_only not implemented for Vision Transformer models.') model = build_model_with_cfg( PoolingVisionTransformer, variant, pretrained, - default_cfg=default_cfg, - img_size=img_size, - num_classes=num_classes, + default_cfg=default_cfgs[variant], pretrained_filter_fn=checkpoint_filter_fn, **kwargs) - return model diff --git a/timm/models/tnt.py b/timm/models/tnt.py index 8e038718..8186cc4a 100644 --- a/timm/models/tnt.py +++ b/timm/models/tnt.py @@ -12,7 +12,7 @@ import torch.nn as nn from functools import partial from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from timm.models.helpers import load_pretrained +from timm.models.helpers import build_model_with_cfg from timm.models.layers import Mlp, DropPath, trunc_normal_ from timm.models.layers.helpers import to_2tuple from timm.models.registry import register_model @@ -238,24 +238,31 @@ def checkpoint_filter_fn(state_dict, model): return state_dict +def _create_tnt(variant, pretrained=False, **kwargs): + if kwargs.get('features_only', None): + raise RuntimeError('features_only not implemented for Vision Transformer models.') + + model = build_model_with_cfg( + TNT, variant, pretrained, + default_cfg=default_cfgs[variant], + pretrained_filter_fn=checkpoint_filter_fn, + **kwargs) + return model + + @register_model def tnt_s_patch16_224(pretrained=False, **kwargs): - model = TNT(patch_size=16, embed_dim=384, in_dim=24, depth=12, num_heads=6, in_num_head=4, + model_cfg = dict( + patch_size=16, embed_dim=384, in_dim=24, depth=12, num_heads=6, in_num_head=4, qkv_bias=False, **kwargs) - model.default_cfg = default_cfgs['tnt_s_patch16_224'] - if pretrained: - load_pretrained( - model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3), - filter_fn=checkpoint_filter_fn) + model = _create_tnt('tnt_s_patch16_224', pretrained=pretrained, **model_cfg) return model @register_model def tnt_b_patch16_224(pretrained=False, **kwargs): - model = TNT(patch_size=16, embed_dim=640, in_dim=40, depth=12, num_heads=10, in_num_head=4, + model_cfg = dict( + patch_size=16, embed_dim=640, in_dim=40, depth=12, num_heads=10, in_num_head=4, qkv_bias=False, **kwargs) - model.default_cfg = default_cfgs['tnt_b_patch16_224'] - if pretrained: - load_pretrained( - model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3)) + model = _create_tnt('tnt_b_patch16_224', pretrained=pretrained, **model_cfg) return model diff --git a/timm/models/twins.py b/timm/models/twins.py index a534d174..793d2ede 100644 --- a/timm/models/twins.py +++ b/timm/models/twins.py @@ -33,7 +33,7 @@ def _cfg(url='', **kwargs): 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, 'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True, 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, - 'first_conv': 'patch_embed.proj', 'classifier': 'head', + 'first_conv': 'patch_embeds.0.proj', 'classifier': 'head', **kwargs } @@ -361,25 +361,14 @@ class Twins(nn.Module): return x -def _create_twins(variant, pretrained=False, default_cfg=None, **kwargs): - if default_cfg is None: - default_cfg = deepcopy(default_cfgs[variant]) - overlay_external_default_cfg(default_cfg, kwargs) - default_num_classes = default_cfg['num_classes'] - default_img_size = default_cfg['input_size'][-2:] - - num_classes = kwargs.pop('num_classes', default_num_classes) - img_size = kwargs.pop('img_size', default_img_size) +def _create_twins(variant, pretrained=False, **kwargs): if kwargs.get('features_only', None): raise RuntimeError('features_only not implemented for Vision Transformer models.') model = build_model_with_cfg( Twins, variant, pretrained, - default_cfg=default_cfg, - img_size=img_size, - num_classes=num_classes, + default_cfg=default_cfgs[variant], **kwargs) - return model diff --git a/timm/models/visformer.py b/timm/models/visformer.py index aa3bca57..df1d502a 100644 --- a/timm/models/visformer.py +++ b/timm/models/visformer.py @@ -1,3 +1,12 @@ +""" Visformer + +Paper: Visformer: The Vision-friendly Transformer - https://arxiv.org/abs/2104.12533 + +From original at https://github.com/danczs/Visformer + +""" +from copy import deepcopy + import torch import torch.nn as nn import torch.nn.functional as F @@ -22,6 +31,12 @@ def _cfg(url='', **kwargs): } +default_cfgs = dict( + visformer_tiny=_cfg(), + visformer_small=_cfg(), +) + + class LayerNormBHWC(nn.LayerNorm): def __init__(self, dim): super().__init__(dim) @@ -300,87 +315,97 @@ class Visformer(nn.Module): return x +def _create_visformer(variant, pretrained=False, default_cfg=None, **kwargs): + if kwargs.get('features_only', None): + raise RuntimeError('features_only not implemented for Vision Transformer models.') + model = build_model_with_cfg( + Visformer, variant, pretrained, + default_cfg=default_cfgs[variant], + **kwargs) + return model + + @register_model def visformer_tiny(pretrained=False, **kwargs): - model = Visformer( + model_cfg = dict( img_size=224, init_channels=16, embed_dim=192, depth=(7, 4, 4), num_heads=3, mlp_ratio=4., group=8, attn_stage='011', spatial_conv='100', norm_layer=nn.BatchNorm2d, conv_init=True, embed_norm=nn.BatchNorm2d, **kwargs) - model.default_cfg = _cfg() + model = _create_visformer('visformer_tiny', pretrained=pretrained, **model_cfg) return model @register_model def visformer_small(pretrained=False, **kwargs): - model = Visformer( + model_cfg = dict( img_size=224, init_channels=32, embed_dim=384, depth=(7, 4, 4), num_heads=6, mlp_ratio=4., group=8, attn_stage='011', spatial_conv='100', norm_layer=nn.BatchNorm2d, conv_init=True, embed_norm=nn.BatchNorm2d, **kwargs) - model.default_cfg = _cfg() + model = _create_visformer('visformer_small', pretrained=pretrained, **model_cfg) return model -@register_model -def visformer_net1(pretrained=False, **kwargs): - model = Visformer( - init_channels=None, embed_dim=384, depth=(0, 12, 0), num_heads=6, mlp_ratio=4., attn_stage='111', - spatial_conv='000', vit_stem=True, conv_init=True, **kwargs) - model.default_cfg = _cfg() - return model - - -@register_model -def visformer_net2(pretrained=False, **kwargs): - model = Visformer( - init_channels=32, embed_dim=384, depth=(0, 12, 0), num_heads=6, mlp_ratio=4., attn_stage='111', - spatial_conv='000', vit_stem=False, conv_init=True, **kwargs) - model.default_cfg = _cfg() - return model - - -@register_model -def visformer_net3(pretrained=False, **kwargs): - model = Visformer( - init_channels=32, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4., attn_stage='111', - spatial_conv='000', vit_stem=False, conv_init=True, **kwargs) - model.default_cfg = _cfg() - return model - - -@register_model -def visformer_net4(pretrained=False, **kwargs): - model = Visformer( - init_channels=32, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4., attn_stage='111', - spatial_conv='000', vit_stem=False, conv_init=True, **kwargs) - model.default_cfg = _cfg() - return model - - -@register_model -def visformer_net5(pretrained=False, **kwargs): - model = Visformer( - init_channels=32, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4., group=1, attn_stage='111', - spatial_conv='111', vit_stem=False, conv_init=True, **kwargs) - model.default_cfg = _cfg() - return model - - -@register_model -def visformer_net6(pretrained=False, **kwargs): - model = Visformer( - init_channels=32, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4., group=1, attn_stage='111', - pos_embed=False, spatial_conv='111', conv_init=True, **kwargs) - model.default_cfg = _cfg() - return model - - -@register_model -def visformer_net7(pretrained=False, **kwargs): - model = Visformer( - init_channels=32, embed_dim=384, depth=(6, 7, 7), num_heads=6, group=1, attn_stage='000', - pos_embed=False, spatial_conv='111', conv_init=True, **kwargs) - model.default_cfg = _cfg() - return model +# @register_model +# def visformer_net1(pretrained=False, **kwargs): +# model = Visformer( +# init_channels=None, embed_dim=384, depth=(0, 12, 0), num_heads=6, mlp_ratio=4., attn_stage='111', +# spatial_conv='000', vit_stem=True, conv_init=True, **kwargs) +# model.default_cfg = _cfg() +# return model +# +# +# @register_model +# def visformer_net2(pretrained=False, **kwargs): +# model = Visformer( +# init_channels=32, embed_dim=384, depth=(0, 12, 0), num_heads=6, mlp_ratio=4., attn_stage='111', +# spatial_conv='000', vit_stem=False, conv_init=True, **kwargs) +# model.default_cfg = _cfg() +# return model +# +# +# @register_model +# def visformer_net3(pretrained=False, **kwargs): +# model = Visformer( +# init_channels=32, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4., attn_stage='111', +# spatial_conv='000', vit_stem=False, conv_init=True, **kwargs) +# model.default_cfg = _cfg() +# return model +# +# +# @register_model +# def visformer_net4(pretrained=False, **kwargs): +# model = Visformer( +# init_channels=32, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4., attn_stage='111', +# spatial_conv='000', vit_stem=False, conv_init=True, **kwargs) +# model.default_cfg = _cfg() +# return model +# +# +# @register_model +# def visformer_net5(pretrained=False, **kwargs): +# model = Visformer( +# init_channels=32, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4., group=1, attn_stage='111', +# spatial_conv='111', vit_stem=False, conv_init=True, **kwargs) +# model.default_cfg = _cfg() +# return model +# +# +# @register_model +# def visformer_net6(pretrained=False, **kwargs): +# model = Visformer( +# init_channels=32, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4., group=1, attn_stage='111', +# pos_embed=False, spatial_conv='111', conv_init=True, **kwargs) +# model.default_cfg = _cfg() +# return model +# +# +# @register_model +# def visformer_net7(pretrained=False, **kwargs): +# model = Visformer( +# init_channels=32, embed_dim=384, depth=(6, 7, 7), num_heads=6, group=1, attn_stage='000', +# pos_embed=False, spatial_conv='111', conv_init=True, **kwargs) +# model.default_cfg = _cfg() +# return model diff --git a/timm/models/vision_transformer.py b/timm/models/vision_transformer.py index bef6dfb0..ff74d836 100644 --- a/timm/models/vision_transformer.py +++ b/timm/models/vision_transformer.py @@ -387,21 +387,20 @@ def checkpoint_filter_fn(state_dict, model): v = v.reshape(O, -1, H, W) elif k == 'pos_embed' and v.shape != model.pos_embed.shape: # To resize pos embedding when using model at different size from pretrained weights - v = resize_pos_embed(v, model.pos_embed, getattr(model, 'num_tokens', 1), - model.patch_embed.grid_size) + v = resize_pos_embed( + v, model.pos_embed, getattr(model, 'num_tokens', 1), model.patch_embed.grid_size) out_dict[k] = v return out_dict def _create_vision_transformer(variant, pretrained=False, default_cfg=None, **kwargs): - if default_cfg is None: - default_cfg = deepcopy(default_cfgs[variant]) - overlay_external_default_cfg(default_cfg, kwargs) - default_num_classes = default_cfg['num_classes'] - default_img_size = default_cfg['input_size'][-2:] + default_cfg = default_cfg or default_cfgs[variant] + if kwargs.get('features_only', None): + raise RuntimeError('features_only not implemented for Vision Transformer models.') - num_classes = kwargs.pop('num_classes', default_num_classes) - img_size = kwargs.pop('img_size', default_img_size) + # NOTE this extra code to support handling of repr size for in21k pretrained models + default_num_classes = default_cfg['num_classes'] + num_classes = kwargs.get('num_classes', default_num_classes) repr_size = kwargs.pop('representation_size', None) if repr_size is not None and num_classes != default_num_classes: # Remove representation layer if fine-tuning. This may not always be the desired action, @@ -409,18 +408,12 @@ def _create_vision_transformer(variant, pretrained=False, default_cfg=None, **kw _logger.warning("Removing representation layer for fine-tuning.") repr_size = None - if kwargs.get('features_only', None): - raise RuntimeError('features_only not implemented for Vision Transformer models.') - model = build_model_with_cfg( VisionTransformer, variant, pretrained, default_cfg=default_cfg, - img_size=img_size, - num_classes=num_classes, representation_size=repr_size, pretrained_filter_fn=checkpoint_filter_fn, **kwargs) - return model diff --git a/timm/models/vision_transformer_hybrid.py b/timm/models/vision_transformer_hybrid.py index 1656559f..9e5a62b2 100644 --- a/timm/models/vision_transformer_hybrid.py +++ b/timm/models/vision_transformer_hybrid.py @@ -27,7 +27,7 @@ def _cfg(url='', **kwargs): return { 'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, - 'crop_pct': .9, 'interpolation': 'bicubic', + 'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True, 'mean': (0.5, 0.5, 0.5), 'std': (0.5, 0.5, 0.5), 'first_conv': 'patch_embed.backbone.stem.conv', 'classifier': 'head', **kwargs @@ -107,11 +107,10 @@ class HybridEmbed(nn.Module): def _create_vision_transformer_hybrid(variant, backbone, pretrained=False, **kwargs): - default_cfg = deepcopy(default_cfgs[variant]) embed_layer = partial(HybridEmbed, backbone=backbone) kwargs.setdefault('patch_size', 1) # default patch size for hybrid models if not set return _create_vision_transformer( - variant, pretrained=pretrained, default_cfg=default_cfg, embed_layer=embed_layer, **kwargs) + variant, pretrained=pretrained, embed_layer=embed_layer, default_cfg=default_cfgs[variant], **kwargs) def _resnetv2(layers=(3, 4, 9), **kwargs): From 83487e2a0dd93e9112ee19ee0c8dcdf393463440 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Mon, 24 May 2021 21:36:56 -0700 Subject: [PATCH 08/18] Lower max backward size for tests. --- tests/test_models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_models.py b/tests/test_models.py index 570b49db..c1632271 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -32,7 +32,7 @@ else: TARGET_FWD_SIZE = MAX_FWD_SIZE = 384 TARGET_BWD_SIZE = 128 -MAX_BWD_SIZE = 384 +MAX_BWD_SIZE = 320 MAX_FWD_FEAT_SIZE = 448 From c4572cc5aa21dfa6c81b4ef6a3479409e49561f0 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Mon, 24 May 2021 22:50:12 -0700 Subject: [PATCH 09/18] Add Visformer-small weighs, tweak torchscript jit test img size. --- tests/test_models.py | 18 +++++++++++------- timm/models/visformer.py | 4 +++- 2 files changed, 14 insertions(+), 8 deletions(-) diff --git a/tests/test_models.py b/tests/test_models.py index c1632271..e6a73619 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -33,7 +33,11 @@ else: TARGET_FWD_SIZE = MAX_FWD_SIZE = 384 TARGET_BWD_SIZE = 128 MAX_BWD_SIZE = 320 -MAX_FWD_FEAT_SIZE = 448 +MAX_FWD_OUT_SIZE = 448 +TARGET_JIT_SIZE = 128 +MAX_JIT_SIZE = 320 +TARGET_FFEAT_SIZE = 96 +MAX_FFEAT_SIZE = 256 def _get_input_size(model, target=None): @@ -109,10 +113,10 @@ def test_model_default_cfgs(model_name, batch_size): pool_size = cfg['pool_size'] input_size = model.default_cfg['input_size'] - if all([x <= MAX_FWD_FEAT_SIZE for x in input_size]) and \ + if all([x <= MAX_FWD_OUT_SIZE for x in input_size]) and \ not any([fnmatch.fnmatch(model_name, x) for x in EXCLUDE_FILTERS]): # output sizes only checked if default res <= 448 * 448 to keep resource down - input_size = tuple([min(x, MAX_FWD_FEAT_SIZE) for x in input_size]) + input_size = tuple([min(x, MAX_FWD_OUT_SIZE) for x in input_size]) input_tensor = torch.randn((batch_size, *input_size)) # test forward_features (always unpooled) @@ -176,8 +180,8 @@ def test_model_forward_torchscript(model_name, batch_size): model = create_model(model_name, pretrained=False) model.eval() - input_size = _get_input_size(model, 128) - if max(input_size) > MAX_FWD_SIZE: # NOTE using MAX_FWD_SIZE as the final limit is intentional + input_size = _get_input_size(model, TARGET_JIT_SIZE) + if max(input_size) > MAX_JIT_SIZE: # NOTE using MAX_FWD_SIZE as the final limit is intentional pytest.skip("Fixed input size model > limit.") model = torch.jit.script(model) @@ -205,8 +209,8 @@ def test_model_forward_features(model_name, batch_size): expected_channels = model.feature_info.channels() assert len(expected_channels) >= 4 # all models here should have at least 4 feature levels by default, some 5 or 6 - input_size = _get_input_size(model, 96) # jit compile is already a bit slow and we've tested normal res already... - if max(input_size) > MAX_FWD_SIZE: # NOTE using MAX_FWD_SIZE as the final limit is intentional + input_size = _get_input_size(model, TARGET_FFEAT_SIZE) + if max(input_size) > MAX_FFEAT_SIZE: pytest.skip("Fixed input size model > limit.") outputs = model(torch.randn((batch_size, *input_size))) diff --git a/timm/models/visformer.py b/timm/models/visformer.py index df1d502a..936f1ddf 100644 --- a/timm/models/visformer.py +++ b/timm/models/visformer.py @@ -33,7 +33,9 @@ def _cfg(url='', **kwargs): default_cfgs = dict( visformer_tiny=_cfg(), - visformer_small=_cfg(), + visformer_small=_cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vt3p-weights/visformer_small-839e1f5b.pth' + ), ) From d400f1dbddf71faed78e485e0797465298510815 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Tue, 25 May 2021 10:14:45 -0700 Subject: [PATCH 10/18] Filter test models before creation for backward/torchscript tests --- tests/test_models.py | 44 ++++++++++++++++++++++++++--------------- timm/models/registry.py | 7 +++++-- 2 files changed, 33 insertions(+), 18 deletions(-) diff --git a/tests/test_models.py b/tests/test_models.py index e6a73619..63a95fa5 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -40,14 +40,25 @@ TARGET_FFEAT_SIZE = 96 MAX_FFEAT_SIZE = 256 -def _get_input_size(model, target=None): - default_cfg = model.default_cfg - input_size = default_cfg['input_size'] - if 'fixed_input_size' in default_cfg and default_cfg['fixed_input_size']: +def _get_input_size(model=None, model_name='', target=None): + if model is None: + assert model_name, "One of model or model_name must be provided" + input_size = get_model_default_value(model_name, 'input_size') + fixed_input_size = get_model_default_value(model_name, 'fixed_input_size') + min_input_size = get_model_default_value(model_name, 'min_input_size') + else: + default_cfg = model.default_cfg + input_size = default_cfg['input_size'] + fixed_input_size = default_cfg.get('fixed_input_size', None) + min_input_size = default_cfg.get('min_input_size', None) + assert input_size is not None + + if fixed_input_size: return input_size - if 'min_input_size' in default_cfg: + + if min_input_size: if target and max(input_size) > target: - input_size = default_cfg['min_input_size'] + input_size = min_input_size else: if target and max(input_size) > target: input_size = tuple([min(x, target) for x in input_size]) @@ -73,18 +84,18 @@ def test_model_forward(model_name, batch_size): @pytest.mark.timeout(120) -@pytest.mark.parametrize('model_name', list_models(exclude_filters=EXCLUDE_FILTERS)) +@pytest.mark.parametrize('model_name', list_models(exclude_filters=EXCLUDE_FILTERS, name_matches_cfg=True)) @pytest.mark.parametrize('batch_size', [2]) def test_model_backward(model_name, batch_size): """Run a single forward pass with each model""" + input_size = _get_input_size(model_name=model_name, target=TARGET_BWD_SIZE) + if max(input_size) > MAX_BWD_SIZE: + pytest.skip("Fixed input size model > limit.") + model = create_model(model_name, pretrained=False, num_classes=42) num_params = sum([x.numel() for x in model.parameters()]) model.train() - input_size = _get_input_size(model, TARGET_BWD_SIZE) - if max(input_size) > MAX_BWD_SIZE: - pytest.skip("Fixed input size model > limit.") - inputs = torch.randn((batch_size, *input_size)) outputs = model(inputs) if isinstance(outputs, tuple): @@ -172,18 +183,19 @@ EXCLUDE_JIT_FILTERS = [ @pytest.mark.timeout(120) -@pytest.mark.parametrize('model_name', list_models(exclude_filters=EXCLUDE_FILTERS + EXCLUDE_JIT_FILTERS)) +@pytest.mark.parametrize( + 'model_name', list_models(exclude_filters=EXCLUDE_FILTERS + EXCLUDE_JIT_FILTERS, name_matches_cfg=True)) @pytest.mark.parametrize('batch_size', [1]) def test_model_forward_torchscript(model_name, batch_size): """Run a single forward pass with each model""" + input_size = _get_input_size(model_name=model_name, target=TARGET_JIT_SIZE) + if max(input_size) > MAX_JIT_SIZE: # NOTE using MAX_FWD_SIZE as the final limit is intentional + pytest.skip("Fixed input size model > limit.") + with set_scriptable(True): model = create_model(model_name, pretrained=False) model.eval() - input_size = _get_input_size(model, TARGET_JIT_SIZE) - if max(input_size) > MAX_JIT_SIZE: # NOTE using MAX_FWD_SIZE as the final limit is intentional - pytest.skip("Fixed input size model > limit.") - model = torch.jit.script(model) outputs = model(torch.randn((batch_size, *input_size))) diff --git a/timm/models/registry.py b/timm/models/registry.py index 9172ac7e..6927b6d6 100644 --- a/timm/models/registry.py +++ b/timm/models/registry.py @@ -50,7 +50,7 @@ def _natural_key(string_): return [int(s) if s.isdigit() else s for s in re.split(r'(\d+)', string_.lower())] -def list_models(filter='', module='', pretrained=False, exclude_filters=''): +def list_models(filter='', module='', pretrained=False, exclude_filters='', name_matches_cfg=False): """ Return list of available model names, sorted alphabetically Args: @@ -58,6 +58,7 @@ def list_models(filter='', module='', pretrained=False, exclude_filters=''): module (str) - Limit model selection to a specific sub-module (ie 'gen_efficientnet') pretrained (bool) - Include only models with pretrained weights if True exclude_filters (str or list[str]) - Wildcard filters to exclude models after including them with filter + name_matches_cfg (bool) - Include only models w/ model_name matching default_cfg name (excludes some aliases) Example: model_list('gluon_resnet*') -- returns all models starting with 'gluon_resnet' @@ -70,7 +71,7 @@ def list_models(filter='', module='', pretrained=False, exclude_filters=''): if filter: models = fnmatch.filter(models, filter) # include these models if exclude_filters: - if not isinstance(exclude_filters, list): + if not isinstance(exclude_filters, (tuple, list)): exclude_filters = [exclude_filters] for xf in exclude_filters: exclude_models = fnmatch.filter(models, xf) # exclude these models @@ -78,6 +79,8 @@ def list_models(filter='', module='', pretrained=False, exclude_filters=''): models = set(models).difference(exclude_models) if pretrained: models = _model_has_pretrained.intersection(models) + if name_matches_cfg: + models = set(_model_default_cfgs).intersection(models) return list(sorted(models, key=_natural_key)) From 11ae795e99f6146dfada86eeef8dcd8d1dcb8679 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Tue, 25 May 2021 10:15:32 -0700 Subject: [PATCH 11/18] Redo LeViT attention bias caching in a way that works with both torchscript and DataParallel --- timm/models/levit.py | 49 +++++++++++++++++++++++++++++++++----------- 1 file changed, 37 insertions(+), 12 deletions(-) diff --git a/timm/models/levit.py b/timm/models/levit.py index 96a0c85b..5019ee9a 100644 --- a/timm/models/levit.py +++ b/timm/models/levit.py @@ -26,6 +26,7 @@ Modifications by/coyright Copyright 2021 Ross Wightman import itertools from copy import deepcopy from functools import partial +from typing import Dict import torch import torch.nn as nn @@ -255,6 +256,8 @@ class Subsample(nn.Module): class Attention(nn.Module): + ab: Dict[str, torch.Tensor] + def __init__( self, dim, key_dim, num_heads=8, attn_ratio=4, act_layer=None, resolution=14, use_conv=False): super().__init__() @@ -286,20 +289,31 @@ class Attention(nn.Module): idxs.append(attention_offsets[offset]) self.attention_biases = nn.Parameter(torch.zeros(num_heads, len(attention_offsets))) self.register_buffer('attention_bias_idxs', torch.LongTensor(idxs).view(N, N)) - self.ab = None + self.ab = {} @torch.no_grad() def train(self, mode=True): super().train(mode) - self.ab = None if mode else self.attention_biases[:, self.attention_bias_idxs] + if mode and self.ab: + self.ab = {} # clear ab cache + + def get_attention_biases(self, device: torch.device) -> torch.Tensor: + if self.training: + return self.attention_biases[:, self.attention_bias_idxs] + else: + device_key = str(device) + if device_key not in self.ab: + self.ab[device_key] = self.attention_biases[:, self.attention_bias_idxs] + return self.ab[device_key] def forward(self, x): # x (B,C,H,W) if self.use_conv: B, C, H, W = x.shape q, k, v = self.qkv(x).view(B, self.num_heads, -1, H * W).split([self.key_dim, self.key_dim, self.d], dim=2) - ab = self.attention_biases[:, self.attention_bias_idxs] if self.ab is None else self.ab - attn = (q.transpose(-2, -1) @ k) * self.scale + ab + + attn = (q.transpose(-2, -1) @ k) * self.scale + self.get_attention_biases(x.device) attn = attn.softmax(dim=-1) + x = (v @ attn.transpose(-2, -1)).view(B, -1, H, W) else: B, N, C = x.shape @@ -308,15 +322,18 @@ class Attention(nn.Module): q = q.permute(0, 2, 1, 3) k = k.permute(0, 2, 1, 3) v = v.permute(0, 2, 1, 3) - ab = self.attention_biases[:, self.attention_bias_idxs] if self.ab is None else self.ab - attn = q @ k.transpose(-2, -1) * self.scale + ab + + attn = q @ k.transpose(-2, -1) * self.scale + self.get_attention_biases(x.device) attn = attn.softmax(dim=-1) + x = (attn @ v).transpose(1, 2).reshape(B, N, self.dh) x = self.proj(x) return x class AttentionSubsample(nn.Module): + ab: Dict[str, torch.Tensor] + def __init__( self, in_dim, out_dim, key_dim, num_heads=8, attn_ratio=2, act_layer=None, stride=2, resolution=14, resolution_=7, use_conv=False): @@ -366,12 +383,22 @@ class AttentionSubsample(nn.Module): idxs.append(attention_offsets[offset]) self.attention_biases = nn.Parameter(torch.zeros(num_heads, len(attention_offsets))) self.register_buffer('attention_bias_idxs', torch.LongTensor(idxs).view(N_, N)) - self.ab = None + self.ab = {} # per-device attention_biases cache @torch.no_grad() def train(self, mode=True): super().train(mode) - self.ab = None if mode else self.attention_biases[:, self.attention_bias_idxs] + if mode and self.ab: + self.ab = {} # clear ab cache + + def get_attention_biases(self, device: torch.device) -> torch.Tensor: + if self.training: + return self.attention_biases[:, self.attention_bias_idxs] + else: + device_key = str(device) + if device_key not in self.ab: + self.ab[device_key] = self.attention_biases[:, self.attention_bias_idxs] + return self.ab[device_key] def forward(self, x): if self.use_conv: @@ -379,8 +406,7 @@ class AttentionSubsample(nn.Module): k, v = self.kv(x).view(B, self.num_heads, -1, H * W).split([self.key_dim, self.d], dim=2) q = self.q(x).view(B, self.num_heads, self.key_dim, self.resolution_2) - ab = self.attention_biases[:, self.attention_bias_idxs] if self.ab is None else self.ab - attn = (q.transpose(-2, -1) @ k) * self.scale + ab + attn = (q.transpose(-2, -1) @ k) * self.scale + self.get_attention_biases(x.device) attn = attn.softmax(dim=-1) x = (v @ attn.transpose(-2, -1)).reshape(B, -1, self.resolution_, self.resolution_) @@ -391,8 +417,7 @@ class AttentionSubsample(nn.Module): v = v.permute(0, 2, 1, 3) # BHNC q = self.q(x).view(B, self.resolution_2, self.num_heads, self.key_dim).permute(0, 2, 1, 3) - ab = self.attention_biases[:, self.attention_bias_idxs] if self.ab is None else self.ab - attn = q @ k.transpose(-2, -1) * self.scale + ab + attn = q @ k.transpose(-2, -1) * self.scale + self.get_attention_biases(x.device) attn = attn.softmax(dim=-1) x = (attn @ v).transpose(1, 2).reshape(B, -1, self.dh) From 99d97e0d67d302b4518c16ccd0c74a05a533855f Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Tue, 25 May 2021 11:10:17 -0700 Subject: [PATCH 12/18] Hopefully the last test update for this PR... --- .github/workflows/tests.yml | 4 ++-- tests/test_models.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 1cc44acf..9f7aebdb 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -17,8 +17,8 @@ jobs: matrix: os: [ubuntu-latest, macOS-latest] python: ['3.8'] - torch: ['1.8.0'] - torchvision: ['0.9.0'] + torch: ['1.8.1'] + torchvision: ['0.9.1'] runs-on: ${{ matrix.os }} steps: diff --git a/tests/test_models.py b/tests/test_models.py index 63a95fa5..029ae0dd 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -73,7 +73,7 @@ def test_model_forward(model_name, batch_size): model = create_model(model_name, pretrained=False) model.eval() - input_size = _get_input_size(model, TARGET_FWD_SIZE) + input_size = _get_input_size(model=model, target=TARGET_FWD_SIZE) if max(input_size) > MAX_FWD_SIZE: pytest.skip("Fixed input size model > limit.") inputs = torch.randn((batch_size, *input_size)) @@ -221,7 +221,7 @@ def test_model_forward_features(model_name, batch_size): expected_channels = model.feature_info.channels() assert len(expected_channels) >= 4 # all models here should have at least 4 feature levels by default, some 5 or 6 - input_size = _get_input_size(model, TARGET_FFEAT_SIZE) + input_size = _get_input_size(model=model, target=TARGET_FFEAT_SIZE) if max(input_size) > MAX_FFEAT_SIZE: pytest.skip("Fixed input size model > limit.") From 318360c3f97dd9ab457a44db0b5f3af7cf0e9c88 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Tue, 25 May 2021 12:25:53 -0700 Subject: [PATCH 13/18] Update README.md before merge. Bump version to 0.4.10 --- README.md | 37 +++++++++++++------------------------ docs/archived_changes.md | 24 ++++++++++++++++++++++++ docs/changes.md | 28 ++++++++++++++++++++++++++++ timm/version.py | 2 +- 4 files changed, 66 insertions(+), 25 deletions(-) diff --git a/README.md b/README.md index ca283605..8ba95e98 100644 --- a/README.md +++ b/README.md @@ -23,6 +23,14 @@ I'm fortunate to be able to dedicate significant time and money of my own suppor ## What's New +### May 25, 2021 +* Add LeViT, Visformer, ConViT (PR by Aman Arora), Twins (PR by paper authors) transformer models +* Add ResMLP and gMLP MLP vision models to the existing MLP Mixer impl +* Fix a number of torchscript issues with various vision transformer models +* Cleanup input_size/img_size override handling and improve testing / test coverage for all vision transformer and MLP models +* More flexible pos embedding resize (non-square) for ViT and TnT. Thanks [Alexander Soare](https://github.com/alexander-soare) +* Add `efficientnetv2_rw_m` model and weights (started training before official code). 84.8 top-1, 53M params. + ### May 14, 2021 * Add EfficientNet-V2 official model defs w/ ported weights from official [Tensorflow/Keras](https://github.com/google/automl/tree/master/efficientnetv2) impl. * 1k trained variants: `tf_efficientnetv2_s/m/l` @@ -166,30 +174,6 @@ I'm fortunate to be able to dedicate significant time and money of my own suppor * Misc fixes for SiLU ONNX export, default_cfg missing from Feature extraction models, Linear layer w/ AMP + torchscript * PyPi release @ 0.3.2 (needed by EfficientDet) -### Oct 30, 2020 -* Test with PyTorch 1.7 and fix a small top-n metric view vs reshape issue. -* Convert newly added 224x224 Vision Transformer weights from official JAX repo. 81.8 top-1 for B/16, 83.1 L/16. -* Support PyTorch 1.7 optimized, native SiLU (aka Swish) activation. Add mapping to 'silu' name, custom swish will eventually be deprecated. -* Fix regression for loading pretrained classifier via direct model entrypoint functions. Didn't impact create_model() factory usage. -* PyPi release @ 0.3.0 version! - -### Oct 26, 2020 -* Update Vision Transformer models to be compatible with official code release at https://github.com/google-research/vision_transformer -* Add Vision Transformer weights (ImageNet-21k pretrain) for 384x384 base and large models converted from official jax impl - * ViT-B/16 - 84.2 - * ViT-B/32 - 81.7 - * ViT-L/16 - 85.2 - * ViT-L/32 - 81.5 - -### Oct 21, 2020 -* Weights added for Vision Transformer (ViT) models. 77.86 top-1 for 'small' and 79.35 for 'base'. Thanks to [Christof](https://www.kaggle.com/christofhenkel) for training the base model w/ lots of GPUs. - -### Oct 13, 2020 -* Initial impl of Vision Transformer models. Both patch and hybrid (CNN backbone) variants. Currently trying to train... -* Adafactor and AdaHessian (FP32 only, no AMP) optimizers -* EdgeTPU-M (`efficientnet_em`) model trained in PyTorch, 79.3 top-1 -* Pip release, doc updates pending a few more changes... - ## Introduction @@ -207,6 +191,7 @@ A full version of the list below with source links can be found in the [document * Bottleneck Transformers - https://arxiv.org/abs/2101.11605 * CaiT (Class-Attention in Image Transformers) - https://arxiv.org/abs/2103.17239 * CoaT (Co-Scale Conv-Attentional Image Transformers) - https://arxiv.org/abs/2104.06399 +* ConViT (Soft Convolutional Inductive Biases Vision Transformers)- https://arxiv.org/abs/2103.10697 * CspNet (Cross-Stage Partial Networks) - https://arxiv.org/abs/1911.11929 * DeiT (Vision Transformer) - https://arxiv.org/abs/2012.12877 * DenseNet - https://arxiv.org/abs/1608.06993 @@ -224,6 +209,7 @@ A full version of the list below with source links can be found in the [document * MobileNet-V2 - https://arxiv.org/abs/1801.04381 * Single-Path NAS - https://arxiv.org/abs/1904.02877 * GhostNet - https://arxiv.org/abs/1911.11907 +* gMLP - https://arxiv.org/abs/2105.08050 * GPU-Efficient Networks - https://arxiv.org/abs/2006.14090 * Halo Nets - https://arxiv.org/abs/2103.12731 * HardCoRe-NAS - https://arxiv.org/abs/2102.11646 @@ -231,6 +217,7 @@ A full version of the list below with source links can be found in the [document * Inception-V3 - https://arxiv.org/abs/1512.00567 * Inception-ResNet-V2 and Inception-V4 - https://arxiv.org/abs/1602.07261 * Lambda Networks - https://arxiv.org/abs/2102.08602 +* LeViT (Vision Transformer in ConvNet's Clothing) - https://arxiv.org/abs/2104.01136 * MLP-Mixer - https://arxiv.org/abs/2105.01601 * MobileNet-V3 (MBConvNet w/ Efficient Head) - https://arxiv.org/abs/1905.02244 * NASNet-A - https://arxiv.org/abs/1707.07012 @@ -240,6 +227,7 @@ A full version of the list below with source links can be found in the [document * Pooling-based Vision Transformer (PiT) - https://arxiv.org/abs/2103.16302 * RegNet - https://arxiv.org/abs/2003.13678 * RepVGG - https://arxiv.org/abs/2101.03697 +* ResMLP - https://arxiv.org/abs/2105.03404 * ResNet/ResNeXt * ResNet (v1b/v1.5) - https://arxiv.org/abs/1512.03385 * ResNeXt - https://arxiv.org/abs/1611.05431 @@ -257,6 +245,7 @@ A full version of the list below with source links can be found in the [document * Swin Transformer - https://arxiv.org/abs/2103.14030 * Transformer-iN-Transformer (TNT) - https://arxiv.org/abs/2103.00112 * TResNet - https://arxiv.org/abs/2003.13630 +* Twins (Spatial Attention in Vision Transformers) - https://arxiv.org/pdf/2104.13840.pdf * Vision Transformer - https://arxiv.org/abs/2010.11929 * VovNet V2 and V1 - https://arxiv.org/abs/1911.06667 * Xception - https://arxiv.org/abs/1610.02357 diff --git a/docs/archived_changes.md b/docs/archived_changes.md index 857a914d..56ee706f 100644 --- a/docs/archived_changes.md +++ b/docs/archived_changes.md @@ -1,5 +1,29 @@ # Archived Changes +### Oct 30, 2020 +* Test with PyTorch 1.7 and fix a small top-n metric view vs reshape issue. +* Convert newly added 224x224 Vision Transformer weights from official JAX repo. 81.8 top-1 for B/16, 83.1 L/16. +* Support PyTorch 1.7 optimized, native SiLU (aka Swish) activation. Add mapping to 'silu' name, custom swish will eventually be deprecated. +* Fix regression for loading pretrained classifier via direct model entrypoint functions. Didn't impact create_model() factory usage. +* PyPi release @ 0.3.0 version! + +### Oct 26, 2020 +* Update Vision Transformer models to be compatible with official code release at https://github.com/google-research/vision_transformer +* Add Vision Transformer weights (ImageNet-21k pretrain) for 384x384 base and large models converted from official jax impl + * ViT-B/16 - 84.2 + * ViT-B/32 - 81.7 + * ViT-L/16 - 85.2 + * ViT-L/32 - 81.5 + +### Oct 21, 2020 +* Weights added for Vision Transformer (ViT) models. 77.86 top-1 for 'small' and 79.35 for 'base'. Thanks to [Christof](https://www.kaggle.com/christofhenkel) for training the base model w/ lots of GPUs. + +### Oct 13, 2020 +* Initial impl of Vision Transformer models. Both patch and hybrid (CNN backbone) variants. Currently trying to train... +* Adafactor and AdaHessian (FP32 only, no AMP) optimizers +* EdgeTPU-M (`efficientnet_em`) model trained in PyTorch, 79.3 top-1 +* Pip release, doc updates pending a few more changes... + ### Sept 18, 2020 * New ResNet 'D' weights. 72.7 (top-1) ResNet-18-D, 77.1 ResNet-34-D, 80.5 ResNet-50-D * Added a few untrained defs for other ResNet models (66D, 101D, 152D, 200/200D) diff --git a/docs/changes.md b/docs/changes.md index b0ac125c..9719dd65 100644 --- a/docs/changes.md +++ b/docs/changes.md @@ -1,5 +1,33 @@ # Recent Changes +### May 25, 2021 +* Add LeViT, Visformer, Convit (PR by Aman Arora), Twins (PR by paper authors) transformer models +* Cleanup input_size/img_size override handling and testing for all vision transformer models +* Add `efficientnetv2_rw_m` model and weights (started training before official code). 84.8 top-1, 53M params. + +### May 14, 2021 +* Add EfficientNet-V2 official model defs w/ ported weights from official [Tensorflow/Keras](https://github.com/google/automl/tree/master/efficientnetv2) impl. + * 1k trained variants: `tf_efficientnetv2_s/m/l` + * 21k trained variants: `tf_efficientnetv2_s/m/l_in21k` + * 21k pretrained -> 1k fine-tuned: `tf_efficientnetv2_s/m/l_in21ft1k` + * v2 models w/ v1 scaling: `tf_efficientnetv2_b0` through `b3` + * Rename my prev V2 guess `efficientnet_v2s` -> `efficientnetv2_rw_s` + * Some blank `efficientnetv2_*` models in-place for future native PyTorch training + +### May 5, 2021 +* Add MLP-Mixer models and port pretrained weights from [Google JAX impl](https://github.com/google-research/vision_transformer/tree/linen) +* Add CaiT models and pretrained weights from [FB](https://github.com/facebookresearch/deit) +* Add ResNet-RS models and weights from [TF](https://github.com/tensorflow/tpu/tree/master/models/official/resnet/resnet_rs). Thanks [Aman Arora](https://github.com/amaarora) +* Add CoaT models and weights. Thanks [Mohammed Rizin](https://github.com/morizin) +* Add new ImageNet-21k weights & finetuned weights for TResNet, MobileNet-V3, ViT models. Thanks [mrT](https://github.com/mrT23) +* Add GhostNet models and weights. Thanks [Kai Han](https://github.com/iamhankai) +* Update ByoaNet attention modles + * Improve SA module inits + * Hack together experimental stand-alone Swin based attn module and `swinnet` + * Consistent '26t' model defs for experiments. +* Add improved Efficientnet-V2S (prelim model def) weights. 83.8 top-1. +* WandB logging support + ### April 13, 2021 * Add Swin Transformer models and weights from https://github.com/microsoft/Swin-Transformer diff --git a/timm/version.py b/timm/version.py index 2d802716..b94cbb01 100644 --- a/timm/version.py +++ b/timm/version.py @@ -1 +1 @@ -__version__ = '0.4.9' +__version__ = '0.4.10' From fd92ba0de85cb9da474a50d765601b36fee778ba Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Tue, 25 May 2021 12:52:07 -0700 Subject: [PATCH 14/18] Filter large vit models from torchscript tests --- tests/test_models.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/test_models.py b/tests/test_models.py index 029ae0dd..de664068 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -178,7 +178,8 @@ if 'GITHUB_ACTIONS' not in os.environ: EXCLUDE_JIT_FILTERS = [ '*iabn*', 'tresnet*', # models using inplace abn unlikely to ever be scriptable - 'dla*', 'hrnet*', 'ghostnet*', # hopefully fix at some point + 'dla*', 'hrnet*', 'ghostnet*', # hopefully fix at some point + 'vit_large_*', 'vit_huge_*', ] From 51c432150a0019f6193ca53b1e0fc21b86dc2c2a Mon Sep 17 00:00:00 2001 From: Peter Vandenabeele Date: Tue, 25 May 2021 22:42:44 +0200 Subject: [PATCH 15/18] README: fix simple typos --- README.md | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index ca283605..6a8d520e 100644 --- a/README.md +++ b/README.md @@ -39,7 +39,7 @@ I'm fortunate to be able to dedicate significant time and money of my own suppor * Add CoaT models and weights. Thanks [Mohammed Rizin](https://github.com/morizin) * Add new ImageNet-21k weights & finetuned weights for TResNet, MobileNet-V3, ViT models. Thanks [mrT](https://github.com/mrT23) * Add GhostNet models and weights. Thanks [Kai Han](https://github.com/iamhankai) -* Update ByoaNet attention modles +* Update ByoaNet attention modules * Improve SA module inits * Hack together experimental stand-alone Swin based attn module and `swinnet` * Consistent '26t' model defs for experiments. @@ -282,7 +282,7 @@ Several (less common) features that I often utilize in my projects are included. * PyTorch DistributedDataParallel w/ multi-gpu, single process (AMP disabled as it crashes when enabled) * PyTorch w/ single GPU single process (AMP optional) * A dynamic global pool implementation that allows selecting from average pooling, max pooling, average + max, or concat([average, max]) at model creation. All global pooling is adaptive average by default and compatible with pretrained weights. -* A 'Test Time Pool' wrapper that can wrap any of the included models and usually provide improved performance doing inference with input images larger than the training size. Idea adapted from original DPN implementation when I ported (https://github.com/cypw/DPNs) +* A 'Test Time Pool' wrapper that can wrap any of the included models and usually provides improved performance doing inference with input images larger than the training size. Idea adapted from original DPN implementation when I ported (https://github.com/cypw/DPNs) * Learning rate schedulers * Ideas adopted from * [AllenNLP schedulers](https://github.com/allenai/allennlp/tree/master/allennlp/training/learning_rate_schedulers) @@ -329,7 +329,7 @@ The root folder of the repository contains reference train, validation, and infe ## Awesome PyTorch Resources -One of the greatest assets of PyTorch is the community and their contributions. A few of my favourite resources that pair well with the models and componenets here are listed below. +One of the greatest assets of PyTorch is the community and their contributions. A few of my favourite resources that pair well with the models and components here are listed below. ### Object Detection, Instance and Semantic Segmentation * Detectron2 - https://github.com/facebookresearch/detectron2 @@ -353,7 +353,7 @@ One of the greatest assets of PyTorch is the community and their contributions. ## Licenses ### Code -The code here is licensed Apache 2.0. I've taken care to make sure any third party code included or adapted has compatible (permissive) licenses such as MIT, BSD, etc. I've made an effort to avoid any GPL / LGPL conflicts. That said, it is your responsibility to ensure you comply with license here and conditions of any dependent licenses. Where applicable, I've linked the sources/references for various components in docstrings. If you think I've missed anything please create an issue. +The code here is licensed Apache 2.0. I've taken care to make sure any third party code included or adapted has compatible (permissive) licenses such as MIT, BSD, etc. I've made an effort to avoid any GPL / LGPL conflicts. That said, it is your responsibility to ensure you comply with licenses here and conditions of any dependent licenses. Where applicable, I've linked the sources/references for various components in docstrings. If you think I've missed anything please create an issue. ### Pretrained Weights So far all of the pretrained weights available here are pretrained on ImageNet with a select few that have some additional pretraining (see extra note below). ImageNet was released for non-commercial research purposes only (http://www.image-net.org/download-faq). It's not clear what the implications of that are for the use of pretrained weights from that dataset. Any models I have trained with ImageNet are done for research purposes and one should assume that the original dataset license applies to the weights. It's best to seek legal advice if you intend to use the pretrained weights in a commercial product. From 5db74521736f6d6caef4e0cd8aba1ad624ae9390 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Tue, 25 May 2021 14:11:36 -0700 Subject: [PATCH 16/18] Fix visformer in_chans stem handling --- tests/test_models.py | 2 +- timm/models/visformer.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/test_models.py b/tests/test_models.py index de664068..44cb3ba2 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -190,7 +190,7 @@ EXCLUDE_JIT_FILTERS = [ def test_model_forward_torchscript(model_name, batch_size): """Run a single forward pass with each model""" input_size = _get_input_size(model_name=model_name, target=TARGET_JIT_SIZE) - if max(input_size) > MAX_JIT_SIZE: # NOTE using MAX_FWD_SIZE as the final limit is intentional + if max(input_size) > MAX_JIT_SIZE: pytest.skip("Fixed input size model > limit.") with set_scriptable(True): diff --git a/timm/models/visformer.py b/timm/models/visformer.py index 936f1ddf..33a2fe87 100644 --- a/timm/models/visformer.py +++ b/timm/models/visformer.py @@ -26,7 +26,7 @@ def _cfg(url='', **kwargs): 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, 'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True, 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, - 'first_conv': 'patch_embed.proj', 'classifier': 'head', + 'first_conv': 'stem.0', 'classifier': 'head', **kwargs } @@ -183,7 +183,7 @@ class Visformer(nn.Module): img_size //= 8 else: self.stem = nn.Sequential( - nn.Conv2d(3, self.init_channels, 7, stride=2, padding=3, bias=False), + nn.Conv2d(in_chans, self.init_channels, 7, stride=2, padding=3, bias=False), nn.BatchNorm2d(self.init_channels), nn.ReLU(inplace=True) ) From 9c78de8c024bff0acc68b044dfb935366c6185dc Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Wed, 26 May 2021 15:28:42 -0700 Subject: [PATCH 17/18] Fix #661, move hardswish out of default args for LeViT. Enable native torch support for hardswish, hardsigmoid, mish if present. --- tests/test_layers.py | 12 ++--- tests/test_models.py | 2 +- timm/models/efficientnet_blocks.py | 6 +-- timm/models/efficientnet_builder.py | 5 +- timm/models/ghostnet.py | 4 +- timm/models/layers/create_act.py | 74 ++++++++++++++++++----------- timm/models/layers/se.py | 2 +- timm/models/levit.py | 8 ++-- 8 files changed, 66 insertions(+), 47 deletions(-) diff --git a/tests/test_layers.py b/tests/test_layers.py index 714cb444..508a6aae 100644 --- a/tests/test_layers.py +++ b/tests/test_layers.py @@ -8,10 +8,10 @@ from timm.models.layers import create_act_layer, get_act_layer, set_layer_config class MLP(nn.Module): - def __init__(self, act_layer="relu"): + def __init__(self, act_layer="relu", inplace=True): super(MLP, self).__init__() self.fc1 = nn.Linear(1000, 100) - self.act = create_act_layer(act_layer, inplace=True) + self.act = create_act_layer(act_layer, inplace=inplace) self.fc2 = nn.Linear(100, 10) def forward(self, x): @@ -21,14 +21,14 @@ class MLP(nn.Module): return x -def _run_act_layer_grad(act_type): +def _run_act_layer_grad(act_type, inplace=True): x = torch.rand(10, 1000) * 10 - m = MLP(act_layer=act_type) + m = MLP(act_layer=act_type, inplace=inplace) def _run(x, act_layer=''): if act_layer: # replace act layer if set - m.act = create_act_layer(act_layer, inplace=True) + m.act = create_act_layer(act_layer, inplace=inplace) out = m(x) l = (out - 0).pow(2).sum() return l @@ -58,7 +58,7 @@ def test_mish_grad(): def test_hard_sigmoid_grad(): for _ in range(100): - _run_act_layer_grad('hard_sigmoid') + _run_act_layer_grad('hard_sigmoid', inplace=None) def test_hard_swish_grad(): diff --git a/tests/test_models.py b/tests/test_models.py index 44cb3ba2..18298dff 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -110,7 +110,7 @@ def test_model_backward(model_name, batch_size): assert not torch.isnan(outputs).any(), 'Output included NaNs' -@pytest.mark.timeout(120) +@pytest.mark.timeout(300) @pytest.mark.parametrize('model_name', list_models(exclude_filters=NON_STD_FILTERS)) @pytest.mark.parametrize('batch_size', [1]) def test_model_default_cfgs(model_name, batch_size): diff --git a/timm/models/efficientnet_blocks.py b/timm/models/efficientnet_blocks.py index 83b57beb..7853db0e 100644 --- a/timm/models/efficientnet_blocks.py +++ b/timm/models/efficientnet_blocks.py @@ -7,7 +7,7 @@ import torch import torch.nn as nn from torch.nn import functional as F -from .layers import create_conv2d, drop_path, make_divisible +from .layers import create_conv2d, drop_path, make_divisible, get_act_fn, create_act_layer from .layers.activations import sigmoid __all__ = [ @@ -36,9 +36,9 @@ class SqueezeExcite(nn.Module): reduced_chs = make_divisible(reduced_chs * se_ratio, divisor) act_layer = force_act_layer or act_layer self.conv_reduce = nn.Conv2d(in_chs, reduced_chs, 1, bias=True) - self.act1 = act_layer(inplace=True) + self.act1 = create_act_layer(act_layer, inplace=True) self.conv_expand = nn.Conv2d(reduced_chs, in_chs, 1, bias=True) - self.gate_fn = gate_fn + self.gate_fn = get_act_fn(gate_fn) def forward(self, x): x_se = x.mean((2, 3), keepdim=True) diff --git a/timm/models/efficientnet_builder.py b/timm/models/efficientnet_builder.py index 30739454..57e2039b 100644 --- a/timm/models/efficientnet_builder.py +++ b/timm/models/efficientnet_builder.py @@ -50,10 +50,7 @@ def resolve_bn_args(kwargs): def resolve_act_layer(kwargs, default='relu'): - act_layer = kwargs.pop('act_layer', default) - if isinstance(act_layer, str): - act_layer = get_act_layer(act_layer) - return act_layer + return get_act_layer(kwargs.pop('act_layer', default)) def round_channels(channels, multiplier=1.0, divisor=8, channel_min=None, round_limit=0.9): diff --git a/timm/models/ghostnet.py b/timm/models/ghostnet.py index c132142a..1783ff7a 100644 --- a/timm/models/ghostnet.py +++ b/timm/models/ghostnet.py @@ -13,7 +13,7 @@ import torch.nn.functional as F from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from .layers import SelectAdaptivePool2d, Linear, hard_sigmoid, make_divisible +from .layers import SelectAdaptivePool2d, Linear, make_divisible from .efficientnet_blocks import SqueezeExcite, ConvBnAct from .helpers import build_model_with_cfg from .registry import register_model @@ -40,7 +40,7 @@ default_cfgs = { } -_SE_LAYER = partial(SqueezeExcite, gate_fn=hard_sigmoid, divisor=4) +_SE_LAYER = partial(SqueezeExcite, gate_fn='hard_sigmoid', divisor=4) class GhostModule(nn.Module): diff --git a/timm/models/layers/create_act.py b/timm/models/layers/create_act.py index 426c3681..aa557692 100644 --- a/timm/models/layers/create_act.py +++ b/timm/models/layers/create_act.py @@ -1,20 +1,26 @@ """ Activation Factory Hacked together by / Copyright 2020 Ross Wightman """ +from typing import Union, Callable, Type + from .activations import * from .activations_jit import * from .activations_me import * from .config import is_exportable, is_scriptable, is_no_jit -# PyTorch has an optimized, native 'silu' (aka 'swish') operator as of PyTorch 1.7. This code -# will use native version if present. Eventually, the custom Swish layers will be removed -# and only native 'silu' will be used. +# PyTorch has an optimized, native 'silu' (aka 'swish') operator as of PyTorch 1.7. +# Also hardsigmoid, hardswish, and soon mish. This code will use native version if present. +# Eventually, the custom SiLU, Mish, Hard*, layers will be removed and only native variants will be used. _has_silu = 'silu' in dir(torch.nn.functional) +_has_hardswish = 'hardswish' in dir(torch.nn.functional) +_has_hardsigmoid = 'hardsigmoid' in dir(torch.nn.functional) +_has_mish = 'mish' in dir(torch.nn.functional) + _ACT_FN_DEFAULT = dict( silu=F.silu if _has_silu else swish, swish=F.silu if _has_silu else swish, - mish=mish, + mish=F.mish if _has_mish else mish, relu=F.relu, relu6=F.relu6, leaky_relu=F.leaky_relu, @@ -24,33 +30,39 @@ _ACT_FN_DEFAULT = dict( gelu=gelu, sigmoid=sigmoid, tanh=tanh, - hard_sigmoid=hard_sigmoid, - hard_swish=hard_swish, + hard_sigmoid=F.hardsigmoid if _has_hardsigmoid else hard_sigmoid, + hard_swish=F.hardswish if _has_hardswish else hard_swish, hard_mish=hard_mish, ) _ACT_FN_JIT = dict( silu=F.silu if _has_silu else swish_jit, swish=F.silu if _has_silu else swish_jit, - mish=mish_jit, - hard_sigmoid=hard_sigmoid_jit, - hard_swish=hard_swish_jit, + mish=F.mish if _has_mish else mish_jit, + hard_sigmoid=F.hardsigmoid if _has_hardsigmoid else hard_sigmoid_jit, + hard_swish=F.hardswish if _has_hardswish else hard_swish_jit, hard_mish=hard_mish_jit ) _ACT_FN_ME = dict( silu=F.silu if _has_silu else swish_me, swish=F.silu if _has_silu else swish_me, - mish=mish_me, - hard_sigmoid=hard_sigmoid_me, - hard_swish=hard_swish_me, + mish=F.mish if _has_mish else mish_me, + hard_sigmoid=F.hardsigmoid if _has_hardsigmoid else hard_sigmoid_me, + hard_swish=F.hardswish if _has_hardswish else hard_swish_me, hard_mish=hard_mish_me, ) +_ACT_FNS = (_ACT_FN_ME, _ACT_FN_JIT, _ACT_FN_DEFAULT) +for a in _ACT_FNS: + a.setdefault('hardsigmoid', a.get('hard_sigmoid')) + a.setdefault('hardswish', a.get('hard_swish')) + + _ACT_LAYER_DEFAULT = dict( silu=nn.SiLU if _has_silu else Swish, swish=nn.SiLU if _has_silu else Swish, - mish=Mish, + mish=nn.Mish if _has_mish else Mish, relu=nn.ReLU, relu6=nn.ReLU6, leaky_relu=nn.LeakyReLU, @@ -61,37 +73,44 @@ _ACT_LAYER_DEFAULT = dict( gelu=GELU, sigmoid=Sigmoid, tanh=Tanh, - hard_sigmoid=HardSigmoid, - hard_swish=HardSwish, + hard_sigmoid=nn.Hardsigmoid if _has_hardsigmoid else HardSigmoid, + hard_swish=nn.Hardswish if _has_hardswish else HardSwish, hard_mish=HardMish, ) _ACT_LAYER_JIT = dict( silu=nn.SiLU if _has_silu else SwishJit, swish=nn.SiLU if _has_silu else SwishJit, - mish=MishJit, - hard_sigmoid=HardSigmoidJit, - hard_swish=HardSwishJit, + mish=nn.Mish if _has_mish else MishJit, + hard_sigmoid=nn.Hardsigmoid if _has_hardsigmoid else HardSigmoidJit, + hard_swish=nn.Hardswish if _has_hardswish else HardSwishJit, hard_mish=HardMishJit ) _ACT_LAYER_ME = dict( silu=nn.SiLU if _has_silu else SwishMe, swish=nn.SiLU if _has_silu else SwishMe, - mish=MishMe, - hard_sigmoid=HardSigmoidMe, - hard_swish=HardSwishMe, + mish=nn.Mish if _has_mish else MishMe, + hard_sigmoid=nn.Hardsigmoid if _has_hardsigmoid else HardSigmoidMe, + hard_swish=nn.Hardswish if _has_hardswish else HardSwishMe, hard_mish=HardMishMe, ) +_ACT_LAYERS = (_ACT_LAYER_ME, _ACT_LAYER_JIT, _ACT_LAYER_DEFAULT) +for a in _ACT_LAYERS: + a.setdefault('hardsigmoid', a.get('hard_sigmoid')) + a.setdefault('hardswish', a.get('hard_swish')) + -def get_act_fn(name='relu'): +def get_act_fn(name: Union[Callable, str] = 'relu'): """ Activation Function Factory Fetching activation fns by name with this function allows export or torch script friendly functions to be returned dynamically based on current config. """ if not name: return None + if isinstance(name, Callable): + return name if not (is_no_jit() or is_exportable() or is_scriptable()): # If not exporting or scripting the model, first look for a memory-efficient version with # custom autograd, then fallback @@ -106,13 +125,15 @@ def get_act_fn(name='relu'): return _ACT_FN_DEFAULT[name] -def get_act_layer(name='relu'): +def get_act_layer(name: Union[Type[nn.Module], str] = 'relu'): """ Activation Layer Factory Fetching activation layers by name with this function allows export or torch script friendly functions to be returned dynamically based on current config. """ if not name: return None + if isinstance(name, type): + return name if not (is_no_jit() or is_exportable() or is_scriptable()): if name in _ACT_LAYER_ME: return _ACT_LAYER_ME[name] @@ -125,9 +146,8 @@ def get_act_layer(name='relu'): return _ACT_LAYER_DEFAULT[name] -def create_act_layer(name, inplace=False, **kwargs): +def create_act_layer(name: Union[nn.Module, str], inplace=None, **kwargs): act_layer = get_act_layer(name) - if act_layer is not None: - return act_layer(inplace=inplace, **kwargs) - else: + if act_layer is None: return None + return act_layer(**kwargs) if inplace is None else act_layer(inplace=inplace, **kwargs) diff --git a/timm/models/layers/se.py b/timm/models/layers/se.py index 54c0ef33..4354144d 100644 --- a/timm/models/layers/se.py +++ b/timm/models/layers/se.py @@ -42,7 +42,7 @@ class EffectiveSEModule(nn.Module): def __init__(self, channels, gate_layer='hard_sigmoid'): super(EffectiveSEModule, self).__init__() self.fc = nn.Conv2d(channels, channels, kernel_size=1, padding=0) - self.gate = create_act_layer(gate_layer, inplace=True) + self.gate = create_act_layer(gate_layer) def forward(self, x): x_se = x.mean((2, 3), keepdim=True) diff --git a/timm/models/levit.py b/timm/models/levit.py index 5019ee9a..2180254a 100644 --- a/timm/models/levit.py +++ b/timm/models/levit.py @@ -33,7 +33,7 @@ import torch.nn as nn from timm.data import IMAGENET_DEFAULT_STD, IMAGENET_DEFAULT_MEAN from .helpers import build_model_with_cfg, overlay_external_default_cfg -from .layers import to_ntuple +from .layers import to_ntuple, get_act_layer from .vision_transformer import trunc_normal_ from .registry import register_model @@ -443,12 +443,14 @@ class Levit(nn.Module): mlp_ratio=2, hybrid_backbone=None, down_ops=None, - act_layer=nn.Hardswish, - attn_act_layer=nn.Hardswish, + act_layer='hard_swish', + attn_act_layer='hard_swish', distillation=True, use_conv=False, drop_path=0): super().__init__() + act_layer = get_act_layer(act_layer) + attn_act_layer = get_act_layer(attn_act_layer) if isinstance(img_size, tuple): # FIXME origin impl passes single img/res dim through whole hierarchy, # not sure this model will be used enough to spend time fixing it. From d7bab8a6c52a72487d1bed0a28aad41e326d7622 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Fri, 28 May 2021 09:54:50 -0700 Subject: [PATCH 18/18] Fix strict flag change for checkpoint load. --- timm/models/helpers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/timm/models/helpers.py b/timm/models/helpers.py index dfb6b860..adfef550 100644 --- a/timm/models/helpers.py +++ b/timm/models/helpers.py @@ -44,7 +44,7 @@ def load_state_dict(checkpoint_path, use_ema=False): raise FileNotFoundError() -def load_checkpoint(model, checkpoint_path, use_ema=False, strict=False): +def load_checkpoint(model, checkpoint_path, use_ema=False, strict=True): state_dict = load_state_dict(checkpoint_path, use_ema) model.load_state_dict(state_dict, strict=strict)