Expand scope of testing for non-std vision transformer / mlp models. Some related cleanup and create fn cleanup for all vision transformer and mlp models. More CoaT weights.

pull/637/head
Ross Wightman 4 years ago
parent 18bf520ad1
commit bfc72f75d3

@ -26,29 +26,41 @@ if 'GITHUB_ACTIONS' in os.environ: # and 'Linux' in platform.system():
EXCLUDE_FILTERS = [
'*efficientnet_l2*', '*resnext101_32x48d', '*in21k', '*152x4_bitm', '*101x3_bitm',
'*nfnet_f3*', '*nfnet_f4*', '*nfnet_f5*', '*nfnet_f6*', '*nfnet_f7*',
'*resnetrs350*', '*resnetrs420*'] + NON_STD_FILTERS
'*resnetrs350*', '*resnetrs420*']
else:
EXCLUDE_FILTERS = NON_STD_FILTERS
EXCLUDE_FILTERS = []
MAX_FWD_SIZE = 384
MAX_BWD_SIZE = 128
TARGET_FWD_SIZE = MAX_FWD_SIZE = 384
TARGET_BWD_SIZE = 128
MAX_BWD_SIZE = 384
MAX_FWD_FEAT_SIZE = 448
def _get_input_size(model, target=None):
default_cfg = model.default_cfg
input_size = default_cfg['input_size']
if 'fixed_input_size' in default_cfg and default_cfg['fixed_input_size']:
return input_size
if 'min_input_size' in default_cfg:
if target and max(input_size) > target:
input_size = default_cfg['min_input_size']
else:
if target and max(input_size) > target:
input_size = tuple([min(x, target) for x in input_size])
return input_size
@pytest.mark.timeout(120)
@pytest.mark.parametrize('model_name', list_models(exclude_filters=EXCLUDE_FILTERS[:-NUM_NON_STD]))
@pytest.mark.parametrize('model_name', list_models(exclude_filters=EXCLUDE_FILTERS))
@pytest.mark.parametrize('batch_size', [1])
def test_model_forward(model_name, batch_size):
"""Run a single forward pass with each model"""
model = create_model(model_name, pretrained=False)
model.eval()
input_size = model.default_cfg['input_size']
if any([x > MAX_FWD_SIZE for x in input_size]):
if is_model_default_key(model_name, 'fixed_input_size'):
pytest.skip("Fixed input size model > limit.")
# cap forward test at max res 384 * 384 to keep resource down
input_size = tuple([min(x, MAX_FWD_SIZE) for x in input_size])
input_size = _get_input_size(model, TARGET_FWD_SIZE)
if max(input_size) > MAX_FWD_SIZE:
pytest.skip("Fixed input size model > limit.")
inputs = torch.randn((batch_size, *input_size))
outputs = model(inputs)
@ -63,20 +75,16 @@ def test_model_backward(model_name, batch_size):
"""Run a single forward pass with each model"""
model = create_model(model_name, pretrained=False, num_classes=42)
num_params = sum([x.numel() for x in model.parameters()])
model.eval()
model.train()
input_size = model.default_cfg['input_size']
if not is_model_default_key(model_name, 'fixed_input_size'):
min_input_size = get_model_default_value(model_name, 'min_input_size')
if min_input_size is not None:
input_size = min_input_size
else:
if any([x > MAX_BWD_SIZE for x in input_size]):
# cap backward test at 128 * 128 to keep resource usage down
input_size = tuple([min(x, MAX_BWD_SIZE) for x in input_size])
input_size = _get_input_size(model, TARGET_BWD_SIZE)
if max(input_size) > MAX_BWD_SIZE:
pytest.skip("Fixed input size model > limit.")
inputs = torch.randn((batch_size, *input_size))
outputs = model(inputs)
if isinstance(outputs, tuple):
outputs = torch.cat(outputs)
outputs.mean().backward()
for n, x in model.named_parameters():
assert x.grad is not None, f'No gradient for {n}'
@ -168,12 +176,9 @@ def test_model_forward_torchscript(model_name, batch_size):
model = create_model(model_name, pretrained=False)
model.eval()
if has_model_default_key(model_name, 'fixed_input_size'):
input_size = get_model_default_value(model_name, 'input_size')
elif has_model_default_key(model_name, 'min_input_size'):
input_size = get_model_default_value(model_name, 'min_input_size')
else:
input_size = (3, 128, 128) # jit compile is already a bit slow and we've tested normal res already...
input_size = _get_input_size(model, 128)
if max(input_size) > MAX_FWD_SIZE: # NOTE using MAX_FWD_SIZE as the final limit is intentional
pytest.skip("Fixed input size model > limit.")
model = torch.jit.script(model)
outputs = model(torch.randn((batch_size, *input_size)))
@ -184,7 +189,7 @@ def test_model_forward_torchscript(model_name, batch_size):
EXCLUDE_FEAT_FILTERS = [
'*pruned*', # hopefully fix at some point
]
] + NON_STD_FILTERS
if 'GITHUB_ACTIONS' in os.environ: # and 'Linux' in platform.system():
# GitHub Linux runner is slower and hits memory limits sooner than MacOS, exclude bigger models
EXCLUDE_FEAT_FILTERS += ['*resnext101_32x32d', '*resnext101_32x16d']
@ -200,12 +205,9 @@ def test_model_forward_features(model_name, batch_size):
expected_channels = model.feature_info.channels()
assert len(expected_channels) >= 4 # all models here should have at least 4 feature levels by default, some 5 or 6
if has_model_default_key(model_name, 'fixed_input_size'):
input_size = get_model_default_value(model_name, 'input_size')
elif has_model_default_key(model_name, 'min_input_size'):
input_size = get_model_default_value(model_name, 'min_input_size')
else:
input_size = (3, 96, 96) # jit compile is already a bit slow and we've tested normal res already...
input_size = _get_input_size(model, 96) # jit compile is already a bit slow and we've tested normal res already...
if max(input_size) > MAX_FWD_SIZE: # NOTE using MAX_FWD_SIZE as the final limit is intentional
pytest.skip("Fixed input size model > limit.")
outputs = model(torch.randn((batch_size, *input_size)))
assert len(expected_channels) == len(outputs)

@ -16,8 +16,8 @@ from .hrnet import *
from .inception_resnet_v2 import *
from .inception_v3 import *
from .inception_v4 import *
from .levitc import *
from .levit import *
#from .levit import *
from .mlp_mixer import *
from .mobilenetv3 import *
from .nasnet import *

@ -306,26 +306,15 @@ def checkpoint_filter_fn(state_dict, model=None):
return checkpoint_no_module
def _create_cait(variant, pretrained=False, default_cfg=None, **kwargs):
if default_cfg is None:
default_cfg = deepcopy(default_cfgs[variant])
overlay_external_default_cfg(default_cfg, kwargs)
default_num_classes = default_cfg['num_classes']
default_img_size = default_cfg['input_size'][-2:]
num_classes = kwargs.pop('num_classes', default_num_classes)
img_size = kwargs.pop('img_size', default_img_size)
def _create_cait(variant, pretrained=False, **kwargs):
if kwargs.get('features_only', None):
raise RuntimeError('features_only not implemented for Vision Transformer models.')
model = build_model_with_cfg(
Cait, variant, pretrained,
default_cfg=default_cfg,
img_size=img_size,
num_classes=num_classes,
default_cfg=default_cfgs[variant],
pretrained_filter_fn=checkpoint_filter_fn,
**kwargs)
return model

