|
|
|
@ -96,8 +96,8 @@ class MlpMixer(nn.Module):
|
|
|
|
|
mlp_layer=Mlp,
|
|
|
|
|
norm_layer=partial(nn.LayerNorm, eps=1e-6),
|
|
|
|
|
act_layer=nn.GELU,
|
|
|
|
|
drop=0.,
|
|
|
|
|
drop_path=0.,
|
|
|
|
|
drop_rate=0.,
|
|
|
|
|
drop_path_rate=0.,
|
|
|
|
|
nlhb=False,
|
|
|
|
|
):
|
|
|
|
|
super().__init__()
|
|
|
|
@ -108,7 +108,7 @@ class MlpMixer(nn.Module):
|
|
|
|
|
self.blocks = nn.Sequential(*[
|
|
|
|
|
MixerBlock(
|
|
|
|
|
hidden_dim, self.stem.num_patches, tokens_dim, channels_dim,
|
|
|
|
|
mlp_layer=mlp_layer, norm_layer=norm_layer, act_layer=act_layer, drop=drop, drop_path=drop_path)
|
|
|
|
|
mlp_layer=mlp_layer, norm_layer=norm_layer, act_layer=act_layer, drop=drop_rate, drop_path=drop_path_rate)
|
|
|
|
|
for _ in range(num_blocks)])
|
|
|
|
|
self.norm = norm_layer(hidden_dim)
|
|
|
|
|
self.head = nn.Linear(hidden_dim, self.num_classes) # zero init
|
|
|
|
|