diff --git a/.github/ISSUE_TEMPLATE/config.yml b/.github/ISSUE_TEMPLATE/config.yml index e44fb40c..a178cd62 100644 --- a/.github/ISSUE_TEMPLATE/config.yml +++ b/.github/ISSUE_TEMPLATE/config.yml @@ -2,4 +2,4 @@ blank_issues_enabled: false contact_links: - name: Community Discussions url: https://github.com/rwightman/pytorch-image-models/discussions - about: Issues are for features and bugs. Questions can be asked in Discussions. + about: Hparam request in issues will be ignored! Issues are for features and bugs. Questions can be asked in Discussions. diff --git a/.github/ISSUE_TEMPLATE/feature_request.md b/.github/ISSUE_TEMPLATE/feature_request.md index d751adc8..ea51beb7 100644 --- a/.github/ISSUE_TEMPLATE/feature_request.md +++ b/.github/ISSUE_TEMPLATE/feature_request.md @@ -1,8 +1,7 @@ --- name: Feature request -about: Suggest an idea for this project. Issues are for reporting bugs or requesting - features, the discussion forum is available for asking questions or seeking help - from the community. +about: Suggest an idea for this project. Hparam requests, training help are not feature requests. + The discussion forum is available for asking questions or seeking help from the community. title: "[FEATURE] Feature title..." labels: enhancement assignees: '' diff --git a/README.md b/README.md index 019fdae2..b5945b19 100644 --- a/README.md +++ b/README.md @@ -21,6 +21,30 @@ And a big thanks to all GitHub sponsors who helped with some of my costs before ## What's New +### Aug 29, 2022 +* MaxVit window size scales with img_size by default. Add new RelPosMlp MaxViT weight that leverages this: + * `maxvit_rmlp_nano_rw_256` - 83.0 @ 256, 83.6 @ 320 (T) + +### Aug 26, 2022 +* CoAtNet (https://arxiv.org/abs/2106.04803) and MaxVit (https://arxiv.org/abs/2204.01697) `timm` original models + * both found in [`maxxvit.py`](https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/maxxvit.py) model def, contains numerous experiments outside scope of original papers + * an unfinished Tensorflow version from MaxVit authors can be found https://github.com/google-research/maxvit +* Initial CoAtNet and MaxVit timm pretrained weights (working on more): + * `coatnet_nano_rw_224` - 81.7 @ 224 (T) + * `coatnet_rmlp_nano_rw_224` - 82.0 @ 224, 82.8 @ 320 (T) + * `coatnet_0_rw_224` - 82.4 (T) -- NOTE timm '0' coatnets have 2 more 3rd stage blocks + * `coatnet_bn_0_rw_224` - 82.4 (T) + * `maxvit_nano_rw_256` - 82.9 @ 256 (T) + * `coatnet_rmlp_1_rw_224` - 83.4 @ 224, 84 @ 320 (T) + * `coatnet_1_rw_224` - 83.6 @ 224 (G) + * (T) = TPU trained with `bits_and_tpu` branch training code, (G) = GPU trained +* GCVit (weights adapted from https://github.com/NVlabs/GCVit, code 100% `timm` re-write for license purposes) +* MViT-V2 (multi-scale vit, adapted from https://github.com/facebookresearch/mvit) +* EfficientFormer (adapted from https://github.com/snap-research/EfficientFormer) +* PyramidVisionTransformer-V2 (adapted from https://github.com/whai362/PVT) +* 'Fast Norm' support for LayerNorm and GroupNorm that avoids float32 upcast w/ AMP (uses APEX LN if available for further boost) + + ### Aug 15, 2022 * ConvNeXt atto weights added * `convnext_atto` - 75.7 @ 224, 77.0 @ 288 @@ -229,6 +253,7 @@ A full version of the list below with source links can be found in the [document * Bottleneck Transformers - https://arxiv.org/abs/2101.11605 * CaiT (Class-Attention in Image Transformers) - https://arxiv.org/abs/2103.17239 * CoaT (Co-Scale Conv-Attentional Image Transformers) - https://arxiv.org/abs/2104.06399 +* CoAtNet (Convolution and Attention) - https://arxiv.org/abs/2106.04803 * ConvNeXt - https://arxiv.org/abs/2201.03545 * ConViT (Soft Convolutional Inductive Biases Vision Transformers)- https://arxiv.org/abs/2103.10697 * CspNet (Cross-Stage Partial Networks) - https://arxiv.org/abs/1911.11929 @@ -238,6 +263,7 @@ A full version of the list below with source links can be found in the [document * DLA - https://arxiv.org/abs/1707.06484 * DPN (Dual-Path Network) - https://arxiv.org/abs/1707.01629 * EdgeNeXt - https://arxiv.org/abs/2206.10589 +* EfficientFormer - https://arxiv.org/abs/2206.01191 * EfficientNet (MBConvNet Family) * EfficientNet NoisyStudent (B0-B7, L2) - https://arxiv.org/abs/1911.04252 * EfficientNet AdvProp (B0-B8) - https://arxiv.org/abs/1911.09665 @@ -250,6 +276,7 @@ A full version of the list below with source links can be found in the [document * MobileNet-V2 - https://arxiv.org/abs/1801.04381 * Single-Path NAS - https://arxiv.org/abs/1904.02877 * TinyNet - https://arxiv.org/abs/2010.14819 +* GCViT (Global Context Vision Transformer) - https://arxiv.org/abs/2206.09959 * GhostNet - https://arxiv.org/abs/1911.11907 * gMLP - https://arxiv.org/abs/2105.08050 * GPU-Efficient Networks - https://arxiv.org/abs/2006.14090 @@ -259,6 +286,7 @@ A full version of the list below with source links can be found in the [document * Inception-ResNet-V2 and Inception-V4 - https://arxiv.org/abs/1602.07261 * Lambda Networks - https://arxiv.org/abs/2102.08602 * LeViT (Vision Transformer in ConvNet's Clothing) - https://arxiv.org/abs/2104.01136 +* MaxViT (Multi-Axis Vision Transformer) - https://arxiv.org/abs/2204.01697 * MLP-Mixer - https://arxiv.org/abs/2105.01601 * MobileNet-V3 (MBConvNet w/ Efficient Head) - https://arxiv.org/abs/1905.02244 * FBNet-V3 - https://arxiv.org/abs/2006.02049 @@ -266,6 +294,7 @@ A full version of the list below with source links can be found in the [document * LCNet - https://arxiv.org/abs/2109.15099 * MobileViT - https://arxiv.org/abs/2110.02178 * MobileViT-V2 - https://arxiv.org/abs/2206.02680 +* MViT-V2 (Improved Multiscale Vision Transformer) - https://arxiv.org/abs/2112.01526 * NASNet-A - https://arxiv.org/abs/1707.07012 * NesT - https://arxiv.org/abs/2105.12723 * NFNet-F - https://arxiv.org/abs/2102.06171 @@ -273,6 +302,7 @@ A full version of the list below with source links can be found in the [document * PNasNet - https://arxiv.org/abs/1712.00559 * PoolFormer (MetaFormer) - https://arxiv.org/abs/2111.11418 * Pooling-based Vision Transformer (PiT) - https://arxiv.org/abs/2103.16302 +* PVT-V2 (Improved Pyramid Vision Transformer) - https://arxiv.org/abs/2106.13797 * RegNet - https://arxiv.org/abs/2003.13678 * RegNetZ - https://arxiv.org/abs/2103.06877 * RepVGG - https://arxiv.org/abs/2101.03697 diff --git a/benchmark.py b/benchmark.py index 4679a009..4a89441b 100755 --- a/benchmark.py +++ b/benchmark.py @@ -19,7 +19,7 @@ import torch.nn as nn import torch.nn.parallel from timm.data import resolve_data_config -from timm.models import create_model, is_model, list_models +from timm.models import create_model, is_model, list_models, set_fast_norm from timm.optim import create_optimizer_v2 from timm.utils import setup_default_logging, set_jit_fuser, decay_batch_step, check_batch_size_retry @@ -109,7 +109,8 @@ scripting_group.add_argument('--torchscript', dest='torchscript', action='store_ help='convert model torchscript for inference') scripting_group.add_argument('--aot-autograd', default=False, action='store_true', help="Enable AOT Autograd support. (It's recommended to use this option with `--fuser nvfuser` together)") - +scripting_group.add_argument('--fast-norm', default=False, action='store_true', + help='enable experimental fast-norm') # train optimizer parameters parser.add_argument('--opt', default='sgd', type=str, metavar='OPTIMIZER', @@ -598,6 +599,9 @@ def main(): model_cfgs = [] model_names = [] + if args.fast_norm: + set_fast_norm() + if args.model_list: args.model = '' with open(args.model_list) as f: diff --git a/timm/models/__init__.py b/timm/models/__init__.py index 51a38d0c..5ff79595 100644 --- a/timm/models/__init__.py +++ b/timm/models/__init__.py @@ -69,5 +69,6 @@ from .helpers import load_checkpoint, resume_checkpoint, model_parameters from .layers import TestTimePoolHead, apply_test_time_pool from .layers import convert_splitbn_model, convert_sync_batchnorm from .layers import is_scriptable, is_exportable, set_scriptable, set_exportable, is_no_jit, set_no_jit +from .layers import set_fast_norm from .registry import register_model, model_entrypoint, list_models, is_model, list_modules, is_model_in_modules,\ is_model_pretrained, get_pretrained_cfg, has_pretrained_cfg_key, is_pretrained_cfg_key, get_pretrained_cfg_value diff --git a/timm/models/densenet.py b/timm/models/densenet.py index a46b86ad..1afdfd7b 100644 --- a/timm/models/densenet.py +++ b/timm/models/densenet.py @@ -115,7 +115,7 @@ class DenseBlock(nn.ModuleDict): _version = 2 def __init__( - self, num_layers, num_input_features, bn_size, growth_rate, norm_layer=nn.ReLU, + self, num_layers, num_input_features, bn_size, growth_rate, norm_layer=BatchNormAct2d, drop_rate=0., memory_efficient=False): super(DenseBlock, self).__init__() for i in range(num_layers): @@ -138,7 +138,7 @@ class DenseBlock(nn.ModuleDict): class DenseTransition(nn.Sequential): - def __init__(self, num_input_features, num_output_features, norm_layer=nn.BatchNorm2d, aa_layer=None): + def __init__(self, num_input_features, num_output_features, norm_layer=BatchNormAct2d, aa_layer=None): super(DenseTransition, self).__init__() self.add_module('norm', norm_layer(num_input_features)) self.add_module('conv', nn.Conv2d( diff --git a/timm/models/layers/fast_norm.py b/timm/models/layers/fast_norm.py index 9a34a15e..fb35e47d 100644 --- a/timm/models/layers/fast_norm.py +++ b/timm/models/layers/fast_norm.py @@ -1,3 +1,11 @@ +""" 'Fast' Normalization Functions + +For GroupNorm and LayerNorm these functions bypass typical AMP upcast to float32. + +Additionally, for LayerNorm, the APEX fused LN is used if available (which also does not upcast) + +Hacked together by / Copyright 2022 Ross Wightman +""" from typing import List, Optional import torch @@ -37,6 +45,7 @@ def fast_group_norm( if torch.is_autocast_enabled(): # normally native AMP casts GN inputs to float32 # here we use the low precision autocast dtype + # FIXME what to do re CPU autocast? dt = torch.get_autocast_gpu_dtype() x, weight, bias = x.to(dt), weight.to(dt), bias.to(dt) @@ -62,6 +71,7 @@ def fast_layer_norm( # normally native AMP casts LN inputs to float32 # apex LN does not, this is behaving like Apex dt = torch.get_autocast_gpu_dtype() + # FIXME what to do re CPU autocast? x, weight, bias = x.to(dt), weight.to(dt), bias.to(dt) with torch.cuda.amp.autocast(enabled=False): diff --git a/timm/models/layers/norm.py b/timm/models/layers/norm.py index 2ff8fc08..42445a49 100644 --- a/timm/models/layers/norm.py +++ b/timm/models/layers/norm.py @@ -1,4 +1,8 @@ """ Normalization layers and wrappers + +Norm layer definitions that support fast norm and consistent channel arg order (always first arg). + +Hacked together by / Copyright 2022 Ross Wightman """ import torch diff --git a/timm/models/layers/norm_act.py b/timm/models/layers/norm_act.py index 3cd9fb36..e5bd0e78 100644 --- a/timm/models/layers/norm_act.py +++ b/timm/models/layers/norm_act.py @@ -1,4 +1,16 @@ """ Normalization + Activation Layers + +Provides Norm+Act fns for standard PyTorch norm layers such as +* BatchNorm +* GroupNorm +* LayerNorm + +This allows swapping with alternative layers that are natively both norm + act such as +* EvoNorm (evo_norm.py) +* FilterResponseNorm (filter_response_norm.py) +* InplaceABN (inplace_abn.py) + +Hacked together by / Copyright 2022 Ross Wightman """ from typing import Union, List, Optional, Any diff --git a/timm/models/layers/weight_init.py b/timm/models/layers/weight_init.py index 4a160931..943e4f4c 100644 --- a/timm/models/layers/weight_init.py +++ b/timm/models/layers/weight_init.py @@ -5,7 +5,7 @@ import warnings from torch.nn.init import _calculate_fan_in_and_fan_out -def _no_grad_trunc_normal_(tensor, mean, std, a, b): +def _trunc_normal_(tensor, mean, std, a, b): # Cut & paste from PyTorch official master until it's in a few official releases - RW # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf def norm_cdf(x): @@ -17,28 +17,27 @@ def _no_grad_trunc_normal_(tensor, mean, std, a, b): "The distribution of values may be incorrect.", stacklevel=2) - with torch.no_grad(): - # Values are generated by using a truncated uniform distribution and - # then using the inverse CDF for the normal distribution. - # Get upper and lower cdf values - l = norm_cdf((a - mean) / std) - u = norm_cdf((b - mean) / std) + # Values are generated by using a truncated uniform distribution and + # then using the inverse CDF for the normal distribution. + # Get upper and lower cdf values + l = norm_cdf((a - mean) / std) + u = norm_cdf((b - mean) / std) - # Uniformly fill tensor with values from [l, u], then translate to - # [2l-1, 2u-1]. - tensor.uniform_(2 * l - 1, 2 * u - 1) + # Uniformly fill tensor with values from [l, u], then translate to + # [2l-1, 2u-1]. + tensor.uniform_(2 * l - 1, 2 * u - 1) - # Use inverse cdf transform for normal distribution to get truncated - # standard normal - tensor.erfinv_() + # Use inverse cdf transform for normal distribution to get truncated + # standard normal + tensor.erfinv_() - # Transform to proper mean, std - tensor.mul_(std * math.sqrt(2.)) - tensor.add_(mean) + # Transform to proper mean, std + tensor.mul_(std * math.sqrt(2.)) + tensor.add_(mean) - # Clamp to ensure it's in the proper range - tensor.clamp_(min=a, max=b) - return tensor + # Clamp to ensure it's in the proper range + tensor.clamp_(min=a, max=b) + return tensor def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.): @@ -64,7 +63,8 @@ def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.): >>> w = torch.empty(3, 5) >>> nn.init.trunc_normal_(w) """ - return _no_grad_trunc_normal_(tensor, mean, std, a, b) + with torch.no_grad(): + return _trunc_normal_(tensor, mean, std, a, b) def trunc_normal_tf_(tensor, mean=0., std=1., a=-2., b=2.): @@ -90,8 +90,8 @@ def trunc_normal_tf_(tensor, mean=0., std=1., a=-2., b=2.): >>> w = torch.empty(3, 5) >>> nn.init.trunc_normal_(w) """ - _no_grad_trunc_normal_(tensor, 0, 1.0, a, b) with torch.no_grad(): + _trunc_normal_(tensor, 0, 1.0, a, b) tensor.mul_(std).add_(mean) return tensor @@ -111,10 +111,12 @@ def variance_scaling_(tensor, scale=1.0, mode='fan_in', distribution='normal'): # constant is stddev of standard normal truncated to (-2, 2) trunc_normal_tf_(tensor, std=math.sqrt(variance) / .87962566103423978) elif distribution == "normal": - tensor.normal_(std=math.sqrt(variance)) + with torch.no_grad(): + tensor.normal_(std=math.sqrt(variance)) elif distribution == "uniform": bound = math.sqrt(3 * variance) - tensor.uniform_(-bound, bound) + with torch.no_grad(): + tensor.uniform_(-bound, bound) else: raise ValueError(f"invalid distribution {distribution}") diff --git a/timm/models/maxxvit.py b/timm/models/maxxvit.py index 898e1685..495f682b 100644 --- a/timm/models/maxxvit.py +++ b/timm/models/maxxvit.py @@ -39,7 +39,7 @@ Hacked together by / Copyright 2022, Ross Wightman import math from collections import OrderedDict -from dataclasses import dataclass +from dataclasses import dataclass, replace from functools import partial from typing import Callable, Optional, Union, Tuple, List @@ -108,10 +108,13 @@ default_cfgs = { # Experimental configs 'maxvit_pico_rw_256': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8)), 'maxvit_nano_rw_256': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/maxvit_nano_rw_256_sw-3e790ce3.pth', + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/maxvit_nano_rw_256_sw-fb127241.pth', input_size=(3, 256, 256), pool_size=(8, 8)), 'maxvit_tiny_rw_224': _cfg(url=''), 'maxvit_tiny_rw_256': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8)), + 'maxvit_rmlp_nano_rw_256': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/maxvit_rmlp_nano_rw_256_sw-c17bb0d6.pth', + input_size=(3, 256, 256), pool_size=(8, 8)), 'maxvit_tiny_pm_256': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8)), 'maxxvit_nano_rw_256': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8)), @@ -129,21 +132,30 @@ class MaxxVitTransformerCfg: dim_head: int = 32 expand_ratio: float = 4.0 expand_first: bool = True - shortcut_bias: bool = True, + shortcut_bias: bool = True attn_bias: bool = True attn_drop: float = 0. proj_drop: float = 0. pool_type: str = 'avg2' rel_pos_type: str = 'bias' rel_pos_dim: int = 512 # for relative position types w/ MLP - window_size: Tuple[int, int] = (7, 7) - grid_size: Tuple[int, int] = (7, 7) + partition_stride: int = 32 + window_size: Optional[Tuple[int, int]] = None + grid_size: Optional[Tuple[int, int]] = None init_values: Optional[float] = None act_layer: str = 'gelu' norm_layer: str = 'layernorm2d' norm_layer_cl: str = 'layernorm' norm_eps: float = 1e-6 + def __post_init__(self): + if self.grid_size is not None: + self.grid_size = to_2tuple(self.grid_size) + if self.window_size is not None: + self.window_size = to_2tuple(self.window_size) + if self.grid_size is None: + self.grid_size = self.window_size + @dataclass class MaxxVitConvCfg: @@ -249,7 +261,7 @@ def _rw_max_cfg( conv_norm_layer='', transformer_norm_layer='layernorm2d', transformer_norm_layer_cl='layernorm', - window_size=7, + window_size=None, dim_head=32, rel_pos_type='bias', rel_pos_dim=512, @@ -259,8 +271,6 @@ def _rw_max_cfg( # - mbconv expansion calculated from input instead of output chs # - mbconv shortcut and final 1x1 conv did not have a bias # - mbconv uses silu in timm, not gelu - # - avg pool with kernel_size=2 favoured downsampling (instead of maxpool for coat) - # - default to avg pool for mbconv downsample instead of 1x1 or dw conv # - expansion in attention block done via output proj, not input proj return dict( conv_cfg=MaxxVitConvCfg( @@ -276,8 +286,7 @@ def _rw_max_cfg( expand_first=False, pool_type=pool_type, dim_head=dim_head, - window_size=to_2tuple(window_size), - grid_size=to_2tuple(window_size), + window_size=window_size, norm_layer=transformer_norm_layer, norm_layer_cl=transformer_norm_layer_cl, rel_pos_type=rel_pos_type, @@ -293,7 +302,7 @@ def _next_cfg( conv_norm_layer_cl='layernorm', transformer_norm_layer='layernorm2d', transformer_norm_layer_cl='layernorm', - window_size=7, + window_size=None, rel_pos_type='bias', rel_pos_dim=512, ): @@ -310,8 +319,7 @@ def _next_cfg( transformer_cfg=MaxxVitTransformerCfg( expand_first=False, pool_type=pool_type, - window_size=to_2tuple(window_size), - grid_size=to_2tuple(window_size), + window_size=window_size, norm_layer=transformer_norm_layer, norm_layer_cl=transformer_norm_layer_cl, rel_pos_type=rel_pos_type, @@ -411,18 +419,19 @@ model_cfgs = dict( rel_pos_dim=384, # was supposed to be 512, woops ), ), - coatnext_nano_rw_224=MaxxVitCfg( + coatnet_nano_cc_224=MaxxVitCfg( embed_dim=(64, 128, 256, 512), depths=(3, 4, 6, 3), stem_width=(32, 64), - **_next_cfg(), + block_type=('C', 'C', ('C', 'T'), ('C', 'T')), + **_rw_coat_cfg(), ), - coatnet_nano_cc_224=MaxxVitCfg( + coatnext_nano_rw_224=MaxxVitCfg( embed_dim=(64, 128, 256, 512), depths=(3, 4, 6, 3), stem_width=(32, 64), - block_type=('C', 'C', ('C', 'T'), ('C', 'T')), - **_rw_coat_cfg(), + weight_init='normal', + **_next_cfg(), ), # Trying to be like the CoAtNet paper configs @@ -463,14 +472,14 @@ model_cfgs = dict( depths=(2, 2, 5, 2), block_type=('M',) * 4, stem_width=(24, 32), - **_rw_max_cfg(window_size=8), + **_rw_max_cfg(), ), maxvit_nano_rw_256=MaxxVitCfg( embed_dim=(64, 128, 256, 512), depths=(1, 2, 3, 1), block_type=('M',) * 4, stem_width=(32, 64), - **_rw_max_cfg(window_size=8), + **_rw_max_cfg(), ), maxvit_tiny_rw_224=MaxxVitCfg( embed_dim=(64, 128, 256, 512), @@ -484,21 +493,29 @@ model_cfgs = dict( depths=(2, 2, 5, 2), block_type=('M',) * 4, stem_width=(32, 64), - **_rw_max_cfg(window_size=8), + **_rw_max_cfg(), + ), + maxvit_rmlp_nano_rw_256=MaxxVitCfg( + embed_dim=(64, 128, 256, 512), + depths=(1, 2, 3, 1), + block_type=('M',) * 4, + stem_width=(32, 64), + **_rw_max_cfg(rel_pos_type='mlp'), ), maxvit_tiny_pm_256=MaxxVitCfg( embed_dim=(64, 128, 256, 512), depths=(2, 2, 5, 2), block_type=('PM',) * 4, stem_width=(32, 64), - **_rw_max_cfg(window_size=8), + **_rw_max_cfg(), ), maxxvit_nano_rw_256=MaxxVitCfg( embed_dim=(64, 128, 256, 512), depths=(1, 2, 3, 1), block_type=('M',) * 4, stem_width=(32, 64), - **_next_cfg(window_size=8), + weight_init='normal', + **_next_cfg(), ), # Trying to be like the MaxViT paper configs @@ -651,7 +668,11 @@ class LayerScale2d(nn.Module): class Downsample2d(nn.Module): - """ A downsample pooling module for Coat that handles 2d <-> 1d conversion + """ A downsample pooling module supporting several maxpool and avgpool modes + * 'max' - MaxPool2d w/ kernel_size 3, stride 2, padding 1 + * 'max2' - MaxPool2d w/ kernel_size = stride = 2 + * 'avg' - AvgPool2d w/ kernel_size 3, stride 2, padding 1 + * 'avg2' - AvgPool2d w/ kernel_size = stride = 2 """ def __init__( @@ -710,6 +731,11 @@ def _init_transformer(module, name, scheme=''): class TransformerBlock2d(nn.Module): """ Transformer block with 2D downsampling '2D' NCHW tensor layout + + Some gains can be seen on GPU using a 1D / CL block, BUT w/ the need to switch back/forth to NCHW + for spatial pooling, the benefit is minimal so ended up using just this variant for CoAt configs. + + This impl was faster on TPU w/ PT XLA than the 1D experiment. """ def __init__( @@ -1011,9 +1037,9 @@ def get_rel_pos_cls(cfg: MaxxVitTransformerCfg, window_size): return rel_pos_cls -class PartitionAttention(nn.Module): +class PartitionAttentionCl(nn.Module): """ Grid or Block partition + Attn + FFN. - NxC tensor layout. + NxC 'channels last' tensor layout. """ def __init__( @@ -1183,6 +1209,7 @@ def grid_reverse_nchw(windows, grid_size: List[int], img_size: List[int]): class PartitionAttention2d(nn.Module): """ Grid or Block partition + Attn + FFN + '2D' NCHW tensor layout. """ @@ -1245,7 +1272,7 @@ class PartitionAttention2d(nn.Module): class MaxxVitBlock(nn.Module): - """ + """ MaxVit conv, window partition + FFN , grid partition + FFN """ def __init__( @@ -1264,7 +1291,7 @@ class MaxxVitBlock(nn.Module): self.conv = conv_cls(dim, dim_out, stride=stride, cfg=conv_cfg, drop_path=drop_path) attn_kwargs = dict(dim=dim_out, cfg=transformer_cfg, drop_path=drop_path) - partition_layer = PartitionAttention2d if use_nchw_attn else PartitionAttention + partition_layer = PartitionAttention2d if use_nchw_attn else PartitionAttentionCl self.nchw_attn = use_nchw_attn self.attn_block = partition_layer(**attn_kwargs) self.attn_grid = partition_layer(partition_type='grid', **attn_kwargs) @@ -1288,7 +1315,8 @@ class MaxxVitBlock(nn.Module): class ParallelMaxxVitBlock(nn.Module): - """ + """ MaxVit block with parallel cat(window + grid), one FF + Experimental timm block. """ def __init__( @@ -1426,8 +1454,19 @@ class Stem(nn.Module): return x +def cfg_window_size(cfg: MaxxVitTransformerCfg, img_size: Tuple[int, int]): + if cfg.window_size is not None: + assert cfg.grid_size + return cfg + partition_size = img_size[0] // cfg.partition_stride, img_size[1] // cfg.partition_stride + cfg = replace(cfg, window_size=partition_size, grid_size=partition_size) + return cfg + + class MaxxVit(nn.Module): - """ + """ CoaTNet + MaxVit base model. + + Highly configurable for different block compositions, tensor layouts, pooling types. """ def __init__( @@ -1442,6 +1481,7 @@ class MaxxVit(nn.Module): ): super().__init__() img_size = to_2tuple(img_size) + transformer_cfg = cfg_window_size(cfg.transformer_cfg, img_size) self.num_classes = num_classes self.global_pool = global_pool self.num_features = cfg.embed_dim[-1] @@ -1475,7 +1515,7 @@ class MaxxVit(nn.Module): depth=cfg.depths[i], block_types=cfg.block_type[i], conv_cfg=cfg.conv_cfg, - transformer_cfg=cfg.transformer_cfg, + transformer_cfg=transformer_cfg, feat_size=feat_size, drop_path=dpr[i], )] @@ -1658,6 +1698,11 @@ def maxvit_tiny_rw_256(pretrained=False, **kwargs): return _create_maxxvit('maxvit_tiny_rw_256', pretrained=pretrained, **kwargs) +@register_model +def maxvit_rmlp_nano_rw_256(pretrained=False, **kwargs): + return _create_maxxvit('maxvit_rmlp_nano_rw_256', pretrained=pretrained, **kwargs) + + @register_model def maxvit_tiny_pm_256(pretrained=False, **kwargs): return _create_maxxvit('maxvit_tiny_pm_256', pretrained=pretrained, **kwargs) diff --git a/timm/models/mvitv2.py b/timm/models/mvitv2.py index 002225c6..c5aaa09e 100644 --- a/timm/models/mvitv2.py +++ b/timm/models/mvitv2.py @@ -57,6 +57,8 @@ default_cfgs = dict( mvitv2_huge_in21k=_cfg( url='https://dl.fbaipublicfiles.com/mvit/mvitv2_models/MViTv2_H_in21k.pyth', num_classes=19168), + + mvitv2_small_cls=_cfg(url=''), ) @@ -135,6 +137,11 @@ model_cfgs = dict( num_heads=2, expand_attn=False, ), + + mvitv2_small_cls=MultiScaleVitCfg( + depths=(1, 2, 11, 2), + use_cls_token=True, + ), ) @@ -641,7 +648,7 @@ class MultiScaleBlock(nn.Module): if self.shortcut_pool_attn is None: return x if self.has_cls_token: - cls_tok, x = x[:, :, :1, :], x[:, :, 1:, :] + cls_tok, x = x[:, :1, :], x[:, 1:, :] else: cls_tok = None B, L, C = x.shape @@ -650,7 +657,7 @@ class MultiScaleBlock(nn.Module): x = self.shortcut_pool_attn(x) x = x.reshape(B, C, -1).transpose(1, 2) if cls_tok is not None: - x = torch.cat((cls_tok, x), dim=2) + x = torch.cat((cls_tok, x), dim=1) return x def forward(self, x, feat_size: List[int]): @@ -996,3 +1003,8 @@ def mvitv2_large(pretrained=False, **kwargs): # @register_model # def mvitv2_huge_in21k(pretrained=False, **kwargs): # return _create_mvitv2('mvitv2_huge_in21k', pretrained=pretrained, **kwargs) + + +@register_model +def mvitv2_small_cls(pretrained=False, **kwargs): + return _create_mvitv2('mvitv2_small_cls', pretrained=pretrained, **kwargs)