Update metaformers.py

pull/1647/head
Fredo Guan 2 years ago
parent 01f671ed08
commit c7e1819ca5

@ -588,7 +588,7 @@ class MetaFormer(nn.Module):
if not isinstance(res_scale_init_values, (list, tuple)): if not isinstance(res_scale_init_values, (list, tuple)):
res_scale_init_values = [res_scale_init_values] * num_stage res_scale_init_values = [res_scale_init_values] * num_stage
self.stages = nn.ModuleList() # each stage consists of multiple metaformer blocks stages = nn.ModuleList() # each stage consists of multiple metaformer blocks
cur = 0 cur = 0
for i in range(num_stage): for i in range(num_stage):
stage = nn.Sequential( stage = nn.Sequential(
@ -603,8 +603,11 @@ class MetaFormer(nn.Module):
) )
self.stages.append(stage) self.stages.append(stage)
cur += depths[i] cur += depths[i]
self.stages = nn.Sequential(*stages)
self.norm = output_norm(dims[-1]) self.norm = output_norm(dims[-1])
if head_dropout > 0.0: if head_dropout > 0.0:
self.head = head_fn(dims[-1], num_classes, head_dropout=head_dropout) self.head = head_fn(dims[-1], num_classes, head_dropout=head_dropout)

Loading…
Cancel
Save