|
|
@ -9,6 +9,12 @@ keep file sizes sane.
|
|
|
|
|
|
|
|
|
|
|
|
Hacked together by / Copyright 2020 Ross Wightman
|
|
|
|
Hacked together by / Copyright 2020 Ross Wightman
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
|
|
|
|
from copy import deepcopy
|
|
|
|
|
|
|
|
from functools import partial
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import torch
|
|
|
|
|
|
|
|
import torch.nn as nn
|
|
|
|
|
|
|
|
|
|
|
|
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
|
|
|
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
|
|
|
from .layers import StdConv2dSame, StdConv2d, to_2tuple
|
|
|
|
from .layers import StdConv2dSame, StdConv2d, to_2tuple
|
|
|
|
from .resnet import resnet26d, resnet50d
|
|
|
|
from .resnet import resnet26d, resnet50d
|
|
|
@ -41,39 +47,14 @@ default_cfgs = {
|
|
|
|
|
|
|
|
|
|
|
|
# hybrid in-1k models (mostly untrained, experimental configs w/ resnetv2 stdconv backbones)
|
|
|
|
# hybrid in-1k models (mostly untrained, experimental configs w/ resnetv2 stdconv backbones)
|
|
|
|
'vit_tiny_r_s16_p8_224': _cfg(),
|
|
|
|
'vit_tiny_r_s16_p8_224': _cfg(),
|
|
|
|
'vit_tiny_r_s16_p8_384': _cfg(
|
|
|
|
'vit_small_r_s16_p8_224': _cfg(),
|
|
|
|
input_size=(3, 384, 384), crop_pct=1.0),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
'vit_small_r_s16_p8_224': _cfg(
|
|
|
|
|
|
|
|
crop_pct=1.0),
|
|
|
|
|
|
|
|
'vit_small_r_s16_p8_384': _cfg(
|
|
|
|
|
|
|
|
input_size=(3, 384, 384), crop_pct=1.0),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
'vit_small_r20_s16_p2_224': _cfg(),
|
|
|
|
'vit_small_r20_s16_p2_224': _cfg(),
|
|
|
|
'vit_small_r20_s16_p2_384': _cfg(
|
|
|
|
|
|
|
|
input_size=(3, 384, 384), crop_pct=1.0),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
'vit_small_r20_s16_224': _cfg(),
|
|
|
|
'vit_small_r20_s16_224': _cfg(),
|
|
|
|
'vit_small_r20_s16_384': _cfg(
|
|
|
|
|
|
|
|
input_size=(3, 384, 384), crop_pct=1.0),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
'vit_small_r26_s32_224': _cfg(),
|
|
|
|
'vit_small_r26_s32_224': _cfg(),
|
|
|
|
'vit_small_r26_s32_384': _cfg(
|
|
|
|
|
|
|
|
input_size=(3, 384, 384), crop_pct=1.0),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
'vit_base_r20_s16_224': _cfg(),
|
|
|
|
'vit_base_r20_s16_224': _cfg(),
|
|
|
|
'vit_base_r20_s16_384': _cfg(
|
|
|
|
|
|
|
|
input_size=(3, 384, 384), crop_pct=1.0),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
'vit_base_r26_s32_224': _cfg(),
|
|
|
|
'vit_base_r26_s32_224': _cfg(),
|
|
|
|
'vit_base_r26_s32_384': _cfg(
|
|
|
|
|
|
|
|
input_size=(3, 384, 384), crop_pct=1.0),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
'vit_base_r50_s16_224': _cfg(),
|
|
|
|
'vit_base_r50_s16_224': _cfg(),
|
|
|
|
|
|
|
|
|
|
|
|
'vit_large_r50_s32_224': _cfg(),
|
|
|
|
'vit_large_r50_s32_224': _cfg(),
|
|
|
|
'vit_large_r50_s32_384': _cfg(
|
|
|
|
|
|
|
|
input_size=(3, 384, 384), crop_pct=1.0),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# hybrid models (using timm resnet backbones)
|
|
|
|
# hybrid models (using timm resnet backbones)
|
|
|
|
'vit_small_resnet26d_224': _cfg(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
|
|
|
|
'vit_small_resnet26d_224': _cfg(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
|
|
|
@ -83,6 +64,56 @@ default_cfgs = {
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class HybridEmbed(nn.Module):
|
|
|
|
|
|
|
|
""" CNN Feature Map Embedding
|
|
|
|
|
|
|
|
Extract feature map from CNN, flatten, project to embedding dim.
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
def __init__(self, backbone, img_size=224, patch_size=1, feature_size=None, in_chans=3, embed_dim=768):
|
|
|
|
|
|
|
|
super().__init__()
|
|
|
|
|
|
|
|
assert isinstance(backbone, nn.Module)
|
|
|
|
|
|
|
|
img_size = to_2tuple(img_size)
|
|
|
|
|
|
|
|
patch_size = to_2tuple(patch_size)
|
|
|
|
|
|
|
|
self.img_size = img_size
|
|
|
|
|
|
|
|
self.patch_size = patch_size
|
|
|
|
|
|
|
|
self.backbone = backbone
|
|
|
|
|
|
|
|
if feature_size is None:
|
|
|
|
|
|
|
|
with torch.no_grad():
|
|
|
|
|
|
|
|
# NOTE Most reliable way of determining output dims is to run forward pass
|
|
|
|
|
|
|
|
training = backbone.training
|
|
|
|
|
|
|
|
if training:
|
|
|
|
|
|
|
|
backbone.eval()
|
|
|
|
|
|
|
|
o = self.backbone(torch.zeros(1, in_chans, img_size[0], img_size[1]))
|
|
|
|
|
|
|
|
if isinstance(o, (list, tuple)):
|
|
|
|
|
|
|
|
o = o[-1] # last feature if backbone outputs list/tuple of features
|
|
|
|
|
|
|
|
feature_size = o.shape[-2:]
|
|
|
|
|
|
|
|
feature_dim = o.shape[1]
|
|
|
|
|
|
|
|
backbone.train(training)
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
feature_size = to_2tuple(feature_size)
|
|
|
|
|
|
|
|
if hasattr(self.backbone, 'feature_info'):
|
|
|
|
|
|
|
|
feature_dim = self.backbone.feature_info.channels()[-1]
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
feature_dim = self.backbone.num_features
|
|
|
|
|
|
|
|
assert feature_size[0] % patch_size[0] == 0 and feature_size[1] % patch_size[1] == 0
|
|
|
|
|
|
|
|
self.num_patches = feature_size[0] // patch_size[0] * feature_size[1] // patch_size[1]
|
|
|
|
|
|
|
|
self.proj = nn.Conv2d(feature_dim, embed_dim, kernel_size=patch_size, stride=patch_size)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
|
|
|
|
x = self.backbone(x)
|
|
|
|
|
|
|
|
if isinstance(x, (list, tuple)):
|
|
|
|
|
|
|
|
x = x[-1] # last feature if backbone outputs list/tuple of features
|
|
|
|
|
|
|
|
x = self.proj(x).flatten(2).transpose(1, 2)
|
|
|
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _create_vision_transformer_hybrid(variant, backbone, pretrained=False, **kwargs):
|
|
|
|
|
|
|
|
default_cfg = deepcopy(default_cfgs[variant])
|
|
|
|
|
|
|
|
embed_layer = partial(HybridEmbed, backbone=backbone)
|
|
|
|
|
|
|
|
kwargs.setdefault('patch_size', 1) # default patch size for hybrid models if not set
|
|
|
|
|
|
|
|
return _create_vision_transformer(
|
|
|
|
|
|
|
|
variant, pretrained=pretrained, default_cfg=default_cfg, embed_layer=embed_layer, **kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _resnetv2(layers=(3, 4, 9), **kwargs):
|
|
|
|
def _resnetv2(layers=(3, 4, 9), **kwargs):
|
|
|
|
""" ResNet-V2 backbone helper"""
|
|
|
|
""" ResNet-V2 backbone helper"""
|
|
|
|
padding_same = kwargs.get('padding_same', True)
|
|
|
|
padding_same = kwargs.get('padding_same', True)
|
|
|
@ -108,9 +139,9 @@ def vit_base_r50_s16_224_in21k(pretrained=False, **kwargs):
|
|
|
|
ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
|
|
|
|
ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
backbone = _resnetv2(layers=(3, 4, 9), **kwargs)
|
|
|
|
backbone = _resnetv2(layers=(3, 4, 9), **kwargs)
|
|
|
|
model_kwargs = dict(
|
|
|
|
model_kwargs = dict(embed_dim=768, depth=12, num_heads=12, representation_size=768, **kwargs)
|
|
|
|
embed_dim=768, depth=12, num_heads=12, hybrid_backbone=backbone, representation_size=768, **kwargs)
|
|
|
|
model = _create_vision_transformer_hybrid(
|
|
|
|
model = _create_vision_transformer('vit_base_r50_s16_224_in21k', pretrained=pretrained, **model_kwargs)
|
|
|
|
'vit_base_r50_s16_224_in21k', backbone=backbone, pretrained=pretrained, **model_kwargs)
|
|
|
|
return model
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -120,8 +151,9 @@ def vit_base_r50_s16_384(pretrained=False, **kwargs):
|
|
|
|
ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer.
|
|
|
|
ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer.
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
backbone = _resnetv2((3, 4, 9), **kwargs)
|
|
|
|
backbone = _resnetv2((3, 4, 9), **kwargs)
|
|
|
|
model_kwargs = dict(embed_dim=768, depth=12, num_heads=12, hybrid_backbone=backbone, **kwargs)
|
|
|
|
model_kwargs = dict(embed_dim=768, depth=12, num_heads=12, **kwargs)
|
|
|
|
model = _create_vision_transformer('vit_base_r50_s16_384', pretrained=pretrained, **model_kwargs)
|
|
|
|
model = _create_vision_transformer_hybrid(
|
|
|
|
|
|
|
|
'vit_base_r50_s16_384', backbone=backbone, pretrained=pretrained, **model_kwargs)
|
|
|
|
return model
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -130,20 +162,9 @@ def vit_tiny_r_s16_p8_224(pretrained=False, **kwargs):
|
|
|
|
""" R+ViT-Ti/S16 w/ 8x8 patch hybrid @ 224 x 224.
|
|
|
|
""" R+ViT-Ti/S16 w/ 8x8 patch hybrid @ 224 x 224.
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
backbone = _resnetv2(layers=(), **kwargs)
|
|
|
|
backbone = _resnetv2(layers=(), **kwargs)
|
|
|
|
model_kwargs = dict(
|
|
|
|
model_kwargs = dict(patch_size=8, embed_dim=192, depth=12, num_heads=3, **kwargs)
|
|
|
|
patch_size=8, embed_dim=192, depth=12, num_heads=3, hybrid_backbone=backbone, **kwargs)
|
|
|
|
model = _create_vision_transformer_hybrid(
|
|
|
|
model = _create_vision_transformer('vit_tiny_r_s16_p8_224', pretrained=pretrained, **model_kwargs)
|
|
|
|
'vit_tiny_r_s16_p8_224', backbone=backbone, pretrained=pretrained, **model_kwargs)
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
|
|
|
def vit_tiny_r_s16_p8_384(pretrained=False, **kwargs):
|
|
|
|
|
|
|
|
""" R+ViT-Ti/S16 w/ 8x8 patch hybrid @ 224 x 224.
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
backbone = _resnetv2(layers=(), **kwargs)
|
|
|
|
|
|
|
|
model_kwargs = dict(
|
|
|
|
|
|
|
|
patch_size=8, embed_dim=192, depth=12, num_heads=3, hybrid_backbone=backbone, **kwargs)
|
|
|
|
|
|
|
|
model = _create_vision_transformer('vit_tiny_r_s16_p8_384', pretrained=pretrained, **model_kwargs)
|
|
|
|
|
|
|
|
return model
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -152,43 +173,21 @@ def vit_small_r_s16_p8_224(pretrained=False, **kwargs):
|
|
|
|
""" R+ViT-S/S16 w/ 8x8 patch hybrid @ 224 x 224.
|
|
|
|
""" R+ViT-S/S16 w/ 8x8 patch hybrid @ 224 x 224.
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
backbone = _resnetv2(layers=(), **kwargs)
|
|
|
|
backbone = _resnetv2(layers=(), **kwargs)
|
|
|
|
model_kwargs = dict(
|
|
|
|
model_kwargs = dict(patch_size=8, embed_dim=384, depth=12, num_heads=6, **kwargs)
|
|
|
|
patch_size=8, embed_dim=384, depth=12, num_heads=6, hybrid_backbone=backbone, **kwargs)
|
|
|
|
model = _create_vision_transformer_hybrid(
|
|
|
|
model = _create_vision_transformer('vit_small_r_s16_p8_224', pretrained=pretrained, **model_kwargs)
|
|
|
|
'vit_small_r_s16_p8_224', backbone=backbone, pretrained=pretrained, **model_kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
return model
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
|
|
|
def vit_small_r_s16_p8_384(pretrained=False, **kwargs):
|
|
|
|
|
|
|
|
""" R+ViT-S/S16 w/ 8x8 patch hybrid @ 224 x 224.
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
backbone = _resnetv2(layers=(), **kwargs)
|
|
|
|
|
|
|
|
model_kwargs = dict(
|
|
|
|
|
|
|
|
patch_size=8, embed_dim=384, depth=12, num_heads=6, hybrid_backbone=backbone, **kwargs)
|
|
|
|
|
|
|
|
model = _create_vision_transformer('vit_small_r_s16_p8_384', pretrained=pretrained, **model_kwargs)
|
|
|
|
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
@register_model
|
|
|
|
def vit_small_r20_s16_p2_224(pretrained=False, **kwargs):
|
|
|
|
def vit_small_r20_s16_p2_224(pretrained=False, **kwargs):
|
|
|
|
""" R52+ViT-S/S16 w/ 2x2 patch hybrid @ 224 x 224.
|
|
|
|
""" R52+ViT-S/S16 w/ 2x2 patch hybrid @ 224 x 224.
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
backbone = _resnetv2((2, 4), **kwargs)
|
|
|
|
backbone = _resnetv2((2, 4), **kwargs)
|
|
|
|
model_kwargs = dict(
|
|
|
|
model_kwargs = dict(patch_size=2, embed_dim=384, depth=12, num_heads=6, **kwargs)
|
|
|
|
patch_size=2, embed_dim=384, depth=12, num_heads=6, hybrid_backbone=backbone, **kwargs)
|
|
|
|
model = _create_vision_transformer_hybrid(
|
|
|
|
model = _create_vision_transformer('vit_small_r20_s16_p2_224', pretrained=pretrained, **model_kwargs)
|
|
|
|
'vit_small_r20_s16_p2_224', backbone=backbone, pretrained=pretrained, **model_kwargs)
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
|
|
|
def vit_small_r20_s16_p2_384(pretrained=False, **kwargs):
|
|
|
|
|
|
|
|
""" R20+ViT-S/S16 w/ 2x2 Patch hybrid @ 384x384.
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
backbone = _resnetv2((2, 4), **kwargs)
|
|
|
|
|
|
|
|
model_kwargs = dict(
|
|
|
|
|
|
|
|
embed_dim=384, patch_size=2, depth=12, num_heads=6, hybrid_backbone=backbone, **kwargs)
|
|
|
|
|
|
|
|
model = _create_vision_transformer('vit_small_r20_s16_p2_384', pretrained=pretrained, **model_kwargs)
|
|
|
|
|
|
|
|
return model
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -197,18 +196,9 @@ def vit_small_r20_s16_224(pretrained=False, **kwargs):
|
|
|
|
""" R20+ViT-S/S16 hybrid.
|
|
|
|
""" R20+ViT-S/S16 hybrid.
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
backbone = _resnetv2((2, 2, 2), **kwargs)
|
|
|
|
backbone = _resnetv2((2, 2, 2), **kwargs)
|
|
|
|
model_kwargs = dict(embed_dim=384, depth=12, num_heads=6, hybrid_backbone=backbone, **kwargs)
|
|
|
|
model_kwargs = dict(embed_dim=384, depth=12, num_heads=6, **kwargs)
|
|
|
|
model = _create_vision_transformer('vit_small_r20_s16_224', pretrained=pretrained, **model_kwargs)
|
|
|
|
model = _create_vision_transformer_hybrid(
|
|
|
|
return model
|
|
|
|
'vit_small_r20_s16_224', backbone=backbone, pretrained=pretrained, **model_kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
|
|
|
def vit_small_r20_s16_384(pretrained=False, **kwargs):
|
|
|
|
|
|
|
|
""" R20+ViT-S/S16 hybrid @ 384x384.
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
backbone = _resnetv2((2, 2, 2), **kwargs)
|
|
|
|
|
|
|
|
model_kwargs = dict(embed_dim=384, depth=12, num_heads=6, hybrid_backbone=backbone, **kwargs)
|
|
|
|
|
|
|
|
model = _create_vision_transformer('vit_small_r20_s16_384', pretrained=pretrained, **model_kwargs)
|
|
|
|
|
|
|
|
return model
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -217,18 +207,9 @@ def vit_small_r26_s32_224(pretrained=False, **kwargs):
|
|
|
|
""" R26+ViT-S/S32 hybrid.
|
|
|
|
""" R26+ViT-S/S32 hybrid.
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
backbone = _resnetv2((2, 2, 2, 2), **kwargs)
|
|
|
|
backbone = _resnetv2((2, 2, 2, 2), **kwargs)
|
|
|
|
model_kwargs = dict(embed_dim=384, depth=12, num_heads=6, hybrid_backbone=backbone, **kwargs)
|
|
|
|
model_kwargs = dict(embed_dim=384, depth=12, num_heads=6, **kwargs)
|
|
|
|
model = _create_vision_transformer('vit_small_r26_s32_224', pretrained=pretrained, **model_kwargs)
|
|
|
|
model = _create_vision_transformer_hybrid(
|
|
|
|
return model
|
|
|
|
'vit_small_r26_s32_224', backbone=backbone, pretrained=pretrained, **model_kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
|
|
|
def vit_small_r26_s32_384(pretrained=False, **kwargs):
|
|
|
|
|
|
|
|
""" R26+ViT-S/S32 hybrid @ 384x384.
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
backbone = _resnetv2((2, 2, 2, 2), **kwargs)
|
|
|
|
|
|
|
|
model_kwargs = dict(embed_dim=384, depth=12, num_heads=6, hybrid_backbone=backbone, **kwargs)
|
|
|
|
|
|
|
|
model = _create_vision_transformer('vit_small_r26_s32_384', pretrained=pretrained, **model_kwargs)
|
|
|
|
|
|
|
|
return model
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -237,18 +218,9 @@ def vit_base_r20_s16_224(pretrained=False, **kwargs):
|
|
|
|
""" R20+ViT-B/S16 hybrid.
|
|
|
|
""" R20+ViT-B/S16 hybrid.
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
backbone = _resnetv2((2, 2, 2), **kwargs)
|
|
|
|
backbone = _resnetv2((2, 2, 2), **kwargs)
|
|
|
|
model_kwargs = dict(embed_dim=768, depth=12, num_heads=12, hybrid_backbone=backbone, **kwargs)
|
|
|
|
model_kwargs = dict(embed_dim=768, depth=12, num_heads=12, **kwargs)
|
|
|
|
model = _create_vision_transformer('vit_base_r20_s16_224', pretrained=pretrained, **model_kwargs)
|
|
|
|
model = _create_vision_transformer_hybrid(
|
|
|
|
return model
|
|
|
|
'vit_base_r20_s16_224', backbone=backbone, pretrained=pretrained, **model_kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
|
|
|
def vit_base_r20_s16_384(pretrained=False, **kwargs):
|
|
|
|
|
|
|
|
""" R20+ViT-B/S16 hybrid.
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
backbone = _resnetv2((2, 2, 2), **kwargs)
|
|
|
|
|
|
|
|
model_kwargs = dict(embed_dim=768, depth=12, num_heads=12, hybrid_backbone=backbone, **kwargs)
|
|
|
|
|
|
|
|
model = _create_vision_transformer('vit_base_r20_s16_384', pretrained=pretrained, **model_kwargs)
|
|
|
|
|
|
|
|
return model
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -257,18 +229,9 @@ def vit_base_r26_s32_224(pretrained=False, **kwargs):
|
|
|
|
""" R26+ViT-B/S32 hybrid.
|
|
|
|
""" R26+ViT-B/S32 hybrid.
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
backbone = _resnetv2((2, 2, 2, 2), **kwargs)
|
|
|
|
backbone = _resnetv2((2, 2, 2, 2), **kwargs)
|
|
|
|
model_kwargs = dict(embed_dim=768, depth=12, num_heads=12, hybrid_backbone=backbone, **kwargs)
|
|
|
|
model_kwargs = dict(embed_dim=768, depth=12, num_heads=12, **kwargs)
|
|
|
|
model = _create_vision_transformer('vit_base_r26_s32_224', pretrained=pretrained, **model_kwargs)
|
|
|
|
model = _create_vision_transformer_hybrid(
|
|
|
|
return model
|
|
|
|
'vit_base_r26_s32_224', backbone=backbone, pretrained=pretrained, **model_kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
|
|
|
def vit_base_r26_s32_384(pretrained=False, **kwargs):
|
|
|
|
|
|
|
|
""" R26+ViT-B/S32 hybrid.
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
backbone = _resnetv2((2, 2, 2, 2), **kwargs)
|
|
|
|
|
|
|
|
model_kwargs = dict(embed_dim=768, depth=12, num_heads=12, hybrid_backbone=backbone, **kwargs)
|
|
|
|
|
|
|
|
model = _create_vision_transformer('vit_base_r26_s32_384', pretrained=pretrained, **model_kwargs)
|
|
|
|
|
|
|
|
return model
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -277,8 +240,9 @@ def vit_base_r50_s16_224(pretrained=False, **kwargs):
|
|
|
|
""" R50+ViT-B/S16 hybrid from original paper (https://arxiv.org/abs/2010.11929).
|
|
|
|
""" R50+ViT-B/S16 hybrid from original paper (https://arxiv.org/abs/2010.11929).
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
backbone = _resnetv2((3, 4, 9), **kwargs)
|
|
|
|
backbone = _resnetv2((3, 4, 9), **kwargs)
|
|
|
|
model_kwargs = dict(embed_dim=768, depth=12, num_heads=12, hybrid_backbone=backbone, **kwargs)
|
|
|
|
model_kwargs = dict(embed_dim=768, depth=12, num_heads=12, **kwargs)
|
|
|
|
model = _create_vision_transformer('vit_base_r50_s16_224', pretrained=pretrained, **model_kwargs)
|
|
|
|
model = _create_vision_transformer_hybrid(
|
|
|
|
|
|
|
|
'vit_base_r50_s16_224', backbone=backbone, pretrained=pretrained, **model_kwargs)
|
|
|
|
return model
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -287,29 +251,9 @@ def vit_large_r50_s32_224(pretrained=False, **kwargs):
|
|
|
|
""" R50+ViT-L/S32 hybrid.
|
|
|
|
""" R50+ViT-L/S32 hybrid.
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
backbone = _resnetv2((3, 4, 6, 3), **kwargs)
|
|
|
|
backbone = _resnetv2((3, 4, 6, 3), **kwargs)
|
|
|
|
model_kwargs = dict(embed_dim=768, depth=12, num_heads=12, hybrid_backbone=backbone, **kwargs)
|
|
|
|
model_kwargs = dict(embed_dim=768, depth=12, num_heads=12, **kwargs)
|
|
|
|
model = _create_vision_transformer('vit_large_r50_s32_224', pretrained=pretrained, **model_kwargs)
|
|
|
|
model = _create_vision_transformer_hybrid(
|
|
|
|
return model
|
|
|
|
'vit_large_r50_s32_224', backbone=backbone, pretrained=pretrained, **model_kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
|
|
|
def vit_large_r50_s32_224_in21k(pretrained=False, **kwargs):
|
|
|
|
|
|
|
|
""" R50+ViT-L/S32 hybrid.
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
backbone = _resnetv2((3, 4, 6, 3), **kwargs)
|
|
|
|
|
|
|
|
model_kwargs = dict(
|
|
|
|
|
|
|
|
embed_dim=768, depth=12, num_heads=12, representation_size=768, hybrid_backbone=backbone, **kwargs)
|
|
|
|
|
|
|
|
model = _create_vision_transformer('vit_large_r50_s32_224_in21k', pretrained=pretrained, **model_kwargs)
|
|
|
|
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
|
|
|
def vit_large_r50_s32_384(pretrained=False, **kwargs):
|
|
|
|
|
|
|
|
""" R50+ViT-L/S32 hybrid.
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
backbone = _resnetv2((3, 4, 6, 3), **kwargs)
|
|
|
|
|
|
|
|
model_kwargs = dict(embed_dim=768, depth=12, num_heads=12, hybrid_backbone=backbone, **kwargs)
|
|
|
|
|
|
|
|
model = _create_vision_transformer('vit_large_r50_s32_384', pretrained=pretrained, **model_kwargs)
|
|
|
|
|
|
|
|
return model
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -318,8 +262,9 @@ def vit_small_resnet26d_224(pretrained=False, **kwargs):
|
|
|
|
""" Custom ViT small hybrid w/ ResNet26D stride 32. No pretrained weights.
|
|
|
|
""" Custom ViT small hybrid w/ ResNet26D stride 32. No pretrained weights.
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
backbone = resnet26d(pretrained=pretrained, in_chans=kwargs.get('in_chans', 3), features_only=True, out_indices=[4])
|
|
|
|
backbone = resnet26d(pretrained=pretrained, in_chans=kwargs.get('in_chans', 3), features_only=True, out_indices=[4])
|
|
|
|
model_kwargs = dict(embed_dim=768, depth=8, num_heads=8, mlp_ratio=3, hybrid_backbone=backbone, **kwargs)
|
|
|
|
model_kwargs = dict(embed_dim=768, depth=8, num_heads=8, mlp_ratio=3, **kwargs)
|
|
|
|
model = _create_vision_transformer('vit_small_resnet26d_224', pretrained=pretrained, **model_kwargs)
|
|
|
|
model = _create_vision_transformer_hybrid(
|
|
|
|
|
|
|
|
'vit_small_resnet26d_224', backbone=backbone, pretrained=pretrained, **model_kwargs)
|
|
|
|
return model
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -328,8 +273,9 @@ def vit_small_resnet50d_s16_224(pretrained=False, **kwargs):
|
|
|
|
""" Custom ViT small hybrid w/ ResNet50D 3-stages, stride 16. No pretrained weights.
|
|
|
|
""" Custom ViT small hybrid w/ ResNet50D 3-stages, stride 16. No pretrained weights.
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
backbone = resnet50d(pretrained=pretrained, in_chans=kwargs.get('in_chans', 3), features_only=True, out_indices=[3])
|
|
|
|
backbone = resnet50d(pretrained=pretrained, in_chans=kwargs.get('in_chans', 3), features_only=True, out_indices=[3])
|
|
|
|
model_kwargs = dict(embed_dim=768, depth=8, num_heads=8, mlp_ratio=3, hybrid_backbone=backbone, **kwargs)
|
|
|
|
model_kwargs = dict(embed_dim=768, depth=8, num_heads=8, mlp_ratio=3, **kwargs)
|
|
|
|
model = _create_vision_transformer('vit_small_resnet50d_s16_224', pretrained=pretrained, **model_kwargs)
|
|
|
|
model = _create_vision_transformer_hybrid(
|
|
|
|
|
|
|
|
'vit_small_resnet50d_s16_224', backbone=backbone, pretrained=pretrained, **model_kwargs)
|
|
|
|
return model
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -338,8 +284,9 @@ def vit_base_resnet26d_224(pretrained=False, **kwargs):
|
|
|
|
""" Custom ViT base hybrid w/ ResNet26D stride 32. No pretrained weights.
|
|
|
|
""" Custom ViT base hybrid w/ ResNet26D stride 32. No pretrained weights.
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
backbone = resnet26d(pretrained=pretrained, in_chans=kwargs.get('in_chans', 3), features_only=True, out_indices=[4])
|
|
|
|
backbone = resnet26d(pretrained=pretrained, in_chans=kwargs.get('in_chans', 3), features_only=True, out_indices=[4])
|
|
|
|
model_kwargs = dict(embed_dim=768, depth=12, num_heads=12, hybrid_backbone=backbone, **kwargs)
|
|
|
|
model_kwargs = dict(embed_dim=768, depth=12, num_heads=12, **kwargs)
|
|
|
|
model = _create_vision_transformer('vit_base_resnet26d_224', pretrained=pretrained, **model_kwargs)
|
|
|
|
model = _create_vision_transformer_hybrid(
|
|
|
|
|
|
|
|
'vit_base_resnet26d_224', backbone=backbone, pretrained=pretrained, **model_kwargs)
|
|
|
|
return model
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -348,6 +295,7 @@ def vit_base_resnet50d_224(pretrained=False, **kwargs):
|
|
|
|
""" Custom ViT base hybrid w/ ResNet50D stride 32. No pretrained weights.
|
|
|
|
""" Custom ViT base hybrid w/ ResNet50D stride 32. No pretrained weights.
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
backbone = resnet50d(pretrained=pretrained, in_chans=kwargs.get('in_chans', 3), features_only=True, out_indices=[4])
|
|
|
|
backbone = resnet50d(pretrained=pretrained, in_chans=kwargs.get('in_chans', 3), features_only=True, out_indices=[4])
|
|
|
|
model_kwargs = dict(embed_dim=768, depth=12, num_heads=12, hybrid_backbone=backbone, **kwargs)
|
|
|
|
model_kwargs = dict(embed_dim=768, depth=12, num_heads=12, **kwargs)
|
|
|
|
model = _create_vision_transformer('vit_base_resnet50d_224', pretrained=pretrained, **model_kwargs)
|
|
|
|
model = _create_vision_transformer_hybrid(
|
|
|
|
|
|
|
|
'vit_base_resnet50d_224', backbone=backbone, pretrained=pretrained, **model_kwargs)
|
|
|
|
return model
|
|
|
|
return model
|