From e94075577808dbb8224020547aa4d62f8bc7341a Mon Sep 17 00:00:00 2001 From: naman jain Date: Fri, 22 Apr 2022 03:16:13 -0500 Subject: [PATCH] added new wt initializtion- xavier_uniform --- timm/models/mlp_mixer.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/timm/models/mlp_mixer.py b/timm/models/mlp_mixer.py index 98767036..ae5b6f49 100644 --- a/timm/models/mlp_mixer.py +++ b/timm/models/mlp_mixer.py @@ -307,11 +307,16 @@ class MlpMixer(nn.Module): ) #self.sigmoid = nn.Sigmoid() self.sm = nn.Softmax(dim=1) - self.init_weights(nlhb=nlhb) - - def init_weights(self, nlhb=False): - head_bias = -math.log(self.num_classes) if nlhb else 0. - named_apply(partial(_init_weights, head_bias=head_bias), module=self) # depth-first + # self.init_weights(nlhb=nlhb) + self.apply(self.init_weights) + + # def init_weights(self, nlhb=False): + # head_bias = -math.log(self.num_classes) if nlhb else 0. + # named_apply(partial(_init_weights, head_bias=head_bias), module=self) # depth-first + def init_weights(m): + if isinstance(m, nn.Linear): + torch.nn.init.xavier_uniform(m.weight) + m.bias.data.fill_(0.01) def get_classifier(self): return self.head