Expand scope of testing for non-std vision transformer / mlp models. Some related cleanup and create fn cleanup for all vision transformer and mlp models. More CoaT weights.

pull/637/head
Ross Wightman 4 years ago
parent 18bf520ad1
commit bfc72f75d3

@ -26,29 +26,41 @@ if 'GITHUB_ACTIONS' in os.environ: # and 'Linux' in platform.system():
EXCLUDE_FILTERS = [ EXCLUDE_FILTERS = [
'*efficientnet_l2*', '*resnext101_32x48d', '*in21k', '*152x4_bitm', '*101x3_bitm', '*efficientnet_l2*', '*resnext101_32x48d', '*in21k', '*152x4_bitm', '*101x3_bitm',
'*nfnet_f3*', '*nfnet_f4*', '*nfnet_f5*', '*nfnet_f6*', '*nfnet_f7*', '*nfnet_f3*', '*nfnet_f4*', '*nfnet_f5*', '*nfnet_f6*', '*nfnet_f7*',
'*resnetrs350*', '*resnetrs420*'] + NON_STD_FILTERS '*resnetrs350*', '*resnetrs420*']
else: else:
EXCLUDE_FILTERS = NON_STD_FILTERS EXCLUDE_FILTERS = []
MAX_FWD_SIZE = 384 TARGET_FWD_SIZE = MAX_FWD_SIZE = 384
MAX_BWD_SIZE = 128 TARGET_BWD_SIZE = 128
MAX_BWD_SIZE = 384
MAX_FWD_FEAT_SIZE = 448 MAX_FWD_FEAT_SIZE = 448
def _get_input_size(model, target=None):
default_cfg = model.default_cfg
input_size = default_cfg['input_size']
if 'fixed_input_size' in default_cfg and default_cfg['fixed_input_size']:
return input_size
if 'min_input_size' in default_cfg:
if target and max(input_size) > target:
input_size = default_cfg['min_input_size']
else:
if target and max(input_size) > target:
input_size = tuple([min(x, target) for x in input_size])
return input_size
@pytest.mark.timeout(120) @pytest.mark.timeout(120)
@pytest.mark.parametrize('model_name', list_models(exclude_filters=EXCLUDE_FILTERS[:-NUM_NON_STD])) @pytest.mark.parametrize('model_name', list_models(exclude_filters=EXCLUDE_FILTERS))
@pytest.mark.parametrize('batch_size', [1]) @pytest.mark.parametrize('batch_size', [1])
def test_model_forward(model_name, batch_size): def test_model_forward(model_name, batch_size):
"""Run a single forward pass with each model""" """Run a single forward pass with each model"""
model = create_model(model_name, pretrained=False) model = create_model(model_name, pretrained=False)
model.eval() model.eval()
input_size = model.default_cfg['input_size'] input_size = _get_input_size(model, TARGET_FWD_SIZE)
if any([x > MAX_FWD_SIZE for x in input_size]): if max(input_size) > MAX_FWD_SIZE:
if is_model_default_key(model_name, 'fixed_input_size'): pytest.skip("Fixed input size model > limit.")
pytest.skip("Fixed input size model > limit.")
# cap forward test at max res 384 * 384 to keep resource down
input_size = tuple([min(x, MAX_FWD_SIZE) for x in input_size])
inputs = torch.randn((batch_size, *input_size)) inputs = torch.randn((batch_size, *input_size))
outputs = model(inputs) outputs = model(inputs)
@ -63,20 +75,16 @@ def test_model_backward(model_name, batch_size):
"""Run a single forward pass with each model""" """Run a single forward pass with each model"""
model = create_model(model_name, pretrained=False, num_classes=42) model = create_model(model_name, pretrained=False, num_classes=42)
num_params = sum([x.numel() for x in model.parameters()]) num_params = sum([x.numel() for x in model.parameters()])
model.eval() model.train()
input_size = model.default_cfg['input_size'] input_size = _get_input_size(model, TARGET_BWD_SIZE)
if not is_model_default_key(model_name, 'fixed_input_size'): if max(input_size) > MAX_BWD_SIZE:
min_input_size = get_model_default_value(model_name, 'min_input_size') pytest.skip("Fixed input size model > limit.")
if min_input_size is not None:
input_size = min_input_size
else:
if any([x > MAX_BWD_SIZE for x in input_size]):
# cap backward test at 128 * 128 to keep resource usage down
input_size = tuple([min(x, MAX_BWD_SIZE) for x in input_size])
inputs = torch.randn((batch_size, *input_size)) inputs = torch.randn((batch_size, *input_size))
outputs = model(inputs) outputs = model(inputs)
if isinstance(outputs, tuple):
outputs = torch.cat(outputs)
outputs.mean().backward() outputs.mean().backward()
for n, x in model.named_parameters(): for n, x in model.named_parameters():
assert x.grad is not None, f'No gradient for {n}' assert x.grad is not None, f'No gradient for {n}'
@ -168,12 +176,9 @@ def test_model_forward_torchscript(model_name, batch_size):
model = create_model(model_name, pretrained=False) model = create_model(model_name, pretrained=False)
model.eval() model.eval()
if has_model_default_key(model_name, 'fixed_input_size'): input_size = _get_input_size(model, 128)
input_size = get_model_default_value(model_name, 'input_size') if max(input_size) > MAX_FWD_SIZE: # NOTE using MAX_FWD_SIZE as the final limit is intentional
elif has_model_default_key(model_name, 'min_input_size'): pytest.skip("Fixed input size model > limit.")
input_size = get_model_default_value(model_name, 'min_input_size')
else:
input_size = (3, 128, 128) # jit compile is already a bit slow and we've tested normal res already...
model = torch.jit.script(model) model = torch.jit.script(model)
outputs = model(torch.randn((batch_size, *input_size))) outputs = model(torch.randn((batch_size, *input_size)))
@ -184,7 +189,7 @@ def test_model_forward_torchscript(model_name, batch_size):
EXCLUDE_FEAT_FILTERS = [ EXCLUDE_FEAT_FILTERS = [
'*pruned*', # hopefully fix at some point '*pruned*', # hopefully fix at some point
] ] + NON_STD_FILTERS
if 'GITHUB_ACTIONS' in os.environ: # and 'Linux' in platform.system(): if 'GITHUB_ACTIONS' in os.environ: # and 'Linux' in platform.system():
# GitHub Linux runner is slower and hits memory limits sooner than MacOS, exclude bigger models # GitHub Linux runner is slower and hits memory limits sooner than MacOS, exclude bigger models
EXCLUDE_FEAT_FILTERS += ['*resnext101_32x32d', '*resnext101_32x16d'] EXCLUDE_FEAT_FILTERS += ['*resnext101_32x32d', '*resnext101_32x16d']
@ -200,12 +205,9 @@ def test_model_forward_features(model_name, batch_size):
expected_channels = model.feature_info.channels() expected_channels = model.feature_info.channels()
assert len(expected_channels) >= 4 # all models here should have at least 4 feature levels by default, some 5 or 6 assert len(expected_channels) >= 4 # all models here should have at least 4 feature levels by default, some 5 or 6
if has_model_default_key(model_name, 'fixed_input_size'): input_size = _get_input_size(model, 96) # jit compile is already a bit slow and we've tested normal res already...
input_size = get_model_default_value(model_name, 'input_size') if max(input_size) > MAX_FWD_SIZE: # NOTE using MAX_FWD_SIZE as the final limit is intentional
elif has_model_default_key(model_name, 'min_input_size'): pytest.skip("Fixed input size model > limit.")
input_size = get_model_default_value(model_name, 'min_input_size')
else:
input_size = (3, 96, 96) # jit compile is already a bit slow and we've tested normal res already...
outputs = model(torch.randn((batch_size, *input_size))) outputs = model(torch.randn((batch_size, *input_size)))
assert len(expected_channels) == len(outputs) assert len(expected_channels) == len(outputs)

@ -16,8 +16,8 @@ from .hrnet import *
from .inception_resnet_v2 import * from .inception_resnet_v2 import *
from .inception_v3 import * from .inception_v3 import *
from .inception_v4 import * from .inception_v4 import *
from .levitc import *
from .levit import * from .levit import *
#from .levit import *
from .mlp_mixer import * from .mlp_mixer import *
from .mobilenetv3 import * from .mobilenetv3 import *
from .nasnet import * from .nasnet import *

@ -306,26 +306,15 @@ def checkpoint_filter_fn(state_dict, model=None):
return checkpoint_no_module return checkpoint_no_module
def _create_cait(variant, pretrained=False, default_cfg=None, **kwargs): def _create_cait(variant, pretrained=False, **kwargs):
if default_cfg is None:
default_cfg = deepcopy(default_cfgs[variant])
overlay_external_default_cfg(default_cfg, kwargs)
default_num_classes = default_cfg['num_classes']
default_img_size = default_cfg['input_size'][-2:]
num_classes = kwargs.pop('num_classes', default_num_classes)
img_size = kwargs.pop('img_size', default_img_size)
if kwargs.get('features_only', None): if kwargs.get('features_only', None):
raise RuntimeError('features_only not implemented for Vision Transformer models.') raise RuntimeError('features_only not implemented for Vision Transformer models.')
model = build_model_with_cfg( model = build_model_with_cfg(
Cait, variant, pretrained, Cait, variant, pretrained,
default_cfg=default_cfg, default_cfg=default_cfgs[variant],
img_size=img_size,
num_classes=num_classes,
pretrained_filter_fn=checkpoint_filter_fn, pretrained_filter_fn=checkpoint_filter_fn,
**kwargs) **kwargs)
return model return model

