Merge pull request #1415 from rwightman/more_vit

More ViT and ViT-CNN Hybrid architecture
pull/804/merge
Ross Wightman 2 years ago committed by GitHub
commit 4f72bae43b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -21,6 +21,26 @@ And a big thanks to all GitHub sponsors who helped with some of my costs before
## What's New
### Aug 26, 2022
* CoAtNet (https://arxiv.org/abs/2106.04803) and MaxVit (https://arxiv.org/abs/2204.01697) `timm` original models
* both found in [`maxxvit.py`](https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/maxxvit.py) model def, contains numerous experiments outside scope of original papers
* an unfinished Tensorflow version from MaxVit authors can be found https://github.com/google-research/maxvit
* Initial CoAtNet and MaxVit timm pretrained weights (working on more):
* `coatnet_nano_rw_224` - 81.7 @ 224 (T)
* `coatnet_rmlp_nano_rw_224` - 82.0 @ 224, 82.8 @ 320 (T)
* `coatnet_0_rw_224` - 82.4 (T) -- NOTE timm '0' coatnets have 2 more 3rd stage blocks
* `coatnet_bn_0_rw_224` - 82.4 (T)
* `maxvit_nano_rw_256` - 82.9 @ 256 (T)
* `coatnet_rmlp_1_rw_224` - 83.4 @ 224, 84 @ 320 (T)
* `coatnet_1_rw_224` - 83.6 @ 224 (G)
* (T) = TPU trained with `bits_and_tpu` branch training code, (G) = GPU trained
* GCVit (weights adapted from https://github.com/NVlabs/GCVit, code 100% `timm` re-write for license purposes)
* MViT-V2 (multi-scale vit, adapted from https://github.com/facebookresearch/mvit)
* EfficientFormer (adapted from https://github.com/snap-research/EfficientFormer)
* PyramidVisionTransformer-V2 (adapted from https://github.com/whai362/PVT)
* 'Fast Norm' support for LayerNorm and GroupNorm that avoids float32 upcast w/ AMP (uses APEX LN if available for further boost)
### Aug 15, 2022
* ConvNeXt atto weights added
* `convnext_atto` - 75.7 @ 224, 77.0 @ 288
@ -229,6 +249,7 @@ A full version of the list below with source links can be found in the [document
* Bottleneck Transformers - https://arxiv.org/abs/2101.11605
* CaiT (Class-Attention in Image Transformers) - https://arxiv.org/abs/2103.17239
* CoaT (Co-Scale Conv-Attentional Image Transformers) - https://arxiv.org/abs/2104.06399
* CoAtNet (Convolution and Attention) - https://arxiv.org/abs/2106.04803
* ConvNeXt - https://arxiv.org/abs/2201.03545
* ConViT (Soft Convolutional Inductive Biases Vision Transformers)- https://arxiv.org/abs/2103.10697
* CspNet (Cross-Stage Partial Networks) - https://arxiv.org/abs/1911.11929
@ -238,6 +259,7 @@ A full version of the list below with source links can be found in the [document
* DLA - https://arxiv.org/abs/1707.06484
* DPN (Dual-Path Network) - https://arxiv.org/abs/1707.01629
* EdgeNeXt - https://arxiv.org/abs/2206.10589
* EfficientFormer - https://arxiv.org/abs/2206.01191
* EfficientNet (MBConvNet Family)
* EfficientNet NoisyStudent (B0-B7, L2) - https://arxiv.org/abs/1911.04252
* EfficientNet AdvProp (B0-B8) - https://arxiv.org/abs/1911.09665
@ -259,6 +281,7 @@ A full version of the list below with source links can be found in the [document
* Inception-ResNet-V2 and Inception-V4 - https://arxiv.org/abs/1602.07261
* Lambda Networks - https://arxiv.org/abs/2102.08602
* LeViT (Vision Transformer in ConvNet's Clothing) - https://arxiv.org/abs/2104.01136
* MaxViT (Multi-Axis Vision Transformer) - https://arxiv.org/abs/2204.01697
* MLP-Mixer - https://arxiv.org/abs/2105.01601
* MobileNet-V3 (MBConvNet w/ Efficient Head) - https://arxiv.org/abs/1905.02244
* FBNet-V3 - https://arxiv.org/abs/2006.02049
@ -266,6 +289,7 @@ A full version of the list below with source links can be found in the [document
* LCNet - https://arxiv.org/abs/2109.15099
* MobileViT - https://arxiv.org/abs/2110.02178
* MobileViT-V2 - https://arxiv.org/abs/2206.02680
* MViT-V2 (Improved Multiscale Vision Transformer) - https://arxiv.org/abs/2112.01526
* NASNet-A - https://arxiv.org/abs/1707.07012
* NesT - https://arxiv.org/abs/2105.12723
* NFNet-F - https://arxiv.org/abs/2102.06171
@ -273,6 +297,7 @@ A full version of the list below with source links can be found in the [document
* PNasNet - https://arxiv.org/abs/1712.00559
* PoolFormer (MetaFormer) - https://arxiv.org/abs/2111.11418
* Pooling-based Vision Transformer (PiT) - https://arxiv.org/abs/2103.16302
* PVT-V2 (Improved Pyramid Vision Transformer) - https://arxiv.org/abs/2106.13797
* RegNet - https://arxiv.org/abs/2003.13678
* RegNetZ - https://arxiv.org/abs/2103.06877
* RepVGG - https://arxiv.org/abs/2101.03697

@ -19,7 +19,7 @@ import torch.nn as nn
import torch.nn.parallel
from timm.data import resolve_data_config
from timm.models import create_model, is_model, list_models
from timm.models import create_model, is_model, list_models, set_fast_norm
from timm.optim import create_optimizer_v2
from timm.utils import setup_default_logging, set_jit_fuser, decay_batch_step, check_batch_size_retry
@ -109,7 +109,8 @@ scripting_group.add_argument('--torchscript', dest='torchscript', action='store_
help='convert model torchscript for inference')
scripting_group.add_argument('--aot-autograd', default=False, action='store_true',
help="Enable AOT Autograd support. (It's recommended to use this option with `--fuser nvfuser` together)")
scripting_group.add_argument('--fast-norm', default=False, action='store_true',
help='enable experimental fast-norm')
# train optimizer parameters
parser.add_argument('--opt', default='sgd', type=str, metavar='OPTIMIZER',
@ -598,6 +599,9 @@ def main():
model_cfgs = []
model_names = []
if args.fast_norm:
set_fast_norm()
if args.model_list:
args.model = ''
with open(args.model_list) as f:

@ -27,7 +27,9 @@ if hasattr(torch._C, '_jit_set_profiling_executor'):
NON_STD_FILTERS = [
'vit_*', 'tnt_*', 'pit_*', 'swin_*', 'coat_*', 'cait_*', '*mixer_*', 'gmlp_*', 'resmlp_*', 'twins_*',
'convit_*', 'levit*', 'visformer*', 'deit*', 'jx_nest_*', 'nest_*', 'xcit_*', 'crossvit_*', 'beit_*',
'poolformer_*', 'volo_*', 'sequencer2d_*', 'swinv2_*']
'poolformer_*', 'volo_*', 'sequencer2d_*', 'swinv2_*', 'pvt_v2*', 'mvitv2*', 'gcvit*', 'efficientformer*',
'coatnet*', 'coatnext*', 'maxvit*', 'maxxvit*',
]
NUM_NON_STD = len(NON_STD_FILTERS)
# exclude models that cause specific test failures

@ -13,7 +13,9 @@ from .densenet import *
from .dla import *
from .dpn import *
from .edgenext import *
from .efficientformer import *
from .efficientnet import *
from .gcvit import *
from .ghostnet import *
from .gluon_resnet import *
from .gluon_xception import *
@ -23,15 +25,18 @@ from .inception_resnet_v2 import *
from .inception_v3 import *
from .inception_v4 import *
from .levit import *
from .maxxvit import *
from .mlp_mixer import *
from .mobilenetv3 import *
from .mobilevit import *
from .mvitv2 import *
from .nasnet import *
from .nest import *
from .nfnet import *
from .pit import *
from .pnasnet import *
from .poolformer import *
from .pvt_v2 import *
from .regnet import *
from .res2net import *
from .resnest import *
@ -64,5 +69,6 @@ from .helpers import load_checkpoint, resume_checkpoint, model_parameters
from .layers import TestTimePoolHead, apply_test_time_pool
from .layers import convert_splitbn_model, convert_sync_batchnorm
from .layers import is_scriptable, is_exportable, set_scriptable, set_exportable, is_no_jit, set_no_jit
from .layers import set_fast_norm
from .registry import register_model, model_entrypoint, list_models, is_model, list_modules, is_model_in_modules,\
is_model_pretrained, get_pretrained_cfg, has_pretrained_cfg_key, is_pretrained_cfg_key, get_pretrained_cfg_value

@ -19,7 +19,7 @@ import torch.nn as nn
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .helpers import named_apply, build_model_with_cfg, checkpoint_seq
from .layers import trunc_normal_, SelectAdaptivePool2d, DropPath, ConvMlp, Mlp, LayerNorm2d,\
from .layers import trunc_normal_, SelectAdaptivePool2d, DropPath, ConvMlp, Mlp, LayerNorm2d, LayerNorm, \
create_conv2d, get_act_layer, make_divisible, to_ntuple
from .registry import register_model
@ -161,7 +161,7 @@ class ConvNeXtBlock(nn.Module):
out_chs = out_chs or in_chs
act_layer = get_act_layer(act_layer)
if not norm_layer:
norm_layer = partial(LayerNorm2d, eps=1e-6) if conv_mlp else partial(nn.LayerNorm, eps=1e-6)
norm_layer = LayerNorm2d if conv_mlp else LayerNorm
mlp_layer = ConvMlp if conv_mlp else Mlp
self.use_conv_mlp = conv_mlp
@ -291,8 +291,8 @@ class ConvNeXt(nn.Module):
assert output_stride in (8, 16, 32)
kernel_sizes = to_ntuple(4)(kernel_sizes)
if norm_layer is None:
norm_layer = partial(LayerNorm2d, eps=1e-6)
norm_layer_cl = norm_layer if conv_mlp else partial(nn.LayerNorm, eps=1e-6)
norm_layer = LayerNorm2d
norm_layer_cl = norm_layer if conv_mlp else LayerNorm
else:
assert conv_mlp,\
'If a norm_layer is specified, conv MLP must be used so all norm expect rank-4, channels-first input'

@ -115,7 +115,7 @@ class DenseBlock(nn.ModuleDict):
_version = 2
def __init__(
self, num_layers, num_input_features, bn_size, growth_rate, norm_layer=nn.ReLU,
self, num_layers, num_input_features, bn_size, growth_rate, norm_layer=BatchNormAct2d,
drop_rate=0., memory_efficient=False):
super(DenseBlock, self).__init__()
for i in range(num_layers):
@ -138,7 +138,7 @@ class DenseBlock(nn.ModuleDict):
class DenseTransition(nn.Sequential):
def __init__(self, num_input_features, num_output_features, norm_layer=nn.BatchNorm2d, aa_layer=None):
def __init__(self, num_input_features, num_output_features, norm_layer=BatchNormAct2d, aa_layer=None):
super(DenseTransition, self).__init__()
self.add_module('norm', norm_layer(num_input_features))
self.add_module('conv', nn.Conv2d(

@ -0,0 +1,551 @@
""" EfficientFormer
@article{li2022efficientformer,
title={EfficientFormer: Vision Transformers at MobileNet Speed},
author={Li, Yanyu and Yuan, Geng and Wen, Yang and Hu, Eric and Evangelidis, Georgios and Tulyakov,
Sergey and Wang, Yanzhi and Ren, Jian},
journal={arXiv preprint arXiv:2206.01191},
year={2022}
}
Based on Apache 2.0 licensed code at https://github.com/snap-research/EfficientFormer, Copyright (c) 2022 Snap Inc.
Modifications and timm support by / Copyright 2022, Ross Wightman
"""
from typing import Dict
import torch
import torch.nn as nn
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .helpers import build_model_with_cfg
from .layers import DropPath, trunc_normal_, to_2tuple, Mlp
from .registry import register_model
def _cfg(url='', **kwargs):
return {
'url': url,
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, 'fixed_input_size': True,
'crop_pct': .95, 'interpolation': 'bicubic',
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
'first_conv': 'stem.conv1', 'classifier': ('head', 'head_dist'),
**kwargs
}
default_cfgs = dict(
efficientformer_l1=_cfg(
url="https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-morevit/efficientformer_l1_1000d_224-5b08fab0.pth",
),
efficientformer_l3=_cfg(
url="https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-morevit/efficientformer_l3_300d_224-6816624f.pth",
),
efficientformer_l7=_cfg(
url="https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-morevit/efficientformer_l7_300d_224-e957ab75.pth",
),
)
EfficientFormer_width = {
'l1': (48, 96, 224, 448),
'l3': (64, 128, 320, 512),
'l7': (96, 192, 384, 768),
}
EfficientFormer_depth = {
'l1': (3, 2, 6, 4),
'l3': (4, 4, 12, 6),
'l7': (6, 6, 18, 8),
}
class Attention(torch.nn.Module):
attention_bias_cache: Dict[str, torch.Tensor]
def __init__(
self,
dim=384,
key_dim=32,
num_heads=8,
attn_ratio=4,
resolution=7
):
super().__init__()
self.num_heads = num_heads
self.scale = key_dim ** -0.5
self.key_dim = key_dim
self.key_attn_dim = key_dim * num_heads
self.val_dim = int(attn_ratio * key_dim)
self.val_attn_dim = self.val_dim * num_heads
self.attn_ratio = attn_ratio
self.qkv = nn.Linear(dim, self.key_attn_dim * 2 + self.val_attn_dim)
self.proj = nn.Linear(self.val_attn_dim, dim)
resolution = to_2tuple(resolution)
pos = torch.stack(torch.meshgrid(torch.arange(resolution[0]), torch.arange(resolution[1]))).flatten(1)
rel_pos = (pos[..., :, None] - pos[..., None, :]).abs()
rel_pos = (rel_pos[0] * resolution[1]) + rel_pos[1]
self.attention_biases = torch.nn.Parameter(torch.zeros(num_heads, resolution[0] * resolution[1]))
self.register_buffer('attention_bias_idxs', torch.LongTensor(rel_pos))
self.attention_bias_cache = {} # per-device attention_biases cache (data-parallel compat)
@torch.no_grad()
def train(self, mode=True):
super().train(mode)
if mode and self.attention_bias_cache:
self.attention_bias_cache = {} # clear ab cache
def get_attention_biases(self, device: torch.device) -> torch.Tensor:
if self.training:
return self.attention_biases[:, self.attention_bias_idxs]
else:
device_key = str(device)
if device_key not in self.attention_bias_cache:
self.attention_bias_cache[device_key] = self.attention_biases[:, self.attention_bias_idxs]
return self.attention_bias_cache[device_key]
def forward(self, x): # x (B,N,C)
B, N, C = x.shape
qkv = self.qkv(x)
qkv = qkv.reshape(B, N, self.num_heads, -1).permute(0, 2, 1, 3)
q, k, v = qkv.split([self.key_dim, self.key_dim, self.val_dim], dim=3)
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn + self.get_attention_biases(x.device)
attn = attn.softmax(dim=-1)
x = (attn @ v).transpose(1, 2).reshape(B, N, self.val_attn_dim)
x = self.proj(x)
return x
class Stem4(nn.Sequential):
def __init__(self, in_chs, out_chs, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d):
super().__init__()
self.stride = 4
self.add_module('conv1', nn.Conv2d(in_chs, out_chs // 2, kernel_size=3, stride=2, padding=1))
self.add_module('norm1', norm_layer(out_chs // 2))
self.add_module('act1', act_layer())
self.add_module('conv2', nn.Conv2d(out_chs // 2, out_chs, kernel_size=3, stride=2, padding=1))
self.add_module('norm2', norm_layer(out_chs))
self.add_module('act2', act_layer())
class Downsample(nn.Module):
"""
Downsampling via strided conv w/ norm
Input: tensor in shape [B, C, H, W]
Output: tensor in shape [B, C, H/stride, W/stride]
"""
def __init__(self, in_chs, out_chs, kernel_size=3, stride=2, padding=None, norm_layer=nn.BatchNorm2d):
super().__init__()
if padding is None:
padding = kernel_size // 2
self.conv = nn.Conv2d(in_chs, out_chs, kernel_size=kernel_size, stride=stride, padding=padding)
self.norm = norm_layer(out_chs)
def forward(self, x):
x = self.conv(x)
x = self.norm(x)
return x
class Flat(nn.Module):
def __init__(self, ):
super().__init__()
def forward(self, x):
x = x.flatten(2).transpose(1, 2)
return x
class Pooling(nn.Module):
"""
Implementation of pooling for PoolFormer
--pool_size: pooling size
"""
def __init__(self, pool_size=3):
super().__init__()
self.pool = nn.AvgPool2d(pool_size, stride=1, padding=pool_size // 2, count_include_pad=False)
def forward(self, x):
return self.pool(x) - x
class ConvMlpWithNorm(nn.Module):
"""
Implementation of MLP with 1*1 convolutions.
Input: tensor with shape [B, C, H, W]
"""
def __init__(
self,
in_features,
hidden_features=None,
out_features=None,
act_layer=nn.GELU,
norm_layer=nn.BatchNorm2d,
drop=0.
):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Conv2d(in_features, hidden_features, 1)
self.norm1 = norm_layer(hidden_features) if norm_layer is not None else nn.Identity()
self.act = act_layer()
self.fc2 = nn.Conv2d(hidden_features, out_features, 1)
self.norm2 = norm_layer(out_features) if norm_layer is not None else nn.Identity()
self.drop = nn.Dropout(drop)
def forward(self, x):
x = self.fc1(x)
x = self.norm1(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.norm2(x)
x = self.drop(x)
return x
class LayerScale(nn.Module):
def __init__(self, dim, init_values=1e-5, inplace=False):
super().__init__()
self.inplace = inplace
self.gamma = nn.Parameter(init_values * torch.ones(dim))
def forward(self, x):
return x.mul_(self.gamma) if self.inplace else x * self.gamma
class MetaBlock1d(nn.Module):
def __init__(
self,
dim,
mlp_ratio=4.,
act_layer=nn.GELU,
norm_layer=nn.LayerNorm,
drop=0.,
drop_path=0.,
layer_scale_init_value=1e-5
):
super().__init__()
self.norm1 = norm_layer(dim)
self.token_mixer = Attention(dim)
self.norm2 = norm_layer(dim)
self.mlp = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=drop)
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.ls1 = LayerScale(dim, layer_scale_init_value)
self.ls2 = LayerScale(dim, layer_scale_init_value)
def forward(self, x):
x = x + self.drop_path(self.ls1(self.token_mixer(self.norm1(x))))
x = x + self.drop_path(self.ls2(self.mlp(self.norm2(x))))
return x
class LayerScale2d(nn.Module):
def __init__(self, dim, init_values=1e-5, inplace=False):
super().__init__()
self.inplace = inplace
self.gamma = nn.Parameter(init_values * torch.ones(dim))
def forward(self, x):
gamma = self.gamma.view(1, -1, 1, 1)
return x.mul_(gamma) if self.inplace else x * gamma
class MetaBlock2d(nn.Module):
def __init__(
self,
dim,
pool_size=3,
mlp_ratio=4.,
act_layer=nn.GELU,
norm_layer=nn.BatchNorm2d,
drop=0.,
drop_path=0.,
layer_scale_init_value=1e-5
):
super().__init__()
self.token_mixer = Pooling(pool_size=pool_size)
self.mlp = ConvMlpWithNorm(
dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, norm_layer=norm_layer, drop=drop)
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.ls1 = LayerScale2d(dim, layer_scale_init_value)
self.ls2 = LayerScale2d(dim, layer_scale_init_value)
def forward(self, x):
x = x + self.drop_path(self.ls1(self.token_mixer(x)))
x = x + self.drop_path(self.ls2(self.mlp(x)))
return x
class EfficientFormerStage(nn.Module):
def __init__(
self,
dim,
dim_out,
depth,
downsample=True,
num_vit=1,
pool_size=3,
mlp_ratio=4.,
act_layer=nn.GELU,
norm_layer=nn.BatchNorm2d,
norm_layer_cl=nn.LayerNorm,
drop=.0,
drop_path=0.,
layer_scale_init_value=1e-5,
):
super().__init__()
self.grad_checkpointing = False
if downsample:
self.downsample = Downsample(in_chs=dim, out_chs=dim_out, norm_layer=norm_layer)
dim = dim_out
else:
assert dim == dim_out
self.downsample = nn.Identity()
blocks = []
if num_vit and num_vit >= depth:
blocks.append(Flat())
for block_idx in range(depth):
remain_idx = depth - block_idx - 1
if num_vit and num_vit > remain_idx:
blocks.append(
MetaBlock1d(
dim,
mlp_ratio=mlp_ratio,
act_layer=act_layer,
norm_layer=norm_layer_cl,
drop=drop,
drop_path=drop_path[block_idx],
layer_scale_init_value=layer_scale_init_value,
))
else:
blocks.append(
MetaBlock2d(
dim,
pool_size=pool_size,
mlp_ratio=mlp_ratio,
act_layer=act_layer,
norm_layer=norm_layer,
drop=drop,
drop_path=drop_path[block_idx],
layer_scale_init_value=layer_scale_init_value,
))
if num_vit and num_vit == remain_idx:
blocks.append(Flat())
self.blocks = nn.Sequential(*blocks)
def forward(self, x):
x = self.downsample(x)
x = self.blocks(x)
return x
class EfficientFormer(nn.Module):
def __init__(
self,
depths,
embed_dims=None,
in_chans=3,
num_classes=1000,
global_pool='avg',
downsamples=None,
num_vit=0,
mlp_ratios=4,
pool_size=3,
layer_scale_init_value=1e-5,
act_layer=nn.GELU,
norm_layer=nn.BatchNorm2d,
norm_layer_cl=nn.LayerNorm,
drop_rate=0.,
drop_path_rate=0.,
**kwargs
):
super().__init__()
self.num_classes = num_classes
self.global_pool = global_pool
self.stem = Stem4(in_chans, embed_dims[0], norm_layer=norm_layer)
prev_dim = embed_dims[0]
# stochastic depth decay rule
dpr = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depths)).split(depths)]
downsamples = downsamples or (False,) + (True,) * (len(depths) - 1)
stages = []
for i in range(len(depths)):
stage = EfficientFormerStage(
prev_dim,
embed_dims[i],
depths[i],
downsample=downsamples[i],
num_vit=num_vit if i == 3 else 0,
pool_size=pool_size,
mlp_ratio=mlp_ratios,
act_layer=act_layer,
norm_layer_cl=norm_layer_cl,
norm_layer=norm_layer,
drop=drop_rate,
drop_path=dpr[i],
layer_scale_init_value=layer_scale_init_value,
)
prev_dim = embed_dims[i]
stages.append(stage)
self.stages = nn.Sequential(*stages)
# Classifier head
self.num_features = embed_dims[-1]
self.norm = norm_layer_cl(self.num_features)
self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
# assuming model is always distilled (valid for current checkpoints, will split def if that changes)
self.head_dist = nn.Linear(embed_dims[-1], num_classes) if num_classes > 0 else nn.Identity()
self.distilled_training = False # must set this True to train w/ distillation token
self.apply(self._init_weights)
# init for classification
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
@torch.jit.ignore
def no_weight_decay(self):
return {k for k, _ in self.named_parameters() if 'attention_biases' in k}
@torch.jit.ignore
def group_matcher(self, coarse=False):
matcher = dict(
stem=r'^stem', # stem and embed
blocks=[(r'^stages\.(\d+)', None), (r'^norm', (99999,))]
)
return matcher
@torch.jit.ignore
def set_grad_checkpointing(self, enable=True):
for s in self.stages:
s.grad_checkpointing = enable
@torch.jit.ignore
def get_classifier(self):
return self.head, self.head_dist
def reset_classifier(self, num_classes, global_pool=None):
self.num_classes = num_classes
if global_pool is not None:
self.global_pool = global_pool
self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
self.head_dist = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
@torch.jit.ignore
def set_distilled_training(self, enable=True):
self.distilled_training = enable
def forward_features(self, x):
x = self.stem(x)
x = self.stages(x)
x = self.norm(x)
return x
def forward_head(self, x, pre_logits: bool = False):
if self.global_pool == 'avg':
x = x.mean(dim=1)
if pre_logits:
return x
x, x_dist = self.head(x), self.head_dist(x)
if self.distilled_training and self.training and not torch.jit.is_scripting():
# only return separate classification predictions when training in distilled mode
return x, x_dist
else:
# during standard train/finetune, inference average the classifier predictions
return (x + x_dist) / 2
def forward(self, x):
x = self.forward_features(x)
x = self.forward_head(x)
return x
def _checkpoint_filter_fn(state_dict, model):
""" Remap original checkpoints -> timm """
if 'stem.0.weight' in state_dict:
return state_dict # non-original checkpoint, no remapping needed
out_dict = {}
import re
stage_idx = 0
for k, v in state_dict.items():
if k.startswith('patch_embed'):
k = k.replace('patch_embed.0', 'stem.conv1')
k = k.replace('patch_embed.1', 'stem.norm1')
k = k.replace('patch_embed.3', 'stem.conv2')
k = k.replace('patch_embed.4', 'stem.norm2')
if re.match(r'network\.(\d+)\.proj\.weight', k):
stage_idx += 1
k = re.sub(r'network.(\d+).(\d+)', f'stages.{stage_idx}.blocks.\\2', k)
k = re.sub(r'network.(\d+).proj', f'stages.{stage_idx}.downsample.conv', k)
k = re.sub(r'network.(\d+).norm', f'stages.{stage_idx}.downsample.norm', k)
k = re.sub(r'layer_scale_([0-9])', r'ls\1.gamma', k)
k = k.replace('dist_head', 'head_dist')
out_dict[k] = v
return out_dict
def _create_efficientformer(variant, pretrained=False, **kwargs):
model = build_model_with_cfg(
EfficientFormer, variant, pretrained,
pretrained_filter_fn=_checkpoint_filter_fn,
**kwargs)
return model
@register_model
def efficientformer_l1(pretrained=False, **kwargs):
model_kwargs = dict(
depths=EfficientFormer_depth['l1'],
embed_dims=EfficientFormer_width['l1'],
num_vit=1,
**kwargs)
return _create_efficientformer('efficientformer_l1', pretrained=pretrained, **model_kwargs)
@register_model
def efficientformer_l3(pretrained=False, **kwargs):
model_kwargs = dict(
depths=EfficientFormer_depth['l3'],
embed_dims=EfficientFormer_width['l3'],
num_vit=4,
**kwargs)
return _create_efficientformer('efficientformer_l3', pretrained=pretrained, **model_kwargs)
@register_model
def efficientformer_l7(pretrained=False, **kwargs):
model_kwargs = dict(
depths=EfficientFormer_depth['l7'],
embed_dims=EfficientFormer_width['l7'],
num_vit=8,
**kwargs)
return _create_efficientformer('efficientformer_l7', pretrained=pretrained, **model_kwargs)

@ -0,0 +1,588 @@
""" Global Context ViT
From scratch implementation of GCViT in the style of timm swin_transformer_v2_cr.py
Global Context Vision Transformers -https://arxiv.org/abs/2206.09959
@article{hatamizadeh2022global,
title={Global Context Vision Transformers},
author={Hatamizadeh, Ali and Yin, Hongxu and Kautz, Jan and Molchanov, Pavlo},
journal={arXiv preprint arXiv:2206.09959},
year={2022}
}
Free of any code related to NVIDIA GCVit impl at https://github.com/NVlabs/GCVit.
The license for this code release is Apache 2.0 with no commercial restrictions.
However, weight files adapted from NVIDIA GCVit impl ARE under a non-commercial share-alike license
(https://creativecommons.org/licenses/by-nc-sa/4.0/) until I have a chance to train new ones...
Hacked together by / Copyright 2022, Ross Wightman
"""
import math
from functools import partial
from typing import Callable, List, Optional, Tuple, Union
import torch
import torch.nn as nn
import torch.utils.checkpoint as checkpoint
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .fx_features import register_notrace_function
from .helpers import build_model_with_cfg, named_apply
from .layers import trunc_normal_tf_, DropPath, to_2tuple, Mlp, get_attn, get_act_layer, get_norm_layer, \
ClassifierHead, LayerNorm2d, _assert
from .registry import register_model
from .vision_transformer_relpos import RelPosMlp, RelPosBias # FIXME move to common location
__all__ = ['GlobalContextVit']
def _cfg(url='', **kwargs):
return {
'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
'crop_pct': 0.875, 'interpolation': 'bicubic',
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
'first_conv': 'stem.conv1', 'classifier': 'head.fc',
'fixed_input_size': True,
**kwargs
}
default_cfgs = {
'gcvit_xxtiny': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-morevit/gcvit_xxtiny_224_nvidia-d1d86009.pth'),
'gcvit_xtiny': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-morevit/gcvit_xtiny_224_nvidia-274b92b7.pth'),
'gcvit_tiny': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-morevit/gcvit_tiny_224_nvidia-ac783954.pth'),
'gcvit_small': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-morevit/gcvit_small_224_nvidia-4e98afa2.pth'),
'gcvit_base': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-morevit/gcvit_base_224_nvidia-f009139b.pth'),
}
class MbConvBlock(nn.Module):
""" A depthwise separable / fused mbconv style residual block with SE, `no norm.
"""
def __init__(
self,
in_chs,
out_chs=None,
expand_ratio=1.0,
attn_layer='se',
bias=False,
act_layer=nn.GELU,
):
super().__init__()
attn_kwargs = dict(act_layer=act_layer)
if isinstance(attn_layer, str) and attn_layer == 'se' or attn_layer == 'eca':
attn_kwargs['rd_ratio'] = 0.25
attn_kwargs['bias'] = False
attn_layer = get_attn(attn_layer)
out_chs = out_chs or in_chs
mid_chs = int(expand_ratio * in_chs)
self.conv_dw = nn.Conv2d(in_chs, mid_chs, 3, 1, 1, groups=in_chs, bias=bias)
self.act = act_layer()
self.se = attn_layer(mid_chs, **attn_kwargs)
self.conv_pw = nn.Conv2d(mid_chs, out_chs, 1, 1, 0, bias=bias)
def forward(self, x):
shortcut = x
x = self.conv_dw(x)
x = self.act(x)
x = self.se(x)
x = self.conv_pw(x)
x = x + shortcut
return x
class Downsample2d(nn.Module):
def __init__(
self,
dim,
dim_out=None,
reduction='conv',
act_layer=nn.GELU,
norm_layer=LayerNorm2d, # NOTE in NCHW
):
super().__init__()
dim_out = dim_out or dim
self.norm1 = norm_layer(dim) if norm_layer is not None else nn.Identity()
self.conv_block = MbConvBlock(dim, act_layer=act_layer)
assert reduction in ('conv', 'max', 'avg')
if reduction == 'conv':
self.reduction = nn.Conv2d(dim, dim_out, 3, 2, 1, bias=False)
elif reduction == 'max':
assert dim == dim_out
self.reduction = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
else:
assert dim == dim_out
self.reduction = nn.AvgPool2d(kernel_size=2)
self.norm2 = norm_layer(dim_out) if norm_layer is not None else nn.Identity()
def forward(self, x):
x = self.norm1(x)
x = self.conv_block(x)
x = self.reduction(x)
x = self.norm2(x)
return x
class FeatureBlock(nn.Module):
def __init__(
self,
dim,
levels=0,
reduction='max',
act_layer=nn.GELU,
):
super().__init__()
reductions = levels
levels = max(1, levels)
if reduction == 'avg':
pool_fn = partial(nn.AvgPool2d, kernel_size=2)
else:
pool_fn = partial(nn.MaxPool2d, kernel_size=3, stride=2, padding=1)
self.blocks = nn.Sequential()
for i in range(levels):
self.blocks.add_module(f'conv{i+1}', MbConvBlock(dim, act_layer=act_layer))
if reductions:
self.blocks.add_module(f'pool{i+1}', pool_fn())
reductions -= 1
def forward(self, x):
return self.blocks(x)
class Stem(nn.Module):
def __init__(
self,
in_chs: int = 3,
out_chs: int = 96,
act_layer: Callable = nn.GELU,
norm_layer: Callable = LayerNorm2d, # NOTE stem in NCHW
):
super().__init__()
self.conv1 = nn.Conv2d(in_chs, out_chs, kernel_size=3, stride=2, padding=1)
self.down = Downsample2d(out_chs, act_layer=act_layer, norm_layer=norm_layer)
def forward(self, x):
x = self.conv1(x)
x = self.down(x)
return x
class WindowAttentionGlobal(nn.Module):
def __init__(
self,
dim: int,
num_heads: int,
window_size: Tuple[int, int],
use_global: bool = True,
qkv_bias: bool = True,
attn_drop: float = 0.,
proj_drop: float = 0.,
):
super().__init__()
window_size = to_2tuple(window_size)
self.window_size = window_size
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.scale = self.head_dim ** -0.5
self.use_global = use_global
self.rel_pos = RelPosBias(window_size=window_size, num_heads=num_heads)
if self.use_global:
self.qkv = nn.Linear(dim, dim * 2, bias=qkv_bias)
else:
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x, q_global: Optional[torch.Tensor] = None):
B, N, C = x.shape
if self.use_global and q_global is not None:
_assert(x.shape[-1] == q_global.shape[-1], 'x and q_global seq lengths should be equal')
kv = self.qkv(x)
kv = kv.reshape(B, N, 2, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
k, v = kv.unbind(0)
q = q_global.repeat(B // q_global.shape[0], 1, 1, 1)
q = q.reshape(B, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
else:
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
q, k, v = qkv.unbind(0)
q = q * self.scale
attn = (q @ k.transpose(-2, -1))
attn = self.rel_pos(attn)
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
def window_partition(x, window_size: Tuple[int, int]):
B, H, W, C = x.shape
x = x.view(B, H // window_size[0], window_size[0], W // window_size[1], window_size[1], C)
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size[0], window_size[1], C)
return windows
@register_notrace_function # reason: int argument is a Proxy
def window_reverse(windows, window_size: Tuple[int, int], img_size: Tuple[int, int]):
H, W = img_size
B = int(windows.shape[0] / (H * W / window_size[0] / window_size[1]))
x = windows.view(B, H // window_size[0], W // window_size[1], window_size[0], window_size[1], -1)
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
return x
class LayerScale(nn.Module):
def __init__(self, dim, init_values=1e-5, inplace=False):
super().__init__()
self.inplace = inplace
self.gamma = nn.Parameter(init_values * torch.ones(dim))
def forward(self, x):
return x.mul_(self.gamma) if self.inplace else x * self.gamma
class GlobalContextVitBlock(nn.Module):
def __init__(
self,
dim: int,
feat_size: Tuple[int, int],
num_heads: int,
window_size: int = 7,
mlp_ratio: float = 4.,
use_global: bool = True,
qkv_bias: bool = True,
layer_scale: Optional[float] = None,
proj_drop: float = 0.,
attn_drop: float = 0.,
drop_path: float = 0.,
attn_layer: Callable = WindowAttentionGlobal,
act_layer: Callable = nn.GELU,
norm_layer: Callable = nn.LayerNorm,
):
super().__init__()
feat_size = to_2tuple(feat_size)
window_size = to_2tuple(window_size)
self.window_size = window_size
self.num_windows = int((feat_size[0] // window_size[0]) * (feat_size[1] // window_size[1]))
self.norm1 = norm_layer(dim)
self.attn = attn_layer(
dim,
num_heads=num_heads,
window_size=window_size,
use_global=use_global,
qkv_bias=qkv_bias,
attn_drop=attn_drop,
proj_drop=proj_drop,
)
self.ls1 = LayerScale(dim, layer_scale) if layer_scale is not None else nn.Identity()
self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(dim)
self.mlp = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=proj_drop)
self.ls2 = LayerScale(dim, layer_scale) if layer_scale is not None else nn.Identity()
self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
def _window_attn(self, x, q_global: Optional[torch.Tensor] = None):
B, H, W, C = x.shape
x_win = window_partition(x, self.window_size)
x_win = x_win.view(-1, self.window_size[0] * self.window_size[1], C)
attn_win = self.attn(x_win, q_global)
x = window_reverse(attn_win, self.window_size, (H, W))
return x
def forward(self, x, q_global: Optional[torch.Tensor] = None):
x = x + self.drop_path1(self.ls1(self._window_attn(self.norm1(x), q_global)))
x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x))))
return x
class GlobalContextVitStage(nn.Module):
def __init__(
self,
dim,
depth: int,
num_heads: int,
feat_size: Tuple[int, int],
window_size: int,
downsample: bool = True,
global_norm: bool = False,
stage_norm: bool = False,
mlp_ratio: float = 4.,
qkv_bias: bool = True,
layer_scale: Optional[float] = None,
proj_drop: float = 0.,
attn_drop: float = 0.,
drop_path: Union[List[float], float] = 0.0,
act_layer: Callable = nn.GELU,
norm_layer: Callable = nn.LayerNorm,
norm_layer_cl: Callable = LayerNorm2d,
):
super().__init__()
if downsample:
self.downsample = Downsample2d(
dim=dim,
dim_out=dim * 2,
norm_layer=norm_layer,
)
dim = dim * 2
feat_size = (feat_size[0] // 2, feat_size[1] // 2)
else:
self.downsample = nn.Identity()
self.feat_size = feat_size
feat_levels = int(math.log2(min(feat_size) / window_size))
self.global_block = FeatureBlock(dim, feat_levels)
self.global_norm = norm_layer_cl(dim) if global_norm else nn.Identity()
self.blocks = nn.ModuleList([
GlobalContextVitBlock(
dim=dim,
num_heads=num_heads,
feat_size=feat_size,
window_size=window_size,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
use_global=(i % 2 != 0),
layer_scale=layer_scale,
proj_drop=proj_drop,
attn_drop=attn_drop,
drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
act_layer=act_layer,
norm_layer=norm_layer_cl,
)
for i in range(depth)
])
self.norm = norm_layer_cl(dim) if stage_norm else nn.Identity()
self.dim = dim
self.feat_size = feat_size
self.grad_checkpointing = False
def forward(self, x):
# input NCHW, downsample & global block are 2d conv + pooling
x = self.downsample(x)
global_query = self.global_block(x)
# reshape NCHW --> NHWC for transformer blocks
x = x.permute(0, 2, 3, 1)
global_query = self.global_norm(global_query.permute(0, 2, 3, 1))
for blk in self.blocks:
if self.grad_checkpointing and not torch.jit.is_scripting():
x = checkpoint.checkpoint(blk, x)
else:
x = blk(x, global_query)
x = self.norm(x)
x = x.permute(0, 3, 1, 2).contiguous() # back to NCHW
return x
class GlobalContextVit(nn.Module):
def __init__(
self,
in_chans: int = 3,
num_classes: int = 1000,
global_pool: str = 'avg',
img_size: Tuple[int, int] = 224,
window_size: Tuple[int, ...] = (7, 7, 14, 7),
embed_dim: int = 64,
depths: Tuple[int, ...] = (3, 4, 19, 5),
num_heads: Tuple[int, ...] = (2, 4, 8, 16),
mlp_ratio: float = 3.0,
qkv_bias: bool = True,
layer_scale: Optional[float] = None,
drop_rate: float = 0.,
proj_drop_rate: float = 0.,
attn_drop_rate: float = 0.,
drop_path_rate: float = 0.,
weight_init='vit',
act_layer: str = 'gelu',
norm_layer: str = 'layernorm2d',
norm_layer_cl: str = 'layernorm',
norm_eps: float = 1e-5,
):
super().__init__()
act_layer = get_act_layer(act_layer)
norm_layer = partial(get_norm_layer(norm_layer), eps=norm_eps)
norm_layer_cl = partial(get_norm_layer(norm_layer_cl), eps=norm_eps)
img_size = to_2tuple(img_size)
feat_size = tuple(d // 4 for d in img_size) # stem reduction by 4
self.global_pool = global_pool
self.num_classes = num_classes
self.drop_rate = drop_rate
num_stages = len(depths)
self.num_features = int(embed_dim * 2 ** (num_stages - 1))
self.stem = Stem(
in_chs=in_chans,
out_chs=embed_dim,
act_layer=act_layer,
norm_layer=norm_layer
)
dpr = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depths)).split(depths)]
stages = []
for i in range(num_stages):
last_stage = i == num_stages - 1
stage_scale = 2 ** max(i - 1, 0)
stages.append(GlobalContextVitStage(
dim=embed_dim * stage_scale,
depth=depths[i],
num_heads=num_heads[i],
feat_size=(feat_size[0] // stage_scale, feat_size[1] // stage_scale),
window_size=window_size[i],
downsample=i != 0,
stage_norm=last_stage,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
layer_scale=layer_scale,
proj_drop=proj_drop_rate,
attn_drop=attn_drop_rate,
drop_path=dpr[i],
act_layer=act_layer,
norm_layer=norm_layer,
norm_layer_cl=norm_layer_cl,
))
self.stages = nn.Sequential(*stages)
# Classifier head
self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=drop_rate)
if weight_init:
named_apply(partial(self._init_weights, scheme=weight_init), self)
def _init_weights(self, module, name, scheme='vit'):
# note Conv2d left as default init
if scheme == 'vit':
if isinstance(module, nn.Linear):
nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
if 'mlp' in name:
nn.init.normal_(module.bias, std=1e-6)
else:
nn.init.zeros_(module.bias)
else:
if isinstance(module, nn.Linear):
trunc_normal_tf_(module.weight, std=.02)
if module.bias is not None:
nn.init.zeros_(module.bias)
@torch.jit.ignore
def no_weight_decay(self):
return {
k for k, _ in self.named_parameters()
if any(n in k for n in ["relative_position_bias_table", "rel_pos.mlp"])}
@torch.jit.ignore
def group_matcher(self, coarse=False):
matcher = dict(
stem=r'^stem', # stem and embed
blocks=r'^stages\.(\d+)'
)
return matcher
@torch.jit.ignore
def set_grad_checkpointing(self, enable=True):
for s in self.stages:
s.grad_checkpointing = enable
@torch.jit.ignore
def get_classifier(self):
return self.head.fc
def reset_classifier(self, num_classes, global_pool=None):
self.num_classes = num_classes
if global_pool is None:
global_pool = self.head.global_pool.pool_type
self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate)
def forward_features(self, x: torch.Tensor) -> torch.Tensor:
x = self.stem(x)
x = self.stages(x)
return x
def forward_head(self, x, pre_logits: bool = False):
return self.head(x, pre_logits=pre_logits)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.forward_features(x)
x = self.forward_head(x)
return x
def _create_gcvit(variant, pretrained=False, **kwargs):
if kwargs.get('features_only', None):
raise RuntimeError('features_only not implemented for Vision Transformer models.')
model = build_model_with_cfg(GlobalContextVit, variant, pretrained, **kwargs)
return model
@register_model
def gcvit_xxtiny(pretrained=False, **kwargs):
model_kwargs = dict(
depths=(2, 2, 6, 2),
num_heads=(2, 4, 8, 16),
**kwargs)
return _create_gcvit('gcvit_xxtiny', pretrained=pretrained, **model_kwargs)
@register_model
def gcvit_xtiny(pretrained=False, **kwargs):
model_kwargs = dict(
depths=(3, 4, 6, 5),
num_heads=(2, 4, 8, 16),
**kwargs)
return _create_gcvit('gcvit_xtiny', pretrained=pretrained, **model_kwargs)
@register_model
def gcvit_tiny(pretrained=False, **kwargs):
model_kwargs = dict(
depths=(3, 4, 19, 5),
num_heads=(2, 4, 8, 16),
**kwargs)
return _create_gcvit('gcvit_tiny', pretrained=pretrained, **model_kwargs)
@register_model
def gcvit_small(pretrained=False, **kwargs):
model_kwargs = dict(
depths=(3, 4, 19, 5),
num_heads=(3, 6, 12, 24),
window_size=(7, 7, 14, 7),
embed_dim=96,
mlp_ratio=2,
layer_scale=1e-5,
**kwargs)
return _create_gcvit('gcvit_small', pretrained=pretrained, **model_kwargs)
@register_model
def gcvit_base(pretrained=False, **kwargs):
model_kwargs = dict(
depths=(3, 4, 19, 5),
num_heads=(4, 8, 16, 32),
window_size=(7, 7, 14, 7),
embed_dim=128,
mlp_ratio=2,
layer_scale=1e-5,
**kwargs)
return _create_gcvit('gcvit_base', pretrained=pretrained, **model_kwargs)

@ -11,21 +11,23 @@ from .conv_bn_act import ConvNormAct, ConvNormActAa, ConvBnAct
from .create_act import create_act_layer, get_act_layer, get_act_fn
from .create_attn import get_attn, create_attn
from .create_conv2d import create_conv2d
from .create_norm import get_norm_layer, create_norm_layer
from .create_norm_act import get_norm_act_layer, create_norm_act_layer, get_norm_act_layer
from .drop import DropBlock2d, DropPath, drop_block_2d, drop_path
from .eca import EcaModule, CecaModule, EfficientChannelAttn, CircularEfficientChannelAttn
from .evo_norm import EvoNorm2dB0, EvoNorm2dB1, EvoNorm2dB2,\
EvoNorm2dS0, EvoNorm2dS0a, EvoNorm2dS1, EvoNorm2dS1a, EvoNorm2dS2, EvoNorm2dS2a
from .fast_norm import is_fast_norm, set_fast_norm, fast_group_norm, fast_layer_norm
from .filter_response_norm import FilterResponseNormTlu2d, FilterResponseNormAct2d
from .gather_excite import GatherExcite
from .global_context import GlobalContext
from .helpers import to_ntuple, to_2tuple, to_3tuple, to_4tuple, make_divisible
from .helpers import to_ntuple, to_2tuple, to_3tuple, to_4tuple, make_divisible, extend_tuple
from .inplace_abn import InplaceAbn
from .linear import Linear
from .mixed_conv2d import MixedConv2d
from .mlp import Mlp, GluMlp, GatedMlp, ConvMlp
from .non_local_attn import NonLocalAttn, BatNonLocalAttn
from .norm import GroupNorm, GroupNorm1, LayerNorm2d
from .norm import GroupNorm, GroupNorm1, LayerNorm, LayerNorm2d
from .norm_act import BatchNormAct2d, GroupNormAct, convert_sync_batchnorm
from .padding import get_padding, get_same_padding, pad_same
from .patch_embed import PatchEmbed

@ -145,4 +145,10 @@ def create_act_layer(name: Union[nn.Module, str], inplace=None, **kwargs):
act_layer = get_act_layer(name)
if act_layer is None:
return None
return act_layer(**kwargs) if inplace is None else act_layer(inplace=inplace, **kwargs)
if inplace is None:
return act_layer(**kwargs)
try:
return act_layer(inplace=inplace, **kwargs)
except TypeError:
# recover if act layer doesn't have inplace arg
return act_layer(**kwargs)

@ -0,0 +1,56 @@
""" Norm Layer Factory
Create norm modules by string (to mirror create_act and creat_norm-act fns)
Copyright 2022 Ross Wightman
"""
import types
import functools
import torch.nn as nn
from .norm import GroupNorm, GroupNorm1, LayerNorm, LayerNorm2d
_NORM_MAP = dict(
batchnorm=nn.BatchNorm2d,
batchnorm2d=nn.BatchNorm2d,
batchnorm1d=nn.BatchNorm1d,
groupnorm=GroupNorm,
groupnorm1=GroupNorm1,
layernorm=LayerNorm,
layernorm2d=LayerNorm2d,
)
_NORM_TYPES = {m for n, m in _NORM_MAP.items()}
def create_norm_layer(layer_name, num_features, act_layer=None, apply_act=True, **kwargs):
layer = get_norm_layer(layer_name, act_layer=act_layer)
layer_instance = layer(num_features, apply_act=apply_act, **kwargs)
return layer_instance
def get_norm_layer(norm_layer):
assert isinstance(norm_layer, (type, str, types.FunctionType, functools.partial))
norm_kwargs = {}
# unbind partial fn, so args can be rebound later
if isinstance(norm_layer, functools.partial):
norm_kwargs.update(norm_layer.keywords)
norm_layer = norm_layer.func
if isinstance(norm_layer, str):
layer_name = norm_layer.replace('_', '')
norm_layer = _NORM_MAP.get(layer_name, None)
elif norm_layer in _NORM_TYPES:
norm_layer = norm_layer
elif isinstance(norm_layer, types.FunctionType):
# if function type, assume it is a lambda/fn that creates a norm layer
norm_layer = norm_layer
else:
type_name = norm_layer.__name__.lower().replace('_', '')
norm_layer = _NORM_MAP.get(type_name, None)
assert norm_layer is not None, f"No equivalent norm layer for {type_name}"
if norm_kwargs:
norm_layer = functools.partial(norm_layer, **norm_kwargs) # bind/rebind args
return norm_layer

@ -18,6 +18,7 @@ _NORM_ACT_MAP = dict(
batchnorm=BatchNormAct2d,
batchnorm2d=BatchNormAct2d,
groupnorm=GroupNormAct,
groupnorm1=functools.partial(GroupNormAct, num_groups=1),
layernorm=LayerNormAct,
layernorm2d=LayerNormAct2d,
evonormb0=EvoNorm2dB0,
@ -72,6 +73,8 @@ def get_norm_act_layer(norm_layer, act_layer=None):
norm_act_layer = BatchNormAct2d
elif type_name.startswith('groupnorm'):
norm_act_layer = GroupNormAct
elif type_name.startswith('groupnorm1'):
norm_act_layer = functools.partial(GroupNormAct, num_groups=1)
elif type_name.startswith('layernorm2d'):
norm_act_layer = LayerNormAct2d
elif type_name.startswith('layernorm'):

@ -0,0 +1,78 @@
""" 'Fast' Normalization Functions
For GroupNorm and LayerNorm these functions bypass typical AMP upcast to float32.
Additionally, for LayerNorm, the APEX fused LN is used if available (which also does not upcast)
Hacked together by / Copyright 2022 Ross Wightman
"""
from typing import List, Optional
import torch
from torch.nn import functional as F
try:
from apex.normalization.fused_layer_norm import fused_layer_norm_affine
has_apex = True
except ImportError:
has_apex = False
# fast (ie lower precision LN) can be disabled with this flag if issues crop up
_USE_FAST_NORM = False # defaulting to False for now
def is_fast_norm():
return _USE_FAST_NORM
def set_fast_norm(enable=True):
global _USE_FAST_NORM
_USE_FAST_NORM = enable
def fast_group_norm(
x: torch.Tensor,
num_groups: int,
weight: Optional[torch.Tensor] = None,
bias: Optional[torch.Tensor] = None,
eps: float = 1e-5
) -> torch.Tensor:
if torch.jit.is_scripting():
# currently cannot use is_autocast_enabled within torchscript
return F.group_norm(x, num_groups, weight, bias, eps)
if torch.is_autocast_enabled():
# normally native AMP casts GN inputs to float32
# here we use the low precision autocast dtype
# FIXME what to do re CPU autocast?
dt = torch.get_autocast_gpu_dtype()
x, weight, bias = x.to(dt), weight.to(dt), bias.to(dt)
with torch.cuda.amp.autocast(enabled=False):
return F.group_norm(x, num_groups, weight, bias, eps)
def fast_layer_norm(
x: torch.Tensor,
normalized_shape: List[int],
weight: Optional[torch.Tensor] = None,
bias: Optional[torch.Tensor] = None,
eps: float = 1e-5
) -> torch.Tensor:
if torch.jit.is_scripting():
# currently cannot use is_autocast_enabled within torchscript
return F.layer_norm(x, normalized_shape, weight, bias, eps)
if has_apex:
return fused_layer_norm_affine(x, weight, bias, normalized_shape, eps)
if torch.is_autocast_enabled():
# normally native AMP casts LN inputs to float32
# apex LN does not, this is behaving like Apex
dt = torch.get_autocast_gpu_dtype()
# FIXME what to do re CPU autocast?
x, weight, bias = x.to(dt), weight.to(dt), bias.to(dt)
with torch.cuda.amp.autocast(enabled=False):
return F.layer_norm(x, normalized_shape, weight, bias, eps)

@ -29,3 +29,15 @@ def make_divisible(v, divisor=8, min_value=None, round_limit=.9):
if new_v < round_limit * v:
new_v += divisor
return new_v
def extend_tuple(x, n):
# pdas a tuple to specified n by padding with last value
if not isinstance(x, (tuple, list)):
x = (x,)
else:
x = tuple(x)
pad_n = n - len(x)
if pad_n <= 0:
return x[:n]
return x + (x[-1],) * pad_n

@ -1,17 +1,28 @@
""" Normalization layers and wrappers
Norm layer definitions that support fast norm and consistent channel arg order (always first arg).
Hacked together by / Copyright 2022 Ross Wightman
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from .fast_norm import is_fast_norm, fast_group_norm, fast_layer_norm
class GroupNorm(nn.GroupNorm):
def __init__(self, num_channels, num_groups=32, eps=1e-5, affine=True):
# NOTE num_channels is swapped to first arg for consistency in swapping norm layers with BN
super().__init__(num_groups, num_channels, eps=eps, affine=affine)
self.fast_norm = is_fast_norm() # can't script unless we have these flags here (no globals)
def forward(self, x):
return F.group_norm(x, self.num_groups, self.weight, self.bias, self.eps)
if self.fast_norm:
return fast_group_norm(x, self.num_groups, self.weight, self.bias, self.eps)
else:
return F.group_norm(x, self.num_groups, self.weight, self.bias, self.eps)
class GroupNorm1(nn.GroupNorm):
@ -21,22 +32,48 @@ class GroupNorm1(nn.GroupNorm):
def __init__(self, num_channels, **kwargs):
super().__init__(1, num_channels, **kwargs)
self.fast_norm = is_fast_norm() # can't script unless we have these flags here (no globals)
def forward(self, x: torch.Tensor) -> torch.Tensor:
if self.fast_norm:
return fast_group_norm(x, self.num_groups, self.weight, self.bias, self.eps)
else:
return F.group_norm(x, self.num_groups, self.weight, self.bias, self.eps)
class LayerNorm(nn.LayerNorm):
""" LayerNorm w/ fast norm option
"""
def __init__(self, num_channels, eps=1e-6, affine=True):
super().__init__(num_channels, eps=eps, elementwise_affine=affine)
self._fast_norm = is_fast_norm() # can't script unless we have these flags here (no globals)
def forward(self, x: torch.Tensor) -> torch.Tensor:
if self._fast_norm:
x = fast_layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
else:
x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
return x
class LayerNorm2d(nn.LayerNorm):
""" LayerNorm for channels of '2D' spatial NCHW tensors """
def __init__(self, num_channels, eps=1e-6, affine=True):
super().__init__(num_channels, eps=eps, elementwise_affine=affine)
self._fast_norm = is_fast_norm() # can't script unless we have these flags here (no globals)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return F.layer_norm(
x.permute(0, 2, 3, 1), self.normalized_shape, self.weight, self.bias, self.eps).permute(0, 3, 1, 2)
x = x.permute(0, 2, 3, 1)
if self._fast_norm:
x = fast_layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
else:
x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
x = x.permute(0, 3, 1, 2)
return x
def _is_contiguous(tensor: torch.Tensor) -> bool:
# jit is oh so lovely :/
# if torch.jit.is_tracing():
# return True
if torch.jit.is_scripting():
return tensor.is_contiguous()
else:
@ -51,6 +88,14 @@ def _layer_norm_cf(x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, ep
return x
def _layer_norm_cf_sqm(x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, eps: float):
u = x.mean(dim=1, keepdim=True)
s = ((x * x).mean(dim=1, keepdim=True) - (u * u)).clamp(0)
x = (x - u) * torch.rsqrt(s + eps)
x = x * weight.view(1, -1, 1, 1) + bias.view(1, -1, 1, 1)
return x
class LayerNormExp2d(nn.LayerNorm):
""" LayerNorm for channels_first tensors with 2d spatial dimensions (ie N, C, H, W).

@ -1,4 +1,16 @@
""" Normalization + Activation Layers
Provides Norm+Act fns for standard PyTorch norm layers such as
* BatchNorm
* GroupNorm
* LayerNorm
This allows swapping with alternative layers that are natively both norm + act such as
* EvoNorm (evo_norm.py)
* FilterResponseNorm (filter_response_norm.py)
* InplaceABN (inplace_abn.py)
Hacked together by / Copyright 2022 Ross Wightman
"""
from typing import Union, List, Optional, Any
@ -6,8 +18,9 @@ import torch
from torch import nn as nn
from torch.nn import functional as F
from .trace_utils import _assert
from .create_act import get_act_layer
from .fast_norm import is_fast_norm, fast_group_norm, fast_layer_norm
from .trace_utils import _assert
class BatchNormAct2d(nn.BatchNorm2d):
@ -177,9 +190,13 @@ class GroupNormAct(nn.GroupNorm):
self.act = act_layer(**act_args)
else:
self.act = nn.Identity()
self._fast_norm = is_fast_norm()
def forward(self, x):
x = F.group_norm(x, self.num_groups, self.weight, self.bias, self.eps)
if self._fast_norm:
x = fast_group_norm(x, self.num_groups, self.weight, self.bias, self.eps)
else:
x = F.group_norm(x, self.num_groups, self.weight, self.bias, self.eps)
x = self.drop(x)
x = self.act(x)
return x
@ -197,9 +214,13 @@ class LayerNormAct(nn.LayerNorm):
self.act = act_layer(**act_args)
else:
self.act = nn.Identity()
self._fast_norm = is_fast_norm()
def forward(self, x):
x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
if self._fast_norm:
x = fast_layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
else:
x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
x = self.drop(x)
x = self.act(x)
return x
@ -217,10 +238,15 @@ class LayerNormAct2d(nn.LayerNorm):
self.act = act_layer(**act_args)
else:
self.act = nn.Identity()
self._fast_norm = is_fast_norm()
def forward(self, x):
x = F.layer_norm(
x.permute(0, 2, 3, 1), self.normalized_shape, self.weight, self.bias, self.eps).permute(0, 3, 1, 2)
x = x.permute(0, 2, 3, 1)
if self._fast_norm:
x = fast_layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
else:
x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
x = x.permute(0, 3, 1, 2)
x = self.drop(x)
x = self.act(x)
return x

@ -27,15 +27,15 @@ class SEModule(nn.Module):
"""
def __init__(
self, channels, rd_ratio=1. / 16, rd_channels=None, rd_divisor=8, add_maxpool=False,
act_layer=nn.ReLU, norm_layer=None, gate_layer='sigmoid'):
bias=True, act_layer=nn.ReLU, norm_layer=None, gate_layer='sigmoid'):
super(SEModule, self).__init__()
self.add_maxpool = add_maxpool
if not rd_channels:
rd_channels = make_divisible(channels * rd_ratio, rd_divisor, round_limit=0.)
self.fc1 = nn.Conv2d(channels, rd_channels, kernel_size=1, bias=True)
self.fc1 = nn.Conv2d(channels, rd_channels, kernel_size=1, bias=bias)
self.bn = norm_layer(rd_channels) if norm_layer else nn.Identity()
self.act = create_act_layer(act_layer, inplace=True)
self.fc2 = nn.Conv2d(rd_channels, channels, kernel_size=1, bias=True)
self.fc2 = nn.Conv2d(rd_channels, channels, kernel_size=1, bias=bias)
self.gate = create_act_layer(gate_layer)
def forward(self, x):

@ -5,7 +5,7 @@ import warnings
from torch.nn.init import _calculate_fan_in_and_fan_out
def _no_grad_trunc_normal_(tensor, mean, std, a, b):
def _trunc_normal_(tensor, mean, std, a, b):
# Cut & paste from PyTorch official master until it's in a few official releases - RW
# Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
def norm_cdf(x):
@ -17,28 +17,27 @@ def _no_grad_trunc_normal_(tensor, mean, std, a, b):
"The distribution of values may be incorrect.",
stacklevel=2)
with torch.no_grad():
# Values are generated by using a truncated uniform distribution and
# then using the inverse CDF for the normal distribution.
# Get upper and lower cdf values
l = norm_cdf((a - mean) / std)
u = norm_cdf((b - mean) / std)
# Values are generated by using a truncated uniform distribution and
# then using the inverse CDF for the normal distribution.
# Get upper and lower cdf values
l = norm_cdf((a - mean) / std)
u = norm_cdf((b - mean) / std)
# Uniformly fill tensor with values from [l, u], then translate to
# [2l-1, 2u-1].
tensor.uniform_(2 * l - 1, 2 * u - 1)
# Uniformly fill tensor with values from [l, u], then translate to
# [2l-1, 2u-1].
tensor.uniform_(2 * l - 1, 2 * u - 1)
# Use inverse cdf transform for normal distribution to get truncated
# standard normal
tensor.erfinv_()
# Use inverse cdf transform for normal distribution to get truncated
# standard normal
tensor.erfinv_()
# Transform to proper mean, std
tensor.mul_(std * math.sqrt(2.))
tensor.add_(mean)
# Transform to proper mean, std
tensor.mul_(std * math.sqrt(2.))
tensor.add_(mean)
# Clamp to ensure it's in the proper range
tensor.clamp_(min=a, max=b)
return tensor
# Clamp to ensure it's in the proper range
tensor.clamp_(min=a, max=b)
return tensor
def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
@ -64,7 +63,8 @@ def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
>>> w = torch.empty(3, 5)
>>> nn.init.trunc_normal_(w)
"""
return _no_grad_trunc_normal_(tensor, mean, std, a, b)
with torch.no_grad():
return _trunc_normal_(tensor, mean, std, a, b)
def trunc_normal_tf_(tensor, mean=0., std=1., a=-2., b=2.):
@ -90,8 +90,8 @@ def trunc_normal_tf_(tensor, mean=0., std=1., a=-2., b=2.):
>>> w = torch.empty(3, 5)
>>> nn.init.trunc_normal_(w)
"""
_no_grad_trunc_normal_(tensor, 0, 1.0, a, b)
with torch.no_grad():
_trunc_normal_(tensor, 0, 1.0, a, b)
tensor.mul_(std).add_(mean)
return tensor
@ -111,10 +111,12 @@ def variance_scaling_(tensor, scale=1.0, mode='fan_in', distribution='normal'):
# constant is stddev of standard normal truncated to (-2, 2)
trunc_normal_tf_(tensor, std=math.sqrt(variance) / .87962566103423978)
elif distribution == "normal":
tensor.normal_(std=math.sqrt(variance))
with torch.no_grad():
tensor.normal_(std=math.sqrt(variance))
elif distribution == "uniform":
bound = math.sqrt(3 * variance)
tensor.uniform_(-bound, bound)
with torch.no_grad():
tensor.uniform_(-bound, bound)
else:
raise ValueError(f"invalid distribution {distribution}")

File diff suppressed because it is too large Load Diff

@ -0,0 +1,998 @@
""" Multi-Scale Vision Transformer v2
@inproceedings{li2021improved,
title={MViTv2: Improved multiscale vision transformers for classification and detection},
author={Li, Yanghao and Wu, Chao-Yuan and Fan, Haoqi and Mangalam, Karttikeya and Xiong, Bo and Malik, Jitendra and Feichtenhofer, Christoph},
booktitle={CVPR},
year={2022}
}
Code adapted from original Apache 2.0 licensed impl at https://github.com/facebookresearch/mvit
Original copyright below.
Modifications and timm support by / Copyright 2022, Ross Wightman
"""
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved. All Rights Reserved.
import operator
from collections import OrderedDict
from dataclasses import dataclass
from functools import partial, reduce
from typing import Union, List, Tuple, Optional
import torch
import torch.utils.checkpoint as checkpoint
from torch import nn
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .fx_features import register_notrace_function
from .helpers import build_model_with_cfg
from .layers import Mlp, DropPath, trunc_normal_tf_, get_norm_layer, to_2tuple
from .registry import register_model
def _cfg(url='', **kwargs):
return {
'url': url,
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
'crop_pct': .9, 'interpolation': 'bicubic',
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
'first_conv': 'patch_embed.proj', 'classifier': 'head.fc',
'fixed_input_size': True,
**kwargs
}
default_cfgs = dict(
mvitv2_tiny=_cfg(url='https://dl.fbaipublicfiles.com/mvit/mvitv2_models/MViTv2_T_in1k.pyth'),
mvitv2_small=_cfg(url='https://dl.fbaipublicfiles.com/mvit/mvitv2_models/MViTv2_S_in1k.pyth'),
mvitv2_base=_cfg(url='https://dl.fbaipublicfiles.com/mvit/mvitv2_models/MViTv2_B_in1k.pyth'),
mvitv2_large=_cfg(url='https://dl.fbaipublicfiles.com/mvit/mvitv2_models/MViTv2_L_in1k.pyth'),
mvitv2_base_in21k=_cfg(
url='https://dl.fbaipublicfiles.com/mvit/mvitv2_models/MViTv2_B_in21k.pyth',
num_classes=19168),
mvitv2_large_in21k=_cfg(
url='https://dl.fbaipublicfiles.com/mvit/mvitv2_models/MViTv2_L_in21k.pyth',
num_classes=19168),
mvitv2_huge_in21k=_cfg(
url='https://dl.fbaipublicfiles.com/mvit/mvitv2_models/MViTv2_H_in21k.pyth',
num_classes=19168),
)
@dataclass
class MultiScaleVitCfg:
depths: Tuple[int, ...] = (2, 3, 16, 3)
embed_dim: Union[int, Tuple[int, ...]] = 96
num_heads: Union[int, Tuple[int, ...]] = 1
mlp_ratio: float = 4.
pool_first: bool = False
expand_attn: bool = True
qkv_bias: bool = True
use_cls_token: bool = False
use_abs_pos: bool = False
residual_pooling: bool = True
mode: str = 'conv'
kernel_qkv: Tuple[int, int] = (3, 3)
stride_q: Optional[Tuple[Tuple[int, int]]] = ((1, 1), (2, 2), (2, 2), (2, 2))
stride_kv: Optional[Tuple[Tuple[int, int]]] = None
stride_kv_adaptive: Optional[Tuple[int, int]] = (4, 4)
patch_kernel: Tuple[int, int] = (7, 7)
patch_stride: Tuple[int, int] = (4, 4)
patch_padding: Tuple[int, int] = (3, 3)
pool_type: str = 'max'
rel_pos_type: str = 'spatial'
act_layer: Union[str, Tuple[str, str]] = 'gelu'
norm_layer: Union[str, Tuple[str, str]] = 'layernorm'
norm_eps: float = 1e-6
def __post_init__(self):
num_stages = len(self.depths)
if not isinstance(self.embed_dim, (tuple, list)):
self.embed_dim = tuple(self.embed_dim * 2 ** i for i in range(num_stages))
assert len(self.embed_dim) == num_stages
if not isinstance(self.num_heads, (tuple, list)):
self.num_heads = tuple(self.num_heads * 2 ** i for i in range(num_stages))
assert len(self.num_heads) == num_stages
if self.stride_kv_adaptive is not None and self.stride_kv is None:
_stride_kv = self.stride_kv_adaptive
pool_kv_stride = []
for i in range(num_stages):
if min(self.stride_q[i]) > 1:
_stride_kv = [
max(_stride_kv[d] // self.stride_q[i][d], 1)
for d in range(len(_stride_kv))
]
pool_kv_stride.append(tuple(_stride_kv))
self.stride_kv = tuple(pool_kv_stride)
model_cfgs = dict(
mvitv2_tiny=MultiScaleVitCfg(
depths=(1, 2, 5, 2),
),
mvitv2_small=MultiScaleVitCfg(
depths=(1, 2, 11, 2),
),
mvitv2_base=MultiScaleVitCfg(
depths=(2, 3, 16, 3),
),
mvitv2_large=MultiScaleVitCfg(
depths=(2, 6, 36, 4),
embed_dim=144,
num_heads=2,
expand_attn=False,
),
mvitv2_base_in21k=MultiScaleVitCfg(
depths=(2, 3, 16, 3),
),
mvitv2_large_in21k=MultiScaleVitCfg(
depths=(2, 6, 36, 4),
embed_dim=144,
num_heads=2,
expand_attn=False,
),
)
def prod(iterable):
return reduce(operator.mul, iterable, 1)
class PatchEmbed(nn.Module):
"""
PatchEmbed.
"""
def __init__(
self,
dim_in=3,
dim_out=768,
kernel=(7, 7),
stride=(4, 4),
padding=(3, 3),
):
super().__init__()
self.proj = nn.Conv2d(
dim_in,
dim_out,
kernel_size=kernel,
stride=stride,
padding=padding,
)
def forward(self, x) -> Tuple[torch.Tensor, List[int]]:
x = self.proj(x)
# B C H W -> B HW C
return x.flatten(2).transpose(1, 2), x.shape[-2:]
@register_notrace_function
def reshape_pre_pool(
x,
feat_size: List[int],
has_cls_token: bool = True
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
H, W = feat_size
if has_cls_token:
cls_tok, x = x[:, :, :1, :], x[:, :, 1:, :]
else:
cls_tok = None
x = x.reshape(-1, H, W, x.shape[-1]).permute(0, 3, 1, 2).contiguous()
return x, cls_tok
@register_notrace_function
def reshape_post_pool(
x,
num_heads: int,
cls_tok: Optional[torch.Tensor] = None
) -> Tuple[torch.Tensor, List[int]]:
feat_size = [x.shape[2], x.shape[3]]
L_pooled = x.shape[2] * x.shape[3]
x = x.reshape(-1, num_heads, x.shape[1], L_pooled).transpose(2, 3)
if cls_tok is not None:
x = torch.cat((cls_tok, x), dim=2)
return x, feat_size
@register_notrace_function
def cal_rel_pos_type(
attn: torch.Tensor,
q: torch.Tensor,
has_cls_token: bool,
q_size: List[int],
k_size: List[int],
rel_pos_h: torch.Tensor,
rel_pos_w: torch.Tensor,
):
"""
Spatial Relative Positional Embeddings.
"""
sp_idx = 1 if has_cls_token else 0
q_h, q_w = q_size
k_h, k_w = k_size
# Scale up rel pos if shapes for q and k are different.
q_h_ratio = max(k_h / q_h, 1.0)
k_h_ratio = max(q_h / k_h, 1.0)
dist_h = torch.arange(q_h)[:, None] * q_h_ratio - torch.arange(k_h)[None, :] * k_h_ratio
dist_h += (k_h - 1) * k_h_ratio
q_w_ratio = max(k_w / q_w, 1.0)
k_w_ratio = max(q_w / k_w, 1.0)
dist_w = torch.arange(q_w)[:, None] * q_w_ratio - torch.arange(k_w)[None, :] * k_w_ratio
dist_w += (k_w - 1) * k_w_ratio
Rh = rel_pos_h[dist_h.long()]
Rw = rel_pos_w[dist_w.long()]
B, n_head, q_N, dim = q.shape
r_q = q[:, :, sp_idx:].reshape(B, n_head, q_h, q_w, dim)
rel_h = torch.einsum("byhwc,hkc->byhwk", r_q, Rh)
rel_w = torch.einsum("byhwc,wkc->byhwk", r_q, Rw)
attn[:, :, sp_idx:, sp_idx:] = (
attn[:, :, sp_idx:, sp_idx:].view(B, -1, q_h, q_w, k_h, k_w)
+ rel_h[:, :, :, :, :, None]
+ rel_w[:, :, :, :, None, :]
).view(B, -1, q_h * q_w, k_h * k_w)
return attn
class MultiScaleAttentionPoolFirst(nn.Module):
def __init__(
self,
dim,
dim_out,
feat_size,
num_heads=8,
qkv_bias=True,
mode="conv",
kernel_q=(1, 1),
kernel_kv=(1, 1),
stride_q=(1, 1),
stride_kv=(1, 1),
has_cls_token=True,
rel_pos_type='spatial',
residual_pooling=True,
norm_layer=nn.LayerNorm,
):
super().__init__()
self.num_heads = num_heads
self.dim_out = dim_out
self.head_dim = dim_out // num_heads
self.scale = self.head_dim ** -0.5
self.has_cls_token = has_cls_token
padding_q = tuple([int(q // 2) for q in kernel_q])
padding_kv = tuple([int(kv // 2) for kv in kernel_kv])
self.q = nn.Linear(dim, dim_out, bias=qkv_bias)
self.k = nn.Linear(dim, dim_out, bias=qkv_bias)
self.v = nn.Linear(dim, dim_out, bias=qkv_bias)
self.proj = nn.Linear(dim_out, dim_out)
# Skip pooling with kernel and stride size of (1, 1, 1).
if prod(kernel_q) == 1 and prod(stride_q) == 1:
kernel_q = None
if prod(kernel_kv) == 1 and prod(stride_kv) == 1:
kernel_kv = None
self.mode = mode
self.unshared = mode == 'conv_unshared'
self.pool_q, self.pool_k, self.pool_v = None, None, None
self.norm_q, self.norm_k, self.norm_v = None, None, None
if mode in ("avg", "max"):
pool_op = nn.MaxPool2d if mode == "max" else nn.AvgPool2d
if kernel_q:
self.pool_q = pool_op(kernel_q, stride_q, padding_q)
if kernel_kv:
self.pool_k = pool_op(kernel_kv, stride_kv, padding_kv)
self.pool_v = pool_op(kernel_kv, stride_kv, padding_kv)
elif mode == "conv" or mode == "conv_unshared":
dim_conv = dim // num_heads if mode == "conv" else dim
if kernel_q:
self.pool_q = nn.Conv2d(
dim_conv,
dim_conv,
kernel_q,
stride=stride_q,
padding=padding_q,
groups=dim_conv,
bias=False,
)
self.norm_q = norm_layer(dim_conv)
if kernel_kv:
self.pool_k = nn.Conv2d(
dim_conv,
dim_conv,
kernel_kv,
stride=stride_kv,
padding=padding_kv,
groups=dim_conv,
bias=False,
)
self.norm_k = norm_layer(dim_conv)
self.pool_v = nn.Conv2d(
dim_conv,
dim_conv,
kernel_kv,
stride=stride_kv,
padding=padding_kv,
groups=dim_conv,
bias=False,
)
self.norm_v = norm_layer(dim_conv)
else:
raise NotImplementedError(f"Unsupported model {mode}")
# relative pos embedding
self.rel_pos_type = rel_pos_type
if self.rel_pos_type == 'spatial':
assert feat_size[0] == feat_size[1]
size = feat_size[0]
q_size = size // stride_q[1] if len(stride_q) > 0 else size
kv_size = size // stride_kv[1] if len(stride_kv) > 0 else size
rel_sp_dim = 2 * max(q_size, kv_size) - 1
self.rel_pos_h = nn.Parameter(torch.zeros(rel_sp_dim, self.head_dim))
self.rel_pos_w = nn.Parameter(torch.zeros(rel_sp_dim, self.head_dim))
trunc_normal_tf_(self.rel_pos_h, std=0.02)
trunc_normal_tf_(self.rel_pos_w, std=0.02)
self.residual_pooling = residual_pooling
def forward(self, x, feat_size: List[int]):
B, N, _ = x.shape
fold_dim = 1 if self.unshared else self.num_heads
x = x.reshape(B, N, fold_dim, -1).permute(0, 2, 1, 3)
q = k = v = x
if self.pool_q is not None:
q, q_tok = reshape_pre_pool(q, feat_size, self.has_cls_token)
q = self.pool_q(q)
q, q_size = reshape_post_pool(q, self.num_heads, q_tok)
else:
q_size = feat_size
if self.norm_q is not None:
q = self.norm_q(q)
if self.pool_k is not None:
k, k_tok = reshape_pre_pool(k, feat_size, self.has_cls_token)
k = self.pool_k(k)
k, k_size = reshape_post_pool(k, self.num_heads, k_tok)
else:
k_size = feat_size
if self.norm_k is not None:
k = self.norm_k(k)
if self.pool_v is not None:
v, v_tok = reshape_pre_pool(v, feat_size, self.has_cls_token)
v = self.pool_v(v)
v, v_size = reshape_post_pool(v, self.num_heads, v_tok)
else:
v_size = feat_size
if self.norm_v is not None:
v = self.norm_v(v)
q_N = q_size[0] * q_size[1] + int(self.has_cls_token)
q = q.permute(0, 2, 1, 3).reshape(B, q_N, -1)
q = self.q(q).reshape(B, q_N, self.num_heads, -1).permute(0, 2, 1, 3)
k_N = k_size[0] * k_size[1] + int(self.has_cls_token)
k = k.permute(0, 2, 1, 3).reshape(B, k_N, -1)
k = self.k(k).reshape(B, k_N, self.num_heads, -1).permute(0, 2, 1, 3)
v_N = v_size[0] * v_size[1] + int(self.has_cls_token)
v = v.permute(0, 2, 1, 3).reshape(B, v_N, -1)
v = self.v(v).reshape(B, v_N, self.num_heads, -1).permute(0, 2, 1, 3)
attn = (q * self.scale) @ k.transpose(-2, -1)
if self.rel_pos_type == 'spatial':
attn = cal_rel_pos_type(
attn,
q,
self.has_cls_token,
q_size,
k_size,
self.rel_pos_h,
self.rel_pos_w,
)
attn = attn.softmax(dim=-1)
x = attn @ v
if self.residual_pooling:
x = x + q
x = x.transpose(1, 2).reshape(B, -1, self.dim_out)
x = self.proj(x)
return x, q_size
class MultiScaleAttention(nn.Module):
def __init__(
self,
dim,
dim_out,
feat_size,
num_heads=8,
qkv_bias=True,
mode="conv",
kernel_q=(1, 1),
kernel_kv=(1, 1),
stride_q=(1, 1),
stride_kv=(1, 1),
has_cls_token=True,
rel_pos_type='spatial',
residual_pooling=True,
norm_layer=nn.LayerNorm,
):
super().__init__()
self.num_heads = num_heads
self.dim_out = dim_out
self.head_dim = dim_out // num_heads
self.scale = self.head_dim ** -0.5
self.has_cls_token = has_cls_token
padding_q = tuple([int(q // 2) for q in kernel_q])
padding_kv = tuple([int(kv // 2) for kv in kernel_kv])
self.qkv = nn.Linear(dim, dim_out * 3, bias=qkv_bias)
self.proj = nn.Linear(dim_out, dim_out)
# Skip pooling with kernel and stride size of (1, 1, 1).
if prod(kernel_q) == 1 and prod(stride_q) == 1:
kernel_q = None
if prod(kernel_kv) == 1 and prod(stride_kv) == 1:
kernel_kv = None
self.mode = mode
self.unshared = mode == 'conv_unshared'
self.norm_q, self.norm_k, self.norm_v = None, None, None
self.pool_q, self.pool_k, self.pool_v = None, None, None
if mode in ("avg", "max"):
pool_op = nn.MaxPool2d if mode == "max" else nn.AvgPool2d
if kernel_q:
self.pool_q = pool_op(kernel_q, stride_q, padding_q)
if kernel_kv:
self.pool_k = pool_op(kernel_kv, stride_kv, padding_kv)
self.pool_v = pool_op(kernel_kv, stride_kv, padding_kv)
elif mode == "conv" or mode == "conv_unshared":
dim_conv = dim_out // num_heads if mode == "conv" else dim_out
if kernel_q:
self.pool_q = nn.Conv2d(
dim_conv,
dim_conv,
kernel_q,
stride=stride_q,
padding=padding_q,
groups=dim_conv,
bias=False,
)
self.norm_q = norm_layer(dim_conv)
if kernel_kv:
self.pool_k = nn.Conv2d(
dim_conv,
dim_conv,
kernel_kv,
stride=stride_kv,
padding=padding_kv,
groups=dim_conv,
bias=False,
)
self.norm_k = norm_layer(dim_conv)
self.pool_v = nn.Conv2d(
dim_conv,
dim_conv,
kernel_kv,
stride=stride_kv,
padding=padding_kv,
groups=dim_conv,
bias=False,
)
self.norm_v = norm_layer(dim_conv)
else:
raise NotImplementedError(f"Unsupported model {mode}")
# relative pos embedding
self.rel_pos_type = rel_pos_type
if self.rel_pos_type == 'spatial':
assert feat_size[0] == feat_size[1]
size = feat_size[0]
q_size = size // stride_q[1] if len(stride_q) > 0 else size
kv_size = size // stride_kv[1] if len(stride_kv) > 0 else size
rel_sp_dim = 2 * max(q_size, kv_size) - 1
self.rel_pos_h = nn.Parameter(torch.zeros(rel_sp_dim, self.head_dim))
self.rel_pos_w = nn.Parameter(torch.zeros(rel_sp_dim, self.head_dim))
trunc_normal_tf_(self.rel_pos_h, std=0.02)
trunc_normal_tf_(self.rel_pos_w, std=0.02)
self.residual_pooling = residual_pooling
def forward(self, x, feat_size: List[int]):
B, N, _ = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
q, k, v = qkv.unbind(dim=0)
if self.pool_q is not None:
q, q_tok = reshape_pre_pool(q, feat_size, self.has_cls_token)
q = self.pool_q(q)
q, q_size = reshape_post_pool(q, self.num_heads, q_tok)
else:
q_size = feat_size
if self.norm_q is not None:
q = self.norm_q(q)
if self.pool_k is not None:
k, k_tok = reshape_pre_pool(k, feat_size, self.has_cls_token)
k = self.pool_k(k)
k, k_size = reshape_post_pool(k, self.num_heads, k_tok)
else:
k_size = feat_size
if self.norm_k is not None:
k = self.norm_k(k)
if self.pool_v is not None:
v, v_tok = reshape_pre_pool(v, feat_size, self.has_cls_token)
v = self.pool_v(v)
v, _ = reshape_post_pool(v, self.num_heads, v_tok)
if self.norm_v is not None:
v = self.norm_v(v)
attn = (q * self.scale) @ k.transpose(-2, -1)
if self.rel_pos_type == 'spatial':
attn = cal_rel_pos_type(
attn,
q,
self.has_cls_token,
q_size,
k_size,
self.rel_pos_h,
self.rel_pos_w,
)
attn = attn.softmax(dim=-1)
x = attn @ v
if self.residual_pooling:
x = x + q
x = x.transpose(1, 2).reshape(B, -1, self.dim_out)
x = self.proj(x)
return x, q_size
class MultiScaleBlock(nn.Module):
def __init__(
self,
dim,
dim_out,
num_heads,
feat_size,
mlp_ratio=4.0,
qkv_bias=True,
drop_path=0.0,
norm_layer=nn.LayerNorm,
kernel_q=(1, 1),
kernel_kv=(1, 1),
stride_q=(1, 1),
stride_kv=(1, 1),
mode="conv",
has_cls_token=True,
expand_attn=False,
pool_first=False,
rel_pos_type='spatial',
residual_pooling=True,
):
super().__init__()
proj_needed = dim != dim_out
self.dim = dim
self.dim_out = dim_out
self.has_cls_token = has_cls_token
self.norm1 = norm_layer(dim)
self.shortcut_proj_attn = nn.Linear(dim, dim_out) if proj_needed and expand_attn else None
if stride_q and prod(stride_q) > 1:
kernel_skip = [s + 1 if s > 1 else s for s in stride_q]
stride_skip = stride_q
padding_skip = [int(skip // 2) for skip in kernel_skip]
self.shortcut_pool_attn = nn.MaxPool2d(kernel_skip, stride_skip, padding_skip)
else:
self.shortcut_pool_attn = None
att_dim = dim_out if expand_attn else dim
attn_layer = MultiScaleAttentionPoolFirst if pool_first else MultiScaleAttention
self.attn = attn_layer(
dim,
att_dim,
num_heads=num_heads,
feat_size=feat_size,
qkv_bias=qkv_bias,
kernel_q=kernel_q,
kernel_kv=kernel_kv,
stride_q=stride_q,
stride_kv=stride_kv,
norm_layer=norm_layer,
has_cls_token=has_cls_token,
mode=mode,
rel_pos_type=rel_pos_type,
residual_pooling=residual_pooling,
)
self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
self.norm2 = norm_layer(att_dim)
mlp_dim_out = dim_out
self.shortcut_proj_mlp = nn.Linear(dim, dim_out) if proj_needed and not expand_attn else None
self.mlp = Mlp(
in_features=att_dim,
hidden_features=int(att_dim * mlp_ratio),
out_features=mlp_dim_out,
)
self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
def _shortcut_pool(self, x, feat_size: List[int]):
if self.shortcut_pool_attn is None:
return x
if self.has_cls_token:
cls_tok, x = x[:, :, :1, :], x[:, :, 1:, :]
else:
cls_tok = None
B, L, C = x.shape
H, W = feat_size
x = x.reshape(B, H, W, C).permute(0, 3, 1, 2).contiguous()
x = self.shortcut_pool_attn(x)
x = x.reshape(B, C, -1).transpose(1, 2)
if cls_tok is not None:
x = torch.cat((cls_tok, x), dim=2)
return x
def forward(self, x, feat_size: List[int]):
x_norm = self.norm1(x)
# NOTE as per the original impl, this seems odd, but shortcut uses un-normalized input if no proj
x_shortcut = x if self.shortcut_proj_attn is None else self.shortcut_proj_attn(x_norm)
x_shortcut = self._shortcut_pool(x_shortcut, feat_size)
x, feat_size_new = self.attn(x_norm, feat_size)
x = x_shortcut + self.drop_path1(x)
x_norm = self.norm2(x)
x_shortcut = x if self.shortcut_proj_mlp is None else self.shortcut_proj_mlp(x_norm)
x = x_shortcut + self.drop_path2(self.mlp(x_norm))
return x, feat_size_new
class MultiScaleVitStage(nn.Module):
def __init__(
self,
dim,
dim_out,
depth,
num_heads,
feat_size,
mlp_ratio=4.0,
qkv_bias=True,
mode="conv",
kernel_q=(1, 1),
kernel_kv=(1, 1),
stride_q=(1, 1),
stride_kv=(1, 1),
has_cls_token=True,
expand_attn=False,
pool_first=False,
rel_pos_type='spatial',
residual_pooling=True,
norm_layer=nn.LayerNorm,
drop_path=0.0,
):
super().__init__()
self.grad_checkpointing = False
self.blocks = nn.ModuleList()
if expand_attn:
out_dims = (dim_out,) * depth
else:
out_dims = (dim,) * (depth - 1) + (dim_out,)
for i in range(depth):
attention_block = MultiScaleBlock(
dim=dim,
dim_out=out_dims[i],
num_heads=num_heads,
feat_size=feat_size,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
kernel_q=kernel_q,
kernel_kv=kernel_kv,
stride_q=stride_q if i == 0 else (1, 1),
stride_kv=stride_kv,
mode=mode,
has_cls_token=has_cls_token,
pool_first=pool_first,
rel_pos_type=rel_pos_type,
residual_pooling=residual_pooling,
expand_attn=expand_attn,
norm_layer=norm_layer,
drop_path=drop_path[i] if isinstance(drop_path, (list, tuple)) else drop_path,
)
dim = out_dims[i]
self.blocks.append(attention_block)
if i == 0:
feat_size = tuple([size // stride for size, stride in zip(feat_size, stride_q)])
self.feat_size = feat_size
def forward(self, x, feat_size: List[int]):
for blk in self.blocks:
if self.grad_checkpointing and not torch.jit.is_scripting():
x, feat_size = checkpoint.checkpoint(blk, x, feat_size)
else:
x, feat_size = blk(x, feat_size)
return x, feat_size
class MultiScaleVit(nn.Module):
"""
Improved Multiscale Vision Transformers for Classification and Detection
Yanghao Li*, Chao-Yuan Wu*, Haoqi Fan, Karttikeya Mangalam, Bo Xiong, Jitendra Malik,
Christoph Feichtenhofer*
https://arxiv.org/abs/2112.01526
Multiscale Vision Transformers
Haoqi Fan*, Bo Xiong*, Karttikeya Mangalam*, Yanghao Li*, Zhicheng Yan, Jitendra Malik,
Christoph Feichtenhofer*
https://arxiv.org/abs/2104.11227
"""
def __init__(
self,
cfg: MultiScaleVitCfg,
img_size: Tuple[int, int] = (224, 224),
in_chans: int = 3,
global_pool: str = 'avg',
num_classes: int = 1000,
drop_path_rate: float = 0.,
drop_rate: float = 0.,
):
super().__init__()
img_size = to_2tuple(img_size)
norm_layer = partial(get_norm_layer(cfg.norm_layer), eps=cfg.norm_eps)
self.num_classes = num_classes
self.drop_rate = drop_rate
self.global_pool = global_pool
self.depths = tuple(cfg.depths)
self.expand_attn = cfg.expand_attn
embed_dim = cfg.embed_dim[0]
self.patch_embed = PatchEmbed(
dim_in=in_chans,
dim_out=embed_dim,
kernel=cfg.patch_kernel,
stride=cfg.patch_stride,
padding=cfg.patch_padding,
)
patch_dims = (img_size[0] // cfg.patch_stride[0], img_size[1] // cfg.patch_stride[1])
num_patches = prod(patch_dims)
if cfg.use_cls_token:
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
self.num_prefix_tokens = 1
pos_embed_dim = num_patches + 1
else:
self.num_prefix_tokens = 0
self.cls_token = None
pos_embed_dim = num_patches
if cfg.use_abs_pos:
self.pos_embed = nn.Parameter(torch.zeros(1, pos_embed_dim, embed_dim))
else:
self.pos_embed = None
num_stages = len(cfg.embed_dim)
feat_size = patch_dims
dpr = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(cfg.depths)).split(cfg.depths)]
self.stages = nn.ModuleList()
for i in range(num_stages):
if cfg.expand_attn:
dim_out = cfg.embed_dim[i]
else:
dim_out = cfg.embed_dim[min(i + 1, num_stages - 1)]
stage = MultiScaleVitStage(
dim=embed_dim,
dim_out=dim_out,
depth=cfg.depths[i],
num_heads=cfg.num_heads[i],
feat_size=feat_size,
mlp_ratio=cfg.mlp_ratio,
qkv_bias=cfg.qkv_bias,
mode=cfg.mode,
pool_first=cfg.pool_first,
expand_attn=cfg.expand_attn,
kernel_q=cfg.kernel_qkv,
kernel_kv=cfg.kernel_qkv,
stride_q=cfg.stride_q[i],
stride_kv=cfg.stride_kv[i],
has_cls_token=cfg.use_cls_token,
rel_pos_type=cfg.rel_pos_type,
residual_pooling=cfg.residual_pooling,
norm_layer=norm_layer,
drop_path=dpr[i],
)
embed_dim = dim_out
feat_size = stage.feat_size
self.stages.append(stage)
self.num_features = embed_dim
self.norm = norm_layer(embed_dim)
self.head = nn.Sequential(OrderedDict([
('drop', nn.Dropout(self.drop_rate)),
('fc', nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity())
]))
if self.pos_embed is not None:
trunc_normal_tf_(self.pos_embed, std=0.02)
if self.cls_token is not None:
trunc_normal_tf_(self.cls_token, std=0.02)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_tf_(m.weight, std=0.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0.0)
@torch.jit.ignore
def no_weight_decay(self):
return {k for k, _ in self.named_parameters()
if any(n in k for n in ["pos_embed", "rel_pos_h", "rel_pos_w", "cls_token"])}
@torch.jit.ignore
def group_matcher(self, coarse=False):
matcher = dict(
stem=r'^patch_embed', # stem and embed
blocks=[(r'^stages\.(\d+)', None), (r'^norm', (99999,))]
)
return matcher
@torch.jit.ignore
def set_grad_checkpointing(self, enable=True):
for s in self.stages:
s.grad_checkpointing = enable
@torch.jit.ignore
def get_classifier(self):
return self.head.fc
def reset_classifier(self, num_classes, global_pool=None):
self.num_classes = num_classes
if global_pool is not None:
self.global_pool = global_pool
self.head = nn.Sequential(OrderedDict([
('drop', nn.Dropout(self.drop_rate)),
('fc', nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity())
]))
def forward_features(self, x):
x, feat_size = self.patch_embed(x)
B, N, C = x.shape
if self.cls_token is not None:
cls_tokens = self.cls_token.expand(B, -1, -1)
x = torch.cat((cls_tokens, x), dim=1)
if self.pos_embed is not None:
x = x + self.pos_embed
for stage in self.stages:
x, feat_size = stage(x, feat_size)
x = self.norm(x)
return x
def forward_head(self, x, pre_logits: bool = False):
if self.global_pool:
if self.global_pool == 'avg':
x = x[:, self.num_prefix_tokens:].mean(1)
else:
x = x[:, 0]
return x if pre_logits else self.head(x)
def forward(self, x):
x = self.forward_features(x)
x = self.forward_head(x)
return x
def checkpoint_filter_fn(state_dict, model):
if 'stages.0.blocks.0.norm1.weight' in state_dict:
return state_dict
import re
if 'model_state' in state_dict:
state_dict = state_dict['model_state']
depths = getattr(model, 'depths', None)
expand_attn = getattr(model, 'expand_attn', True)
assert depths is not None, 'model requires depth attribute to remap checkpoints'
depth_map = {}
block_idx = 0
for stage_idx, d in enumerate(depths):
depth_map.update({i: (stage_idx, i - block_idx) for i in range(block_idx, block_idx + d)})
block_idx += d
out_dict = {}
for k, v in state_dict.items():
k = re.sub(
r'blocks\.(\d+)',
lambda x: f'stages.{depth_map[int(x.group(1))][0]}.blocks.{depth_map[int(x.group(1))][1]}',
k)
if expand_attn:
k = re.sub(r'stages\.(\d+).blocks\.(\d+).proj', f'stages.\\1.blocks.\\2.shortcut_proj_attn', k)
else:
k = re.sub(r'stages\.(\d+).blocks\.(\d+).proj', f'stages.\\1.blocks.\\2.shortcut_proj_mlp', k)
if 'head' in k:
k = k.replace('head.projection', 'head.fc')
out_dict[k] = v
# for k, v in state_dict.items():
# if model.pos_embed is not None and k == 'pos_embed' and v.shape[1] != model.pos_embed.shape[1]:
# # To resize pos embedding when using model at different size from pretrained weights
# v = resize_pos_embed(
# v,
# model.pos_embed,
# 0 if getattr(model, 'no_embed_class') else getattr(model, 'num_prefix_tokens', 1),
# model.patch_embed.grid_size
# )
return out_dict
def _create_mvitv2(variant, cfg_variant=None, pretrained=False, **kwargs):
return build_model_with_cfg(
MultiScaleVit, variant, pretrained,
model_cfg=model_cfgs[variant] if not cfg_variant else model_cfgs[cfg_variant],
pretrained_filter_fn=checkpoint_filter_fn,
feature_cfg=dict(flatten_sequential=True),
**kwargs)
@register_model
def mvitv2_tiny(pretrained=False, **kwargs):
return _create_mvitv2('mvitv2_tiny', pretrained=pretrained, **kwargs)
@register_model
def mvitv2_small(pretrained=False, **kwargs):
return _create_mvitv2('mvitv2_small', pretrained=pretrained, **kwargs)
@register_model
def mvitv2_base(pretrained=False, **kwargs):
return _create_mvitv2('mvitv2_base', pretrained=pretrained, **kwargs)
@register_model
def mvitv2_large(pretrained=False, **kwargs):
return _create_mvitv2('mvitv2_large', pretrained=pretrained, **kwargs)
# @register_model
# def mvitv2_base_in21k(pretrained=False, **kwargs):
# return _create_mvitv2('mvitv2_base_in21k', pretrained=pretrained, **kwargs)
#
#
# @register_model
# def mvitv2_large_in21k(pretrained=False, **kwargs):
# return _create_mvitv2('mvitv2_large_in21k', pretrained=pretrained, **kwargs)
#
#
# @register_model
# def mvitv2_huge_in21k(pretrained=False, **kwargs):
# return _create_mvitv2('mvitv2_huge_in21k', pretrained=pretrained, **kwargs)

@ -0,0 +1,476 @@
""" Pyramid Vision Transformer v2
@misc{wang2021pvtv2,
title={PVTv2: Improved Baselines with Pyramid Vision Transformer},
author={Wenhai Wang and Enze Xie and Xiang Li and Deng-Ping Fan and Kaitao Song and Ding Liang and
Tong Lu and Ping Luo and Ling Shao},
year={2021},
eprint={2106.13797},
archivePrefix={arXiv},
primaryClass={cs.CV}
}
Based on Apache 2.0 licensed code at https://github.com/whai362/PVT
Modifications and timm support by / Copyright 2022, Ross Wightman
"""
import math
from functools import partial
from typing import Tuple, List, Callable, Union
import torch
import torch.nn as nn
import torch.utils.checkpoint as checkpoint
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .helpers import build_model_with_cfg
from .layers import DropPath, to_2tuple, to_ntuple, trunc_normal_
from .registry import register_model
__all__ = ['PyramidVisionTransformerV2']
def _cfg(url='', **kwargs):
return {
'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
'crop_pct': 0.9, 'interpolation': 'bicubic',
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
'first_conv': 'patch_embed.proj', 'classifier': 'head', 'fixed_input_size': False,
**kwargs
}
default_cfgs = {
'pvt_v2_b0': _cfg(url='https://github.com/whai362/PVT/releases/download/v2/pvt_v2_b0.pth'),
'pvt_v2_b1': _cfg(url='https://github.com/whai362/PVT/releases/download/v2/pvt_v2_b1.pth'),
'pvt_v2_b2': _cfg(url='https://github.com/whai362/PVT/releases/download/v2/pvt_v2_b2.pth'),
'pvt_v2_b3': _cfg(url='https://github.com/whai362/PVT/releases/download/v2/pvt_v2_b3.pth'),
'pvt_v2_b4': _cfg(url='https://github.com/whai362/PVT/releases/download/v2/pvt_v2_b4.pth'),
'pvt_v2_b5': _cfg(url='https://github.com/whai362/PVT/releases/download/v2/pvt_v2_b5.pth'),
'pvt_v2_b2_li': _cfg(url='https://github.com/whai362/PVT/releases/download/v2/pvt_v2_b2_li.pth')
}
class MlpWithDepthwiseConv(nn.Module):
def __init__(
self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU,
drop=0., extra_relu=False):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.relu = nn.ReLU() if extra_relu else nn.Identity()
self.dwconv = nn.Conv2d(hidden_features, hidden_features, 3, 1, 1, bias=True, groups=hidden_features)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop = nn.Dropout(drop)
def forward(self, x, feat_size: List[int]):
x = self.fc1(x)
B, N, C = x.shape
x = x.transpose(1, 2).view(B, C, feat_size[0], feat_size[1])
x = self.relu(x)
x = self.dwconv(x)
x = x.flatten(2).transpose(1, 2)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
class Attention(nn.Module):
def __init__(
self,
dim,
num_heads=8,
sr_ratio=1,
linear_attn=False,
qkv_bias=True,
attn_drop=0.,
proj_drop=0.
):
super().__init__()
assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}."
self.dim = dim
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.scale = self.head_dim ** -0.5
self.q = nn.Linear(dim, dim, bias=qkv_bias)
self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
if not linear_attn:
self.pool = None
if sr_ratio > 1:
self.sr = nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio)
self.norm = nn.LayerNorm(dim)
else:
self.sr = None
self.norm = None
self.act = None
else:
self.pool = nn.AdaptiveAvgPool2d(7)
self.sr = nn.Conv2d(dim, dim, kernel_size=1, stride=1)
self.norm = nn.LayerNorm(dim)
self.act = nn.GELU()
def forward(self, x, feat_size: List[int]):
B, N, C = x.shape
H, W = feat_size
q = self.q(x).reshape(B, N, self.num_heads, -1).permute(0, 2, 1, 3)
if self.pool is not None:
x_ = x.permute(0, 2, 1).reshape(B, C, H, W)
x_ = self.sr(self.pool(x_)).reshape(B, C, -1).permute(0, 2, 1)
x_ = self.norm(x_)
x_ = self.act(x_)
kv = self.kv(x_).reshape(B, -1, 2, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
else:
if self.sr is not None:
x_ = x.permute(0, 2, 1).reshape(B, C, H, W)
x_ = self.sr(x_).reshape(B, C, -1).permute(0, 2, 1)
x_ = self.norm(x_)
kv = self.kv(x_).reshape(B, -1, 2, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
else:
kv = self.kv(x).reshape(B, -1, 2, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
k, v = kv.unbind(0)
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
class Block(nn.Module):
def __init__(
self, dim, num_heads, mlp_ratio=4., sr_ratio=1, linear_attn=False, qkv_bias=False,
drop=0., attn_drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
super().__init__()
self.norm1 = norm_layer(dim)
self.attn = Attention(
dim,
num_heads=num_heads,
sr_ratio=sr_ratio,
linear_attn=linear_attn,
qkv_bias=qkv_bias,
attn_drop=attn_drop,
proj_drop=drop,
)
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(dim)
self.mlp = MlpWithDepthwiseConv(
in_features=dim,
hidden_features=int(dim * mlp_ratio),
act_layer=act_layer,
drop=drop,
extra_relu=linear_attn
)
def forward(self, x, feat_size: List[int]):
x = x + self.drop_path(self.attn(self.norm1(x), feat_size))
x = x + self.drop_path(self.mlp(self.norm2(x), feat_size))
return x
class OverlapPatchEmbed(nn.Module):
""" Image to Patch Embedding
"""
def __init__(self, patch_size=7, stride=4, in_chans=3, embed_dim=768):
super().__init__()
patch_size = to_2tuple(patch_size)
assert max(patch_size) > stride, "Set larger patch_size than stride"
self.patch_size = patch_size
self.proj = nn.Conv2d(
in_chans, embed_dim, kernel_size=patch_size, stride=stride,
padding=(patch_size[0] // 2, patch_size[1] // 2))
self.norm = nn.LayerNorm(embed_dim)
def forward(self, x):
x = self.proj(x)
feat_size = x.shape[-2:]
x = x.flatten(2).transpose(1, 2)
x = self.norm(x)
return x, feat_size
class PyramidVisionTransformerStage(nn.Module):
def __init__(
self,
dim: int,
dim_out: int,
depth: int,
downsample: bool = True,
num_heads: int = 8,
sr_ratio: int = 1,
linear_attn: bool = False,
mlp_ratio: float = 4.0,
qkv_bias: bool = True,
drop: float = 0.,
attn_drop: float = 0.,
drop_path: Union[List[float], float] = 0.0,
norm_layer: Callable = nn.LayerNorm,
):
super().__init__()
self.grad_checkpointing = False
if downsample:
self.downsample = OverlapPatchEmbed(
patch_size=3,
stride=2,
in_chans=dim,
embed_dim=dim_out)
else:
assert dim == dim_out
self.downsample = None
self.blocks = nn.ModuleList([Block(
dim=dim_out,
num_heads=num_heads,
sr_ratio=sr_ratio,
linear_attn=linear_attn,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
drop=drop,
attn_drop=attn_drop,
drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
norm_layer=norm_layer,
) for i in range(depth)])
self.norm = norm_layer(dim_out)
def forward(self, x, feat_size: List[int]) -> Tuple[torch.Tensor, List[int]]:
if self.downsample is not None:
x, feat_size = self.downsample(x)
for blk in self.blocks:
if self.grad_checkpointing and not torch.jit.is_scripting():
x = checkpoint.checkpoint(blk, x, feat_size)
else:
x = blk(x, feat_size)
x = self.norm(x)
x = x.reshape(x.shape[0], feat_size[0], feat_size[1], -1).permute(0, 3, 1, 2).contiguous()
return x, feat_size
class PyramidVisionTransformerV2(nn.Module):
def __init__(
self,
img_size=None,
in_chans=3,
num_classes=1000,
global_pool='avg',
depths=(3, 4, 6, 3),
embed_dims=(64, 128, 256, 512),
num_heads=(1, 2, 4, 8),
sr_ratios=(8, 4, 2, 1),
mlp_ratios=(8., 8., 4., 4.),
qkv_bias=True,
linear=False,
drop_rate=0.,
attn_drop_rate=0.,
drop_path_rate=0.,
norm_layer=nn.LayerNorm,
):
super().__init__()
self.num_classes = num_classes
assert global_pool in ('avg', '')
self.global_pool = global_pool
self.depths = depths
num_stages = len(depths)
mlp_ratios = to_ntuple(num_stages)(mlp_ratios)
num_heads = to_ntuple(num_stages)(num_heads)
sr_ratios = to_ntuple(num_stages)(sr_ratios)
assert(len(embed_dims)) == num_stages
self.patch_embed = OverlapPatchEmbed(
patch_size=7,
stride=4,
in_chans=in_chans,
embed_dim=embed_dims[0])
dpr = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depths)).split(depths)]
cur = 0
prev_dim = embed_dims[0]
self.stages = nn.ModuleList()
for i in range(num_stages):
self.stages.append(PyramidVisionTransformerStage(
dim=prev_dim,
dim_out=embed_dims[i],
depth=depths[i],
downsample=i > 0,
num_heads=num_heads[i],
sr_ratio=sr_ratios[i],
mlp_ratio=mlp_ratios[i],
linear_attn=linear,
qkv_bias=qkv_bias,
drop=drop_rate,
attn_drop=attn_drop_rate,
drop_path=dpr[i],
norm_layer=norm_layer
))
prev_dim = embed_dims[i]
cur += depths[i]
# classification head
self.num_features = embed_dims[-1]
self.head = nn.Linear(embed_dims[-1], num_classes) if num_classes > 0 else nn.Identity()
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Conv2d):
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
fan_out //= m.groups
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
if m.bias is not None:
m.bias.data.zero_()
def freeze_patch_emb(self):
self.patch_embed.requires_grad = False
@torch.jit.ignore
def no_weight_decay(self):
return {}
@torch.jit.ignore
def group_matcher(self, coarse=False):
matcher = dict(
stem=r'^patch_embed', # stem and embed
blocks=r'^stages\.(\d+)'
)
return matcher
@torch.jit.ignore
def set_grad_checkpointing(self, enable=True):
for s in self.stages:
s.grad_checkpointing = enable
def get_classifier(self):
return self.head
def reset_classifier(self, num_classes, global_pool=None):
self.num_classes = num_classes
if global_pool is not None:
assert global_pool in ('avg', '')
self.global_pool = global_pool
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
def forward_features(self, x):
x, feat_size = self.patch_embed(x)
for stage in self.stages:
x, feat_size = stage(x, feat_size=feat_size)
return x
def forward_head(self, x, pre_logits: bool = False):
if self.global_pool:
x = x.mean(dim=(-1, -2))
return x if pre_logits else self.head(x)
def forward(self, x):
x = self.forward_features(x)
x = self.forward_head(x)
return x
def _checkpoint_filter_fn(state_dict, model):
""" Remap original checkpoints -> timm """
if 'patch_embed.proj.weight' in state_dict:
return state_dict # non-original checkpoint, no remapping needed
out_dict = {}
import re
for k, v in state_dict.items():
if k.startswith('patch_embed'):
k = k.replace('patch_embed1', 'patch_embed')
k = k.replace('patch_embed2', 'stages.1.downsample')
k = k.replace('patch_embed3', 'stages.2.downsample')
k = k.replace('patch_embed4', 'stages.3.downsample')
k = k.replace('dwconv.dwconv', 'dwconv')
k = re.sub(r'block(\d+).(\d+)', lambda x: f'stages.{int(x.group(1)) - 1}.blocks.{x.group(2)}', k)
k = re.sub(r'^norm(\d+)', lambda x: f'stages.{int(x.group(1)) - 1}.norm', k)
out_dict[k] = v
return out_dict
def _create_pvt2(variant, pretrained=False, **kwargs):
if kwargs.get('features_only', None):
raise RuntimeError('features_only not implemented for Vision Transformer models.')
model = build_model_with_cfg(
PyramidVisionTransformerV2, variant, pretrained,
pretrained_filter_fn=_checkpoint_filter_fn,
**kwargs
)
return model
@register_model
def pvt_v2_b0(pretrained=False, **kwargs):
model_kwargs = dict(
depths=(2, 2, 2, 2), embed_dims=(32, 64, 160, 256), num_heads=(1, 2, 5, 8),
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
return _create_pvt2('pvt_v2_b0', pretrained=pretrained, **model_kwargs)
@register_model
def pvt_v2_b1(pretrained=False, **kwargs):
model_kwargs = dict(
depths=(2, 2, 2, 2), embed_dims=(64, 128, 320, 512), num_heads=(1, 2, 5, 8),
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
return _create_pvt2('pvt_v2_b1', pretrained=pretrained, **model_kwargs)
@register_model
def pvt_v2_b2(pretrained=False, **kwargs):
model_kwargs = dict(
depths=(3, 4, 6, 3), embed_dims=(64, 128, 320, 512), num_heads=(1, 2, 5, 8),
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
return _create_pvt2('pvt_v2_b2', pretrained=pretrained, **model_kwargs)
@register_model
def pvt_v2_b3(pretrained=False, **kwargs):
model_kwargs = dict(
depths=(3, 4, 18, 3), embed_dims=(64, 128, 320, 512), num_heads=(1, 2, 5, 8),
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
return _create_pvt2('pvt_v2_b3', pretrained=pretrained, **model_kwargs)
@register_model
def pvt_v2_b4(pretrained=False, **kwargs):
model_kwargs = dict(
depths=(3, 8, 27, 3), embed_dims=(64, 128, 320, 512), num_heads=(1, 2, 5, 8),
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
return _create_pvt2('pvt_v2_b4', pretrained=pretrained, **model_kwargs)
@register_model
def pvt_v2_b5(pretrained=False, **kwargs):
model_kwargs = dict(
depths=(3, 6, 40, 3), embed_dims=(64, 128, 320, 512), num_heads=(1, 2, 5, 8),
mlp_ratios=(4, 4, 4, 4), norm_layer=partial(nn.LayerNorm, eps=1e-6),
**kwargs)
return _create_pvt2('pvt_v2_b5', pretrained=pretrained, **model_kwargs)
@register_model
def pvt_v2_b2_li(pretrained=False, **kwargs):
model_kwargs = dict(
depths=(3, 4, 6, 3), embed_dims=(64, 128, 320, 512), num_heads=(1, 2, 5, 8),
norm_layer=partial(nn.LayerNorm, eps=1e-6), linear=True, **kwargs)
return _create_pvt2('pvt_v2_b2_li', pretrained=pretrained, **model_kwargs)

@ -33,7 +33,7 @@ from timm.data import create_dataset, create_loader, resolve_data_config, Mixup,
from timm.loss import JsdCrossEntropy, SoftTargetCrossEntropy, BinaryCrossEntropy, \
LabelSmoothingCrossEntropy
from timm.models import create_model, safe_model_name, resume_checkpoint, load_checkpoint, \
convert_splitbn_model, convert_sync_batchnorm, model_parameters
convert_splitbn_model, convert_sync_batchnorm, model_parameters, set_fast_norm
from timm.optim import create_optimizer_v2, optimizer_kwargs
from timm.scheduler import create_scheduler
from timm.utils import ApexScaler, NativeScaler
@ -135,6 +135,8 @@ scripting_group.add_argument('--aot-autograd', default=False, action='store_true
help="Enable AOT Autograd support. (It's recommended to use this option with `--fuser nvfuser` together)")
group.add_argument('--fuser', default='', type=str,
help="Select jit fuser. One of ('', 'te', 'old', 'nvfuser')")
group.add_argument('--fast-norm', default=False, action='store_true',
help='enable experimental fast-norm')
group.add_argument('--grad-checkpointing', action='store_true', default=False,
help='Enable gradient checkpointing through model blocks/stages')
@ -395,6 +397,8 @@ def main():
if args.fuser:
utils.set_jit_fuser(args.fuser)
if args.fast_norm:
set_fast_norm()
model = create_model(
args.model,

@ -20,7 +20,7 @@ import torch.nn.parallel
from collections import OrderedDict
from contextlib import suppress
from timm.models import create_model, apply_test_time_pool, load_checkpoint, is_model, list_models
from timm.models import create_model, apply_test_time_pool, load_checkpoint, is_model, list_models, set_fast_norm
from timm.data import create_dataset, create_loader, resolve_data_config, RealLabelsImagenet
from timm.utils import accuracy, AverageMeter, natural_key, setup_default_logging, set_jit_fuser,\
decay_batch_step, check_batch_size_retry
@ -117,6 +117,8 @@ scripting_group.add_argument('--aot-autograd', default=False, action='store_true
help="Enable AOT Autograd support. (It's recommended to use this option with `--fuser nvfuser` together)")
parser.add_argument('--fuser', default='', type=str,
help="Select jit fuser. One of ('', 'te', 'old', 'nvfuser')")
parser.add_argument('--fast-norm', default=False, action='store_true',
help='enable experimental fast-norm')
parser.add_argument('--results-file', default='', type=str, metavar='FILENAME',
help='Output csv file for validation results (summary)')
parser.add_argument('--real-labels', default='', type=str, metavar='FILENAME',
@ -150,6 +152,8 @@ def validate(args):
if args.fuser:
set_jit_fuser(args.fuser)
if args.fast_norm:
set_fast_norm()
# create model
model = create_model(

Loading…
Cancel
Save