diff --git a/README.md b/README.md index adb43971..019fdae2 100644 --- a/README.md +++ b/README.md @@ -21,6 +21,20 @@ And a big thanks to all GitHub sponsors who helped with some of my costs before ## What's New +### Aug 15, 2022 +* ConvNeXt atto weights added + * `convnext_atto` - 75.7 @ 224, 77.0 @ 288 + * `convnext_atto_ols` - 75.9 @ 224, 77.2 @ 288 + +### Aug 5, 2022 +* More custom ConvNeXt smaller model defs with weights + * `convnext_femto` - 77.5 @ 224, 78.7 @ 288 + * `convnext_femto_ols` - 77.9 @ 224, 78.9 @ 288 + * `convnext_pico` - 79.5 @ 224, 80.4 @ 288 + * `convnext_pico_ols` - 79.5 @ 224, 80.5 @ 288 + * `convnext_nano_ols` - 80.9 @ 224, 81.6 @ 288 +* Updated EdgeNeXt to improve ONNX export, add new base variant and weights from original (https://github.com/mmaaz60/EdgeNeXt) + ### July 28, 2022 * Add freshly minted DeiT-III Medium (width=512, depth=12, num_heads=8) model weights. Thanks [Hugo Touvron](https://github.com/TouvronHugo)! diff --git a/tests/test_models.py b/tests/test_models.py index 94744483..175137e2 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -27,7 +27,9 @@ if hasattr(torch._C, '_jit_set_profiling_executor'): NON_STD_FILTERS = [ 'vit_*', 'tnt_*', 'pit_*', 'swin_*', 'coat_*', 'cait_*', '*mixer_*', 'gmlp_*', 'resmlp_*', 'twins_*', 'convit_*', 'levit*', 'visformer*', 'deit*', 'jx_nest_*', 'nest_*', 'xcit_*', 'crossvit_*', 'beit_*', - 'poolformer_*', 'volo_*', 'sequencer2d_*', 'swinv2_*'] + 'poolformer_*', 'volo_*', 'sequencer2d_*', 'swinv2_*', 'pvt_v2*', 'mvitv2*', 'gcvit*', 'efficientformer*', + 'coatnet*', 'coatnext*', 'maxvit*', 'maxxvit*', +] NUM_NON_STD = len(NON_STD_FILTERS) # exclude models that cause specific test failures diff --git a/timm/models/__init__.py b/timm/models/__init__.py index 195e451b..51a38d0c 100644 --- a/timm/models/__init__.py +++ b/timm/models/__init__.py @@ -13,7 +13,9 @@ from .densenet import * from .dla import * from .dpn import * from .edgenext import * +from .efficientformer import * from .efficientnet import * +from .gcvit import * from .ghostnet import * from .gluon_resnet import * from .gluon_xception import * @@ -23,15 +25,18 @@ from .inception_resnet_v2 import * from .inception_v3 import * from .inception_v4 import * from .levit import * +from .maxxvit import * from .mlp_mixer import * from .mobilenetv3 import * from .mobilevit import * +from .mvitv2 import * from .nasnet import * from .nest import * from .nfnet import * from .pit import * from .pnasnet import * from .poolformer import * +from .pvt_v2 import * from .regnet import * from .res2net import * from .resnest import * diff --git a/timm/models/convnext.py b/timm/models/convnext.py index 4b22c929..15000b40 100644 --- a/timm/models/convnext.py +++ b/timm/models/convnext.py @@ -4,6 +4,8 @@ Paper: `A ConvNet for the 2020s` - https://arxiv.org/pdf/2201.03545.pdf Original code and weights from https://github.com/facebookresearch/ConvNeXt, original copyright below +Model defs atto, femto, pico, nano and _ols / _hnf variants are timm specific. + Modifications and additions for timm hacked together by / Copyright 2022, Ross Wightman """ # Copyright (c) Meta Platforms, Inc. and affiliates. @@ -14,11 +16,11 @@ from functools import partial import torch import torch.nn as nn -import torch.nn.functional as F from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from .helpers import named_apply, build_model_with_cfg, checkpoint_seq -from .layers import trunc_normal_, SelectAdaptivePool2d, DropPath, ConvMlp, Mlp, LayerNorm2d, create_conv2d +from .layers import trunc_normal_, SelectAdaptivePool2d, DropPath, ConvMlp, Mlp, LayerNorm2d, LayerNorm, \ + create_conv2d, get_act_layer, make_divisible, to_ntuple from .registry import register_model @@ -37,31 +39,63 @@ def _cfg(url='', **kwargs): default_cfgs = dict( - convnext_tiny=_cfg(url="https://dl.fbaipublicfiles.com/convnext/convnext_tiny_1k_224_ema.pth"), - convnext_small=_cfg(url="https://dl.fbaipublicfiles.com/convnext/convnext_small_1k_224_ema.pth"), - convnext_base=_cfg(url="https://dl.fbaipublicfiles.com/convnext/convnext_base_1k_224_ema.pth"), - convnext_large=_cfg(url="https://dl.fbaipublicfiles.com/convnext/convnext_large_1k_224_ema.pth"), - # timm specific variants + convnext_atto=_cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_atto_d2-01bb0f51.pth', + test_input_size=(3, 288, 288), test_crop_pct=0.95), + convnext_atto_ols=_cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_atto_ols_a2-78d1c8f3.pth', + test_input_size=(3, 288, 288), test_crop_pct=0.95), + convnext_femto=_cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_femto_d1-d71d5b4c.pth', + test_input_size=(3, 288, 288), test_crop_pct=0.95), + convnext_femto_ols=_cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_femto_ols_d1-246bf2ed.pth', + test_input_size=(3, 288, 288), test_crop_pct=0.95), + convnext_pico=_cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_pico_d1-10ad7f0d.pth', + test_input_size=(3, 288, 288), test_crop_pct=0.95), + convnext_pico_ols=_cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_pico_ols_d1-611f0ca7.pth', + crop_pct=0.95, test_input_size=(3, 288, 288), test_crop_pct=1.0), convnext_nano=_cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_nano_d1h-7eb4bdea.pth', crop_pct=0.95, test_input_size=(3, 288, 288), test_crop_pct=1.0), - convnext_nano_hnf=_cfg(url=''), - convnext_nano_ols=_cfg(url=''), + convnext_nano_ols=_cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_nano_ols_d1h-ae424a9a.pth', + crop_pct=0.95, test_input_size=(3, 288, 288), test_crop_pct=1.0), convnext_tiny_hnf=_cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_tiny_hnf_a2h-ab7e9df2.pth', crop_pct=0.95, test_input_size=(3, 288, 288), test_crop_pct=1.0), + convnext_tiny=_cfg( + url="https://dl.fbaipublicfiles.com/convnext/convnext_tiny_1k_224_ema.pth", + test_input_size=(3, 288, 288), test_crop_pct=1.0), + convnext_small=_cfg( + url="https://dl.fbaipublicfiles.com/convnext/convnext_small_1k_224_ema.pth", + test_input_size=(3, 288, 288), test_crop_pct=1.0), + convnext_base=_cfg( + url="https://dl.fbaipublicfiles.com/convnext/convnext_base_1k_224_ema.pth", + test_input_size=(3, 288, 288), test_crop_pct=1.0), + convnext_large=_cfg( + url="https://dl.fbaipublicfiles.com/convnext/convnext_large_1k_224_ema.pth", + test_input_size=(3, 288, 288), test_crop_pct=1.0), + convnext_tiny_in22ft1k=_cfg( - url='https://dl.fbaipublicfiles.com/convnext/convnext_tiny_22k_1k_224.pth'), + url='https://dl.fbaipublicfiles.com/convnext/convnext_tiny_22k_1k_224.pth', + test_input_size=(3, 288, 288), test_crop_pct=1.0), convnext_small_in22ft1k=_cfg( - url='https://dl.fbaipublicfiles.com/convnext/convnext_small_22k_1k_224.pth'), + url='https://dl.fbaipublicfiles.com/convnext/convnext_small_22k_1k_224.pth', + test_input_size=(3, 288, 288), test_crop_pct=1.0), convnext_base_in22ft1k=_cfg( - url='https://dl.fbaipublicfiles.com/convnext/convnext_base_22k_1k_224.pth'), + url='https://dl.fbaipublicfiles.com/convnext/convnext_base_22k_1k_224.pth', + test_input_size=(3, 288, 288), test_crop_pct=1.0), convnext_large_in22ft1k=_cfg( - url='https://dl.fbaipublicfiles.com/convnext/convnext_large_22k_1k_224.pth'), + url='https://dl.fbaipublicfiles.com/convnext/convnext_large_22k_1k_224.pth', + test_input_size=(3, 288, 288), test_crop_pct=1.0), convnext_xlarge_in22ft1k=_cfg( - url='https://dl.fbaipublicfiles.com/convnext/convnext_xlarge_22k_1k_224_ema.pth'), + url='https://dl.fbaipublicfiles.com/convnext/convnext_xlarge_22k_1k_224_ema.pth', + test_input_size=(3, 288, 288), test_crop_pct=1.0), convnext_tiny_384_in22ft1k=_cfg( url='https://dl.fbaipublicfiles.com/convnext/convnext_tiny_22k_1k_384.pth', @@ -103,37 +137,39 @@ class ConvNeXtBlock(nn.Module): is a better choice. This was observed with PyTorch 1.10 on 3090 GPU, it could change over time & w/ different HW. Args: - dim (int): Number of input channels. + in_chs (int): Number of input channels. drop_path (float): Stochastic depth rate. Default: 0.0 ls_init_value (float): Init value for Layer Scale. Default: 1e-6. """ def __init__( self, - dim, - dim_out=None, + in_chs, + out_chs=None, + kernel_size=7, stride=1, dilation=1, mlp_ratio=4, conv_mlp=False, conv_bias=True, ls_init_value=1e-6, + act_layer='gelu', norm_layer=None, - act_layer=nn.GELU, drop_path=0., ): super().__init__() - dim_out = dim_out or dim + out_chs = out_chs or in_chs + act_layer = get_act_layer(act_layer) if not norm_layer: - norm_layer = partial(LayerNorm2d, eps=1e-6) if conv_mlp else partial(nn.LayerNorm, eps=1e-6) + norm_layer = LayerNorm2d if conv_mlp else LayerNorm mlp_layer = ConvMlp if conv_mlp else Mlp self.use_conv_mlp = conv_mlp self.conv_dw = create_conv2d( - dim, dim_out, kernel_size=7, stride=stride, dilation=dilation, depthwise=True, bias=conv_bias) - self.norm = norm_layer(dim_out) - self.mlp = mlp_layer(dim_out, int(mlp_ratio * dim_out), act_layer=act_layer) - self.gamma = nn.Parameter(ls_init_value * torch.ones(dim_out)) if ls_init_value > 0 else None + in_chs, out_chs, kernel_size=kernel_size, stride=stride, dilation=dilation, depthwise=True, bias=conv_bias) + self.norm = norm_layer(out_chs) + self.mlp = mlp_layer(out_chs, int(mlp_ratio * out_chs), act_layer=act_layer) + self.gamma = nn.Parameter(ls_init_value * torch.ones(out_chs)) if ls_init_value > 0 else None self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() def forward(self, x): @@ -160,6 +196,7 @@ class ConvNeXtStage(nn.Module): self, in_chs, out_chs, + kernel_size=7, stride=2, depth=2, dilation=(1, 1), @@ -167,6 +204,7 @@ class ConvNeXtStage(nn.Module): ls_init_value=1.0, conv_mlp=False, conv_bias=True, + act_layer='gelu', norm_layer=None, norm_layer_cl=None ): @@ -190,13 +228,15 @@ class ConvNeXtStage(nn.Module): stage_blocks = [] for i in range(depth): stage_blocks.append(ConvNeXtBlock( - dim=in_chs, - dim_out=out_chs, + in_chs=in_chs, + out_chs=out_chs, + kernel_size=kernel_size, dilation=dilation[1], drop_path=drop_path_rates[i], ls_init_value=ls_init_value, conv_mlp=conv_mlp, conv_bias=conv_bias, + act_layer=act_layer, norm_layer=norm_layer if conv_mlp else norm_layer_cl )) in_chs = out_chs @@ -234,23 +274,25 @@ class ConvNeXt(nn.Module): output_stride=32, depths=(3, 3, 9, 3), dims=(96, 192, 384, 768), + kernel_sizes=7, ls_init_value=1e-6, stem_type='patch', - stem_kernel_size=4, - stem_stride=4, + patch_size=4, head_init_scale=1., head_norm_first=False, conv_mlp=False, conv_bias=True, + act_layer='gelu', norm_layer=None, drop_rate=0., drop_path_rate=0., ): super().__init__() assert output_stride in (8, 16, 32) + kernel_sizes = to_ntuple(4)(kernel_sizes) if norm_layer is None: - norm_layer = partial(LayerNorm2d, eps=1e-6) - norm_layer_cl = norm_layer if conv_mlp else partial(nn.LayerNorm, eps=1e-6) + norm_layer = LayerNorm2d + norm_layer_cl = norm_layer if conv_mlp else LayerNorm else: assert conv_mlp,\ 'If a norm_layer is specified, conv MLP must be used so all norm expect rank-4, channels-first input' @@ -260,21 +302,22 @@ class ConvNeXt(nn.Module): self.drop_rate = drop_rate self.feature_info = [] - assert stem_type in ('patch', 'overlap') + assert stem_type in ('patch', 'overlap', 'overlap_tiered') if stem_type == 'patch': - assert stem_kernel_size == stem_stride # NOTE: this stem is a minimal form of ViT PatchEmbed, as used in SwinTransformer w/ patch_size = 4 self.stem = nn.Sequential( - nn.Conv2d(in_chans, dims[0], kernel_size=stem_kernel_size, stride=stem_stride, bias=conv_bias), + nn.Conv2d(in_chans, dims[0], kernel_size=patch_size, stride=patch_size, bias=conv_bias), norm_layer(dims[0]) ) + stem_stride = patch_size else: + mid_chs = make_divisible(dims[0] // 2) if 'tiered' in stem_type else dims[0] self.stem = nn.Sequential( - nn.Conv2d( - in_chans, dims[0], kernel_size=stem_kernel_size, stride=stem_stride, - padding=stem_kernel_size // 2, bias=conv_bias), + nn.Conv2d(in_chans, mid_chs, kernel_size=3, stride=2, padding=1, bias=conv_bias), + nn.Conv2d(mid_chs, dims[0], kernel_size=3, stride=2, padding=1, bias=conv_bias), norm_layer(dims[0]), ) + stem_stride = 4 self.stages = nn.Sequential() dp_rates = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depths)).split(depths)] @@ -294,6 +337,7 @@ class ConvNeXt(nn.Module): stages.append(ConvNeXtStage( prev_chs, out_chs, + kernel_size=kernel_sizes[i], stride=stride, dilation=(first_dilation, dilation), depth=depths[i], @@ -301,6 +345,7 @@ class ConvNeXt(nn.Module): ls_init_value=ls_init_value, conv_mlp=conv_mlp, conv_bias=conv_bias, + act_layer=act_layer, norm_layer=norm_layer, norm_layer_cl=norm_layer_cl )) @@ -415,20 +460,65 @@ def _create_convnext(variant, pretrained=False, **kwargs): @register_model -def convnext_nano(pretrained=False, **kwargs): - # timm nano variant with standard stem and head +def convnext_atto(pretrained=False, **kwargs): + # timm femto variant (NOTE: still tweaking depths, will vary between 3-4M param, current is 3.7M model_args = dict( - depths=(2, 2, 8, 2), dims=(80, 160, 320, 640), conv_mlp=True, **kwargs) - model = _create_convnext('convnext_nano', pretrained=pretrained, **model_args) + depths=(2, 2, 6, 2), dims=(40, 80, 160, 320), conv_mlp=True, **kwargs) + model = _create_convnext('convnext_atto', pretrained=pretrained, **model_args) + return model + + +@register_model +def convnext_atto_ols(pretrained=False, **kwargs): + # timm femto variant with overlapping 3x3 conv stem, wider than non-ols femto above, current param count 3.7M + model_args = dict( + depths=(2, 2, 6, 2), dims=(40, 80, 160, 320), conv_mlp=True, stem_type='overlap_tiered', **kwargs) + model = _create_convnext('convnext_atto_ols', pretrained=pretrained, **model_args) + return model + + +@register_model +def convnext_femto(pretrained=False, **kwargs): + # timm femto variant + model_args = dict( + depths=(2, 2, 6, 2), dims=(48, 96, 192, 384), conv_mlp=True, **kwargs) + model = _create_convnext('convnext_femto', pretrained=pretrained, **model_args) return model @register_model -def convnext_nano_hnf(pretrained=False, **kwargs): - # experimental nano variant with normalization before pooling in head (head norm first) +def convnext_femto_ols(pretrained=False, **kwargs): + # timm femto variant model_args = dict( - depths=(2, 2, 8, 2), dims=(80, 160, 320, 640), head_norm_first=True, conv_mlp=True, **kwargs) - model = _create_convnext('convnext_nano_hnf', pretrained=pretrained, **model_args) + depths=(2, 2, 6, 2), dims=(48, 96, 192, 384), conv_mlp=True, stem_type='overlap_tiered', **kwargs) + model = _create_convnext('convnext_femto_ols', pretrained=pretrained, **model_args) + return model + + +@register_model +def convnext_pico(pretrained=False, **kwargs): + # timm pico variant + model_args = dict( + depths=(2, 2, 6, 2), dims=(64, 128, 256, 512), conv_mlp=True, **kwargs) + model = _create_convnext('convnext_pico', pretrained=pretrained, **model_args) + return model + + +@register_model +def convnext_pico_ols(pretrained=False, **kwargs): + # timm nano variant with overlapping 3x3 conv stem + model_args = dict( + depths=(2, 2, 6, 2), dims=(64, 128, 256, 512), conv_mlp=True, stem_type='overlap_tiered', **kwargs) + model = _create_convnext('convnext_pico_ols', pretrained=pretrained, **model_args) + return model + + +@register_model +def convnext_nano(pretrained=False, **kwargs): + # timm nano variant with standard stem and head + model_args = dict( + depths=(2, 2, 8, 2), dims=(80, 160, 320, 640), conv_mlp=True, **kwargs) + model = _create_convnext('convnext_nano', pretrained=pretrained, **model_args) return model @@ -436,8 +526,7 @@ def convnext_nano_hnf(pretrained=False, **kwargs): def convnext_nano_ols(pretrained=False, **kwargs): # experimental nano variant with overlapping conv stem model_args = dict( - depths=(2, 2, 8, 2), dims=(80, 160, 320, 640), conv_mlp=True, - stem_type='overlap', stem_kernel_size=9, **kwargs) + depths=(2, 2, 8, 2), dims=(80, 160, 320, 640), conv_mlp=True, stem_type='overlap', **kwargs) model = _create_convnext('convnext_nano_ols', pretrained=pretrained, **model_args) return model diff --git a/timm/models/edgenext.py b/timm/models/edgenext.py index a81bd9b0..422d4f2c 100644 --- a/timm/models/edgenext.py +++ b/timm/models/edgenext.py @@ -50,6 +50,12 @@ default_cfgs = dict( url="https://github.com/mmaaz60/EdgeNeXt/releases/download/v1.1/edgenext_small_usi.pth", crop_pct=0.95, test_input_size=(3, 320, 320), test_crop_pct=1.0, ), + # edgenext_base=_cfg( + # url="https://github.com/mmaaz60/EdgeNeXt/releases/download/v1.2/edgenext_base_usi.pth"), + edgenext_base=_cfg( # USI weights + url="https://github.com/mmaaz60/EdgeNeXt/releases/download/v1.2/edgenext_base_usi.pth", + crop_pct=0.95, test_input_size=(3, 320, 320), test_crop_pct=1.0, + ), edgenext_small_rw=_cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/edgenext_small_rw-sw-b00041bb.pth', @@ -154,7 +160,7 @@ class CrossCovarianceAttn(nn.Module): def forward(self, x): B, N, C = x.shape - qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 4, 1) + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 4, 1) q, k, v = qkv.unbind(0) # NOTE, this is NOT spatial attn, q, k, v are B, num_heads, C, L --> C x C attn map @@ -217,7 +223,8 @@ class SplitTransposeBlock(nn.Module): shortcut = x # scales code re-written for torchscript as per my res2net fixes -rw - spx = torch.split(x, self.width, 1) + # NOTE torch.split(x, self.width, 1) causing issues with ONNX export + spx = x.chunk(len(self.convs) + 1, dim=1) spo = [] sp = spx[0] for i, conv in enumerate(self.convs): @@ -545,13 +552,19 @@ def edgenext_small(pretrained=False, **kwargs): return _create_edgenext('edgenext_small', pretrained=pretrained, **model_kwargs) +@register_model +def edgenext_base(pretrained=False, **kwargs): + # 18.51M & 3840.93M @ 256 resolution + # 82.5% (normal) 83.7% (USI) Top-1 accuracy + # AA=True, Mixup & Cutmix, DropPath=0.1, BS=4096, lr=0.006, multi-scale-sampler + # Jetson FPS=xx.xx versus xx.xx for MobileViT_S + # For A100: FPS @ BS=1: xxx.xx & @ BS=256: xxxx.xx + model_kwargs = dict(depths=[3, 3, 9, 3], dims=[80, 160, 288, 584], **kwargs) + return _create_edgenext('edgenext_base', pretrained=pretrained, **model_kwargs) + + @register_model def edgenext_small_rw(pretrained=False, **kwargs): - # 5.59M & 1260.59M @ 256 resolution - # 79.43% Top-1 accuracy - # AA=True, No Mixup & Cutmix, DropPath=0.1, BS=4096, lr=0.006, multi-scale-sampler - # Jetson FPS=20.47 versus 18.86 for MobileViT_S - # For A100: FPS @ BS=1: 172.33 & @ BS=256: 3010.25 versus FPS @ BS=1: 93.84 & @ BS=256: 1785.92 for MobileViT_S model_kwargs = dict( depths=(3, 3, 9, 3), dims=(48, 96, 192, 384), downsample_block=True, conv_bias=False, stem_type='overlap', **kwargs) diff --git a/timm/models/efficientformer.py b/timm/models/efficientformer.py new file mode 100644 index 00000000..4749d93a --- /dev/null +++ b/timm/models/efficientformer.py @@ -0,0 +1,551 @@ +""" EfficientFormer + +@article{li2022efficientformer, + title={EfficientFormer: Vision Transformers at MobileNet Speed}, + author={Li, Yanyu and Yuan, Geng and Wen, Yang and Hu, Eric and Evangelidis, Georgios and Tulyakov, + Sergey and Wang, Yanzhi and Ren, Jian}, + journal={arXiv preprint arXiv:2206.01191}, + year={2022} +} + +Based on Apache 2.0 licensed code at https://github.com/snap-research/EfficientFormer, Copyright (c) 2022 Snap Inc. + +Modifications and timm support by / Copyright 2022, Ross Wightman +""" +from typing import Dict + +import torch +import torch.nn as nn + +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from .helpers import build_model_with_cfg +from .layers import DropPath, trunc_normal_, to_2tuple, Mlp +from .registry import register_model + + +def _cfg(url='', **kwargs): + return { + 'url': url, + 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, 'fixed_input_size': True, + 'crop_pct': .95, 'interpolation': 'bicubic', + 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, + 'first_conv': 'stem.conv1', 'classifier': ('head', 'head_dist'), + **kwargs + } + + +default_cfgs = dict( + efficientformer_l1=_cfg( + url="https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-morevit/efficientformer_l1_1000d_224-5b08fab0.pth", + ), + efficientformer_l3=_cfg( + url="https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-morevit/efficientformer_l3_300d_224-6816624f.pth", + ), + efficientformer_l7=_cfg( + url="https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-morevit/efficientformer_l7_300d_224-e957ab75.pth", + ), +) + +EfficientFormer_width = { + 'l1': (48, 96, 224, 448), + 'l3': (64, 128, 320, 512), + 'l7': (96, 192, 384, 768), +} + +EfficientFormer_depth = { + 'l1': (3, 2, 6, 4), + 'l3': (4, 4, 12, 6), + 'l7': (6, 6, 18, 8), +} + + +class Attention(torch.nn.Module): + attention_bias_cache: Dict[str, torch.Tensor] + + def __init__( + self, + dim=384, + key_dim=32, + num_heads=8, + attn_ratio=4, + resolution=7 + ): + super().__init__() + self.num_heads = num_heads + self.scale = key_dim ** -0.5 + self.key_dim = key_dim + self.key_attn_dim = key_dim * num_heads + self.val_dim = int(attn_ratio * key_dim) + self.val_attn_dim = self.val_dim * num_heads + self.attn_ratio = attn_ratio + + self.qkv = nn.Linear(dim, self.key_attn_dim * 2 + self.val_attn_dim) + self.proj = nn.Linear(self.val_attn_dim, dim) + + resolution = to_2tuple(resolution) + pos = torch.stack(torch.meshgrid(torch.arange(resolution[0]), torch.arange(resolution[1]))).flatten(1) + rel_pos = (pos[..., :, None] - pos[..., None, :]).abs() + rel_pos = (rel_pos[0] * resolution[1]) + rel_pos[1] + self.attention_biases = torch.nn.Parameter(torch.zeros(num_heads, resolution[0] * resolution[1])) + self.register_buffer('attention_bias_idxs', torch.LongTensor(rel_pos)) + self.attention_bias_cache = {} # per-device attention_biases cache (data-parallel compat) + + @torch.no_grad() + def train(self, mode=True): + super().train(mode) + if mode and self.attention_bias_cache: + self.attention_bias_cache = {} # clear ab cache + + def get_attention_biases(self, device: torch.device) -> torch.Tensor: + if self.training: + return self.attention_biases[:, self.attention_bias_idxs] + else: + device_key = str(device) + if device_key not in self.attention_bias_cache: + self.attention_bias_cache[device_key] = self.attention_biases[:, self.attention_bias_idxs] + return self.attention_bias_cache[device_key] + + def forward(self, x): # x (B,N,C) + B, N, C = x.shape + qkv = self.qkv(x) + qkv = qkv.reshape(B, N, self.num_heads, -1).permute(0, 2, 1, 3) + q, k, v = qkv.split([self.key_dim, self.key_dim, self.val_dim], dim=3) + + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn + self.get_attention_biases(x.device) + + attn = attn.softmax(dim=-1) + x = (attn @ v).transpose(1, 2).reshape(B, N, self.val_attn_dim) + x = self.proj(x) + return x + + +class Stem4(nn.Sequential): + def __init__(self, in_chs, out_chs, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d): + super().__init__() + self.stride = 4 + + self.add_module('conv1', nn.Conv2d(in_chs, out_chs // 2, kernel_size=3, stride=2, padding=1)) + self.add_module('norm1', norm_layer(out_chs // 2)) + self.add_module('act1', act_layer()) + self.add_module('conv2', nn.Conv2d(out_chs // 2, out_chs, kernel_size=3, stride=2, padding=1)) + self.add_module('norm2', norm_layer(out_chs)) + self.add_module('act2', act_layer()) + + +class Downsample(nn.Module): + """ + Downsampling via strided conv w/ norm + Input: tensor in shape [B, C, H, W] + Output: tensor in shape [B, C, H/stride, W/stride] + """ + + def __init__(self, in_chs, out_chs, kernel_size=3, stride=2, padding=None, norm_layer=nn.BatchNorm2d): + super().__init__() + if padding is None: + padding = kernel_size // 2 + self.conv = nn.Conv2d(in_chs, out_chs, kernel_size=kernel_size, stride=stride, padding=padding) + self.norm = norm_layer(out_chs) + + def forward(self, x): + x = self.conv(x) + x = self.norm(x) + return x + + +class Flat(nn.Module): + + def __init__(self, ): + super().__init__() + + def forward(self, x): + x = x.flatten(2).transpose(1, 2) + return x + + +class Pooling(nn.Module): + """ + Implementation of pooling for PoolFormer + --pool_size: pooling size + """ + + def __init__(self, pool_size=3): + super().__init__() + self.pool = nn.AvgPool2d(pool_size, stride=1, padding=pool_size // 2, count_include_pad=False) + + def forward(self, x): + return self.pool(x) - x + + +class ConvMlpWithNorm(nn.Module): + """ + Implementation of MLP with 1*1 convolutions. + Input: tensor with shape [B, C, H, W] + """ + + def __init__( + self, + in_features, + hidden_features=None, + out_features=None, + act_layer=nn.GELU, + norm_layer=nn.BatchNorm2d, + drop=0. + ): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Conv2d(in_features, hidden_features, 1) + self.norm1 = norm_layer(hidden_features) if norm_layer is not None else nn.Identity() + self.act = act_layer() + self.fc2 = nn.Conv2d(hidden_features, out_features, 1) + self.norm2 = norm_layer(out_features) if norm_layer is not None else nn.Identity() + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.norm1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.norm2(x) + x = self.drop(x) + return x + + +class LayerScale(nn.Module): + def __init__(self, dim, init_values=1e-5, inplace=False): + super().__init__() + self.inplace = inplace + self.gamma = nn.Parameter(init_values * torch.ones(dim)) + + def forward(self, x): + return x.mul_(self.gamma) if self.inplace else x * self.gamma + + +class MetaBlock1d(nn.Module): + + def __init__( + self, + dim, + mlp_ratio=4., + act_layer=nn.GELU, + norm_layer=nn.LayerNorm, + drop=0., + drop_path=0., + layer_scale_init_value=1e-5 + ): + super().__init__() + self.norm1 = norm_layer(dim) + self.token_mixer = Attention(dim) + self.norm2 = norm_layer(dim) + self.mlp = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=drop) + + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.ls1 = LayerScale(dim, layer_scale_init_value) + self.ls2 = LayerScale(dim, layer_scale_init_value) + + def forward(self, x): + x = x + self.drop_path(self.ls1(self.token_mixer(self.norm1(x)))) + x = x + self.drop_path(self.ls2(self.mlp(self.norm2(x)))) + return x + + +class LayerScale2d(nn.Module): + def __init__(self, dim, init_values=1e-5, inplace=False): + super().__init__() + self.inplace = inplace + self.gamma = nn.Parameter(init_values * torch.ones(dim)) + + def forward(self, x): + gamma = self.gamma.view(1, -1, 1, 1) + return x.mul_(gamma) if self.inplace else x * gamma + + +class MetaBlock2d(nn.Module): + + def __init__( + self, + dim, + pool_size=3, + mlp_ratio=4., + act_layer=nn.GELU, + norm_layer=nn.BatchNorm2d, + drop=0., + drop_path=0., + layer_scale_init_value=1e-5 + ): + super().__init__() + self.token_mixer = Pooling(pool_size=pool_size) + self.mlp = ConvMlpWithNorm( + dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, norm_layer=norm_layer, drop=drop) + + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.ls1 = LayerScale2d(dim, layer_scale_init_value) + self.ls2 = LayerScale2d(dim, layer_scale_init_value) + + def forward(self, x): + x = x + self.drop_path(self.ls1(self.token_mixer(x))) + x = x + self.drop_path(self.ls2(self.mlp(x))) + return x + + +class EfficientFormerStage(nn.Module): + + def __init__( + self, + dim, + dim_out, + depth, + downsample=True, + num_vit=1, + pool_size=3, + mlp_ratio=4., + act_layer=nn.GELU, + norm_layer=nn.BatchNorm2d, + norm_layer_cl=nn.LayerNorm, + drop=.0, + drop_path=0., + layer_scale_init_value=1e-5, +): + super().__init__() + self.grad_checkpointing = False + + if downsample: + self.downsample = Downsample(in_chs=dim, out_chs=dim_out, norm_layer=norm_layer) + dim = dim_out + else: + assert dim == dim_out + self.downsample = nn.Identity() + + blocks = [] + if num_vit and num_vit >= depth: + blocks.append(Flat()) + + for block_idx in range(depth): + remain_idx = depth - block_idx - 1 + if num_vit and num_vit > remain_idx: + blocks.append( + MetaBlock1d( + dim, + mlp_ratio=mlp_ratio, + act_layer=act_layer, + norm_layer=norm_layer_cl, + drop=drop, + drop_path=drop_path[block_idx], + layer_scale_init_value=layer_scale_init_value, + )) + else: + blocks.append( + MetaBlock2d( + dim, + pool_size=pool_size, + mlp_ratio=mlp_ratio, + act_layer=act_layer, + norm_layer=norm_layer, + drop=drop, + drop_path=drop_path[block_idx], + layer_scale_init_value=layer_scale_init_value, + )) + if num_vit and num_vit == remain_idx: + blocks.append(Flat()) + + self.blocks = nn.Sequential(*blocks) + + def forward(self, x): + x = self.downsample(x) + x = self.blocks(x) + return x + + +class EfficientFormer(nn.Module): + + def __init__( + self, + depths, + embed_dims=None, + in_chans=3, + num_classes=1000, + global_pool='avg', + downsamples=None, + num_vit=0, + mlp_ratios=4, + pool_size=3, + layer_scale_init_value=1e-5, + act_layer=nn.GELU, + norm_layer=nn.BatchNorm2d, + norm_layer_cl=nn.LayerNorm, + drop_rate=0., + drop_path_rate=0., + **kwargs + ): + super().__init__() + self.num_classes = num_classes + self.global_pool = global_pool + + self.stem = Stem4(in_chans, embed_dims[0], norm_layer=norm_layer) + prev_dim = embed_dims[0] + + # stochastic depth decay rule + dpr = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depths)).split(depths)] + downsamples = downsamples or (False,) + (True,) * (len(depths) - 1) + stages = [] + for i in range(len(depths)): + stage = EfficientFormerStage( + prev_dim, + embed_dims[i], + depths[i], + downsample=downsamples[i], + num_vit=num_vit if i == 3 else 0, + pool_size=pool_size, + mlp_ratio=mlp_ratios, + act_layer=act_layer, + norm_layer_cl=norm_layer_cl, + norm_layer=norm_layer, + drop=drop_rate, + drop_path=dpr[i], + layer_scale_init_value=layer_scale_init_value, + ) + prev_dim = embed_dims[i] + stages.append(stage) + + self.stages = nn.Sequential(*stages) + + # Classifier head + self.num_features = embed_dims[-1] + self.norm = norm_layer_cl(self.num_features) + self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() + # assuming model is always distilled (valid for current checkpoints, will split def if that changes) + self.head_dist = nn.Linear(embed_dims[-1], num_classes) if num_classes > 0 else nn.Identity() + self.distilled_training = False # must set this True to train w/ distillation token + + self.apply(self._init_weights) + + # init for classification + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + + @torch.jit.ignore + def no_weight_decay(self): + return {k for k, _ in self.named_parameters() if 'attention_biases' in k} + + @torch.jit.ignore + def group_matcher(self, coarse=False): + matcher = dict( + stem=r'^stem', # stem and embed + blocks=[(r'^stages\.(\d+)', None), (r'^norm', (99999,))] + ) + return matcher + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + for s in self.stages: + s.grad_checkpointing = enable + + @torch.jit.ignore + def get_classifier(self): + return self.head, self.head_dist + + def reset_classifier(self, num_classes, global_pool=None): + self.num_classes = num_classes + if global_pool is not None: + self.global_pool = global_pool + self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() + self.head_dist = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() + + @torch.jit.ignore + def set_distilled_training(self, enable=True): + self.distilled_training = enable + + def forward_features(self, x): + x = self.stem(x) + x = self.stages(x) + x = self.norm(x) + return x + + def forward_head(self, x, pre_logits: bool = False): + if self.global_pool == 'avg': + x = x.mean(dim=1) + if pre_logits: + return x + x, x_dist = self.head(x), self.head_dist(x) + if self.distilled_training and self.training and not torch.jit.is_scripting(): + # only return separate classification predictions when training in distilled mode + return x, x_dist + else: + # during standard train/finetune, inference average the classifier predictions + return (x + x_dist) / 2 + + def forward(self, x): + x = self.forward_features(x) + x = self.forward_head(x) + return x + + +def _checkpoint_filter_fn(state_dict, model): + """ Remap original checkpoints -> timm """ + if 'stem.0.weight' in state_dict: + return state_dict # non-original checkpoint, no remapping needed + + out_dict = {} + import re + stage_idx = 0 + for k, v in state_dict.items(): + if k.startswith('patch_embed'): + k = k.replace('patch_embed.0', 'stem.conv1') + k = k.replace('patch_embed.1', 'stem.norm1') + k = k.replace('patch_embed.3', 'stem.conv2') + k = k.replace('patch_embed.4', 'stem.norm2') + + if re.match(r'network\.(\d+)\.proj\.weight', k): + stage_idx += 1 + k = re.sub(r'network.(\d+).(\d+)', f'stages.{stage_idx}.blocks.\\2', k) + k = re.sub(r'network.(\d+).proj', f'stages.{stage_idx}.downsample.conv', k) + k = re.sub(r'network.(\d+).norm', f'stages.{stage_idx}.downsample.norm', k) + + k = re.sub(r'layer_scale_([0-9])', r'ls\1.gamma', k) + k = k.replace('dist_head', 'head_dist') + out_dict[k] = v + return out_dict + + +def _create_efficientformer(variant, pretrained=False, **kwargs): + model = build_model_with_cfg( + EfficientFormer, variant, pretrained, + pretrained_filter_fn=_checkpoint_filter_fn, + **kwargs) + return model + + +@register_model +def efficientformer_l1(pretrained=False, **kwargs): + model_kwargs = dict( + depths=EfficientFormer_depth['l1'], + embed_dims=EfficientFormer_width['l1'], + num_vit=1, + **kwargs) + return _create_efficientformer('efficientformer_l1', pretrained=pretrained, **model_kwargs) + + +@register_model +def efficientformer_l3(pretrained=False, **kwargs): + model_kwargs = dict( + depths=EfficientFormer_depth['l3'], + embed_dims=EfficientFormer_width['l3'], + num_vit=4, + **kwargs) + return _create_efficientformer('efficientformer_l3', pretrained=pretrained, **model_kwargs) + + +@register_model +def efficientformer_l7(pretrained=False, **kwargs): + model_kwargs = dict( + depths=EfficientFormer_depth['l7'], + embed_dims=EfficientFormer_width['l7'], + num_vit=8, + **kwargs) + return _create_efficientformer('efficientformer_l7', pretrained=pretrained, **model_kwargs) + diff --git a/timm/models/gcvit.py b/timm/models/gcvit.py new file mode 100644 index 00000000..bad40bd6 --- /dev/null +++ b/timm/models/gcvit.py @@ -0,0 +1,588 @@ +""" Global Context ViT + +From scratch implementation of GCViT in the style of timm swin_transformer_v2_cr.py + +Global Context Vision Transformers -https://arxiv.org/abs/2206.09959 + +@article{hatamizadeh2022global, + title={Global Context Vision Transformers}, + author={Hatamizadeh, Ali and Yin, Hongxu and Kautz, Jan and Molchanov, Pavlo}, + journal={arXiv preprint arXiv:2206.09959}, + year={2022} +} + +Free of any code related to NVIDIA GCVit impl at https://github.com/NVlabs/GCVit. +The license for this code release is Apache 2.0 with no commercial restrictions. + +However, weight files adapted from NVIDIA GCVit impl ARE under a non-commercial share-alike license +(https://creativecommons.org/licenses/by-nc-sa/4.0/) until I have a chance to train new ones... + +Hacked together by / Copyright 2022, Ross Wightman +""" +import math +from functools import partial +from typing import Callable, List, Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.utils.checkpoint as checkpoint + +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from .fx_features import register_notrace_function +from .helpers import build_model_with_cfg, named_apply +from .layers import trunc_normal_tf_, DropPath, to_2tuple, Mlp, get_attn, get_act_layer, get_norm_layer, \ + ClassifierHead, LayerNorm2d, _assert +from .registry import register_model +from .vision_transformer_relpos import RelPosMlp, RelPosBias # FIXME move to common location + +__all__ = ['GlobalContextVit'] + + +def _cfg(url='', **kwargs): + return { + 'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7), + 'crop_pct': 0.875, 'interpolation': 'bicubic', + 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, + 'first_conv': 'stem.conv1', 'classifier': 'head.fc', + 'fixed_input_size': True, + **kwargs + } + + +default_cfgs = { + 'gcvit_xxtiny': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-morevit/gcvit_xxtiny_224_nvidia-d1d86009.pth'), + 'gcvit_xtiny': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-morevit/gcvit_xtiny_224_nvidia-274b92b7.pth'), + 'gcvit_tiny': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-morevit/gcvit_tiny_224_nvidia-ac783954.pth'), + 'gcvit_small': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-morevit/gcvit_small_224_nvidia-4e98afa2.pth'), + 'gcvit_base': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-morevit/gcvit_base_224_nvidia-f009139b.pth'), +} + + +class MbConvBlock(nn.Module): + """ A depthwise separable / fused mbconv style residual block with SE, `no norm. + """ + def __init__( + self, + in_chs, + out_chs=None, + expand_ratio=1.0, + attn_layer='se', + bias=False, + act_layer=nn.GELU, + ): + super().__init__() + attn_kwargs = dict(act_layer=act_layer) + if isinstance(attn_layer, str) and attn_layer == 'se' or attn_layer == 'eca': + attn_kwargs['rd_ratio'] = 0.25 + attn_kwargs['bias'] = False + attn_layer = get_attn(attn_layer) + out_chs = out_chs or in_chs + mid_chs = int(expand_ratio * in_chs) + + self.conv_dw = nn.Conv2d(in_chs, mid_chs, 3, 1, 1, groups=in_chs, bias=bias) + self.act = act_layer() + self.se = attn_layer(mid_chs, **attn_kwargs) + self.conv_pw = nn.Conv2d(mid_chs, out_chs, 1, 1, 0, bias=bias) + + def forward(self, x): + shortcut = x + x = self.conv_dw(x) + x = self.act(x) + x = self.se(x) + x = self.conv_pw(x) + x = x + shortcut + return x + + +class Downsample2d(nn.Module): + def __init__( + self, + dim, + dim_out=None, + reduction='conv', + act_layer=nn.GELU, + norm_layer=LayerNorm2d, # NOTE in NCHW + ): + super().__init__() + dim_out = dim_out or dim + + self.norm1 = norm_layer(dim) if norm_layer is not None else nn.Identity() + self.conv_block = MbConvBlock(dim, act_layer=act_layer) + assert reduction in ('conv', 'max', 'avg') + if reduction == 'conv': + self.reduction = nn.Conv2d(dim, dim_out, 3, 2, 1, bias=False) + elif reduction == 'max': + assert dim == dim_out + self.reduction = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + else: + assert dim == dim_out + self.reduction = nn.AvgPool2d(kernel_size=2) + self.norm2 = norm_layer(dim_out) if norm_layer is not None else nn.Identity() + + def forward(self, x): + x = self.norm1(x) + x = self.conv_block(x) + x = self.reduction(x) + x = self.norm2(x) + return x + + +class FeatureBlock(nn.Module): + def __init__( + self, + dim, + levels=0, + reduction='max', + act_layer=nn.GELU, + ): + super().__init__() + reductions = levels + levels = max(1, levels) + if reduction == 'avg': + pool_fn = partial(nn.AvgPool2d, kernel_size=2) + else: + pool_fn = partial(nn.MaxPool2d, kernel_size=3, stride=2, padding=1) + self.blocks = nn.Sequential() + for i in range(levels): + self.blocks.add_module(f'conv{i+1}', MbConvBlock(dim, act_layer=act_layer)) + if reductions: + self.blocks.add_module(f'pool{i+1}', pool_fn()) + reductions -= 1 + + def forward(self, x): + return self.blocks(x) + + +class Stem(nn.Module): + def __init__( + self, + in_chs: int = 3, + out_chs: int = 96, + act_layer: Callable = nn.GELU, + norm_layer: Callable = LayerNorm2d, # NOTE stem in NCHW + ): + super().__init__() + self.conv1 = nn.Conv2d(in_chs, out_chs, kernel_size=3, stride=2, padding=1) + self.down = Downsample2d(out_chs, act_layer=act_layer, norm_layer=norm_layer) + + def forward(self, x): + x = self.conv1(x) + x = self.down(x) + return x + + +class WindowAttentionGlobal(nn.Module): + + def __init__( + self, + dim: int, + num_heads: int, + window_size: Tuple[int, int], + use_global: bool = True, + qkv_bias: bool = True, + attn_drop: float = 0., + proj_drop: float = 0., + ): + super().__init__() + window_size = to_2tuple(window_size) + self.window_size = window_size + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.scale = self.head_dim ** -0.5 + self.use_global = use_global + + self.rel_pos = RelPosBias(window_size=window_size, num_heads=num_heads) + if self.use_global: + self.qkv = nn.Linear(dim, dim * 2, bias=qkv_bias) + else: + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x, q_global: Optional[torch.Tensor] = None): + B, N, C = x.shape + if self.use_global and q_global is not None: + _assert(x.shape[-1] == q_global.shape[-1], 'x and q_global seq lengths should be equal') + + kv = self.qkv(x) + kv = kv.reshape(B, N, 2, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) + k, v = kv.unbind(0) + + q = q_global.repeat(B // q_global.shape[0], 1, 1, 1) + q = q.reshape(B, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3) + else: + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) + q, k, v = qkv.unbind(0) + q = q * self.scale + + attn = (q @ k.transpose(-2, -1)) + attn = self.rel_pos(attn) + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +def window_partition(x, window_size: Tuple[int, int]): + B, H, W, C = x.shape + x = x.view(B, H // window_size[0], window_size[0], W // window_size[1], window_size[1], C) + windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size[0], window_size[1], C) + return windows + + +@register_notrace_function # reason: int argument is a Proxy +def window_reverse(windows, window_size: Tuple[int, int], img_size: Tuple[int, int]): + H, W = img_size + B = int(windows.shape[0] / (H * W / window_size[0] / window_size[1])) + x = windows.view(B, H // window_size[0], W // window_size[1], window_size[0], window_size[1], -1) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) + return x + + +class LayerScale(nn.Module): + def __init__(self, dim, init_values=1e-5, inplace=False): + super().__init__() + self.inplace = inplace + self.gamma = nn.Parameter(init_values * torch.ones(dim)) + + def forward(self, x): + return x.mul_(self.gamma) if self.inplace else x * self.gamma + + +class GlobalContextVitBlock(nn.Module): + def __init__( + self, + dim: int, + feat_size: Tuple[int, int], + num_heads: int, + window_size: int = 7, + mlp_ratio: float = 4., + use_global: bool = True, + qkv_bias: bool = True, + layer_scale: Optional[float] = None, + proj_drop: float = 0., + attn_drop: float = 0., + drop_path: float = 0., + attn_layer: Callable = WindowAttentionGlobal, + act_layer: Callable = nn.GELU, + norm_layer: Callable = nn.LayerNorm, + ): + super().__init__() + feat_size = to_2tuple(feat_size) + window_size = to_2tuple(window_size) + self.window_size = window_size + self.num_windows = int((feat_size[0] // window_size[0]) * (feat_size[1] // window_size[1])) + + self.norm1 = norm_layer(dim) + self.attn = attn_layer( + dim, + num_heads=num_heads, + window_size=window_size, + use_global=use_global, + qkv_bias=qkv_bias, + attn_drop=attn_drop, + proj_drop=proj_drop, + ) + self.ls1 = LayerScale(dim, layer_scale) if layer_scale is not None else nn.Identity() + self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity() + + self.norm2 = norm_layer(dim) + self.mlp = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=proj_drop) + self.ls2 = LayerScale(dim, layer_scale) if layer_scale is not None else nn.Identity() + self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity() + + def _window_attn(self, x, q_global: Optional[torch.Tensor] = None): + B, H, W, C = x.shape + x_win = window_partition(x, self.window_size) + x_win = x_win.view(-1, self.window_size[0] * self.window_size[1], C) + attn_win = self.attn(x_win, q_global) + x = window_reverse(attn_win, self.window_size, (H, W)) + return x + + def forward(self, x, q_global: Optional[torch.Tensor] = None): + x = x + self.drop_path1(self.ls1(self._window_attn(self.norm1(x), q_global))) + x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x)))) + return x + + +class GlobalContextVitStage(nn.Module): + def __init__( + self, + dim, + depth: int, + num_heads: int, + feat_size: Tuple[int, int], + window_size: int, + downsample: bool = True, + global_norm: bool = False, + stage_norm: bool = False, + mlp_ratio: float = 4., + qkv_bias: bool = True, + layer_scale: Optional[float] = None, + proj_drop: float = 0., + attn_drop: float = 0., + drop_path: Union[List[float], float] = 0.0, + act_layer: Callable = nn.GELU, + norm_layer: Callable = nn.LayerNorm, + norm_layer_cl: Callable = LayerNorm2d, + ): + super().__init__() + if downsample: + self.downsample = Downsample2d( + dim=dim, + dim_out=dim * 2, + norm_layer=norm_layer, + ) + dim = dim * 2 + feat_size = (feat_size[0] // 2, feat_size[1] // 2) + else: + self.downsample = nn.Identity() + self.feat_size = feat_size + + feat_levels = int(math.log2(min(feat_size) / window_size)) + self.global_block = FeatureBlock(dim, feat_levels) + self.global_norm = norm_layer_cl(dim) if global_norm else nn.Identity() + + self.blocks = nn.ModuleList([ + GlobalContextVitBlock( + dim=dim, + num_heads=num_heads, + feat_size=feat_size, + window_size=window_size, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + use_global=(i % 2 != 0), + layer_scale=layer_scale, + proj_drop=proj_drop, + attn_drop=attn_drop, + drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, + act_layer=act_layer, + norm_layer=norm_layer_cl, + ) + for i in range(depth) + ]) + self.norm = norm_layer_cl(dim) if stage_norm else nn.Identity() + self.dim = dim + self.feat_size = feat_size + self.grad_checkpointing = False + + def forward(self, x): + # input NCHW, downsample & global block are 2d conv + pooling + x = self.downsample(x) + global_query = self.global_block(x) + + # reshape NCHW --> NHWC for transformer blocks + x = x.permute(0, 2, 3, 1) + global_query = self.global_norm(global_query.permute(0, 2, 3, 1)) + for blk in self.blocks: + if self.grad_checkpointing and not torch.jit.is_scripting(): + x = checkpoint.checkpoint(blk, x) + else: + x = blk(x, global_query) + x = self.norm(x) + x = x.permute(0, 3, 1, 2).contiguous() # back to NCHW + return x + + +class GlobalContextVit(nn.Module): + def __init__( + self, + in_chans: int = 3, + num_classes: int = 1000, + global_pool: str = 'avg', + img_size: Tuple[int, int] = 224, + window_size: Tuple[int, ...] = (7, 7, 14, 7), + embed_dim: int = 64, + depths: Tuple[int, ...] = (3, 4, 19, 5), + num_heads: Tuple[int, ...] = (2, 4, 8, 16), + mlp_ratio: float = 3.0, + qkv_bias: bool = True, + layer_scale: Optional[float] = None, + drop_rate: float = 0., + proj_drop_rate: float = 0., + attn_drop_rate: float = 0., + drop_path_rate: float = 0., + weight_init='vit', + act_layer: str = 'gelu', + norm_layer: str = 'layernorm2d', + norm_layer_cl: str = 'layernorm', + norm_eps: float = 1e-5, + ): + super().__init__() + act_layer = get_act_layer(act_layer) + norm_layer = partial(get_norm_layer(norm_layer), eps=norm_eps) + norm_layer_cl = partial(get_norm_layer(norm_layer_cl), eps=norm_eps) + + img_size = to_2tuple(img_size) + feat_size = tuple(d // 4 for d in img_size) # stem reduction by 4 + self.global_pool = global_pool + self.num_classes = num_classes + self.drop_rate = drop_rate + num_stages = len(depths) + self.num_features = int(embed_dim * 2 ** (num_stages - 1)) + + self.stem = Stem( + in_chs=in_chans, + out_chs=embed_dim, + act_layer=act_layer, + norm_layer=norm_layer + ) + + dpr = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depths)).split(depths)] + stages = [] + for i in range(num_stages): + last_stage = i == num_stages - 1 + stage_scale = 2 ** max(i - 1, 0) + stages.append(GlobalContextVitStage( + dim=embed_dim * stage_scale, + depth=depths[i], + num_heads=num_heads[i], + feat_size=(feat_size[0] // stage_scale, feat_size[1] // stage_scale), + window_size=window_size[i], + downsample=i != 0, + stage_norm=last_stage, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + layer_scale=layer_scale, + proj_drop=proj_drop_rate, + attn_drop=attn_drop_rate, + drop_path=dpr[i], + act_layer=act_layer, + norm_layer=norm_layer, + norm_layer_cl=norm_layer_cl, + )) + self.stages = nn.Sequential(*stages) + + # Classifier head + self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=drop_rate) + + if weight_init: + named_apply(partial(self._init_weights, scheme=weight_init), self) + + def _init_weights(self, module, name, scheme='vit'): + # note Conv2d left as default init + if scheme == 'vit': + if isinstance(module, nn.Linear): + nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + if 'mlp' in name: + nn.init.normal_(module.bias, std=1e-6) + else: + nn.init.zeros_(module.bias) + else: + if isinstance(module, nn.Linear): + trunc_normal_tf_(module.weight, std=.02) + if module.bias is not None: + nn.init.zeros_(module.bias) + + @torch.jit.ignore + def no_weight_decay(self): + return { + k for k, _ in self.named_parameters() + if any(n in k for n in ["relative_position_bias_table", "rel_pos.mlp"])} + + + @torch.jit.ignore + def group_matcher(self, coarse=False): + matcher = dict( + stem=r'^stem', # stem and embed + blocks=r'^stages\.(\d+)' + ) + return matcher + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + for s in self.stages: + s.grad_checkpointing = enable + + @torch.jit.ignore + def get_classifier(self): + return self.head.fc + + def reset_classifier(self, num_classes, global_pool=None): + self.num_classes = num_classes + if global_pool is None: + global_pool = self.head.global_pool.pool_type + self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate) + + def forward_features(self, x: torch.Tensor) -> torch.Tensor: + x = self.stem(x) + x = self.stages(x) + return x + + def forward_head(self, x, pre_logits: bool = False): + return self.head(x, pre_logits=pre_logits) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.forward_features(x) + x = self.forward_head(x) + return x + + +def _create_gcvit(variant, pretrained=False, **kwargs): + if kwargs.get('features_only', None): + raise RuntimeError('features_only not implemented for Vision Transformer models.') + model = build_model_with_cfg(GlobalContextVit, variant, pretrained, **kwargs) + return model + + +@register_model +def gcvit_xxtiny(pretrained=False, **kwargs): + model_kwargs = dict( + depths=(2, 2, 6, 2), + num_heads=(2, 4, 8, 16), + **kwargs) + return _create_gcvit('gcvit_xxtiny', pretrained=pretrained, **model_kwargs) + + +@register_model +def gcvit_xtiny(pretrained=False, **kwargs): + model_kwargs = dict( + depths=(3, 4, 6, 5), + num_heads=(2, 4, 8, 16), + **kwargs) + return _create_gcvit('gcvit_xtiny', pretrained=pretrained, **model_kwargs) + + +@register_model +def gcvit_tiny(pretrained=False, **kwargs): + model_kwargs = dict( + depths=(3, 4, 19, 5), + num_heads=(2, 4, 8, 16), + **kwargs) + return _create_gcvit('gcvit_tiny', pretrained=pretrained, **model_kwargs) + + +@register_model +def gcvit_small(pretrained=False, **kwargs): + model_kwargs = dict( + depths=(3, 4, 19, 5), + num_heads=(3, 6, 12, 24), + window_size=(7, 7, 14, 7), + embed_dim=96, + mlp_ratio=2, + layer_scale=1e-5, + **kwargs) + return _create_gcvit('gcvit_small', pretrained=pretrained, **model_kwargs) + + +@register_model +def gcvit_base(pretrained=False, **kwargs): + model_kwargs = dict( + depths=(3, 4, 19, 5), + num_heads=(4, 8, 16, 32), + window_size=(7, 7, 14, 7), + embed_dim=128, + mlp_ratio=2, + layer_scale=1e-5, + **kwargs) + return _create_gcvit('gcvit_base', pretrained=pretrained, **model_kwargs) diff --git a/timm/models/layers/__init__.py b/timm/models/layers/__init__.py index b9eeec0f..21c641b6 100644 --- a/timm/models/layers/__init__.py +++ b/timm/models/layers/__init__.py @@ -11,21 +11,23 @@ from .conv_bn_act import ConvNormAct, ConvNormActAa, ConvBnAct from .create_act import create_act_layer, get_act_layer, get_act_fn from .create_attn import get_attn, create_attn from .create_conv2d import create_conv2d +from .create_norm import get_norm_layer, create_norm_layer from .create_norm_act import get_norm_act_layer, create_norm_act_layer, get_norm_act_layer from .drop import DropBlock2d, DropPath, drop_block_2d, drop_path from .eca import EcaModule, CecaModule, EfficientChannelAttn, CircularEfficientChannelAttn from .evo_norm import EvoNorm2dB0, EvoNorm2dB1, EvoNorm2dB2,\ EvoNorm2dS0, EvoNorm2dS0a, EvoNorm2dS1, EvoNorm2dS1a, EvoNorm2dS2, EvoNorm2dS2a +from .fast_norm import is_fast_norm, set_fast_norm, fast_group_norm, fast_layer_norm from .filter_response_norm import FilterResponseNormTlu2d, FilterResponseNormAct2d from .gather_excite import GatherExcite from .global_context import GlobalContext -from .helpers import to_ntuple, to_2tuple, to_3tuple, to_4tuple, make_divisible +from .helpers import to_ntuple, to_2tuple, to_3tuple, to_4tuple, make_divisible, extend_tuple from .inplace_abn import InplaceAbn from .linear import Linear from .mixed_conv2d import MixedConv2d from .mlp import Mlp, GluMlp, GatedMlp, ConvMlp from .non_local_attn import NonLocalAttn, BatNonLocalAttn -from .norm import GroupNorm, GroupNorm1, LayerNorm2d +from .norm import GroupNorm, GroupNorm1, LayerNorm, LayerNorm2d from .norm_act import BatchNormAct2d, GroupNormAct, convert_sync_batchnorm from .padding import get_padding, get_same_padding, pad_same from .patch_embed import PatchEmbed diff --git a/timm/models/layers/create_act.py b/timm/models/layers/create_act.py index e38f2e03..a3044a3d 100644 --- a/timm/models/layers/create_act.py +++ b/timm/models/layers/create_act.py @@ -145,4 +145,10 @@ def create_act_layer(name: Union[nn.Module, str], inplace=None, **kwargs): act_layer = get_act_layer(name) if act_layer is None: return None - return act_layer(**kwargs) if inplace is None else act_layer(inplace=inplace, **kwargs) + if inplace is None: + return act_layer(**kwargs) + try: + return act_layer(inplace=inplace, **kwargs) + except TypeError: + # recover if act layer doesn't have inplace arg + return act_layer(**kwargs) diff --git a/timm/models/layers/create_norm.py b/timm/models/layers/create_norm.py new file mode 100644 index 00000000..b9efae8c --- /dev/null +++ b/timm/models/layers/create_norm.py @@ -0,0 +1,56 @@ +""" Norm Layer Factory + +Create norm modules by string (to mirror create_act and creat_norm-act fns) + +Copyright 2022 Ross Wightman +""" +import types +import functools + +import torch.nn as nn + +from .norm import GroupNorm, GroupNorm1, LayerNorm, LayerNorm2d + +_NORM_MAP = dict( + batchnorm=nn.BatchNorm2d, + batchnorm2d=nn.BatchNorm2d, + batchnorm1d=nn.BatchNorm1d, + groupnorm=GroupNorm, + groupnorm1=GroupNorm1, + layernorm=LayerNorm, + layernorm2d=LayerNorm2d, +) +_NORM_TYPES = {m for n, m in _NORM_MAP.items()} + + +def create_norm_layer(layer_name, num_features, act_layer=None, apply_act=True, **kwargs): + layer = get_norm_layer(layer_name, act_layer=act_layer) + layer_instance = layer(num_features, apply_act=apply_act, **kwargs) + return layer_instance + + +def get_norm_layer(norm_layer): + assert isinstance(norm_layer, (type, str, types.FunctionType, functools.partial)) + norm_kwargs = {} + + # unbind partial fn, so args can be rebound later + if isinstance(norm_layer, functools.partial): + norm_kwargs.update(norm_layer.keywords) + norm_layer = norm_layer.func + + if isinstance(norm_layer, str): + layer_name = norm_layer.replace('_', '') + norm_layer = _NORM_MAP.get(layer_name, None) + elif norm_layer in _NORM_TYPES: + norm_layer = norm_layer + elif isinstance(norm_layer, types.FunctionType): + # if function type, assume it is a lambda/fn that creates a norm layer + norm_layer = norm_layer + else: + type_name = norm_layer.__name__.lower().replace('_', '') + norm_layer = _NORM_MAP.get(type_name, None) + assert norm_layer is not None, f"No equivalent norm layer for {type_name}" + + if norm_kwargs: + norm_layer = functools.partial(norm_layer, **norm_kwargs) # bind/rebind args + return norm_layer diff --git a/timm/models/layers/create_norm_act.py b/timm/models/layers/create_norm_act.py index cd15c2f8..78dd9a51 100644 --- a/timm/models/layers/create_norm_act.py +++ b/timm/models/layers/create_norm_act.py @@ -18,6 +18,7 @@ _NORM_ACT_MAP = dict( batchnorm=BatchNormAct2d, batchnorm2d=BatchNormAct2d, groupnorm=GroupNormAct, + groupnorm1=functools.partial(GroupNormAct, num_groups=1), layernorm=LayerNormAct, layernorm2d=LayerNormAct2d, evonormb0=EvoNorm2dB0, @@ -72,6 +73,8 @@ def get_norm_act_layer(norm_layer, act_layer=None): norm_act_layer = BatchNormAct2d elif type_name.startswith('groupnorm'): norm_act_layer = GroupNormAct + elif type_name.startswith('groupnorm1'): + norm_act_layer = functools.partial(GroupNormAct, num_groups=1) elif type_name.startswith('layernorm2d'): norm_act_layer = LayerNormAct2d elif type_name.startswith('layernorm'): diff --git a/timm/models/layers/fast_norm.py b/timm/models/layers/fast_norm.py new file mode 100644 index 00000000..9a34a15e --- /dev/null +++ b/timm/models/layers/fast_norm.py @@ -0,0 +1,68 @@ +from typing import List, Optional + +import torch +from torch.nn import functional as F + +try: + from apex.normalization.fused_layer_norm import fused_layer_norm_affine + has_apex = True +except ImportError: + has_apex = False + + +# fast (ie lower precision LN) can be disabled with this flag if issues crop up +_USE_FAST_NORM = False # defaulting to False for now + + +def is_fast_norm(): + return _USE_FAST_NORM + + +def set_fast_norm(enable=True): + global _USE_FAST_NORM + _USE_FAST_NORM = enable + + +def fast_group_norm( + x: torch.Tensor, + num_groups: int, + weight: Optional[torch.Tensor] = None, + bias: Optional[torch.Tensor] = None, + eps: float = 1e-5 +) -> torch.Tensor: + if torch.jit.is_scripting(): + # currently cannot use is_autocast_enabled within torchscript + return F.group_norm(x, num_groups, weight, bias, eps) + + if torch.is_autocast_enabled(): + # normally native AMP casts GN inputs to float32 + # here we use the low precision autocast dtype + dt = torch.get_autocast_gpu_dtype() + x, weight, bias = x.to(dt), weight.to(dt), bias.to(dt) + + with torch.cuda.amp.autocast(enabled=False): + return F.group_norm(x, num_groups, weight, bias, eps) + + +def fast_layer_norm( + x: torch.Tensor, + normalized_shape: List[int], + weight: Optional[torch.Tensor] = None, + bias: Optional[torch.Tensor] = None, + eps: float = 1e-5 +) -> torch.Tensor: + if torch.jit.is_scripting(): + # currently cannot use is_autocast_enabled within torchscript + return F.layer_norm(x, normalized_shape, weight, bias, eps) + + if has_apex: + return fused_layer_norm_affine(x, weight, bias, normalized_shape, eps) + + if torch.is_autocast_enabled(): + # normally native AMP casts LN inputs to float32 + # apex LN does not, this is behaving like Apex + dt = torch.get_autocast_gpu_dtype() + x, weight, bias = x.to(dt), weight.to(dt), bias.to(dt) + + with torch.cuda.amp.autocast(enabled=False): + return F.layer_norm(x, normalized_shape, weight, bias, eps) diff --git a/timm/models/layers/helpers.py b/timm/models/layers/helpers.py index cc54ca7f..2fa296bc 100644 --- a/timm/models/layers/helpers.py +++ b/timm/models/layers/helpers.py @@ -9,7 +9,7 @@ import collections.abc # From PyTorch internals def _ntuple(n): def parse(x): - if isinstance(x, collections.abc.Iterable): + if isinstance(x, collections.abc.Iterable) and not isinstance(x, str): return x return tuple(repeat(x, n)) return parse @@ -29,3 +29,15 @@ def make_divisible(v, divisor=8, min_value=None, round_limit=.9): if new_v < round_limit * v: new_v += divisor return new_v + + +def extend_tuple(x, n): + # pdas a tuple to specified n by padding with last value + if not isinstance(x, (tuple, list)): + x = (x,) + else: + x = tuple(x) + pad_n = n - len(x) + if pad_n <= 0: + return x[:n] + return x + (x[-1],) * pad_n diff --git a/timm/models/layers/norm.py b/timm/models/layers/norm.py index 1677dbfa..2ff8fc08 100644 --- a/timm/models/layers/norm.py +++ b/timm/models/layers/norm.py @@ -1,17 +1,24 @@ """ Normalization layers and wrappers """ + import torch import torch.nn as nn import torch.nn.functional as F +from .fast_norm import is_fast_norm, fast_group_norm, fast_layer_norm + class GroupNorm(nn.GroupNorm): def __init__(self, num_channels, num_groups=32, eps=1e-5, affine=True): # NOTE num_channels is swapped to first arg for consistency in swapping norm layers with BN super().__init__(num_groups, num_channels, eps=eps, affine=affine) + self.fast_norm = is_fast_norm() # can't script unless we have these flags here (no globals) def forward(self, x): - return F.group_norm(x, self.num_groups, self.weight, self.bias, self.eps) + if self.fast_norm: + return fast_group_norm(x, self.num_groups, self.weight, self.bias, self.eps) + else: + return F.group_norm(x, self.num_groups, self.weight, self.bias, self.eps) class GroupNorm1(nn.GroupNorm): @@ -21,22 +28,48 @@ class GroupNorm1(nn.GroupNorm): def __init__(self, num_channels, **kwargs): super().__init__(1, num_channels, **kwargs) + self.fast_norm = is_fast_norm() # can't script unless we have these flags here (no globals) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self.fast_norm: + return fast_group_norm(x, self.num_groups, self.weight, self.bias, self.eps) + else: + return F.group_norm(x, self.num_groups, self.weight, self.bias, self.eps) + + +class LayerNorm(nn.LayerNorm): + """ LayerNorm w/ fast norm option + """ + def __init__(self, num_channels, eps=1e-6, affine=True): + super().__init__(num_channels, eps=eps, elementwise_affine=affine) + self._fast_norm = is_fast_norm() # can't script unless we have these flags here (no globals) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self._fast_norm: + x = fast_layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) + else: + x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) + return x class LayerNorm2d(nn.LayerNorm): """ LayerNorm for channels of '2D' spatial NCHW tensors """ def __init__(self, num_channels, eps=1e-6, affine=True): super().__init__(num_channels, eps=eps, elementwise_affine=affine) + self._fast_norm = is_fast_norm() # can't script unless we have these flags here (no globals) def forward(self, x: torch.Tensor) -> torch.Tensor: - return F.layer_norm( - x.permute(0, 2, 3, 1), self.normalized_shape, self.weight, self.bias, self.eps).permute(0, 3, 1, 2) + x = x.permute(0, 2, 3, 1) + if self._fast_norm: + x = fast_layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) + else: + x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) + x = x.permute(0, 3, 1, 2) + return x def _is_contiguous(tensor: torch.Tensor) -> bool: # jit is oh so lovely :/ - # if torch.jit.is_tracing(): - # return True if torch.jit.is_scripting(): return tensor.is_contiguous() else: @@ -51,6 +84,14 @@ def _layer_norm_cf(x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, ep return x +def _layer_norm_cf_sqm(x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, eps: float): + u = x.mean(dim=1, keepdim=True) + s = ((x * x).mean(dim=1, keepdim=True) - (u * u)).clamp(0) + x = (x - u) * torch.rsqrt(s + eps) + x = x * weight.view(1, -1, 1, 1) + bias.view(1, -1, 1, 1) + return x + + class LayerNormExp2d(nn.LayerNorm): """ LayerNorm for channels_first tensors with 2d spatial dimensions (ie N, C, H, W). diff --git a/timm/models/layers/norm_act.py b/timm/models/layers/norm_act.py index 6bfdf201..3cd9fb36 100644 --- a/timm/models/layers/norm_act.py +++ b/timm/models/layers/norm_act.py @@ -6,8 +6,9 @@ import torch from torch import nn as nn from torch.nn import functional as F -from .trace_utils import _assert from .create_act import get_act_layer +from .fast_norm import is_fast_norm, fast_group_norm, fast_layer_norm +from .trace_utils import _assert class BatchNormAct2d(nn.BatchNorm2d): @@ -202,12 +203,16 @@ class GroupNormAct(nn.GroupNorm): self.act = act_layer(**act_args) else: self.act = nn.Identity() + self._fast_norm = is_fast_norm() def forward(self, x): if False: # FIXME TPU temporary while resolving some performance issues x = group_norm_tpu(x, self.weight, self.bias, self.num_groups, self.eps) else: - x = F.group_norm(x, self.num_groups, self.weight, self.bias, self.eps) + if self._fast_norm: + x = fast_group_norm(x, self.num_groups, self.weight, self.bias, self.eps) + else: + x = F.group_norm(x, self.num_groups, self.weight, self.bias, self.eps) x = self.drop(x) x = self.act(x) return x @@ -225,9 +230,13 @@ class LayerNormAct(nn.LayerNorm): self.act = act_layer(**act_args) else: self.act = nn.Identity() + self._fast_norm = is_fast_norm() def forward(self, x): - x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) + if self._fast_norm: + x = fast_layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) + else: + x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) x = self.drop(x) x = self.act(x) return x @@ -245,10 +254,15 @@ class LayerNormAct2d(nn.LayerNorm): self.act = act_layer(**act_args) else: self.act = nn.Identity() + self._fast_norm = is_fast_norm() def forward(self, x): - x = F.layer_norm( - x.permute(0, 2, 3, 1), self.normalized_shape, self.weight, self.bias, self.eps).permute(0, 3, 1, 2) + x = x.permute(0, 2, 3, 1) + if self._fast_norm: + x = fast_layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) + else: + x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) + x = x.permute(0, 3, 1, 2) x = self.drop(x) x = self.act(x) return x diff --git a/timm/models/layers/squeeze_excite.py b/timm/models/layers/squeeze_excite.py index e5da29ef..2e41d956 100644 --- a/timm/models/layers/squeeze_excite.py +++ b/timm/models/layers/squeeze_excite.py @@ -27,15 +27,15 @@ class SEModule(nn.Module): """ def __init__( self, channels, rd_ratio=1. / 16, rd_channels=None, rd_divisor=8, add_maxpool=False, - act_layer=nn.ReLU, norm_layer=None, gate_layer='sigmoid'): + bias=True, act_layer=nn.ReLU, norm_layer=None, gate_layer='sigmoid'): super(SEModule, self).__init__() self.add_maxpool = add_maxpool if not rd_channels: rd_channels = make_divisible(channels * rd_ratio, rd_divisor, round_limit=0.) - self.fc1 = nn.Conv2d(channels, rd_channels, kernel_size=1, bias=True) + self.fc1 = nn.Conv2d(channels, rd_channels, kernel_size=1, bias=bias) self.bn = norm_layer(rd_channels) if norm_layer else nn.Identity() self.act = create_act_layer(act_layer, inplace=True) - self.fc2 = nn.Conv2d(rd_channels, channels, kernel_size=1, bias=True) + self.fc2 = nn.Conv2d(rd_channels, channels, kernel_size=1, bias=bias) self.gate = create_act_layer(gate_layer) def forward(self, x): diff --git a/timm/models/maxxvit.py b/timm/models/maxxvit.py new file mode 100644 index 00000000..898e1685 --- /dev/null +++ b/timm/models/maxxvit.py @@ -0,0 +1,1693 @@ +""" MaxVit and CoAtNet Vision Transformer - CNN Hybrids in PyTorch + +This is a from-scratch implementation of both CoAtNet and MaxVit in PyTorch. + +99% of the implementation was done from papers, however last minute some adjustments were made +based on the (as yet unfinished?) public code release https://github.com/google-research/maxvit + +There are multiple sets of models defined for both architectures. Typically, names with a + `_rw` suffix are my own original configs prior to referencing https://github.com/google-research/maxvit. +These configs work well and appear to be a bit faster / lower resource than the paper. + +The models without extra prefix / suffix' (coatnet_0_224, maxvit_tiny_224, etc), are intended to +match paper, BUT, without any official pretrained weights it's difficult to confirm a 100% match. + +# FIXME / WARNING +This impl remains a WIP, some configs and models may vanish or change... + +Papers: + +MaxViT: Multi-Axis Vision Transformer - https://arxiv.org/abs/2204.01697 +@article{tu2022maxvit, + title={MaxViT: Multi-Axis Vision Transformer}, + author={Tu, Zhengzhong and Talebi, Hossein and Zhang, Han and Yang, Feng and Milanfar, Peyman and Bovik, Alan and Li, Yinxiao}, + journal={ECCV}, + year={2022}, +} + +CoAtNet: Marrying Convolution and Attention for All Data Sizes - https://arxiv.org/abs/2106.04803 +@article{DBLP:journals/corr/abs-2106-04803, + author = {Zihang Dai and Hanxiao Liu and Quoc V. Le and Mingxing Tan}, + title = {CoAtNet: Marrying Convolution and Attention for All Data Sizes}, + journal = {CoRR}, + volume = {abs/2106.04803}, + year = {2021} +} + +Hacked together by / Copyright 2022, Ross Wightman +""" + +import math +from collections import OrderedDict +from dataclasses import dataclass +from functools import partial +from typing import Callable, Optional, Union, Tuple, List + +import torch +from torch import nn +from torch.utils.checkpoint import checkpoint + +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from .helpers import build_model_with_cfg, checkpoint_seq, named_apply +from .fx_features import register_notrace_function +from .layers import Mlp, ConvMlp, DropPath, ClassifierHead, trunc_normal_tf_, LayerNorm2d, LayerNorm +from .layers import create_attn, get_act_layer, get_norm_layer, get_norm_act_layer, create_conv2d +from .layers import to_2tuple, extend_tuple, make_divisible, _assert +from .registry import register_model +from .vision_transformer_relpos import RelPosMlp, RelPosBias # FIXME move these to common location + +__all__ = ['MaxxVitCfg', 'MaxxVitConvCfg', 'MaxxVitTransformerCfg', 'MaxxVit'] + + +def _cfg(url='', **kwargs): + return { + 'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7), + 'crop_pct': 0.95, 'interpolation': 'bicubic', + 'mean': (0.5, 0.5, 0.5), 'std': (0.5, 0.5, 0.5), + 'first_conv': 'stem.conv1', 'classifier': 'head.fc', + 'fixed_input_size': True, + **kwargs + } + + +default_cfgs = { + # Fiddling with configs / defaults / still pretraining + 'coatnet_pico_rw_224': _cfg(url=''), + 'coatnet_nano_rw_224': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/coatnet_nano_rw_224_sw-f53093b4.pth', + crop_pct=0.9), + 'coatnet_0_rw_224': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/coatnet_0_rw_224_sw-a6439706.pth'), + 'coatnet_1_rw_224': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/coatnet_1_rw_224_sw-5cae1ea8.pth' + ), + 'coatnet_2_rw_224': _cfg(url=''), + + # Highly experimental configs + 'coatnet_bn_0_rw_224': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/coatnet_bn_0_rw_224_sw-c228e218.pth', + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, + crop_pct=0.95), + 'coatnet_rmlp_nano_rw_224': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/coatnet_rmlp_nano_rw_224_sw-bd1d51b3.pth', + crop_pct=0.9), + 'coatnet_rmlp_0_rw_224': _cfg(url=''), + 'coatnet_rmlp_1_rw_224': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/coatnet_rmlp_1_rw_224_sw-9051e6c3.pth'), + 'coatnet_nano_cc_224': _cfg(url=''), + 'coatnext_nano_rw_224': _cfg(url=''), + + # Trying to be like the CoAtNet paper configs + 'coatnet_0_224': _cfg(url=''), + 'coatnet_1_224': _cfg(url=''), + 'coatnet_2_224': _cfg(url=''), + 'coatnet_3_224': _cfg(url=''), + 'coatnet_4_224': _cfg(url=''), + 'coatnet_5_224': _cfg(url=''), + + # Experimental configs + 'maxvit_pico_rw_256': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8)), + 'maxvit_nano_rw_256': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/maxvit_nano_rw_256_sw-3e790ce3.pth', + input_size=(3, 256, 256), pool_size=(8, 8)), + 'maxvit_tiny_rw_224': _cfg(url=''), + 'maxvit_tiny_rw_256': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8)), + 'maxvit_tiny_pm_256': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8)), + 'maxxvit_nano_rw_256': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8)), + + # Trying to be like the MaxViT paper configs + 'maxvit_tiny_224': _cfg(url=''), + 'maxvit_small_224': _cfg(url=''), + 'maxvit_base_224': _cfg(url=''), + 'maxvit_large_224': _cfg(url=''), + 'maxvit_xlarge_224': _cfg(url=''), +} + + +@dataclass +class MaxxVitTransformerCfg: + dim_head: int = 32 + expand_ratio: float = 4.0 + expand_first: bool = True + shortcut_bias: bool = True, + attn_bias: bool = True + attn_drop: float = 0. + proj_drop: float = 0. + pool_type: str = 'avg2' + rel_pos_type: str = 'bias' + rel_pos_dim: int = 512 # for relative position types w/ MLP + window_size: Tuple[int, int] = (7, 7) + grid_size: Tuple[int, int] = (7, 7) + init_values: Optional[float] = None + act_layer: str = 'gelu' + norm_layer: str = 'layernorm2d' + norm_layer_cl: str = 'layernorm' + norm_eps: float = 1e-6 + + +@dataclass +class MaxxVitConvCfg: + block_type: str = 'mbconv' + expand_ratio: float = 4.0 + expand_output: bool = True # calculate expansion channels from output (vs input chs) + kernel_size: int = 3 + group_size: int = 1 # 1 == depthwise + pre_norm_act: bool = False # activation after pre-norm + output_bias: bool = True # bias for shortcut + final 1x1 projection conv + stride_mode: str = 'dw' # stride done via one of 'pool', '1x1', 'dw' + pool_type: str = 'avg2' + downsample_pool_type: str = 'avg2' + attn_early: bool = False # apply attn between conv2 and norm2, instead of after norm2 + attn_layer: str = 'se' + attn_act_layer: str = 'silu' + attn_ratio: float = 0.25 + init_values: Optional[float] = 1e-5 # for ConvNeXt block + act_layer: str = 'gelu' + norm_layer: str = '' + norm_layer_cl: str = '' + norm_eps: Optional[float] = None + + def __post_init__(self): + # mbconv vs convnext blocks have different defaults, set in post_init to avoid explicit config args + assert self.block_type in ('mbconv', 'convnext') + use_mbconv = self.block_type == 'mbconv' + if not self.norm_layer: + self.norm_layer = 'batchnorm2d' if use_mbconv else 'layernorm2d' + if not self.norm_layer_cl and not use_mbconv: + self.norm_layer_cl = 'layernorm' + if self.norm_eps is None: + self.norm_eps = 1e-5 if use_mbconv else 1e-6 + self.downsample_pool_type = self.downsample_pool_type or self.pool_type + + +@dataclass +class MaxxVitCfg: + embed_dim: Tuple[int, ...] = (96, 192, 384, 768) + depths: Tuple[int, ...] = (2, 3, 5, 2) + block_type: Tuple[Union[str, Tuple[str, ...]], ...] = ('C', 'C', 'T', 'T') + stem_width: Union[int, Tuple[int, int]] = 64 + stem_bias: bool = True + conv_cfg: MaxxVitConvCfg = MaxxVitConvCfg() + transformer_cfg: MaxxVitTransformerCfg = MaxxVitTransformerCfg() + weight_init: str = 'vit_eff' + + +def _rw_coat_cfg( + stride_mode='pool', + pool_type='avg2', + conv_output_bias=False, + conv_attn_early=False, + conv_norm_layer='', + transformer_shortcut_bias=True, + transformer_norm_layer='layernorm2d', + transformer_norm_layer_cl='layernorm', + rel_pos_type='bias', + rel_pos_dim=512, +): + # 'RW' timm variant models were created and trained before seeing https://github.com/google-research/maxvit + # Common differences for initial timm models: + # - pre-norm layer in MZBConv included an activation after norm + # - mbconv expansion calculated from input instead of output chs + # - mbconv shortcut and final 1x1 conv did not have a bias + # - SE act layer was relu, not silu + # - mbconv uses silu in timm, not gelu + # - expansion in attention block done via output proj, not input proj + # Variable differences (evolved over training initial models): + # - avg pool with kernel_size=2 favoured downsampling (instead of maxpool for coat) + # - SE attention was between conv2 and norm/act + # - default to avg pool for mbconv downsample instead of 1x1 or dw conv + # - transformer block shortcut has no bias + return dict( + conv_cfg=MaxxVitConvCfg( + stride_mode=stride_mode, + pool_type=pool_type, + pre_norm_act=True, + expand_output=False, + output_bias=conv_output_bias, + attn_early=conv_attn_early, + attn_act_layer='relu', + act_layer='silu', + norm_layer=conv_norm_layer, + ), + transformer_cfg=MaxxVitTransformerCfg( + expand_first=False, + shortcut_bias=transformer_shortcut_bias, + pool_type=pool_type, + norm_layer=transformer_norm_layer, + norm_layer_cl=transformer_norm_layer_cl, + rel_pos_type=rel_pos_type, + rel_pos_dim=rel_pos_dim, + ), + ) + + +def _rw_max_cfg( + stride_mode='dw', + pool_type='avg2', + conv_output_bias=False, + conv_attn_ratio=1 / 16, + conv_norm_layer='', + transformer_norm_layer='layernorm2d', + transformer_norm_layer_cl='layernorm', + window_size=7, + dim_head=32, + rel_pos_type='bias', + rel_pos_dim=512, +): + # 'RW' timm variant models were created and trained before seeing https://github.com/google-research/maxvit + # Differences of initial timm models: + # - mbconv expansion calculated from input instead of output chs + # - mbconv shortcut and final 1x1 conv did not have a bias + # - mbconv uses silu in timm, not gelu + # - avg pool with kernel_size=2 favoured downsampling (instead of maxpool for coat) + # - default to avg pool for mbconv downsample instead of 1x1 or dw conv + # - expansion in attention block done via output proj, not input proj + return dict( + conv_cfg=MaxxVitConvCfg( + stride_mode=stride_mode, + pool_type=pool_type, + expand_output=False, + output_bias=conv_output_bias, + attn_ratio=conv_attn_ratio, + act_layer='silu', + norm_layer=conv_norm_layer, + ), + transformer_cfg=MaxxVitTransformerCfg( + expand_first=False, + pool_type=pool_type, + dim_head=dim_head, + window_size=to_2tuple(window_size), + grid_size=to_2tuple(window_size), + norm_layer=transformer_norm_layer, + norm_layer_cl=transformer_norm_layer_cl, + rel_pos_type=rel_pos_type, + rel_pos_dim=rel_pos_dim, + ), + ) + + +def _next_cfg( + stride_mode='dw', + pool_type='avg2', + conv_norm_layer='layernorm2d', + conv_norm_layer_cl='layernorm', + transformer_norm_layer='layernorm2d', + transformer_norm_layer_cl='layernorm', + window_size=7, + rel_pos_type='bias', + rel_pos_dim=512, +): + # For experimental models with convnext instead of mbconv + return dict( + conv_cfg=MaxxVitConvCfg( + block_type='convnext', + stride_mode=stride_mode, + pool_type=pool_type, + expand_output=False, + norm_layer=conv_norm_layer, + norm_layer_cl=conv_norm_layer_cl, + ), + transformer_cfg=MaxxVitTransformerCfg( + expand_first=False, + pool_type=pool_type, + window_size=to_2tuple(window_size), + grid_size=to_2tuple(window_size), + norm_layer=transformer_norm_layer, + norm_layer_cl=transformer_norm_layer_cl, + rel_pos_type=rel_pos_type, + rel_pos_dim=rel_pos_dim, + ), + ) + + +model_cfgs = dict( + # Fiddling with configs / defaults / still pretraining + coatnet_pico_rw_224=MaxxVitCfg( + embed_dim=(64, 128, 256, 512), + depths=(2, 3, 5, 2), + stem_width=(32, 64), + **_rw_max_cfg( # using newer max defaults here + conv_output_bias=True, + conv_attn_ratio=0.25, + ), + ), + coatnet_nano_rw_224=MaxxVitCfg( + embed_dim=(64, 128, 256, 512), + depths=(3, 4, 6, 3), + stem_width=(32, 64), + **_rw_max_cfg( # using newer max defaults here + stride_mode='pool', + conv_output_bias=True, + conv_attn_ratio=0.25, + ), + ), + coatnet_0_rw_224=MaxxVitCfg( + embed_dim=(96, 192, 384, 768), + depths=(2, 3, 7, 2), # deeper than paper '0' model + stem_width=(32, 64), + **_rw_coat_cfg( + conv_attn_early=True, + transformer_shortcut_bias=False, + ), + ), + coatnet_1_rw_224=MaxxVitCfg( + embed_dim=(96, 192, 384, 768), + depths=(2, 6, 14, 2), + stem_width=(32, 64), + **_rw_coat_cfg( + stride_mode='dw', + conv_attn_early=True, + transformer_shortcut_bias=False, + ) + ), + coatnet_2_rw_224=MaxxVitCfg( + embed_dim=(128, 256, 512, 1024), + depths=(2, 6, 14, 2), + stem_width=(64, 128), + **_rw_coat_cfg(stride_mode='dw'), + ), + + # Highly experimental configs + coatnet_bn_0_rw_224=MaxxVitCfg( + embed_dim=(96, 192, 384, 768), + depths=(2, 3, 7, 2), # deeper than paper '0' model + stem_width=(32, 64), + **_rw_coat_cfg( + stride_mode='dw', + conv_attn_early=True, + transformer_shortcut_bias=False, + transformer_norm_layer='batchnorm2d', + ) + ), + coatnet_rmlp_nano_rw_224=MaxxVitCfg( + embed_dim=(64, 128, 256, 512), + depths=(3, 4, 6, 3), + stem_width=(32, 64), + **_rw_max_cfg( + conv_output_bias=True, + conv_attn_ratio=0.25, + rel_pos_type='mlp', + rel_pos_dim=384, + ), + ), + coatnet_rmlp_0_rw_224=MaxxVitCfg( + embed_dim=(96, 192, 384, 768), + depths=(2, 3, 7, 2), # deeper than paper '0' model + stem_width=(32, 64), + **_rw_coat_cfg( + stride_mode='dw', + rel_pos_type='mlp', + ), + ), + coatnet_rmlp_1_rw_224=MaxxVitCfg( + embed_dim=(96, 192, 384, 768), + depths=(2, 6, 14, 2), + stem_width=(32, 64), + **_rw_coat_cfg( + pool_type='max', + conv_attn_early=True, + transformer_shortcut_bias=False, + rel_pos_type='mlp', + rel_pos_dim=384, # was supposed to be 512, woops + ), + ), + coatnext_nano_rw_224=MaxxVitCfg( + embed_dim=(64, 128, 256, 512), + depths=(3, 4, 6, 3), + stem_width=(32, 64), + **_next_cfg(), + ), + coatnet_nano_cc_224=MaxxVitCfg( + embed_dim=(64, 128, 256, 512), + depths=(3, 4, 6, 3), + stem_width=(32, 64), + block_type=('C', 'C', ('C', 'T'), ('C', 'T')), + **_rw_coat_cfg(), + ), + + # Trying to be like the CoAtNet paper configs + coatnet_0_224=MaxxVitCfg( + embed_dim=(96, 192, 384, 768), + depths=(2, 3, 5, 2), + stem_width=64, + ), + coatnet_1_224=MaxxVitCfg( + embed_dim=(96, 192, 384, 768), + depths=(2, 6, 14, 2), + stem_width=64, + ), + coatnet_2_224=MaxxVitCfg( + embed_dim=(128, 256, 512, 1024), + depths=(2, 6, 14, 2), + stem_width=128, + ), + coatnet_3_224=MaxxVitCfg( + embed_dim=(192, 384, 768, 1536), + depths=(2, 6, 14, 2), + stem_width=192, + ), + coatnet_4_224=MaxxVitCfg( + embed_dim=(192, 384, 768, 1536), + depths=(2, 12, 28, 2), + stem_width=192, + ), + coatnet_5_224=MaxxVitCfg( + embed_dim=(256, 512, 1280, 2048), + depths=(2, 12, 28, 2), + stem_width=192, + ), + + # Experimental MaxVit configs + maxvit_pico_rw_256=MaxxVitCfg( + embed_dim=(32, 64, 128, 256), + depths=(2, 2, 5, 2), + block_type=('M',) * 4, + stem_width=(24, 32), + **_rw_max_cfg(window_size=8), + ), + maxvit_nano_rw_256=MaxxVitCfg( + embed_dim=(64, 128, 256, 512), + depths=(1, 2, 3, 1), + block_type=('M',) * 4, + stem_width=(32, 64), + **_rw_max_cfg(window_size=8), + ), + maxvit_tiny_rw_224=MaxxVitCfg( + embed_dim=(64, 128, 256, 512), + depths=(2, 2, 5, 2), + block_type=('M',) * 4, + stem_width=(32, 64), + **_rw_max_cfg(), + ), + maxvit_tiny_rw_256=MaxxVitCfg( + embed_dim=(64, 128, 256, 512), + depths=(2, 2, 5, 2), + block_type=('M',) * 4, + stem_width=(32, 64), + **_rw_max_cfg(window_size=8), + ), + maxvit_tiny_pm_256=MaxxVitCfg( + embed_dim=(64, 128, 256, 512), + depths=(2, 2, 5, 2), + block_type=('PM',) * 4, + stem_width=(32, 64), + **_rw_max_cfg(window_size=8), + ), + maxxvit_nano_rw_256=MaxxVitCfg( + embed_dim=(64, 128, 256, 512), + depths=(1, 2, 3, 1), + block_type=('M',) * 4, + stem_width=(32, 64), + **_next_cfg(window_size=8), + ), + + # Trying to be like the MaxViT paper configs + maxvit_tiny_224=MaxxVitCfg( + embed_dim=(64, 128, 256, 512), + depths=(2, 2, 5, 2), + block_type=('M',) * 4, + stem_width=64, + ), + maxvit_small_224=MaxxVitCfg( + embed_dim=(96, 192, 384, 768), + depths=(2, 2, 5, 2), + block_type=('M',) * 4, + stem_width=64, + ), + maxvit_base_224=MaxxVitCfg( + embed_dim=(96, 192, 384, 768), + depths=(2, 6, 14, 2), + block_type=('M',) * 4, + stem_width=64, + ), + maxvit_large_224=MaxxVitCfg( + embed_dim=(128, 256, 512, 1024), + depths=(2, 6, 14, 2), + block_type=('M',) * 4, + stem_width=128, + ), + maxvit_xlarge_224=MaxxVitCfg( + embed_dim=(192, 384, 768, 1536), + depths=(2, 6, 14, 2), + block_type=('M',) * 4, + stem_width=192, + ), + +) + + +class Attention2d(nn.Module): + """ multi-head attention for 2D NCHW tensors""" + def __init__( + self, + dim: int, + dim_out: Optional[int] = None, + dim_head: int = 32, + bias: bool = True, + expand_first: bool = True, + rel_pos_cls: Callable = None, + attn_drop: float = 0., + proj_drop: float = 0. + ): + super().__init__() + dim_out = dim_out or dim + dim_attn = dim_out if expand_first else dim + self.num_heads = dim_attn // dim_head + self.dim_head = dim_head + self.scale = dim_head ** -0.5 + + self.qkv = nn.Conv2d(dim, dim_attn * 3, 1, bias=bias) + self.rel_pos = rel_pos_cls(num_heads=self.num_heads) if rel_pos_cls else None + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Conv2d(dim_attn, dim_out, 1, bias=bias) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x, shared_rel_pos: Optional[torch.Tensor] = None): + B, C, H, W = x.shape + + q, k, v = self.qkv(x).view(B, self.num_heads, self.dim_head * 3, -1).chunk(3, dim=2) + + attn = (q.transpose(-2, -1) @ k) * self.scale + if self.rel_pos is not None: + attn = self.rel_pos(attn) + elif shared_rel_pos is not None: + attn = attn + shared_rel_pos + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (v @ attn.transpose(-2, -1)).view(B, -1, H, W) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class AttentionCl(nn.Module): + """ Channels-last multi-head attention (B, ..., C) """ + def __init__( + self, + dim: int, + dim_out: Optional[int] = None, + dim_head: int = 32, + bias: bool = True, + expand_first: bool = True, + rel_pos_cls: Callable = None, + attn_drop: float = 0., + proj_drop: float = 0. + ): + super().__init__() + dim_out = dim_out or dim + dim_attn = dim_out if expand_first and dim_out > dim else dim + assert dim_attn % dim_head == 0, 'attn dim should be divisible by head_dim' + self.num_heads = dim_attn // dim_head + self.dim_head = dim_head + self.scale = dim_head ** -0.5 + + self.qkv = nn.Linear(dim, dim_attn * 3, bias=bias) + self.rel_pos = rel_pos_cls(num_heads=self.num_heads) if rel_pos_cls else None + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim_attn, dim_out, bias=bias) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x, shared_rel_pos: Optional[torch.Tensor] = None): + B = x.shape[0] + restore_shape = x.shape[:-1] + + q, k, v = self.qkv(x).view(B, -1, self.num_heads, self.dim_head * 3).transpose(1, 2).chunk(3, dim=3) + + attn = (q @ k.transpose(-2, -1)) * self.scale + if self.rel_pos is not None: + attn = self.rel_pos(attn, shared_rel_pos=shared_rel_pos) + elif shared_rel_pos is not None: + attn = attn + shared_rel_pos + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(restore_shape + (-1,)) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class LayerScale(nn.Module): + def __init__(self, dim, init_values=1e-5, inplace=False): + super().__init__() + self.inplace = inplace + self.gamma = nn.Parameter(init_values * torch.ones(dim)) + + def forward(self, x): + gamma = self.gamma + return x.mul_(gamma) if self.inplace else x * gamma + + +class LayerScale2d(nn.Module): + def __init__(self, dim, init_values=1e-5, inplace=False): + super().__init__() + self.inplace = inplace + self.gamma = nn.Parameter(init_values * torch.ones(dim)) + + def forward(self, x): + gamma = self.gamma.view(1, -1, 1, 1) + return x.mul_(gamma) if self.inplace else x * gamma + + +class Downsample2d(nn.Module): + """ A downsample pooling module for Coat that handles 2d <-> 1d conversion + """ + + def __init__( + self, + dim: int, + dim_out: int, + pool_type: str = 'avg2', + bias: bool = True, + ): + super().__init__() + assert pool_type in ('max', 'max2', 'avg', 'avg2') + if pool_type == 'max': + self.pool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + elif pool_type == 'max2': + self.pool = nn.MaxPool2d(2) # kernel_size == stride == 2 + elif pool_type == 'avg': + self.pool = nn.AvgPool2d(kernel_size=3, stride=2, padding=1, count_include_pad=False) + else: + self.pool = nn.AvgPool2d(2) # kernel_size == stride == 2 + + if dim != dim_out: + self.expand = nn.Conv2d(dim, dim_out, 1, bias=bias) + else: + self.expand = nn.Identity() + + def forward(self, x): + x = self.pool(x) # spatial downsample + x = self.expand(x) # expand chs + return x + + +def _init_transformer(module, name, scheme=''): + if isinstance(module, (nn.Conv2d, nn.Linear)): + if scheme == 'normal': + nn.init.normal_(module.weight, std=.02) + if module.bias is not None: + nn.init.zeros_(module.bias) + elif scheme == 'trunc_normal': + trunc_normal_tf_(module.weight, std=.02) + if module.bias is not None: + nn.init.zeros_(module.bias) + elif scheme == 'xavier_normal': + nn.init.xavier_normal_(module.weight) + if module.bias is not None: + nn.init.zeros_(module.bias) + else: + # vit like + nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + if 'mlp' in name: + nn.init.normal_(module.bias, std=1e-6) + else: + nn.init.zeros_(module.bias) + + +class TransformerBlock2d(nn.Module): + """ Transformer block with 2D downsampling + '2D' NCHW tensor layout + """ + + def __init__( + self, + dim: int, + dim_out: int, + stride: int = 1, + rel_pos_cls: Callable = None, + cfg: MaxxVitTransformerCfg = MaxxVitTransformerCfg(), + drop_path: float = 0., + ): + super().__init__() + norm_layer = partial(get_norm_layer(cfg.norm_layer), eps=cfg.norm_eps) + act_layer = get_act_layer(cfg.act_layer) + + if stride == 2: + self.shortcut = Downsample2d(dim, dim_out, pool_type=cfg.pool_type, bias=cfg.shortcut_bias) + self.norm1 = nn.Sequential(OrderedDict([ + ('norm', norm_layer(dim)), + ('down', Downsample2d(dim, dim, pool_type=cfg.pool_type)), + ])) + else: + assert dim == dim_out + self.shortcut = nn.Identity() + self.norm1 = norm_layer(dim) + + self.attn = Attention2d( + dim, + dim_out, + dim_head=cfg.dim_head, + expand_first=cfg.expand_first, + bias=cfg.attn_bias, + rel_pos_cls=rel_pos_cls, + attn_drop=cfg.attn_drop, + proj_drop=cfg.proj_drop + ) + self.ls1 = LayerScale2d(dim_out, init_values=cfg.init_values) if cfg.init_values else nn.Identity() + self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity() + + self.norm2 = norm_layer(dim_out) + self.mlp = ConvMlp( + in_features=dim_out, + hidden_features=int(dim_out * cfg.expand_ratio), + act_layer=act_layer, + drop=cfg.proj_drop) + self.ls2 = LayerScale2d(dim_out, init_values=cfg.init_values) if cfg.init_values else nn.Identity() + self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity() + + def init_weights(self, scheme=''): + named_apply(partial(_init_transformer, scheme=scheme), self) + + def forward(self, x, shared_rel_pos: Optional[torch.Tensor] = None): + x = self.shortcut(x) + self.drop_path1(self.ls1(self.attn(self.norm1(x), shared_rel_pos=shared_rel_pos))) + x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x)))) + return x + + +def _init_conv(module, name, scheme=''): + if isinstance(module, nn.Conv2d): + if scheme == 'normal': + nn.init.normal_(module.weight, std=.02) + if module.bias is not None: + nn.init.zeros_(module.bias) + elif scheme == 'trunc_normal': + trunc_normal_tf_(module.weight, std=.02) + if module.bias is not None: + nn.init.zeros_(module.bias) + elif scheme == 'xavier_normal': + nn.init.xavier_normal_(module.weight) + if module.bias is not None: + nn.init.zeros_(module.bias) + else: + # efficientnet like + fan_out = module.kernel_size[0] * module.kernel_size[1] * module.out_channels + fan_out //= module.groups + nn.init.normal_(module.weight, 0, math.sqrt(2.0 / fan_out)) + if module.bias is not None: + nn.init.zeros_(module.bias) + + +def num_groups(group_size, channels): + if not group_size: # 0 or None + return 1 # normal conv with 1 group + else: + # NOTE group_size == 1 -> depthwise conv + assert channels % group_size == 0 + return channels // group_size + + +class MbConvBlock(nn.Module): + """ Pre-Norm Conv Block - 1x1 - kxk - 1x1, w/ inverted bottleneck (expand) + """ + def __init__( + self, + in_chs: int, + out_chs: int, + stride: int = 1, + dilation: Tuple[int, int] = (1, 1), + cfg: MaxxVitConvCfg = MaxxVitConvCfg(), + drop_path: float = 0. + ): + super(MbConvBlock, self).__init__() + norm_act_layer = partial(get_norm_act_layer(cfg.norm_layer, cfg.act_layer), eps=cfg.norm_eps) + mid_chs = make_divisible((out_chs if cfg.expand_output else in_chs) * cfg.expand_ratio) + groups = num_groups(cfg.group_size, mid_chs) + + if stride == 2: + self.shortcut = Downsample2d(in_chs, out_chs, pool_type=cfg.pool_type, bias=cfg.output_bias) + else: + self.shortcut = nn.Identity() + + assert cfg.stride_mode in ('pool', '1x1', 'dw') + stride_pool, stride_1, stride_2 = 1, 1, 1 + if cfg.stride_mode == 'pool': + # NOTE this is not described in paper, experiment to find faster option that doesn't stride in 1x1 + stride_pool, dilation_2 = stride, dilation[1] + # FIXME handle dilation of avg pool + elif cfg.stride_mode == '1x1': + # NOTE I don't like this option described in paper, 1x1 w/ stride throws info away + stride_1, dilation_2 = stride, dilation[1] + else: + stride_2, dilation_2 = stride, dilation[0] + + self.pre_norm = norm_act_layer(in_chs, apply_act=cfg.pre_norm_act) + if stride_pool > 1: + self.down = Downsample2d(in_chs, in_chs, pool_type=cfg.downsample_pool_type) + else: + self.down = nn.Identity() + self.conv1_1x1 = create_conv2d(in_chs, mid_chs, 1, stride=stride_1) + self.norm1 = norm_act_layer(mid_chs) + + self.conv2_kxk = create_conv2d( + mid_chs, mid_chs, cfg.kernel_size, stride=stride_2, dilation=dilation_2, groups=groups) + + attn_kwargs = {} + if isinstance(cfg.attn_layer, str): + if cfg.attn_layer == 'se' or cfg.attn_layer == 'eca': + attn_kwargs['act_layer'] = cfg.attn_act_layer + attn_kwargs['rd_channels'] = int(cfg.attn_ratio * (out_chs if cfg.expand_output else mid_chs)) + + # two different orderings for SE and norm2 (due to some weights and trials using SE before norm2) + if cfg.attn_early: + self.se_early = create_attn(cfg.attn_layer, mid_chs, **attn_kwargs) + self.norm2 = norm_act_layer(mid_chs) + self.se = None + else: + self.se_early = None + self.norm2 = norm_act_layer(mid_chs) + self.se = create_attn(cfg.attn_layer, mid_chs, **attn_kwargs) + + self.conv3_1x1 = create_conv2d(mid_chs, out_chs, 1, bias=cfg.output_bias) + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + + def init_weights(self, scheme=''): + named_apply(partial(_init_conv, scheme=scheme), self) + + def forward(self, x): + shortcut = self.shortcut(x) + x = self.pre_norm(x) + x = self.down(x) + + # 1x1 expansion conv & norm-act + x = self.conv1_1x1(x) + x = self.norm1(x) + + # depthwise / grouped 3x3 conv w/ SE (or other) channel attention & norm-act + x = self.conv2_kxk(x) + if self.se_early is not None: + x = self.se_early(x) + x = self.norm2(x) + if self.se is not None: + x = self.se(x) + + # 1x1 linear projection to output width + x = self.conv3_1x1(x) + x = self.drop_path(x) + shortcut + return x + + +class ConvNeXtBlock(nn.Module): + """ ConvNeXt Block + """ + + def __init__( + self, + in_chs: int, + out_chs: Optional[int] = None, + kernel_size: int = 7, + stride: int = 1, + dilation: Tuple[int, int] = (1, 1), + cfg: MaxxVitConvCfg = MaxxVitConvCfg(), + conv_mlp: bool = True, + drop_path: float = 0. + ): + super().__init__() + out_chs = out_chs or in_chs + act_layer = get_act_layer(cfg.act_layer) + if conv_mlp: + norm_layer = partial(get_norm_layer(cfg.norm_layer), eps=cfg.norm_eps) + mlp_layer = ConvMlp + else: + assert 'layernorm' in cfg.norm_layer + norm_layer = LayerNorm + mlp_layer = Mlp + self.use_conv_mlp = conv_mlp + + if stride == 2: + self.shortcut = Downsample2d(in_chs, out_chs) + elif in_chs != out_chs: + self.shortcut = nn.Conv2d(in_chs, out_chs, kernel_size=1, bias=cfg.output_bias) + else: + self.shortcut = nn.Identity() + + assert cfg.stride_mode in ('pool', 'dw') + stride_pool, stride_dw = 1, 1 + # FIXME handle dilation? + if cfg.stride_mode == 'pool': + stride_pool = stride + else: + stride_dw = stride + + if stride_pool == 2: + self.down = Downsample2d(in_chs, in_chs, pool_type=cfg.downsample_pool_type) + else: + self.down = nn.Identity() + + self.conv_dw = create_conv2d( + in_chs, out_chs, kernel_size=kernel_size, stride=stride_dw, dilation=dilation[1], + depthwise=True, bias=cfg.output_bias) + self.norm = norm_layer(out_chs) + self.mlp = mlp_layer(out_chs, int(cfg.expand_ratio * out_chs), bias=cfg.output_bias, act_layer=act_layer) + if conv_mlp: + self.ls = LayerScale2d(out_chs, cfg.init_values) if cfg.init_values else nn.Identity() + else: + self.ls = LayerScale(out_chs, cfg.init_values) if cfg.init_values else nn.Identity() + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + + def forward(self, x): + shortcut = self.shortcut(x) + x = self.down(x) + x = self.conv_dw(x) + if self.use_conv_mlp: + x = self.norm(x) + x = self.mlp(x) + x = self.ls(x) + else: + x = x.permute(0, 2, 3, 1) + x = self.norm(x) + x = self.mlp(x) + x = self.ls(x) + x = x.permute(0, 3, 1, 2) + + x = self.drop_path(x) + shortcut + return x + + +def window_partition(x, window_size: List[int]): + B, H, W, C = x.shape + _assert(H % window_size[0] == 0, f'height ({H}) must be divisible by window ({window_size[0]})') + _assert(W % window_size[1] == 0, '') + x = x.view(B, H // window_size[0], window_size[0], W // window_size[1], window_size[1], C) + windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size[0], window_size[1], C) + return windows + + +@register_notrace_function # reason: int argument is a Proxy +def window_reverse(windows, window_size: List[int], img_size: List[int]): + H, W = img_size + C = windows.shape[-1] + x = windows.view(-1, H // window_size[0], W // window_size[1], window_size[0], window_size[1], C) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, H, W, C) + return x + + +def grid_partition(x, grid_size: List[int]): + B, H, W, C = x.shape + _assert(H % grid_size[0] == 0, f'height {H} must be divisible by grid {grid_size[0]}') + _assert(W % grid_size[1] == 0, '') + x = x.view(B, grid_size[0], H // grid_size[0], grid_size[1], W // grid_size[1], C) + windows = x.permute(0, 2, 4, 1, 3, 5).contiguous().view(-1, grid_size[0], grid_size[1], C) + return windows + + +@register_notrace_function # reason: int argument is a Proxy +def grid_reverse(windows, grid_size: List[int], img_size: List[int]): + H, W = img_size + C = windows.shape[-1] + x = windows.view(-1, H // grid_size[0], W // grid_size[1], grid_size[0], grid_size[1], C) + x = x.permute(0, 3, 1, 4, 2, 5).contiguous().view(-1, H, W, C) + return x + + +def get_rel_pos_cls(cfg: MaxxVitTransformerCfg, window_size): + rel_pos_cls = None + if cfg.rel_pos_type == 'mlp': + rel_pos_cls = partial(RelPosMlp, window_size=window_size, hidden_dim=cfg.rel_pos_dim) + elif cfg.rel_pos_type == 'bias': + rel_pos_cls = partial(RelPosBias, window_size=window_size) + return rel_pos_cls + + +class PartitionAttention(nn.Module): + """ Grid or Block partition + Attn + FFN. + NxC tensor layout. + """ + + def __init__( + self, + dim: int, + partition_type: str = 'block', + cfg: MaxxVitTransformerCfg = MaxxVitTransformerCfg(), + drop_path: float = 0., + ): + super().__init__() + norm_layer = partial(get_norm_layer(cfg.norm_layer_cl), eps=cfg.norm_eps) # NOTE this block is channels-last + act_layer = get_act_layer(cfg.act_layer) + + self.partition_block = partition_type == 'block' + self.partition_size = to_2tuple(cfg.window_size if self.partition_block else cfg.grid_size) + rel_pos_cls = get_rel_pos_cls(cfg, self.partition_size) + + self.norm1 = norm_layer(dim) + self.attn = AttentionCl( + dim, + dim, + dim_head=cfg.dim_head, + bias=cfg.attn_bias, + rel_pos_cls=rel_pos_cls, + attn_drop=cfg.attn_drop, + proj_drop=cfg.proj_drop, + ) + self.ls1 = LayerScale(dim, init_values=cfg.init_values) if cfg.init_values else nn.Identity() + self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity() + + self.norm2 = norm_layer(dim) + self.mlp = Mlp( + in_features=dim, + hidden_features=int(dim * cfg.expand_ratio), + act_layer=act_layer, + drop=cfg.proj_drop) + self.ls2 = LayerScale(dim, init_values=cfg.init_values) if cfg.init_values else nn.Identity() + self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity() + + def _partition_attn(self, x): + img_size = x.shape[1:3] + if self.partition_block: + partitioned = window_partition(x, self.partition_size) + else: + partitioned = grid_partition(x, self.partition_size) + + partitioned = self.attn(partitioned) + + if self.partition_block: + x = window_reverse(partitioned, self.partition_size, img_size) + else: + x = grid_reverse(partitioned, self.partition_size, img_size) + return x + + def forward(self, x): + x = x + self.drop_path1(self.ls1(self._partition_attn(self.norm1(x)))) + x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x)))) + return x + + +class ParallelPartitionAttention(nn.Module): + """ Experimental. Grid and Block partition + single FFN + NxC tensor layout. + """ + + def __init__( + self, + dim: int, + cfg: MaxxVitTransformerCfg = MaxxVitTransformerCfg(), + drop_path: float = 0., + ): + super().__init__() + assert dim % 2 == 0 + norm_layer = partial(get_norm_layer(cfg.norm_layer_cl), eps=cfg.norm_eps) # NOTE this block is channels-last + act_layer = get_act_layer(cfg.act_layer) + + assert cfg.window_size == cfg.grid_size + self.partition_size = to_2tuple(cfg.window_size) + rel_pos_cls = get_rel_pos_cls(cfg, self.partition_size) + + self.norm1 = norm_layer(dim) + self.attn_block = AttentionCl( + dim, + dim // 2, + dim_head=cfg.dim_head, + bias=cfg.attn_bias, + rel_pos_cls=rel_pos_cls, + attn_drop=cfg.attn_drop, + proj_drop=cfg.proj_drop, + ) + self.attn_grid = AttentionCl( + dim, + dim // 2, + dim_head=cfg.dim_head, + bias=cfg.attn_bias, + rel_pos_cls=rel_pos_cls, + attn_drop=cfg.attn_drop, + proj_drop=cfg.proj_drop, + ) + self.ls1 = LayerScale(dim, init_values=cfg.init_values) if cfg.init_values else nn.Identity() + self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity() + + self.norm2 = norm_layer(dim) + self.mlp = Mlp( + in_features=dim, + hidden_features=int(dim * cfg.expand_ratio), + out_features=dim, + act_layer=act_layer, + drop=cfg.proj_drop) + self.ls2 = LayerScale(dim, init_values=cfg.init_values) if cfg.init_values else nn.Identity() + self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity() + + def _partition_attn(self, x): + img_size = x.shape[1:3] + + partitioned_block = window_partition(x, self.partition_size) + partitioned_block = self.attn_block(partitioned_block) + x_window = window_reverse(partitioned_block, self.partition_size, img_size) + + partitioned_grid = grid_partition(x, self.partition_size) + partitioned_grid = self.attn_grid(partitioned_grid) + x_grid = grid_reverse(partitioned_grid, self.partition_size, img_size) + + return torch.cat([x_window, x_grid], dim=-1) + + def forward(self, x): + x = x + self.drop_path1(self.ls1(self._partition_attn(self.norm1(x)))) + x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x)))) + return x + + +def window_partition_nchw(x, window_size: List[int]): + B, C, H, W = x.shape + _assert(H % window_size[0] == 0, f'height ({H}) must be divisible by window ({window_size[0]})') + _assert(W % window_size[1] == 0, '') + x = x.view(B, C, H // window_size[0], window_size[0], W // window_size[1], window_size[1]) + windows = x.permute(0, 2, 4, 1, 3, 5).contiguous().view(-1, C, window_size[0], window_size[1]) + return windows + + +@register_notrace_function # reason: int argument is a Proxy +def window_reverse_nchw(windows, window_size: List[int], img_size: List[int]): + H, W = img_size + C = windows.shape[1] + x = windows.view(-1, H // window_size[0], W // window_size[1], C, window_size[0], window_size[1]) + x = x.permute(0, 3, 1, 4, 2, 5).contiguous().view(-1, C, H, W) + return x + + +def grid_partition_nchw(x, grid_size: List[int]): + B, C, H, W = x.shape + _assert(H % grid_size[0] == 0, f'height {H} must be divisible by grid {grid_size[0]}') + _assert(W % grid_size[1] == 0, '') + x = x.view(B, C, grid_size[0], H // grid_size[0], grid_size[1], W // grid_size[1]) + windows = x.permute(0, 3, 5, 1, 2, 4).contiguous().view(-1, C, grid_size[0], grid_size[1]) + return windows + + +@register_notrace_function # reason: int argument is a Proxy +def grid_reverse_nchw(windows, grid_size: List[int], img_size: List[int]): + H, W = img_size + C = windows.shape[1] + x = windows.view(-1, H // grid_size[0], W // grid_size[1], C, grid_size[0], grid_size[1]) + x = x.permute(0, 3, 4, 1, 5, 2).contiguous().view(-1, C, H, W) + return x + + +class PartitionAttention2d(nn.Module): + """ Grid or Block partition + Attn + FFN + '2D' NCHW tensor layout. + """ + + def __init__( + self, + dim: int, + partition_type: str = 'block', + cfg: MaxxVitTransformerCfg = MaxxVitTransformerCfg(), + drop_path: float = 0., + ): + super().__init__() + norm_layer = partial(get_norm_layer(cfg.norm_layer), eps=cfg.norm_eps) # NOTE this block is channels-last + act_layer = get_act_layer(cfg.act_layer) + + self.partition_block = partition_type == 'block' + self.partition_size = to_2tuple(cfg.window_size if self.partition_block else cfg.grid_size) + rel_pos_cls = get_rel_pos_cls(cfg, self.partition_size) + + self.norm1 = norm_layer(dim) + self.attn = Attention2d( + dim, + dim, + dim_head=cfg.dim_head, + bias=cfg.attn_bias, + rel_pos_cls=rel_pos_cls, + attn_drop=cfg.attn_drop, + proj_drop=cfg.proj_drop, + ) + self.ls1 = LayerScale2d(dim, init_values=cfg.init_values) if cfg.init_values else nn.Identity() + self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity() + + self.norm2 = norm_layer(dim) + self.mlp = ConvMlp( + in_features=dim, + hidden_features=int(dim * cfg.expand_ratio), + act_layer=act_layer, + drop=cfg.proj_drop) + self.ls2 = LayerScale2d(dim, init_values=cfg.init_values) if cfg.init_values else nn.Identity() + self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity() + + def _partition_attn(self, x): + img_size = x.shape[-2:] + if self.partition_block: + partitioned = window_partition_nchw(x, self.partition_size) + else: + partitioned = grid_partition_nchw(x, self.partition_size) + + partitioned = self.attn(partitioned) + + if self.partition_block: + x = window_reverse_nchw(partitioned, self.partition_size, img_size) + else: + x = grid_reverse_nchw(partitioned, self.partition_size, img_size) + return x + + def forward(self, x): + x = x + self.drop_path1(self.ls1(self._partition_attn(self.norm1(x)))) + x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x)))) + return x + + +class MaxxVitBlock(nn.Module): + """ + """ + + def __init__( + self, + dim: int, + dim_out: int, + stride: int = 1, + conv_cfg: MaxxVitConvCfg = MaxxVitConvCfg(), + transformer_cfg: MaxxVitTransformerCfg = MaxxVitTransformerCfg(), + use_nchw_attn: bool = False, # FIXME move to cfg? True is ~20-30% faster on TPU, 5-10% slower on GPU + drop_path: float = 0., + ): + super().__init__() + + conv_cls = ConvNeXtBlock if conv_cfg.block_type == 'convnext' else MbConvBlock + self.conv = conv_cls(dim, dim_out, stride=stride, cfg=conv_cfg, drop_path=drop_path) + + attn_kwargs = dict(dim=dim_out, cfg=transformer_cfg, drop_path=drop_path) + partition_layer = PartitionAttention2d if use_nchw_attn else PartitionAttention + self.nchw_attn = use_nchw_attn + self.attn_block = partition_layer(**attn_kwargs) + self.attn_grid = partition_layer(partition_type='grid', **attn_kwargs) + + def init_weights(self, scheme=''): + named_apply(partial(_init_transformer, scheme=scheme), self.attn_block) + named_apply(partial(_init_transformer, scheme=scheme), self.attn_grid) + named_apply(partial(_init_conv, scheme=scheme), self.conv) + + def forward(self, x): + # NCHW format + x = self.conv(x) + + if not self.nchw_attn: + x = x.permute(0, 2, 3, 1) # to NHWC (channels-last) + x = self.attn_block(x) + x = self.attn_grid(x) + if not self.nchw_attn: + x = x.permute(0, 3, 1, 2) # back to NCHW + return x + + +class ParallelMaxxVitBlock(nn.Module): + """ + """ + + def __init__( + self, + dim, + dim_out, + stride=1, + num_conv=2, + conv_cfg: MaxxVitConvCfg = MaxxVitConvCfg(), + transformer_cfg: MaxxVitTransformerCfg = MaxxVitTransformerCfg(), + drop_path=0., + ): + super().__init__() + + conv_cls = ConvNeXtBlock if conv_cfg.block_type == 'convnext' else MbConvBlock + if num_conv > 1: + convs = [conv_cls(dim, dim_out, stride=stride, cfg=conv_cfg, drop_path=drop_path)] + convs += [conv_cls(dim_out, dim_out, cfg=conv_cfg, drop_path=drop_path)] * (num_conv - 1) + self.conv = nn.Sequential(*convs) + else: + self.conv = conv_cls(dim, dim_out, stride=stride, cfg=conv_cfg, drop_path=drop_path) + self.attn = ParallelPartitionAttention(dim=dim_out, cfg=transformer_cfg, drop_path=drop_path) + + def init_weights(self, scheme=''): + named_apply(partial(_init_transformer, scheme=scheme), self.attn) + named_apply(partial(_init_conv, scheme=scheme), self.conv) + + def forward(self, x): + x = self.conv(x) + x = x.permute(0, 2, 3, 1) + x = self.attn(x) + x = x.permute(0, 3, 1, 2) + return x + + +class MaxxVitStage(nn.Module): + def __init__( + self, + in_chs: int, + out_chs: int, + stride: int = 2, + depth: int = 4, + feat_size: Tuple[int, int] = (14, 14), + block_types: Union[str, Tuple[str]] = 'C', + transformer_cfg: MaxxVitTransformerCfg = MaxxVitTransformerCfg(), + conv_cfg: MaxxVitConvCfg = MaxxVitConvCfg(), + drop_path: Union[float, List[float]] = 0., + ): + super().__init__() + self.grad_checkpointing = False + + block_types = extend_tuple(block_types, depth) + blocks = [] + for i, t in enumerate(block_types): + block_stride = stride if i == 0 else 1 + assert t in ('C', 'T', 'M', 'PM') + if t == 'C': + conv_cls = ConvNeXtBlock if conv_cfg.block_type == 'convnext' else MbConvBlock + blocks += [conv_cls( + in_chs, + out_chs, + stride=block_stride, + cfg=conv_cfg, + drop_path=drop_path[i], + )] + elif t == 'T': + rel_pos_cls = get_rel_pos_cls(transformer_cfg, feat_size) + blocks += [TransformerBlock2d( + in_chs, + out_chs, + stride=block_stride, + rel_pos_cls=rel_pos_cls, + cfg=transformer_cfg, + drop_path=drop_path[i], + )] + elif t == 'M': + blocks += [MaxxVitBlock( + in_chs, + out_chs, + stride=block_stride, + conv_cfg=conv_cfg, + transformer_cfg=transformer_cfg, + drop_path=drop_path[i], + )] + elif t == 'PM': + blocks += [ParallelMaxxVitBlock( + in_chs, + out_chs, + stride=block_stride, + conv_cfg=conv_cfg, + transformer_cfg=transformer_cfg, + drop_path=drop_path[i], + )] + in_chs = out_chs + self.blocks = nn.Sequential(*blocks) + + def forward(self, x): + if self.grad_checkpointing and not torch.jit.is_scripting(): + x = checkpoint_seq(self.blocks, x) + else: + x = self.blocks(x) + return x + + +class Stem(nn.Module): + + def __init__( + self, + in_chs: int, + out_chs: int, + kernel_size: int = 3, + act_layer: str = 'gelu', + norm_layer: str = 'batchnorm2d', + norm_eps: float = 1e-5, + ): + super().__init__() + if not isinstance(out_chs, (list, tuple)): + out_chs = to_2tuple(out_chs) + + norm_act_layer = partial(get_norm_act_layer(norm_layer, act_layer), eps=norm_eps) + self.out_chs = out_chs[-1] + self.stride = 2 + + self.conv1 = create_conv2d(in_chs, out_chs[0], kernel_size, stride=2) + self.norm1 = norm_act_layer(out_chs[0]) + self.conv2 = create_conv2d(out_chs[0], out_chs[1], kernel_size, stride=1) + + def init_weights(self, scheme=''): + named_apply(partial(_init_conv, scheme=scheme), self) + + def forward(self, x): + x = self.conv1(x) + x = self.norm1(x) + x = self.conv2(x) + return x + + +class MaxxVit(nn.Module): + """ + """ + + def __init__( + self, + cfg: MaxxVitCfg, + img_size: Union[int, Tuple[int, int]] = 224, + in_chans: int = 3, + num_classes: int = 1000, + global_pool: str = 'avg', + drop_rate: float = 0., + drop_path_rate: float = 0. + ): + super().__init__() + img_size = to_2tuple(img_size) + self.num_classes = num_classes + self.global_pool = global_pool + self.num_features = cfg.embed_dim[-1] + self.embed_dim = cfg.embed_dim + self.drop_rate = drop_rate + self.grad_checkpointing = False + + self.stem = Stem( + in_chs=in_chans, + out_chs=cfg.stem_width, + act_layer=cfg.conv_cfg.act_layer, + norm_layer=cfg.conv_cfg.norm_layer, + norm_eps=cfg.conv_cfg.norm_eps, + ) + + stride = self.stem.stride + feat_size = tuple([i // s for i, s in zip(img_size, to_2tuple(stride))]) + + num_stages = len(cfg.embed_dim) + assert len(cfg.depths) == num_stages + dpr = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(cfg.depths)).split(cfg.depths)] + in_chs = self.stem.out_chs + stages = [] + for i in range(num_stages): + stage_stride = 2 + out_chs = cfg.embed_dim[i] + feat_size = tuple([(r - 1) // stage_stride + 1 for r in feat_size]) + stages += [MaxxVitStage( + in_chs, + out_chs, + depth=cfg.depths[i], + block_types=cfg.block_type[i], + conv_cfg=cfg.conv_cfg, + transformer_cfg=cfg.transformer_cfg, + feat_size=feat_size, + drop_path=dpr[i], + )] + stride *= stage_stride + in_chs = out_chs + self.stages = nn.Sequential(*stages) + + final_norm_layer = get_norm_layer(cfg.transformer_cfg.norm_layer) + self.norm = final_norm_layer(self.num_features, eps=cfg.transformer_cfg.norm_eps) + + # Classifier head + self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=drop_rate) + + # Weight init (default PyTorch init works well for AdamW if scheme not set) + assert cfg.weight_init in ('', 'normal', 'trunc_normal', 'xavier_normal', 'vit_eff') + if cfg.weight_init: + named_apply(partial(self._init_weights, scheme=cfg.weight_init), self) + + def _init_weights(self, module, name, scheme=''): + if hasattr(module, 'init_weights'): + try: + module.init_weights(scheme=scheme) + except TypeError: + module.init_weights() + + @torch.jit.ignore + def no_weight_decay(self): + return { + k for k, _ in self.named_parameters() + if any(n in k for n in ["relative_position_bias_table", "rel_pos.mlp"])} + + @torch.jit.ignore + def group_matcher(self, coarse=False): + matcher = dict( + stem=r'^stem', # stem and embed + blocks=[(r'^stages\.(\d+)', None), (r'^norm', (99999,))] + ) + return matcher + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + for s in self.stages: + s.grad_checkpointing = enable + + @torch.jit.ignore + def get_classifier(self): + return self.head.fc + + def reset_classifier(self, num_classes, global_pool=None): + self.num_classes = num_classes + if global_pool is None: + global_pool = self.head.global_pool.pool_type + self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate) + + def forward_features(self, x): + x = self.stem(x) + x = self.stages(x) + x = self.norm(x) + return x + + def forward_head(self, x, pre_logits: bool = False): + return self.head(x, pre_logits=pre_logits) + + def forward(self, x): + x = self.forward_features(x) + x = self.forward_head(x) + return x + + +def _create_maxxvit(variant, cfg_variant=None, pretrained=False, **kwargs): + return build_model_with_cfg( + MaxxVit, variant, pretrained, + model_cfg=model_cfgs[variant] if not cfg_variant else model_cfgs[cfg_variant], + feature_cfg=dict(flatten_sequential=True), + **kwargs) + + +@register_model +def coatnet_pico_rw_224(pretrained=False, **kwargs): + return _create_maxxvit('coatnet_pico_rw_224', pretrained=pretrained, **kwargs) + + +@register_model +def coatnet_nano_rw_224(pretrained=False, **kwargs): + return _create_maxxvit('coatnet_nano_rw_224', pretrained=pretrained, **kwargs) + + +@register_model +def coatnet_0_rw_224(pretrained=False, **kwargs): + return _create_maxxvit('coatnet_0_rw_224', pretrained=pretrained, **kwargs) + + +@register_model +def coatnet_1_rw_224(pretrained=False, **kwargs): + return _create_maxxvit('coatnet_1_rw_224', pretrained=pretrained, **kwargs) + + +@register_model +def coatnet_2_rw_224(pretrained=False, **kwargs): + return _create_maxxvit('coatnet_2_rw_224', pretrained=pretrained, **kwargs) + + +@register_model +def coatnet_bn_0_rw_224(pretrained=False, **kwargs): + return _create_maxxvit('coatnet_bn_0_rw_224', pretrained=pretrained, **kwargs) + + +@register_model +def coatnet_rmlp_nano_rw_224(pretrained=False, **kwargs): + return _create_maxxvit('coatnet_rmlp_nano_rw_224', pretrained=pretrained, **kwargs) + + +@register_model +def coatnet_rmlp_0_rw_224(pretrained=False, **kwargs): + return _create_maxxvit('coatnet_rmlp_0_rw_224', pretrained=pretrained, **kwargs) + + +@register_model +def coatnet_rmlp_1_rw_224(pretrained=False, **kwargs): + return _create_maxxvit('coatnet_rmlp_1_rw_224', pretrained=pretrained, **kwargs) + + +@register_model +def coatnet_nano_cc_224(pretrained=False, **kwargs): + return _create_maxxvit('coatnet_nano_cc_224', pretrained=pretrained, **kwargs) + + +@register_model +def coatnext_nano_rw_224(pretrained=False, **kwargs): + return _create_maxxvit('coatnext_nano_rw_224', pretrained=pretrained, **kwargs) + + +@register_model +def coatnet_0_224(pretrained=False, **kwargs): + return _create_maxxvit('coatnet_0_224', pretrained=pretrained, **kwargs) + + +@register_model +def coatnet_1_224(pretrained=False, **kwargs): + return _create_maxxvit('coatnet_1_224', pretrained=pretrained, **kwargs) + + +@register_model +def coatnet_2_224(pretrained=False, **kwargs): + return _create_maxxvit('coatnet_2_224', pretrained=pretrained, **kwargs) + + +@register_model +def coatnet_3_224(pretrained=False, **kwargs): + return _create_maxxvit('coatnet_3_224', pretrained=pretrained, **kwargs) + + +@register_model +def coatnet_4_224(pretrained=False, **kwargs): + return _create_maxxvit('coatnet_4_224', pretrained=pretrained, **kwargs) + + +@register_model +def coatnet_5_224(pretrained=False, **kwargs): + return _create_maxxvit('coatnet_5_224', pretrained=pretrained, **kwargs) + + +@register_model +def maxvit_pico_rw_256(pretrained=False, **kwargs): + return _create_maxxvit('maxvit_pico_rw_256', pretrained=pretrained, **kwargs) + + +@register_model +def maxvit_nano_rw_256(pretrained=False, **kwargs): + return _create_maxxvit('maxvit_nano_rw_256', pretrained=pretrained, **kwargs) + + +@register_model +def maxvit_tiny_rw_224(pretrained=False, **kwargs): + return _create_maxxvit('maxvit_tiny_rw_224', pretrained=pretrained, **kwargs) + + +@register_model +def maxvit_tiny_rw_256(pretrained=False, **kwargs): + return _create_maxxvit('maxvit_tiny_rw_256', pretrained=pretrained, **kwargs) + + +@register_model +def maxvit_tiny_pm_256(pretrained=False, **kwargs): + return _create_maxxvit('maxvit_tiny_pm_256', pretrained=pretrained, **kwargs) + + +@register_model +def maxxvit_nano_rw_256(pretrained=False, **kwargs): + return _create_maxxvit('maxxvit_nano_rw_256', pretrained=pretrained, **kwargs) + + +@register_model +def maxvit_tiny_224(pretrained=False, **kwargs): + return _create_maxxvit('maxvit_tiny_224', pretrained=pretrained, **kwargs) + + +@register_model +def maxvit_small_224(pretrained=False, **kwargs): + return _create_maxxvit('maxvit_small_224', pretrained=pretrained, **kwargs) + + +@register_model +def maxvit_base_224(pretrained=False, **kwargs): + return _create_maxxvit('maxvit_base_224', pretrained=pretrained, **kwargs) + + +@register_model +def maxvit_large_224(pretrained=False, **kwargs): + return _create_maxxvit('maxvit_large_224', pretrained=pretrained, **kwargs) + + +@register_model +def maxvit_xlarge_224(pretrained=False, **kwargs): + return _create_maxxvit('maxvit_xlarge_224', pretrained=pretrained, **kwargs) \ No newline at end of file diff --git a/timm/models/mvitv2.py b/timm/models/mvitv2.py new file mode 100644 index 00000000..002225c6 --- /dev/null +++ b/timm/models/mvitv2.py @@ -0,0 +1,998 @@ +""" Multi-Scale Vision Transformer v2 + +@inproceedings{li2021improved, + title={MViTv2: Improved multiscale vision transformers for classification and detection}, + author={Li, Yanghao and Wu, Chao-Yuan and Fan, Haoqi and Mangalam, Karttikeya and Xiong, Bo and Malik, Jitendra and Feichtenhofer, Christoph}, + booktitle={CVPR}, + year={2022} +} + +Code adapted from original Apache 2.0 licensed impl at https://github.com/facebookresearch/mvit +Original copyright below. + +Modifications and timm support by / Copyright 2022, Ross Wightman +""" +# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved. All Rights Reserved. +import operator +from collections import OrderedDict +from dataclasses import dataclass +from functools import partial, reduce +from typing import Union, List, Tuple, Optional + +import torch +import torch.utils.checkpoint as checkpoint +from torch import nn + +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from .fx_features import register_notrace_function +from .helpers import build_model_with_cfg +from .layers import Mlp, DropPath, trunc_normal_tf_, get_norm_layer, to_2tuple +from .registry import register_model + + +def _cfg(url='', **kwargs): + return { + 'url': url, + 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, + 'crop_pct': .9, 'interpolation': 'bicubic', + 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, + 'first_conv': 'patch_embed.proj', 'classifier': 'head.fc', + 'fixed_input_size': True, + **kwargs + } + + +default_cfgs = dict( + mvitv2_tiny=_cfg(url='https://dl.fbaipublicfiles.com/mvit/mvitv2_models/MViTv2_T_in1k.pyth'), + mvitv2_small=_cfg(url='https://dl.fbaipublicfiles.com/mvit/mvitv2_models/MViTv2_S_in1k.pyth'), + mvitv2_base=_cfg(url='https://dl.fbaipublicfiles.com/mvit/mvitv2_models/MViTv2_B_in1k.pyth'), + mvitv2_large=_cfg(url='https://dl.fbaipublicfiles.com/mvit/mvitv2_models/MViTv2_L_in1k.pyth'), + + mvitv2_base_in21k=_cfg( + url='https://dl.fbaipublicfiles.com/mvit/mvitv2_models/MViTv2_B_in21k.pyth', + num_classes=19168), + mvitv2_large_in21k=_cfg( + url='https://dl.fbaipublicfiles.com/mvit/mvitv2_models/MViTv2_L_in21k.pyth', + num_classes=19168), + mvitv2_huge_in21k=_cfg( + url='https://dl.fbaipublicfiles.com/mvit/mvitv2_models/MViTv2_H_in21k.pyth', + num_classes=19168), +) + + +@dataclass +class MultiScaleVitCfg: + depths: Tuple[int, ...] = (2, 3, 16, 3) + embed_dim: Union[int, Tuple[int, ...]] = 96 + num_heads: Union[int, Tuple[int, ...]] = 1 + mlp_ratio: float = 4. + pool_first: bool = False + expand_attn: bool = True + qkv_bias: bool = True + use_cls_token: bool = False + use_abs_pos: bool = False + residual_pooling: bool = True + mode: str = 'conv' + kernel_qkv: Tuple[int, int] = (3, 3) + stride_q: Optional[Tuple[Tuple[int, int]]] = ((1, 1), (2, 2), (2, 2), (2, 2)) + stride_kv: Optional[Tuple[Tuple[int, int]]] = None + stride_kv_adaptive: Optional[Tuple[int, int]] = (4, 4) + patch_kernel: Tuple[int, int] = (7, 7) + patch_stride: Tuple[int, int] = (4, 4) + patch_padding: Tuple[int, int] = (3, 3) + pool_type: str = 'max' + rel_pos_type: str = 'spatial' + act_layer: Union[str, Tuple[str, str]] = 'gelu' + norm_layer: Union[str, Tuple[str, str]] = 'layernorm' + norm_eps: float = 1e-6 + + def __post_init__(self): + num_stages = len(self.depths) + if not isinstance(self.embed_dim, (tuple, list)): + self.embed_dim = tuple(self.embed_dim * 2 ** i for i in range(num_stages)) + assert len(self.embed_dim) == num_stages + + if not isinstance(self.num_heads, (tuple, list)): + self.num_heads = tuple(self.num_heads * 2 ** i for i in range(num_stages)) + assert len(self.num_heads) == num_stages + + if self.stride_kv_adaptive is not None and self.stride_kv is None: + _stride_kv = self.stride_kv_adaptive + pool_kv_stride = [] + for i in range(num_stages): + if min(self.stride_q[i]) > 1: + _stride_kv = [ + max(_stride_kv[d] // self.stride_q[i][d], 1) + for d in range(len(_stride_kv)) + ] + pool_kv_stride.append(tuple(_stride_kv)) + self.stride_kv = tuple(pool_kv_stride) + + +model_cfgs = dict( + mvitv2_tiny=MultiScaleVitCfg( + depths=(1, 2, 5, 2), + ), + mvitv2_small=MultiScaleVitCfg( + depths=(1, 2, 11, 2), + ), + mvitv2_base=MultiScaleVitCfg( + depths=(2, 3, 16, 3), + ), + mvitv2_large=MultiScaleVitCfg( + depths=(2, 6, 36, 4), + embed_dim=144, + num_heads=2, + expand_attn=False, + ), + + mvitv2_base_in21k=MultiScaleVitCfg( + depths=(2, 3, 16, 3), + ), + mvitv2_large_in21k=MultiScaleVitCfg( + depths=(2, 6, 36, 4), + embed_dim=144, + num_heads=2, + expand_attn=False, + ), +) + + +def prod(iterable): + return reduce(operator.mul, iterable, 1) + + +class PatchEmbed(nn.Module): + """ + PatchEmbed. + """ + + def __init__( + self, + dim_in=3, + dim_out=768, + kernel=(7, 7), + stride=(4, 4), + padding=(3, 3), + ): + super().__init__() + + self.proj = nn.Conv2d( + dim_in, + dim_out, + kernel_size=kernel, + stride=stride, + padding=padding, + ) + + def forward(self, x) -> Tuple[torch.Tensor, List[int]]: + x = self.proj(x) + # B C H W -> B HW C + return x.flatten(2).transpose(1, 2), x.shape[-2:] + + +@register_notrace_function +def reshape_pre_pool( + x, + feat_size: List[int], + has_cls_token: bool = True +) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + H, W = feat_size + if has_cls_token: + cls_tok, x = x[:, :, :1, :], x[:, :, 1:, :] + else: + cls_tok = None + x = x.reshape(-1, H, W, x.shape[-1]).permute(0, 3, 1, 2).contiguous() + return x, cls_tok + + +@register_notrace_function +def reshape_post_pool( + x, + num_heads: int, + cls_tok: Optional[torch.Tensor] = None +) -> Tuple[torch.Tensor, List[int]]: + feat_size = [x.shape[2], x.shape[3]] + L_pooled = x.shape[2] * x.shape[3] + x = x.reshape(-1, num_heads, x.shape[1], L_pooled).transpose(2, 3) + if cls_tok is not None: + x = torch.cat((cls_tok, x), dim=2) + return x, feat_size + + +@register_notrace_function +def cal_rel_pos_type( + attn: torch.Tensor, + q: torch.Tensor, + has_cls_token: bool, + q_size: List[int], + k_size: List[int], + rel_pos_h: torch.Tensor, + rel_pos_w: torch.Tensor, +): + """ + Spatial Relative Positional Embeddings. + """ + sp_idx = 1 if has_cls_token else 0 + q_h, q_w = q_size + k_h, k_w = k_size + + # Scale up rel pos if shapes for q and k are different. + q_h_ratio = max(k_h / q_h, 1.0) + k_h_ratio = max(q_h / k_h, 1.0) + dist_h = torch.arange(q_h)[:, None] * q_h_ratio - torch.arange(k_h)[None, :] * k_h_ratio + dist_h += (k_h - 1) * k_h_ratio + q_w_ratio = max(k_w / q_w, 1.0) + k_w_ratio = max(q_w / k_w, 1.0) + dist_w = torch.arange(q_w)[:, None] * q_w_ratio - torch.arange(k_w)[None, :] * k_w_ratio + dist_w += (k_w - 1) * k_w_ratio + + Rh = rel_pos_h[dist_h.long()] + Rw = rel_pos_w[dist_w.long()] + + B, n_head, q_N, dim = q.shape + + r_q = q[:, :, sp_idx:].reshape(B, n_head, q_h, q_w, dim) + rel_h = torch.einsum("byhwc,hkc->byhwk", r_q, Rh) + rel_w = torch.einsum("byhwc,wkc->byhwk", r_q, Rw) + + attn[:, :, sp_idx:, sp_idx:] = ( + attn[:, :, sp_idx:, sp_idx:].view(B, -1, q_h, q_w, k_h, k_w) + + rel_h[:, :, :, :, :, None] + + rel_w[:, :, :, :, None, :] + ).view(B, -1, q_h * q_w, k_h * k_w) + + return attn + + +class MultiScaleAttentionPoolFirst(nn.Module): + def __init__( + self, + dim, + dim_out, + feat_size, + num_heads=8, + qkv_bias=True, + mode="conv", + kernel_q=(1, 1), + kernel_kv=(1, 1), + stride_q=(1, 1), + stride_kv=(1, 1), + has_cls_token=True, + rel_pos_type='spatial', + residual_pooling=True, + norm_layer=nn.LayerNorm, + ): + super().__init__() + self.num_heads = num_heads + self.dim_out = dim_out + self.head_dim = dim_out // num_heads + self.scale = self.head_dim ** -0.5 + self.has_cls_token = has_cls_token + padding_q = tuple([int(q // 2) for q in kernel_q]) + padding_kv = tuple([int(kv // 2) for kv in kernel_kv]) + + self.q = nn.Linear(dim, dim_out, bias=qkv_bias) + self.k = nn.Linear(dim, dim_out, bias=qkv_bias) + self.v = nn.Linear(dim, dim_out, bias=qkv_bias) + self.proj = nn.Linear(dim_out, dim_out) + + # Skip pooling with kernel and stride size of (1, 1, 1). + if prod(kernel_q) == 1 and prod(stride_q) == 1: + kernel_q = None + if prod(kernel_kv) == 1 and prod(stride_kv) == 1: + kernel_kv = None + self.mode = mode + self.unshared = mode == 'conv_unshared' + self.pool_q, self.pool_k, self.pool_v = None, None, None + self.norm_q, self.norm_k, self.norm_v = None, None, None + if mode in ("avg", "max"): + pool_op = nn.MaxPool2d if mode == "max" else nn.AvgPool2d + if kernel_q: + self.pool_q = pool_op(kernel_q, stride_q, padding_q) + if kernel_kv: + self.pool_k = pool_op(kernel_kv, stride_kv, padding_kv) + self.pool_v = pool_op(kernel_kv, stride_kv, padding_kv) + elif mode == "conv" or mode == "conv_unshared": + dim_conv = dim // num_heads if mode == "conv" else dim + if kernel_q: + self.pool_q = nn.Conv2d( + dim_conv, + dim_conv, + kernel_q, + stride=stride_q, + padding=padding_q, + groups=dim_conv, + bias=False, + ) + self.norm_q = norm_layer(dim_conv) + if kernel_kv: + self.pool_k = nn.Conv2d( + dim_conv, + dim_conv, + kernel_kv, + stride=stride_kv, + padding=padding_kv, + groups=dim_conv, + bias=False, + ) + self.norm_k = norm_layer(dim_conv) + self.pool_v = nn.Conv2d( + dim_conv, + dim_conv, + kernel_kv, + stride=stride_kv, + padding=padding_kv, + groups=dim_conv, + bias=False, + ) + self.norm_v = norm_layer(dim_conv) + else: + raise NotImplementedError(f"Unsupported model {mode}") + + # relative pos embedding + self.rel_pos_type = rel_pos_type + if self.rel_pos_type == 'spatial': + assert feat_size[0] == feat_size[1] + size = feat_size[0] + q_size = size // stride_q[1] if len(stride_q) > 0 else size + kv_size = size // stride_kv[1] if len(stride_kv) > 0 else size + rel_sp_dim = 2 * max(q_size, kv_size) - 1 + + self.rel_pos_h = nn.Parameter(torch.zeros(rel_sp_dim, self.head_dim)) + self.rel_pos_w = nn.Parameter(torch.zeros(rel_sp_dim, self.head_dim)) + trunc_normal_tf_(self.rel_pos_h, std=0.02) + trunc_normal_tf_(self.rel_pos_w, std=0.02) + + self.residual_pooling = residual_pooling + + def forward(self, x, feat_size: List[int]): + B, N, _ = x.shape + + fold_dim = 1 if self.unshared else self.num_heads + x = x.reshape(B, N, fold_dim, -1).permute(0, 2, 1, 3) + q = k = v = x + + if self.pool_q is not None: + q, q_tok = reshape_pre_pool(q, feat_size, self.has_cls_token) + q = self.pool_q(q) + q, q_size = reshape_post_pool(q, self.num_heads, q_tok) + else: + q_size = feat_size + if self.norm_q is not None: + q = self.norm_q(q) + + if self.pool_k is not None: + k, k_tok = reshape_pre_pool(k, feat_size, self.has_cls_token) + k = self.pool_k(k) + k, k_size = reshape_post_pool(k, self.num_heads, k_tok) + else: + k_size = feat_size + if self.norm_k is not None: + k = self.norm_k(k) + + if self.pool_v is not None: + v, v_tok = reshape_pre_pool(v, feat_size, self.has_cls_token) + v = self.pool_v(v) + v, v_size = reshape_post_pool(v, self.num_heads, v_tok) + else: + v_size = feat_size + if self.norm_v is not None: + v = self.norm_v(v) + + q_N = q_size[0] * q_size[1] + int(self.has_cls_token) + q = q.permute(0, 2, 1, 3).reshape(B, q_N, -1) + q = self.q(q).reshape(B, q_N, self.num_heads, -1).permute(0, 2, 1, 3) + + k_N = k_size[0] * k_size[1] + int(self.has_cls_token) + k = k.permute(0, 2, 1, 3).reshape(B, k_N, -1) + k = self.k(k).reshape(B, k_N, self.num_heads, -1).permute(0, 2, 1, 3) + + v_N = v_size[0] * v_size[1] + int(self.has_cls_token) + v = v.permute(0, 2, 1, 3).reshape(B, v_N, -1) + v = self.v(v).reshape(B, v_N, self.num_heads, -1).permute(0, 2, 1, 3) + + attn = (q * self.scale) @ k.transpose(-2, -1) + if self.rel_pos_type == 'spatial': + attn = cal_rel_pos_type( + attn, + q, + self.has_cls_token, + q_size, + k_size, + self.rel_pos_h, + self.rel_pos_w, + ) + attn = attn.softmax(dim=-1) + x = attn @ v + + if self.residual_pooling: + x = x + q + + x = x.transpose(1, 2).reshape(B, -1, self.dim_out) + x = self.proj(x) + + return x, q_size + + +class MultiScaleAttention(nn.Module): + def __init__( + self, + dim, + dim_out, + feat_size, + num_heads=8, + qkv_bias=True, + mode="conv", + kernel_q=(1, 1), + kernel_kv=(1, 1), + stride_q=(1, 1), + stride_kv=(1, 1), + has_cls_token=True, + rel_pos_type='spatial', + residual_pooling=True, + norm_layer=nn.LayerNorm, + ): + super().__init__() + self.num_heads = num_heads + self.dim_out = dim_out + self.head_dim = dim_out // num_heads + self.scale = self.head_dim ** -0.5 + self.has_cls_token = has_cls_token + padding_q = tuple([int(q // 2) for q in kernel_q]) + padding_kv = tuple([int(kv // 2) for kv in kernel_kv]) + + self.qkv = nn.Linear(dim, dim_out * 3, bias=qkv_bias) + self.proj = nn.Linear(dim_out, dim_out) + + # Skip pooling with kernel and stride size of (1, 1, 1). + if prod(kernel_q) == 1 and prod(stride_q) == 1: + kernel_q = None + if prod(kernel_kv) == 1 and prod(stride_kv) == 1: + kernel_kv = None + self.mode = mode + self.unshared = mode == 'conv_unshared' + self.norm_q, self.norm_k, self.norm_v = None, None, None + self.pool_q, self.pool_k, self.pool_v = None, None, None + if mode in ("avg", "max"): + pool_op = nn.MaxPool2d if mode == "max" else nn.AvgPool2d + if kernel_q: + self.pool_q = pool_op(kernel_q, stride_q, padding_q) + if kernel_kv: + self.pool_k = pool_op(kernel_kv, stride_kv, padding_kv) + self.pool_v = pool_op(kernel_kv, stride_kv, padding_kv) + elif mode == "conv" or mode == "conv_unshared": + dim_conv = dim_out // num_heads if mode == "conv" else dim_out + if kernel_q: + self.pool_q = nn.Conv2d( + dim_conv, + dim_conv, + kernel_q, + stride=stride_q, + padding=padding_q, + groups=dim_conv, + bias=False, + ) + self.norm_q = norm_layer(dim_conv) + if kernel_kv: + self.pool_k = nn.Conv2d( + dim_conv, + dim_conv, + kernel_kv, + stride=stride_kv, + padding=padding_kv, + groups=dim_conv, + bias=False, + ) + self.norm_k = norm_layer(dim_conv) + self.pool_v = nn.Conv2d( + dim_conv, + dim_conv, + kernel_kv, + stride=stride_kv, + padding=padding_kv, + groups=dim_conv, + bias=False, + ) + self.norm_v = norm_layer(dim_conv) + else: + raise NotImplementedError(f"Unsupported model {mode}") + + # relative pos embedding + self.rel_pos_type = rel_pos_type + if self.rel_pos_type == 'spatial': + assert feat_size[0] == feat_size[1] + size = feat_size[0] + q_size = size // stride_q[1] if len(stride_q) > 0 else size + kv_size = size // stride_kv[1] if len(stride_kv) > 0 else size + rel_sp_dim = 2 * max(q_size, kv_size) - 1 + + self.rel_pos_h = nn.Parameter(torch.zeros(rel_sp_dim, self.head_dim)) + self.rel_pos_w = nn.Parameter(torch.zeros(rel_sp_dim, self.head_dim)) + trunc_normal_tf_(self.rel_pos_h, std=0.02) + trunc_normal_tf_(self.rel_pos_w, std=0.02) + + self.residual_pooling = residual_pooling + + def forward(self, x, feat_size: List[int]): + B, N, _ = x.shape + + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) + q, k, v = qkv.unbind(dim=0) + + if self.pool_q is not None: + q, q_tok = reshape_pre_pool(q, feat_size, self.has_cls_token) + q = self.pool_q(q) + q, q_size = reshape_post_pool(q, self.num_heads, q_tok) + else: + q_size = feat_size + if self.norm_q is not None: + q = self.norm_q(q) + + if self.pool_k is not None: + k, k_tok = reshape_pre_pool(k, feat_size, self.has_cls_token) + k = self.pool_k(k) + k, k_size = reshape_post_pool(k, self.num_heads, k_tok) + else: + k_size = feat_size + if self.norm_k is not None: + k = self.norm_k(k) + + if self.pool_v is not None: + v, v_tok = reshape_pre_pool(v, feat_size, self.has_cls_token) + v = self.pool_v(v) + v, _ = reshape_post_pool(v, self.num_heads, v_tok) + if self.norm_v is not None: + v = self.norm_v(v) + + attn = (q * self.scale) @ k.transpose(-2, -1) + if self.rel_pos_type == 'spatial': + attn = cal_rel_pos_type( + attn, + q, + self.has_cls_token, + q_size, + k_size, + self.rel_pos_h, + self.rel_pos_w, + ) + attn = attn.softmax(dim=-1) + x = attn @ v + + if self.residual_pooling: + x = x + q + + x = x.transpose(1, 2).reshape(B, -1, self.dim_out) + x = self.proj(x) + + return x, q_size + + +class MultiScaleBlock(nn.Module): + def __init__( + self, + dim, + dim_out, + num_heads, + feat_size, + mlp_ratio=4.0, + qkv_bias=True, + drop_path=0.0, + norm_layer=nn.LayerNorm, + kernel_q=(1, 1), + kernel_kv=(1, 1), + stride_q=(1, 1), + stride_kv=(1, 1), + mode="conv", + has_cls_token=True, + expand_attn=False, + pool_first=False, + rel_pos_type='spatial', + residual_pooling=True, + ): + super().__init__() + proj_needed = dim != dim_out + self.dim = dim + self.dim_out = dim_out + self.has_cls_token = has_cls_token + + self.norm1 = norm_layer(dim) + + self.shortcut_proj_attn = nn.Linear(dim, dim_out) if proj_needed and expand_attn else None + if stride_q and prod(stride_q) > 1: + kernel_skip = [s + 1 if s > 1 else s for s in stride_q] + stride_skip = stride_q + padding_skip = [int(skip // 2) for skip in kernel_skip] + self.shortcut_pool_attn = nn.MaxPool2d(kernel_skip, stride_skip, padding_skip) + else: + self.shortcut_pool_attn = None + + att_dim = dim_out if expand_attn else dim + attn_layer = MultiScaleAttentionPoolFirst if pool_first else MultiScaleAttention + self.attn = attn_layer( + dim, + att_dim, + num_heads=num_heads, + feat_size=feat_size, + qkv_bias=qkv_bias, + kernel_q=kernel_q, + kernel_kv=kernel_kv, + stride_q=stride_q, + stride_kv=stride_kv, + norm_layer=norm_layer, + has_cls_token=has_cls_token, + mode=mode, + rel_pos_type=rel_pos_type, + residual_pooling=residual_pooling, + ) + self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + self.norm2 = norm_layer(att_dim) + mlp_dim_out = dim_out + self.shortcut_proj_mlp = nn.Linear(dim, dim_out) if proj_needed and not expand_attn else None + self.mlp = Mlp( + in_features=att_dim, + hidden_features=int(att_dim * mlp_ratio), + out_features=mlp_dim_out, + ) + self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + def _shortcut_pool(self, x, feat_size: List[int]): + if self.shortcut_pool_attn is None: + return x + if self.has_cls_token: + cls_tok, x = x[:, :, :1, :], x[:, :, 1:, :] + else: + cls_tok = None + B, L, C = x.shape + H, W = feat_size + x = x.reshape(B, H, W, C).permute(0, 3, 1, 2).contiguous() + x = self.shortcut_pool_attn(x) + x = x.reshape(B, C, -1).transpose(1, 2) + if cls_tok is not None: + x = torch.cat((cls_tok, x), dim=2) + return x + + def forward(self, x, feat_size: List[int]): + x_norm = self.norm1(x) + # NOTE as per the original impl, this seems odd, but shortcut uses un-normalized input if no proj + x_shortcut = x if self.shortcut_proj_attn is None else self.shortcut_proj_attn(x_norm) + x_shortcut = self._shortcut_pool(x_shortcut, feat_size) + x, feat_size_new = self.attn(x_norm, feat_size) + x = x_shortcut + self.drop_path1(x) + + x_norm = self.norm2(x) + x_shortcut = x if self.shortcut_proj_mlp is None else self.shortcut_proj_mlp(x_norm) + x = x_shortcut + self.drop_path2(self.mlp(x_norm)) + return x, feat_size_new + + +class MultiScaleVitStage(nn.Module): + + def __init__( + self, + dim, + dim_out, + depth, + num_heads, + feat_size, + mlp_ratio=4.0, + qkv_bias=True, + mode="conv", + kernel_q=(1, 1), + kernel_kv=(1, 1), + stride_q=(1, 1), + stride_kv=(1, 1), + has_cls_token=True, + expand_attn=False, + pool_first=False, + rel_pos_type='spatial', + residual_pooling=True, + norm_layer=nn.LayerNorm, + drop_path=0.0, + ): + super().__init__() + self.grad_checkpointing = False + + self.blocks = nn.ModuleList() + if expand_attn: + out_dims = (dim_out,) * depth + else: + out_dims = (dim,) * (depth - 1) + (dim_out,) + + for i in range(depth): + attention_block = MultiScaleBlock( + dim=dim, + dim_out=out_dims[i], + num_heads=num_heads, + feat_size=feat_size, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + kernel_q=kernel_q, + kernel_kv=kernel_kv, + stride_q=stride_q if i == 0 else (1, 1), + stride_kv=stride_kv, + mode=mode, + has_cls_token=has_cls_token, + pool_first=pool_first, + rel_pos_type=rel_pos_type, + residual_pooling=residual_pooling, + expand_attn=expand_attn, + norm_layer=norm_layer, + drop_path=drop_path[i] if isinstance(drop_path, (list, tuple)) else drop_path, + ) + dim = out_dims[i] + self.blocks.append(attention_block) + if i == 0: + feat_size = tuple([size // stride for size, stride in zip(feat_size, stride_q)]) + + self.feat_size = feat_size + + def forward(self, x, feat_size: List[int]): + for blk in self.blocks: + if self.grad_checkpointing and not torch.jit.is_scripting(): + x, feat_size = checkpoint.checkpoint(blk, x, feat_size) + else: + x, feat_size = blk(x, feat_size) + return x, feat_size + + +class MultiScaleVit(nn.Module): + """ + Improved Multiscale Vision Transformers for Classification and Detection + Yanghao Li*, Chao-Yuan Wu*, Haoqi Fan, Karttikeya Mangalam, Bo Xiong, Jitendra Malik, + Christoph Feichtenhofer* + https://arxiv.org/abs/2112.01526 + + Multiscale Vision Transformers + Haoqi Fan*, Bo Xiong*, Karttikeya Mangalam*, Yanghao Li*, Zhicheng Yan, Jitendra Malik, + Christoph Feichtenhofer* + https://arxiv.org/abs/2104.11227 + """ + + def __init__( + self, + cfg: MultiScaleVitCfg, + img_size: Tuple[int, int] = (224, 224), + in_chans: int = 3, + global_pool: str = 'avg', + num_classes: int = 1000, + drop_path_rate: float = 0., + drop_rate: float = 0., + ): + super().__init__() + img_size = to_2tuple(img_size) + norm_layer = partial(get_norm_layer(cfg.norm_layer), eps=cfg.norm_eps) + self.num_classes = num_classes + self.drop_rate = drop_rate + self.global_pool = global_pool + self.depths = tuple(cfg.depths) + self.expand_attn = cfg.expand_attn + + embed_dim = cfg.embed_dim[0] + self.patch_embed = PatchEmbed( + dim_in=in_chans, + dim_out=embed_dim, + kernel=cfg.patch_kernel, + stride=cfg.patch_stride, + padding=cfg.patch_padding, + ) + patch_dims = (img_size[0] // cfg.patch_stride[0], img_size[1] // cfg.patch_stride[1]) + num_patches = prod(patch_dims) + + if cfg.use_cls_token: + self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) + self.num_prefix_tokens = 1 + pos_embed_dim = num_patches + 1 + else: + self.num_prefix_tokens = 0 + self.cls_token = None + pos_embed_dim = num_patches + + if cfg.use_abs_pos: + self.pos_embed = nn.Parameter(torch.zeros(1, pos_embed_dim, embed_dim)) + else: + self.pos_embed = None + + num_stages = len(cfg.embed_dim) + feat_size = patch_dims + dpr = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(cfg.depths)).split(cfg.depths)] + self.stages = nn.ModuleList() + for i in range(num_stages): + if cfg.expand_attn: + dim_out = cfg.embed_dim[i] + else: + dim_out = cfg.embed_dim[min(i + 1, num_stages - 1)] + stage = MultiScaleVitStage( + dim=embed_dim, + dim_out=dim_out, + depth=cfg.depths[i], + num_heads=cfg.num_heads[i], + feat_size=feat_size, + mlp_ratio=cfg.mlp_ratio, + qkv_bias=cfg.qkv_bias, + mode=cfg.mode, + pool_first=cfg.pool_first, + expand_attn=cfg.expand_attn, + kernel_q=cfg.kernel_qkv, + kernel_kv=cfg.kernel_qkv, + stride_q=cfg.stride_q[i], + stride_kv=cfg.stride_kv[i], + has_cls_token=cfg.use_cls_token, + rel_pos_type=cfg.rel_pos_type, + residual_pooling=cfg.residual_pooling, + norm_layer=norm_layer, + drop_path=dpr[i], + ) + embed_dim = dim_out + feat_size = stage.feat_size + self.stages.append(stage) + + self.num_features = embed_dim + self.norm = norm_layer(embed_dim) + self.head = nn.Sequential(OrderedDict([ + ('drop', nn.Dropout(self.drop_rate)), + ('fc', nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()) + ])) + + if self.pos_embed is not None: + trunc_normal_tf_(self.pos_embed, std=0.02) + if self.cls_token is not None: + trunc_normal_tf_(self.cls_token, std=0.02) + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_tf_(m.weight, std=0.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0.0) + + @torch.jit.ignore + def no_weight_decay(self): + return {k for k, _ in self.named_parameters() + if any(n in k for n in ["pos_embed", "rel_pos_h", "rel_pos_w", "cls_token"])} + + @torch.jit.ignore + def group_matcher(self, coarse=False): + matcher = dict( + stem=r'^patch_embed', # stem and embed + blocks=[(r'^stages\.(\d+)', None), (r'^norm', (99999,))] + ) + return matcher + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + for s in self.stages: + s.grad_checkpointing = enable + + @torch.jit.ignore + def get_classifier(self): + return self.head.fc + + def reset_classifier(self, num_classes, global_pool=None): + self.num_classes = num_classes + if global_pool is not None: + self.global_pool = global_pool + self.head = nn.Sequential(OrderedDict([ + ('drop', nn.Dropout(self.drop_rate)), + ('fc', nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()) + ])) + + def forward_features(self, x): + x, feat_size = self.patch_embed(x) + B, N, C = x.shape + + if self.cls_token is not None: + cls_tokens = self.cls_token.expand(B, -1, -1) + x = torch.cat((cls_tokens, x), dim=1) + + if self.pos_embed is not None: + x = x + self.pos_embed + + for stage in self.stages: + x, feat_size = stage(x, feat_size) + + x = self.norm(x) + return x + + def forward_head(self, x, pre_logits: bool = False): + if self.global_pool: + if self.global_pool == 'avg': + x = x[:, self.num_prefix_tokens:].mean(1) + else: + x = x[:, 0] + return x if pre_logits else self.head(x) + + def forward(self, x): + x = self.forward_features(x) + x = self.forward_head(x) + return x + + +def checkpoint_filter_fn(state_dict, model): + if 'stages.0.blocks.0.norm1.weight' in state_dict: + return state_dict + + import re + if 'model_state' in state_dict: + state_dict = state_dict['model_state'] + + depths = getattr(model, 'depths', None) + expand_attn = getattr(model, 'expand_attn', True) + assert depths is not None, 'model requires depth attribute to remap checkpoints' + depth_map = {} + block_idx = 0 + for stage_idx, d in enumerate(depths): + depth_map.update({i: (stage_idx, i - block_idx) for i in range(block_idx, block_idx + d)}) + block_idx += d + + out_dict = {} + for k, v in state_dict.items(): + k = re.sub( + r'blocks\.(\d+)', + lambda x: f'stages.{depth_map[int(x.group(1))][0]}.blocks.{depth_map[int(x.group(1))][1]}', + k) + + if expand_attn: + k = re.sub(r'stages\.(\d+).blocks\.(\d+).proj', f'stages.\\1.blocks.\\2.shortcut_proj_attn', k) + else: + k = re.sub(r'stages\.(\d+).blocks\.(\d+).proj', f'stages.\\1.blocks.\\2.shortcut_proj_mlp', k) + if 'head' in k: + k = k.replace('head.projection', 'head.fc') + out_dict[k] = v + + # for k, v in state_dict.items(): + # if model.pos_embed is not None and k == 'pos_embed' and v.shape[1] != model.pos_embed.shape[1]: + # # To resize pos embedding when using model at different size from pretrained weights + # v = resize_pos_embed( + # v, + # model.pos_embed, + # 0 if getattr(model, 'no_embed_class') else getattr(model, 'num_prefix_tokens', 1), + # model.patch_embed.grid_size + # ) + + return out_dict + + +def _create_mvitv2(variant, cfg_variant=None, pretrained=False, **kwargs): + return build_model_with_cfg( + MultiScaleVit, variant, pretrained, + model_cfg=model_cfgs[variant] if not cfg_variant else model_cfgs[cfg_variant], + pretrained_filter_fn=checkpoint_filter_fn, + feature_cfg=dict(flatten_sequential=True), + **kwargs) + + +@register_model +def mvitv2_tiny(pretrained=False, **kwargs): + return _create_mvitv2('mvitv2_tiny', pretrained=pretrained, **kwargs) + + +@register_model +def mvitv2_small(pretrained=False, **kwargs): + return _create_mvitv2('mvitv2_small', pretrained=pretrained, **kwargs) + + +@register_model +def mvitv2_base(pretrained=False, **kwargs): + return _create_mvitv2('mvitv2_base', pretrained=pretrained, **kwargs) + + +@register_model +def mvitv2_large(pretrained=False, **kwargs): + return _create_mvitv2('mvitv2_large', pretrained=pretrained, **kwargs) + + +# @register_model +# def mvitv2_base_in21k(pretrained=False, **kwargs): +# return _create_mvitv2('mvitv2_base_in21k', pretrained=pretrained, **kwargs) +# +# +# @register_model +# def mvitv2_large_in21k(pretrained=False, **kwargs): +# return _create_mvitv2('mvitv2_large_in21k', pretrained=pretrained, **kwargs) +# +# +# @register_model +# def mvitv2_huge_in21k(pretrained=False, **kwargs): +# return _create_mvitv2('mvitv2_huge_in21k', pretrained=pretrained, **kwargs) diff --git a/timm/models/pvt_v2.py b/timm/models/pvt_v2.py new file mode 100644 index 00000000..dd3cf690 --- /dev/null +++ b/timm/models/pvt_v2.py @@ -0,0 +1,476 @@ +""" Pyramid Vision Transformer v2 + +@misc{wang2021pvtv2, + title={PVTv2: Improved Baselines with Pyramid Vision Transformer}, + author={Wenhai Wang and Enze Xie and Xiang Li and Deng-Ping Fan and Kaitao Song and Ding Liang and + Tong Lu and Ping Luo and Ling Shao}, + year={2021}, + eprint={2106.13797}, + archivePrefix={arXiv}, + primaryClass={cs.CV} +} + +Based on Apache 2.0 licensed code at https://github.com/whai362/PVT + +Modifications and timm support by / Copyright 2022, Ross Wightman +""" + +import math +from functools import partial +from typing import Tuple, List, Callable, Union + +import torch +import torch.nn as nn +import torch.utils.checkpoint as checkpoint + +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from .helpers import build_model_with_cfg +from .layers import DropPath, to_2tuple, to_ntuple, trunc_normal_ +from .registry import register_model + +__all__ = ['PyramidVisionTransformerV2'] + + +def _cfg(url='', **kwargs): + return { + 'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7), + 'crop_pct': 0.9, 'interpolation': 'bicubic', + 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, + 'first_conv': 'patch_embed.proj', 'classifier': 'head', 'fixed_input_size': False, + **kwargs + } + + +default_cfgs = { + 'pvt_v2_b0': _cfg(url='https://github.com/whai362/PVT/releases/download/v2/pvt_v2_b0.pth'), + 'pvt_v2_b1': _cfg(url='https://github.com/whai362/PVT/releases/download/v2/pvt_v2_b1.pth'), + 'pvt_v2_b2': _cfg(url='https://github.com/whai362/PVT/releases/download/v2/pvt_v2_b2.pth'), + 'pvt_v2_b3': _cfg(url='https://github.com/whai362/PVT/releases/download/v2/pvt_v2_b3.pth'), + 'pvt_v2_b4': _cfg(url='https://github.com/whai362/PVT/releases/download/v2/pvt_v2_b4.pth'), + 'pvt_v2_b5': _cfg(url='https://github.com/whai362/PVT/releases/download/v2/pvt_v2_b5.pth'), + 'pvt_v2_b2_li': _cfg(url='https://github.com/whai362/PVT/releases/download/v2/pvt_v2_b2_li.pth') +} + + +class MlpWithDepthwiseConv(nn.Module): + def __init__( + self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, + drop=0., extra_relu=False): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.relu = nn.ReLU() if extra_relu else nn.Identity() + self.dwconv = nn.Conv2d(hidden_features, hidden_features, 3, 1, 1, bias=True, groups=hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x, feat_size: List[int]): + x = self.fc1(x) + B, N, C = x.shape + x = x.transpose(1, 2).view(B, C, feat_size[0], feat_size[1]) + x = self.relu(x) + x = self.dwconv(x) + x = x.flatten(2).transpose(1, 2) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class Attention(nn.Module): + def __init__( + self, + dim, + num_heads=8, + sr_ratio=1, + linear_attn=False, + qkv_bias=True, + attn_drop=0., + proj_drop=0. + ): + super().__init__() + assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}." + + self.dim = dim + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.scale = self.head_dim ** -0.5 + + self.q = nn.Linear(dim, dim, bias=qkv_bias) + self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + if not linear_attn: + self.pool = None + if sr_ratio > 1: + self.sr = nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio) + self.norm = nn.LayerNorm(dim) + else: + self.sr = None + self.norm = None + self.act = None + else: + self.pool = nn.AdaptiveAvgPool2d(7) + self.sr = nn.Conv2d(dim, dim, kernel_size=1, stride=1) + self.norm = nn.LayerNorm(dim) + self.act = nn.GELU() + + def forward(self, x, feat_size: List[int]): + B, N, C = x.shape + H, W = feat_size + q = self.q(x).reshape(B, N, self.num_heads, -1).permute(0, 2, 1, 3) + + if self.pool is not None: + x_ = x.permute(0, 2, 1).reshape(B, C, H, W) + x_ = self.sr(self.pool(x_)).reshape(B, C, -1).permute(0, 2, 1) + x_ = self.norm(x_) + x_ = self.act(x_) + kv = self.kv(x_).reshape(B, -1, 2, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) + else: + if self.sr is not None: + x_ = x.permute(0, 2, 1).reshape(B, C, H, W) + x_ = self.sr(x_).reshape(B, C, -1).permute(0, 2, 1) + x_ = self.norm(x_) + kv = self.kv(x_).reshape(B, -1, 2, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) + else: + kv = self.kv(x).reshape(B, -1, 2, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) + k, v = kv.unbind(0) + + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class Block(nn.Module): + + def __init__( + self, dim, num_heads, mlp_ratio=4., sr_ratio=1, linear_attn=False, qkv_bias=False, + drop=0., attn_drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = Attention( + dim, + num_heads=num_heads, + sr_ratio=sr_ratio, + linear_attn=linear_attn, + qkv_bias=qkv_bias, + attn_drop=attn_drop, + proj_drop=drop, + ) + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + self.mlp = MlpWithDepthwiseConv( + in_features=dim, + hidden_features=int(dim * mlp_ratio), + act_layer=act_layer, + drop=drop, + extra_relu=linear_attn + ) + + def forward(self, x, feat_size: List[int]): + x = x + self.drop_path(self.attn(self.norm1(x), feat_size)) + x = x + self.drop_path(self.mlp(self.norm2(x), feat_size)) + + return x + + +class OverlapPatchEmbed(nn.Module): + """ Image to Patch Embedding + """ + def __init__(self, patch_size=7, stride=4, in_chans=3, embed_dim=768): + super().__init__() + patch_size = to_2tuple(patch_size) + assert max(patch_size) > stride, "Set larger patch_size than stride" + self.patch_size = patch_size + self.proj = nn.Conv2d( + in_chans, embed_dim, kernel_size=patch_size, stride=stride, + padding=(patch_size[0] // 2, patch_size[1] // 2)) + self.norm = nn.LayerNorm(embed_dim) + + def forward(self, x): + x = self.proj(x) + feat_size = x.shape[-2:] + x = x.flatten(2).transpose(1, 2) + x = self.norm(x) + return x, feat_size + + +class PyramidVisionTransformerStage(nn.Module): + def __init__( + self, + dim: int, + dim_out: int, + depth: int, + downsample: bool = True, + num_heads: int = 8, + sr_ratio: int = 1, + linear_attn: bool = False, + mlp_ratio: float = 4.0, + qkv_bias: bool = True, + drop: float = 0., + attn_drop: float = 0., + drop_path: Union[List[float], float] = 0.0, + norm_layer: Callable = nn.LayerNorm, + ): + super().__init__() + self.grad_checkpointing = False + + if downsample: + self.downsample = OverlapPatchEmbed( + patch_size=3, + stride=2, + in_chans=dim, + embed_dim=dim_out) + else: + assert dim == dim_out + self.downsample = None + + self.blocks = nn.ModuleList([Block( + dim=dim_out, + num_heads=num_heads, + sr_ratio=sr_ratio, + linear_attn=linear_attn, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + drop=drop, + attn_drop=attn_drop, + drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, + norm_layer=norm_layer, + ) for i in range(depth)]) + + self.norm = norm_layer(dim_out) + + def forward(self, x, feat_size: List[int]) -> Tuple[torch.Tensor, List[int]]: + if self.downsample is not None: + x, feat_size = self.downsample(x) + for blk in self.blocks: + if self.grad_checkpointing and not torch.jit.is_scripting(): + x = checkpoint.checkpoint(blk, x, feat_size) + else: + x = blk(x, feat_size) + x = self.norm(x) + x = x.reshape(x.shape[0], feat_size[0], feat_size[1], -1).permute(0, 3, 1, 2).contiguous() + return x, feat_size + + +class PyramidVisionTransformerV2(nn.Module): + def __init__( + self, + img_size=None, + in_chans=3, + num_classes=1000, + global_pool='avg', + depths=(3, 4, 6, 3), + embed_dims=(64, 128, 256, 512), + num_heads=(1, 2, 4, 8), + sr_ratios=(8, 4, 2, 1), + mlp_ratios=(8., 8., 4., 4.), + qkv_bias=True, + linear=False, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0., + norm_layer=nn.LayerNorm, + ): + super().__init__() + self.num_classes = num_classes + assert global_pool in ('avg', '') + self.global_pool = global_pool + self.depths = depths + num_stages = len(depths) + mlp_ratios = to_ntuple(num_stages)(mlp_ratios) + num_heads = to_ntuple(num_stages)(num_heads) + sr_ratios = to_ntuple(num_stages)(sr_ratios) + assert(len(embed_dims)) == num_stages + + self.patch_embed = OverlapPatchEmbed( + patch_size=7, + stride=4, + in_chans=in_chans, + embed_dim=embed_dims[0]) + + dpr = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depths)).split(depths)] + cur = 0 + prev_dim = embed_dims[0] + self.stages = nn.ModuleList() + for i in range(num_stages): + self.stages.append(PyramidVisionTransformerStage( + dim=prev_dim, + dim_out=embed_dims[i], + depth=depths[i], + downsample=i > 0, + num_heads=num_heads[i], + sr_ratio=sr_ratios[i], + mlp_ratio=mlp_ratios[i], + linear_attn=linear, + qkv_bias=qkv_bias, + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=dpr[i], + norm_layer=norm_layer + )) + prev_dim = embed_dims[i] + cur += depths[i] + + # classification head + self.num_features = embed_dims[-1] + self.head = nn.Linear(embed_dims[-1], num_classes) if num_classes > 0 else nn.Identity() + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.Conv2d): + fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + fan_out //= m.groups + m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) + if m.bias is not None: + m.bias.data.zero_() + + def freeze_patch_emb(self): + self.patch_embed.requires_grad = False + + @torch.jit.ignore + def no_weight_decay(self): + return {} + + @torch.jit.ignore + def group_matcher(self, coarse=False): + matcher = dict( + stem=r'^patch_embed', # stem and embed + blocks=r'^stages\.(\d+)' + ) + return matcher + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + for s in self.stages: + s.grad_checkpointing = enable + + def get_classifier(self): + return self.head + + def reset_classifier(self, num_classes, global_pool=None): + self.num_classes = num_classes + if global_pool is not None: + assert global_pool in ('avg', '') + self.global_pool = global_pool + self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() + + def forward_features(self, x): + x, feat_size = self.patch_embed(x) + for stage in self.stages: + x, feat_size = stage(x, feat_size=feat_size) + return x + + def forward_head(self, x, pre_logits: bool = False): + if self.global_pool: + x = x.mean(dim=(-1, -2)) + return x if pre_logits else self.head(x) + + def forward(self, x): + x = self.forward_features(x) + x = self.forward_head(x) + return x + + +def _checkpoint_filter_fn(state_dict, model): + """ Remap original checkpoints -> timm """ + if 'patch_embed.proj.weight' in state_dict: + return state_dict # non-original checkpoint, no remapping needed + + out_dict = {} + import re + for k, v in state_dict.items(): + if k.startswith('patch_embed'): + k = k.replace('patch_embed1', 'patch_embed') + k = k.replace('patch_embed2', 'stages.1.downsample') + k = k.replace('patch_embed3', 'stages.2.downsample') + k = k.replace('patch_embed4', 'stages.3.downsample') + k = k.replace('dwconv.dwconv', 'dwconv') + k = re.sub(r'block(\d+).(\d+)', lambda x: f'stages.{int(x.group(1)) - 1}.blocks.{x.group(2)}', k) + k = re.sub(r'^norm(\d+)', lambda x: f'stages.{int(x.group(1)) - 1}.norm', k) + out_dict[k] = v + return out_dict + + +def _create_pvt2(variant, pretrained=False, **kwargs): + if kwargs.get('features_only', None): + raise RuntimeError('features_only not implemented for Vision Transformer models.') + model = build_model_with_cfg( + PyramidVisionTransformerV2, variant, pretrained, + pretrained_filter_fn=_checkpoint_filter_fn, + **kwargs + ) + return model + + +@register_model +def pvt_v2_b0(pretrained=False, **kwargs): + model_kwargs = dict( + depths=(2, 2, 2, 2), embed_dims=(32, 64, 160, 256), num_heads=(1, 2, 5, 8), + norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) + return _create_pvt2('pvt_v2_b0', pretrained=pretrained, **model_kwargs) + + +@register_model +def pvt_v2_b1(pretrained=False, **kwargs): + model_kwargs = dict( + depths=(2, 2, 2, 2), embed_dims=(64, 128, 320, 512), num_heads=(1, 2, 5, 8), + norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) + return _create_pvt2('pvt_v2_b1', pretrained=pretrained, **model_kwargs) + + +@register_model +def pvt_v2_b2(pretrained=False, **kwargs): + model_kwargs = dict( + depths=(3, 4, 6, 3), embed_dims=(64, 128, 320, 512), num_heads=(1, 2, 5, 8), + norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) + return _create_pvt2('pvt_v2_b2', pretrained=pretrained, **model_kwargs) + + +@register_model +def pvt_v2_b3(pretrained=False, **kwargs): + model_kwargs = dict( + depths=(3, 4, 18, 3), embed_dims=(64, 128, 320, 512), num_heads=(1, 2, 5, 8), + norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) + return _create_pvt2('pvt_v2_b3', pretrained=pretrained, **model_kwargs) + + +@register_model +def pvt_v2_b4(pretrained=False, **kwargs): + model_kwargs = dict( + depths=(3, 8, 27, 3), embed_dims=(64, 128, 320, 512), num_heads=(1, 2, 5, 8), + norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) + return _create_pvt2('pvt_v2_b4', pretrained=pretrained, **model_kwargs) + + +@register_model +def pvt_v2_b5(pretrained=False, **kwargs): + model_kwargs = dict( + depths=(3, 6, 40, 3), embed_dims=(64, 128, 320, 512), num_heads=(1, 2, 5, 8), + mlp_ratios=(4, 4, 4, 4), norm_layer=partial(nn.LayerNorm, eps=1e-6), + **kwargs) + return _create_pvt2('pvt_v2_b5', pretrained=pretrained, **model_kwargs) + + +@register_model +def pvt_v2_b2_li(pretrained=False, **kwargs): + model_kwargs = dict( + depths=(3, 4, 6, 3), embed_dims=(64, 128, 320, 512), num_heads=(1, 2, 5, 8), + norm_layer=partial(nn.LayerNorm, eps=1e-6), linear=True, **kwargs) + return _create_pvt2('pvt_v2_b2_li', pretrained=pretrained, **model_kwargs) + diff --git a/timm/version.py b/timm/version.py index 6ab9c039..1a6fcf61 100644 --- a/timm/version.py +++ b/timm/version.py @@ -1 +1 @@ -__version__ = '0.8.1.dev0' +__version__ = '0.8.2.dev0'