diff --git a/timm/models/mlp_mixer.py b/timm/models/mlp_mixer.py index 98767036..ae5b6f49 100644 --- a/timm/models/mlp_mixer.py +++ b/timm/models/mlp_mixer.py @@ -307,11 +307,16 @@ class MlpMixer(nn.Module): ) #self.sigmoid = nn.Sigmoid() self.sm = nn.Softmax(dim=1) - self.init_weights(nlhb=nlhb) - - def init_weights(self, nlhb=False): - head_bias = -math.log(self.num_classes) if nlhb else 0. - named_apply(partial(_init_weights, head_bias=head_bias), module=self) # depth-first + # self.init_weights(nlhb=nlhb) + self.apply(self.init_weights) + + # def init_weights(self, nlhb=False): + # head_bias = -math.log(self.num_classes) if nlhb else 0. + # named_apply(partial(_init_weights, head_bias=head_bias), module=self) # depth-first + def init_weights(m): + if isinstance(m, nn.Linear): + torch.nn.init.xavier_uniform(m.weight) + m.bias.data.fill_(0.01) def get_classifier(self): return self.head