Adjust arg order for recent vit model args, add a few comments

pull/1245/head
Ross Wightman 2 years ago
parent 41dc49a337
commit f5ca4141f7

@ -325,8 +325,8 @@ class VisionTransformer(nn.Module):
def __init__(
self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, global_pool='token',
embed_dim=768, depth=12, num_heads=12, mlp_ratio=4., qkv_bias=True, init_values=None,
drop_rate=0., attn_drop_rate=0., drop_path_rate=0., weight_init='', class_token=True,
fc_norm=None, embed_layer=PatchEmbed, norm_layer=None, act_layer=None, block_fn=Block):
class_token=True, fc_norm=None, drop_rate=0., attn_drop_rate=0., drop_path_rate=0., weight_init='',
embed_layer=PatchEmbed, norm_layer=None, act_layer=None, block_fn=Block):
"""
Args:
img_size (int, tuple): input image size
@ -340,12 +340,12 @@ class VisionTransformer(nn.Module):
mlp_ratio (int): ratio of mlp hidden dim to embedding dim
qkv_bias (bool): enable bias for qkv if True
init_values: (float): layer-scale init values
class_token (bool): use class token
fc_norm (Optional[bool]): pre-fc norm after pool, set if global_pool == 'avg' if None (default: None)
drop_rate (float): dropout rate
attn_drop_rate (float): attention dropout rate
drop_path_rate (float): stochastic depth rate
weight_init (str): weight init scheme
class_token (bool): use class token
fc_norm (Optional[bool]): pre-fc norm after pool, set if global_pool == 'avg' if None (default: None)
embed_layer (nn.Module): patch embedding layer
norm_layer: (nn.Module): normalization layer
act_layer: (nn.Module): MLP activation layer

@ -240,13 +240,19 @@ class ResPostRelPosBlock(nn.Module):
class VisionTransformerRelPos(nn.Module):
""" Vision Transformer w/ Relative Position Bias
Differing from classic vit, this impl
* uses relative position index (swin v1 / beit) or relative log coord + mlp (swin v2) pos embed
* defaults to no class token (can be enabled)
* defaults to global avg pool for head (can be changed)
* layer-scale (residual branch gain) enabled
"""
def __init__(
self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, global_pool='avg',
embed_dim=768, depth=12, num_heads=12, mlp_ratio=4., qkv_bias=True, init_values=None,
drop_rate=0., attn_drop_rate=0., drop_path_rate=0., weight_init='skip', class_token=False,
rel_pos_type='mlp', shared_rel_pos=False, fc_norm=False,
embed_dim=768, depth=12, num_heads=12, mlp_ratio=4., qkv_bias=True, init_values=1e-5,
class_token=False, rel_pos_type='mlp', shared_rel_pos=False, fc_norm=False,
drop_rate=0., attn_drop_rate=0., drop_path_rate=0., weight_init='skip',
embed_layer=PatchEmbed, norm_layer=None, act_layer=None, block_fn=RelPosBlock):
"""
Args:
@ -254,21 +260,21 @@ class VisionTransformerRelPos(nn.Module):
patch_size (int, tuple): patch size
in_chans (int): number of input channels
num_classes (int): number of classes for classification head
global_pool (str): type of global pooling for final sequence (default: 'token')
global_pool (str): type of global pooling for final sequence (default: 'avg')
embed_dim (int): embedding dimension
depth (int): depth of transformer
num_heads (int): number of attention heads
mlp_ratio (int): ratio of mlp hidden dim to embedding dim
qkv_bias (bool): enable bias for qkv if True
init_values: (float): layer-scale init values
class_token (bool): use class token (default: False)
rel_pos_ty pe (str): type of relative position
shared_rel_pos (bool): share relative pos across all blocks
fc_norm (bool): use pre classifier norm instead of pre-pool
drop_rate (float): dropout rate
attn_drop_rate (float): attention dropout rate
drop_path_rate (float): stochastic depth rate
weight_init (str): weight init scheme
class_token (bool): use class token (default: False)
rel_pos_ty pe (str): type of relative position
shared_rel_pos (bool): share relative pos across all blocks
fc_norm (bool): use pre classifier norm
embed_layer (nn.Module): patch embedding layer
norm_layer: (nn.Module): normalization layer
act_layer: (nn.Module): MLP activation layer
@ -384,11 +390,10 @@ def _create_vision_transformer_relpos(variant, pretrained=False, **kwargs):
@register_model
def vit_relpos_base_patch32_plus_rpn_256(pretrained=False, **kwargs):
""" ViT-Base (ViT-B/32+) w/ relative log-coord position and residual post-norm, no class token
""" ViT-Base (ViT-B/32+) w/ relative log-coord position and residual post-norm, no class token
"""
model_kwargs = dict(
patch_size=32, embed_dim=896, depth=12, num_heads=14, init_values=1e-5,
block_fn=ResPostRelPosBlock, **kwargs)
patch_size=32, embed_dim=896, depth=12, num_heads=14, block_fn=ResPostRelPosBlock, **kwargs)
model = _create_vision_transformer_relpos(
'vit_relpos_base_patch32_plus_rpn_256', pretrained=pretrained, **model_kwargs)
return model
@ -398,7 +403,7 @@ def vit_relpos_base_patch32_plus_rpn_256(pretrained=False, **kwargs):
def vit_relpos_base_patch16_plus_240(pretrained=False, **kwargs):
""" ViT-Base (ViT-B/16+) w/ relative log-coord position, no class token
"""
model_kwargs = dict(patch_size=16, embed_dim=896, depth=12, num_heads=14, init_values=1e-5, **kwargs)
model_kwargs = dict(patch_size=16, embed_dim=896, depth=12, num_heads=14, **kwargs)
model = _create_vision_transformer_relpos('vit_relpos_base_patch16_plus_240', pretrained=pretrained, **model_kwargs)
return model
@ -408,8 +413,7 @@ def vit_relpos_base_patch16_224(pretrained=False, **kwargs):
""" ViT-Base (ViT-B/16) w/ relative log-coord position, no class token
"""
model_kwargs = dict(
patch_size=16, embed_dim=768, depth=12, num_heads=12, qkv_bias=False, init_values=1e-5,
fc_norm=True, **kwargs)
patch_size=16, embed_dim=768, depth=12, num_heads=12, qkv_bias=False, fc_norm=True, **kwargs)
model = _create_vision_transformer_relpos('vit_relpos_base_patch16_224', pretrained=pretrained, **model_kwargs)
return model
@ -419,7 +423,6 @@ def vit_relpos_base_patch16_rpn_224(pretrained=False, **kwargs):
""" ViT-Base (ViT-B/16) w/ relative log-coord position and residual post-norm, no class token
"""
model_kwargs = dict(
patch_size=16, embed_dim=768, depth=12, num_heads=12, qkv_bias=False, init_values=1e-5,
block_fn=ResPostRelPosBlock, **kwargs)
patch_size=16, embed_dim=768, depth=12, num_heads=12, qkv_bias=False, block_fn=ResPostRelPosBlock, **kwargs)
model = _create_vision_transformer_relpos('vit_relpos_base_patch16_rpn_224', pretrained=pretrained, **model_kwargs)
return model

Loading…
Cancel
Save