From f2c4d6f963588e3e4ba3dd22e00c4b640cc1aa5c Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Wed, 18 Jan 2023 11:04:30 -0800 Subject: [PATCH] Update metaformers.py --- timm/models/metaformers.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/timm/models/metaformers.py b/timm/models/metaformers.py index dda526c2..5d3f7160 100644 --- a/timm/models/metaformers.py +++ b/timm/models/metaformers.py @@ -654,7 +654,7 @@ def checkpoint_filter_fn(state_dict, model): import re out_dict = {} for k, v in state_dict.items(): - k = k.replace('proj', 'conv') + #k = k.replace('proj', 'conv') k = re.sub(r'layer_scale_([0-9]+)', r'layer_scale\1.scale', k) k = k.replace('network.1', 'downsample_layers.1') k = k.replace('network.3', 'downsample_layers.2') @@ -663,6 +663,7 @@ def checkpoint_filter_fn(state_dict, model): k = k.replace('network.4', 'network.2') k = k.replace('network.6', 'network.3') k = k.replace('network', 'stages') + k = re.sub(r'downsample_layers.([0-9]+)', r'stages.\1.downsample', k) k = re.sub(r'([0-9]+).([0-9]+)', r'\1.blocks.\2', k) k = k.replace('stages.0.downsample', 'patch_embed')