diff --git a/timm/layers/create_norm.py b/timm/layers/create_norm.py index b9efae8c..a1941b5a 100644 --- a/timm/layers/create_norm.py +++ b/timm/layers/create_norm.py @@ -23,9 +23,9 @@ _NORM_MAP = dict( _NORM_TYPES = {m for n, m in _NORM_MAP.items()} -def create_norm_layer(layer_name, num_features, act_layer=None, apply_act=True, **kwargs): - layer = get_norm_layer(layer_name, act_layer=act_layer) - layer_instance = layer(num_features, apply_act=apply_act, **kwargs) +def create_norm_layer(layer_name, num_features, **kwargs): + layer = get_norm_layer(layer_name) + layer_instance = layer(num_features, **kwargs) return layer_instance diff --git a/timm/models/__init__.py b/timm/models/__init__.py index ea945ccd..a9fbbc26 100644 --- a/timm/models/__init__.py +++ b/timm/models/__init__.py @@ -15,6 +15,7 @@ from .dla import * from .dpn import * from .edgenext import * from .efficientformer import * +from .efficientformer_v2 import * from .efficientnet import * from .gcvit import * from .ghostnet import * diff --git a/timm/models/efficientformer.py b/timm/models/efficientformer.py index 4f33f29a..c6920020 100644 --- a/timm/models/efficientformer.py +++ b/timm/models/efficientformer.py @@ -20,34 +20,13 @@ import torch.nn as nn from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.layers import DropPath, trunc_normal_, to_2tuple, Mlp from ._builder import build_model_with_cfg +from ._manipulate import checkpoint_seq +from ._pretrained import generate_default_cfgs from ._registry import register_model __all__ = ['EfficientFormer'] # model_registry will add each entrypoint fn to this -def _cfg(url='', **kwargs): - return { - 'url': url, - 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, 'fixed_input_size': True, - 'crop_pct': .95, 'interpolation': 'bicubic', - 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, - 'first_conv': 'stem.conv1', 'classifier': ('head', 'head_dist'), - **kwargs - } - - -default_cfgs = dict( - efficientformer_l1=_cfg( - url="https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-morevit/efficientformer_l1_1000d_224-5b08fab0.pth", - ), - efficientformer_l3=_cfg( - url="https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-morevit/efficientformer_l3_300d_224-6816624f.pth", - ), - efficientformer_l7=_cfg( - url="https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-morevit/efficientformer_l7_300d_224-e957ab75.pth", - ), -) - EfficientFormer_width = { 'l1': (48, 96, 224, 448), 'l3': (64, 128, 320, 512), @@ -99,7 +78,7 @@ class Attention(torch.nn.Module): self.attention_bias_cache = {} # clear ab cache def get_attention_biases(self, device: torch.device) -> torch.Tensor: - if self.training: + if torch.jit.is_tracing() or self.training: return self.attention_biases[:, self.attention_bias_idxs] else: device_key = str(device) @@ -279,16 +258,17 @@ class MetaBlock2d(nn.Module): ): super().__init__() self.token_mixer = Pooling(pool_size=pool_size) + self.ls1 = LayerScale2d(dim, layer_scale_init_value) + self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.mlp = ConvMlpWithNorm( dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, norm_layer=norm_layer, drop=drop) - - self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() - self.ls1 = LayerScale2d(dim, layer_scale_init_value) self.ls2 = LayerScale2d(dim, layer_scale_init_value) + self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity() def forward(self, x): - x = x + self.drop_path(self.ls1(self.token_mixer(x))) - x = x + self.drop_path(self.ls2(self.mlp(x))) + x = x + self.drop_path1(self.ls1(self.token_mixer(x))) + x = x + self.drop_path2(self.ls2(self.mlp(x))) return x @@ -356,7 +336,10 @@ class EfficientFormerStage(nn.Module): def forward(self, x): x = self.downsample(x) - x = self.blocks(x) + if self.grad_checkpointing and not torch.jit.is_scripting(): + x = checkpoint_seq(self.blocks, x) + else: + x = self.blocks(x) return x @@ -514,6 +497,30 @@ def _checkpoint_filter_fn(state_dict, model): return out_dict +def _cfg(url='', **kwargs): + return { + 'url': url, + 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, 'fixed_input_size': True, + 'crop_pct': .95, 'interpolation': 'bicubic', + 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, + 'first_conv': 'stem.conv1', 'classifier': ('head', 'head_dist'), + **kwargs + } + + +default_cfgs = generate_default_cfgs({ + 'efficientformer_l1.snap_dist_in1k': _cfg( + hf_hub_id='timm/', + ), + 'efficientformer_l3.snap_dist_in1k': _cfg( + hf_hub_id='timm/', + ), + 'efficientformer_l7.snap_dist_in1k': _cfg( + hf_hub_id='timm/', + ), +}) + + def _create_efficientformer(variant, pretrained=False, **kwargs): model = build_model_with_cfg( EfficientFormer, variant, pretrained, @@ -524,30 +531,30 @@ def _create_efficientformer(variant, pretrained=False, **kwargs): @register_model def efficientformer_l1(pretrained=False, **kwargs): - model_kwargs = dict( + model_args = dict( depths=EfficientFormer_depth['l1'], embed_dims=EfficientFormer_width['l1'], num_vit=1, - **kwargs) - return _create_efficientformer('efficientformer_l1', pretrained=pretrained, **model_kwargs) + ) + return _create_efficientformer('efficientformer_l1', pretrained=pretrained, **dict(model_args, **kwargs)) @register_model def efficientformer_l3(pretrained=False, **kwargs): - model_kwargs = dict( + model_args = dict( depths=EfficientFormer_depth['l3'], embed_dims=EfficientFormer_width['l3'], num_vit=4, - **kwargs) - return _create_efficientformer('efficientformer_l3', pretrained=pretrained, **model_kwargs) + ) + return _create_efficientformer('efficientformer_l3', pretrained=pretrained, **dict(model_args, **kwargs)) @register_model def efficientformer_l7(pretrained=False, **kwargs): - model_kwargs = dict( + model_args = dict( depths=EfficientFormer_depth['l7'], embed_dims=EfficientFormer_width['l7'], num_vit=8, - **kwargs) - return _create_efficientformer('efficientformer_l7', pretrained=pretrained, **model_kwargs) + ) + return _create_efficientformer('efficientformer_l7', pretrained=pretrained, **dict(model_args, **kwargs)) diff --git a/timm/models/efficientformer_v2.py b/timm/models/efficientformer_v2.py new file mode 100644 index 00000000..54cc3318 --- /dev/null +++ b/timm/models/efficientformer_v2.py @@ -0,0 +1,732 @@ +""" EfficientFormer-V2 + +@article{ + li2022rethinking, + title={Rethinking Vision Transformers for MobileNet Size and Speed}, + author={Li, Yanyu and Hu, Ju and Wen, Yang and Evangelidis, Georgios and Salahi, Kamyar and Wang, Yanzhi and Tulyakov, Sergey and Ren, Jian}, + journal={arXiv preprint arXiv:2212.08059}, + year={2022} +} + +Significantly refactored and cleaned up for timm from original at: https://github.com/snap-research/EfficientFormer + +Original code licensed Apache 2.0, Copyright (c) 2022 Snap Inc. + +Modifications and timm support by / Copyright 2023, Ross Wightman +""" +import math +from functools import partial +from typing import Dict + +import torch +import torch.nn as nn + +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from timm.layers import create_conv2d, create_norm_layer, get_act_layer, get_norm_layer, ConvNormAct +from timm.layers import DropPath, trunc_normal_, to_2tuple, to_ntuple +from ._builder import build_model_with_cfg +from ._manipulate import checkpoint_seq +from ._pretrained import generate_default_cfgs +from ._registry import register_model + + +EfficientFormer_width = { + 'L': (40, 80, 192, 384), # 26m 83.3% 6attn + 'S2': (32, 64, 144, 288), # 12m 81.6% 4attn dp0.02 + 'S1': (32, 48, 120, 224), # 6.1m 79.0 + 'S0': (32, 48, 96, 176), # 75.0 75.7 +} + +EfficientFormer_depth = { + 'L': (5, 5, 15, 10), # 26m 83.3% + 'S2': (4, 4, 12, 8), # 12m + 'S1': (3, 3, 9, 6), # 79.0 + 'S0': (2, 2, 6, 4), # 75.7 +} + +EfficientFormer_expansion_ratios = { + 'L': (4, 4, (4, 4, 4, 4, 3, 3, 3, 3, 3, 3, 3, 4, 4, 4, 4), (4, 4, 4, 3, 3, 3, 3, 4, 4, 4)), + 'S2': (4, 4, (4, 4, 3, 3, 3, 3, 3, 3, 4, 4, 4, 4), (4, 4, 3, 3, 3, 3, 4, 4)), + 'S1': (4, 4, (4, 4, 3, 3, 3, 3, 4, 4, 4), (4, 4, 3, 3, 4, 4)), + 'S0': (4, 4, (4, 3, 3, 3, 4, 4), (4, 3, 3, 4)), +} + + +class ConvNorm(nn.Module): + def __init__( + self, + in_channels, + out_channels, + kernel_size=1, + stride=1, + padding='', + dilation=1, + groups=1, + bias=True, + norm_layer='batchnorm2d', + norm_kwargs=None, + ): + norm_kwargs = norm_kwargs or {} + super(ConvNorm, self).__init__() + self.conv = create_conv2d( + in_channels, + out_channels, + kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + bias=bias, + ) + self.bn = create_norm_layer(norm_layer, out_channels, **norm_kwargs) + + def forward(self, x): + x = self.conv(x) + x = self.bn(x) + return x + + +class Attention2d(torch.nn.Module): + attention_bias_cache: Dict[str, torch.Tensor] + + def __init__( + self, + dim=384, + key_dim=32, + num_heads=8, + attn_ratio=4, + resolution=7, + act_layer=nn.GELU, + stride=None, + ): + super().__init__() + self.num_heads = num_heads + self.scale = key_dim ** -0.5 + self.key_dim = key_dim + + resolution = to_2tuple(resolution) + if stride is not None: + resolution = tuple([math.ceil(r / stride) for r in resolution]) + self.stride_conv = ConvNorm(dim, dim, kernel_size=3, stride=stride, groups=dim) + self.upsample = nn.Upsample(scale_factor=stride, mode='bilinear') + else: + self.stride_conv = None + self.upsample = None + + self.resolution = resolution + self.N = self.resolution[0] * self.resolution[1] + self.d = int(attn_ratio * key_dim) + self.dh = int(attn_ratio * key_dim) * num_heads + self.attn_ratio = attn_ratio + kh = self.key_dim * self.num_heads + + self.q = ConvNorm(dim, kh) + self.k = ConvNorm(dim, kh) + self.v = ConvNorm(dim, self.dh) + self.v_local = ConvNorm(self.dh, self.dh, kernel_size=3, groups=self.dh) + self.talking_head1 = nn.Conv2d(self.num_heads, self.num_heads, kernel_size=1) + self.talking_head2 = nn.Conv2d(self.num_heads, self.num_heads, kernel_size=1) + + self.act = act_layer() + self.proj = ConvNorm(self.dh, dim, 1) + + pos = torch.stack(torch.meshgrid(torch.arange(self.resolution[0]), torch.arange(self.resolution[1]))).flatten(1) + rel_pos = (pos[..., :, None] - pos[..., None, :]).abs() + rel_pos = (rel_pos[0] * self.resolution[1]) + rel_pos[1] + self.attention_biases = torch.nn.Parameter(torch.zeros(num_heads, self.N)) + self.register_buffer('attention_bias_idxs', torch.LongTensor(rel_pos), persistent=False) + self.attention_bias_cache = {} # per-device attention_biases cache (data-parallel compat) + + @torch.no_grad() + def train(self, mode=True): + super().train(mode) + if mode and self.attention_bias_cache: + self.attention_bias_cache = {} # clear ab cache + + def get_attention_biases(self, device: torch.device) -> torch.Tensor: + if torch.jit.is_tracing() or self.training: + return self.attention_biases[:, self.attention_bias_idxs] + else: + device_key = str(device) + if device_key not in self.attention_bias_cache: + self.attention_bias_cache[device_key] = self.attention_biases[:, self.attention_bias_idxs] + return self.attention_bias_cache[device_key] + + def forward(self, x): + B, C, H, W = x.shape + if self.stride_conv is not None: + x = self.stride_conv(x) + + q = self.q(x).reshape(B, self.num_heads, -1, self.N).permute(0, 1, 3, 2) + k = self.k(x).reshape(B, self.num_heads, -1, self.N).permute(0, 1, 2, 3) + v = self.v(x) + v_local = self.v_local(v) + v = v.reshape(B, self.num_heads, -1, self.N).permute(0, 1, 3, 2) + + attn = (q @ k) * self.scale + attn = attn + self.get_attention_biases(x.device) + attn = self.talking_head1(attn) + attn = attn.softmax(dim=-1) + attn = self.talking_head2(attn) + + x = (attn @ v).transpose(2, 3) + x = x.reshape(B, self.dh, self.resolution[0], self.resolution[1]) + v_local + if self.upsample is not None: + x = self.upsample(x) + + x = self.act(x) + x = self.proj(x) + return x + + +class LocalGlobalQuery(torch.nn.Module): + def __init__(self, in_dim, out_dim): + super().__init__() + self.pool = nn.AvgPool2d(1, 2, 0) + self.local = nn.Conv2d(in_dim, in_dim, kernel_size=3, stride=2, padding=1, groups=in_dim) + self.proj = ConvNorm(in_dim, out_dim, 1) + + def forward(self, x): + local_q = self.local(x) + pool_q = self.pool(x) + q = local_q + pool_q + q = self.proj(q) + return q + + +class Attention2dDownsample(torch.nn.Module): + attention_bias_cache: Dict[str, torch.Tensor] + + def __init__( + self, + dim=384, + key_dim=16, + num_heads=8, + attn_ratio=4, + resolution=7, + out_dim=None, + act_layer=nn.GELU, + ): + super().__init__() + + self.num_heads = num_heads + self.scale = key_dim ** -0.5 + self.key_dim = key_dim + self.resolution = to_2tuple(resolution) + self.resolution2 = tuple([math.ceil(r / 2) for r in self.resolution]) + self.N = self.resolution[0] * self.resolution[1] + self.N2 = self.resolution2[0] * self.resolution2[1] + + self.d = int(attn_ratio * key_dim) + self.dh = int(attn_ratio * key_dim) * num_heads + self.attn_ratio = attn_ratio + self.out_dim = out_dim or dim + kh = self.key_dim * self.num_heads + + self.q = LocalGlobalQuery(dim, kh) + self.k = ConvNorm(dim, kh, 1) + self.v = ConvNorm(dim, self.dh, 1) + self.v_local = ConvNorm(self.dh, self.dh, kernel_size=3, stride=2, groups=self.dh) + + self.act = act_layer() + self.proj = ConvNorm(self.dh, self.out_dim, 1) + + self.attention_biases = nn.Parameter(torch.zeros(num_heads, self.N)) + k_pos = torch.stack(torch.meshgrid(torch.arange( + self.resolution[1]), + torch.arange(self.resolution[1]))).flatten(1) + q_pos = torch.stack(torch.meshgrid( + torch.arange(0, self.resolution[0], step=2), + torch.arange(0, self.resolution[1], step=2))).flatten(1) + rel_pos = (q_pos[..., :, None] - k_pos[..., None, :]).abs() + rel_pos = (rel_pos[0] * self.resolution[1]) + rel_pos[1] + self.register_buffer('attention_bias_idxs', rel_pos, persistent=False) + self.attention_bias_cache = {} # per-device attention_biases cache (data-parallel compat) + + @torch.no_grad() + def train(self, mode=True): + super().train(mode) + if mode and self.attention_bias_cache: + self.attention_bias_cache = {} # clear ab cache + + def get_attention_biases(self, device: torch.device) -> torch.Tensor: + if torch.jit.is_tracing() or self.training: + return self.attention_biases[:, self.attention_bias_idxs] + else: + device_key = str(device) + if device_key not in self.attention_bias_cache: + self.attention_bias_cache[device_key] = self.attention_biases[:, self.attention_bias_idxs] + return self.attention_bias_cache[device_key] + + def forward(self, x): + B, C, H, W = x.shape + + q = self.q(x).reshape(B, self.num_heads, -1, self.N2).permute(0, 1, 3, 2) + k = self.k(x).reshape(B, self.num_heads, -1, self.N).permute(0, 1, 2, 3) + v = self.v(x) + v_local = self.v_local(v) + v = v.reshape(B, self.num_heads, -1, self.N).permute(0, 1, 3, 2) + + attn = (q @ k) * self.scale + attn = attn + self.get_attention_biases(x.device) + attn = attn.softmax(dim=-1) + + x = (attn @ v).transpose(2, 3) + x = x.reshape(B, self.dh, self.resolution2[0], self.resolution2[1]) + v_local + x = self.act(x) + x = self.proj(x) + return x + + +class Downsample(nn.Module): + def __init__( + self, + in_chs, + out_chs, + kernel_size=3, + stride=2, + padding=1, + resolution=7, + use_attn=False, + act_layer=nn.GELU, + norm_layer=nn.BatchNorm2d, + ): + super().__init__() + + kernel_size = to_2tuple(kernel_size) + stride = to_2tuple(stride) + padding = to_2tuple(padding) + norm_layer = norm_layer or nn.Identity() + self.conv = ConvNorm( + in_chs, + out_chs, + kernel_size=kernel_size, + stride=stride, + padding=padding, + norm_layer=norm_layer, + ) + + if use_attn: + self.attn = Attention2dDownsample( + dim=in_chs, + out_dim=out_chs, + resolution=resolution, + act_layer=act_layer, + ) + else: + self.attn = None + + def forward(self, x): + out = self.conv(x) + if self.attn is not None: + return self.attn(x) + out + return out + + +class ConvMlpWithNorm(nn.Module): + """ + Implementation of MLP with 1*1 convolutions. + Input: tensor with shape [B, C, H, W] + """ + + def __init__( + self, + in_features, + hidden_features=None, + out_features=None, + act_layer=nn.GELU, + norm_layer=nn.BatchNorm2d, + drop=0., + mid_conv=False, + ): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = ConvNormAct( + in_features, hidden_features, 1, bias=True, norm_layer=norm_layer, act_layer=act_layer) + if mid_conv: + self.mid = ConvNormAct( + hidden_features, hidden_features, 3, + groups=hidden_features, bias=True, norm_layer=norm_layer, act_layer=act_layer) + else: + self.mid = nn.Identity() + self.drop1 = nn.Dropout(drop) + self.fc2 = ConvNorm(hidden_features, out_features, 1, norm_layer=norm_layer) + self.drop2 = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.mid(x) + x = self.drop1(x) + x = self.fc2(x) + x = self.drop2(x) + return x + + +class LayerScale2d(nn.Module): + def __init__(self, dim, init_values=1e-5, inplace=False): + super().__init__() + self.inplace = inplace + self.gamma = nn.Parameter(init_values * torch.ones(dim)) + + def forward(self, x): + gamma = self.gamma.view(1, -1, 1, 1) + return x.mul_(gamma) if self.inplace else x * gamma + + +class EfficientFormerV2Block(nn.Module): + def __init__( + self, + dim, + mlp_ratio=4., + act_layer=nn.GELU, + norm_layer=nn.BatchNorm2d, + drop=0., + drop_path=0., + layer_scale_init_value=1e-5, + resolution=7, + stride=None, + use_attn=True, + ): + super().__init__() + + if use_attn: + self.token_mixer = Attention2d( + dim, + resolution=resolution, + act_layer=act_layer, + stride=stride, + ) + self.ls1 = LayerScale2d( + dim, layer_scale_init_value) if layer_scale_init_value is not None else nn.Identity() + self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity() + else: + self.token_mixer = None + self.ls1 = None + self.drop_path1 = None + + self.mlp = ConvMlpWithNorm( + in_features=dim, + hidden_features=int(dim * mlp_ratio), + act_layer=act_layer, + norm_layer=norm_layer, + drop=drop, + mid_conv=True, + ) + self.ls2 = LayerScale2d( + dim, layer_scale_init_value) if layer_scale_init_value is not None else nn.Identity() + self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity() + + def forward(self, x): + if self.token_mixer is not None: + x = x + self.drop_path1(self.ls1(self.token_mixer(x))) + x = x + self.drop_path2(self.ls2(self.mlp(x))) + return x + + +class Stem4(nn.Sequential): + def __init__(self, in_chs, out_chs, act_layer=nn.GELU, norm_layer=nn.BatchNorm2d): + super().__init__() + self.stride = 4 + self.conv1 = ConvNormAct( + in_chs, out_chs // 2, kernel_size=3, stride=2, padding=1, bias=True, + norm_layer=norm_layer, act_layer=act_layer + ) + self.conv2 = ConvNormAct( + out_chs // 2, out_chs, kernel_size=3, stride=2, padding=1, bias=True, + norm_layer=norm_layer, act_layer=act_layer + ) + + +class EfficientFormerV2Stage(nn.Module): + + def __init__( + self, + dim, + dim_out, + depth, + resolution=7, + downsample=True, + block_stride=None, + downsample_use_attn=False, + block_use_attn=False, + num_vit=1, + mlp_ratio=4., + drop=.0, + drop_path=0., + layer_scale_init_value=1e-5, + act_layer=nn.GELU, + norm_layer=nn.BatchNorm2d, + + ): + super().__init__() + self.grad_checkpointing = False + mlp_ratio = to_ntuple(depth)(mlp_ratio) + resolution = to_2tuple(resolution) + + if downsample: + self.downsample = Downsample( + dim, + dim_out, + use_attn=downsample_use_attn, + resolution=resolution, + norm_layer=norm_layer, + act_layer=act_layer, + ) + dim = dim_out + resolution = tuple([math.ceil(r / 2) for r in resolution]) + else: + assert dim == dim_out + self.downsample = nn.Identity() + + blocks = [] + for block_idx in range(depth): + remain_idx = depth - num_vit - 1 + b = EfficientFormerV2Block( + dim, + resolution=resolution, + stride=block_stride, + mlp_ratio=mlp_ratio[block_idx], + use_attn=block_use_attn and block_idx > remain_idx, + drop=drop, + drop_path=drop_path[block_idx], + layer_scale_init_value=layer_scale_init_value, + act_layer=act_layer, + norm_layer=norm_layer, + ) + blocks += [b] + self.blocks = nn.Sequential(*blocks) + + def forward(self, x): + x = self.downsample(x) + if self.grad_checkpointing and not torch.jit.is_scripting(): + x = checkpoint_seq(self.blocks, x) + else: + x = self.blocks(x) + return x + + +class EfficientFormerV2(nn.Module): + def __init__( + self, + depths, + in_chans=3, + img_size=224, + global_pool='avg', + embed_dims=None, + downsamples=None, + mlp_ratios=4, + norm_layer='batchnorm2d', + norm_eps=1e-5, + act_layer='gelu', + num_classes=1000, + drop_rate=0., + drop_path_rate=0., + layer_scale_init_value=1e-5, + num_vit=0, + distillation=True, + ): + super().__init__() + assert global_pool in ('avg', '') + self.num_classes = num_classes + self.global_pool = global_pool + self.feature_info = [] + img_size = to_2tuple(img_size) + norm_layer = partial(get_norm_layer(norm_layer), eps=norm_eps) + act_layer = get_act_layer(act_layer) + + self.stem = Stem4(in_chans, embed_dims[0], act_layer=act_layer, norm_layer=norm_layer) + prev_dim = embed_dims[0] + stride = 4 + + num_stages = len(depths) + dpr = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depths)).split(depths)] + downsamples = downsamples or (False,) + (True,) * (len(depths) - 1) + mlp_ratios = to_ntuple(num_stages)(mlp_ratios) + stages = [] + for i in range(num_stages): + curr_resolution = tuple([math.ceil(s / stride) for s in img_size]) + stage = EfficientFormerV2Stage( + prev_dim, + embed_dims[i], + depth=depths[i], + resolution=curr_resolution, + downsample=downsamples[i], + block_stride=2 if i == 2 else None, + downsample_use_attn=i >= 3, + block_use_attn=i >= 2, + num_vit=num_vit, + mlp_ratio=mlp_ratios[i], + drop=drop_rate, + drop_path=dpr[i], + layer_scale_init_value=layer_scale_init_value, + act_layer=act_layer, + norm_layer=norm_layer, + ) + if downsamples[i]: + stride *= 2 + prev_dim = embed_dims[i] + self.feature_info += [dict(num_chs=prev_dim, reduction=stride, module=f'stages.{i}')] + stages.append(stage) + self.stages = nn.Sequential(*stages) + + # Classifier head + self.num_features = embed_dims[-1] + self.norm = norm_layer(embed_dims[-1]) + self.head = nn.Linear(embed_dims[-1], num_classes) if num_classes > 0 else nn.Identity() + self.dist = distillation + if self.dist: + self.head_dist = nn.Linear(embed_dims[-1], num_classes) if num_classes > 0 else nn.Identity() + else: + self.head_dist = None + + self.apply(self.init_weights) + self.distilled_training = False + + # init for classification + def init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + @torch.jit.ignore + def no_weight_decay(self): + return {k for k, _ in self.named_parameters() if 'attention_biases' in k} + + @torch.jit.ignore + def group_matcher(self, coarse=False): + matcher = dict( + stem=r'^stem', # stem and embed + blocks=[(r'^stages\.(\d+)', None), (r'^norm', (99999,))] + ) + return matcher + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + for s in self.stages: + s.grad_checkpointing = enable + + @torch.jit.ignore + def get_classifier(self): + return self.head, self.head_dist + + def reset_classifier(self, num_classes, global_pool=None): + self.num_classes = num_classes + if global_pool is not None: + self.global_pool = global_pool + self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() + self.head_dist = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() + + @torch.jit.ignore + def set_distilled_training(self, enable=True): + self.distilled_training = enable + + def forward_features(self, x): + x = self.stem(x) + x = self.stages(x) + x = self.norm(x) + return x + + def forward_head(self, x, pre_logits: bool = False): + if self.global_pool == 'avg': + x = x.mean(dim=(2, 3)) + if pre_logits: + return x + x, x_dist = self.head(x), self.head_dist(x) + if self.distilled_training and self.training and not torch.jit.is_scripting(): + # only return separate classification predictions when training in distilled mode + return x, x_dist + else: + # during standard train/finetune, inference average the classifier predictions + return (x + x_dist) / 2 + + def forward(self, x): + x = self.forward_features(x) + x = self.forward_head(x) + return x + + +def _cfg(url='', **kwargs): + return { + 'url': url, + 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, 'fixed_input_size': True, + 'crop_pct': .95, 'interpolation': 'bicubic', + 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, + 'classifier': ('head', 'head_dist'), 'first_conv': 'stem.conv1.conv', + **kwargs + } + + +default_cfgs = generate_default_cfgs({ + 'efficientformerv2_s0.snap_dist_in1k': _cfg( + hf_hub_id='timm/', + ), + 'efficientformerv2_s1.snap_dist_in1k': _cfg( + hf_hub_id='timm/', + ), + 'efficientformerv2_s2.snap_dist_in1k': _cfg( + hf_hub_id='timm/', + ), + 'efficientformerv2_l.snap_dist_in1k': _cfg( + hf_hub_id='timm/', + ), +}) + + +def _create_efficientformerv2(variant, pretrained=False, **kwargs): + out_indices = kwargs.pop('out_indices', (0, 1, 2, 3)) + model = build_model_with_cfg( + EfficientFormerV2, variant, pretrained, + feature_cfg=dict(flatten_sequential=True, out_indices=out_indices), + **kwargs) + return model + + +@register_model +def efficientformerv2_s0(pretrained=False, **kwargs): + model_args = dict( + depths=EfficientFormer_depth['S0'], + embed_dims=EfficientFormer_width['S0'], + num_vit=2, + drop_path_rate=0.0, + mlp_ratios=EfficientFormer_expansion_ratios['S0'], + ) + return _create_efficientformerv2('efficientformerv2_s0', pretrained=pretrained, **dict(model_args, **kwargs)) + + +@register_model +def efficientformerv2_s1(pretrained=False, **kwargs): + model_args = dict( + depths=EfficientFormer_depth['S1'], + embed_dims=EfficientFormer_width['S1'], + num_vit=2, + drop_path_rate=0.0, + mlp_ratios=EfficientFormer_expansion_ratios['S1'], + ) + return _create_efficientformerv2('efficientformerv2_s1', pretrained=pretrained, **dict(model_args, **kwargs)) + + +@register_model +def efficientformerv2_s2(pretrained=False, **kwargs): + model_args = dict( + depths=EfficientFormer_depth['S2'], + embed_dims=EfficientFormer_width['S2'], + num_vit=4, + drop_path_rate=0.02, + mlp_ratios=EfficientFormer_expansion_ratios['S2'], + ) + return _create_efficientformerv2('efficientformerv2_s2', pretrained=pretrained, **dict(model_args, **kwargs)) + + +@register_model +def efficientformerv2_l(pretrained=False, **kwargs): + model_args = dict( + depths=EfficientFormer_depth['L'], + embed_dims=EfficientFormer_width['L'], + num_vit=6, + drop_path_rate=0.1, + mlp_ratios=EfficientFormer_expansion_ratios['L'], + ) + return _create_efficientformerv2('efficientformerv2_l', pretrained=pretrained, **dict(model_args, **kwargs)) + diff --git a/timm/models/levit.py b/timm/models/levit.py index 8dc11309..ea731c09 100644 --- a/timm/models/levit.py +++ b/timm/models/levit.py @@ -23,6 +23,8 @@ Modifications and additions for timm hacked together by / Copyright 2021, Ross W # Modified from # https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py # Copyright 2020 Ross Wightman, Apache-2.0 License +from collections import OrderedDict +from dataclasses import dataclass from functools import partial from typing import Dict @@ -30,135 +32,53 @@ import torch import torch.nn as nn from timm.data import IMAGENET_DEFAULT_STD, IMAGENET_DEFAULT_MEAN -from timm.layers import to_ntuple, get_act_layer, trunc_normal_ +from timm.layers import to_ntuple, to_2tuple, get_act_layer, DropPath, trunc_normal_ from ._builder import build_model_with_cfg from ._manipulate import checkpoint_seq +from ._pretrained import generate_default_cfgs from ._registry import register_model -__all__ = ['LevitDistilled'] # model_registry will add each entrypoint fn to this - - -def _cfg(url='', **kwargs): - return { - 'url': url, - 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, - 'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True, - 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, - 'first_conv': 'patch_embed.0.c', 'classifier': ('head.l', 'head_dist.l'), - **kwargs - } - - -default_cfgs = dict( - levit_128s=_cfg( - url='https://dl.fbaipublicfiles.com/LeViT/LeViT-128S-96703c44.pth' - ), - levit_128=_cfg( - url='https://dl.fbaipublicfiles.com/LeViT/LeViT-128-b88c2750.pth' - ), - levit_192=_cfg( - url='https://dl.fbaipublicfiles.com/LeViT/LeViT-192-92712e41.pth' - ), - levit_256=_cfg( - url='https://dl.fbaipublicfiles.com/LeViT/LeViT-256-13b5763e.pth' - ), - levit_384=_cfg( - url='https://dl.fbaipublicfiles.com/LeViT/LeViT-384-9bdaf2e2.pth' - ), - - levit_256d=_cfg(url='', classifier='head.l'), -) - -model_cfgs = dict( - levit_128s=dict( - embed_dim=(128, 256, 384), key_dim=16, num_heads=(4, 6, 8), depth=(2, 3, 4)), - levit_128=dict( - embed_dim=(128, 256, 384), key_dim=16, num_heads=(4, 8, 12), depth=(4, 4, 4)), - levit_192=dict( - embed_dim=(192, 288, 384), key_dim=32, num_heads=(3, 5, 6), depth=(4, 4, 4)), - levit_256=dict( - embed_dim=(256, 384, 512), key_dim=32, num_heads=(4, 6, 8), depth=(4, 4, 4)), - levit_384=dict( - embed_dim=(384, 512, 768), key_dim=32, num_heads=(6, 9, 12), depth=(4, 4, 4)), - - levit_256d=dict( - embed_dim=(256, 384, 512), key_dim=32, num_heads=(4, 6, 8), depth=(4, 8, 6)), -) __all__ = ['Levit'] -@register_model -def levit_128s(pretrained=False, use_conv=False, **kwargs): - return create_levit( - 'levit_128s', pretrained=pretrained, use_conv=use_conv, **kwargs) - - -@register_model -def levit_128(pretrained=False, use_conv=False, **kwargs): - return create_levit( - 'levit_128', pretrained=pretrained, use_conv=use_conv, **kwargs) - - -@register_model -def levit_192(pretrained=False, use_conv=False, **kwargs): - return create_levit( - 'levit_192', pretrained=pretrained, use_conv=use_conv, **kwargs) - - -@register_model -def levit_256(pretrained=False, use_conv=False, **kwargs): - return create_levit( - 'levit_256', pretrained=pretrained, use_conv=use_conv, **kwargs) - - -@register_model -def levit_384(pretrained=False, use_conv=False, **kwargs): - return create_levit( - 'levit_384', pretrained=pretrained, use_conv=use_conv, **kwargs) - - -@register_model -def levit_256d(pretrained=False, use_conv=False, **kwargs): - return create_levit( - 'levit_256d', pretrained=pretrained, use_conv=use_conv, distilled=False, **kwargs) - - -class ConvNorm(nn.Sequential): +class ConvNorm(nn.Module): def __init__( - self, in_chs, out_chs, kernel_size=1, stride=1, pad=0, dilation=1, - groups=1, bn_weight_init=1, resolution=-10000): + self, in_chs, out_chs, kernel_size=1, stride=1, padding=0, dilation=1, groups=1, bn_weight_init=1): super().__init__() - self.add_module('c', nn.Conv2d(in_chs, out_chs, kernel_size, stride, pad, dilation, groups, bias=False)) - self.add_module('bn', nn.BatchNorm2d(out_chs)) + self.linear = nn.Conv2d(in_chs, out_chs, kernel_size, stride, padding, dilation, groups, bias=False) + self.bn = nn.BatchNorm2d(out_chs) nn.init.constant_(self.bn.weight, bn_weight_init) @torch.no_grad() def fuse(self): - c, bn = self._modules.values() + c, bn = self.linear, self.bn w = bn.weight / (bn.running_var + bn.eps) ** 0.5 w = c.weight * w[:, None, None, None] b = bn.bias - bn.running_mean * bn.weight / (bn.running_var + bn.eps) ** 0.5 m = nn.Conv2d( - w.size(1), w.size(0), w.shape[2:], stride=self.c.stride, - padding=self.c.padding, dilation=self.c.dilation, groups=self.c.groups) + w.size(1), w.size(0), w.shape[2:], stride=self.linear.stride, + padding=self.linear.padding, dilation=self.linear.dilation, groups=self.linear.groups) m.weight.data.copy_(w) m.bias.data.copy_(b) return m + def forward(self, x): + return self.bn(self.linear(x)) + -class LinearNorm(nn.Sequential): - def __init__(self, in_features, out_features, bn_weight_init=1, resolution=-100000): +class LinearNorm(nn.Module): + def __init__(self, in_features, out_features, bn_weight_init=1): super().__init__() - self.add_module('c', nn.Linear(in_features, out_features, bias=False)) - self.add_module('bn', nn.BatchNorm1d(out_features)) + self.linear = nn.Linear(in_features, out_features, bias=False) + self.bn = nn.BatchNorm1d(out_features) nn.init.constant_(self.bn.weight, bn_weight_init) @torch.no_grad() def fuse(self): - l, bn = self._modules.values() + l, bn = self.linear, self.bn w = bn.weight / (bn.running_var + bn.eps) ** 0.5 w = l.weight * w[:, None] b = bn.bias - bn.running_mean * bn.weight / (bn.running_var + bn.eps) ** 0.5 @@ -168,80 +88,100 @@ class LinearNorm(nn.Sequential): return m def forward(self, x): - x = self.c(x) + x = self.linear(x) return self.bn(x.flatten(0, 1)).reshape_as(x) -class NormLinear(nn.Sequential): - def __init__(self, in_features, out_features, bias=True, std=0.02): +class NormLinear(nn.Module): + def __init__(self, in_features, out_features, bias=True, std=0.02, drop=0.): super().__init__() - self.add_module('bn', nn.BatchNorm1d(in_features)) - self.add_module('l', nn.Linear(in_features, out_features, bias=bias)) + self.bn = nn.BatchNorm1d(in_features) + self.drop = nn.Dropout(drop) + self.linear = nn.Linear(in_features, out_features, bias=bias) - trunc_normal_(self.l.weight, std=std) - if self.l.bias is not None: - nn.init.constant_(self.l.bias, 0) + trunc_normal_(self.linear.weight, std=std) + if self.linear.bias is not None: + nn.init.constant_(self.linear.bias, 0) @torch.no_grad() def fuse(self): - bn, l = self._modules.values() + bn, l = self.bn, self.linear w = bn.weight / (bn.running_var + bn.eps) ** 0.5 b = bn.bias - self.bn.running_mean * self.bn.weight / (bn.running_var + bn.eps) ** 0.5 w = l.weight * w[None, :] if l.bias is None: - b = b @ self.l.weight.T + b = b @ self.linear.weight.T else: - b = (l.weight @ b[:, None]).view(-1) + self.l.bias + b = (l.weight @ b[:, None]).view(-1) + self.linear.bias m = nn.Linear(w.size(1), w.size(0)) m.weight.data.copy_(w) m.bias.data.copy_(b) return m + def forward(self, x): + return self.linear(self.drop(self.bn(x))) -def stem_b16(in_chs, out_chs, activation, resolution=224): - return nn.Sequential( - ConvNorm(in_chs, out_chs // 8, 3, 2, 1, resolution=resolution), - activation(), - ConvNorm(out_chs // 8, out_chs // 4, 3, 2, 1, resolution=resolution // 2), - activation(), - ConvNorm(out_chs // 4, out_chs // 2, 3, 2, 1, resolution=resolution // 4), - activation(), - ConvNorm(out_chs // 2, out_chs, 3, 2, 1, resolution=resolution // 8)) +class Stem8(nn.Sequential): + def __init__(self, in_chs, out_chs, act_layer): + super().__init__() + self.stride = 8 -class Residual(nn.Module): - def __init__(self, m, drop): + self.add_module('conv1', ConvNorm(in_chs, out_chs // 4, 3, stride=2, padding=1)) + self.add_module('act1', act_layer()) + self.add_module('conv2', ConvNorm(out_chs // 4, out_chs // 2, 3, stride=2, padding=1)) + self.add_module('act2', act_layer()) + self.add_module('conv3', ConvNorm(out_chs // 2, out_chs, 3, stride=2, padding=1)) + + +class Stem16(nn.Sequential): + def __init__(self, in_chs, out_chs, act_layer): super().__init__() - self.m = m - self.drop = drop + self.stride = 16 - def forward(self, x): - if self.training and self.drop > 0: - return x + self.m(x) * torch.rand( - x.size(0), 1, 1, device=x.device).ge_(self.drop).div(1 - self.drop).detach() - else: - return x + self.m(x) + self.add_module('conv1', ConvNorm(in_chs, out_chs // 8, 3, stride=2, padding=1)) + self.add_module('act1', act_layer()) + self.add_module('conv2', ConvNorm(out_chs // 8, out_chs // 4, 3, stride=2, padding=1)) + self.add_module('act2', act_layer()) + self.add_module('conv3', ConvNorm(out_chs // 4, out_chs // 2, 3, stride=2, padding=1)) + self.add_module('act3', act_layer()) + self.add_module('conv4', ConvNorm(out_chs // 2, out_chs, 3, stride=2, padding=1)) -class Subsample(nn.Module): - def __init__(self, stride, resolution): +class Downsample(nn.Module): + def __init__(self, stride, resolution, use_pool=False): super().__init__() self.stride = stride - self.resolution = resolution + self.resolution = to_2tuple(resolution) + self.pool = nn.AvgPool2d(3, stride=stride, padding=1, count_include_pad=False) if use_pool else None def forward(self, x): B, N, C = x.shape - x = x.view(B, self.resolution, self.resolution, C)[:, ::self.stride, ::self.stride] + x = x.view(B, self.resolution[0], self.resolution[1], C) + if self.pool is not None: + x = self.pool(x.permute(0, 3, 1, 2)).permute(0, 2, 3, 1) + else: + x = x[:, ::self.stride, ::self.stride] return x.reshape(B, -1, C) class Attention(nn.Module): - ab: Dict[str, torch.Tensor] + attention_bias_cache: Dict[str, torch.Tensor] def __init__( - self, dim, key_dim, num_heads=8, attn_ratio=4, act_layer=None, resolution=14, use_conv=False): + self, + dim, + key_dim, + num_heads=8, + attn_ratio=4., + resolution=14, + use_conv=False, + act_layer=nn.SiLU, + ): super().__init__() ln_layer = ConvNorm if use_conv else LinearNorm + resolution = to_2tuple(resolution) + self.use_conv = use_conv self.num_heads = num_heads self.scale = key_dim ** -0.5 @@ -250,33 +190,33 @@ class Attention(nn.Module): self.val_dim = int(attn_ratio * key_dim) self.val_attn_dim = int(attn_ratio * key_dim) * num_heads - self.qkv = ln_layer(dim, self.val_attn_dim + self.key_attn_dim * 2, resolution=resolution) - self.proj = nn.Sequential( - act_layer(), - ln_layer(self.val_attn_dim, dim, bn_weight_init=0, resolution=resolution) - ) + self.qkv = ln_layer(dim, self.val_attn_dim + self.key_attn_dim * 2) + self.proj = nn.Sequential(OrderedDict([ + ('act', act_layer()), + ('ln', ln_layer(self.val_attn_dim, dim, bn_weight_init=0)) + ])) - self.attention_biases = nn.Parameter(torch.zeros(num_heads, resolution ** 2)) - pos = torch.stack(torch.meshgrid(torch.arange(resolution), torch.arange(resolution))).flatten(1) + self.attention_biases = nn.Parameter(torch.zeros(num_heads, resolution[0] * resolution[1])) + pos = torch.stack(torch.meshgrid(torch.arange(resolution[0]), torch.arange(resolution[1]))).flatten(1) rel_pos = (pos[..., :, None] - pos[..., None, :]).abs() - rel_pos = (rel_pos[0] * resolution) + rel_pos[1] - self.register_buffer('attention_bias_idxs', rel_pos) - self.ab = {} + rel_pos = (rel_pos[0] * resolution[1]) + rel_pos[1] + self.register_buffer('attention_bias_idxs', rel_pos, persistent=False) + self.attention_bias_cache = {} @torch.no_grad() def train(self, mode=True): super().train(mode) - if mode and self.ab: - self.ab = {} # clear ab cache + if mode and self.attention_bias_cache: + self.attention_bias_cache = {} # clear ab cache def get_attention_biases(self, device: torch.device) -> torch.Tensor: - if self.training: + if torch.jit.is_tracing() or self.training: return self.attention_biases[:, self.attention_bias_idxs] else: device_key = str(device) - if device_key not in self.ab: - self.ab[device_key] = self.attention_biases[:, self.attention_bias_idxs] - return self.ab[device_key] + if device_key not in self.attention_bias_cache: + self.attention_bias_cache[device_key] = self.attention_biases[:, self.attention_bias_idxs] + return self.attention_bias_cache[device_key] def forward(self, x): # x (B,C,H,W) if self.use_conv: @@ -304,83 +244,97 @@ class Attention(nn.Module): return x -class AttentionSubsample(nn.Module): - ab: Dict[str, torch.Tensor] +class AttentionDownsample(nn.Module): + attention_bias_cache: Dict[str, torch.Tensor] def __init__( - self, in_dim, out_dim, key_dim, num_heads=8, attn_ratio=2, - act_layer=None, stride=2, resolution=14, resolution_out=7, use_conv=False): + self, + in_dim, + out_dim, + key_dim, + num_heads=8, + attn_ratio=2.0, + stride=2, + resolution=14, + use_conv=False, + use_pool=False, + act_layer=nn.SiLU, + ): super().__init__() + resolution = to_2tuple(resolution) + self.stride = stride + self.resolution = resolution self.num_heads = num_heads - self.scale = key_dim ** -0.5 self.key_dim = key_dim self.key_attn_dim = key_dim * num_heads self.val_dim = int(attn_ratio * key_dim) self.val_attn_dim = self.val_dim * self.num_heads - self.resolution = resolution - self.resolution_out_area = resolution_out ** 2 - + self.scale = key_dim ** -0.5 self.use_conv = use_conv + if self.use_conv: ln_layer = ConvNorm - sub_layer = partial(nn.AvgPool2d, kernel_size=1, padding=0) + sub_layer = partial( + nn.AvgPool2d, + kernel_size=3 if use_pool else 1, padding=1 if use_pool else 0, count_include_pad=False) else: ln_layer = LinearNorm - sub_layer = partial(Subsample, resolution=resolution) - - self.kv = ln_layer(in_dim, self.val_attn_dim + self.key_attn_dim, resolution=resolution) - self.q = nn.Sequential( - sub_layer(stride=stride), - ln_layer(in_dim, self.key_attn_dim, resolution=resolution_out) - ) - self.proj = nn.Sequential( - act_layer(), - ln_layer(self.val_attn_dim, out_dim, resolution=resolution_out) - ) - - self.attention_biases = nn.Parameter(torch.zeros(num_heads, self.resolution ** 2)) - k_pos = torch.stack(torch.meshgrid(torch.arange(resolution), torch.arange(resolution))).flatten(1) + sub_layer = partial(Downsample, resolution=resolution, use_pool=use_pool) + + self.kv = ln_layer(in_dim, self.val_attn_dim + self.key_attn_dim) + self.q = nn.Sequential(OrderedDict([ + ('down', sub_layer(stride=stride)), + ('ln', ln_layer(in_dim, self.key_attn_dim)) + ])) + self.proj = nn.Sequential(OrderedDict([ + ('act', act_layer()), + ('ln', ln_layer(self.val_attn_dim, out_dim)) + ])) + + self.attention_biases = nn.Parameter(torch.zeros(num_heads, resolution[0] * resolution[1])) + k_pos = torch.stack(torch.meshgrid(torch.arange(resolution[0]), torch.arange(resolution[1]))).flatten(1) q_pos = torch.stack(torch.meshgrid( - torch.arange(0, resolution, step=stride), - torch.arange(0, resolution, step=stride))).flatten(1) + torch.arange(0, resolution[0], step=stride), + torch.arange(0, resolution[1], step=stride))).flatten(1) rel_pos = (q_pos[..., :, None] - k_pos[..., None, :]).abs() - rel_pos = (rel_pos[0] * resolution) + rel_pos[1] - self.register_buffer('attention_bias_idxs', rel_pos) + rel_pos = (rel_pos[0] * resolution[1]) + rel_pos[1] + self.register_buffer('attention_bias_idxs', rel_pos, persistent=False) - self.ab = {} # per-device attention_biases cache + self.attention_bias_cache = {} # per-device attention_biases cache @torch.no_grad() def train(self, mode=True): super().train(mode) - if mode and self.ab: - self.ab = {} # clear ab cache + if mode and self.attention_bias_cache: + self.attention_bias_cache = {} # clear ab cache def get_attention_biases(self, device: torch.device) -> torch.Tensor: - if self.training: + if torch.jit.is_tracing() or self.training: return self.attention_biases[:, self.attention_bias_idxs] else: device_key = str(device) - if device_key not in self.ab: - self.ab[device_key] = self.attention_biases[:, self.attention_bias_idxs] - return self.ab[device_key] + if device_key not in self.attention_bias_cache: + self.attention_bias_cache[device_key] = self.attention_biases[:, self.attention_bias_idxs] + return self.attention_bias_cache[device_key] def forward(self, x): if self.use_conv: B, C, H, W = x.shape + HH, WW = (H - 1) // self.stride + 1, (W - 1) // self.stride + 1 k, v = self.kv(x).view(B, self.num_heads, -1, H * W).split([self.key_dim, self.val_dim], dim=2) - q = self.q(x).view(B, self.num_heads, self.key_dim, self.resolution_out_area) + q = self.q(x).view(B, self.num_heads, self.key_dim, -1) attn = (q.transpose(-2, -1) @ k) * self.scale + self.get_attention_biases(x.device) attn = attn.softmax(dim=-1) - x = (v @ attn.transpose(-2, -1)).reshape(B, -1, self.resolution, self.resolution) + x = (v @ attn.transpose(-2, -1)).reshape(B, self.val_attn_dim, HH, WW) else: B, N, C = x.shape k, v = self.kv(x).view(B, N, self.num_heads, -1).split([self.key_dim, self.val_dim], dim=3) k = k.permute(0, 2, 3, 1) # BHCN v = v.permute(0, 2, 1, 3) # BHNC - q = self.q(x).view(B, self.resolution_out_area, self.num_heads, self.key_dim).permute(0, 2, 1, 3) + q = self.q(x).view(B, -1, self.num_heads, self.key_dim).permute(0, 2, 1, 3) attn = q @ k * self.scale + self.get_attention_biases(x.device) attn = attn.softmax(dim=-1) @@ -390,6 +344,184 @@ class AttentionSubsample(nn.Module): return x +class LevitMlp(nn.Module): + """ MLP for Levit w/ normalization + ability to switch btw conv and linear + """ + def __init__( + self, + in_features, + hidden_features=None, + out_features=None, + use_conv=False, + act_layer=nn.SiLU, + drop=0. + ): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + ln_layer = ConvNorm if use_conv else LinearNorm + + self.ln1 = ln_layer(in_features, hidden_features) + self.act = act_layer() + self.drop = nn.Dropout(drop) + self.ln2 = ln_layer(hidden_features, out_features, bn_weight_init=0) + + def forward(self, x): + x = self.ln1(x) + x = self.act(x) + x = self.drop(x) + x = self.ln2(x) + return x + + +class LevitDownsample(nn.Module): + def __init__( + self, + in_dim, + out_dim, + key_dim, + num_heads=8, + attn_ratio=4., + mlp_ratio=2., + act_layer=nn.SiLU, + attn_act_layer=None, + resolution=14, + use_conv=False, + use_pool=False, + drop_path=0., + ): + super().__init__() + attn_act_layer = attn_act_layer or act_layer + + self.attn_downsample = AttentionDownsample( + in_dim=in_dim, + out_dim=out_dim, + key_dim=key_dim, + num_heads=num_heads, + attn_ratio=attn_ratio, + act_layer=attn_act_layer, + resolution=resolution, + use_conv=use_conv, + use_pool=use_pool, + ) + + self.mlp = LevitMlp( + out_dim, + int(out_dim * mlp_ratio), + use_conv=use_conv, + act_layer=act_layer + ) + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + + def forward(self, x): + x = self.attn_downsample(x) + x = x + self.drop_path(self.mlp(x)) + return x + + +class LevitBlock(nn.Module): + def __init__( + self, + dim, + key_dim, + num_heads=8, + attn_ratio=4., + mlp_ratio=2., + resolution=14, + use_conv=False, + act_layer=nn.SiLU, + attn_act_layer=None, + drop_path=0., + ): + super().__init__() + attn_act_layer = attn_act_layer or act_layer + + self.attn = Attention( + dim=dim, + key_dim=key_dim, + num_heads=num_heads, + attn_ratio=attn_ratio, + resolution=resolution, + use_conv=use_conv, + act_layer=attn_act_layer, + ) + self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity() + + self.mlp = LevitMlp( + dim, + int(dim * mlp_ratio), + use_conv=use_conv, + act_layer=act_layer + ) + self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity() + + def forward(self, x): + x = x + self.drop_path1(self.attn(x)) + x = x + self.drop_path2(self.mlp(x)) + return x + + +class LevitStage(nn.Module): + def __init__( + self, + in_dim, + out_dim, + key_dim, + depth=4, + num_heads=8, + attn_ratio=4.0, + mlp_ratio=4.0, + act_layer=nn.SiLU, + attn_act_layer=None, + resolution=14, + downsample='', + use_conv=False, + drop_path=0., + ): + super().__init__() + resolution = to_2tuple(resolution) + + if downsample: + self.downsample = LevitDownsample( + in_dim, + out_dim, + key_dim=key_dim, + num_heads=in_dim // key_dim, + attn_ratio=4., + mlp_ratio=2., + act_layer=act_layer, + attn_act_layer=attn_act_layer, + resolution=resolution, + use_conv=use_conv, + drop_path=drop_path, + ) + resolution = [(r - 1) // 2 + 1 for r in resolution] + else: + assert in_dim == out_dim + self.downsample = nn.Identity() + + blocks = [] + for _ in range(depth): + blocks += [LevitBlock( + out_dim, + key_dim, + num_heads=num_heads, + attn_ratio=attn_ratio, + mlp_ratio=mlp_ratio, + act_layer=act_layer, + attn_act_layer=attn_act_layer, + resolution=resolution, + use_conv=use_conv, + drop_path=drop_path, + )] + self.blocks = nn.Sequential(*blocks) + + def forward(self, x): + x = self.downsample(x) + x = self.blocks(x) + return x + + class Levit(nn.Module): """ Vision Transformer with support for patch or hybrid CNN input stage @@ -400,93 +532,82 @@ class Levit(nn.Module): def __init__( self, img_size=224, - patch_size=16, in_chans=3, num_classes=1000, embed_dim=(192,), key_dim=64, depth=(12,), num_heads=(3,), - attn_ratio=2, - mlp_ratio=2, - hybrid_backbone=None, - down_ops=None, + attn_ratio=2., + mlp_ratio=2., + stem_backbone=None, + stem_stride=None, + stem_type='s16', + down_op='subsample', act_layer='hard_swish', - attn_act_layer='hard_swish', + attn_act_layer=None, use_conv=False, global_pool='avg', drop_rate=0., drop_path_rate=0.): super().__init__() act_layer = get_act_layer(act_layer) - attn_act_layer = get_act_layer(attn_act_layer) - ln_layer = ConvNorm if use_conv else LinearNorm + attn_act_layer = get_act_layer(attn_act_layer or act_layer) self.use_conv = use_conv - if isinstance(img_size, tuple): - # FIXME origin impl passes single img/res dim through whole hierarchy, - # not sure this model will be used enough to spend time fixing it. - assert img_size[0] == img_size[1] - img_size = img_size[0] self.num_classes = num_classes self.global_pool = global_pool self.num_features = embed_dim[-1] self.embed_dim = embed_dim + self.drop_rate = drop_rate self.grad_checkpointing = False + self.feature_info = [] num_stages = len(embed_dim) - assert len(depth) == len(num_heads) == num_stages - key_dim = to_ntuple(num_stages)(key_dim) + assert len(depth) == num_stages + num_heads = to_ntuple(num_stages)(num_heads) attn_ratio = to_ntuple(num_stages)(attn_ratio) mlp_ratio = to_ntuple(num_stages)(mlp_ratio) - down_ops = down_ops or ( - # ('Subsample',key_dim, num_heads, attn_ratio, mlp_ratio, stride) - ('Subsample', key_dim[0], embed_dim[0] // key_dim[0], 4, 2, 2), - ('Subsample', key_dim[0], embed_dim[1] // key_dim[1], 4, 2, 2), - ('',) - ) - self.patch_embed = hybrid_backbone or stem_b16(in_chans, embed_dim[0], activation=act_layer) - - self.blocks = [] - resolution = img_size // patch_size - for i, (ed, kd, dpth, nh, ar, mr, do) in enumerate( - zip(embed_dim, key_dim, depth, num_heads, attn_ratio, mlp_ratio, down_ops)): - for _ in range(dpth): - self.blocks.append( - Residual( - Attention( - ed, kd, nh, attn_ratio=ar, act_layer=attn_act_layer, - resolution=resolution, use_conv=use_conv), - drop_path_rate)) - if mr > 0: - h = int(ed * mr) - self.blocks.append( - Residual(nn.Sequential( - ln_layer(ed, h, resolution=resolution), - act_layer(), - ln_layer(h, ed, bn_weight_init=0, resolution=resolution), - ), drop_path_rate)) - if do[0] == 'Subsample': - # ('Subsample',key_dim, num_heads, attn_ratio, mlp_ratio, stride) - resolution_out = (resolution - 1) // do[5] + 1 - self.blocks.append( - AttentionSubsample( - *embed_dim[i:i + 2], key_dim=do[1], num_heads=do[2], - attn_ratio=do[3], act_layer=attn_act_layer, stride=do[5], - resolution=resolution, resolution_out=resolution_out, use_conv=use_conv)) - resolution = resolution_out - if do[4] > 0: # mlp_ratio - h = int(embed_dim[i + 1] * do[4]) - self.blocks.append( - Residual(nn.Sequential( - ln_layer(embed_dim[i + 1], h, resolution=resolution), - act_layer(), - ln_layer(h, embed_dim[i + 1], bn_weight_init=0, resolution=resolution), - ), drop_path_rate)) - self.blocks = nn.Sequential(*self.blocks) + if stem_backbone is not None: + assert stem_stride >= 2 + self.stem = stem_backbone + stride = stem_stride + else: + assert stem_type in ('s16', 's8') + if stem_type == 's16': + self.stem = Stem16(in_chans, embed_dim[0], act_layer=act_layer) + else: + self.stem = Stem8(in_chans, embed_dim[0], act_layer=act_layer) + stride = self.stem.stride + resolution = tuple([i // p for i, p in zip(to_2tuple(img_size), to_2tuple(stride))]) + + in_dim = embed_dim[0] + stages = [] + for i in range(num_stages): + stage_stride = 2 if i > 0 else 1 + stages += [LevitStage( + in_dim, + embed_dim[i], + key_dim, + depth=depth[i], + num_heads=num_heads[i], + attn_ratio=attn_ratio[i], + mlp_ratio=mlp_ratio[i], + act_layer=act_layer, + attn_act_layer=attn_act_layer, + resolution=resolution, + use_conv=use_conv, + downsample=down_op if stage_stride == 2 else '', + drop_path=drop_path_rate + )] + stride *= stage_stride + resolution = tuple([(r - 1) // stage_stride + 1 for r in resolution]) + self.feature_info += [dict(num_chs=embed_dim[i], reduction=stride, module=f'stages.{i}')] + in_dim = embed_dim[i] + self.stages = nn.Sequential(*stages) # Classifier head - self.head = NormLinear(embed_dim[-1], num_classes) if num_classes > 0 else nn.Identity() + self.head = NormLinear(embed_dim[-1], num_classes, drop=drop_rate) if num_classes > 0 else nn.Identity() @torch.jit.ignore def no_weight_decay(self): @@ -512,16 +633,17 @@ class Levit(nn.Module): self.num_classes = num_classes if global_pool is not None: self.global_pool = global_pool - self.head = NormLinear(self.embed_dim[-1], num_classes) if num_classes > 0 else nn.Identity() + self.head = NormLinear( + self.embed_dim[-1], num_classes, drop=self.drop_rate) if num_classes > 0 else nn.Identity() def forward_features(self, x): - x = self.patch_embed(x) + x = self.stem(x) if not self.use_conv: x = x.flatten(2).transpose(1, 2) if self.grad_checkpointing and not torch.jit.is_scripting(): - x = checkpoint_seq(self.blocks, x) + x = checkpoint_seq(self.stages, x) else: - x = self.blocks(x) + x = self.stages(x) return x def forward_head(self, x, pre_logits: bool = False): @@ -549,16 +671,19 @@ class LevitDistilled(Levit): self.num_classes = num_classes if global_pool is not None: self.global_pool = global_pool - self.head = NormLinear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() + self.head = NormLinear( + self.num_features, num_classes, drop=self.drop_rate) if num_classes > 0 else nn.Identity() self.head_dist = NormLinear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() @torch.jit.ignore def set_distilled_training(self, enable=True): self.distilled_training = enable - def forward_head(self, x): + def forward_head(self, x, pre_logits: bool = False): if self.global_pool == 'avg': x = x.mean(dim=(-2, -1)) if self.use_conv else x.mean(dim=1) + if pre_logits: + return x x, x_dist = self.head(x), self.head_dist(x) if self.distilled_training and self.training and not torch.jit.is_scripting(): # only return separate classification predictions when training in distilled mode @@ -570,23 +695,241 @@ class LevitDistilled(Levit): def checkpoint_filter_fn(state_dict, model): if 'model' in state_dict: - # For deit models state_dict = state_dict['model'] + + # filter out attn biases, should not have been persistent + state_dict = {k: v for k, v in state_dict.items() if 'attention_bias_idxs' not in k} + D = model.state_dict() - for k in state_dict.keys(): - if k in D and D[k].ndim == 4 and state_dict[k].ndim == 2: - state_dict[k] = state_dict[k][:, :, None, None] - return state_dict + out_dict = {} + for ka, kb, va, vb in zip(D.keys(), state_dict.keys(), D.values(), state_dict.values()): + if va.ndim == 4 and vb.ndim == 2: + vb = vb[:, :, None, None] + if va.shape != vb.shape: + # head or first-conv shapes may change for fine-tune + assert 'head' in ka or 'stem.conv1.linear' in ka + out_dict[ka] = vb + return out_dict -def create_levit(variant, pretrained=False, distilled=True, **kwargs): - if kwargs.get('features_only', None): - raise RuntimeError('features_only not implemented for Vision Transformer models.') - model_cfg = dict(**model_cfgs[variant], **kwargs) +model_cfgs = dict( + levit_128s=dict( + embed_dim=(128, 256, 384), key_dim=16, num_heads=(4, 6, 8), depth=(2, 3, 4)), + levit_128=dict( + embed_dim=(128, 256, 384), key_dim=16, num_heads=(4, 8, 12), depth=(4, 4, 4)), + levit_192=dict( + embed_dim=(192, 288, 384), key_dim=32, num_heads=(3, 5, 6), depth=(4, 4, 4)), + levit_256=dict( + embed_dim=(256, 384, 512), key_dim=32, num_heads=(4, 6, 8), depth=(4, 4, 4)), + levit_384=dict( + embed_dim=(384, 512, 768), key_dim=32, num_heads=(6, 9, 12), depth=(4, 4, 4)), + + # stride-8 stem experiments + levit_384_s8=dict( + embed_dim=(384, 512, 768), key_dim=32, num_heads=(6, 9, 12), depth=(4, 4, 4), + act_layer='silu', stem_type='s8'), + levit_512_s8=dict( + embed_dim=(512, 640, 896), key_dim=64, num_heads=(8, 10, 14), depth=(4, 4, 4), + act_layer='silu', stem_type='s8'), + + # wider experiments + levit_512=dict( + embed_dim=(512, 768, 1024), key_dim=64, num_heads=(8, 12, 16), depth=(4, 4, 4), act_layer='silu'), + + # deeper experiments + levit_256d=dict( + embed_dim=(256, 384, 512), key_dim=32, num_heads=(4, 6, 8), depth=(4, 8, 6), act_layer='silu'), + levit_512d=dict( + embed_dim=(512, 640, 768), key_dim=64, num_heads=(8, 10, 12), depth=(4, 8, 6), act_layer='silu'), +) + + +def create_levit(variant, cfg_variant=None, pretrained=False, distilled=True, **kwargs): + is_conv = '_conv' in variant + out_indices = kwargs.pop('out_indices', (0, 1, 2)) + if kwargs.get('features_only', None): + if not is_conv: + raise RuntimeError('features_only not implemented for LeVit in non-convolutional mode.') + if cfg_variant is None: + if variant in model_cfgs: + cfg_variant = variant + elif is_conv: + cfg_variant = variant.replace('_conv', '') + + model_cfg = dict(model_cfgs[cfg_variant], **kwargs) model = build_model_with_cfg( - LevitDistilled if distilled else Levit, variant, pretrained, + LevitDistilled if distilled else Levit, + variant, + pretrained, pretrained_filter_fn=checkpoint_filter_fn, - **model_cfg) + feature_cfg=dict(flatten_sequential=True, out_indices=out_indices), + **model_cfg, + ) return model + +def _cfg(url='', **kwargs): + return { + 'url': url, + 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, + 'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True, + 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, + 'first_conv': 'stem.conv1.linear', 'classifier': ('head.linear', 'head_dist.linear'), + **kwargs + } + + +default_cfgs = generate_default_cfgs({ + # weights in nn.Linear mode + 'levit_128s.fb_dist_in1k': _cfg( + hf_hub_id='timm/', + ), + 'levit_128.fb_dist_in1k': _cfg( + hf_hub_id='timm/', + ), + 'levit_192.fb_dist_in1k': _cfg( + hf_hub_id='timm/', + ), + 'levit_256.fb_dist_in1k': _cfg( + hf_hub_id='timm/', + ), + 'levit_384.fb_dist_in1k': _cfg( + hf_hub_id='timm/', + ), + + # weights in nn.Conv2d mode + 'levit_conv_128s.fb_dist_in1k': _cfg( + hf_hub_id='timm/', + pool_size=(4, 4), + ), + 'levit_conv_128.fb_dist_in1k': _cfg( + hf_hub_id='timm/', + pool_size=(4, 4), + ), + 'levit_conv_192.fb_dist_in1k': _cfg( + hf_hub_id='timm/', + pool_size=(4, 4), + ), + 'levit_conv_256.fb_dist_in1k': _cfg( + hf_hub_id='timm/', + pool_size=(4, 4), + ), + 'levit_conv_384.fb_dist_in1k': _cfg( + hf_hub_id='timm/', + pool_size=(4, 4), + ), + + 'levit_384_s8.untrained': _cfg(classifier='head.linear'), + 'levit_512_s8.untrained': _cfg(classifier='head.linear'), + 'levit_512.untrained': _cfg(classifier='head.linear'), + 'levit_256d.untrained': _cfg(classifier='head.linear'), + 'levit_512d.untrained': _cfg(classifier='head.linear'), + + 'levit_conv_384_s8.untrained': _cfg(classifier='head.linear'), + 'levit_conv_512_s8.untrained': _cfg(classifier='head.linear'), + 'levit_conv_512.untrained': _cfg(classifier='head.linear'), + 'levit_conv_256d.untrained': _cfg(classifier='head.linear'), + 'levit_conv_512d.untrained': _cfg(classifier='head.linear'), +}) + + +@register_model +def levit_128s(pretrained=False, **kwargs): + return create_levit('levit_128s', pretrained=pretrained, **kwargs) + + +@register_model +def levit_128(pretrained=False, **kwargs): + return create_levit('levit_128', pretrained=pretrained, **kwargs) + + +@register_model +def levit_192(pretrained=False, **kwargs): + return create_levit('levit_192', pretrained=pretrained, **kwargs) + + +@register_model +def levit_256(pretrained=False, **kwargs): + return create_levit('levit_256', pretrained=pretrained, **kwargs) + + +@register_model +def levit_384(pretrained=False, **kwargs): + return create_levit('levit_384', pretrained=pretrained, **kwargs) + + +@register_model +def levit_384_s8(pretrained=False, **kwargs): + return create_levit('levit_384_s8', pretrained=pretrained, **kwargs) + + +@register_model +def levit_512_s8(pretrained=False, **kwargs): + return create_levit('levit_512_s8', pretrained=pretrained, distilled=False, **kwargs) + + +@register_model +def levit_512(pretrained=False, **kwargs): + return create_levit('levit_512', pretrained=pretrained, distilled=False, **kwargs) + + +@register_model +def levit_256d(pretrained=False, **kwargs): + return create_levit('levit_256d', pretrained=pretrained, distilled=False, **kwargs) + + +@register_model +def levit_512d(pretrained=False, **kwargs): + return create_levit('levit_512d', pretrained=pretrained, distilled=False, **kwargs) + + +@register_model +def levit_conv_128s(pretrained=False, **kwargs): + return create_levit('levit_conv_128s', pretrained=pretrained, use_conv=True, **kwargs) + + +@register_model +def levit_conv_128(pretrained=False, **kwargs): + return create_levit('levit_conv_128', pretrained=pretrained, use_conv=True, **kwargs) + + +@register_model +def levit_conv_192(pretrained=False, **kwargs): + return create_levit('levit_conv_192', pretrained=pretrained, use_conv=True, **kwargs) + + +@register_model +def levit_conv_256(pretrained=False, **kwargs): + return create_levit('levit_conv_256', pretrained=pretrained, use_conv=True, **kwargs) + + +@register_model +def levit_conv_384(pretrained=False, **kwargs): + return create_levit('levit_conv_384', pretrained=pretrained, use_conv=True, **kwargs) + + +@register_model +def levit_conv_384_s8(pretrained=False, **kwargs): + return create_levit('levit_conv_384_s8', pretrained=pretrained, use_conv=True, **kwargs) + + +@register_model +def levit_conv_512_s8(pretrained=False, **kwargs): + return create_levit('levit_conv_512_s8', pretrained=pretrained, use_conv=True, distilled=False, **kwargs) + + +@register_model +def levit_conv_512(pretrained=False, **kwargs): + return create_levit('levit_conv_512', pretrained=pretrained, use_conv=True, distilled=False, **kwargs) + + +@register_model +def levit_conv_256d(pretrained=False, **kwargs): + return create_levit('levit_conv_256d', pretrained=pretrained, use_conv=True, distilled=False, **kwargs) + + +@register_model +def levit_conv_512d(pretrained=False, **kwargs): + return create_levit('levit_conv_512d', pretrained=pretrained, use_conv=True, distilled=False, **kwargs) +