|
|
@ -437,8 +437,9 @@ class MetaFormerBlock(nn.Module):
|
|
|
|
if res_scale_init_value else nn.Identity()
|
|
|
|
if res_scale_init_value else nn.Identity()
|
|
|
|
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
def forward(self, x):
|
|
|
|
B, C, H, W = x.shape
|
|
|
|
#B, C, H, W = x.shape
|
|
|
|
x = x.view(B, H, W, C)
|
|
|
|
#x = x.view(B, H, W, C)
|
|
|
|
|
|
|
|
x = x.permute(0, 2, 3, 1)
|
|
|
|
x = self.res_scale1(x) + \
|
|
|
|
x = self.res_scale1(x) + \
|
|
|
|
self.layer_scale1(
|
|
|
|
self.layer_scale1(
|
|
|
|
self.drop_path1(
|
|
|
|
self.drop_path1(
|
|
|
@ -451,7 +452,8 @@ class MetaFormerBlock(nn.Module):
|
|
|
|
self.mlp(self.norm2(x))
|
|
|
|
self.mlp(self.norm2(x))
|
|
|
|
)
|
|
|
|
)
|
|
|
|
)
|
|
|
|
)
|
|
|
|
x = x.view(B, C, H, W)
|
|
|
|
#x = x.view(B, C, H, W)
|
|
|
|
|
|
|
|
x = x.permute(0, 3, 1, 2)
|
|
|
|
return x
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
|
|
class MetaFormer(nn.Module):
|
|
|
|
class MetaFormer(nn.Module):
|
|
|
@ -630,11 +632,12 @@ class MetaFormer(nn.Module):
|
|
|
|
if pre_logits:
|
|
|
|
if pre_logits:
|
|
|
|
return x
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
|
|
x = self.global_pool(x)
|
|
|
|
#x = self.global_pool(x)
|
|
|
|
x = x.squeeze()
|
|
|
|
#x = x.squeeze()
|
|
|
|
x = self.norm(x)
|
|
|
|
#x = self.norm(x)
|
|
|
|
# (B, H, W, C) -> (B, C)
|
|
|
|
# (B, H, W, C) -> (B, C)
|
|
|
|
x = self.head(x)
|
|
|
|
#x = self.head(x)
|
|
|
|
|
|
|
|
x=self.head(self.norm(x.mean([2, 3])))
|
|
|
|
return x
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
|
|
def forward_features(self, x):
|
|
|
|
def forward_features(self, x):
|
|
|
@ -655,6 +658,7 @@ def checkpoint_filter_fn(state_dict, model):
|
|
|
|
import re
|
|
|
|
import re
|
|
|
|
out_dict = {}
|
|
|
|
out_dict = {}
|
|
|
|
for k, v in state_dict.items():
|
|
|
|
for k, v in state_dict.items():
|
|
|
|
|
|
|
|
'''
|
|
|
|
k = k.replace('proj', 'conv')
|
|
|
|
k = k.replace('proj', 'conv')
|
|
|
|
k = re.sub(r'layer_scale_([0-9]+)', r'layer_scale\1.scale', k)
|
|
|
|
k = re.sub(r'layer_scale_([0-9]+)', r'layer_scale\1.scale', k)
|
|
|
|
k = k.replace('network.1', 'downsample_layers.1')
|
|
|
|
k = k.replace('network.1', 'downsample_layers.1')
|
|
|
@ -664,6 +668,7 @@ def checkpoint_filter_fn(state_dict, model):
|
|
|
|
k = k.replace('network.4', 'network.2')
|
|
|
|
k = k.replace('network.4', 'network.2')
|
|
|
|
k = k.replace('network.6', 'network.3')
|
|
|
|
k = k.replace('network.6', 'network.3')
|
|
|
|
k = k.replace('network', 'stages')
|
|
|
|
k = k.replace('network', 'stages')
|
|
|
|
|
|
|
|
'''
|
|
|
|
k = re.sub(r'downsample_layers.([0-9]+)', r'stages.\1.downsample', k)
|
|
|
|
k = re.sub(r'downsample_layers.([0-9]+)', r'stages.\1.downsample', k)
|
|
|
|
k = re.sub(r'([0-9]+).([0-9]+)', r'\1.blocks.\2', k)
|
|
|
|
k = re.sub(r'([0-9]+).([0-9]+)', r'\1.blocks.\2', k)
|
|
|
|
k = k.replace('stages.0.downsample', 'patch_embed')
|
|
|
|
k = k.replace('stages.0.downsample', 'patch_embed')
|
|
|
|