Refactor vision_transformer entrpy fns, add pos embedding resize support for fine tuning, add some deit models for testing

pull/323/head
Ross Wightman 3 years ago
parent 9d5d4b8df6
commit 55f7dfa9ea

@ -14,7 +14,7 @@ if hasattr(torch._C, '_jit_set_profiling_executor'):
torch._C._jit_set_profiling_mode(False)
# transformer models don't support many of the spatial / feature based model functionalities
NON_STD_FILTERS = ['vit_*', 'deit_*']
NON_STD_FILTERS = ['vit_*']
# exclude models that cause specific test failures
if 'GITHUB_ACTIONS' in os.environ: # and 'Linux' in platform.system():
@ -29,7 +29,7 @@ MAX_FWD_FEAT_SIZE = 448
@pytest.mark.timeout(120)
@pytest.mark.parametrize('model_name', list_models(exclude_filters=EXCLUDE_FILTERS[:-2]))
@pytest.mark.parametrize('model_name', list_models(exclude_filters=EXCLUDE_FILTERS[:-1]))
@pytest.mark.parametrize('batch_size', [1])
def test_model_forward(model_name, batch_size):
"""Run a single forward pass with each model"""

@ -17,11 +17,15 @@ paper `DeiT: Data-efficient Image Transformers` - https://arxiv.org/abs/2012.128
Hacked together by / Copyright 2020 Ross Wightman
"""
import torch
import torch.nn as nn
import math
import logging
from functools import partial
from collections import OrderedDict
import torch
import torch.nn as nn
import torch.nn.functional as F
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .helpers import load_pretrained
from .layers import DropPath, to_2tuple, trunc_normal_
@ -29,6 +33,8 @@ from .resnet import resnet26d, resnet50d
from .resnetv2 import ResNetV2, StdConv2dSame
from .registry import register_model
_logger = logging.getLogger(__name__)
def _cfg(url='', **kwargs):
return {
@ -94,7 +100,7 @@ default_cfgs = {
# hybrid models (weights ported from official Google JAX impl)
'vit_base_resnet50_224_in21k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_resnet50_224_in21k-6f7c7740.pth',
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=0.9),
num_classes=21843, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=0.9),
'vit_base_resnet50_384': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_resnet50_384-9fd3c705.pth',
input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0),
@ -106,15 +112,15 @@ default_cfgs = {
'vit_base_resnet50d_224': _cfg(),
# deit models (FB weights)
'deit_tiny_patch16_224': _cfg(
'vit_deit_tiny_patch16_224': _cfg(
url='https://dl.fbaipublicfiles.com/deit/deit_tiny_patch16_224-a1311bcf.pth'),
'deit_small_patch16_224': _cfg(
'vit_deit_small_patch16_224': _cfg(
url='https://dl.fbaipublicfiles.com/deit/deit_small_patch16_224-cd65a155.pth'),
'deit_base_patch16_224': _cfg(
'vit_deit_base_patch16_224': _cfg(
url='https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth',),
'deit_base_patch16_384': _cfg(
url='', # no weights yet
input_size=(3, 384, 384)),
'vit_deit_base_patch16_384': _cfg(
url='https://dl.fbaipublicfiles.com/deit/deit_base_patch16_384-8de9b5d1.pth',
input_size=(3, 384, 384), crop_pct=1.0),
}
@ -253,11 +259,12 @@ class VisionTransformer(nn.Module):
""" Vision Transformer with support for patch or hybrid CNN input stage
"""
def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,
num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, representation_size=None,
drop_rate=0., attn_drop_rate=0., drop_path_rate=0., hybrid_backbone=None, norm_layer=nn.LayerNorm):
num_heads=12, mlp_ratio=4., qkv_bias=True, qk_scale=None, representation_size=None,
drop_rate=0., attn_drop_rate=0., drop_path_rate=0., hybrid_backbone=None, norm_layer=None):
super().__init__()
self.num_classes = num_classes
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
if hybrid_backbone is not None:
self.patch_embed = HybridEmbed(
@ -290,7 +297,7 @@ class VisionTransformer(nn.Module):
self.pre_logits = nn.Identity()
# Classifier head
self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
trunc_normal_(self.pos_embed, std=.02)
trunc_normal_(self.cls_token, std=.02)
@ -338,180 +345,196 @@ class VisionTransformer(nn.Module):
return x
def _conv_filter(state_dict, patch_size=16):
def resize_pos_embed(posemb, posemb_new):
# Rescale the grid of position embeddings when loading from state_dict
# Adapted from
# https://github.com/google-research/vision_transformer/blob/00883dd691c63a6830751563748663526e811cee/vit_jax/checkpoint.py#L224
_logger.info('Resized position embedding: %s to %s', posemb.shape, posemb_new.shape)
ntok_new = posemb_new.shape[1]
if True:
posemb_tok, posemb_grid = posemb[:, :1], posemb[0, 1:]
ntok_new -= 1
else:
posemb_tok, posemb_grid = posemb[:, :0], posemb[0]
gs_old = int(math.sqrt(len(posemb_grid)))
gs_new = int(math.sqrt(ntok_new))
_logger.info('Position embedding grid-size from %s to %s', gs_old, gs_new)
posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2)
posemb_grid = F.interpolate(posemb_grid, size=(gs_new, gs_new), mode='bilinear')
posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_new * gs_new, -1)
posemb = torch.cat([posemb_tok, posemb_grid], dim=1)
state_dict['pos_embed'] = posemb
return state_dict
def checkpoint_filter_fn(state_dict, model):
""" convert patch embedding weight from manual patchify + linear proj to conv"""
out_dict = {}
if 'model' in state_dict:
# for deit models
state_dict = state_dict['model']
for k, v in state_dict.items():
if 'patch_embed.proj.weight' in k:
v = v.reshape((v.shape[0], 3, patch_size, patch_size))
if 'patch_embed.proj.weight' in k and len(v.shape) < 4:
# for old models that I trained prior to conv based patchification
v = v.reshape(model.patch_embed.proj.weight.shape)
elif k == 'pos_embed' and v.shape != model.pos_embed.shape:
# to resize pos embedding when using model at different size from pretrained weights
v = resize_pos_embed(v, model.pos_embed)
out_dict[k] = v
return out_dict
def _create_vision_transformer(variant, pretrained=False, **kwargs):
default_cfg = default_cfgs[variant]
default_num_classes = default_cfg['num_classes']
default_img_size = default_cfg['input_size'][-1]
num_classes = kwargs.pop('num_classes', default_num_classes)
img_size = kwargs.pop('img_size', default_img_size)
repr_size = kwargs.pop('representation_size', None)
if repr_size is not None and num_classes != default_num_classes:
# remove representation layer if fine-tuning
_logger.info("Removing representation layer for fine-tuning.")
repr_size = None
model = VisionTransformer(img_size=img_size, num_classes=num_classes, representation_size=repr_size, **kwargs)
model.default_cfg = default_cfg
if pretrained:
load_pretrained(
model, num_classes=num_classes, in_chans=kwargs.get('in_chans', 3),
filter_fn=partial(checkpoint_filter_fn, model=model))
return model
@register_model
def vit_small_patch16_224(pretrained=False, **kwargs):
model_kwargs = dict(
patch_size=16, embed_dim=768, depth=8, num_heads=8, mlp_ratio=3.,
qkv_bias=False, norm_layer=nn.LayerNorm, **kwargs)
if pretrained:
# NOTE my scale was wrong for original weights, leaving this here until I have better ones for this model
kwargs.setdefault('qk_scale', 768 ** -0.5)
model = VisionTransformer(patch_size=16, embed_dim=768, depth=8, num_heads=8, mlp_ratio=3., **kwargs)
model.default_cfg = default_cfgs['vit_small_patch16_224']
if pretrained:
load_pretrained(
model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3), filter_fn=_conv_filter)
model_kwargs.setdefault('qk_scale', 768 ** -0.5)
model = _create_vision_transformer('vit_small_patch16_224', pretrained=pretrained, **model_kwargs)
return model
@register_model
def vit_base_patch16_224(pretrained=False, **kwargs):
model = VisionTransformer(
patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
model.default_cfg = default_cfgs['vit_base_patch16_224']
if pretrained:
load_pretrained(
model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3), filter_fn=_conv_filter)
model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, **kwargs)
model = _create_vision_transformer('vit_base_patch16_224', pretrained=pretrained, **model_kwargs)
return model
@register_model
def vit_base_patch32_224(pretrained=False, **kwargs):
model = VisionTransformer(
img_size=224, patch_size=32, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
model.default_cfg = default_cfgs['vit_base_patch32_224']
if pretrained:
load_pretrained(model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3))
model_kwargs = dict(patch_size=32, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, **kwargs)
model = _create_vision_transformer('vit_base_patch32_224', pretrained=pretrained, **model_kwargs)
return model
@register_model
def vit_base_patch16_384(pretrained=False, **kwargs):
model = VisionTransformer(
img_size=384, patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
model.default_cfg = default_cfgs['vit_base_patch16_384']
if pretrained:
load_pretrained(model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3))
model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, **kwargs)
model = _create_vision_transformer('vit_base_patch16_384', pretrained=pretrained, **model_kwargs)
return model
@register_model
def vit_base_patch32_384(pretrained=False, **kwargs):
model = VisionTransformer(
img_size=384, patch_size=32, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
model_kwargs = dict(
patch_size=32, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4,
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
model.default_cfg = default_cfgs['vit_base_patch32_384']
if pretrained:
load_pretrained(model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3))
model = _create_vision_transformer('vit_base_patch32_384', pretrained=pretrained, **model_kwargs)
return model
@register_model
def vit_large_patch16_224(pretrained=False, **kwargs):
model = VisionTransformer(
patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True,
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
model.default_cfg = default_cfgs['vit_large_patch16_224']
if pretrained:
load_pretrained(model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3))
model_kwargs = dict(patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, **kwargs)
model = _create_vision_transformer('vit_large_patch16_224', pretrained=pretrained, **model_kwargs)
return model
@register_model
def vit_large_patch32_224(pretrained=False, **kwargs):
model = VisionTransformer(
img_size=224, patch_size=32, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True,
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
model.default_cfg = default_cfgs['vit_large_patch32_224']
if pretrained:
load_pretrained(model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3))
model_kwargs = dict(patch_size=32, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, **kwargs)
model = _create_vision_transformer('vit_large_patch32_224', pretrained=pretrained, **model_kwargs)
return model
@register_model
def vit_large_patch16_384(pretrained=False, **kwargs):
model = VisionTransformer(
img_size=384, patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True,
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
model.default_cfg = default_cfgs['vit_large_patch16_384']
if pretrained:
load_pretrained(model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3))
model_kwargs = dict(patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, **kwargs)
model = _create_vision_transformer('vit_large_patch16_384', pretrained=pretrained, **model_kwargs)
return model
@register_model
def vit_base_patch16_224_in21k(pretrained=False, **kwargs):
num_classes = kwargs.pop('num_classes', 21843)
model = VisionTransformer(
patch_size=16, num_classes=num_classes, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
representation_size=768, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
model.default_cfg = default_cfgs['vit_base_patch16_224_in21k']
if pretrained:
load_pretrained(
model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3), filter_fn=_conv_filter)
model_kwargs = dict(
patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, representation_size=768, **kwargs)
model = _create_vision_transformer('vit_base_patch16_224_in21k', pretrained=pretrained, **model_kwargs)
return model
@register_model
def vit_base_patch16_384_in21k(pretrained=False, **kwargs):
model_kwargs = dict(
patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, representation_size=768, **kwargs)
model = _create_vision_transformer('vit_base_patch16_224_in21k', pretrained=pretrained, **model_kwargs)
return model
@register_model
def vit_base_patch32_224_in21k(pretrained=False, **kwargs):
num_classes = kwargs.pop('num_classes', 21843)
model = VisionTransformer(
img_size=224, num_classes=num_classes, patch_size=32, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4,
qkv_bias=True, representation_size=768, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
model.default_cfg = default_cfgs['vit_base_patch32_224_in21k']
if pretrained:
load_pretrained(model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3))
model_kwargs = dict(
patch_size=32, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, representation_size=768, **kwargs)
model = _create_vision_transformer('vit_base_patch32_224_in21k', pretrained=pretrained, **model_kwargs)
return model
@register_model
def vit_large_patch16_224_in21k(pretrained=False, **kwargs):
num_classes = kwargs.pop('num_classes', 21843)
model = VisionTransformer(
patch_size=16, num_classes=num_classes, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True,
representation_size=1024, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
model.default_cfg = default_cfgs['vit_large_patch16_224_in21k']
if pretrained:
load_pretrained(model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3))
model_kwargs = dict(
patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, representation_size=1024, **kwargs)
model = _create_vision_transformer('vit_large_patch16_224_in21k', pretrained=pretrained, **model_kwargs)
return model
# @register_model
# def vit_large_patch16_384_in21k(pretrained=False, **kwargs):
# model_kwargs = dict(
# patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, representation_size=1024, **kwargs)
# model = _create_vision_transformer('vit_large_patch16_224_in21k', pretrained=pretrained, **model_kwargs)
# return model
@register_model
def vit_large_patch32_224_in21k(pretrained=False, **kwargs):
num_classes = kwargs.get('num_classes', 21843)
model = VisionTransformer(
img_size=224, num_classes=num_classes, patch_size=32, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4,
qkv_bias=True, representation_size=1024, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
model.default_cfg = default_cfgs['vit_large_patch32_224_in21k']
if pretrained:
load_pretrained(model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3))
model_kwargs = dict(
patch_size=32, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, representation_size=1024, **kwargs)
model = _create_vision_transformer('vit_large_patch32_224_in21k', pretrained=pretrained, **model_kwargs)
return model
@register_model
def vit_huge_patch14_224_in21k(pretrained=False, **kwargs):
num_classes = kwargs.pop('num_classes', 21843)
model = VisionTransformer(
img_size=224, patch_size=14, num_classes=num_classes, embed_dim=1280, depth=32, num_heads=16, mlp_ratio=4,
qkv_bias=True, representation_size=1280, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
model.default_cfg = default_cfgs['vit_huge_patch14_224_in21k']
if pretrained:
load_pretrained(model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3))
model_kwargs = dict(
patch_size=14, embed_dim=1280, depth=32, num_heads=16, mlp_ratio=4, representation_size=1280, **kwargs)
model = _create_vision_transformer('vit_huge_patch14_224_in21k', pretrained=pretrained, **model_kwargs)
return model
@register_model
def vit_base_resnet50_224_in21k(pretrained=False, **kwargs):
# create a ResNetV2 w/o pre-activation, that uses StdConv and GroupNorm and has 3 stages, no head
num_classes = kwargs.pop('num_classes', 21843)
backbone = ResNetV2(
layers=(3, 4, 9), preact=False, stem_type='same', conv_layer=StdConv2dSame, num_classes=0, global_pool='')
model = VisionTransformer(
img_size=224, num_classes=num_classes, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
hybrid_backbone=backbone, representation_size=768, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
model.default_cfg = default_cfgs['vit_base_resnet50_224_in21k']
if pretrained:
load_pretrained(model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3))
model_kwargs = dict(
embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, hybrid_backbone=backbone,
representation_size=768, **kwargs)
model = _create_vision_transformer('vit_base_resnet50_224_in21k', pretrained=pretrained, **model_kwargs)
return model
@ -520,12 +543,8 @@ def vit_base_resnet50_384(pretrained=False, **kwargs):
# create a ResNetV2 w/o pre-activation, that uses StdConv and GroupNorm and has 3 stages, no head
backbone = ResNetV2(
layers=(3, 4, 9), preact=False, stem_type='same', conv_layer=StdConv2dSame, num_classes=0, global_pool='')
model = VisionTransformer(
img_size=384, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, hybrid_backbone=backbone,
qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
model.default_cfg = default_cfgs['vit_base_resnet50_384']
if pretrained:
load_pretrained(model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3))
model_kwargs = dict(embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, hybrid_backbone=backbone, **kwargs)
model = _create_vision_transformer('vit_base_resnet50_384', pretrained=pretrained, **model_kwargs)
return model
@ -533,9 +552,8 @@ def vit_base_resnet50_384(pretrained=False, **kwargs):
def vit_small_resnet26d_224(pretrained=False, **kwargs):
pretrained_backbone = kwargs.get('pretrained_backbone', True) # default to True for now, for testing
backbone = resnet26d(pretrained=pretrained_backbone, features_only=True, out_indices=[4])
model = VisionTransformer(
img_size=224, embed_dim=768, depth=8, num_heads=8, mlp_ratio=3, hybrid_backbone=backbone, **kwargs)
model.default_cfg = default_cfgs['vit_small_resnet26d_224']
model_kwargs = dict(embed_dim=768, depth=8, num_heads=8, mlp_ratio=3, hybrid_backbone=backbone, **kwargs)
model = _create_vision_transformer('vit_small_resnet26d_224', pretrained=pretrained, **model_kwargs)
return model
@ -543,9 +561,8 @@ def vit_small_resnet26d_224(pretrained=False, **kwargs):
def vit_small_resnet50d_s3_224(pretrained=False, **kwargs):
pretrained_backbone = kwargs.get('pretrained_backbone', True) # default to True for now, for testing
backbone = resnet50d(pretrained=pretrained_backbone, features_only=True, out_indices=[3])
model = VisionTransformer(
img_size=224, embed_dim=768, depth=8, num_heads=8, mlp_ratio=3, hybrid_backbone=backbone, **kwargs)
model.default_cfg = default_cfgs['vit_small_resnet50d_s3_224']
model_kwargs = dict(embed_dim=768, depth=8, num_heads=8, mlp_ratio=3, hybrid_backbone=backbone, **kwargs)
model = _create_vision_transformer('vit_small_resnet50d_s3_224', pretrained=pretrained, **model_kwargs)
return model
@ -553,9 +570,8 @@ def vit_small_resnet50d_s3_224(pretrained=False, **kwargs):
def vit_base_resnet26d_224(pretrained=False, **kwargs):
pretrained_backbone = kwargs.get('pretrained_backbone', True) # default to True for now, for testing
backbone = resnet26d(pretrained=pretrained_backbone, features_only=True, out_indices=[4])
model = VisionTransformer(
img_size=224, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, hybrid_backbone=backbone, **kwargs)
model.default_cfg = default_cfgs['vit_base_resnet26d_224']
model_kwargs = dict(embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, hybrid_backbone=backbone, **kwargs)
model = _create_vision_transformer('vit_base_resnet26d_224', pretrained=pretrained, **model_kwargs)
return model
@ -563,55 +579,34 @@ def vit_base_resnet26d_224(pretrained=False, **kwargs):
def vit_base_resnet50d_224(pretrained=False, **kwargs):
pretrained_backbone = kwargs.get('pretrained_backbone', True) # default to True for now, for testing
backbone = resnet50d(pretrained=pretrained_backbone, features_only=True, out_indices=[4])
model = VisionTransformer(
img_size=224, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, hybrid_backbone=backbone, **kwargs)
model.default_cfg = default_cfgs['vit_base_resnet50d_224']
model_kwargs = dict(embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, hybrid_backbone=backbone, **kwargs)
model = _create_vision_transformer('vit_base_resnet50d_224', pretrained=pretrained, **model_kwargs)
return model
@register_model
def deit_tiny_patch16_224(pretrained=False, **kwargs):
model = VisionTransformer(
patch_size=16, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4, qkv_bias=True,
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
model.default_cfg = default_cfgs['deit_tiny_patch16_224']
if pretrained:
load_pretrained(
model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3), filter_fn=lambda x: x['model'])
def vit_deit_tiny_patch16_224(pretrained=False, **kwargs):
model_kwargs = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4, **kwargs)
model = _create_vision_transformer('vit_deit_tiny_patch16_224', pretrained=pretrained, **model_kwargs)
return model
@register_model
def deit_small_patch16_224(pretrained=False, **kwargs):
model = VisionTransformer(
patch_size=16, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4, qkv_bias=True,
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
model.default_cfg = default_cfgs['deit_small_patch16_224']
if pretrained:
load_pretrained(
model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3), filter_fn=lambda x: x['model'])
def vit_deit_small_patch16_224(pretrained=False, **kwargs):
model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4, **kwargs)
model = _create_vision_transformer('vit_deit_small_patch16_224', pretrained=pretrained, **model_kwargs)
return model
@register_model
def deit_base_patch16_224(pretrained=False, **kwargs):
model = VisionTransformer(
patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
model.default_cfg = default_cfgs['deit_base_patch16_224']
if pretrained:
load_pretrained(
model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3), filter_fn=lambda x: x['model'])
def vit_deit_base_patch16_224(pretrained=False, **kwargs):
model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, **kwargs)
model = _create_vision_transformer('vit_deit_base_patch16_224', pretrained=pretrained, **model_kwargs)
return model
@register_model
def deit_base_patch16_384(pretrained=False, **kwargs):
model = VisionTransformer(
img_size=384, patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
model.default_cfg = default_cfgs['deit_base_patch16_384']
if pretrained:
load_pretrained(
model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3), filter_fn=lambda x: x['model'])
def vit_deit_base_patch16_384(pretrained=False, **kwargs):
model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, **kwargs)
model = _create_vision_transformer('vit_deit_base_patch16_384', pretrained=pretrained, **model_kwargs)
return model

Loading…
Cancel
Save