Fix pos_embed scaling for ViT and num_classes != 1000 for pretrained distilled deit and pit models. Fix #426 and fix #433

pull/533/head
Ross Wightman 4 years ago
parent a760a4c3f4
commit 7953e5d11a

@ -198,15 +198,19 @@ def load_pretrained(model, default_cfg=None, num_classes=1000, in_chans=3, filte
_logger.warning( _logger.warning(
f'Unable to convert pretrained {input_conv_name} weights, using random init for this layer.') f'Unable to convert pretrained {input_conv_name} weights, using random init for this layer.')
classifier_name = default_cfg.get('classifier', None) classifiers = default_cfg.get('classifier', None)
label_offset = default_cfg.get('label_offset', 0) label_offset = default_cfg.get('label_offset', 0)
if classifier_name is not None: if classifiers is not None:
if isinstance(classifiers, str):
classifiers = (classifiers,)
if num_classes != default_cfg['num_classes']: if num_classes != default_cfg['num_classes']:
for classifier_name in classifiers:
# completely discard fully connected if model num_classes doesn't match pretrained weights # completely discard fully connected if model num_classes doesn't match pretrained weights
del state_dict[classifier_name + '.weight'] del state_dict[classifier_name + '.weight']
del state_dict[classifier_name + '.bias'] del state_dict[classifier_name + '.bias']
strict = False strict = False
elif label_offset > 0: elif label_offset > 0:
for classifier_name in classifiers:
# special case for pretrained weights with an extra background class in pretrained weights # special case for pretrained weights with an extra background class in pretrained weights
classifier_weight = state_dict[classifier_name + '.weight'] classifier_weight = state_dict[classifier_name + '.weight']
state_dict[classifier_name + '.weight'] = classifier_weight[label_offset:] state_dict[classifier_name + '.weight'] = classifier_weight[label_offset:]

@ -49,14 +49,17 @@ default_cfgs = {
'pit_b_224': _cfg( 'pit_b_224': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-pit-weights/pit_b_820.pth'), url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-pit-weights/pit_b_820.pth'),
'pit_ti_distilled_224': _cfg( 'pit_ti_distilled_224': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-pit-weights/pit_ti_distill_746.pth'), url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-pit-weights/pit_ti_distill_746.pth',
classifier=('head', 'head_dist')),
'pit_xs_distilled_224': _cfg( 'pit_xs_distilled_224': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-pit-weights/pit_xs_distill_791.pth'), url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-pit-weights/pit_xs_distill_791.pth',
classifier=('head', 'head_dist')),
'pit_s_distilled_224': _cfg( 'pit_s_distilled_224': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-pit-weights/pit_s_distill_819.pth'), url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-pit-weights/pit_s_distill_819.pth',
classifier=('head', 'head_dist')),
'pit_b_distilled_224': _cfg( 'pit_b_distilled_224': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-pit-weights/pit_b_distill_840.pth'), url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-pit-weights/pit_b_distill_840.pth',
classifier=('head', 'head_dist')),
} }

