|
|
|
@ -307,11 +307,16 @@ class MlpMixer(nn.Module):
|
|
|
|
|
)
|
|
|
|
|
#self.sigmoid = nn.Sigmoid()
|
|
|
|
|
self.sm = nn.Softmax(dim=1)
|
|
|
|
|
self.init_weights(nlhb=nlhb)
|
|
|
|
|
|
|
|
|
|
def init_weights(self, nlhb=False):
|
|
|
|
|
head_bias = -math.log(self.num_classes) if nlhb else 0.
|
|
|
|
|
named_apply(partial(_init_weights, head_bias=head_bias), module=self) # depth-first
|
|
|
|
|
# self.init_weights(nlhb=nlhb)
|
|
|
|
|
self.apply(self.init_weights)
|
|
|
|
|
|
|
|
|
|
# def init_weights(self, nlhb=False):
|
|
|
|
|
# head_bias = -math.log(self.num_classes) if nlhb else 0.
|
|
|
|
|
# named_apply(partial(_init_weights, head_bias=head_bias), module=self) # depth-first
|
|
|
|
|
def init_weights(m):
|
|
|
|
|
if isinstance(m, nn.Linear):
|
|
|
|
|
torch.nn.init.xavier_uniform(m.weight)
|
|
|
|
|
m.bias.data.fill_(0.01)
|
|
|
|
|
|
|
|
|
|
def get_classifier(self):
|
|
|
|
|
return self.head
|
|
|
|
|