Merge remote-tracking branch 'origin/master' into bits_and_tpu

pull/804/merge^2
Ross Wightman 2 years ago
commit 87bfb055c0

@ -2,4 +2,4 @@ blank_issues_enabled: false
contact_links: contact_links:
- name: Community Discussions - name: Community Discussions
url: https://github.com/rwightman/pytorch-image-models/discussions url: https://github.com/rwightman/pytorch-image-models/discussions
about: Issues are for features and bugs. Questions can be asked in Discussions. about: Hparam request in issues will be ignored! Issues are for features and bugs. Questions can be asked in Discussions.

@ -1,8 +1,7 @@
--- ---
name: Feature request name: Feature request
about: Suggest an idea for this project. Issues are for reporting bugs or requesting about: Suggest an idea for this project. Hparam requests, training help are not feature requests.
features, the discussion forum is available for asking questions or seeking help The discussion forum is available for asking questions or seeking help from the community.
from the community.
title: "[FEATURE] Feature title..." title: "[FEATURE] Feature title..."
labels: enhancement labels: enhancement
assignees: '' assignees: ''

@ -21,6 +21,30 @@ And a big thanks to all GitHub sponsors who helped with some of my costs before
## What's New ## What's New
### Aug 29, 2022
* MaxVit window size scales with img_size by default. Add new RelPosMlp MaxViT weight that leverages this:
* `maxvit_rmlp_nano_rw_256` - 83.0 @ 256, 83.6 @ 320 (T)
### Aug 26, 2022
* CoAtNet (https://arxiv.org/abs/2106.04803) and MaxVit (https://arxiv.org/abs/2204.01697) `timm` original models
* both found in [`maxxvit.py`](https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/maxxvit.py) model def, contains numerous experiments outside scope of original papers
* an unfinished Tensorflow version from MaxVit authors can be found https://github.com/google-research/maxvit
* Initial CoAtNet and MaxVit timm pretrained weights (working on more):
* `coatnet_nano_rw_224` - 81.7 @ 224 (T)
* `coatnet_rmlp_nano_rw_224` - 82.0 @ 224, 82.8 @ 320 (T)
* `coatnet_0_rw_224` - 82.4 (T) -- NOTE timm '0' coatnets have 2 more 3rd stage blocks
* `coatnet_bn_0_rw_224` - 82.4 (T)
* `maxvit_nano_rw_256` - 82.9 @ 256 (T)
* `coatnet_rmlp_1_rw_224` - 83.4 @ 224, 84 @ 320 (T)
* `coatnet_1_rw_224` - 83.6 @ 224 (G)
* (T) = TPU trained with `bits_and_tpu` branch training code, (G) = GPU trained
* GCVit (weights adapted from https://github.com/NVlabs/GCVit, code 100% `timm` re-write for license purposes)
* MViT-V2 (multi-scale vit, adapted from https://github.com/facebookresearch/mvit)
* EfficientFormer (adapted from https://github.com/snap-research/EfficientFormer)
* PyramidVisionTransformer-V2 (adapted from https://github.com/whai362/PVT)
* 'Fast Norm' support for LayerNorm and GroupNorm that avoids float32 upcast w/ AMP (uses APEX LN if available for further boost)
### Aug 15, 2022 ### Aug 15, 2022
* ConvNeXt atto weights added * ConvNeXt atto weights added
* `convnext_atto` - 75.7 @ 224, 77.0 @ 288 * `convnext_atto` - 75.7 @ 224, 77.0 @ 288
@ -229,6 +253,7 @@ A full version of the list below with source links can be found in the [document
* Bottleneck Transformers - https://arxiv.org/abs/2101.11605 * Bottleneck Transformers - https://arxiv.org/abs/2101.11605
* CaiT (Class-Attention in Image Transformers) - https://arxiv.org/abs/2103.17239 * CaiT (Class-Attention in Image Transformers) - https://arxiv.org/abs/2103.17239
* CoaT (Co-Scale Conv-Attentional Image Transformers) - https://arxiv.org/abs/2104.06399 * CoaT (Co-Scale Conv-Attentional Image Transformers) - https://arxiv.org/abs/2104.06399
* CoAtNet (Convolution and Attention) - https://arxiv.org/abs/2106.04803
* ConvNeXt - https://arxiv.org/abs/2201.03545 * ConvNeXt - https://arxiv.org/abs/2201.03545
* ConViT (Soft Convolutional Inductive Biases Vision Transformers)- https://arxiv.org/abs/2103.10697 * ConViT (Soft Convolutional Inductive Biases Vision Transformers)- https://arxiv.org/abs/2103.10697
* CspNet (Cross-Stage Partial Networks) - https://arxiv.org/abs/1911.11929 * CspNet (Cross-Stage Partial Networks) - https://arxiv.org/abs/1911.11929
@ -238,6 +263,7 @@ A full version of the list below with source links can be found in the [document
* DLA - https://arxiv.org/abs/1707.06484 * DLA - https://arxiv.org/abs/1707.06484
* DPN (Dual-Path Network) - https://arxiv.org/abs/1707.01629 * DPN (Dual-Path Network) - https://arxiv.org/abs/1707.01629
* EdgeNeXt - https://arxiv.org/abs/2206.10589 * EdgeNeXt - https://arxiv.org/abs/2206.10589
* EfficientFormer - https://arxiv.org/abs/2206.01191
* EfficientNet (MBConvNet Family) * EfficientNet (MBConvNet Family)
* EfficientNet NoisyStudent (B0-B7, L2) - https://arxiv.org/abs/1911.04252 * EfficientNet NoisyStudent (B0-B7, L2) - https://arxiv.org/abs/1911.04252
* EfficientNet AdvProp (B0-B8) - https://arxiv.org/abs/1911.09665 * EfficientNet AdvProp (B0-B8) - https://arxiv.org/abs/1911.09665
@ -250,6 +276,7 @@ A full version of the list below with source links can be found in the [document
* MobileNet-V2 - https://arxiv.org/abs/1801.04381 * MobileNet-V2 - https://arxiv.org/abs/1801.04381
* Single-Path NAS - https://arxiv.org/abs/1904.02877 * Single-Path NAS - https://arxiv.org/abs/1904.02877
* TinyNet - https://arxiv.org/abs/2010.14819 * TinyNet - https://arxiv.org/abs/2010.14819
* GCViT (Global Context Vision Transformer) - https://arxiv.org/abs/2206.09959
* GhostNet - https://arxiv.org/abs/1911.11907 * GhostNet - https://arxiv.org/abs/1911.11907
* gMLP - https://arxiv.org/abs/2105.08050 * gMLP - https://arxiv.org/abs/2105.08050
* GPU-Efficient Networks - https://arxiv.org/abs/2006.14090 * GPU-Efficient Networks - https://arxiv.org/abs/2006.14090
@ -259,6 +286,7 @@ A full version of the list below with source links can be found in the [document
* Inception-ResNet-V2 and Inception-V4 - https://arxiv.org/abs/1602.07261 * Inception-ResNet-V2 and Inception-V4 - https://arxiv.org/abs/1602.07261
* Lambda Networks - https://arxiv.org/abs/2102.08602 * Lambda Networks - https://arxiv.org/abs/2102.08602
* LeViT (Vision Transformer in ConvNet's Clothing) - https://arxiv.org/abs/2104.01136 * LeViT (Vision Transformer in ConvNet's Clothing) - https://arxiv.org/abs/2104.01136
* MaxViT (Multi-Axis Vision Transformer) - https://arxiv.org/abs/2204.01697
* MLP-Mixer - https://arxiv.org/abs/2105.01601 * MLP-Mixer - https://arxiv.org/abs/2105.01601
* MobileNet-V3 (MBConvNet w/ Efficient Head) - https://arxiv.org/abs/1905.02244 * MobileNet-V3 (MBConvNet w/ Efficient Head) - https://arxiv.org/abs/1905.02244
* FBNet-V3 - https://arxiv.org/abs/2006.02049 * FBNet-V3 - https://arxiv.org/abs/2006.02049
@ -266,6 +294,7 @@ A full version of the list below with source links can be found in the [document
* LCNet - https://arxiv.org/abs/2109.15099 * LCNet - https://arxiv.org/abs/2109.15099
* MobileViT - https://arxiv.org/abs/2110.02178 * MobileViT - https://arxiv.org/abs/2110.02178
* MobileViT-V2 - https://arxiv.org/abs/2206.02680 * MobileViT-V2 - https://arxiv.org/abs/2206.02680
* MViT-V2 (Improved Multiscale Vision Transformer) - https://arxiv.org/abs/2112.01526
* NASNet-A - https://arxiv.org/abs/1707.07012 * NASNet-A - https://arxiv.org/abs/1707.07012
* NesT - https://arxiv.org/abs/2105.12723 * NesT - https://arxiv.org/abs/2105.12723
* NFNet-F - https://arxiv.org/abs/2102.06171 * NFNet-F - https://arxiv.org/abs/2102.06171
@ -273,6 +302,7 @@ A full version of the list below with source links can be found in the [document
* PNasNet - https://arxiv.org/abs/1712.00559 * PNasNet - https://arxiv.org/abs/1712.00559
* PoolFormer (MetaFormer) - https://arxiv.org/abs/2111.11418 * PoolFormer (MetaFormer) - https://arxiv.org/abs/2111.11418
* Pooling-based Vision Transformer (PiT) - https://arxiv.org/abs/2103.16302 * Pooling-based Vision Transformer (PiT) - https://arxiv.org/abs/2103.16302
* PVT-V2 (Improved Pyramid Vision Transformer) - https://arxiv.org/abs/2106.13797
* RegNet - https://arxiv.org/abs/2003.13678 * RegNet - https://arxiv.org/abs/2003.13678
* RegNetZ - https://arxiv.org/abs/2103.06877 * RegNetZ - https://arxiv.org/abs/2103.06877
* RepVGG - https://arxiv.org/abs/2101.03697 * RepVGG - https://arxiv.org/abs/2101.03697

@ -19,7 +19,7 @@ import torch.nn as nn
import torch.nn.parallel import torch.nn.parallel
from timm.data import resolve_data_config from timm.data import resolve_data_config
from timm.models import create_model, is_model, list_models from timm.models import create_model, is_model, list_models, set_fast_norm
from timm.optim import create_optimizer_v2 from timm.optim import create_optimizer_v2
from timm.utils import setup_default_logging, set_jit_fuser, decay_batch_step, check_batch_size_retry from timm.utils import setup_default_logging, set_jit_fuser, decay_batch_step, check_batch_size_retry
@ -109,7 +109,8 @@ scripting_group.add_argument('--torchscript', dest='torchscript', action='store_
help='convert model torchscript for inference') help='convert model torchscript for inference')
scripting_group.add_argument('--aot-autograd', default=False, action='store_true', scripting_group.add_argument('--aot-autograd', default=False, action='store_true',
help="Enable AOT Autograd support. (It's recommended to use this option with `--fuser nvfuser` together)") help="Enable AOT Autograd support. (It's recommended to use this option with `--fuser nvfuser` together)")
scripting_group.add_argument('--fast-norm', default=False, action='store_true',
help='enable experimental fast-norm')
# train optimizer parameters # train optimizer parameters
parser.add_argument('--opt', default='sgd', type=str, metavar='OPTIMIZER', parser.add_argument('--opt', default='sgd', type=str, metavar='OPTIMIZER',
@ -598,6 +599,9 @@ def main():
model_cfgs = [] model_cfgs = []
model_names = [] model_names = []
if args.fast_norm:
set_fast_norm()
if args.model_list: if args.model_list:
args.model = '' args.model = ''
with open(args.model_list) as f: with open(args.model_list) as f:

@ -69,5 +69,6 @@ from .helpers import load_checkpoint, resume_checkpoint, model_parameters
from .layers import TestTimePoolHead, apply_test_time_pool from .layers import TestTimePoolHead, apply_test_time_pool
from .layers import convert_splitbn_model, convert_sync_batchnorm from .layers import convert_splitbn_model, convert_sync_batchnorm
from .layers import is_scriptable, is_exportable, set_scriptable, set_exportable, is_no_jit, set_no_jit from .layers import is_scriptable, is_exportable, set_scriptable, set_exportable, is_no_jit, set_no_jit
from .layers import set_fast_norm
from .registry import register_model, model_entrypoint, list_models, is_model, list_modules, is_model_in_modules,\ from .registry import register_model, model_entrypoint, list_models, is_model, list_modules, is_model_in_modules,\
is_model_pretrained, get_pretrained_cfg, has_pretrained_cfg_key, is_pretrained_cfg_key, get_pretrained_cfg_value is_model_pretrained, get_pretrained_cfg, has_pretrained_cfg_key, is_pretrained_cfg_key, get_pretrained_cfg_value

@ -115,7 +115,7 @@ class DenseBlock(nn.ModuleDict):
_version = 2 _version = 2
def __init__( def __init__(
self, num_layers, num_input_features, bn_size, growth_rate, norm_layer=nn.ReLU, self, num_layers, num_input_features, bn_size, growth_rate, norm_layer=BatchNormAct2d,
drop_rate=0., memory_efficient=False): drop_rate=0., memory_efficient=False):
super(DenseBlock, self).__init__() super(DenseBlock, self).__init__()
for i in range(num_layers): for i in range(num_layers):
@ -138,7 +138,7 @@ class DenseBlock(nn.ModuleDict):
class DenseTransition(nn.Sequential): class DenseTransition(nn.Sequential):
def __init__(self, num_input_features, num_output_features, norm_layer=nn.BatchNorm2d, aa_layer=None): def __init__(self, num_input_features, num_output_features, norm_layer=BatchNormAct2d, aa_layer=None):
super(DenseTransition, self).__init__() super(DenseTransition, self).__init__()
self.add_module('norm', norm_layer(num_input_features)) self.add_module('norm', norm_layer(num_input_features))
self.add_module('conv', nn.Conv2d( self.add_module('conv', nn.Conv2d(

@ -1,3 +1,11 @@
""" 'Fast' Normalization Functions
For GroupNorm and LayerNorm these functions bypass typical AMP upcast to float32.
Additionally, for LayerNorm, the APEX fused LN is used if available (which also does not upcast)
Hacked together by / Copyright 2022 Ross Wightman
"""
from typing import List, Optional from typing import List, Optional
import torch import torch
@ -37,6 +45,7 @@ def fast_group_norm(
if torch.is_autocast_enabled(): if torch.is_autocast_enabled():
# normally native AMP casts GN inputs to float32 # normally native AMP casts GN inputs to float32
# here we use the low precision autocast dtype # here we use the low precision autocast dtype
# FIXME what to do re CPU autocast?
dt = torch.get_autocast_gpu_dtype() dt = torch.get_autocast_gpu_dtype()
x, weight, bias = x.to(dt), weight.to(dt), bias.to(dt) x, weight, bias = x.to(dt), weight.to(dt), bias.to(dt)
@ -62,6 +71,7 @@ def fast_layer_norm(
# normally native AMP casts LN inputs to float32 # normally native AMP casts LN inputs to float32
# apex LN does not, this is behaving like Apex # apex LN does not, this is behaving like Apex
dt = torch.get_autocast_gpu_dtype() dt = torch.get_autocast_gpu_dtype()
# FIXME what to do re CPU autocast?
x, weight, bias = x.to(dt), weight.to(dt), bias.to(dt) x, weight, bias = x.to(dt), weight.to(dt), bias.to(dt)
with torch.cuda.amp.autocast(enabled=False): with torch.cuda.amp.autocast(enabled=False):

@ -1,4 +1,8 @@
""" Normalization layers and wrappers """ Normalization layers and wrappers
Norm layer definitions that support fast norm and consistent channel arg order (always first arg).
Hacked together by / Copyright 2022 Ross Wightman
""" """
import torch import torch

@ -1,4 +1,16 @@
""" Normalization + Activation Layers """ Normalization + Activation Layers
Provides Norm+Act fns for standard PyTorch norm layers such as
* BatchNorm
* GroupNorm
* LayerNorm
This allows swapping with alternative layers that are natively both norm + act such as
* EvoNorm (evo_norm.py)
* FilterResponseNorm (filter_response_norm.py)
* InplaceABN (inplace_abn.py)
Hacked together by / Copyright 2022 Ross Wightman
""" """
from typing import Union, List, Optional, Any from typing import Union, List, Optional, Any

@ -5,7 +5,7 @@ import warnings
from torch.nn.init import _calculate_fan_in_and_fan_out from torch.nn.init import _calculate_fan_in_and_fan_out
def _no_grad_trunc_normal_(tensor, mean, std, a, b): def _trunc_normal_(tensor, mean, std, a, b):
# Cut & paste from PyTorch official master until it's in a few official releases - RW # Cut & paste from PyTorch official master until it's in a few official releases - RW
# Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
def norm_cdf(x): def norm_cdf(x):
@ -17,28 +17,27 @@ def _no_grad_trunc_normal_(tensor, mean, std, a, b):
"The distribution of values may be incorrect.", "The distribution of values may be incorrect.",
stacklevel=2) stacklevel=2)
with torch.no_grad(): # Values are generated by using a truncated uniform distribution and
# Values are generated by using a truncated uniform distribution and # then using the inverse CDF for the normal distribution.
# then using the inverse CDF for the normal distribution. # Get upper and lower cdf values
# Get upper and lower cdf values l = norm_cdf((a - mean) / std)
l = norm_cdf((a - mean) / std) u = norm_cdf((b - mean) / std)
u = norm_cdf((b - mean) / std)
# Uniformly fill tensor with values from [l, u], then translate to # Uniformly fill tensor with values from [l, u], then translate to
# [2l-1, 2u-1]. # [2l-1, 2u-1].
tensor.uniform_(2 * l - 1, 2 * u - 1) tensor.uniform_(2 * l - 1, 2 * u - 1)
# Use inverse cdf transform for normal distribution to get truncated # Use inverse cdf transform for normal distribution to get truncated
# standard normal # standard normal
tensor.erfinv_() tensor.erfinv_()
# Transform to proper mean, std # Transform to proper mean, std
tensor.mul_(std * math.sqrt(2.)) tensor.mul_(std * math.sqrt(2.))
tensor.add_(mean) tensor.add_(mean)
# Clamp to ensure it's in the proper range # Clamp to ensure it's in the proper range
tensor.clamp_(min=a, max=b) tensor.clamp_(min=a, max=b)
return tensor return tensor
def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.): def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
@ -64,7 +63,8 @@ def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
>>> w = torch.empty(3, 5) >>> w = torch.empty(3, 5)
>>> nn.init.trunc_normal_(w) >>> nn.init.trunc_normal_(w)
""" """
return _no_grad_trunc_normal_(tensor, mean, std, a, b) with torch.no_grad():
return _trunc_normal_(tensor, mean, std, a, b)
def trunc_normal_tf_(tensor, mean=0., std=1., a=-2., b=2.): def trunc_normal_tf_(tensor, mean=0., std=1., a=-2., b=2.):
@ -90,8 +90,8 @@ def trunc_normal_tf_(tensor, mean=0., std=1., a=-2., b=2.):
>>> w = torch.empty(3, 5) >>> w = torch.empty(3, 5)
>>> nn.init.trunc_normal_(w) >>> nn.init.trunc_normal_(w)
""" """
_no_grad_trunc_normal_(tensor, 0, 1.0, a, b)
with torch.no_grad(): with torch.no_grad():
_trunc_normal_(tensor, 0, 1.0, a, b)
tensor.mul_(std).add_(mean) tensor.mul_(std).add_(mean)
return tensor return tensor
@ -111,10 +111,12 @@ def variance_scaling_(tensor, scale=1.0, mode='fan_in', distribution='normal'):
# constant is stddev of standard normal truncated to (-2, 2) # constant is stddev of standard normal truncated to (-2, 2)
trunc_normal_tf_(tensor, std=math.sqrt(variance) / .87962566103423978) trunc_normal_tf_(tensor, std=math.sqrt(variance) / .87962566103423978)
elif distribution == "normal": elif distribution == "normal":
tensor.normal_(std=math.sqrt(variance)) with torch.no_grad():
tensor.normal_(std=math.sqrt(variance))
elif distribution == "uniform": elif distribution == "uniform":
bound = math.sqrt(3 * variance) bound = math.sqrt(3 * variance)
tensor.uniform_(-bound, bound) with torch.no_grad():
tensor.uniform_(-bound, bound)
else: else:
raise ValueError(f"invalid distribution {distribution}") raise ValueError(f"invalid distribution {distribution}")

@ -39,7 +39,7 @@ Hacked together by / Copyright 2022, Ross Wightman
import math import math
from collections import OrderedDict from collections import OrderedDict
from dataclasses import dataclass from dataclasses import dataclass, replace
from functools import partial from functools import partial
from typing import Callable, Optional, Union, Tuple, List from typing import Callable, Optional, Union, Tuple, List
@ -108,10 +108,13 @@ default_cfgs = {
# Experimental configs # Experimental configs
'maxvit_pico_rw_256': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8)), 'maxvit_pico_rw_256': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8)),
'maxvit_nano_rw_256': _cfg( 'maxvit_nano_rw_256': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/maxvit_nano_rw_256_sw-3e790ce3.pth', url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/maxvit_nano_rw_256_sw-fb127241.pth',
input_size=(3, 256, 256), pool_size=(8, 8)), input_size=(3, 256, 256), pool_size=(8, 8)),
'maxvit_tiny_rw_224': _cfg(url=''), 'maxvit_tiny_rw_224': _cfg(url=''),
'maxvit_tiny_rw_256': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8)), 'maxvit_tiny_rw_256': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8)),
'maxvit_rmlp_nano_rw_256': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/maxvit_rmlp_nano_rw_256_sw-c17bb0d6.pth',
input_size=(3, 256, 256), pool_size=(8, 8)),
'maxvit_tiny_pm_256': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8)), 'maxvit_tiny_pm_256': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8)),
'maxxvit_nano_rw_256': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8)), 'maxxvit_nano_rw_256': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8)),
@ -129,21 +132,30 @@ class MaxxVitTransformerCfg:
dim_head: int = 32 dim_head: int = 32
expand_ratio: float = 4.0 expand_ratio: float = 4.0
expand_first: bool = True expand_first: bool = True
shortcut_bias: bool = True, shortcut_bias: bool = True
attn_bias: bool = True attn_bias: bool = True
attn_drop: float = 0. attn_drop: float = 0.
proj_drop: float = 0. proj_drop: float = 0.
pool_type: str = 'avg2' pool_type: str = 'avg2'
rel_pos_type: str = 'bias' rel_pos_type: str = 'bias'
rel_pos_dim: int = 512 # for relative position types w/ MLP rel_pos_dim: int = 512 # for relative position types w/ MLP
window_size: Tuple[int, int] = (7, 7) partition_stride: int = 32
grid_size: Tuple[int, int] = (7, 7) window_size: Optional[Tuple[int, int]] = None
grid_size: Optional[Tuple[int, int]] = None
init_values: Optional[float] = None init_values: Optional[float] = None
act_layer: str = 'gelu' act_layer: str = 'gelu'
norm_layer: str = 'layernorm2d' norm_layer: str = 'layernorm2d'
norm_layer_cl: str = 'layernorm' norm_layer_cl: str = 'layernorm'
norm_eps: float = 1e-6 norm_eps: float = 1e-6
def __post_init__(self):
if self.grid_size is not None:
self.grid_size = to_2tuple(self.grid_size)
if self.window_size is not None:
self.window_size = to_2tuple(self.window_size)
if self.grid_size is None:
self.grid_size = self.window_size
@dataclass @dataclass
class MaxxVitConvCfg: class MaxxVitConvCfg:
@ -249,7 +261,7 @@ def _rw_max_cfg(
conv_norm_layer='', conv_norm_layer='',
transformer_norm_layer='layernorm2d', transformer_norm_layer='layernorm2d',
transformer_norm_layer_cl='layernorm', transformer_norm_layer_cl='layernorm',
window_size=7, window_size=None,
dim_head=32, dim_head=32,
rel_pos_type='bias', rel_pos_type='bias',
rel_pos_dim=512, rel_pos_dim=512,
@ -259,8 +271,6 @@ def _rw_max_cfg(
# - mbconv expansion calculated from input instead of output chs # - mbconv expansion calculated from input instead of output chs
# - mbconv shortcut and final 1x1 conv did not have a bias # - mbconv shortcut and final 1x1 conv did not have a bias
# - mbconv uses silu in timm, not gelu # - mbconv uses silu in timm, not gelu
# - avg pool with kernel_size=2 favoured downsampling (instead of maxpool for coat)
# - default to avg pool for mbconv downsample instead of 1x1 or dw conv
# - expansion in attention block done via output proj, not input proj # - expansion in attention block done via output proj, not input proj
return dict( return dict(
conv_cfg=MaxxVitConvCfg( conv_cfg=MaxxVitConvCfg(
@ -276,8 +286,7 @@ def _rw_max_cfg(
expand_first=False, expand_first=False,
pool_type=pool_type, pool_type=pool_type,
dim_head=dim_head, dim_head=dim_head,
window_size=to_2tuple(window_size), window_size=window_size,
grid_size=to_2tuple(window_size),
norm_layer=transformer_norm_layer, norm_layer=transformer_norm_layer,
norm_layer_cl=transformer_norm_layer_cl, norm_layer_cl=transformer_norm_layer_cl,
rel_pos_type=rel_pos_type, rel_pos_type=rel_pos_type,
@ -293,7 +302,7 @@ def _next_cfg(
conv_norm_layer_cl='layernorm', conv_norm_layer_cl='layernorm',
transformer_norm_layer='layernorm2d', transformer_norm_layer='layernorm2d',
transformer_norm_layer_cl='layernorm', transformer_norm_layer_cl='layernorm',
window_size=7, window_size=None,
rel_pos_type='bias', rel_pos_type='bias',
rel_pos_dim=512, rel_pos_dim=512,
): ):
@ -310,8 +319,7 @@ def _next_cfg(
transformer_cfg=MaxxVitTransformerCfg( transformer_cfg=MaxxVitTransformerCfg(
expand_first=False, expand_first=False,
pool_type=pool_type, pool_type=pool_type,
window_size=to_2tuple(window_size), window_size=window_size,
grid_size=to_2tuple(window_size),
norm_layer=transformer_norm_layer, norm_layer=transformer_norm_layer,
norm_layer_cl=transformer_norm_layer_cl, norm_layer_cl=transformer_norm_layer_cl,
rel_pos_type=rel_pos_type, rel_pos_type=rel_pos_type,
@ -411,18 +419,19 @@ model_cfgs = dict(
rel_pos_dim=384, # was supposed to be 512, woops rel_pos_dim=384, # was supposed to be 512, woops
), ),
), ),
coatnext_nano_rw_224=MaxxVitCfg( coatnet_nano_cc_224=MaxxVitCfg(
embed_dim=(64, 128, 256, 512), embed_dim=(64, 128, 256, 512),
depths=(3, 4, 6, 3), depths=(3, 4, 6, 3),
stem_width=(32, 64), stem_width=(32, 64),
**_next_cfg(), block_type=('C', 'C', ('C', 'T'), ('C', 'T')),
**_rw_coat_cfg(),
), ),
coatnet_nano_cc_224=MaxxVitCfg( coatnext_nano_rw_224=MaxxVitCfg(
embed_dim=(64, 128, 256, 512), embed_dim=(64, 128, 256, 512),
depths=(3, 4, 6, 3), depths=(3, 4, 6, 3),
stem_width=(32, 64), stem_width=(32, 64),
block_type=('C', 'C', ('C', 'T'), ('C', 'T')), weight_init='normal',
**_rw_coat_cfg(), **_next_cfg(),
), ),
# Trying to be like the CoAtNet paper configs # Trying to be like the CoAtNet paper configs
@ -463,14 +472,14 @@ model_cfgs = dict(
depths=(2, 2, 5, 2), depths=(2, 2, 5, 2),
block_type=('M',) * 4, block_type=('M',) * 4,
stem_width=(24, 32), stem_width=(24, 32),
**_rw_max_cfg(window_size=8), **_rw_max_cfg(),
), ),
maxvit_nano_rw_256=MaxxVitCfg( maxvit_nano_rw_256=MaxxVitCfg(
embed_dim=(64, 128, 256, 512), embed_dim=(64, 128, 256, 512),
depths=(1, 2, 3, 1), depths=(1, 2, 3, 1),
block_type=('M',) * 4, block_type=('M',) * 4,
stem_width=(32, 64), stem_width=(32, 64),
**_rw_max_cfg(window_size=8), **_rw_max_cfg(),
), ),
maxvit_tiny_rw_224=MaxxVitCfg( maxvit_tiny_rw_224=MaxxVitCfg(
embed_dim=(64, 128, 256, 512), embed_dim=(64, 128, 256, 512),
@ -484,21 +493,29 @@ model_cfgs = dict(
depths=(2, 2, 5, 2), depths=(2, 2, 5, 2),
block_type=('M',) * 4, block_type=('M',) * 4,
stem_width=(32, 64), stem_width=(32, 64),
**_rw_max_cfg(window_size=8), **_rw_max_cfg(),
),
maxvit_rmlp_nano_rw_256=MaxxVitCfg(
embed_dim=(64, 128, 256, 512),
depths=(1, 2, 3, 1),
block_type=('M',) * 4,
stem_width=(32, 64),
**_rw_max_cfg(rel_pos_type='mlp'),
), ),
maxvit_tiny_pm_256=MaxxVitCfg( maxvit_tiny_pm_256=MaxxVitCfg(
embed_dim=(64, 128, 256, 512), embed_dim=(64, 128, 256, 512),
depths=(2, 2, 5, 2), depths=(2, 2, 5, 2),
block_type=('PM',) * 4, block_type=('PM',) * 4,
stem_width=(32, 64), stem_width=(32, 64),
**_rw_max_cfg(window_size=8), **_rw_max_cfg(),
), ),
maxxvit_nano_rw_256=MaxxVitCfg( maxxvit_nano_rw_256=MaxxVitCfg(
embed_dim=(64, 128, 256, 512), embed_dim=(64, 128, 256, 512),
depths=(1, 2, 3, 1), depths=(1, 2, 3, 1),
block_type=('M',) * 4, block_type=('M',) * 4,
stem_width=(32, 64), stem_width=(32, 64),
**_next_cfg(window_size=8), weight_init='normal',
**_next_cfg(),
), ),
# Trying to be like the MaxViT paper configs # Trying to be like the MaxViT paper configs
@ -651,7 +668,11 @@ class LayerScale2d(nn.Module):
class Downsample2d(nn.Module): class Downsample2d(nn.Module):
""" A downsample pooling module for Coat that handles 2d <-> 1d conversion """ A downsample pooling module supporting several maxpool and avgpool modes
* 'max' - MaxPool2d w/ kernel_size 3, stride 2, padding 1
* 'max2' - MaxPool2d w/ kernel_size = stride = 2
* 'avg' - AvgPool2d w/ kernel_size 3, stride 2, padding 1
* 'avg2' - AvgPool2d w/ kernel_size = stride = 2
""" """
def __init__( def __init__(
@ -710,6 +731,11 @@ def _init_transformer(module, name, scheme=''):
class TransformerBlock2d(nn.Module): class TransformerBlock2d(nn.Module):
""" Transformer block with 2D downsampling """ Transformer block with 2D downsampling
'2D' NCHW tensor layout '2D' NCHW tensor layout
Some gains can be seen on GPU using a 1D / CL block, BUT w/ the need to switch back/forth to NCHW
for spatial pooling, the benefit is minimal so ended up using just this variant for CoAt configs.
This impl was faster on TPU w/ PT XLA than the 1D experiment.
""" """
def __init__( def __init__(
@ -1011,9 +1037,9 @@ def get_rel_pos_cls(cfg: MaxxVitTransformerCfg, window_size):
return rel_pos_cls return rel_pos_cls
class PartitionAttention(nn.Module): class PartitionAttentionCl(nn.Module):
""" Grid or Block partition + Attn + FFN. """ Grid or Block partition + Attn + FFN.
NxC tensor layout. NxC 'channels last' tensor layout.
""" """
def __init__( def __init__(
@ -1183,6 +1209,7 @@ def grid_reverse_nchw(windows, grid_size: List[int], img_size: List[int]):
class PartitionAttention2d(nn.Module): class PartitionAttention2d(nn.Module):
""" Grid or Block partition + Attn + FFN """ Grid or Block partition + Attn + FFN
'2D' NCHW tensor layout. '2D' NCHW tensor layout.
""" """
@ -1245,7 +1272,7 @@ class PartitionAttention2d(nn.Module):
class MaxxVitBlock(nn.Module): class MaxxVitBlock(nn.Module):
""" """ MaxVit conv, window partition + FFN , grid partition + FFN
""" """
def __init__( def __init__(
@ -1264,7 +1291,7 @@ class MaxxVitBlock(nn.Module):
self.conv = conv_cls(dim, dim_out, stride=stride, cfg=conv_cfg, drop_path=drop_path) self.conv = conv_cls(dim, dim_out, stride=stride, cfg=conv_cfg, drop_path=drop_path)
attn_kwargs = dict(dim=dim_out, cfg=transformer_cfg, drop_path=drop_path) attn_kwargs = dict(dim=dim_out, cfg=transformer_cfg, drop_path=drop_path)
partition_layer = PartitionAttention2d if use_nchw_attn else PartitionAttention partition_layer = PartitionAttention2d if use_nchw_attn else PartitionAttentionCl
self.nchw_attn = use_nchw_attn self.nchw_attn = use_nchw_attn
self.attn_block = partition_layer(**attn_kwargs) self.attn_block = partition_layer(**attn_kwargs)
self.attn_grid = partition_layer(partition_type='grid', **attn_kwargs) self.attn_grid = partition_layer(partition_type='grid', **attn_kwargs)
@ -1288,7 +1315,8 @@ class MaxxVitBlock(nn.Module):
class ParallelMaxxVitBlock(nn.Module): class ParallelMaxxVitBlock(nn.Module):
""" """ MaxVit block with parallel cat(window + grid), one FF
Experimental timm block.
""" """
def __init__( def __init__(
@ -1426,8 +1454,19 @@ class Stem(nn.Module):
return x return x
def cfg_window_size(cfg: MaxxVitTransformerCfg, img_size: Tuple[int, int]):
if cfg.window_size is not None:
assert cfg.grid_size
return cfg
partition_size = img_size[0] // cfg.partition_stride, img_size[1] // cfg.partition_stride
cfg = replace(cfg, window_size=partition_size, grid_size=partition_size)
return cfg
class MaxxVit(nn.Module): class MaxxVit(nn.Module):
""" """ CoaTNet + MaxVit base model.
Highly configurable for different block compositions, tensor layouts, pooling types.
""" """
def __init__( def __init__(
@ -1442,6 +1481,7 @@ class MaxxVit(nn.Module):
): ):
super().__init__() super().__init__()
img_size = to_2tuple(img_size) img_size = to_2tuple(img_size)
transformer_cfg = cfg_window_size(cfg.transformer_cfg, img_size)
self.num_classes = num_classes self.num_classes = num_classes
self.global_pool = global_pool self.global_pool = global_pool
self.num_features = cfg.embed_dim[-1] self.num_features = cfg.embed_dim[-1]
@ -1475,7 +1515,7 @@ class MaxxVit(nn.Module):
depth=cfg.depths[i], depth=cfg.depths[i],
block_types=cfg.block_type[i], block_types=cfg.block_type[i],
conv_cfg=cfg.conv_cfg, conv_cfg=cfg.conv_cfg,
transformer_cfg=cfg.transformer_cfg, transformer_cfg=transformer_cfg,
feat_size=feat_size, feat_size=feat_size,
drop_path=dpr[i], drop_path=dpr[i],
)] )]
@ -1658,6 +1698,11 @@ def maxvit_tiny_rw_256(pretrained=False, **kwargs):
return _create_maxxvit('maxvit_tiny_rw_256', pretrained=pretrained, **kwargs) return _create_maxxvit('maxvit_tiny_rw_256', pretrained=pretrained, **kwargs)
@register_model
def maxvit_rmlp_nano_rw_256(pretrained=False, **kwargs):
return _create_maxxvit('maxvit_rmlp_nano_rw_256', pretrained=pretrained, **kwargs)
@register_model @register_model
def maxvit_tiny_pm_256(pretrained=False, **kwargs): def maxvit_tiny_pm_256(pretrained=False, **kwargs):
return _create_maxxvit('maxvit_tiny_pm_256', pretrained=pretrained, **kwargs) return _create_maxxvit('maxvit_tiny_pm_256', pretrained=pretrained, **kwargs)

@ -57,6 +57,8 @@ default_cfgs = dict(
mvitv2_huge_in21k=_cfg( mvitv2_huge_in21k=_cfg(
url='https://dl.fbaipublicfiles.com/mvit/mvitv2_models/MViTv2_H_in21k.pyth', url='https://dl.fbaipublicfiles.com/mvit/mvitv2_models/MViTv2_H_in21k.pyth',
num_classes=19168), num_classes=19168),
mvitv2_small_cls=_cfg(url=''),
) )
@ -135,6 +137,11 @@ model_cfgs = dict(
num_heads=2, num_heads=2,
expand_attn=False, expand_attn=False,
), ),
mvitv2_small_cls=MultiScaleVitCfg(
depths=(1, 2, 11, 2),
use_cls_token=True,
),
) )
@ -641,7 +648,7 @@ class MultiScaleBlock(nn.Module):
if self.shortcut_pool_attn is None: if self.shortcut_pool_attn is None:
return x return x
if self.has_cls_token: if self.has_cls_token:
cls_tok, x = x[:, :, :1, :], x[:, :, 1:, :] cls_tok, x = x[:, :1, :], x[:, 1:, :]
else: else:
cls_tok = None cls_tok = None
B, L, C = x.shape B, L, C = x.shape
@ -650,7 +657,7 @@ class MultiScaleBlock(nn.Module):
x = self.shortcut_pool_attn(x) x = self.shortcut_pool_attn(x)
x = x.reshape(B, C, -1).transpose(1, 2) x = x.reshape(B, C, -1).transpose(1, 2)
if cls_tok is not None: if cls_tok is not None:
x = torch.cat((cls_tok, x), dim=2) x = torch.cat((cls_tok, x), dim=1)
return x return x
def forward(self, x, feat_size: List[int]): def forward(self, x, feat_size: List[int]):
@ -996,3 +1003,8 @@ def mvitv2_large(pretrained=False, **kwargs):
# @register_model # @register_model
# def mvitv2_huge_in21k(pretrained=False, **kwargs): # def mvitv2_huge_in21k(pretrained=False, **kwargs):
# return _create_mvitv2('mvitv2_huge_in21k', pretrained=pretrained, **kwargs) # return _create_mvitv2('mvitv2_huge_in21k', pretrained=pretrained, **kwargs)
@register_model
def mvitv2_small_cls(pretrained=False, **kwargs):
return _create_mvitv2('mvitv2_small_cls', pretrained=pretrained, **kwargs)

Loading…
Cancel
Save