diff --git a/timm/models/maxxvit.py b/timm/models/maxxvit.py index 495f682b..f1df148b 100644 --- a/timm/models/maxxvit.py +++ b/timm/models/maxxvit.py @@ -110,8 +110,14 @@ default_cfgs = { 'maxvit_nano_rw_256': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/maxvit_nano_rw_256_sw-fb127241.pth', input_size=(3, 256, 256), pool_size=(8, 8)), - 'maxvit_tiny_rw_224': _cfg(url=''), - 'maxvit_tiny_rw_256': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8)), + 'maxvit_tiny_rw_224': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/maxvit_tiny_rw_224_sw-7d0dffeb.pth'), + 'maxvit_tiny_rw_256': _cfg( + url='', + input_size=(3, 256, 256), pool_size=(8, 8)), + 'maxvit_rmlp_pico_rw_256': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/maxvit_rmlp_pico_rw_256_sw-8d82f2c6.pth', + input_size=(3, 256, 256), pool_size=(8, 8)), 'maxvit_rmlp_nano_rw_256': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/maxvit_rmlp_nano_rw_256_sw-c17bb0d6.pth', input_size=(3, 256, 256), pool_size=(8, 8)), @@ -139,7 +145,7 @@ class MaxxVitTransformerCfg: pool_type: str = 'avg2' rel_pos_type: str = 'bias' rel_pos_dim: int = 512 # for relative position types w/ MLP - partition_stride: int = 32 + partition_ratio: int = 32 window_size: Optional[Tuple[int, int]] = None grid_size: Optional[Tuple[int, int]] = None init_values: Optional[float] = None @@ -495,6 +501,13 @@ model_cfgs = dict( stem_width=(32, 64), **_rw_max_cfg(), ), + maxvit_rmlp_pico_rw_256=MaxxVitCfg( + embed_dim=(32, 64, 128, 256), + depths=(2, 2, 5, 2), + block_type=('M',) * 4, + stem_width=(24, 32), + **_rw_max_cfg(rel_pos_type='mlp'), + ), maxvit_rmlp_nano_rw_256=MaxxVitCfg( embed_dim=(64, 128, 256, 512), depths=(1, 2, 3, 1), @@ -1458,7 +1471,7 @@ def cfg_window_size(cfg: MaxxVitTransformerCfg, img_size: Tuple[int, int]): if cfg.window_size is not None: assert cfg.grid_size return cfg - partition_size = img_size[0] // cfg.partition_stride, img_size[1] // cfg.partition_stride + partition_size = img_size[0] // cfg.partition_ratio, img_size[1] // cfg.partition_ratio cfg = replace(cfg, window_size=partition_size, grid_size=partition_size) return cfg @@ -1698,6 +1711,11 @@ def maxvit_tiny_rw_256(pretrained=False, **kwargs): return _create_maxxvit('maxvit_tiny_rw_256', pretrained=pretrained, **kwargs) +@register_model +def maxvit_rmlp_pico_rw_256(pretrained=False, **kwargs): + return _create_maxxvit('maxvit_rmlp_pico_rw_256', pretrained=pretrained, **kwargs) + + @register_model def maxvit_rmlp_nano_rw_256(pretrained=False, **kwargs): return _create_maxxvit('maxvit_rmlp_nano_rw_256', pretrained=pretrained, **kwargs)