diff --git a/timm/models/metaformers.py b/timm/models/metaformers.py index 89651e02..921a33cc 100644 --- a/timm/models/metaformers.py +++ b/timm/models/metaformers.py @@ -26,7 +26,7 @@ from functools import partial import torch import torch.nn as nn from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from timm.layers import trunc_normal_, DropPath +from timm.layers import trunc_normal_, DropPath, SelectAdaptivePool2d from timm.layers.helpers import to_2tuple from ._builder import build_model_with_cfg from ._features import FeatureInfo @@ -640,6 +640,7 @@ class MetaFormer(nn.Module): res_scale_init_values=[None, None, 1.0, 1.0], output_norm=partial(nn.LayerNorm, eps=1e-6), head_fn=nn.Linear, + global_pool = 'avg', **kwargs, ): super().__init__() @@ -705,7 +706,7 @@ class MetaFormer(nn.Module): self.stages = nn.Sequential(*stages) self.norm = output_norm(dims[-1]) - + self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) if head_dropout > 0.0: self.head = self.head_fn(self.num_features, self.num_classes, head_dropout=self.head_dropout) @@ -731,7 +732,7 @@ class MetaFormer(nn.Module): def reset_classifier(self, num_classes=0, global_pool=None): if num_classes == 0: - self.head= nn.Identity() + self.head = nn.Identity() self.norm = nn.Identity() else: if self.head_dropout > 0.0: @@ -743,7 +744,7 @@ class MetaFormer(nn.Module): if pre_logits: return x - x = x.mean([-1,-2]) # TODO use adaptive pool instead of mean + x = self.global_pool(x) x = self.norm(x) # (B, H, W, C) -> (B, C) x = self.head(x)