|
|
|
@ -586,13 +586,22 @@ class MetaFormer(nn.Module):
|
|
|
|
|
|
|
|
|
|
# if head_norm_first == true, norm -> global pool -> fc ordering, like most other nets
|
|
|
|
|
# otherwise pool -> norm -> fc, similar to ConvNeXt
|
|
|
|
|
# drop removed - if using single fc layer, models have no dropout
|
|
|
|
|
# if using MlpHead, dropout is handled by MlpHead
|
|
|
|
|
if num_classes > 0:
|
|
|
|
|
if self.drop_rate > 0.0:
|
|
|
|
|
head = self.head_fn(dims[-1], num_classes, head_dropout=self.drop_rate)
|
|
|
|
|
else:
|
|
|
|
|
head = self.head_fn(dims[-1], num_classes)
|
|
|
|
|
else:
|
|
|
|
|
head = nn.Identity()
|
|
|
|
|
|
|
|
|
|
self.norm_pre = output_norm(self.num_features) if head_norm_first else nn.Identity()
|
|
|
|
|
self.head = nn.Sequential(OrderedDict([
|
|
|
|
|
('global_pool', SelectAdaptivePool2d(pool_type=global_pool)),
|
|
|
|
|
('norm', nn.Identity() if head_norm_first else output_norm(self.num_features)),
|
|
|
|
|
('flatten', nn.Flatten(1) if global_pool else nn.Identity()),
|
|
|
|
|
('drop', nn.Dropout(self.drop_rate)),
|
|
|
|
|
('fc', nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity())]))
|
|
|
|
|
('fc', head)]))
|
|
|
|
|
|
|
|
|
|
self.apply(self._init_weights)
|
|
|
|
|
|
|
|
|
@ -608,20 +617,26 @@ class MetaFormer(nn.Module):
|
|
|
|
|
|
|
|
|
|
@torch.jit.ignore
|
|
|
|
|
def get_classifier(self):
|
|
|
|
|
return self.head.fc2
|
|
|
|
|
return self.head.fc
|
|
|
|
|
|
|
|
|
|
def reset_classifier(self, num_classes=0, global_pool=None):
|
|
|
|
|
if global_pool is not None:
|
|
|
|
|
self.head.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
|
|
|
|
|
self.head.flatten = nn.Flatten(1) if global_pool else nn.Identity()
|
|
|
|
|
self.head.fc = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
|
|
|
|
|
if num_classes > 0:
|
|
|
|
|
if self.drop_rate > 0.0:
|
|
|
|
|
head = self.head_fn(dims[-1], num_classes, head_dropout=self.drop_rate)
|
|
|
|
|
else:
|
|
|
|
|
head = self.head_fn(dims[-1], num_classes)
|
|
|
|
|
else:
|
|
|
|
|
head = nn.Identity()
|
|
|
|
|
self.head.fc = head
|
|
|
|
|
|
|
|
|
|
def forward_head(self, x, pre_logits: bool = False):
|
|
|
|
|
# NOTE nn.Sequential in head broken down since can't call head[:-1](x) in torchscript :(
|
|
|
|
|
x = self.head.global_pool(x)
|
|
|
|
|
x = self.head.norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
|
|
|
|
|
x = self.head.flatten(x)
|
|
|
|
|
x = self.head.drop(x)
|
|
|
|
|
return x if pre_logits else self.head.fc(x)
|
|
|
|
|
|
|
|
|
|
def forward_features(self, x):
|
|
|
|
|