@ -39,7 +39,7 @@ Hacked together by / Copyright 2022, Ross Wightman
import math
from collections import OrderedDict
from dataclasses import dataclass
from dataclasses import dataclass , replace
from functools import partial
from typing import Callable , Optional , Union , Tuple , List
@ -108,10 +108,13 @@ default_cfgs = {
# Experimental configs
' maxvit_pico_rw_256 ' : _cfg ( url = ' ' , input_size = ( 3 , 256 , 256 ) , pool_size = ( 8 , 8 ) ) ,
' maxvit_nano_rw_256 ' : _cfg (
url = ' https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/maxvit_nano_rw_256_sw- 3e790ce3 .pth' ,
url = ' https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/maxvit_nano_rw_256_sw- fb127241 .pth' ,
input_size = ( 3 , 256 , 256 ) , pool_size = ( 8 , 8 ) ) ,
' maxvit_tiny_rw_224 ' : _cfg ( url = ' ' ) ,
' maxvit_tiny_rw_256 ' : _cfg ( url = ' ' , input_size = ( 3 , 256 , 256 ) , pool_size = ( 8 , 8 ) ) ,
' maxvit_rmlp_nano_rw_256 ' : _cfg (
url = ' https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/maxvit_rmlp_nano_rw_256_sw-c17bb0d6.pth ' ,
input_size = ( 3 , 256 , 256 ) , pool_size = ( 8 , 8 ) ) ,
' maxvit_tiny_pm_256 ' : _cfg ( url = ' ' , input_size = ( 3 , 256 , 256 ) , pool_size = ( 8 , 8 ) ) ,
' maxxvit_nano_rw_256 ' : _cfg ( url = ' ' , input_size = ( 3 , 256 , 256 ) , pool_size = ( 8 , 8 ) ) ,
@ -129,21 +132,30 @@ class MaxxVitTransformerCfg:
dim_head : int = 32
expand_ratio : float = 4.0
expand_first : bool = True
shortcut_bias : bool = True ,
shortcut_bias : bool = True
attn_bias : bool = True
attn_drop : float = 0.
proj_drop : float = 0.
pool_type : str = ' avg2 '
rel_pos_type : str = ' bias '
rel_pos_dim : int = 512 # for relative position types w/ MLP
window_size : Tuple [ int , int ] = ( 7 , 7 )
grid_size : Tuple [ int , int ] = ( 7 , 7 )
partition_stride : int = 32
window_size : Optional [ Tuple [ int , int ] ] = None
grid_size : Optional [ Tuple [ int , int ] ] = None
init_values : Optional [ float ] = None
act_layer : str = ' gelu '
norm_layer : str = ' layernorm2d '
norm_layer_cl : str = ' layernorm '
norm_eps : float = 1e-6
def __post_init__ ( self ) :
if self . grid_size is not None :
self . grid_size = to_2tuple ( self . grid_size )
if self . window_size is not None :
self . window_size = to_2tuple ( self . window_size )
if self . grid_size is None :
self . grid_size = self . window_size
@dataclass
class MaxxVitConvCfg :
@ -249,7 +261,7 @@ def _rw_max_cfg(
conv_norm_layer = ' ' ,
transformer_norm_layer = ' layernorm2d ' ,
transformer_norm_layer_cl = ' layernorm ' ,
window_size = 7 ,
window_size = None ,
dim_head = 32 ,
rel_pos_type = ' bias ' ,
rel_pos_dim = 512 ,
@ -259,8 +271,6 @@ def _rw_max_cfg(
# - mbconv expansion calculated from input instead of output chs
# - mbconv shortcut and final 1x1 conv did not have a bias
# - mbconv uses silu in timm, not gelu
# - avg pool with kernel_size=2 favoured downsampling (instead of maxpool for coat)
# - default to avg pool for mbconv downsample instead of 1x1 or dw conv
# - expansion in attention block done via output proj, not input proj
return dict (
conv_cfg = MaxxVitConvCfg (
@ -276,8 +286,7 @@ def _rw_max_cfg(
expand_first = False ,
pool_type = pool_type ,
dim_head = dim_head ,
window_size = to_2tuple ( window_size ) ,
grid_size = to_2tuple ( window_size ) ,
window_size = window_size ,
norm_layer = transformer_norm_layer ,
norm_layer_cl = transformer_norm_layer_cl ,
rel_pos_type = rel_pos_type ,
@ -293,7 +302,7 @@ def _next_cfg(
conv_norm_layer_cl = ' layernorm ' ,
transformer_norm_layer = ' layernorm2d ' ,
transformer_norm_layer_cl = ' layernorm ' ,
window_size = 7 ,
window_size = None ,
rel_pos_type = ' bias ' ,
rel_pos_dim = 512 ,
) :
@ -310,8 +319,7 @@ def _next_cfg(
transformer_cfg = MaxxVitTransformerCfg (
expand_first = False ,
pool_type = pool_type ,
window_size = to_2tuple ( window_size ) ,
grid_size = to_2tuple ( window_size ) ,
window_size = window_size ,
norm_layer = transformer_norm_layer ,
norm_layer_cl = transformer_norm_layer_cl ,
rel_pos_type = rel_pos_type ,
@ -411,18 +419,19 @@ model_cfgs = dict(
rel_pos_dim = 384 , # was supposed to be 512, woops
) ,
) ,
coatne xt_nano_rw _224= MaxxVitCfg (
coatne t_nano_cc _224= MaxxVitCfg (
embed_dim = ( 64 , 128 , 256 , 512 ) ,
depths = ( 3 , 4 , 6 , 3 ) ,
stem_width = ( 32 , 64 ) ,
* * _next_cfg ( ) ,
block_type = ( ' C ' , ' C ' , ( ' C ' , ' T ' ) , ( ' C ' , ' T ' ) ) ,
* * _rw_coat_cfg ( ) ,
) ,
coatne t_nano_cc _224= MaxxVitCfg (
coatne xt_nano_rw _224= MaxxVitCfg (
embed_dim = ( 64 , 128 , 256 , 512 ) ,
depths = ( 3 , 4 , 6 , 3 ) ,
stem_width = ( 32 , 64 ) ,
block_type= ( ' C ' , ' C ' , ( ' C ' , ' T ' ) , ( ' C ' , ' T ' ) ) ,
* * _ rw_coa t_cfg( ) ,
weight_init= ' normal ' ,
* * _ nex t_cfg( ) ,
) ,
# Trying to be like the CoAtNet paper configs
@ -463,14 +472,14 @@ model_cfgs = dict(
depths = ( 2 , 2 , 5 , 2 ) ,
block_type = ( ' M ' , ) * 4 ,
stem_width = ( 24 , 32 ) ,
* * _rw_max_cfg ( window_size = 8 ) ,
* * _rw_max_cfg ( ) ,
) ,
maxvit_nano_rw_256 = MaxxVitCfg (
embed_dim = ( 64 , 128 , 256 , 512 ) ,
depths = ( 1 , 2 , 3 , 1 ) ,
block_type = ( ' M ' , ) * 4 ,
stem_width = ( 32 , 64 ) ,
* * _rw_max_cfg ( window_size = 8 ) ,
* * _rw_max_cfg ( ) ,
) ,
maxvit_tiny_rw_224 = MaxxVitCfg (
embed_dim = ( 64 , 128 , 256 , 512 ) ,
@ -484,21 +493,29 @@ model_cfgs = dict(
depths = ( 2 , 2 , 5 , 2 ) ,
block_type = ( ' M ' , ) * 4 ,
stem_width = ( 32 , 64 ) ,
* * _rw_max_cfg ( window_size = 8 ) ,
* * _rw_max_cfg ( ) ,
) ,
maxvit_rmlp_nano_rw_256 = MaxxVitCfg (
embed_dim = ( 64 , 128 , 256 , 512 ) ,
depths = ( 1 , 2 , 3 , 1 ) ,
block_type = ( ' M ' , ) * 4 ,
stem_width = ( 32 , 64 ) ,
* * _rw_max_cfg ( rel_pos_type = ' mlp ' ) ,
) ,
maxvit_tiny_pm_256 = MaxxVitCfg (
embed_dim = ( 64 , 128 , 256 , 512 ) ,
depths = ( 2 , 2 , 5 , 2 ) ,
block_type = ( ' PM ' , ) * 4 ,
stem_width = ( 32 , 64 ) ,
* * _rw_max_cfg ( window_size = 8 ) ,
* * _rw_max_cfg ( ) ,
) ,
maxxvit_nano_rw_256 = MaxxVitCfg (
embed_dim = ( 64 , 128 , 256 , 512 ) ,
depths = ( 1 , 2 , 3 , 1 ) ,
block_type = ( ' M ' , ) * 4 ,
stem_width = ( 32 , 64 ) ,
* * _next_cfg ( window_size = 8 ) ,
weight_init = ' normal ' ,
* * _next_cfg ( ) ,
) ,
# Trying to be like the MaxViT paper configs
@ -651,7 +668,11 @@ class LayerScale2d(nn.Module):
class Downsample2d ( nn . Module ) :
""" A downsample pooling module for Coat that handles 2d <-> 1d conversion
""" A downsample pooling module supporting several maxpool and avgpool modes
* ' max ' - MaxPool2d w / kernel_size 3 , stride 2 , padding 1
* ' max2 ' - MaxPool2d w / kernel_size = stride = 2
* ' avg ' - AvgPool2d w / kernel_size 3 , stride 2 , padding 1
* ' avg2 ' - AvgPool2d w / kernel_size = stride = 2
"""
def __init__ (
@ -710,6 +731,11 @@ def _init_transformer(module, name, scheme=''):
class TransformerBlock2d ( nn . Module ) :
""" Transformer block with 2D downsampling
' 2D ' NCHW tensor layout
Some gains can be seen on GPU using a 1 D / CL block , BUT w / the need to switch back / forth to NCHW
for spatial pooling , the benefit is minimal so ended up using just this variant for CoAt configs .
This impl was faster on TPU w / PT XLA than the 1 D experiment .
"""
def __init__ (
@ -1011,9 +1037,9 @@ def get_rel_pos_cls(cfg: MaxxVitTransformerCfg, window_size):
return rel_pos_cls
class PartitionAttention ( nn . Module ) :
class PartitionAttention Cl ( nn . Module ) :
""" Grid or Block partition + Attn + FFN.
NxC tensor layout .
NxC ' channels last ' tensor layout .
"""
def __init__ (
@ -1183,6 +1209,7 @@ def grid_reverse_nchw(windows, grid_size: List[int], img_size: List[int]):
class PartitionAttention2d ( nn . Module ) :
""" Grid or Block partition + Attn + FFN
' 2D ' NCHW tensor layout .
"""
@ -1245,7 +1272,7 @@ class PartitionAttention2d(nn.Module):
class MaxxVitBlock ( nn . Module ) :
"""
""" MaxVit conv, window partition + FFN , grid partition + FFN
"""
def __init__ (
@ -1264,7 +1291,7 @@ class MaxxVitBlock(nn.Module):
self . conv = conv_cls ( dim , dim_out , stride = stride , cfg = conv_cfg , drop_path = drop_path )
attn_kwargs = dict ( dim = dim_out , cfg = transformer_cfg , drop_path = drop_path )
partition_layer = PartitionAttention2d if use_nchw_attn else PartitionAttention
partition_layer = PartitionAttention2d if use_nchw_attn else PartitionAttention Cl
self . nchw_attn = use_nchw_attn
self . attn_block = partition_layer ( * * attn_kwargs )
self . attn_grid = partition_layer ( partition_type = ' grid ' , * * attn_kwargs )
@ -1288,7 +1315,8 @@ class MaxxVitBlock(nn.Module):
class ParallelMaxxVitBlock ( nn . Module ) :
"""
""" MaxVit block with parallel cat(window + grid), one FF
Experimental timm block .
"""
def __init__ (
@ -1426,8 +1454,19 @@ class Stem(nn.Module):
return x
def cfg_window_size ( cfg : MaxxVitTransformerCfg , img_size : Tuple [ int , int ] ) :
if cfg . window_size is not None :
assert cfg . grid_size
return cfg
partition_size = img_size [ 0 ] / / cfg . partition_stride , img_size [ 1 ] / / cfg . partition_stride
cfg = replace ( cfg , window_size = partition_size , grid_size = partition_size )
return cfg
class MaxxVit ( nn . Module ) :
"""
""" CoaTNet + MaxVit base model.
Highly configurable for different block compositions , tensor layouts , pooling types .
"""
def __init__ (
@ -1442,6 +1481,7 @@ class MaxxVit(nn.Module):
) :
super ( ) . __init__ ( )
img_size = to_2tuple ( img_size )
transformer_cfg = cfg_window_size ( cfg . transformer_cfg , img_size )
self . num_classes = num_classes
self . global_pool = global_pool
self . num_features = cfg . embed_dim [ - 1 ]
@ -1475,7 +1515,7 @@ class MaxxVit(nn.Module):
depth = cfg . depths [ i ] ,
block_types = cfg . block_type [ i ] ,
conv_cfg = cfg . conv_cfg ,
transformer_cfg = cfg. transformer_cfg,
transformer_cfg = transformer_cfg,
feat_size = feat_size ,
drop_path = dpr [ i ] ,
) ]
@ -1658,6 +1698,11 @@ def maxvit_tiny_rw_256(pretrained=False, **kwargs):
return _create_maxxvit ( ' maxvit_tiny_rw_256 ' , pretrained = pretrained , * * kwargs )
@register_model
def maxvit_rmlp_nano_rw_256 ( pretrained = False , * * kwargs ) :
return _create_maxxvit ( ' maxvit_rmlp_nano_rw_256 ' , pretrained = pretrained , * * kwargs )
@register_model
def maxvit_tiny_pm_256 ( pretrained = False , * * kwargs ) :
return _create_maxxvit ( ' maxvit_tiny_pm_256 ' , pretrained = pretrained , * * kwargs )