From c16e9650371d167dcb38669aa1280caba2c69dcd Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Sun, 24 Jan 2021 23:18:35 -0800 Subject: [PATCH] Add some ViT comments and fix a few minor issues. --- timm/models/vision_transformer.py | 145 ++++++++++++++++++++++-------- 1 file changed, 110 insertions(+), 35 deletions(-) diff --git a/timm/models/vision_transformer.py b/timm/models/vision_transformer.py index a832cce3..90122090 100644 --- a/timm/models/vision_transformer.py +++ b/timm/models/vision_transformer.py @@ -100,10 +100,10 @@ default_cfgs = { # hybrid models (weights ported from official Google JAX impl) 'vit_base_resnet50_224_in21k': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_resnet50_224_in21k-6f7c7740.pth', - num_classes=21843, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=0.9), + num_classes=21843, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=0.9, first_conv='patch_embed.backbone.stem.conv'), 'vit_base_resnet50_384': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_resnet50_384-9fd3c705.pth', - input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0), + input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0, first_conv='patch_embed.backbone.stem.conv'), # hybrid models (my experiments) 'vit_small_resnet26d_224': _cfg(), @@ -256,11 +256,33 @@ class HybridEmbed(nn.Module): class VisionTransformer(nn.Module): - """ Vision Transformer with support for patch or hybrid CNN input stage + """ Vision Transformer + + A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale` - + https://arxiv.org/abs/2010.11929 """ def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4., qkv_bias=True, qk_scale=None, representation_size=None, drop_rate=0., attn_drop_rate=0., drop_path_rate=0., hybrid_backbone=None, norm_layer=None): + """ + Args: + img_size (int, tuple): input image size + patch_size (int, tuple): patch size + in_chans (int): number of input channels + num_classes (int): number of classes for classification head + embed_dim (int): embedding dimension + depth (int): depth of transformer + num_heads (int): number of attention heads + mlp_ratio (int): ratio of mlp hidden dim to embedding dim + qkv_bias (bool): enable bias for qkv if True + qk_scale (float): override default qk scale of head_dim ** -0.5 if set + representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set + drop_rate (float): dropout rate + attn_drop_rate (float): attention dropout rate + drop_path_rate (float): stochastic depth rate + hybrid_backbone (nn.Module): CNN backbone to use in-place of PatchEmbed module + norm_layer: (nn.Module): normalization layer + """ super().__init__() self.num_classes = num_classes self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models @@ -346,8 +368,7 @@ class VisionTransformer(nn.Module): def resize_pos_embed(posemb, posemb_new): - # Rescale the grid of position embeddings when loading from state_dict - # Adapted from + # Rescale the grid of position embeddings when loading from state_dict. Adapted from # https://github.com/google-research/vision_transformer/blob/00883dd691c63a6830751563748663526e811cee/vit_jax/checkpoint.py#L224 _logger.info('Resized position embedding: %s to %s', posemb.shape, posemb_new.shape) ntok_new = posemb_new.shape[1] @@ -363,22 +384,21 @@ def resize_pos_embed(posemb, posemb_new): posemb_grid = F.interpolate(posemb_grid, size=(gs_new, gs_new), mode='bilinear') posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_new * gs_new, -1) posemb = torch.cat([posemb_tok, posemb_grid], dim=1) - state_dict['pos_embed'] = posemb - return state_dict + return posemb def checkpoint_filter_fn(state_dict, model): """ convert patch embedding weight from manual patchify + linear proj to conv""" out_dict = {} if 'model' in state_dict: - # for deit models + # For deit models state_dict = state_dict['model'] for k, v in state_dict.items(): if 'patch_embed.proj.weight' in k and len(v.shape) < 4: - # for old models that I trained prior to conv based patchification + # For old models that I trained prior to conv based patchification v = v.reshape(model.patch_embed.proj.weight.shape) elif k == 'pos_embed' and v.shape != model.pos_embed.shape: - # to resize pos embedding when using model at different size from pretrained weights + # To resize pos embedding when using model at different size from pretrained weights v = resize_pos_embed(v, model.pos_embed) out_dict[k] = v return out_dict @@ -393,8 +413,9 @@ def _create_vision_transformer(variant, pretrained=False, **kwargs): img_size = kwargs.pop('img_size', default_img_size) repr_size = kwargs.pop('representation_size', None) if repr_size is not None and num_classes != default_num_classes: - # remove representation layer if fine-tuning - _logger.info("Removing representation layer for fine-tuning.") + # Remove representation layer if fine-tuning. This may not always be the desired action, + # but I feel better than doing nothing by default for fine-tuning. Perhaps a better interface? + _logger.warning("Removing representation layer for fine-tuning.") repr_size = None model = VisionTransformer(img_size=img_size, num_classes=num_classes, representation_size=repr_size, **kwargs) @@ -409,6 +430,7 @@ def _create_vision_transformer(variant, pretrained=False, **kwargs): @register_model def vit_small_patch16_224(pretrained=False, **kwargs): + """ My custom 'small' ViT model. Depth=8, heads=8= mlp_ratio=3.""" model_kwargs = dict( patch_size=16, embed_dim=768, depth=8, num_heads=8, mlp_ratio=3., qkv_bias=False, norm_layer=nn.LayerNorm, **kwargs) @@ -421,6 +443,9 @@ def vit_small_patch16_224(pretrained=False, **kwargs): @register_model def vit_base_patch16_224(pretrained=False, **kwargs): + """ ViT-Base (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929). + ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer. + """ model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, **kwargs) model = _create_vision_transformer('vit_base_patch16_224', pretrained=pretrained, **model_kwargs) return model @@ -428,6 +453,8 @@ def vit_base_patch16_224(pretrained=False, **kwargs): @register_model def vit_base_patch32_224(pretrained=False, **kwargs): + """ ViT-Base (ViT-B/32) from original paper (https://arxiv.org/abs/2010.11929). No pretrained weights. + """ model_kwargs = dict(patch_size=32, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, **kwargs) model = _create_vision_transformer('vit_base_patch32_224', pretrained=pretrained, **model_kwargs) return model @@ -435,6 +462,9 @@ def vit_base_patch32_224(pretrained=False, **kwargs): @register_model def vit_base_patch16_384(pretrained=False, **kwargs): + """ ViT-Base model (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929). + ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer. + """ model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, **kwargs) model = _create_vision_transformer('vit_base_patch16_384', pretrained=pretrained, **model_kwargs) return model @@ -442,6 +472,9 @@ def vit_base_patch16_384(pretrained=False, **kwargs): @register_model def vit_base_patch32_384(pretrained=False, **kwargs): + """ ViT-Base model (ViT-B/32) from original paper (https://arxiv.org/abs/2010.11929). + ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer. + """ model_kwargs = dict( patch_size=32, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) @@ -451,6 +484,9 @@ def vit_base_patch32_384(pretrained=False, **kwargs): @register_model def vit_large_patch16_224(pretrained=False, **kwargs): + """ ViT-Large model (ViT-L/32) from original paper (https://arxiv.org/abs/2010.11929). + ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer. + """ model_kwargs = dict(patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, **kwargs) model = _create_vision_transformer('vit_large_patch16_224', pretrained=pretrained, **model_kwargs) return model @@ -458,6 +494,8 @@ def vit_large_patch16_224(pretrained=False, **kwargs): @register_model def vit_large_patch32_224(pretrained=False, **kwargs): + """ ViT-Large model (ViT-L/32) from original paper (https://arxiv.org/abs/2010.11929). No pretrained weights. + """ model_kwargs = dict(patch_size=32, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, **kwargs) model = _create_vision_transformer('vit_large_patch32_224', pretrained=pretrained, **model_kwargs) return model @@ -465,21 +503,29 @@ def vit_large_patch32_224(pretrained=False, **kwargs): @register_model def vit_large_patch16_384(pretrained=False, **kwargs): + """ ViT-Large model (ViT-L/16) from original paper (https://arxiv.org/abs/2010.11929). + ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer. + """ model_kwargs = dict(patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, **kwargs) model = _create_vision_transformer('vit_large_patch16_384', pretrained=pretrained, **model_kwargs) return model @register_model -def vit_base_patch16_224_in21k(pretrained=False, **kwargs): - model_kwargs = dict( - patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, representation_size=768, **kwargs) - model = _create_vision_transformer('vit_base_patch16_224_in21k', pretrained=pretrained, **model_kwargs) +def vit_large_patch32_384(pretrained=False, **kwargs): + """ ViT-Large model (ViT-L/32) from original paper (https://arxiv.org/abs/2010.11929). + ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer. + """ + model_kwargs = dict(patch_size=32, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, **kwargs) + model = _create_vision_transformer('vit_large_patch32_384', pretrained=pretrained, **model_kwargs) return model @register_model -def vit_base_patch16_384_in21k(pretrained=False, **kwargs): +def vit_base_patch16_224_in21k(pretrained=False, **kwargs): + """ ViT-Base model (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929). + ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer. + """ model_kwargs = dict( patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, representation_size=768, **kwargs) model = _create_vision_transformer('vit_base_patch16_224_in21k', pretrained=pretrained, **model_kwargs) @@ -488,6 +534,9 @@ def vit_base_patch16_384_in21k(pretrained=False, **kwargs): @register_model def vit_base_patch32_224_in21k(pretrained=False, **kwargs): + """ ViT-Base model (ViT-B/32) from original paper (https://arxiv.org/abs/2010.11929). + ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer. + """ model_kwargs = dict( patch_size=32, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, representation_size=768, **kwargs) model = _create_vision_transformer('vit_base_patch32_224_in21k', pretrained=pretrained, **model_kwargs) @@ -496,22 +545,20 @@ def vit_base_patch32_224_in21k(pretrained=False, **kwargs): @register_model def vit_large_patch16_224_in21k(pretrained=False, **kwargs): + """ ViT-Large model (ViT-L/16) from original paper (https://arxiv.org/abs/2010.11929). + ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer. + """ model_kwargs = dict( patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, representation_size=1024, **kwargs) model = _create_vision_transformer('vit_large_patch16_224_in21k', pretrained=pretrained, **model_kwargs) return model -# @register_model -# def vit_large_patch16_384_in21k(pretrained=False, **kwargs): -# model_kwargs = dict( -# patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, representation_size=1024, **kwargs) -# model = _create_vision_transformer('vit_large_patch16_224_in21k', pretrained=pretrained, **model_kwargs) -# return model - - @register_model def vit_large_patch32_224_in21k(pretrained=False, **kwargs): + """ ViT-Large model (ViT-L/32) from original paper (https://arxiv.org/abs/2010.11929). + ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer. + """ model_kwargs = dict( patch_size=32, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, representation_size=1024, **kwargs) model = _create_vision_transformer('vit_large_patch32_224_in21k', pretrained=pretrained, **model_kwargs) @@ -520,6 +567,10 @@ def vit_large_patch32_224_in21k(pretrained=False, **kwargs): @register_model def vit_huge_patch14_224_in21k(pretrained=False, **kwargs): + """ ViT-Huge model (ViT-H/14) from original paper (https://arxiv.org/abs/2010.11929). + ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer. + NOTE: converted weights not currently available, too large for github release hosting. + """ model_kwargs = dict( patch_size=14, embed_dim=1280, depth=32, num_heads=16, mlp_ratio=4, representation_size=1280, **kwargs) model = _create_vision_transformer('vit_huge_patch14_224_in21k', pretrained=pretrained, **model_kwargs) @@ -528,9 +579,13 @@ def vit_huge_patch14_224_in21k(pretrained=False, **kwargs): @register_model def vit_base_resnet50_224_in21k(pretrained=False, **kwargs): + """ R50+ViT-B/16 hybrid model from original paper (https://arxiv.org/abs/2010.11929). + ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer. + """ # create a ResNetV2 w/o pre-activation, that uses StdConv and GroupNorm and has 3 stages, no head backbone = ResNetV2( - layers=(3, 4, 9), preact=False, stem_type='same', conv_layer=StdConv2dSame, num_classes=0, global_pool='') + layers=(3, 4, 9), num_classes=0, global_pool='', in_chans=kwargs.get('in_chans', 3), + preact=False, stem_type='same', conv_layer=StdConv2dSame) model_kwargs = dict( embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, hybrid_backbone=backbone, representation_size=768, **kwargs) @@ -540,9 +595,13 @@ def vit_base_resnet50_224_in21k(pretrained=False, **kwargs): @register_model def vit_base_resnet50_384(pretrained=False, **kwargs): + """ R50+ViT-B/16 hybrid from original paper (https://arxiv.org/abs/2010.11929). + ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer. + """ # create a ResNetV2 w/o pre-activation, that uses StdConv and GroupNorm and has 3 stages, no head backbone = ResNetV2( - layers=(3, 4, 9), preact=False, stem_type='same', conv_layer=StdConv2dSame, num_classes=0, global_pool='') + layers=(3, 4, 9), num_classes=0, global_pool='', in_chans=kwargs.get('in_chans', 3), + preact=False, stem_type='same', conv_layer=StdConv2dSame) model_kwargs = dict(embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, hybrid_backbone=backbone, **kwargs) model = _create_vision_transformer('vit_base_resnet50_384', pretrained=pretrained, **model_kwargs) return model @@ -550,8 +609,9 @@ def vit_base_resnet50_384(pretrained=False, **kwargs): @register_model def vit_small_resnet26d_224(pretrained=False, **kwargs): - pretrained_backbone = kwargs.get('pretrained_backbone', True) # default to True for now, for testing - backbone = resnet26d(pretrained=pretrained_backbone, features_only=True, out_indices=[4]) + """ Custom ViT small hybrid w/ ResNet26D stride 32. No pretrained weights. + """ + backbone = resnet26d(pretrained=pretrained, features_only=True, out_indices=[4]) model_kwargs = dict(embed_dim=768, depth=8, num_heads=8, mlp_ratio=3, hybrid_backbone=backbone, **kwargs) model = _create_vision_transformer('vit_small_resnet26d_224', pretrained=pretrained, **model_kwargs) return model @@ -559,8 +619,9 @@ def vit_small_resnet26d_224(pretrained=False, **kwargs): @register_model def vit_small_resnet50d_s3_224(pretrained=False, **kwargs): - pretrained_backbone = kwargs.get('pretrained_backbone', True) # default to True for now, for testing - backbone = resnet50d(pretrained=pretrained_backbone, features_only=True, out_indices=[3]) + """ Custom ViT small hybrid w/ ResNet50D 3-stages, stride 16. No pretrained weights. + """ + backbone = resnet50d(pretrained=pretrained, features_only=True, out_indices=[3]) model_kwargs = dict(embed_dim=768, depth=8, num_heads=8, mlp_ratio=3, hybrid_backbone=backbone, **kwargs) model = _create_vision_transformer('vit_small_resnet50d_s3_224', pretrained=pretrained, **model_kwargs) return model @@ -568,8 +629,9 @@ def vit_small_resnet50d_s3_224(pretrained=False, **kwargs): @register_model def vit_base_resnet26d_224(pretrained=False, **kwargs): - pretrained_backbone = kwargs.get('pretrained_backbone', True) # default to True for now, for testing - backbone = resnet26d(pretrained=pretrained_backbone, features_only=True, out_indices=[4]) + """ Custom ViT base hybrid w/ ResNet26D stride 32. No pretrained weights. + """ + backbone = resnet26d(pretrained=pretrained, features_only=True, out_indices=[4]) model_kwargs = dict(embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, hybrid_backbone=backbone, **kwargs) model = _create_vision_transformer('vit_base_resnet26d_224', pretrained=pretrained, **model_kwargs) return model @@ -577,8 +639,9 @@ def vit_base_resnet26d_224(pretrained=False, **kwargs): @register_model def vit_base_resnet50d_224(pretrained=False, **kwargs): - pretrained_backbone = kwargs.get('pretrained_backbone', True) # default to True for now, for testing - backbone = resnet50d(pretrained=pretrained_backbone, features_only=True, out_indices=[4]) + """ Custom ViT base hybrid w/ ResNet50D stride 32. No pretrained weights. + """ + backbone = resnet50d(pretrained=pretrained, features_only=True, out_indices=[4]) model_kwargs = dict(embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, hybrid_backbone=backbone, **kwargs) model = _create_vision_transformer('vit_base_resnet50d_224', pretrained=pretrained, **model_kwargs) return model @@ -586,6 +649,9 @@ def vit_base_resnet50d_224(pretrained=False, **kwargs): @register_model def vit_deit_tiny_patch16_224(pretrained=False, **kwargs): + """ DeiT-tiny model @ 224x224 from paper (https://arxiv.org/abs/2012.12877). + ImageNet-1k weights from https://github.com/facebookresearch/deit. + """ model_kwargs = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4, **kwargs) model = _create_vision_transformer('vit_deit_tiny_patch16_224', pretrained=pretrained, **model_kwargs) return model @@ -593,6 +659,9 @@ def vit_deit_tiny_patch16_224(pretrained=False, **kwargs): @register_model def vit_deit_small_patch16_224(pretrained=False, **kwargs): + """ DeiT-small model @ 224x224 from paper (https://arxiv.org/abs/2012.12877). + ImageNet-1k weights from https://github.com/facebookresearch/deit. + """ model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4, **kwargs) model = _create_vision_transformer('vit_deit_small_patch16_224', pretrained=pretrained, **model_kwargs) return model @@ -600,6 +669,9 @@ def vit_deit_small_patch16_224(pretrained=False, **kwargs): @register_model def vit_deit_base_patch16_224(pretrained=False, **kwargs): + """ DeiT base model @ 224x224 from paper (https://arxiv.org/abs/2012.12877). + ImageNet-1k weights from https://github.com/facebookresearch/deit. + """ model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, **kwargs) model = _create_vision_transformer('vit_deit_base_patch16_224', pretrained=pretrained, **model_kwargs) return model @@ -607,6 +679,9 @@ def vit_deit_base_patch16_224(pretrained=False, **kwargs): @register_model def vit_deit_base_patch16_384(pretrained=False, **kwargs): + """ DeiT base model @ 384x384 from paper (https://arxiv.org/abs/2012.12877). + ImageNet-1k weights from https://github.com/facebookresearch/deit. + """ model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, **kwargs) model = _create_vision_transformer('vit_deit_base_patch16_384', pretrained=pretrained, **model_kwargs) return model