added new wt initializtion- xavier_uniform

pull/1229/head
naman jain 3 years ago
parent 450b0d57f0
commit e940755778

@ -307,11 +307,16 @@ class MlpMixer(nn.Module):
) )
#self.sigmoid = nn.Sigmoid() #self.sigmoid = nn.Sigmoid()
self.sm = nn.Softmax(dim=1) self.sm = nn.Softmax(dim=1)
self.init_weights(nlhb=nlhb) # self.init_weights(nlhb=nlhb)
self.apply(self.init_weights)
def init_weights(self, nlhb=False):
head_bias = -math.log(self.num_classes) if nlhb else 0. # def init_weights(self, nlhb=False):
named_apply(partial(_init_weights, head_bias=head_bias), module=self) # depth-first # head_bias = -math.log(self.num_classes) if nlhb else 0.
# named_apply(partial(_init_weights, head_bias=head_bias), module=self) # depth-first
def init_weights(m):
if isinstance(m, nn.Linear):
torch.nn.init.xavier_uniform(m.weight)
m.bias.data.fill_(0.01)
def get_classifier(self): def get_classifier(self):
return self.head return self.head

Loading…
Cancel
Save