From a8e34051c1de050421abd50fbc1201d125a50fe7 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Thu, 7 Jul 2022 23:07:43 -0700 Subject: [PATCH] Unbreak gamma remap impacting beit checkpoint load, version bump to 0.6.4 --- timm/models/deit.py | 4 +++- timm/models/vision_transformer.py | 4 ++-- timm/version.py | 2 +- 3 files changed, 6 insertions(+), 4 deletions(-) diff --git a/timm/models/deit.py b/timm/models/deit.py index a2f43b91..8cb36bd6 100644 --- a/timm/models/deit.py +++ b/timm/models/deit.py @@ -10,6 +10,8 @@ Modifications copyright 2021, Ross Wightman """ # Copyright (c) 2015-present, Facebook, Inc. # All rights reserved. +from functools import partial + import torch from torch import nn as nn @@ -177,7 +179,7 @@ def _create_deit(variant, pretrained=False, distilled=False, **kwargs): model_cls = VisionTransformerDistilled if distilled else VisionTransformer model = build_model_with_cfg( model_cls, variant, pretrained, - pretrained_filter_fn=checkpoint_filter_fn, + pretrained_filter_fn=partial(checkpoint_filter_fn, adapt_layer_scale=True), **kwargs) return model diff --git a/timm/models/vision_transformer.py b/timm/models/vision_transformer.py index 022052d0..c92c22a3 100644 --- a/timm/models/vision_transformer.py +++ b/timm/models/vision_transformer.py @@ -626,7 +626,7 @@ def resize_pos_embed(posemb, posemb_new, num_prefix_tokens=1, gs_new=()): return posemb -def checkpoint_filter_fn(state_dict, model): +def checkpoint_filter_fn(state_dict, model, adapt_layer_scale=False): """ convert patch embedding weight from manual patchify + linear proj to conv""" import re out_dict = {} @@ -647,7 +647,7 @@ def checkpoint_filter_fn(state_dict, model): getattr(model, 'num_prefix_tokens', 1), model.patch_embed.grid_size ) - elif 'gamma_' in k: + elif adapt_layer_scale and 'gamma_' in k: # remap layer-scale gamma into sub-module (deit3 models) k = re.sub(r'gamma_([0-9])', r'ls\1.gamma', k) elif 'pre_logits' in k: diff --git a/timm/version.py b/timm/version.py index 7165c7fa..02f8497c 100644 --- a/timm/version.py +++ b/timm/version.py @@ -1 +1 @@ -__version__ = '0.6.3.dev0' +__version__ = '0.6.4'