diff --git a/timm/models/metaformers.py b/timm/models/metaformers.py index 1ca8ad7e..c31185fa 100644 --- a/timm/models/metaformers.py +++ b/timm/models/metaformers.py @@ -706,13 +706,16 @@ class MetaFormer(nn.Module): self.stages = nn.Sequential(*stages) self.norm = self.output_norm(self.num_features) - + ''' self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) if head_dropout > 0.0: self.head = self.head_fn(self.num_features, self.num_classes, head_dropout=self.head_dropout) else: self.head = self.head_fn(self.num_features, self.num_classes) + + ''' + self.reset_classifier(self.num_classes, global_pool) self.apply(self._init_weights) @@ -742,9 +745,9 @@ class MetaFormer(nn.Module): else: self.norm = self.output_norm(self.num_features) if self.head_dropout > 0.0: - self.head = self.head_fn(self.num_features, self.num_classes, head_dropout=self.head_dropout) + self.head = self.head_fn(self.num_features, num_classes, head_dropout=self.head_dropout) else: - self.head = self.head_fn(self.num_features, self.num_classes) + self.head = self.head_fn(self.num_features, num_classes) def forward_head(self, x, pre_logits: bool = False): if pre_logits: