|
|
@ -60,10 +60,7 @@ class Downsampling(nn.Module):
|
|
|
|
x = self.pre_norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
|
|
|
|
x = self.pre_norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
|
|
|
|
|
|
|
|
|
|
|
|
x = self.conv(x)
|
|
|
|
x = self.conv(x)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
x = self.post_norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
|
|
|
|
x = self.post_norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
|
|
|
|
print(x[0][0][0][0])
|
|
|
|
|
|
|
|
return x
|
|
|
|
return x
|
|
|
|
'''
|
|
|
|
'''
|
|
|
|
class Downsampling(nn.Module):
|
|
|
|
class Downsampling(nn.Module):
|
|
|
@ -494,10 +491,11 @@ class MetaFormer(nn.Module):
|
|
|
|
mlp_bias=False,
|
|
|
|
mlp_bias=False,
|
|
|
|
norm_layers=partial(LayerNormGeneral, eps=1e-6, bias=False),
|
|
|
|
norm_layers=partial(LayerNormGeneral, eps=1e-6, bias=False),
|
|
|
|
drop_path_rate=0.,
|
|
|
|
drop_path_rate=0.,
|
|
|
|
head_dropout=0.0,
|
|
|
|
drop_rate=0.0,
|
|
|
|
layer_scale_init_values=None,
|
|
|
|
layer_scale_init_values=None,
|
|
|
|
res_scale_init_values=[None, None, 1.0, 1.0],
|
|
|
|
res_scale_init_values=[None, None, 1.0, 1.0],
|
|
|
|
output_norm=partial(nn.LayerNorm, eps=1e-6),
|
|
|
|
output_norm=partial(nn.LayerNorm, eps=1e-6),
|
|
|
|
|
|
|
|
head_norm_first=False,
|
|
|
|
head_fn=nn.Linear,
|
|
|
|
head_fn=nn.Linear,
|
|
|
|
global_pool = 'avg',
|
|
|
|
global_pool = 'avg',
|
|
|
|
**kwargs,
|
|
|
|
**kwargs,
|
|
|
@ -506,9 +504,8 @@ class MetaFormer(nn.Module):
|
|
|
|
self.num_classes = num_classes
|
|
|
|
self.num_classes = num_classes
|
|
|
|
self.head_fn = head_fn
|
|
|
|
self.head_fn = head_fn
|
|
|
|
self.num_features = dims[-1]
|
|
|
|
self.num_features = dims[-1]
|
|
|
|
self.head_dropout = head_dropout
|
|
|
|
self.drop_rate = drop_rate
|
|
|
|
self.output_norm = output_norm
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if not isinstance(depths, (list, tuple)):
|
|
|
|
if not isinstance(depths, (list, tuple)):
|
|
|
|
depths = [depths] # it means the model has only one stage
|
|
|
|
depths = [depths] # it means the model has only one stage
|
|
|
|
if not isinstance(dims, (list, tuple)):
|
|
|
|
if not isinstance(dims, (list, tuple)):
|
|
|
@ -586,15 +583,16 @@ class MetaFormer(nn.Module):
|
|
|
|
self.feature_info += [dict(num_chs=dims[i], reduction=2, module=f'stages.{i}')]
|
|
|
|
self.feature_info += [dict(num_chs=dims[i], reduction=2, module=f'stages.{i}')]
|
|
|
|
|
|
|
|
|
|
|
|
self.stages = nn.Sequential(*stages)
|
|
|
|
self.stages = nn.Sequential(*stages)
|
|
|
|
self.norm = self.output_norm(self.num_features)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if head_dropout > 0.0:
|
|
|
|
|
|
|
|
self.head = self.head_fn(self.num_features, self.num_classes, head_dropout=self.head_dropout)
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
self.head = self.head_fn(self.num_features, self.num_classes)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# if head_norm_first == true, norm -> global pool -> fc ordering, like most other nets
|
|
|
|
|
|
|
|
# otherwise pool -> norm -> fc, similar to ConvNeXt
|
|
|
|
|
|
|
|
self.norm_pre = output_norm(self.num_features) if head_norm_first else nn.Identity()
|
|
|
|
|
|
|
|
self.head = nn.Sequential(OrderedDict([
|
|
|
|
|
|
|
|
('global_pool', SelectAdaptivePool2d(pool_type=global_pool)),
|
|
|
|
|
|
|
|
('norm', nn.Identity() if head_norm_first else output_norm(self.num_features)),
|
|
|
|
|
|
|
|
('flatten', nn.Flatten(1) if global_pool else nn.Identity()),
|
|
|
|
|
|
|
|
('drop', nn.Dropout(self.drop_rate)),
|
|
|
|
|
|
|
|
('fc', nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity())]))
|
|
|
|
|
|
|
|
|
|
|
|
self.apply(self._init_weights)
|
|
|
|
self.apply(self._init_weights)
|
|
|
|
|
|
|
|
|
|
|
@ -613,40 +611,23 @@ class MetaFormer(nn.Module):
|
|
|
|
return self.head.fc2
|
|
|
|
return self.head.fc2
|
|
|
|
|
|
|
|
|
|
|
|
def reset_classifier(self, num_classes=0, global_pool=None):
|
|
|
|
def reset_classifier(self, num_classes=0, global_pool=None):
|
|
|
|
|
|
|
|
|
|
|
|
if global_pool is not None:
|
|
|
|
if global_pool is not None:
|
|
|
|
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
|
|
|
|
self.head.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
|
|
|
|
|
|
|
|
self.head.flatten = nn.Flatten(1) if global_pool else nn.Identity()
|
|
|
|
|
|
|
|
self.head.fc = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
|
|
|
|
if num_classes == 0:
|
|
|
|
|
|
|
|
self.head = nn.Identity()
|
|
|
|
|
|
|
|
self.norm = nn.Identity()
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
self.norm = self.output_norm(self.num_features)
|
|
|
|
|
|
|
|
if self.head_dropout > 0.0:
|
|
|
|
|
|
|
|
self.head = self.head_fn(self.num_features, num_classes, head_dropout=self.head_dropout)
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
self.head = self.head_fn(self.num_features, num_classes)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def forward_head(self, x, pre_logits: bool = False):
|
|
|
|
def forward_head(self, x, pre_logits: bool = False):
|
|
|
|
if pre_logits:
|
|
|
|
# NOTE nn.Sequential in head broken down since can't call head[:-1](x) in torchscript :(
|
|
|
|
return x
|
|
|
|
x = self.head.global_pool(x)
|
|
|
|
|
|
|
|
x = self.head.norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
|
|
|
|
#x = self.global_pool(x)
|
|
|
|
x = self.head.flatten(x)
|
|
|
|
#x = x.squeeze()
|
|
|
|
x = self.head.drop(x)
|
|
|
|
#x = self.norm(x)
|
|
|
|
return x if pre_logits else self.head.fc(x)
|
|
|
|
# (B, H, W, C) -> (B, C)
|
|
|
|
|
|
|
|
#x = self.head(x)
|
|
|
|
|
|
|
|
x=self.head(self.norm(x.mean([2, 3])))
|
|
|
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def forward_features(self, x):
|
|
|
|
def forward_features(self, x):
|
|
|
|
x = self.patch_embed(x)
|
|
|
|
x = self.patch_embed(x)
|
|
|
|
#x = self.stages(x)
|
|
|
|
x = self.stages(x)
|
|
|
|
for i, stage in enumerate(self.stages):
|
|
|
|
x = self.norm_pre(x)
|
|
|
|
x = stage(x)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return x
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
def forward(self, x):
|
|
|
@ -658,7 +639,6 @@ def checkpoint_filter_fn(state_dict, model):
|
|
|
|
import re
|
|
|
|
import re
|
|
|
|
out_dict = {}
|
|
|
|
out_dict = {}
|
|
|
|
for k, v in state_dict.items():
|
|
|
|
for k, v in state_dict.items():
|
|
|
|
'''
|
|
|
|
|
|
|
|
k = k.replace('proj', 'conv')
|
|
|
|
k = k.replace('proj', 'conv')
|
|
|
|
k = re.sub(r'layer_scale_([0-9]+)', r'layer_scale\1.scale', k)
|
|
|
|
k = re.sub(r'layer_scale_([0-9]+)', r'layer_scale\1.scale', k)
|
|
|
|
k = k.replace('network.1', 'downsample_layers.1')
|
|
|
|
k = k.replace('network.1', 'downsample_layers.1')
|
|
|
@ -668,10 +648,11 @@ def checkpoint_filter_fn(state_dict, model):
|
|
|
|
k = k.replace('network.4', 'network.2')
|
|
|
|
k = k.replace('network.4', 'network.2')
|
|
|
|
k = k.replace('network.6', 'network.3')
|
|
|
|
k = k.replace('network.6', 'network.3')
|
|
|
|
k = k.replace('network', 'stages')
|
|
|
|
k = k.replace('network', 'stages')
|
|
|
|
'''
|
|
|
|
|
|
|
|
k = re.sub(r'downsample_layers.([0-9]+)', r'stages.\1.downsample', k)
|
|
|
|
k = re.sub(r'downsample_layers.([0-9]+)', r'stages.\1.downsample', k)
|
|
|
|
k = re.sub(r'([0-9]+).([0-9]+)', r'\1.blocks.\2', k)
|
|
|
|
k = re.sub(r'([0-9]+).([0-9]+)', r'\1.blocks.\2', k)
|
|
|
|
k = k.replace('stages.0.downsample', 'patch_embed')
|
|
|
|
k = k.replace('stages.0.downsample', 'patch_embed')
|
|
|
|
|
|
|
|
k = re.sub(r'^head', 'head.fc', k)
|
|
|
|
|
|
|
|
k = re.sub(r'^norm', 'head.norm', k)
|
|
|
|
out_dict[k] = v
|
|
|
|
out_dict[k] = v
|
|
|
|
return out_dict
|
|
|
|
return out_dict
|
|
|
|
|
|
|
|
|
|
|
|