Add maxvit_rmlp_nano_rw_256 model def & weights, make window/grid size dynamic wrt img_size by default

pull/804/merge
Ross Wightman 2 years ago
parent e6a4361306
commit 7f1b223c02

@ -39,7 +39,7 @@ Hacked together by / Copyright 2022, Ross Wightman
import math
from collections import OrderedDict
from dataclasses import dataclass
from dataclasses import dataclass, replace
from functools import partial
from typing import Callable, Optional, Union, Tuple, List
@ -112,6 +112,9 @@ default_cfgs = {
input_size=(3, 256, 256), pool_size=(8, 8)),
'maxvit_tiny_rw_224': _cfg(url=''),
'maxvit_tiny_rw_256': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8)),
'maxvit_rmlp_nano_rw_256': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/maxvit_rmlp_nano_rw_256_sw-c17bb0d6.pth',
input_size=(3, 256, 256), pool_size=(8, 8)),
'maxvit_tiny_pm_256': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8)),
'maxxvit_nano_rw_256': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8)),
@ -136,14 +139,23 @@ class MaxxVitTransformerCfg:
pool_type: str = 'avg2'
rel_pos_type: str = 'bias'
rel_pos_dim: int = 512 # for relative position types w/ MLP
window_size: Tuple[int, int] = (7, 7)
grid_size: Tuple[int, int] = (7, 7)
partition_stride: int = 32
window_size: Optional[Tuple[int, int]] = None
grid_size: Optional[Tuple[int, int]] = None
init_values: Optional[float] = None
act_layer: str = 'gelu'
norm_layer: str = 'layernorm2d'
norm_layer_cl: str = 'layernorm'
norm_eps: float = 1e-6
def __post_init__(self):
if self.grid_size is not None:
self.grid_size = to_2tuple(self.grid_size)
if self.window_size is not None:
self.window_size = to_2tuple(self.window_size)
if self.grid_size is None:
self.grid_size = self.window_size
@dataclass
class MaxxVitConvCfg:
@ -249,7 +261,7 @@ def _rw_max_cfg(
conv_norm_layer='',
transformer_norm_layer='layernorm2d',
transformer_norm_layer_cl='layernorm',
window_size=7,
window_size=None,
dim_head=32,
rel_pos_type='bias',
rel_pos_dim=512,
@ -274,8 +286,7 @@ def _rw_max_cfg(
expand_first=False,
pool_type=pool_type,
dim_head=dim_head,
window_size=to_2tuple(window_size),
grid_size=to_2tuple(window_size),
window_size=window_size,
norm_layer=transformer_norm_layer,
norm_layer_cl=transformer_norm_layer_cl,
rel_pos_type=rel_pos_type,
@ -291,7 +302,7 @@ def _next_cfg(
conv_norm_layer_cl='layernorm',
transformer_norm_layer='layernorm2d',
transformer_norm_layer_cl='layernorm',
window_size=7,
window_size=None,
rel_pos_type='bias',
rel_pos_dim=512,
):
@ -308,8 +319,7 @@ def _next_cfg(
transformer_cfg=MaxxVitTransformerCfg(
expand_first=False,
pool_type=pool_type,
window_size=to_2tuple(window_size),
grid_size=to_2tuple(window_size),
window_size=window_size,
norm_layer=transformer_norm_layer,
norm_layer_cl=transformer_norm_layer_cl,
rel_pos_type=rel_pos_type,
@ -462,14 +472,14 @@ model_cfgs = dict(
depths=(2, 2, 5, 2),
block_type=('M',) * 4,
stem_width=(24, 32),
**_rw_max_cfg(window_size=8),
**_rw_max_cfg(),
),
maxvit_nano_rw_256=MaxxVitCfg(
embed_dim=(64, 128, 256, 512),
depths=(1, 2, 3, 1),
block_type=('M',) * 4,
stem_width=(32, 64),
**_rw_max_cfg(window_size=8),
**_rw_max_cfg(),
),
maxvit_tiny_rw_224=MaxxVitCfg(
embed_dim=(64, 128, 256, 512),
@ -483,14 +493,21 @@ model_cfgs = dict(
depths=(2, 2, 5, 2),
block_type=('M',) * 4,
stem_width=(32, 64),
**_rw_max_cfg(window_size=8),
**_rw_max_cfg(),
),
maxvit_rmlp_nano_rw_256=MaxxVitCfg(
embed_dim=(64, 128, 256, 512),
depths=(1, 2, 3, 1),
block_type=('M',) * 4,
stem_width=(32, 64),
**_rw_max_cfg(rel_pos_type='mlp'),
),
maxvit_tiny_pm_256=MaxxVitCfg(
embed_dim=(64, 128, 256, 512),
depths=(2, 2, 5, 2),
block_type=('PM',) * 4,
stem_width=(32, 64),
**_rw_max_cfg(window_size=8),
**_rw_max_cfg(),
),
maxxvit_nano_rw_256=MaxxVitCfg(
embed_dim=(64, 128, 256, 512),
@ -498,7 +515,7 @@ model_cfgs = dict(
block_type=('M',) * 4,
stem_width=(32, 64),
weight_init='normal',
**_next_cfg(window_size=8),
**_next_cfg(),
),
# Trying to be like the MaxViT paper configs
@ -1437,6 +1454,15 @@ class Stem(nn.Module):
return x
def cfg_window_size(cfg: MaxxVitTransformerCfg, img_size: Tuple[int, int]):
if cfg.window_size is not None:
assert cfg.grid_size
return cfg
partition_size = img_size[0] // cfg.partition_stride, img_size[1] // cfg.partition_stride
cfg = replace(cfg, window_size=partition_size, grid_size=partition_size)
return cfg
class MaxxVit(nn.Module):
""" CoaTNet + MaxVit base model.
@ -1455,6 +1481,7 @@ class MaxxVit(nn.Module):
):
super().__init__()
img_size = to_2tuple(img_size)
transformer_cfg = cfg_window_size(cfg.transformer_cfg, img_size)
self.num_classes = num_classes
self.global_pool = global_pool
self.num_features = cfg.embed_dim[-1]
@ -1488,7 +1515,7 @@ class MaxxVit(nn.Module):
depth=cfg.depths[i],
block_types=cfg.block_type[i],
conv_cfg=cfg.conv_cfg,
transformer_cfg=cfg.transformer_cfg,
transformer_cfg=transformer_cfg,
feat_size=feat_size,
drop_path=dpr[i],
)]
@ -1671,6 +1698,11 @@ def maxvit_tiny_rw_256(pretrained=False, **kwargs):
return _create_maxxvit('maxvit_tiny_rw_256', pretrained=pretrained, **kwargs)
@register_model
def maxvit_rmlp_nano_rw_256(pretrained=False, **kwargs):
return _create_maxxvit('maxvit_rmlp_nano_rw_256', pretrained=pretrained, **kwargs)
@register_model
def maxvit_tiny_pm_256(pretrained=False, **kwargs):
return _create_maxxvit('maxvit_tiny_pm_256', pretrained=pretrained, **kwargs)

Loading…
Cancel
Save