diff --git a/timm/models/metaformers.py b/timm/models/metaformers.py index b44d593b..a819421c 100644 --- a/timm/models/metaformers.py +++ b/timm/models/metaformers.py @@ -586,13 +586,22 @@ class MetaFormer(nn.Module): # if head_norm_first == true, norm -> global pool -> fc ordering, like most other nets # otherwise pool -> norm -> fc, similar to ConvNeXt + # drop removed - if using single fc layer, models have no dropout + # if using MlpHead, dropout is handled by MlpHead + if num_classes > 0: + if self.drop_rate > 0.0: + head = self.head_fn(dims[-1], num_classes, head_dropout=self.drop_rate) + else: + head = self.head_fn(dims[-1], num_classes) + else: + head = nn.Identity() + self.norm_pre = output_norm(self.num_features) if head_norm_first else nn.Identity() self.head = nn.Sequential(OrderedDict([ ('global_pool', SelectAdaptivePool2d(pool_type=global_pool)), ('norm', nn.Identity() if head_norm_first else output_norm(self.num_features)), ('flatten', nn.Flatten(1) if global_pool else nn.Identity()), - ('drop', nn.Dropout(self.drop_rate)), - ('fc', nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity())])) + ('fc', head)])) self.apply(self._init_weights) @@ -608,20 +617,26 @@ class MetaFormer(nn.Module): @torch.jit.ignore def get_classifier(self): - return self.head.fc2 + return self.head.fc def reset_classifier(self, num_classes=0, global_pool=None): if global_pool is not None: self.head.global_pool = SelectAdaptivePool2d(pool_type=global_pool) self.head.flatten = nn.Flatten(1) if global_pool else nn.Identity() - self.head.fc = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() + if num_classes > 0: + if self.drop_rate > 0.0: + head = self.head_fn(dims[-1], num_classes, head_dropout=self.drop_rate) + else: + head = self.head_fn(dims[-1], num_classes) + else: + head = nn.Identity() + self.head.fc = head def forward_head(self, x, pre_logits: bool = False): # NOTE nn.Sequential in head broken down since can't call head[:-1](x) in torchscript :( x = self.head.global_pool(x) x = self.head.norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) x = self.head.flatten(x) - x = self.head.drop(x) return x if pre_logits else self.head.fc(x) def forward_features(self, x):