Update metaformers.py

pull/1647/head
Fredo Guan 2 years ago
parent 61e8414ad0
commit 199b443884

@ -26,7 +26,7 @@ from functools import partial
import torch
import torch.nn as nn
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import trunc_normal_, DropPath
from timm.layers import trunc_normal_, DropPath, SelectAdaptivePool2d
from timm.layers.helpers import to_2tuple
from ._builder import build_model_with_cfg
from ._features import FeatureInfo
@ -640,6 +640,7 @@ class MetaFormer(nn.Module):
res_scale_init_values=[None, None, 1.0, 1.0],
output_norm=partial(nn.LayerNorm, eps=1e-6),
head_fn=nn.Linear,
global_pool = 'avg',
**kwargs,
):
super().__init__()
@ -705,7 +706,7 @@ class MetaFormer(nn.Module):
self.stages = nn.Sequential(*stages)
self.norm = output_norm(dims[-1])
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
if head_dropout > 0.0:
self.head = self.head_fn(self.num_features, self.num_classes, head_dropout=self.head_dropout)
@ -731,7 +732,7 @@ class MetaFormer(nn.Module):
def reset_classifier(self, num_classes=0, global_pool=None):
if num_classes == 0:
self.head= nn.Identity()
self.head = nn.Identity()
self.norm = nn.Identity()
else:
if self.head_dropout > 0.0:
@ -743,7 +744,7 @@ class MetaFormer(nn.Module):
if pre_logits:
return x
x = x.mean([-1,-2]) # TODO use adaptive pool instead of mean
x = self.global_pool(x)
x = self.norm(x)
# (B, H, W, C) -> (B, C)
x = self.head(x)

Loading…
Cancel
Save