|
|
|
@ -710,7 +710,7 @@ class MetaFormer(nn.Module):
|
|
|
|
|
if head_dropout > 0.0:
|
|
|
|
|
self.head = self.head_fn(self.num_features, self.num_classes, head_dropout=self.head_dropout)
|
|
|
|
|
else:
|
|
|
|
|
self.head = self.head_fn(self.num_featuers, self.num_classes)
|
|
|
|
|
self.head = self.head_fn(self.num_features, self.num_classes)
|
|
|
|
|
|
|
|
|
|
self.apply(self._init_weights)
|
|
|
|
|
|
|
|
|
@ -737,7 +737,7 @@ class MetaFormer(nn.Module):
|
|
|
|
|
if self.head_dropout > 0.0:
|
|
|
|
|
self.head = self.head_fn(self.num_features, self.num_classes, head_dropout=self.head_dropout)
|
|
|
|
|
else:
|
|
|
|
|
self.head = self.head_fn(self.num_featuers, self.num_classes)
|
|
|
|
|
self.head = self.head_fn(self.num_features, self.num_classes)
|
|
|
|
|
|
|
|
|
|
def forward_head(self, x, pre_logits: bool = False):
|
|
|
|
|
if pre_logits:
|
|
|
|
|