Update metaformers.py

pull/1647/head
Fredo Guan 3 years ago
parent f400e8a3c9
commit 1b1b1d83b4

@ -706,7 +706,7 @@ class MetaFormer(nn.Module):
self.stages = nn.Sequential(*stages)
self.norm = self.output_norm(self.num_features)
'''
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
if head_dropout > 0.0:
@ -714,6 +714,9 @@ class MetaFormer(nn.Module):
else:
self.head = self.head_fn(self.num_features, self.num_classes)
'''
self.reset_classifier(self.num_classes, global_pool)
self.apply(self._init_weights)
def _init_weights(self, m):
@ -742,9 +745,9 @@ class MetaFormer(nn.Module):
else:
self.norm = self.output_norm(self.num_features)
if self.head_dropout > 0.0:
self.head = self.head_fn(self.num_features, self.num_classes, head_dropout=self.head_dropout)
self.head = self.head_fn(self.num_features, num_classes, head_dropout=self.head_dropout)
else:
self.head = self.head_fn(self.num_features, self.num_classes)
self.head = self.head_fn(self.num_features, num_classes)
def forward_head(self, x, pre_logits: bool = False):
if pre_logits:

Loading…
Cancel
Save