From 95ec7cf01668a8ae7869b78ff011fa8085a45253 Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Thu, 12 Jan 2023 11:01:28 -0800 Subject: [PATCH] Update metaformers.py --- timm/models/metaformers.py | 26 +++++++++++++------------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/timm/models/metaformers.py b/timm/models/metaformers.py index 3aa6ff1f..ef43ded3 100644 --- a/timm/models/metaformers.py +++ b/timm/models/metaformers.py @@ -644,6 +644,9 @@ class MetaFormer(nn.Module): ): super().__init__() self.num_classes = num_classes + self.head_fn = head_fn + self.num_features = dims[-1] + self.head_dropout = head_dropout if not isinstance(depths, (list, tuple)): depths = [depths] # it means the model has only one stage @@ -705,9 +708,9 @@ class MetaFormer(nn.Module): if head_dropout > 0.0: - self.head = head_fn(dims[-1], num_classes, head_dropout=head_dropout) + self.head = self.head_fn(self.num_features, self.num_classes, head_dropout=self.head_dropout) else: - self.head = head_fn(dims[-1], num_classes) + self.head = self.head_fn(self.num_featuers, self.num_classes) self.apply(self._init_weights) @@ -723,20 +726,18 @@ class MetaFormer(nn.Module): @torch.jit.ignore def get_classifier(self): - return self.head.fc + return self.head.fc2 def reset_classifier(self, num_classes=0, global_pool=None): - if global_pool is not None: - self.head.global_pool = SelectAdaptivePool2d(pool_type=global_pool) - self.head.flatten = nn.Flatten(1) if global_pool else nn.Identity() + if num_classes == 0: - self.head.norm = nn.Identity() - self.head.fc = nn.Identity() + self.head= nn.Identity() + self.norm = nn.Identity() else: - if not self.head_norm_first: - norm_layer = type(self.stem[-1]) # obtain type from stem norm - self.head.norm = norm_layer(self.num_features) - self.head.fc = nn.Linear(self.num_features, num_classes) + if self.head_dropout > 0.0: + self.head = self.head_fn(self.num_features, self.num_classes, head_dropout=self.head_dropout) + else: + self.head = self.head_fn(self.num_featuers, self.num_classes) def forward_head(self, x, pre_logits: bool = False): if pre_logits: @@ -756,7 +757,6 @@ class MetaFormer(nn.Module): def forward(self, x): x = self.forward_features(x) - print(x.shape) x = self.forward_head(x) return x