Update metaformers.py

pull/1647/head
Fredo Guan 2 years ago
parent 01f671ed08
commit c7e1819ca5

@ -588,7 +588,7 @@ class MetaFormer(nn.Module):
if not isinstance(res_scale_init_values, (list, tuple)):
res_scale_init_values = [res_scale_init_values] * num_stage
self.stages = nn.ModuleList() # each stage consists of multiple metaformer blocks
stages = nn.ModuleList() # each stage consists of multiple metaformer blocks
cur = 0
for i in range(num_stage):
stage = nn.Sequential(
@ -603,8 +603,11 @@ class MetaFormer(nn.Module):
)
self.stages.append(stage)
cur += depths[i]
self.stages = nn.Sequential(*stages)
self.norm = output_norm(dims[-1])
if head_dropout > 0.0:
self.head = head_fn(dims[-1], num_classes, head_dropout=head_dropout)

Loading…
Cancel
Save