From 4cfecf8acb528f1ee2d349323a263eec1e1a4927 Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Sun, 8 Jan 2023 11:19:28 -0800 Subject: [PATCH] Update metaformers.py --- timm/models/metaformers.py | 23 ++++++++++++++++++----- 1 file changed, 18 insertions(+), 5 deletions(-) diff --git a/timm/models/metaformers.py b/timm/models/metaformers.py index 794fb62c..2ba67b0a 100644 --- a/timm/models/metaformers.py +++ b/timm/models/metaformers.py @@ -646,7 +646,10 @@ class MetaFormer(nn.Module): layer_scale_init_values = [layer_scale_init_values] * num_stage if not isinstance(res_scale_init_values, (list, tuple)): res_scale_init_values = [res_scale_init_values] * num_stage - + + self.grad_checkpointing = False + self.feature_info = [] + stages = nn.ModuleList() # each stage consists of multiple metaformer blocks cur = 0 for i in range(num_stage): @@ -665,6 +668,7 @@ class MetaFormer(nn.Module): ) stages.append(stage) cur += depths[i] + self.feature_info += [dict(num_chs=dims[stage_id], reduction=2, module=f'stages.{stage_id}')] self.stages = nn.Sequential(*stages) self.norm = output_norm(dims[-1]) @@ -687,17 +691,26 @@ class MetaFormer(nn.Module): @torch.jit.ignore def no_weight_decay(self): return {'norm'} - - def forward_features(self, x): - x = self.stages(x) + + def forward_head(self, x, pre_logits: bool = False): + if pre_logits: + return x + x = x.mean([1,2]) # TODO use adaptive pool instead of mean x = self.norm(x) # (B, H, W, C) -> (B, C) + x = self.head(x) + return x + + def forward_features(self, x): + x = self.stages(x) + + return x def forward(self, x): x = self.forward_features(x) - x = self.head(x) + x = forward_head(x) return x def checkpoint_filter_fn(state_dict, model):