From cd36989a604c86b159331452894c7c33f3fde758 Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Tue, 17 Jan 2023 20:57:51 -0800 Subject: [PATCH] Update metaformers.py --- timm/models/metaformers.py | 71 ++++++++++++++------------------------ 1 file changed, 26 insertions(+), 45 deletions(-) diff --git a/timm/models/metaformers.py b/timm/models/metaformers.py index 9aef43a5..0efaedc6 100644 --- a/timm/models/metaformers.py +++ b/timm/models/metaformers.py @@ -60,10 +60,7 @@ class Downsampling(nn.Module): x = self.pre_norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) x = self.conv(x) - - x = self.post_norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) - print(x[0][0][0][0]) return x ''' class Downsampling(nn.Module): @@ -494,10 +491,11 @@ class MetaFormer(nn.Module): mlp_bias=False, norm_layers=partial(LayerNormGeneral, eps=1e-6, bias=False), drop_path_rate=0., - head_dropout=0.0, + drop_rate=0.0, layer_scale_init_values=None, res_scale_init_values=[None, None, 1.0, 1.0], output_norm=partial(nn.LayerNorm, eps=1e-6), + head_norm_first=False, head_fn=nn.Linear, global_pool = 'avg', **kwargs, @@ -506,9 +504,8 @@ class MetaFormer(nn.Module): self.num_classes = num_classes self.head_fn = head_fn self.num_features = dims[-1] - self.head_dropout = head_dropout - self.output_norm = output_norm - + self.drop_rate = drop_rate + if not isinstance(depths, (list, tuple)): depths = [depths] # it means the model has only one stage if not isinstance(dims, (list, tuple)): @@ -586,15 +583,16 @@ class MetaFormer(nn.Module): self.feature_info += [dict(num_chs=dims[i], reduction=2, module=f'stages.{i}')] self.stages = nn.Sequential(*stages) - self.norm = self.output_norm(self.num_features) - - self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) - - if head_dropout > 0.0: - self.head = self.head_fn(self.num_features, self.num_classes, head_dropout=self.head_dropout) - else: - self.head = self.head_fn(self.num_features, self.num_classes) + # if head_norm_first == true, norm -> global pool -> fc ordering, like most other nets + # otherwise pool -> norm -> fc, similar to ConvNeXt + self.norm_pre = output_norm(self.num_features) if head_norm_first else nn.Identity() + self.head = nn.Sequential(OrderedDict([ + ('global_pool', SelectAdaptivePool2d(pool_type=global_pool)), + ('norm', nn.Identity() if head_norm_first else output_norm(self.num_features)), + ('flatten', nn.Flatten(1) if global_pool else nn.Identity()), + ('drop', nn.Dropout(self.drop_rate)), + ('fc', nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity())])) self.apply(self._init_weights) @@ -613,40 +611,23 @@ class MetaFormer(nn.Module): return self.head.fc2 def reset_classifier(self, num_classes=0, global_pool=None): - if global_pool is not None: - self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) - - - if num_classes == 0: - self.head = nn.Identity() - self.norm = nn.Identity() - else: - self.norm = self.output_norm(self.num_features) - if self.head_dropout > 0.0: - self.head = self.head_fn(self.num_features, num_classes, head_dropout=self.head_dropout) - else: - self.head = self.head_fn(self.num_features, num_classes) + self.head.global_pool = SelectAdaptivePool2d(pool_type=global_pool) + self.head.flatten = nn.Flatten(1) if global_pool else nn.Identity() + self.head.fc = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() def forward_head(self, x, pre_logits: bool = False): - if pre_logits: - return x - - #x = self.global_pool(x) - #x = x.squeeze() - #x = self.norm(x) - # (B, H, W, C) -> (B, C) - #x = self.head(x) - x=self.head(self.norm(x.mean([2, 3]))) - return x + # NOTE nn.Sequential in head broken down since can't call head[:-1](x) in torchscript :( + x = self.head.global_pool(x) + x = self.head.norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) + x = self.head.flatten(x) + x = self.head.drop(x) + return x if pre_logits else self.head.fc(x) def forward_features(self, x): x = self.patch_embed(x) - #x = self.stages(x) - for i, stage in enumerate(self.stages): - x = stage(x) - - + x = self.stages(x) + x = self.norm_pre(x) return x def forward(self, x): @@ -658,7 +639,6 @@ def checkpoint_filter_fn(state_dict, model): import re out_dict = {} for k, v in state_dict.items(): - ''' k = k.replace('proj', 'conv') k = re.sub(r'layer_scale_([0-9]+)', r'layer_scale\1.scale', k) k = k.replace('network.1', 'downsample_layers.1') @@ -668,10 +648,11 @@ def checkpoint_filter_fn(state_dict, model): k = k.replace('network.4', 'network.2') k = k.replace('network.6', 'network.3') k = k.replace('network', 'stages') - ''' k = re.sub(r'downsample_layers.([0-9]+)', r'stages.\1.downsample', k) k = re.sub(r'([0-9]+).([0-9]+)', r'\1.blocks.\2', k) k = k.replace('stages.0.downsample', 'patch_embed') + k = re.sub(r'^head', 'head.fc', k) + k = re.sub(r'^norm', 'head.norm', k) out_dict[k] = v return out_dict