|
|
|
@ -568,7 +568,7 @@ class MetaFormer(nn.Module):
|
|
|
|
|
if not isinstance(downsample_layers, (list, tuple)):
|
|
|
|
|
downsample_layers = [downsample_layers] * num_stage
|
|
|
|
|
down_dims = [in_chans] + dims
|
|
|
|
|
self.downsample_layers = nn.ModuleList(
|
|
|
|
|
downsample_layers = nn.ModuleList(
|
|
|
|
|
[downsample_layers[i](down_dims[i], down_dims[i+1]) for i in range(num_stage)]
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
@ -604,7 +604,7 @@ class MetaFormer(nn.Module):
|
|
|
|
|
self.stages.append(stage)
|
|
|
|
|
cur += depths[i]
|
|
|
|
|
|
|
|
|
|
self.stages = nn.Sequential(*stages)
|
|
|
|
|
self.stages = nn.Sequential(zip(*downsample_layers, *stages))
|
|
|
|
|
self.norm = output_norm(dims[-1])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|