diff --git a/timm/models/metaformers.py b/timm/models/metaformers.py index 01857167..1ca8ad7e 100644 --- a/timm/models/metaformers.py +++ b/timm/models/metaformers.py @@ -648,6 +648,7 @@ class MetaFormer(nn.Module): self.head_fn = head_fn self.num_features = dims[-1] self.head_dropout = head_dropout + self.output_norm = output_norm if not isinstance(depths, (list, tuple)): depths = [depths] # it means the model has only one stage @@ -704,7 +705,7 @@ class MetaFormer(nn.Module): self.feature_info += [dict(num_chs=dims[i], reduction=2, module=f'stages.{i}')] self.stages = nn.Sequential(*stages) - self.norm = output_norm(dims[-1]) + self.norm = self.output_norm(self.num_features) self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) @@ -739,6 +740,7 @@ class MetaFormer(nn.Module): self.head = nn.Identity() self.norm = nn.Identity() else: + self.norm = self.output_norm(self.num_features) if self.head_dropout > 0.0: self.head = self.head_fn(self.num_features, self.num_classes, head_dropout=self.head_dropout) else: