Benchmark models listed in txt file. Add more hybrid vit variants for testing

pull/450/head
Ross Wightman 4 years ago
parent 2db2d87ff7
commit 0706d05d52

@ -45,6 +45,8 @@ _logger = logging.getLogger('validate')
parser = argparse.ArgumentParser(description='PyTorch Benchmark') parser = argparse.ArgumentParser(description='PyTorch Benchmark')
# benchmark specific args # benchmark specific args
parser.add_argument('--model-list', metavar='NAME', default='',
help='txt file based list of model names to benchmark')
parser.add_argument('--bench', default='both', type=str, parser.add_argument('--bench', default='both', type=str,
help="Benchmark mode. One of 'inference', 'train', 'both'. Defaults to 'inference'") help="Benchmark mode. One of 'inference', 'train', 'both'. Defaults to 'inference'")
parser.add_argument('--detail', action='store_true', default=False, parser.add_argument('--detail', action='store_true', default=False,
@ -357,7 +359,7 @@ def _try_run(model_name, bench_fn, initial_batch_size, bench_kwargs):
except RuntimeError as e: except RuntimeError as e:
torch.cuda.empty_cache() torch.cuda.empty_cache()
batch_size = decay_batch_exp(batch_size) batch_size = decay_batch_exp(batch_size)
print(f'Reducing batch size to {batch_size}') print(f'Error: {str(e)} while running benchmark. Reducing batch size to {batch_size} for retry.')
return results return results
@ -413,7 +415,12 @@ def main():
model_cfgs = [] model_cfgs = []
model_names = [] model_names = []
if args.model == 'all': if args.model_list:
args.model = ''
with open(args.model_list) as f:
model_names = [line.rstrip() for line in f]
model_cfgs = [(n, None) for n in model_names]
elif args.model == 'all':
# validate all models in a list of names with pretrained checkpoints # validate all models in a list of names with pretrained checkpoints
args.pretrained = True args.pretrained = True
model_names = list_models(pretrained=True, exclude_filters=['*in21k']) model_names = list_models(pretrained=True, exclude_filters=['*in21k'])
@ -429,6 +436,8 @@ def main():
results = [] results = []
try: try:
for m, _ in model_cfgs: for m, _ in model_cfgs:
if not m:
continue
args.model = m args.model = m
r = benchmark(args) r = benchmark(args)
results.append(r) results.append(r)

@ -103,48 +103,90 @@ default_cfgs = {
num_classes=21843, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=0.9, first_conv='patch_embed.backbone.stem.conv'), num_classes=21843, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=0.9, first_conv='patch_embed.backbone.stem.conv'),
# hybrid in-1k models (weights ported from official Google JAX impl where they exist) # hybrid in-1k models (weights ported from official Google JAX impl where they exist)
'vit_tiny_r_s16_p8_224': _cfg(
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0,
first_conv='patch_embed.backbone.stem.conv'),
'vit_tiny_r_s16_p8_224_in21k': _cfg(
num_classes=21843, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0,
first_conv='patch_embed.backbone.stem.conv'),
'vit_tiny_r_s16_p8_384': _cfg(
input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0,
first_conv='patch_embed.backbone.stem.conv'),
'vit_small_r_s16_p8_224': _cfg( 'vit_small_r_s16_p8_224': _cfg(
input_size=(3, 224, 224), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0,
first_conv='patch_embed.backbone.stem.conv'), first_conv='patch_embed.backbone.stem.conv'),
'vit_small_r_s16_p8_224_in21k': _cfg(
num_classes=21843, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0,
first_conv='patch_embed.backbone.stem.conv'),
'vit_small_r_s16_p8_384': _cfg(
input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0,
first_conv='patch_embed.backbone.stem.conv'),
'vit_small_r20_s16_p2_224': _cfg( 'vit_small_r20_s16_p2_224': _cfg(
input_size=(3, 224, 224), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0,
first_conv='patch_embed.backbone.stem.conv'),
'vit_small_r20_s16_p2_224_in21k': _cfg(
inum_classes=21843, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0,
first_conv='patch_embed.backbone.stem.conv'), first_conv='patch_embed.backbone.stem.conv'),
'vit_small_r20_s16_p2_384': _cfg( 'vit_small_r20_s16_p2_384': _cfg(
input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0, input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0,
first_conv='patch_embed.backbone.stem.conv'), first_conv='patch_embed.backbone.stem.conv'),
'vit_small_r20_s16_224': _cfg( 'vit_small_r20_s16_224': _cfg(
input_size=(3, 224, 224), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0,
first_conv='patch_embed.backbone.stem.conv'),
'vit_small_r20_s16_224_in21k': _cfg(
num_classes=21843, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0,
first_conv='patch_embed.backbone.stem.conv'), first_conv='patch_embed.backbone.stem.conv'),
'vit_small_r20_s16_384': _cfg( 'vit_small_r20_s16_384': _cfg(
input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0, input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0,
first_conv='patch_embed.backbone.stem.conv'), first_conv='patch_embed.backbone.stem.conv'),
'vit_small_r26_s32_224': _cfg( 'vit_small_r26_s32_224': _cfg(
input_size=(3, 224, 224), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0,
first_conv='patch_embed.backbone.stem.conv'),
'vit_small_r26_s32_224_in21k': _cfg(
num_classes=21843, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0,
first_conv='patch_embed.backbone.stem.conv'), first_conv='patch_embed.backbone.stem.conv'),
'vit_small_r26_s32_384': _cfg( 'vit_small_r26_s32_384': _cfg(
input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0, input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0,
first_conv='patch_embed.backbone.stem.conv'), first_conv='patch_embed.backbone.stem.conv'),
'vit_base_r20_s16_224': _cfg( 'vit_base_r20_s16_224': _cfg(
input_size=(3, 224, 224), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0,
first_conv='patch_embed.backbone.stem.conv'),
'vit_base_r20_s16_224_in21k': _cfg(
num_classes=21843, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0,
first_conv='patch_embed.backbone.stem.conv'), first_conv='patch_embed.backbone.stem.conv'),
'vit_base_r20_s16_384': _cfg( 'vit_base_r20_s16_384': _cfg(
input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0, input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0,
first_conv='patch_embed.backbone.stem.conv'), first_conv='patch_embed.backbone.stem.conv'),
'vit_base_r26_s32_224': _cfg( 'vit_base_r26_s32_224': _cfg(
input_size=(3, 224, 224), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0,
first_conv='patch_embed.backbone.stem.conv'),
'vit_base_r26_s32_224_in21k': _cfg(
num_classes=21843, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0,
first_conv='patch_embed.backbone.stem.conv'), first_conv='patch_embed.backbone.stem.conv'),
'vit_base_r26_s32_384': _cfg( 'vit_base_r26_s32_384': _cfg(
input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0, input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0,
first_conv='patch_embed.backbone.stem.conv'), first_conv='patch_embed.backbone.stem.conv'),
'vit_base_r50_s16_224': _cfg( 'vit_base_r50_s16_224': _cfg(
input_size=(3, 224, 224), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0,
first_conv='patch_embed.backbone.stem.conv'), first_conv='patch_embed.backbone.stem.conv'),
'vit_base_r50_s16_384': _cfg( 'vit_base_r50_s16_384': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_resnet50_384-9fd3c705.pth', url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_resnet50_384-9fd3c705.pth',
input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0, input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0,
first_conv='patch_embed.backbone.stem.conv'), first_conv='patch_embed.backbone.stem.conv'),
'vit_large_r50_s32_224': _cfg( 'vit_large_r50_s32_224': _cfg(
input_size=(3, 224, 224), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0,
first_conv='patch_embed.backbone.stem.conv'),
'vit_large_r50_s32_224_in21k': _cfg(
num_classes=21843, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0,
first_conv='patch_embed.backbone.stem.conv'), first_conv='patch_embed.backbone.stem.conv'),
'vit_large_r50_s32_384': _cfg( 'vit_large_r50_s32_384': _cfg(
input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0, input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0,
@ -159,8 +201,19 @@ default_cfgs = {
# deit models (FB weights) # deit models (FB weights)
'vit_deit_tiny_patch16_224': _cfg( 'vit_deit_tiny_patch16_224': _cfg(
url='https://dl.fbaipublicfiles.com/deit/deit_tiny_patch16_224-a1311bcf.pth'), url='https://dl.fbaipublicfiles.com/deit/deit_tiny_patch16_224-a1311bcf.pth'),
'vit_deit_tiny_patch16_224_in21k': _cfg(num_classes=21843),
'vit_deit_tiny_patch16_224_in21k_norep': _cfg(num_classes=21843),
'vit_deit_tiny_patch16_384': _cfg(input_size=(3, 384, 384)),
'vit_deit_small_patch16_224': _cfg( 'vit_deit_small_patch16_224': _cfg(
url='https://dl.fbaipublicfiles.com/deit/deit_small_patch16_224-cd65a155.pth'), url='https://dl.fbaipublicfiles.com/deit/deit_small_patch16_224-cd65a155.pth'),
'vit_deit_small_patch16_224_in21k': _cfg(num_classes=21843),
'vit_deit_small_patch16_384': _cfg(input_size=(3, 384, 384)),
'vit_deit_small_patch32_224': _cfg(),
'vit_deit_small_patch32_224_in21k': _cfg(num_classes=21843),
'vit_deit_small_patch32_384': _cfg(input_size=(3, 384, 384)),
'vit_deit_base_patch16_224': _cfg( 'vit_deit_base_patch16_224': _cfg(
url='https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth',), url='https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth',),
'vit_deit_base_patch16_384': _cfg( 'vit_deit_base_patch16_384': _cfg(
@ -728,7 +781,29 @@ def vit_tiny_r_s16_p8_224(pretrained=False, **kwargs):
backbone = _resnetv2(layers=(), **kwargs) backbone = _resnetv2(layers=(), **kwargs)
model_kwargs = dict( model_kwargs = dict(
patch_size=8, embed_dim=192, depth=12, num_heads=3, hybrid_backbone=backbone, **kwargs) patch_size=8, embed_dim=192, depth=12, num_heads=3, hybrid_backbone=backbone, **kwargs)
model = _create_vision_transformer('vit_small_r20_s16_p2_224', pretrained=pretrained, **model_kwargs) model = _create_vision_transformer('vit_tiny_r_s16_p8_224', pretrained=pretrained, **model_kwargs)
return model
@register_model
def vit_tiny_r_s16_p8_224_in21k(pretrained=False, **kwargs):
""" R+ViT-Ti/S16 w/ 8x8 patch hybrid @ 224 x 224.
"""
backbone = _resnetv2(layers=(), **kwargs)
model_kwargs = dict(
patch_size=8, embed_dim=192, depth=12, num_heads=3, representation_size=192, hybrid_backbone=backbone, **kwargs)
model = _create_vision_transformer('vit_tiny_r_s16_p8_224_in21k', pretrained=pretrained, **model_kwargs)
return model
@register_model
def vit_tiny_r_s16_p8_384(pretrained=False, **kwargs):
""" R+ViT-Ti/S16 w/ 8x8 patch hybrid @ 224 x 224.
"""
backbone = _resnetv2(layers=(), **kwargs)
model_kwargs = dict(
patch_size=8, embed_dim=192, depth=12, num_heads=3, hybrid_backbone=backbone, **kwargs)
model = _create_vision_transformer('vit_tiny_r_s16_p8_384', pretrained=pretrained, **model_kwargs)
return model return model
@ -740,6 +815,29 @@ def vit_small_r_s16_p8_224(pretrained=False, **kwargs):
model_kwargs = dict( model_kwargs = dict(
patch_size=8, embed_dim=384, depth=12, num_heads=6, hybrid_backbone=backbone, **kwargs) patch_size=8, embed_dim=384, depth=12, num_heads=6, hybrid_backbone=backbone, **kwargs)
model = _create_vision_transformer('vit_small_r_s16_p8_224', pretrained=pretrained, **model_kwargs) model = _create_vision_transformer('vit_small_r_s16_p8_224', pretrained=pretrained, **model_kwargs)
return model
@register_model
def vit_small_r_s16_p8_224_in21k(pretrained=False, **kwargs):
""" R+ViT-S/S16 w/ 8x8 patch hybrid @ 224 x 224.
"""
backbone = _resnetv2(layers=(), **kwargs)
model_kwargs = dict(
patch_size=8, embed_dim=384, depth=12, num_heads=6, representation_size=384, hybrid_backbone=backbone, **kwargs)
model = _create_vision_transformer('vit_small_r_s16_p8_224_in21k', pretrained=pretrained, **model_kwargs)
return model
@register_model
def vit_small_r_s16_p8_384(pretrained=False, **kwargs):
""" R+ViT-S/S16 w/ 8x8 patch hybrid @ 224 x 224.
"""
backbone = _resnetv2(layers=(), **kwargs)
model_kwargs = dict(
patch_size=8, embed_dim=384, depth=12, num_heads=6, hybrid_backbone=backbone, **kwargs)
model = _create_vision_transformer('vit_small_r_s16_p8_384', pretrained=pretrained, **model_kwargs)
return model return model
@ -754,6 +852,17 @@ def vit_small_r20_s16_p2_224(pretrained=False, **kwargs):
return model return model
@register_model
def vit_small_r20_s16_p2_224_in21k(pretrained=False, **kwargs):
""" R52+ViT-S/S16 w/ 2x2 patch hybrid @ 224 x 224.
"""
backbone = _resnetv2((2, 4), **kwargs)
model_kwargs = dict(
patch_size=2, embed_dim=384, depth=12, num_heads=6, representation_size=384, hybrid_backbone=backbone, **kwargs)
model = _create_vision_transformer('vit_small_r20_s16_p2_224_in21k', pretrained=pretrained, **model_kwargs)
return model
@register_model @register_model
def vit_small_r20_s16_p2_384(pretrained=False, **kwargs): def vit_small_r20_s16_p2_384(pretrained=False, **kwargs):
""" R20+ViT-S/S16 w/ 2x2 Patch hybrid @ 384x384. """ R20+ViT-S/S16 w/ 2x2 Patch hybrid @ 384x384.
@ -775,6 +884,16 @@ def vit_small_r20_s16_224(pretrained=False, **kwargs):
return model return model
@register_model
def vit_small_r20_s16_224_in21k(pretrained=False, **kwargs):
""" R20+ViT-S/S16 hybrid.
"""
backbone = _resnetv2((2, 2, 2), **kwargs)
model_kwargs = dict(embed_dim=384, depth=12, num_heads=6, representation_size=384, hybrid_backbone=backbone, **kwargs)
model = _create_vision_transformer('vit_small_r20_s16_224_in21k', pretrained=pretrained, **model_kwargs)
return model
@register_model @register_model
def vit_small_r20_s16_384(pretrained=False, **kwargs): def vit_small_r20_s16_384(pretrained=False, **kwargs):
""" R20+ViT-S/S16 hybrid @ 384x384. """ R20+ViT-S/S16 hybrid @ 384x384.
@ -795,6 +914,17 @@ def vit_small_r26_s32_224(pretrained=False, **kwargs):
return model return model
@register_model
def vit_small_r26_s32_224_in21k(pretrained=False, **kwargs):
""" R26+ViT-S/S32 hybrid.
"""
backbone = _resnetv2((2, 2, 2, 2), **kwargs)
model_kwargs = dict(
embed_dim=384, depth=12, num_heads=6, representation_size=384, hybrid_backbone=backbone, **kwargs)
model = _create_vision_transformer('vit_small_r26_s32_224_in21k', pretrained=pretrained, **model_kwargs)
return model
@register_model @register_model
def vit_small_r26_s32_384(pretrained=False, **kwargs): def vit_small_r26_s32_384(pretrained=False, **kwargs):
""" R26+ViT-S/S32 hybrid @ 384x384. """ R26+ViT-S/S32 hybrid @ 384x384.
@ -810,12 +940,22 @@ def vit_base_r20_s16_224(pretrained=False, **kwargs):
""" R20+ViT-B/S16 hybrid. """ R20+ViT-B/S16 hybrid.
""" """
backbone = _resnetv2((2, 2, 2), **kwargs) backbone = _resnetv2((2, 2, 2), **kwargs)
model_kwargs = dict( model_kwargs = dict(embed_dim=768, depth=12, num_heads=12, hybrid_backbone=backbone, **kwargs)
embed_dim=768, depth=12, num_heads=12, hybrid_backbone=backbone, act_layer=nn.SiLU, **kwargs)
model = _create_vision_transformer('vit_base_r20_s16_224', pretrained=pretrained, **model_kwargs) model = _create_vision_transformer('vit_base_r20_s16_224', pretrained=pretrained, **model_kwargs)
return model return model
@register_model
def vit_base_r20_s16_224_in21k(pretrained=False, **kwargs):
""" R20+ViT-B/S16 hybrid.
"""
backbone = _resnetv2((2, 2, 2), **kwargs)
model_kwargs = dict(
embed_dim=768, depth=12, num_heads=12, representation_size=768, hybrid_backbone=backbone, **kwargs)
model = _create_vision_transformer('vit_base_r20_s16_224_in21k', pretrained=pretrained, **model_kwargs)
return model
@register_model @register_model
def vit_base_r20_s16_384(pretrained=False, **kwargs): def vit_base_r20_s16_384(pretrained=False, **kwargs):
""" R20+ViT-B/S16 hybrid. """ R20+ViT-B/S16 hybrid.
@ -836,6 +976,27 @@ def vit_base_r26_s32_224(pretrained=False, **kwargs):
return model return model
@register_model
def vit_base_r26_s32_224_in21k(pretrained=False, **kwargs):
""" R26+ViT-B/S32 hybrid.
"""
backbone = _resnetv2((2, 2, 2, 2), **kwargs)
model_kwargs = dict(
embed_dim=768, depth=12, num_heads=12, representation_size=768, hybrid_backbone=backbone, **kwargs)
model = _create_vision_transformer('vit_base_r26_s32_224_in21k', pretrained=pretrained, **model_kwargs)
return model
@register_model
def vit_base_r26_s32_384(pretrained=False, **kwargs):
""" R26+ViT-B/S32 hybrid.
"""
backbone = _resnetv2((2, 2, 2, 2), **kwargs)
model_kwargs = dict(embed_dim=768, depth=12, num_heads=12, hybrid_backbone=backbone, **kwargs)
model = _create_vision_transformer('vit_base_r26_s32_384', pretrained=pretrained, **model_kwargs)
return model
@register_model @register_model
def vit_base_r50_s16_224(pretrained=False, **kwargs): def vit_base_r50_s16_224(pretrained=False, **kwargs):
""" R50+ViT-B/S16 hybrid from original paper (https://arxiv.org/abs/2010.11929). """ R50+ViT-B/S16 hybrid from original paper (https://arxiv.org/abs/2010.11929).
@ -867,6 +1028,17 @@ def vit_large_r50_s32_224(pretrained=False, **kwargs):
return model return model
@register_model
def vit_large_r50_s32_224_in21k(pretrained=False, **kwargs):
""" R50+ViT-L/S32 hybrid.
"""
backbone = _resnetv2((3, 4, 6, 3), **kwargs)
model_kwargs = dict(
embed_dim=768, depth=12, num_heads=12, representation_size=768, hybrid_backbone=backbone, **kwargs)
model = _create_vision_transformer('vit_large_r50_s32_224_in21k', pretrained=pretrained, **model_kwargs)
return model
@register_model @register_model
def vit_large_r50_s32_384(pretrained=False, **kwargs): def vit_large_r50_s32_384(pretrained=False, **kwargs):
""" R50+ViT-L/S32 hybrid. """ R50+ViT-L/S32 hybrid.
@ -927,6 +1099,31 @@ def vit_deit_tiny_patch16_224(pretrained=False, **kwargs):
return model return model
@register_model
def vit_deit_tiny_patch16_224_in21k_norep(pretrained=False, **kwargs):
""" DeiT-tiny model"""
model_kwargs = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3, **kwargs)
model = _create_vision_transformer('vit_deit_tiny_patch16_224_in21k_norep', pretrained=pretrained, **model_kwargs)
return model
@register_model
def vit_deit_tiny_patch16_224_in21k(pretrained=False, **kwargs):
""" DeiT-tiny model"""
model_kwargs = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3, representation_size=192, **kwargs)
model = _create_vision_transformer('vit_deit_tiny_patch16_224_in21k', pretrained=pretrained, **model_kwargs)
return model
@register_model
def vit_deit_tiny_patch16_384(pretrained=False, **kwargs):
""" DeiT-tiny model"""
model_kwargs = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3, **kwargs)
model = _create_vision_transformer('vit_deit_tiny_patch16_384', pretrained=pretrained, **model_kwargs)
return model
@register_model @register_model
def vit_deit_small_patch16_224(pretrained=False, **kwargs): def vit_deit_small_patch16_224(pretrained=False, **kwargs):
""" DeiT-small model @ 224x224 from paper (https://arxiv.org/abs/2012.12877). """ DeiT-small model @ 224x224 from paper (https://arxiv.org/abs/2012.12877).
@ -937,6 +1134,48 @@ def vit_deit_small_patch16_224(pretrained=False, **kwargs):
return model return model
@register_model
def vit_deit_small_patch16_224_in21k(pretrained=False, **kwargs):
""" DeiT-small """
model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, representation_size=384, **kwargs)
model = _create_vision_transformer('vit_deit_small_patch16_224_in21k', pretrained=pretrained, **model_kwargs)
return model
@register_model
def vit_deit_small_patch16_384(pretrained=False, **kwargs):
""" DeiT-small """
model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, **kwargs)
model = _create_vision_transformer('vit_deit_small_patch16_384', pretrained=pretrained, **model_kwargs)
return model
@register_model
def vit_deit_small_patch32_224(pretrained=False, **kwargs):
""" DeiT-small model @ 224x224 from paper (https://arxiv.org/abs/2012.12877).
ImageNet-1k weights from https://github.com/facebookresearch/deit.
"""
model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, **kwargs)
model = _create_vision_transformer('vit_deit_small_patch32_224', pretrained=pretrained, **model_kwargs)
return model
@register_model
def vit_deit_small_patch32_224_in21k(pretrained=False, **kwargs):
""" DeiT-small """
model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, representation_size=384, **kwargs)
model = _create_vision_transformer('vit_deit_small_patch32_224_in21k', pretrained=pretrained, **model_kwargs)
return model
@register_model
def vit_deit_small_patch32_384(pretrained=False, **kwargs):
""" DeiT-small """
model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, **kwargs)
model = _create_vision_transformer('vit_deit_small_patch32_384', pretrained=pretrained, **model_kwargs)
return model
@register_model @register_model
def vit_deit_base_patch16_224(pretrained=False, **kwargs): def vit_deit_base_patch16_224(pretrained=False, **kwargs):
""" DeiT base model @ 224x224 from paper (https://arxiv.org/abs/2012.12877). """ DeiT base model @ 224x224 from paper (https://arxiv.org/abs/2012.12877).

Loading…
Cancel
Save