Fix pos_embed scaling for ViT and num_classes != 1000 for pretrained distilled deit and pit models. Fix #426 and fix #433

pull/533/head
Ross Wightman 4 years ago
parent a760a4c3f4
commit 7953e5d11a

@ -198,20 +198,24 @@ def load_pretrained(model, default_cfg=None, num_classes=1000, in_chans=3, filte
_logger.warning(
f'Unable to convert pretrained {input_conv_name} weights, using random init for this layer.')
classifier_name = default_cfg.get('classifier', None)
classifiers = default_cfg.get('classifier', None)
label_offset = default_cfg.get('label_offset', 0)
if classifier_name is not None:
if classifiers is not None:
if isinstance(classifiers, str):
classifiers = (classifiers,)
if num_classes != default_cfg['num_classes']:
# completely discard fully connected if model num_classes doesn't match pretrained weights
del state_dict[classifier_name + '.weight']
del state_dict[classifier_name + '.bias']
for classifier_name in classifiers:
# completely discard fully connected if model num_classes doesn't match pretrained weights
del state_dict[classifier_name + '.weight']
del state_dict[classifier_name + '.bias']
strict = False
elif label_offset > 0:
# special case for pretrained weights with an extra background class in pretrained weights
classifier_weight = state_dict[classifier_name + '.weight']
state_dict[classifier_name + '.weight'] = classifier_weight[label_offset:]
classifier_bias = state_dict[classifier_name + '.bias']
state_dict[classifier_name + '.bias'] = classifier_bias[label_offset:]
for classifier_name in classifiers:
# special case for pretrained weights with an extra background class in pretrained weights
classifier_weight = state_dict[classifier_name + '.weight']
state_dict[classifier_name + '.weight'] = classifier_weight[label_offset:]
classifier_bias = state_dict[classifier_name + '.bias']
state_dict[classifier_name + '.bias'] = classifier_bias[label_offset:]
model.load_state_dict(state_dict, strict=strict)

@ -49,14 +49,17 @@ default_cfgs = {
'pit_b_224': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-pit-weights/pit_b_820.pth'),
'pit_ti_distilled_224': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-pit-weights/pit_ti_distill_746.pth'),
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-pit-weights/pit_ti_distill_746.pth',
classifier=('head', 'head_dist')),
'pit_xs_distilled_224': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-pit-weights/pit_xs_distill_791.pth'),
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-pit-weights/pit_xs_distill_791.pth',
classifier=('head', 'head_dist')),
'pit_s_distilled_224': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-pit-weights/pit_s_distill_819.pth'),
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-pit-weights/pit_s_distill_819.pth',
classifier=('head', 'head_dist')),
'pit_b_distilled_224': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-pit-weights/pit_b_distill_840.pth'),
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-pit-weights/pit_b_distill_840.pth',
classifier=('head', 'head_dist')),
}

@ -123,14 +123,17 @@ default_cfgs = {
url='https://dl.fbaipublicfiles.com/deit/deit_base_patch16_384-8de9b5d1.pth',
input_size=(3, 384, 384), crop_pct=1.0),
'vit_deit_tiny_distilled_patch16_224': _cfg(
url='https://dl.fbaipublicfiles.com/deit/deit_tiny_distilled_patch16_224-b40b3cf7.pth'),
url='https://dl.fbaipublicfiles.com/deit/deit_tiny_distilled_patch16_224-b40b3cf7.pth',
classifier=('head', 'head_dist')),
'vit_deit_small_distilled_patch16_224': _cfg(
url='https://dl.fbaipublicfiles.com/deit/deit_small_distilled_patch16_224-649709d9.pth'),
url='https://dl.fbaipublicfiles.com/deit/deit_small_distilled_patch16_224-649709d9.pth',
classifier=('head', 'head_dist')),
'vit_deit_base_distilled_patch16_224': _cfg(
url='https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_224-df68dfff.pth', ),
url='https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_224-df68dfff.pth',
classifier=('head', 'head_dist')),
'vit_deit_base_distilled_patch16_384': _cfg(
url='https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_384-d0272ac0.pth',
input_size=(3, 384, 384), crop_pct=1.0),
input_size=(3, 384, 384), crop_pct=1.0, classifier=('head', 'head_dist')),
}
@ -302,6 +305,7 @@ class VisionTransformer(nn.Module):
super().__init__()
self.num_classes = num_classes
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
self.num_tokens = 2 if distilled else 1
norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
if hybrid_backbone is not None:
@ -313,9 +317,8 @@ class VisionTransformer(nn.Module):
num_patches = self.patch_embed.num_patches
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
self.dist_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim)) if distilled else None
num_tokens = 2 if distilled else 1
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + num_tokens, embed_dim))
self.dist_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if distilled else None
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim))
self.pos_drop = nn.Dropout(p=drop_rate)
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
@ -382,10 +385,10 @@ class VisionTransformer(nn.Module):
x = self.pos_drop(x + self.pos_embed)
x = self.blocks(x)
x = self.norm(x)
if self.dist_token is not None:
return x[:, 0], x[:, 1]
else:
if self.dist_token is None:
return self.pre_logits(x[:, 0])
else:
return x[:, 0], x[:, 1]
def forward(self, x):
x = self.forward_features(x)
@ -401,15 +404,13 @@ class VisionTransformer(nn.Module):
return x
def resize_pos_embed(posemb, posemb_new, token='class'):
def resize_pos_embed(posemb, posemb_new, num_tokens=1):
# Rescale the grid of position embeddings when loading from state_dict. Adapted from
# https://github.com/google-research/vision_transformer/blob/00883dd691c63a6830751563748663526e811cee/vit_jax/checkpoint.py#L224
_logger.info('Resized position embedding: %s to %s', posemb.shape, posemb_new.shape)
ntok_new = posemb_new.shape[1]
if token:
assert token in ('class', 'distill')
token_idx = 2 if token == 'distill' else 1
posemb_tok, posemb_grid = posemb[:, :token_idx], posemb[0, token_idx:]
if num_tokens:
posemb_tok, posemb_grid = posemb[:, :num_tokens], posemb[0, num_tokens:]
ntok_new -= 1
else:
posemb_tok, posemb_grid = posemb[:, :0], posemb[0]
@ -436,7 +437,7 @@ def checkpoint_filter_fn(state_dict, model):
v = v.reshape(O, -1, H, W)
elif k == 'pos_embed' and v.shape != model.pos_embed.shape:
# To resize pos embedding when using model at different size from pretrained weights
v = resize_pos_embed(v, model.pos_embed, token='distill' if model.dist_token is not None else 'class')
v = resize_pos_embed(v, model.pos_embed, getattr(model, 'num_tokens', 1))
out_dict[k] = v
return out_dict

Loading…
Cancel
Save