|
|
@ -109,6 +109,8 @@ default_cfgs = {
|
|
|
|
'vit_giant_patch14_224': _cfg(url=''),
|
|
|
|
'vit_giant_patch14_224': _cfg(url=''),
|
|
|
|
'vit_gigantic_patch14_224': _cfg(url=''),
|
|
|
|
'vit_gigantic_patch14_224': _cfg(url=''),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
'vit_base2_patch32_256': _cfg(url='', input_size=(3, 256, 256), crop_pct=0.95),
|
|
|
|
|
|
|
|
|
|
|
|
# patch models, imagenet21k (weights from official Google JAX impl)
|
|
|
|
# patch models, imagenet21k (weights from official Google JAX impl)
|
|
|
|
'vit_tiny_patch16_224_in21k': _cfg(
|
|
|
|
'vit_tiny_patch16_224_in21k': _cfg(
|
|
|
|
url='https://storage.googleapis.com/vit_models/augreg/Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0.npz',
|
|
|
|
url='https://storage.googleapis.com/vit_models/augreg/Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0.npz',
|
|
|
@ -202,6 +204,7 @@ default_cfgs = {
|
|
|
|
class Attention(nn.Module):
|
|
|
|
class Attention(nn.Module):
|
|
|
|
def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.):
|
|
|
|
def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.):
|
|
|
|
super().__init__()
|
|
|
|
super().__init__()
|
|
|
|
|
|
|
|
assert dim % num_heads == 0, 'dim should be divisible by num_heads'
|
|
|
|
self.num_heads = num_heads
|
|
|
|
self.num_heads = num_heads
|
|
|
|
head_dim = dim // num_heads
|
|
|
|
head_dim = dim // num_heads
|
|
|
|
self.scale = head_dim ** -0.5
|
|
|
|
self.scale = head_dim ** -0.5
|
|
|
@ -634,6 +637,16 @@ def vit_base_patch32_224(pretrained=False, **kwargs):
|
|
|
|
return model
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
|
|
|
def vit_base2_patch32_256(pretrained=False, **kwargs):
|
|
|
|
|
|
|
|
""" ViT-Base (ViT-B/32)
|
|
|
|
|
|
|
|
# FIXME experiment
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
model_kwargs = dict(patch_size=32, embed_dim=896, depth=12, num_heads=14, **kwargs)
|
|
|
|
|
|
|
|
model = _create_vision_transformer('vit_base2_patch32_256', pretrained=pretrained, **model_kwargs)
|
|
|
|
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
@register_model
|
|
|
|
def vit_base_patch32_384(pretrained=False, **kwargs):
|
|
|
|
def vit_base_patch32_384(pretrained=False, **kwargs):
|
|
|
|
""" ViT-Base model (ViT-B/32) from original paper (https://arxiv.org/abs/2010.11929).
|
|
|
|
""" ViT-Base model (ViT-B/32) from original paper (https://arxiv.org/abs/2010.11929).
|
|
|
|