@ -7,19 +7,19 @@ Official CoaT code at: https://github.com/mlpc-ucsd/CoaT
Modified from timm/models/vision_transformer.py Modified from timm/models/vision_transformer.py
""" """
from typing import Tuple, Dict, Any, Optional from copy import deepcopy
from functools import partial
from typing import Tuple, List
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.models.helpers import load_pretrained from .helpers import build_model_with_cfg, overlay_external_default_cfg
from timm.models.layers import PatchEmbed, Mlp, DropPath, to_2tuple, trunc_normal_ from .layers import PatchEmbed, Mlp, DropPath, to_2tuple, trunc_normal_
from timm.models.registry import register_model from .registry import register_model
from functools import partial
from torch import nn
__all__ = [ __all__ = [
"coat_tiny", "coat_tiny",
@ -34,7 +34,7 @@ def _cfg_coat(url='', **kwargs):
return { return {
'url': url, 'url': url,
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
'crop_pct': .9, 'interpolation': 'bicubic', 'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True,
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
'first_conv': 'patch_embed1.proj', 'classifier': 'head', 'first_conv': 'patch_embed1.proj', 'classifier': 'head',
**kwargs **kwargs
@ -42,15 +42,21 @@ def _cfg_coat(url='', **kwargs):
default_cfgs = { default_cfgs = {
'coat_tiny': _cfg_coat(), 'coat_tiny': _cfg_coat(
'coat_mini': _cfg_coat(), url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-coat-weights/coat_tiny-473c2a20.pth'
),
'coat_mini': _cfg_coat(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-coat-weights/coat_mini-2c6baf49.pth'
),
'coat_lite_tiny': _cfg_coat( 'coat_lite_tiny': _cfg_coat(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-coat-weights/coat_lite_tiny-461b07a7.pth' url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-coat-weights/coat_lite_tiny-461b07a7.pth'
), ),
'coat_lite_mini': _cfg_coat( 'coat_lite_mini': _cfg_coat(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-coat-weights/coat_lite_mini-d7842000.pth' url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-coat-weights/coat_lite_mini-d7842000.pth'
), ),
'coat_lite_small': _cfg_coat(), 'coat_lite_small': _cfg_coat(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-coat-weights/coat_lite_small-fea1d5a1.pth'
),
} }
@ -120,11 +126,11 @@ class ConvRelPosEnc(nn.Module):
class FactorAtt_ConvRelPosEnc(nn.Module): class FactorAtt_ConvRelPosEnc(nn.Module):
""" Factorized attention with convolutional relative position encoding class. """ """ Factorized attention with convolutional relative position encoding class. """
def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., shared_crpe=None): def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0., shared_crpe=None):
super().__init__() super().__init__()
self.num_heads = num_heads self.num_heads = num_heads
head_dim = dim // num_heads head_dim = dim // num_heads
self.scale = qk_scale or head_dim ** -0.5 self.scale = head_dim ** -0.5
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop) # Note: attn_drop is actually not used. self.attn_drop = nn.Dropout(attn_drop) # Note: attn_drop is actually not used.
@ -190,9 +196,8 @@ class ConvPosEnc(nn.Module):
class SerialBlock(nn.Module): class SerialBlock(nn.Module):
""" Serial block class. """ Serial block class.
Note: In this implementation, each serial block only contains a conv-attention and a FFN (MLP) module. """ Note: In this implementation, each serial block only contains a conv-attention and a FFN (MLP) module. """
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0.,
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, shared_cpe=None, shared_crpe=None):
shared_cpe=None, shared_crpe=None):
super().__init__() super().__init__()
# Conv-Attention. # Conv-Attention.
@ -200,8 +205,7 @@ class SerialBlock(nn.Module):
self.norm1 = norm_layer(dim) self.norm1 = norm_layer(dim)
self.factoratt_crpe = FactorAtt_ConvRelPosEnc( self.factoratt_crpe = FactorAtt_ConvRelPosEnc(
dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop, shared_crpe=shared_crpe)
shared_crpe=shared_crpe)
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
# MLP. # MLP.
@ -226,27 +230,24 @@ class SerialBlock(nn.Module):
class ParallelBlock(nn.Module): class ParallelBlock(nn.Module):
""" Parallel block class. """ """ Parallel block class. """
def __init__(self, dims, num_heads, mlp_ratios=[], qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., def __init__(self, dims, num_heads, mlp_ratios=[], qkv_bias=False, drop=0., attn_drop=0.,
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, shared_crpes=None):
shared_cpes=None, shared_crpes=None):
super().__init__() super().__init__()
# Conv-Attention. # Conv-Attention.
self.cpes = shared_cpes
self.norm12 = norm_layer(dims[1]) self.norm12 = norm_layer(dims[1])
self.norm13 = norm_layer(dims[2]) self.norm13 = norm_layer(dims[2])
self.norm14 = norm_layer(dims[3]) self.norm14 = norm_layer(dims[3])
self.factoratt_crpe2 = FactorAtt_ConvRelPosEnc( self.factoratt_crpe2 = FactorAtt_ConvRelPosEnc(
dims[1], num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, dims[1], num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop,
shared_crpe=shared_crpes[1] shared_crpe=shared_crpes[1]
) )
self.factoratt_crpe3 = FactorAtt_ConvRelPosEnc( self.factoratt_crpe3 = FactorAtt_ConvRelPosEnc(
dims[2], num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, dims[2], num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop,
shared_crpe=shared_crpes[2] shared_crpe=shared_crpes[2]
) )
self.factoratt_crpe4 = FactorAtt_ConvRelPosEnc( self.factoratt_crpe4 = FactorAtt_ConvRelPosEnc(
dims[3], num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, dims[3], num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop,
shared_crpe=shared_crpes[3] shared_crpe=shared_crpes[3]
) )
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
@ -262,15 +263,15 @@ class ParallelBlock(nn.Module):
self.mlp2 = self.mlp3 = self.mlp4 = Mlp( self.mlp2 = self.mlp3 = self.mlp4 = Mlp(
in_features=dims[1], hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) in_features=dims[1], hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
def upsample(self, x, factor, size): def upsample(self, x, factor: float, size: Tuple[int, int]):
""" Feature map up-sampling. """ """ Feature map up-sampling. """
return self.interpolate(x, scale_factor=factor, size=size) return self.interpolate(x, scale_factor=factor, size=size)
def downsample(self, x, factor, size): def downsample(self, x, factor: float, size: Tuple[int, int]):
""" Feature map down-sampling. """ """ Feature map down-sampling. """
return self.interpolate(x, scale_factor=1.0/factor, size=size) return self.interpolate(x, scale_factor=1.0/factor, size=size)
def interpolate(self, x, scale_factor, size): def interpolate(self, x, scale_factor: float, size: Tuple[int, int]):
""" Feature map interpolation. """ """ Feature map interpolation. """
B, N, C = x.shape B, N, C = x.shape
H, W = size H, W = size
@ -280,33 +281,28 @@ class ParallelBlock(nn.Module):
img_tokens = x[:, 1:, :] img_tokens = x[:, 1:, :]
img_tokens = img_tokens.transpose(1, 2).reshape(B, C, H, W) img_tokens = img_tokens.transpose(1, 2).reshape(B, C, H, W)
img_tokens = F.interpolate(img_tokens, scale_factor=scale_factor, mode='bilinear') img_tokens = F.interpolate(
img_tokens, scale_factor=scale_factor, recompute_scale_factor=False, mode='bilinear', align_corners=False)
img_tokens = img_tokens.reshape(B, C, -1).transpose(1, 2) img_tokens = img_tokens.reshape(B, C, -1).transpose(1, 2)
out = torch.cat((cls_token, img_tokens), dim=1) out = torch.cat((cls_token, img_tokens), dim=1)
return out return out
def forward(self, x1, x2, x3, x4, sizes): def forward(self, x1, x2, x3, x4, sizes: List[Tuple[int, int]]):
_, (H2, W2), (H3, W3), (H4, W4) = sizes _, S2, S3, S4 = sizes
# Conv-Attention.
x2 = self.cpes[1](x2, size=(H2, W2)) # Note: x1 is ignored.
x3 = self.cpes[2](x3, size=(H3, W3))
x4 = self.cpes[3](x4, size=(H4, W4))
cur2 = self.norm12(x2) cur2 = self.norm12(x2)
cur3 = self.norm13(x3) cur3 = self.norm13(x3)
cur4 = self.norm14(x4) cur4 = self.norm14(x4)
cur2 = self.factoratt_crpe2(cur2, size=(H2, W2)) cur2 = self.factoratt_crpe2(cur2, size=S2)
cur3 = self.factoratt_crpe3(cur3, size=(H3, W3)) cur3 = self.factoratt_crpe3(cur3, size=S3)
cur4 = self.factoratt_crpe4(cur4, size=(H4, W4)) cur4 = self.factoratt_crpe4(cur4, size=S4)
upsample3_2 = self.upsample(cur3, factor=2, size=(H3, W3)) upsample3_2 = self.upsample(cur3, factor=2., size=S3)
upsample4_3 = self.upsample(cur4, factor=2, size=(H4, W4)) upsample4_3 = self.upsample(cur4, factor=2., size=S4)
upsample4_2 = self.upsample(cur4, factor=4, size=(H4, W4)) upsample4_2 = self.upsample(cur4, factor=4., size=S4)
downsample2_3 = self.downsample(cur2, factor=2, size=(H2, W2)) downsample2_3 = self.downsample(cur2, factor=2., size=S2)
downsample3_4 = self.downsample(cur3, factor=2, size=(H3, W3)) downsample3_4 = self.downsample(cur3, factor=2., size=S3)
downsample2_4 = self.downsample(cur2, factor=4, size=(H2, W2)) downsample2_4 = self.downsample(cur2, factor=4., size=S2)
cur2 = cur2 + upsample3_2 + upsample4_2 cur2 = cur2 + upsample3_2 + upsample4_2
cur3 = cur3 + upsample4_3 + downsample2_3 cur3 = cur3 + upsample4_3 + downsample2_3
cur4 = cur4 + downsample3_4 + downsample2_4 cur4 = cur4 + downsample3_4 + downsample2_4
@ -330,11 +326,11 @@ class ParallelBlock(nn.Module):
class CoaT(nn.Module): class CoaT(nn.Module):
""" CoaT class. """ """ CoaT class. """
def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dims=[0, 0, 0, 0], def __init__(
serial_depths=[0, 0, 0, 0], parallel_depth=0, self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dims=(0, 0, 0, 0),
num_heads=0, mlp_ratios=[0, 0, 0, 0], qkv_bias=True, qk_scale=None, drop_rate=0., attn_drop_rate=0., serial_depths=(0, 0, 0, 0), parallel_depth=0, num_heads=0, mlp_ratios=(0, 0, 0, 0), qkv_bias=True,
drop_path_rate=0., norm_layer=partial(nn.LayerNorm, eps=1e-6), drop_rate=0., attn_drop_rate=0., drop_path_rate=0., norm_layer=partial(nn.LayerNorm, eps=1e-6),
return_interm_layers=False, out_features = None, crpe_window=None, **kwargs): return_interm_layers=False, out_features=None, crpe_window=None, **kwargs):
super().__init__() super().__init__()
crpe_window = crpe_window or {3: 2, 5: 3, 7: 3} crpe_window = crpe_window or {3: 2, 5: 3, 7: 3}
self.return_interm_layers = return_interm_layers self.return_interm_layers = return_interm_layers
@ -342,17 +338,18 @@ class CoaT(nn.Module):
self.num_classes = num_classes self.num_classes = num_classes
# Patch embeddings. # Patch embeddings.
img_size = to_2tuple(img_size)
self.patch_embed1 = PatchEmbed( self.patch_embed1 = PatchEmbed(
img_size=img_size, patch_size=patch_size, in_chans=in_chans, img_size=img_size, patch_size=patch_size, in_chans=in_chans,
embed_dim=embed_dims[0], norm_layer=nn.LayerNorm) embed_dim=embed_dims[0], norm_layer=nn.LayerNorm)
self.patch_embed2 = PatchEmbed( self.patch_embed2 = PatchEmbed(
img_size=img_size // 4, patch_size=2, in_chans=embed_dims[0], img_size=[x // 4 for x in img_size], patch_size=2, in_chans=embed_dims[0],
embed_dim=embed_dims[1], norm_layer=nn.LayerNorm) embed_dim=embed_dims[1], norm_layer=nn.LayerNorm)
self.patch_embed3 = PatchEmbed( self.patch_embed3 = PatchEmbed(
img_size=img_size // 8, patch_size=2, in_chans=embed_dims[1], img_size=[x // 8 for x in img_size], patch_size=2, in_chans=embed_dims[1],
embed_dim=embed_dims[2], norm_layer=nn.LayerNorm) embed_dim=embed_dims[2], norm_layer=nn.LayerNorm)
self.patch_embed4 = PatchEmbed( self.patch_embed4 = PatchEmbed(
img_size=img_size // 16, patch_size=2, in_chans=embed_dims[2], img_size=[x // 16 for x in img_size], patch_size=2, in_chans=embed_dims[2],
embed_dim=embed_dims[3], norm_layer=nn.LayerNorm) embed_dim=embed_dims[3], norm_layer=nn.LayerNorm)
# Class tokens. # Class tokens.
@ -380,7 +377,7 @@ class CoaT(nn.Module):
# Serial blocks 1. # Serial blocks 1.
self.serial_blocks1 = nn.ModuleList([ self.serial_blocks1 = nn.ModuleList([
SerialBlock( SerialBlock(
dim=embed_dims[0], num_heads=num_heads, mlp_ratio=mlp_ratios[0], qkv_bias=qkv_bias, qk_scale=qk_scale, dim=embed_dims[0], num_heads=num_heads, mlp_ratio=mlp_ratios[0], qkv_bias=qkv_bias,
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr, norm_layer=norm_layer, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr, norm_layer=norm_layer,
shared_cpe=self.cpe1, shared_crpe=self.crpe1 shared_cpe=self.cpe1, shared_crpe=self.crpe1
) )
@ -390,7 +387,7 @@ class CoaT(nn.Module):
# Serial blocks 2. # Serial blocks 2.
self.serial_blocks2 = nn.ModuleList([ self.serial_blocks2 = nn.ModuleList([
SerialBlock( SerialBlock(
dim=embed_dims[1], num_heads=num_heads, mlp_ratio=mlp_ratios[1], qkv_bias=qkv_bias, qk_scale=qk_scale, dim=embed_dims[1], num_heads=num_heads, mlp_ratio=mlp_ratios[1], qkv_bias=qkv_bias,
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr, norm_layer=norm_layer, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr, norm_layer=norm_layer,
shared_cpe=self.cpe2, shared_crpe=self.crpe2 shared_cpe=self.cpe2, shared_crpe=self.crpe2
) )
@ -400,7 +397,7 @@ class CoaT(nn.Module):
# Serial blocks 3. # Serial blocks 3.
self.serial_blocks3 = nn.ModuleList([ self.serial_blocks3 = nn.ModuleList([
SerialBlock( SerialBlock(
dim=embed_dims[2], num_heads=num_heads, mlp_ratio=mlp_ratios[2], qkv_bias=qkv_bias, qk_scale=qk_scale, dim=embed_dims[2], num_heads=num_heads, mlp_ratio=mlp_ratios[2], qkv_bias=qkv_bias,
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr, norm_layer=norm_layer, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr, norm_layer=norm_layer,
shared_cpe=self.cpe3, shared_crpe=self.crpe3 shared_cpe=self.cpe3, shared_crpe=self.crpe3
) )
@ -410,7 +407,7 @@ class CoaT(nn.Module):
# Serial blocks 4. # Serial blocks 4.
self.serial_blocks4 = nn.ModuleList([ self.serial_blocks4 = nn.ModuleList([
SerialBlock( SerialBlock(
dim=embed_dims[3], num_heads=num_heads, mlp_ratio=mlp_ratios[3], qkv_bias=qkv_bias, qk_scale=qk_scale, dim=embed_dims[3], num_heads=num_heads, mlp_ratio=mlp_ratios[3], qkv_bias=qkv_bias,
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr, norm_layer=norm_layer, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr, norm_layer=norm_layer,
shared_cpe=self.cpe4, shared_crpe=self.crpe4 shared_cpe=self.cpe4, shared_crpe=self.crpe4
) )
@ -422,10 +419,9 @@ class CoaT(nn.Module):
if self.parallel_depth > 0: if self.parallel_depth > 0:
self.parallel_blocks = nn.ModuleList([ self.parallel_blocks = nn.ModuleList([
ParallelBlock( ParallelBlock(
dims=embed_dims, num_heads=num_heads, mlp_ratios=mlp_ratios, qkv_bias=qkv_bias, qk_scale=qk_scale, dims=embed_dims, num_heads=num_heads, mlp_ratios=mlp_ratios, qkv_bias=qkv_bias,
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr, norm_layer=norm_layer, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr, norm_layer=norm_layer,
shared_cpes=[self.cpe1, self.cpe2, self.cpe3, self.cpe4], shared_crpes=(self.crpe1, self.crpe2, self.crpe3, self.crpe4)
shared_crpes=[self.crpe1, self.crpe2, self.crpe3, self.crpe4]
) )
for _ in range(parallel_depth)] for _ in range(parallel_depth)]
) )
@ -434,9 +430,11 @@ class CoaT(nn.Module):
# Classification head(s). # Classification head(s).
if not self.return_interm_layers: if not self.return_interm_layers:
self.norm1 = norm_layer(embed_dims[0]) if self.parallel_blocks is not None:
self.norm2 = norm_layer(embed_dims[1]) self.norm2 = norm_layer(embed_dims[1])
self.norm3 = norm_layer(embed_dims[2]) self.norm3 = norm_layer(embed_dims[2])
else:
self.norm2 = self.norm3 = None
self.norm4 = norm_layer(embed_dims[3]) self.norm4 = norm_layer(embed_dims[3])
if self.parallel_depth > 0: if self.parallel_depth > 0:
@ -546,6 +544,7 @@ class CoaT(nn.Module):
# Parallel blocks. # Parallel blocks.
for blk in self.parallel_blocks: for blk in self.parallel_blocks:
x2, x3, x4 = self.cpe2(x2, (H2, W2)), self.cpe3(x3, (H3, W3)), self.cpe4(x4, (H4, W4))
x1, x2, x3, x4 = blk(x1, x2, x3, x4, sizes=[(H1, W1), (H2, W2), (H3, W3), (H4, W4)]) x1, x2, x3, x4 = blk(x1, x2, x3, x4, sizes=[(H1, W1), (H2, W2), (H3, W3), (H4, W4)])
if not torch.jit.is_scripting() and self.return_interm_layers: if not torch.jit.is_scripting() and self.return_interm_layers:
@ -590,52 +589,70 @@ class CoaT(nn.Module):
return x return x
def checkpoint_filter_fn(state_dict, model):
out_dict = {}
for k, v in state_dict.items():
# original model had unused norm layers, removing them requires filtering pretrained checkpoints
if k.startswith('norm1') or \
(model.norm2 is None and k.startswith('norm2')) or \
(model.norm3 is None and k.startswith('norm3')):
continue
out_dict[k] = v
return out_dict
def _create_coat(variant, pretrained=False, default_cfg=None, **kwargs):
if kwargs.get('features_only', None):
raise RuntimeError('features_only not implemented for Vision Transformer models.')
model = build_model_with_cfg(
CoaT, variant, pretrained,
default_cfg=default_cfgs[variant],
pretrained_filter_fn=checkpoint_filter_fn,
**kwargs)
return model
@register_model @register_model
def coat_tiny(pretrained=False, **kwargs): def coat_tiny(pretrained=False, **kwargs):
model = CoaT( model_cfg = dict(
patch_size=4, embed_dims=[152, 152, 152, 152], serial_depths=[2, 2, 2, 2], parallel_depth=6, patch_size=4, embed_dims=[152, 152, 152, 152], serial_depths=[2, 2, 2, 2], parallel_depth=6,
num_heads=8, mlp_ratios=[4, 4, 4, 4], **kwargs) num_heads=8, mlp_ratios=[4, 4, 4, 4], **kwargs)
model.default_cfg = default_cfgs['coat_tiny'] model = _create_coat('coat_tiny', pretrained=pretrained, **model_cfg)
return model return model
@register_model @register_model
def coat_mini(pretrained=False, **kwargs): def coat_mini(pretrained=False, **kwargs):
model = CoaT( model_cfg = dict(
patch_size=4, embed_dims=[152, 216, 216, 216], serial_depths=[2, 2, 2, 2], parallel_depth=6, patch_size=4, embed_dims=[152, 216, 216, 216], serial_depths=[2, 2, 2, 2], parallel_depth=6,
num_heads=8, mlp_ratios=[4, 4, 4, 4], **kwargs) num_heads=8, mlp_ratios=[4, 4, 4, 4], **kwargs)
model.default_cfg = default_cfgs['coat_mini'] model = _create_coat('coat_mini', pretrained=pretrained, **model_cfg)
return model return model
@register_model @register_model
def coat_lite_tiny(pretrained=False, **kwargs): def coat_lite_tiny(pretrained=False, **kwargs):
model = CoaT( model_cfg = dict(
patch_size=4, embed_dims=[64, 128, 256, 320], serial_depths=[2, 2, 2, 2], parallel_depth=0, patch_size=4, embed_dims=[64, 128, 256, 320], serial_depths=[2, 2, 2, 2], parallel_depth=0,
num_heads=8, mlp_ratios=[8, 8, 4, 4], **kwargs) num_heads=8, mlp_ratios=[8, 8, 4, 4], **kwargs)
# FIXME use builder model = _create_coat('coat_lite_tiny', pretrained=pretrained, **model_cfg)
model.default_cfg = default_cfgs['coat_lite_tiny']
if pretrained:
load_pretrained(model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3))
return model return model
@register_model @register_model
def coat_lite_mini(pretrained=False, **kwargs): def coat_lite_mini(pretrained=False, **kwargs):
model = CoaT( model_cfg = dict(
patch_size=4, embed_dims=[64, 128, 320, 512], serial_depths=[2, 2, 2, 2], parallel_depth=0, patch_size=4, embed_dims=[64, 128, 320, 512], serial_depths=[2, 2, 2, 2], parallel_depth=0,
num_heads=8, mlp_ratios=[8, 8, 4, 4], **kwargs) num_heads=8, mlp_ratios=[8, 8, 4, 4], **kwargs)
# FIXME use builder model = _create_coat('coat_lite_mini', pretrained=pretrained, **model_cfg)
model.default_cfg = default_cfgs['coat_lite_mini']
if pretrained:
load_pretrained(model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3))
return model return model
@register_model @register_model
def coat_lite_small(pretrained=False, **kwargs): def coat_lite_small(pretrained=False, **kwargs):
model = CoaT( model_cfg = dict(
patch_size=4, embed_dims=[64, 128, 320, 512], serial_depths=[3, 4, 6, 3], parallel_depth=0, patch_size=4, embed_dims=[64, 128, 320, 512], serial_depths=[3, 4, 6, 3], parallel_depth=0,
num_heads=8, mlp_ratios=[8, 8, 4, 4], **kwargs) num_heads=8, mlp_ratios=[8, 8, 4, 4], **kwargs)
model.default_cfg = default_cfgs['coat_lite_small'] model = _create_coat('coat_lite_small', pretrained=pretrained, **model_cfg)
return model return model

@ -39,7 +39,7 @@ def _cfg(url='', **kwargs):
return { return {
'url': url, 'url': url,
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, 'fixed_input_size': True,
'first_conv': 'patch_embed.proj', 'classifier': 'head', 'first_conv': 'patch_embed.proj', 'classifier': 'head',
**kwargs **kwargs
} }
@ -317,6 +317,9 @@ class ConViT(nn.Module):
def _create_convit(variant, pretrained=False, **kwargs): def _create_convit(variant, pretrained=False, **kwargs):
if kwargs.get('features_only', None):
raise RuntimeError('features_only not implemented for Vision Transformer models.')
return build_model_with_cfg( return build_model_with_cfg(
ConViT, variant, pretrained, ConViT, variant, pretrained,
default_cfg=default_cfgs[variant], default_cfg=default_cfgs[variant],

@ -44,7 +44,7 @@ def load_state_dict(checkpoint_path, use_ema=False):
raise FileNotFoundError() raise FileNotFoundError()
def load_checkpoint(model, checkpoint_path, use_ema=False, strict=True): def load_checkpoint(model, checkpoint_path, use_ema=False, strict=False):
state_dict = load_state_dict(checkpoint_path, use_ema) state_dict = load_state_dict(checkpoint_path, use_ema)
model.load_state_dict(state_dict, strict=strict) model.load_state_dict(state_dict, strict=strict)
@ -378,7 +378,11 @@ def update_default_cfg_and_kwargs(default_cfg, kwargs, kwargs_filter):
# Overlay default cfg values from `external_default_cfg` if it exists in kwargs # Overlay default cfg values from `external_default_cfg` if it exists in kwargs
overlay_external_default_cfg(default_cfg, kwargs) overlay_external_default_cfg(default_cfg, kwargs)
# Set model __init__ args that can be determined by default_cfg (if not already passed as kwargs) # Set model __init__ args that can be determined by default_cfg (if not already passed as kwargs)
set_default_kwargs(kwargs, names=('num_classes', 'global_pool', 'in_chans'), default_cfg=default_cfg) default_kwarg_names = ('num_classes', 'global_pool', 'in_chans')
if default_cfg.get('fixed_input_size', False):
# if fixed_input_size exists and is True, model takes an img_size arg that fixes its input size
default_kwarg_names += ('img_size',)
set_default_kwargs(kwargs, names=default_kwarg_names, default_cfg=default_cfg)
# Filter keyword args for task specific model variants (some 'features only' models, etc.) # Filter keyword args for task specific model variants (some 'features only' models, etc.)
filter_kwargs(kwargs, names=kwargs_filter) filter_kwargs(kwargs, names=kwargs_filter)

@ -1,3 +1,22 @@
""" LeViT
Paper: `LeViT: a Vision Transformer in ConvNet's Clothing for Faster Inference`
- https://arxiv.org/abs/2104.01136
@article{graham2021levit,
title={LeViT: a Vision Transformer in ConvNet's Clothing for Faster Inference},
author={Benjamin Graham and Alaaeldin El-Nouby and Hugo Touvron and Pierre Stock and Armand Joulin and Herv\'e J\'egou and Matthijs Douze},
journal={arXiv preprint arXiv:22104.01136},
year={2021}
}
Adapted from official impl at https://github.com/facebookresearch/LeViT, original copyright bellow.
This version combines both conv/linear models and fixes torchscript compatibility.
Modifications by/coyright Copyright 2021 Ross Wightman
"""
# Copyright (c) 2015-present, Facebook, Inc. # Copyright (c) 2015-present, Facebook, Inc.
# All rights reserved. # All rights reserved.
@ -5,10 +24,15 @@
# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py # https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
# Copyright 2020 Ross Wightman, Apache-2.0 License # Copyright 2020 Ross Wightman, Apache-2.0 License
import itertools import itertools
from copy import deepcopy
from functools import partial
import torch import torch
import torch.nn as nn
from timm.data import IMAGENET_DEFAULT_STD, IMAGENET_DEFAULT_MEAN from timm.data import IMAGENET_DEFAULT_STD, IMAGENET_DEFAULT_MEAN
from .helpers import build_model_with_cfg, overlay_external_default_cfg
from .layers import to_ntuple
from .vision_transformer import trunc_normal_ from .vision_transformer import trunc_normal_
from .registry import register_model from .registry import register_model
@ -19,70 +43,113 @@ def _cfg(url='', **kwargs):
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True, 'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True,
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
'first_conv': 'patch_embed.proj', 'classifier': 'head', 'first_conv': 'patch_embed.0.c', 'classifier': ('head.l', 'head_dist.l'),
**kwargs **kwargs
} }
specification = { default_cfgs = dict(
'levit_128s': { levit_128s=_cfg(
'C': '128_256_384', 'D': 16, 'N': '4_6_8', 'X': '2_3_4', 'drop_path': 0, url='https://dl.fbaipublicfiles.com/LeViT/LeViT-128S-96703c44.pth'
'weights': 'https://dl.fbaipublicfiles.com/LeViT/LeViT-128S-96703c44.pth'}, ),
'levit_128': { levit_128=_cfg(
'C': '128_256_384', 'D': 16, 'N': '4_8_12', 'X': '4_4_4', 'drop_path': 0, url='https://dl.fbaipublicfiles.com/LeViT/LeViT-128-b88c2750.pth'
'weights': 'https://dl.fbaipublicfiles.com/LeViT/LeViT-128-b88c2750.pth'}, ),
'levit_192': { levit_192=_cfg(
'C': '192_288_384', 'D': 32, 'N': '3_5_6', 'X': '4_4_4', 'drop_path': 0, url='https://dl.fbaipublicfiles.com/LeViT/LeViT-192-92712e41.pth'
'weights': 'https://dl.fbaipublicfiles.com/LeViT/LeViT-192-92712e41.pth'}, ),
'levit_256': { levit_256=_cfg(
'C': '256_384_512', 'D': 32, 'N': '4_6_8', 'X': '4_4_4', 'drop_path': 0, url='https://dl.fbaipublicfiles.com/LeViT/LeViT-256-13b5763e.pth'
'weights': 'https://dl.fbaipublicfiles.com/LeViT/LeViT-256-13b5763e.pth'}, ),
'levit_384': { levit_384=_cfg(
'C': '384_512_768', 'D': 32, 'N': '6_9_12', 'X': '4_4_4', 'drop_path': 0.1, url='https://dl.fbaipublicfiles.com/LeViT/LeViT-384-9bdaf2e2.pth'
'weights': 'https://dl.fbaipublicfiles.com/LeViT/LeViT-384-9bdaf2e2.pth'}, ),
} )
model_cfgs = dict(
levit_128s=dict(
embed_dim=(128, 256, 384), key_dim=16, num_heads=(4, 6, 8), depth=(2, 3, 4)),
levit_128=dict(
embed_dim=(128, 256, 384), key_dim=16, num_heads=(4, 8, 12), depth=(4, 4, 4)),
levit_192=dict(
embed_dim=(192, 288, 384), key_dim=32, num_heads=(3, 5, 6), depth=(4, 4, 4)),
levit_256=dict(
embed_dim=(256, 384, 512), key_dim=32, num_heads=(4, 6, 8), depth=(4, 4, 4)),
levit_384=dict(
embed_dim=(384, 512, 768), key_dim=32, num_heads=(6, 9, 12), depth=(4, 4, 4)),
)
__all__ = ['Levit'] __all__ = ['Levit']
@register_model @register_model
def levit_128s(num_classes=1000, distillation=True, pretrained=False, fuse=False, **kwargs): def levit_128s(pretrained=False, fuse=False,distillation=True, use_conv=False, **kwargs):
return model_factory(**specification['levit_128s'], num_classes=num_classes, return create_levit(
distillation=distillation, pretrained=pretrained, fuse=fuse) 'levit_128s', pretrained=pretrained, fuse=fuse, distillation=distillation, use_conv=use_conv, **kwargs)
@register_model
def levit_128(pretrained=False, fuse=False, distillation=True, use_conv=False, **kwargs):
return create_levit(
'levit_128', pretrained=pretrained, fuse=fuse, distillation=distillation, use_conv=use_conv, **kwargs)
@register_model
def levit_192(pretrained=False, fuse=False, distillation=True, use_conv=False, **kwargs):
return create_levit(
'levit_192', pretrained=pretrained, fuse=fuse, distillation=distillation, use_conv=use_conv, **kwargs)
@register_model @register_model
def levit_128(num_classes=1000, distillation=True, pretrained=False, fuse=False, **kwargs): def levit_256(pretrained=False, fuse=False, distillation=True, use_conv=False, **kwargs):
return model_factory(**specification['levit_128'], num_classes=num_classes, return create_levit(
distillation=distillation, pretrained=pretrained, fuse=fuse) 'levit_256', pretrained=pretrained, fuse=fuse, distillation=distillation, use_conv=use_conv, **kwargs)
@register_model @register_model
def levit_192(num_classes=1000, distillation=True, pretrained=False, fuse=False, **kwargs): def levit_384(pretrained=False, fuse=False, distillation=True, use_conv=False, **kwargs):
return model_factory(**specification['levit_192'], num_classes=num_classes, return create_levit(
distillation=distillation, pretrained=pretrained, fuse=fuse) 'levit_384', pretrained=pretrained, fuse=fuse, distillation=distillation, use_conv=use_conv, **kwargs)
@register_model @register_model
def levit_256(num_classes=1000, distillation=True, pretrained=False, fuse=False, **kwargs): def levit_c_128s(pretrained=False, fuse=False, distillation=True, use_conv=True,**kwargs):
return model_factory(**specification['levit_256'], num_classes=num_classes, return create_levit(
distillation=distillation, pretrained=pretrained, fuse=fuse) 'levit_128s', pretrained=pretrained, fuse=fuse, distillation=distillation, use_conv=use_conv, **kwargs)
@register_model @register_model
def levit_384(num_classes=1000, distillation=True, pretrained=False, fuse=False, **kwargs): def levit_c_128(pretrained=False, fuse=False,distillation=True, use_conv=True, **kwargs):
return model_factory(**specification['levit_384'], num_classes=num_classes, return create_levit(
distillation=distillation, pretrained=pretrained, fuse=fuse) 'levit_128', pretrained=pretrained, fuse=fuse, distillation=distillation, use_conv=use_conv, **kwargs)
class ConvNorm(torch.nn.Sequential): @register_model
def levit_c_192(pretrained=False, fuse=False, distillation=True, use_conv=True, **kwargs):
return create_levit(
'levit_192', pretrained=pretrained, fuse=fuse, distillation=distillation, use_conv=use_conv, **kwargs)
@register_model
def levit_c_256(pretrained=False, fuse=False, distillation=True, use_conv=True, **kwargs):
return create_levit(
'levit_256', pretrained=pretrained, fuse=fuse, distillation=distillation, use_conv=use_conv, **kwargs)
@register_model
def levit_c_384(pretrained=False, fuse=False, distillation=True, use_conv=True, **kwargs):
return create_levit(
'levit_384', pretrained=pretrained, fuse=fuse, distillation=distillation, use_conv=use_conv, **kwargs)
class ConvNorm(nn.Sequential):
def __init__( def __init__(
self, a, b, ks=1, stride=1, pad=0, dilation=1, groups=1, bn_weight_init=1, resolution=-10000): self, a, b, ks=1, stride=1, pad=0, dilation=1, groups=1, bn_weight_init=1, resolution=-10000):
super().__init__() super().__init__()
self.add_module('c', torch.nn.Conv2d(a, b, ks, stride, pad, dilation, groups, bias=False)) self.add_module('c', nn.Conv2d(a, b, ks, stride, pad, dilation, groups, bias=False))
bn = torch.nn.BatchNorm2d(b) bn = nn.BatchNorm2d(b)
torch.nn.init.constant_(bn.weight, bn_weight_init) nn.init.constant_(bn.weight, bn_weight_init)
torch.nn.init.constant_(bn.bias, 0) nn.init.constant_(bn.bias, 0)
self.add_module('bn', bn) self.add_module('bn', bn)
@torch.no_grad() @torch.no_grad()
@ -91,7 +158,7 @@ class ConvNorm(torch.nn.Sequential):
w = bn.weight / (bn.running_var + bn.eps) ** 0.5 w = bn.weight / (bn.running_var + bn.eps) ** 0.5
w = c.weight * w[:, None, None, None] w = c.weight * w[:, None, None, None]
b = bn.bias - bn.running_mean * bn.weight / (bn.running_var + bn.eps) ** 0.5 b = bn.bias - bn.running_mean * bn.weight / (bn.running_var + bn.eps) ** 0.5
m = torch.nn.Conv2d( m = nn.Conv2d(
w.size(1), w.size(0), w.shape[2:], stride=self.c.stride, w.size(1), w.size(0), w.shape[2:], stride=self.c.stride,
padding=self.c.padding, dilation=self.c.dilation, groups=self.c.groups) padding=self.c.padding, dilation=self.c.dilation, groups=self.c.groups)
m.weight.data.copy_(w) m.weight.data.copy_(w)
@ -99,13 +166,13 @@ class ConvNorm(torch.nn.Sequential):
return m return m
class LinearNorm(torch.nn.Sequential): class LinearNorm(nn.Sequential):
def __init__(self, a, b, bn_weight_init=1, resolution=-100000): def __init__(self, a, b, bn_weight_init=1, resolution=-100000):
super().__init__() super().__init__()
self.add_module('c', torch.nn.Linear(a, b, bias=False)) self.add_module('c', nn.Linear(a, b, bias=False))
bn = torch.nn.BatchNorm1d(b) bn = nn.BatchNorm1d(b)
torch.nn.init.constant_(bn.weight, bn_weight_init) nn.init.constant_(bn.weight, bn_weight_init)
torch.nn.init.constant_(bn.bias, 0) nn.init.constant_(bn.bias, 0)
self.add_module('bn', bn) self.add_module('bn', bn)
@torch.no_grad() @torch.no_grad()
@ -114,25 +181,24 @@ class LinearNorm(torch.nn.Sequential):
w = bn.weight / (bn.running_var + bn.eps) ** 0.5 w = bn.weight / (bn.running_var + bn.eps) ** 0.5
w = l.weight * w[:, None] w = l.weight * w[:, None]
b = bn.bias - bn.running_mean * bn.weight / (bn.running_var + bn.eps) ** 0.5 b = bn.bias - bn.running_mean * bn.weight / (bn.running_var + bn.eps) ** 0.5
m = torch.nn.Linear(w.size(1), w.size(0)) m = nn.Linear(w.size(1), w.size(0))
m.weight.data.copy_(w) m.weight.data.copy_(w)
m.bias.data.copy_(b) m.bias.data.copy_(b)
return m return m
def forward(self, x): def forward(self, x):
l, bn = self._modules.values() x = self.c(x)
x = l(x) return self.bn(x.flatten(0, 1)).reshape_as(x)
return bn(x.flatten(0, 1)).reshape_as(x)
class NormLinear(torch.nn.Sequential): class NormLinear(nn.Sequential):
def __init__(self, a, b, bias=True, std=0.02): def __init__(self, a, b, bias=True, std=0.02):
super().__init__() super().__init__()
self.add_module('bn', torch.nn.BatchNorm1d(a)) self.add_module('bn', nn.BatchNorm1d(a))
l = torch.nn.Linear(a, b, bias=bias) l = nn.Linear(a, b, bias=bias)
trunc_normal_(l.weight, std=std) trunc_normal_(l.weight, std=std)
if bias: if bias:
torch.nn.init.constant_(l.bias, 0) nn.init.constant_(l.bias, 0)
self.add_module('l', l) self.add_module('l', l)
@torch.no_grad() @torch.no_grad()
@ -145,24 +211,24 @@ class NormLinear(torch.nn.Sequential):
b = b @ self.l.weight.T b = b @ self.l.weight.T
else: else:
b = (l.weight @ b[:, None]).view(-1) + self.l.bias b = (l.weight @ b[:, None]).view(-1) + self.l.bias
m = torch.nn.Linear(w.size(1), w.size(0)) m = nn.Linear(w.size(1), w.size(0))
m.weight.data.copy_(w) m.weight.data.copy_(w)
m.bias.data.copy_(b) m.bias.data.copy_(b)
return m return m
def b16(n, activation, resolution=224): def stem_b16(in_chs, out_chs, activation, resolution=224):
return torch.nn.Sequential( return nn.Sequential(
ConvNorm(3, n // 8, 3, 2, 1, resolution=resolution), ConvNorm(in_chs, out_chs // 8, 3, 2, 1, resolution=resolution),
activation(), activation(),
ConvNorm(n // 8, n // 4, 3, 2, 1, resolution=resolution // 2), ConvNorm(out_chs // 8, out_chs // 4, 3, 2, 1, resolution=resolution // 2),
activation(), activation(),
ConvNorm(n // 4, n // 2, 3, 2, 1, resolution=resolution // 4), ConvNorm(out_chs // 4, out_chs // 2, 3, 2, 1, resolution=resolution // 4),
activation(), activation(),
ConvNorm(n // 2, n, 3, 2, 1, resolution=resolution // 8)) ConvNorm(out_chs // 2, out_chs, 3, 2, 1, resolution=resolution // 8))
class Residual(torch.nn.Module): class Residual(nn.Module):
def __init__(self, m, drop): def __init__(self, m, drop):
super().__init__() super().__init__()
self.m = m self.m = m
@ -176,10 +242,23 @@ class Residual(torch.nn.Module):
return x + self.m(x) return x + self.m(x)
class Attention(torch.nn.Module): class Subsample(nn.Module):
def __init__(self, stride, resolution):
super().__init__()
self.stride = stride
self.resolution = resolution
def forward(self, x):
B, N, C = x.shape
x = x.view(B, self.resolution, self.resolution, C)[:, ::self.stride, ::self.stride]
return x.reshape(B, -1, C)
class Attention(nn.Module):
def __init__( def __init__(
self, dim, key_dim, num_heads=8, attn_ratio=4, act_layer=None, resolution=14): self, dim, key_dim, num_heads=8, attn_ratio=4, act_layer=None, resolution=14, use_conv=False):
super().__init__() super().__init__()
self.num_heads = num_heads self.num_heads = num_heads
self.scale = key_dim ** -0.5 self.scale = key_dim ** -0.5
self.key_dim = key_dim self.key_dim = key_dim
@ -187,11 +266,13 @@ class Attention(torch.nn.Module):
self.d = int(attn_ratio * key_dim) self.d = int(attn_ratio * key_dim)
self.dh = int(attn_ratio * key_dim) * num_heads self.dh = int(attn_ratio * key_dim) * num_heads
self.attn_ratio = attn_ratio self.attn_ratio = attn_ratio
self.use_conv = use_conv
ln_layer = ConvNorm if self.use_conv else LinearNorm
h = self.dh + nh_kd * 2 h = self.dh + nh_kd * 2
self.qkv = LinearNorm(dim, h, resolution=resolution) self.qkv = ln_layer(dim, h, resolution=resolution)
self.proj = torch.nn.Sequential( self.proj = nn.Sequential(
act_layer(), act_layer(),
LinearNorm(self.dh, dim, bn_weight_init=0, resolution=resolution)) ln_layer(self.dh, dim, bn_weight_init=0, resolution=resolution))
points = list(itertools.product(range(resolution), range(resolution))) points = list(itertools.product(range(resolution), range(resolution)))
N = len(points) N = len(points)
@ -203,68 +284,68 @@ class Attention(torch.nn.Module):
if offset not in attention_offsets: if offset not in attention_offsets:
attention_offsets[offset] = len(attention_offsets) attention_offsets[offset] = len(attention_offsets)
idxs.append(attention_offsets[offset]) idxs.append(attention_offsets[offset])
self.attention_biases = torch.nn.Parameter(torch.zeros(num_heads, len(attention_offsets))) self.attention_biases = nn.Parameter(torch.zeros(num_heads, len(attention_offsets)))
self.register_buffer('attention_bias_idxs', torch.LongTensor(idxs).view(N, N)) self.register_buffer('attention_bias_idxs', torch.LongTensor(idxs).view(N, N))
self.ab = None
@torch.no_grad() @torch.no_grad()
def train(self, mode=True): def train(self, mode=True):
super().train(mode) super().train(mode)
if mode and hasattr(self, 'ab'): self.ab = None if mode else self.attention_biases[:, self.attention_bias_idxs]
del self.ab
def forward(self, x): # x (B,C,H,W)
if self.use_conv:
B, C, H, W = x.shape
q, k, v = self.qkv(x).view(B, self.num_heads, -1, H * W).split([self.key_dim, self.key_dim, self.d], dim=2)
ab = self.attention_biases[:, self.attention_bias_idxs] if self.ab is None else self.ab
attn = (q.transpose(-2, -1) @ k) * self.scale + ab
attn = attn.softmax(dim=-1)
x = (v @ attn.transpose(-2, -1)).view(B, -1, H, W)
else: else:
self.ab = self.attention_biases[:, self.attention_bias_idxs] B, N, C = x.shape
qkv = self.qkv(x)
def forward(self, x): # x (B,N,C) q, k, v = qkv.view(B, N, self.num_heads, -1).split([self.key_dim, self.key_dim, self.d], dim=3)
B, N, C = x.shape q = q.permute(0, 2, 1, 3)
qkv = self.qkv(x) k = k.permute(0, 2, 1, 3)
q, k, v = qkv.view(B, N, self.num_heads, -1).split([self.key_dim, self.key_dim, self.d], dim=3) v = v.permute(0, 2, 1, 3)
q = q.permute(0, 2, 1, 3) ab = self.attention_biases[:, self.attention_bias_idxs] if self.ab is None else self.ab
k = k.permute(0, 2, 1, 3) attn = q @ k.transpose(-2, -1) * self.scale + ab
v = v.permute(0, 2, 1, 3) attn = attn.softmax(dim=-1)
x = (attn @ v).transpose(1, 2).reshape(B, N, self.dh)
ab = self.attention_biases[:, self.attention_bias_idxs] if self.training else self.ab
attn = q @ k.transpose(-2, -1) * self.scale + ab
attn = attn.softmax(dim=-1)
x = (attn @ v).transpose(1, 2).reshape(B, N, self.dh)
x = self.proj(x) x = self.proj(x)
return x return x
class Subsample(torch.nn.Module): class AttentionSubsample(nn.Module):
def __init__(self, stride, resolution): def __init__(
super().__init__() self, in_dim, out_dim, key_dim, num_heads=8, attn_ratio=2,
self.stride = stride act_layer=None, stride=2, resolution=14, resolution_=7, use_conv=False):
self.resolution = resolution
def forward(self, x):
B, N, C = x.shape
x = x.view(B, self.resolution, self.resolution, C)[:, ::self.stride, ::self.stride]
return x.reshape(B, -1, C)
class AttentionSubsample(torch.nn.Module):
def __init__(self, in_dim, out_dim, key_dim, num_heads=8,
attn_ratio=2, act_layer=None, stride=2, resolution=14, resolution_=7):
super().__init__() super().__init__()
self.num_heads = num_heads self.num_heads = num_heads
self.scale = key_dim ** -0.5 self.scale = key_dim ** -0.5
self.key_dim = key_dim self.key_dim = key_dim
self.nh_kd = nh_kd = key_dim * num_heads self.nh_kd = nh_kd = key_dim * num_heads
self.d = int(attn_ratio * key_dim) self.d = int(attn_ratio * key_dim)
self.dh = int(attn_ratio * key_dim) * self.num_heads self.dh = self.d * self.num_heads
self.attn_ratio = attn_ratio self.attn_ratio = attn_ratio
self.resolution_ = resolution_ self.resolution_ = resolution_
self.resolution_2 = resolution_ ** 2 self.resolution_2 = resolution_ ** 2
h = self.dh + nh_kd self.use_conv = use_conv
self.kv = LinearNorm(in_dim, h, resolution=resolution) if self.use_conv:
ln_layer = ConvNorm
sub_layer = partial(nn.AvgPool2d, kernel_size=1, padding=0)
else:
ln_layer = LinearNorm
sub_layer = partial(Subsample, resolution=resolution)
self.q = torch.nn.Sequential( h = self.dh + nh_kd
Subsample(stride, resolution), self.kv = ln_layer(in_dim, h, resolution=resolution)
LinearNorm(in_dim, nh_kd, resolution=resolution_)) self.q = nn.Sequential(
self.proj = torch.nn.Sequential( sub_layer(stride=stride),
ln_layer(in_dim, nh_kd, resolution=resolution_))
self.proj = nn.Sequential(
act_layer(), act_layer(),
LinearNorm(self.dh, out_dim, resolution=resolution_)) ln_layer(self.dh, out_dim, resolution=resolution_))
self.stride = stride self.stride = stride
self.resolution = resolution self.resolution = resolution
@ -283,35 +364,43 @@ class AttentionSubsample(torch.nn.Module):
if offset not in attention_offsets: if offset not in attention_offsets:
attention_offsets[offset] = len(attention_offsets) attention_offsets[offset] = len(attention_offsets)
idxs.append(attention_offsets[offset]) idxs.append(attention_offsets[offset])
self.attention_biases = torch.nn.Parameter(torch.zeros(num_heads, len(attention_offsets))) self.attention_biases = nn.Parameter(torch.zeros(num_heads, len(attention_offsets)))
self.register_buffer('attention_bias_idxs', torch.LongTensor(idxs).view(N_, N)) self.register_buffer('attention_bias_idxs', torch.LongTensor(idxs).view(N_, N))
self.ab = None
@torch.no_grad() @torch.no_grad()
def train(self, mode=True): def train(self, mode=True):
super().train(mode) super().train(mode)
if mode and hasattr(self, 'ab'): self.ab = None if mode else self.attention_biases[:, self.attention_bias_idxs]
del self.ab
else:
self.ab = self.attention_biases[:, self.attention_bias_idxs]
def forward(self, x): def forward(self, x):
B, N, C = x.shape if self.use_conv:
k, v = self.kv(x).view(B, N, self.num_heads, -1).split([self.key_dim, self.d], dim=3) B, C, H, W = x.shape
k = k.permute(0, 2, 1, 3) # BHNC k, v = self.kv(x).view(B, self.num_heads, -1, H * W).split([self.key_dim, self.d], dim=2)
v = v.permute(0, 2, 1, 3) # BHNC q = self.q(x).view(B, self.num_heads, self.key_dim, self.resolution_2)
q = self.q(x).view(B, self.resolution_2, self.num_heads, self.key_dim).permute(0, 2, 1, 3)
ab = self.attention_biases[:, self.attention_bias_idxs] if self.ab is None else self.ab
attn = (q.transpose(-2, -1) @ k) * self.scale + ab
attn = attn.softmax(dim=-1)
x = (v @ attn.transpose(-2, -1)).reshape(B, -1, self.resolution_, self.resolution_)
else:
B, N, C = x.shape
k, v = self.kv(x).view(B, N, self.num_heads, -1).split([self.key_dim, self.d], dim=3)
k = k.permute(0, 2, 1, 3) # BHNC
v = v.permute(0, 2, 1, 3) # BHNC
q = self.q(x).view(B, self.resolution_2, self.num_heads, self.key_dim).permute(0, 2, 1, 3)
ab = self.attention_biases[:, self.attention_bias_idxs] if self.training else self.ab ab = self.attention_biases[:, self.attention_bias_idxs] if self.ab is None else self.ab
attn = q @ k.transpose(-2, -1) * self.scale + ab attn = q @ k.transpose(-2, -1) * self.scale + ab
attn = attn.softmax(dim=-1) attn = attn.softmax(dim=-1)
x = (attn @ v).transpose(1, 2).reshape(B, -1, self.dh) x = (attn @ v).transpose(1, 2).reshape(B, -1, self.dh)
x = self.proj(x) x = self.proj(x)
return x return x
class Levit(torch.nn.Module): class Levit(nn.Module):
""" Vision Transformer with support for patch or hybrid CNN input stage """ Vision Transformer with support for patch or hybrid CNN input stage
""" """
@ -321,45 +410,63 @@ class Levit(torch.nn.Module):
patch_size=16, patch_size=16,
in_chans=3, in_chans=3,
num_classes=1000, num_classes=1000,
embed_dim=[192], embed_dim=(192,),
key_dim=[64], key_dim=64,
depth=[12], depth=(12,),
num_heads=[3], num_heads=(3,),
attn_ratio=[2], attn_ratio=2,
mlp_ratio=[2], mlp_ratio=2,
hybrid_backbone=None, hybrid_backbone=None,
down_ops=[], down_ops=None,
attn_act_layer=torch.nn.Hardswish, act_layer=nn.Hardswish,
mlp_act_layer=torch.nn.Hardswish, attn_act_layer=nn.Hardswish,
distillation=True, distillation=True,
use_conv=False,
drop_path=0): drop_path=0):
super().__init__() super().__init__()
global FLOPS_COUNTER if isinstance(img_size, tuple):
# FIXME origin impl passes single img/res dim through whole hierarchy,
# not sure this model will be used enough to spend time fixing it.
assert img_size[0] == img_size[1]
img_size = img_size[0]
self.num_classes = num_classes self.num_classes = num_classes
self.num_features = embed_dim[-1] self.num_features = embed_dim[-1]
self.embed_dim = embed_dim self.embed_dim = embed_dim
N = len(embed_dim)
assert len(depth) == len(num_heads) == N
key_dim = to_ntuple(N)(key_dim)
attn_ratio = to_ntuple(N)(attn_ratio)
mlp_ratio = to_ntuple(N)(mlp_ratio)
down_ops = down_ops or (
# ('Subsample',key_dim, num_heads, attn_ratio, mlp_ratio, stride)
('Subsample', key_dim[0], embed_dim[0] // key_dim[0], 4, 2, 2),
('Subsample', key_dim[0], embed_dim[1] // key_dim[1], 4, 2, 2),
('',)
)
self.distillation = distillation self.distillation = distillation
self.use_conv = use_conv
ln_layer = ConvNorm if self.use_conv else LinearNorm
self.patch_embed = hybrid_backbone self.patch_embed = hybrid_backbone or stem_b16(in_chans, embed_dim[0], activation=act_layer)
self.blocks = [] self.blocks = []
down_ops.append([''])
resolution = img_size // patch_size resolution = img_size // patch_size
for i, (ed, kd, dpth, nh, ar, mr, do) in enumerate( for i, (ed, kd, dpth, nh, ar, mr, do) in enumerate(
zip(embed_dim, key_dim, depth, num_heads, attn_ratio, mlp_ratio, down_ops)): zip(embed_dim, key_dim, depth, num_heads, attn_ratio, mlp_ratio, down_ops)):
for _ in range(dpth): for _ in range(dpth):
self.blocks.append( self.blocks.append(
Residual( Residual(
Attention(ed, kd, nh, attn_ratio=ar, act_layer=attn_act_layer, resolution=resolution), Attention(
ed, kd, nh, attn_ratio=ar, act_layer=attn_act_layer,
resolution=resolution, use_conv=use_conv),
drop_path)) drop_path))
if mr > 0: if mr > 0:
h = int(ed * mr) h = int(ed * mr)
self.blocks.append( self.blocks.append(
Residual(torch.nn.Sequential( Residual(nn.Sequential(
LinearNorm(ed, h, resolution=resolution), ln_layer(ed, h, resolution=resolution),
mlp_act_layer(), act_layer(),
LinearNorm(h, ed, bn_weight_init=0, resolution=resolution), ln_layer(h, ed, bn_weight_init=0, resolution=resolution),
), drop_path)) ), drop_path))
if do[0] == 'Subsample': if do[0] == 'Subsample':
# ('Subsample',key_dim, num_heads, attn_ratio, mlp_ratio, stride) # ('Subsample',key_dim, num_heads, attn_ratio, mlp_ratio, stride)
@ -368,22 +475,22 @@ class Levit(torch.nn.Module):
AttentionSubsample( AttentionSubsample(
*embed_dim[i:i + 2], key_dim=do[1], num_heads=do[2], *embed_dim[i:i + 2], key_dim=do[1], num_heads=do[2],
attn_ratio=do[3], act_layer=attn_act_layer, stride=do[5], attn_ratio=do[3], act_layer=attn_act_layer, stride=do[5],
resolution=resolution, resolution_=resolution_)) resolution=resolution, resolution_=resolution_, use_conv=use_conv))
resolution = resolution_ resolution = resolution_
if do[4] > 0: # mlp_ratio if do[4] > 0: # mlp_ratio
h = int(embed_dim[i + 1] * do[4]) h = int(embed_dim[i + 1] * do[4])
self.blocks.append( self.blocks.append(
Residual(torch.nn.Sequential( Residual(nn.Sequential(
LinearNorm(embed_dim[i + 1], h, resolution=resolution), ln_layer(embed_dim[i + 1], h, resolution=resolution),
mlp_act_layer(), act_layer(),
LinearNorm(h, embed_dim[i + 1], bn_weight_init=0, resolution=resolution), ln_layer(h, embed_dim[i + 1], bn_weight_init=0, resolution=resolution),
), drop_path)) ), drop_path))
self.blocks = torch.nn.Sequential(*self.blocks) self.blocks = nn.Sequential(*self.blocks)
# Classifier head # Classifier head
self.head = NormLinear(embed_dim[-1], num_classes) if num_classes > 0 else torch.nn.Identity() self.head = NormLinear(embed_dim[-1], num_classes) if num_classes > 0 else nn.Identity()
if distillation: if distillation:
self.head_dist = NormLinear(embed_dim[-1], num_classes) if num_classes > 0 else torch.nn.Identity() self.head_dist = NormLinear(embed_dim[-1], num_classes) if num_classes > 0 else nn.Identity()
else: else:
self.head_dist = None self.head_dist = None
@ -393,48 +500,44 @@ class Levit(torch.nn.Module):
def forward(self, x): def forward(self, x):
x = self.patch_embed(x) x = self.patch_embed(x)
x = x.flatten(2).transpose(1, 2) if not self.use_conv:
x = x.flatten(2).transpose(1, 2)
x = self.blocks(x) x = self.blocks(x)
x = x.mean(1) x = x.mean((-2, -1)) if self.use_conv else x.mean(1)
if self.distillation: if self.head_dist is not None:
x = self.head(x), self.head_dist(x) x, x_dist = self.head(x), self.head_dist(x)
if not self.training: if self.training and not torch.jit.is_scripting():
x = (x[0] + x[1]) / 2 return x, x_dist
else:
# during inference, return the average of both classifier predictions
return (x + x_dist) / 2
else: else:
x = self.head(x) x = self.head(x)
return x return x
def model_factory(C, D, X, N, drop_path, weights, num_classes, distillation, pretrained, fuse): def checkpoint_filter_fn(state_dict, model):
embed_dim = [int(x) for x in C.split('_')] if 'model' in state_dict:
num_heads = [int(x) for x in N.split('_')] # For deit models
depth = [int(x) for x in X.split('_')] state_dict = state_dict['model']
act = torch.nn.Hardswish D = model.state_dict()
model = Levit( for k in state_dict.keys():
patch_size=16, if D[k].ndim == 4 and state_dict[k].ndim == 2:
embed_dim=embed_dim, state_dict[k] = state_dict[k][:, :, None, None]
num_heads=num_heads, return state_dict
key_dim=[D] * 3,
depth=depth,
attn_ratio=[2, 2, 2], def create_levit(variant, pretrained=False, default_cfg=None, fuse=False, **kwargs):
mlp_ratio=[2, 2, 2], if kwargs.get('features_only', None):
down_ops=[ raise RuntimeError('features_only not implemented for Vision Transformer models.')
# ('Subsample',key_dim, num_heads, attn_ratio, mlp_ratio, stride)
['Subsample', D, embed_dim[0] // D, 4, 2, 2], model_cfg = dict(**model_cfgs[variant], **kwargs)
['Subsample', D, embed_dim[1] // D, 4, 2, 2], model = build_model_with_cfg(
], Levit, variant, pretrained,
attn_act_layer=act, default_cfg=default_cfgs[variant],
mlp_act_layer=act, pretrained_filter_fn=checkpoint_filter_fn,
hybrid_backbone=b16(embed_dim[0], activation=act), **model_cfg)
num_classes=num_classes,
drop_path=drop_path,
distillation=distillation
)
model.default_cfg = _cfg()
if pretrained:
checkpoint = torch.hub.load_state_dict_from_url(weights, map_location='cpu')
model.load_state_dict(checkpoint['model'])
#if fuse: #if fuse:
# utils.replace_batchnorm(model) # utils.replace_batchnorm(model)
return model return model

@ -1,400 +0,0 @@
# Copyright (c) 2015-present, Facebook, Inc.
# All rights reserved.
# Modified from
# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
# Copyright 2020 Ross Wightman, Apache-2.0 License
import itertools
import torch
from timm.data import IMAGENET_DEFAULT_STD, IMAGENET_DEFAULT_MEAN
from .vision_transformer import trunc_normal_
from .registry import register_model
def _cfg(url='', **kwargs):
return {
'url': url,
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True,
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
'first_conv': 'patch_embed.proj', 'classifier': 'head',
**kwargs
}
specification = {
'levit_c_128s': {
'C': '128_256_384', 'D': 16, 'N': '4_6_8', 'X': '2_3_4', 'drop_path': 0,
'weights': 'https://dl.fbaipublicfiles.com/LeViT/LeViT-128S-96703c44.pth'},
'levit_c_128': {
'C': '128_256_384', 'D': 16, 'N': '4_8_12', 'X': '4_4_4', 'drop_path': 0,
'weights': 'https://dl.fbaipublicfiles.com/LeViT/LeViT-128-b88c2750.pth'},
'levit_c_192': {
'C': '192_288_384', 'D': 32, 'N': '3_5_6', 'X': '4_4_4', 'drop_path': 0,
'weights': 'https://dl.fbaipublicfiles.com/LeViT/LeViT-192-92712e41.pth'},
'levit_c_256': {
'C': '256_384_512', 'D': 32, 'N': '4_6_8', 'X': '4_4_4', 'drop_path': 0,
'weights': 'https://dl.fbaipublicfiles.com/LeViT/LeViT-256-13b5763e.pth'},
'levit_c_384': {
'C': '384_512_768', 'D': 32, 'N': '6_9_12', 'X': '4_4_4', 'drop_path': 0.1,
'weights': 'https://dl.fbaipublicfiles.com/LeViT/LeViT-384-9bdaf2e2.pth'},
}
__all__ = ['Levit']
@register_model
def levit_c_128s(num_classes=1000, distillation=True, pretrained=False, fuse=False, **kwargs):
return model_factory(**specification['levit_c_128s'], num_classes=num_classes,
distillation=distillation, pretrained=pretrained, fuse=fuse)
@register_model
def levit_c_128(num_classes=1000, distillation=True, pretrained=False, fuse=False, **kwargs):
return model_factory(**specification['levit_c_128'], num_classes=num_classes,
distillation=distillation, pretrained=pretrained, fuse=fuse)
@register_model
def levit_c_192(num_classes=1000, distillation=True, pretrained=False, fuse=False, **kwargs):
return model_factory(**specification['levit_c_192'], num_classes=num_classes,
distillation=distillation, pretrained=pretrained, fuse=fuse)
@register_model
def levit_c_256(num_classes=1000, distillation=True, pretrained=False, fuse=False, **kwargs):
return model_factory(**specification['levit_c_256'], num_classes=num_classes,
distillation=distillation, pretrained=pretrained, fuse=fuse)
@register_model
def levit_c_384(num_classes=1000, distillation=True, pretrained=False, fuse=False, **kwargs):
return model_factory(**specification['levit_c_384'], num_classes=num_classes,
distillation=distillation, pretrained=pretrained, fuse=fuse)
class ConvNorm(torch.nn.Sequential):
def __init__(
self, a, b, ks=1, stride=1, pad=0, dilation=1, groups=1, bn_weight_init=1, resolution=-10000):
super().__init__()
self.add_module('c', torch.nn.Conv2d(a, b, ks, stride, pad, dilation, groups, bias=False))
bn = torch.nn.BatchNorm2d(b)
torch.nn.init.constant_(bn.weight, bn_weight_init)
torch.nn.init.constant_(bn.bias, 0)
self.add_module('bn', bn)
@torch.no_grad()
def fuse(self):
c, bn = self._modules.values()
w = bn.weight / (bn.running_var + bn.eps) ** 0.5
w = c.weight * w[:, None, None, None]
b = bn.bias - bn.running_mean * bn.weight / \
(bn.running_var + bn.eps) ** 0.5
m = torch.nn.Conv2d(
w.size(1), w.size(0), w.shape[2:], stride=self.c.stride,
padding=self.c.padding, dilation=self.c.dilation, groups=self.c.groups)
m.weight.data.copy_(w)
m.bias.data.copy_(b)
return m
class NormLinear(torch.nn.Sequential):
def __init__(self, a, b, bias=True, std=0.02):
super().__init__()
self.add_module('bn', torch.nn.BatchNorm1d(a))
l = torch.nn.Linear(a, b, bias=bias)
trunc_normal_(l.weight, std=std)
if bias:
torch.nn.init.constant_(l.bias, 0)
self.add_module('l', l)
@torch.no_grad()
def fuse(self):
bn, l = self._modules.values()
w = bn.weight / (bn.running_var + bn.eps) ** 0.5
b = bn.bias - self.bn.running_mean * \
self.bn.weight / (bn.running_var + bn.eps) ** 0.5
w = l.weight * w[None, :]
if l.bias is None:
b = b @ self.l.weight.T
else:
b = (l.weight @ b[:, None]).view(-1) + self.l.bias
m = torch.nn.Linear(w.size(1), w.size(0))
m.weight.data.copy_(w)
m.bias.data.copy_(b)
return m
def b16(n, activation, resolution=224):
return torch.nn.Sequential(
ConvNorm(3, n // 8, 3, 2, 1, resolution=resolution),
activation(),
ConvNorm(n // 8, n // 4, 3, 2, 1, resolution=resolution // 2),
activation(),
ConvNorm(n // 4, n // 2, 3, 2, 1, resolution=resolution // 4),
activation(),
ConvNorm(n // 2, n, 3, 2, 1, resolution=resolution // 8))
class Residual(torch.nn.Module):
def __init__(self, m, drop):
super().__init__()
self.m = m
self.drop = drop
def forward(self, x):
if self.training and self.drop > 0:
return x + self.m(x) * torch.rand(
x.size(0), 1, 1, device=x.device).ge_(self.drop).div(1 - self.drop).detach()
else:
return x + self.m(x)
class Attention(torch.nn.Module):
def __init__(self, dim, key_dim, num_heads=8,
attn_ratio=4, act_layer=None, resolution=14):
super().__init__()
self.num_heads = num_heads
self.scale = key_dim ** -0.5
self.key_dim = key_dim
self.nh_kd = nh_kd = key_dim * num_heads
self.d = int(attn_ratio * key_dim)
self.dh = int(attn_ratio * key_dim) * num_heads
self.attn_ratio = attn_ratio
h = self.dh + nh_kd * 2
self.qkv = ConvNorm(dim, h, resolution=resolution)
self.proj = torch.nn.Sequential(
act_layer(),
ConvNorm(self.dh, dim, bn_weight_init=0, resolution=resolution))
points = list(itertools.product(range(resolution), range(resolution)))
N = len(points)
attention_offsets = {}
idxs = []
for p1 in points:
for p2 in points:
offset = (abs(p1[0] - p2[0]), abs(p1[1] - p2[1]))
if offset not in attention_offsets:
attention_offsets[offset] = len(attention_offsets)
idxs.append(attention_offsets[offset])
self.attention_biases = torch.nn.Parameter(torch.zeros(num_heads, len(attention_offsets)))
self.register_buffer('attention_bias_idxs', torch.LongTensor(idxs).view(N, N))
self.ab = None
@torch.no_grad()
def train(self, mode=True):
super().train(mode)
if mode and self.ab is not None:
self.ab = None
else:
self.ab = self.attention_biases[:, self.attention_bias_idxs]
def forward(self, x): # x (B,C,H,W)
B, C, H, W = x.shape
q, k, v = self.qkv(x).view(B, self.num_heads, -1, H * W).split([self.key_dim, self.key_dim, self.d], dim=2)
ab = self.attention_biases[:, self.attention_bias_idxs] if self.training else self.ab
attn = (q.transpose(-2, -1) @ k) * self.scale + ab
attn = attn.softmax(dim=-1)
x = (v @ attn.transpose(-2, -1)).view(B, -1, H, W)
x = self.proj(x)
return x
class AttentionSubsample(torch.nn.Module):
def __init__(
self, in_dim, out_dim, key_dim, num_heads=8, attn_ratio=2,
act_layer=None, stride=2, resolution=14, resolution_=7):
super().__init__()
self.num_heads = num_heads
self.scale = key_dim ** -0.5
self.key_dim = key_dim
self.nh_kd = nh_kd = key_dim * num_heads
self.d = int(attn_ratio * key_dim)
self.dh = int(attn_ratio * key_dim) * self.num_heads
self.attn_ratio = attn_ratio
self.resolution_ = resolution_
self.resolution_2 = resolution_ ** 2
h = self.dh + nh_kd
self.kv = ConvNorm(in_dim, h, resolution=resolution)
self.q = torch.nn.Sequential(
torch.nn.AvgPool2d(1, stride, 0),
ConvNorm(in_dim, nh_kd, resolution=resolution_))
self.proj = torch.nn.Sequential(
act_layer(),
ConvNorm(self.d * num_heads, out_dim, resolution=resolution_))
self.stride = stride
self.resolution = resolution
points = list(itertools.product(range(resolution), range(resolution)))
points_ = list(itertools.product(range(resolution_), range(resolution_)))
N = len(points)
N_ = len(points_)
attention_offsets = {}
idxs = []
for p1 in points_:
for p2 in points:
size = 1
offset = (
abs(p1[0] * stride - p2[0] + (size - 1) / 2),
abs(p1[1] * stride - p2[1] + (size - 1) / 2))
if offset not in attention_offsets:
attention_offsets[offset] = len(attention_offsets)
idxs.append(attention_offsets[offset])
self.attention_biases = torch.nn.Parameter(torch.zeros(num_heads, len(attention_offsets)))
self.register_buffer('attention_bias_idxs', torch.LongTensor(idxs).view(N_, N))
self.ab = None
@torch.no_grad()
def train(self, mode=True):
super().train(mode)
if mode and self.ab is not None:
self.ab = None
else:
self.ab = self.attention_biases[:, self.attention_bias_idxs]
def forward(self, x):
B, C, H, W = x.shape
k, v = self.kv(x).view(B, self.num_heads, -1, H * W).split([self.key_dim, self.d], dim=2)
q = self.q(x).view(B, self.num_heads, self.key_dim, self.resolution_2)
ab = self.attention_biases[:, self.attention_bias_idxs] if self.training else self.ab
attn = (q.transpose(-2, -1) @ k) * self.scale + ab
attn = attn.softmax(dim=-1)
x = (v @ attn.transpose(-2, -1)).reshape(B, -1, self.resolution_, self.resolution_)
x = self.proj(x)
return x
class Levit(torch.nn.Module):
""" Vision Transformer with support for patch or hybrid CNN input stage
"""
def __init__(
self,
img_size=224,
patch_size=16,
in_chans=3,
num_classes=1000,
embed_dim=[192],
key_dim=[64],
depth=[12],
num_heads=[3],
attn_ratio=[2],
mlp_ratio=[2],
hybrid_backbone=None,
down_ops=[],
attn_act_layer=torch.nn.Hardswish,
mlp_act_layer=torch.nn.Hardswish,
distillation=True,
drop_path=0):
super().__init__()
self.num_classes = num_classes
self.num_features = embed_dim[-1]
self.embed_dim = embed_dim
self.distillation = distillation
self.patch_embed = hybrid_backbone
self.blocks = []
down_ops.append([''])
resolution = img_size // patch_size
for i, (ed, kd, dpth, nh, ar, mr, do) in enumerate(
zip(embed_dim, key_dim, depth, num_heads, attn_ratio, mlp_ratio, down_ops)):
for _ in range(dpth):
self.blocks.append(
Residual(
Attention(ed, kd, nh, attn_ratio=ar, act_layer=attn_act_layer, resolution=resolution),
drop_path))
if mr > 0:
h = int(ed * mr)
self.blocks.append(
Residual(torch.nn.Sequential(
ConvNorm(ed, h, resolution=resolution),
mlp_act_layer(),
ConvNorm(h, ed, bn_weight_init=0, resolution=resolution),
), drop_path))
if do[0] == 'Subsample':
# ('Subsample',key_dim, num_heads, attn_ratio, mlp_ratio, stride)
resolution_ = (resolution - 1) // do[5] + 1
self.blocks.append(
AttentionSubsample(
*embed_dim[i:i + 2], key_dim=do[1], num_heads=do[2], attn_ratio=do[3],
act_layer=attn_act_layer, stride=do[5],
resolution=resolution, resolution_=resolution_))
resolution = resolution_
if do[4] > 0: # mlp_ratio
h = int(embed_dim[i + 1] * do[4])
self.blocks.append(
Residual(torch.nn.Sequential(
ConvNorm(embed_dim[i + 1], h, resolution=resolution),
mlp_act_layer(),
ConvNorm(h, embed_dim[i + 1], bn_weight_init=0, resolution=resolution),
), drop_path))
self.blocks = torch.nn.Sequential(*self.blocks)
# Classifier head
self.head = NormLinear(
embed_dim[-1], num_classes) if num_classes > 0 else torch.nn.Identity()
if distillation:
self.head_dist = NormLinear(
embed_dim[-1], num_classes) if num_classes > 0 else torch.nn.Identity()
@torch.jit.ignore
def no_weight_decay(self):
return {x for x in self.state_dict().keys() if 'attention_biases' in x}
def forward(self, x):
x = self.patch_embed(x)
x = self.blocks(x)
x = torch.nn.functional.adaptive_avg_pool2d(x, 1).flatten(1)
if self.distillation:
x = self.head(x), self.head_dist(x)
if not self.training:
x = (x[0] + x[1]) / 2
else:
x = self.head(x)
return x
def model_factory(C, D, X, N, drop_path, weights, num_classes, distillation, pretrained, fuse):
embed_dim = [int(x) for x in C.split('_')]
num_heads = [int(x) for x in N.split('_')]
depth = [int(x) for x in X.split('_')]
act = torch.nn.Hardswish
model = Levit(
patch_size=16,
embed_dim=embed_dim,
num_heads=num_heads,
key_dim=[D] * 3,
depth=depth,
attn_ratio=[2, 2, 2],
mlp_ratio=[2, 2, 2],
down_ops=[
# ('Subsample',key_dim, num_heads, attn_ratio, mlp_ratio, stride)
['Subsample', D, embed_dim[0] // D, 4, 2, 2],
['Subsample', D, embed_dim[1] // D, 4, 2, 2],
],
attn_act_layer=act,
mlp_act_layer=act,
hybrid_backbone=b16(embed_dim[0], activation=act),
num_classes=num_classes,
drop_path=drop_path,
distillation=distillation
)
model.default_cfg = _cfg()
if pretrained:
checkpoint = torch.hub.load_state_dict_from_url(
weights, map_location='cpu')
d = checkpoint['model']
D = model.state_dict()
for k in d.keys():
if D[k].shape != d[k].shape:
d[k] = d[k][:, :, None, None]
model.load_state_dict(d)
#if fuse:
# utils.replace_batchnorm(model)
return model

@ -273,25 +273,14 @@ def _init_weights(m, n: str, head_bias: float = 0.):
nn.init.ones_(m.weight) nn.init.ones_(m.weight)
def _create_mixer(variant, pretrained=False, default_cfg=None, **kwargs): def _create_mixer(variant, pretrained=False, **kwargs):
if default_cfg is None:
default_cfg = deepcopy(default_cfgs[variant])
overlay_external_default_cfg(default_cfg, kwargs)
default_num_classes = default_cfg['num_classes']
default_img_size = default_cfg['input_size'][-2:]
num_classes = kwargs.pop('num_classes', default_num_classes)
img_size = kwargs.pop('img_size', default_img_size)
if kwargs.get('features_only', None): if kwargs.get('features_only', None):
raise RuntimeError('features_only not implemented for MLP-Mixer models.') raise RuntimeError('features_only not implemented for MLP-Mixer models.')
model = build_model_with_cfg( model = build_model_with_cfg(
MlpMixer, variant, pretrained, MlpMixer, variant, pretrained,
default_cfg=default_cfg, default_cfg=default_cfgs[variant],
img_size=img_size,
num_classes=num_classes,
**kwargs) **kwargs)
return model return model

@ -251,24 +251,14 @@ def checkpoint_filter_fn(state_dict, model):
def _create_pit(variant, pretrained=False, **kwargs): def _create_pit(variant, pretrained=False, **kwargs):
default_cfg = deepcopy(default_cfgs[variant])
overlay_external_default_cfg(default_cfg, kwargs)
default_num_classes = default_cfg['num_classes']
default_img_size = default_cfg['input_size'][-2:]
img_size = kwargs.pop('img_size', default_img_size)
num_classes = kwargs.pop('num_classes', default_num_classes)
if kwargs.get('features_only', None): if kwargs.get('features_only', None):
raise RuntimeError('features_only not implemented for Vision Transformer models.') raise RuntimeError('features_only not implemented for Vision Transformer models.')
model = build_model_with_cfg( model = build_model_with_cfg(
PoolingVisionTransformer, variant, pretrained, PoolingVisionTransformer, variant, pretrained,
default_cfg=default_cfg, default_cfg=default_cfgs[variant],
img_size=img_size,
num_classes=num_classes,
pretrained_filter_fn=checkpoint_filter_fn, pretrained_filter_fn=checkpoint_filter_fn,
**kwargs) **kwargs)
return model return model

@ -12,7 +12,7 @@ import torch.nn as nn
from functools import partial from functools import partial
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.models.helpers import load_pretrained from timm.models.helpers import build_model_with_cfg
from timm.models.layers import Mlp, DropPath, trunc_normal_ from timm.models.layers import Mlp, DropPath, trunc_normal_
from timm.models.layers.helpers import to_2tuple from timm.models.layers.helpers import to_2tuple
from timm.models.registry import register_model from timm.models.registry import register_model
@ -238,24 +238,31 @@ def checkpoint_filter_fn(state_dict, model):
return state_dict return state_dict
def _create_tnt(variant, pretrained=False, **kwargs):
if kwargs.get('features_only', None):
raise RuntimeError('features_only not implemented for Vision Transformer models.')
model = build_model_with_cfg(
TNT, variant, pretrained,
default_cfg=default_cfgs[variant],
pretrained_filter_fn=checkpoint_filter_fn,
**kwargs)
return model
@register_model @register_model
def tnt_s_patch16_224(pretrained=False, **kwargs): def tnt_s_patch16_224(pretrained=False, **kwargs):
model = TNT(patch_size=16, embed_dim=384, in_dim=24, depth=12, num_heads=6, in_num_head=4, model_cfg = dict(
patch_size=16, embed_dim=384, in_dim=24, depth=12, num_heads=6, in_num_head=4,
qkv_bias=False, **kwargs) qkv_bias=False, **kwargs)
model.default_cfg = default_cfgs['tnt_s_patch16_224'] model = _create_tnt('tnt_s_patch16_224', pretrained=pretrained, **model_cfg)
if pretrained:
load_pretrained(
model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3),
filter_fn=checkpoint_filter_fn)
return model return model
@register_model @register_model
def tnt_b_patch16_224(pretrained=False, **kwargs): def tnt_b_patch16_224(pretrained=False, **kwargs):
model = TNT(patch_size=16, embed_dim=640, in_dim=40, depth=12, num_heads=10, in_num_head=4, model_cfg = dict(
patch_size=16, embed_dim=640, in_dim=40, depth=12, num_heads=10, in_num_head=4,
qkv_bias=False, **kwargs) qkv_bias=False, **kwargs)
model.default_cfg = default_cfgs['tnt_b_patch16_224'] model = _create_tnt('tnt_b_patch16_224', pretrained=pretrained, **model_cfg)
if pretrained:
load_pretrained(
model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3))
return model return model

@ -33,7 +33,7 @@ def _cfg(url='', **kwargs):
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True, 'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True,
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
'first_conv': 'patch_embed.proj', 'classifier': 'head', 'first_conv': 'patch_embeds.0.proj', 'classifier': 'head',
**kwargs **kwargs
} }
@ -361,25 +361,14 @@ class Twins(nn.Module):
return x return x
def _create_twins(variant, pretrained=False, default_cfg=None, **kwargs): def _create_twins(variant, pretrained=False, **kwargs):
if default_cfg is None:
default_cfg = deepcopy(default_cfgs[variant])
overlay_external_default_cfg(default_cfg, kwargs)
default_num_classes = default_cfg['num_classes']
default_img_size = default_cfg['input_size'][-2:]
num_classes = kwargs.pop('num_classes', default_num_classes)
img_size = kwargs.pop('img_size', default_img_size)
if kwargs.get('features_only', None): if kwargs.get('features_only', None):
raise RuntimeError('features_only not implemented for Vision Transformer models.') raise RuntimeError('features_only not implemented for Vision Transformer models.')
model = build_model_with_cfg( model = build_model_with_cfg(
Twins, variant, pretrained, Twins, variant, pretrained,
default_cfg=default_cfg, default_cfg=default_cfgs[variant],
img_size=img_size,
num_classes=num_classes,
**kwargs) **kwargs)
return model return model

@ -1,3 +1,12 @@
""" Visformer
Paper: Visformer: The Vision-friendly Transformer - https://arxiv.org/abs/2104.12533
From original at https://github.com/danczs/Visformer
"""
from copy import deepcopy
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
@ -22,6 +31,12 @@ def _cfg(url='', **kwargs):
} }
default_cfgs = dict(
visformer_tiny=_cfg(),
visformer_small=_cfg(),
)
class LayerNormBHWC(nn.LayerNorm): class LayerNormBHWC(nn.LayerNorm):
def __init__(self, dim): def __init__(self, dim):
super().__init__(dim) super().__init__(dim)
@ -300,87 +315,97 @@ class Visformer(nn.Module):
return x return x
def _create_visformer(variant, pretrained=False, default_cfg=None, **kwargs):
if kwargs.get('features_only', None):
raise RuntimeError('features_only not implemented for Vision Transformer models.')
model = build_model_with_cfg(
Visformer, variant, pretrained,
default_cfg=default_cfgs[variant],
**kwargs)
return model
@register_model @register_model
def visformer_tiny(pretrained=False, **kwargs): def visformer_tiny(pretrained=False, **kwargs):
model = Visformer( model_cfg = dict(
img_size=224, init_channels=16, embed_dim=192, depth=(7, 4, 4), num_heads=3, mlp_ratio=4., group=8, img_size=224, init_channels=16, embed_dim=192, depth=(7, 4, 4), num_heads=3, mlp_ratio=4., group=8,
attn_stage='011', spatial_conv='100', norm_layer=nn.BatchNorm2d, conv_init=True, attn_stage='011', spatial_conv='100', norm_layer=nn.BatchNorm2d, conv_init=True,
embed_norm=nn.BatchNorm2d, **kwargs) embed_norm=nn.BatchNorm2d, **kwargs)
model.default_cfg = _cfg() model = _create_visformer('visformer_tiny', pretrained=pretrained, **model_cfg)
return model return model
@register_model @register_model
def visformer_small(pretrained=False, **kwargs): def visformer_small(pretrained=False, **kwargs):
model = Visformer( model_cfg = dict(
img_size=224, init_channels=32, embed_dim=384, depth=(7, 4, 4), num_heads=6, mlp_ratio=4., group=8, img_size=224, init_channels=32, embed_dim=384, depth=(7, 4, 4), num_heads=6, mlp_ratio=4., group=8,
attn_stage='011', spatial_conv='100', norm_layer=nn.BatchNorm2d, conv_init=True, attn_stage='011', spatial_conv='100', norm_layer=nn.BatchNorm2d, conv_init=True,
embed_norm=nn.BatchNorm2d, **kwargs) embed_norm=nn.BatchNorm2d, **kwargs)
model.default_cfg = _cfg() model = _create_visformer('visformer_small', pretrained=pretrained, **model_cfg)
return model return model
@register_model # @register_model
def visformer_net1(pretrained=False, **kwargs): # def visformer_net1(pretrained=False, **kwargs):
model = Visformer( # model = Visformer(
init_channels=None, embed_dim=384, depth=(0, 12, 0), num_heads=6, mlp_ratio=4., attn_stage='111', # init_channels=None, embed_dim=384, depth=(0, 12, 0), num_heads=6, mlp_ratio=4., attn_stage='111',
spatial_conv='000', vit_stem=True, conv_init=True, **kwargs) # spatial_conv='000', vit_stem=True, conv_init=True, **kwargs)
model.default_cfg = _cfg() # model.default_cfg = _cfg()
return model # return model
#
#
@register_model # @register_model
def visformer_net2(pretrained=False, **kwargs): # def visformer_net2(pretrained=False, **kwargs):
model = Visformer( # model = Visformer(
init_channels=32, embed_dim=384, depth=(0, 12, 0), num_heads=6, mlp_ratio=4., attn_stage='111', # init_channels=32, embed_dim=384, depth=(0, 12, 0), num_heads=6, mlp_ratio=4., attn_stage='111',
spatial_conv='000', vit_stem=False, conv_init=True, **kwargs) # spatial_conv='000', vit_stem=False, conv_init=True, **kwargs)
model.default_cfg = _cfg() # model.default_cfg = _cfg()
return model # return model
#
#
@register_model # @register_model
def visformer_net3(pretrained=False, **kwargs): # def visformer_net3(pretrained=False, **kwargs):
model = Visformer( # model = Visformer(
init_channels=32, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4., attn_stage='111', # init_channels=32, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4., attn_stage='111',
spatial_conv='000', vit_stem=False, conv_init=True, **kwargs) # spatial_conv='000', vit_stem=False, conv_init=True, **kwargs)
model.default_cfg = _cfg() # model.default_cfg = _cfg()
return model # return model
#
#
@register_model # @register_model
def visformer_net4(pretrained=False, **kwargs): # def visformer_net4(pretrained=False, **kwargs):
model = Visformer( # model = Visformer(
init_channels=32, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4., attn_stage='111', # init_channels=32, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4., attn_stage='111',
spatial_conv='000', vit_stem=False, conv_init=True, **kwargs) # spatial_conv='000', vit_stem=False, conv_init=True, **kwargs)
model.default_cfg = _cfg() # model.default_cfg = _cfg()
return model # return model
#
#
@register_model # @register_model
def visformer_net5(pretrained=False, **kwargs): # def visformer_net5(pretrained=False, **kwargs):
model = Visformer( # model = Visformer(
init_channels=32, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4., group=1, attn_stage='111', # init_channels=32, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4., group=1, attn_stage='111',
spatial_conv='111', vit_stem=False, conv_init=True, **kwargs) # spatial_conv='111', vit_stem=False, conv_init=True, **kwargs)
model.default_cfg = _cfg() # model.default_cfg = _cfg()
return model # return model
#
#
@register_model # @register_model
def visformer_net6(pretrained=False, **kwargs): # def visformer_net6(pretrained=False, **kwargs):
model = Visformer( # model = Visformer(
init_channels=32, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4., group=1, attn_stage='111', # init_channels=32, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4., group=1, attn_stage='111',
pos_embed=False, spatial_conv='111', conv_init=True, **kwargs) # pos_embed=False, spatial_conv='111', conv_init=True, **kwargs)
model.default_cfg = _cfg() # model.default_cfg = _cfg()
return model # return model
#
#
@register_model # @register_model
def visformer_net7(pretrained=False, **kwargs): # def visformer_net7(pretrained=False, **kwargs):
model = Visformer( # model = Visformer(
init_channels=32, embed_dim=384, depth=(6, 7, 7), num_heads=6, group=1, attn_stage='000', # init_channels=32, embed_dim=384, depth=(6, 7, 7), num_heads=6, group=1, attn_stage='000',
pos_embed=False, spatial_conv='111', conv_init=True, **kwargs) # pos_embed=False, spatial_conv='111', conv_init=True, **kwargs)
model.default_cfg = _cfg() # model.default_cfg = _cfg()
return model # return model

@ -387,21 +387,20 @@ def checkpoint_filter_fn(state_dict, model):
v = v.reshape(O, -1, H, W) v = v.reshape(O, -1, H, W)
elif k == 'pos_embed' and v.shape != model.pos_embed.shape: elif k == 'pos_embed' and v.shape != model.pos_embed.shape:
# To resize pos embedding when using model at different size from pretrained weights # To resize pos embedding when using model at different size from pretrained weights
v = resize_pos_embed(v, model.pos_embed, getattr(model, 'num_tokens', 1), v = resize_pos_embed(
model.patch_embed.grid_size) v, model.pos_embed, getattr(model, 'num_tokens', 1), model.patch_embed.grid_size)
out_dict[k] = v out_dict[k] = v
return out_dict return out_dict
def _create_vision_transformer(variant, pretrained=False, default_cfg=None, **kwargs): def _create_vision_transformer(variant, pretrained=False, default_cfg=None, **kwargs):
if default_cfg is None: default_cfg = default_cfg or default_cfgs[variant]
default_cfg = deepcopy(default_cfgs[variant]) if kwargs.get('features_only', None):
overlay_external_default_cfg(default_cfg, kwargs) raise RuntimeError('features_only not implemented for Vision Transformer models.')
default_num_classes = default_cfg['num_classes']
default_img_size = default_cfg['input_size'][-2:]
num_classes = kwargs.pop('num_classes', default_num_classes) # NOTE this extra code to support handling of repr size for in21k pretrained models
img_size = kwargs.pop('img_size', default_img_size) default_num_classes = default_cfg['num_classes']
num_classes = kwargs.get('num_classes', default_num_classes)
repr_size = kwargs.pop('representation_size', None) repr_size = kwargs.pop('representation_size', None)
if repr_size is not None and num_classes != default_num_classes: if repr_size is not None and num_classes != default_num_classes:
# Remove representation layer if fine-tuning. This may not always be the desired action, # Remove representation layer if fine-tuning. This may not always be the desired action,
@ -409,18 +408,12 @@ def _create_vision_transformer(variant, pretrained=False, default_cfg=None, **kw
_logger.warning("Removing representation layer for fine-tuning.") _logger.warning("Removing representation layer for fine-tuning.")
repr_size = None repr_size = None
if kwargs.get('features_only', None):
raise RuntimeError('features_only not implemented for Vision Transformer models.')
model = build_model_with_cfg( model = build_model_with_cfg(
VisionTransformer, variant, pretrained, VisionTransformer, variant, pretrained,
default_cfg=default_cfg, default_cfg=default_cfg,
img_size=img_size,
num_classes=num_classes,
representation_size=repr_size, representation_size=repr_size,
pretrained_filter_fn=checkpoint_filter_fn, pretrained_filter_fn=checkpoint_filter_fn,
**kwargs) **kwargs)
return model return model

@ -27,7 +27,7 @@ def _cfg(url='', **kwargs):
return { return {
'url': url, 'url': url,
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
'crop_pct': .9, 'interpolation': 'bicubic', 'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True,
'mean': (0.5, 0.5, 0.5), 'std': (0.5, 0.5, 0.5), 'mean': (0.5, 0.5, 0.5), 'std': (0.5, 0.5, 0.5),
'first_conv': 'patch_embed.backbone.stem.conv', 'classifier': 'head', 'first_conv': 'patch_embed.backbone.stem.conv', 'classifier': 'head',
**kwargs **kwargs
@ -107,11 +107,10 @@ class HybridEmbed(nn.Module):
def _create_vision_transformer_hybrid(variant, backbone, pretrained=False, **kwargs): def _create_vision_transformer_hybrid(variant, backbone, pretrained=False, **kwargs):
default_cfg = deepcopy(default_cfgs[variant])
embed_layer = partial(HybridEmbed, backbone=backbone) embed_layer = partial(HybridEmbed, backbone=backbone)
kwargs.setdefault('patch_size', 1) # default patch size for hybrid models if not set kwargs.setdefault('patch_size', 1) # default patch size for hybrid models if not set
return _create_vision_transformer( return _create_vision_transformer(
variant, pretrained=pretrained, default_cfg=default_cfg, embed_layer=embed_layer, **kwargs) variant, pretrained=pretrained, embed_layer=embed_layer, default_cfg=default_cfgs[variant], **kwargs)
def _resnetv2(layers=(3, 4, 9), **kwargs): def _resnetv2(layers=(3, 4, 9), **kwargs):

Loading…
Cancel
Save