diff --git a/timm/models/maxxvit.py b/timm/models/maxxvit.py index de2c9fb8..495f682b 100644 --- a/timm/models/maxxvit.py +++ b/timm/models/maxxvit.py @@ -39,7 +39,7 @@ Hacked together by / Copyright 2022, Ross Wightman import math from collections import OrderedDict -from dataclasses import dataclass +from dataclasses import dataclass, replace from functools import partial from typing import Callable, Optional, Union, Tuple, List @@ -112,6 +112,9 @@ default_cfgs = { input_size=(3, 256, 256), pool_size=(8, 8)), 'maxvit_tiny_rw_224': _cfg(url=''), 'maxvit_tiny_rw_256': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8)), + 'maxvit_rmlp_nano_rw_256': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/maxvit_rmlp_nano_rw_256_sw-c17bb0d6.pth', + input_size=(3, 256, 256), pool_size=(8, 8)), 'maxvit_tiny_pm_256': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8)), 'maxxvit_nano_rw_256': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8)), @@ -136,14 +139,23 @@ class MaxxVitTransformerCfg: pool_type: str = 'avg2' rel_pos_type: str = 'bias' rel_pos_dim: int = 512 # for relative position types w/ MLP - window_size: Tuple[int, int] = (7, 7) - grid_size: Tuple[int, int] = (7, 7) + partition_stride: int = 32 + window_size: Optional[Tuple[int, int]] = None + grid_size: Optional[Tuple[int, int]] = None init_values: Optional[float] = None act_layer: str = 'gelu' norm_layer: str = 'layernorm2d' norm_layer_cl: str = 'layernorm' norm_eps: float = 1e-6 + def __post_init__(self): + if self.grid_size is not None: + self.grid_size = to_2tuple(self.grid_size) + if self.window_size is not None: + self.window_size = to_2tuple(self.window_size) + if self.grid_size is None: + self.grid_size = self.window_size + @dataclass class MaxxVitConvCfg: @@ -249,7 +261,7 @@ def _rw_max_cfg( conv_norm_layer='', transformer_norm_layer='layernorm2d', transformer_norm_layer_cl='layernorm', - window_size=7, + window_size=None, dim_head=32, rel_pos_type='bias', rel_pos_dim=512, @@ -274,8 +286,7 @@ def _rw_max_cfg( expand_first=False, pool_type=pool_type, dim_head=dim_head, - window_size=to_2tuple(window_size), - grid_size=to_2tuple(window_size), + window_size=window_size, norm_layer=transformer_norm_layer, norm_layer_cl=transformer_norm_layer_cl, rel_pos_type=rel_pos_type, @@ -291,7 +302,7 @@ def _next_cfg( conv_norm_layer_cl='layernorm', transformer_norm_layer='layernorm2d', transformer_norm_layer_cl='layernorm', - window_size=7, + window_size=None, rel_pos_type='bias', rel_pos_dim=512, ): @@ -308,8 +319,7 @@ def _next_cfg( transformer_cfg=MaxxVitTransformerCfg( expand_first=False, pool_type=pool_type, - window_size=to_2tuple(window_size), - grid_size=to_2tuple(window_size), + window_size=window_size, norm_layer=transformer_norm_layer, norm_layer_cl=transformer_norm_layer_cl, rel_pos_type=rel_pos_type, @@ -462,14 +472,14 @@ model_cfgs = dict( depths=(2, 2, 5, 2), block_type=('M',) * 4, stem_width=(24, 32), - **_rw_max_cfg(window_size=8), + **_rw_max_cfg(), ), maxvit_nano_rw_256=MaxxVitCfg( embed_dim=(64, 128, 256, 512), depths=(1, 2, 3, 1), block_type=('M',) * 4, stem_width=(32, 64), - **_rw_max_cfg(window_size=8), + **_rw_max_cfg(), ), maxvit_tiny_rw_224=MaxxVitCfg( embed_dim=(64, 128, 256, 512), @@ -483,14 +493,21 @@ model_cfgs = dict( depths=(2, 2, 5, 2), block_type=('M',) * 4, stem_width=(32, 64), - **_rw_max_cfg(window_size=8), + **_rw_max_cfg(), + ), + maxvit_rmlp_nano_rw_256=MaxxVitCfg( + embed_dim=(64, 128, 256, 512), + depths=(1, 2, 3, 1), + block_type=('M',) * 4, + stem_width=(32, 64), + **_rw_max_cfg(rel_pos_type='mlp'), ), maxvit_tiny_pm_256=MaxxVitCfg( embed_dim=(64, 128, 256, 512), depths=(2, 2, 5, 2), block_type=('PM',) * 4, stem_width=(32, 64), - **_rw_max_cfg(window_size=8), + **_rw_max_cfg(), ), maxxvit_nano_rw_256=MaxxVitCfg( embed_dim=(64, 128, 256, 512), @@ -498,7 +515,7 @@ model_cfgs = dict( block_type=('M',) * 4, stem_width=(32, 64), weight_init='normal', - **_next_cfg(window_size=8), + **_next_cfg(), ), # Trying to be like the MaxViT paper configs @@ -1437,6 +1454,15 @@ class Stem(nn.Module): return x +def cfg_window_size(cfg: MaxxVitTransformerCfg, img_size: Tuple[int, int]): + if cfg.window_size is not None: + assert cfg.grid_size + return cfg + partition_size = img_size[0] // cfg.partition_stride, img_size[1] // cfg.partition_stride + cfg = replace(cfg, window_size=partition_size, grid_size=partition_size) + return cfg + + class MaxxVit(nn.Module): """ CoaTNet + MaxVit base model. @@ -1455,6 +1481,7 @@ class MaxxVit(nn.Module): ): super().__init__() img_size = to_2tuple(img_size) + transformer_cfg = cfg_window_size(cfg.transformer_cfg, img_size) self.num_classes = num_classes self.global_pool = global_pool self.num_features = cfg.embed_dim[-1] @@ -1488,7 +1515,7 @@ class MaxxVit(nn.Module): depth=cfg.depths[i], block_types=cfg.block_type[i], conv_cfg=cfg.conv_cfg, - transformer_cfg=cfg.transformer_cfg, + transformer_cfg=transformer_cfg, feat_size=feat_size, drop_path=dpr[i], )] @@ -1671,6 +1698,11 @@ def maxvit_tiny_rw_256(pretrained=False, **kwargs): return _create_maxxvit('maxvit_tiny_rw_256', pretrained=pretrained, **kwargs) +@register_model +def maxvit_rmlp_nano_rw_256(pretrained=False, **kwargs): + return _create_maxxvit('maxvit_rmlp_nano_rw_256', pretrained=pretrained, **kwargs) + + @register_model def maxvit_tiny_pm_256(pretrained=False, **kwargs): return _create_maxxvit('maxvit_tiny_pm_256', pretrained=pretrained, **kwargs)