From 7953e5d11af1dbef49fd60d9aeaba8c1d740096c Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Wed, 31 Mar 2021 23:11:28 -0700 Subject: [PATCH] Fix pos_embed scaling for ViT and num_classes != 1000 for pretrained distilled deit and pit models. Fix #426 and fix #433 --- timm/models/helpers.py | 24 ++++++++++++---------- timm/models/pit.py | 13 +++++++----- timm/models/vision_transformer.py | 33 ++++++++++++++++--------------- 3 files changed, 39 insertions(+), 31 deletions(-) diff --git a/timm/models/helpers.py b/timm/models/helpers.py index 2f6e098b..e9ac7f00 100644 --- a/timm/models/helpers.py +++ b/timm/models/helpers.py @@ -198,20 +198,24 @@ def load_pretrained(model, default_cfg=None, num_classes=1000, in_chans=3, filte _logger.warning( f'Unable to convert pretrained {input_conv_name} weights, using random init for this layer.') - classifier_name = default_cfg.get('classifier', None) + classifiers = default_cfg.get('classifier', None) label_offset = default_cfg.get('label_offset', 0) - if classifier_name is not None: + if classifiers is not None: + if isinstance(classifiers, str): + classifiers = (classifiers,) if num_classes != default_cfg['num_classes']: - # completely discard fully connected if model num_classes doesn't match pretrained weights - del state_dict[classifier_name + '.weight'] - del state_dict[classifier_name + '.bias'] + for classifier_name in classifiers: + # completely discard fully connected if model num_classes doesn't match pretrained weights + del state_dict[classifier_name + '.weight'] + del state_dict[classifier_name + '.bias'] strict = False elif label_offset > 0: - # special case for pretrained weights with an extra background class in pretrained weights - classifier_weight = state_dict[classifier_name + '.weight'] - state_dict[classifier_name + '.weight'] = classifier_weight[label_offset:] - classifier_bias = state_dict[classifier_name + '.bias'] - state_dict[classifier_name + '.bias'] = classifier_bias[label_offset:] + for classifier_name in classifiers: + # special case for pretrained weights with an extra background class in pretrained weights + classifier_weight = state_dict[classifier_name + '.weight'] + state_dict[classifier_name + '.weight'] = classifier_weight[label_offset:] + classifier_bias = state_dict[classifier_name + '.bias'] + state_dict[classifier_name + '.bias'] = classifier_bias[label_offset:] model.load_state_dict(state_dict, strict=strict) diff --git a/timm/models/pit.py b/timm/models/pit.py index 2137bea8..1cee4d04 100644 --- a/timm/models/pit.py +++ b/timm/models/pit.py @@ -49,14 +49,17 @@ default_cfgs = { 'pit_b_224': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-pit-weights/pit_b_820.pth'), 'pit_ti_distilled_224': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-pit-weights/pit_ti_distill_746.pth'), + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-pit-weights/pit_ti_distill_746.pth', + classifier=('head', 'head_dist')), 'pit_xs_distilled_224': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-pit-weights/pit_xs_distill_791.pth'), + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-pit-weights/pit_xs_distill_791.pth', + classifier=('head', 'head_dist')), 'pit_s_distilled_224': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-pit-weights/pit_s_distill_819.pth'), + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-pit-weights/pit_s_distill_819.pth', + classifier=('head', 'head_dist')), 'pit_b_distilled_224': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-pit-weights/pit_b_distill_840.pth'), - + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-pit-weights/pit_b_distill_840.pth', + classifier=('head', 'head_dist')), } diff --git a/timm/models/vision_transformer.py b/timm/models/vision_transformer.py index c7c9027d..c05871b8 100644 --- a/timm/models/vision_transformer.py +++ b/timm/models/vision_transformer.py @@ -123,14 +123,17 @@ default_cfgs = { url='https://dl.fbaipublicfiles.com/deit/deit_base_patch16_384-8de9b5d1.pth', input_size=(3, 384, 384), crop_pct=1.0), 'vit_deit_tiny_distilled_patch16_224': _cfg( - url='https://dl.fbaipublicfiles.com/deit/deit_tiny_distilled_patch16_224-b40b3cf7.pth'), + url='https://dl.fbaipublicfiles.com/deit/deit_tiny_distilled_patch16_224-b40b3cf7.pth', + classifier=('head', 'head_dist')), 'vit_deit_small_distilled_patch16_224': _cfg( - url='https://dl.fbaipublicfiles.com/deit/deit_small_distilled_patch16_224-649709d9.pth'), + url='https://dl.fbaipublicfiles.com/deit/deit_small_distilled_patch16_224-649709d9.pth', + classifier=('head', 'head_dist')), 'vit_deit_base_distilled_patch16_224': _cfg( - url='https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_224-df68dfff.pth', ), + url='https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_224-df68dfff.pth', + classifier=('head', 'head_dist')), 'vit_deit_base_distilled_patch16_384': _cfg( url='https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_384-d0272ac0.pth', - input_size=(3, 384, 384), crop_pct=1.0), + input_size=(3, 384, 384), crop_pct=1.0, classifier=('head', 'head_dist')), } @@ -302,6 +305,7 @@ class VisionTransformer(nn.Module): super().__init__() self.num_classes = num_classes self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models + self.num_tokens = 2 if distilled else 1 norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) if hybrid_backbone is not None: @@ -313,9 +317,8 @@ class VisionTransformer(nn.Module): num_patches = self.patch_embed.num_patches self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) - self.dist_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim)) if distilled else None - num_tokens = 2 if distilled else 1 - self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + num_tokens, embed_dim)) + self.dist_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if distilled else None + self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim)) self.pos_drop = nn.Dropout(p=drop_rate) dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule @@ -382,10 +385,10 @@ class VisionTransformer(nn.Module): x = self.pos_drop(x + self.pos_embed) x = self.blocks(x) x = self.norm(x) - if self.dist_token is not None: - return x[:, 0], x[:, 1] - else: + if self.dist_token is None: return self.pre_logits(x[:, 0]) + else: + return x[:, 0], x[:, 1] def forward(self, x): x = self.forward_features(x) @@ -401,15 +404,13 @@ class VisionTransformer(nn.Module): return x -def resize_pos_embed(posemb, posemb_new, token='class'): +def resize_pos_embed(posemb, posemb_new, num_tokens=1): # Rescale the grid of position embeddings when loading from state_dict. Adapted from # https://github.com/google-research/vision_transformer/blob/00883dd691c63a6830751563748663526e811cee/vit_jax/checkpoint.py#L224 _logger.info('Resized position embedding: %s to %s', posemb.shape, posemb_new.shape) ntok_new = posemb_new.shape[1] - if token: - assert token in ('class', 'distill') - token_idx = 2 if token == 'distill' else 1 - posemb_tok, posemb_grid = posemb[:, :token_idx], posemb[0, token_idx:] + if num_tokens: + posemb_tok, posemb_grid = posemb[:, :num_tokens], posemb[0, num_tokens:] ntok_new -= 1 else: posemb_tok, posemb_grid = posemb[:, :0], posemb[0] @@ -436,7 +437,7 @@ def checkpoint_filter_fn(state_dict, model): v = v.reshape(O, -1, H, W) elif k == 'pos_embed' and v.shape != model.pos_embed.shape: # To resize pos embedding when using model at different size from pretrained weights - v = resize_pos_embed(v, model.pos_embed, token='distill' if model.dist_token is not None else 'class') + v = resize_pos_embed(v, model.pos_embed, getattr(model, 'num_tokens', 1)) out_dict[k] = v return out_dict