|
|
|
@ -588,7 +588,7 @@ class MetaFormer(nn.Module):
|
|
|
|
|
if not isinstance(res_scale_init_values, (list, tuple)):
|
|
|
|
|
res_scale_init_values = [res_scale_init_values] * num_stage
|
|
|
|
|
|
|
|
|
|
self.stages = nn.ModuleList() # each stage consists of multiple metaformer blocks
|
|
|
|
|
stages = nn.ModuleList() # each stage consists of multiple metaformer blocks
|
|
|
|
|
cur = 0
|
|
|
|
|
for i in range(num_stage):
|
|
|
|
|
stage = nn.Sequential(
|
|
|
|
@ -603,8 +603,11 @@ class MetaFormer(nn.Module):
|
|
|
|
|
)
|
|
|
|
|
self.stages.append(stage)
|
|
|
|
|
cur += depths[i]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.stages = nn.Sequential(*stages)
|
|
|
|
|
self.norm = output_norm(dims[-1])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if head_dropout > 0.0:
|
|
|
|
|
self.head = head_fn(dims[-1], num_classes, head_dropout=head_dropout)
|
|
|
|
|