@ -601,10 +601,11 @@ class MetaFormer(nn.Module):
res_scale_init_value=res_scale_init_values[i],
) for j in range(depths[i])]
)
stages.append(downsample_layers[i])
stages.append(stage)
cur += depths[i]
self.stages = nn.Sequential(zip(*downsample_layers, *stages))
self.stages = nn.Sequential(*stages)
self.norm = output_norm(dims[-1])