Update metaformers.py

pull/1647/head
Fredo Guan 2 years ago
parent eaf54b66af
commit 49bf08ed22

@ -39,184 +39,6 @@ from ._registry import register_model
__all__ = ['MetaFormer']
def _cfg(url='', **kwargs):
return {
'url': url,
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
'crop_pct': 1.0, 'interpolation': 'bicubic',
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
'classifier': 'head', 'first_conv': 'patch_embed.conv',
**kwargs
}
cfgs_v2 = generate_default_cfgs({
'poolformerv1_s12.sail_in1k': _cfg(
url='https://github.com/sail-sg/poolformer/releases/download/v1.0/poolformer_s12.pth.tar',
crop_pct=0.9),
'poolformerv1_s24.sail_in1k': _cfg(
url='https://github.com/sail-sg/poolformer/releases/download/v1.0/poolformer_s24.pth.tar',
crop_pct=0.9),
'poolformerv1_s36.sail_in1k': _cfg(
url='https://github.com/sail-sg/poolformer/releases/download/v1.0/poolformer_s36.pth.tar',
crop_pct=0.9),
'poolformerv1_m36.sail_in1k': _cfg(
url='https://github.com/sail-sg/poolformer/releases/download/v1.0/poolformer_m36.pth.tar',
crop_pct=0.95),
'poolformerv1_m48.sail_in1k': _cfg(
url='https://github.com/sail-sg/poolformer/releases/download/v1.0/poolformer_m48.pth.tar',
crop_pct=0.95),
'identityformer_s12.sail_in1k': _cfg(
url='https://huggingface.co/sail/dl/resolve/main/identityformer/identityformer_s12.pth'),
'identityformer_s24.sail_in1k': _cfg(
url='https://huggingface.co/sail/dl/resolve/main/identityformer/identityformer_s24.pth'),
'identityformer_s36.sail_in1k': _cfg(
url='https://huggingface.co/sail/dl/resolve/main/identityformer/identityformer_s36.pth'),
'identityformer_m36.sail_in1k': _cfg(
url='https://huggingface.co/sail/dl/resolve/main/identityformer/identityformer_m36.pth'),
'identityformer_m48.sail_in1k': _cfg(
url='https://huggingface.co/sail/dl/resolve/main/identityformer/identityformer_m48.pth'),
'randformer_s12.sail_in1k': _cfg(
url='https://huggingface.co/sail/dl/resolve/main/randformer/randformer_s12.pth'),
'randformer_s24.sail_in1k': _cfg(
url='https://huggingface.co/sail/dl/resolve/main/randformer/randformer_s24.pth'),
'randformer_s36.sail_in1k': _cfg(
url='https://huggingface.co/sail/dl/resolve/main/randformer/randformer_s36.pth'),
'randformer_m36.sail_in1k': _cfg(
url='https://huggingface.co/sail/dl/resolve/main/randformer/randformer_m36.pth'),
'randformer_m48.sail_in1k': _cfg(
url='https://huggingface.co/sail/dl/resolve/main/randformer/randformer_m48.pth'),
'poolformerv2_s12.sail_in1k': _cfg(
url='https://huggingface.co/sail/dl/resolve/main/poolformerv2/poolformerv2_s12.pth'),
'poolformerv2_s24.sail_in1k': _cfg(
url='https://huggingface.co/sail/dl/resolve/main/poolformerv2/poolformerv2_s24.pth'),
'poolformerv2_s36.sail_in1k': _cfg(
url='https://huggingface.co/sail/dl/resolve/main/poolformerv2/poolformerv2_s36.pth'),
'poolformerv2_m36.sail_in1k': _cfg(
url='https://huggingface.co/sail/dl/resolve/main/poolformerv2/poolformerv2_m36.pth'),
'poolformerv2_m48.sail_in1k': _cfg(
url='https://huggingface.co/sail/dl/resolve/main/poolformerv2/poolformerv2_m48.pth'),
'convformer_s18.sail_in1k': _cfg(
url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_s18.pth'),
'convformer_s18.sail_in1k_384': _cfg(
url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_s18_384.pth',
input_size=(3, 384, 384)),
'convformer_s18.sail_in22k_ft_in1k': _cfg(
url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_s18_in21ft1k.pth'),
'convformer_s18.sail_in22k_ft_in1k_384': _cfg(
url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_s18_384_in21ft1k.pth',
input_size=(3, 384, 384)),
'convformer_s18.sail_in22k': _cfg(
url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_s18_in21k.pth',
num_classes=21841),
'convformer_s36.sail_in1k': _cfg(
url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_s36.pth'),
'convformer_s36.sail_in1k_384': _cfg(
url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_s36_384.pth',
input_size=(3, 384, 384)),
'convformer_s36.sail_in22k_ft_in1k': _cfg(
url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_s36_in21ft1k.pth'),
'convformer_s36.sail_in22k_ft_in1k_384': _cfg(
url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_s36_384_in21ft1k.pth',
input_size=(3, 384, 384)),
'convformer_s36.sail_in22k': _cfg(
url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_s36_in21k.pth',
num_classes=21841),
'convformer_m36.sail_in1k': _cfg(
url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_m36.pth'),
'convformer_m36.sail_in1k_384': _cfg(
url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_m36_384.pth',
input_size=(3, 384, 384)),
'convformer_m36.sail_in22k_ft_in1k': _cfg(
url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_m36_in21ft1k.pth'),
'convformer_m36.sail_in22k_ft_in1k_384': _cfg(
url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_m36_384_in21ft1k.pth',
input_size=(3, 384, 384)),
'convformer_m36.sail_in22k': _cfg(
url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_m36_in21k.pth',
num_classes=21841),
'convformer_b36.sail_in1k': _cfg(
url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_b36.pth'),
'convformer_b36.sail_in1k_384': _cfg(
url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_b36_384.pth',
input_size=(3, 384, 384)),
'convformer_b36.sail_in22k_ft_in1k': _cfg(
url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_b36_in21ft1k.pth'),
'convformer_b36_384.sail_in22k_ft_in1k_384': _cfg(
url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_b36_384_in21ft1k.pth',
input_size=(3, 384, 384)),
'convformer_b36.sail_in22k': _cfg(
url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_b36_in21k.pth',
num_classes=21841),
'caformer_s18.sail_in1k': _cfg(
url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_s18.pth'),
'caformer_s18.sail_in1k_384': _cfg(
url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_s18_384.pth',
input_size=(3, 384, 384)),
'caformer_s18.sail_in22k_ft_in1k': _cfg(
url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_s18_in21ft1k.pth'),
'caformer_s18.sail_in22k_ft_in1k_384': _cfg(
url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_s18_384_in21ft1k.pth',
input_size=(3, 384, 384)),
'caformer_s18.sail_in22k': _cfg(
url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_s18_in21k.pth',
num_classes=21841),
'caformer_s36.sail_in1k': _cfg(
url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_s36.pth'),
'caformer_s36.sail_in1k_384': _cfg(
url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_s36_384.pth',
input_size=(3, 384, 384)),
'caformer_s36.sail_in22k_ft_in1k': _cfg(
url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_s36_in21ft1k.pth'),
'caformer_s36.sail_in22k_ft_in1k_384': _cfg(
url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_s36_384_in21ft1k.pth',
input_size=(3, 384, 384)),
'caformer_s36.sail_in22k': _cfg(
url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_s36_in21k.pth',
num_classes=21841),
'caformer_m36.sail_in1k': _cfg(
url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_m36.pth'),
'caformer_m36.sail_in1k_384': _cfg(
url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_m36_384.pth',
input_size=(3, 384, 384)),
'caformer_m36.sail_in22k_ft_in1k': _cfg(
url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_m36_in21ft1k.pth'),
'caformer_m36.sail_in22k_ft_in1k_384': _cfg(
url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_m36_384_in21ft1k.pth',
input_size=(3, 384, 384)),
'caformer_m36.sail_in22k': _cfg(
url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_m36_in21k.pth',
num_classes=21841),
'caformer_b36.sail_in1k': _cfg(
url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_b36.pth'),
'caformer_b36.sail_in1k_384': _cfg(
url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_b36_384.pth',
input_size=(3, 384, 384)),
'caformer_b36.sail_in22k_ft_in1k': _cfg(
url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_b36_in21ft1k.pth'),
'caformer_b36.sail_in22k_ft_in1k_384': _cfg(
url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_b36_384_in21ft1k.pth',
input_size=(3, 384, 384)),
'caformer_b36.sail_in22k': _cfg(
url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_b36_in21k.pth',
num_classes=21841),
})
class Downsampling(nn.Module):
"""
Downsampling implemented by a layer of convolution.
@ -238,7 +60,10 @@ class Downsampling(nn.Module):
x = self.pre_norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
x = self.conv(x)
x = self.post_norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
print(x[0][0][0][0])
return x
'''
class Downsampling(nn.Module):
@ -760,16 +585,14 @@ class MetaFormer(nn.Module):
self.stages = nn.Sequential(*stages)
self.norm = self.output_norm(self.num_features)
'''
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
if head_dropout > 0.0:
self.head = self.head_fn(self.num_features, self.num_classes, head_dropout=self.head_dropout)
else:
self.head = self.head_fn(self.num_features, self.num_classes)
'''
self.reset_classifier(self.num_classes, global_pool)
self.apply(self._init_weights)
@ -816,7 +639,9 @@ class MetaFormer(nn.Module):
def forward_features(self, x):
x = self.patch_embed(x)
x = self.stages(x)
#x = self.stages(x)
for i, stage in enumerate(self.stages):
x = stage(x)
return x
@ -860,6 +685,185 @@ def _create_metaformer(variant, pretrained=False, **kwargs):
return model
def _cfg(url='', **kwargs):
return {
'url': url,
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
'crop_pct': 1.0, 'interpolation': 'bicubic',
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
'classifier': 'head', 'first_conv': 'patch_embed.conv',
**kwargs
}
default_cfgs = generate_default_cfgs({
'poolformerv1_s12.sail_in1k': _cfg(
url='https://github.com/sail-sg/poolformer/releases/download/v1.0/poolformer_s12.pth.tar',
crop_pct=0.9),
'poolformerv1_s24.sail_in1k': _cfg(
url='https://github.com/sail-sg/poolformer/releases/download/v1.0/poolformer_s24.pth.tar',
crop_pct=0.9),
'poolformerv1_s36.sail_in1k': _cfg(
url='https://github.com/sail-sg/poolformer/releases/download/v1.0/poolformer_s36.pth.tar',
crop_pct=0.9),
'poolformerv1_m36.sail_in1k': _cfg(
url='https://github.com/sail-sg/poolformer/releases/download/v1.0/poolformer_m36.pth.tar',
crop_pct=0.95),
'poolformerv1_m48.sail_in1k': _cfg(
url='https://github.com/sail-sg/poolformer/releases/download/v1.0/poolformer_m48.pth.tar',
crop_pct=0.95),
'identityformer_s12.sail_in1k': _cfg(
url='https://huggingface.co/sail/dl/resolve/main/identityformer/identityformer_s12.pth'),
'identityformer_s24.sail_in1k': _cfg(
url='https://huggingface.co/sail/dl/resolve/main/identityformer/identityformer_s24.pth'),
'identityformer_s36.sail_in1k': _cfg(
url='https://huggingface.co/sail/dl/resolve/main/identityformer/identityformer_s36.pth'),
'identityformer_m36.sail_in1k': _cfg(
url='https://huggingface.co/sail/dl/resolve/main/identityformer/identityformer_m36.pth'),
'identityformer_m48.sail_in1k': _cfg(
url='https://huggingface.co/sail/dl/resolve/main/identityformer/identityformer_m48.pth'),
'randformer_s12.sail_in1k': _cfg(
url='https://huggingface.co/sail/dl/resolve/main/randformer/randformer_s12.pth'),
'randformer_s24.sail_in1k': _cfg(
url='https://huggingface.co/sail/dl/resolve/main/randformer/randformer_s24.pth'),
'randformer_s36.sail_in1k': _cfg(
url='https://huggingface.co/sail/dl/resolve/main/randformer/randformer_s36.pth'),
'randformer_m36.sail_in1k': _cfg(
url='https://huggingface.co/sail/dl/resolve/main/randformer/randformer_m36.pth'),
'randformer_m48.sail_in1k': _cfg(
url='https://huggingface.co/sail/dl/resolve/main/randformer/randformer_m48.pth'),
'poolformerv2_s12.sail_in1k': _cfg(
url='https://huggingface.co/sail/dl/resolve/main/poolformerv2/poolformerv2_s12.pth'),
'poolformerv2_s24.sail_in1k': _cfg(
url='https://huggingface.co/sail/dl/resolve/main/poolformerv2/poolformerv2_s24.pth'),
'poolformerv2_s36.sail_in1k': _cfg(
url='https://huggingface.co/sail/dl/resolve/main/poolformerv2/poolformerv2_s36.pth'),
'poolformerv2_m36.sail_in1k': _cfg(
url='https://huggingface.co/sail/dl/resolve/main/poolformerv2/poolformerv2_m36.pth'),
'poolformerv2_m48.sail_in1k': _cfg(
url='https://huggingface.co/sail/dl/resolve/main/poolformerv2/poolformerv2_m48.pth'),
'convformer_s18.sail_in1k': _cfg(
url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_s18.pth'),
'convformer_s18.sail_in1k_384': _cfg(
url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_s18_384.pth',
input_size=(3, 384, 384)),
'convformer_s18.sail_in22k_ft_in1k': _cfg(
url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_s18_in21ft1k.pth'),
'convformer_s18.sail_in22k_ft_in1k_384': _cfg(
url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_s18_384_in21ft1k.pth',
input_size=(3, 384, 384)),
'convformer_s18.sail_in22k': _cfg(
url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_s18_in21k.pth',
num_classes=21841),
'convformer_s36.sail_in1k': _cfg(
url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_s36.pth'),
'convformer_s36.sail_in1k_384': _cfg(
url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_s36_384.pth',
input_size=(3, 384, 384)),
'convformer_s36.sail_in22k_ft_in1k': _cfg(
url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_s36_in21ft1k.pth'),
'convformer_s36.sail_in22k_ft_in1k_384': _cfg(
url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_s36_384_in21ft1k.pth',
input_size=(3, 384, 384)),
'convformer_s36.sail_in22k': _cfg(
url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_s36_in21k.pth',
num_classes=21841),
'convformer_m36.sail_in1k': _cfg(
url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_m36.pth'),
'convformer_m36.sail_in1k_384': _cfg(
url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_m36_384.pth',
input_size=(3, 384, 384)),
'convformer_m36.sail_in22k_ft_in1k': _cfg(
url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_m36_in21ft1k.pth'),
'convformer_m36.sail_in22k_ft_in1k_384': _cfg(
url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_m36_384_in21ft1k.pth',
input_size=(3, 384, 384)),
'convformer_m36.sail_in22k': _cfg(
url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_m36_in21k.pth',
num_classes=21841),
'convformer_b36.sail_in1k': _cfg(
url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_b36.pth'),
'convformer_b36.sail_in1k_384': _cfg(
url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_b36_384.pth',
input_size=(3, 384, 384)),
'convformer_b36.sail_in22k_ft_in1k': _cfg(
url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_b36_in21ft1k.pth'),
'convformer_b36_384.sail_in22k_ft_in1k_384': _cfg(
url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_b36_384_in21ft1k.pth',
input_size=(3, 384, 384)),
'convformer_b36.sail_in22k': _cfg(
url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_b36_in21k.pth',
num_classes=21841),
'caformer_s18.sail_in1k': _cfg(
url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_s18.pth'),
'caformer_s18.sail_in1k_384': _cfg(
url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_s18_384.pth',
input_size=(3, 384, 384)),
'caformer_s18.sail_in22k_ft_in1k': _cfg(
url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_s18_in21ft1k.pth'),
'caformer_s18.sail_in22k_ft_in1k_384': _cfg(
url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_s18_384_in21ft1k.pth',
input_size=(3, 384, 384)),
'caformer_s18.sail_in22k': _cfg(
url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_s18_in21k.pth',
num_classes=21841),
'caformer_s36.sail_in1k': _cfg(
url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_s36.pth'),
'caformer_s36.sail_in1k_384': _cfg(
url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_s36_384.pth',
input_size=(3, 384, 384)),
'caformer_s36.sail_in22k_ft_in1k': _cfg(
url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_s36_in21ft1k.pth'),
'caformer_s36.sail_in22k_ft_in1k_384': _cfg(
url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_s36_384_in21ft1k.pth',
input_size=(3, 384, 384)),
'caformer_s36.sail_in22k': _cfg(
url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_s36_in21k.pth',
num_classes=21841),
'caformer_m36.sail_in1k': _cfg(
url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_m36.pth'),
'caformer_m36.sail_in1k_384': _cfg(
url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_m36_384.pth',
input_size=(3, 384, 384)),
'caformer_m36.sail_in22k_ft_in1k': _cfg(
url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_m36_in21ft1k.pth'),
'caformer_m36.sail_in22k_ft_in1k_384': _cfg(
url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_m36_384_in21ft1k.pth',
input_size=(3, 384, 384)),
'caformer_m36.sail_in22k': _cfg(
url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_m36_in21k.pth',
num_classes=21841),
'caformer_b36.sail_in1k': _cfg(
url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_b36.pth'),
'caformer_b36.sail_in1k_384': _cfg(
url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_b36_384.pth',
input_size=(3, 384, 384)),
'caformer_b36.sail_in22k_ft_in1k': _cfg(
url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_b36_in21ft1k.pth'),
'caformer_b36.sail_in22k_ft_in1k_384': _cfg(
url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_b36_384_in21ft1k.pth',
input_size=(3, 384, 384)),
'caformer_b36.sail_in22k': _cfg(
url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_b36_in21k.pth',
num_classes=21841),
})
@register_model
def poolformerv1_s12(pretrained=False, **kwargs):
model_kwargs = dict(
@ -952,7 +956,7 @@ def identityformer_s12(pretrained=False, **kwargs):
@register_model
def identityformer_s24(pretrained=False, **kwargs):
model = MetaFormer(
model_kwargs = dict(
depths=[4, 4, 12, 4],
dims=[64, 128, 320, 512],
token_mixers=nn.Identity,
@ -962,7 +966,7 @@ def identityformer_s24(pretrained=False, **kwargs):
@register_model
def identityformer_s36(pretrained=False, **kwargs):
model = MetaFormer(
model_kwargs = dict(
depths=[6, 6, 18, 6],
dims=[64, 128, 320, 512],
token_mixers=nn.Identity,
@ -972,7 +976,7 @@ def identityformer_s36(pretrained=False, **kwargs):
@register_model
def identityformer_m36(pretrained=False, **kwargs):
model = MetaFormer(
model_kwargs = dict(
depths=[6, 6, 18, 6],
dims=[96, 192, 384, 768],
token_mixers=nn.Identity,
@ -982,7 +986,7 @@ def identityformer_m36(pretrained=False, **kwargs):
@register_model
def identityformer_m48(pretrained=False, **kwargs):
model = MetaFormer(
model_kwargs = dict(
depths=[8, 8, 24, 8],
dims=[96, 192, 384, 768],
token_mixers=nn.Identity,
@ -992,7 +996,7 @@ def identityformer_m48(pretrained=False, **kwargs):
@register_model
def randformer_s12(pretrained=False, **kwargs):
model = MetaFormer(
model_kwargs = dict(
depths=[2, 2, 6, 2],
dims=[64, 128, 320, 512],
token_mixers=[nn.Identity, nn.Identity, RandomMixing, partial(RandomMixing, num_tokens=49)],
@ -1002,7 +1006,7 @@ def randformer_s12(pretrained=False, **kwargs):
@register_model
def randformer_s24(pretrained=False, **kwargs):
model = MetaFormer(
model_kwargs = dict(
depths=[4, 4, 12, 4],
dims=[64, 128, 320, 512],
token_mixers=[nn.Identity, nn.Identity, RandomMixing, partial(RandomMixing, num_tokens=49)],
@ -1012,7 +1016,7 @@ def randformer_s24(pretrained=False, **kwargs):
@register_model
def randformer_s36(pretrained=False, **kwargs):
model = MetaFormer(
model_kwargs = dict(
depths=[6, 6, 18, 6],
dims=[64, 128, 320, 512],
token_mixers=[nn.Identity, nn.Identity, RandomMixing, partial(RandomMixing, num_tokens=49)],
@ -1022,7 +1026,7 @@ def randformer_s36(pretrained=False, **kwargs):
@register_model
def randformer_m36(pretrained=False, **kwargs):
model = MetaFormer(
model_kwargs = dict(
depths=[6, 6, 18, 6],
dims=[96, 192, 384, 768],
token_mixers=[nn.Identity, nn.Identity, RandomMixing, partial(RandomMixing, num_tokens=49)],
@ -1032,7 +1036,7 @@ def randformer_m36(pretrained=False, **kwargs):
@register_model
def randformer_m48(pretrained=False, **kwargs):
model = MetaFormer(
model_kwargs = dict(
depths=[8, 8, 24, 8],
dims=[96, 192, 384, 768],
token_mixers=[nn.Identity, nn.Identity, RandomMixing, partial(RandomMixing, num_tokens=49)],
@ -1042,7 +1046,7 @@ def randformer_m48(pretrained=False, **kwargs):
@register_model
def poolformerv2_s12(pretrained=False, **kwargs):
model = MetaFormer(
model_kwargs = dict(
depths=[2, 2, 6, 2],
dims=[64, 128, 320, 512],
token_mixers=Pooling,
@ -1052,7 +1056,7 @@ def poolformerv2_s12(pretrained=False, **kwargs):
@register_model
def poolformerv2_s24(pretrained=False, **kwargs):
model = MetaFormer(
model_kwargs = dict(
depths=[4, 4, 12, 4],
dims=[64, 128, 320, 512],
token_mixers=Pooling,
@ -1064,7 +1068,7 @@ def poolformerv2_s24(pretrained=False, **kwargs):
@register_model
def poolformerv2_s36(pretrained=False, **kwargs):
model = MetaFormer(
model_kwargs = dict(
depths=[6, 6, 18, 6],
dims=[64, 128, 320, 512],
token_mixers=Pooling,
@ -1076,7 +1080,7 @@ def poolformerv2_s36(pretrained=False, **kwargs):
@register_model
def poolformerv2_m36(pretrained=False, **kwargs):
model = MetaFormer(
model_kwargs = dict(
depths=[6, 6, 18, 6],
dims=[96, 192, 384, 768],
token_mixers=Pooling,
@ -1087,7 +1091,7 @@ def poolformerv2_m36(pretrained=False, **kwargs):
@register_model
def poolformerv2_m48(pretrained=False, **kwargs):
model = MetaFormer(
model_kwargs = dict(
depths=[8, 8, 24, 8],
dims=[96, 192, 384, 768],
token_mixers=Pooling,
@ -1099,7 +1103,7 @@ def poolformerv2_m48(pretrained=False, **kwargs):
@register_model
def convformer_s18(pretrained=False, **kwargs):
model = MetaFormer(
model_kwargs = dict(
depths=[3, 3, 9, 3],
dims=[64, 128, 320, 512],
token_mixers=SepConv,
@ -1112,7 +1116,7 @@ def convformer_s18(pretrained=False, **kwargs):
@register_model
def convformer_s36(pretrained=False, **kwargs):
model = MetaFormer(
model_kwargs = dict(
depths=[3, 12, 18, 3],
dims=[64, 128, 320, 512],
token_mixers=SepConv,
@ -1123,7 +1127,7 @@ def convformer_s36(pretrained=False, **kwargs):
@register_model
def convformer_m36(pretrained=False, **kwargs):
model = MetaFormer(
model_kwargs = dict(
depths=[3, 12, 18, 3],
dims=[96, 192, 384, 576],
token_mixers=SepConv,
@ -1135,7 +1139,7 @@ def convformer_m36(pretrained=False, **kwargs):
@register_model
def convformer_b36(pretrained=False, **kwargs):
model = MetaFormer(
model_kwargs = dict(
depths=[3, 12, 18, 3],
dims=[128, 256, 512, 768],
token_mixers=SepConv,
@ -1148,7 +1152,7 @@ def convformer_b36(pretrained=False, **kwargs):
@register_model
def caformer_s18(pretrained=False, **kwargs):
model = MetaFormer(
model_kwargs = dict(
depths=[3, 3, 9, 3],
dims=[64, 128, 320, 512],
token_mixers=[SepConv, SepConv, Attention, Attention],
@ -1160,7 +1164,7 @@ def caformer_s18(pretrained=False, **kwargs):
@register_model
def caformer_s36(pretrained=False, **kwargs):
model = MetaFormer(
model_kwargs = dict(
depths=[3, 12, 18, 3],
dims=[64, 128, 320, 512],
token_mixers=[SepConv, SepConv, Attention, Attention],
@ -1171,7 +1175,7 @@ def caformer_s36(pretrained=False, **kwargs):
@register_model
def caformer_m36(pretrained=False, **kwargs):
model = MetaFormer(
model_kwargs = dict(
depths=[3, 12, 18, 3],
dims=[96, 192, 384, 576],
token_mixers=[SepConv, SepConv, Attention, Attention],
@ -1182,7 +1186,7 @@ def caformer_m36(pretrained=False, **kwargs):
@register_model
def caformer_b36(pretrained=False, **kwargs):
model = MetaFormer(
model_kwargs = dict(
depths=[3, 12, 18, 3],
dims=[128, 256, 512, 768],
token_mixers=[SepConv, SepConv, Attention, Attention],

Loading…
Cancel
Save