|
|
|
@ -644,6 +644,9 @@ class MetaFormer(nn.Module):
|
|
|
|
|
):
|
|
|
|
|
super().__init__()
|
|
|
|
|
self.num_classes = num_classes
|
|
|
|
|
self.head_fn = head_fn
|
|
|
|
|
self.num_features = dims[-1]
|
|
|
|
|
self.head_dropout = head_dropout
|
|
|
|
|
|
|
|
|
|
if not isinstance(depths, (list, tuple)):
|
|
|
|
|
depths = [depths] # it means the model has only one stage
|
|
|
|
@ -705,9 +708,9 @@ class MetaFormer(nn.Module):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if head_dropout > 0.0:
|
|
|
|
|
self.head = head_fn(dims[-1], num_classes, head_dropout=head_dropout)
|
|
|
|
|
self.head = self.head_fn(self.num_features, self.num_classes, head_dropout=self.head_dropout)
|
|
|
|
|
else:
|
|
|
|
|
self.head = head_fn(dims[-1], num_classes)
|
|
|
|
|
self.head = self.head_fn(self.num_featuers, self.num_classes)
|
|
|
|
|
|
|
|
|
|
self.apply(self._init_weights)
|
|
|
|
|
|
|
|
|
@ -723,20 +726,18 @@ class MetaFormer(nn.Module):
|
|
|
|
|
|
|
|
|
|
@torch.jit.ignore
|
|
|
|
|
def get_classifier(self):
|
|
|
|
|
return self.head.fc
|
|
|
|
|
return self.head.fc2
|
|
|
|
|
|
|
|
|
|
def reset_classifier(self, num_classes=0, global_pool=None):
|
|
|
|
|
if global_pool is not None:
|
|
|
|
|
self.head.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
|
|
|
|
|
self.head.flatten = nn.Flatten(1) if global_pool else nn.Identity()
|
|
|
|
|
|
|
|
|
|
if num_classes == 0:
|
|
|
|
|
self.head.norm = nn.Identity()
|
|
|
|
|
self.head.fc = nn.Identity()
|
|
|
|
|
self.head= nn.Identity()
|
|
|
|
|
self.norm = nn.Identity()
|
|
|
|
|
else:
|
|
|
|
|
if not self.head_norm_first:
|
|
|
|
|
norm_layer = type(self.stem[-1]) # obtain type from stem norm
|
|
|
|
|
self.head.norm = norm_layer(self.num_features)
|
|
|
|
|
self.head.fc = nn.Linear(self.num_features, num_classes)
|
|
|
|
|
if self.head_dropout > 0.0:
|
|
|
|
|
self.head = self.head_fn(self.num_features, self.num_classes, head_dropout=self.head_dropout)
|
|
|
|
|
else:
|
|
|
|
|
self.head = self.head_fn(self.num_featuers, self.num_classes)
|
|
|
|
|
|
|
|
|
|
def forward_head(self, x, pre_logits: bool = False):
|
|
|
|
|
if pre_logits:
|
|
|
|
@ -756,7 +757,6 @@ class MetaFormer(nn.Module):
|
|
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
|
x = self.forward_features(x)
|
|
|
|
|
print(x.shape)
|
|
|
|
|
x = self.forward_head(x)
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|