|
|
|
@ -543,7 +543,7 @@ class MetaFormer(nn.Module):
|
|
|
|
|
dp_rates=[x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.patch_embed = Downsampling(
|
|
|
|
|
self.stem = Downsampling(
|
|
|
|
|
in_chans,
|
|
|
|
|
dims[0],
|
|
|
|
|
kernel_size=7,
|
|
|
|
@ -670,6 +670,7 @@ def checkpoint_filter_fn(state_dict, model):
|
|
|
|
|
k = k.replace('patch_embed.proj', 'patch_embed.conv')
|
|
|
|
|
k = re.sub(r'([0-9]+).([0-9]+)', r'\1.blocks.\2', k)
|
|
|
|
|
k = k.replace('stages.0.downsample', 'patch_embed')
|
|
|
|
|
k = k.replace('patch_embed', 'stem')
|
|
|
|
|
k = re.sub(r'^head', 'head.fc', k)
|
|
|
|
|
k = re.sub(r'^norm', 'head.norm', k)
|
|
|
|
|
out_dict[k] = v
|
|
|
|
|