Unbreak gamma remap impacting beit checkpoint load, version bump to 0.6.4

pull/1345/head
Ross Wightman 2 years ago
parent 1ccce50d48
commit a8e34051c1

@ -10,6 +10,8 @@ Modifications copyright 2021, Ross Wightman
""" """
# Copyright (c) 2015-present, Facebook, Inc. # Copyright (c) 2015-present, Facebook, Inc.
# All rights reserved. # All rights reserved.
from functools import partial
import torch import torch
from torch import nn as nn from torch import nn as nn
@ -177,7 +179,7 @@ def _create_deit(variant, pretrained=False, distilled=False, **kwargs):
model_cls = VisionTransformerDistilled if distilled else VisionTransformer model_cls = VisionTransformerDistilled if distilled else VisionTransformer
model = build_model_with_cfg( model = build_model_with_cfg(
model_cls, variant, pretrained, model_cls, variant, pretrained,
pretrained_filter_fn=checkpoint_filter_fn, pretrained_filter_fn=partial(checkpoint_filter_fn, adapt_layer_scale=True),
**kwargs) **kwargs)
return model return model

@ -626,7 +626,7 @@ def resize_pos_embed(posemb, posemb_new, num_prefix_tokens=1, gs_new=()):
return posemb return posemb
def checkpoint_filter_fn(state_dict, model): def checkpoint_filter_fn(state_dict, model, adapt_layer_scale=False):
""" convert patch embedding weight from manual patchify + linear proj to conv""" """ convert patch embedding weight from manual patchify + linear proj to conv"""
import re import re
out_dict = {} out_dict = {}
@ -647,7 +647,7 @@ def checkpoint_filter_fn(state_dict, model):
getattr(model, 'num_prefix_tokens', 1), getattr(model, 'num_prefix_tokens', 1),
model.patch_embed.grid_size model.patch_embed.grid_size
) )
elif 'gamma_' in k: elif adapt_layer_scale and 'gamma_' in k:
# remap layer-scale gamma into sub-module (deit3 models) # remap layer-scale gamma into sub-module (deit3 models)
k = re.sub(r'gamma_([0-9])', r'ls\1.gamma', k) k = re.sub(r'gamma_([0-9])', r'ls\1.gamma', k)
elif 'pre_logits' in k: elif 'pre_logits' in k:

@ -1 +1 @@
__version__ = '0.6.3.dev0' __version__ = '0.6.4'

Loading…
Cancel
Save