diff --git a/timm/models/metaformers.py b/timm/models/metaformers.py index 27d29a0a..672599a7 100644 --- a/timm/models/metaformers.py +++ b/timm/models/metaformers.py @@ -654,7 +654,7 @@ def checkpoint_filter_fn(state_dict, model): import re out_dict = {} for k, v in state_dict.items(): - k = k.replace('patch_embed.proj', 'patch_embed.conv') + k = k.replace('downsample.proj', 'downsample.conv') k = re.sub(r'layer_scale_([0-9]+)', r'layer_scale\1.scale', k) k = k.replace('network.1', 'downsample_layers.1') k = k.replace('network.3', 'downsample_layers.2')