|
|
|
@ -26,7 +26,7 @@ from functools import partial
|
|
|
|
|
import torch
|
|
|
|
|
import torch.nn as nn
|
|
|
|
|
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
|
|
|
|
from timm.layers import trunc_normal_, DropPath
|
|
|
|
|
from timm.layers import trunc_normal_, DropPath, SelectAdaptivePool2d
|
|
|
|
|
from timm.layers.helpers import to_2tuple
|
|
|
|
|
from ._builder import build_model_with_cfg
|
|
|
|
|
from ._features import FeatureInfo
|
|
|
|
@ -640,6 +640,7 @@ class MetaFormer(nn.Module):
|
|
|
|
|
res_scale_init_values=[None, None, 1.0, 1.0],
|
|
|
|
|
output_norm=partial(nn.LayerNorm, eps=1e-6),
|
|
|
|
|
head_fn=nn.Linear,
|
|
|
|
|
global_pool = 'avg',
|
|
|
|
|
**kwargs,
|
|
|
|
|
):
|
|
|
|
|
super().__init__()
|
|
|
|
@ -705,7 +706,7 @@ class MetaFormer(nn.Module):
|
|
|
|
|
self.stages = nn.Sequential(*stages)
|
|
|
|
|
self.norm = output_norm(dims[-1])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
|
|
|
|
|
|
|
|
|
|
if head_dropout > 0.0:
|
|
|
|
|
self.head = self.head_fn(self.num_features, self.num_classes, head_dropout=self.head_dropout)
|
|
|
|
@ -731,7 +732,7 @@ class MetaFormer(nn.Module):
|
|
|
|
|
def reset_classifier(self, num_classes=0, global_pool=None):
|
|
|
|
|
|
|
|
|
|
if num_classes == 0:
|
|
|
|
|
self.head= nn.Identity()
|
|
|
|
|
self.head = nn.Identity()
|
|
|
|
|
self.norm = nn.Identity()
|
|
|
|
|
else:
|
|
|
|
|
if self.head_dropout > 0.0:
|
|
|
|
@ -743,7 +744,7 @@ class MetaFormer(nn.Module):
|
|
|
|
|
if pre_logits:
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
x = x.mean([-1,-2]) # TODO use adaptive pool instead of mean
|
|
|
|
|
x = self.global_pool(x)
|
|
|
|
|
x = self.norm(x)
|
|
|
|
|
# (B, H, W, C) -> (B, C)
|
|
|
|
|
x = self.head(x)
|
|
|
|
|