@ -668,7 +668,7 @@ class MetaFormer(nn.Module):
)
stages.append(stage)
cur += depths[i]
self.feature_info += [dict(num_chs=dims[stage_id], reduction=2, module=f'stages.{stage_id}')]
self.feature_info += [dict(num_chs=dims[i], reduction=2, module=f'stages.{i}')]
self.stages = nn.Sequential(*stages)
self.norm = output_norm(dims[-1])