Add maxvit_rmlp_nano_rw_256 model def & weights, make window/grid size dynamic wrt img_size by default

pull/804/merge
Ross Wightman 2 years ago
parent e6a4361306
commit 7f1b223c02

@ -39,7 +39,7 @@ Hacked together by / Copyright 2022, Ross Wightman
import math import math
from collections import OrderedDict from collections import OrderedDict
from dataclasses import dataclass from dataclasses import dataclass, replace
from functools import partial from functools import partial
from typing import Callable, Optional, Union, Tuple, List from typing import Callable, Optional, Union, Tuple, List
@ -112,6 +112,9 @@ default_cfgs = {
input_size=(3, 256, 256), pool_size=(8, 8)), input_size=(3, 256, 256), pool_size=(8, 8)),
'maxvit_tiny_rw_224': _cfg(url=''), 'maxvit_tiny_rw_224': _cfg(url=''),
'maxvit_tiny_rw_256': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8)), 'maxvit_tiny_rw_256': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8)),
'maxvit_rmlp_nano_rw_256': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/maxvit_rmlp_nano_rw_256_sw-c17bb0d6.pth',
input_size=(3, 256, 256), pool_size=(8, 8)),
'maxvit_tiny_pm_256': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8)), 'maxvit_tiny_pm_256': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8)),
'maxxvit_nano_rw_256': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8)), 'maxxvit_nano_rw_256': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8)),
@ -136,14 +139,23 @@ class MaxxVitTransformerCfg:
pool_type: str = 'avg2' pool_type: str = 'avg2'
rel_pos_type: str = 'bias' rel_pos_type: str = 'bias'
rel_pos_dim: int = 512 # for relative position types w/ MLP rel_pos_dim: int = 512 # for relative position types w/ MLP
window_size: Tuple[int, int] = (7, 7) partition_stride: int = 32
grid_size: Tuple[int, int] = (7, 7) window_size: Optional[Tuple[int, int]] = None
grid_size: Optional[Tuple[int, int]] = None
init_values: Optional[float] = None init_values: Optional[float] = None
act_layer: str = 'gelu' act_layer: str = 'gelu'
norm_layer: str = 'layernorm2d' norm_layer: str = 'layernorm2d'
norm_layer_cl: str = 'layernorm' norm_layer_cl: str = 'layernorm'
norm_eps: float = 1e-6 norm_eps: float = 1e-6
def __post_init__(self):
if self.grid_size is not None:
self.grid_size = to_2tuple(self.grid_size)
if self.window_size is not None:
self.window_size = to_2tuple(self.window_size)
if self.grid_size is None:
self.grid_size = self.window_size
@dataclass @dataclass
class MaxxVitConvCfg: class MaxxVitConvCfg:
@ -249,7 +261,7 @@ def _rw_max_cfg(
conv_norm_layer='', conv_norm_layer='',
transformer_norm_layer='layernorm2d', transformer_norm_layer='layernorm2d',
transformer_norm_layer_cl='layernorm', transformer_norm_layer_cl='layernorm',
window_size=7, window_size=None,
dim_head=32, dim_head=32,
rel_pos_type='bias', rel_pos_type='bias',
rel_pos_dim=512, rel_pos_dim=512,
@ -274,8 +286,7 @@ def _rw_max_cfg(
expand_first=False, expand_first=False,
pool_type=pool_type, pool_type=pool_type,
dim_head=dim_head, dim_head=dim_head,
window_size=to_2tuple(window_size), window_size=window_size,
grid_size=to_2tuple(window_size),
norm_layer=transformer_norm_layer, norm_layer=transformer_norm_layer,
norm_layer_cl=transformer_norm_layer_cl, norm_layer_cl=transformer_norm_layer_cl,
rel_pos_type=rel_pos_type, rel_pos_type=rel_pos_type,
@ -291,7 +302,7 @@ def _next_cfg(
conv_norm_layer_cl='layernorm', conv_norm_layer_cl='layernorm',
transformer_norm_layer='layernorm2d', transformer_norm_layer='layernorm2d',
transformer_norm_layer_cl='layernorm', transformer_norm_layer_cl='layernorm',
window_size=7, window_size=None,
rel_pos_type='bias', rel_pos_type='bias',
rel_pos_dim=512, rel_pos_dim=512,
): ):
@ -308,8 +319,7 @@ def _next_cfg(
transformer_cfg=MaxxVitTransformerCfg( transformer_cfg=MaxxVitTransformerCfg(
expand_first=False, expand_first=False,
pool_type=pool_type, pool_type=pool_type,
window_size=to_2tuple(window_size), window_size=window_size,
grid_size=to_2tuple(window_size),
norm_layer=transformer_norm_layer, norm_layer=transformer_norm_layer,
norm_layer_cl=transformer_norm_layer_cl, norm_layer_cl=transformer_norm_layer_cl,
rel_pos_type=rel_pos_type, rel_pos_type=rel_pos_type,
@ -462,14 +472,14 @@ model_cfgs = dict(
depths=(2, 2, 5, 2), depths=(2, 2, 5, 2),
block_type=('M',) * 4, block_type=('M',) * 4,
stem_width=(24, 32), stem_width=(24, 32),
**_rw_max_cfg(window_size=8), **_rw_max_cfg(),
), ),
maxvit_nano_rw_256=MaxxVitCfg( maxvit_nano_rw_256=MaxxVitCfg(
embed_dim=(64, 128, 256, 512), embed_dim=(64, 128, 256, 512),
depths=(1, 2, 3, 1), depths=(1, 2, 3, 1),
block_type=('M',) * 4, block_type=('M',) * 4,
stem_width=(32, 64), stem_width=(32, 64),
**_rw_max_cfg(window_size=8), **_rw_max_cfg(),
), ),
maxvit_tiny_rw_224=MaxxVitCfg( maxvit_tiny_rw_224=MaxxVitCfg(
embed_dim=(64, 128, 256, 512), embed_dim=(64, 128, 256, 512),
@ -483,14 +493,21 @@ model_cfgs = dict(
depths=(2, 2, 5, 2), depths=(2, 2, 5, 2),
block_type=('M',) * 4, block_type=('M',) * 4,
stem_width=(32, 64), stem_width=(32, 64),
**_rw_max_cfg(window_size=8), **_rw_max_cfg(),
),
maxvit_rmlp_nano_rw_256=MaxxVitCfg(
embed_dim=(64, 128, 256, 512),
depths=(1, 2, 3, 1),
block_type=('M',) * 4,
stem_width=(32, 64),
**_rw_max_cfg(rel_pos_type='mlp'),
), ),
maxvit_tiny_pm_256=MaxxVitCfg( maxvit_tiny_pm_256=MaxxVitCfg(
embed_dim=(64, 128, 256, 512), embed_dim=(64, 128, 256, 512),
depths=(2, 2, 5, 2), depths=(2, 2, 5, 2),
block_type=('PM',) * 4, block_type=('PM',) * 4,
stem_width=(32, 64), stem_width=(32, 64),
**_rw_max_cfg(window_size=8), **_rw_max_cfg(),
), ),
maxxvit_nano_rw_256=MaxxVitCfg( maxxvit_nano_rw_256=MaxxVitCfg(
embed_dim=(64, 128, 256, 512), embed_dim=(64, 128, 256, 512),
@ -498,7 +515,7 @@ model_cfgs = dict(
block_type=('M',) * 4, block_type=('M',) * 4,
stem_width=(32, 64), stem_width=(32, 64),
weight_init='normal', weight_init='normal',
**_next_cfg(window_size=8), **_next_cfg(),
), ),
# Trying to be like the MaxViT paper configs # Trying to be like the MaxViT paper configs
@ -1437,6 +1454,15 @@ class Stem(nn.Module):
return x return x
def cfg_window_size(cfg: MaxxVitTransformerCfg, img_size: Tuple[int, int]):
if cfg.window_size is not None:
assert cfg.grid_size
return cfg
partition_size = img_size[0] // cfg.partition_stride, img_size[1] // cfg.partition_stride
cfg = replace(cfg, window_size=partition_size, grid_size=partition_size)
return cfg
class MaxxVit(nn.Module): class MaxxVit(nn.Module):
""" CoaTNet + MaxVit base model. """ CoaTNet + MaxVit base model.
@ -1455,6 +1481,7 @@ class MaxxVit(nn.Module):
): ):
super().__init__() super().__init__()
img_size = to_2tuple(img_size) img_size = to_2tuple(img_size)
transformer_cfg = cfg_window_size(cfg.transformer_cfg, img_size)
self.num_classes = num_classes self.num_classes = num_classes
self.global_pool = global_pool self.global_pool = global_pool
self.num_features = cfg.embed_dim[-1] self.num_features = cfg.embed_dim[-1]
@ -1488,7 +1515,7 @@ class MaxxVit(nn.Module):
depth=cfg.depths[i], depth=cfg.depths[i],
block_types=cfg.block_type[i], block_types=cfg.block_type[i],
conv_cfg=cfg.conv_cfg, conv_cfg=cfg.conv_cfg,
transformer_cfg=cfg.transformer_cfg, transformer_cfg=transformer_cfg,
feat_size=feat_size, feat_size=feat_size,
drop_path=dpr[i], drop_path=dpr[i],
)] )]
@ -1671,6 +1698,11 @@ def maxvit_tiny_rw_256(pretrained=False, **kwargs):
return _create_maxxvit('maxvit_tiny_rw_256', pretrained=pretrained, **kwargs) return _create_maxxvit('maxvit_tiny_rw_256', pretrained=pretrained, **kwargs)
@register_model
def maxvit_rmlp_nano_rw_256(pretrained=False, **kwargs):
return _create_maxxvit('maxvit_rmlp_nano_rw_256', pretrained=pretrained, **kwargs)
@register_model @register_model
def maxvit_tiny_pm_256(pretrained=False, **kwargs): def maxvit_tiny_pm_256(pretrained=False, **kwargs):
return _create_maxxvit('maxvit_tiny_pm_256', pretrained=pretrained, **kwargs) return _create_maxxvit('maxvit_tiny_pm_256', pretrained=pretrained, **kwargs)

Loading…
Cancel
Save