Update metaformers.py

pull/1647/head
Fredo Guan 2 years ago
parent 32bede4e27
commit 2209d0830e

@ -536,6 +536,7 @@ class Mlp(nn.Module):
self.drop2 = nn.Dropout(drop_probs[1]) self.drop2 = nn.Dropout(drop_probs[1])
def forward(self, x): def forward(self, x):
print(x)
x = self.fc1(x) x = self.fc1(x)
x = self.act(x) x = self.act(x)
x = self.drop1(x) x = self.drop1(x)
@ -612,9 +613,8 @@ class MetaFormerBlock(nn.Module):
if res_scale_init_value else nn.Identity() if res_scale_init_value else nn.Identity()
def forward(self, x): def forward(self, x):
#B, C, H, W = x.shape B, C, H, W = x.shape
#x = x.view(B, H, W, C) x = x.view(B, H, W, C)
x = x.permute(0, 2, 3, 1)
x = self.res_scale1(x) + \ x = self.res_scale1(x) + \
self.layer_scale1( self.layer_scale1(
self.drop_path1( self.drop_path1(
@ -627,8 +627,7 @@ class MetaFormerBlock(nn.Module):
self.mlp(self.norm2(x)) self.mlp(self.norm2(x))
) )
) )
#x = x.view(B, C, H, W) x = x.view(B, C, H, W)
x = x.permute(0, 3, 1, 2)
return x return x
class MetaFormer(nn.Module): class MetaFormer(nn.Module):
@ -814,31 +813,25 @@ class MetaFormer(nn.Module):
x = self.norm(x) x = self.norm(x)
# (B, H, W, C) -> (B, C) # (B, H, W, C) -> (B, C)
x = self.head(x) x = self.head(x)
#x=self.head(self.norm(x.mean([2, 3])))
return x return x
def forward_features(self, x): def forward_features(self, x):
x = self.patch_embed(x) x = self.patch_embed(x)
print('timm') x = self.stages(x)
#x = self.stages(x)
'''
for i, stage in enumerate(self.stages):
x=stage(x)
#print(x[0][0][0][0])
'''
return x return x
def forward(self, x): def forward(self, x):
x = self.forward_features(x) x = self.forward_features(x)
x = self.forward_head(x) x = self.forward_head(x)
print('timm')
return x return x
def checkpoint_filter_fn(state_dict, model): def checkpoint_filter_fn(state_dict, model):
import re import re
out_dict = {} out_dict = {}
for k, v in state_dict.items(): for k, v in state_dict.items():
'''
k = k.replace('proj', 'conv') k = k.replace('proj', 'conv')
k = re.sub(r'layer_scale_([0-9]+)', r'layer_scale\1.scale', k) k = re.sub(r'layer_scale_([0-9]+)', r'layer_scale\1.scale', k)
k = k.replace('network.1', 'downsample_layers.1') k = k.replace('network.1', 'downsample_layers.1')
@ -848,7 +841,6 @@ def checkpoint_filter_fn(state_dict, model):
k = k.replace('network.4', 'network.2') k = k.replace('network.4', 'network.2')
k = k.replace('network.6', 'network.3') k = k.replace('network.6', 'network.3')
k = k.replace('network', 'stages') k = k.replace('network', 'stages')
'''
k = re.sub(r'downsample_layers.([0-9]+)', r'stages.\1.downsample', k) k = re.sub(r'downsample_layers.([0-9]+)', r'stages.\1.downsample', k)
k = re.sub(r'([0-9]+).([0-9]+)', r'\1.blocks.\2', k) k = re.sub(r'([0-9]+).([0-9]+)', r'\1.blocks.\2', k)
k = k.replace('stages.0.downsample', 'patch_embed') k = k.replace('stages.0.downsample', 'patch_embed')

Loading…
Cancel
Save