From bed350f5e584241a753d22b94ab36146ad824c2e Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Fri, 20 Jan 2023 14:45:25 -0800 Subject: [PATCH] Push all MaxxViT weights to HF hub, cleanup impl, add feature map extraction support and prompote to 'std' architecture. Fix norm head for proper embedding / feat map output. Add new in12k + ft 1k weights. --- tests/test_models.py | 5 +- timm/models/maxxvit.py | 453 +++++++++++++++++++++++++---------------- 2 files changed, 286 insertions(+), 172 deletions(-) diff --git a/tests/test_models.py b/tests/test_models.py index 3e91d9a8..4ad18477 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -27,8 +27,9 @@ NON_STD_FILTERS = [ 'vit_*', 'tnt_*', 'pit_*', 'swin_*', 'coat_*', 'cait_*', '*mixer_*', 'gmlp_*', 'resmlp_*', 'twins_*', 'convit_*', 'levit*', 'visformer*', 'deit*', 'jx_nest_*', 'nest_*', 'xcit_*', 'crossvit_*', 'beit*', 'poolformer_*', 'volo_*', 'sequencer2d_*', 'swinv2_*', 'pvt_v2*', 'mvitv2*', 'gcvit*', 'efficientformer*', - 'coatnet*', 'coatnext*', 'maxvit*', 'maxxvit*', 'eva_*', 'flexivit*' + 'eva_*', 'flexivit*' ] +#'coatnet*', 'coatnext*', 'maxvit*', 'maxxvit*', ' NUM_NON_STD = len(NON_STD_FILTERS) # exclude models that cause specific test failures @@ -53,7 +54,7 @@ MAX_JIT_SIZE = 320 TARGET_FFEAT_SIZE = 96 MAX_FFEAT_SIZE = 256 TARGET_FWD_FX_SIZE = 128 -MAX_FWD_FX_SIZE = 224 +MAX_FWD_FX_SIZE = 256 TARGET_BWD_FX_SIZE = 128 MAX_BWD_FX_SIZE = 224 diff --git a/timm/models/maxxvit.py b/timm/models/maxxvit.py index dd424078..e730fa30 100644 --- a/timm/models/maxxvit.py +++ b/timm/models/maxxvit.py @@ -12,9 +12,6 @@ These configs work well and appear to be a bit faster / lower resource than the The models without extra prefix / suffix' (coatnet_0_224, maxvit_tiny_224, etc), are intended to match paper, BUT, without any official pretrained weights it's difficult to confirm a 100% match. -# FIXME / WARNING -This impl remains a WIP, some configs and models may vanish or change... - Papers: MaxViT: Multi-Axis Vision Transformer - https://arxiv.org/abs/2204.01697 @@ -76,6 +73,8 @@ class MaxxVitTransformerCfg: partition_ratio: int = 32 window_size: Optional[Tuple[int, int]] = None grid_size: Optional[Tuple[int, int]] = None + no_block_attn: bool = False # disable window block attention for maxvit (ie only grid) + use_nchw_attn: bool = False # for MaxViT variants (not used for CoAt), keep tensors in NCHW order init_values: Optional[float] = None act_layer: str = 'gelu' norm_layer: str = 'layernorm2d' @@ -889,19 +888,17 @@ class MaxxVitBlock(nn.Module): stride: int = 1, conv_cfg: MaxxVitConvCfg = MaxxVitConvCfg(), transformer_cfg: MaxxVitTransformerCfg = MaxxVitTransformerCfg(), - use_nchw_attn: bool = False, # FIXME move to cfg? True is ~20-30% faster on TPU, 5-10% slower on GPU - use_block_attn: bool = True, # FIXME for testing ConvNeXt conv w/o block attention drop_path: float = 0., ): super().__init__() + self.nchw_attn = transformer_cfg.use_nchw_attn conv_cls = ConvNeXtBlock if conv_cfg.block_type == 'convnext' else MbConvBlock self.conv = conv_cls(dim, dim_out, stride=stride, cfg=conv_cfg, drop_path=drop_path) attn_kwargs = dict(dim=dim_out, cfg=transformer_cfg, drop_path=drop_path) - partition_layer = PartitionAttention2d if use_nchw_attn else PartitionAttentionCl - self.nchw_attn = use_nchw_attn - self.attn_block = partition_layer(**attn_kwargs) if use_block_attn else None + partition_layer = PartitionAttention2d if self.nchw_attn else PartitionAttentionCl + self.attn_block = None if transformer_cfg.no_block_attn else partition_layer(**attn_kwargs) self.attn_grid = partition_layer(partition_type='grid', **attn_kwargs) def init_weights(self, scheme=''): @@ -1084,26 +1081,48 @@ class NormMlpHead(nn.Module): hidden_size=None, pool_type='avg', drop_rate=0., - norm_layer=nn.LayerNorm, - act_layer=nn.Tanh, + norm_layer='layernorm2d', + act_layer='tanh', ): super().__init__() self.drop_rate = drop_rate + self.in_features = in_features + self.hidden_size = hidden_size self.num_features = in_features + self.use_conv = not pool_type + norm_layer = get_norm_layer(norm_layer) + act_layer = get_act_layer(act_layer) + linear_layer = partial(nn.Conv2d, kernel_size=1) if self.use_conv else nn.Linear self.global_pool = SelectAdaptivePool2d(pool_type=pool_type) self.norm = norm_layer(in_features) self.flatten = nn.Flatten(1) if pool_type else nn.Identity() if hidden_size: self.pre_logits = nn.Sequential(OrderedDict([ - ('fc', nn.Linear(in_features, hidden_size)), + ('fc', linear_layer(in_features, hidden_size)), ('act', act_layer()), ])) self.num_features = hidden_size else: self.pre_logits = nn.Identity() self.drop = nn.Dropout(self.drop_rate) - self.fc = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() + self.fc = linear_layer(self.num_features, num_classes) if num_classes > 0 else nn.Identity() + + def reset(self, num_classes, global_pool=None): + if global_pool is not None: + self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) + self.flatten = nn.Flatten(1) if global_pool else nn.Identity() + self.use_conv = self.global_pool.is_identity() + linear_layer = partial(nn.Conv2d, kernel_size=1) if self.use_conv else nn.Linear + if self.hidden_size: + if ((isinstance(self.pre_logits.fc, nn.Conv2d) and not self.use_conv) or + (isinstance(self.pre_logits.fc, nn.Linear) and self.use_conv)): + with torch.no_grad(): + new_fc = linear_layer(self.in_features, self.hidden_size) + new_fc.weight.copy_(self.pre_logits.fc.weight.reshape(new_fc.weight.shape)) + new_fc.bias.copy_(self.pre_logits.fc.bias) + self.pre_logits.fc = new_fc + self.fc = linear_layer(self.num_features, num_classes) if num_classes > 0 else nn.Identity() def forward(self, x, pre_logits: bool = False): x = self.global_pool(x) @@ -1163,6 +1182,7 @@ class MaxxVit(nn.Module): self.num_features = self.embed_dim = cfg.embed_dim[-1] self.drop_rate = drop_rate self.grad_checkpointing = False + self.feature_info = [] self.stem = Stem( in_chs=in_chans, @@ -1173,8 +1193,8 @@ class MaxxVit(nn.Module): norm_layer=cfg.conv_cfg.norm_layer, norm_eps=cfg.conv_cfg.norm_eps, ) - stride = self.stem.stride + self.feature_info += [dict(num_chs=self.stem.out_chs, reduction=2, module='stem')] feat_size = tuple([i // s for i, s in zip(img_size, to_2tuple(stride))]) num_stages = len(cfg.embed_dim) @@ -1198,15 +1218,17 @@ class MaxxVit(nn.Module): )] stride *= stage_stride in_chs = out_chs + self.feature_info += [dict(num_chs=out_chs, reduction=stride, module=f'stages.{i}')] self.stages = nn.Sequential(*stages) final_norm_layer = partial(get_norm_layer(cfg.transformer_cfg.norm_layer), eps=cfg.transformer_cfg.norm_eps) - if cfg.head_hidden_size: + self.head_hidden_size = cfg.head_hidden_size + if self.head_hidden_size: self.norm = nn.Identity() self.head = NormMlpHead( self.num_features, num_classes, - hidden_size=cfg.head_hidden_size, + hidden_size=self.head_hidden_size, pool_type=global_pool, drop_rate=drop_rate, norm_layer=final_norm_layer, @@ -1253,9 +1275,7 @@ class MaxxVit(nn.Module): def reset_classifier(self, num_classes, global_pool=None): self.num_classes = num_classes - if global_pool is None: - global_pool = self.head.global_pool.pool_type - self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate) + self.head.reset(num_classes, global_pool) def forward_features(self, x): x = self.stem(x) @@ -1376,6 +1396,7 @@ def _next_cfg( transformer_norm_layer='layernorm2d', transformer_norm_layer_cl='layernorm', window_size=None, + no_block_attn=False, init_values=1e-6, rel_pos_type='mlp', # MLP by default for maxxvit rel_pos_dim=512, @@ -1396,6 +1417,7 @@ def _next_cfg( expand_first=False, pool_type=pool_type, window_size=window_size, + no_block_attn=no_block_attn, # enabled for MaxxViT-V2 init_values=init_values[1], norm_layer=transformer_norm_layer, norm_layer_cl=transformer_norm_layer_cl, @@ -1422,8 +1444,8 @@ def _tf_cfg(): model_cfgs = dict( - # Fiddling with configs / defaults / still pretraining - coatnet_pico_rw_224=MaxxVitCfg( + # timm specific CoAtNet configs + coatnet_pico_rw=MaxxVitCfg( embed_dim=(64, 128, 256, 512), depths=(2, 3, 5, 2), stem_width=(32, 64), @@ -1432,7 +1454,7 @@ model_cfgs = dict( conv_attn_ratio=0.25, ), ), - coatnet_nano_rw_224=MaxxVitCfg( + coatnet_nano_rw=MaxxVitCfg( embed_dim=(64, 128, 256, 512), depths=(3, 4, 6, 3), stem_width=(32, 64), @@ -1442,7 +1464,7 @@ model_cfgs = dict( conv_attn_ratio=0.25, ), ), - coatnet_0_rw_224=MaxxVitCfg( + coatnet_0_rw=MaxxVitCfg( embed_dim=(96, 192, 384, 768), depths=(2, 3, 7, 2), # deeper than paper '0' model stem_width=(32, 64), @@ -1451,7 +1473,7 @@ model_cfgs = dict( transformer_shortcut_bias=False, ), ), - coatnet_1_rw_224=MaxxVitCfg( + coatnet_1_rw=MaxxVitCfg( embed_dim=(96, 192, 384, 768), depths=(2, 6, 14, 2), stem_width=(32, 64), @@ -1461,7 +1483,7 @@ model_cfgs = dict( transformer_shortcut_bias=False, ) ), - coatnet_2_rw_224=MaxxVitCfg( + coatnet_2_rw=MaxxVitCfg( embed_dim=(128, 256, 512, 1024), depths=(2, 6, 14, 2), stem_width=(64, 128), @@ -1471,7 +1493,7 @@ model_cfgs = dict( #init_values=1e-6, ), ), - coatnet_3_rw_224=MaxxVitCfg( + coatnet_3_rw=MaxxVitCfg( embed_dim=(192, 384, 768, 1536), depths=(2, 6, 14, 2), stem_width=(96, 192), @@ -1482,8 +1504,8 @@ model_cfgs = dict( ), ), - # Highly experimental configs - coatnet_bn_0_rw_224=MaxxVitCfg( + # Experimental CoAtNet configs w/ ImageNet-1k train (different norm layers, MLP rel-pos) + coatnet_bn_0_rw=MaxxVitCfg( embed_dim=(96, 192, 384, 768), depths=(2, 3, 7, 2), # deeper than paper '0' model stem_width=(32, 64), @@ -1494,7 +1516,7 @@ model_cfgs = dict( transformer_norm_layer='batchnorm2d', ) ), - coatnet_rmlp_nano_rw_224=MaxxVitCfg( + coatnet_rmlp_nano_rw=MaxxVitCfg( embed_dim=(64, 128, 256, 512), depths=(3, 4, 6, 3), stem_width=(32, 64), @@ -1505,7 +1527,7 @@ model_cfgs = dict( rel_pos_dim=384, ), ), - coatnet_rmlp_0_rw_224=MaxxVitCfg( + coatnet_rmlp_0_rw=MaxxVitCfg( embed_dim=(96, 192, 384, 768), depths=(2, 3, 7, 2), # deeper than paper '0' model stem_width=(32, 64), @@ -1514,7 +1536,7 @@ model_cfgs = dict( rel_pos_type='mlp', ), ), - coatnet_rmlp_1_rw_224=MaxxVitCfg( + coatnet_rmlp_1_rw=MaxxVitCfg( embed_dim=(96, 192, 384, 768), depths=(2, 6, 14, 2), stem_width=(32, 64), @@ -1526,7 +1548,7 @@ model_cfgs = dict( rel_pos_dim=384, # was supposed to be 512, woops ), ), - coatnet_rmlp_1_rw2_224=MaxxVitCfg( + coatnet_rmlp_1_rw2=MaxxVitCfg( embed_dim=(96, 192, 384, 768), depths=(2, 6, 14, 2), stem_width=(32, 64), @@ -1536,7 +1558,7 @@ model_cfgs = dict( rel_pos_dim=512, # was supposed to be 512, woops ), ), - coatnet_rmlp_2_rw_224=MaxxVitCfg( + coatnet_rmlp_2_rw=MaxxVitCfg( embed_dim=(128, 256, 512, 1024), depths=(2, 6, 14, 2), stem_width=(64, 128), @@ -1547,7 +1569,7 @@ model_cfgs = dict( rel_pos_type='mlp' ), ), - coatnet_rmlp_3_rw_224=MaxxVitCfg( + coatnet_rmlp_3_rw=MaxxVitCfg( embed_dim=(192, 384, 768, 1536), depths=(2, 6, 14, 2), stem_width=(96, 192), @@ -1559,14 +1581,14 @@ model_cfgs = dict( ), ), - coatnet_nano_cc_224=MaxxVitCfg( + coatnet_nano_cc=MaxxVitCfg( embed_dim=(64, 128, 256, 512), depths=(3, 4, 6, 3), stem_width=(32, 64), block_type=('C', 'C', ('C', 'T'), ('C', 'T')), **_rw_coat_cfg(), ), - coatnext_nano_rw_224=MaxxVitCfg( + coatnext_nano_rw=MaxxVitCfg( embed_dim=(64, 128, 256, 512), depths=(3, 4, 6, 3), stem_width=(32, 64), @@ -1578,89 +1600,95 @@ model_cfgs = dict( ), # Trying to be like the CoAtNet paper configs - coatnet_0_224=MaxxVitCfg( + coatnet_0=MaxxVitCfg( embed_dim=(96, 192, 384, 768), depths=(2, 3, 5, 2), stem_width=64, + head_hidden_size=768, ), - coatnet_1_224=MaxxVitCfg( + coatnet_1=MaxxVitCfg( embed_dim=(96, 192, 384, 768), depths=(2, 6, 14, 2), stem_width=64, + head_hidden_size=768, ), - coatnet_2_224=MaxxVitCfg( + coatnet_2=MaxxVitCfg( embed_dim=(128, 256, 512, 1024), depths=(2, 6, 14, 2), stem_width=128, + head_hidden_size=1024, ), - coatnet_3_224=MaxxVitCfg( + coatnet_3=MaxxVitCfg( embed_dim=(192, 384, 768, 1536), depths=(2, 6, 14, 2), stem_width=192, + head_hidden_size=1536, ), - coatnet_4_224=MaxxVitCfg( + coatnet_4=MaxxVitCfg( embed_dim=(192, 384, 768, 1536), depths=(2, 12, 28, 2), stem_width=192, + head_hidden_size=1536, ), - coatnet_5_224=MaxxVitCfg( + coatnet_5=MaxxVitCfg( embed_dim=(256, 512, 1280, 2048), depths=(2, 12, 28, 2), stem_width=192, + head_hidden_size=2048, ), # Experimental MaxVit configs - maxvit_pico_rw_256=MaxxVitCfg( + maxvit_pico_rw=MaxxVitCfg( embed_dim=(32, 64, 128, 256), depths=(2, 2, 5, 2), block_type=('M',) * 4, stem_width=(24, 32), **_rw_max_cfg(), ), - maxvit_nano_rw_256=MaxxVitCfg( + maxvit_nano_rw=MaxxVitCfg( embed_dim=(64, 128, 256, 512), depths=(1, 2, 3, 1), block_type=('M',) * 4, stem_width=(32, 64), **_rw_max_cfg(), ), - maxvit_tiny_rw_224=MaxxVitCfg( + maxvit_tiny_rw=MaxxVitCfg( embed_dim=(64, 128, 256, 512), depths=(2, 2, 5, 2), block_type=('M',) * 4, stem_width=(32, 64), **_rw_max_cfg(), ), - maxvit_tiny_rw_256=MaxxVitCfg( + maxvit_tiny_pm=MaxxVitCfg( embed_dim=(64, 128, 256, 512), depths=(2, 2, 5, 2), - block_type=('M',) * 4, + block_type=('PM',) * 4, stem_width=(32, 64), **_rw_max_cfg(), ), - maxvit_rmlp_pico_rw_256=MaxxVitCfg( + maxvit_rmlp_pico_rw=MaxxVitCfg( embed_dim=(32, 64, 128, 256), depths=(2, 2, 5, 2), block_type=('M',) * 4, stem_width=(24, 32), **_rw_max_cfg(rel_pos_type='mlp'), ), - maxvit_rmlp_nano_rw_256=MaxxVitCfg( + maxvit_rmlp_nano_rw=MaxxVitCfg( embed_dim=(64, 128, 256, 512), depths=(1, 2, 3, 1), block_type=('M',) * 4, stem_width=(32, 64), **_rw_max_cfg(rel_pos_type='mlp'), ), - maxvit_rmlp_tiny_rw_256=MaxxVitCfg( + maxvit_rmlp_tiny_rw=MaxxVitCfg( embed_dim=(64, 128, 256, 512), depths=(2, 2, 5, 2), block_type=('M',) * 4, stem_width=(32, 64), **_rw_max_cfg(rel_pos_type='mlp'), ), - maxvit_rmlp_small_rw_224=MaxxVitCfg( + maxvit_rmlp_small_rw=MaxxVitCfg( embed_dim=(96, 192, 384, 768), depths=(2, 2, 5, 2), block_type=('M',) * 4, @@ -1670,27 +1698,7 @@ model_cfgs = dict( init_values=1e-6, ), ), - maxvit_rmlp_small_rw_256=MaxxVitCfg( - embed_dim=(96, 192, 384, 768), - depths=(2, 2, 5, 2), - block_type=('M',) * 4, - stem_width=(32, 64), - **_rw_max_cfg( - rel_pos_type='mlp', - init_values=1e-6, - ), - ), - maxvit_rmlp_base_rw_224=MaxxVitCfg( - embed_dim=(96, 192, 384, 768), - depths=(2, 6, 14, 2), - block_type=('M',) * 4, - stem_width=(32, 64), - head_hidden_size=768, - **_rw_max_cfg( - rel_pos_type='mlp', - ), - ), - maxvit_rmlp_base_rw_384=MaxxVitCfg( + maxvit_rmlp_base_rw=MaxxVitCfg( embed_dim=(96, 192, 384, 768), depths=(2, 6, 14, 2), block_type=('M',) * 4, @@ -1701,15 +1709,7 @@ model_cfgs = dict( ), ), - maxvit_tiny_pm_256=MaxxVitCfg( - embed_dim=(64, 128, 256, 512), - depths=(2, 2, 5, 2), - block_type=('PM',) * 4, - stem_width=(32, 64), - **_rw_max_cfg(), - ), - - maxxvit_rmlp_nano_rw_256=MaxxVitCfg( + maxxvit_rmlp_nano_rw=MaxxVitCfg( embed_dim=(64, 128, 256, 512), depths=(1, 2, 3, 1), block_type=('M',) * 4, @@ -1717,33 +1717,50 @@ model_cfgs = dict( weight_init='normal', **_next_cfg(), ), - maxxvit_rmlp_tiny_rw_256=MaxxVitCfg( + maxxvit_rmlp_tiny_rw=MaxxVitCfg( embed_dim=(64, 128, 256, 512), depths=(2, 2, 5, 2), block_type=('M',) * 4, stem_width=(32, 64), **_next_cfg(), ), - maxxvit_rmlp_small_rw_256=MaxxVitCfg( + maxxvit_rmlp_small_rw=MaxxVitCfg( embed_dim=(96, 192, 384, 768), depths=(2, 2, 5, 2), block_type=('M',) * 4, stem_width=(48, 96), **_next_cfg(), ), - maxxvit_rmlp_base_rw_224=MaxxVitCfg( + + maxxvitv2_nano_rw=MaxxVitCfg( embed_dim=(96, 192, 384, 768), - depths=(2, 6, 14, 2), + depths=(1, 2, 3, 1), block_type=('M',) * 4, stem_width=(48, 96), - **_next_cfg(), + weight_init='normal', + **_next_cfg( + no_block_attn=True, + rel_pos_type='bias', + ), ), - maxxvit_rmlp_large_rw_224=MaxxVitCfg( + maxxvitv2_rmlp_base_rw=MaxxVitCfg( embed_dim=(128, 256, 512, 1024), depths=(2, 6, 12, 2), block_type=('M',) * 4, stem_width=(64, 128), - **_next_cfg(), + **_next_cfg( + no_block_attn=True, + ), + ), + maxxvitv2_rmlp_large_rw=MaxxVitCfg( + embed_dim=(160, 320, 640, 1280), + depths=(2, 6, 16, 2), + block_type=('M',) * 4, + stem_width=(80, 160), + head_hidden_size=1280, + **_next_cfg( + no_block_attn=True, + ), ), # Trying to be like the MaxViT paper configs @@ -1795,11 +1812,29 @@ model_cfgs = dict( ) +def checkpoint_filter_fn(state_dict, model: nn.Module): + model_state_dict = model.state_dict() + out_dict = {} + for k, v in state_dict.items(): + if k in model_state_dict and v.ndim != model_state_dict[k].ndim and v.numel() == model_state_dict[k].numel(): + # adapt between conv2d / linear layers + assert v.ndim in (2, 4) + v = v.reshape(model_state_dict[k].shape) + out_dict[k] = v + return out_dict + + def _create_maxxvit(variant, cfg_variant=None, pretrained=False, **kwargs): + if cfg_variant is None: + if variant in model_cfgs: + cfg_variant = variant + else: + cfg_variant = '_'.join(variant.split('_')[:-1]) return build_model_with_cfg( MaxxVit, variant, pretrained, - model_cfg=model_cfgs[variant] if not cfg_variant else model_cfgs[cfg_variant], + model_cfg=model_cfgs[cfg_variant], feature_cfg=dict(flatten_sequential=True), + pretrained_filter_fn=checkpoint_filter_fn, **kwargs) @@ -1815,155 +1850,218 @@ def _cfg(url='', **kwargs): default_cfgs = generate_default_cfgs({ - # Fiddling with configs / defaults / still pretraining - 'coatnet_pico_rw_224': _cfg(url=''), - 'coatnet_nano_rw_224': _cfg( + # timm specific CoAtNet configs, ImageNet-1k pretrain, fixed rel-pos + 'coatnet_pico_rw_224.untrained': _cfg(url=''), + 'coatnet_nano_rw_224.sw_in1k': _cfg( + hf_hub_id='timm/', url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/coatnet_nano_rw_224_sw-f53093b4.pth', crop_pct=0.9), - 'coatnet_0_rw_224': _cfg( + 'coatnet_0_rw_224.sw_in1k': _cfg( + hf_hub_id='timm/', url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/coatnet_0_rw_224_sw-a6439706.pth'), - 'coatnet_1_rw_224': _cfg( + 'coatnet_1_rw_224.sw_in1k': _cfg( + hf_hub_id='timm/', url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/coatnet_1_rw_224_sw-5cae1ea8.pth' ), - 'coatnet_2_rw_224': _cfg(url=''), - 'coatnet_3_rw_224': _cfg(url=''), - # Highly experimental configs - 'coatnet_bn_0_rw_224': _cfg( + # timm specific CoAtNet configs, ImageNet-12k pretrain w/ 1k fine-tune, fixed rel-pos + 'coatnet_2_rw_224.sw_in12k_ft_in1k': _cfg( + hf_hub_id='timm/'), + #'coatnet_3_rw_224.untrained': _cfg(url=''), + + # Experimental CoAtNet configs w/ ImageNet-12k pretrain -> 1k fine-tune (different norm layers, MLP rel-pos) + 'coatnet_rmlp_1_rw2_224.sw_in12k_ft_in1k': _cfg( + hf_hub_id='timm/'), + 'coatnet_rmlp_2_rw_224.sw_in12k_ft_in1k': _cfg( + hf_hub_id='timm/'), + 'coatnet_rmlp_2_rw_384.sw_in12k_ft_in1k': _cfg( + hf_hub_id='timm/', + input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'), + + # Experimental CoAtNet configs w/ ImageNet-1k train (different norm layers, MLP rel-pos) + 'coatnet_bn_0_rw_224.sw_in1k': _cfg( + hf_hub_id='timm/', url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/coatnet_bn_0_rw_224_sw-c228e218.pth', mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, crop_pct=0.95), - 'coatnet_rmlp_nano_rw_224': _cfg( + 'coatnet_rmlp_nano_rw_224.sw_in1k': _cfg( + hf_hub_id='timm/', url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/coatnet_rmlp_nano_rw_224_sw-bd1d51b3.pth', crop_pct=0.9), - 'coatnet_rmlp_0_rw_224': _cfg(url=''), - 'coatnet_rmlp_1_rw_224': _cfg( + 'coatnet_rmlp_0_rw_224.untrained': _cfg(url=''), + 'coatnet_rmlp_1_rw_224.sw_in1k': _cfg( + hf_hub_id='timm/', url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/coatnet_rmlp_1_rw_224_sw-9051e6c3.pth'), - 'coatnet_rmlp_1_rw2_224': _cfg(url=''), - 'coatnet_rmlp_2_rw_224': _cfg( + 'coatnet_rmlp_2_rw_224.sw_in1k': _cfg( + hf_hub_id='timm/', url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/coatnet_rmlp_2_rw_224_sw-5ccfac55.pth'), - 'coatnet_rmlp_3_rw_224': _cfg(url=''), - 'coatnet_nano_cc_224': _cfg(url=''), - 'coatnext_nano_rw_224': _cfg( + 'coatnet_rmlp_3_rw_224.untrained': _cfg(url=''), + 'coatnet_nano_cc_224.untrained': _cfg(url=''), + 'coatnext_nano_rw_224.sw_in1k': _cfg( + hf_hub_id='timm/', url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/coatnext_nano_rw_224_ad-22cb71c2.pth', crop_pct=0.9), - # Trying to be like the CoAtNet paper configs - 'coatnet_0_224': _cfg(url=''), - 'coatnet_1_224': _cfg(url=''), - 'coatnet_2_224': _cfg(url=''), - 'coatnet_3_224': _cfg(url=''), - 'coatnet_4_224': _cfg(url=''), - 'coatnet_5_224': _cfg(url=''), - - # Experimental configs - 'maxvit_pico_rw_256': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8)), - 'maxvit_nano_rw_256': _cfg( + # ImagenNet-12k pretrain CoAtNet + 'coatnet_2_rw_224.sw_in12k': _cfg( + hf_hub_id='timm/', + num_classes=11821), + 'coatnet_3_rw_224.sw_in12k': _cfg( + hf_hub_id='timm/', + num_classes=11821), + 'coatnet_rmlp_1_rw2_224.sw_in12k': _cfg( + hf_hub_id='timm/', + num_classes=11821), + 'coatnet_rmlp_2_rw_224.sw_in12k': _cfg( + hf_hub_id='timm/', + num_classes=11821), + + # Trying to be like the CoAtNet paper configs (will adapt if 'tf' weights are ever released) + 'coatnet_0_224.untrained': _cfg(url=''), + 'coatnet_1_224.untrained': _cfg(url=''), + 'coatnet_2_224.untrained': _cfg(url=''), + 'coatnet_3_224.untrained': _cfg(url=''), + 'coatnet_4_224.untrained': _cfg(url=''), + 'coatnet_5_224.untrained': _cfg(url=''), + + # timm specific MaxVit configs, ImageNet-1k pretrain or untrained + 'maxvit_pico_rw_256.untrained': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8)), + 'maxvit_nano_rw_256.sw_in1k': _cfg( + hf_hub_id='timm/', url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/maxvit_nano_rw_256_sw-fb127241.pth', input_size=(3, 256, 256), pool_size=(8, 8)), - 'maxvit_tiny_rw_224': _cfg( + 'maxvit_tiny_rw_224.sw_in1k': _cfg( + hf_hub_id='timm/', url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/maxvit_tiny_rw_224_sw-7d0dffeb.pth'), - 'maxvit_tiny_rw_256': _cfg( + 'maxvit_tiny_rw_256.untrained': _cfg( url='', input_size=(3, 256, 256), pool_size=(8, 8)), - 'maxvit_rmlp_pico_rw_256': _cfg( + 'maxvit_tiny_pm_256.untrained': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8)), + + # timm specific MaxVit w/ MLP rel-pos, ImageNet-1k pretrain + 'maxvit_rmlp_pico_rw_256.sw_in1k': _cfg( + hf_hub_id='timm/', url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/maxvit_rmlp_pico_rw_256_sw-8d82f2c6.pth', input_size=(3, 256, 256), pool_size=(8, 8)), - 'maxvit_rmlp_nano_rw_256': _cfg( + 'maxvit_rmlp_nano_rw_256.sw_in1k': _cfg( + hf_hub_id='timm/', url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/maxvit_rmlp_nano_rw_256_sw-c17bb0d6.pth', input_size=(3, 256, 256), pool_size=(8, 8)), - 'maxvit_rmlp_tiny_rw_256': _cfg( + 'maxvit_rmlp_tiny_rw_256.sw_in1k': _cfg( + hf_hub_id='timm/', url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/maxvit_rmlp_tiny_rw_256_sw-bbef0ff5.pth', input_size=(3, 256, 256), pool_size=(8, 8)), - 'maxvit_rmlp_small_rw_224': _cfg( + 'maxvit_rmlp_small_rw_224.sw_in1k': _cfg( + hf_hub_id='timm/', url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/maxvit_rmlp_small_rw_224_sw-6ef0ae4f.pth', crop_pct=0.9, ), - 'maxvit_rmlp_small_rw_256': _cfg( + 'maxvit_rmlp_small_rw_256.untrained': _cfg( url='', input_size=(3, 256, 256), pool_size=(8, 8)), - 'maxvit_rmlp_base_rw_224': _cfg( - url='', - ), - 'maxvit_rmlp_base_rw_384': _cfg( - url='', - input_size=(3, 384, 384), pool_size=(12, 12)), - 'maxvit_tiny_pm_256': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8)), + # timm specific MaxVit w/ ImageNet-12k pretrain and 1k fine-tune + 'maxvit_rmlp_base_rw_224.sw_in12k_ft_in1k': _cfg( + hf_hub_id='timm/', + ), + 'maxvit_rmlp_base_rw_384.sw_in12k_ft_in1k': _cfg( + hf_hub_id='timm/', + input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'), + + # timm specific MaxVit w/ ImageNet-12k pretrain + 'maxvit_rmlp_base_rw_224.sw_in12k': _cfg( + hf_hub_id='timm/', + num_classes=11821, + ), - 'maxxvit_rmlp_nano_rw_256': _cfg( + # timm MaxxViT configs (ConvNeXt conv blocks mixed with MaxVit transformer blocks) + 'maxxvit_rmlp_nano_rw_256.sw_in1k': _cfg( + hf_hub_id='timm/', url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/maxxvit_rmlp_nano_rw_256_sw-0325d459.pth', input_size=(3, 256, 256), pool_size=(8, 8)), - 'maxxvit_rmlp_tiny_rw_256': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8)), - 'maxxvit_rmlp_small_rw_256': _cfg( + 'maxxvit_rmlp_tiny_rw_256.untrained': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8)), + 'maxxvit_rmlp_small_rw_256.sw_in1k': _cfg( + hf_hub_id='timm/', url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/maxxvit_rmlp_small_rw_256_sw-37e217ff.pth', input_size=(3, 256, 256), pool_size=(8, 8)), - 'maxxvit_rmlp_base_rw_224': _cfg(url=''), - 'maxxvit_rmlp_large_rw_224': _cfg(url=''), + # timm MaxxViT-V2 configs (ConvNeXt conv blocks mixed with MaxVit transformer blocks, more width, no block attn) + 'maxxvitv2_nano_rw_256.sw_in1k': _cfg( + hf_hub_id='timm/', + input_size=(3, 256, 256), pool_size=(8, 8)), + 'maxxvitv2_rmlp_base_rw_224.sw_in12k_ft_in1k': _cfg( + hf_hub_id='timm/'), + 'maxxvitv2_rmlp_base_rw_384.sw_in12k_ft_in1k': _cfg( + hf_hub_id='timm/', + input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'), + 'maxxvitv2_rmlp_large_rw_224.untrained': _cfg(url=''), + + 'maxxvitv2_rmlp_base_rw_224.sw_in12k': _cfg( + hf_hub_id='timm/', + num_classes=11821), # MaxViT models ported from official Tensorflow impl 'maxvit_tiny_tf_224.in1k': _cfg( - hf_hub_id='timm/maxvit_tiny_tf_224.in1k', + hf_hub_id='timm/', mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD), 'maxvit_tiny_tf_384.in1k': _cfg( - hf_hub_id='timm/maxvit_tiny_tf_384.in1k', - input_size=(3, 384, 384), crop_pct=1.0, crop_mode='squash'), + hf_hub_id='timm/', + input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'), 'maxvit_tiny_tf_512.in1k': _cfg( - hf_hub_id='timm/maxvit_tiny_tf_512.in1k', - input_size=(3, 512, 512), crop_pct=1.0, crop_mode='squash'), + hf_hub_id='timm/', + input_size=(3, 512, 512), pool_size=(16, 16), crop_pct=1.0, crop_mode='squash'), 'maxvit_small_tf_224.in1k': _cfg( - hf_hub_id='timm/maxvit_small_tf_224.in1k', + hf_hub_id='timm/', mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD), 'maxvit_small_tf_384.in1k': _cfg( - hf_hub_id='timm/maxvit_small_tf_384.in1k', - input_size=(3, 384, 384), crop_pct=1.0, crop_mode='squash'), + hf_hub_id='timm/', + input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'), 'maxvit_small_tf_512.in1k': _cfg( - hf_hub_id='timm/maxvit_small_tf_512.in1k', - input_size=(3, 512, 512), crop_pct=1.0, crop_mode='squash'), + hf_hub_id='timm/', + input_size=(3, 512, 512), pool_size=(16, 16), crop_pct=1.0, crop_mode='squash'), 'maxvit_base_tf_224.in1k': _cfg( - hf_hub_id='timm/maxvit_base_tf_224.in1k', + hf_hub_id='timm/', mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD), 'maxvit_base_tf_384.in1k': _cfg( - hf_hub_id='timm/maxvit_base_tf_384.in1k', - input_size=(3, 384, 384), crop_pct=1.0, crop_mode='squash'), + hf_hub_id='timm/', + input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'), 'maxvit_base_tf_512.in1k': _cfg( - hf_hub_id='timm/maxvit_base_tf_512.in1k', - input_size=(3, 512, 512), crop_pct=1.0, crop_mode='squash'), + hf_hub_id='timm/', + input_size=(3, 512, 512), pool_size=(16, 16), crop_pct=1.0, crop_mode='squash'), 'maxvit_large_tf_224.in1k': _cfg( - hf_hub_id='timm/maxvit_large_tf_224.in1k', + hf_hub_id='timm/', mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD), 'maxvit_large_tf_384.in1k': _cfg( - hf_hub_id='timm/maxvit_large_tf_384.in1k', - input_size=(3, 384, 384), crop_pct=1.0, crop_mode='squash'), + hf_hub_id='timm/', + input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'), 'maxvit_large_tf_512.in1k': _cfg( - hf_hub_id='timm/maxvit_large_tf_512.in1k', - input_size=(3, 512, 512), crop_pct=1.0, crop_mode='squash'), + hf_hub_id='timm/', + input_size=(3, 512, 512), pool_size=(16, 16), crop_pct=1.0, crop_mode='squash'), 'maxvit_base_tf_224.in21k': _cfg( url=''), 'maxvit_base_tf_384.in21k_ft_in1k': _cfg( - hf_hub_id='timm/maxvit_base_tf_384.in21k_ft_in1k', - input_size=(3, 384, 384), crop_pct=1.0, crop_mode='squash'), + hf_hub_id='timm/', + input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'), 'maxvit_base_tf_512.in21k_ft_in1k': _cfg( - hf_hub_id='timm/maxvit_base_tf_512.in21k_ft_in1k', - input_size=(3, 512, 512), crop_pct=1.0, crop_mode='squash'), + hf_hub_id='timm/', + input_size=(3, 512, 512), pool_size=(16, 16), crop_pct=1.0, crop_mode='squash'), 'maxvit_large_tf_224.in21k': _cfg( url=''), 'maxvit_large_tf_384.in21k_ft_in1k': _cfg( - hf_hub_id='timm/maxvit_large_tf_384.in21k_ft_in1k', - input_size=(3, 384, 384), crop_pct=1.0, crop_mode='squash'), + hf_hub_id='timm/', + input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'), 'maxvit_large_tf_512.in21k_ft_in1k': _cfg( - hf_hub_id='timm/maxvit_large_tf_512.in21k_ft_in1k', + hf_hub_id='timm/', input_size=(3, 512, 512), crop_pct=1.0, crop_mode='squash'), 'maxvit_xlarge_tf_224.in21k': _cfg( url=''), 'maxvit_xlarge_tf_384.in21k_ft_in1k': _cfg( - hf_hub_id='timm/maxvit_xlarge_tf_384.in21k_ft_in1k', - input_size=(3, 384, 384), crop_pct=1.0, crop_mode='squash'), + hf_hub_id='timm/', + input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'), 'maxvit_xlarge_tf_512.in21k_ft_in1k': _cfg( - hf_hub_id='timm/maxvit_xlarge_tf_512.in21k_ft_in1k', - input_size=(3, 512, 512), crop_pct=1.0, crop_mode='squash'), + hf_hub_id='timm/', + input_size=(3, 512, 512), pool_size=(16, 16), crop_pct=1.0, crop_mode='squash'), }) @@ -2027,6 +2125,11 @@ def coatnet_rmlp_2_rw_224(pretrained=False, **kwargs): return _create_maxxvit('coatnet_rmlp_2_rw_224', pretrained=pretrained, **kwargs) +@register_model +def coatnet_rmlp_2_rw_384(pretrained=False, **kwargs): + return _create_maxxvit('coatnet_rmlp_2_rw_384', pretrained=pretrained, **kwargs) + + @register_model def coatnet_rmlp_3_rw_224(pretrained=False, **kwargs): return _create_maxxvit('coatnet_rmlp_3_rw_224', pretrained=pretrained, **kwargs) @@ -2148,13 +2251,23 @@ def maxxvit_rmlp_small_rw_256(pretrained=False, **kwargs): @register_model -def maxxvit_rmlp_base_rw_224(pretrained=False, **kwargs): - return _create_maxxvit('maxxvit_rmlp_base_rw_224', pretrained=pretrained, **kwargs) +def maxxvitv2_nano_rw_256(pretrained=False, **kwargs): + return _create_maxxvit('maxxvitv2_nano_rw_256', pretrained=pretrained, **kwargs) + + +@register_model +def maxxvitv2_rmlp_base_rw_224(pretrained=False, **kwargs): + return _create_maxxvit('maxxvitv2_rmlp_base_rw_224', pretrained=pretrained, **kwargs) + + +@register_model +def maxxvitv2_rmlp_base_rw_384(pretrained=False, **kwargs): + return _create_maxxvit('maxxvitv2_rmlp_base_rw_384', pretrained=pretrained, **kwargs) @register_model -def maxxvit_rmlp_large_rw_224(pretrained=False, **kwargs): - return _create_maxxvit('maxxvit_rmlp_large_rw_224', pretrained=pretrained, **kwargs) +def maxxvitv2_rmlp_large_rw_224(pretrained=False, **kwargs): + return _create_maxxvit('maxxvitv2_rmlp_large_rw_224', pretrained=pretrained, **kwargs) @register_model