@ -7,19 +7,19 @@ Official CoaT code at: https://github.com/mlpc-ucsd/CoaT
Modified from timm/models/vision_transformer.py
"""
from typing import Tuple, Dict, Any, Optional
from copy import deepcopy
from functools import partial
from typing import Tuple, List
import torch
import torch.nn as nn
import torch.nn.functional as F
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.models.helpers import load_pretrained
from timm.models.layers import PatchEmbed, Mlp, DropPath, to_2tuple, trunc_normal_
from timm.models.registry import register_model
from .helpers import build_model_with_cfg, overlay_external_default_cfg
from .layers import PatchEmbed, Mlp, DropPath, to_2tuple, trunc_normal_
from .registry import register_model
from functools import partial
from torch import nn
__all__ = [
"coat_tiny",
@ -34,7 +34,7 @@ def _cfg_coat(url='', **kwargs):
return {
'url': url,
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
'crop_pct': .9, 'interpolation': 'bicubic',
'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True,
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
'first_conv': 'patch_embed1.proj', 'classifier': 'head',
**kwargs
@ -42,15 +42,21 @@ def _cfg_coat(url='', **kwargs):
default_cfgs = {
'coat_tiny': _cfg_coat(),
'coat_mini': _cfg_coat(),
'coat_tiny': _cfg_coat(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-coat-weights/coat_tiny-473c2a20.pth'
),
'coat_mini': _cfg_coat(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-coat-weights/coat_mini-2c6baf49.pth'
),
'coat_lite_tiny': _cfg_coat(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-coat-weights/coat_lite_tiny-461b07a7.pth'
),
'coat_lite_mini': _cfg_coat(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-coat-weights/coat_lite_mini-d7842000.pth'
),
'coat_lite_small': _cfg_coat(),
'coat_lite_small': _cfg_coat(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-coat-weights/coat_lite_small-fea1d5a1.pth'
),
}
@ -120,11 +126,11 @@ class ConvRelPosEnc(nn.Module):
class FactorAtt_ConvRelPosEnc(nn.Module):
""" Factorized attention with convolutional relative position encoding class. """
def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., shared_crpe=None):
def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0., shared_crpe=None):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = qk_scale or head_dim ** -0.5
self.scale = head_dim ** -0.5
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop) # Note: attn_drop is actually not used.
@ -190,9 +196,8 @@ class ConvPosEnc(nn.Module):
class SerialBlock(nn.Module):
""" Serial block class.
Note: In this implementation, each serial block only contains a conv-attention and a FFN (MLP) module. """
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm,
shared_cpe=None, shared_crpe=None):
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0.,
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, shared_cpe=None, shared_crpe=None):
super().__init__()
# Conv-Attention.
@ -200,8 +205,7 @@ class SerialBlock(nn.Module):
self.norm1 = norm_layer(dim)
self.factoratt_crpe = FactorAtt_ConvRelPosEnc(
dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop,
shared_crpe=shared_crpe)
dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop, shared_crpe=shared_crpe)
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
# MLP.
@ -226,27 +230,24 @@ class SerialBlock(nn.Module):
class ParallelBlock(nn.Module):
""" Parallel block class. """
def __init__(self, dims, num_heads, mlp_ratios=[], qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm,
shared_cpes=None, shared_crpes=None):
def __init__(self, dims, num_heads, mlp_ratios=[], qkv_bias=False, drop=0., attn_drop=0.,
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, shared_crpes=None):
super().__init__()
# Conv-Attention.
self.cpes = shared_cpes
self.norm12 = norm_layer(dims[1])
self.norm13 = norm_layer(dims[2])
self.norm14 = norm_layer(dims[3])
self.factoratt_crpe2 = FactorAtt_ConvRelPosEnc(
dims[1], num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop,
dims[1], num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop,
shared_crpe=shared_crpes[1]
)
self.factoratt_crpe3 = FactorAtt_ConvRelPosEnc(
dims[2], num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop,
dims[2], num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop,
shared_crpe=shared_crpes[2]
)
self.factoratt_crpe4 = FactorAtt_ConvRelPosEnc(
dims[3], num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop,
dims[3], num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop,
shared_crpe=shared_crpes[3]
)
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
@ -262,15 +263,15 @@ class ParallelBlock(nn.Module):
self.mlp2 = self.mlp3 = self.mlp4 = Mlp(
in_features=dims[1], hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
def upsample(self, x, factor, size):
def upsample(self, x, factor: float, size: Tuple[int, int]):
""" Feature map up-sampling. """
return self.interpolate(x, scale_factor=factor, size=size)
def downsample(self, x, factor, size):
def downsample(self, x, factor: float, size: Tuple[int, int]):
""" Feature map down-sampling. """
return self.interpolate(x, scale_factor=1.0/factor, size=size)
def interpolate(self, x, scale_factor, size):
def interpolate(self, x, scale_factor: float, size: Tuple[int, int]):
""" Feature map interpolation. """
B, N, C = x.shape
H, W = size
@ -280,33 +281,28 @@ class ParallelBlock(nn.Module):
img_tokens = x[:, 1:, :]
img_tokens = img_tokens.transpose(1, 2).reshape(B, C, H, W)
img_tokens = F.interpolate(img_tokens, scale_factor=scale_factor, mode='bilinear')
img_tokens = F.interpolate(
img_tokens, scale_factor=scale_factor, recompute_scale_factor=False, mode='bilinear', align_corners=False)
img_tokens = img_tokens.reshape(B, C, -1).transpose(1, 2)
out = torch.cat((cls_token, img_tokens), dim=1)
return out
def forward(self, x1, x2, x3, x4, sizes):
_, (H2, W2), (H3, W3), (H4, W4) = sizes
# Conv-Attention.
x2 = self.cpes[1](x2, size=(H2, W2)) # Note: x1 is ignored.
x3 = self.cpes[2](x3, size=(H3, W3))
x4 = self.cpes[3](x4, size=(H4, W4))
def forward(self, x1, x2, x3, x4, sizes: List[Tuple[int, int]]):
_, S2, S3, S4 = sizes
cur2 = self.norm12(x2)
cur3 = self.norm13(x3)
cur4 = self.norm14(x4)
cur2 = self.factoratt_crpe2(cur2, size=(H2, W2))
cur3 = self.factoratt_crpe3(cur3, size=(H3, W3))
cur4 = self.factoratt_crpe4(cur4, size=(H4, W4))
upsample3_2 = self.upsample(cur3, factor=2, size=(H3, W3))
upsample4_3 = self.upsample(cur4, factor=2, size=(H4, W4))
upsample4_2 = self.upsample(cur4, factor=4, size=(H4, W4))
downsample2_3 = self.downsample(cur2, factor=2, size=(H2, W2))
downsample3_4 = self.downsample(cur3, factor=2, size=(H3, W3))
downsample2_4 = self.downsample(cur2, factor=4, size=(H2, W2))
cur2 = self.factoratt_crpe2(cur2, size=S2)
cur3 = self.factoratt_crpe3(cur3, size=S3)
cur4 = self.factoratt_crpe4(cur4, size=S4)
upsample3_2 = self.upsample(cur3, factor=2., size=S3)
upsample4_3 = self.upsample(cur4, factor=2., size=S4)
upsample4_2 = self.upsample(cur4, factor=4., size=S4)
downsample2_3 = self.downsample(cur2, factor=2., size=S2)
downsample3_4 = self.downsample(cur3, factor=2., size=S3)
downsample2_4 = self.downsample(cur2, factor=4., size=S2)
cur2 = cur2 + upsample3_2 + upsample4_2
cur3 = cur3 + upsample4_3 + downsample2_3
cur4 = cur4 + downsample3_4 + downsample2_4
@ -330,11 +326,11 @@ class ParallelBlock(nn.Module):
class CoaT(nn.Module):
""" CoaT class. """
def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dims=[0, 0, 0, 0],
serial_depths=[0, 0, 0, 0], parallel_depth=0,
num_heads=0, mlp_ratios=[0, 0, 0, 0], qkv_bias=True, qk_scale=None, drop_rate=0., attn_drop_rate=0.,
drop_path_rate=0., norm_layer=partial(nn.LayerNorm, eps=1e-6),
return_interm_layers=False, out_features = None, crpe_window=None, **kwargs):
def __init__(
self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dims=(0, 0, 0, 0),
serial_depths=(0, 0, 0, 0), parallel_depth=0, num_heads=0, mlp_ratios=(0, 0, 0, 0), qkv_bias=True,
drop_rate=0., attn_drop_rate=0., drop_path_rate=0., norm_layer=partial(nn.LayerNorm, eps=1e-6),
return_interm_layers=False, out_features=None, crpe_window=None, **kwargs):
super().__init__()
crpe_window = crpe_window or {3: 2, 5: 3, 7: 3}
self.return_interm_layers = return_interm_layers
@ -342,17 +338,18 @@ class CoaT(nn.Module):
self.num_classes = num_classes
# Patch embeddings.
img_size = to_2tuple(img_size)
self.patch_embed1 = PatchEmbed(
img_size=img_size, patch_size=patch_size, in_chans=in_chans,
embed_dim=embed_dims[0], norm_layer=nn.LayerNorm)
self.patch_embed2 = PatchEmbed(
img_size=img_size // 4, patch_size=2, in_chans=embed_dims[0],
img_size=[x // 4 for x in img_size], patch_size=2, in_chans=embed_dims[0],
embed_dim=embed_dims[1], norm_layer=nn.LayerNorm)
self.patch_embed3 = PatchEmbed(
img_size=img_size // 8, patch_size=2, in_chans=embed_dims[1],
img_size=[x // 8 for x in img_size], patch_size=2, in_chans=embed_dims[1],
embed_dim=embed_dims[2], norm_layer=nn.LayerNorm)
self.patch_embed4 = PatchEmbed(
img_size=img_size // 16, patch_size=2, in_chans=embed_dims[2],
img_size=[x // 16 for x in img_size], patch_size=2, in_chans=embed_dims[2],
embed_dim=embed_dims[3], norm_layer=nn.LayerNorm)
# Class tokens.
@ -380,7 +377,7 @@ class CoaT(nn.Module):
# Serial blocks 1.
self.serial_blocks1 = nn.ModuleList([
SerialBlock(
dim=embed_dims[0], num_heads=num_heads, mlp_ratio=mlp_ratios[0], qkv_bias=qkv_bias, qk_scale=qk_scale,
dim=embed_dims[0], num_heads=num_heads, mlp_ratio=mlp_ratios[0], qkv_bias=qkv_bias,
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr, norm_layer=norm_layer,
shared_cpe=self.cpe1, shared_crpe=self.crpe1
)
@ -390,7 +387,7 @@ class CoaT(nn.Module):
# Serial blocks 2.
self.serial_blocks2 = nn.ModuleList([
SerialBlock(
dim=embed_dims[1], num_heads=num_heads, mlp_ratio=mlp_ratios[1], qkv_bias=qkv_bias, qk_scale=qk_scale,
dim=embed_dims[1], num_heads=num_heads, mlp_ratio=mlp_ratios[1], qkv_bias=qkv_bias,
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr, norm_layer=norm_layer,
shared_cpe=self.cpe2, shared_crpe=self.crpe2
)
@ -400,7 +397,7 @@ class CoaT(nn.Module):
# Serial blocks 3.
self.serial_blocks3 = nn.ModuleList([
SerialBlock(
dim=embed_dims[2], num_heads=num_heads, mlp_ratio=mlp_ratios[2], qkv_bias=qkv_bias, qk_scale=qk_scale,
dim=embed_dims[2], num_heads=num_heads, mlp_ratio=mlp_ratios[2], qkv_bias=qkv_bias,
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr, norm_layer=norm_layer,
shared_cpe=self.cpe3, shared_crpe=self.crpe3
)
@ -410,7 +407,7 @@ class CoaT(nn.Module):
# Serial blocks 4.
self.serial_blocks4 = nn.ModuleList([
SerialBlock(
dim=embed_dims[3], num_heads=num_heads, mlp_ratio=mlp_ratios[3], qkv_bias=qkv_bias, qk_scale=qk_scale,
dim=embed_dims[3], num_heads=num_heads, mlp_ratio=mlp_ratios[3], qkv_bias=qkv_bias,
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr, norm_layer=norm_layer,
shared_cpe=self.cpe4, shared_crpe=self.crpe4
)
@ -422,10 +419,9 @@ class CoaT(nn.Module):
if self.parallel_depth > 0:
self.parallel_blocks = nn.ModuleList([
ParallelBlock(
dims=embed_dims, num_heads=num_heads, mlp_ratios=mlp_ratios, qkv_bias=qkv_bias, qk_scale=qk_scale,
dims=embed_dims, num_heads=num_heads, mlp_ratios=mlp_ratios, qkv_bias=qkv_bias,
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr, norm_layer=norm_layer,
shared_cpes=[self.cpe1, self.cpe2, self.cpe3, self.cpe4],
shared_crpes=[self.crpe1, self.crpe2, self.crpe3, self.crpe4]
shared_crpes=(self.crpe1, self.crpe2, self.crpe3, self.crpe4)
)
for _ in range(parallel_depth)]
)
@ -434,9 +430,11 @@ class CoaT(nn.Module):
# Classification head(s).
if not self.return_interm_layers:
self.norm1 = norm_layer(embed_dims[0])
self.norm2 = norm_layer(embed_dims[1])
self.norm3 = norm_layer(embed_dims[2])
if self.parallel_blocks is not None:
self.norm2 = norm_layer(embed_dims[1])
self.norm3 = norm_layer(embed_dims[2])
else:
self.norm2 = self.norm3 = None
self.norm4 = norm_layer(embed_dims[3])
if self.parallel_depth > 0:
@ -546,6 +544,7 @@ class CoaT(nn.Module):
# Parallel blocks.
for blk in self.parallel_blocks:
x2, x3, x4 = self.cpe2(x2, (H2, W2)), self.cpe3(x3, (H3, W3)), self.cpe4(x4, (H4, W4))
x1, x2, x3, x4 = blk(x1, x2, x3, x4, sizes=[(H1, W1), (H2, W2), (H3, W3), (H4, W4)])
if not torch.jit.is_scripting() and self.return_interm_layers:
@ -590,52 +589,70 @@ class CoaT(nn.Module):
return x
def checkpoint_filter_fn(state_dict, model):
out_dict = {}
for k, v in state_dict.items():
# original model had unused norm layers, removing them requires filtering pretrained checkpoints
if k.startswith('norm1') or \
(model.norm2 is None and k.startswith('norm2')) or \
(model.norm3 is None and k.startswith('norm3')):
continue
out_dict[k] = v
return out_dict
def _create_coat(variant, pretrained=False, default_cfg=None, **kwargs):
if kwargs.get('features_only', None):
raise RuntimeError('features_only not implemented for Vision Transformer models.')
model = build_model_with_cfg(
CoaT, variant, pretrained,
default_cfg=default_cfgs[variant],
pretrained_filter_fn=checkpoint_filter_fn,
**kwargs)
return model
@register_model
def coat_tiny(pretrained=False, **kwargs):
model = CoaT(
model_cfg = dict(
patch_size=4, embed_dims=[152, 152, 152, 152], serial_depths=[2, 2, 2, 2], parallel_depth=6,
num_heads=8, mlp_ratios=[4, 4, 4, 4], **kwargs)
model.default_cfg = default_cfgs['coat_tiny']
model = _create_coat('coat_tiny', pretrained=pretrained, **model_cfg)
return model
@register_model
def coat_mini(pretrained=False, **kwargs):
model = CoaT(
model_cfg = dict(
patch_size=4, embed_dims=[152, 216, 216, 216], serial_depths=[2, 2, 2, 2], parallel_depth=6,
num_heads=8, mlp_ratios=[4, 4, 4, 4], **kwargs)
model.default_cfg = default_cfgs['coat_mini']
model = _create_coat('coat_mini', pretrained=pretrained, **model_cfg)
return model
@register_model
def coat_lite_tiny(pretrained=False, **kwargs):
model = CoaT(
model_cfg = dict(
patch_size=4, embed_dims=[64, 128, 256, 320], serial_depths=[2, 2, 2, 2], parallel_depth=0,
num_heads=8, mlp_ratios=[8, 8, 4, 4], **kwargs)
# FIXME use builder
model.default_cfg = default_cfgs['coat_lite_tiny']
if pretrained:
load_pretrained(model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3))
model = _create_coat('coat_lite_tiny', pretrained=pretrained, **model_cfg)
return model
@register_model
def coat_lite_mini(pretrained=False, **kwargs):
model = CoaT(
model_cfg = dict(
patch_size=4, embed_dims=[64, 128, 320, 512], serial_depths=[2, 2, 2, 2], parallel_depth=0,
num_heads=8, mlp_ratios=[8, 8, 4, 4], **kwargs)
# FIXME use builder
model.default_cfg = default_cfgs['coat_lite_mini']
if pretrained:
load_pretrained(model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3))
model = _create_coat('coat_lite_mini', pretrained=pretrained, **model_cfg)
return model
@register_model
def coat_lite_small(pretrained=False, **kwargs):
model = CoaT(
model_cfg = dict(
patch_size=4, embed_dims=[64, 128, 320, 512], serial_depths=[3, 4, 6, 3], parallel_depth=0,
num_heads=8, mlp_ratios=[8, 8, 4, 4], **kwargs)
model.default_cfg = default_cfgs['coat_lite_small']
model = _create_coat('coat_lite_small', pretrained=pretrained, **model_cfg)
return model

@ -39,7 +39,7 @@ def _cfg(url='', **kwargs):
return {
'url': url,
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, 'fixed_input_size': True,
'first_conv': 'patch_embed.proj', 'classifier': 'head',
**kwargs
}
@ -317,6 +317,9 @@ class ConViT(nn.Module):
def _create_convit(variant, pretrained=False, **kwargs):
if kwargs.get('features_only', None):
raise RuntimeError('features_only not implemented for Vision Transformer models.')
return build_model_with_cfg(
ConViT, variant, pretrained,
default_cfg=default_cfgs[variant],

@ -44,7 +44,7 @@ def load_state_dict(checkpoint_path, use_ema=False):
raise FileNotFoundError()
def load_checkpoint(model, checkpoint_path, use_ema=False, strict=True):
def load_checkpoint(model, checkpoint_path, use_ema=False, strict=False):
state_dict = load_state_dict(checkpoint_path, use_ema)
model.load_state_dict(state_dict, strict=strict)
@ -378,7 +378,11 @@ def update_default_cfg_and_kwargs(default_cfg, kwargs, kwargs_filter):
# Overlay default cfg values from `external_default_cfg` if it exists in kwargs
overlay_external_default_cfg(default_cfg, kwargs)
# Set model __init__ args that can be determined by default_cfg (if not already passed as kwargs)
set_default_kwargs(kwargs, names=('num_classes', 'global_pool', 'in_chans'), default_cfg=default_cfg)
default_kwarg_names = ('num_classes', 'global_pool', 'in_chans')
if default_cfg.get('fixed_input_size', False):
# if fixed_input_size exists and is True, model takes an img_size arg that fixes its input size
default_kwarg_names += ('img_size',)
set_default_kwargs(kwargs, names=default_kwarg_names, default_cfg=default_cfg)
# Filter keyword args for task specific model variants (some 'features only' models, etc.)
filter_kwargs(kwargs, names=kwargs_filter)

@ -1,3 +1,22 @@
""" LeViT
Paper: `LeViT: a Vision Transformer in ConvNet's Clothing for Faster Inference`
- https://arxiv.org/abs/2104.01136
@article{graham2021levit,
title={LeViT: a Vision Transformer in ConvNet's Clothing for Faster Inference},
author={Benjamin Graham and Alaaeldin El-Nouby and Hugo Touvron and Pierre Stock and Armand Joulin and Herv\'e J\'egou and Matthijs Douze},
journal={arXiv preprint arXiv:22104.01136},
year={2021}
}
Adapted from official impl at https://github.com/facebookresearch/LeViT, original copyright bellow.
This version combines both conv/linear models and fixes torchscript compatibility.
Modifications by/coyright Copyright 2021 Ross Wightman
"""
# Copyright (c) 2015-present, Facebook, Inc.
# All rights reserved.
@ -5,10 +24,15 @@
# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
# Copyright 2020 Ross Wightman, Apache-2.0 License
import itertools
from copy import deepcopy
from functools import partial
import torch
import torch.nn as nn
from timm.data import IMAGENET_DEFAULT_STD, IMAGENET_DEFAULT_MEAN
from .helpers import build_model_with_cfg, overlay_external_default_cfg
from .layers import to_ntuple
from .vision_transformer import trunc_normal_
from .registry import register_model
@ -19,70 +43,113 @@ def _cfg(url='', **kwargs):
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True,
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
'first_conv': 'patch_embed.proj', 'classifier': 'head',
'first_conv': 'patch_embed.0.c', 'classifier': ('head.l', 'head_dist.l'),
**kwargs
}
specification = {
'levit_128s': {
'C': '128_256_384', 'D': 16, 'N': '4_6_8', 'X': '2_3_4', 'drop_path': 0,
'weights': 'https://dl.fbaipublicfiles.com/LeViT/LeViT-128S-96703c44.pth'},
'levit_128': {
'C': '128_256_384', 'D': 16, 'N': '4_8_12', 'X': '4_4_4', 'drop_path': 0,
'weights': 'https://dl.fbaipublicfiles.com/LeViT/LeViT-128-b88c2750.pth'},
'levit_192': {
'C': '192_288_384', 'D': 32, 'N': '3_5_6', 'X': '4_4_4', 'drop_path': 0,
'weights': 'https://dl.fbaipublicfiles.com/LeViT/LeViT-192-92712e41.pth'},
'levit_256': {
'C': '256_384_512', 'D': 32, 'N': '4_6_8', 'X': '4_4_4', 'drop_path': 0,
'weights': 'https://dl.fbaipublicfiles.com/LeViT/LeViT-256-13b5763e.pth'},
'levit_384': {
'C': '384_512_768', 'D': 32, 'N': '6_9_12', 'X': '4_4_4', 'drop_path': 0.1,
'weights': 'https://dl.fbaipublicfiles.com/LeViT/LeViT-384-9bdaf2e2.pth'},
}
default_cfgs = dict(
levit_128s=_cfg(
url='https://dl.fbaipublicfiles.com/LeViT/LeViT-128S-96703c44.pth'
),
levit_128=_cfg(
url='https://dl.fbaipublicfiles.com/LeViT/LeViT-128-b88c2750.pth'
),
levit_192=_cfg(
url='https://dl.fbaipublicfiles.com/LeViT/LeViT-192-92712e41.pth'
),
levit_256=_cfg(
url='https://dl.fbaipublicfiles.com/LeViT/LeViT-256-13b5763e.pth'
),
levit_384=_cfg(
url='https://dl.fbaipublicfiles.com/LeViT/LeViT-384-9bdaf2e2.pth'
),
)
model_cfgs = dict(
levit_128s=dict(
embed_dim=(128, 256, 384), key_dim=16, num_heads=(4, 6, 8), depth=(2, 3, 4)),
levit_128=dict(
embed_dim=(128, 256, 384), key_dim=16, num_heads=(4, 8, 12), depth=(4, 4, 4)),
levit_192=dict(
embed_dim=(192, 288, 384), key_dim=32, num_heads=(3, 5, 6), depth=(4, 4, 4)),
levit_256=dict(
embed_dim=(256, 384, 512), key_dim=32, num_heads=(4, 6, 8), depth=(4, 4, 4)),
levit_384=dict(
embed_dim=(384, 512, 768), key_dim=32, num_heads=(6, 9, 12), depth=(4, 4, 4)),
)
__all__ = ['Levit']
@register_model
def levit_128s(num_classes=1000, distillation=True, pretrained=False, fuse=False, **kwargs):
return model_factory(**specification['levit_128s'], num_classes=num_classes,
distillation=distillation, pretrained=pretrained, fuse=fuse)
def levit_128s(pretrained=False, fuse=False,distillation=True, use_conv=False, **kwargs):
return create_levit(
'levit_128s', pretrained=pretrained, fuse=fuse, distillation=distillation, use_conv=use_conv, **kwargs)
@register_model
def levit_128(pretrained=False, fuse=False, distillation=True, use_conv=False, **kwargs):
return create_levit(
'levit_128', pretrained=pretrained, fuse=fuse, distillation=distillation, use_conv=use_conv, **kwargs)
@register_model
def levit_192(pretrained=False, fuse=False, distillation=True, use_conv=False, **kwargs):
return create_levit(
'levit_192', pretrained=pretrained, fuse=fuse, distillation=distillation, use_conv=use_conv, **kwargs)
@register_model
def levit_128(num_classes=1000, distillation=True, pretrained=False, fuse=False, **kwargs):
return model_factory(**specification['levit_128'], num_classes=num_classes,
distillation=distillation, pretrained=pretrained, fuse=fuse)
def levit_256(pretrained=False, fuse=False, distillation=True, use_conv=False, **kwargs):
return create_levit(
'levit_256', pretrained=pretrained, fuse=fuse, distillation=distillation, use_conv=use_conv, **kwargs)
@register_model
def levit_192(num_classes=1000, distillation=True, pretrained=False, fuse=False, **kwargs):
return model_factory(**specification['levit_192'], num_classes=num_classes,
distillation=distillation, pretrained=pretrained, fuse=fuse)
def levit_384(pretrained=False, fuse=False, distillation=True, use_conv=False, **kwargs):
return create_levit(
'levit_384', pretrained=pretrained, fuse=fuse, distillation=distillation, use_conv=use_conv, **kwargs)
@register_model
def levit_256(num_classes=1000, distillation=True, pretrained=False, fuse=False, **kwargs):
return model_factory(**specification['levit_256'], num_classes=num_classes,
distillation=distillation, pretrained=pretrained, fuse=fuse)
def levit_c_128s(pretrained=False, fuse=False, distillation=True, use_conv=True,**kwargs):
return create_levit(
'levit_128s', pretrained=pretrained, fuse=fuse, distillation=distillation, use_conv=use_conv, **kwargs)
@register_model
def levit_384(num_classes=1000, distillation=True, pretrained=False, fuse=False, **kwargs):
return model_factory(**specification['levit_384'], num_classes=num_classes,
distillation=distillation, pretrained=pretrained, fuse=fuse)
def levit_c_128(pretrained=False, fuse=False,distillation=True, use_conv=True, **kwargs):
return create_levit(
'levit_128', pretrained=pretrained, fuse=fuse, distillation=distillation, use_conv=use_conv, **kwargs)
class ConvNorm(torch.nn.Sequential):
@register_model
def levit_c_192(pretrained=False, fuse=False, distillation=True, use_conv=True, **kwargs):
return create_levit(
'levit_192', pretrained=pretrained, fuse=fuse, distillation=distillation, use_conv=use_conv, **kwargs)
@register_model
def levit_c_256(pretrained=False, fuse=False, distillation=True, use_conv=True, **kwargs):
return create_levit(
'levit_256', pretrained=pretrained, fuse=fuse, distillation=distillation, use_conv=use_conv, **kwargs)
@register_model
def levit_c_384(pretrained=False, fuse=False, distillation=True, use_conv=True, **kwargs):
return create_levit(
'levit_384', pretrained=pretrained, fuse=fuse, distillation=distillation, use_conv=use_conv, **kwargs)
class ConvNorm(nn.Sequential):
def __init__(
self, a, b, ks=1, stride=1, pad=0, dilation=1, groups=1, bn_weight_init=1, resolution=-10000):
super().__init__()
self.add_module('c', torch.nn.Conv2d(a, b, ks, stride, pad, dilation, groups, bias=False))
bn = torch.nn.BatchNorm2d(b)
torch.nn.init.constant_(bn.weight, bn_weight_init)
torch.nn.init.constant_(bn.bias, 0)
self.add_module('c', nn.Conv2d(a, b, ks, stride, pad, dilation, groups, bias=False))
bn = nn.BatchNorm2d(b)
nn.init.constant_(bn.weight, bn_weight_init)
nn.init.constant_(bn.bias, 0)
self.add_module('bn', bn)
@torch.no_grad()
@ -91,7 +158,7 @@ class ConvNorm(torch.nn.Sequential):
w = bn.weight / (bn.running_var + bn.eps) ** 0.5
w = c.weight * w[:, None, None, None]
b = bn.bias - bn.running_mean * bn.weight / (bn.running_var + bn.eps) ** 0.5
m = torch.nn.Conv2d(
m = nn.Conv2d(
w.size(1), w.size(0), w.shape[2:], stride=self.c.stride,
padding=self.c.padding, dilation=self.c.dilation, groups=self.c.groups)
m.weight.data.copy_(w)
@ -99,13 +166,13 @@ class ConvNorm(torch.nn.Sequential):
return m
class LinearNorm(torch.nn.Sequential):
class LinearNorm(nn.Sequential):
def __init__(self, a, b, bn_weight_init=1, resolution=-100000):
super().__init__()
self.add_module('c', torch.nn.Linear(a, b, bias=False))
bn = torch.nn.BatchNorm1d(b)
torch.nn.init.constant_(bn.weight, bn_weight_init)
torch.nn.init.constant_(bn.bias, 0)
self.add_module('c', nn.Linear(a, b, bias=False))
bn = nn.BatchNorm1d(b)
nn.init.constant_(bn.weight, bn_weight_init)
nn.init.constant_(bn.bias, 0)
self.add_module('bn', bn)
@torch.no_grad()
@ -114,25 +181,24 @@ class LinearNorm(torch.nn.Sequential):
w = bn.weight / (bn.running_var + bn.eps) ** 0.5
w = l.weight * w[:, None]
b = bn.bias - bn.running_mean * bn.weight / (bn.running_var + bn.eps) ** 0.5
m = torch.nn.Linear(w.size(1), w.size(0))
m = nn.Linear(w.size(1), w.size(0))
m.weight.data.copy_(w)
m.bias.data.copy_(b)
return m
def forward(self, x):
l, bn = self._modules.values()
x = l(x)
return bn(x.flatten(0, 1)).reshape_as(x)
x = self.c(x)
return self.bn(x.flatten(0, 1)).reshape_as(x)
class NormLinear(torch.nn.Sequential):
class NormLinear(nn.Sequential):
def __init__(self, a, b, bias=True, std=0.02):
super().__init__()
self.add_module('bn', torch.nn.BatchNorm1d(a))
l = torch.nn.Linear(a, b, bias=bias)
self.add_module('bn', nn.BatchNorm1d(a))
l = nn.Linear(a, b, bias=bias)
trunc_normal_(l.weight, std=std)
if bias:
torch.nn.init.constant_(l.bias, 0)
nn.init.constant_(l.bias, 0)
self.add_module('l', l)
@torch.no_grad()
@ -145,24 +211,24 @@ class NormLinear(torch.nn.Sequential):
b = b @ self.l.weight.T
else:
b = (l.weight @ b[:, None]).view(-1) + self.l.bias
m = torch.nn.Linear(w.size(1), w.size(0))
m = nn.Linear(w.size(1), w.size(0))
m.weight.data.copy_(w)
m.bias.data.copy_(b)
return m
def b16(n, activation, resolution=224):
return torch.nn.Sequential(
ConvNorm(3, n // 8, 3, 2, 1, resolution=resolution),
def stem_b16(in_chs, out_chs, activation, resolution=224):
return nn.Sequential(
ConvNorm(in_chs, out_chs // 8, 3, 2, 1, resolution=resolution),
activation(),
ConvNorm(n // 8, n // 4, 3, 2, 1, resolution=resolution // 2),
ConvNorm(out_chs // 8, out_chs // 4, 3, 2, 1, resolution=resolution // 2),
activation(),
ConvNorm(n // 4, n // 2, 3, 2, 1, resolution=resolution // 4),
ConvNorm(out_chs // 4, out_chs // 2, 3, 2, 1, resolution=resolution // 4),
activation(),
ConvNorm(n // 2, n, 3, 2, 1, resolution=resolution // 8))
ConvNorm(out_chs // 2, out_chs, 3, 2, 1, resolution=resolution // 8))
class Residual(torch.nn.Module):
class Residual(nn.Module):
def __init__(self, m, drop):
super().__init__()
self.m = m
@ -176,10 +242,23 @@ class Residual(torch.nn.Module):
return x + self.m(x)
class Attention(torch.nn.Module):
class Subsample(nn.Module):
def __init__(self, stride, resolution):
super().__init__()
self.stride = stride
self.resolution = resolution
def forward(self, x):
B, N, C = x.shape
x = x.view(B, self.resolution, self.resolution, C)[:, ::self.stride, ::self.stride]
return x.reshape(B, -1, C)
class Attention(nn.Module):
def __init__(
self, dim, key_dim, num_heads=8, attn_ratio=4, act_layer=None, resolution=14):
self, dim, key_dim, num_heads=8, attn_ratio=4, act_layer=None, resolution=14, use_conv=False):
super().__init__()
self.num_heads = num_heads
self.scale = key_dim ** -0.5
self.key_dim = key_dim
@ -187,11 +266,13 @@ class Attention(torch.nn.Module):
self.d = int(attn_ratio * key_dim)
self.dh = int(attn_ratio * key_dim) * num_heads
self.attn_ratio = attn_ratio
self.use_conv = use_conv
ln_layer = ConvNorm if self.use_conv else LinearNorm
h = self.dh + nh_kd * 2
self.qkv = LinearNorm(dim, h, resolution=resolution)
self.proj = torch.nn.Sequential(
self.qkv = ln_layer(dim, h, resolution=resolution)
self.proj = nn.Sequential(
act_layer(),
LinearNorm(self.dh, dim, bn_weight_init=0, resolution=resolution))
ln_layer(self.dh, dim, bn_weight_init=0, resolution=resolution))
points = list(itertools.product(range(resolution), range(resolution)))
N = len(points)
@ -203,68 +284,68 @@ class Attention(torch.nn.Module):
if offset not in attention_offsets:
attention_offsets[offset] = len(attention_offsets)
idxs.append(attention_offsets[offset])
self.attention_biases = torch.nn.Parameter(torch.zeros(num_heads, len(attention_offsets)))
self.attention_biases = nn.Parameter(torch.zeros(num_heads, len(attention_offsets)))
self.register_buffer('attention_bias_idxs', torch.LongTensor(idxs).view(N, N))
self.ab = None
@torch.no_grad()
def train(self, mode=True):
super().train(mode)
if mode and hasattr(self, 'ab'):
del self.ab
self.ab = None if mode else self.attention_biases[:, self.attention_bias_idxs]
def forward(self, x): # x (B,C,H,W)
if self.use_conv:
B, C, H, W = x.shape
q, k, v = self.qkv(x).view(B, self.num_heads, -1, H * W).split([self.key_dim, self.key_dim, self.d], dim=2)
ab = self.attention_biases[:, self.attention_bias_idxs] if self.ab is None else self.ab
attn = (q.transpose(-2, -1) @ k) * self.scale + ab
attn = attn.softmax(dim=-1)
x = (v @ attn.transpose(-2, -1)).view(B, -1, H, W)
else:
self.ab = self.attention_biases[:, self.attention_bias_idxs]
def forward(self, x): # x (B,N,C)
B, N, C = x.shape
qkv = self.qkv(x)
q, k, v = qkv.view(B, N, self.num_heads, -1).split([self.key_dim, self.key_dim, self.d], dim=3)
q = q.permute(0, 2, 1, 3)
k = k.permute(0, 2, 1, 3)
v = v.permute(0, 2, 1, 3)
ab = self.attention_biases[:, self.attention_bias_idxs] if self.training else self.ab
attn = q @ k.transpose(-2, -1) * self.scale + ab
attn = attn.softmax(dim=-1)
x = (attn @ v).transpose(1, 2).reshape(B, N, self.dh)
B, N, C = x.shape
qkv = self.qkv(x)
q, k, v = qkv.view(B, N, self.num_heads, -1).split([self.key_dim, self.key_dim, self.d], dim=3)
q = q.permute(0, 2, 1, 3)
k = k.permute(0, 2, 1, 3)
v = v.permute(0, 2, 1, 3)
ab = self.attention_biases[:, self.attention_bias_idxs] if self.ab is None else self.ab
attn = q @ k.transpose(-2, -1) * self.scale + ab
attn = attn.softmax(dim=-1)
x = (attn @ v).transpose(1, 2).reshape(B, N, self.dh)
x = self.proj(x)
return x
class Subsample(torch.nn.Module):
def __init__(self, stride, resolution):
super().__init__()
self.stride = stride
self.resolution = resolution
def forward(self, x):
B, N, C = x.shape
x = x.view(B, self.resolution, self.resolution, C)[:, ::self.stride, ::self.stride]
return x.reshape(B, -1, C)
class AttentionSubsample(torch.nn.Module):
def __init__(self, in_dim, out_dim, key_dim, num_heads=8,
attn_ratio=2, act_layer=None, stride=2, resolution=14, resolution_=7):
class AttentionSubsample(nn.Module):
def __init__(
self, in_dim, out_dim, key_dim, num_heads=8, attn_ratio=2,
act_layer=None, stride=2, resolution=14, resolution_=7, use_conv=False):
super().__init__()
self.num_heads = num_heads
self.scale = key_dim ** -0.5
self.key_dim = key_dim
self.nh_kd = nh_kd = key_dim * num_heads
self.d = int(attn_ratio * key_dim)
self.dh = int(attn_ratio * key_dim) * self.num_heads
self.dh = self.d * self.num_heads
self.attn_ratio = attn_ratio
self.resolution_ = resolution_
self.resolution_2 = resolution_ ** 2
h = self.dh + nh_kd
self.kv = LinearNorm(in_dim, h, resolution=resolution)
self.use_conv = use_conv
if self.use_conv:
ln_layer = ConvNorm
sub_layer = partial(nn.AvgPool2d, kernel_size=1, padding=0)
else:
ln_layer = LinearNorm
sub_layer = partial(Subsample, resolution=resolution)
self.q = torch.nn.Sequential(
Subsample(stride, resolution),
LinearNorm(in_dim, nh_kd, resolution=resolution_))
self.proj = torch.nn.Sequential(
h = self.dh + nh_kd
self.kv = ln_layer(in_dim, h, resolution=resolution)
self.q = nn.Sequential(
sub_layer(stride=stride),
ln_layer(in_dim, nh_kd, resolution=resolution_))
self.proj = nn.Sequential(
act_layer(),
LinearNorm(self.dh, out_dim, resolution=resolution_))
ln_layer(self.dh, out_dim, resolution=resolution_))
self.stride = stride
self.resolution = resolution
@ -283,35 +364,43 @@ class AttentionSubsample(torch.nn.Module):
if offset not in attention_offsets:
attention_offsets[offset] = len(attention_offsets)
idxs.append(attention_offsets[offset])
self.attention_biases = torch.nn.Parameter(torch.zeros(num_heads, len(attention_offsets)))
self.attention_biases = nn.Parameter(torch.zeros(num_heads, len(attention_offsets)))
self.register_buffer('attention_bias_idxs', torch.LongTensor(idxs).view(N_, N))
self.ab = None
@torch.no_grad()
def train(self, mode=True):
super().train(mode)
if mode and hasattr(self, 'ab'):
del self.ab
else:
self.ab = self.attention_biases[:, self.attention_bias_idxs]
self.ab = None if mode else self.attention_biases[:, self.attention_bias_idxs]
def forward(self, x):
B, N, C = x.shape
k, v = self.kv(x).view(B, N, self.num_heads, -1).split([self.key_dim, self.d], dim=3)
k = k.permute(0, 2, 1, 3) # BHNC
v = v.permute(0, 2, 1, 3) # BHNC
q = self.q(x).view(B, self.resolution_2, self.num_heads, self.key_dim).permute(0, 2, 1, 3)
if self.use_conv:
B, C, H, W = x.shape
k, v = self.kv(x).view(B, self.num_heads, -1, H * W).split([self.key_dim, self.d], dim=2)
q = self.q(x).view(B, self.num_heads, self.key_dim, self.resolution_2)
ab = self.attention_biases[:, self.attention_bias_idxs] if self.ab is None else self.ab
attn = (q.transpose(-2, -1) @ k) * self.scale + ab
attn = attn.softmax(dim=-1)
x = (v @ attn.transpose(-2, -1)).reshape(B, -1, self.resolution_, self.resolution_)
else:
B, N, C = x.shape
k, v = self.kv(x).view(B, N, self.num_heads, -1).split([self.key_dim, self.d], dim=3)
k = k.permute(0, 2, 1, 3) # BHNC
v = v.permute(0, 2, 1, 3) # BHNC
q = self.q(x).view(B, self.resolution_2, self.num_heads, self.key_dim).permute(0, 2, 1, 3)
ab = self.attention_biases[:, self.attention_bias_idxs] if self.training else self.ab
attn = q @ k.transpose(-2, -1) * self.scale + ab
attn = attn.softmax(dim=-1)
ab = self.attention_biases[:, self.attention_bias_idxs] if self.ab is None else self.ab
attn = q @ k.transpose(-2, -1) * self.scale + ab
attn = attn.softmax(dim=-1)
x = (attn @ v).transpose(1, 2).reshape(B, -1, self.dh)
x = (attn @ v).transpose(1, 2).reshape(B, -1, self.dh)
x = self.proj(x)
return x
class Levit(torch.nn.Module):
class Levit(nn.Module):
""" Vision Transformer with support for patch or hybrid CNN input stage
"""
@ -321,45 +410,63 @@ class Levit(torch.nn.Module):
patch_size=16,
in_chans=3,
num_classes=1000,
embed_dim=[192],
key_dim=[64],
depth=[12],
num_heads=[3],
attn_ratio=[2],
mlp_ratio=[2],
embed_dim=(192,),
key_dim=64,
depth=(12,),
num_heads=(3,),
attn_ratio=2,
mlp_ratio=2,
hybrid_backbone=None,
down_ops=[],
attn_act_layer=torch.nn.Hardswish,
mlp_act_layer=torch.nn.Hardswish,
down_ops=None,
act_layer=nn.Hardswish,
attn_act_layer=nn.Hardswish,
distillation=True,
use_conv=False,
drop_path=0):
super().__init__()
global FLOPS_COUNTER
if isinstance(img_size, tuple):
# FIXME origin impl passes single img/res dim through whole hierarchy,
# not sure this model will be used enough to spend time fixing it.
assert img_size[0] == img_size[1]
img_size = img_size[0]
self.num_classes = num_classes
self.num_features = embed_dim[-1]
self.embed_dim = embed_dim
N = len(embed_dim)
assert len(depth) == len(num_heads) == N
key_dim = to_ntuple(N)(key_dim)
attn_ratio = to_ntuple(N)(attn_ratio)
mlp_ratio = to_ntuple(N)(mlp_ratio)
down_ops = down_ops or (
# ('Subsample',key_dim, num_heads, attn_ratio, mlp_ratio, stride)
('Subsample', key_dim[0], embed_dim[0] // key_dim[0], 4, 2, 2),
('Subsample', key_dim[0], embed_dim[1] // key_dim[1], 4, 2, 2),
('',)
)
self.distillation = distillation
self.use_conv = use_conv
ln_layer = ConvNorm if self.use_conv else LinearNorm
self.patch_embed = hybrid_backbone
self.patch_embed = hybrid_backbone or stem_b16(in_chans, embed_dim[0], activation=act_layer)
self.blocks = []
down_ops.append([''])
resolution = img_size // patch_size
for i, (ed, kd, dpth, nh, ar, mr, do) in enumerate(
zip(embed_dim, key_dim, depth, num_heads, attn_ratio, mlp_ratio, down_ops)):
for _ in range(dpth):
self.blocks.append(
Residual(
Attention(ed, kd, nh, attn_ratio=ar, act_layer=attn_act_layer, resolution=resolution),
Attention(
ed, kd, nh, attn_ratio=ar, act_layer=attn_act_layer,
resolution=resolution, use_conv=use_conv),
drop_path))
if mr > 0:
h = int(ed * mr)
self.blocks.append(
Residual(torch.nn.Sequential(
LinearNorm(ed, h, resolution=resolution),
mlp_act_layer(),
LinearNorm(h, ed, bn_weight_init=0, resolution=resolution),
Residual(nn.Sequential(
ln_layer(ed, h, resolution=resolution),
act_layer(),
ln_layer(h, ed, bn_weight_init=0, resolution=resolution),
), drop_path))
if do[0] == 'Subsample':
# ('Subsample',key_dim, num_heads, attn_ratio, mlp_ratio, stride)
@ -368,22 +475,22 @@ class Levit(torch.nn.Module):
AttentionSubsample(
*embed_dim[i:i + 2], key_dim=do[1], num_heads=do[2],
attn_ratio=do[3], act_layer=attn_act_layer, stride=do[5],
resolution=resolution, resolution_=resolution_))
resolution=resolution, resolution_=resolution_, use_conv=use_conv))
resolution = resolution_
if do[4] > 0: # mlp_ratio
h = int(embed_dim[i + 1] * do[4])
self.blocks.append(
Residual(torch.nn.Sequential(
LinearNorm(embed_dim[i + 1], h, resolution=resolution),
mlp_act_layer(),
LinearNorm(h, embed_dim[i + 1], bn_weight_init=0, resolution=resolution),
Residual(nn.Sequential(
ln_layer(embed_dim[i + 1], h, resolution=resolution),
act_layer(),
ln_layer(h, embed_dim[i + 1], bn_weight_init=0, resolution=resolution),
), drop_path))
self.blocks = torch.nn.Sequential(*self.blocks)
self.blocks = nn.Sequential(*self.blocks)
# Classifier head
self.head = NormLinear(embed_dim[-1], num_classes) if num_classes > 0 else torch.nn.Identity()
self.head = NormLinear(embed_dim[-1], num_classes) if num_classes > 0 else nn.Identity()
if distillation:
self.head_dist = NormLinear(embed_dim[-1], num_classes) if num_classes > 0 else torch.nn.Identity()
self.head_dist = NormLinear(embed_dim[-1], num_classes) if num_classes > 0 else nn.Identity()
else:
self.head_dist = None
@ -393,48 +500,44 @@ class Levit(torch.nn.Module):
def forward(self, x):
x = self.patch_embed(x)
x = x.flatten(2).transpose(1, 2)
if not self.use_conv:
x = x.flatten(2).transpose(1, 2)
x = self.blocks(x)
x = x.mean(1)
if self.distillation:
x = self.head(x), self.head_dist(x)
if not self.training:
x = (x[0] + x[1]) / 2
x = x.mean((-2, -1)) if self.use_conv else x.mean(1)
if self.head_dist is not None:
x, x_dist = self.head(x), self.head_dist(x)
if self.training and not torch.jit.is_scripting():
return x, x_dist
else:
# during inference, return the average of both classifier predictions
return (x + x_dist) / 2
else:
x = self.head(x)
return x
def model_factory(C, D, X, N, drop_path, weights, num_classes, distillation, pretrained, fuse):
embed_dim = [int(x) for x in C.split('_')]
num_heads = [int(x) for x in N.split('_')]
depth = [int(x) for x in X.split('_')]
act = torch.nn.Hardswish
model = Levit(
patch_size=16,
embed_dim=embed_dim,
num_heads=num_heads,
key_dim=[D] * 3,
depth=depth,
attn_ratio=[2, 2, 2],
mlp_ratio=[2, 2, 2],
down_ops=[
# ('Subsample',key_dim, num_heads, attn_ratio, mlp_ratio, stride)
['Subsample', D, embed_dim[0] // D, 4, 2, 2],
['Subsample', D, embed_dim[1] // D, 4, 2, 2],
],
attn_act_layer=act,
mlp_act_layer=act,
hybrid_backbone=b16(embed_dim[0], activation=act),
num_classes=num_classes,
drop_path=drop_path,
distillation=distillation
)
model.default_cfg = _cfg()
if pretrained:
checkpoint = torch.hub.load_state_dict_from_url(weights, map_location='cpu')
model.load_state_dict(checkpoint['model'])
def checkpoint_filter_fn(state_dict, model):
if 'model' in state_dict:
# For deit models
state_dict = state_dict['model']
D = model.state_dict()
for k in state_dict.keys():
if D[k].ndim == 4 and state_dict[k].ndim == 2:
state_dict[k] = state_dict[k][:, :, None, None]
return state_dict
def create_levit(variant, pretrained=False, default_cfg=None, fuse=False, **kwargs):
if kwargs.get('features_only', None):
raise RuntimeError('features_only not implemented for Vision Transformer models.')
model_cfg = dict(**model_cfgs[variant], **kwargs)
model = build_model_with_cfg(
Levit, variant, pretrained,
default_cfg=default_cfgs[variant],
pretrained_filter_fn=checkpoint_filter_fn,
**model_cfg)
#if fuse:
# utils.replace_batchnorm(model)
return model

@ -1,400 +0,0 @@
# Copyright (c) 2015-present, Facebook, Inc.
# All rights reserved.
# Modified from
# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
# Copyright 2020 Ross Wightman, Apache-2.0 License
import itertools
import torch
from timm.data import IMAGENET_DEFAULT_STD, IMAGENET_DEFAULT_MEAN
from .vision_transformer import trunc_normal_
from .registry import register_model
def _cfg(url='', **kwargs):
return {
'url': url,
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True,
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
'first_conv': 'patch_embed.proj', 'classifier': 'head',
**kwargs
}
specification = {
'levit_c_128s': {
'C': '128_256_384', 'D': 16, 'N': '4_6_8', 'X': '2_3_4', 'drop_path': 0,
'weights': 'https://dl.fbaipublicfiles.com/LeViT/LeViT-128S-96703c44.pth'},
'levit_c_128': {
'C': '128_256_384', 'D': 16, 'N': '4_8_12', 'X': '4_4_4', 'drop_path': 0,
'weights': 'https://dl.fbaipublicfiles.com/LeViT/LeViT-128-b88c2750.pth'},
'levit_c_192': {
'C': '192_288_384', 'D': 32, 'N': '3_5_6', 'X': '4_4_4', 'drop_path': 0,
'weights': 'https://dl.fbaipublicfiles.com/LeViT/LeViT-192-92712e41.pth'},
'levit_c_256': {
'C': '256_384_512', 'D': 32, 'N': '4_6_8', 'X': '4_4_4', 'drop_path': 0,
'weights': 'https://dl.fbaipublicfiles.com/LeViT/LeViT-256-13b5763e.pth'},
'levit_c_384': {
'C': '384_512_768', 'D': 32, 'N': '6_9_12', 'X': '4_4_4', 'drop_path': 0.1,
'weights': 'https://dl.fbaipublicfiles.com/LeViT/LeViT-384-9bdaf2e2.pth'},
}
__all__ = ['Levit']
@register_model
def levit_c_128s(num_classes=1000, distillation=True, pretrained=False, fuse=False, **kwargs):
return model_factory(**specification['levit_c_128s'], num_classes=num_classes,
distillation=distillation, pretrained=pretrained, fuse=fuse)
@register_model
def levit_c_128(num_classes=1000, distillation=True, pretrained=False, fuse=False, **kwargs):
return model_factory(**specification['levit_c_128'], num_classes=num_classes,
distillation=distillation, pretrained=pretrained, fuse=fuse)
@register_model
def levit_c_192(num_classes=1000, distillation=True, pretrained=False, fuse=False, **kwargs):
return model_factory(**specification['levit_c_192'], num_classes=num_classes,
distillation=distillation, pretrained=pretrained, fuse=fuse)
@register_model
def levit_c_256(num_classes=1000, distillation=True, pretrained=False, fuse=False, **kwargs):
return model_factory(**specification['levit_c_256'], num_classes=num_classes,
distillation=distillation, pretrained=pretrained, fuse=fuse)
@register_model
def levit_c_384(num_classes=1000, distillation=True, pretrained=False, fuse=False, **kwargs):
return model_factory(**specification['levit_c_384'], num_classes=num_classes,
distillation=distillation, pretrained=pretrained, fuse=fuse)
class ConvNorm(torch.nn.Sequential):
def __init__(
self, a, b, ks=1, stride=1, pad=0, dilation=1, groups=1, bn_weight_init=1, resolution=-10000):
super().__init__()
self.add_module('c', torch.nn.Conv2d(a, b, ks, stride, pad, dilation, groups, bias=False))
bn = torch.nn.BatchNorm2d(b)
torch.nn.init.constant_(bn.weight, bn_weight_init)
torch.nn.init.constant_(bn.bias, 0)
self.add_module('bn', bn)
@torch.no_grad()
def fuse(self):
c, bn = self._modules.values()
w = bn.weight / (bn.running_var + bn.eps) ** 0.5
w = c.weight * w[:, None, None, None]
b = bn.bias - bn.running_mean * bn.weight / \
(bn.running_var + bn.eps) ** 0.5
m = torch.nn.Conv2d(
w.size(1), w.size(0), w.shape[2:], stride=self.c.stride,
padding=self.c.padding, dilation=self.c.dilation, groups=self.c.groups)
m.weight.data.copy_(w)
m.bias.data.copy_(b)
return m
class NormLinear(torch.nn.Sequential):
def __init__(self, a, b, bias=True, std=0.02):
super().__init__()
self.add_module('bn', torch.nn.BatchNorm1d(a))
l = torch.nn.Linear(a, b, bias=bias)
trunc_normal_(l.weight, std=std)
if bias:
torch.nn.init.constant_(l.bias, 0)
self.add_module('l', l)
@torch.no_grad()
def fuse(self):
bn, l = self._modules.values()
w = bn.weight / (bn.running_var + bn.eps) ** 0.5
b = bn.bias - self.bn.running_mean * \
self.bn.weight / (bn.running_var + bn.eps) ** 0.5
w = l.weight * w[None, :]
if l.bias is None:
b = b @ self.l.weight.T
else:
b = (l.weight @ b[:, None]).view(-1) + self.l.bias
m = torch.nn.Linear(w.size(1), w.size(0))
m.weight.data.copy_(w)
m.bias.data.copy_(b)
return m
def b16(n, activation, resolution=224):
return torch.nn.Sequential(
ConvNorm(3, n // 8, 3, 2, 1, resolution=resolution),
activation(),
ConvNorm(n // 8, n // 4, 3, 2, 1, resolution=resolution // 2),
activation(),
ConvNorm(n // 4, n // 2, 3, 2, 1, resolution=resolution // 4),
activation(),
ConvNorm(n // 2, n, 3, 2, 1, resolution=resolution // 8))
class Residual(torch.nn.Module):
def __init__(self, m, drop):
super().__init__()
self.m = m
self.drop = drop
def forward(self, x):
if self.training and self.drop > 0:
return x + self.m(x) * torch.rand(
x.size(0), 1, 1, device=x.device).ge_(self.drop).div(1 - self.drop).detach()
else:
return x + self.m(x)
class Attention(torch.nn.Module):
def __init__(self, dim, key_dim, num_heads=8,
attn_ratio=4, act_layer=None, resolution=14):
super().__init__()
self.num_heads = num_heads
self.scale = key_dim ** -0.5
self.key_dim = key_dim
self.nh_kd = nh_kd = key_dim * num_heads
self.d = int(attn_ratio * key_dim)
self.dh = int(attn_ratio * key_dim) * num_heads
self.attn_ratio = attn_ratio
h = self.dh + nh_kd * 2
self.qkv = ConvNorm(dim, h, resolution=resolution)
self.proj = torch.nn.Sequential(
act_layer(),
ConvNorm(self.dh, dim, bn_weight_init=0, resolution=resolution))
points = list(itertools.product(range(resolution), range(resolution)))
N = len(points)
attention_offsets = {}
idxs = []
for p1 in points:
for p2 in points:
offset = (abs(p1[0] - p2[0]), abs(p1[1] - p2[1]))
if offset not in attention_offsets:
attention_offsets[offset] = len(attention_offsets)
idxs.append(attention_offsets[offset])
self.attention_biases = torch.nn.Parameter(torch.zeros(num_heads, len(attention_offsets)))
self.register_buffer('attention_bias_idxs', torch.LongTensor(idxs).view(N, N))
self.ab = None
@torch.no_grad()
def train(self, mode=True):
super().train(mode)
if mode and self.ab is not None:
self.ab = None
else:
self.ab = self.attention_biases[:, self.attention_bias_idxs]
def forward(self, x): # x (B,C,H,W)
B, C, H, W = x.shape
q, k, v = self.qkv(x).view(B, self.num_heads, -1, H * W).split([self.key_dim, self.key_dim, self.d], dim=2)
ab = self.attention_biases[:, self.attention_bias_idxs] if self.training else self.ab
attn = (q.transpose(-2, -1) @ k) * self.scale + ab
attn = attn.softmax(dim=-1)
x = (v @ attn.transpose(-2, -1)).view(B, -1, H, W)
x = self.proj(x)
return x
class AttentionSubsample(torch.nn.Module):
def __init__(
self, in_dim, out_dim, key_dim, num_heads=8, attn_ratio=2,
act_layer=None, stride=2, resolution=14, resolution_=7):
super().__init__()
self.num_heads = num_heads
self.scale = key_dim ** -0.5
self.key_dim = key_dim
self.nh_kd = nh_kd = key_dim * num_heads
self.d = int(attn_ratio * key_dim)
self.dh = int(attn_ratio * key_dim) * self.num_heads
self.attn_ratio = attn_ratio
self.resolution_ = resolution_
self.resolution_2 = resolution_ ** 2
h = self.dh + nh_kd
self.kv = ConvNorm(in_dim, h, resolution=resolution)
self.q = torch.nn.Sequential(
torch.nn.AvgPool2d(1, stride, 0),
ConvNorm(in_dim, nh_kd, resolution=resolution_))
self.proj = torch.nn.Sequential(
act_layer(),
ConvNorm(self.d * num_heads, out_dim, resolution=resolution_))
self.stride = stride
self.resolution = resolution
points = list(itertools.product(range(resolution), range(resolution)))
points_ = list(itertools.product(range(resolution_), range(resolution_)))
N = len(points)
N_ = len(points_)
attention_offsets = {}
idxs = []
for p1 in points_:
for p2 in points:
size = 1
offset = (
abs(p1[0] * stride - p2[0] + (size - 1) / 2),
abs(p1[1] * stride - p2[1] + (size - 1) / 2))
if offset not in attention_offsets:
attention_offsets[offset] = len(attention_offsets)
idxs.append(attention_offsets[offset])
self.attention_biases = torch.nn.Parameter(torch.zeros(num_heads, len(attention_offsets)))
self.register_buffer('attention_bias_idxs', torch.LongTensor(idxs).view(N_, N))
self.ab = None
@torch.no_grad()
def train(self, mode=True):
super().train(mode)
if mode and self.ab is not None:
self.ab = None
else:
self.ab = self.attention_biases[:, self.attention_bias_idxs]
def forward(self, x):
B, C, H, W = x.shape
k, v = self.kv(x).view(B, self.num_heads, -1, H * W).split([self.key_dim, self.d], dim=2)
q = self.q(x).view(B, self.num_heads, self.key_dim, self.resolution_2)
ab = self.attention_biases[:, self.attention_bias_idxs] if self.training else self.ab
attn = (q.transpose(-2, -1) @ k) * self.scale + ab
attn = attn.softmax(dim=-1)
x = (v @ attn.transpose(-2, -1)).reshape(B, -1, self.resolution_, self.resolution_)
x = self.proj(x)
return x
class Levit(torch.nn.Module):
""" Vision Transformer with support for patch or hybrid CNN input stage
"""
def __init__(
self,
img_size=224,
patch_size=16,
in_chans=3,
num_classes=1000,
embed_dim=[192],
key_dim=[64],
depth=[12],
num_heads=[3],
attn_ratio=[2],
mlp_ratio=[2],
hybrid_backbone=None,
down_ops=[],
attn_act_layer=torch.nn.Hardswish,
mlp_act_layer=torch.nn.Hardswish,
distillation=True,
drop_path=0):
super().__init__()
self.num_classes = num_classes
self.num_features = embed_dim[-1]
self.embed_dim = embed_dim
self.distillation = distillation
self.patch_embed = hybrid_backbone
self.blocks = []
down_ops.append([''])
resolution = img_size // patch_size
for i, (ed, kd, dpth, nh, ar, mr, do) in enumerate(
zip(embed_dim, key_dim, depth, num_heads, attn_ratio, mlp_ratio, down_ops)):
for _ in range(dpth):
self.blocks.append(
Residual(
Attention(ed, kd, nh, attn_ratio=ar, act_layer=attn_act_layer, resolution=resolution),
drop_path))
if mr > 0:
h = int(ed * mr)
self.blocks.append(
Residual(torch.nn.Sequential(
ConvNorm(ed, h, resolution=resolution),
mlp_act_layer(),
ConvNorm(h, ed, bn_weight_init=0, resolution=resolution),
), drop_path))
if do[0] == 'Subsample':
# ('Subsample',key_dim, num_heads, attn_ratio, mlp_ratio, stride)
resolution_ = (resolution - 1) // do[5] + 1
self.blocks.append(
AttentionSubsample(
*embed_dim[i:i + 2], key_dim=do[1], num_heads=do[2], attn_ratio=do[3],
act_layer=attn_act_layer, stride=do[5],
resolution=resolution, resolution_=resolution_))
resolution = resolution_
if do[4] > 0: # mlp_ratio
h = int(embed_dim[i + 1] * do[4])
self.blocks.append(
Residual(torch.nn.Sequential(
ConvNorm(embed_dim[i + 1], h, resolution=resolution),
mlp_act_layer(),
ConvNorm(h, embed_dim[i + 1], bn_weight_init=0, resolution=resolution),
), drop_path))
self.blocks = torch.nn.Sequential(*self.blocks)
# Classifier head
self.head = NormLinear(
embed_dim[-1], num_classes) if num_classes > 0 else torch.nn.Identity()
if distillation:
self.head_dist = NormLinear(
embed_dim[-1], num_classes) if num_classes > 0 else torch.nn.Identity()
@torch.jit.ignore
def no_weight_decay(self):
return {x for x in self.state_dict().keys() if 'attention_biases' in x}
def forward(self, x):
x = self.patch_embed(x)
x = self.blocks(x)
x = torch.nn.functional.adaptive_avg_pool2d(x, 1).flatten(1)
if self.distillation:
x = self.head(x), self.head_dist(x)
if not self.training:
x = (x[0] + x[1]) / 2
else:
x = self.head(x)
return x
def model_factory(C, D, X, N, drop_path, weights, num_classes, distillation, pretrained, fuse):
embed_dim = [int(x) for x in C.split('_')]
num_heads = [int(x) for x in N.split('_')]
depth = [int(x) for x in X.split('_')]
act = torch.nn.Hardswish
model = Levit(
patch_size=16,
embed_dim=embed_dim,
num_heads=num_heads,
key_dim=[D] * 3,
depth=depth,
attn_ratio=[2, 2, 2],
mlp_ratio=[2, 2, 2],
down_ops=[
# ('Subsample',key_dim, num_heads, attn_ratio, mlp_ratio, stride)
['Subsample', D, embed_dim[0] // D, 4, 2, 2],
['Subsample', D, embed_dim[1] // D, 4, 2, 2],
],
attn_act_layer=act,
mlp_act_layer=act,
hybrid_backbone=b16(embed_dim[0], activation=act),
num_classes=num_classes,
drop_path=drop_path,
distillation=distillation
)
model.default_cfg = _cfg()
if pretrained:
checkpoint = torch.hub.load_state_dict_from_url(
weights, map_location='cpu')
d = checkpoint['model']
D = model.state_dict()
for k in d.keys():
if D[k].shape != d[k].shape:
d[k] = d[k][:, :, None, None]
model.load_state_dict(d)
#if fuse:
# utils.replace_batchnorm(model)
return model

@ -273,25 +273,14 @@ def _init_weights(m, n: str, head_bias: float = 0.):
nn.init.ones_(m.weight)
def _create_mixer(variant, pretrained=False, default_cfg=None, **kwargs):
if default_cfg is None:
default_cfg = deepcopy(default_cfgs[variant])
overlay_external_default_cfg(default_cfg, kwargs)
default_num_classes = default_cfg['num_classes']
default_img_size = default_cfg['input_size'][-2:]
num_classes = kwargs.pop('num_classes', default_num_classes)
img_size = kwargs.pop('img_size', default_img_size)
def _create_mixer(variant, pretrained=False, **kwargs):
if kwargs.get('features_only', None):
raise RuntimeError('features_only not implemented for MLP-Mixer models.')
model = build_model_with_cfg(
MlpMixer, variant, pretrained,
default_cfg=default_cfg,
img_size=img_size,
num_classes=num_classes,
default_cfg=default_cfgs[variant],
**kwargs)
return model

@ -251,24 +251,14 @@ def checkpoint_filter_fn(state_dict, model):
def _create_pit(variant, pretrained=False, **kwargs):
default_cfg = deepcopy(default_cfgs[variant])
overlay_external_default_cfg(default_cfg, kwargs)
default_num_classes = default_cfg['num_classes']
default_img_size = default_cfg['input_size'][-2:]
img_size = kwargs.pop('img_size', default_img_size)
num_classes = kwargs.pop('num_classes', default_num_classes)
if kwargs.get('features_only', None):
raise RuntimeError('features_only not implemented for Vision Transformer models.')
model = build_model_with_cfg(
PoolingVisionTransformer, variant, pretrained,
default_cfg=default_cfg,
img_size=img_size,
num_classes=num_classes,
default_cfg=default_cfgs[variant],
pretrained_filter_fn=checkpoint_filter_fn,
**kwargs)
return model

@ -12,7 +12,7 @@ import torch.nn as nn
from functools import partial
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.models.helpers import load_pretrained
from timm.models.helpers import build_model_with_cfg
from timm.models.layers import Mlp, DropPath, trunc_normal_
from timm.models.layers.helpers import to_2tuple
from timm.models.registry import register_model
@ -238,24 +238,31 @@ def checkpoint_filter_fn(state_dict, model):
return state_dict
def _create_tnt(variant, pretrained=False, **kwargs):
if kwargs.get('features_only', None):
raise RuntimeError('features_only not implemented for Vision Transformer models.')
model = build_model_with_cfg(
TNT, variant, pretrained,
default_cfg=default_cfgs[variant],
pretrained_filter_fn=checkpoint_filter_fn,
**kwargs)
return model
@register_model
def tnt_s_patch16_224(pretrained=False, **kwargs):
model = TNT(patch_size=16, embed_dim=384, in_dim=24, depth=12, num_heads=6, in_num_head=4,
model_cfg = dict(
patch_size=16, embed_dim=384, in_dim=24, depth=12, num_heads=6, in_num_head=4,
qkv_bias=False, **kwargs)
model.default_cfg = default_cfgs['tnt_s_patch16_224']
if pretrained:
load_pretrained(
model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3),
filter_fn=checkpoint_filter_fn)
model = _create_tnt('tnt_s_patch16_224', pretrained=pretrained, **model_cfg)
return model
@register_model
def tnt_b_patch16_224(pretrained=False, **kwargs):
model = TNT(patch_size=16, embed_dim=640, in_dim=40, depth=12, num_heads=10, in_num_head=4,
model_cfg = dict(
patch_size=16, embed_dim=640, in_dim=40, depth=12, num_heads=10, in_num_head=4,
qkv_bias=False, **kwargs)
model.default_cfg = default_cfgs['tnt_b_patch16_224']
if pretrained:
load_pretrained(
model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3))
model = _create_tnt('tnt_b_patch16_224', pretrained=pretrained, **model_cfg)
return model

@ -33,7 +33,7 @@ def _cfg(url='', **kwargs):
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True,
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
'first_conv': 'patch_embed.proj', 'classifier': 'head',
'first_conv': 'patch_embeds.0.proj', 'classifier': 'head',
**kwargs
}
@ -361,25 +361,14 @@ class Twins(nn.Module):
return x
def _create_twins(variant, pretrained=False, default_cfg=None, **kwargs):
if default_cfg is None:
default_cfg = deepcopy(default_cfgs[variant])
overlay_external_default_cfg(default_cfg, kwargs)
default_num_classes = default_cfg['num_classes']
default_img_size = default_cfg['input_size'][-2:]
num_classes = kwargs.pop('num_classes', default_num_classes)
img_size = kwargs.pop('img_size', default_img_size)
def _create_twins(variant, pretrained=False, **kwargs):
if kwargs.get('features_only', None):
raise RuntimeError('features_only not implemented for Vision Transformer models.')
model = build_model_with_cfg(
Twins, variant, pretrained,
default_cfg=default_cfg,
img_size=img_size,
num_classes=num_classes,
default_cfg=default_cfgs[variant],
**kwargs)
return model

@ -1,3 +1,12 @@
""" Visformer
Paper: Visformer: The Vision-friendly Transformer - https://arxiv.org/abs/2104.12533
From original at https://github.com/danczs/Visformer
"""
from copy import deepcopy
import torch
import torch.nn as nn
import torch.nn.functional as F
@ -22,6 +31,12 @@ def _cfg(url='', **kwargs):
}
default_cfgs = dict(
visformer_tiny=_cfg(),
visformer_small=_cfg(),
)
class LayerNormBHWC(nn.LayerNorm):
def __init__(self, dim):
super().__init__(dim)
@ -300,87 +315,97 @@ class Visformer(nn.Module):
return x
def _create_visformer(variant, pretrained=False, default_cfg=None, **kwargs):
if kwargs.get('features_only', None):
raise RuntimeError('features_only not implemented for Vision Transformer models.')
model = build_model_with_cfg(
Visformer, variant, pretrained,
default_cfg=default_cfgs[variant],
**kwargs)
return model
@register_model
def visformer_tiny(pretrained=False, **kwargs):
model = Visformer(
model_cfg = dict(
img_size=224, init_channels=16, embed_dim=192, depth=(7, 4, 4), num_heads=3, mlp_ratio=4., group=8,
attn_stage='011', spatial_conv='100', norm_layer=nn.BatchNorm2d, conv_init=True,
embed_norm=nn.BatchNorm2d, **kwargs)
model.default_cfg = _cfg()
model = _create_visformer('visformer_tiny', pretrained=pretrained, **model_cfg)
return model
@register_model
def visformer_small(pretrained=False, **kwargs):
model = Visformer(
model_cfg = dict(
img_size=224, init_channels=32, embed_dim=384, depth=(7, 4, 4), num_heads=6, mlp_ratio=4., group=8,
attn_stage='011', spatial_conv='100', norm_layer=nn.BatchNorm2d, conv_init=True,
embed_norm=nn.BatchNorm2d, **kwargs)
model.default_cfg = _cfg()
model = _create_visformer('visformer_small', pretrained=pretrained, **model_cfg)
return model
@register_model
def visformer_net1(pretrained=False, **kwargs):
model = Visformer(
init_channels=None, embed_dim=384, depth=(0, 12, 0), num_heads=6, mlp_ratio=4., attn_stage='111',
spatial_conv='000', vit_stem=True, conv_init=True, **kwargs)
model.default_cfg = _cfg()
return model
@register_model
def visformer_net2(pretrained=False, **kwargs):
model = Visformer(
init_channels=32, embed_dim=384, depth=(0, 12, 0), num_heads=6, mlp_ratio=4., attn_stage='111',
spatial_conv='000', vit_stem=False, conv_init=True, **kwargs)
model.default_cfg = _cfg()
return model
@register_model
def visformer_net3(pretrained=False, **kwargs):
model = Visformer(
init_channels=32, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4., attn_stage='111',
spatial_conv='000', vit_stem=False, conv_init=True, **kwargs)
model.default_cfg = _cfg()
return model
@register_model
def visformer_net4(pretrained=False, **kwargs):
model = Visformer(
init_channels=32, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4., attn_stage='111',
spatial_conv='000', vit_stem=False, conv_init=True, **kwargs)
model.default_cfg = _cfg()
return model
@register_model
def visformer_net5(pretrained=False, **kwargs):
model = Visformer(
init_channels=32, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4., group=1, attn_stage='111',
spatial_conv='111', vit_stem=False, conv_init=True, **kwargs)
model.default_cfg = _cfg()
return model
@register_model
def visformer_net6(pretrained=False, **kwargs):
model = Visformer(
init_channels=32, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4., group=1, attn_stage='111',
pos_embed=False, spatial_conv='111', conv_init=True, **kwargs)
model.default_cfg = _cfg()
return model
@register_model
def visformer_net7(pretrained=False, **kwargs):
model = Visformer(
init_channels=32, embed_dim=384, depth=(6, 7, 7), num_heads=6, group=1, attn_stage='000',
pos_embed=False, spatial_conv='111', conv_init=True, **kwargs)
model.default_cfg = _cfg()
return model
# @register_model
# def visformer_net1(pretrained=False, **kwargs):
# model = Visformer(
# init_channels=None, embed_dim=384, depth=(0, 12, 0), num_heads=6, mlp_ratio=4., attn_stage='111',
# spatial_conv='000', vit_stem=True, conv_init=True, **kwargs)
# model.default_cfg = _cfg()
# return model
#
#
# @register_model
# def visformer_net2(pretrained=False, **kwargs):
# model = Visformer(
# init_channels=32, embed_dim=384, depth=(0, 12, 0), num_heads=6, mlp_ratio=4., attn_stage='111',
# spatial_conv='000', vit_stem=False, conv_init=True, **kwargs)
# model.default_cfg = _cfg()
# return model
#
#
# @register_model
# def visformer_net3(pretrained=False, **kwargs):
# model = Visformer(
# init_channels=32, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4., attn_stage='111',
# spatial_conv='000', vit_stem=False, conv_init=True, **kwargs)
# model.default_cfg = _cfg()
# return model
#
#
# @register_model
# def visformer_net4(pretrained=False, **kwargs):
# model = Visformer(
# init_channels=32, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4., attn_stage='111',
# spatial_conv='000', vit_stem=False, conv_init=True, **kwargs)
# model.default_cfg = _cfg()
# return model
#
#
# @register_model
# def visformer_net5(pretrained=False, **kwargs):
# model = Visformer(
# init_channels=32, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4., group=1, attn_stage='111',
# spatial_conv='111', vit_stem=False, conv_init=True, **kwargs)
# model.default_cfg = _cfg()
# return model
#
#
# @register_model
# def visformer_net6(pretrained=False, **kwargs):
# model = Visformer(
# init_channels=32, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4., group=1, attn_stage='111',
# pos_embed=False, spatial_conv='111', conv_init=True, **kwargs)
# model.default_cfg = _cfg()
# return model
#
#
# @register_model
# def visformer_net7(pretrained=False, **kwargs):
# model = Visformer(
# init_channels=32, embed_dim=384, depth=(6, 7, 7), num_heads=6, group=1, attn_stage='000',
# pos_embed=False, spatial_conv='111', conv_init=True, **kwargs)
# model.default_cfg = _cfg()
# return model

@ -387,21 +387,20 @@ def checkpoint_filter_fn(state_dict, model):
v = v.reshape(O, -1, H, W)
elif k == 'pos_embed' and v.shape != model.pos_embed.shape:
# To resize pos embedding when using model at different size from pretrained weights
v = resize_pos_embed(v, model.pos_embed, getattr(model, 'num_tokens', 1),
model.patch_embed.grid_size)
v = resize_pos_embed(
v, model.pos_embed, getattr(model, 'num_tokens', 1), model.patch_embed.grid_size)
out_dict[k] = v
return out_dict
def _create_vision_transformer(variant, pretrained=False, default_cfg=None, **kwargs):
if default_cfg is None:
default_cfg = deepcopy(default_cfgs[variant])
overlay_external_default_cfg(default_cfg, kwargs)
default_num_classes = default_cfg['num_classes']
default_img_size = default_cfg['input_size'][-2:]
default_cfg = default_cfg or default_cfgs[variant]
if kwargs.get('features_only', None):
raise RuntimeError('features_only not implemented for Vision Transformer models.')
num_classes = kwargs.pop('num_classes', default_num_classes)
img_size = kwargs.pop('img_size', default_img_size)
# NOTE this extra code to support handling of repr size for in21k pretrained models
default_num_classes = default_cfg['num_classes']
num_classes = kwargs.get('num_classes', default_num_classes)
repr_size = kwargs.pop('representation_size', None)
if repr_size is not None and num_classes != default_num_classes:
# Remove representation layer if fine-tuning. This may not always be the desired action,
@ -409,18 +408,12 @@ def _create_vision_transformer(variant, pretrained=False, default_cfg=None, **kw
_logger.warning("Removing representation layer for fine-tuning.")
repr_size = None
if kwargs.get('features_only', None):
raise RuntimeError('features_only not implemented for Vision Transformer models.')
model = build_model_with_cfg(
VisionTransformer, variant, pretrained,
default_cfg=default_cfg,
img_size=img_size,
num_classes=num_classes,
representation_size=repr_size,
pretrained_filter_fn=checkpoint_filter_fn,
**kwargs)
return model

@ -27,7 +27,7 @@ def _cfg(url='', **kwargs):
return {
'url': url,
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
'crop_pct': .9, 'interpolation': 'bicubic',
'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True,
'mean': (0.5, 0.5, 0.5), 'std': (0.5, 0.5, 0.5),
'first_conv': 'patch_embed.backbone.stem.conv', 'classifier': 'head',
**kwargs
@ -107,11 +107,10 @@ class HybridEmbed(nn.Module):
def _create_vision_transformer_hybrid(variant, backbone, pretrained=False, **kwargs):
default_cfg = deepcopy(default_cfgs[variant])
embed_layer = partial(HybridEmbed, backbone=backbone)
kwargs.setdefault('patch_size', 1) # default patch size for hybrid models if not set
return _create_vision_transformer(
variant, pretrained=pretrained, default_cfg=default_cfg, embed_layer=embed_layer, **kwargs)
variant, pretrained=pretrained, embed_layer=embed_layer, default_cfg=default_cfgs[variant], **kwargs)
def _resnetv2(layers=(3, 4, 9), **kwargs):

Loading…
Cancel
Save