Update metaformers.py

pull/1647/head
Fredo Guan 2 years ago
parent 61e8414ad0
commit 199b443884

@ -26,7 +26,7 @@ from functools import partial
import torch import torch
import torch.nn as nn import torch.nn as nn
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import trunc_normal_, DropPath from timm.layers import trunc_normal_, DropPath, SelectAdaptivePool2d
from timm.layers.helpers import to_2tuple from timm.layers.helpers import to_2tuple
from ._builder import build_model_with_cfg from ._builder import build_model_with_cfg
from ._features import FeatureInfo from ._features import FeatureInfo
@ -640,6 +640,7 @@ class MetaFormer(nn.Module):
res_scale_init_values=[None, None, 1.0, 1.0], res_scale_init_values=[None, None, 1.0, 1.0],
output_norm=partial(nn.LayerNorm, eps=1e-6), output_norm=partial(nn.LayerNorm, eps=1e-6),
head_fn=nn.Linear, head_fn=nn.Linear,
global_pool = 'avg',
**kwargs, **kwargs,
): ):
super().__init__() super().__init__()
@ -705,7 +706,7 @@ class MetaFormer(nn.Module):
self.stages = nn.Sequential(*stages) self.stages = nn.Sequential(*stages)
self.norm = output_norm(dims[-1]) self.norm = output_norm(dims[-1])
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
if head_dropout > 0.0: if head_dropout > 0.0:
self.head = self.head_fn(self.num_features, self.num_classes, head_dropout=self.head_dropout) self.head = self.head_fn(self.num_features, self.num_classes, head_dropout=self.head_dropout)
@ -731,7 +732,7 @@ class MetaFormer(nn.Module):
def reset_classifier(self, num_classes=0, global_pool=None): def reset_classifier(self, num_classes=0, global_pool=None):
if num_classes == 0: if num_classes == 0:
self.head= nn.Identity() self.head = nn.Identity()
self.norm = nn.Identity() self.norm = nn.Identity()
else: else:
if self.head_dropout > 0.0: if self.head_dropout > 0.0:
@ -743,7 +744,7 @@ class MetaFormer(nn.Module):
if pre_logits: if pre_logits:
return x return x
x = x.mean([-1,-2]) # TODO use adaptive pool instead of mean x = self.global_pool(x)
x = self.norm(x) x = self.norm(x)
# (B, H, W, C) -> (B, C) # (B, H, W, C) -> (B, C)
x = self.head(x) x = self.head(x)

Loading…
Cancel
Save