From c7e1819ca5520e7180ba87902e923ffaf9a6cdcd Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Sat, 7 Jan 2023 21:23:56 -0800 Subject: [PATCH] Update metaformers.py --- timm/models/metaformers.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/timm/models/metaformers.py b/timm/models/metaformers.py index 0a601783..ab1b0f0c 100644 --- a/timm/models/metaformers.py +++ b/timm/models/metaformers.py @@ -588,7 +588,7 @@ class MetaFormer(nn.Module): if not isinstance(res_scale_init_values, (list, tuple)): res_scale_init_values = [res_scale_init_values] * num_stage - self.stages = nn.ModuleList() # each stage consists of multiple metaformer blocks + stages = nn.ModuleList() # each stage consists of multiple metaformer blocks cur = 0 for i in range(num_stage): stage = nn.Sequential( @@ -603,8 +603,11 @@ class MetaFormer(nn.Module): ) self.stages.append(stage) cur += depths[i] - + + self.stages = nn.Sequential(*stages) self.norm = output_norm(dims[-1]) + + if head_dropout > 0.0: self.head = head_fn(dims[-1], num_classes, head_dropout=head_dropout)