Add vit_base2_patch32_256 for a model between base_patch16 and patch32 with a slightly larger img size and width

pull/1112/head v0.1-mvit-weights
Ross Wightman 2 years ago
parent cf4334391e
commit 07379c6d5d

@ -109,6 +109,8 @@ default_cfgs = {
'vit_giant_patch14_224': _cfg(url=''),
'vit_gigantic_patch14_224': _cfg(url=''),
'vit_base2_patch32_256': _cfg(url='', input_size=(3, 256, 256), crop_pct=0.95),
# patch models, imagenet21k (weights from official Google JAX impl)
'vit_tiny_patch16_224_in21k': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0.npz',
@ -202,6 +204,7 @@ default_cfgs = {
class Attention(nn.Module):
def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.):
super().__init__()
assert dim % num_heads == 0, 'dim should be divisible by num_heads'
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = head_dim ** -0.5
@ -634,6 +637,16 @@ def vit_base_patch32_224(pretrained=False, **kwargs):
return model
@register_model
def vit_base2_patch32_256(pretrained=False, **kwargs):
""" ViT-Base (ViT-B/32)
# FIXME experiment
"""
model_kwargs = dict(patch_size=32, embed_dim=896, depth=12, num_heads=14, **kwargs)
model = _create_vision_transformer('vit_base2_patch32_256', pretrained=pretrained, **model_kwargs)
return model
@register_model
def vit_base_patch32_384(pretrained=False, **kwargs):
""" ViT-Base model (ViT-B/32) from original paper (https://arxiv.org/abs/2010.11929).

Loading…
Cancel
Save