|
|
|
@ -240,13 +240,19 @@ class ResPostRelPosBlock(nn.Module):
|
|
|
|
|
|
|
|
|
|
class VisionTransformerRelPos(nn.Module):
|
|
|
|
|
""" Vision Transformer w/ Relative Position Bias
|
|
|
|
|
|
|
|
|
|
Differing from classic vit, this impl
|
|
|
|
|
* uses relative position index (swin v1 / beit) or relative log coord + mlp (swin v2) pos embed
|
|
|
|
|
* defaults to no class token (can be enabled)
|
|
|
|
|
* defaults to global avg pool for head (can be changed)
|
|
|
|
|
* layer-scale (residual branch gain) enabled
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def __init__(
|
|
|
|
|
self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, global_pool='avg',
|
|
|
|
|
embed_dim=768, depth=12, num_heads=12, mlp_ratio=4., qkv_bias=True, init_values=None,
|
|
|
|
|
drop_rate=0., attn_drop_rate=0., drop_path_rate=0., weight_init='skip', class_token=False,
|
|
|
|
|
rel_pos_type='mlp', shared_rel_pos=False, fc_norm=False,
|
|
|
|
|
embed_dim=768, depth=12, num_heads=12, mlp_ratio=4., qkv_bias=True, init_values=1e-5,
|
|
|
|
|
class_token=False, rel_pos_type='mlp', shared_rel_pos=False, fc_norm=False,
|
|
|
|
|
drop_rate=0., attn_drop_rate=0., drop_path_rate=0., weight_init='skip',
|
|
|
|
|
embed_layer=PatchEmbed, norm_layer=None, act_layer=None, block_fn=RelPosBlock):
|
|
|
|
|
"""
|
|
|
|
|
Args:
|
|
|
|
@ -254,21 +260,21 @@ class VisionTransformerRelPos(nn.Module):
|
|
|
|
|
patch_size (int, tuple): patch size
|
|
|
|
|
in_chans (int): number of input channels
|
|
|
|
|
num_classes (int): number of classes for classification head
|
|
|
|
|
global_pool (str): type of global pooling for final sequence (default: 'token')
|
|
|
|
|
global_pool (str): type of global pooling for final sequence (default: 'avg')
|
|
|
|
|
embed_dim (int): embedding dimension
|
|
|
|
|
depth (int): depth of transformer
|
|
|
|
|
num_heads (int): number of attention heads
|
|
|
|
|
mlp_ratio (int): ratio of mlp hidden dim to embedding dim
|
|
|
|
|
qkv_bias (bool): enable bias for qkv if True
|
|
|
|
|
init_values: (float): layer-scale init values
|
|
|
|
|
class_token (bool): use class token (default: False)
|
|
|
|
|
rel_pos_ty pe (str): type of relative position
|
|
|
|
|
shared_rel_pos (bool): share relative pos across all blocks
|
|
|
|
|
fc_norm (bool): use pre classifier norm instead of pre-pool
|
|
|
|
|
drop_rate (float): dropout rate
|
|
|
|
|
attn_drop_rate (float): attention dropout rate
|
|
|
|
|
drop_path_rate (float): stochastic depth rate
|
|
|
|
|
weight_init (str): weight init scheme
|
|
|
|
|
class_token (bool): use class token (default: False)
|
|
|
|
|
rel_pos_ty pe (str): type of relative position
|
|
|
|
|
shared_rel_pos (bool): share relative pos across all blocks
|
|
|
|
|
fc_norm (bool): use pre classifier norm
|
|
|
|
|
embed_layer (nn.Module): patch embedding layer
|
|
|
|
|
norm_layer: (nn.Module): normalization layer
|
|
|
|
|
act_layer: (nn.Module): MLP activation layer
|
|
|
|
@ -384,11 +390,10 @@ def _create_vision_transformer_relpos(variant, pretrained=False, **kwargs):
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
def vit_relpos_base_patch32_plus_rpn_256(pretrained=False, **kwargs):
|
|
|
|
|
""" ViT-Base (ViT-B/32+) w/ relative log-coord position and residual post-norm, no class token
|
|
|
|
|
""" ViT-Base (ViT-B/32+) w/ relative log-coord position and residual post-norm, no class token
|
|
|
|
|
"""
|
|
|
|
|
model_kwargs = dict(
|
|
|
|
|
patch_size=32, embed_dim=896, depth=12, num_heads=14, init_values=1e-5,
|
|
|
|
|
block_fn=ResPostRelPosBlock, **kwargs)
|
|
|
|
|
patch_size=32, embed_dim=896, depth=12, num_heads=14, block_fn=ResPostRelPosBlock, **kwargs)
|
|
|
|
|
model = _create_vision_transformer_relpos(
|
|
|
|
|
'vit_relpos_base_patch32_plus_rpn_256', pretrained=pretrained, **model_kwargs)
|
|
|
|
|
return model
|
|
|
|
@ -398,7 +403,7 @@ def vit_relpos_base_patch32_plus_rpn_256(pretrained=False, **kwargs):
|
|
|
|
|
def vit_relpos_base_patch16_plus_240(pretrained=False, **kwargs):
|
|
|
|
|
""" ViT-Base (ViT-B/16+) w/ relative log-coord position, no class token
|
|
|
|
|
"""
|
|
|
|
|
model_kwargs = dict(patch_size=16, embed_dim=896, depth=12, num_heads=14, init_values=1e-5, **kwargs)
|
|
|
|
|
model_kwargs = dict(patch_size=16, embed_dim=896, depth=12, num_heads=14, **kwargs)
|
|
|
|
|
model = _create_vision_transformer_relpos('vit_relpos_base_patch16_plus_240', pretrained=pretrained, **model_kwargs)
|
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
@ -408,8 +413,7 @@ def vit_relpos_base_patch16_224(pretrained=False, **kwargs):
|
|
|
|
|
""" ViT-Base (ViT-B/16) w/ relative log-coord position, no class token
|
|
|
|
|
"""
|
|
|
|
|
model_kwargs = dict(
|
|
|
|
|
patch_size=16, embed_dim=768, depth=12, num_heads=12, qkv_bias=False, init_values=1e-5,
|
|
|
|
|
fc_norm=True, **kwargs)
|
|
|
|
|
patch_size=16, embed_dim=768, depth=12, num_heads=12, qkv_bias=False, fc_norm=True, **kwargs)
|
|
|
|
|
model = _create_vision_transformer_relpos('vit_relpos_base_patch16_224', pretrained=pretrained, **model_kwargs)
|
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
@ -419,7 +423,6 @@ def vit_relpos_base_patch16_rpn_224(pretrained=False, **kwargs):
|
|
|
|
|
""" ViT-Base (ViT-B/16) w/ relative log-coord position and residual post-norm, no class token
|
|
|
|
|
"""
|
|
|
|
|
model_kwargs = dict(
|
|
|
|
|
patch_size=16, embed_dim=768, depth=12, num_heads=12, qkv_bias=False, init_values=1e-5,
|
|
|
|
|
block_fn=ResPostRelPosBlock, **kwargs)
|
|
|
|
|
patch_size=16, embed_dim=768, depth=12, num_heads=12, qkv_bias=False, block_fn=ResPostRelPosBlock, **kwargs)
|
|
|
|
|
model = _create_vision_transformer_relpos('vit_relpos_base_patch16_rpn_224', pretrained=pretrained, **model_kwargs)
|
|
|
|
|
return model
|
|
|
|
|