From 511a8e8c96dcbac7014aec8355f38a658ef40e49 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Mon, 14 Jun 2021 17:01:12 -0700 Subject: [PATCH] Add official ResMLP weights. --- timm/models/mlp_mixer.py | 146 +++++++++++++++++++++++++++++++++++---- 1 file changed, 134 insertions(+), 12 deletions(-) diff --git a/timm/models/mlp_mixer.py b/timm/models/mlp_mixer.py index 637e00ea..db3a1be5 100644 --- a/timm/models/mlp_mixer.py +++ b/timm/models/mlp_mixer.py @@ -14,8 +14,9 @@ Paper: 'MLP-Mixer: An all-MLP Architecture for Vision' - https://arxiv.org/abs/2 year={2021} } -Also supporting preliminary (not verified) implementations of ResMlp, gMLP, and possibly more... +Also supporting ResMlp, and a preliminary (not verified) implementations of gMLP +Code: https://github.com/facebookresearch/deit Paper: `ResMLP: Feedforward networks for image classification...` - https://arxiv.org/abs/2105.03404 @misc{touvron2021resmlp, title={ResMLP: Feedforward networks for image classification with data-efficient training}, @@ -94,11 +95,36 @@ default_cfgs = dict( gmixer_12_224=_cfg(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD), gmixer_24_224=_cfg(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD), - resmlp_12_224=_cfg(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD), + resmlp_12_224=_cfg( + url='https://dl.fbaipublicfiles.com/deit/resmlp_12_no_dist.pth', + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD), resmlp_24_224=_cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resmlp_24_224_raa-a8256759.pth', - mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, crop_pct=0.89), - resmlp_36_224=_cfg(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD), + url='https://dl.fbaipublicfiles.com/deit/resmlp_24_no_dist.pth', + #url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resmlp_24_224_raa-a8256759.pth', + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD), + resmlp_36_224=_cfg( + url='https://dl.fbaipublicfiles.com/deit/resmlp_36_no_dist.pth', + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD), + resmlp_big_24_224=_cfg( + url='https://dl.fbaipublicfiles.com/deit/resmlpB_24_no_dist.pth', + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD), + + resmlp_12_distilled_224=_cfg( + url='https://dl.fbaipublicfiles.com/deit/resmlp_12_dist.pth', + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD), + resmlp_24_distilled_224=_cfg( + url='https://dl.fbaipublicfiles.com/deit/resmlp_24_dist.pth', + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD), + resmlp_36_distilled_224=_cfg( + url='https://dl.fbaipublicfiles.com/deit/resmlp_36_dist.pth', + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD), + resmlp_big_24_distilled_224=_cfg( + url='https://dl.fbaipublicfiles.com/deit/resmlpB_24_dist.pth', + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD), + + resmlp_big_24_224_in22ft1k=_cfg( + url='https://dl.fbaipublicfiles.com/deit/resmlpB_24_22k.pth', + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD), gmlp_ti16_224=_cfg(), gmlp_s16_224=_cfg(), @@ -266,7 +292,7 @@ class MlpMixer(nn.Module): return x -def _init_weights(module: nn.Module, name: str, head_bias: float = 0.): +def _init_weights(module: nn.Module, name: str, head_bias: float = 0., flax=False): """ Mixer weight initialization (trying to match Flax defaults) """ if isinstance(module, nn.Linear): @@ -274,12 +300,19 @@ def _init_weights(module: nn.Module, name: str, head_bias: float = 0.): nn.init.zeros_(module.weight) nn.init.constant_(module.bias, head_bias) else: - nn.init.xavier_uniform_(module.weight) - if module.bias is not None: - if 'mlp' in name: - nn.init.normal_(module.bias, std=1e-6) - else: + if flax: + # Flax defaults + lecun_normal_(module.weight) + if module.bias is not None: nn.init.zeros_(module.bias) + else: + # like MLP init in vit (my original init) + nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + if 'mlp' in name: + nn.init.normal_(module.bias, std=1e-6) + else: + nn.init.zeros_(module.bias) elif isinstance(module, nn.Conv2d): lecun_normal_(module.weight) if module.bias is not None: @@ -293,6 +326,23 @@ def _init_weights(module: nn.Module, name: str, head_bias: float = 0.): module.init_weights() +def checkpoint_filter_fn(state_dict, model): + """ Remap checkpoints if needed """ + if 'patch_embed.proj.weight' in state_dict: + # Remap FB ResMlp models -> timm + out_dict = {} + for k, v in state_dict.items(): + k = k.replace('patch_embed.', 'stem.') + k = k.replace('attn.', 'linear_tokens.') + k = k.replace('mlp.', 'mlp_channels.') + k = k.replace('gamma_', 'ls') + if k.endswith('.alpha') or k.endswith('.beta'): + v = v.reshape(1, 1, -1) + out_dict[k] = v + return out_dict + return state_dict + + def _create_mixer(variant, pretrained=False, **kwargs): if kwargs.get('features_only', None): raise RuntimeError('features_only not implemented for MLP-Mixer models.') @@ -300,6 +350,7 @@ def _create_mixer(variant, pretrained=False, **kwargs): model = build_model_with_cfg( MlpMixer, variant, pretrained, default_cfg=default_cfgs[variant], + pretrained_filter_fn=checkpoint_filter_fn, **kwargs) return model @@ -458,11 +509,82 @@ def resmlp_36_224(pretrained=False, **kwargs): """ model_args = dict( patch_size=16, num_blocks=36, embed_dim=384, mlp_ratio=4, - block_layer=partial(ResBlock, init_values=1e-5), norm_layer=Affine, **kwargs) + block_layer=partial(ResBlock, init_values=1e-6), norm_layer=Affine, **kwargs) model = _create_mixer('resmlp_36_224', pretrained=pretrained, **model_args) return model +@register_model +def resmlp_big_24_224(pretrained=False, **kwargs): + """ ResMLP-B-24 + Paper: `ResMLP: Feedforward networks for image classification...` - https://arxiv.org/abs/2105.03404 + """ + model_args = dict( + patch_size=8, num_blocks=24, embed_dim=768, mlp_ratio=4, + block_layer=partial(ResBlock, init_values=1e-6), norm_layer=Affine, **kwargs) + model = _create_mixer('resmlp_big_24_224', pretrained=pretrained, **model_args) + return model + + +@register_model +def resmlp_12_distilled_224(pretrained=False, **kwargs): + """ ResMLP-12 + Paper: `ResMLP: Feedforward networks for image classification...` - https://arxiv.org/abs/2105.03404 + """ + model_args = dict( + patch_size=16, num_blocks=12, embed_dim=384, mlp_ratio=4, block_layer=ResBlock, norm_layer=Affine, **kwargs) + model = _create_mixer('resmlp_12_distilled_224', pretrained=pretrained, **model_args) + return model + + +@register_model +def resmlp_24_distilled_224(pretrained=False, **kwargs): + """ ResMLP-24 + Paper: `ResMLP: Feedforward networks for image classification...` - https://arxiv.org/abs/2105.03404 + """ + model_args = dict( + patch_size=16, num_blocks=24, embed_dim=384, mlp_ratio=4, + block_layer=partial(ResBlock, init_values=1e-5), norm_layer=Affine, **kwargs) + model = _create_mixer('resmlp_24_distilled_224', pretrained=pretrained, **model_args) + return model + + +@register_model +def resmlp_36_distilled_224(pretrained=False, **kwargs): + """ ResMLP-36 + Paper: `ResMLP: Feedforward networks for image classification...` - https://arxiv.org/abs/2105.03404 + """ + model_args = dict( + patch_size=16, num_blocks=36, embed_dim=384, mlp_ratio=4, + block_layer=partial(ResBlock, init_values=1e-6), norm_layer=Affine, **kwargs) + model = _create_mixer('resmlp_36_distilled_224', pretrained=pretrained, **model_args) + return model + + +@register_model +def resmlp_big_24_distilled_224(pretrained=False, **kwargs): + """ ResMLP-B-24 + Paper: `ResMLP: Feedforward networks for image classification...` - https://arxiv.org/abs/2105.03404 + """ + model_args = dict( + patch_size=8, num_blocks=24, embed_dim=768, mlp_ratio=4, + block_layer=partial(ResBlock, init_values=1e-6), norm_layer=Affine, **kwargs) + model = _create_mixer('resmlp_big_24_distilled_224', pretrained=pretrained, **model_args) + return model + + +@register_model +def resmlp_big_24_224_in22ft1k(pretrained=False, **kwargs): + """ ResMLP-B-24 + Paper: `ResMLP: Feedforward networks for image classification...` - https://arxiv.org/abs/2105.03404 + """ + model_args = dict( + patch_size=8, num_blocks=24, embed_dim=768, mlp_ratio=4, + block_layer=partial(ResBlock, init_values=1e-6), norm_layer=Affine, **kwargs) + model = _create_mixer('resmlp_big_24_224_in22ft1k', pretrained=pretrained, **model_args) + return model + + @register_model def gmlp_ti16_224(pretrained=False, **kwargs): """ gMLP-Tiny