diff --git a/timm/models/metaformers.py b/timm/models/metaformers.py index 733d05c9..87018b47 100644 --- a/timm/models/metaformers.py +++ b/timm/models/metaformers.py @@ -536,7 +536,6 @@ class Mlp(nn.Module): self.drop2 = nn.Dropout(drop_probs[1]) def forward(self, x): - print(x) x = self.fc1(x) x = self.act(x) x = self.drop1(x) @@ -613,8 +612,9 @@ class MetaFormerBlock(nn.Module): if res_scale_init_value else nn.Identity() def forward(self, x): - B, C, H, W = x.shape - x = x.view(B, H, W, C) + #B, C, H, W = x.shape + #x = x.view(B, H, W, C) + x = x.permute(0, 2, 3, 1) x = self.res_scale1(x) + \ self.layer_scale1( self.drop_path1( @@ -627,7 +627,8 @@ class MetaFormerBlock(nn.Module): self.mlp(self.norm2(x)) ) ) - x = x.view(B, C, H, W) + #x = x.view(B, C, H, W) + x = x.permute(0, 3, 1, 2) return x class MetaFormer(nn.Module): @@ -813,25 +814,31 @@ class MetaFormer(nn.Module): x = self.norm(x) # (B, H, W, C) -> (B, C) x = self.head(x) + #x=self.head(self.norm(x.mean([2, 3]))) return x def forward_features(self, x): x = self.patch_embed(x) - x = self.stages(x) - + print('timm') + #x = self.stages(x) + ''' + for i, stage in enumerate(self.stages): + x=stage(x) + #print(x[0][0][0][0]) + ''' return x def forward(self, x): x = self.forward_features(x) x = self.forward_head(x) - print('timm') return x def checkpoint_filter_fn(state_dict, model): import re out_dict = {} for k, v in state_dict.items(): + ''' k = k.replace('proj', 'conv') k = re.sub(r'layer_scale_([0-9]+)', r'layer_scale\1.scale', k) k = k.replace('network.1', 'downsample_layers.1') @@ -841,6 +848,7 @@ def checkpoint_filter_fn(state_dict, model): k = k.replace('network.4', 'network.2') k = k.replace('network.6', 'network.3') k = k.replace('network', 'stages') + ''' k = re.sub(r'downsample_layers.([0-9]+)', r'stages.\1.downsample', k) k = re.sub(r'([0-9]+).([0-9]+)', r'\1.blocks.\2', k) k = k.replace('stages.0.downsample', 'patch_embed')