|
|
|
@ -648,6 +648,7 @@ class MetaFormer(nn.Module):
|
|
|
|
|
self.head_fn = head_fn
|
|
|
|
|
self.num_features = dims[-1]
|
|
|
|
|
self.head_dropout = head_dropout
|
|
|
|
|
self.output_norm = output_norm
|
|
|
|
|
|
|
|
|
|
if not isinstance(depths, (list, tuple)):
|
|
|
|
|
depths = [depths] # it means the model has only one stage
|
|
|
|
@ -704,7 +705,7 @@ class MetaFormer(nn.Module):
|
|
|
|
|
self.feature_info += [dict(num_chs=dims[i], reduction=2, module=f'stages.{i}')]
|
|
|
|
|
|
|
|
|
|
self.stages = nn.Sequential(*stages)
|
|
|
|
|
self.norm = output_norm(dims[-1])
|
|
|
|
|
self.norm = self.output_norm(self.num_features)
|
|
|
|
|
|
|
|
|
|
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
|
|
|
|
|
|
|
|
|
@ -739,6 +740,7 @@ class MetaFormer(nn.Module):
|
|
|
|
|
self.head = nn.Identity()
|
|
|
|
|
self.norm = nn.Identity()
|
|
|
|
|
else:
|
|
|
|
|
self.norm = self.output_norm(self.num_features)
|
|
|
|
|
if self.head_dropout > 0.0:
|
|
|
|
|
self.head = self.head_fn(self.num_features, self.num_classes, head_dropout=self.head_dropout)
|
|
|
|
|
else:
|
|
|
|
|