|
|
@ -37,7 +37,7 @@ def _cfg(url='', **kwargs):
|
|
|
|
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
|
|
|
|
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
|
|
|
|
'crop_pct': .9, 'interpolation': 'bicubic',
|
|
|
|
'crop_pct': .9, 'interpolation': 'bicubic',
|
|
|
|
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
|
|
|
|
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
|
|
|
|
'first_conv': '', 'classifier': 'head',
|
|
|
|
'first_conv': 'patch_embed.proj', 'classifier': 'head',
|
|
|
|
**kwargs
|
|
|
|
**kwargs
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
@ -48,7 +48,8 @@ default_cfgs = {
|
|
|
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/vit_small_p16_224-15ec54c9.pth',
|
|
|
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/vit_small_p16_224-15ec54c9.pth',
|
|
|
|
),
|
|
|
|
),
|
|
|
|
'vit_base_patch16_224': _cfg(
|
|
|
|
'vit_base_patch16_224': _cfg(
|
|
|
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/vit_base_p16_224-4e355ebd.pth',
|
|
|
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p16_224-80ecf9dd.pth',
|
|
|
|
|
|
|
|
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
|
|
|
|
),
|
|
|
|
),
|
|
|
|
'vit_base_patch16_384': _cfg(
|
|
|
|
'vit_base_patch16_384': _cfg(
|
|
|
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p16_384-83fb41ba.pth',
|
|
|
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p16_384-83fb41ba.pth',
|
|
|
@ -56,7 +57,9 @@ default_cfgs = {
|
|
|
|
'vit_base_patch32_384': _cfg(
|
|
|
|
'vit_base_patch32_384': _cfg(
|
|
|
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p32_384-830016f5.pth',
|
|
|
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p32_384-830016f5.pth',
|
|
|
|
input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0),
|
|
|
|
input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0),
|
|
|
|
'vit_large_patch16_224': _cfg(),
|
|
|
|
'vit_large_patch16_224': _cfg(
|
|
|
|
|
|
|
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p16_224-4ee7a4dc.pth',
|
|
|
|
|
|
|
|
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
|
|
|
|
'vit_large_patch16_384': _cfg(
|
|
|
|
'vit_large_patch16_384': _cfg(
|
|
|
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p16_384-b3be5167.pth',
|
|
|
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p16_384-b3be5167.pth',
|
|
|
|
input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0),
|
|
|
|
input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0),
|
|
|
@ -206,7 +209,7 @@ class VisionTransformer(nn.Module):
|
|
|
|
drop_path_rate=0., hybrid_backbone=None, norm_layer=nn.LayerNorm):
|
|
|
|
drop_path_rate=0., hybrid_backbone=None, norm_layer=nn.LayerNorm):
|
|
|
|
super().__init__()
|
|
|
|
super().__init__()
|
|
|
|
self.num_classes = num_classes
|
|
|
|
self.num_classes = num_classes
|
|
|
|
self.embed_dim = embed_dim
|
|
|
|
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
|
|
|
|
|
|
|
|
|
|
|
|
if hybrid_backbone is not None:
|
|
|
|
if hybrid_backbone is not None:
|
|
|
|
self.patch_embed = HybridEmbed(
|
|
|
|
self.patch_embed = HybridEmbed(
|
|
|
@ -305,10 +308,9 @@ def vit_small_patch16_224(pretrained=False, **kwargs):
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
@register_model
|
|
|
|
def vit_base_patch16_224(pretrained=False, **kwargs):
|
|
|
|
def vit_base_patch16_224(pretrained=False, **kwargs):
|
|
|
|
if pretrained:
|
|
|
|
model = VisionTransformer(
|
|
|
|
# NOTE my scale was wrong for original weights, leaving this here until I have better ones for this model
|
|
|
|
patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
|
|
|
|
kwargs.setdefault('qk_scale', 768 ** -0.5)
|
|
|
|
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
|
|
|
|
model = VisionTransformer(patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, **kwargs)
|
|
|
|
|
|
|
|
model.default_cfg = default_cfgs['vit_base_patch16_224']
|
|
|
|
model.default_cfg = default_cfgs['vit_base_patch16_224']
|
|
|
|
if pretrained:
|
|
|
|
if pretrained:
|
|
|
|
load_pretrained(
|
|
|
|
load_pretrained(
|
|
|
@ -340,8 +342,12 @@ def vit_base_patch32_384(pretrained=False, **kwargs):
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
@register_model
|
|
|
|
def vit_large_patch16_224(pretrained=False, **kwargs):
|
|
|
|
def vit_large_patch16_224(pretrained=False, **kwargs):
|
|
|
|
model = VisionTransformer(patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, **kwargs)
|
|
|
|
model = VisionTransformer(
|
|
|
|
|
|
|
|
patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True,
|
|
|
|
|
|
|
|
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
|
|
|
|
model.default_cfg = default_cfgs['vit_large_patch16_224']
|
|
|
|
model.default_cfg = default_cfgs['vit_large_patch16_224']
|
|
|
|
|
|
|
|
if pretrained:
|
|
|
|
|
|
|
|
load_pretrained(model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3))
|
|
|
|
return model
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|