From f5ca4141f710d8b0b363f849abbf0182aebc5021 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Mon, 2 May 2022 22:41:38 -0700 Subject: [PATCH] Adjust arg order for recent vit model args, add a few comments --- timm/models/vision_transformer.py | 8 +++--- timm/models/vision_transformer_relpos.py | 35 +++++++++++++----------- 2 files changed, 23 insertions(+), 20 deletions(-) diff --git a/timm/models/vision_transformer.py b/timm/models/vision_transformer.py index 33cc5db2..59fd7849 100644 --- a/timm/models/vision_transformer.py +++ b/timm/models/vision_transformer.py @@ -325,8 +325,8 @@ class VisionTransformer(nn.Module): def __init__( self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, global_pool='token', embed_dim=768, depth=12, num_heads=12, mlp_ratio=4., qkv_bias=True, init_values=None, - drop_rate=0., attn_drop_rate=0., drop_path_rate=0., weight_init='', class_token=True, - fc_norm=None, embed_layer=PatchEmbed, norm_layer=None, act_layer=None, block_fn=Block): + class_token=True, fc_norm=None, drop_rate=0., attn_drop_rate=0., drop_path_rate=0., weight_init='', + embed_layer=PatchEmbed, norm_layer=None, act_layer=None, block_fn=Block): """ Args: img_size (int, tuple): input image size @@ -340,12 +340,12 @@ class VisionTransformer(nn.Module): mlp_ratio (int): ratio of mlp hidden dim to embedding dim qkv_bias (bool): enable bias for qkv if True init_values: (float): layer-scale init values + class_token (bool): use class token + fc_norm (Optional[bool]): pre-fc norm after pool, set if global_pool == 'avg' if None (default: None) drop_rate (float): dropout rate attn_drop_rate (float): attention dropout rate drop_path_rate (float): stochastic depth rate weight_init (str): weight init scheme - class_token (bool): use class token - fc_norm (Optional[bool]): pre-fc norm after pool, set if global_pool == 'avg' if None (default: None) embed_layer (nn.Module): patch embedding layer norm_layer: (nn.Module): normalization layer act_layer: (nn.Module): MLP activation layer diff --git a/timm/models/vision_transformer_relpos.py b/timm/models/vision_transformer_relpos.py index 056dba97..9ecfd473 100644 --- a/timm/models/vision_transformer_relpos.py +++ b/timm/models/vision_transformer_relpos.py @@ -240,13 +240,19 @@ class ResPostRelPosBlock(nn.Module): class VisionTransformerRelPos(nn.Module): """ Vision Transformer w/ Relative Position Bias + + Differing from classic vit, this impl + * uses relative position index (swin v1 / beit) or relative log coord + mlp (swin v2) pos embed + * defaults to no class token (can be enabled) + * defaults to global avg pool for head (can be changed) + * layer-scale (residual branch gain) enabled """ def __init__( self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, global_pool='avg', - embed_dim=768, depth=12, num_heads=12, mlp_ratio=4., qkv_bias=True, init_values=None, - drop_rate=0., attn_drop_rate=0., drop_path_rate=0., weight_init='skip', class_token=False, - rel_pos_type='mlp', shared_rel_pos=False, fc_norm=False, + embed_dim=768, depth=12, num_heads=12, mlp_ratio=4., qkv_bias=True, init_values=1e-5, + class_token=False, rel_pos_type='mlp', shared_rel_pos=False, fc_norm=False, + drop_rate=0., attn_drop_rate=0., drop_path_rate=0., weight_init='skip', embed_layer=PatchEmbed, norm_layer=None, act_layer=None, block_fn=RelPosBlock): """ Args: @@ -254,21 +260,21 @@ class VisionTransformerRelPos(nn.Module): patch_size (int, tuple): patch size in_chans (int): number of input channels num_classes (int): number of classes for classification head - global_pool (str): type of global pooling for final sequence (default: 'token') + global_pool (str): type of global pooling for final sequence (default: 'avg') embed_dim (int): embedding dimension depth (int): depth of transformer num_heads (int): number of attention heads mlp_ratio (int): ratio of mlp hidden dim to embedding dim qkv_bias (bool): enable bias for qkv if True init_values: (float): layer-scale init values + class_token (bool): use class token (default: False) + rel_pos_ty pe (str): type of relative position + shared_rel_pos (bool): share relative pos across all blocks + fc_norm (bool): use pre classifier norm instead of pre-pool drop_rate (float): dropout rate attn_drop_rate (float): attention dropout rate drop_path_rate (float): stochastic depth rate weight_init (str): weight init scheme - class_token (bool): use class token (default: False) - rel_pos_ty pe (str): type of relative position - shared_rel_pos (bool): share relative pos across all blocks - fc_norm (bool): use pre classifier norm embed_layer (nn.Module): patch embedding layer norm_layer: (nn.Module): normalization layer act_layer: (nn.Module): MLP activation layer @@ -384,11 +390,10 @@ def _create_vision_transformer_relpos(variant, pretrained=False, **kwargs): @register_model def vit_relpos_base_patch32_plus_rpn_256(pretrained=False, **kwargs): - """ ViT-Base (ViT-B/32+) w/ relative log-coord position and residual post-norm, no class token + """ ViT-Base (ViT-B/32+) w/ relative log-coord position and residual post-norm, no class token """ model_kwargs = dict( - patch_size=32, embed_dim=896, depth=12, num_heads=14, init_values=1e-5, - block_fn=ResPostRelPosBlock, **kwargs) + patch_size=32, embed_dim=896, depth=12, num_heads=14, block_fn=ResPostRelPosBlock, **kwargs) model = _create_vision_transformer_relpos( 'vit_relpos_base_patch32_plus_rpn_256', pretrained=pretrained, **model_kwargs) return model @@ -398,7 +403,7 @@ def vit_relpos_base_patch32_plus_rpn_256(pretrained=False, **kwargs): def vit_relpos_base_patch16_plus_240(pretrained=False, **kwargs): """ ViT-Base (ViT-B/16+) w/ relative log-coord position, no class token """ - model_kwargs = dict(patch_size=16, embed_dim=896, depth=12, num_heads=14, init_values=1e-5, **kwargs) + model_kwargs = dict(patch_size=16, embed_dim=896, depth=12, num_heads=14, **kwargs) model = _create_vision_transformer_relpos('vit_relpos_base_patch16_plus_240', pretrained=pretrained, **model_kwargs) return model @@ -408,8 +413,7 @@ def vit_relpos_base_patch16_224(pretrained=False, **kwargs): """ ViT-Base (ViT-B/16) w/ relative log-coord position, no class token """ model_kwargs = dict( - patch_size=16, embed_dim=768, depth=12, num_heads=12, qkv_bias=False, init_values=1e-5, - fc_norm=True, **kwargs) + patch_size=16, embed_dim=768, depth=12, num_heads=12, qkv_bias=False, fc_norm=True, **kwargs) model = _create_vision_transformer_relpos('vit_relpos_base_patch16_224', pretrained=pretrained, **model_kwargs) return model @@ -419,7 +423,6 @@ def vit_relpos_base_patch16_rpn_224(pretrained=False, **kwargs): """ ViT-Base (ViT-B/16) w/ relative log-coord position and residual post-norm, no class token """ model_kwargs = dict( - patch_size=16, embed_dim=768, depth=12, num_heads=12, qkv_bias=False, init_values=1e-5, - block_fn=ResPostRelPosBlock, **kwargs) + patch_size=16, embed_dim=768, depth=12, num_heads=12, qkv_bias=False, block_fn=ResPostRelPosBlock, **kwargs) model = _create_vision_transformer_relpos('vit_relpos_base_patch16_rpn_224', pretrained=pretrained, **model_kwargs) return model