diff --git a/timm/models/metaformers.py b/timm/models/metaformers.py index 679d7ed6..361936d4 100644 --- a/timm/models/metaformers.py +++ b/timm/models/metaformers.py @@ -39,184 +39,6 @@ from ._registry import register_model __all__ = ['MetaFormer'] -def _cfg(url='', **kwargs): - return { - 'url': url, - 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7), - 'crop_pct': 1.0, 'interpolation': 'bicubic', - 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, - 'classifier': 'head', 'first_conv': 'patch_embed.conv', - **kwargs - } - - -cfgs_v2 = generate_default_cfgs({ - 'poolformerv1_s12.sail_in1k': _cfg( - url='https://github.com/sail-sg/poolformer/releases/download/v1.0/poolformer_s12.pth.tar', - crop_pct=0.9), - 'poolformerv1_s24.sail_in1k': _cfg( - url='https://github.com/sail-sg/poolformer/releases/download/v1.0/poolformer_s24.pth.tar', - crop_pct=0.9), - 'poolformerv1_s36.sail_in1k': _cfg( - url='https://github.com/sail-sg/poolformer/releases/download/v1.0/poolformer_s36.pth.tar', - crop_pct=0.9), - 'poolformerv1_m36.sail_in1k': _cfg( - url='https://github.com/sail-sg/poolformer/releases/download/v1.0/poolformer_m36.pth.tar', - crop_pct=0.95), - 'poolformerv1_m48.sail_in1k': _cfg( - url='https://github.com/sail-sg/poolformer/releases/download/v1.0/poolformer_m48.pth.tar', - crop_pct=0.95), - - 'identityformer_s12.sail_in1k': _cfg( - url='https://huggingface.co/sail/dl/resolve/main/identityformer/identityformer_s12.pth'), - 'identityformer_s24.sail_in1k': _cfg( - url='https://huggingface.co/sail/dl/resolve/main/identityformer/identityformer_s24.pth'), - 'identityformer_s36.sail_in1k': _cfg( - url='https://huggingface.co/sail/dl/resolve/main/identityformer/identityformer_s36.pth'), - 'identityformer_m36.sail_in1k': _cfg( - url='https://huggingface.co/sail/dl/resolve/main/identityformer/identityformer_m36.pth'), - 'identityformer_m48.sail_in1k': _cfg( - url='https://huggingface.co/sail/dl/resolve/main/identityformer/identityformer_m48.pth'), - - - 'randformer_s12.sail_in1k': _cfg( - url='https://huggingface.co/sail/dl/resolve/main/randformer/randformer_s12.pth'), - 'randformer_s24.sail_in1k': _cfg( - url='https://huggingface.co/sail/dl/resolve/main/randformer/randformer_s24.pth'), - 'randformer_s36.sail_in1k': _cfg( - url='https://huggingface.co/sail/dl/resolve/main/randformer/randformer_s36.pth'), - 'randformer_m36.sail_in1k': _cfg( - url='https://huggingface.co/sail/dl/resolve/main/randformer/randformer_m36.pth'), - 'randformer_m48.sail_in1k': _cfg( - url='https://huggingface.co/sail/dl/resolve/main/randformer/randformer_m48.pth'), - - 'poolformerv2_s12.sail_in1k': _cfg( - url='https://huggingface.co/sail/dl/resolve/main/poolformerv2/poolformerv2_s12.pth'), - 'poolformerv2_s24.sail_in1k': _cfg( - url='https://huggingface.co/sail/dl/resolve/main/poolformerv2/poolformerv2_s24.pth'), - 'poolformerv2_s36.sail_in1k': _cfg( - url='https://huggingface.co/sail/dl/resolve/main/poolformerv2/poolformerv2_s36.pth'), - 'poolformerv2_m36.sail_in1k': _cfg( - url='https://huggingface.co/sail/dl/resolve/main/poolformerv2/poolformerv2_m36.pth'), - 'poolformerv2_m48.sail_in1k': _cfg( - url='https://huggingface.co/sail/dl/resolve/main/poolformerv2/poolformerv2_m48.pth'), - - - - 'convformer_s18.sail_in1k': _cfg( - url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_s18.pth'), - 'convformer_s18.sail_in1k_384': _cfg( - url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_s18_384.pth', - input_size=(3, 384, 384)), - 'convformer_s18.sail_in22k_ft_in1k': _cfg( - url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_s18_in21ft1k.pth'), - 'convformer_s18.sail_in22k_ft_in1k_384': _cfg( - url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_s18_384_in21ft1k.pth', - input_size=(3, 384, 384)), - 'convformer_s18.sail_in22k': _cfg( - url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_s18_in21k.pth', - num_classes=21841), - - 'convformer_s36.sail_in1k': _cfg( - url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_s36.pth'), - 'convformer_s36.sail_in1k_384': _cfg( - url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_s36_384.pth', - input_size=(3, 384, 384)), - 'convformer_s36.sail_in22k_ft_in1k': _cfg( - url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_s36_in21ft1k.pth'), - 'convformer_s36.sail_in22k_ft_in1k_384': _cfg( - url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_s36_384_in21ft1k.pth', - input_size=(3, 384, 384)), - 'convformer_s36.sail_in22k': _cfg( - url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_s36_in21k.pth', - num_classes=21841), - - 'convformer_m36.sail_in1k': _cfg( - url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_m36.pth'), - 'convformer_m36.sail_in1k_384': _cfg( - url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_m36_384.pth', - input_size=(3, 384, 384)), - 'convformer_m36.sail_in22k_ft_in1k': _cfg( - url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_m36_in21ft1k.pth'), - 'convformer_m36.sail_in22k_ft_in1k_384': _cfg( - url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_m36_384_in21ft1k.pth', - input_size=(3, 384, 384)), - 'convformer_m36.sail_in22k': _cfg( - url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_m36_in21k.pth', - num_classes=21841), - - 'convformer_b36.sail_in1k': _cfg( - url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_b36.pth'), - 'convformer_b36.sail_in1k_384': _cfg( - url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_b36_384.pth', - input_size=(3, 384, 384)), - 'convformer_b36.sail_in22k_ft_in1k': _cfg( - url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_b36_in21ft1k.pth'), - 'convformer_b36_384.sail_in22k_ft_in1k_384': _cfg( - url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_b36_384_in21ft1k.pth', - input_size=(3, 384, 384)), - 'convformer_b36.sail_in22k': _cfg( - url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_b36_in21k.pth', - num_classes=21841), - - - 'caformer_s18.sail_in1k': _cfg( - url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_s18.pth'), - 'caformer_s18.sail_in1k_384': _cfg( - url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_s18_384.pth', - input_size=(3, 384, 384)), - 'caformer_s18.sail_in22k_ft_in1k': _cfg( - url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_s18_in21ft1k.pth'), - 'caformer_s18.sail_in22k_ft_in1k_384': _cfg( - url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_s18_384_in21ft1k.pth', - input_size=(3, 384, 384)), - 'caformer_s18.sail_in22k': _cfg( - url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_s18_in21k.pth', - num_classes=21841), - - 'caformer_s36.sail_in1k': _cfg( - url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_s36.pth'), - 'caformer_s36.sail_in1k_384': _cfg( - url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_s36_384.pth', - input_size=(3, 384, 384)), - 'caformer_s36.sail_in22k_ft_in1k': _cfg( - url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_s36_in21ft1k.pth'), - 'caformer_s36.sail_in22k_ft_in1k_384': _cfg( - url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_s36_384_in21ft1k.pth', - input_size=(3, 384, 384)), - 'caformer_s36.sail_in22k': _cfg( - url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_s36_in21k.pth', - num_classes=21841), - - 'caformer_m36.sail_in1k': _cfg( - url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_m36.pth'), - 'caformer_m36.sail_in1k_384': _cfg( - url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_m36_384.pth', - input_size=(3, 384, 384)), - 'caformer_m36.sail_in22k_ft_in1k': _cfg( - url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_m36_in21ft1k.pth'), - 'caformer_m36.sail_in22k_ft_in1k_384': _cfg( - url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_m36_384_in21ft1k.pth', - input_size=(3, 384, 384)), - 'caformer_m36.sail_in22k': _cfg( - url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_m36_in21k.pth', - num_classes=21841), - - 'caformer_b36.sail_in1k': _cfg( - url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_b36.pth'), - 'caformer_b36.sail_in1k_384': _cfg( - url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_b36_384.pth', - input_size=(3, 384, 384)), - 'caformer_b36.sail_in22k_ft_in1k': _cfg( - url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_b36_in21ft1k.pth'), - 'caformer_b36.sail_in22k_ft_in1k_384': _cfg( - url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_b36_384_in21ft1k.pth', - input_size=(3, 384, 384)), - 'caformer_b36.sail_in22k': _cfg( - url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_b36_in21k.pth', - num_classes=21841), -}) - class Downsampling(nn.Module): """ Downsampling implemented by a layer of convolution. @@ -238,7 +60,10 @@ class Downsampling(nn.Module): x = self.pre_norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) x = self.conv(x) + + x = self.post_norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) + print(x[0][0][0][0]) return x ''' class Downsampling(nn.Module): @@ -760,16 +585,14 @@ class MetaFormer(nn.Module): self.stages = nn.Sequential(*stages) self.norm = self.output_norm(self.num_features) - ''' + self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) if head_dropout > 0.0: self.head = self.head_fn(self.num_features, self.num_classes, head_dropout=self.head_dropout) else: self.head = self.head_fn(self.num_features, self.num_classes) - - ''' - self.reset_classifier(self.num_classes, global_pool) + self.apply(self._init_weights) @@ -816,7 +639,9 @@ class MetaFormer(nn.Module): def forward_features(self, x): x = self.patch_embed(x) - x = self.stages(x) + #x = self.stages(x) + for i, stage in enumerate(self.stages): + x = stage(x) return x @@ -860,6 +685,185 @@ def _create_metaformer(variant, pretrained=False, **kwargs): return model + +def _cfg(url='', **kwargs): + return { + 'url': url, + 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7), + 'crop_pct': 1.0, 'interpolation': 'bicubic', + 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, + 'classifier': 'head', 'first_conv': 'patch_embed.conv', + **kwargs + } + +default_cfgs = generate_default_cfgs({ + 'poolformerv1_s12.sail_in1k': _cfg( + url='https://github.com/sail-sg/poolformer/releases/download/v1.0/poolformer_s12.pth.tar', + crop_pct=0.9), + 'poolformerv1_s24.sail_in1k': _cfg( + url='https://github.com/sail-sg/poolformer/releases/download/v1.0/poolformer_s24.pth.tar', + crop_pct=0.9), + 'poolformerv1_s36.sail_in1k': _cfg( + url='https://github.com/sail-sg/poolformer/releases/download/v1.0/poolformer_s36.pth.tar', + crop_pct=0.9), + 'poolformerv1_m36.sail_in1k': _cfg( + url='https://github.com/sail-sg/poolformer/releases/download/v1.0/poolformer_m36.pth.tar', + crop_pct=0.95), + 'poolformerv1_m48.sail_in1k': _cfg( + url='https://github.com/sail-sg/poolformer/releases/download/v1.0/poolformer_m48.pth.tar', + crop_pct=0.95), + + 'identityformer_s12.sail_in1k': _cfg( + url='https://huggingface.co/sail/dl/resolve/main/identityformer/identityformer_s12.pth'), + 'identityformer_s24.sail_in1k': _cfg( + url='https://huggingface.co/sail/dl/resolve/main/identityformer/identityformer_s24.pth'), + 'identityformer_s36.sail_in1k': _cfg( + url='https://huggingface.co/sail/dl/resolve/main/identityformer/identityformer_s36.pth'), + 'identityformer_m36.sail_in1k': _cfg( + url='https://huggingface.co/sail/dl/resolve/main/identityformer/identityformer_m36.pth'), + 'identityformer_m48.sail_in1k': _cfg( + url='https://huggingface.co/sail/dl/resolve/main/identityformer/identityformer_m48.pth'), + + + 'randformer_s12.sail_in1k': _cfg( + url='https://huggingface.co/sail/dl/resolve/main/randformer/randformer_s12.pth'), + 'randformer_s24.sail_in1k': _cfg( + url='https://huggingface.co/sail/dl/resolve/main/randformer/randformer_s24.pth'), + 'randformer_s36.sail_in1k': _cfg( + url='https://huggingface.co/sail/dl/resolve/main/randformer/randformer_s36.pth'), + 'randformer_m36.sail_in1k': _cfg( + url='https://huggingface.co/sail/dl/resolve/main/randformer/randformer_m36.pth'), + 'randformer_m48.sail_in1k': _cfg( + url='https://huggingface.co/sail/dl/resolve/main/randformer/randformer_m48.pth'), + + 'poolformerv2_s12.sail_in1k': _cfg( + url='https://huggingface.co/sail/dl/resolve/main/poolformerv2/poolformerv2_s12.pth'), + 'poolformerv2_s24.sail_in1k': _cfg( + url='https://huggingface.co/sail/dl/resolve/main/poolformerv2/poolformerv2_s24.pth'), + 'poolformerv2_s36.sail_in1k': _cfg( + url='https://huggingface.co/sail/dl/resolve/main/poolformerv2/poolformerv2_s36.pth'), + 'poolformerv2_m36.sail_in1k': _cfg( + url='https://huggingface.co/sail/dl/resolve/main/poolformerv2/poolformerv2_m36.pth'), + 'poolformerv2_m48.sail_in1k': _cfg( + url='https://huggingface.co/sail/dl/resolve/main/poolformerv2/poolformerv2_m48.pth'), + + + + 'convformer_s18.sail_in1k': _cfg( + url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_s18.pth'), + 'convformer_s18.sail_in1k_384': _cfg( + url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_s18_384.pth', + input_size=(3, 384, 384)), + 'convformer_s18.sail_in22k_ft_in1k': _cfg( + url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_s18_in21ft1k.pth'), + 'convformer_s18.sail_in22k_ft_in1k_384': _cfg( + url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_s18_384_in21ft1k.pth', + input_size=(3, 384, 384)), + 'convformer_s18.sail_in22k': _cfg( + url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_s18_in21k.pth', + num_classes=21841), + + 'convformer_s36.sail_in1k': _cfg( + url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_s36.pth'), + 'convformer_s36.sail_in1k_384': _cfg( + url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_s36_384.pth', + input_size=(3, 384, 384)), + 'convformer_s36.sail_in22k_ft_in1k': _cfg( + url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_s36_in21ft1k.pth'), + 'convformer_s36.sail_in22k_ft_in1k_384': _cfg( + url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_s36_384_in21ft1k.pth', + input_size=(3, 384, 384)), + 'convformer_s36.sail_in22k': _cfg( + url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_s36_in21k.pth', + num_classes=21841), + + 'convformer_m36.sail_in1k': _cfg( + url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_m36.pth'), + 'convformer_m36.sail_in1k_384': _cfg( + url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_m36_384.pth', + input_size=(3, 384, 384)), + 'convformer_m36.sail_in22k_ft_in1k': _cfg( + url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_m36_in21ft1k.pth'), + 'convformer_m36.sail_in22k_ft_in1k_384': _cfg( + url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_m36_384_in21ft1k.pth', + input_size=(3, 384, 384)), + 'convformer_m36.sail_in22k': _cfg( + url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_m36_in21k.pth', + num_classes=21841), + + 'convformer_b36.sail_in1k': _cfg( + url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_b36.pth'), + 'convformer_b36.sail_in1k_384': _cfg( + url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_b36_384.pth', + input_size=(3, 384, 384)), + 'convformer_b36.sail_in22k_ft_in1k': _cfg( + url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_b36_in21ft1k.pth'), + 'convformer_b36_384.sail_in22k_ft_in1k_384': _cfg( + url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_b36_384_in21ft1k.pth', + input_size=(3, 384, 384)), + 'convformer_b36.sail_in22k': _cfg( + url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_b36_in21k.pth', + num_classes=21841), + + + 'caformer_s18.sail_in1k': _cfg( + url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_s18.pth'), + 'caformer_s18.sail_in1k_384': _cfg( + url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_s18_384.pth', + input_size=(3, 384, 384)), + 'caformer_s18.sail_in22k_ft_in1k': _cfg( + url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_s18_in21ft1k.pth'), + 'caformer_s18.sail_in22k_ft_in1k_384': _cfg( + url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_s18_384_in21ft1k.pth', + input_size=(3, 384, 384)), + 'caformer_s18.sail_in22k': _cfg( + url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_s18_in21k.pth', + num_classes=21841), + + 'caformer_s36.sail_in1k': _cfg( + url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_s36.pth'), + 'caformer_s36.sail_in1k_384': _cfg( + url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_s36_384.pth', + input_size=(3, 384, 384)), + 'caformer_s36.sail_in22k_ft_in1k': _cfg( + url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_s36_in21ft1k.pth'), + 'caformer_s36.sail_in22k_ft_in1k_384': _cfg( + url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_s36_384_in21ft1k.pth', + input_size=(3, 384, 384)), + 'caformer_s36.sail_in22k': _cfg( + url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_s36_in21k.pth', + num_classes=21841), + + 'caformer_m36.sail_in1k': _cfg( + url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_m36.pth'), + 'caformer_m36.sail_in1k_384': _cfg( + url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_m36_384.pth', + input_size=(3, 384, 384)), + 'caformer_m36.sail_in22k_ft_in1k': _cfg( + url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_m36_in21ft1k.pth'), + 'caformer_m36.sail_in22k_ft_in1k_384': _cfg( + url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_m36_384_in21ft1k.pth', + input_size=(3, 384, 384)), + 'caformer_m36.sail_in22k': _cfg( + url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_m36_in21k.pth', + num_classes=21841), + + 'caformer_b36.sail_in1k': _cfg( + url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_b36.pth'), + 'caformer_b36.sail_in1k_384': _cfg( + url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_b36_384.pth', + input_size=(3, 384, 384)), + 'caformer_b36.sail_in22k_ft_in1k': _cfg( + url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_b36_in21ft1k.pth'), + 'caformer_b36.sail_in22k_ft_in1k_384': _cfg( + url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_b36_384_in21ft1k.pth', + input_size=(3, 384, 384)), + 'caformer_b36.sail_in22k': _cfg( + url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_b36_in21k.pth', + num_classes=21841), +}) + + @register_model def poolformerv1_s12(pretrained=False, **kwargs): model_kwargs = dict( @@ -952,7 +956,7 @@ def identityformer_s12(pretrained=False, **kwargs): @register_model def identityformer_s24(pretrained=False, **kwargs): - model = MetaFormer( + model_kwargs = dict( depths=[4, 4, 12, 4], dims=[64, 128, 320, 512], token_mixers=nn.Identity, @@ -962,7 +966,7 @@ def identityformer_s24(pretrained=False, **kwargs): @register_model def identityformer_s36(pretrained=False, **kwargs): - model = MetaFormer( + model_kwargs = dict( depths=[6, 6, 18, 6], dims=[64, 128, 320, 512], token_mixers=nn.Identity, @@ -972,7 +976,7 @@ def identityformer_s36(pretrained=False, **kwargs): @register_model def identityformer_m36(pretrained=False, **kwargs): - model = MetaFormer( + model_kwargs = dict( depths=[6, 6, 18, 6], dims=[96, 192, 384, 768], token_mixers=nn.Identity, @@ -982,7 +986,7 @@ def identityformer_m36(pretrained=False, **kwargs): @register_model def identityformer_m48(pretrained=False, **kwargs): - model = MetaFormer( + model_kwargs = dict( depths=[8, 8, 24, 8], dims=[96, 192, 384, 768], token_mixers=nn.Identity, @@ -992,7 +996,7 @@ def identityformer_m48(pretrained=False, **kwargs): @register_model def randformer_s12(pretrained=False, **kwargs): - model = MetaFormer( + model_kwargs = dict( depths=[2, 2, 6, 2], dims=[64, 128, 320, 512], token_mixers=[nn.Identity, nn.Identity, RandomMixing, partial(RandomMixing, num_tokens=49)], @@ -1002,7 +1006,7 @@ def randformer_s12(pretrained=False, **kwargs): @register_model def randformer_s24(pretrained=False, **kwargs): - model = MetaFormer( + model_kwargs = dict( depths=[4, 4, 12, 4], dims=[64, 128, 320, 512], token_mixers=[nn.Identity, nn.Identity, RandomMixing, partial(RandomMixing, num_tokens=49)], @@ -1012,7 +1016,7 @@ def randformer_s24(pretrained=False, **kwargs): @register_model def randformer_s36(pretrained=False, **kwargs): - model = MetaFormer( + model_kwargs = dict( depths=[6, 6, 18, 6], dims=[64, 128, 320, 512], token_mixers=[nn.Identity, nn.Identity, RandomMixing, partial(RandomMixing, num_tokens=49)], @@ -1022,7 +1026,7 @@ def randformer_s36(pretrained=False, **kwargs): @register_model def randformer_m36(pretrained=False, **kwargs): - model = MetaFormer( + model_kwargs = dict( depths=[6, 6, 18, 6], dims=[96, 192, 384, 768], token_mixers=[nn.Identity, nn.Identity, RandomMixing, partial(RandomMixing, num_tokens=49)], @@ -1032,7 +1036,7 @@ def randformer_m36(pretrained=False, **kwargs): @register_model def randformer_m48(pretrained=False, **kwargs): - model = MetaFormer( + model_kwargs = dict( depths=[8, 8, 24, 8], dims=[96, 192, 384, 768], token_mixers=[nn.Identity, nn.Identity, RandomMixing, partial(RandomMixing, num_tokens=49)], @@ -1042,7 +1046,7 @@ def randformer_m48(pretrained=False, **kwargs): @register_model def poolformerv2_s12(pretrained=False, **kwargs): - model = MetaFormer( + model_kwargs = dict( depths=[2, 2, 6, 2], dims=[64, 128, 320, 512], token_mixers=Pooling, @@ -1052,7 +1056,7 @@ def poolformerv2_s12(pretrained=False, **kwargs): @register_model def poolformerv2_s24(pretrained=False, **kwargs): - model = MetaFormer( + model_kwargs = dict( depths=[4, 4, 12, 4], dims=[64, 128, 320, 512], token_mixers=Pooling, @@ -1064,7 +1068,7 @@ def poolformerv2_s24(pretrained=False, **kwargs): @register_model def poolformerv2_s36(pretrained=False, **kwargs): - model = MetaFormer( + model_kwargs = dict( depths=[6, 6, 18, 6], dims=[64, 128, 320, 512], token_mixers=Pooling, @@ -1076,7 +1080,7 @@ def poolformerv2_s36(pretrained=False, **kwargs): @register_model def poolformerv2_m36(pretrained=False, **kwargs): - model = MetaFormer( + model_kwargs = dict( depths=[6, 6, 18, 6], dims=[96, 192, 384, 768], token_mixers=Pooling, @@ -1087,7 +1091,7 @@ def poolformerv2_m36(pretrained=False, **kwargs): @register_model def poolformerv2_m48(pretrained=False, **kwargs): - model = MetaFormer( + model_kwargs = dict( depths=[8, 8, 24, 8], dims=[96, 192, 384, 768], token_mixers=Pooling, @@ -1099,7 +1103,7 @@ def poolformerv2_m48(pretrained=False, **kwargs): @register_model def convformer_s18(pretrained=False, **kwargs): - model = MetaFormer( + model_kwargs = dict( depths=[3, 3, 9, 3], dims=[64, 128, 320, 512], token_mixers=SepConv, @@ -1112,7 +1116,7 @@ def convformer_s18(pretrained=False, **kwargs): @register_model def convformer_s36(pretrained=False, **kwargs): - model = MetaFormer( + model_kwargs = dict( depths=[3, 12, 18, 3], dims=[64, 128, 320, 512], token_mixers=SepConv, @@ -1123,7 +1127,7 @@ def convformer_s36(pretrained=False, **kwargs): @register_model def convformer_m36(pretrained=False, **kwargs): - model = MetaFormer( + model_kwargs = dict( depths=[3, 12, 18, 3], dims=[96, 192, 384, 576], token_mixers=SepConv, @@ -1135,7 +1139,7 @@ def convformer_m36(pretrained=False, **kwargs): @register_model def convformer_b36(pretrained=False, **kwargs): - model = MetaFormer( + model_kwargs = dict( depths=[3, 12, 18, 3], dims=[128, 256, 512, 768], token_mixers=SepConv, @@ -1148,7 +1152,7 @@ def convformer_b36(pretrained=False, **kwargs): @register_model def caformer_s18(pretrained=False, **kwargs): - model = MetaFormer( + model_kwargs = dict( depths=[3, 3, 9, 3], dims=[64, 128, 320, 512], token_mixers=[SepConv, SepConv, Attention, Attention], @@ -1160,7 +1164,7 @@ def caformer_s18(pretrained=False, **kwargs): @register_model def caformer_s36(pretrained=False, **kwargs): - model = MetaFormer( + model_kwargs = dict( depths=[3, 12, 18, 3], dims=[64, 128, 320, 512], token_mixers=[SepConv, SepConv, Attention, Attention], @@ -1171,7 +1175,7 @@ def caformer_s36(pretrained=False, **kwargs): @register_model def caformer_m36(pretrained=False, **kwargs): - model = MetaFormer( + model_kwargs = dict( depths=[3, 12, 18, 3], dims=[96, 192, 384, 576], token_mixers=[SepConv, SepConv, Attention, Attention], @@ -1182,7 +1186,7 @@ def caformer_m36(pretrained=False, **kwargs): @register_model def caformer_b36(pretrained=False, **kwargs): - model = MetaFormer( + model_kwargs = dict( depths=[3, 12, 18, 3], dims=[128, 256, 512, 768], token_mixers=[SepConv, SepConv, Attention, Attention],