Unbreak gamma remap impacting beit checkpoint load, version bump to 0.6.4

pull/1345/head
Ross Wightman 2 years ago
parent 1ccce50d48
commit a8e34051c1

@ -10,6 +10,8 @@ Modifications copyright 2021, Ross Wightman
"""
# Copyright (c) 2015-present, Facebook, Inc.
# All rights reserved.
from functools import partial
import torch
from torch import nn as nn
@ -177,7 +179,7 @@ def _create_deit(variant, pretrained=False, distilled=False, **kwargs):
model_cls = VisionTransformerDistilled if distilled else VisionTransformer
model = build_model_with_cfg(
model_cls, variant, pretrained,
pretrained_filter_fn=checkpoint_filter_fn,
pretrained_filter_fn=partial(checkpoint_filter_fn, adapt_layer_scale=True),
**kwargs)
return model

@ -626,7 +626,7 @@ def resize_pos_embed(posemb, posemb_new, num_prefix_tokens=1, gs_new=()):
return posemb
def checkpoint_filter_fn(state_dict, model):
def checkpoint_filter_fn(state_dict, model, adapt_layer_scale=False):
""" convert patch embedding weight from manual patchify + linear proj to conv"""
import re
out_dict = {}
@ -647,7 +647,7 @@ def checkpoint_filter_fn(state_dict, model):
getattr(model, 'num_prefix_tokens', 1),
model.patch_embed.grid_size
)
elif 'gamma_' in k:
elif adapt_layer_scale and 'gamma_' in k:
# remap layer-scale gamma into sub-module (deit3 models)
k = re.sub(r'gamma_([0-9])', r'ls\1.gamma', k)
elif 'pre_logits' in k:

@ -1 +1 @@
__version__ = '0.6.3.dev0'
__version__ = '0.6.4'

Loading…
Cancel
Save