diff --git a/timm/models/mlp_mixer.py b/timm/models/mlp_mixer.py index 5a6dce6f..6f53264a 100644 --- a/timm/models/mlp_mixer.py +++ b/timm/models/mlp_mixer.py @@ -422,7 +422,8 @@ def resmlp_24_224(pretrained=False, **kwargs): Paper: `ResMLP: Feedforward networks for image classification...` - https://arxiv.org/abs/2105.03404 """ model_args = dict( - patch_size=16, num_blocks=24, hidden_dim=384, mlp_ratio=4, block_layer=ResBlock, norm_layer=Affine, **kwargs) + patch_size=16, num_blocks=24, hidden_dim=384, mlp_ratio=4, + block_layer=partial(ResBlock, init_values=1e-5), norm_layer=Affine, **kwargs) model = _create_mixer('resmlp_24_224', pretrained=pretrained, **model_args) return model @@ -433,7 +434,8 @@ def resmlp_36_224(pretrained=False, **kwargs): Paper: `ResMLP: Feedforward networks for image classification...` - https://arxiv.org/abs/2105.03404 """ model_args = dict( - patch_size=16, num_blocks=36, hidden_dim=384, mlp_ratio=4, block_layer=ResBlock, norm_layer=Affine, **kwargs) + patch_size=16, num_blocks=36, hidden_dim=384, mlp_ratio=4, + block_layer=partial(ResBlock, init_values=1e-5), norm_layer=Affine, **kwargs) model = _create_mixer('resmlp_36_224', pretrained=pretrained, **model_args) return model