diff --git a/timm/models/mlp_mixer.py b/timm/models/mlp_mixer.py index c2c96e6c..248568fc 100644 --- a/timm/models/mlp_mixer.py +++ b/timm/models/mlp_mixer.py @@ -96,8 +96,8 @@ class MlpMixer(nn.Module): mlp_layer=Mlp, norm_layer=partial(nn.LayerNorm, eps=1e-6), act_layer=nn.GELU, - drop=0., - drop_path=0., + drop_rate=0., + drop_path_rate=0., nlhb=False, ): super().__init__() @@ -108,7 +108,7 @@ class MlpMixer(nn.Module): self.blocks = nn.Sequential(*[ MixerBlock( hidden_dim, self.stem.num_patches, tokens_dim, channels_dim, - mlp_layer=mlp_layer, norm_layer=norm_layer, act_layer=act_layer, drop=drop, drop_path=drop_path) + mlp_layer=mlp_layer, norm_layer=norm_layer, act_layer=act_layer, drop=drop_rate, drop_path=drop_path_rate) for _ in range(num_blocks)]) self.norm = norm_layer(hidden_dim) self.head = nn.Linear(hidden_dim, self.num_classes) # zero init