diff --git a/timm/models/__init__.py b/timm/models/__init__.py index b93d3f94..51a38d0c 100644 --- a/timm/models/__init__.py +++ b/timm/models/__init__.py @@ -25,6 +25,7 @@ from .inception_resnet_v2 import * from .inception_v3 import * from .inception_v4 import * from .levit import * +from .maxxvit import * from .mlp_mixer import * from .mobilenetv3 import * from .mobilevit import * diff --git a/timm/models/layers/__init__.py b/timm/models/layers/__init__.py index 071da7bc..21c641b6 100644 --- a/timm/models/layers/__init__.py +++ b/timm/models/layers/__init__.py @@ -21,7 +21,7 @@ from .fast_norm import is_fast_norm, set_fast_norm, fast_group_norm, fast_layer_ from .filter_response_norm import FilterResponseNormTlu2d, FilterResponseNormAct2d from .gather_excite import GatherExcite from .global_context import GlobalContext -from .helpers import to_ntuple, to_2tuple, to_3tuple, to_4tuple, make_divisible +from .helpers import to_ntuple, to_2tuple, to_3tuple, to_4tuple, make_divisible, extend_tuple from .inplace_abn import InplaceAbn from .linear import Linear from .mixed_conv2d import MixedConv2d diff --git a/timm/models/layers/helpers.py b/timm/models/layers/helpers.py index 4a10ca0c..2fa296bc 100644 --- a/timm/models/layers/helpers.py +++ b/timm/models/layers/helpers.py @@ -29,3 +29,15 @@ def make_divisible(v, divisor=8, min_value=None, round_limit=.9): if new_v < round_limit * v: new_v += divisor return new_v + + +def extend_tuple(x, n): + # pdas a tuple to specified n by padding with last value + if not isinstance(x, (tuple, list)): + x = (x,) + else: + x = tuple(x) + pad_n = n - len(x) + if pad_n <= 0: + return x[:n] + return x + (x[-1],) * pad_n diff --git a/timm/models/maxxvit.py b/timm/models/maxxvit.py new file mode 100644 index 00000000..8b1fe0a6 --- /dev/null +++ b/timm/models/maxxvit.py @@ -0,0 +1,1692 @@ +""" MaxVit and CoAtNet Vision Transformer - CNN Hybrids in PyTorch + +This is a from-scratch implementation of both CoAtNet and MaxVit in PyTorch. + +99% of the implementation was done from papers, however last minute some adjustments were made +based on the (as yet unfinished?) public code release https://github.com/google-research/maxvit + +There are multiple sets of models defined for both architectures. Typically, names with a + `_rw` suffix are my own original configs prior to referencing https://github.com/google-research/maxvit. +These configs work well and appear to be a bit faster / lower resource than the paper. + +The models without extra prefix / suffix' (coatnet_0_224, maxvit_tiny_224, etc), are intended to +match paper, BUT, without any official pretrained weights it's difficult to confirm a 100% match. + +# FIXME / WARNING +This impl remains a WIP, some configs and models may vanish or change... + +Papers: + +MaxViT: Multi-Axis Vision Transformer - https://arxiv.org/abs/2204.01697 +@article{tu2022maxvit, + title={MaxViT: Multi-Axis Vision Transformer}, + author={Tu, Zhengzhong and Talebi, Hossein and Zhang, Han and Yang, Feng and Milanfar, Peyman and Bovik, Alan and Li, Yinxiao}, + journal={ECCV}, + year={2022}, +} + +CoAtNet: Marrying Convolution and Attention for All Data Sizes - https://arxiv.org/abs/2106.04803 +@article{DBLP:journals/corr/abs-2106-04803, + author = {Zihang Dai and Hanxiao Liu and Quoc V. Le and Mingxing Tan}, + title = {CoAtNet: Marrying Convolution and Attention for All Data Sizes}, + journal = {CoRR}, + volume = {abs/2106.04803}, + year = {2021} +} + +Hacked together by / Copyright 2022, Ross Wightman +""" + +import math +from collections import OrderedDict +from dataclasses import dataclass +from functools import partial +from typing import Callable, Optional, Union, Tuple, List + +import torch +from torch import nn +from torch.utils.checkpoint import checkpoint + +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from .helpers import build_model_with_cfg, checkpoint_seq, named_apply +from .fx_features import register_notrace_function +from .layers import Mlp, ConvMlp, DropPath, ClassifierHead, trunc_normal_tf_, LayerNorm2d, LayerNorm +from .layers import create_attn, get_act_layer, get_norm_layer, get_norm_act_layer, create_conv2d +from .layers import to_2tuple, extend_tuple, make_divisible, _assert +from .registry import register_model +from .vision_transformer_relpos import RelPosMlp, RelPosBias # FIXME move these to common location + +__all__ = ['MaxxVitCfg', 'MaxxVitConvCfg', 'MaxxVitTransformerCfg', 'MaxxVit'] + + +def _cfg(url='', **kwargs): + return { + 'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7), + 'crop_pct': 0.95, 'interpolation': 'bicubic', + 'mean': (0.5, 0.5, 0.5), 'std': (0.5, 0.5, 0.5), + 'first_conv': 'stem.conv1', 'classifier': 'head.fc', + 'fixed_input_size': True, + **kwargs + } + + +default_cfgs = { + # Fiddling with configs / defaults / still pretraining + 'coatnet_pico_rw_224': _cfg(url=''), + 'coatnet_nano_rw_224': _cfg( + url='', + crop_pct=0.9), + 'coatnet_0_rw_224': _cfg( + url=''), + 'coatnet_1_rw_224': _cfg( + url='' + ), + 'coatnet_2_rw_224': _cfg(url=''), + + # Highly experimental configs + 'coatnet_bn_0_rw_224': _cfg( + url='', + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, + crop_pct=0.95), + 'coatnet_rmlp_nano_rw_224': _cfg( + url='', + crop_pct=0.9), + 'coatnet_rmlp_0_rw_224': _cfg(url=''), + 'coatnet_rmlp_1_rw_224': _cfg( + url=''), + 'coatnext_nano_rw_224': _cfg(url=''), + + # Trying to be like the CoAtNet paper configs + 'coatnet_0_224': _cfg(url=''), + 'coatnet_1_224': _cfg(url=''), + 'coatnet_2_224': _cfg(url=''), + 'coatnet_3_224': _cfg(url=''), + 'coatnet_4_224': _cfg(url=''), + 'coatnet_5_224': _cfg(url=''), + + # Experimental configs + 'maxvit_pico_rw_256': _cfg(url='', input_size=(3, 256, 256)), + 'maxvit_nano_rw_256': _cfg(url='', input_size=(3, 256, 256)), + 'maxvit_tiny_rw_224': _cfg(url=''), + 'maxvit_tiny_rw_256': _cfg(url='', input_size=(3, 256, 256)), + 'maxvit_tiny_cm_256': _cfg(url='', input_size=(3, 256, 256)), + 'maxxvit_nano_rw_256': _cfg(url='', input_size=(3, 256, 256)), + + # Trying to be like the MaxViT paper configs + 'maxvit_tiny_224': _cfg(url=''), + 'maxvit_small_224': _cfg(url=''), + 'maxvit_base_224': _cfg(url=''), + 'maxvit_large_224': _cfg(url=''), + 'maxvit_xlarge_224': _cfg(url=''), +} + + +@dataclass +class MaxxVitTransformerCfg: + dim_head: int = 32 + expand_ratio: float = 4.0 + expand_first: bool = True + shortcut_bias: bool = True, + attn_bias: bool = True + attn_drop: float = 0. + proj_drop: float = 0. + pool_type: str = 'avg' + rel_pos_type: str = 'bias' + rel_pos_dim: int = 512 # for relative position types w/ MLP + window_size: Tuple[int, int] = (7, 7) + grid_size: Tuple[int, int] = (7, 7) + init_values: Optional[float] = None + act_layer: str = 'gelu' + norm_layer: str = 'layernorm2d' + norm_layer_cl: str = 'layernorm' + norm_eps: float = 1e-6 + + +@dataclass +class MaxxVitConvCfg: + block_type: str = 'mbconv' + expand_ratio: float = 4.0 + expand_output: bool = True # calculate expansion channels from output (vs input chs) + kernel_size: int = 3 + group_size: int = 1 # 1 == depthwise + pre_norm_act: bool = False # activation after pre-norm + output_bias: bool = True # bias for shortcut + final 1x1 projection conv + stride_mode: str = 'dw' # stride done via one of 'pool', '1x1', 'dw' + pool_type: str = 'avg' + downsample_pool_type: str = 'avg2' + attn_early: bool = False # apply attn between conv2 and norm2, instead of after norm2 + attn_layer: str = 'se' + attn_act_layer: str = 'silu' + attn_ratio: float = 0.25 + init_values: Optional[float] = 1e-5 # for ConvNeXt block + act_layer: str = 'gelu' + norm_layer: str = '' + norm_layer_cl: str = '' + norm_eps: Optional[float] = None + + def __post_init__(self): + # mbconv vs convnext blocks have different defaults, set in post_init to avoid explicit config args + assert self.block_type in ('mbconv', 'convnext') + use_mbconv = self.block_type == 'mbconv' + if not self.norm_layer: + self.norm_layer = 'batchnorm2d' if use_mbconv else 'layernorm2d' + if not self.norm_layer_cl and not use_mbconv: + self.norm_layer_cl = 'layernorm' + if self.norm_eps is None: + self.norm_eps = 1e-5 if use_mbconv else 1e-6 + self.downsample_pool_type = self.downsample_pool_type or self.pool_type + + +@dataclass +class MaxxVitCfg: + embed_dim: Tuple[int, ...] = (96, 192, 384, 768) + depths: Tuple[int, ...] = (2, 3, 5, 2) + block_type: Tuple[Union[str, Tuple[str, ...]], ...] = ('C', 'C', 'T', 'T') + stem_width: Union[int, Tuple[int, int]] = 64 + stem_bias: bool = True + conv_cfg: MaxxVitConvCfg = MaxxVitConvCfg() + transformer_cfg: MaxxVitTransformerCfg = MaxxVitTransformerCfg() + weight_init: str = 'vit_eff' + + +def _rw_coat_cfg( + stride_mode='pool', + pool_type='avg2', + conv_output_bias=False, + conv_attn_early=False, + conv_norm_layer='', + transformer_shortcut_bias=True, + transformer_norm_layer='layernorm2d', + transformer_norm_layer_cl='layernorm', + rel_pos_type='bias', + rel_pos_dim=512, +): + # 'RW' timm variant models were created and trained before seeing https://github.com/google-research/maxvit + # Common differences for initial timm models: + # - pre-norm layer in MZBConv included an activation after norm + # - mbconv expansion calculated from input instead of output chs + # - mbconv shortcut and final 1x1 conv did not have a bias + # - SE act layer was relu, not silu + # - mbconv uses silu in timm, not gelu + # - expansion in attention block done via output proj, not input proj + # Variable differences (evolved over training initial models): + # - avg pool with kernel_size=2 favoured downsampling (instead of maxpool for coat) + # - SE attention was between conv2 and norm/act + # - default to avg pool for mbconv downsample instead of 1x1 or dw conv + # - transformer block shortcut has no bias + return dict( + conv_cfg=MaxxVitConvCfg( + stride_mode=stride_mode, + pool_type=pool_type, + pre_norm_act=True, + expand_output=False, + output_bias=conv_output_bias, + attn_early=conv_attn_early, + attn_act_layer='relu', + act_layer='silu', + norm_layer=conv_norm_layer, + ), + transformer_cfg=MaxxVitTransformerCfg( + expand_first=False, + shortcut_bias=transformer_shortcut_bias, + pool_type=pool_type, + norm_layer=transformer_norm_layer, + norm_layer_cl=transformer_norm_layer_cl, + rel_pos_type=rel_pos_type, + rel_pos_dim=rel_pos_dim, + ), + ) + + +def _rw_max_cfg( + stride_mode='dw', + pool_type='avg', + conv_output_bias=False, + conv_attn_ratio=1 / 16, + conv_norm_layer='', + transformer_norm_layer='layernorm2d', + transformer_norm_layer_cl='layernorm', + window_size=7, + dim_head=32, + rel_pos_type='bias', + rel_pos_dim=512, +): + # 'RW' timm variant models were created and trained before seeing https://github.com/google-research/maxvit + # Differences of initial timm models: + # - mbconv expansion calculated from input instead of output chs + # - mbconv shortcut and final 1x1 conv did not have a bias + # - mbconv uses silu in timm, not gelu + # - avg pool with kernel_size=2 favoured downsampling (instead of maxpool for coat) + # - default to avg pool for mbconv downsample instead of 1x1 or dw conv + # - expansion in attention block done via output proj, not input proj + return dict( + conv_cfg=MaxxVitConvCfg( + stride_mode=stride_mode, + pool_type=pool_type, + expand_output=False, + output_bias=conv_output_bias, + attn_ratio=conv_attn_ratio, + act_layer='silu', + norm_layer=conv_norm_layer, + ), + transformer_cfg=MaxxVitTransformerCfg( + expand_first=False, + pool_type=pool_type, + dim_head=dim_head, + window_size=to_2tuple(window_size), + grid_size=to_2tuple(window_size), + norm_layer=transformer_norm_layer, + norm_layer_cl=transformer_norm_layer_cl, + rel_pos_type=rel_pos_type, + rel_pos_dim=rel_pos_dim, + ), + ) + + +def _next_cfg( + stride_mode='dw', + pool_type='avg2', + conv_norm_layer='layernorm2d', + conv_norm_layer_cl='layernorm', + transformer_norm_layer='layernorm2d', + transformer_norm_layer_cl='layernorm', + window_size=7, + rel_pos_type='bias', + rel_pos_dim=512, +): + # For experimental models with convnext instead of mbconv + return dict( + conv_cfg=MaxxVitConvCfg( + block_type='convnext', + stride_mode=stride_mode, + pool_type=pool_type, + expand_output=False, + norm_layer=conv_norm_layer, + norm_layer_cl=conv_norm_layer_cl, + ), + transformer_cfg=MaxxVitTransformerCfg( + expand_first=False, + pool_type=pool_type, + window_size=to_2tuple(window_size), + grid_size=to_2tuple(window_size), + norm_layer=transformer_norm_layer, + norm_layer_cl=transformer_norm_layer_cl, + rel_pos_type=rel_pos_type, + rel_pos_dim=rel_pos_dim, + ), + ) + + +model_cfgs = dict( + # Fiddling with configs / defaults / still pretraining + coatnet_pico_rw_224=MaxxVitCfg( + embed_dim=(64, 128, 256, 512), + depths=(2, 3, 5, 2), + stem_width=(32, 64), + **_rw_max_cfg( # using newer max defaults here + pool_type='avg2', + conv_output_bias=True, + conv_attn_ratio=0.25, + ), + ), + coatnet_nano_rw_224=MaxxVitCfg( + embed_dim=(64, 128, 256, 512), + depths=(3, 4, 6, 3), + stem_width=(32, 64), + **_rw_max_cfg( # using newer max defaults here + stride_mode='pool', + pool_type='avg2', + conv_output_bias=True, + conv_attn_ratio=0.25, + ), + ), + coatnet_0_rw_224=MaxxVitCfg( + embed_dim=(96, 192, 384, 768), + depths=(2, 3, 7, 2), # deeper than paper '0' model + stem_width=(32, 64), + **_rw_coat_cfg( + conv_attn_early=True, + transformer_shortcut_bias=False, + ), + ), + coatnet_1_rw_224=MaxxVitCfg( + embed_dim=(96, 192, 384, 768), + depths=(2, 6, 14, 2), + stem_width=(32, 64), + **_rw_coat_cfg( + stride_mode='dw', + conv_attn_early=True, + transformer_shortcut_bias=False, + ) + ), + coatnet_2_rw_224=MaxxVitCfg( + embed_dim=(128, 256, 512, 1024), + depths=(2, 6, 14, 2), + stem_width=(64, 128), + **_rw_coat_cfg(stride_mode='dw'), + ), + + # Highly experimental configs + coatnet_bn_0_rw_224=MaxxVitCfg( + embed_dim=(96, 192, 384, 768), + depths=(2, 3, 7, 2), # deeper than paper '0' model + stem_width=(32, 64), + **_rw_coat_cfg( + stride_mode='dw', + conv_attn_early=True, + transformer_shortcut_bias=False, + transformer_norm_layer='batchnorm2d', + ) + ), + coatnet_rmlp_nano_rw_224=MaxxVitCfg( + embed_dim=(64, 128, 256, 512), + depths=(3, 4, 6, 3), + stem_width=(32, 64), + **_rw_max_cfg( + pool_type='avg2', + conv_output_bias=True, + conv_attn_ratio=0.25, + rel_pos_type='mlp', + rel_pos_dim=384, + ), + ), + coatnet_rmlp_0_rw_224=MaxxVitCfg( + embed_dim=(96, 192, 384, 768), + depths=(2, 3, 7, 2), # deeper than paper '0' model + stem_width=(32, 64), + **_rw_coat_cfg( + stride_mode='dw', + rel_pos_type='mlp', + ), + ), + coatnet_rmlp_1_rw_224=MaxxVitCfg( + embed_dim=(96, 192, 384, 768), + depths=(2, 6, 14, 2), + stem_width=(32, 64), + **_rw_coat_cfg( + pool_type='max', + conv_attn_early=True, + transformer_shortcut_bias=False, + rel_pos_type='mlp', + rel_pos_dim=384, # was supposed to be 512, woops + ), + ), + coatnext_nano_rw_224=MaxxVitCfg( + embed_dim=(64, 128, 256, 512), + depths=(3, 4, 6, 3), + stem_width=(32, 64), + **_next_cfg(), + ), + coatnet_nano_cc_224=MaxxVitCfg( + embed_dim=(64, 128, 256, 512), + depths=(3, 4, 6, 3), + stem_width=(32, 64), + block_type=('C', 'C', ('C', 'T'), ('C', 'T')), + **_rw_coat_cfg(), + ), + + # Trying to be like the CoAtNet paper configs + coatnet_0_224=MaxxVitCfg( + embed_dim=(96, 192, 384, 768), + depths=(2, 3, 5, 2), + stem_width=64, + ), + coatnet_1_224=MaxxVitCfg( + embed_dim=(96, 192, 384, 768), + depths=(2, 6, 14, 2), + stem_width=64, + ), + coatnet_2_224=MaxxVitCfg( + embed_dim=(128, 256, 512, 1024), + depths=(2, 6, 14, 2), + stem_width=128, + ), + coatnet_3_224=MaxxVitCfg( + embed_dim=(192, 384, 768, 1536), + depths=(2, 6, 14, 2), + stem_width=192, + ), + coatnet_4_224=MaxxVitCfg( + embed_dim=(192, 384, 768, 1536), + depths=(2, 12, 28, 2), + stem_width=192, + ), + coatnet_5_224=MaxxVitCfg( + embed_dim=(256, 512, 1280, 2048), + depths=(2, 12, 28, 2), + stem_width=192, + ), + + # Experimental MaxVit configs + maxvit_pico_rw_256=MaxxVitCfg( + embed_dim=(32, 64, 128, 256), + depths=(2, 2, 5, 2), + block_type=('M',) * 4, + stem_width=(24, 32), + **_rw_max_cfg(window_size=8), + ), + maxvit_nano_rw_256=MaxxVitCfg( + embed_dim=(64, 128, 256, 512), + depths=(1, 2, 3, 1), + block_type=('M',) * 4, + stem_width=(32, 64), + **_rw_max_cfg(window_size=8), + ), + maxvit_tiny_rw_224=MaxxVitCfg( + embed_dim=(64, 128, 256, 512), + depths=(2, 2, 5, 2), + block_type=('M',) * 4, + stem_width=(32, 64), + **_rw_max_cfg(), + ), + maxvit_tiny_rw_256=MaxxVitCfg( + embed_dim=(64, 128, 256, 512), + depths=(2, 2, 5, 2), + block_type=('M',) * 4, + stem_width=(32, 64), + **_rw_max_cfg(window_size=8), + ), + maxvit_tiny_cm_256=MaxxVitCfg( + embed_dim=(64, 128, 256, 512), + depths=(2, 2, 5, 2), + block_type=('CM',) * 4, + stem_width=(32, 64), + **_rw_max_cfg(window_size=8), + ), + maxxvit_nano_rw_256=MaxxVitCfg( + embed_dim=(64, 128, 256, 512), + depths=(1, 2, 3, 1), + block_type=('M',) * 4, + stem_width=(32, 64), + **_next_cfg(window_size=8), + ), + + # Trying to be like the MaxViT paper configs + maxvit_tiny_224=MaxxVitCfg( + embed_dim=(64, 128, 256, 512), + depths=(2, 2, 5, 2), + block_type=('M',) * 4, + stem_width=64, + ), + maxvit_small_224=MaxxVitCfg( + embed_dim=(96, 192, 384, 768), + depths=(2, 2, 5, 2), + block_type=('M',) * 4, + stem_width=64, + ), + maxvit_base_224=MaxxVitCfg( + embed_dim=(96, 192, 384, 768), + depths=(2, 6, 14, 2), + block_type=('M',) * 4, + stem_width=64, + ), + maxvit_large_224=MaxxVitCfg( + embed_dim=(128, 256, 512, 1024), + depths=(2, 6, 14, 2), + block_type=('M',) * 4, + stem_width=128, + ), + maxvit_xlarge_224=MaxxVitCfg( + embed_dim=(192, 384, 768, 1536), + depths=(2, 6, 14, 2), + block_type=('M',) * 4, + stem_width=192, + ), + +) + + +class Attention2d(nn.Module): + """ multi-head attention for 2D NCHW tensors""" + def __init__( + self, + dim: int, + dim_out: Optional[int] = None, + dim_head: int = 32, + bias: bool = True, + expand_first: bool = True, + rel_pos_cls: Callable = None, + attn_drop: float = 0., + proj_drop: float = 0. + ): + super().__init__() + dim_out = dim_out or dim + dim_attn = dim_out if expand_first else dim + self.num_heads = dim_attn // dim_head + self.dim_head = dim_head + self.scale = dim_head ** -0.5 + + self.qkv = nn.Conv2d(dim, dim_attn * 3, 1, bias=bias) + self.rel_pos = rel_pos_cls(num_heads=self.num_heads) if rel_pos_cls else None + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Conv2d(dim_attn, dim_out, 1, bias=bias) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x, shared_rel_pos: Optional[torch.Tensor] = None): + B, C, H, W = x.shape + + q, k, v = self.qkv(x).view(B, self.num_heads, self.dim_head * 3, -1).chunk(3, dim=2) + + attn = (q.transpose(-2, -1) @ k) * self.scale + if self.rel_pos is not None: + attn = self.rel_pos(attn) + elif shared_rel_pos is not None: + attn = attn + shared_rel_pos + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (v @ attn.transpose(-2, -1)).view(B, -1, H, W) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class AttentionCl(nn.Module): + """ Channels-last multi-head attention (B, ..., C) """ + def __init__( + self, + dim: int, + dim_out: Optional[int] = None, + dim_head: int = 32, + bias: bool = True, + expand_first: bool = True, + rel_pos_cls: Callable = None, + attn_drop: float = 0., + proj_drop: float = 0. + ): + super().__init__() + dim_out = dim_out or dim + dim_attn = dim_out if expand_first and dim_out > dim else dim + assert dim_attn % dim_head == 0, 'attn dim should be divisible by head_dim' + self.num_heads = dim_attn // dim_head + self.dim_head = dim_head + self.scale = dim_head ** -0.5 + + self.qkv = nn.Linear(dim, dim_attn * 3, bias=bias) + self.rel_pos = rel_pos_cls(num_heads=self.num_heads) if rel_pos_cls else None + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim_attn, dim_out, bias=bias) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x, shared_rel_pos: Optional[torch.Tensor] = None): + B = x.shape[0] + restore_shape = x.shape[:-1] + + q, k, v = self.qkv(x).view(B, -1, self.num_heads, self.dim_head * 3).transpose(1, 2).chunk(3, dim=3) + + attn = (q @ k.transpose(-2, -1)) * self.scale + if self.rel_pos is not None: + attn = self.rel_pos(attn, shared_rel_pos=shared_rel_pos) + elif shared_rel_pos is not None: + attn = attn + shared_rel_pos + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(restore_shape + (-1,)) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class LayerScale(nn.Module): + def __init__(self, dim, init_values=1e-5, inplace=False): + super().__init__() + self.inplace = inplace + self.gamma = nn.Parameter(init_values * torch.ones(dim)) + + def forward(self, x): + gamma = self.gamma + return x.mul_(gamma) if self.inplace else x * gamma + + +class LayerScale2d(nn.Module): + def __init__(self, dim, init_values=1e-5, inplace=False): + super().__init__() + self.inplace = inplace + self.gamma = nn.Parameter(init_values * torch.ones(dim)) + + def forward(self, x): + gamma = self.gamma.view(1, -1, 1, 1) + return x.mul_(gamma) if self.inplace else x * gamma + + +class Downsample2d(nn.Module): + """ A downsample pooling module for Coat that handles 2d <-> 1d conversion + """ + + def __init__( + self, + dim: int, + dim_out: int, + pool_type: str = 'avg2', + bias: bool = True, + ): + super().__init__() + assert pool_type in ('max', 'avg', 'avg2') + if pool_type == 'max': + self.pool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + elif pool_type == 'avg': + self.pool = nn.AvgPool2d(kernel_size=3, stride=2, padding=1, count_include_pad=False) + else: + self.pool = nn.AvgPool2d(2) + + if dim != dim_out: + self.expand = nn.Conv2d(dim, dim_out, 1, bias=bias) + else: + self.expand = nn.Identity() + + def forward(self, x): + x = self.pool(x) # spatial downsample + x = self.expand(x) # expand chs + return x + + +def _init_transformer(module, name, scheme=''): + if isinstance(module, (nn.Conv2d, nn.Linear)): + if scheme == 'normal': + nn.init.normal_(module.weight, std=.02) + if module.bias is not None: + nn.init.zeros_(module.bias) + elif scheme == 'trunc_normal': + trunc_normal_tf_(module.weight, std=.02) + if module.bias is not None: + nn.init.zeros_(module.bias) + elif scheme == 'xavier_normal': + nn.init.xavier_normal_(module.weight) + if module.bias is not None: + nn.init.zeros_(module.bias) + else: + # vit like + nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + if 'mlp' in name: + nn.init.normal_(module.bias, std=1e-6) + else: + nn.init.zeros_(module.bias) + + +class TransformerBlock2d(nn.Module): + """ Transformer block with 2D downsampling + '2D' NCHW tensor layout + """ + + def __init__( + self, + dim: int, + dim_out: int, + stride: int = 1, + rel_pos_cls: Callable = None, + cfg: MaxxVitTransformerCfg = MaxxVitTransformerCfg(), + drop_path: float = 0., + ): + super().__init__() + norm_layer = partial(get_norm_layer(cfg.norm_layer), eps=cfg.norm_eps) + act_layer = get_act_layer(cfg.act_layer) + + if stride == 2: + self.shortcut = Downsample2d(dim, dim_out, pool_type=cfg.pool_type, bias=cfg.shortcut_bias) + self.norm1 = nn.Sequential(OrderedDict([ + ('norm', norm_layer(dim)), + ('down', Downsample2d(dim, dim, pool_type=cfg.pool_type)), + ])) + else: + assert dim == dim_out + self.shortcut = nn.Identity() + self.norm1 = norm_layer(dim) + + self.attn = Attention2d( + dim, + dim_out, + dim_head=cfg.dim_head, + expand_first=cfg.expand_first, + bias=cfg.attn_bias, + rel_pos_cls=rel_pos_cls, + attn_drop=cfg.attn_drop, + proj_drop=cfg.proj_drop + ) + self.ls1 = LayerScale2d(dim_out, init_values=cfg.init_values) if cfg.init_values else nn.Identity() + self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity() + + self.norm2 = norm_layer(dim_out) + self.mlp = ConvMlp( + in_features=dim_out, + hidden_features=int(dim_out * cfg.expand_ratio), + act_layer=act_layer, + drop=cfg.proj_drop) + self.ls2 = LayerScale2d(dim_out, init_values=cfg.init_values) if cfg.init_values else nn.Identity() + self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity() + + def init_weights(self, scheme=''): + named_apply(partial(_init_transformer, scheme=scheme), self) + + def forward(self, x, shared_rel_pos: Optional[torch.Tensor] = None): + x = self.shortcut(x) + self.drop_path1(self.ls1(self.attn(self.norm1(x), shared_rel_pos=shared_rel_pos))) + x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x)))) + return x + + +def _init_conv(module, name, scheme=''): + if isinstance(module, nn.Conv2d): + if scheme == 'normal': + nn.init.normal_(module.weight, std=.02) + if module.bias is not None: + nn.init.zeros_(module.bias) + elif scheme == 'trunc_normal': + trunc_normal_tf_(module.weight, std=.02) + if module.bias is not None: + nn.init.zeros_(module.bias) + elif scheme == 'xavier_normal': + nn.init.xavier_normal_(module.weight) + if module.bias is not None: + nn.init.zeros_(module.bias) + else: + # efficientnet like + fan_out = module.kernel_size[0] * module.kernel_size[1] * module.out_channels + fan_out //= module.groups + nn.init.normal_(module.weight, 0, math.sqrt(2.0 / fan_out)) + if module.bias is not None: + nn.init.zeros_(module.bias) + + +def num_groups(group_size, channels): + if not group_size: # 0 or None + return 1 # normal conv with 1 group + else: + # NOTE group_size == 1 -> depthwise conv + assert channels % group_size == 0 + return channels // group_size + + +class MbConvBlock(nn.Module): + """ Pre-Norm Conv Block - 1x1 - kxk - 1x1, w/ inverted bottleneck (expand) + """ + def __init__( + self, + in_chs: int, + out_chs: int, + stride: int = 1, + dilation: Tuple[int, int] = (1, 1), + cfg: MaxxVitConvCfg = MaxxVitConvCfg(), + drop_path: float = 0. + ): + super(MbConvBlock, self).__init__() + norm_act_layer = partial(get_norm_act_layer(cfg.norm_layer, cfg.act_layer), eps=cfg.norm_eps) + mid_chs = make_divisible((out_chs if cfg.expand_output else in_chs) * cfg.expand_ratio) + groups = num_groups(cfg.group_size, mid_chs) + + if stride == 2: + self.shortcut = Downsample2d(in_chs, out_chs, pool_type=cfg.pool_type, bias=cfg.output_bias) + else: + self.shortcut = nn.Identity() + + assert cfg.stride_mode in ('pool', '1x1', 'dw') + stride_pool, stride_1, stride_2 = 1, 1, 1 + if cfg.stride_mode == 'pool': + # NOTE this is not described in paper, experiment to find faster option that doesn't stride in 1x1 + stride_pool, dilation_2 = stride, dilation[1] + # FIXME handle dilation of avg pool + elif cfg.stride_mode == '1x1': + # NOTE I don't like this option described in paper, 1x1 w/ stride throws info away + stride_1, dilation_2 = stride, dilation[1] + else: + stride_2, dilation_2 = stride, dilation[0] + + self.pre_norm = norm_act_layer(in_chs, apply_act=cfg.pre_norm_act) + if stride_pool > 1: + self.down = Downsample2d(in_chs, in_chs, pool_type=cfg.downsample_pool_type) + else: + self.down = nn.Identity() + self.conv1_1x1 = create_conv2d(in_chs, mid_chs, 1, stride=stride_1) + self.norm1 = norm_act_layer(mid_chs) + + self.conv2_kxk = create_conv2d( + mid_chs, mid_chs, cfg.kernel_size, stride=stride_2, dilation=dilation_2, groups=groups) + + attn_kwargs = {} + if isinstance(cfg.attn_layer, str): + if cfg.attn_layer == 'se' or cfg.attn_layer == 'eca': + attn_kwargs['act_layer'] = cfg.attn_act_layer + attn_kwargs['rd_channels'] = int(cfg.attn_ratio * (out_chs if cfg.expand_output else mid_chs)) + + # two different orderings for SE and norm2 (due to some weights and trials using SE before norm2) + if cfg.attn_early: + self.se_early = create_attn(cfg.attn_layer, mid_chs, **attn_kwargs) + self.norm2 = norm_act_layer(mid_chs) + self.se = None + else: + self.se_early = None + self.norm2 = norm_act_layer(mid_chs) + self.se = create_attn(cfg.attn_layer, mid_chs, **attn_kwargs) + + self.conv3_1x1 = create_conv2d(mid_chs, out_chs, 1, bias=cfg.output_bias) + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + + def init_weights(self, scheme=''): + named_apply(partial(_init_conv, scheme=scheme), self) + + def forward(self, x): + shortcut = self.shortcut(x) + x = self.pre_norm(x) + x = self.down(x) + + # 1x1 expansion conv & norm-act + x = self.conv1_1x1(x) + x = self.norm1(x) + + # depthwise / grouped 3x3 conv w/ SE (or other) channel attention & norm-act + x = self.conv2_kxk(x) + if self.se_early is not None: + x = self.se_early(x) + x = self.norm2(x) + if self.se is not None: + x = self.se(x) + + # 1x1 linear projection to output width + x = self.conv3_1x1(x) + x = self.drop_path(x) + shortcut + return x + + +class ConvNeXtBlock(nn.Module): + """ ConvNeXt Block + """ + + def __init__( + self, + in_chs: int, + out_chs: Optional[int] = None, + kernel_size: int = 7, + stride: int = 1, + dilation: Tuple[int, int] = (1, 1), + cfg: MaxxVitConvCfg = MaxxVitConvCfg(), + conv_mlp: bool = True, + drop_path: float = 0. + ): + super().__init__() + out_chs = out_chs or in_chs + act_layer = get_act_layer(cfg.act_layer) + if conv_mlp: + norm_layer = partial(get_norm_layer(cfg.norm_layer), eps=cfg.norm_eps) + mlp_layer = ConvMlp + else: + assert 'layernorm' in cfg.norm_layer + norm_layer = LayerNorm + mlp_layer = Mlp + self.use_conv_mlp = conv_mlp + + if stride == 2: + self.shortcut = Downsample2d(in_chs, out_chs) + elif in_chs != out_chs: + self.shortcut = nn.Conv2d(in_chs, out_chs, kernel_size=1, bias=cfg.output_bias) + else: + self.shortcut = nn.Identity() + + assert cfg.stride_mode in ('pool', 'dw') + stride_pool, stride_dw = 1, 1 + # FIXME handle dilation? + if cfg.stride_mode == 'pool': + stride_pool = stride + else: + stride_dw = stride + + if stride_pool == 2: + self.down = Downsample2d(in_chs, in_chs, pool_type=cfg.downsample_pool_type) + else: + self.down = nn.Identity() + + self.conv_dw = create_conv2d( + in_chs, out_chs, kernel_size=kernel_size, stride=stride_dw, dilation=dilation[1], + depthwise=True, bias=cfg.output_bias) + self.norm = norm_layer(out_chs) + self.mlp = mlp_layer(out_chs, int(cfg.expand_ratio * out_chs), bias=cfg.output_bias, act_layer=act_layer) + if conv_mlp: + self.ls = LayerScale2d(out_chs, cfg.init_values) if cfg.init_values else nn.Identity() + else: + self.ls = LayerScale(out_chs, cfg.init_values) if cfg.init_values else nn.Identity() + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + + def forward(self, x): + shortcut = self.shortcut(x) + x = self.down(x) + x = self.conv_dw(x) + if self.use_conv_mlp: + x = self.norm(x) + x = self.mlp(x) + x = self.ls(x) + else: + x = x.permute(0, 2, 3, 1) + x = self.norm(x) + x = self.mlp(x) + x = self.ls(x) + x = x.permute(0, 3, 1, 2) + + x = self.drop_path(x) + shortcut + return x + + +def window_partition(x, window_size: List[int]): + B, H, W, C = x.shape + _assert(H % window_size[0] == 0, f'height ({H}) must be divisible by window ({window_size[0]})') + _assert(W % window_size[1] == 0, '') + x = x.view(B, H // window_size[0], window_size[0], W // window_size[1], window_size[1], C) + windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size[0], window_size[1], C) + return windows + + +@register_notrace_function # reason: int argument is a Proxy +def window_reverse(windows, window_size: List[int], img_size: List[int]): + H, W = img_size + C = windows.shape[-1] + x = windows.view(-1, H // window_size[0], W // window_size[1], window_size[0], window_size[1], C) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, H, W, C) + return x + + +def grid_partition(x, grid_size: List[int]): + B, H, W, C = x.shape + _assert(H % grid_size[0] == 0, f'height {H} must be divisible by grid {grid_size[0]}') + _assert(W % grid_size[1] == 0, '') + x = x.view(B, grid_size[0], H // grid_size[0], grid_size[1], W // grid_size[1], C) + windows = x.permute(0, 2, 4, 1, 3, 5).contiguous().view(-1, grid_size[0], grid_size[1], C) + return windows + + +@register_notrace_function # reason: int argument is a Proxy +def grid_reverse(windows, grid_size: List[int], img_size: List[int]): + H, W = img_size + C = windows.shape[-1] + x = windows.view(-1, H // grid_size[0], W // grid_size[1], grid_size[0], grid_size[1], C) + x = x.permute(0, 3, 1, 4, 2, 5).contiguous().view(-1, H, W, C) + return x + + +def get_rel_pos_cls(cfg: MaxxVitTransformerCfg, window_size): + rel_pos_cls = None + if cfg.rel_pos_type == 'mlp': + rel_pos_cls = partial(RelPosMlp, window_size=window_size, hidden_dim=cfg.rel_pos_dim) + elif cfg.rel_pos_type == 'bias': + rel_pos_cls = partial(RelPosBias, window_size=window_size) + return rel_pos_cls + + +class PartitionAttention(nn.Module): + """ Grid or Block partition + Attn + FFN. + NxC tensor layout. + """ + + def __init__( + self, + dim: int, + partition_type: str = 'block', + cfg: MaxxVitTransformerCfg = MaxxVitTransformerCfg(), + drop_path: float = 0., + ): + super().__init__() + norm_layer = partial(get_norm_layer(cfg.norm_layer_cl), eps=cfg.norm_eps) # NOTE this block is channels-last + act_layer = get_act_layer(cfg.act_layer) + + self.partition_block = partition_type == 'block' + self.partition_size = to_2tuple(cfg.window_size if self.partition_block else cfg.grid_size) + rel_pos_cls = get_rel_pos_cls(cfg, self.partition_size) + + self.norm1 = norm_layer(dim) + self.attn = AttentionCl( + dim, + dim, + dim_head=cfg.dim_head, + bias=cfg.attn_bias, + rel_pos_cls=rel_pos_cls, + attn_drop=cfg.attn_drop, + proj_drop=cfg.proj_drop, + ) + self.ls1 = LayerScale(dim, init_values=cfg.init_values) if cfg.init_values else nn.Identity() + self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity() + + self.norm2 = norm_layer(dim) + self.mlp = Mlp( + in_features=dim, + hidden_features=int(dim * cfg.expand_ratio), + act_layer=act_layer, + drop=cfg.proj_drop) + self.ls2 = LayerScale(dim, init_values=cfg.init_values) if cfg.init_values else nn.Identity() + self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity() + + def _partition_attn(self, x): + C = x.shape[-1] + img_size = x.shape[1:3] + if self.partition_block: + partitioned = window_partition(x, self.partition_size) + else: + partitioned = grid_partition(x, self.partition_size) + + partitioned = self.attn(partitioned) + + if self.partition_block: + x = window_reverse(partitioned, self.partition_size, img_size) + else: + x = grid_reverse(partitioned, self.partition_size, img_size) + return x + + def forward(self, x): + x = x + self.drop_path1(self.ls1(self._partition_attn(self.norm1(x)))) + x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x)))) + return x + + +class CombinedPartitionAttention(nn.Module): + """ Experimental. Grid and Block partition + single FFN + NxC tensor layout. + """ + + def __init__( + self, + dim: int, + cfg: MaxxVitTransformerCfg = MaxxVitTransformerCfg(), + drop_path: float = 0., + ): + super().__init__() + assert dim % 2 == 0 + norm_layer = partial(get_norm_layer(cfg.norm_layer_cl), eps=cfg.norm_eps) # NOTE this block is channels-last + act_layer = get_act_layer(cfg.act_layer) + + assert cfg.window_size == cfg.grid_size + self.partition_size = to_2tuple(cfg.window_size) + rel_pos_cls = get_rel_pos_cls(cfg, self.partition_size) + + self.norm1 = norm_layer(dim) + self.attn_block = AttentionCl( + dim, + dim // 2, + dim_head=cfg.dim_head, + bias=cfg.attn_bias, + rel_pos_cls=rel_pos_cls, + attn_drop=cfg.attn_drop, + proj_drop=cfg.proj_drop, + ) + self.attn_grid = AttentionCl( + dim, + dim // 2, + dim_head=cfg.dim_head, + bias=cfg.attn_bias, + rel_pos_cls=rel_pos_cls, + attn_drop=cfg.attn_drop, + proj_drop=cfg.proj_drop, + ) + self.ls1 = LayerScale(dim, init_values=cfg.init_values) if cfg.init_values else nn.Identity() + self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity() + + self.norm2 = norm_layer(dim) + self.mlp = Mlp( + in_features=dim, + hidden_features=int(dim * cfg.expand_ratio), + out_features=dim, + act_layer=act_layer, + drop=cfg.proj_drop) + self.ls2 = LayerScale(dim, init_values=cfg.init_values) if cfg.init_values else nn.Identity() + self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity() + + def _partition_attn(self, x): + img_size = x.shape[1:3] + + partitioned_block = window_partition(x, self.partition_size) + partitioned_block = self.attn_block(partitioned_block) + x_window = window_reverse(partitioned_block, self.partition_size, img_size) + + partitioned_grid = grid_partition(x, self.partition_size) + partitioned_grid = self.attn_grid(partitioned_grid) + x_grid = grid_reverse(partitioned_grid, self.partition_size, img_size) + + return torch.cat([x_window, x_grid], dim=-1) + + def forward(self, x): + x = x + self.drop_path1(self.ls1(self._partition_attn(self.norm1(x)))) + x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x)))) + return x + + +def window_partition_nchw(x, window_size: List[int]): + B, C, H, W = x.shape + _assert(H % window_size[0] == 0, f'height ({H}) must be divisible by window ({window_size[0]})') + _assert(W % window_size[1] == 0, '') + x = x.view(B, C, H // window_size[0], window_size[0], W // window_size[1], window_size[1]) + windows = x.permute(0, 2, 4, 1, 3, 5).contiguous().view(-1, C, window_size[0], window_size[1]) + return windows + + +@register_notrace_function # reason: int argument is a Proxy +def window_reverse_nchw(windows, window_size: List[int], img_size: List[int]): + H, W = img_size + C = windows.shape[1] + x = windows.view(-1, H // window_size[0], W // window_size[1], C, window_size[0], window_size[1]) + x = x.permute(0, 3, 1, 4, 2, 5).contiguous().view(-1, C, H, W) + return x + + +def grid_partition_nchw(x, grid_size: List[int]): + B, C, H, W = x.shape + _assert(H % grid_size[0] == 0, f'height {H} must be divisible by grid {grid_size[0]}') + _assert(W % grid_size[1] == 0, '') + x = x.view(B, C, grid_size[0], H // grid_size[0], grid_size[1], W // grid_size[1]) + windows = x.permute(0, 3, 5, 1, 2, 4).contiguous().view(-1, C, grid_size[0], grid_size[1]) + return windows + + +@register_notrace_function # reason: int argument is a Proxy +def grid_reverse_nchw(windows, grid_size: List[int], img_size: List[int]): + H, W = img_size + C = windows.shape[1] + x = windows.view(-1, H // grid_size[0], W // grid_size[1], C, grid_size[0], grid_size[1]) + x = x.permute(0, 3, 4, 1, 5, 2).contiguous().view(-1, C, H, W) + return x + + +class PartitionAttention2d(nn.Module): + """ Grid or Block partition + Attn + FFN + '2D' NCHW tensor layout. + """ + + def __init__( + self, + dim: int, + partition_type: str = 'block', + cfg: MaxxVitTransformerCfg = MaxxVitTransformerCfg(), + drop_path: float = 0., + ): + super().__init__() + norm_layer = partial(get_norm_layer(cfg.norm_layer), eps=cfg.norm_eps) # NOTE this block is channels-last + act_layer = get_act_layer(cfg.act_layer) + + self.partition_block = partition_type == 'block' + self.partition_size = to_2tuple(cfg.window_size if self.partition_block else cfg.grid_size) + rel_pos_cls = get_rel_pos_cls(cfg, self.partition_size) + + self.norm1 = norm_layer(dim) + self.attn = Attention2d( + dim, + dim, + dim_head=cfg.dim_head, + bias=cfg.attn_bias, + rel_pos_cls=rel_pos_cls, + attn_drop=cfg.attn_drop, + proj_drop=cfg.proj_drop, + ) + self.ls1 = LayerScale2d(dim, init_values=cfg.init_values) if cfg.init_values else nn.Identity() + self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity() + + self.norm2 = norm_layer(dim) + self.mlp = ConvMlp( + in_features=dim, + hidden_features=int(dim * cfg.expand_ratio), + act_layer=act_layer, + drop=cfg.proj_drop) + self.ls2 = LayerScale2d(dim, init_values=cfg.init_values) if cfg.init_values else nn.Identity() + self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity() + + def _partition_attn(self, x): + img_size = x.shape[-2:] + if self.partition_block: + partitioned = window_partition_nchw(x, self.partition_size) + else: + partitioned = grid_partition_nchw(x, self.partition_size) + + partitioned = self.attn(partitioned) + + if self.partition_block: + x = window_reverse_nchw(partitioned, self.partition_size, img_size) + else: + x = grid_reverse_nchw(partitioned, self.partition_size, img_size) + return x + + def forward(self, x): + x = x + self.drop_path1(self.ls1(self._partition_attn(self.norm1(x)))) + x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x)))) + return x + + +class MaxxVitBlock(nn.Module): + """ + """ + + def __init__( + self, + dim: int, + dim_out: int, + stride: int = 1, + conv_cfg: MaxxVitConvCfg = MaxxVitConvCfg(), + transformer_cfg: MaxxVitTransformerCfg = MaxxVitTransformerCfg(), + use_nchw_attn: bool = False, # FIXME move to cfg? True is ~20-30% faster on TPU, 5-10% slower on GPU + drop_path: float = 0., + ): + super().__init__() + + conv_cls = ConvNeXtBlock if conv_cfg.block_type == 'convnext' else MbConvBlock + self.conv = conv_cls(dim, dim_out, stride=stride, cfg=conv_cfg, drop_path=drop_path) + + attn_kwargs = dict(dim=dim_out, cfg=transformer_cfg, drop_path=drop_path) + partition_layer = PartitionAttention2d if use_nchw_attn else PartitionAttention + self.nchw_attn = use_nchw_attn + self.attn_block = partition_layer(**attn_kwargs) + self.attn_grid = partition_layer(partition_type='grid', **attn_kwargs) + + def init_weights(self, scheme=''): + named_apply(partial(_init_transformer, scheme=scheme), self.attn_block) + named_apply(partial(_init_transformer, scheme=scheme), self.attn_grid) + named_apply(partial(_init_conv, scheme=scheme), self.conv) + + def forward(self, x): + # NCHW format + x = self.conv(x) + + if not self.nchw_attn: + x = x.permute(0, 2, 3, 1) # to NHWC (channels-last) + x = self.attn_block(x) + x = self.attn_grid(x) + if not self.nchw_attn: + x = x.permute(0, 3, 1, 2) # back to NCHW + return x + + +class CombinedMaxxVitBlock(nn.Module): + """ + """ + + def __init__( + self, + dim, + dim_out, + stride=1, + num_conv=2, + conv_cfg: MaxxVitConvCfg = MaxxVitConvCfg(), + transformer_cfg: MaxxVitTransformerCfg = MaxxVitTransformerCfg(), + drop_path=0., + ): + super().__init__() + + conv_cls = ConvNeXtBlock if conv_cfg.block_type == 'convnext' else MbConvBlock + if num_conv > 1: + convs = [conv_cls(dim, dim_out, stride=stride, cfg=conv_cfg, drop_path=drop_path)] + convs += [conv_cls(dim_out, dim_out, cfg=conv_cfg, drop_path=drop_path)] * (num_conv - 1) + self.conv = nn.Sequential(*convs) + else: + self.conv = conv_cls(dim, dim_out, stride=stride, cfg=conv_cfg, drop_path=drop_path) + self.attn = CombinedPartitionAttention(dim=dim_out, cfg=transformer_cfg, drop_path=drop_path) + + def init_weights(self, scheme=''): + named_apply(partial(_init_transformer, scheme=scheme), self.attn) + named_apply(partial(_init_conv, scheme=scheme), self.conv) + + def forward(self, x): + x = self.conv(x) + x = x.permute(0, 2, 3, 1) + x = self.attn(x) + x = x.permute(0, 3, 1, 2) + return x + + +class MaxxVitStage(nn.Module): + def __init__( + self, + in_chs: int, + out_chs: int, + stride: int = 2, + depth: int = 4, + feat_size: Tuple[int, int] = (14, 14), + block_types: Union[str, Tuple[str]] = 'C', + transformer_cfg: MaxxVitTransformerCfg = MaxxVitTransformerCfg(), + conv_cfg: MaxxVitConvCfg = MaxxVitConvCfg(), + drop_path: Union[float, List[float]] = 0., + ): + super().__init__() + self.grad_checkpointing = False + + block_types = extend_tuple(block_types, depth) + blocks = [] + for i, t in enumerate(block_types): + block_stride = stride if i == 0 else 1 + assert t in ('C', 'T', 'M', 'CM') + if t == 'C': + conv_cls = ConvNeXtBlock if conv_cfg.block_type == 'convnext' else MbConvBlock + blocks += [conv_cls( + in_chs, + out_chs, + stride=block_stride, + cfg=conv_cfg, + drop_path=drop_path[i], + )] + elif t == 'T': + rel_pos_cls = get_rel_pos_cls(transformer_cfg, feat_size) + blocks += [TransformerBlock2d( + in_chs, + out_chs, + stride=block_stride, + rel_pos_cls=rel_pos_cls, + cfg=transformer_cfg, + drop_path=drop_path[i], + )] + elif t == 'M': + blocks += [MaxxVitBlock( + in_chs, + out_chs, + stride=block_stride, + conv_cfg=conv_cfg, + transformer_cfg=transformer_cfg, + drop_path=drop_path[i], + )] + elif t == 'CM': + blocks += [CombinedMaxxVitBlock( + in_chs, + out_chs, + stride=block_stride, + conv_cfg=conv_cfg, + transformer_cfg=transformer_cfg, + drop_path=drop_path[i], + )] + in_chs = out_chs + self.blocks = nn.Sequential(*blocks) + + def forward(self, x): + if self.grad_checkpointing and not torch.jit.is_scripting(): + x = checkpoint_seq(self.blocks, x) + else: + x = self.blocks(x) + return x + + +class Stem(nn.Module): + + def __init__( + self, + in_chs: int, + out_chs: int, + kernel_size: int = 3, + act_layer: str = 'gelu', + norm_layer: str = 'batchnorm2d', + norm_eps: float = 1e-5, + ): + super().__init__() + if not isinstance(out_chs, (list, tuple)): + out_chs = to_2tuple(out_chs) + + norm_act_layer = partial(get_norm_act_layer(norm_layer, act_layer), eps=norm_eps) + self.out_chs = out_chs[-1] + self.stride = 2 + + self.conv1 = create_conv2d(in_chs, out_chs[0], kernel_size, stride=2) + self.norm1 = norm_act_layer(out_chs[0]) + self.conv2 = create_conv2d(out_chs[0], out_chs[1], kernel_size, stride=1) + + def init_weights(self, scheme=''): + named_apply(partial(_init_conv, scheme=scheme), self) + + def forward(self, x): + x = self.conv1(x) + x = self.norm1(x) + x = self.conv2(x) + return x + + +class MaxxVit(nn.Module): + """ + """ + + def __init__( + self, + cfg: MaxxVitCfg, + img_size: Union[int, Tuple[int, int]] = 224, + in_chans: int = 3, + num_classes: int = 1000, + global_pool: str = 'avg', + drop_rate: float = 0., + drop_path_rate: float = 0. + ): + super().__init__() + img_size = to_2tuple(img_size) + self.num_classes = num_classes + self.global_pool = global_pool + self.num_features = cfg.embed_dim[-1] + self.embed_dim = cfg.embed_dim + self.drop_rate = drop_rate + self.grad_checkpointing = False + + self.stem = Stem( + in_chs=in_chans, + out_chs=cfg.stem_width, + act_layer=cfg.conv_cfg.act_layer, + norm_layer=cfg.conv_cfg.norm_layer, + norm_eps=cfg.conv_cfg.norm_eps, + ) + + stride = self.stem.stride + feat_size = tuple([i // s for i, s in zip(img_size, to_2tuple(stride))]) + + num_stages = len(cfg.embed_dim) + assert len(cfg.depths) == num_stages + dpr = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(cfg.depths)).split(cfg.depths)] + in_chs = self.stem.out_chs + stages = [] + for i in range(num_stages): + stage_stride = 2 + out_chs = cfg.embed_dim[i] + feat_size = tuple([(r - 1) // stage_stride + 1 for r in feat_size]) + stages += [MaxxVitStage( + in_chs, + out_chs, + depth=cfg.depths[i], + block_types=cfg.block_type[i], + conv_cfg=cfg.conv_cfg, + transformer_cfg=cfg.transformer_cfg, + feat_size=feat_size, + drop_path=dpr[i], + )] + stride *= stage_stride + in_chs = out_chs + self.stages = nn.Sequential(*stages) + + final_norm_layer = get_norm_layer(cfg.transformer_cfg.norm_layer) + self.norm = final_norm_layer(self.num_features, eps=cfg.transformer_cfg.norm_eps) + + # Classifier head + self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=drop_rate) + + # Weight init (default PyTorch init works well for AdamW if scheme not set) + assert cfg.weight_init in ('', 'normal', 'trunc_normal', 'xavier_normal', 'vit_eff') + if cfg.weight_init: + named_apply(partial(self._init_weights, scheme=cfg.weight_init), self) + + def _init_weights(self, module, name, scheme=''): + if hasattr(module, 'init_weights'): + try: + module.init_weights(scheme=scheme) + except TypeError: + module.init_weights() + + @torch.jit.ignore + def no_weight_decay(self): + return { + k for k, _ in self.named_parameters() + if any(n in k for n in ["relative_position_bias_table", "rel_pos.mlp"])} + + @torch.jit.ignore + def group_matcher(self, coarse=False): + matcher = dict( + stem=r'^stem', # stem and embed + blocks=[(r'^stages\.(\d+)', None), (r'^norm', (99999,))] + ) + return matcher + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + for s in self.stages: + s.grad_checkpointing = enable + + @torch.jit.ignore + def get_classifier(self): + return self.head.fc + + def reset_classifier(self, num_classes, global_pool=None): + self.num_classes = num_classes + if global_pool is None: + global_pool = self.head.global_pool.pool_type + self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate) + + def forward_features(self, x): + x = self.stem(x) + x = self.stages(x) + x = self.norm(x) + return x + + def forward_head(self, x, pre_logits: bool = False): + return self.head(x, pre_logits=pre_logits) + + def forward(self, x): + x = self.forward_features(x) + x = self.forward_head(x) + return x + + +def _create_coatnet(variant, cfg_variant=None, pretrained=False, **kwargs): + return build_model_with_cfg( + MaxxVit, variant, pretrained, + model_cfg=model_cfgs[variant] if not cfg_variant else model_cfgs[cfg_variant], + feature_cfg=dict(flatten_sequential=True), + **kwargs) + + +@register_model +def coatnet_pico_rw_224(pretrained=False, **kwargs): + return _create_coatnet('coatnet_pico_rw_224', pretrained=pretrained, **kwargs) + + +@register_model +def coatnet_nano_rw_224(pretrained=False, **kwargs): + return _create_coatnet('coatnet_nano_rw_224', pretrained=pretrained, **kwargs) + + +@register_model +def coatnet_0_rw_224(pretrained=False, **kwargs): + return _create_coatnet('coatnet_0_rw_224', pretrained=pretrained, **kwargs) + + +@register_model +def coatnet_1_rw_224(pretrained=False, **kwargs): + return _create_coatnet('coatnet_1_rw_224', pretrained=pretrained, **kwargs) + + +@register_model +def coatnet_2_rw_224(pretrained=False, **kwargs): + return _create_coatnet('coatnet_2_rw_224', pretrained=pretrained, **kwargs) + + +@register_model +def coatnet_bn_0_rw_224(pretrained=False, **kwargs): + return _create_coatnet('coatnet_bn_0_rw_224', pretrained=pretrained, **kwargs) + + +@register_model +def coatnet_rmlp_nano_rw_224(pretrained=False, **kwargs): + return _create_coatnet('coatnet_rmlp_nano_rw_224', pretrained=pretrained, **kwargs) + + +@register_model +def coatnet_rmlp_0_rw_224(pretrained=False, **kwargs): + return _create_coatnet('coatnet_rmlp_0_rw_224', pretrained=pretrained, **kwargs) + + +@register_model +def coatnet_rmlp_1_rw_224(pretrained=False, **kwargs): + return _create_coatnet('coatnet_rmlp_1_rw_224', pretrained=pretrained, **kwargs) + + +@register_model +def coatnet_nano_cc_224(pretrained=False, **kwargs): + return _create_coatnet('coatnet_nano_cc_224', pretrained=pretrained, **kwargs) + + +@register_model +def coatnext_nano_rw_224(pretrained=False, **kwargs): + return _create_coatnet('coatnext_nano_rw_224', pretrained=pretrained, **kwargs) + + +@register_model +def coatnet_0_224(pretrained=False, **kwargs): + return _create_coatnet('coatnet_0_224', pretrained=pretrained, **kwargs) + + +@register_model +def coatnet_1_224(pretrained=False, **kwargs): + return _create_coatnet('coatnet_1_224', pretrained=pretrained, **kwargs) + + +@register_model +def coatnet_2_224(pretrained=False, **kwargs): + return _create_coatnet('coatnet_2_224', pretrained=pretrained, **kwargs) + + +@register_model +def coatnet_3_224(pretrained=False, **kwargs): + return _create_coatnet('coatnet_3_224', pretrained=pretrained, **kwargs) + + +@register_model +def coatnet_4_224(pretrained=False, **kwargs): + return _create_coatnet('coatnet_4_224', pretrained=pretrained, **kwargs) + + +@register_model +def coatnet_5_224(pretrained=False, **kwargs): + return _create_coatnet('coatnet_5_224', pretrained=pretrained, **kwargs) + + +@register_model +def maxvit_pico_rw_256(pretrained=False, **kwargs): + return _create_coatnet('maxvit_pico_rw_256', pretrained=pretrained, **kwargs) + + +@register_model +def maxvit_nano_rw_256(pretrained=False, **kwargs): + return _create_coatnet('maxvit_nano_rw_256', pretrained=pretrained, **kwargs) + + +@register_model +def maxvit_tiny_rw_224(pretrained=False, **kwargs): + return _create_coatnet('maxvit_tiny_rw_224', pretrained=pretrained, **kwargs) + + +@register_model +def maxvit_tiny_rw_256(pretrained=False, **kwargs): + return _create_coatnet('maxvit_tiny_rw_256', pretrained=pretrained, **kwargs) + + +@register_model +def maxvit_tiny_cm_256(pretrained=False, **kwargs): + return _create_coatnet('maxvit_tiny_cm_256', pretrained=pretrained, **kwargs) + + +@register_model +def maxxvit_nano_rw_256(pretrained=False, **kwargs): + return _create_coatnet('maxxvit_nano_rw_256', pretrained=pretrained, **kwargs) + + +@register_model +def maxvit_tiny_224(pretrained=False, **kwargs): + return _create_coatnet('maxvit_tiny_224', pretrained=pretrained, **kwargs) + + +@register_model +def maxvit_small_224(pretrained=False, **kwargs): + return _create_coatnet('maxvit_small_224', pretrained=pretrained, **kwargs) + + +@register_model +def maxvit_base_224(pretrained=False, **kwargs): + return _create_coatnet('maxvit_base_224', pretrained=pretrained, **kwargs) + + +@register_model +def maxvit_large_224(pretrained=False, **kwargs): + return _create_coatnet('maxvit_large_224', pretrained=pretrained, **kwargs) + + +@register_model +def maxvit_xlarge_224(pretrained=False, **kwargs): + return _create_coatnet('maxvit_xlarge_224', pretrained=pretrained, **kwargs) \ No newline at end of file