From ff4a38e2c3d64353995b78732d3e3dec7e3df5dc Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Wed, 17 Aug 2022 12:06:05 -0700 Subject: [PATCH 01/21] Add PyramidVisionTransformerV2 --- timm/models/__init__.py | 1 + timm/models/pvt_v2.py | 476 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 477 insertions(+) create mode 100644 timm/models/pvt_v2.py diff --git a/timm/models/__init__.py b/timm/models/__init__.py index 195e451b..65b1d955 100644 --- a/timm/models/__init__.py +++ b/timm/models/__init__.py @@ -32,6 +32,7 @@ from .nfnet import * from .pit import * from .pnasnet import * from .poolformer import * +from .pvt_v2 import * from .regnet import * from .res2net import * from .resnest import * diff --git a/timm/models/pvt_v2.py b/timm/models/pvt_v2.py new file mode 100644 index 00000000..551a8325 --- /dev/null +++ b/timm/models/pvt_v2.py @@ -0,0 +1,476 @@ +""" Pyramid Vision Transformer v2 + +@misc{wang2021pvtv2, + title={PVTv2: Improved Baselines with Pyramid Vision Transformer}, + author={Wenhai Wang and Enze Xie and Xiang Li and Deng-Ping Fan and Kaitao Song and Ding Liang and + Tong Lu and Ping Luo and Ling Shao}, + year={2021}, + eprint={2106.13797}, + archivePrefix={arXiv}, + primaryClass={cs.CV} +} + +Based on Apache 2.0 licensed code at https://github.com/whai362/PVT + +Modifications and timm support by / Copyright 2022, Ross Wightman +""" + +import math +from functools import partial +from typing import Tuple, List, Callable, Union + +import torch +import torch.nn as nn +import torch.utils.checkpoint as checkpoint + +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from .helpers import build_model_with_cfg +from .layers import DropPath, to_2tuple, to_ntuple, trunc_normal_ +from .registry import register_model + +__all__ = ['PyramidVisionTransformerV2'] + + +def _cfg(url='', **kwargs): + return { + 'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7), + 'crop_pct': 0.9, 'interpolation': 'bicubic', + 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, + 'first_conv': 'patch_embed.conv', 'classifier': 'head', 'fixed_input_size': False, + **kwargs + } + + +default_cfgs = { + 'pvt_v2_b0': _cfg(url='https://github.com/whai362/PVT/releases/download/v2/pvt_v2_b0.pth'), + 'pvt_v2_b1': _cfg(url='https://github.com/whai362/PVT/releases/download/v2/pvt_v2_b1.pth'), + 'pvt_v2_b2': _cfg(url='https://github.com/whai362/PVT/releases/download/v2/pvt_v2_b2.pth'), + 'pvt_v2_b3': _cfg(url='https://github.com/whai362/PVT/releases/download/v2/pvt_v2_b3.pth'), + 'pvt_v2_b4': _cfg(url='https://github.com/whai362/PVT/releases/download/v2/pvt_v2_b4.pth'), + 'pvt_v2_b5': _cfg(url='https://github.com/whai362/PVT/releases/download/v2/pvt_v2_b5.pth'), + 'pvt_v2_b2_li': _cfg(url='https://github.com/whai362/PVT/releases/download/v2/pvt_v2_b2_li.pth') +} + + +class MlpWithDepthwiseConv(nn.Module): + def __init__( + self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, + drop=0., extra_relu=False): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.relu = nn.ReLU() if extra_relu else nn.Identity() + self.dwconv = nn.Conv2d(hidden_features, hidden_features, 3, 1, 1, bias=True, groups=hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x, feat_size: List[int]): + x = self.fc1(x) + B, N, C = x.shape + x = x.transpose(1, 2).view(B, C, feat_size[0], feat_size[1]) + x = self.relu(x) + x = self.dwconv(x) + x = x.flatten(2).transpose(1, 2) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class Attention(nn.Module): + def __init__( + self, + dim, + num_heads=8, + sr_ratio=1, + linear_attn=False, + qkv_bias=True, + attn_drop=0., + proj_drop=0. + ): + super().__init__() + assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}." + + self.dim = dim + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.scale = self.head_dim ** -0.5 + + self.q = nn.Linear(dim, dim, bias=qkv_bias) + self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + if not linear_attn: + self.pool = None + if sr_ratio > 1: + self.sr = nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio) + self.norm = nn.LayerNorm(dim) + else: + self.sr = None + self.norm = None + self.act = None + else: + self.pool = nn.AdaptiveAvgPool2d(7) + self.sr = nn.Conv2d(dim, dim, kernel_size=1, stride=1) + self.norm = nn.LayerNorm(dim) + self.act = nn.GELU() + + def forward(self, x, feat_size: List[int]): + B, N, C = x.shape + H, W = feat_size + q = self.q(x).reshape(B, N, self.num_heads, -1).permute(0, 2, 1, 3) + + if self.pool is not None: + x_ = x.permute(0, 2, 1).reshape(B, C, H, W) + x_ = self.sr(self.pool(x_)).reshape(B, C, -1).permute(0, 2, 1) + x_ = self.norm(x_) + x_ = self.act(x_) + kv = self.kv(x_).reshape(B, -1, 2, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) + else: + if self.sr is not None: + x_ = x.permute(0, 2, 1).reshape(B, C, H, W) + x_ = self.sr(x_).reshape(B, C, -1).permute(0, 2, 1) + x_ = self.norm(x_) + kv = self.kv(x_).reshape(B, -1, 2, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) + else: + kv = self.kv(x).reshape(B, -1, 2, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) + k, v = kv.unbind(0) + + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class Block(nn.Module): + + def __init__( + self, dim, num_heads, mlp_ratio=4., sr_ratio=1, linear_attn=False, qkv_bias=False, + drop=0., attn_drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = Attention( + dim, + num_heads=num_heads, + sr_ratio=sr_ratio, + linear_attn=linear_attn, + qkv_bias=qkv_bias, + attn_drop=attn_drop, + proj_drop=drop, + ) + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + self.mlp = MlpWithDepthwiseConv( + in_features=dim, + hidden_features=int(dim * mlp_ratio), + act_layer=act_layer, + drop=drop, + extra_relu=linear_attn + ) + + def forward(self, x, feat_size: List[int]): + x = x + self.drop_path(self.attn(self.norm1(x), feat_size)) + x = x + self.drop_path(self.mlp(self.norm2(x), feat_size)) + + return x + + +class OverlapPatchEmbed(nn.Module): + """ Image to Patch Embedding + """ + def __init__(self, patch_size=7, stride=4, in_chans=3, embed_dim=768): + super().__init__() + patch_size = to_2tuple(patch_size) + assert max(patch_size) > stride, "Set larger patch_size than stride" + self.patch_size = patch_size + self.proj = nn.Conv2d( + in_chans, embed_dim, kernel_size=patch_size, stride=stride, + padding=(patch_size[0] // 2, patch_size[1] // 2)) + self.norm = nn.LayerNorm(embed_dim) + + def forward(self, x): + x = self.proj(x) + feat_size = x.shape[-2:] + x = x.flatten(2).transpose(1, 2) + x = self.norm(x) + return x, feat_size + + +class PyramidVisionTransformerStage(nn.Module): + def __init__( + self, + dim: int, + dim_out: int, + depth: int, + downsample: bool = True, + num_heads: int = 8, + sr_ratio: int = 1, + linear_attn: bool = False, + mlp_ratio: float = 4.0, + qkv_bias: bool = True, + drop: float = 0., + attn_drop: float = 0., + drop_path: Union[List[float], float] = 0.0, + norm_layer: Callable = nn.LayerNorm, + ): + super().__init__() + self.grad_checkpointing = False + + if downsample: + self.downsample = OverlapPatchEmbed( + patch_size=3, + stride=2, + in_chans=dim, + embed_dim=dim_out) + else: + assert dim == dim_out + self.downsample = None + + self.blocks = nn.ModuleList([Block( + dim=dim_out, + num_heads=num_heads, + sr_ratio=sr_ratio, + linear_attn=linear_attn, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + drop=drop, + attn_drop=attn_drop, + drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, + norm_layer=norm_layer, + ) for i in range(depth)]) + + self.norm = norm_layer(dim_out) + + def forward(self, x, feat_size: List[int]) -> Tuple[torch.Tensor, List[int]]: + if self.downsample is not None: + x, feat_size = self.downsample(x) + for blk in self.blocks: + if self.grad_checkpointing and not torch.jit.is_scripting(): + x = checkpoint.checkpoint(blk, x, feat_size) + else: + x = blk(x, feat_size) + x = self.norm(x) + x = x.reshape(x.shape[0], feat_size[0], feat_size[1], -1).permute(0, 3, 1, 2).contiguous() + return x, feat_size + + +class PyramidVisionTransformerV2(nn.Module): + def __init__( + self, + img_size=None, + in_chans=3, + num_classes=1000, + global_pool='avg', + depths=(3, 4, 6, 3), + embed_dims=(64, 128, 256, 512), + num_heads=(1, 2, 4, 8), + sr_ratios=(8, 4, 2, 1), + mlp_ratios=(8., 8., 4., 4.), + qkv_bias=True, + linear=False, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0., + norm_layer=nn.LayerNorm, + ): + super().__init__() + self.num_classes = num_classes + assert global_pool in ('avg', '') + self.global_pool = global_pool + self.img_size = to_2tuple(img_size) if img_size is not None else None + self.depths = depths + num_stages = len(depths) + mlp_ratios = to_ntuple(num_stages)(mlp_ratios) + num_heads = to_ntuple(num_stages)(num_heads) + sr_ratios = to_ntuple(num_stages)(sr_ratios) + assert(len(embed_dims)) == num_stages + + self.patch_embed = OverlapPatchEmbed( + patch_size=7, + stride=4, + in_chans=in_chans, + embed_dim=embed_dims[0]) + + dpr = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depths)).split(depths)] + cur = 0 + prev_dim = embed_dims[0] + self.stages = nn.ModuleList() + for i in range(num_stages): + self.stages.append(PyramidVisionTransformerStage( + dim=prev_dim, + dim_out=embed_dims[i], + depth=depths[i], + downsample=i > 0, + num_heads=num_heads[i], + sr_ratio=sr_ratios[i], + mlp_ratio=mlp_ratios[i], + linear_attn=linear, + qkv_bias=qkv_bias, + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=dpr[i], + norm_layer=norm_layer + )) + prev_dim = embed_dims[i] + cur += depths[i] + + # classification head + self.head = nn.Linear(embed_dims[3], num_classes) if num_classes > 0 else nn.Identity() + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.Conv2d): + fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + fan_out //= m.groups + m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) + if m.bias is not None: + m.bias.data.zero_() + + def freeze_patch_emb(self): + self.patch_embed.requires_grad = False + + @torch.jit.ignore + def no_weight_decay(self): + return {} + + @torch.jit.ignore + def group_matcher(self, coarse=False): + matcher = dict( + stem=r'^patch_embed', # stem and embed + blocks=[(r'^stages\.(\d+)', None), (r'^norm', (99999,))] + ) + return matcher + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + for s in self.stages: + s.grad_checkpointing = enable + + def get_classifier(self): + return self.head + + def reset_classifier(self, num_classes, global_pool=None): + self.num_classes = num_classes + if global_pool is not None: + assert global_pool in ('avg', '') + self.global_pool = global_pool + self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() + + def forward_features(self, x): + x, feat_size = self.patch_embed(x) + for stage in self.stages: + x, feat_size = stage(x, feat_size=feat_size) + return x + + def forward_head(self, x, pre_logits: bool = False): + if self.global_pool: + x = x.mean(dim=(-1, -2)) + return x if pre_logits else self.head(x) + + def forward(self, x): + x = self.forward_features(x) + x = self.forward_head(x) + return x + + +def _checkpoint_filter_fn(state_dict, model): + """ Remap original checkpoints -> timm """ + if 'patch_embed.proj.weight' in state_dict: + return state_dict # non-original checkpoint, no remapping needed + + out_dict = {} + import re + for k, v in state_dict.items(): + if k.startswith('patch_embed'): + k = k.replace('patch_embed1', 'patch_embed') + k = k.replace('patch_embed2', 'stages.1.downsample') + k = k.replace('patch_embed3', 'stages.2.downsample') + k = k.replace('patch_embed4', 'stages.3.downsample') + k = k.replace('dwconv.dwconv', 'dwconv') + k = re.sub(r'block(\d+).(\d+)', lambda x: f'stages.{int(x.group(1)) - 1}.blocks.{x.group(2)}', k) + k = re.sub(r'^norm(\d+)', lambda x: f'stages.{int(x.group(1)) - 1}.norm', k) + out_dict[k] = v + return out_dict + + +def _create_pvt2(variant, pretrained=False, **kwargs): + if kwargs.get('features_only', None): + raise RuntimeError('features_only not implemented for Vision Transformer models.') + model = build_model_with_cfg( + PyramidVisionTransformerV2, variant, pretrained, + pretrained_filter_fn=_checkpoint_filter_fn, + **kwargs + ) + return model + + +@register_model +def pvt_v2_b0(pretrained=False, **kwargs): + model_kwargs = dict( + depths=(2, 2, 2, 2), embed_dims=(32, 64, 160, 256), num_heads=(1, 2, 5, 8), + norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) + return _create_pvt2('pvt_v2_b0', pretrained=pretrained, **model_kwargs) + + +@register_model +def pvt_v2_b1(pretrained=False, **kwargs): + model_kwargs = dict( + depths=(2, 2, 2, 2), embed_dims=(64, 128, 320, 512), num_heads=(1, 2, 5, 8), + norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) + return _create_pvt2('pvt_v2_b1', pretrained=pretrained, **model_kwargs) + + +@register_model +def pvt_v2_b2(pretrained=False, **kwargs): + model_kwargs = dict( + depths=(3, 4, 6, 3), embed_dims=(64, 128, 320, 512), num_heads=(1, 2, 5, 8), + norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) + return _create_pvt2('pvt_v2_b2', pretrained=pretrained, **model_kwargs) + + +@register_model +def pvt_v2_b3(pretrained=False, **kwargs): + model_kwargs = dict( + depths=(3, 4, 18, 3), embed_dims=(64, 128, 320, 512), num_heads=(1, 2, 5, 8), + norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) + return _create_pvt2('pvt_v2_b3', pretrained=pretrained, **model_kwargs) + + +@register_model +def pvt_v2_b4(pretrained=False, **kwargs): + model_kwargs = dict( + depths=(3, 8, 27, 3), embed_dims=(64, 128, 320, 512), num_heads=(1, 2, 5, 8), + norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) + return _create_pvt2('pvt_v2_b4', pretrained=pretrained, **model_kwargs) + + +@register_model +def pvt_v2_b5(pretrained=False, **kwargs): + model_kwargs = dict( + depths=(3, 6, 40, 3), embed_dims=(64, 128, 320, 512), num_heads=(1, 2, 5, 8), + mlp_ratios=(4, 4, 4, 4), norm_layer=partial(nn.LayerNorm, eps=1e-6), + **kwargs) + return _create_pvt2('pvt_v2_b5', pretrained=pretrained, **model_kwargs) + + +@register_model +def pvt_v2_b2_li(pretrained=False, **kwargs): + model_kwargs = dict( + depths=(3, 4, 6, 3), embed_dims=(64, 128, 320, 512), num_heads=(1, 2, 5, 8), + norm_layer=partial(nn.LayerNorm, eps=1e-6), linear=True, **kwargs) + return _create_pvt2('pvt_v2_b2_li', pretrained=pretrained, **model_kwargs) + From fba6ecd39b357f405d09cd343989d07163c4a0e4 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Wed, 17 Aug 2022 14:08:53 -0700 Subject: [PATCH 02/21] Add EfficientFormer --- timm/models/__init__.py | 1 + timm/models/efficientformer.py | 552 +++++++++++++++++++++++++++++++++ 2 files changed, 553 insertions(+) create mode 100644 timm/models/efficientformer.py diff --git a/timm/models/__init__.py b/timm/models/__init__.py index 65b1d955..5e19358c 100644 --- a/timm/models/__init__.py +++ b/timm/models/__init__.py @@ -13,6 +13,7 @@ from .densenet import * from .dla import * from .dpn import * from .edgenext import * +from .efficientformer import * from .efficientnet import * from .ghostnet import * from .gluon_resnet import * diff --git a/timm/models/efficientformer.py b/timm/models/efficientformer.py new file mode 100644 index 00000000..0a54f7fe --- /dev/null +++ b/timm/models/efficientformer.py @@ -0,0 +1,552 @@ +""" EfficientFormer + +@article{li2022efficientformer, + title={EfficientFormer: Vision Transformers at MobileNet Speed}, + author={Li, Yanyu and Yuan, Geng and Wen, Yang and Hu, Eric and Evangelidis, Georgios and Tulyakov, + Sergey and Wang, Yanzhi and Ren, Jian}, + journal={arXiv preprint arXiv:2206.01191}, + year={2022} +} + +Based on Apache 2.0 licensed code at https://github.com/snap-research/EfficientFormer, Copyright (c) 2022 Snap Inc. + +Modifications and timm support by / Copyright 2022, Ross Wightman +""" +from typing import Dict + +import torch +import torch.nn as nn + +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from .helpers import build_model_with_cfg +from .layers import DropPath, trunc_normal_, to_2tuple, Mlp +from .registry import register_model + + +def _cfg(url='', **kwargs): + return { + 'url': url, + 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, + 'crop_pct': .95, 'interpolation': 'bicubic', + 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, + 'first_conv': 'stem.conv1', 'classifier': 'head', + **kwargs + } + + +default_cfgs = dict( + efficientformer_l1=_cfg( + url="https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-morevit/efficientformer_l1_1000d_224-5b08fab0.pth", + ), + efficientformer_l3=_cfg( + url="https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-morevit/efficientformer_l3_300d_224-6816624f.pth", + ), + efficientformer_l7=_cfg( + url="https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-morevit/efficientformer_l7_300d_224-e957ab75.pth", + ), +) + +EfficientFormer_width = { + 'l1': (48, 96, 224, 448), + 'l3': (64, 128, 320, 512), + 'l7': (96, 192, 384, 768), +} + +EfficientFormer_depth = { + 'l1': (3, 2, 6, 4), + 'l3': (4, 4, 12, 6), + 'l7': (6, 6, 18, 8), +} + + +class Attention(torch.nn.Module): + attention_bias_cache: Dict[str, torch.Tensor] + + def __init__( + self, + dim=384, + key_dim=32, + num_heads=8, + attn_ratio=4, + resolution=7 + ): + super().__init__() + self.num_heads = num_heads + self.scale = key_dim ** -0.5 + self.key_dim = key_dim + self.key_attn_dim = key_dim * num_heads + self.val_dim = int(attn_ratio * key_dim) + self.val_attn_dim = self.val_dim * num_heads + self.attn_ratio = attn_ratio + + self.qkv = nn.Linear(dim, self.key_attn_dim * 2 + self.val_attn_dim) + self.proj = nn.Linear(self.val_attn_dim, dim) + + resolution = to_2tuple(resolution) + pos = torch.stack(torch.meshgrid(torch.arange(resolution[0]), torch.arange(resolution[1]))).flatten(1) + rel_pos = (pos[..., :, None] - pos[..., None, :]).abs() + rel_pos = (rel_pos[0] * resolution[1]) + rel_pos[1] + self.attention_biases = torch.nn.Parameter(torch.zeros(num_heads, resolution[0] * resolution[1])) + self.register_buffer('attention_bias_idxs', torch.LongTensor(rel_pos)) + self.attention_bias_cache = {} # per-device attention_biases cache (data-parallel compat) + + @torch.no_grad() + def train(self, mode=True): + super().train(mode) + if mode and self.attention_bias_cache: + self.attention_bias_cache = {} # clear ab cache + + def get_attention_biases(self, device: torch.device) -> torch.Tensor: + if self.training: + return self.attention_biases[:, self.attention_bias_idxs] + else: + device_key = str(device) + if device_key not in self.attention_bias_cache: + self.attention_bias_cache[device_key] = self.attention_biases[:, self.attention_bias_idxs] + return self.attention_bias_cache[device_key] + + def forward(self, x): # x (B,N,C) + B, N, C = x.shape + qkv = self.qkv(x) + qkv = qkv.reshape(B, N, self.num_heads, -1).permute(0, 2, 1, 3) + q, k, v = qkv.split([self.key_dim, self.key_dim, self.val_dim], dim=3) + + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn + self.get_attention_biases(x.device) + + attn = attn.softmax(dim=-1) + x = (attn @ v).transpose(1, 2).reshape(B, N, self.val_attn_dim) + x = self.proj(x) + return x + + +class Stem4(nn.Sequential): + def __init__(self, in_chs, out_chs, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d): + super().__init__() + self.stride = 4 + + self.add_module('conv1', nn.Conv2d(in_chs, out_chs // 2, kernel_size=3, stride=2, padding=1)) + self.add_module('norm1', norm_layer(out_chs // 2)) + self.add_module('act1', act_layer()) + self.add_module('conv2', nn.Conv2d(out_chs // 2, out_chs, kernel_size=3, stride=2, padding=1)) + self.add_module('norm2', norm_layer(out_chs)) + self.add_module('act2', act_layer()) + + +class Downsample(nn.Module): + """ + Downsampling via strided conv w/ norm + Input: tensor in shape [B, C, H, W] + Output: tensor in shape [B, C, H/stride, W/stride] + """ + + def __init__(self, in_chs, out_chs, kernel_size=3, stride=2, padding=None, norm_layer=nn.BatchNorm2d): + super().__init__() + if padding is None: + padding = kernel_size // 2 + self.conv = nn.Conv2d(in_chs, out_chs, kernel_size=kernel_size, stride=stride, padding=padding) + self.norm = norm_layer(out_chs) + + def forward(self, x): + x = self.conv(x) + x = self.norm(x) + return x + + +class Flat(nn.Module): + + def __init__(self, ): + super().__init__() + + def forward(self, x): + x = x.flatten(2).transpose(1, 2) + return x + + +class Pooling(nn.Module): + """ + Implementation of pooling for PoolFormer + --pool_size: pooling size + """ + + def __init__(self, pool_size=3): + super().__init__() + self.pool = nn.AvgPool2d(pool_size, stride=1, padding=pool_size // 2, count_include_pad=False) + + def forward(self, x): + return self.pool(x) - x + + +class ConvMlpWithNorm(nn.Module): + """ + Implementation of MLP with 1*1 convolutions. + Input: tensor with shape [B, C, H, W] + """ + + def __init__( + self, + in_features, + hidden_features=None, + out_features=None, + act_layer=nn.GELU, + norm_layer=nn.BatchNorm2d, + drop=0. + ): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Conv2d(in_features, hidden_features, 1) + self.norm1 = norm_layer(hidden_features) if norm_layer is not None else nn.Identity() + self.act = act_layer() + self.fc2 = nn.Conv2d(hidden_features, out_features, 1) + self.norm2 = norm_layer(out_features) if norm_layer is not None else nn.Identity() + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.norm1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.norm2(x) + x = self.drop(x) + return x + + +class LayerScale(nn.Module): + def __init__(self, dim, init_values=1e-5, inplace=False): + super().__init__() + self.inplace = inplace + self.gamma = nn.Parameter(init_values * torch.ones(dim)) + + def forward(self, x): + return x.mul_(self.gamma) if self.inplace else x * self.gamma + + +class MetaBlock1d(nn.Module): + + def __init__( + self, + dim, + mlp_ratio=4., + act_layer=nn.GELU, + norm_layer=nn.LayerNorm, + drop=0., + drop_path=0., + layer_scale_init_value=1e-5 + ): + super().__init__() + self.norm1 = norm_layer(dim) + self.token_mixer = Attention(dim) + self.norm2 = norm_layer(dim) + self.mlp = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=drop) + + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.ls1 = LayerScale(dim, layer_scale_init_value) + self.ls2 = LayerScale(dim, layer_scale_init_value) + + def forward(self, x): + x = x + self.drop_path(self.ls1(self.token_mixer(self.norm1(x)))) + x = x + self.drop_path(self.ls2(self.mlp(self.norm2(x)))) + return x + + +class LayerScale2d(nn.Module): + def __init__(self, dim, init_values=1e-5, inplace=False): + super().__init__() + self.inplace = inplace + self.gamma = nn.Parameter(init_values * torch.ones(dim)) + + def forward(self, x): + gamma = self.gamma.view(1, -1, 1, 1) + return x.mul_(gamma) if self.inplace else x * gamma + + +class MetaBlock2d(nn.Module): + + def __init__( + self, + dim, + pool_size=3, + mlp_ratio=4., + act_layer=nn.GELU, + norm_layer=nn.BatchNorm2d, + drop=0., + drop_path=0., + layer_scale_init_value=1e-5 + ): + super().__init__() + self.token_mixer = Pooling(pool_size=pool_size) + self.mlp = ConvMlpWithNorm( + dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, norm_layer=norm_layer, drop=drop) + + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.ls1 = LayerScale2d(dim, layer_scale_init_value) + self.ls2 = LayerScale2d(dim, layer_scale_init_value) + + def forward(self, x): + x = x + self.drop_path(self.ls1(self.token_mixer(x))) + x = x + self.drop_path(self.ls2(self.mlp(x))) + return x + + +class EfficientFormerStage(nn.Module): + + def __init__( + self, + dim, + dim_out, + depth, + downsample=True, + num_vit=1, + pool_size=3, + mlp_ratio=4., + act_layer=nn.GELU, + norm_layer=nn.BatchNorm2d, + norm_layer_cl=nn.LayerNorm, + drop=.0, + drop_path=0., + layer_scale_init_value=1e-5, +): + super().__init__() + self.grad_checkpointing = False + + if downsample: + self.downsample = Downsample(in_chs=dim, out_chs=dim_out, norm_layer=norm_layer) + dim = dim_out + else: + assert dim == dim_out + self.downsample = nn.Identity() + + blocks = [] + if num_vit and num_vit >= depth: + blocks.append(Flat()) + + for block_idx in range(depth): + remain_idx = depth - block_idx - 1 + if num_vit and num_vit > remain_idx: + blocks.append( + MetaBlock1d( + dim, + mlp_ratio=mlp_ratio, + act_layer=act_layer, + norm_layer=norm_layer_cl, + drop=drop, + drop_path=drop_path[block_idx], + layer_scale_init_value=layer_scale_init_value, + )) + else: + blocks.append( + MetaBlock2d( + dim, + pool_size=pool_size, + mlp_ratio=mlp_ratio, + act_layer=act_layer, + norm_layer=norm_layer, + drop=drop, + drop_path=drop_path[block_idx], + layer_scale_init_value=layer_scale_init_value, + )) + if num_vit and num_vit == remain_idx: + blocks.append(Flat()) + + self.blocks = nn.Sequential(*blocks) + + def forward(self, x): + x = self.downsample(x) + x = self.blocks(x) + return x + + +class EfficientFormer(nn.Module): + + def __init__( + self, + depths, + embed_dims=None, + in_chans=3, + num_classes=1000, + global_pool='avg', + downsamples=None, + num_vit=0, + mlp_ratios=4, + pool_size=3, + layer_scale_init_value=1e-5, + act_layer=nn.GELU, + norm_layer=nn.BatchNorm2d, + norm_layer_cl=nn.LayerNorm, + drop_rate=0., + drop_path_rate=0., + **kwargs + ): + super().__init__() + self.num_classes = num_classes + self.global_pool = global_pool + + self.stem = Stem4(in_chans, embed_dims[0], norm_layer=norm_layer) + prev_dim = embed_dims[0] + + # stochastic depth decay rule + dpr = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depths)).split(depths)] + downsamples = downsamples or (False,) + (True,) * (len(depths) - 1) + stages = [] + for i in range(len(depths)): + stage = EfficientFormerStage( + prev_dim, + embed_dims[i], + depths[i], + downsample=downsamples[i], + num_vit=num_vit if i == 3 else 0, + pool_size=pool_size, + mlp_ratio=mlp_ratios, + act_layer=act_layer, + norm_layer_cl=norm_layer_cl, + norm_layer=norm_layer, + drop=drop_rate, + drop_path=dpr[i], + layer_scale_init_value=layer_scale_init_value, + ) + prev_dim = embed_dims[i] + stages.append(stage) + + self.stages = nn.Sequential(*stages) + + # Classifier head + self.num_features = embed_dims[-1] + self.norm = norm_layer_cl(self.num_features) + self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() + # assuming model is always distilled (valid for current checkpoints, will split def if that changes) + self.head_dist = nn.Linear(embed_dims[-1], num_classes) if num_classes > 0 else nn.Identity() + self.distilled_training = False # must set this True to train w/ distillation token + + self.apply(self._init_weights) + + # init for classification + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + + @torch.jit.ignore + def no_weight_decay(self): + return {k for k, _ in self.named_parameters() if 'attention_biases' in k} + + @torch.jit.ignore + def group_matcher(self, coarse=False): + matcher = dict( + stem=r'^stem', # stem and embed + blocks=[(r'^stages\.(\d+)', None), (r'^norm', (99999,))] + ) + return matcher + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + for s in self.stages: + s.grad_checkpointing = enable + + @torch.jit.ignore + def get_classifier(self): + return self.head, self.head_dist + + def reset_classifier(self, num_classes, global_pool=None, distillation=None): + self.num_classes = num_classes + if global_pool is not None: + self.global_pool = global_pool + self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() + if self.dist: + self.head_dist = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() + + @torch.jit.ignore + def set_distilled_training(self, enable=True): + self.distilled_training = enable + + def forward_features(self, x): + x = self.stem(x) + x = self.stages(x) + x = self.norm(x) + return x + + def forward_head(self, x, pre_logits: bool = False): + if self.global_pool == 'avg': + x = x.mean(dim=1) + if pre_logits: + return x + x, x_dist = self.head(x), self.head_dist(x) + if self.distilled_training and self.training and not torch.jit.is_scripting(): + # only return separate classification predictions when training in distilled mode + return x, x_dist + else: + # during standard train/finetune, inference average the classifier predictions + return (x + x_dist) / 2 + + def forward(self, x): + x = self.forward_features(x) + x = self.forward_head(x) + return x + + +def _checkpoint_filter_fn(state_dict, model): + """ Remap original checkpoints -> timm """ + if 'stem.0.weight' in state_dict: + return state_dict # non-original checkpoint, no remapping needed + + out_dict = {} + import re + stage_idx = 0 + for k, v in state_dict.items(): + if k.startswith('patch_embed'): + k = k.replace('patch_embed.0', 'stem.conv1') + k = k.replace('patch_embed.1', 'stem.norm1') + k = k.replace('patch_embed.3', 'stem.conv2') + k = k.replace('patch_embed.4', 'stem.norm2') + + if re.match(r'network\.(\d+)\.proj\.weight', k): + stage_idx += 1 + k = re.sub(r'network.(\d+).(\d+)', f'stages.{stage_idx}.blocks.\\2', k) + k = re.sub(r'network.(\d+).proj', f'stages.{stage_idx}.downsample.conv', k) + k = re.sub(r'network.(\d+).norm', f'stages.{stage_idx}.downsample.norm', k) + + k = re.sub(r'layer_scale_([0-9])', r'ls\1.gamma', k) + k = k.replace('dist_head', 'head_dist') + out_dict[k] = v + return out_dict + + +def _create_efficientformer(variant, pretrained=False, **kwargs): + model = build_model_with_cfg( + EfficientFormer, variant, pretrained, + pretrained_filter_fn=_checkpoint_filter_fn, + **kwargs) + return model + + +@register_model +def efficientformer_l1(pretrained=False, **kwargs): + model_kwargs = dict( + depths=EfficientFormer_depth['l1'], + embed_dims=EfficientFormer_width['l1'], + num_vit=1, + **kwargs) + return _create_efficientformer('efficientformer_l1', pretrained=pretrained, **model_kwargs) + + +@register_model +def efficientformer_l3(pretrained=False, **kwargs): + model_kwargs = dict( + depths=EfficientFormer_depth['l3'], + embed_dims=EfficientFormer_width['l3'], + num_vit=4, + **kwargs) + return _create_efficientformer('efficientformer_l3', pretrained=pretrained, **model_kwargs) + + +@register_model +def efficientformer_l7(pretrained=False, **kwargs): + model_kwargs = dict( + depths=EfficientFormer_depth['l7'], + embed_dims=EfficientFormer_width['l7'], + num_vit=8, + **kwargs) + return _create_efficientformer('efficientformer_l7', pretrained=pretrained, **model_kwargs) + From c486aa71f815188d9d86aa6711dd175a0bb7f955 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Wed, 17 Aug 2022 14:29:18 -0700 Subject: [PATCH 03/21] Add GCViT --- timm/models/__init__.py | 1 + timm/models/gcvit.py | 575 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 576 insertions(+) create mode 100644 timm/models/gcvit.py diff --git a/timm/models/__init__.py b/timm/models/__init__.py index 5e19358c..936846c3 100644 --- a/timm/models/__init__.py +++ b/timm/models/__init__.py @@ -15,6 +15,7 @@ from .dpn import * from .edgenext import * from .efficientformer import * from .efficientnet import * +from .gcvit import * from .ghostnet import * from .gluon_resnet import * from .gluon_xception import * diff --git a/timm/models/gcvit.py b/timm/models/gcvit.py new file mode 100644 index 00000000..3e2cd96a --- /dev/null +++ b/timm/models/gcvit.py @@ -0,0 +1,575 @@ +""" Global Context ViT + +From scratch implementation of GCViT in the style of timm swin_transformer_v2_cr.py + +Global Context Vision Transformers -https://arxiv.org/abs/2206.09959 + +@article{hatamizadeh2022global, + title={Global Context Vision Transformers}, + author={Hatamizadeh, Ali and Yin, Hongxu and Kautz, Jan and Molchanov, Pavlo}, + journal={arXiv preprint arXiv:2206.09959}, + year={2022} +} + +Free of any code related to NVIDIA GCVit impl at https://github.com/NVlabs/GCVit. +The license for this code release is Apache 2.0 with no commercial restrictions. + +However, weight files adapted from NVIDIA GCVit impl ARE under a non-commercial share-alike license +(https://creativecommons.org/licenses/by-nc-sa/4.0/) until I have a chance to train new ones... + +Hacked together by / Copyright 2022, Ross Wightman +""" +import math +from functools import partial +from typing import Callable, List, Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.utils.checkpoint as checkpoint + +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from .fx_features import register_notrace_function +from .helpers import build_model_with_cfg, named_apply +from .layers import trunc_normal_tf_, DropPath, to_2tuple, Mlp, get_attn, get_act_layer, get_norm_layer, \ + ClassifierHead, LayerNorm2d, _assert +from .registry import register_model +from .vision_transformer_relpos import RelPosMlp, RelPosBias # FIXME move to common location + +__all__ = ['GlobalContextVit'] + + +def _cfg(url='', **kwargs): + return { + 'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7), + 'crop_pct': 0.875, 'interpolation': 'bicubic', + 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, + 'first_conv': 'stem.conv', 'classifier': 'head.fc', + 'fixed_input_size': True, + **kwargs + } + + +default_cfgs = { + 'gcvit_xxtiny': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-morevit/gcvit_xxtiny_224_nvidia-d1d86009.pth'), + 'gcvit_xtiny': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-morevit/gcvit_xtiny_224_nvidia-274b92b7.pth'), + 'gcvit_tiny': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-morevit/gcvit_tiny_224_nvidia-ac783954.pth'), + 'gcvit_small': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-morevit/gcvit_small_224_nvidia-4e98afa2.pth'), + 'gcvit_base': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-morevit/gcvit_base_224_nvidia-f009139b.pth'), +} + + +class MbConvBlock(nn.Module): + """ A depthwise separable / fused mbconv style residual block with SE, `no norm. + """ + def __init__( + self, + in_chs, + out_chs=None, + expand_ratio=1.0, + attn_layer='se', + bias=False, + act_layer=nn.GELU, + ): + super().__init__() + attn_kwargs = dict(act_layer=act_layer) + if isinstance(attn_layer, str) and attn_layer == 'se' or attn_layer == 'eca': + attn_kwargs['rd_ratio'] = 0.25 + attn_kwargs['bias'] = False + attn_layer = get_attn(attn_layer) + out_chs = out_chs or in_chs + mid_chs = int(expand_ratio * in_chs) + + self.conv_dw = nn.Conv2d(in_chs, mid_chs, 3, 1, 1, groups=in_chs, bias=bias) + self.act = act_layer() + self.se = attn_layer(mid_chs, **attn_kwargs) + self.conv_pw = nn.Conv2d(mid_chs, out_chs, 1, 1, 0, bias=bias) + + def forward(self, x): + shortcut = x + x = self.conv_dw(x) + x = self.act(x) + x = self.se(x) + x = self.conv_pw(x) + x = x + shortcut + return x + + +class Downsample2d(nn.Module): + def __init__( + self, + dim, + dim_out=None, + reduction='conv', + act_layer=nn.GELU, + norm_layer=LayerNorm2d, + ): + super().__init__() + dim_out = dim_out or dim + + self.norm1 = norm_layer(dim) if norm_layer is not None else nn.Identity() + self.conv_block = MbConvBlock(dim, act_layer=act_layer) + assert reduction in ('conv', 'max', 'avg') + if reduction == 'conv': + self.reduction = nn.Conv2d(dim, dim_out, 3, 2, 1, bias=False) + elif reduction == 'max': + assert dim == dim_out + self.reduction = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + else: + assert dim == dim_out + self.reduction = nn.AvgPool2d(kernel_size=2) + self.norm2 = norm_layer(dim_out) if norm_layer is not None else nn.Identity() + + def forward(self, x): + x = self.norm1(x) + x = self.conv_block(x) + x = self.reduction(x) + x = self.norm2(x) + return x + + +class FeatureBlock(nn.Module): + def __init__( + self, + dim, + levels=0, + reduction='max', + act_layer=nn.GELU, + ): + super().__init__() + reductions = levels + levels = max(1, levels) + if reduction == 'avg': + pool_fn = partial(nn.AvgPool2d, kernel_size=2) + else: + pool_fn = partial(nn.MaxPool2d, kernel_size=3, stride=2, padding=1) + self.blocks = nn.Sequential() + for i in range(levels): + self.blocks.add_module(f'conv{i+1}', MbConvBlock(dim, act_layer=act_layer)) + if reductions: + self.blocks.add_module(f'pool{i+1}', pool_fn()) + reductions -= 1 + + def forward(self, x): + return self.blocks(x) + + +class Stem(nn.Module): + def __init__( + self, + in_chs: int = 3, + out_chs: int = 96, + act_layer: str = 'gelu', + norm_layer: str = 'layernorm2d', # NOTE norm for NCHW + ): + super().__init__() + act_layer = get_act_layer(act_layer) + norm_layer = get_norm_layer(norm_layer) + self.conv1 = nn.Conv2d(in_chs, out_chs, kernel_size=3, stride=2, padding=1) + self.down = Downsample2d(out_chs, act_layer=act_layer, norm_layer=norm_layer) + + def forward(self, x): + x = self.conv1(x) + x = self.down(x) + return x + + +class WindowAttentionGlobal(nn.Module): + + def __init__( + self, + dim: int, + num_heads: int, + window_size: Tuple[int, int], + use_global: bool = True, + qkv_bias: bool = True, + attn_drop: float = 0., + proj_drop: float = 0., + ): + super().__init__() + window_size = to_2tuple(window_size) + self.window_size = window_size + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.scale = self.head_dim ** -0.5 + self.use_global = use_global + + self.rel_pos = RelPosBias(window_size=window_size, num_heads=num_heads) + if self.use_global: + self.qkv = nn.Linear(dim, dim * 2, bias=qkv_bias) + else: + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x, q_global: Optional[torch.Tensor] = None): + B, N, C = x.shape + if self.use_global: + _assert(q_global is not None, 'q_global must be passed in global mode') + _assert(x.shape[-1] == q_global.shape[-1], 'x and q_global seq lengths should be equal') + + kv = self.qkv(x) + kv = kv.reshape(B, N, 2, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) + k, v = kv.unbind(0) + + q = q_global.repeat(B // q_global.shape[0], 1, 1, 1) + q = q.reshape(B, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3) + else: + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) + q, k, v = qkv.unbind(0) + q = q * self.scale + + attn = (q @ k.transpose(-2, -1)) + attn = self.rel_pos(attn) + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +def window_partition(x, window_size: Tuple[int, int]): + B, H, W, C = x.shape + x = x.view(B, H // window_size[0], window_size[0], W // window_size[1], window_size[1], C) + windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size[0], window_size[1], C) + return windows + + +@register_notrace_function # reason: int argument is a Proxy +def window_reverse(windows, window_size: Tuple[int, int], img_size: Tuple[int, int]): + H, W = img_size + B = int(windows.shape[0] / (H * W / window_size[0] / window_size[1])) + x = windows.view(B, H // window_size[0], W // window_size[1], window_size[0], window_size[1], -1) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) + return x + + +class LayerScale(nn.Module): + def __init__(self, dim, init_values=1e-5, inplace=False): + super().__init__() + self.inplace = inplace + self.gamma = nn.Parameter(init_values * torch.ones(dim)) + + def forward(self, x): + return x.mul_(self.gamma) if self.inplace else x * self.gamma + + +class GlobalContextVitBlock(nn.Module): + def __init__( + self, + dim: int, + feat_size: Tuple[int, int], + num_heads: int, + window_size: int = 7, + mlp_ratio: float = 4., + use_global: bool = True, + qkv_bias: bool = True, + layer_scale: Optional[float] = None, + proj_drop: float = 0., + attn_drop: float = 0., + drop_path: float = 0., + attn_layer: Callable = WindowAttentionGlobal, + act_layer: Callable = nn.GELU, + norm_layer: Callable = nn.LayerNorm, + ): + super().__init__() + feat_size = to_2tuple(feat_size) + window_size = to_2tuple(window_size) + self.window_size = window_size + self.num_windows = int((feat_size[0] // window_size[0]) * (feat_size[1] // window_size[1])) + + self.norm1 = norm_layer(dim) + self.attn = attn_layer( + dim, + num_heads=num_heads, + window_size=window_size, + use_global=use_global, + qkv_bias=qkv_bias, + attn_drop=attn_drop, + proj_drop=proj_drop, + ) + self.ls1 = LayerScale(dim, layer_scale) if layer_scale is not None else nn.Identity() + self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity() + + self.norm2 = norm_layer(dim) + self.mlp = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=proj_drop) + self.ls2 = LayerScale(dim, layer_scale) if layer_scale is not None else nn.Identity() + self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity() + + def _window_attn(self, x, q_global: Optional[torch.Tensor] = None): + B, H, W, C = x.shape + x_win = window_partition(x, self.window_size) + x_win = x_win.view(-1, self.window_size[0] * self.window_size[1], C) + attn_win = self.attn(x_win, q_global) + x = window_reverse(attn_win, self.window_size, (H, W)) + return x + + def forward(self, x, q_global: Optional[torch.Tensor] = None): + x = x + self.drop_path1(self.ls1(self._window_attn(self.norm1(x), q_global))) + x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x)))) + return x + + +class GlobalContextVitStage(nn.Module): + def __init__( + self, + dim, + depth: int, + num_heads: int, + feat_size: Tuple[int, int], + window_size: int, + downsample: bool = True, + global_norm: bool = False, + stage_norm: bool = False, + mlp_ratio: float = 4., + qkv_bias: bool = True, + layer_scale: Optional[float] = None, + proj_drop: float = 0., + attn_drop: float = 0., + drop_path: Union[List[float], float] = 0.0, + act_layer: str = 'gelu', + norm_layer: str = 'layernorm2d', + norm_layer_cl: str = 'layernorm', + ): + super().__init__() + act_layer = get_act_layer(act_layer) + norm_layer = get_norm_layer(norm_layer) + norm_layer_cl = get_norm_layer(norm_layer_cl) + + if downsample: + self.downsample = Downsample2d( + dim=dim, + dim_out=dim * 2, + norm_layer=norm_layer, + ) + dim = dim * 2 + feat_size = (feat_size[0] // 2, feat_size[1] // 2) + else: + self.downsample = nn.Identity() + self.feat_size = feat_size + + feat_levels = int(math.log2(min(feat_size) / window_size)) + self.global_block = FeatureBlock(dim, feat_levels) + self.global_norm = norm_layer_cl(dim) if global_norm else nn.Identity() + + self.blocks = nn.ModuleList([ + GlobalContextVitBlock( + dim=dim, + num_heads=num_heads, + feat_size=feat_size, + window_size=window_size, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + use_global=(i % 2 != 0), + layer_scale=layer_scale, + proj_drop=proj_drop, + attn_drop=attn_drop, + drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, + act_layer=act_layer, + norm_layer=norm_layer_cl, + ) + for i in range(depth) + ]) + self.norm = norm_layer_cl(dim) if stage_norm else nn.Identity() + self.dim = dim + self.feat_size = feat_size + self.grad_checkpointing = False + + def forward(self, x): + # input NCHW, downsample & global block are 2d conv + pooling + x = self.downsample(x) + global_query = self.global_block(x) + + # reshape NCHW --> NHWC for transformer blocks + x = x.permute(0, 2, 3, 1) + global_query = self.global_norm(global_query.permute(0, 2, 3, 1)) + for blk in self.blocks: + if self.grad_checkpointing and not torch.jit.is_scripting(): + x = checkpoint.checkpoint(blk, x) + else: + x = blk(x, global_query) + x = self.norm(x) + x = x.permute(0, 3, 1, 2).contiguous() # back to NCHW + return x + + +class GlobalContextVit(nn.Module): + def __init__( + self, + in_chans: int = 3, + num_classes: int = 1000, + global_pool: str = 'avg', + img_size: Tuple[int, int] = 224, + window_size: Tuple[int, ...] = (7, 7, 14, 7), + embed_dim: int = 64, + depths: Tuple[int, ...] = (3, 4, 19, 5), + num_heads: Tuple[int, ...] = (2, 4, 8, 16), + mlp_ratio: float = 3.0, + qkv_bias: bool = True, + layer_scale: Optional[float] = None, + drop_rate: float = 0., + proj_drop_rate: float = 0., + attn_drop_rate: float = 0., + drop_path_rate: float = 0., + weight_init='vit', + act_layer: str = 'gelu', + norm_layer: str = 'layernorm2d', + norm_layer_cl: str = 'layernorm', + ): + super().__init__() + img_size = to_2tuple(img_size) + feat_size = tuple(d // 4 for d in img_size) # stem reduction by 4 + self.global_pool = global_pool + self.num_classes = num_classes + num_stages = len(depths) + self.num_features = int(embed_dim * 2 ** (num_stages - 1)) + + self.stem = Stem( + in_chs=in_chans, out_chs=embed_dim, act_layer=act_layer, norm_layer=norm_layer) + + dpr = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depths)).split(depths)] + stages = [] + for i in range(num_stages): + last_stage = i == num_stages - 1 + stage_scale = 2 ** max(i - 1, 0) + stages.append(GlobalContextVitStage( + dim=embed_dim * stage_scale, + depth=depths[i], + num_heads=num_heads[i], + feat_size=(feat_size[0] // stage_scale, feat_size[1] // stage_scale), + window_size=window_size[i], + downsample=i != 0, + stage_norm=last_stage, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + layer_scale=layer_scale, + proj_drop=proj_drop_rate, + attn_drop=attn_drop_rate, + drop_path=dpr[i], + act_layer=act_layer, + norm_layer=norm_layer, + norm_layer_cl=norm_layer_cl, + )) + self.stages = nn.Sequential(*stages) + + # Classifier head + self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=drop_rate) + + if weight_init: + named_apply(partial(self._init_weights, scheme=weight_init), self) + + def _init_weights(self, module, name, scheme='vit'): + # note Conv2d left as default init + if scheme == 'vit': + if isinstance(module, nn.Linear): + nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + if 'mlp' in name: + nn.init.normal_(module.bias, std=1e-6) + else: + nn.init.zeros_(module.bias) + else: + if isinstance(module, nn.Linear): + trunc_normal_tf_(module.weight, std=.02) + if module.bias is not None: + nn.init.zeros_(module.bias) + + @torch.jit.ignore + def no_weight_decay(self): + return { + k for k, _ in self.named_parameters() + if any(n in k for n in ["relative_position_bias_table", "rel_pos.mlp"])} + + + @torch.jit.ignore + def group_matcher(self, coarse=False): + matcher = dict( + stem=r'^stem', # stem and embed + blocks=(r'^stages\.(\d+)', None) + ) + return matcher + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + for s in self.stages: + s.grad_checkpointing = enable + + def forward_features(self, x: torch.Tensor) -> torch.Tensor: + x = self.stem(x) + x = self.stages(x) + return x + + def forward_head(self, x, pre_logits: bool = False): + return self.head(x, pre_logits=pre_logits) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.forward_features(x) + x = self.forward_head(x) + return x + + +def _create_gcvit(variant, pretrained=False, **kwargs): + if kwargs.get('features_only', None): + raise RuntimeError('features_only not implemented for Vision Transformer models.') + model = build_model_with_cfg(GlobalContextVit, variant, pretrained, **kwargs) + return model + + +@register_model +def gcvit_xxtiny(pretrained=False, **kwargs): + model_kwargs = dict( + depths=(2, 2, 6, 2), + num_heads=(2, 4, 8, 16), + **kwargs) + return _create_gcvit('gcvit_xxtiny', pretrained=pretrained, **model_kwargs) + + +@register_model +def gcvit_xtiny(pretrained=False, **kwargs): + model_kwargs = dict( + depths=(3, 4, 6, 5), + num_heads=(2, 4, 8, 16), + **kwargs) + return _create_gcvit('gcvit_xtiny', pretrained=pretrained, **model_kwargs) + + +@register_model +def gcvit_tiny(pretrained=False, **kwargs): + model_kwargs = dict( + depths=(3, 4, 19, 5), + num_heads=(2, 4, 8, 16), + **kwargs) + return _create_gcvit('gcvit_tiny', pretrained=pretrained, **model_kwargs) + + +@register_model +def gcvit_small(pretrained=False, **kwargs): + model_kwargs = dict( + depths=(3, 4, 19, 5), + num_heads=(3, 6, 12, 24), + window_size=(7, 7, 14, 7), + embed_dim=96, + mlp_ratio=2, + layer_scale=1e-5, + **kwargs) + return _create_gcvit('gcvit_small', pretrained=pretrained, **model_kwargs) + + +@register_model +def gcvit_base(pretrained=False, **kwargs): + model_kwargs = dict( + depths=(3, 4, 19, 5), + num_heads=(4, 8, 16, 32), + window_size=(7, 7, 14, 7), + embed_dim=128, + mlp_ratio=2, + layer_scale=1e-5, + **kwargs) + return _create_gcvit('gcvit_base', pretrained=pretrained, **model_kwargs) From 43aa84e861b2df12f0ec92db493b4049b753311c Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Wed, 17 Aug 2022 14:32:58 -0700 Subject: [PATCH 04/21] Add 'fast' layer norm that doesn't cast to float32, support APEX LN impl for slight speed gain, update norm and act factories, tweak SE for ability to disable bias (needed by GCVit) --- timm/models/layers/__init__.py | 4 +- timm/models/layers/create_act.py | 8 +++- timm/models/layers/create_norm.py | 56 +++++++++++++++++++++++ timm/models/layers/fast_norm.py | 68 ++++++++++++++++++++++++++++ timm/models/layers/norm.py | 51 +++++++++++++++++++-- timm/models/layers/norm_act.py | 23 ++++++++-- timm/models/layers/squeeze_excite.py | 6 +-- 7 files changed, 201 insertions(+), 15 deletions(-) create mode 100644 timm/models/layers/create_norm.py create mode 100644 timm/models/layers/fast_norm.py diff --git a/timm/models/layers/__init__.py b/timm/models/layers/__init__.py index b9eeec0f..071da7bc 100644 --- a/timm/models/layers/__init__.py +++ b/timm/models/layers/__init__.py @@ -11,11 +11,13 @@ from .conv_bn_act import ConvNormAct, ConvNormActAa, ConvBnAct from .create_act import create_act_layer, get_act_layer, get_act_fn from .create_attn import get_attn, create_attn from .create_conv2d import create_conv2d +from .create_norm import get_norm_layer, create_norm_layer from .create_norm_act import get_norm_act_layer, create_norm_act_layer, get_norm_act_layer from .drop import DropBlock2d, DropPath, drop_block_2d, drop_path from .eca import EcaModule, CecaModule, EfficientChannelAttn, CircularEfficientChannelAttn from .evo_norm import EvoNorm2dB0, EvoNorm2dB1, EvoNorm2dB2,\ EvoNorm2dS0, EvoNorm2dS0a, EvoNorm2dS1, EvoNorm2dS1a, EvoNorm2dS2, EvoNorm2dS2a +from .fast_norm import is_fast_norm, set_fast_norm, fast_group_norm, fast_layer_norm from .filter_response_norm import FilterResponseNormTlu2d, FilterResponseNormAct2d from .gather_excite import GatherExcite from .global_context import GlobalContext @@ -25,7 +27,7 @@ from .linear import Linear from .mixed_conv2d import MixedConv2d from .mlp import Mlp, GluMlp, GatedMlp, ConvMlp from .non_local_attn import NonLocalAttn, BatNonLocalAttn -from .norm import GroupNorm, GroupNorm1, LayerNorm2d +from .norm import GroupNorm, GroupNorm1, LayerNorm, LayerNorm2d from .norm_act import BatchNormAct2d, GroupNormAct, convert_sync_batchnorm from .padding import get_padding, get_same_padding, pad_same from .patch_embed import PatchEmbed diff --git a/timm/models/layers/create_act.py b/timm/models/layers/create_act.py index e38f2e03..a3044a3d 100644 --- a/timm/models/layers/create_act.py +++ b/timm/models/layers/create_act.py @@ -145,4 +145,10 @@ def create_act_layer(name: Union[nn.Module, str], inplace=None, **kwargs): act_layer = get_act_layer(name) if act_layer is None: return None - return act_layer(**kwargs) if inplace is None else act_layer(inplace=inplace, **kwargs) + if inplace is None: + return act_layer(**kwargs) + try: + return act_layer(inplace=inplace, **kwargs) + except TypeError: + # recover if act layer doesn't have inplace arg + return act_layer(**kwargs) diff --git a/timm/models/layers/create_norm.py b/timm/models/layers/create_norm.py new file mode 100644 index 00000000..b9efae8c --- /dev/null +++ b/timm/models/layers/create_norm.py @@ -0,0 +1,56 @@ +""" Norm Layer Factory + +Create norm modules by string (to mirror create_act and creat_norm-act fns) + +Copyright 2022 Ross Wightman +""" +import types +import functools + +import torch.nn as nn + +from .norm import GroupNorm, GroupNorm1, LayerNorm, LayerNorm2d + +_NORM_MAP = dict( + batchnorm=nn.BatchNorm2d, + batchnorm2d=nn.BatchNorm2d, + batchnorm1d=nn.BatchNorm1d, + groupnorm=GroupNorm, + groupnorm1=GroupNorm1, + layernorm=LayerNorm, + layernorm2d=LayerNorm2d, +) +_NORM_TYPES = {m for n, m in _NORM_MAP.items()} + + +def create_norm_layer(layer_name, num_features, act_layer=None, apply_act=True, **kwargs): + layer = get_norm_layer(layer_name, act_layer=act_layer) + layer_instance = layer(num_features, apply_act=apply_act, **kwargs) + return layer_instance + + +def get_norm_layer(norm_layer): + assert isinstance(norm_layer, (type, str, types.FunctionType, functools.partial)) + norm_kwargs = {} + + # unbind partial fn, so args can be rebound later + if isinstance(norm_layer, functools.partial): + norm_kwargs.update(norm_layer.keywords) + norm_layer = norm_layer.func + + if isinstance(norm_layer, str): + layer_name = norm_layer.replace('_', '') + norm_layer = _NORM_MAP.get(layer_name, None) + elif norm_layer in _NORM_TYPES: + norm_layer = norm_layer + elif isinstance(norm_layer, types.FunctionType): + # if function type, assume it is a lambda/fn that creates a norm layer + norm_layer = norm_layer + else: + type_name = norm_layer.__name__.lower().replace('_', '') + norm_layer = _NORM_MAP.get(type_name, None) + assert norm_layer is not None, f"No equivalent norm layer for {type_name}" + + if norm_kwargs: + norm_layer = functools.partial(norm_layer, **norm_kwargs) # bind/rebind args + return norm_layer diff --git a/timm/models/layers/fast_norm.py b/timm/models/layers/fast_norm.py new file mode 100644 index 00000000..9a34a15e --- /dev/null +++ b/timm/models/layers/fast_norm.py @@ -0,0 +1,68 @@ +from typing import List, Optional + +import torch +from torch.nn import functional as F + +try: + from apex.normalization.fused_layer_norm import fused_layer_norm_affine + has_apex = True +except ImportError: + has_apex = False + + +# fast (ie lower precision LN) can be disabled with this flag if issues crop up +_USE_FAST_NORM = False # defaulting to False for now + + +def is_fast_norm(): + return _USE_FAST_NORM + + +def set_fast_norm(enable=True): + global _USE_FAST_NORM + _USE_FAST_NORM = enable + + +def fast_group_norm( + x: torch.Tensor, + num_groups: int, + weight: Optional[torch.Tensor] = None, + bias: Optional[torch.Tensor] = None, + eps: float = 1e-5 +) -> torch.Tensor: + if torch.jit.is_scripting(): + # currently cannot use is_autocast_enabled within torchscript + return F.group_norm(x, num_groups, weight, bias, eps) + + if torch.is_autocast_enabled(): + # normally native AMP casts GN inputs to float32 + # here we use the low precision autocast dtype + dt = torch.get_autocast_gpu_dtype() + x, weight, bias = x.to(dt), weight.to(dt), bias.to(dt) + + with torch.cuda.amp.autocast(enabled=False): + return F.group_norm(x, num_groups, weight, bias, eps) + + +def fast_layer_norm( + x: torch.Tensor, + normalized_shape: List[int], + weight: Optional[torch.Tensor] = None, + bias: Optional[torch.Tensor] = None, + eps: float = 1e-5 +) -> torch.Tensor: + if torch.jit.is_scripting(): + # currently cannot use is_autocast_enabled within torchscript + return F.layer_norm(x, normalized_shape, weight, bias, eps) + + if has_apex: + return fused_layer_norm_affine(x, weight, bias, normalized_shape, eps) + + if torch.is_autocast_enabled(): + # normally native AMP casts LN inputs to float32 + # apex LN does not, this is behaving like Apex + dt = torch.get_autocast_gpu_dtype() + x, weight, bias = x.to(dt), weight.to(dt), bias.to(dt) + + with torch.cuda.amp.autocast(enabled=False): + return F.layer_norm(x, normalized_shape, weight, bias, eps) diff --git a/timm/models/layers/norm.py b/timm/models/layers/norm.py index 1677dbfa..2ff8fc08 100644 --- a/timm/models/layers/norm.py +++ b/timm/models/layers/norm.py @@ -1,17 +1,24 @@ """ Normalization layers and wrappers """ + import torch import torch.nn as nn import torch.nn.functional as F +from .fast_norm import is_fast_norm, fast_group_norm, fast_layer_norm + class GroupNorm(nn.GroupNorm): def __init__(self, num_channels, num_groups=32, eps=1e-5, affine=True): # NOTE num_channels is swapped to first arg for consistency in swapping norm layers with BN super().__init__(num_groups, num_channels, eps=eps, affine=affine) + self.fast_norm = is_fast_norm() # can't script unless we have these flags here (no globals) def forward(self, x): - return F.group_norm(x, self.num_groups, self.weight, self.bias, self.eps) + if self.fast_norm: + return fast_group_norm(x, self.num_groups, self.weight, self.bias, self.eps) + else: + return F.group_norm(x, self.num_groups, self.weight, self.bias, self.eps) class GroupNorm1(nn.GroupNorm): @@ -21,22 +28,48 @@ class GroupNorm1(nn.GroupNorm): def __init__(self, num_channels, **kwargs): super().__init__(1, num_channels, **kwargs) + self.fast_norm = is_fast_norm() # can't script unless we have these flags here (no globals) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self.fast_norm: + return fast_group_norm(x, self.num_groups, self.weight, self.bias, self.eps) + else: + return F.group_norm(x, self.num_groups, self.weight, self.bias, self.eps) + + +class LayerNorm(nn.LayerNorm): + """ LayerNorm w/ fast norm option + """ + def __init__(self, num_channels, eps=1e-6, affine=True): + super().__init__(num_channels, eps=eps, elementwise_affine=affine) + self._fast_norm = is_fast_norm() # can't script unless we have these flags here (no globals) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self._fast_norm: + x = fast_layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) + else: + x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) + return x class LayerNorm2d(nn.LayerNorm): """ LayerNorm for channels of '2D' spatial NCHW tensors """ def __init__(self, num_channels, eps=1e-6, affine=True): super().__init__(num_channels, eps=eps, elementwise_affine=affine) + self._fast_norm = is_fast_norm() # can't script unless we have these flags here (no globals) def forward(self, x: torch.Tensor) -> torch.Tensor: - return F.layer_norm( - x.permute(0, 2, 3, 1), self.normalized_shape, self.weight, self.bias, self.eps).permute(0, 3, 1, 2) + x = x.permute(0, 2, 3, 1) + if self._fast_norm: + x = fast_layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) + else: + x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) + x = x.permute(0, 3, 1, 2) + return x def _is_contiguous(tensor: torch.Tensor) -> bool: # jit is oh so lovely :/ - # if torch.jit.is_tracing(): - # return True if torch.jit.is_scripting(): return tensor.is_contiguous() else: @@ -51,6 +84,14 @@ def _layer_norm_cf(x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, ep return x +def _layer_norm_cf_sqm(x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, eps: float): + u = x.mean(dim=1, keepdim=True) + s = ((x * x).mean(dim=1, keepdim=True) - (u * u)).clamp(0) + x = (x - u) * torch.rsqrt(s + eps) + x = x * weight.view(1, -1, 1, 1) + bias.view(1, -1, 1, 1) + return x + + class LayerNormExp2d(nn.LayerNorm): """ LayerNorm for channels_first tensors with 2d spatial dimensions (ie N, C, H, W). diff --git a/timm/models/layers/norm_act.py b/timm/models/layers/norm_act.py index ea5b7883..be1edead 100644 --- a/timm/models/layers/norm_act.py +++ b/timm/models/layers/norm_act.py @@ -6,8 +6,9 @@ import torch from torch import nn as nn from torch.nn import functional as F -from .trace_utils import _assert from .create_act import get_act_layer +from .fast_norm import is_fast_norm, fast_group_norm, fast_layer_norm +from .trace_utils import _assert class BatchNormAct2d(nn.BatchNorm2d): @@ -177,9 +178,13 @@ class GroupNormAct(nn.GroupNorm): self.act = act_layer(**act_args) else: self.act = nn.Identity() + self._fast_norm = is_fast_norm() def forward(self, x): - x = F.group_norm(x, self.num_groups, self.weight, self.bias, self.eps) + if self._fast_norm: + x = fast_group_norm(x, self.num_groups, self.weight, self.bias, self.eps) + else: + x = F.group_norm(x, self.num_groups, self.weight, self.bias, self.eps) x = self.drop(x) x = self.act(x) return x @@ -197,9 +202,13 @@ class LayerNormAct(nn.LayerNorm): self.act = act_layer(**act_args) else: self.act = nn.Identity() + self._fast_norm = is_fast_norm() def forward(self, x): - x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) + if self._fast_norm: + x = fast_layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) + else: + x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) x = self.drop(x) x = self.act(x) return x @@ -219,8 +228,12 @@ class LayerNormAct2d(nn.LayerNorm): self.act = nn.Identity() def forward(self, x): - x = F.layer_norm( - x.permute(0, 2, 3, 1), self.normalized_shape, self.weight, self.bias, self.eps).permute(0, 3, 1, 2) + x = x.permute(0, 2, 3, 1) + if self._fast_norm: + x = fast_layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) + else: + x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) + x = x.permute(0, 3, 1, 2) x = self.drop(x) x = self.act(x) return x diff --git a/timm/models/layers/squeeze_excite.py b/timm/models/layers/squeeze_excite.py index e5da29ef..2e41d956 100644 --- a/timm/models/layers/squeeze_excite.py +++ b/timm/models/layers/squeeze_excite.py @@ -27,15 +27,15 @@ class SEModule(nn.Module): """ def __init__( self, channels, rd_ratio=1. / 16, rd_channels=None, rd_divisor=8, add_maxpool=False, - act_layer=nn.ReLU, norm_layer=None, gate_layer='sigmoid'): + bias=True, act_layer=nn.ReLU, norm_layer=None, gate_layer='sigmoid'): super(SEModule, self).__init__() self.add_maxpool = add_maxpool if not rd_channels: rd_channels = make_divisible(channels * rd_ratio, rd_divisor, round_limit=0.) - self.fc1 = nn.Conv2d(channels, rd_channels, kernel_size=1, bias=True) + self.fc1 = nn.Conv2d(channels, rd_channels, kernel_size=1, bias=bias) self.bn = norm_layer(rd_channels) if norm_layer else nn.Identity() self.act = create_act_layer(act_layer, inplace=True) - self.fc2 = nn.Conv2d(rd_channels, channels, kernel_size=1, bias=True) + self.fc2 = nn.Conv2d(rd_channels, channels, kernel_size=1, bias=bias) self.gate = create_act_layer(gate_layer) def forward(self, x): From 6e559e9b5fb3db657f68727a90adf89603172cfe Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Wed, 17 Aug 2022 15:12:31 -0700 Subject: [PATCH 05/21] Add MViT (Multi-Scale) V2 --- timm/models/__init__.py | 1 + timm/models/mvitv2.py | 993 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 994 insertions(+) create mode 100644 timm/models/mvitv2.py diff --git a/timm/models/__init__.py b/timm/models/__init__.py index 936846c3..b93d3f94 100644 --- a/timm/models/__init__.py +++ b/timm/models/__init__.py @@ -28,6 +28,7 @@ from .levit import * from .mlp_mixer import * from .mobilenetv3 import * from .mobilevit import * +from .mvitv2 import * from .nasnet import * from .nest import * from .nfnet import * diff --git a/timm/models/mvitv2.py b/timm/models/mvitv2.py new file mode 100644 index 00000000..4c7cd044 --- /dev/null +++ b/timm/models/mvitv2.py @@ -0,0 +1,993 @@ +""" Multi-Scale Vision Transformer v2 + +@inproceedings{li2021improved, + title={MViTv2: Improved multiscale vision transformers for classification and detection}, + author={Li, Yanghao and Wu, Chao-Yuan and Fan, Haoqi and Mangalam, Karttikeya and Xiong, Bo and Malik, Jitendra and Feichtenhofer, Christoph}, + booktitle={CVPR}, + year={2022} +} + +Code adapted from original Apache 2.0 licensed impl at https://github.com/facebookresearch/mvit +Original copyright below. + +Modifications and timm support by / Copyright 2022, Ross Wightman +""" +# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved. All Rights Reserved. +import operator +from collections import OrderedDict +from dataclasses import dataclass +from functools import partial, reduce +from typing import Union, List, Tuple, Optional + +import torch +import torch.utils.checkpoint as checkpoint +from torch import nn + +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from .helpers import build_model_with_cfg +from .layers import Mlp, DropPath, trunc_normal_tf_, get_norm_layer, to_2tuple +from .registry import register_model + + +def _cfg(url='', **kwargs): + return { + 'url': url, + 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, + 'crop_pct': .9, 'interpolation': 'bicubic', + 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, + 'first_conv': 'patch_embed.proj', 'classifier': 'head', 'fixed_input_size': True, + **kwargs + } + + +default_cfgs = dict( + mvitv2_tiny=_cfg(url='https://dl.fbaipublicfiles.com/mvit/mvitv2_models/MViTv2_T_in1k.pyth'), + mvitv2_small=_cfg(url='https://dl.fbaipublicfiles.com/mvit/mvitv2_models/MViTv2_S_in1k.pyth'), + mvitv2_base=_cfg(url='https://dl.fbaipublicfiles.com/mvit/mvitv2_models/MViTv2_B_in1k.pyth'), + mvitv2_large=_cfg(url='https://dl.fbaipublicfiles.com/mvit/mvitv2_models/MViTv2_L_in1k.pyth'), + + mvitv2_base_in21k=_cfg( + url='https://dl.fbaipublicfiles.com/mvit/mvitv2_models/MViTv2_B_in21k.pyth', + num_classes=19168), + mvitv2_large_in21k=_cfg( + url='https://dl.fbaipublicfiles.com/mvit/mvitv2_models/MViTv2_L_in21k.pyth', + num_classes=19168), + mvitv2_huge_in21k=_cfg( + url='https://dl.fbaipublicfiles.com/mvit/mvitv2_models/MViTv2_H_in21k.pyth', + num_classes=19168), +) + + +@dataclass +class MultiScaleVitCfg: + depths: Tuple[int, ...] = (2, 3, 16, 3) + embed_dim: Union[int, Tuple[int, ...]] = 96 + num_heads: Union[int, Tuple[int, ...]] = 1 + mlp_ratio: float = 4. + pool_first: bool = False + expand_attn: bool = True + qkv_bias: bool = True + use_cls_token: bool = False + use_abs_pos: bool = False + residual_pooling: bool = True + mode: str = 'conv' + kernel_qkv: Tuple[int, int] = (3, 3) + stride_q: Optional[Tuple[Tuple[int, int]]] = ((1, 1), (2, 2), (2, 2), (2, 2)) + stride_kv: Optional[Tuple[Tuple[int, int]]] = None + stride_kv_adaptive: Optional[Tuple[int, int]] = (4, 4) + patch_kernel: Tuple[int, int] = (7, 7) + patch_stride: Tuple[int, int] = (4, 4) + patch_padding: Tuple[int, int] = (3, 3) + pool_type: str = 'max' + rel_pos_type: str = 'spatial' + act_layer: Union[str, Tuple[str, str]] = 'gelu' + norm_layer: Union[str, Tuple[str, str]] = 'layernorm' + norm_eps: float = 1e-6 + + def __post_init__(self): + num_stages = len(self.depths) + if not isinstance(self.embed_dim, (tuple, list)): + self.embed_dim = tuple(self.embed_dim * 2 ** i for i in range(num_stages)) + assert len(self.embed_dim) == num_stages + + if not isinstance(self.num_heads, (tuple, list)): + self.num_heads = tuple(self.num_heads * 2 ** i for i in range(num_stages)) + assert len(self.num_heads) == num_stages + + if self.stride_kv_adaptive is not None and self.stride_kv is None: + _stride_kv = self.stride_kv_adaptive + pool_kv_stride = [] + for i in range(num_stages): + if min(self.stride_q[i]) > 1: + _stride_kv = [ + max(_stride_kv[d] // self.stride_q[i][d], 1) + for d in range(len(_stride_kv)) + ] + pool_kv_stride.append(tuple(_stride_kv)) + self.stride_kv = tuple(pool_kv_stride) + + +model_cfgs = dict( + mvitv2_tiny=MultiScaleVitCfg( + depths=(1, 2, 5, 2), + ), + mvitv2_small=MultiScaleVitCfg( + depths=(1, 2, 11, 2), + ), + mvitv2_base=MultiScaleVitCfg( + depths=(2, 3, 16, 3), + ), + mvitv2_large=MultiScaleVitCfg( + depths=(2, 6, 36, 4), + embed_dim=144, + num_heads=2, + expand_attn=False, + ), + + mvitv2_base_in21k=MultiScaleVitCfg( + depths=(2, 3, 16, 3), + ), + mvitv2_large_in21k=MultiScaleVitCfg( + depths=(2, 6, 36, 4), + embed_dim=144, + num_heads=2, + expand_attn=False, + ), +) + + +def prod(iterable): + return reduce(operator.mul, iterable, 1) + + +class PatchEmbed(nn.Module): + """ + PatchEmbed. + """ + + def __init__( + self, + dim_in=3, + dim_out=768, + kernel=(7, 7), + stride=(4, 4), + padding=(3, 3), + ): + super().__init__() + + self.proj = nn.Conv2d( + dim_in, + dim_out, + kernel_size=kernel, + stride=stride, + padding=padding, + ) + + def forward(self, x) -> Tuple[torch.Tensor, List[int]]: + x = self.proj(x) + # B C H W -> B HW C + return x.flatten(2).transpose(1, 2), x.shape[-2:] + + +def reshape_pre_pool( + x, + feat_size: List[int], + has_cls_token: bool = True +) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + H, W = feat_size + if has_cls_token: + cls_tok, x = x[:, :, :1, :], x[:, :, 1:, :] + else: + cls_tok = None + x = x.reshape(-1, H, W, x.shape[-1]).permute(0, 3, 1, 2).contiguous() + return x, cls_tok + + +def reshape_post_pool( + x, + num_heads: int, + cls_tok: Optional[torch.Tensor] = None +) -> Tuple[torch.Tensor, List[int]]: + feat_size = [x.shape[2], x.shape[3]] + L_pooled = x.shape[2] * x.shape[3] + x = x.reshape(-1, num_heads, x.shape[1], L_pooled).transpose(2, 3) + if cls_tok is not None: + x = torch.cat((cls_tok, x), dim=2) + return x, feat_size + + +def cal_rel_pos_type( + attn: torch.Tensor, + q: torch.Tensor, + has_cls_token: bool, + q_size: List[int], + k_size: List[int], + rel_pos_h: torch.Tensor, + rel_pos_w: torch.Tensor, +): + """ + Spatial Relative Positional Embeddings. + """ + sp_idx = 1 if has_cls_token else 0 + q_h, q_w = q_size + k_h, k_w = k_size + + # Scale up rel pos if shapes for q and k are different. + q_h_ratio = max(k_h / q_h, 1.0) + k_h_ratio = max(q_h / k_h, 1.0) + dist_h = torch.arange(q_h)[:, None] * q_h_ratio - torch.arange(k_h)[None, :] * k_h_ratio + dist_h += (k_h - 1) * k_h_ratio + q_w_ratio = max(k_w / q_w, 1.0) + k_w_ratio = max(q_w / k_w, 1.0) + dist_w = torch.arange(q_w)[:, None] * q_w_ratio - torch.arange(k_w)[None, :] * k_w_ratio + dist_w += (k_w - 1) * k_w_ratio + + Rh = rel_pos_h[dist_h.long()] + Rw = rel_pos_w[dist_w.long()] + + B, n_head, q_N, dim = q.shape + + r_q = q[:, :, sp_idx:].reshape(B, n_head, q_h, q_w, dim) + rel_h = torch.einsum("byhwc,hkc->byhwk", r_q, Rh) + rel_w = torch.einsum("byhwc,wkc->byhwk", r_q, Rw) + + attn[:, :, sp_idx:, sp_idx:] = ( + attn[:, :, sp_idx:, sp_idx:].view(B, -1, q_h, q_w, k_h, k_w) + + rel_h[:, :, :, :, :, None] + + rel_w[:, :, :, :, None, :] + ).view(B, -1, q_h * q_w, k_h * k_w) + + return attn + + +class MultiScaleAttentionPoolFirst(nn.Module): + def __init__( + self, + dim, + dim_out, + feat_size, + num_heads=8, + qkv_bias=True, + mode="conv", + kernel_q=(1, 1), + kernel_kv=(1, 1), + stride_q=(1, 1), + stride_kv=(1, 1), + has_cls_token=True, + rel_pos_type='spatial', + residual_pooling=True, + norm_layer=nn.LayerNorm, + ): + super().__init__() + self.num_heads = num_heads + self.dim_out = dim_out + self.head_dim = dim_out // num_heads + self.scale = self.head_dim ** -0.5 + self.has_cls_token = has_cls_token + padding_q = tuple([int(q // 2) for q in kernel_q]) + padding_kv = tuple([int(kv // 2) for kv in kernel_kv]) + + self.q = nn.Linear(dim, dim_out, bias=qkv_bias) + self.k = nn.Linear(dim, dim_out, bias=qkv_bias) + self.v = nn.Linear(dim, dim_out, bias=qkv_bias) + self.proj = nn.Linear(dim_out, dim_out) + + # Skip pooling with kernel and stride size of (1, 1, 1). + if prod(kernel_q) == 1 and prod(stride_q) == 1: + kernel_q = None + if prod(kernel_kv) == 1 and prod(stride_kv) == 1: + kernel_kv = None + self.mode = mode + self.unshared = mode == 'conv_unshared' + self.pool_q, self.pool_k, self.pool_v = None, None, None + self.norm_q, self.norm_k, self.norm_v = None, None, None + if mode in ("avg", "max"): + pool_op = nn.MaxPool2d if mode == "max" else nn.AvgPool2d + if kernel_q: + self.pool_q = pool_op(kernel_q, stride_q, padding_q) + if kernel_kv: + self.pool_k = pool_op(kernel_kv, stride_kv, padding_kv) + self.pool_v = pool_op(kernel_kv, stride_kv, padding_kv) + elif mode == "conv" or mode == "conv_unshared": + dim_conv = dim // num_heads if mode == "conv" else dim + if kernel_q: + self.pool_q = nn.Conv2d( + dim_conv, + dim_conv, + kernel_q, + stride=stride_q, + padding=padding_q, + groups=dim_conv, + bias=False, + ) + self.norm_q = norm_layer(dim_conv) + if kernel_kv: + self.pool_k = nn.Conv2d( + dim_conv, + dim_conv, + kernel_kv, + stride=stride_kv, + padding=padding_kv, + groups=dim_conv, + bias=False, + ) + self.norm_k = norm_layer(dim_conv) + self.pool_v = nn.Conv2d( + dim_conv, + dim_conv, + kernel_kv, + stride=stride_kv, + padding=padding_kv, + groups=dim_conv, + bias=False, + ) + self.norm_v = norm_layer(dim_conv) + else: + raise NotImplementedError(f"Unsupported model {mode}") + + # relative pos embedding + self.rel_pos_type = rel_pos_type + if self.rel_pos_type == 'spatial': + assert feat_size[0] == feat_size[1] + size = feat_size[0] + q_size = size // stride_q[1] if len(stride_q) > 0 else size + kv_size = size // stride_kv[1] if len(stride_kv) > 0 else size + rel_sp_dim = 2 * max(q_size, kv_size) - 1 + + self.rel_pos_h = nn.Parameter(torch.zeros(rel_sp_dim, self.head_dim)) + self.rel_pos_w = nn.Parameter(torch.zeros(rel_sp_dim, self.head_dim)) + trunc_normal_tf_(self.rel_pos_h, std=0.02) + trunc_normal_tf_(self.rel_pos_w, std=0.02) + + self.residual_pooling = residual_pooling + + def forward(self, x, feat_size: List[int]): + B, N, _ = x.shape + + fold_dim = 1 if self.unshared else self.num_heads + x = x.reshape(B, N, fold_dim, -1).permute(0, 2, 1, 3) + q = k = v = x + + if self.pool_q is not None: + q, q_tok = reshape_pre_pool(q, feat_size, self.has_cls_token) + q = self.pool_q(q) + q, q_size = reshape_post_pool(q, self.num_heads, q_tok) + else: + q_size = feat_size + if self.norm_q is not None: + q = self.norm_q(q) + + if self.pool_k is not None: + k, k_tok = reshape_pre_pool(k, feat_size, self.has_cls_token) + k = self.pool_k(k) + k, k_size = reshape_post_pool(k, self.num_heads, k_tok) + else: + k_size = feat_size + if self.norm_k is not None: + k = self.norm_k(k) + + if self.pool_v is not None: + v, v_tok = reshape_pre_pool(v, feat_size, self.has_cls_token) + v = self.pool_v(v) + v, v_size = reshape_post_pool(v, self.num_heads, v_tok) + else: + v_size = feat_size + if self.norm_v is not None: + v = self.norm_v(v) + + q_N = q_size[0] * q_size[1] + int(self.has_cls_token) + q = q.permute(0, 2, 1, 3).reshape(B, q_N, -1) + q = self.q(q).reshape(B, q_N, self.num_heads, -1).permute(0, 2, 1, 3) + + k_N = k_size[0] * k_size[1] + int(self.has_cls_token) + k = k.permute(0, 2, 1, 3).reshape(B, k_N, -1) + k = self.k(k).reshape(B, k_N, self.num_heads, -1).permute(0, 2, 1, 3) + + v_N = v_size[0] * v_size[1] + int(self.has_cls_token) + v = v.permute(0, 2, 1, 3).reshape(B, v_N, -1) + v = self.v(v).reshape(B, v_N, self.num_heads, -1).permute(0, 2, 1, 3) + + attn = (q * self.scale) @ k.transpose(-2, -1) + if self.rel_pos_type == 'spatial': + attn = cal_rel_pos_type( + attn, + q, + self.has_cls_token, + q_size, + k_size, + self.rel_pos_h, + self.rel_pos_w, + ) + attn = attn.softmax(dim=-1) + x = attn @ v + + if self.residual_pooling: + x = x + q + + x = x.transpose(1, 2).reshape(B, -1, self.dim_out) + x = self.proj(x) + + return x, q_size + + +class MultiScaleAttention(nn.Module): + def __init__( + self, + dim, + dim_out, + feat_size, + num_heads=8, + qkv_bias=True, + mode="conv", + kernel_q=(1, 1), + kernel_kv=(1, 1), + stride_q=(1, 1), + stride_kv=(1, 1), + has_cls_token=True, + rel_pos_type='spatial', + residual_pooling=True, + norm_layer=nn.LayerNorm, + ): + super().__init__() + self.num_heads = num_heads + self.dim_out = dim_out + self.head_dim = dim_out // num_heads + self.scale = self.head_dim ** -0.5 + self.has_cls_token = has_cls_token + padding_q = tuple([int(q // 2) for q in kernel_q]) + padding_kv = tuple([int(kv // 2) for kv in kernel_kv]) + + self.qkv = nn.Linear(dim, dim_out * 3, bias=qkv_bias) + self.proj = nn.Linear(dim_out, dim_out) + + # Skip pooling with kernel and stride size of (1, 1, 1). + if prod(kernel_q) == 1 and prod(stride_q) == 1: + kernel_q = None + if prod(kernel_kv) == 1 and prod(stride_kv) == 1: + kernel_kv = None + self.mode = mode + self.unshared = mode == 'conv_unshared' + self.norm_q, self.norm_k, self.norm_v = None, None, None + self.pool_q, self.pool_k, self.pool_v = None, None, None + if mode in ("avg", "max"): + pool_op = nn.MaxPool2d if mode == "max" else nn.AvgPool2d + if kernel_q: + self.pool_q = pool_op(kernel_q, stride_q, padding_q) + if kernel_kv: + self.pool_k = pool_op(kernel_kv, stride_kv, padding_kv) + self.pool_v = pool_op(kernel_kv, stride_kv, padding_kv) + elif mode == "conv" or mode == "conv_unshared": + dim_conv = dim_out // num_heads if mode == "conv" else dim_out + if kernel_q: + self.pool_q = nn.Conv2d( + dim_conv, + dim_conv, + kernel_q, + stride=stride_q, + padding=padding_q, + groups=dim_conv, + bias=False, + ) + self.norm_q = norm_layer(dim_conv) + if kernel_kv: + self.pool_k = nn.Conv2d( + dim_conv, + dim_conv, + kernel_kv, + stride=stride_kv, + padding=padding_kv, + groups=dim_conv, + bias=False, + ) + self.norm_k = norm_layer(dim_conv) + self.pool_v = nn.Conv2d( + dim_conv, + dim_conv, + kernel_kv, + stride=stride_kv, + padding=padding_kv, + groups=dim_conv, + bias=False, + ) + self.norm_v = norm_layer(dim_conv) + else: + raise NotImplementedError(f"Unsupported model {mode}") + + # relative pos embedding + self.rel_pos_type = rel_pos_type + if self.rel_pos_type == 'spatial': + assert feat_size[0] == feat_size[1] + size = feat_size[0] + q_size = size // stride_q[1] if len(stride_q) > 0 else size + kv_size = size // stride_kv[1] if len(stride_kv) > 0 else size + rel_sp_dim = 2 * max(q_size, kv_size) - 1 + + self.rel_pos_h = nn.Parameter(torch.zeros(rel_sp_dim, self.head_dim)) + self.rel_pos_w = nn.Parameter(torch.zeros(rel_sp_dim, self.head_dim)) + trunc_normal_tf_(self.rel_pos_h, std=0.02) + trunc_normal_tf_(self.rel_pos_w, std=0.02) + + self.residual_pooling = residual_pooling + + def forward(self, x, feat_size: List[int]): + B, N, _ = x.shape + + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) + q, k, v = qkv.unbind(dim=0) + + if self.pool_q is not None: + q, q_tok = reshape_pre_pool(q, feat_size, self.has_cls_token) + q = self.pool_q(q) + q, q_size = reshape_post_pool(q, self.num_heads, q_tok) + else: + q_size = feat_size + if self.norm_q is not None: + q = self.norm_q(q) + + if self.pool_k is not None: + k, k_tok = reshape_pre_pool(k, feat_size, self.has_cls_token) + k = self.pool_k(k) + k, k_size = reshape_post_pool(k, self.num_heads, k_tok) + else: + k_size = feat_size + if self.norm_k is not None: + k = self.norm_k(k) + + if self.pool_v is not None: + v, v_tok = reshape_pre_pool(v, feat_size, self.has_cls_token) + v = self.pool_v(v) + v, _ = reshape_post_pool(v, self.num_heads, v_tok) + if self.norm_v is not None: + v = self.norm_v(v) + + attn = (q * self.scale) @ k.transpose(-2, -1) + if self.rel_pos_type == 'spatial': + attn = cal_rel_pos_type( + attn, + q, + self.has_cls_token, + q_size, + k_size, + self.rel_pos_h, + self.rel_pos_w, + ) + attn = attn.softmax(dim=-1) + x = attn @ v + + if self.residual_pooling: + x = x + q + + x = x.transpose(1, 2).reshape(B, -1, self.dim_out) + x = self.proj(x) + + return x, q_size + + +class MultiScaleBlock(nn.Module): + def __init__( + self, + dim, + dim_out, + num_heads, + feat_size, + mlp_ratio=4.0, + qkv_bias=True, + drop_path=0.0, + norm_layer=nn.LayerNorm, + kernel_q=(1, 1), + kernel_kv=(1, 1), + stride_q=(1, 1), + stride_kv=(1, 1), + mode="conv", + has_cls_token=True, + expand_attn=False, + pool_first=False, + rel_pos_type='spatial', + residual_pooling=True, + ): + super().__init__() + proj_needed = dim != dim_out + self.dim = dim + self.dim_out = dim_out + self.has_cls_token = has_cls_token + + self.norm1 = norm_layer(dim) + + self.shortcut_proj_attn = nn.Linear(dim, dim_out) if proj_needed and expand_attn else None + if stride_q and prod(stride_q) > 1: + kernel_skip = [s + 1 if s > 1 else s for s in stride_q] + stride_skip = stride_q + padding_skip = [int(skip // 2) for skip in kernel_skip] + self.shortcut_pool_attn = nn.MaxPool2d(kernel_skip, stride_skip, padding_skip) + else: + self.shortcut_pool_attn = None + + att_dim = dim_out if expand_attn else dim + attn_layer = MultiScaleAttentionPoolFirst if pool_first else MultiScaleAttention + self.attn = attn_layer( + dim, + att_dim, + num_heads=num_heads, + feat_size=feat_size, + qkv_bias=qkv_bias, + kernel_q=kernel_q, + kernel_kv=kernel_kv, + stride_q=stride_q, + stride_kv=stride_kv, + norm_layer=norm_layer, + has_cls_token=has_cls_token, + mode=mode, + rel_pos_type=rel_pos_type, + residual_pooling=residual_pooling, + ) + self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + self.norm2 = norm_layer(att_dim) + mlp_dim_out = dim_out + self.shortcut_proj_mlp = nn.Linear(dim, dim_out) if proj_needed and not expand_attn else None + self.mlp = Mlp( + in_features=att_dim, + hidden_features=int(att_dim * mlp_ratio), + out_features=mlp_dim_out, + ) + self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + def _shortcut_pool(self, x, feat_size: List[int]): + if self.shortcut_pool_attn is None: + return x + if self.has_cls_token: + cls_tok, x = x[:, :, :1, :], x[:, :, 1:, :] + else: + cls_tok = None + B, L, C = x.shape + H, W = feat_size + x = x.reshape(B, H, W, C).permute(0, 3, 1, 2).contiguous() + x = self.shortcut_pool_attn(x) + x = x.reshape(B, C, -1).transpose(1, 2) + if cls_tok is not None: + x = torch.cat((cls_tok, x), dim=2) + return x + + def forward(self, x, feat_size: List[int]): + x_norm = self.norm1(x) + # NOTE as per the original impl, this seems odd, but shortcut uses un-normalized input if no proj + x_shortcut = x if self.shortcut_proj_attn is None else self.shortcut_proj_attn(x_norm) + x_shortcut = self._shortcut_pool(x_shortcut, feat_size) + x, feat_size_new = self.attn(x_norm, feat_size) + x = x_shortcut + self.drop_path1(x) + + x_norm = self.norm2(x) + x_shortcut = x if self.shortcut_proj_mlp is None else self.shortcut_proj_mlp(x_norm) + x = x_shortcut + self.drop_path2(self.mlp(x_norm)) + return x, feat_size_new + + +class MultiScaleVitStage(nn.Module): + + def __init__( + self, + dim, + dim_out, + depth, + num_heads, + feat_size, + mlp_ratio=4.0, + qkv_bias=True, + mode="conv", + kernel_q=(1, 1), + kernel_kv=(1, 1), + stride_q=(1, 1), + stride_kv=(1, 1), + has_cls_token=True, + expand_attn=False, + pool_first=False, + rel_pos_type='spatial', + residual_pooling=True, + norm_layer=nn.LayerNorm, + drop_path=0.0, + ): + super().__init__() + self.grad_checkpointing = False + + self.blocks = nn.ModuleList() + if expand_attn: + out_dims = (dim_out,) * depth + else: + out_dims = (dim,) * (depth - 1) + (dim_out,) + + for i in range(depth): + attention_block = MultiScaleBlock( + dim=dim, + dim_out=out_dims[i], + num_heads=num_heads, + feat_size=feat_size, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + kernel_q=kernel_q, + kernel_kv=kernel_kv, + stride_q=stride_q if i == 0 else (1, 1), + stride_kv=stride_kv, + mode=mode, + has_cls_token=has_cls_token, + pool_first=pool_first, + rel_pos_type=rel_pos_type, + residual_pooling=residual_pooling, + expand_attn=expand_attn, + norm_layer=norm_layer, + drop_path=drop_path[i] if isinstance(drop_path, (list, tuple)) else drop_path, + ) + dim = out_dims[i] + self.blocks.append(attention_block) + if i == 0: + feat_size = tuple([size // stride for size, stride in zip(feat_size, stride_q)]) + + self.feat_size = feat_size + + def forward(self, x, feat_size: List[int]): + for blk in self.blocks: + if self.grad_checkpointing and not torch.jit.is_scripting(): + x, feat_size = checkpoint.checkpoint(blk, x, feat_size) + else: + x, feat_size = blk(x, feat_size) + return x, feat_size + + +class MultiScaleVit(nn.Module): + """ + Improved Multiscale Vision Transformers for Classification and Detection + Yanghao Li*, Chao-Yuan Wu*, Haoqi Fan, Karttikeya Mangalam, Bo Xiong, Jitendra Malik, + Christoph Feichtenhofer* + https://arxiv.org/abs/2112.01526 + + Multiscale Vision Transformers + Haoqi Fan*, Bo Xiong*, Karttikeya Mangalam*, Yanghao Li*, Zhicheng Yan, Jitendra Malik, + Christoph Feichtenhofer* + https://arxiv.org/abs/2104.11227 + """ + + def __init__( + self, + cfg: MultiScaleVitCfg, + img_size: Tuple[int, int] = (224, 224), + in_chans: int = 3, + global_pool: str = 'avg', + num_classes: int = 1000, + drop_path_rate: float = 0., + drop_rate: float = 0., + ): + super().__init__() + img_size = to_2tuple(img_size) + norm_layer = partial(get_norm_layer(cfg.norm_layer), eps=cfg.norm_eps) + self.num_classes = num_classes + self.drop_rate = drop_rate + self.global_pool = global_pool + self.depths = tuple(cfg.depths) + self.expand_attn = cfg.expand_attn + + embed_dim = cfg.embed_dim[0] + self.patch_embed = PatchEmbed( + dim_in=in_chans, + dim_out=embed_dim, + kernel=cfg.patch_kernel, + stride=cfg.patch_stride, + padding=cfg.patch_padding, + ) + patch_dims = (img_size[0] // cfg.patch_stride[0], img_size[1] // cfg.patch_stride[1]) + num_patches = prod(patch_dims) + + if cfg.use_cls_token: + self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) + self.num_prefix_tokens = 1 + pos_embed_dim = num_patches + 1 + else: + self.num_prefix_tokens = 0 + self.cls_token = None + pos_embed_dim = num_patches + + if cfg.use_abs_pos: + self.pos_embed = nn.Parameter(torch.zeros(1, pos_embed_dim, embed_dim)) + else: + self.pos_embed = None + + num_stages = len(cfg.embed_dim) + feat_size = patch_dims + dpr = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(cfg.depths)).split(cfg.depths)] + self.stages = nn.ModuleList() + for i in range(num_stages): + if cfg.expand_attn: + dim_out = cfg.embed_dim[i] + else: + dim_out = cfg.embed_dim[min(i + 1, num_stages - 1)] + stage = MultiScaleVitStage( + dim=embed_dim, + dim_out=dim_out, + depth=cfg.depths[i], + num_heads=cfg.num_heads[i], + feat_size=feat_size, + mlp_ratio=cfg.mlp_ratio, + qkv_bias=cfg.qkv_bias, + mode=cfg.mode, + pool_first=cfg.pool_first, + expand_attn=cfg.expand_attn, + kernel_q=cfg.kernel_qkv, + kernel_kv=cfg.kernel_qkv, + stride_q=cfg.stride_q[i], + stride_kv=cfg.stride_kv[i], + has_cls_token=cfg.use_cls_token, + rel_pos_type=cfg.rel_pos_type, + residual_pooling=cfg.residual_pooling, + norm_layer=norm_layer, + drop_path=dpr[i], + ) + embed_dim = dim_out + feat_size = stage.feat_size + self.stages.append(stage) + + self.num_features = embed_dim + self.norm = norm_layer(embed_dim) + self.head = nn.Sequential(OrderedDict([ + ('drop', nn.Dropout(self.drop_rate)), + ('fc', nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()) + ])) + + if self.pos_embed is not None: + trunc_normal_tf_(self.pos_embed, std=0.02) + if self.cls_token is not None: + trunc_normal_tf_(self.cls_token, std=0.02) + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_tf_(m.weight, std=0.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0.0) + + @torch.jit.ignore + def no_weight_decay(self): + return {k for k, _ in self.named_parameters() + if any(n in k for n in ["pos_embed", "rel_pos_h", "rel_pos_w", "cls_token"])} + + @torch.jit.ignore + def group_matcher(self, coarse=False): + matcher = dict( + stem=r'^stem', # stem and embed + blocks=[(r'^stages\.(\d+)', None), (r'^norm', (99999,))] + ) + return matcher + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + for s in self.stages: + s.grad_checkpointing = enable + + @torch.jit.ignore + def get_classifier(self): + return self.head + + def reset_classifier(self, num_classes, global_pool=None): + self.num_classes = num_classes + if global_pool is not None: + self.global_pool = global_pool + self.head = nn.Sequential(OrderedDict([ + ('drop', nn.Dropout(self.drop_rate)), + ('fc', nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()) + ])) + + def forward_features(self, x): + x, feat_size = self.patch_embed(x) + B, N, C = x.shape + + if self.cls_token is not None: + cls_tokens = self.cls_token.expand(B, -1, -1) + x = torch.cat((cls_tokens, x), dim=1) + + if self.pos_embed is not None: + x = x + self.pos_embed + + for stage in self.stages: + x, feat_size = stage(x, feat_size) + + x = self.norm(x) + return x + + def forward_head(self, x, pre_logits: bool = False): + if self.global_pool: + if self.global_pool == 'avg': + x = x[:, self.num_prefix_tokens:].mean(1) + else: + x = x[:, 0] + return x if pre_logits else self.head(x) + + def forward(self, x): + x = self.forward_features(x) + x = self.forward_head(x) + return x + + +def checkpoint_filter_fn(state_dict, model): + if 'stages.0.blocks.0.norm1.weight' in state_dict: + return state_dict + + import re + if 'model_state' in state_dict: + state_dict = state_dict['model_state'] + + depths = getattr(model, 'depths', None) + expand_attn = getattr(model, 'expand_attn', True) + assert depths is not None, 'model requires depth attribute to remap checkpoints' + depth_map = {} + block_idx = 0 + for stage_idx, d in enumerate(depths): + depth_map.update({i: (stage_idx, i - block_idx) for i in range(block_idx, block_idx + d)}) + block_idx += d + + out_dict = {} + for k, v in state_dict.items(): + k = re.sub( + r'blocks\.(\d+)', + lambda x: f'stages.{depth_map[int(x.group(1))][0]}.blocks.{depth_map[int(x.group(1))][1]}', + k) + + if expand_attn: + k = re.sub(r'stages\.(\d+).blocks\.(\d+).proj', f'stages.\\1.blocks.\\2.shortcut_proj_attn', k) + else: + k = re.sub(r'stages\.(\d+).blocks\.(\d+).proj', f'stages.\\1.blocks.\\2.shortcut_proj_mlp', k) + if 'head' in k: + k = k.replace('head.projection', 'head.fc') + out_dict[k] = v + + # for k, v in state_dict.items(): + # if model.pos_embed is not None and k == 'pos_embed' and v.shape[1] != model.pos_embed.shape[1]: + # # To resize pos embedding when using model at different size from pretrained weights + # v = resize_pos_embed( + # v, + # model.pos_embed, + # 0 if getattr(model, 'no_embed_class') else getattr(model, 'num_prefix_tokens', 1), + # model.patch_embed.grid_size + # ) + + return out_dict + + +def _create_mvitv2(variant, cfg_variant=None, pretrained=False, **kwargs): + return build_model_with_cfg( + MultiScaleVit, variant, pretrained, + model_cfg=model_cfgs[variant] if not cfg_variant else model_cfgs[cfg_variant], + pretrained_filter_fn=checkpoint_filter_fn, + feature_cfg=dict(flatten_sequential=True), + **kwargs) + + +@register_model +def mvitv2_tiny(pretrained=False, **kwargs): + return _create_mvitv2('mvitv2_tiny', pretrained=pretrained, **kwargs) + + +@register_model +def mvitv2_small(pretrained=False, **kwargs): + return _create_mvitv2('mvitv2_small', pretrained=pretrained, **kwargs) + + +@register_model +def mvitv2_base(pretrained=False, **kwargs): + return _create_mvitv2('mvitv2_base', pretrained=pretrained, **kwargs) + + +@register_model +def mvitv2_large(pretrained=False, **kwargs): + return _create_mvitv2('mvitv2_large', pretrained=pretrained, **kwargs) + + +# @register_model +# def mvitv2_base_in21k(pretrained=False, **kwargs): +# return _create_mvitv2('mvitv2_base_in21k', pretrained=pretrained, **kwargs) +# +# +# @register_model +# def mvitv2_large_in21k(pretrained=False, **kwargs): +# return _create_mvitv2('mvitv2_large_in21k', pretrained=pretrained, **kwargs) +# +# +# @register_model +# def mvitv2_huge_in21k(pretrained=False, **kwargs): +# return _create_mvitv2('mvitv2_huge_in21k', pretrained=pretrained, **kwargs) From f332fc2db760e3f3d49ad09c246fe09869ac4f2f Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Thu, 18 Aug 2022 16:19:46 -0700 Subject: [PATCH 06/21] Fix some test failures, torchscript issues --- tests/test_models.py | 2 +- timm/models/efficientformer.py | 2 +- timm/models/gcvit.py | 3 +-- timm/models/pvt_v2.py | 4 ++-- 4 files changed, 5 insertions(+), 6 deletions(-) diff --git a/tests/test_models.py b/tests/test_models.py index 94744483..0f9b8c0b 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -27,7 +27,7 @@ if hasattr(torch._C, '_jit_set_profiling_executor'): NON_STD_FILTERS = [ 'vit_*', 'tnt_*', 'pit_*', 'swin_*', 'coat_*', 'cait_*', '*mixer_*', 'gmlp_*', 'resmlp_*', 'twins_*', 'convit_*', 'levit*', 'visformer*', 'deit*', 'jx_nest_*', 'nest_*', 'xcit_*', 'crossvit_*', 'beit_*', - 'poolformer_*', 'volo_*', 'sequencer2d_*', 'swinv2_*'] + 'poolformer_*', 'volo_*', 'sequencer2d_*', 'swinv2_*', 'pvt_v2*', 'mvitv2*', 'gcvit*', 'efficientformer*'] NUM_NON_STD = len(NON_STD_FILTERS) # exclude models that cause specific test failures diff --git a/timm/models/efficientformer.py b/timm/models/efficientformer.py index 0a54f7fe..2da323cf 100644 --- a/timm/models/efficientformer.py +++ b/timm/models/efficientformer.py @@ -26,7 +26,7 @@ from .registry import register_model def _cfg(url='', **kwargs): return { 'url': url, - 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, + 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, 'fixed_input_size': True, 'crop_pct': .95, 'interpolation': 'bicubic', 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, 'first_conv': 'stem.conv1', 'classifier': 'head', diff --git a/timm/models/gcvit.py b/timm/models/gcvit.py index 3e2cd96a..c134b7c2 100644 --- a/timm/models/gcvit.py +++ b/timm/models/gcvit.py @@ -209,8 +209,7 @@ class WindowAttentionGlobal(nn.Module): def forward(self, x, q_global: Optional[torch.Tensor] = None): B, N, C = x.shape - if self.use_global: - _assert(q_global is not None, 'q_global must be passed in global mode') + if self.use_global and q_global is not None: _assert(x.shape[-1] == q_global.shape[-1], 'x and q_global seq lengths should be equal') kv = self.qkv(x) diff --git a/timm/models/pvt_v2.py b/timm/models/pvt_v2.py index 551a8325..1f698fbc 100644 --- a/timm/models/pvt_v2.py +++ b/timm/models/pvt_v2.py @@ -286,7 +286,6 @@ class PyramidVisionTransformerV2(nn.Module): self.num_classes = num_classes assert global_pool in ('avg', '') self.global_pool = global_pool - self.img_size = to_2tuple(img_size) if img_size is not None else None self.depths = depths num_stages = len(depths) mlp_ratios = to_ntuple(num_stages)(mlp_ratios) @@ -324,7 +323,8 @@ class PyramidVisionTransformerV2(nn.Module): cur += depths[i] # classification head - self.head = nn.Linear(embed_dims[3], num_classes) if num_classes > 0 else nn.Identity() + self.num_features = embed_dims[-1] + self.head = nn.Linear(embed_dims[-1], num_classes) if num_classes > 0 else nn.Identity() self.apply(self._init_weights) From ca52108c2b89694e82713e1b4ca5c28902bc5435 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Fri, 19 Aug 2022 10:20:51 -0700 Subject: [PATCH 07/21] Fix some model support functions --- timm/models/efficientformer.py | 5 ++--- timm/models/gcvit.py | 13 ++++++++++++- timm/models/mvitv2.py | 4 ++-- timm/models/pvt_v2.py | 2 +- 4 files changed, 17 insertions(+), 7 deletions(-) diff --git a/timm/models/efficientformer.py b/timm/models/efficientformer.py index 2da323cf..814b6957 100644 --- a/timm/models/efficientformer.py +++ b/timm/models/efficientformer.py @@ -449,13 +449,12 @@ class EfficientFormer(nn.Module): def get_classifier(self): return self.head, self.head_dist - def reset_classifier(self, num_classes, global_pool=None, distillation=None): + def reset_classifier(self, num_classes, global_pool=None): self.num_classes = num_classes if global_pool is not None: self.global_pool = global_pool self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() - if self.dist: - self.head_dist = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() + self.head_dist = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() @torch.jit.ignore def set_distilled_training(self, enable=True): diff --git a/timm/models/gcvit.py b/timm/models/gcvit.py index c134b7c2..e7eccea8 100644 --- a/timm/models/gcvit.py +++ b/timm/models/gcvit.py @@ -427,6 +427,7 @@ class GlobalContextVit(nn.Module): feat_size = tuple(d // 4 for d in img_size) # stem reduction by 4 self.global_pool = global_pool self.num_classes = num_classes + self.drop_rate = drop_rate num_stages = len(depths) self.num_features = int(embed_dim * 2 ** (num_stages - 1)) @@ -491,7 +492,7 @@ class GlobalContextVit(nn.Module): def group_matcher(self, coarse=False): matcher = dict( stem=r'^stem', # stem and embed - blocks=(r'^stages\.(\d+)', None) + blocks=r'^stages\.(\d+)' ) return matcher @@ -500,6 +501,16 @@ class GlobalContextVit(nn.Module): for s in self.stages: s.grad_checkpointing = enable + @torch.jit.ignore + def get_classifier(self): + return self.head.fc + + def reset_classifier(self, num_classes, global_pool=None): + self.num_classes = num_classes + if global_pool is None: + global_pool = self.head.global_pool.pool_type + self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate) + def forward_features(self, x: torch.Tensor) -> torch.Tensor: x = self.stem(x) x = self.stages(x) diff --git a/timm/models/mvitv2.py b/timm/models/mvitv2.py index 4c7cd044..fc29f113 100644 --- a/timm/models/mvitv2.py +++ b/timm/models/mvitv2.py @@ -850,7 +850,7 @@ class MultiScaleVit(nn.Module): @torch.jit.ignore def group_matcher(self, coarse=False): matcher = dict( - stem=r'^stem', # stem and embed + stem=r'^patch_embed', # stem and embed blocks=[(r'^stages\.(\d+)', None), (r'^norm', (99999,))] ) return matcher @@ -862,7 +862,7 @@ class MultiScaleVit(nn.Module): @torch.jit.ignore def get_classifier(self): - return self.head + return self.head.fc def reset_classifier(self, num_classes, global_pool=None): self.num_classes = num_classes diff --git a/timm/models/pvt_v2.py b/timm/models/pvt_v2.py index 1f698fbc..ce4cbf56 100644 --- a/timm/models/pvt_v2.py +++ b/timm/models/pvt_v2.py @@ -351,7 +351,7 @@ class PyramidVisionTransformerV2(nn.Module): def group_matcher(self, coarse=False): matcher = dict( stem=r'^patch_embed', # stem and embed - blocks=[(r'^stages\.(\d+)', None), (r'^norm', (99999,))] + blocks=r'^stages\.(\d+)' ) return matcher From 8c9696c9df93d54ac17e0afadf5ef687f329fb8f Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Mon, 22 Aug 2022 17:40:31 -0700 Subject: [PATCH 08/21] More model and test fixes --- tests/test_models.py | 4 +++- timm/models/gcvit.py | 31 +++++++++++++++------------ timm/models/layers/create_norm_act.py | 3 +++ timm/models/layers/norm_act.py | 1 + timm/models/mvitv2.py | 7 +++++- timm/models/pvt_v2.py | 2 +- 6 files changed, 31 insertions(+), 17 deletions(-) diff --git a/tests/test_models.py b/tests/test_models.py index 0f9b8c0b..5daee76d 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -27,7 +27,9 @@ if hasattr(torch._C, '_jit_set_profiling_executor'): NON_STD_FILTERS = [ 'vit_*', 'tnt_*', 'pit_*', 'swin_*', 'coat_*', 'cait_*', '*mixer_*', 'gmlp_*', 'resmlp_*', 'twins_*', 'convit_*', 'levit*', 'visformer*', 'deit*', 'jx_nest_*', 'nest_*', 'xcit_*', 'crossvit_*', 'beit_*', - 'poolformer_*', 'volo_*', 'sequencer2d_*', 'swinv2_*', 'pvt_v2*', 'mvitv2*', 'gcvit*', 'efficientformer*'] + 'poolformer_*', 'volo_*', 'sequencer2d_*', 'swinv2_*', 'pvt_v2*', 'mvitv2*', 'gcvit*', 'efficientformer*', + 'coatne?t_*', 'max?vit_*', +] NUM_NON_STD = len(NON_STD_FILTERS) # exclude models that cause specific test failures diff --git a/timm/models/gcvit.py b/timm/models/gcvit.py index e7eccea8..bad40bd6 100644 --- a/timm/models/gcvit.py +++ b/timm/models/gcvit.py @@ -43,7 +43,7 @@ def _cfg(url='', **kwargs): 'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7), 'crop_pct': 0.875, 'interpolation': 'bicubic', 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, - 'first_conv': 'stem.conv', 'classifier': 'head.fc', + 'first_conv': 'stem.conv1', 'classifier': 'head.fc', 'fixed_input_size': True, **kwargs } @@ -106,7 +106,7 @@ class Downsample2d(nn.Module): dim_out=None, reduction='conv', act_layer=nn.GELU, - norm_layer=LayerNorm2d, + norm_layer=LayerNorm2d, # NOTE in NCHW ): super().__init__() dim_out = dim_out or dim @@ -163,12 +163,10 @@ class Stem(nn.Module): self, in_chs: int = 3, out_chs: int = 96, - act_layer: str = 'gelu', - norm_layer: str = 'layernorm2d', # NOTE norm for NCHW + act_layer: Callable = nn.GELU, + norm_layer: Callable = LayerNorm2d, # NOTE stem in NCHW ): super().__init__() - act_layer = get_act_layer(act_layer) - norm_layer = get_norm_layer(norm_layer) self.conv1 = nn.Conv2d(in_chs, out_chs, kernel_size=3, stride=2, padding=1) self.down = Downsample2d(out_chs, act_layer=act_layer, norm_layer=norm_layer) @@ -333,15 +331,11 @@ class GlobalContextVitStage(nn.Module): proj_drop: float = 0., attn_drop: float = 0., drop_path: Union[List[float], float] = 0.0, - act_layer: str = 'gelu', - norm_layer: str = 'layernorm2d', - norm_layer_cl: str = 'layernorm', + act_layer: Callable = nn.GELU, + norm_layer: Callable = nn.LayerNorm, + norm_layer_cl: Callable = LayerNorm2d, ): super().__init__() - act_layer = get_act_layer(act_layer) - norm_layer = get_norm_layer(norm_layer) - norm_layer_cl = get_norm_layer(norm_layer_cl) - if downsample: self.downsample = Downsample2d( dim=dim, @@ -421,8 +415,13 @@ class GlobalContextVit(nn.Module): act_layer: str = 'gelu', norm_layer: str = 'layernorm2d', norm_layer_cl: str = 'layernorm', + norm_eps: float = 1e-5, ): super().__init__() + act_layer = get_act_layer(act_layer) + norm_layer = partial(get_norm_layer(norm_layer), eps=norm_eps) + norm_layer_cl = partial(get_norm_layer(norm_layer_cl), eps=norm_eps) + img_size = to_2tuple(img_size) feat_size = tuple(d // 4 for d in img_size) # stem reduction by 4 self.global_pool = global_pool @@ -432,7 +431,11 @@ class GlobalContextVit(nn.Module): self.num_features = int(embed_dim * 2 ** (num_stages - 1)) self.stem = Stem( - in_chs=in_chans, out_chs=embed_dim, act_layer=act_layer, norm_layer=norm_layer) + in_chs=in_chans, + out_chs=embed_dim, + act_layer=act_layer, + norm_layer=norm_layer + ) dpr = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depths)).split(depths)] stages = [] diff --git a/timm/models/layers/create_norm_act.py b/timm/models/layers/create_norm_act.py index cd15c2f8..78dd9a51 100644 --- a/timm/models/layers/create_norm_act.py +++ b/timm/models/layers/create_norm_act.py @@ -18,6 +18,7 @@ _NORM_ACT_MAP = dict( batchnorm=BatchNormAct2d, batchnorm2d=BatchNormAct2d, groupnorm=GroupNormAct, + groupnorm1=functools.partial(GroupNormAct, num_groups=1), layernorm=LayerNormAct, layernorm2d=LayerNormAct2d, evonormb0=EvoNorm2dB0, @@ -72,6 +73,8 @@ def get_norm_act_layer(norm_layer, act_layer=None): norm_act_layer = BatchNormAct2d elif type_name.startswith('groupnorm'): norm_act_layer = GroupNormAct + elif type_name.startswith('groupnorm1'): + norm_act_layer = functools.partial(GroupNormAct, num_groups=1) elif type_name.startswith('layernorm2d'): norm_act_layer = LayerNormAct2d elif type_name.startswith('layernorm'): diff --git a/timm/models/layers/norm_act.py b/timm/models/layers/norm_act.py index be1edead..dc077160 100644 --- a/timm/models/layers/norm_act.py +++ b/timm/models/layers/norm_act.py @@ -226,6 +226,7 @@ class LayerNormAct2d(nn.LayerNorm): self.act = act_layer(**act_args) else: self.act = nn.Identity() + self._fast_norm = is_fast_norm() def forward(self, x): x = x.permute(0, 2, 3, 1) diff --git a/timm/models/mvitv2.py b/timm/models/mvitv2.py index fc29f113..002225c6 100644 --- a/timm/models/mvitv2.py +++ b/timm/models/mvitv2.py @@ -24,6 +24,7 @@ import torch.utils.checkpoint as checkpoint from torch import nn from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from .fx_features import register_notrace_function from .helpers import build_model_with_cfg from .layers import Mlp, DropPath, trunc_normal_tf_, get_norm_layer, to_2tuple from .registry import register_model @@ -35,7 +36,8 @@ def _cfg(url='', **kwargs): 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, 'crop_pct': .9, 'interpolation': 'bicubic', 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, - 'first_conv': 'patch_embed.proj', 'classifier': 'head', 'fixed_input_size': True, + 'first_conv': 'patch_embed.proj', 'classifier': 'head.fc', + 'fixed_input_size': True, **kwargs } @@ -169,6 +171,7 @@ class PatchEmbed(nn.Module): return x.flatten(2).transpose(1, 2), x.shape[-2:] +@register_notrace_function def reshape_pre_pool( x, feat_size: List[int], @@ -183,6 +186,7 @@ def reshape_pre_pool( return x, cls_tok +@register_notrace_function def reshape_post_pool( x, num_heads: int, @@ -196,6 +200,7 @@ def reshape_post_pool( return x, feat_size +@register_notrace_function def cal_rel_pos_type( attn: torch.Tensor, q: torch.Tensor, diff --git a/timm/models/pvt_v2.py b/timm/models/pvt_v2.py index ce4cbf56..dd3cf690 100644 --- a/timm/models/pvt_v2.py +++ b/timm/models/pvt_v2.py @@ -36,7 +36,7 @@ def _cfg(url='', **kwargs): 'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7), 'crop_pct': 0.9, 'interpolation': 'bicubic', 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, - 'first_conv': 'patch_embed.conv', 'classifier': 'head', 'fixed_input_size': False, + 'first_conv': 'patch_embed.proj', 'classifier': 'head', 'fixed_input_size': False, **kwargs } From ffaf97f813f306ce6f59e195ee72d6f9720fc27d Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Mon, 22 Aug 2022 17:42:10 -0700 Subject: [PATCH 09/21] MaxxVit! A very configurable MaxVit and CoAtNet impl with lots of goodies.. --- timm/models/__init__.py | 1 + timm/models/layers/__init__.py | 2 +- timm/models/layers/helpers.py | 12 + timm/models/maxxvit.py | 1692 ++++++++++++++++++++++++++++++++ 4 files changed, 1706 insertions(+), 1 deletion(-) create mode 100644 timm/models/maxxvit.py diff --git a/timm/models/__init__.py b/timm/models/__init__.py index b93d3f94..51a38d0c 100644 --- a/timm/models/__init__.py +++ b/timm/models/__init__.py @@ -25,6 +25,7 @@ from .inception_resnet_v2 import * from .inception_v3 import * from .inception_v4 import * from .levit import * +from .maxxvit import * from .mlp_mixer import * from .mobilenetv3 import * from .mobilevit import * diff --git a/timm/models/layers/__init__.py b/timm/models/layers/__init__.py index 071da7bc..21c641b6 100644 --- a/timm/models/layers/__init__.py +++ b/timm/models/layers/__init__.py @@ -21,7 +21,7 @@ from .fast_norm import is_fast_norm, set_fast_norm, fast_group_norm, fast_layer_ from .filter_response_norm import FilterResponseNormTlu2d, FilterResponseNormAct2d from .gather_excite import GatherExcite from .global_context import GlobalContext -from .helpers import to_ntuple, to_2tuple, to_3tuple, to_4tuple, make_divisible +from .helpers import to_ntuple, to_2tuple, to_3tuple, to_4tuple, make_divisible, extend_tuple from .inplace_abn import InplaceAbn from .linear import Linear from .mixed_conv2d import MixedConv2d diff --git a/timm/models/layers/helpers.py b/timm/models/layers/helpers.py index 4a10ca0c..2fa296bc 100644 --- a/timm/models/layers/helpers.py +++ b/timm/models/layers/helpers.py @@ -29,3 +29,15 @@ def make_divisible(v, divisor=8, min_value=None, round_limit=.9): if new_v < round_limit * v: new_v += divisor return new_v + + +def extend_tuple(x, n): + # pdas a tuple to specified n by padding with last value + if not isinstance(x, (tuple, list)): + x = (x,) + else: + x = tuple(x) + pad_n = n - len(x) + if pad_n <= 0: + return x[:n] + return x + (x[-1],) * pad_n diff --git a/timm/models/maxxvit.py b/timm/models/maxxvit.py new file mode 100644 index 00000000..8b1fe0a6 --- /dev/null +++ b/timm/models/maxxvit.py @@ -0,0 +1,1692 @@ +""" MaxVit and CoAtNet Vision Transformer - CNN Hybrids in PyTorch + +This is a from-scratch implementation of both CoAtNet and MaxVit in PyTorch. + +99% of the implementation was done from papers, however last minute some adjustments were made +based on the (as yet unfinished?) public code release https://github.com/google-research/maxvit + +There are multiple sets of models defined for both architectures. Typically, names with a + `_rw` suffix are my own original configs prior to referencing https://github.com/google-research/maxvit. +These configs work well and appear to be a bit faster / lower resource than the paper. + +The models without extra prefix / suffix' (coatnet_0_224, maxvit_tiny_224, etc), are intended to +match paper, BUT, without any official pretrained weights it's difficult to confirm a 100% match. + +# FIXME / WARNING +This impl remains a WIP, some configs and models may vanish or change... + +Papers: + +MaxViT: Multi-Axis Vision Transformer - https://arxiv.org/abs/2204.01697 +@article{tu2022maxvit, + title={MaxViT: Multi-Axis Vision Transformer}, + author={Tu, Zhengzhong and Talebi, Hossein and Zhang, Han and Yang, Feng and Milanfar, Peyman and Bovik, Alan and Li, Yinxiao}, + journal={ECCV}, + year={2022}, +} + +CoAtNet: Marrying Convolution and Attention for All Data Sizes - https://arxiv.org/abs/2106.04803 +@article{DBLP:journals/corr/abs-2106-04803, + author = {Zihang Dai and Hanxiao Liu and Quoc V. Le and Mingxing Tan}, + title = {CoAtNet: Marrying Convolution and Attention for All Data Sizes}, + journal = {CoRR}, + volume = {abs/2106.04803}, + year = {2021} +} + +Hacked together by / Copyright 2022, Ross Wightman +""" + +import math +from collections import OrderedDict +from dataclasses import dataclass +from functools import partial +from typing import Callable, Optional, Union, Tuple, List + +import torch +from torch import nn +from torch.utils.checkpoint import checkpoint + +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from .helpers import build_model_with_cfg, checkpoint_seq, named_apply +from .fx_features import register_notrace_function +from .layers import Mlp, ConvMlp, DropPath, ClassifierHead, trunc_normal_tf_, LayerNorm2d, LayerNorm +from .layers import create_attn, get_act_layer, get_norm_layer, get_norm_act_layer, create_conv2d +from .layers import to_2tuple, extend_tuple, make_divisible, _assert +from .registry import register_model +from .vision_transformer_relpos import RelPosMlp, RelPosBias # FIXME move these to common location + +__all__ = ['MaxxVitCfg', 'MaxxVitConvCfg', 'MaxxVitTransformerCfg', 'MaxxVit'] + + +def _cfg(url='', **kwargs): + return { + 'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7), + 'crop_pct': 0.95, 'interpolation': 'bicubic', + 'mean': (0.5, 0.5, 0.5), 'std': (0.5, 0.5, 0.5), + 'first_conv': 'stem.conv1', 'classifier': 'head.fc', + 'fixed_input_size': True, + **kwargs + } + + +default_cfgs = { + # Fiddling with configs / defaults / still pretraining + 'coatnet_pico_rw_224': _cfg(url=''), + 'coatnet_nano_rw_224': _cfg( + url='', + crop_pct=0.9), + 'coatnet_0_rw_224': _cfg( + url=''), + 'coatnet_1_rw_224': _cfg( + url='' + ), + 'coatnet_2_rw_224': _cfg(url=''), + + # Highly experimental configs + 'coatnet_bn_0_rw_224': _cfg( + url='', + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, + crop_pct=0.95), + 'coatnet_rmlp_nano_rw_224': _cfg( + url='', + crop_pct=0.9), + 'coatnet_rmlp_0_rw_224': _cfg(url=''), + 'coatnet_rmlp_1_rw_224': _cfg( + url=''), + 'coatnext_nano_rw_224': _cfg(url=''), + + # Trying to be like the CoAtNet paper configs + 'coatnet_0_224': _cfg(url=''), + 'coatnet_1_224': _cfg(url=''), + 'coatnet_2_224': _cfg(url=''), + 'coatnet_3_224': _cfg(url=''), + 'coatnet_4_224': _cfg(url=''), + 'coatnet_5_224': _cfg(url=''), + + # Experimental configs + 'maxvit_pico_rw_256': _cfg(url='', input_size=(3, 256, 256)), + 'maxvit_nano_rw_256': _cfg(url='', input_size=(3, 256, 256)), + 'maxvit_tiny_rw_224': _cfg(url=''), + 'maxvit_tiny_rw_256': _cfg(url='', input_size=(3, 256, 256)), + 'maxvit_tiny_cm_256': _cfg(url='', input_size=(3, 256, 256)), + 'maxxvit_nano_rw_256': _cfg(url='', input_size=(3, 256, 256)), + + # Trying to be like the MaxViT paper configs + 'maxvit_tiny_224': _cfg(url=''), + 'maxvit_small_224': _cfg(url=''), + 'maxvit_base_224': _cfg(url=''), + 'maxvit_large_224': _cfg(url=''), + 'maxvit_xlarge_224': _cfg(url=''), +} + + +@dataclass +class MaxxVitTransformerCfg: + dim_head: int = 32 + expand_ratio: float = 4.0 + expand_first: bool = True + shortcut_bias: bool = True, + attn_bias: bool = True + attn_drop: float = 0. + proj_drop: float = 0. + pool_type: str = 'avg' + rel_pos_type: str = 'bias' + rel_pos_dim: int = 512 # for relative position types w/ MLP + window_size: Tuple[int, int] = (7, 7) + grid_size: Tuple[int, int] = (7, 7) + init_values: Optional[float] = None + act_layer: str = 'gelu' + norm_layer: str = 'layernorm2d' + norm_layer_cl: str = 'layernorm' + norm_eps: float = 1e-6 + + +@dataclass +class MaxxVitConvCfg: + block_type: str = 'mbconv' + expand_ratio: float = 4.0 + expand_output: bool = True # calculate expansion channels from output (vs input chs) + kernel_size: int = 3 + group_size: int = 1 # 1 == depthwise + pre_norm_act: bool = False # activation after pre-norm + output_bias: bool = True # bias for shortcut + final 1x1 projection conv + stride_mode: str = 'dw' # stride done via one of 'pool', '1x1', 'dw' + pool_type: str = 'avg' + downsample_pool_type: str = 'avg2' + attn_early: bool = False # apply attn between conv2 and norm2, instead of after norm2 + attn_layer: str = 'se' + attn_act_layer: str = 'silu' + attn_ratio: float = 0.25 + init_values: Optional[float] = 1e-5 # for ConvNeXt block + act_layer: str = 'gelu' + norm_layer: str = '' + norm_layer_cl: str = '' + norm_eps: Optional[float] = None + + def __post_init__(self): + # mbconv vs convnext blocks have different defaults, set in post_init to avoid explicit config args + assert self.block_type in ('mbconv', 'convnext') + use_mbconv = self.block_type == 'mbconv' + if not self.norm_layer: + self.norm_layer = 'batchnorm2d' if use_mbconv else 'layernorm2d' + if not self.norm_layer_cl and not use_mbconv: + self.norm_layer_cl = 'layernorm' + if self.norm_eps is None: + self.norm_eps = 1e-5 if use_mbconv else 1e-6 + self.downsample_pool_type = self.downsample_pool_type or self.pool_type + + +@dataclass +class MaxxVitCfg: + embed_dim: Tuple[int, ...] = (96, 192, 384, 768) + depths: Tuple[int, ...] = (2, 3, 5, 2) + block_type: Tuple[Union[str, Tuple[str, ...]], ...] = ('C', 'C', 'T', 'T') + stem_width: Union[int, Tuple[int, int]] = 64 + stem_bias: bool = True + conv_cfg: MaxxVitConvCfg = MaxxVitConvCfg() + transformer_cfg: MaxxVitTransformerCfg = MaxxVitTransformerCfg() + weight_init: str = 'vit_eff' + + +def _rw_coat_cfg( + stride_mode='pool', + pool_type='avg2', + conv_output_bias=False, + conv_attn_early=False, + conv_norm_layer='', + transformer_shortcut_bias=True, + transformer_norm_layer='layernorm2d', + transformer_norm_layer_cl='layernorm', + rel_pos_type='bias', + rel_pos_dim=512, +): + # 'RW' timm variant models were created and trained before seeing https://github.com/google-research/maxvit + # Common differences for initial timm models: + # - pre-norm layer in MZBConv included an activation after norm + # - mbconv expansion calculated from input instead of output chs + # - mbconv shortcut and final 1x1 conv did not have a bias + # - SE act layer was relu, not silu + # - mbconv uses silu in timm, not gelu + # - expansion in attention block done via output proj, not input proj + # Variable differences (evolved over training initial models): + # - avg pool with kernel_size=2 favoured downsampling (instead of maxpool for coat) + # - SE attention was between conv2 and norm/act + # - default to avg pool for mbconv downsample instead of 1x1 or dw conv + # - transformer block shortcut has no bias + return dict( + conv_cfg=MaxxVitConvCfg( + stride_mode=stride_mode, + pool_type=pool_type, + pre_norm_act=True, + expand_output=False, + output_bias=conv_output_bias, + attn_early=conv_attn_early, + attn_act_layer='relu', + act_layer='silu', + norm_layer=conv_norm_layer, + ), + transformer_cfg=MaxxVitTransformerCfg( + expand_first=False, + shortcut_bias=transformer_shortcut_bias, + pool_type=pool_type, + norm_layer=transformer_norm_layer, + norm_layer_cl=transformer_norm_layer_cl, + rel_pos_type=rel_pos_type, + rel_pos_dim=rel_pos_dim, + ), + ) + + +def _rw_max_cfg( + stride_mode='dw', + pool_type='avg', + conv_output_bias=False, + conv_attn_ratio=1 / 16, + conv_norm_layer='', + transformer_norm_layer='layernorm2d', + transformer_norm_layer_cl='layernorm', + window_size=7, + dim_head=32, + rel_pos_type='bias', + rel_pos_dim=512, +): + # 'RW' timm variant models were created and trained before seeing https://github.com/google-research/maxvit + # Differences of initial timm models: + # - mbconv expansion calculated from input instead of output chs + # - mbconv shortcut and final 1x1 conv did not have a bias + # - mbconv uses silu in timm, not gelu + # - avg pool with kernel_size=2 favoured downsampling (instead of maxpool for coat) + # - default to avg pool for mbconv downsample instead of 1x1 or dw conv + # - expansion in attention block done via output proj, not input proj + return dict( + conv_cfg=MaxxVitConvCfg( + stride_mode=stride_mode, + pool_type=pool_type, + expand_output=False, + output_bias=conv_output_bias, + attn_ratio=conv_attn_ratio, + act_layer='silu', + norm_layer=conv_norm_layer, + ), + transformer_cfg=MaxxVitTransformerCfg( + expand_first=False, + pool_type=pool_type, + dim_head=dim_head, + window_size=to_2tuple(window_size), + grid_size=to_2tuple(window_size), + norm_layer=transformer_norm_layer, + norm_layer_cl=transformer_norm_layer_cl, + rel_pos_type=rel_pos_type, + rel_pos_dim=rel_pos_dim, + ), + ) + + +def _next_cfg( + stride_mode='dw', + pool_type='avg2', + conv_norm_layer='layernorm2d', + conv_norm_layer_cl='layernorm', + transformer_norm_layer='layernorm2d', + transformer_norm_layer_cl='layernorm', + window_size=7, + rel_pos_type='bias', + rel_pos_dim=512, +): + # For experimental models with convnext instead of mbconv + return dict( + conv_cfg=MaxxVitConvCfg( + block_type='convnext', + stride_mode=stride_mode, + pool_type=pool_type, + expand_output=False, + norm_layer=conv_norm_layer, + norm_layer_cl=conv_norm_layer_cl, + ), + transformer_cfg=MaxxVitTransformerCfg( + expand_first=False, + pool_type=pool_type, + window_size=to_2tuple(window_size), + grid_size=to_2tuple(window_size), + norm_layer=transformer_norm_layer, + norm_layer_cl=transformer_norm_layer_cl, + rel_pos_type=rel_pos_type, + rel_pos_dim=rel_pos_dim, + ), + ) + + +model_cfgs = dict( + # Fiddling with configs / defaults / still pretraining + coatnet_pico_rw_224=MaxxVitCfg( + embed_dim=(64, 128, 256, 512), + depths=(2, 3, 5, 2), + stem_width=(32, 64), + **_rw_max_cfg( # using newer max defaults here + pool_type='avg2', + conv_output_bias=True, + conv_attn_ratio=0.25, + ), + ), + coatnet_nano_rw_224=MaxxVitCfg( + embed_dim=(64, 128, 256, 512), + depths=(3, 4, 6, 3), + stem_width=(32, 64), + **_rw_max_cfg( # using newer max defaults here + stride_mode='pool', + pool_type='avg2', + conv_output_bias=True, + conv_attn_ratio=0.25, + ), + ), + coatnet_0_rw_224=MaxxVitCfg( + embed_dim=(96, 192, 384, 768), + depths=(2, 3, 7, 2), # deeper than paper '0' model + stem_width=(32, 64), + **_rw_coat_cfg( + conv_attn_early=True, + transformer_shortcut_bias=False, + ), + ), + coatnet_1_rw_224=MaxxVitCfg( + embed_dim=(96, 192, 384, 768), + depths=(2, 6, 14, 2), + stem_width=(32, 64), + **_rw_coat_cfg( + stride_mode='dw', + conv_attn_early=True, + transformer_shortcut_bias=False, + ) + ), + coatnet_2_rw_224=MaxxVitCfg( + embed_dim=(128, 256, 512, 1024), + depths=(2, 6, 14, 2), + stem_width=(64, 128), + **_rw_coat_cfg(stride_mode='dw'), + ), + + # Highly experimental configs + coatnet_bn_0_rw_224=MaxxVitCfg( + embed_dim=(96, 192, 384, 768), + depths=(2, 3, 7, 2), # deeper than paper '0' model + stem_width=(32, 64), + **_rw_coat_cfg( + stride_mode='dw', + conv_attn_early=True, + transformer_shortcut_bias=False, + transformer_norm_layer='batchnorm2d', + ) + ), + coatnet_rmlp_nano_rw_224=MaxxVitCfg( + embed_dim=(64, 128, 256, 512), + depths=(3, 4, 6, 3), + stem_width=(32, 64), + **_rw_max_cfg( + pool_type='avg2', + conv_output_bias=True, + conv_attn_ratio=0.25, + rel_pos_type='mlp', + rel_pos_dim=384, + ), + ), + coatnet_rmlp_0_rw_224=MaxxVitCfg( + embed_dim=(96, 192, 384, 768), + depths=(2, 3, 7, 2), # deeper than paper '0' model + stem_width=(32, 64), + **_rw_coat_cfg( + stride_mode='dw', + rel_pos_type='mlp', + ), + ), + coatnet_rmlp_1_rw_224=MaxxVitCfg( + embed_dim=(96, 192, 384, 768), + depths=(2, 6, 14, 2), + stem_width=(32, 64), + **_rw_coat_cfg( + pool_type='max', + conv_attn_early=True, + transformer_shortcut_bias=False, + rel_pos_type='mlp', + rel_pos_dim=384, # was supposed to be 512, woops + ), + ), + coatnext_nano_rw_224=MaxxVitCfg( + embed_dim=(64, 128, 256, 512), + depths=(3, 4, 6, 3), + stem_width=(32, 64), + **_next_cfg(), + ), + coatnet_nano_cc_224=MaxxVitCfg( + embed_dim=(64, 128, 256, 512), + depths=(3, 4, 6, 3), + stem_width=(32, 64), + block_type=('C', 'C', ('C', 'T'), ('C', 'T')), + **_rw_coat_cfg(), + ), + + # Trying to be like the CoAtNet paper configs + coatnet_0_224=MaxxVitCfg( + embed_dim=(96, 192, 384, 768), + depths=(2, 3, 5, 2), + stem_width=64, + ), + coatnet_1_224=MaxxVitCfg( + embed_dim=(96, 192, 384, 768), + depths=(2, 6, 14, 2), + stem_width=64, + ), + coatnet_2_224=MaxxVitCfg( + embed_dim=(128, 256, 512, 1024), + depths=(2, 6, 14, 2), + stem_width=128, + ), + coatnet_3_224=MaxxVitCfg( + embed_dim=(192, 384, 768, 1536), + depths=(2, 6, 14, 2), + stem_width=192, + ), + coatnet_4_224=MaxxVitCfg( + embed_dim=(192, 384, 768, 1536), + depths=(2, 12, 28, 2), + stem_width=192, + ), + coatnet_5_224=MaxxVitCfg( + embed_dim=(256, 512, 1280, 2048), + depths=(2, 12, 28, 2), + stem_width=192, + ), + + # Experimental MaxVit configs + maxvit_pico_rw_256=MaxxVitCfg( + embed_dim=(32, 64, 128, 256), + depths=(2, 2, 5, 2), + block_type=('M',) * 4, + stem_width=(24, 32), + **_rw_max_cfg(window_size=8), + ), + maxvit_nano_rw_256=MaxxVitCfg( + embed_dim=(64, 128, 256, 512), + depths=(1, 2, 3, 1), + block_type=('M',) * 4, + stem_width=(32, 64), + **_rw_max_cfg(window_size=8), + ), + maxvit_tiny_rw_224=MaxxVitCfg( + embed_dim=(64, 128, 256, 512), + depths=(2, 2, 5, 2), + block_type=('M',) * 4, + stem_width=(32, 64), + **_rw_max_cfg(), + ), + maxvit_tiny_rw_256=MaxxVitCfg( + embed_dim=(64, 128, 256, 512), + depths=(2, 2, 5, 2), + block_type=('M',) * 4, + stem_width=(32, 64), + **_rw_max_cfg(window_size=8), + ), + maxvit_tiny_cm_256=MaxxVitCfg( + embed_dim=(64, 128, 256, 512), + depths=(2, 2, 5, 2), + block_type=('CM',) * 4, + stem_width=(32, 64), + **_rw_max_cfg(window_size=8), + ), + maxxvit_nano_rw_256=MaxxVitCfg( + embed_dim=(64, 128, 256, 512), + depths=(1, 2, 3, 1), + block_type=('M',) * 4, + stem_width=(32, 64), + **_next_cfg(window_size=8), + ), + + # Trying to be like the MaxViT paper configs + maxvit_tiny_224=MaxxVitCfg( + embed_dim=(64, 128, 256, 512), + depths=(2, 2, 5, 2), + block_type=('M',) * 4, + stem_width=64, + ), + maxvit_small_224=MaxxVitCfg( + embed_dim=(96, 192, 384, 768), + depths=(2, 2, 5, 2), + block_type=('M',) * 4, + stem_width=64, + ), + maxvit_base_224=MaxxVitCfg( + embed_dim=(96, 192, 384, 768), + depths=(2, 6, 14, 2), + block_type=('M',) * 4, + stem_width=64, + ), + maxvit_large_224=MaxxVitCfg( + embed_dim=(128, 256, 512, 1024), + depths=(2, 6, 14, 2), + block_type=('M',) * 4, + stem_width=128, + ), + maxvit_xlarge_224=MaxxVitCfg( + embed_dim=(192, 384, 768, 1536), + depths=(2, 6, 14, 2), + block_type=('M',) * 4, + stem_width=192, + ), + +) + + +class Attention2d(nn.Module): + """ multi-head attention for 2D NCHW tensors""" + def __init__( + self, + dim: int, + dim_out: Optional[int] = None, + dim_head: int = 32, + bias: bool = True, + expand_first: bool = True, + rel_pos_cls: Callable = None, + attn_drop: float = 0., + proj_drop: float = 0. + ): + super().__init__() + dim_out = dim_out or dim + dim_attn = dim_out if expand_first else dim + self.num_heads = dim_attn // dim_head + self.dim_head = dim_head + self.scale = dim_head ** -0.5 + + self.qkv = nn.Conv2d(dim, dim_attn * 3, 1, bias=bias) + self.rel_pos = rel_pos_cls(num_heads=self.num_heads) if rel_pos_cls else None + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Conv2d(dim_attn, dim_out, 1, bias=bias) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x, shared_rel_pos: Optional[torch.Tensor] = None): + B, C, H, W = x.shape + + q, k, v = self.qkv(x).view(B, self.num_heads, self.dim_head * 3, -1).chunk(3, dim=2) + + attn = (q.transpose(-2, -1) @ k) * self.scale + if self.rel_pos is not None: + attn = self.rel_pos(attn) + elif shared_rel_pos is not None: + attn = attn + shared_rel_pos + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (v @ attn.transpose(-2, -1)).view(B, -1, H, W) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class AttentionCl(nn.Module): + """ Channels-last multi-head attention (B, ..., C) """ + def __init__( + self, + dim: int, + dim_out: Optional[int] = None, + dim_head: int = 32, + bias: bool = True, + expand_first: bool = True, + rel_pos_cls: Callable = None, + attn_drop: float = 0., + proj_drop: float = 0. + ): + super().__init__() + dim_out = dim_out or dim + dim_attn = dim_out if expand_first and dim_out > dim else dim + assert dim_attn % dim_head == 0, 'attn dim should be divisible by head_dim' + self.num_heads = dim_attn // dim_head + self.dim_head = dim_head + self.scale = dim_head ** -0.5 + + self.qkv = nn.Linear(dim, dim_attn * 3, bias=bias) + self.rel_pos = rel_pos_cls(num_heads=self.num_heads) if rel_pos_cls else None + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim_attn, dim_out, bias=bias) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x, shared_rel_pos: Optional[torch.Tensor] = None): + B = x.shape[0] + restore_shape = x.shape[:-1] + + q, k, v = self.qkv(x).view(B, -1, self.num_heads, self.dim_head * 3).transpose(1, 2).chunk(3, dim=3) + + attn = (q @ k.transpose(-2, -1)) * self.scale + if self.rel_pos is not None: + attn = self.rel_pos(attn, shared_rel_pos=shared_rel_pos) + elif shared_rel_pos is not None: + attn = attn + shared_rel_pos + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(restore_shape + (-1,)) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class LayerScale(nn.Module): + def __init__(self, dim, init_values=1e-5, inplace=False): + super().__init__() + self.inplace = inplace + self.gamma = nn.Parameter(init_values * torch.ones(dim)) + + def forward(self, x): + gamma = self.gamma + return x.mul_(gamma) if self.inplace else x * gamma + + +class LayerScale2d(nn.Module): + def __init__(self, dim, init_values=1e-5, inplace=False): + super().__init__() + self.inplace = inplace + self.gamma = nn.Parameter(init_values * torch.ones(dim)) + + def forward(self, x): + gamma = self.gamma.view(1, -1, 1, 1) + return x.mul_(gamma) if self.inplace else x * gamma + + +class Downsample2d(nn.Module): + """ A downsample pooling module for Coat that handles 2d <-> 1d conversion + """ + + def __init__( + self, + dim: int, + dim_out: int, + pool_type: str = 'avg2', + bias: bool = True, + ): + super().__init__() + assert pool_type in ('max', 'avg', 'avg2') + if pool_type == 'max': + self.pool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + elif pool_type == 'avg': + self.pool = nn.AvgPool2d(kernel_size=3, stride=2, padding=1, count_include_pad=False) + else: + self.pool = nn.AvgPool2d(2) + + if dim != dim_out: + self.expand = nn.Conv2d(dim, dim_out, 1, bias=bias) + else: + self.expand = nn.Identity() + + def forward(self, x): + x = self.pool(x) # spatial downsample + x = self.expand(x) # expand chs + return x + + +def _init_transformer(module, name, scheme=''): + if isinstance(module, (nn.Conv2d, nn.Linear)): + if scheme == 'normal': + nn.init.normal_(module.weight, std=.02) + if module.bias is not None: + nn.init.zeros_(module.bias) + elif scheme == 'trunc_normal': + trunc_normal_tf_(module.weight, std=.02) + if module.bias is not None: + nn.init.zeros_(module.bias) + elif scheme == 'xavier_normal': + nn.init.xavier_normal_(module.weight) + if module.bias is not None: + nn.init.zeros_(module.bias) + else: + # vit like + nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + if 'mlp' in name: + nn.init.normal_(module.bias, std=1e-6) + else: + nn.init.zeros_(module.bias) + + +class TransformerBlock2d(nn.Module): + """ Transformer block with 2D downsampling + '2D' NCHW tensor layout + """ + + def __init__( + self, + dim: int, + dim_out: int, + stride: int = 1, + rel_pos_cls: Callable = None, + cfg: MaxxVitTransformerCfg = MaxxVitTransformerCfg(), + drop_path: float = 0., + ): + super().__init__() + norm_layer = partial(get_norm_layer(cfg.norm_layer), eps=cfg.norm_eps) + act_layer = get_act_layer(cfg.act_layer) + + if stride == 2: + self.shortcut = Downsample2d(dim, dim_out, pool_type=cfg.pool_type, bias=cfg.shortcut_bias) + self.norm1 = nn.Sequential(OrderedDict([ + ('norm', norm_layer(dim)), + ('down', Downsample2d(dim, dim, pool_type=cfg.pool_type)), + ])) + else: + assert dim == dim_out + self.shortcut = nn.Identity() + self.norm1 = norm_layer(dim) + + self.attn = Attention2d( + dim, + dim_out, + dim_head=cfg.dim_head, + expand_first=cfg.expand_first, + bias=cfg.attn_bias, + rel_pos_cls=rel_pos_cls, + attn_drop=cfg.attn_drop, + proj_drop=cfg.proj_drop + ) + self.ls1 = LayerScale2d(dim_out, init_values=cfg.init_values) if cfg.init_values else nn.Identity() + self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity() + + self.norm2 = norm_layer(dim_out) + self.mlp = ConvMlp( + in_features=dim_out, + hidden_features=int(dim_out * cfg.expand_ratio), + act_layer=act_layer, + drop=cfg.proj_drop) + self.ls2 = LayerScale2d(dim_out, init_values=cfg.init_values) if cfg.init_values else nn.Identity() + self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity() + + def init_weights(self, scheme=''): + named_apply(partial(_init_transformer, scheme=scheme), self) + + def forward(self, x, shared_rel_pos: Optional[torch.Tensor] = None): + x = self.shortcut(x) + self.drop_path1(self.ls1(self.attn(self.norm1(x), shared_rel_pos=shared_rel_pos))) + x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x)))) + return x + + +def _init_conv(module, name, scheme=''): + if isinstance(module, nn.Conv2d): + if scheme == 'normal': + nn.init.normal_(module.weight, std=.02) + if module.bias is not None: + nn.init.zeros_(module.bias) + elif scheme == 'trunc_normal': + trunc_normal_tf_(module.weight, std=.02) + if module.bias is not None: + nn.init.zeros_(module.bias) + elif scheme == 'xavier_normal': + nn.init.xavier_normal_(module.weight) + if module.bias is not None: + nn.init.zeros_(module.bias) + else: + # efficientnet like + fan_out = module.kernel_size[0] * module.kernel_size[1] * module.out_channels + fan_out //= module.groups + nn.init.normal_(module.weight, 0, math.sqrt(2.0 / fan_out)) + if module.bias is not None: + nn.init.zeros_(module.bias) + + +def num_groups(group_size, channels): + if not group_size: # 0 or None + return 1 # normal conv with 1 group + else: + # NOTE group_size == 1 -> depthwise conv + assert channels % group_size == 0 + return channels // group_size + + +class MbConvBlock(nn.Module): + """ Pre-Norm Conv Block - 1x1 - kxk - 1x1, w/ inverted bottleneck (expand) + """ + def __init__( + self, + in_chs: int, + out_chs: int, + stride: int = 1, + dilation: Tuple[int, int] = (1, 1), + cfg: MaxxVitConvCfg = MaxxVitConvCfg(), + drop_path: float = 0. + ): + super(MbConvBlock, self).__init__() + norm_act_layer = partial(get_norm_act_layer(cfg.norm_layer, cfg.act_layer), eps=cfg.norm_eps) + mid_chs = make_divisible((out_chs if cfg.expand_output else in_chs) * cfg.expand_ratio) + groups = num_groups(cfg.group_size, mid_chs) + + if stride == 2: + self.shortcut = Downsample2d(in_chs, out_chs, pool_type=cfg.pool_type, bias=cfg.output_bias) + else: + self.shortcut = nn.Identity() + + assert cfg.stride_mode in ('pool', '1x1', 'dw') + stride_pool, stride_1, stride_2 = 1, 1, 1 + if cfg.stride_mode == 'pool': + # NOTE this is not described in paper, experiment to find faster option that doesn't stride in 1x1 + stride_pool, dilation_2 = stride, dilation[1] + # FIXME handle dilation of avg pool + elif cfg.stride_mode == '1x1': + # NOTE I don't like this option described in paper, 1x1 w/ stride throws info away + stride_1, dilation_2 = stride, dilation[1] + else: + stride_2, dilation_2 = stride, dilation[0] + + self.pre_norm = norm_act_layer(in_chs, apply_act=cfg.pre_norm_act) + if stride_pool > 1: + self.down = Downsample2d(in_chs, in_chs, pool_type=cfg.downsample_pool_type) + else: + self.down = nn.Identity() + self.conv1_1x1 = create_conv2d(in_chs, mid_chs, 1, stride=stride_1) + self.norm1 = norm_act_layer(mid_chs) + + self.conv2_kxk = create_conv2d( + mid_chs, mid_chs, cfg.kernel_size, stride=stride_2, dilation=dilation_2, groups=groups) + + attn_kwargs = {} + if isinstance(cfg.attn_layer, str): + if cfg.attn_layer == 'se' or cfg.attn_layer == 'eca': + attn_kwargs['act_layer'] = cfg.attn_act_layer + attn_kwargs['rd_channels'] = int(cfg.attn_ratio * (out_chs if cfg.expand_output else mid_chs)) + + # two different orderings for SE and norm2 (due to some weights and trials using SE before norm2) + if cfg.attn_early: + self.se_early = create_attn(cfg.attn_layer, mid_chs, **attn_kwargs) + self.norm2 = norm_act_layer(mid_chs) + self.se = None + else: + self.se_early = None + self.norm2 = norm_act_layer(mid_chs) + self.se = create_attn(cfg.attn_layer, mid_chs, **attn_kwargs) + + self.conv3_1x1 = create_conv2d(mid_chs, out_chs, 1, bias=cfg.output_bias) + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + + def init_weights(self, scheme=''): + named_apply(partial(_init_conv, scheme=scheme), self) + + def forward(self, x): + shortcut = self.shortcut(x) + x = self.pre_norm(x) + x = self.down(x) + + # 1x1 expansion conv & norm-act + x = self.conv1_1x1(x) + x = self.norm1(x) + + # depthwise / grouped 3x3 conv w/ SE (or other) channel attention & norm-act + x = self.conv2_kxk(x) + if self.se_early is not None: + x = self.se_early(x) + x = self.norm2(x) + if self.se is not None: + x = self.se(x) + + # 1x1 linear projection to output width + x = self.conv3_1x1(x) + x = self.drop_path(x) + shortcut + return x + + +class ConvNeXtBlock(nn.Module): + """ ConvNeXt Block + """ + + def __init__( + self, + in_chs: int, + out_chs: Optional[int] = None, + kernel_size: int = 7, + stride: int = 1, + dilation: Tuple[int, int] = (1, 1), + cfg: MaxxVitConvCfg = MaxxVitConvCfg(), + conv_mlp: bool = True, + drop_path: float = 0. + ): + super().__init__() + out_chs = out_chs or in_chs + act_layer = get_act_layer(cfg.act_layer) + if conv_mlp: + norm_layer = partial(get_norm_layer(cfg.norm_layer), eps=cfg.norm_eps) + mlp_layer = ConvMlp + else: + assert 'layernorm' in cfg.norm_layer + norm_layer = LayerNorm + mlp_layer = Mlp + self.use_conv_mlp = conv_mlp + + if stride == 2: + self.shortcut = Downsample2d(in_chs, out_chs) + elif in_chs != out_chs: + self.shortcut = nn.Conv2d(in_chs, out_chs, kernel_size=1, bias=cfg.output_bias) + else: + self.shortcut = nn.Identity() + + assert cfg.stride_mode in ('pool', 'dw') + stride_pool, stride_dw = 1, 1 + # FIXME handle dilation? + if cfg.stride_mode == 'pool': + stride_pool = stride + else: + stride_dw = stride + + if stride_pool == 2: + self.down = Downsample2d(in_chs, in_chs, pool_type=cfg.downsample_pool_type) + else: + self.down = nn.Identity() + + self.conv_dw = create_conv2d( + in_chs, out_chs, kernel_size=kernel_size, stride=stride_dw, dilation=dilation[1], + depthwise=True, bias=cfg.output_bias) + self.norm = norm_layer(out_chs) + self.mlp = mlp_layer(out_chs, int(cfg.expand_ratio * out_chs), bias=cfg.output_bias, act_layer=act_layer) + if conv_mlp: + self.ls = LayerScale2d(out_chs, cfg.init_values) if cfg.init_values else nn.Identity() + else: + self.ls = LayerScale(out_chs, cfg.init_values) if cfg.init_values else nn.Identity() + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + + def forward(self, x): + shortcut = self.shortcut(x) + x = self.down(x) + x = self.conv_dw(x) + if self.use_conv_mlp: + x = self.norm(x) + x = self.mlp(x) + x = self.ls(x) + else: + x = x.permute(0, 2, 3, 1) + x = self.norm(x) + x = self.mlp(x) + x = self.ls(x) + x = x.permute(0, 3, 1, 2) + + x = self.drop_path(x) + shortcut + return x + + +def window_partition(x, window_size: List[int]): + B, H, W, C = x.shape + _assert(H % window_size[0] == 0, f'height ({H}) must be divisible by window ({window_size[0]})') + _assert(W % window_size[1] == 0, '') + x = x.view(B, H // window_size[0], window_size[0], W // window_size[1], window_size[1], C) + windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size[0], window_size[1], C) + return windows + + +@register_notrace_function # reason: int argument is a Proxy +def window_reverse(windows, window_size: List[int], img_size: List[int]): + H, W = img_size + C = windows.shape[-1] + x = windows.view(-1, H // window_size[0], W // window_size[1], window_size[0], window_size[1], C) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, H, W, C) + return x + + +def grid_partition(x, grid_size: List[int]): + B, H, W, C = x.shape + _assert(H % grid_size[0] == 0, f'height {H} must be divisible by grid {grid_size[0]}') + _assert(W % grid_size[1] == 0, '') + x = x.view(B, grid_size[0], H // grid_size[0], grid_size[1], W // grid_size[1], C) + windows = x.permute(0, 2, 4, 1, 3, 5).contiguous().view(-1, grid_size[0], grid_size[1], C) + return windows + + +@register_notrace_function # reason: int argument is a Proxy +def grid_reverse(windows, grid_size: List[int], img_size: List[int]): + H, W = img_size + C = windows.shape[-1] + x = windows.view(-1, H // grid_size[0], W // grid_size[1], grid_size[0], grid_size[1], C) + x = x.permute(0, 3, 1, 4, 2, 5).contiguous().view(-1, H, W, C) + return x + + +def get_rel_pos_cls(cfg: MaxxVitTransformerCfg, window_size): + rel_pos_cls = None + if cfg.rel_pos_type == 'mlp': + rel_pos_cls = partial(RelPosMlp, window_size=window_size, hidden_dim=cfg.rel_pos_dim) + elif cfg.rel_pos_type == 'bias': + rel_pos_cls = partial(RelPosBias, window_size=window_size) + return rel_pos_cls + + +class PartitionAttention(nn.Module): + """ Grid or Block partition + Attn + FFN. + NxC tensor layout. + """ + + def __init__( + self, + dim: int, + partition_type: str = 'block', + cfg: MaxxVitTransformerCfg = MaxxVitTransformerCfg(), + drop_path: float = 0., + ): + super().__init__() + norm_layer = partial(get_norm_layer(cfg.norm_layer_cl), eps=cfg.norm_eps) # NOTE this block is channels-last + act_layer = get_act_layer(cfg.act_layer) + + self.partition_block = partition_type == 'block' + self.partition_size = to_2tuple(cfg.window_size if self.partition_block else cfg.grid_size) + rel_pos_cls = get_rel_pos_cls(cfg, self.partition_size) + + self.norm1 = norm_layer(dim) + self.attn = AttentionCl( + dim, + dim, + dim_head=cfg.dim_head, + bias=cfg.attn_bias, + rel_pos_cls=rel_pos_cls, + attn_drop=cfg.attn_drop, + proj_drop=cfg.proj_drop, + ) + self.ls1 = LayerScale(dim, init_values=cfg.init_values) if cfg.init_values else nn.Identity() + self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity() + + self.norm2 = norm_layer(dim) + self.mlp = Mlp( + in_features=dim, + hidden_features=int(dim * cfg.expand_ratio), + act_layer=act_layer, + drop=cfg.proj_drop) + self.ls2 = LayerScale(dim, init_values=cfg.init_values) if cfg.init_values else nn.Identity() + self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity() + + def _partition_attn(self, x): + C = x.shape[-1] + img_size = x.shape[1:3] + if self.partition_block: + partitioned = window_partition(x, self.partition_size) + else: + partitioned = grid_partition(x, self.partition_size) + + partitioned = self.attn(partitioned) + + if self.partition_block: + x = window_reverse(partitioned, self.partition_size, img_size) + else: + x = grid_reverse(partitioned, self.partition_size, img_size) + return x + + def forward(self, x): + x = x + self.drop_path1(self.ls1(self._partition_attn(self.norm1(x)))) + x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x)))) + return x + + +class CombinedPartitionAttention(nn.Module): + """ Experimental. Grid and Block partition + single FFN + NxC tensor layout. + """ + + def __init__( + self, + dim: int, + cfg: MaxxVitTransformerCfg = MaxxVitTransformerCfg(), + drop_path: float = 0., + ): + super().__init__() + assert dim % 2 == 0 + norm_layer = partial(get_norm_layer(cfg.norm_layer_cl), eps=cfg.norm_eps) # NOTE this block is channels-last + act_layer = get_act_layer(cfg.act_layer) + + assert cfg.window_size == cfg.grid_size + self.partition_size = to_2tuple(cfg.window_size) + rel_pos_cls = get_rel_pos_cls(cfg, self.partition_size) + + self.norm1 = norm_layer(dim) + self.attn_block = AttentionCl( + dim, + dim // 2, + dim_head=cfg.dim_head, + bias=cfg.attn_bias, + rel_pos_cls=rel_pos_cls, + attn_drop=cfg.attn_drop, + proj_drop=cfg.proj_drop, + ) + self.attn_grid = AttentionCl( + dim, + dim // 2, + dim_head=cfg.dim_head, + bias=cfg.attn_bias, + rel_pos_cls=rel_pos_cls, + attn_drop=cfg.attn_drop, + proj_drop=cfg.proj_drop, + ) + self.ls1 = LayerScale(dim, init_values=cfg.init_values) if cfg.init_values else nn.Identity() + self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity() + + self.norm2 = norm_layer(dim) + self.mlp = Mlp( + in_features=dim, + hidden_features=int(dim * cfg.expand_ratio), + out_features=dim, + act_layer=act_layer, + drop=cfg.proj_drop) + self.ls2 = LayerScale(dim, init_values=cfg.init_values) if cfg.init_values else nn.Identity() + self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity() + + def _partition_attn(self, x): + img_size = x.shape[1:3] + + partitioned_block = window_partition(x, self.partition_size) + partitioned_block = self.attn_block(partitioned_block) + x_window = window_reverse(partitioned_block, self.partition_size, img_size) + + partitioned_grid = grid_partition(x, self.partition_size) + partitioned_grid = self.attn_grid(partitioned_grid) + x_grid = grid_reverse(partitioned_grid, self.partition_size, img_size) + + return torch.cat([x_window, x_grid], dim=-1) + + def forward(self, x): + x = x + self.drop_path1(self.ls1(self._partition_attn(self.norm1(x)))) + x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x)))) + return x + + +def window_partition_nchw(x, window_size: List[int]): + B, C, H, W = x.shape + _assert(H % window_size[0] == 0, f'height ({H}) must be divisible by window ({window_size[0]})') + _assert(W % window_size[1] == 0, '') + x = x.view(B, C, H // window_size[0], window_size[0], W // window_size[1], window_size[1]) + windows = x.permute(0, 2, 4, 1, 3, 5).contiguous().view(-1, C, window_size[0], window_size[1]) + return windows + + +@register_notrace_function # reason: int argument is a Proxy +def window_reverse_nchw(windows, window_size: List[int], img_size: List[int]): + H, W = img_size + C = windows.shape[1] + x = windows.view(-1, H // window_size[0], W // window_size[1], C, window_size[0], window_size[1]) + x = x.permute(0, 3, 1, 4, 2, 5).contiguous().view(-1, C, H, W) + return x + + +def grid_partition_nchw(x, grid_size: List[int]): + B, C, H, W = x.shape + _assert(H % grid_size[0] == 0, f'height {H} must be divisible by grid {grid_size[0]}') + _assert(W % grid_size[1] == 0, '') + x = x.view(B, C, grid_size[0], H // grid_size[0], grid_size[1], W // grid_size[1]) + windows = x.permute(0, 3, 5, 1, 2, 4).contiguous().view(-1, C, grid_size[0], grid_size[1]) + return windows + + +@register_notrace_function # reason: int argument is a Proxy +def grid_reverse_nchw(windows, grid_size: List[int], img_size: List[int]): + H, W = img_size + C = windows.shape[1] + x = windows.view(-1, H // grid_size[0], W // grid_size[1], C, grid_size[0], grid_size[1]) + x = x.permute(0, 3, 4, 1, 5, 2).contiguous().view(-1, C, H, W) + return x + + +class PartitionAttention2d(nn.Module): + """ Grid or Block partition + Attn + FFN + '2D' NCHW tensor layout. + """ + + def __init__( + self, + dim: int, + partition_type: str = 'block', + cfg: MaxxVitTransformerCfg = MaxxVitTransformerCfg(), + drop_path: float = 0., + ): + super().__init__() + norm_layer = partial(get_norm_layer(cfg.norm_layer), eps=cfg.norm_eps) # NOTE this block is channels-last + act_layer = get_act_layer(cfg.act_layer) + + self.partition_block = partition_type == 'block' + self.partition_size = to_2tuple(cfg.window_size if self.partition_block else cfg.grid_size) + rel_pos_cls = get_rel_pos_cls(cfg, self.partition_size) + + self.norm1 = norm_layer(dim) + self.attn = Attention2d( + dim, + dim, + dim_head=cfg.dim_head, + bias=cfg.attn_bias, + rel_pos_cls=rel_pos_cls, + attn_drop=cfg.attn_drop, + proj_drop=cfg.proj_drop, + ) + self.ls1 = LayerScale2d(dim, init_values=cfg.init_values) if cfg.init_values else nn.Identity() + self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity() + + self.norm2 = norm_layer(dim) + self.mlp = ConvMlp( + in_features=dim, + hidden_features=int(dim * cfg.expand_ratio), + act_layer=act_layer, + drop=cfg.proj_drop) + self.ls2 = LayerScale2d(dim, init_values=cfg.init_values) if cfg.init_values else nn.Identity() + self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity() + + def _partition_attn(self, x): + img_size = x.shape[-2:] + if self.partition_block: + partitioned = window_partition_nchw(x, self.partition_size) + else: + partitioned = grid_partition_nchw(x, self.partition_size) + + partitioned = self.attn(partitioned) + + if self.partition_block: + x = window_reverse_nchw(partitioned, self.partition_size, img_size) + else: + x = grid_reverse_nchw(partitioned, self.partition_size, img_size) + return x + + def forward(self, x): + x = x + self.drop_path1(self.ls1(self._partition_attn(self.norm1(x)))) + x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x)))) + return x + + +class MaxxVitBlock(nn.Module): + """ + """ + + def __init__( + self, + dim: int, + dim_out: int, + stride: int = 1, + conv_cfg: MaxxVitConvCfg = MaxxVitConvCfg(), + transformer_cfg: MaxxVitTransformerCfg = MaxxVitTransformerCfg(), + use_nchw_attn: bool = False, # FIXME move to cfg? True is ~20-30% faster on TPU, 5-10% slower on GPU + drop_path: float = 0., + ): + super().__init__() + + conv_cls = ConvNeXtBlock if conv_cfg.block_type == 'convnext' else MbConvBlock + self.conv = conv_cls(dim, dim_out, stride=stride, cfg=conv_cfg, drop_path=drop_path) + + attn_kwargs = dict(dim=dim_out, cfg=transformer_cfg, drop_path=drop_path) + partition_layer = PartitionAttention2d if use_nchw_attn else PartitionAttention + self.nchw_attn = use_nchw_attn + self.attn_block = partition_layer(**attn_kwargs) + self.attn_grid = partition_layer(partition_type='grid', **attn_kwargs) + + def init_weights(self, scheme=''): + named_apply(partial(_init_transformer, scheme=scheme), self.attn_block) + named_apply(partial(_init_transformer, scheme=scheme), self.attn_grid) + named_apply(partial(_init_conv, scheme=scheme), self.conv) + + def forward(self, x): + # NCHW format + x = self.conv(x) + + if not self.nchw_attn: + x = x.permute(0, 2, 3, 1) # to NHWC (channels-last) + x = self.attn_block(x) + x = self.attn_grid(x) + if not self.nchw_attn: + x = x.permute(0, 3, 1, 2) # back to NCHW + return x + + +class CombinedMaxxVitBlock(nn.Module): + """ + """ + + def __init__( + self, + dim, + dim_out, + stride=1, + num_conv=2, + conv_cfg: MaxxVitConvCfg = MaxxVitConvCfg(), + transformer_cfg: MaxxVitTransformerCfg = MaxxVitTransformerCfg(), + drop_path=0., + ): + super().__init__() + + conv_cls = ConvNeXtBlock if conv_cfg.block_type == 'convnext' else MbConvBlock + if num_conv > 1: + convs = [conv_cls(dim, dim_out, stride=stride, cfg=conv_cfg, drop_path=drop_path)] + convs += [conv_cls(dim_out, dim_out, cfg=conv_cfg, drop_path=drop_path)] * (num_conv - 1) + self.conv = nn.Sequential(*convs) + else: + self.conv = conv_cls(dim, dim_out, stride=stride, cfg=conv_cfg, drop_path=drop_path) + self.attn = CombinedPartitionAttention(dim=dim_out, cfg=transformer_cfg, drop_path=drop_path) + + def init_weights(self, scheme=''): + named_apply(partial(_init_transformer, scheme=scheme), self.attn) + named_apply(partial(_init_conv, scheme=scheme), self.conv) + + def forward(self, x): + x = self.conv(x) + x = x.permute(0, 2, 3, 1) + x = self.attn(x) + x = x.permute(0, 3, 1, 2) + return x + + +class MaxxVitStage(nn.Module): + def __init__( + self, + in_chs: int, + out_chs: int, + stride: int = 2, + depth: int = 4, + feat_size: Tuple[int, int] = (14, 14), + block_types: Union[str, Tuple[str]] = 'C', + transformer_cfg: MaxxVitTransformerCfg = MaxxVitTransformerCfg(), + conv_cfg: MaxxVitConvCfg = MaxxVitConvCfg(), + drop_path: Union[float, List[float]] = 0., + ): + super().__init__() + self.grad_checkpointing = False + + block_types = extend_tuple(block_types, depth) + blocks = [] + for i, t in enumerate(block_types): + block_stride = stride if i == 0 else 1 + assert t in ('C', 'T', 'M', 'CM') + if t == 'C': + conv_cls = ConvNeXtBlock if conv_cfg.block_type == 'convnext' else MbConvBlock + blocks += [conv_cls( + in_chs, + out_chs, + stride=block_stride, + cfg=conv_cfg, + drop_path=drop_path[i], + )] + elif t == 'T': + rel_pos_cls = get_rel_pos_cls(transformer_cfg, feat_size) + blocks += [TransformerBlock2d( + in_chs, + out_chs, + stride=block_stride, + rel_pos_cls=rel_pos_cls, + cfg=transformer_cfg, + drop_path=drop_path[i], + )] + elif t == 'M': + blocks += [MaxxVitBlock( + in_chs, + out_chs, + stride=block_stride, + conv_cfg=conv_cfg, + transformer_cfg=transformer_cfg, + drop_path=drop_path[i], + )] + elif t == 'CM': + blocks += [CombinedMaxxVitBlock( + in_chs, + out_chs, + stride=block_stride, + conv_cfg=conv_cfg, + transformer_cfg=transformer_cfg, + drop_path=drop_path[i], + )] + in_chs = out_chs + self.blocks = nn.Sequential(*blocks) + + def forward(self, x): + if self.grad_checkpointing and not torch.jit.is_scripting(): + x = checkpoint_seq(self.blocks, x) + else: + x = self.blocks(x) + return x + + +class Stem(nn.Module): + + def __init__( + self, + in_chs: int, + out_chs: int, + kernel_size: int = 3, + act_layer: str = 'gelu', + norm_layer: str = 'batchnorm2d', + norm_eps: float = 1e-5, + ): + super().__init__() + if not isinstance(out_chs, (list, tuple)): + out_chs = to_2tuple(out_chs) + + norm_act_layer = partial(get_norm_act_layer(norm_layer, act_layer), eps=norm_eps) + self.out_chs = out_chs[-1] + self.stride = 2 + + self.conv1 = create_conv2d(in_chs, out_chs[0], kernel_size, stride=2) + self.norm1 = norm_act_layer(out_chs[0]) + self.conv2 = create_conv2d(out_chs[0], out_chs[1], kernel_size, stride=1) + + def init_weights(self, scheme=''): + named_apply(partial(_init_conv, scheme=scheme), self) + + def forward(self, x): + x = self.conv1(x) + x = self.norm1(x) + x = self.conv2(x) + return x + + +class MaxxVit(nn.Module): + """ + """ + + def __init__( + self, + cfg: MaxxVitCfg, + img_size: Union[int, Tuple[int, int]] = 224, + in_chans: int = 3, + num_classes: int = 1000, + global_pool: str = 'avg', + drop_rate: float = 0., + drop_path_rate: float = 0. + ): + super().__init__() + img_size = to_2tuple(img_size) + self.num_classes = num_classes + self.global_pool = global_pool + self.num_features = cfg.embed_dim[-1] + self.embed_dim = cfg.embed_dim + self.drop_rate = drop_rate + self.grad_checkpointing = False + + self.stem = Stem( + in_chs=in_chans, + out_chs=cfg.stem_width, + act_layer=cfg.conv_cfg.act_layer, + norm_layer=cfg.conv_cfg.norm_layer, + norm_eps=cfg.conv_cfg.norm_eps, + ) + + stride = self.stem.stride + feat_size = tuple([i // s for i, s in zip(img_size, to_2tuple(stride))]) + + num_stages = len(cfg.embed_dim) + assert len(cfg.depths) == num_stages + dpr = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(cfg.depths)).split(cfg.depths)] + in_chs = self.stem.out_chs + stages = [] + for i in range(num_stages): + stage_stride = 2 + out_chs = cfg.embed_dim[i] + feat_size = tuple([(r - 1) // stage_stride + 1 for r in feat_size]) + stages += [MaxxVitStage( + in_chs, + out_chs, + depth=cfg.depths[i], + block_types=cfg.block_type[i], + conv_cfg=cfg.conv_cfg, + transformer_cfg=cfg.transformer_cfg, + feat_size=feat_size, + drop_path=dpr[i], + )] + stride *= stage_stride + in_chs = out_chs + self.stages = nn.Sequential(*stages) + + final_norm_layer = get_norm_layer(cfg.transformer_cfg.norm_layer) + self.norm = final_norm_layer(self.num_features, eps=cfg.transformer_cfg.norm_eps) + + # Classifier head + self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=drop_rate) + + # Weight init (default PyTorch init works well for AdamW if scheme not set) + assert cfg.weight_init in ('', 'normal', 'trunc_normal', 'xavier_normal', 'vit_eff') + if cfg.weight_init: + named_apply(partial(self._init_weights, scheme=cfg.weight_init), self) + + def _init_weights(self, module, name, scheme=''): + if hasattr(module, 'init_weights'): + try: + module.init_weights(scheme=scheme) + except TypeError: + module.init_weights() + + @torch.jit.ignore + def no_weight_decay(self): + return { + k for k, _ in self.named_parameters() + if any(n in k for n in ["relative_position_bias_table", "rel_pos.mlp"])} + + @torch.jit.ignore + def group_matcher(self, coarse=False): + matcher = dict( + stem=r'^stem', # stem and embed + blocks=[(r'^stages\.(\d+)', None), (r'^norm', (99999,))] + ) + return matcher + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + for s in self.stages: + s.grad_checkpointing = enable + + @torch.jit.ignore + def get_classifier(self): + return self.head.fc + + def reset_classifier(self, num_classes, global_pool=None): + self.num_classes = num_classes + if global_pool is None: + global_pool = self.head.global_pool.pool_type + self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate) + + def forward_features(self, x): + x = self.stem(x) + x = self.stages(x) + x = self.norm(x) + return x + + def forward_head(self, x, pre_logits: bool = False): + return self.head(x, pre_logits=pre_logits) + + def forward(self, x): + x = self.forward_features(x) + x = self.forward_head(x) + return x + + +def _create_coatnet(variant, cfg_variant=None, pretrained=False, **kwargs): + return build_model_with_cfg( + MaxxVit, variant, pretrained, + model_cfg=model_cfgs[variant] if not cfg_variant else model_cfgs[cfg_variant], + feature_cfg=dict(flatten_sequential=True), + **kwargs) + + +@register_model +def coatnet_pico_rw_224(pretrained=False, **kwargs): + return _create_coatnet('coatnet_pico_rw_224', pretrained=pretrained, **kwargs) + + +@register_model +def coatnet_nano_rw_224(pretrained=False, **kwargs): + return _create_coatnet('coatnet_nano_rw_224', pretrained=pretrained, **kwargs) + + +@register_model +def coatnet_0_rw_224(pretrained=False, **kwargs): + return _create_coatnet('coatnet_0_rw_224', pretrained=pretrained, **kwargs) + + +@register_model +def coatnet_1_rw_224(pretrained=False, **kwargs): + return _create_coatnet('coatnet_1_rw_224', pretrained=pretrained, **kwargs) + + +@register_model +def coatnet_2_rw_224(pretrained=False, **kwargs): + return _create_coatnet('coatnet_2_rw_224', pretrained=pretrained, **kwargs) + + +@register_model +def coatnet_bn_0_rw_224(pretrained=False, **kwargs): + return _create_coatnet('coatnet_bn_0_rw_224', pretrained=pretrained, **kwargs) + + +@register_model +def coatnet_rmlp_nano_rw_224(pretrained=False, **kwargs): + return _create_coatnet('coatnet_rmlp_nano_rw_224', pretrained=pretrained, **kwargs) + + +@register_model +def coatnet_rmlp_0_rw_224(pretrained=False, **kwargs): + return _create_coatnet('coatnet_rmlp_0_rw_224', pretrained=pretrained, **kwargs) + + +@register_model +def coatnet_rmlp_1_rw_224(pretrained=False, **kwargs): + return _create_coatnet('coatnet_rmlp_1_rw_224', pretrained=pretrained, **kwargs) + + +@register_model +def coatnet_nano_cc_224(pretrained=False, **kwargs): + return _create_coatnet('coatnet_nano_cc_224', pretrained=pretrained, **kwargs) + + +@register_model +def coatnext_nano_rw_224(pretrained=False, **kwargs): + return _create_coatnet('coatnext_nano_rw_224', pretrained=pretrained, **kwargs) + + +@register_model +def coatnet_0_224(pretrained=False, **kwargs): + return _create_coatnet('coatnet_0_224', pretrained=pretrained, **kwargs) + + +@register_model +def coatnet_1_224(pretrained=False, **kwargs): + return _create_coatnet('coatnet_1_224', pretrained=pretrained, **kwargs) + + +@register_model +def coatnet_2_224(pretrained=False, **kwargs): + return _create_coatnet('coatnet_2_224', pretrained=pretrained, **kwargs) + + +@register_model +def coatnet_3_224(pretrained=False, **kwargs): + return _create_coatnet('coatnet_3_224', pretrained=pretrained, **kwargs) + + +@register_model +def coatnet_4_224(pretrained=False, **kwargs): + return _create_coatnet('coatnet_4_224', pretrained=pretrained, **kwargs) + + +@register_model +def coatnet_5_224(pretrained=False, **kwargs): + return _create_coatnet('coatnet_5_224', pretrained=pretrained, **kwargs) + + +@register_model +def maxvit_pico_rw_256(pretrained=False, **kwargs): + return _create_coatnet('maxvit_pico_rw_256', pretrained=pretrained, **kwargs) + + +@register_model +def maxvit_nano_rw_256(pretrained=False, **kwargs): + return _create_coatnet('maxvit_nano_rw_256', pretrained=pretrained, **kwargs) + + +@register_model +def maxvit_tiny_rw_224(pretrained=False, **kwargs): + return _create_coatnet('maxvit_tiny_rw_224', pretrained=pretrained, **kwargs) + + +@register_model +def maxvit_tiny_rw_256(pretrained=False, **kwargs): + return _create_coatnet('maxvit_tiny_rw_256', pretrained=pretrained, **kwargs) + + +@register_model +def maxvit_tiny_cm_256(pretrained=False, **kwargs): + return _create_coatnet('maxvit_tiny_cm_256', pretrained=pretrained, **kwargs) + + +@register_model +def maxxvit_nano_rw_256(pretrained=False, **kwargs): + return _create_coatnet('maxxvit_nano_rw_256', pretrained=pretrained, **kwargs) + + +@register_model +def maxvit_tiny_224(pretrained=False, **kwargs): + return _create_coatnet('maxvit_tiny_224', pretrained=pretrained, **kwargs) + + +@register_model +def maxvit_small_224(pretrained=False, **kwargs): + return _create_coatnet('maxvit_small_224', pretrained=pretrained, **kwargs) + + +@register_model +def maxvit_base_224(pretrained=False, **kwargs): + return _create_coatnet('maxvit_base_224', pretrained=pretrained, **kwargs) + + +@register_model +def maxvit_large_224(pretrained=False, **kwargs): + return _create_coatnet('maxvit_large_224', pretrained=pretrained, **kwargs) + + +@register_model +def maxvit_xlarge_224(pretrained=False, **kwargs): + return _create_coatnet('maxvit_xlarge_224', pretrained=pretrained, **kwargs) \ No newline at end of file From e939ed19b99385e8fb450bcbcbd247048718fe1b Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Mon, 22 Aug 2022 17:44:51 -0700 Subject: [PATCH 10/21] Rename internal creation fn for maxvit, has not been just coatnet for a while... --- timm/models/maxxvit.py | 58 +++++++++++++++++++++--------------------- 1 file changed, 29 insertions(+), 29 deletions(-) diff --git a/timm/models/maxxvit.py b/timm/models/maxxvit.py index 8b1fe0a6..d963bfb7 100644 --- a/timm/models/maxxvit.py +++ b/timm/models/maxxvit.py @@ -1544,7 +1544,7 @@ class MaxxVit(nn.Module): return x -def _create_coatnet(variant, cfg_variant=None, pretrained=False, **kwargs): +def _create_maxxvit(variant, cfg_variant=None, pretrained=False, **kwargs): return build_model_with_cfg( MaxxVit, variant, pretrained, model_cfg=model_cfgs[variant] if not cfg_variant else model_cfgs[cfg_variant], @@ -1554,139 +1554,139 @@ def _create_coatnet(variant, cfg_variant=None, pretrained=False, **kwargs): @register_model def coatnet_pico_rw_224(pretrained=False, **kwargs): - return _create_coatnet('coatnet_pico_rw_224', pretrained=pretrained, **kwargs) + return _create_maxxvit('coatnet_pico_rw_224', pretrained=pretrained, **kwargs) @register_model def coatnet_nano_rw_224(pretrained=False, **kwargs): - return _create_coatnet('coatnet_nano_rw_224', pretrained=pretrained, **kwargs) + return _create_maxxvit('coatnet_nano_rw_224', pretrained=pretrained, **kwargs) @register_model def coatnet_0_rw_224(pretrained=False, **kwargs): - return _create_coatnet('coatnet_0_rw_224', pretrained=pretrained, **kwargs) + return _create_maxxvit('coatnet_0_rw_224', pretrained=pretrained, **kwargs) @register_model def coatnet_1_rw_224(pretrained=False, **kwargs): - return _create_coatnet('coatnet_1_rw_224', pretrained=pretrained, **kwargs) + return _create_maxxvit('coatnet_1_rw_224', pretrained=pretrained, **kwargs) @register_model def coatnet_2_rw_224(pretrained=False, **kwargs): - return _create_coatnet('coatnet_2_rw_224', pretrained=pretrained, **kwargs) + return _create_maxxvit('coatnet_2_rw_224', pretrained=pretrained, **kwargs) @register_model def coatnet_bn_0_rw_224(pretrained=False, **kwargs): - return _create_coatnet('coatnet_bn_0_rw_224', pretrained=pretrained, **kwargs) + return _create_maxxvit('coatnet_bn_0_rw_224', pretrained=pretrained, **kwargs) @register_model def coatnet_rmlp_nano_rw_224(pretrained=False, **kwargs): - return _create_coatnet('coatnet_rmlp_nano_rw_224', pretrained=pretrained, **kwargs) + return _create_maxxvit('coatnet_rmlp_nano_rw_224', pretrained=pretrained, **kwargs) @register_model def coatnet_rmlp_0_rw_224(pretrained=False, **kwargs): - return _create_coatnet('coatnet_rmlp_0_rw_224', pretrained=pretrained, **kwargs) + return _create_maxxvit('coatnet_rmlp_0_rw_224', pretrained=pretrained, **kwargs) @register_model def coatnet_rmlp_1_rw_224(pretrained=False, **kwargs): - return _create_coatnet('coatnet_rmlp_1_rw_224', pretrained=pretrained, **kwargs) + return _create_maxxvit('coatnet_rmlp_1_rw_224', pretrained=pretrained, **kwargs) @register_model def coatnet_nano_cc_224(pretrained=False, **kwargs): - return _create_coatnet('coatnet_nano_cc_224', pretrained=pretrained, **kwargs) + return _create_maxxvit('coatnet_nano_cc_224', pretrained=pretrained, **kwargs) @register_model def coatnext_nano_rw_224(pretrained=False, **kwargs): - return _create_coatnet('coatnext_nano_rw_224', pretrained=pretrained, **kwargs) + return _create_maxxvit('coatnext_nano_rw_224', pretrained=pretrained, **kwargs) @register_model def coatnet_0_224(pretrained=False, **kwargs): - return _create_coatnet('coatnet_0_224', pretrained=pretrained, **kwargs) + return _create_maxxvit('coatnet_0_224', pretrained=pretrained, **kwargs) @register_model def coatnet_1_224(pretrained=False, **kwargs): - return _create_coatnet('coatnet_1_224', pretrained=pretrained, **kwargs) + return _create_maxxvit('coatnet_1_224', pretrained=pretrained, **kwargs) @register_model def coatnet_2_224(pretrained=False, **kwargs): - return _create_coatnet('coatnet_2_224', pretrained=pretrained, **kwargs) + return _create_maxxvit('coatnet_2_224', pretrained=pretrained, **kwargs) @register_model def coatnet_3_224(pretrained=False, **kwargs): - return _create_coatnet('coatnet_3_224', pretrained=pretrained, **kwargs) + return _create_maxxvit('coatnet_3_224', pretrained=pretrained, **kwargs) @register_model def coatnet_4_224(pretrained=False, **kwargs): - return _create_coatnet('coatnet_4_224', pretrained=pretrained, **kwargs) + return _create_maxxvit('coatnet_4_224', pretrained=pretrained, **kwargs) @register_model def coatnet_5_224(pretrained=False, **kwargs): - return _create_coatnet('coatnet_5_224', pretrained=pretrained, **kwargs) + return _create_maxxvit('coatnet_5_224', pretrained=pretrained, **kwargs) @register_model def maxvit_pico_rw_256(pretrained=False, **kwargs): - return _create_coatnet('maxvit_pico_rw_256', pretrained=pretrained, **kwargs) + return _create_maxxvit('maxvit_pico_rw_256', pretrained=pretrained, **kwargs) @register_model def maxvit_nano_rw_256(pretrained=False, **kwargs): - return _create_coatnet('maxvit_nano_rw_256', pretrained=pretrained, **kwargs) + return _create_maxxvit('maxvit_nano_rw_256', pretrained=pretrained, **kwargs) @register_model def maxvit_tiny_rw_224(pretrained=False, **kwargs): - return _create_coatnet('maxvit_tiny_rw_224', pretrained=pretrained, **kwargs) + return _create_maxxvit('maxvit_tiny_rw_224', pretrained=pretrained, **kwargs) @register_model def maxvit_tiny_rw_256(pretrained=False, **kwargs): - return _create_coatnet('maxvit_tiny_rw_256', pretrained=pretrained, **kwargs) + return _create_maxxvit('maxvit_tiny_rw_256', pretrained=pretrained, **kwargs) @register_model def maxvit_tiny_cm_256(pretrained=False, **kwargs): - return _create_coatnet('maxvit_tiny_cm_256', pretrained=pretrained, **kwargs) + return _create_maxxvit('maxvit_tiny_cm_256', pretrained=pretrained, **kwargs) @register_model def maxxvit_nano_rw_256(pretrained=False, **kwargs): - return _create_coatnet('maxxvit_nano_rw_256', pretrained=pretrained, **kwargs) + return _create_maxxvit('maxxvit_nano_rw_256', pretrained=pretrained, **kwargs) @register_model def maxvit_tiny_224(pretrained=False, **kwargs): - return _create_coatnet('maxvit_tiny_224', pretrained=pretrained, **kwargs) + return _create_maxxvit('maxvit_tiny_224', pretrained=pretrained, **kwargs) @register_model def maxvit_small_224(pretrained=False, **kwargs): - return _create_coatnet('maxvit_small_224', pretrained=pretrained, **kwargs) + return _create_maxxvit('maxvit_small_224', pretrained=pretrained, **kwargs) @register_model def maxvit_base_224(pretrained=False, **kwargs): - return _create_coatnet('maxvit_base_224', pretrained=pretrained, **kwargs) + return _create_maxxvit('maxvit_base_224', pretrained=pretrained, **kwargs) @register_model def maxvit_large_224(pretrained=False, **kwargs): - return _create_coatnet('maxvit_large_224', pretrained=pretrained, **kwargs) + return _create_maxxvit('maxvit_large_224', pretrained=pretrained, **kwargs) @register_model def maxvit_xlarge_224(pretrained=False, **kwargs): - return _create_coatnet('maxvit_xlarge_224', pretrained=pretrained, **kwargs) \ No newline at end of file + return _create_maxxvit('maxvit_xlarge_224', pretrained=pretrained, **kwargs) \ No newline at end of file From cac0a4570a96b88cac1b864ea538bf717d73eeb6 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Tue, 23 Aug 2022 13:38:26 -0700 Subject: [PATCH 11/21] More test fixes, pool size for 256x256 maxvit models --- tests/test_models.py | 2 +- timm/models/efficientformer.py | 2 +- timm/models/maxxvit.py | 13 +++++++------ 3 files changed, 9 insertions(+), 8 deletions(-) diff --git a/tests/test_models.py b/tests/test_models.py index 5daee76d..175137e2 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -28,7 +28,7 @@ NON_STD_FILTERS = [ 'vit_*', 'tnt_*', 'pit_*', 'swin_*', 'coat_*', 'cait_*', '*mixer_*', 'gmlp_*', 'resmlp_*', 'twins_*', 'convit_*', 'levit*', 'visformer*', 'deit*', 'jx_nest_*', 'nest_*', 'xcit_*', 'crossvit_*', 'beit_*', 'poolformer_*', 'volo_*', 'sequencer2d_*', 'swinv2_*', 'pvt_v2*', 'mvitv2*', 'gcvit*', 'efficientformer*', - 'coatne?t_*', 'max?vit_*', + 'coatnet*', 'coatnext*', 'maxvit*', 'maxxvit*', ] NUM_NON_STD = len(NON_STD_FILTERS) diff --git a/timm/models/efficientformer.py b/timm/models/efficientformer.py index 814b6957..4749d93a 100644 --- a/timm/models/efficientformer.py +++ b/timm/models/efficientformer.py @@ -29,7 +29,7 @@ def _cfg(url='', **kwargs): 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, 'fixed_input_size': True, 'crop_pct': .95, 'interpolation': 'bicubic', 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, - 'first_conv': 'stem.conv1', 'classifier': 'head', + 'first_conv': 'stem.conv1', 'classifier': ('head', 'head_dist'), **kwargs } diff --git a/timm/models/maxxvit.py b/timm/models/maxxvit.py index d963bfb7..57ebce74 100644 --- a/timm/models/maxxvit.py +++ b/timm/models/maxxvit.py @@ -94,6 +94,7 @@ default_cfgs = { 'coatnet_rmlp_0_rw_224': _cfg(url=''), 'coatnet_rmlp_1_rw_224': _cfg( url=''), + 'coatnet_nano_cc_224': _cfg(url=''), 'coatnext_nano_rw_224': _cfg(url=''), # Trying to be like the CoAtNet paper configs @@ -105,12 +106,12 @@ default_cfgs = { 'coatnet_5_224': _cfg(url=''), # Experimental configs - 'maxvit_pico_rw_256': _cfg(url='', input_size=(3, 256, 256)), - 'maxvit_nano_rw_256': _cfg(url='', input_size=(3, 256, 256)), + 'maxvit_pico_rw_256': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8)), + 'maxvit_nano_rw_256': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8)), 'maxvit_tiny_rw_224': _cfg(url=''), - 'maxvit_tiny_rw_256': _cfg(url='', input_size=(3, 256, 256)), - 'maxvit_tiny_cm_256': _cfg(url='', input_size=(3, 256, 256)), - 'maxxvit_nano_rw_256': _cfg(url='', input_size=(3, 256, 256)), + 'maxvit_tiny_rw_256': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8)), + 'maxvit_tiny_cm_256': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8)), + 'maxxvit_nano_rw_256': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8)), # Trying to be like the MaxViT paper configs 'maxvit_tiny_224': _cfg(url=''), @@ -1052,7 +1053,6 @@ class PartitionAttention(nn.Module): self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity() def _partition_attn(self, x): - C = x.shape[-1] img_size = x.shape[1:3] if self.partition_block: partitioned = window_partition(x, self.partition_size) @@ -1415,6 +1415,7 @@ class Stem(nn.Module): self.norm1 = norm_act_layer(out_chs[0]) self.conv2 = create_conv2d(out_chs[0], out_chs[1], kernel_size, stride=1) + @torch.jit.ignore def init_weights(self, scheme=''): named_apply(partial(_init_conv, scheme=scheme), self) From 837c68263b0af0c34eaadbd301bf89d64786e3bc Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Tue, 23 Aug 2022 15:17:12 -0700 Subject: [PATCH 12/21] For ConvNeXt, use timm internal LayerNorm for fast_norm in non conv_mlp mode --- timm/models/convnext.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/timm/models/convnext.py b/timm/models/convnext.py index ba63a453..15000b40 100644 --- a/timm/models/convnext.py +++ b/timm/models/convnext.py @@ -19,7 +19,7 @@ import torch.nn as nn from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from .helpers import named_apply, build_model_with_cfg, checkpoint_seq -from .layers import trunc_normal_, SelectAdaptivePool2d, DropPath, ConvMlp, Mlp, LayerNorm2d,\ +from .layers import trunc_normal_, SelectAdaptivePool2d, DropPath, ConvMlp, Mlp, LayerNorm2d, LayerNorm, \ create_conv2d, get_act_layer, make_divisible, to_ntuple from .registry import register_model @@ -161,7 +161,7 @@ class ConvNeXtBlock(nn.Module): out_chs = out_chs or in_chs act_layer = get_act_layer(act_layer) if not norm_layer: - norm_layer = partial(LayerNorm2d, eps=1e-6) if conv_mlp else partial(nn.LayerNorm, eps=1e-6) + norm_layer = LayerNorm2d if conv_mlp else LayerNorm mlp_layer = ConvMlp if conv_mlp else Mlp self.use_conv_mlp = conv_mlp @@ -291,8 +291,8 @@ class ConvNeXt(nn.Module): assert output_stride in (8, 16, 32) kernel_sizes = to_ntuple(4)(kernel_sizes) if norm_layer is None: - norm_layer = partial(LayerNorm2d, eps=1e-6) - norm_layer_cl = norm_layer if conv_mlp else partial(nn.LayerNorm, eps=1e-6) + norm_layer = LayerNorm2d + norm_layer_cl = norm_layer if conv_mlp else LayerNorm else: assert conv_mlp,\ 'If a norm_layer is specified, conv MLP must be used so all norm expect rank-4, channels-first input' From b2e8426fca45bcc0f7f9b99e1d131f5823037717 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Wed, 24 Aug 2022 11:01:20 -0700 Subject: [PATCH 13/21] Make k=stride=2 ('avg2') pooling default for coatnet/maxvit. Add weight links. Rename 'combined' partition to 'parallel'. --- timm/models/maxxvit.py | 54 +++++++++++++++++++++--------------------- 1 file changed, 27 insertions(+), 27 deletions(-) diff --git a/timm/models/maxxvit.py b/timm/models/maxxvit.py index 57ebce74..898e1685 100644 --- a/timm/models/maxxvit.py +++ b/timm/models/maxxvit.py @@ -74,26 +74,26 @@ default_cfgs = { # Fiddling with configs / defaults / still pretraining 'coatnet_pico_rw_224': _cfg(url=''), 'coatnet_nano_rw_224': _cfg( - url='', + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/coatnet_nano_rw_224_sw-f53093b4.pth', crop_pct=0.9), 'coatnet_0_rw_224': _cfg( - url=''), + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/coatnet_0_rw_224_sw-a6439706.pth'), 'coatnet_1_rw_224': _cfg( - url='' + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/coatnet_1_rw_224_sw-5cae1ea8.pth' ), 'coatnet_2_rw_224': _cfg(url=''), # Highly experimental configs 'coatnet_bn_0_rw_224': _cfg( - url='', + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/coatnet_bn_0_rw_224_sw-c228e218.pth', mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, crop_pct=0.95), 'coatnet_rmlp_nano_rw_224': _cfg( - url='', + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/coatnet_rmlp_nano_rw_224_sw-bd1d51b3.pth', crop_pct=0.9), 'coatnet_rmlp_0_rw_224': _cfg(url=''), 'coatnet_rmlp_1_rw_224': _cfg( - url=''), + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/coatnet_rmlp_1_rw_224_sw-9051e6c3.pth'), 'coatnet_nano_cc_224': _cfg(url=''), 'coatnext_nano_rw_224': _cfg(url=''), @@ -107,10 +107,12 @@ default_cfgs = { # Experimental configs 'maxvit_pico_rw_256': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8)), - 'maxvit_nano_rw_256': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8)), + 'maxvit_nano_rw_256': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/maxvit_nano_rw_256_sw-3e790ce3.pth', + input_size=(3, 256, 256), pool_size=(8, 8)), 'maxvit_tiny_rw_224': _cfg(url=''), 'maxvit_tiny_rw_256': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8)), - 'maxvit_tiny_cm_256': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8)), + 'maxvit_tiny_pm_256': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8)), 'maxxvit_nano_rw_256': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8)), # Trying to be like the MaxViT paper configs @@ -131,7 +133,7 @@ class MaxxVitTransformerCfg: attn_bias: bool = True attn_drop: float = 0. proj_drop: float = 0. - pool_type: str = 'avg' + pool_type: str = 'avg2' rel_pos_type: str = 'bias' rel_pos_dim: int = 512 # for relative position types w/ MLP window_size: Tuple[int, int] = (7, 7) @@ -153,7 +155,7 @@ class MaxxVitConvCfg: pre_norm_act: bool = False # activation after pre-norm output_bias: bool = True # bias for shortcut + final 1x1 projection conv stride_mode: str = 'dw' # stride done via one of 'pool', '1x1', 'dw' - pool_type: str = 'avg' + pool_type: str = 'avg2' downsample_pool_type: str = 'avg2' attn_early: bool = False # apply attn between conv2 and norm2, instead of after norm2 attn_layer: str = 'se' @@ -241,7 +243,7 @@ def _rw_coat_cfg( def _rw_max_cfg( stride_mode='dw', - pool_type='avg', + pool_type='avg2', conv_output_bias=False, conv_attn_ratio=1 / 16, conv_norm_layer='', @@ -325,7 +327,6 @@ model_cfgs = dict( depths=(2, 3, 5, 2), stem_width=(32, 64), **_rw_max_cfg( # using newer max defaults here - pool_type='avg2', conv_output_bias=True, conv_attn_ratio=0.25, ), @@ -336,7 +337,6 @@ model_cfgs = dict( stem_width=(32, 64), **_rw_max_cfg( # using newer max defaults here stride_mode='pool', - pool_type='avg2', conv_output_bias=True, conv_attn_ratio=0.25, ), @@ -384,7 +384,6 @@ model_cfgs = dict( depths=(3, 4, 6, 3), stem_width=(32, 64), **_rw_max_cfg( - pool_type='avg2', conv_output_bias=True, conv_attn_ratio=0.25, rel_pos_type='mlp', @@ -487,10 +486,10 @@ model_cfgs = dict( stem_width=(32, 64), **_rw_max_cfg(window_size=8), ), - maxvit_tiny_cm_256=MaxxVitCfg( + maxvit_tiny_pm_256=MaxxVitCfg( embed_dim=(64, 128, 256, 512), depths=(2, 2, 5, 2), - block_type=('CM',) * 4, + block_type=('PM',) * 4, stem_width=(32, 64), **_rw_max_cfg(window_size=8), ), @@ -663,13 +662,15 @@ class Downsample2d(nn.Module): bias: bool = True, ): super().__init__() - assert pool_type in ('max', 'avg', 'avg2') + assert pool_type in ('max', 'max2', 'avg', 'avg2') if pool_type == 'max': self.pool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + elif pool_type == 'max2': + self.pool = nn.MaxPool2d(2) # kernel_size == stride == 2 elif pool_type == 'avg': self.pool = nn.AvgPool2d(kernel_size=3, stride=2, padding=1, count_include_pad=False) else: - self.pool = nn.AvgPool2d(2) + self.pool = nn.AvgPool2d(2) # kernel_size == stride == 2 if dim != dim_out: self.expand = nn.Conv2d(dim, dim_out, 1, bias=bias) @@ -1073,7 +1074,7 @@ class PartitionAttention(nn.Module): return x -class CombinedPartitionAttention(nn.Module): +class ParallelPartitionAttention(nn.Module): """ Experimental. Grid and Block partition + single FFN NxC tensor layout. """ @@ -1286,7 +1287,7 @@ class MaxxVitBlock(nn.Module): return x -class CombinedMaxxVitBlock(nn.Module): +class ParallelMaxxVitBlock(nn.Module): """ """ @@ -1309,7 +1310,7 @@ class CombinedMaxxVitBlock(nn.Module): self.conv = nn.Sequential(*convs) else: self.conv = conv_cls(dim, dim_out, stride=stride, cfg=conv_cfg, drop_path=drop_path) - self.attn = CombinedPartitionAttention(dim=dim_out, cfg=transformer_cfg, drop_path=drop_path) + self.attn = ParallelPartitionAttention(dim=dim_out, cfg=transformer_cfg, drop_path=drop_path) def init_weights(self, scheme=''): named_apply(partial(_init_transformer, scheme=scheme), self.attn) @@ -1343,7 +1344,7 @@ class MaxxVitStage(nn.Module): blocks = [] for i, t in enumerate(block_types): block_stride = stride if i == 0 else 1 - assert t in ('C', 'T', 'M', 'CM') + assert t in ('C', 'T', 'M', 'PM') if t == 'C': conv_cls = ConvNeXtBlock if conv_cfg.block_type == 'convnext' else MbConvBlock blocks += [conv_cls( @@ -1372,8 +1373,8 @@ class MaxxVitStage(nn.Module): transformer_cfg=transformer_cfg, drop_path=drop_path[i], )] - elif t == 'CM': - blocks += [CombinedMaxxVitBlock( + elif t == 'PM': + blocks += [ParallelMaxxVitBlock( in_chs, out_chs, stride=block_stride, @@ -1415,7 +1416,6 @@ class Stem(nn.Module): self.norm1 = norm_act_layer(out_chs[0]) self.conv2 = create_conv2d(out_chs[0], out_chs[1], kernel_size, stride=1) - @torch.jit.ignore def init_weights(self, scheme=''): named_apply(partial(_init_conv, scheme=scheme), self) @@ -1659,8 +1659,8 @@ def maxvit_tiny_rw_256(pretrained=False, **kwargs): @register_model -def maxvit_tiny_cm_256(pretrained=False, **kwargs): - return _create_maxxvit('maxvit_tiny_cm_256', pretrained=pretrained, **kwargs) +def maxvit_tiny_pm_256(pretrained=False, **kwargs): + return _create_maxxvit('maxvit_tiny_pm_256', pretrained=pretrained, **kwargs) @register_model From 527f9a4cb22f784e80d8b046ba083000feae081e Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Wed, 24 Aug 2022 12:42:11 -0700 Subject: [PATCH 14/21] Updated to correct maxvit_nano weights... --- timm/models/maxxvit.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/timm/models/maxxvit.py b/timm/models/maxxvit.py index 898e1685..7f4ebf59 100644 --- a/timm/models/maxxvit.py +++ b/timm/models/maxxvit.py @@ -108,7 +108,7 @@ default_cfgs = { # Experimental configs 'maxvit_pico_rw_256': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8)), 'maxvit_nano_rw_256': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/maxvit_nano_rw_256_sw-3e790ce3.pth', + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/maxvit_nano_rw_256_sw-fb127241.pth', input_size=(3, 256, 256), pool_size=(8, 8)), 'maxvit_tiny_rw_224': _cfg(url=''), 'maxvit_tiny_rw_256': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8)), From 1d8d6f6072659e905d91a2b297d53e927853457d Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Thu, 25 Aug 2022 15:00:35 -0700 Subject: [PATCH 15/21] Fix two default args in DenseNet blocks... fix #1427 --- timm/models/densenet.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/timm/models/densenet.py b/timm/models/densenet.py index a46b86ad..1afdfd7b 100644 --- a/timm/models/densenet.py +++ b/timm/models/densenet.py @@ -115,7 +115,7 @@ class DenseBlock(nn.ModuleDict): _version = 2 def __init__( - self, num_layers, num_input_features, bn_size, growth_rate, norm_layer=nn.ReLU, + self, num_layers, num_input_features, bn_size, growth_rate, norm_layer=BatchNormAct2d, drop_rate=0., memory_efficient=False): super(DenseBlock, self).__init__() for i in range(num_layers): @@ -138,7 +138,7 @@ class DenseBlock(nn.ModuleDict): class DenseTransition(nn.Sequential): - def __init__(self, num_input_features, num_output_features, norm_layer=nn.BatchNorm2d, aa_layer=None): + def __init__(self, num_input_features, num_output_features, norm_layer=BatchNormAct2d, aa_layer=None): super(DenseTransition, self).__init__() self.add_module('norm', norm_layer(num_input_features)) self.add_module('conv', nn.Conv2d( From 7c2660576d565b7441922265456ec8b050608da3 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Thu, 25 Aug 2022 15:30:59 -0700 Subject: [PATCH 16/21] Tweak init for convnext block using maxxvit/coatnext. --- timm/models/maxxvit.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/timm/models/maxxvit.py b/timm/models/maxxvit.py index 7f4ebf59..82840523 100644 --- a/timm/models/maxxvit.py +++ b/timm/models/maxxvit.py @@ -259,8 +259,6 @@ def _rw_max_cfg( # - mbconv expansion calculated from input instead of output chs # - mbconv shortcut and final 1x1 conv did not have a bias # - mbconv uses silu in timm, not gelu - # - avg pool with kernel_size=2 favoured downsampling (instead of maxpool for coat) - # - default to avg pool for mbconv downsample instead of 1x1 or dw conv # - expansion in attention block done via output proj, not input proj return dict( conv_cfg=MaxxVitConvCfg( @@ -411,18 +409,19 @@ model_cfgs = dict( rel_pos_dim=384, # was supposed to be 512, woops ), ), - coatnext_nano_rw_224=MaxxVitCfg( + coatnet_nano_cc_224=MaxxVitCfg( embed_dim=(64, 128, 256, 512), depths=(3, 4, 6, 3), stem_width=(32, 64), - **_next_cfg(), + block_type=('C', 'C', ('C', 'T'), ('C', 'T')), + **_rw_coat_cfg(), ), - coatnet_nano_cc_224=MaxxVitCfg( + coatnext_nano_rw_224=MaxxVitCfg( embed_dim=(64, 128, 256, 512), depths=(3, 4, 6, 3), stem_width=(32, 64), - block_type=('C', 'C', ('C', 'T'), ('C', 'T')), - **_rw_coat_cfg(), + weight_init='normal', + **_next_cfg(), ), # Trying to be like the CoAtNet paper configs @@ -498,6 +497,7 @@ model_cfgs = dict( depths=(1, 2, 3, 1), block_type=('M',) * 4, stem_width=(32, 64), + weight_init='normal', **_next_cfg(window_size=8), ), From a54008bd97057149e6c110ddc8887481508ee595 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Thu, 25 Aug 2022 15:56:56 -0700 Subject: [PATCH 17/21] Update README.md for merge --- README.md | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/README.md b/README.md index 019fdae2..fdf83853 100644 --- a/README.md +++ b/README.md @@ -21,6 +21,25 @@ And a big thanks to all GitHub sponsors who helped with some of my costs before ## What's New +### Aug 26, 2022 +* CoAtNet (https://arxiv.org/abs/2106.04803) and MaxVit (https://arxiv.org/abs/2204.01697) `timm` original models + * both found in [`maxxvit.py`](https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/maxxvit.py) model def, contains numerous experiments outside scope of original papers + * an unfinished Tensorflow version from MaxVit authors can be found https://github.com/google-research/maxvit +* Initial CoAtNet and MaxVit timm pretrained weights (working on more): + * `coatnet_nano_rw_224` - 81.7 @ 224 (T) + * `coatnet_rmlp_nano_rw_224` - 82.0 @ 224, 82.8 @ 320 (T) + * `coatnet_0_rw_224` - 82.4 (T) -- NOTE timm '0' coatnets have 2 more 3rd stage blocks + * `coatnet_bn_0_rw_224` - 82.4 (T) + * `maxvit_nano_rw_256` - 82.9 @ 256 (T) + * `coatnet_rmlp_1_rw_224` - 83.4 @ 224, 84 @ 320 (T) + * `coatnet_1_rw_224` - 83.6 @ 224 (G) +* GCVit (weights adapted from https://github.com/NVlabs/GCVit, code 100% `timm` re-write for license purposes) +* MViT-V2 (multi-scale vit, adapted from https://github.com/facebookresearch/mvit) +* EfficientFormer (adapted from https://github.com/snap-research/EfficientFormer) +* PyramidVisionTransformer-V2 (adapted from https://github.com/whai362/PVT) +* 'Fast Norm' support for LayerNorm and GroupNorm that avoids float32 upcast w/ AMP (uses APEX LN if available for further boost) + + ### Aug 15, 2022 * ConvNeXt atto weights added * `convnext_atto` - 75.7 @ 224, 77.0 @ 288 @@ -229,6 +248,7 @@ A full version of the list below with source links can be found in the [document * Bottleneck Transformers - https://arxiv.org/abs/2101.11605 * CaiT (Class-Attention in Image Transformers) - https://arxiv.org/abs/2103.17239 * CoaT (Co-Scale Conv-Attentional Image Transformers) - https://arxiv.org/abs/2104.06399 +* CoAtNet (Convolution and Attention) - https://arxiv.org/abs/2106.04803 * ConvNeXt - https://arxiv.org/abs/2201.03545 * ConViT (Soft Convolutional Inductive Biases Vision Transformers)- https://arxiv.org/abs/2103.10697 * CspNet (Cross-Stage Partial Networks) - https://arxiv.org/abs/1911.11929 @@ -238,6 +258,7 @@ A full version of the list below with source links can be found in the [document * DLA - https://arxiv.org/abs/1707.06484 * DPN (Dual-Path Network) - https://arxiv.org/abs/1707.01629 * EdgeNeXt - https://arxiv.org/abs/2206.10589 +* EfficientFormer - https://arxiv.org/abs/2206.01191 * EfficientNet (MBConvNet Family) * EfficientNet NoisyStudent (B0-B7, L2) - https://arxiv.org/abs/1911.04252 * EfficientNet AdvProp (B0-B8) - https://arxiv.org/abs/1911.09665 @@ -259,6 +280,7 @@ A full version of the list below with source links can be found in the [document * Inception-ResNet-V2 and Inception-V4 - https://arxiv.org/abs/1602.07261 * Lambda Networks - https://arxiv.org/abs/2102.08602 * LeViT (Vision Transformer in ConvNet's Clothing) - https://arxiv.org/abs/2104.01136 +* MaxViT (Multi-Axis Vision Transformer) - https://arxiv.org/abs/2204.01697 * MLP-Mixer - https://arxiv.org/abs/2105.01601 * MobileNet-V3 (MBConvNet w/ Efficient Head) - https://arxiv.org/abs/1905.02244 * FBNet-V3 - https://arxiv.org/abs/2006.02049 @@ -266,6 +288,7 @@ A full version of the list below with source links can be found in the [document * LCNet - https://arxiv.org/abs/2109.15099 * MobileViT - https://arxiv.org/abs/2110.02178 * MobileViT-V2 - https://arxiv.org/abs/2206.02680 +* MViT-V2 (Improved Multiscale Vision Transformer) - https://arxiv.org/abs/2112.01526 * NASNet-A - https://arxiv.org/abs/1707.07012 * NesT - https://arxiv.org/abs/2105.12723 * NFNet-F - https://arxiv.org/abs/2102.06171 @@ -273,6 +296,7 @@ A full version of the list below with source links can be found in the [document * PNasNet - https://arxiv.org/abs/1712.00559 * PoolFormer (MetaFormer) - https://arxiv.org/abs/2111.11418 * Pooling-based Vision Transformer (PiT) - https://arxiv.org/abs/2103.16302 +* PVT-V2 (Improved Pyramid Vision Transformer) - https://arxiv.org/abs/2106.13797 * RegNet - https://arxiv.org/abs/2003.13678 * RegNetZ - https://arxiv.org/abs/2103.06877 * RepVGG - https://arxiv.org/abs/2101.03697 From 99ee61e245e4a7652fd25043d4102fef7e480131 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Thu, 25 Aug 2022 15:58:57 -0700 Subject: [PATCH 18/21] Add T/G legend to README.md maxvit list --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index fdf83853..9912d40d 100644 --- a/README.md +++ b/README.md @@ -33,6 +33,7 @@ And a big thanks to all GitHub sponsors who helped with some of my costs before * `maxvit_nano_rw_256` - 82.9 @ 256 (T) * `coatnet_rmlp_1_rw_224` - 83.4 @ 224, 84 @ 320 (T) * `coatnet_1_rw_224` - 83.6 @ 224 (G) + * (T) = TPU trained with `bits_and_tpu` branch training code, (G) = GPU trained * GCVit (weights adapted from https://github.com/NVlabs/GCVit, code 100% `timm` re-write for license purposes) * MViT-V2 (multi-scale vit, adapted from https://github.com/facebookresearch/mvit) * EfficientFormer (adapted from https://github.com/snap-research/EfficientFormer) From 48e1df8b3789bda43e4942513c896677e1b1e8dc Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Thu, 25 Aug 2022 16:29:34 -0700 Subject: [PATCH 19/21] Add norm/norm_act header comments --- timm/models/layers/fast_norm.py | 10 ++++++++++ timm/models/layers/norm.py | 4 ++++ timm/models/layers/norm_act.py | 12 ++++++++++++ 3 files changed, 26 insertions(+) diff --git a/timm/models/layers/fast_norm.py b/timm/models/layers/fast_norm.py index 9a34a15e..fb35e47d 100644 --- a/timm/models/layers/fast_norm.py +++ b/timm/models/layers/fast_norm.py @@ -1,3 +1,11 @@ +""" 'Fast' Normalization Functions + +For GroupNorm and LayerNorm these functions bypass typical AMP upcast to float32. + +Additionally, for LayerNorm, the APEX fused LN is used if available (which also does not upcast) + +Hacked together by / Copyright 2022 Ross Wightman +""" from typing import List, Optional import torch @@ -37,6 +45,7 @@ def fast_group_norm( if torch.is_autocast_enabled(): # normally native AMP casts GN inputs to float32 # here we use the low precision autocast dtype + # FIXME what to do re CPU autocast? dt = torch.get_autocast_gpu_dtype() x, weight, bias = x.to(dt), weight.to(dt), bias.to(dt) @@ -62,6 +71,7 @@ def fast_layer_norm( # normally native AMP casts LN inputs to float32 # apex LN does not, this is behaving like Apex dt = torch.get_autocast_gpu_dtype() + # FIXME what to do re CPU autocast? x, weight, bias = x.to(dt), weight.to(dt), bias.to(dt) with torch.cuda.amp.autocast(enabled=False): diff --git a/timm/models/layers/norm.py b/timm/models/layers/norm.py index 2ff8fc08..42445a49 100644 --- a/timm/models/layers/norm.py +++ b/timm/models/layers/norm.py @@ -1,4 +1,8 @@ """ Normalization layers and wrappers + +Norm layer definitions that support fast norm and consistent channel arg order (always first arg). + +Hacked together by / Copyright 2022 Ross Wightman """ import torch diff --git a/timm/models/layers/norm_act.py b/timm/models/layers/norm_act.py index dc077160..ff075fbc 100644 --- a/timm/models/layers/norm_act.py +++ b/timm/models/layers/norm_act.py @@ -1,4 +1,16 @@ """ Normalization + Activation Layers + +Provides Norm+Act fns for standard PyTorch norm layers such as +* BatchNorm +* GroupNorm +* LayerNorm + +This allows swapping with alternative layers that are natively both norm + act such as +* EvoNorm (evo_norm.py) +* FilterResponseNorm (filter_response_norm.py) +* InplaceABN (inplace_abn.py) + +Hacked together by / Copyright 2022 Ross Wightman """ from typing import Union, List, Optional, Any From 769ab4b98a51f18ce8b1fa37deab3a8f1f9032aa Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Thu, 25 Aug 2022 16:29:52 -0700 Subject: [PATCH 20/21] Clean up no_grad for trunc normal weight inits --- timm/models/layers/weight_init.py | 48 ++++++++++++++++--------------- 1 file changed, 25 insertions(+), 23 deletions(-) diff --git a/timm/models/layers/weight_init.py b/timm/models/layers/weight_init.py index 4a160931..943e4f4c 100644 --- a/timm/models/layers/weight_init.py +++ b/timm/models/layers/weight_init.py @@ -5,7 +5,7 @@ import warnings from torch.nn.init import _calculate_fan_in_and_fan_out -def _no_grad_trunc_normal_(tensor, mean, std, a, b): +def _trunc_normal_(tensor, mean, std, a, b): # Cut & paste from PyTorch official master until it's in a few official releases - RW # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf def norm_cdf(x): @@ -17,28 +17,27 @@ def _no_grad_trunc_normal_(tensor, mean, std, a, b): "The distribution of values may be incorrect.", stacklevel=2) - with torch.no_grad(): - # Values are generated by using a truncated uniform distribution and - # then using the inverse CDF for the normal distribution. - # Get upper and lower cdf values - l = norm_cdf((a - mean) / std) - u = norm_cdf((b - mean) / std) + # Values are generated by using a truncated uniform distribution and + # then using the inverse CDF for the normal distribution. + # Get upper and lower cdf values + l = norm_cdf((a - mean) / std) + u = norm_cdf((b - mean) / std) - # Uniformly fill tensor with values from [l, u], then translate to - # [2l-1, 2u-1]. - tensor.uniform_(2 * l - 1, 2 * u - 1) + # Uniformly fill tensor with values from [l, u], then translate to + # [2l-1, 2u-1]. + tensor.uniform_(2 * l - 1, 2 * u - 1) - # Use inverse cdf transform for normal distribution to get truncated - # standard normal - tensor.erfinv_() + # Use inverse cdf transform for normal distribution to get truncated + # standard normal + tensor.erfinv_() - # Transform to proper mean, std - tensor.mul_(std * math.sqrt(2.)) - tensor.add_(mean) + # Transform to proper mean, std + tensor.mul_(std * math.sqrt(2.)) + tensor.add_(mean) - # Clamp to ensure it's in the proper range - tensor.clamp_(min=a, max=b) - return tensor + # Clamp to ensure it's in the proper range + tensor.clamp_(min=a, max=b) + return tensor def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.): @@ -64,7 +63,8 @@ def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.): >>> w = torch.empty(3, 5) >>> nn.init.trunc_normal_(w) """ - return _no_grad_trunc_normal_(tensor, mean, std, a, b) + with torch.no_grad(): + return _trunc_normal_(tensor, mean, std, a, b) def trunc_normal_tf_(tensor, mean=0., std=1., a=-2., b=2.): @@ -90,8 +90,8 @@ def trunc_normal_tf_(tensor, mean=0., std=1., a=-2., b=2.): >>> w = torch.empty(3, 5) >>> nn.init.trunc_normal_(w) """ - _no_grad_trunc_normal_(tensor, 0, 1.0, a, b) with torch.no_grad(): + _trunc_normal_(tensor, 0, 1.0, a, b) tensor.mul_(std).add_(mean) return tensor @@ -111,10 +111,12 @@ def variance_scaling_(tensor, scale=1.0, mode='fan_in', distribution='normal'): # constant is stddev of standard normal truncated to (-2, 2) trunc_normal_tf_(tensor, std=math.sqrt(variance) / .87962566103423978) elif distribution == "normal": - tensor.normal_(std=math.sqrt(variance)) + with torch.no_grad(): + tensor.normal_(std=math.sqrt(variance)) elif distribution == "uniform": bound = math.sqrt(3 * variance) - tensor.uniform_(-bound, bound) + with torch.no_grad(): + tensor.uniform_(-bound, bound) else: raise ValueError(f"invalid distribution {distribution}") From ff6a919cf5f0a325236cf57c07548f779123173f Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Thu, 25 Aug 2022 17:00:54 -0700 Subject: [PATCH 21/21] Add --fast-norm arg to benchmark.py, train.py, validate.py --- benchmark.py | 8 ++++++-- timm/models/__init__.py | 1 + train.py | 6 +++++- validate.py | 6 +++++- 4 files changed, 17 insertions(+), 4 deletions(-) diff --git a/benchmark.py b/benchmark.py index 4679a009..4a89441b 100755 --- a/benchmark.py +++ b/benchmark.py @@ -19,7 +19,7 @@ import torch.nn as nn import torch.nn.parallel from timm.data import resolve_data_config -from timm.models import create_model, is_model, list_models +from timm.models import create_model, is_model, list_models, set_fast_norm from timm.optim import create_optimizer_v2 from timm.utils import setup_default_logging, set_jit_fuser, decay_batch_step, check_batch_size_retry @@ -109,7 +109,8 @@ scripting_group.add_argument('--torchscript', dest='torchscript', action='store_ help='convert model torchscript for inference') scripting_group.add_argument('--aot-autograd', default=False, action='store_true', help="Enable AOT Autograd support. (It's recommended to use this option with `--fuser nvfuser` together)") - +scripting_group.add_argument('--fast-norm', default=False, action='store_true', + help='enable experimental fast-norm') # train optimizer parameters parser.add_argument('--opt', default='sgd', type=str, metavar='OPTIMIZER', @@ -598,6 +599,9 @@ def main(): model_cfgs = [] model_names = [] + if args.fast_norm: + set_fast_norm() + if args.model_list: args.model = '' with open(args.model_list) as f: diff --git a/timm/models/__init__.py b/timm/models/__init__.py index 51a38d0c..5ff79595 100644 --- a/timm/models/__init__.py +++ b/timm/models/__init__.py @@ -69,5 +69,6 @@ from .helpers import load_checkpoint, resume_checkpoint, model_parameters from .layers import TestTimePoolHead, apply_test_time_pool from .layers import convert_splitbn_model, convert_sync_batchnorm from .layers import is_scriptable, is_exportable, set_scriptable, set_exportable, is_no_jit, set_no_jit +from .layers import set_fast_norm from .registry import register_model, model_entrypoint, list_models, is_model, list_modules, is_model_in_modules,\ is_model_pretrained, get_pretrained_cfg, has_pretrained_cfg_key, is_pretrained_cfg_key, get_pretrained_cfg_value diff --git a/train.py b/train.py index e5d40566..ee137217 100755 --- a/train.py +++ b/train.py @@ -33,7 +33,7 @@ from timm.data import create_dataset, create_loader, resolve_data_config, Mixup, from timm.loss import JsdCrossEntropy, SoftTargetCrossEntropy, BinaryCrossEntropy, \ LabelSmoothingCrossEntropy from timm.models import create_model, safe_model_name, resume_checkpoint, load_checkpoint, \ - convert_splitbn_model, convert_sync_batchnorm, model_parameters + convert_splitbn_model, convert_sync_batchnorm, model_parameters, set_fast_norm from timm.optim import create_optimizer_v2, optimizer_kwargs from timm.scheduler import create_scheduler from timm.utils import ApexScaler, NativeScaler @@ -135,6 +135,8 @@ scripting_group.add_argument('--aot-autograd', default=False, action='store_true help="Enable AOT Autograd support. (It's recommended to use this option with `--fuser nvfuser` together)") group.add_argument('--fuser', default='', type=str, help="Select jit fuser. One of ('', 'te', 'old', 'nvfuser')") +group.add_argument('--fast-norm', default=False, action='store_true', + help='enable experimental fast-norm') group.add_argument('--grad-checkpointing', action='store_true', default=False, help='Enable gradient checkpointing through model blocks/stages') @@ -395,6 +397,8 @@ def main(): if args.fuser: utils.set_jit_fuser(args.fuser) + if args.fast_norm: + set_fast_norm() model = create_model( args.model, diff --git a/validate.py b/validate.py index a4d41868..6244f052 100755 --- a/validate.py +++ b/validate.py @@ -20,7 +20,7 @@ import torch.nn.parallel from collections import OrderedDict from contextlib import suppress -from timm.models import create_model, apply_test_time_pool, load_checkpoint, is_model, list_models +from timm.models import create_model, apply_test_time_pool, load_checkpoint, is_model, list_models, set_fast_norm from timm.data import create_dataset, create_loader, resolve_data_config, RealLabelsImagenet from timm.utils import accuracy, AverageMeter, natural_key, setup_default_logging, set_jit_fuser,\ decay_batch_step, check_batch_size_retry @@ -117,6 +117,8 @@ scripting_group.add_argument('--aot-autograd', default=False, action='store_true help="Enable AOT Autograd support. (It's recommended to use this option with `--fuser nvfuser` together)") parser.add_argument('--fuser', default='', type=str, help="Select jit fuser. One of ('', 'te', 'old', 'nvfuser')") +parser.add_argument('--fast-norm', default=False, action='store_true', + help='enable experimental fast-norm') parser.add_argument('--results-file', default='', type=str, metavar='FILENAME', help='Output csv file for validation results (summary)') parser.add_argument('--real-labels', default='', type=str, metavar='FILENAME', @@ -150,6 +152,8 @@ def validate(args): if args.fuser: set_jit_fuser(args.fuser) + if args.fast_norm: + set_fast_norm() # create model model = create_model(