@ -123,14 +123,17 @@ default_cfgs = {
url='https://dl.fbaipublicfiles.com/deit/deit_base_patch16_384-8de9b5d1.pth', url='https://dl.fbaipublicfiles.com/deit/deit_base_patch16_384-8de9b5d1.pth',
input_size=(3, 384, 384), crop_pct=1.0), input_size=(3, 384, 384), crop_pct=1.0),
'vit_deit_tiny_distilled_patch16_224': _cfg( 'vit_deit_tiny_distilled_patch16_224': _cfg(
url='https://dl.fbaipublicfiles.com/deit/deit_tiny_distilled_patch16_224-b40b3cf7.pth'), url='https://dl.fbaipublicfiles.com/deit/deit_tiny_distilled_patch16_224-b40b3cf7.pth',
classifier=('head', 'head_dist')),
'vit_deit_small_distilled_patch16_224': _cfg( 'vit_deit_small_distilled_patch16_224': _cfg(
url='https://dl.fbaipublicfiles.com/deit/deit_small_distilled_patch16_224-649709d9.pth'), url='https://dl.fbaipublicfiles.com/deit/deit_small_distilled_patch16_224-649709d9.pth',
classifier=('head', 'head_dist')),
'vit_deit_base_distilled_patch16_224': _cfg( 'vit_deit_base_distilled_patch16_224': _cfg(
url='https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_224-df68dfff.pth', ), url='https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_224-df68dfff.pth',
classifier=('head', 'head_dist')),
'vit_deit_base_distilled_patch16_384': _cfg( 'vit_deit_base_distilled_patch16_384': _cfg(
url='https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_384-d0272ac0.pth', url='https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_384-d0272ac0.pth',
input_size=(3, 384, 384), crop_pct=1.0), input_size=(3, 384, 384), crop_pct=1.0, classifier=('head', 'head_dist')),
} }
@ -302,6 +305,7 @@ class VisionTransformer(nn.Module):
super().__init__() super().__init__()
self.num_classes = num_classes self.num_classes = num_classes
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
self.num_tokens = 2 if distilled else 1
norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
if hybrid_backbone is not None: if hybrid_backbone is not None:
@ -313,9 +317,8 @@ class VisionTransformer(nn.Module):
num_patches = self.patch_embed.num_patches num_patches = self.patch_embed.num_patches
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
self.dist_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim)) if distilled else None self.dist_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if distilled else None
num_tokens = 2 if distilled else 1 self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim))
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + num_tokens, embed_dim))
self.pos_drop = nn.Dropout(p=drop_rate) self.pos_drop = nn.Dropout(p=drop_rate)
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
@ -382,10 +385,10 @@ class VisionTransformer(nn.Module):
x = self.pos_drop(x + self.pos_embed) x = self.pos_drop(x + self.pos_embed)
x = self.blocks(x) x = self.blocks(x)
x = self.norm(x) x = self.norm(x)
if self.dist_token is not None: if self.dist_token is None:
return x[:, 0], x[:, 1]
else:
return self.pre_logits(x[:, 0]) return self.pre_logits(x[:, 0])
else:
return x[:, 0], x[:, 1]
def forward(self, x): def forward(self, x):
x = self.forward_features(x) x = self.forward_features(x)
@ -401,15 +404,13 @@ class VisionTransformer(nn.Module):
return x return x
def resize_pos_embed(posemb, posemb_new, token='class'): def resize_pos_embed(posemb, posemb_new, num_tokens=1):
# Rescale the grid of position embeddings when loading from state_dict. Adapted from # Rescale the grid of position embeddings when loading from state_dict. Adapted from
# https://github.com/google-research/vision_transformer/blob/00883dd691c63a6830751563748663526e811cee/vit_jax/checkpoint.py#L224 # https://github.com/google-research/vision_transformer/blob/00883dd691c63a6830751563748663526e811cee/vit_jax/checkpoint.py#L224
_logger.info('Resized position embedding: %s to %s', posemb.shape, posemb_new.shape) _logger.info('Resized position embedding: %s to %s', posemb.shape, posemb_new.shape)
ntok_new = posemb_new.shape[1] ntok_new = posemb_new.shape[1]
if token: if num_tokens:
assert token in ('class', 'distill') posemb_tok, posemb_grid = posemb[:, :num_tokens], posemb[0, num_tokens:]
token_idx = 2 if token == 'distill' else 1
posemb_tok, posemb_grid = posemb[:, :token_idx], posemb[0, token_idx:]
ntok_new -= 1 ntok_new -= 1
else: else:
posemb_tok, posemb_grid = posemb[:, :0], posemb[0] posemb_tok, posemb_grid = posemb[:, :0], posemb[0]
@ -436,7 +437,7 @@ def checkpoint_filter_fn(state_dict, model):
v = v.reshape(O, -1, H, W) v = v.reshape(O, -1, H, W)
elif k == 'pos_embed' and v.shape != model.pos_embed.shape: elif k == 'pos_embed' and v.shape != model.pos_embed.shape:
# To resize pos embedding when using model at different size from pretrained weights # To resize pos embedding when using model at different size from pretrained weights
v = resize_pos_embed(v, model.pos_embed, token='distill' if model.dist_token is not None else 'class') v = resize_pos_embed(v, model.pos_embed, getattr(model, 'num_tokens', 1))
out_dict[k] = v out_dict[k] = v
return out_dict return out_dict

Loading…
Cancel
Save