|
|
@ -39,7 +39,7 @@ Hacked together by / Copyright 2022, Ross Wightman
|
|
|
|
|
|
|
|
|
|
|
|
import math
|
|
|
|
import math
|
|
|
|
from collections import OrderedDict
|
|
|
|
from collections import OrderedDict
|
|
|
|
from dataclasses import dataclass
|
|
|
|
from dataclasses import dataclass, replace
|
|
|
|
from functools import partial
|
|
|
|
from functools import partial
|
|
|
|
from typing import Callable, Optional, Union, Tuple, List
|
|
|
|
from typing import Callable, Optional, Union, Tuple, List
|
|
|
|
|
|
|
|
|
|
|
@ -112,6 +112,9 @@ default_cfgs = {
|
|
|
|
input_size=(3, 256, 256), pool_size=(8, 8)),
|
|
|
|
input_size=(3, 256, 256), pool_size=(8, 8)),
|
|
|
|
'maxvit_tiny_rw_224': _cfg(url=''),
|
|
|
|
'maxvit_tiny_rw_224': _cfg(url=''),
|
|
|
|
'maxvit_tiny_rw_256': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8)),
|
|
|
|
'maxvit_tiny_rw_256': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8)),
|
|
|
|
|
|
|
|
'maxvit_rmlp_nano_rw_256': _cfg(
|
|
|
|
|
|
|
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/maxvit_rmlp_nano_rw_256_sw-c17bb0d6.pth',
|
|
|
|
|
|
|
|
input_size=(3, 256, 256), pool_size=(8, 8)),
|
|
|
|
'maxvit_tiny_pm_256': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8)),
|
|
|
|
'maxvit_tiny_pm_256': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8)),
|
|
|
|
'maxxvit_nano_rw_256': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8)),
|
|
|
|
'maxxvit_nano_rw_256': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8)),
|
|
|
|
|
|
|
|
|
|
|
@ -136,14 +139,23 @@ class MaxxVitTransformerCfg:
|
|
|
|
pool_type: str = 'avg2'
|
|
|
|
pool_type: str = 'avg2'
|
|
|
|
rel_pos_type: str = 'bias'
|
|
|
|
rel_pos_type: str = 'bias'
|
|
|
|
rel_pos_dim: int = 512 # for relative position types w/ MLP
|
|
|
|
rel_pos_dim: int = 512 # for relative position types w/ MLP
|
|
|
|
window_size: Tuple[int, int] = (7, 7)
|
|
|
|
partition_stride: int = 32
|
|
|
|
grid_size: Tuple[int, int] = (7, 7)
|
|
|
|
window_size: Optional[Tuple[int, int]] = None
|
|
|
|
|
|
|
|
grid_size: Optional[Tuple[int, int]] = None
|
|
|
|
init_values: Optional[float] = None
|
|
|
|
init_values: Optional[float] = None
|
|
|
|
act_layer: str = 'gelu'
|
|
|
|
act_layer: str = 'gelu'
|
|
|
|
norm_layer: str = 'layernorm2d'
|
|
|
|
norm_layer: str = 'layernorm2d'
|
|
|
|
norm_layer_cl: str = 'layernorm'
|
|
|
|
norm_layer_cl: str = 'layernorm'
|
|
|
|
norm_eps: float = 1e-6
|
|
|
|
norm_eps: float = 1e-6
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def __post_init__(self):
|
|
|
|
|
|
|
|
if self.grid_size is not None:
|
|
|
|
|
|
|
|
self.grid_size = to_2tuple(self.grid_size)
|
|
|
|
|
|
|
|
if self.window_size is not None:
|
|
|
|
|
|
|
|
self.window_size = to_2tuple(self.window_size)
|
|
|
|
|
|
|
|
if self.grid_size is None:
|
|
|
|
|
|
|
|
self.grid_size = self.window_size
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
|
@dataclass
|
|
|
|
class MaxxVitConvCfg:
|
|
|
|
class MaxxVitConvCfg:
|
|
|
@ -249,7 +261,7 @@ def _rw_max_cfg(
|
|
|
|
conv_norm_layer='',
|
|
|
|
conv_norm_layer='',
|
|
|
|
transformer_norm_layer='layernorm2d',
|
|
|
|
transformer_norm_layer='layernorm2d',
|
|
|
|
transformer_norm_layer_cl='layernorm',
|
|
|
|
transformer_norm_layer_cl='layernorm',
|
|
|
|
window_size=7,
|
|
|
|
window_size=None,
|
|
|
|
dim_head=32,
|
|
|
|
dim_head=32,
|
|
|
|
rel_pos_type='bias',
|
|
|
|
rel_pos_type='bias',
|
|
|
|
rel_pos_dim=512,
|
|
|
|
rel_pos_dim=512,
|
|
|
@ -274,8 +286,7 @@ def _rw_max_cfg(
|
|
|
|
expand_first=False,
|
|
|
|
expand_first=False,
|
|
|
|
pool_type=pool_type,
|
|
|
|
pool_type=pool_type,
|
|
|
|
dim_head=dim_head,
|
|
|
|
dim_head=dim_head,
|
|
|
|
window_size=to_2tuple(window_size),
|
|
|
|
window_size=window_size,
|
|
|
|
grid_size=to_2tuple(window_size),
|
|
|
|
|
|
|
|
norm_layer=transformer_norm_layer,
|
|
|
|
norm_layer=transformer_norm_layer,
|
|
|
|
norm_layer_cl=transformer_norm_layer_cl,
|
|
|
|
norm_layer_cl=transformer_norm_layer_cl,
|
|
|
|
rel_pos_type=rel_pos_type,
|
|
|
|
rel_pos_type=rel_pos_type,
|
|
|
@ -291,7 +302,7 @@ def _next_cfg(
|
|
|
|
conv_norm_layer_cl='layernorm',
|
|
|
|
conv_norm_layer_cl='layernorm',
|
|
|
|
transformer_norm_layer='layernorm2d',
|
|
|
|
transformer_norm_layer='layernorm2d',
|
|
|
|
transformer_norm_layer_cl='layernorm',
|
|
|
|
transformer_norm_layer_cl='layernorm',
|
|
|
|
window_size=7,
|
|
|
|
window_size=None,
|
|
|
|
rel_pos_type='bias',
|
|
|
|
rel_pos_type='bias',
|
|
|
|
rel_pos_dim=512,
|
|
|
|
rel_pos_dim=512,
|
|
|
|
):
|
|
|
|
):
|
|
|
@ -308,8 +319,7 @@ def _next_cfg(
|
|
|
|
transformer_cfg=MaxxVitTransformerCfg(
|
|
|
|
transformer_cfg=MaxxVitTransformerCfg(
|
|
|
|
expand_first=False,
|
|
|
|
expand_first=False,
|
|
|
|
pool_type=pool_type,
|
|
|
|
pool_type=pool_type,
|
|
|
|
window_size=to_2tuple(window_size),
|
|
|
|
window_size=window_size,
|
|
|
|
grid_size=to_2tuple(window_size),
|
|
|
|
|
|
|
|
norm_layer=transformer_norm_layer,
|
|
|
|
norm_layer=transformer_norm_layer,
|
|
|
|
norm_layer_cl=transformer_norm_layer_cl,
|
|
|
|
norm_layer_cl=transformer_norm_layer_cl,
|
|
|
|
rel_pos_type=rel_pos_type,
|
|
|
|
rel_pos_type=rel_pos_type,
|
|
|
@ -462,14 +472,14 @@ model_cfgs = dict(
|
|
|
|
depths=(2, 2, 5, 2),
|
|
|
|
depths=(2, 2, 5, 2),
|
|
|
|
block_type=('M',) * 4,
|
|
|
|
block_type=('M',) * 4,
|
|
|
|
stem_width=(24, 32),
|
|
|
|
stem_width=(24, 32),
|
|
|
|
**_rw_max_cfg(window_size=8),
|
|
|
|
**_rw_max_cfg(),
|
|
|
|
),
|
|
|
|
),
|
|
|
|
maxvit_nano_rw_256=MaxxVitCfg(
|
|
|
|
maxvit_nano_rw_256=MaxxVitCfg(
|
|
|
|
embed_dim=(64, 128, 256, 512),
|
|
|
|
embed_dim=(64, 128, 256, 512),
|
|
|
|
depths=(1, 2, 3, 1),
|
|
|
|
depths=(1, 2, 3, 1),
|
|
|
|
block_type=('M',) * 4,
|
|
|
|
block_type=('M',) * 4,
|
|
|
|
stem_width=(32, 64),
|
|
|
|
stem_width=(32, 64),
|
|
|
|
**_rw_max_cfg(window_size=8),
|
|
|
|
**_rw_max_cfg(),
|
|
|
|
),
|
|
|
|
),
|
|
|
|
maxvit_tiny_rw_224=MaxxVitCfg(
|
|
|
|
maxvit_tiny_rw_224=MaxxVitCfg(
|
|
|
|
embed_dim=(64, 128, 256, 512),
|
|
|
|
embed_dim=(64, 128, 256, 512),
|
|
|
@ -483,14 +493,21 @@ model_cfgs = dict(
|
|
|
|
depths=(2, 2, 5, 2),
|
|
|
|
depths=(2, 2, 5, 2),
|
|
|
|
block_type=('M',) * 4,
|
|
|
|
block_type=('M',) * 4,
|
|
|
|
stem_width=(32, 64),
|
|
|
|
stem_width=(32, 64),
|
|
|
|
**_rw_max_cfg(window_size=8),
|
|
|
|
**_rw_max_cfg(),
|
|
|
|
|
|
|
|
),
|
|
|
|
|
|
|
|
maxvit_rmlp_nano_rw_256=MaxxVitCfg(
|
|
|
|
|
|
|
|
embed_dim=(64, 128, 256, 512),
|
|
|
|
|
|
|
|
depths=(1, 2, 3, 1),
|
|
|
|
|
|
|
|
block_type=('M',) * 4,
|
|
|
|
|
|
|
|
stem_width=(32, 64),
|
|
|
|
|
|
|
|
**_rw_max_cfg(rel_pos_type='mlp'),
|
|
|
|
),
|
|
|
|
),
|
|
|
|
maxvit_tiny_pm_256=MaxxVitCfg(
|
|
|
|
maxvit_tiny_pm_256=MaxxVitCfg(
|
|
|
|
embed_dim=(64, 128, 256, 512),
|
|
|
|
embed_dim=(64, 128, 256, 512),
|
|
|
|
depths=(2, 2, 5, 2),
|
|
|
|
depths=(2, 2, 5, 2),
|
|
|
|
block_type=('PM',) * 4,
|
|
|
|
block_type=('PM',) * 4,
|
|
|
|
stem_width=(32, 64),
|
|
|
|
stem_width=(32, 64),
|
|
|
|
**_rw_max_cfg(window_size=8),
|
|
|
|
**_rw_max_cfg(),
|
|
|
|
),
|
|
|
|
),
|
|
|
|
maxxvit_nano_rw_256=MaxxVitCfg(
|
|
|
|
maxxvit_nano_rw_256=MaxxVitCfg(
|
|
|
|
embed_dim=(64, 128, 256, 512),
|
|
|
|
embed_dim=(64, 128, 256, 512),
|
|
|
@ -498,7 +515,7 @@ model_cfgs = dict(
|
|
|
|
block_type=('M',) * 4,
|
|
|
|
block_type=('M',) * 4,
|
|
|
|
stem_width=(32, 64),
|
|
|
|
stem_width=(32, 64),
|
|
|
|
weight_init='normal',
|
|
|
|
weight_init='normal',
|
|
|
|
**_next_cfg(window_size=8),
|
|
|
|
**_next_cfg(),
|
|
|
|
),
|
|
|
|
),
|
|
|
|
|
|
|
|
|
|
|
|
# Trying to be like the MaxViT paper configs
|
|
|
|
# Trying to be like the MaxViT paper configs
|
|
|
@ -1437,6 +1454,15 @@ class Stem(nn.Module):
|
|
|
|
return x
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def cfg_window_size(cfg: MaxxVitTransformerCfg, img_size: Tuple[int, int]):
|
|
|
|
|
|
|
|
if cfg.window_size is not None:
|
|
|
|
|
|
|
|
assert cfg.grid_size
|
|
|
|
|
|
|
|
return cfg
|
|
|
|
|
|
|
|
partition_size = img_size[0] // cfg.partition_stride, img_size[1] // cfg.partition_stride
|
|
|
|
|
|
|
|
cfg = replace(cfg, window_size=partition_size, grid_size=partition_size)
|
|
|
|
|
|
|
|
return cfg
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class MaxxVit(nn.Module):
|
|
|
|
class MaxxVit(nn.Module):
|
|
|
|
""" CoaTNet + MaxVit base model.
|
|
|
|
""" CoaTNet + MaxVit base model.
|
|
|
|
|
|
|
|
|
|
|
@ -1455,6 +1481,7 @@ class MaxxVit(nn.Module):
|
|
|
|
):
|
|
|
|
):
|
|
|
|
super().__init__()
|
|
|
|
super().__init__()
|
|
|
|
img_size = to_2tuple(img_size)
|
|
|
|
img_size = to_2tuple(img_size)
|
|
|
|
|
|
|
|
transformer_cfg = cfg_window_size(cfg.transformer_cfg, img_size)
|
|
|
|
self.num_classes = num_classes
|
|
|
|
self.num_classes = num_classes
|
|
|
|
self.global_pool = global_pool
|
|
|
|
self.global_pool = global_pool
|
|
|
|
self.num_features = cfg.embed_dim[-1]
|
|
|
|
self.num_features = cfg.embed_dim[-1]
|
|
|
@ -1488,7 +1515,7 @@ class MaxxVit(nn.Module):
|
|
|
|
depth=cfg.depths[i],
|
|
|
|
depth=cfg.depths[i],
|
|
|
|
block_types=cfg.block_type[i],
|
|
|
|
block_types=cfg.block_type[i],
|
|
|
|
conv_cfg=cfg.conv_cfg,
|
|
|
|
conv_cfg=cfg.conv_cfg,
|
|
|
|
transformer_cfg=cfg.transformer_cfg,
|
|
|
|
transformer_cfg=transformer_cfg,
|
|
|
|
feat_size=feat_size,
|
|
|
|
feat_size=feat_size,
|
|
|
|
drop_path=dpr[i],
|
|
|
|
drop_path=dpr[i],
|
|
|
|
)]
|
|
|
|
)]
|
|
|
@ -1671,6 +1698,11 @@ def maxvit_tiny_rw_256(pretrained=False, **kwargs):
|
|
|
|
return _create_maxxvit('maxvit_tiny_rw_256', pretrained=pretrained, **kwargs)
|
|
|
|
return _create_maxxvit('maxvit_tiny_rw_256', pretrained=pretrained, **kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
|
|
|
def maxvit_rmlp_nano_rw_256(pretrained=False, **kwargs):
|
|
|
|
|
|
|
|
return _create_maxxvit('maxvit_rmlp_nano_rw_256', pretrained=pretrained, **kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
@register_model
|
|
|
|
def maxvit_tiny_pm_256(pretrained=False, **kwargs):
|
|
|
|
def maxvit_tiny_pm_256(pretrained=False, **kwargs):
|
|
|
|
return _create_maxxvit('maxvit_tiny_pm_256', pretrained=pretrained, **kwargs)
|
|
|
|
return _create_maxxvit('maxvit_tiny_pm_256', pretrained=pretrained, **kwargs)
|
|
|
|