diff --git a/timm/models/metaformers.py b/timm/models/metaformers.py index 87018b47..733d05c9 100644 --- a/timm/models/metaformers.py +++ b/timm/models/metaformers.py @@ -536,6 +536,7 @@ class Mlp(nn.Module): self.drop2 = nn.Dropout(drop_probs[1]) def forward(self, x): + print(x) x = self.fc1(x) x = self.act(x) x = self.drop1(x) @@ -612,9 +613,8 @@ class MetaFormerBlock(nn.Module): if res_scale_init_value else nn.Identity() def forward(self, x): - #B, C, H, W = x.shape - #x = x.view(B, H, W, C) - x = x.permute(0, 2, 3, 1) + B, C, H, W = x.shape + x = x.view(B, H, W, C) x = self.res_scale1(x) + \ self.layer_scale1( self.drop_path1( @@ -627,8 +627,7 @@ class MetaFormerBlock(nn.Module): self.mlp(self.norm2(x)) ) ) - #x = x.view(B, C, H, W) - x = x.permute(0, 3, 1, 2) + x = x.view(B, C, H, W) return x class MetaFormer(nn.Module): @@ -814,31 +813,25 @@ class MetaFormer(nn.Module): x = self.norm(x) # (B, H, W, C) -> (B, C) x = self.head(x) - #x=self.head(self.norm(x.mean([2, 3]))) return x def forward_features(self, x): x = self.patch_embed(x) - print('timm') - #x = self.stages(x) - ''' - for i, stage in enumerate(self.stages): - x=stage(x) - #print(x[0][0][0][0]) - ''' + x = self.stages(x) + return x def forward(self, x): x = self.forward_features(x) x = self.forward_head(x) + print('timm') return x def checkpoint_filter_fn(state_dict, model): import re out_dict = {} for k, v in state_dict.items(): - ''' k = k.replace('proj', 'conv') k = re.sub(r'layer_scale_([0-9]+)', r'layer_scale\1.scale', k) k = k.replace('network.1', 'downsample_layers.1') @@ -848,7 +841,6 @@ def checkpoint_filter_fn(state_dict, model): k = k.replace('network.4', 'network.2') k = k.replace('network.6', 'network.3') k = k.replace('network', 'stages') - ''' k = re.sub(r'downsample_layers.([0-9]+)', r'stages.\1.downsample', k) k = re.sub(r'([0-9]+).([0-9]+)', r'\1.blocks.\2', k) k = k.replace('stages.0.downsample', 'patch_embed')