|
|
@ -79,23 +79,27 @@ default_cfgs = {
|
|
|
|
|
|
|
|
|
|
|
|
# patch models, imagenet21k (weights ported from official JAX impl)
|
|
|
|
# patch models, imagenet21k (weights ported from official JAX impl)
|
|
|
|
'vit_base_patch16_224_in21k': _cfg(
|
|
|
|
'vit_base_patch16_224_in21k': _cfg(
|
|
|
|
url='',
|
|
|
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_patch16_224_in21k-e5005f0a.pth',
|
|
|
|
num_classes=21843, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
|
|
|
|
num_classes=21843, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
|
|
|
|
'vit_base_patch32_224_in21k': _cfg(
|
|
|
|
'vit_base_patch32_224_in21k': _cfg(
|
|
|
|
url='',
|
|
|
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_patch32_224_in21k-8db57226.pth',
|
|
|
|
num_classes=21843, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
|
|
|
|
num_classes=21843, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
|
|
|
|
'vit_large_patch16_224_in21k': _cfg(
|
|
|
|
'vit_large_patch16_224_in21k': _cfg(
|
|
|
|
url='',
|
|
|
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_patch16_224_in21k-606da67d.pth',
|
|
|
|
num_classes=21843, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
|
|
|
|
num_classes=21843, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
|
|
|
|
'vit_large_patch32_224_in21k': _cfg(
|
|
|
|
'vit_large_patch32_224_in21k': _cfg(
|
|
|
|
url='',
|
|
|
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_patch32_224_in21k-9046d2e7.pth',
|
|
|
|
num_classes=21843, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
|
|
|
|
num_classes=21843, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
|
|
|
|
'vit_huge_patch14_224_in21k': _cfg(
|
|
|
|
'vit_huge_patch14_224_in21k': _cfg(
|
|
|
|
url='',
|
|
|
|
url='', # FIXME I have weights for this but > 2GB limit for github release binaries
|
|
|
|
num_classes=21843, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
|
|
|
|
num_classes=21843, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
|
|
|
|
|
|
|
|
|
|
|
|
# hybrid models (weights ported from official JAX impl)
|
|
|
|
# hybrid models (weights ported from official JAX impl)
|
|
|
|
|
|
|
|
'vit_base_resnet50_224_in21k': _cfg(
|
|
|
|
|
|
|
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_resnet50_224_in21k-6f7c7740.pth',
|
|
|
|
|
|
|
|
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=0.9),
|
|
|
|
'vit_base_resnet50_384': _cfg(
|
|
|
|
'vit_base_resnet50_384': _cfg(
|
|
|
|
|
|
|
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_resnet50_384-9fd3c705.pth',
|
|
|
|
input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0),
|
|
|
|
input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0),
|
|
|
|
|
|
|
|
|
|
|
|
# hybrid models (my experiments)
|
|
|
|
# hybrid models (my experiments)
|
|
|
@ -269,6 +273,7 @@ class VisionTransformer(nn.Module):
|
|
|
|
|
|
|
|
|
|
|
|
# Representation layer
|
|
|
|
# Representation layer
|
|
|
|
if representation_size:
|
|
|
|
if representation_size:
|
|
|
|
|
|
|
|
self.num_features = representation_size
|
|
|
|
self.pre_logits = nn.Sequential(OrderedDict([
|
|
|
|
self.pre_logits = nn.Sequential(OrderedDict([
|
|
|
|
('fc', nn.Linear(embed_dim, representation_size)),
|
|
|
|
('fc', nn.Linear(embed_dim, representation_size)),
|
|
|
|
('act', nn.Tanh())
|
|
|
|
('act', nn.Tanh())
|
|
|
@ -315,12 +320,12 @@ class VisionTransformer(nn.Module):
|
|
|
|
for blk in self.blocks:
|
|
|
|
for blk in self.blocks:
|
|
|
|
x = blk(x)
|
|
|
|
x = blk(x)
|
|
|
|
|
|
|
|
|
|
|
|
x = self.norm(x)
|
|
|
|
x = self.norm(x)[:, 0]
|
|
|
|
return x[:, 0]
|
|
|
|
x = self.pre_logits(x)
|
|
|
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
def forward(self, x):
|
|
|
|
x = self.forward_features(x)
|
|
|
|
x = self.forward_features(x)
|
|
|
|
x = self.pre_logits(x)
|
|
|
|
|
|
|
|
x = self.head(x)
|
|
|
|
x = self.head(x)
|
|
|
|
return x
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
|
@ -407,7 +412,7 @@ def vit_large_patch16_224(pretrained=False, **kwargs):
|
|
|
|
@register_model
|
|
|
|
@register_model
|
|
|
|
def vit_large_patch32_224(pretrained=False, **kwargs):
|
|
|
|
def vit_large_patch32_224(pretrained=False, **kwargs):
|
|
|
|
model = VisionTransformer(
|
|
|
|
model = VisionTransformer(
|
|
|
|
img_size=224, patch_size=32, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True,
|
|
|
|
img_size=224, patch_size=32, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True,
|
|
|
|
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
|
|
|
|
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
|
|
|
|
model.default_cfg = default_cfgs['vit_large_patch32_224']
|
|
|
|
model.default_cfg = default_cfgs['vit_large_patch32_224']
|
|
|
|
if pretrained:
|
|
|
|
if pretrained:
|
|
|
@ -418,7 +423,7 @@ def vit_large_patch32_224(pretrained=False, **kwargs):
|
|
|
|
@register_model
|
|
|
|
@register_model
|
|
|
|
def vit_large_patch16_384(pretrained=False, **kwargs):
|
|
|
|
def vit_large_patch16_384(pretrained=False, **kwargs):
|
|
|
|
model = VisionTransformer(
|
|
|
|
model = VisionTransformer(
|
|
|
|
img_size=384, patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True,
|
|
|
|
img_size=384, patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True,
|
|
|
|
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
|
|
|
|
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
|
|
|
|
model.default_cfg = default_cfgs['vit_large_patch16_384']
|
|
|
|
model.default_cfg = default_cfgs['vit_large_patch16_384']
|
|
|
|
if pretrained:
|
|
|
|
if pretrained:
|
|
|
@ -426,22 +431,12 @@ def vit_large_patch16_384(pretrained=False, **kwargs):
|
|
|
|
return model
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
|
|
|
def vit_large_patch32_384(pretrained=False, **kwargs):
|
|
|
|
|
|
|
|
model = VisionTransformer(
|
|
|
|
|
|
|
|
img_size=384, patch_size=32, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True,
|
|
|
|
|
|
|
|
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
|
|
|
|
|
|
|
|
model.default_cfg = default_cfgs['vit_large_patch32_384']
|
|
|
|
|
|
|
|
if pretrained:
|
|
|
|
|
|
|
|
load_pretrained(model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3))
|
|
|
|
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
@register_model
|
|
|
|
def vit_base_patch16_224_in21k(pretrained=False, **kwargs):
|
|
|
|
def vit_base_patch16_224_in21k(pretrained=False, **kwargs):
|
|
|
|
|
|
|
|
num_classes = kwargs.get('num_classes', 21843)
|
|
|
|
model = VisionTransformer(
|
|
|
|
model = VisionTransformer(
|
|
|
|
patch_size=16, num_classes=21843, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
|
|
|
|
patch_size=16, num_classes=num_classes, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
|
|
|
|
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
|
|
|
|
representation_size=768, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
|
|
|
|
model.default_cfg = default_cfgs['vit_base_patch16_224_in21k']
|
|
|
|
model.default_cfg = default_cfgs['vit_base_patch16_224_in21k']
|
|
|
|
if pretrained:
|
|
|
|
if pretrained:
|
|
|
|
load_pretrained(
|
|
|
|
load_pretrained(
|
|
|
@ -451,9 +446,10 @@ def vit_base_patch16_224_in21k(pretrained=False, **kwargs):
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
@register_model
|
|
|
|
def vit_base_patch32_224_in21k(pretrained=False, **kwargs):
|
|
|
|
def vit_base_patch32_224_in21k(pretrained=False, **kwargs):
|
|
|
|
|
|
|
|
num_classes = kwargs.get('num_classes', 21843)
|
|
|
|
model = VisionTransformer(
|
|
|
|
model = VisionTransformer(
|
|
|
|
img_size=224, num_classes=21843, patch_size=32, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4,
|
|
|
|
img_size=224, num_classes=num_classes, patch_size=32, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4,
|
|
|
|
qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
|
|
|
|
qkv_bias=True, representation_size=768, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
|
|
|
|
model.default_cfg = default_cfgs['vit_base_patch32_224_in21k']
|
|
|
|
model.default_cfg = default_cfgs['vit_base_patch32_224_in21k']
|
|
|
|
if pretrained:
|
|
|
|
if pretrained:
|
|
|
|
load_pretrained(model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3))
|
|
|
|
load_pretrained(model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3))
|
|
|
@ -462,9 +458,10 @@ def vit_base_patch32_224_in21k(pretrained=False, **kwargs):
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
@register_model
|
|
|
|
def vit_large_patch16_224_in21k(pretrained=False, **kwargs):
|
|
|
|
def vit_large_patch16_224_in21k(pretrained=False, **kwargs):
|
|
|
|
|
|
|
|
num_classes = kwargs.get('num_classes', 21843)
|
|
|
|
model = VisionTransformer(
|
|
|
|
model = VisionTransformer(
|
|
|
|
patch_size=16, num_classes=21843, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True,
|
|
|
|
patch_size=16, num_classes=num_classes, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True,
|
|
|
|
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
|
|
|
|
representation_size=1024, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
|
|
|
|
model.default_cfg = default_cfgs['vit_large_patch16_224_in21k']
|
|
|
|
model.default_cfg = default_cfgs['vit_large_patch16_224_in21k']
|
|
|
|
if pretrained:
|
|
|
|
if pretrained:
|
|
|
|
load_pretrained(model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3))
|
|
|
|
load_pretrained(model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3))
|
|
|
@ -473,9 +470,10 @@ def vit_large_patch16_224_in21k(pretrained=False, **kwargs):
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
@register_model
|
|
|
|
def vit_large_patch32_224_in21k(pretrained=False, **kwargs):
|
|
|
|
def vit_large_patch32_224_in21k(pretrained=False, **kwargs):
|
|
|
|
|
|
|
|
num_classes = kwargs.get('num_classes', 21843)
|
|
|
|
model = VisionTransformer(
|
|
|
|
model = VisionTransformer(
|
|
|
|
img_size=224, num_classes=21843, patch_size=32, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True,
|
|
|
|
img_size=224, num_classes=num_classes, patch_size=32, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4,
|
|
|
|
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
|
|
|
|
qkv_bias=True, representation_size=1024, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
|
|
|
|
model.default_cfg = default_cfgs['vit_large_patch32_224_in21k']
|
|
|
|
model.default_cfg = default_cfgs['vit_large_patch32_224_in21k']
|
|
|
|
if pretrained:
|
|
|
|
if pretrained:
|
|
|
|
load_pretrained(model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3))
|
|
|
|
load_pretrained(model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3))
|
|
|
@ -484,15 +482,31 @@ def vit_large_patch32_224_in21k(pretrained=False, **kwargs):
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
@register_model
|
|
|
|
def vit_huge_patch14_224_in21k(pretrained=False, **kwargs):
|
|
|
|
def vit_huge_patch14_224_in21k(pretrained=False, **kwargs):
|
|
|
|
|
|
|
|
num_classes = kwargs.get('num_classes', 21843)
|
|
|
|
model = VisionTransformer(
|
|
|
|
model = VisionTransformer(
|
|
|
|
img_size=224, patch_size=14, num_classes=21843, embed_dim=1280, depth=32, num_heads=16, mlp_ratio=4,
|
|
|
|
img_size=224, patch_size=14, num_classes=num_classes, embed_dim=1280, depth=32, num_heads=16, mlp_ratio=4,
|
|
|
|
qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
|
|
|
|
qkv_bias=True, representation_size=1280, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
|
|
|
|
model.default_cfg = default_cfgs['vit_huge_patch14_224_in21k']
|
|
|
|
model.default_cfg = default_cfgs['vit_huge_patch14_224_in21k']
|
|
|
|
if pretrained:
|
|
|
|
if pretrained:
|
|
|
|
load_pretrained(model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3))
|
|
|
|
load_pretrained(model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3))
|
|
|
|
return model
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
|
|
|
def vit_base_resnet50_224_in21k(pretrained=False, **kwargs):
|
|
|
|
|
|
|
|
# create a ResNetV2 w/o pre-activation, that uses StdConv and GroupNorm and has 3 stages, no head
|
|
|
|
|
|
|
|
num_classes = kwargs.get('num_classes', 21843)
|
|
|
|
|
|
|
|
backbone = ResNetV2(
|
|
|
|
|
|
|
|
layers=(3, 4, 9), preact=False, stem_type='same', conv_layer=StdConv2dSame, num_classes=0, global_pool='')
|
|
|
|
|
|
|
|
model = VisionTransformer(
|
|
|
|
|
|
|
|
img_size=224, num_classes=num_classes, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
|
|
|
|
|
|
|
|
hybrid_backbone=backbone, representation_size=768, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
|
|
|
|
|
|
|
|
model.default_cfg = default_cfgs['vit_base_resnet50_224_in21k']
|
|
|
|
|
|
|
|
if pretrained:
|
|
|
|
|
|
|
|
load_pretrained(model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3))
|
|
|
|
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
@register_model
|
|
|
|
def vit_base_resnet50_384(pretrained=False, **kwargs):
|
|
|
|
def vit_base_resnet50_384(pretrained=False, **kwargs):
|
|
|
|
# create a ResNetV2 w/o pre-activation, that uses StdConv and GroupNorm and has 3 stages, no head
|
|
|
|
# create a ResNetV2 w/o pre-activation, that uses StdConv and GroupNorm and has 3 stages, no head
|
|
|
|