Merge remote-tracking branch 'origin/master' into bits_and_tpu

pull/804/merge^2
Ross Wightman 2 years ago
commit 87bfb055c0

@ -2,4 +2,4 @@ blank_issues_enabled: false
contact_links:
- name: Community Discussions
url: https://github.com/rwightman/pytorch-image-models/discussions
about: Issues are for features and bugs. Questions can be asked in Discussions.
about: Hparam request in issues will be ignored! Issues are for features and bugs. Questions can be asked in Discussions.

@ -1,8 +1,7 @@
---
name: Feature request
about: Suggest an idea for this project. Issues are for reporting bugs or requesting
features, the discussion forum is available for asking questions or seeking help
from the community.
about: Suggest an idea for this project. Hparam requests, training help are not feature requests.
The discussion forum is available for asking questions or seeking help from the community.
title: "[FEATURE] Feature title..."
labels: enhancement
assignees: ''

@ -21,6 +21,30 @@ And a big thanks to all GitHub sponsors who helped with some of my costs before
## What's New
### Aug 29, 2022
* MaxVit window size scales with img_size by default. Add new RelPosMlp MaxViT weight that leverages this:
* `maxvit_rmlp_nano_rw_256` - 83.0 @ 256, 83.6 @ 320 (T)
### Aug 26, 2022
* CoAtNet (https://arxiv.org/abs/2106.04803) and MaxVit (https://arxiv.org/abs/2204.01697) `timm` original models
* both found in [`maxxvit.py`](https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/maxxvit.py) model def, contains numerous experiments outside scope of original papers
* an unfinished Tensorflow version from MaxVit authors can be found https://github.com/google-research/maxvit
* Initial CoAtNet and MaxVit timm pretrained weights (working on more):
* `coatnet_nano_rw_224` - 81.7 @ 224 (T)
* `coatnet_rmlp_nano_rw_224` - 82.0 @ 224, 82.8 @ 320 (T)
* `coatnet_0_rw_224` - 82.4 (T) -- NOTE timm '0' coatnets have 2 more 3rd stage blocks
* `coatnet_bn_0_rw_224` - 82.4 (T)
* `maxvit_nano_rw_256` - 82.9 @ 256 (T)
* `coatnet_rmlp_1_rw_224` - 83.4 @ 224, 84 @ 320 (T)
* `coatnet_1_rw_224` - 83.6 @ 224 (G)
* (T) = TPU trained with `bits_and_tpu` branch training code, (G) = GPU trained
* GCVit (weights adapted from https://github.com/NVlabs/GCVit, code 100% `timm` re-write for license purposes)
* MViT-V2 (multi-scale vit, adapted from https://github.com/facebookresearch/mvit)
* EfficientFormer (adapted from https://github.com/snap-research/EfficientFormer)
* PyramidVisionTransformer-V2 (adapted from https://github.com/whai362/PVT)
* 'Fast Norm' support for LayerNorm and GroupNorm that avoids float32 upcast w/ AMP (uses APEX LN if available for further boost)
### Aug 15, 2022
* ConvNeXt atto weights added
* `convnext_atto` - 75.7 @ 224, 77.0 @ 288
@ -229,6 +253,7 @@ A full version of the list below with source links can be found in the [document
* Bottleneck Transformers - https://arxiv.org/abs/2101.11605
* CaiT (Class-Attention in Image Transformers) - https://arxiv.org/abs/2103.17239
* CoaT (Co-Scale Conv-Attentional Image Transformers) - https://arxiv.org/abs/2104.06399
* CoAtNet (Convolution and Attention) - https://arxiv.org/abs/2106.04803
* ConvNeXt - https://arxiv.org/abs/2201.03545
* ConViT (Soft Convolutional Inductive Biases Vision Transformers)- https://arxiv.org/abs/2103.10697
* CspNet (Cross-Stage Partial Networks) - https://arxiv.org/abs/1911.11929
@ -238,6 +263,7 @@ A full version of the list below with source links can be found in the [document
* DLA - https://arxiv.org/abs/1707.06484
* DPN (Dual-Path Network) - https://arxiv.org/abs/1707.01629
* EdgeNeXt - https://arxiv.org/abs/2206.10589
* EfficientFormer - https://arxiv.org/abs/2206.01191
* EfficientNet (MBConvNet Family)
* EfficientNet NoisyStudent (B0-B7, L2) - https://arxiv.org/abs/1911.04252
* EfficientNet AdvProp (B0-B8) - https://arxiv.org/abs/1911.09665
@ -250,6 +276,7 @@ A full version of the list below with source links can be found in the [document
* MobileNet-V2 - https://arxiv.org/abs/1801.04381
* Single-Path NAS - https://arxiv.org/abs/1904.02877
* TinyNet - https://arxiv.org/abs/2010.14819
* GCViT (Global Context Vision Transformer) - https://arxiv.org/abs/2206.09959
* GhostNet - https://arxiv.org/abs/1911.11907
* gMLP - https://arxiv.org/abs/2105.08050
* GPU-Efficient Networks - https://arxiv.org/abs/2006.14090
@ -259,6 +286,7 @@ A full version of the list below with source links can be found in the [document
* Inception-ResNet-V2 and Inception-V4 - https://arxiv.org/abs/1602.07261
* Lambda Networks - https://arxiv.org/abs/2102.08602
* LeViT (Vision Transformer in ConvNet's Clothing) - https://arxiv.org/abs/2104.01136
* MaxViT (Multi-Axis Vision Transformer) - https://arxiv.org/abs/2204.01697
* MLP-Mixer - https://arxiv.org/abs/2105.01601
* MobileNet-V3 (MBConvNet w/ Efficient Head) - https://arxiv.org/abs/1905.02244
* FBNet-V3 - https://arxiv.org/abs/2006.02049
@ -266,6 +294,7 @@ A full version of the list below with source links can be found in the [document
* LCNet - https://arxiv.org/abs/2109.15099
* MobileViT - https://arxiv.org/abs/2110.02178
* MobileViT-V2 - https://arxiv.org/abs/2206.02680
* MViT-V2 (Improved Multiscale Vision Transformer) - https://arxiv.org/abs/2112.01526
* NASNet-A - https://arxiv.org/abs/1707.07012
* NesT - https://arxiv.org/abs/2105.12723
* NFNet-F - https://arxiv.org/abs/2102.06171
@ -273,6 +302,7 @@ A full version of the list below with source links can be found in the [document
* PNasNet - https://arxiv.org/abs/1712.00559
* PoolFormer (MetaFormer) - https://arxiv.org/abs/2111.11418
* Pooling-based Vision Transformer (PiT) - https://arxiv.org/abs/2103.16302
* PVT-V2 (Improved Pyramid Vision Transformer) - https://arxiv.org/abs/2106.13797
* RegNet - https://arxiv.org/abs/2003.13678
* RegNetZ - https://arxiv.org/abs/2103.06877
* RepVGG - https://arxiv.org/abs/2101.03697

@ -19,7 +19,7 @@ import torch.nn as nn
import torch.nn.parallel
from timm.data import resolve_data_config
from timm.models import create_model, is_model, list_models
from timm.models import create_model, is_model, list_models, set_fast_norm
from timm.optim import create_optimizer_v2
from timm.utils import setup_default_logging, set_jit_fuser, decay_batch_step, check_batch_size_retry
@ -109,7 +109,8 @@ scripting_group.add_argument('--torchscript', dest='torchscript', action='store_
help='convert model torchscript for inference')
scripting_group.add_argument('--aot-autograd', default=False, action='store_true',
help="Enable AOT Autograd support. (It's recommended to use this option with `--fuser nvfuser` together)")
scripting_group.add_argument('--fast-norm', default=False, action='store_true',
help='enable experimental fast-norm')
# train optimizer parameters
parser.add_argument('--opt', default='sgd', type=str, metavar='OPTIMIZER',
@ -598,6 +599,9 @@ def main():
model_cfgs = []
model_names = []
if args.fast_norm:
set_fast_norm()
if args.model_list:
args.model = ''
with open(args.model_list) as f:

@ -69,5 +69,6 @@ from .helpers import load_checkpoint, resume_checkpoint, model_parameters
from .layers import TestTimePoolHead, apply_test_time_pool
from .layers import convert_splitbn_model, convert_sync_batchnorm
from .layers import is_scriptable, is_exportable, set_scriptable, set_exportable, is_no_jit, set_no_jit
from .layers import set_fast_norm
from .registry import register_model, model_entrypoint, list_models, is_model, list_modules, is_model_in_modules,\
is_model_pretrained, get_pretrained_cfg, has_pretrained_cfg_key, is_pretrained_cfg_key, get_pretrained_cfg_value

@ -115,7 +115,7 @@ class DenseBlock(nn.ModuleDict):
_version = 2
def __init__(
self, num_layers, num_input_features, bn_size, growth_rate, norm_layer=nn.ReLU,
self, num_layers, num_input_features, bn_size, growth_rate, norm_layer=BatchNormAct2d,
drop_rate=0., memory_efficient=False):
super(DenseBlock, self).__init__()
for i in range(num_layers):
@ -138,7 +138,7 @@ class DenseBlock(nn.ModuleDict):
class DenseTransition(nn.Sequential):
def __init__(self, num_input_features, num_output_features, norm_layer=nn.BatchNorm2d, aa_layer=None):
def __init__(self, num_input_features, num_output_features, norm_layer=BatchNormAct2d, aa_layer=None):
super(DenseTransition, self).__init__()
self.add_module('norm', norm_layer(num_input_features))
self.add_module('conv', nn.Conv2d(

@ -1,3 +1,11 @@
""" 'Fast' Normalization Functions
For GroupNorm and LayerNorm these functions bypass typical AMP upcast to float32.
Additionally, for LayerNorm, the APEX fused LN is used if available (which also does not upcast)
Hacked together by / Copyright 2022 Ross Wightman
"""
from typing import List, Optional
import torch
@ -37,6 +45,7 @@ def fast_group_norm(
if torch.is_autocast_enabled():
# normally native AMP casts GN inputs to float32
# here we use the low precision autocast dtype
# FIXME what to do re CPU autocast?
dt = torch.get_autocast_gpu_dtype()
x, weight, bias = x.to(dt), weight.to(dt), bias.to(dt)
@ -62,6 +71,7 @@ def fast_layer_norm(
# normally native AMP casts LN inputs to float32
# apex LN does not, this is behaving like Apex
dt = torch.get_autocast_gpu_dtype()
# FIXME what to do re CPU autocast?
x, weight, bias = x.to(dt), weight.to(dt), bias.to(dt)
with torch.cuda.amp.autocast(enabled=False):

@ -1,4 +1,8 @@
""" Normalization layers and wrappers
Norm layer definitions that support fast norm and consistent channel arg order (always first arg).
Hacked together by / Copyright 2022 Ross Wightman
"""
import torch

@ -1,4 +1,16 @@
""" Normalization + Activation Layers
Provides Norm+Act fns for standard PyTorch norm layers such as
* BatchNorm
* GroupNorm
* LayerNorm
This allows swapping with alternative layers that are natively both norm + act such as
* EvoNorm (evo_norm.py)
* FilterResponseNorm (filter_response_norm.py)
* InplaceABN (inplace_abn.py)
Hacked together by / Copyright 2022 Ross Wightman
"""
from typing import Union, List, Optional, Any

@ -5,7 +5,7 @@ import warnings
from torch.nn.init import _calculate_fan_in_and_fan_out
def _no_grad_trunc_normal_(tensor, mean, std, a, b):
def _trunc_normal_(tensor, mean, std, a, b):
# Cut & paste from PyTorch official master until it's in a few official releases - RW
# Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
def norm_cdf(x):
@ -17,28 +17,27 @@ def _no_grad_trunc_normal_(tensor, mean, std, a, b):
"The distribution of values may be incorrect.",
stacklevel=2)
with torch.no_grad():
# Values are generated by using a truncated uniform distribution and
# then using the inverse CDF for the normal distribution.
# Get upper and lower cdf values
l = norm_cdf((a - mean) / std)
u = norm_cdf((b - mean) / std)
# Values are generated by using a truncated uniform distribution and
# then using the inverse CDF for the normal distribution.
# Get upper and lower cdf values
l = norm_cdf((a - mean) / std)
u = norm_cdf((b - mean) / std)
# Uniformly fill tensor with values from [l, u], then translate to
# [2l-1, 2u-1].
tensor.uniform_(2 * l - 1, 2 * u - 1)
# Uniformly fill tensor with values from [l, u], then translate to
# [2l-1, 2u-1].
tensor.uniform_(2 * l - 1, 2 * u - 1)
# Use inverse cdf transform for normal distribution to get truncated
# standard normal
tensor.erfinv_()
# Use inverse cdf transform for normal distribution to get truncated
# standard normal
tensor.erfinv_()
# Transform to proper mean, std
tensor.mul_(std * math.sqrt(2.))
tensor.add_(mean)
# Transform to proper mean, std
tensor.mul_(std * math.sqrt(2.))
tensor.add_(mean)
# Clamp to ensure it's in the proper range
tensor.clamp_(min=a, max=b)
return tensor
# Clamp to ensure it's in the proper range
tensor.clamp_(min=a, max=b)
return tensor
def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
@ -64,7 +63,8 @@ def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
>>> w = torch.empty(3, 5)
>>> nn.init.trunc_normal_(w)
"""
return _no_grad_trunc_normal_(tensor, mean, std, a, b)
with torch.no_grad():
return _trunc_normal_(tensor, mean, std, a, b)
def trunc_normal_tf_(tensor, mean=0., std=1., a=-2., b=2.):
@ -90,8 +90,8 @@ def trunc_normal_tf_(tensor, mean=0., std=1., a=-2., b=2.):
>>> w = torch.empty(3, 5)
>>> nn.init.trunc_normal_(w)
"""
_no_grad_trunc_normal_(tensor, 0, 1.0, a, b)
with torch.no_grad():
_trunc_normal_(tensor, 0, 1.0, a, b)
tensor.mul_(std).add_(mean)
return tensor
@ -111,10 +111,12 @@ def variance_scaling_(tensor, scale=1.0, mode='fan_in', distribution='normal'):
# constant is stddev of standard normal truncated to (-2, 2)
trunc_normal_tf_(tensor, std=math.sqrt(variance) / .87962566103423978)
elif distribution == "normal":
tensor.normal_(std=math.sqrt(variance))
with torch.no_grad():
tensor.normal_(std=math.sqrt(variance))
elif distribution == "uniform":
bound = math.sqrt(3 * variance)
tensor.uniform_(-bound, bound)
with torch.no_grad():
tensor.uniform_(-bound, bound)
else:
raise ValueError(f"invalid distribution {distribution}")

@ -39,7 +39,7 @@ Hacked together by / Copyright 2022, Ross Wightman
import math
from collections import OrderedDict
from dataclasses import dataclass
from dataclasses import dataclass, replace
from functools import partial
from typing import Callable, Optional, Union, Tuple, List
@ -108,10 +108,13 @@ default_cfgs = {
# Experimental configs
'maxvit_pico_rw_256': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8)),
'maxvit_nano_rw_256': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/maxvit_nano_rw_256_sw-3e790ce3.pth',
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/maxvit_nano_rw_256_sw-fb127241.pth',
input_size=(3, 256, 256), pool_size=(8, 8)),
'maxvit_tiny_rw_224': _cfg(url=''),
'maxvit_tiny_rw_256': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8)),
'maxvit_rmlp_nano_rw_256': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/maxvit_rmlp_nano_rw_256_sw-c17bb0d6.pth',
input_size=(3, 256, 256), pool_size=(8, 8)),
'maxvit_tiny_pm_256': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8)),
'maxxvit_nano_rw_256': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8)),
@ -129,21 +132,30 @@ class MaxxVitTransformerCfg:
dim_head: int = 32
expand_ratio: float = 4.0
expand_first: bool = True
shortcut_bias: bool = True,
shortcut_bias: bool = True
attn_bias: bool = True
attn_drop: float = 0.
proj_drop: float = 0.
pool_type: str = 'avg2'
rel_pos_type: str = 'bias'
rel_pos_dim: int = 512 # for relative position types w/ MLP
window_size: Tuple[int, int] = (7, 7)
grid_size: Tuple[int, int] = (7, 7)
partition_stride: int = 32
window_size: Optional[Tuple[int, int]] = None
grid_size: Optional[Tuple[int, int]] = None
init_values: Optional[float] = None
act_layer: str = 'gelu'
norm_layer: str = 'layernorm2d'
norm_layer_cl: str = 'layernorm'
norm_eps: float = 1e-6
def __post_init__(self):
if self.grid_size is not None:
self.grid_size = to_2tuple(self.grid_size)
if self.window_size is not None:
self.window_size = to_2tuple(self.window_size)
if self.grid_size is None:
self.grid_size = self.window_size
@dataclass
class MaxxVitConvCfg:
@ -249,7 +261,7 @@ def _rw_max_cfg(
conv_norm_layer='',
transformer_norm_layer='layernorm2d',
transformer_norm_layer_cl='layernorm',
window_size=7,
window_size=None,
dim_head=32,
rel_pos_type='bias',
rel_pos_dim=512,
@ -259,8 +271,6 @@ def _rw_max_cfg(
# - mbconv expansion calculated from input instead of output chs
# - mbconv shortcut and final 1x1 conv did not have a bias
# - mbconv uses silu in timm, not gelu
# - avg pool with kernel_size=2 favoured downsampling (instead of maxpool for coat)
# - default to avg pool for mbconv downsample instead of 1x1 or dw conv
# - expansion in attention block done via output proj, not input proj
return dict(
conv_cfg=MaxxVitConvCfg(
@ -276,8 +286,7 @@ def _rw_max_cfg(
expand_first=False,
pool_type=pool_type,
dim_head=dim_head,
window_size=to_2tuple(window_size),
grid_size=to_2tuple(window_size),
window_size=window_size,
norm_layer=transformer_norm_layer,
norm_layer_cl=transformer_norm_layer_cl,
rel_pos_type=rel_pos_type,
@ -293,7 +302,7 @@ def _next_cfg(
conv_norm_layer_cl='layernorm',
transformer_norm_layer='layernorm2d',
transformer_norm_layer_cl='layernorm',
window_size=7,
window_size=None,
rel_pos_type='bias',
rel_pos_dim=512,
):
@ -310,8 +319,7 @@ def _next_cfg(
transformer_cfg=MaxxVitTransformerCfg(
expand_first=False,
pool_type=pool_type,
window_size=to_2tuple(window_size),
grid_size=to_2tuple(window_size),
window_size=window_size,
norm_layer=transformer_norm_layer,
norm_layer_cl=transformer_norm_layer_cl,
rel_pos_type=rel_pos_type,
@ -411,18 +419,19 @@ model_cfgs = dict(
rel_pos_dim=384, # was supposed to be 512, woops
),
),
coatnext_nano_rw_224=MaxxVitCfg(
coatnet_nano_cc_224=MaxxVitCfg(
embed_dim=(64, 128, 256, 512),
depths=(3, 4, 6, 3),
stem_width=(32, 64),
**_next_cfg(),
block_type=('C', 'C', ('C', 'T'), ('C', 'T')),
**_rw_coat_cfg(),
),
coatnet_nano_cc_224=MaxxVitCfg(
coatnext_nano_rw_224=MaxxVitCfg(
embed_dim=(64, 128, 256, 512),
depths=(3, 4, 6, 3),
stem_width=(32, 64),
block_type=('C', 'C', ('C', 'T'), ('C', 'T')),
**_rw_coat_cfg(),
weight_init='normal',
**_next_cfg(),
),
# Trying to be like the CoAtNet paper configs
@ -463,14 +472,14 @@ model_cfgs = dict(
depths=(2, 2, 5, 2),
block_type=('M',) * 4,
stem_width=(24, 32),
**_rw_max_cfg(window_size=8),
**_rw_max_cfg(),
),
maxvit_nano_rw_256=MaxxVitCfg(
embed_dim=(64, 128, 256, 512),
depths=(1, 2, 3, 1),
block_type=('M',) * 4,
stem_width=(32, 64),
**_rw_max_cfg(window_size=8),
**_rw_max_cfg(),
),
maxvit_tiny_rw_224=MaxxVitCfg(
embed_dim=(64, 128, 256, 512),
@ -484,21 +493,29 @@ model_cfgs = dict(
depths=(2, 2, 5, 2),
block_type=('M',) * 4,
stem_width=(32, 64),
**_rw_max_cfg(window_size=8),
**_rw_max_cfg(),
),
maxvit_rmlp_nano_rw_256=MaxxVitCfg(
embed_dim=(64, 128, 256, 512),
depths=(1, 2, 3, 1),
block_type=('M',) * 4,
stem_width=(32, 64),
**_rw_max_cfg(rel_pos_type='mlp'),
),
maxvit_tiny_pm_256=MaxxVitCfg(
embed_dim=(64, 128, 256, 512),
depths=(2, 2, 5, 2),
block_type=('PM',) * 4,
stem_width=(32, 64),
**_rw_max_cfg(window_size=8),
**_rw_max_cfg(),
),
maxxvit_nano_rw_256=MaxxVitCfg(
embed_dim=(64, 128, 256, 512),
depths=(1, 2, 3, 1),
block_type=('M',) * 4,
stem_width=(32, 64),
**_next_cfg(window_size=8),
weight_init='normal',
**_next_cfg(),
),
# Trying to be like the MaxViT paper configs
@ -651,7 +668,11 @@ class LayerScale2d(nn.Module):
class Downsample2d(nn.Module):
""" A downsample pooling module for Coat that handles 2d <-> 1d conversion
""" A downsample pooling module supporting several maxpool and avgpool modes
* 'max' - MaxPool2d w/ kernel_size 3, stride 2, padding 1
* 'max2' - MaxPool2d w/ kernel_size = stride = 2
* 'avg' - AvgPool2d w/ kernel_size 3, stride 2, padding 1
* 'avg2' - AvgPool2d w/ kernel_size = stride = 2
"""
def __init__(
@ -710,6 +731,11 @@ def _init_transformer(module, name, scheme=''):
class TransformerBlock2d(nn.Module):
""" Transformer block with 2D downsampling
'2D' NCHW tensor layout
Some gains can be seen on GPU using a 1D / CL block, BUT w/ the need to switch back/forth to NCHW
for spatial pooling, the benefit is minimal so ended up using just this variant for CoAt configs.
This impl was faster on TPU w/ PT XLA than the 1D experiment.
"""
def __init__(
@ -1011,9 +1037,9 @@ def get_rel_pos_cls(cfg: MaxxVitTransformerCfg, window_size):
return rel_pos_cls
class PartitionAttention(nn.Module):
class PartitionAttentionCl(nn.Module):
""" Grid or Block partition + Attn + FFN.
NxC tensor layout.
NxC 'channels last' tensor layout.
"""
def __init__(
@ -1183,6 +1209,7 @@ def grid_reverse_nchw(windows, grid_size: List[int], img_size: List[int]):
class PartitionAttention2d(nn.Module):
""" Grid or Block partition + Attn + FFN
'2D' NCHW tensor layout.
"""
@ -1245,7 +1272,7 @@ class PartitionAttention2d(nn.Module):
class MaxxVitBlock(nn.Module):
"""
""" MaxVit conv, window partition + FFN , grid partition + FFN
"""
def __init__(
@ -1264,7 +1291,7 @@ class MaxxVitBlock(nn.Module):
self.conv = conv_cls(dim, dim_out, stride=stride, cfg=conv_cfg, drop_path=drop_path)
attn_kwargs = dict(dim=dim_out, cfg=transformer_cfg, drop_path=drop_path)
partition_layer = PartitionAttention2d if use_nchw_attn else PartitionAttention
partition_layer = PartitionAttention2d if use_nchw_attn else PartitionAttentionCl
self.nchw_attn = use_nchw_attn
self.attn_block = partition_layer(**attn_kwargs)
self.attn_grid = partition_layer(partition_type='grid', **attn_kwargs)
@ -1288,7 +1315,8 @@ class MaxxVitBlock(nn.Module):
class ParallelMaxxVitBlock(nn.Module):
"""
""" MaxVit block with parallel cat(window + grid), one FF
Experimental timm block.
"""
def __init__(
@ -1426,8 +1454,19 @@ class Stem(nn.Module):
return x
def cfg_window_size(cfg: MaxxVitTransformerCfg, img_size: Tuple[int, int]):
if cfg.window_size is not None:
assert cfg.grid_size
return cfg
partition_size = img_size[0] // cfg.partition_stride, img_size[1] // cfg.partition_stride
cfg = replace(cfg, window_size=partition_size, grid_size=partition_size)
return cfg
class MaxxVit(nn.Module):
"""
""" CoaTNet + MaxVit base model.
Highly configurable for different block compositions, tensor layouts, pooling types.
"""
def __init__(
@ -1442,6 +1481,7 @@ class MaxxVit(nn.Module):
):
super().__init__()
img_size = to_2tuple(img_size)
transformer_cfg = cfg_window_size(cfg.transformer_cfg, img_size)
self.num_classes = num_classes
self.global_pool = global_pool
self.num_features = cfg.embed_dim[-1]
@ -1475,7 +1515,7 @@ class MaxxVit(nn.Module):
depth=cfg.depths[i],
block_types=cfg.block_type[i],
conv_cfg=cfg.conv_cfg,
transformer_cfg=cfg.transformer_cfg,
transformer_cfg=transformer_cfg,
feat_size=feat_size,
drop_path=dpr[i],
)]
@ -1658,6 +1698,11 @@ def maxvit_tiny_rw_256(pretrained=False, **kwargs):
return _create_maxxvit('maxvit_tiny_rw_256', pretrained=pretrained, **kwargs)
@register_model
def maxvit_rmlp_nano_rw_256(pretrained=False, **kwargs):
return _create_maxxvit('maxvit_rmlp_nano_rw_256', pretrained=pretrained, **kwargs)
@register_model
def maxvit_tiny_pm_256(pretrained=False, **kwargs):
return _create_maxxvit('maxvit_tiny_pm_256', pretrained=pretrained, **kwargs)

@ -57,6 +57,8 @@ default_cfgs = dict(
mvitv2_huge_in21k=_cfg(
url='https://dl.fbaipublicfiles.com/mvit/mvitv2_models/MViTv2_H_in21k.pyth',
num_classes=19168),
mvitv2_small_cls=_cfg(url=''),
)
@ -135,6 +137,11 @@ model_cfgs = dict(
num_heads=2,
expand_attn=False,
),
mvitv2_small_cls=MultiScaleVitCfg(
depths=(1, 2, 11, 2),
use_cls_token=True,
),
)
@ -641,7 +648,7 @@ class MultiScaleBlock(nn.Module):
if self.shortcut_pool_attn is None:
return x
if self.has_cls_token:
cls_tok, x = x[:, :, :1, :], x[:, :, 1:, :]
cls_tok, x = x[:, :1, :], x[:, 1:, :]
else:
cls_tok = None
B, L, C = x.shape
@ -650,7 +657,7 @@ class MultiScaleBlock(nn.Module):
x = self.shortcut_pool_attn(x)
x = x.reshape(B, C, -1).transpose(1, 2)
if cls_tok is not None:
x = torch.cat((cls_tok, x), dim=2)
x = torch.cat((cls_tok, x), dim=1)
return x
def forward(self, x, feat_size: List[int]):
@ -996,3 +1003,8 @@ def mvitv2_large(pretrained=False, **kwargs):
# @register_model
# def mvitv2_huge_in21k(pretrained=False, **kwargs):
# return _create_mvitv2('mvitv2_huge_in21k', pretrained=pretrained, **kwargs)
@register_model
def mvitv2_small_cls(pretrained=False, **kwargs):
return _create_mvitv2('mvitv2_small_cls', pretrained=pretrained, **kwargs)

Loading…
Cancel
Save