Add levit, levit_c, and visformer model defs. Largely untested and not finished cleanup.

pull/637/head
Ross Wightman 3 years ago
parent 165fb354b2
commit ecc7552c5c

@ -15,7 +15,7 @@ if hasattr(torch._C, '_jit_set_profiling_executor'):
torch._C._jit_set_profiling_mode(False)
# transformer models don't support many of the spatial / feature based model functionalities
NON_STD_FILTERS = ['vit_*', 'tnt_*', 'pit_*', 'swin_*', 'coat_*', 'cait_*', 'mixer_*']
NON_STD_FILTERS = ['vit_*', 'tnt_*', 'pit_*', 'swin_*', 'coat_*', 'cait_*', 'mixer_*', 'levit*', 'visformer*']
NUM_NON_STD = len(NON_STD_FILTERS)
# exclude models that cause specific test failures

@ -15,6 +15,8 @@ from .hrnet import *
from .inception_resnet_v2 import *
from .inception_v3 import *
from .inception_v4 import *
from .levitc import *
from .levit import *
from .mlp_mixer import *
from .mobilenetv3 import *
from .nasnet import *
@ -34,6 +36,7 @@ from .swin_transformer import *
from .tnt import *
from .tresnet import *
from .vgg import *
from .visformer import *
from .vision_transformer import *
from .vision_transformer_hybrid import *
from .vovnet import *

@ -15,7 +15,7 @@ from .helpers import to_2tuple
class PatchEmbed(nn.Module):
""" 2D Image to Patch Embedding
"""
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, norm_layer=None):
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, norm_layer=None, flatten=True):
super().__init__()
img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size)
@ -23,6 +23,7 @@ class PatchEmbed(nn.Module):
self.patch_size = patch_size
self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
self.num_patches = self.grid_size[0] * self.grid_size[1]
self.flatten = flatten
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
@ -31,6 +32,8 @@ class PatchEmbed(nn.Module):
B, C, H, W = x.shape
assert H == self.img_size[0] and W == self.img_size[1], \
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
x = self.proj(x).flatten(2).transpose(1, 2)
x = self.proj(x)
if self.flatten:
x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
x = self.norm(x)
return x

@ -0,0 +1,440 @@
# Copyright (c) 2015-present, Facebook, Inc.
# All rights reserved.
# Modified from
# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
# Copyright 2020 Ross Wightman, Apache-2.0 License
import itertools
import torch
from timm.data import IMAGENET_DEFAULT_STD, IMAGENET_DEFAULT_MEAN
from .vision_transformer import trunc_normal_
from .registry import register_model
def _cfg(url='', **kwargs):
return {
'url': url,
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True,
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
'first_conv': 'patch_embed.proj', 'classifier': 'head',
**kwargs
}
specification = {
'levit_128s': {
'C': '128_256_384', 'D': 16, 'N': '4_6_8', 'X': '2_3_4', 'drop_path': 0,
'weights': 'https://dl.fbaipublicfiles.com/LeViT/LeViT-128S-96703c44.pth'},
'levit_128': {
'C': '128_256_384', 'D': 16, 'N': '4_8_12', 'X': '4_4_4', 'drop_path': 0,
'weights': 'https://dl.fbaipublicfiles.com/LeViT/LeViT-128-b88c2750.pth'},
'levit_192': {
'C': '192_288_384', 'D': 32, 'N': '3_5_6', 'X': '4_4_4', 'drop_path': 0,
'weights': 'https://dl.fbaipublicfiles.com/LeViT/LeViT-192-92712e41.pth'},
'levit_256': {
'C': '256_384_512', 'D': 32, 'N': '4_6_8', 'X': '4_4_4', 'drop_path': 0,
'weights': 'https://dl.fbaipublicfiles.com/LeViT/LeViT-256-13b5763e.pth'},
'levit_384': {
'C': '384_512_768', 'D': 32, 'N': '6_9_12', 'X': '4_4_4', 'drop_path': 0.1,
'weights': 'https://dl.fbaipublicfiles.com/LeViT/LeViT-384-9bdaf2e2.pth'},
}
__all__ = ['Levit']
@register_model
def levit_128s(num_classes=1000, distillation=True, pretrained=False, fuse=False, **kwargs):
return model_factory(**specification['levit_128s'], num_classes=num_classes,
distillation=distillation, pretrained=pretrained, fuse=fuse)
@register_model
def levit_128(num_classes=1000, distillation=True, pretrained=False, fuse=False, **kwargs):
return model_factory(**specification['levit_128'], num_classes=num_classes,
distillation=distillation, pretrained=pretrained, fuse=fuse)
@register_model
def levit_192(num_classes=1000, distillation=True, pretrained=False, fuse=False, **kwargs):
return model_factory(**specification['levit_192'], num_classes=num_classes,
distillation=distillation, pretrained=pretrained, fuse=fuse)
@register_model
def levit_256(num_classes=1000, distillation=True, pretrained=False, fuse=False, **kwargs):
return model_factory(**specification['levit_256'], num_classes=num_classes,
distillation=distillation, pretrained=pretrained, fuse=fuse)
@register_model
def levit_384(num_classes=1000, distillation=True, pretrained=False, fuse=False, **kwargs):
return model_factory(**specification['levit_384'], num_classes=num_classes,
distillation=distillation, pretrained=pretrained, fuse=fuse)
class ConvNorm(torch.nn.Sequential):
def __init__(
self, a, b, ks=1, stride=1, pad=0, dilation=1, groups=1, bn_weight_init=1, resolution=-10000):
super().__init__()
self.add_module('c', torch.nn.Conv2d(a, b, ks, stride, pad, dilation, groups, bias=False))
bn = torch.nn.BatchNorm2d(b)
torch.nn.init.constant_(bn.weight, bn_weight_init)
torch.nn.init.constant_(bn.bias, 0)
self.add_module('bn', bn)
@torch.no_grad()
def fuse(self):
c, bn = self._modules.values()
w = bn.weight / (bn.running_var + bn.eps) ** 0.5
w = c.weight * w[:, None, None, None]
b = bn.bias - bn.running_mean * bn.weight / (bn.running_var + bn.eps) ** 0.5
m = torch.nn.Conv2d(
w.size(1), w.size(0), w.shape[2:], stride=self.c.stride,
padding=self.c.padding, dilation=self.c.dilation, groups=self.c.groups)
m.weight.data.copy_(w)
m.bias.data.copy_(b)
return m
class LinearNorm(torch.nn.Sequential):
def __init__(self, a, b, bn_weight_init=1, resolution=-100000):
super().__init__()
self.add_module('c', torch.nn.Linear(a, b, bias=False))
bn = torch.nn.BatchNorm1d(b)
torch.nn.init.constant_(bn.weight, bn_weight_init)
torch.nn.init.constant_(bn.bias, 0)
self.add_module('bn', bn)
@torch.no_grad()
def fuse(self):
l, bn = self._modules.values()
w = bn.weight / (bn.running_var + bn.eps) ** 0.5
w = l.weight * w[:, None]
b = bn.bias - bn.running_mean * bn.weight / (bn.running_var + bn.eps) ** 0.5
m = torch.nn.Linear(w.size(1), w.size(0))
m.weight.data.copy_(w)
m.bias.data.copy_(b)
return m
def forward(self, x):
l, bn = self._modules.values()
x = l(x)
return bn(x.flatten(0, 1)).reshape_as(x)
class NormLinear(torch.nn.Sequential):
def __init__(self, a, b, bias=True, std=0.02):
super().__init__()
self.add_module('bn', torch.nn.BatchNorm1d(a))
l = torch.nn.Linear(a, b, bias=bias)
trunc_normal_(l.weight, std=std)
if bias:
torch.nn.init.constant_(l.bias, 0)
self.add_module('l', l)
@torch.no_grad()
def fuse(self):
bn, l = self._modules.values()
w = bn.weight / (bn.running_var + bn.eps) ** 0.5
b = bn.bias - self.bn.running_mean * self.bn.weight / (bn.running_var + bn.eps) ** 0.5
w = l.weight * w[None, :]
if l.bias is None:
b = b @ self.l.weight.T
else:
b = (l.weight @ b[:, None]).view(-1) + self.l.bias
m = torch.nn.Linear(w.size(1), w.size(0))
m.weight.data.copy_(w)
m.bias.data.copy_(b)
return m
def b16(n, activation, resolution=224):
return torch.nn.Sequential(
ConvNorm(3, n // 8, 3, 2, 1, resolution=resolution),
activation(),
ConvNorm(n // 8, n // 4, 3, 2, 1, resolution=resolution // 2),
activation(),
ConvNorm(n // 4, n // 2, 3, 2, 1, resolution=resolution // 4),
activation(),
ConvNorm(n // 2, n, 3, 2, 1, resolution=resolution // 8))
class Residual(torch.nn.Module):
def __init__(self, m, drop):
super().__init__()
self.m = m
self.drop = drop
def forward(self, x):
if self.training and self.drop > 0:
return x + self.m(x) * torch.rand(
x.size(0), 1, 1, device=x.device).ge_(self.drop).div(1 - self.drop).detach()
else:
return x + self.m(x)
class Attention(torch.nn.Module):
def __init__(
self, dim, key_dim, num_heads=8, attn_ratio=4, act_layer=None, resolution=14):
super().__init__()
self.num_heads = num_heads
self.scale = key_dim ** -0.5
self.key_dim = key_dim
self.nh_kd = nh_kd = key_dim * num_heads
self.d = int(attn_ratio * key_dim)
self.dh = int(attn_ratio * key_dim) * num_heads
self.attn_ratio = attn_ratio
h = self.dh + nh_kd * 2
self.qkv = LinearNorm(dim, h, resolution=resolution)
self.proj = torch.nn.Sequential(
act_layer(),
LinearNorm(self.dh, dim, bn_weight_init=0, resolution=resolution))
points = list(itertools.product(range(resolution), range(resolution)))
N = len(points)
attention_offsets = {}
idxs = []
for p1 in points:
for p2 in points:
offset = (abs(p1[0] - p2[0]), abs(p1[1] - p2[1]))
if offset not in attention_offsets:
attention_offsets[offset] = len(attention_offsets)
idxs.append(attention_offsets[offset])
self.attention_biases = torch.nn.Parameter(torch.zeros(num_heads, len(attention_offsets)))
self.register_buffer('attention_bias_idxs', torch.LongTensor(idxs).view(N, N))
@torch.no_grad()
def train(self, mode=True):
super().train(mode)
if mode and hasattr(self, 'ab'):
del self.ab
else:
self.ab = self.attention_biases[:, self.attention_bias_idxs]
def forward(self, x): # x (B,N,C)
B, N, C = x.shape
qkv = self.qkv(x)
q, k, v = qkv.view(B, N, self.num_heads, -1).split([self.key_dim, self.key_dim, self.d], dim=3)
q = q.permute(0, 2, 1, 3)
k = k.permute(0, 2, 1, 3)
v = v.permute(0, 2, 1, 3)
ab = self.attention_biases[:, self.attention_bias_idxs] if self.training else self.ab
attn = q @ k.transpose(-2, -1) * self.scale + ab
attn = attn.softmax(dim=-1)
x = (attn @ v).transpose(1, 2).reshape(B, N, self.dh)
x = self.proj(x)
return x
class Subsample(torch.nn.Module):
def __init__(self, stride, resolution):
super().__init__()
self.stride = stride
self.resolution = resolution
def forward(self, x):
B, N, C = x.shape
x = x.view(B, self.resolution, self.resolution, C)[:, ::self.stride, ::self.stride]
return x.reshape(B, -1, C)
class AttentionSubsample(torch.nn.Module):
def __init__(self, in_dim, out_dim, key_dim, num_heads=8,
attn_ratio=2, act_layer=None, stride=2, resolution=14, resolution_=7):
super().__init__()
self.num_heads = num_heads
self.scale = key_dim ** -0.5
self.key_dim = key_dim
self.nh_kd = nh_kd = key_dim * num_heads
self.d = int(attn_ratio * key_dim)
self.dh = int(attn_ratio * key_dim) * self.num_heads
self.attn_ratio = attn_ratio
self.resolution_ = resolution_
self.resolution_2 = resolution_ ** 2
h = self.dh + nh_kd
self.kv = LinearNorm(in_dim, h, resolution=resolution)
self.q = torch.nn.Sequential(
Subsample(stride, resolution),
LinearNorm(in_dim, nh_kd, resolution=resolution_))
self.proj = torch.nn.Sequential(
act_layer(),
LinearNorm(self.dh, out_dim, resolution=resolution_))
self.stride = stride
self.resolution = resolution
points = list(itertools.product(range(resolution), range(resolution)))
points_ = list(itertools.product(range(resolution_), range(resolution_)))
N = len(points)
N_ = len(points_)
attention_offsets = {}
idxs = []
for p1 in points_:
for p2 in points:
size = 1
offset = (
abs(p1[0] * stride - p2[0] + (size - 1) / 2),
abs(p1[1] * stride - p2[1] + (size - 1) / 2))
if offset not in attention_offsets:
attention_offsets[offset] = len(attention_offsets)
idxs.append(attention_offsets[offset])
self.attention_biases = torch.nn.Parameter(torch.zeros(num_heads, len(attention_offsets)))
self.register_buffer('attention_bias_idxs', torch.LongTensor(idxs).view(N_, N))
@torch.no_grad()
def train(self, mode=True):
super().train(mode)
if mode and hasattr(self, 'ab'):
del self.ab
else:
self.ab = self.attention_biases[:, self.attention_bias_idxs]
def forward(self, x):
B, N, C = x.shape
k, v = self.kv(x).view(B, N, self.num_heads, -1).split([self.key_dim, self.d], dim=3)
k = k.permute(0, 2, 1, 3) # BHNC
v = v.permute(0, 2, 1, 3) # BHNC
q = self.q(x).view(B, self.resolution_2, self.num_heads, self.key_dim).permute(0, 2, 1, 3)
ab = self.attention_biases[:, self.attention_bias_idxs] if self.training else self.ab
attn = q @ k.transpose(-2, -1) * self.scale + ab
attn = attn.softmax(dim=-1)
x = (attn @ v).transpose(1, 2).reshape(B, -1, self.dh)
x = self.proj(x)
return x
class Levit(torch.nn.Module):
""" Vision Transformer with support for patch or hybrid CNN input stage
"""
def __init__(
self,
img_size=224,
patch_size=16,
in_chans=3,
num_classes=1000,
embed_dim=[192],
key_dim=[64],
depth=[12],
num_heads=[3],
attn_ratio=[2],
mlp_ratio=[2],
hybrid_backbone=None,
down_ops=[],
attn_act_layer=torch.nn.Hardswish,
mlp_act_layer=torch.nn.Hardswish,
distillation=True,
drop_path=0):
super().__init__()
global FLOPS_COUNTER
self.num_classes = num_classes
self.num_features = embed_dim[-1]
self.embed_dim = embed_dim
self.distillation = distillation
self.patch_embed = hybrid_backbone
self.blocks = []
down_ops.append([''])
resolution = img_size // patch_size
for i, (ed, kd, dpth, nh, ar, mr, do) in enumerate(
zip(embed_dim, key_dim, depth, num_heads, attn_ratio, mlp_ratio, down_ops)):
for _ in range(dpth):
self.blocks.append(
Residual(
Attention(ed, kd, nh, attn_ratio=ar, act_layer=attn_act_layer, resolution=resolution),
drop_path))
if mr > 0:
h = int(ed * mr)
self.blocks.append(
Residual(torch.nn.Sequential(
LinearNorm(ed, h, resolution=resolution),
mlp_act_layer(),
LinearNorm(h, ed, bn_weight_init=0, resolution=resolution),
), drop_path))
if do[0] == 'Subsample':
# ('Subsample',key_dim, num_heads, attn_ratio, mlp_ratio, stride)
resolution_ = (resolution - 1) // do[5] + 1
self.blocks.append(
AttentionSubsample(
*embed_dim[i:i + 2], key_dim=do[1], num_heads=do[2],
attn_ratio=do[3], act_layer=attn_act_layer, stride=do[5],
resolution=resolution, resolution_=resolution_))
resolution = resolution_
if do[4] > 0: # mlp_ratio
h = int(embed_dim[i + 1] * do[4])
self.blocks.append(
Residual(torch.nn.Sequential(
LinearNorm(embed_dim[i + 1], h, resolution=resolution),
mlp_act_layer(),
LinearNorm(h, embed_dim[i + 1], bn_weight_init=0, resolution=resolution),
), drop_path))
self.blocks = torch.nn.Sequential(*self.blocks)
# Classifier head
self.head = NormLinear(embed_dim[-1], num_classes) if num_classes > 0 else torch.nn.Identity()
if distillation:
self.head_dist = NormLinear(embed_dim[-1], num_classes) if num_classes > 0 else torch.nn.Identity()
else:
self.head_dist = None
@torch.jit.ignore
def no_weight_decay(self):
return {x for x in self.state_dict().keys() if 'attention_biases' in x}
def forward(self, x):
x = self.patch_embed(x)
x = x.flatten(2).transpose(1, 2)
x = self.blocks(x)
x = x.mean(1)
if self.distillation:
x = self.head(x), self.head_dist(x)
if not self.training:
x = (x[0] + x[1]) / 2
else:
x = self.head(x)
return x
def model_factory(C, D, X, N, drop_path, weights, num_classes, distillation, pretrained, fuse):
embed_dim = [int(x) for x in C.split('_')]
num_heads = [int(x) for x in N.split('_')]
depth = [int(x) for x in X.split('_')]
act = torch.nn.Hardswish
model = Levit(
patch_size=16,
embed_dim=embed_dim,
num_heads=num_heads,
key_dim=[D] * 3,
depth=depth,
attn_ratio=[2, 2, 2],
mlp_ratio=[2, 2, 2],
down_ops=[
# ('Subsample',key_dim, num_heads, attn_ratio, mlp_ratio, stride)
['Subsample', D, embed_dim[0] // D, 4, 2, 2],
['Subsample', D, embed_dim[1] // D, 4, 2, 2],
],
attn_act_layer=act,
mlp_act_layer=act,
hybrid_backbone=b16(embed_dim[0], activation=act),
num_classes=num_classes,
drop_path=drop_path,
distillation=distillation
)
model.default_cfg = _cfg()
if pretrained:
checkpoint = torch.hub.load_state_dict_from_url(weights, map_location='cpu')
model.load_state_dict(checkpoint['model'])
#if fuse:
# utils.replace_batchnorm(model)
return model

@ -0,0 +1,400 @@
# Copyright (c) 2015-present, Facebook, Inc.
# All rights reserved.
# Modified from
# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
# Copyright 2020 Ross Wightman, Apache-2.0 License
import itertools
import torch
from timm.data import IMAGENET_DEFAULT_STD, IMAGENET_DEFAULT_MEAN
from .vision_transformer import trunc_normal_
from .registry import register_model
def _cfg(url='', **kwargs):
return {
'url': url,
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True,
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
'first_conv': 'patch_embed.proj', 'classifier': 'head',
**kwargs
}
specification = {
'levit_c_128s': {
'C': '128_256_384', 'D': 16, 'N': '4_6_8', 'X': '2_3_4', 'drop_path': 0,
'weights': 'https://dl.fbaipublicfiles.com/LeViT/LeViT-128S-96703c44.pth'},
'levit_c_128': {
'C': '128_256_384', 'D': 16, 'N': '4_8_12', 'X': '4_4_4', 'drop_path': 0,
'weights': 'https://dl.fbaipublicfiles.com/LeViT/LeViT-128-b88c2750.pth'},
'levit_c_192': {
'C': '192_288_384', 'D': 32, 'N': '3_5_6', 'X': '4_4_4', 'drop_path': 0,
'weights': 'https://dl.fbaipublicfiles.com/LeViT/LeViT-192-92712e41.pth'},
'levit_c_256': {
'C': '256_384_512', 'D': 32, 'N': '4_6_8', 'X': '4_4_4', 'drop_path': 0,
'weights': 'https://dl.fbaipublicfiles.com/LeViT/LeViT-256-13b5763e.pth'},
'levit_c_384': {
'C': '384_512_768', 'D': 32, 'N': '6_9_12', 'X': '4_4_4', 'drop_path': 0.1,
'weights': 'https://dl.fbaipublicfiles.com/LeViT/LeViT-384-9bdaf2e2.pth'},
}
__all__ = ['Levit']
@register_model
def levit_c_128s(num_classes=1000, distillation=True, pretrained=False, fuse=False, **kwargs):
return model_factory(**specification['levit_c_128s'], num_classes=num_classes,
distillation=distillation, pretrained=pretrained, fuse=fuse)
@register_model
def levit_c_128(num_classes=1000, distillation=True, pretrained=False, fuse=False, **kwargs):
return model_factory(**specification['levit_c_128'], num_classes=num_classes,
distillation=distillation, pretrained=pretrained, fuse=fuse)
@register_model
def levit_c_192(num_classes=1000, distillation=True, pretrained=False, fuse=False, **kwargs):
return model_factory(**specification['levit_c_192'], num_classes=num_classes,
distillation=distillation, pretrained=pretrained, fuse=fuse)
@register_model
def levit_c_256(num_classes=1000, distillation=True, pretrained=False, fuse=False, **kwargs):
return model_factory(**specification['levit_c_256'], num_classes=num_classes,
distillation=distillation, pretrained=pretrained, fuse=fuse)
@register_model
def levit_c_384(num_classes=1000, distillation=True, pretrained=False, fuse=False, **kwargs):
return model_factory(**specification['levit_c_384'], num_classes=num_classes,
distillation=distillation, pretrained=pretrained, fuse=fuse)
class ConvNorm(torch.nn.Sequential):
def __init__(
self, a, b, ks=1, stride=1, pad=0, dilation=1, groups=1, bn_weight_init=1, resolution=-10000):
super().__init__()
self.add_module('c', torch.nn.Conv2d(a, b, ks, stride, pad, dilation, groups, bias=False))
bn = torch.nn.BatchNorm2d(b)
torch.nn.init.constant_(bn.weight, bn_weight_init)
torch.nn.init.constant_(bn.bias, 0)
self.add_module('bn', bn)
@torch.no_grad()
def fuse(self):
c, bn = self._modules.values()
w = bn.weight / (bn.running_var + bn.eps) ** 0.5
w = c.weight * w[:, None, None, None]
b = bn.bias - bn.running_mean * bn.weight / \
(bn.running_var + bn.eps) ** 0.5
m = torch.nn.Conv2d(
w.size(1), w.size(0), w.shape[2:], stride=self.c.stride,
padding=self.c.padding, dilation=self.c.dilation, groups=self.c.groups)
m.weight.data.copy_(w)
m.bias.data.copy_(b)
return m
class NormLinear(torch.nn.Sequential):
def __init__(self, a, b, bias=True, std=0.02):
super().__init__()
self.add_module('bn', torch.nn.BatchNorm1d(a))
l = torch.nn.Linear(a, b, bias=bias)
trunc_normal_(l.weight, std=std)
if bias:
torch.nn.init.constant_(l.bias, 0)
self.add_module('l', l)
@torch.no_grad()
def fuse(self):
bn, l = self._modules.values()
w = bn.weight / (bn.running_var + bn.eps) ** 0.5
b = bn.bias - self.bn.running_mean * \
self.bn.weight / (bn.running_var + bn.eps) ** 0.5
w = l.weight * w[None, :]
if l.bias is None:
b = b @ self.l.weight.T
else:
b = (l.weight @ b[:, None]).view(-1) + self.l.bias
m = torch.nn.Linear(w.size(1), w.size(0))
m.weight.data.copy_(w)
m.bias.data.copy_(b)
return m
def b16(n, activation, resolution=224):
return torch.nn.Sequential(
ConvNorm(3, n // 8, 3, 2, 1, resolution=resolution),
activation(),
ConvNorm(n // 8, n // 4, 3, 2, 1, resolution=resolution // 2),
activation(),
ConvNorm(n // 4, n // 2, 3, 2, 1, resolution=resolution // 4),
activation(),
ConvNorm(n // 2, n, 3, 2, 1, resolution=resolution // 8))
class Residual(torch.nn.Module):
def __init__(self, m, drop):
super().__init__()
self.m = m
self.drop = drop
def forward(self, x):
if self.training and self.drop > 0:
return x + self.m(x) * torch.rand(
x.size(0), 1, 1, device=x.device).ge_(self.drop).div(1 - self.drop).detach()
else:
return x + self.m(x)
class Attention(torch.nn.Module):
def __init__(self, dim, key_dim, num_heads=8,
attn_ratio=4, act_layer=None, resolution=14):
super().__init__()
self.num_heads = num_heads
self.scale = key_dim ** -0.5
self.key_dim = key_dim
self.nh_kd = nh_kd = key_dim * num_heads
self.d = int(attn_ratio * key_dim)
self.dh = int(attn_ratio * key_dim) * num_heads
self.attn_ratio = attn_ratio
h = self.dh + nh_kd * 2
self.qkv = ConvNorm(dim, h, resolution=resolution)
self.proj = torch.nn.Sequential(
act_layer(),
ConvNorm(self.dh, dim, bn_weight_init=0, resolution=resolution))
points = list(itertools.product(range(resolution), range(resolution)))
N = len(points)
attention_offsets = {}
idxs = []
for p1 in points:
for p2 in points:
offset = (abs(p1[0] - p2[0]), abs(p1[1] - p2[1]))
if offset not in attention_offsets:
attention_offsets[offset] = len(attention_offsets)
idxs.append(attention_offsets[offset])
self.attention_biases = torch.nn.Parameter(torch.zeros(num_heads, len(attention_offsets)))
self.register_buffer('attention_bias_idxs', torch.LongTensor(idxs).view(N, N))
self.ab = None
@torch.no_grad()
def train(self, mode=True):
super().train(mode)
if mode and self.ab is not None:
self.ab = None
else:
self.ab = self.attention_biases[:, self.attention_bias_idxs]
def forward(self, x): # x (B,C,H,W)
B, C, H, W = x.shape
q, k, v = self.qkv(x).view(B, self.num_heads, -1, H * W).split([self.key_dim, self.key_dim, self.d], dim=2)
ab = self.attention_biases[:, self.attention_bias_idxs] if self.training else self.ab
attn = (q.transpose(-2, -1) @ k) * self.scale + ab
attn = attn.softmax(dim=-1)
x = (v @ attn.transpose(-2, -1)).view(B, -1, H, W)
x = self.proj(x)
return x
class AttentionSubsample(torch.nn.Module):
def __init__(
self, in_dim, out_dim, key_dim, num_heads=8, attn_ratio=2,
act_layer=None, stride=2, resolution=14, resolution_=7):
super().__init__()
self.num_heads = num_heads
self.scale = key_dim ** -0.5
self.key_dim = key_dim
self.nh_kd = nh_kd = key_dim * num_heads
self.d = int(attn_ratio * key_dim)
self.dh = int(attn_ratio * key_dim) * self.num_heads
self.attn_ratio = attn_ratio
self.resolution_ = resolution_
self.resolution_2 = resolution_ ** 2
h = self.dh + nh_kd
self.kv = ConvNorm(in_dim, h, resolution=resolution)
self.q = torch.nn.Sequential(
torch.nn.AvgPool2d(1, stride, 0),
ConvNorm(in_dim, nh_kd, resolution=resolution_))
self.proj = torch.nn.Sequential(
act_layer(),
ConvNorm(self.d * num_heads, out_dim, resolution=resolution_))
self.stride = stride
self.resolution = resolution
points = list(itertools.product(range(resolution), range(resolution)))
points_ = list(itertools.product(range(resolution_), range(resolution_)))
N = len(points)
N_ = len(points_)
attention_offsets = {}
idxs = []
for p1 in points_:
for p2 in points:
size = 1
offset = (
abs(p1[0] * stride - p2[0] + (size - 1) / 2),
abs(p1[1] * stride - p2[1] + (size - 1) / 2))
if offset not in attention_offsets:
attention_offsets[offset] = len(attention_offsets)
idxs.append(attention_offsets[offset])
self.attention_biases = torch.nn.Parameter(torch.zeros(num_heads, len(attention_offsets)))
self.register_buffer('attention_bias_idxs', torch.LongTensor(idxs).view(N_, N))
self.ab = None
@torch.no_grad()
def train(self, mode=True):
super().train(mode)
if mode and self.ab is not None:
self.ab = None
else:
self.ab = self.attention_biases[:, self.attention_bias_idxs]
def forward(self, x):
B, C, H, W = x.shape
k, v = self.kv(x).view(B, self.num_heads, -1, H * W).split([self.key_dim, self.d], dim=2)
q = self.q(x).view(B, self.num_heads, self.key_dim, self.resolution_2)
ab = self.attention_biases[:, self.attention_bias_idxs] if self.training else self.ab
attn = (q.transpose(-2, -1) @ k) * self.scale + ab
attn = attn.softmax(dim=-1)
x = (v @ attn.transpose(-2, -1)).reshape(B, -1, self.resolution_, self.resolution_)
x = self.proj(x)
return x
class Levit(torch.nn.Module):
""" Vision Transformer with support for patch or hybrid CNN input stage
"""
def __init__(
self,
img_size=224,
patch_size=16,
in_chans=3,
num_classes=1000,
embed_dim=[192],
key_dim=[64],
depth=[12],
num_heads=[3],
attn_ratio=[2],
mlp_ratio=[2],
hybrid_backbone=None,
down_ops=[],
attn_act_layer=torch.nn.Hardswish,
mlp_act_layer=torch.nn.Hardswish,
distillation=True,
drop_path=0):
super().__init__()
self.num_classes = num_classes
self.num_features = embed_dim[-1]
self.embed_dim = embed_dim
self.distillation = distillation
self.patch_embed = hybrid_backbone
self.blocks = []
down_ops.append([''])
resolution = img_size // patch_size
for i, (ed, kd, dpth, nh, ar, mr, do) in enumerate(
zip(embed_dim, key_dim, depth, num_heads, attn_ratio, mlp_ratio, down_ops)):
for _ in range(dpth):
self.blocks.append(
Residual(
Attention(ed, kd, nh, attn_ratio=ar, act_layer=attn_act_layer, resolution=resolution),
drop_path))
if mr > 0:
h = int(ed * mr)
self.blocks.append(
Residual(torch.nn.Sequential(
ConvNorm(ed, h, resolution=resolution),
mlp_act_layer(),
ConvNorm(h, ed, bn_weight_init=0, resolution=resolution),
), drop_path))
if do[0] == 'Subsample':
# ('Subsample',key_dim, num_heads, attn_ratio, mlp_ratio, stride)
resolution_ = (resolution - 1) // do[5] + 1
self.blocks.append(
AttentionSubsample(
*embed_dim[i:i + 2], key_dim=do[1], num_heads=do[2], attn_ratio=do[3],
act_layer=attn_act_layer, stride=do[5],
resolution=resolution, resolution_=resolution_))
resolution = resolution_
if do[4] > 0: # mlp_ratio
h = int(embed_dim[i + 1] * do[4])
self.blocks.append(
Residual(torch.nn.Sequential(
ConvNorm(embed_dim[i + 1], h, resolution=resolution),
mlp_act_layer(),
ConvNorm(h, embed_dim[i + 1], bn_weight_init=0, resolution=resolution),
), drop_path))
self.blocks = torch.nn.Sequential(*self.blocks)
# Classifier head
self.head = NormLinear(
embed_dim[-1], num_classes) if num_classes > 0 else torch.nn.Identity()
if distillation:
self.head_dist = NormLinear(
embed_dim[-1], num_classes) if num_classes > 0 else torch.nn.Identity()
@torch.jit.ignore
def no_weight_decay(self):
return {x for x in self.state_dict().keys() if 'attention_biases' in x}
def forward(self, x):
x = self.patch_embed(x)
x = self.blocks(x)
x = torch.nn.functional.adaptive_avg_pool2d(x, 1).flatten(1)
if self.distillation:
x = self.head(x), self.head_dist(x)
if not self.training:
x = (x[0] + x[1]) / 2
else:
x = self.head(x)
return x
def model_factory(C, D, X, N, drop_path, weights, num_classes, distillation, pretrained, fuse):
embed_dim = [int(x) for x in C.split('_')]
num_heads = [int(x) for x in N.split('_')]
depth = [int(x) for x in X.split('_')]
act = torch.nn.Hardswish
model = Levit(
patch_size=16,
embed_dim=embed_dim,
num_heads=num_heads,
key_dim=[D] * 3,
depth=depth,
attn_ratio=[2, 2, 2],
mlp_ratio=[2, 2, 2],
down_ops=[
# ('Subsample',key_dim, num_heads, attn_ratio, mlp_ratio, stride)
['Subsample', D, embed_dim[0] // D, 4, 2, 2],
['Subsample', D, embed_dim[1] // D, 4, 2, 2],
],
attn_act_layer=act,
mlp_act_layer=act,
hybrid_backbone=b16(embed_dim[0], activation=act),
num_classes=num_classes,
drop_path=drop_path,
distillation=distillation
)
model.default_cfg = _cfg()
if pretrained:
checkpoint = torch.hub.load_state_dict_from_url(
weights, map_location='cpu')
d = checkpoint['model']
D = model.state_dict()
for k in d.keys():
if D[k].shape != d[k].shape:
d[k] = d[k][:, :, None, None]
model.load_state_dict(d)
#if fuse:
# utils.replace_batchnorm(model)
return model

@ -0,0 +1,377 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .helpers import build_model_with_cfg, overlay_external_default_cfg
from .layers import to_2tuple, trunc_normal_, DropPath, PatchEmbed
from .registry import register_model
__all__ = ['Visformer']
def _cfg(url='', **kwargs):
return {
'url': url,
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True,
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
'first_conv': 'patch_embed.proj', 'classifier': 'head',
**kwargs
}
class LayerNormBHWC(nn.LayerNorm):
def __init__(self, dim):
super().__init__(dim)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return F.layer_norm(
x.permute(0, 2, 3, 1), self.normalized_shape, self.weight, self.bias, self.eps).permute(0, 3, 1, 2)
class SpatialMlp(nn.Module):
def __init__(self, in_features, hidden_features=None, out_features=None,
act_layer=nn.GELU, drop=0., group=8, spatial_conv=False):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.in_features = in_features
self.out_features = out_features
self.spatial_conv = spatial_conv
if self.spatial_conv:
if group < 2: # net setting
hidden_features = in_features * 5 // 6
else:
hidden_features = in_features * 2
self.hidden_features = hidden_features
self.group = group
self.drop = nn.Dropout(drop)
self.conv1 = nn.Conv2d(in_features, hidden_features, 1, stride=1, padding=0, bias=False)
self.act1 = act_layer()
if self.spatial_conv:
self.conv2 = nn.Conv2d(
hidden_features, hidden_features, 3, stride=1, padding=1, groups=self.group, bias=False)
self.act2 = act_layer()
else:
self.conv2 = None
self.act2 = None
self.conv3 = nn.Conv2d(hidden_features, out_features, 1, stride=1, padding=0, bias=False)
def forward(self, x):
x = self.conv1(x)
x = self.act1(x)
x = self.drop(x)
if self.conv2 is not None:
x = self.conv2(x)
x = self.act2(x)
x = self.conv3(x)
x = self.drop(x)
return x
class Attention(nn.Module):
def __init__(self, dim, num_heads=8, head_dim_ratio=1., attn_drop=0., proj_drop=0.):
super().__init__()
self.dim = dim
self.num_heads = num_heads
head_dim = round(dim // num_heads * head_dim_ratio)
self.head_dim = head_dim
self.scale = head_dim ** -0.5
self.qkv = nn.Conv2d(dim, head_dim * num_heads * 3, 1, stride=1, padding=0, bias=False)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Conv2d(self.head_dim * self.num_heads, dim, 1, stride=1, padding=0, bias=False)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x):
B, C, H, W = x.shape
x = self.qkv(x).reshape(B, 3, self.num_heads, self.head_dim, -1).permute(1, 0, 2, 4, 3)
q, k, v = x[0], x[1], x[2]
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = attn @ v
x = x.permute(0, 1, 3, 2).reshape(B, -1, H, W)
x = self.proj(x)
x = self.proj_drop(x)
return x
class Block(nn.Module):
def __init__(self, dim, num_heads, head_dim_ratio=1., mlp_ratio=4.,
drop=0., attn_drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=LayerNormBHWC,
group=8, attn_disabled=False, spatial_conv=False):
super().__init__()
self.spatial_conv = spatial_conv
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
if attn_disabled:
self.norm1 = None
self.attn = None
else:
self.norm1 = norm_layer(dim)
self.attn = Attention(
dim, num_heads=num_heads, head_dim_ratio=head_dim_ratio, attn_drop=attn_drop, proj_drop=drop)
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = SpatialMlp(
in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop,
group=group, spatial_conv=spatial_conv) # new setting
def forward(self, x):
if self.attn is not None:
x = x + self.drop_path(self.attn(self.norm1(x)))
x = x + self.drop_path(self.mlp(self.norm2(x)))
return x
class Visformer(nn.Module):
def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, init_channels=32, embed_dim=384,
depth=12, num_heads=6, mlp_ratio=4., drop_rate=0., attn_drop_rate=0., drop_path_rate=0.,
norm_layer=LayerNormBHWC, attn_stage='111', pos_embed=True, spatial_conv='111',
vit_stem=False, group=8, pool=True, conv_init=False, embed_norm=None):
super().__init__()
self.num_classes = num_classes
self.num_features = self.embed_dim = embed_dim
self.init_channels = init_channels
self.img_size = img_size
self.vit_stem = vit_stem
self.pool = pool
self.conv_init = conv_init
if isinstance(depth, (list, tuple)):
self.stage_num1, self.stage_num2, self.stage_num3 = depth
depth = sum(depth)
else:
self.stage_num1 = self.stage_num3 = depth // 3
self.stage_num2 = depth - self.stage_num1 - self.stage_num3
self.pos_embed = pos_embed
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]
# stage 1
if self.vit_stem:
self.stem = None
self.patch_embed1 = PatchEmbed(
img_size=img_size, patch_size=patch_size, in_chans=in_chans,
embed_dim=embed_dim, norm_layer=embed_norm, flatten=False)
img_size //= 16
else:
if self.init_channels is None:
self.stem = None
self.patch_embed1 = PatchEmbed(
img_size=img_size, patch_size=patch_size // 2, in_chans=in_chans,
embed_dim=embed_dim // 2, norm_layer=embed_norm, flatten=False)
img_size //= 8
else:
self.stem = nn.Sequential(
nn.Conv2d(3, self.init_channels, 7, stride=2, padding=3, bias=False),
nn.BatchNorm2d(self.init_channels),
nn.ReLU(inplace=True)
)
img_size //= 2
self.patch_embed1 = PatchEmbed(
img_size=img_size, patch_size=patch_size // 4, in_chans=self.init_channels,
embed_dim=embed_dim // 2, norm_layer=embed_norm, flatten=False)
img_size //= 4
if self.pos_embed:
if self.vit_stem:
self.pos_embed1 = nn.Parameter(torch.zeros(1, embed_dim, img_size, img_size))
else:
self.pos_embed1 = nn.Parameter(torch.zeros(1, embed_dim//2, img_size, img_size))
self.pos_drop = nn.Dropout(p=drop_rate)
self.stage1 = nn.ModuleList([
Block(
dim=embed_dim//2, num_heads=num_heads, head_dim_ratio=0.5, mlp_ratio=mlp_ratio,
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
group=group, attn_disabled=(attn_stage[0] == '0'), spatial_conv=(spatial_conv[0] == '1')
)
for i in range(self.stage_num1)
])
#stage2
if not self.vit_stem:
self.patch_embed2 = PatchEmbed(
img_size=img_size, patch_size=patch_size // 8, in_chans=embed_dim // 2,
embed_dim=embed_dim, norm_layer=embed_norm, flatten=False)
img_size //= 2
if self.pos_embed:
self.pos_embed2 = nn.Parameter(torch.zeros(1, embed_dim, img_size, img_size))
self.stage2 = nn.ModuleList([
Block(
dim=embed_dim, num_heads=num_heads, head_dim_ratio=1.0, mlp_ratio=mlp_ratio,
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
group=group, attn_disabled=(attn_stage[1] == '0'), spatial_conv=(spatial_conv[1] == '1')
)
for i in range(self.stage_num1, self.stage_num1+self.stage_num2)
])
# stage 3
if not self.vit_stem:
self.patch_embed3 = PatchEmbed(
img_size=img_size, patch_size=patch_size // 8, in_chans=embed_dim,
embed_dim=embed_dim * 2, norm_layer=embed_norm, flatten=False)
img_size //= 2
if self.pos_embed:
self.pos_embed3 = nn.Parameter(torch.zeros(1, embed_dim*2, img_size, img_size))
self.stage3 = nn.ModuleList([
Block(
dim=embed_dim*2, num_heads=num_heads, head_dim_ratio=1.0, mlp_ratio=mlp_ratio,
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
group=group, attn_disabled=(attn_stage[2] == '0'), spatial_conv=(spatial_conv[2] == '1')
)
for i in range(self.stage_num1+self.stage_num2, depth)
])
# head
if self.pool:
self.global_pooling = nn.AdaptiveAvgPool2d(1)
head_dim = embed_dim if self.vit_stem else embed_dim * 2
self.norm = norm_layer(head_dim)
self.head = nn.Linear(head_dim, num_classes)
# weights init
if self.pos_embed:
trunc_normal_(self.pos_embed1, std=0.02)
if not self.vit_stem:
trunc_normal_(self.pos_embed2, std=0.02)
trunc_normal_(self.pos_embed3, std=0.02)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=0.02)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
elif isinstance(m, nn.Conv2d):
if self.conv_init:
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
else:
trunc_normal_(m.weight, std=0.02)
if m.bias is not None:
nn.init.constant_(m.bias, 0.)
def forward(self, x):
if self.stem is not None:
x = self.stem(x)
# stage 1
x = self.patch_embed1(x)
if self.pos_embed:
x = x + self.pos_embed1
x = self.pos_drop(x)
for b in self.stage1:
x = b(x)
# stage 2
if not self.vit_stem:
x = self.patch_embed2(x)
if self.pos_embed:
x = x + self.pos_embed2
x = self.pos_drop(x)
for b in self.stage2:
x = b(x)
# stage3
if not self.vit_stem:
x = self.patch_embed3(x)
if self.pos_embed:
x = x + self.pos_embed3
x = self.pos_drop(x)
for b in self.stage3:
x = b(x)
# head
x = self.norm(x)
if self.pool:
x = self.global_pooling(x)
else:
x = x[:, :, 0, 0]
x = self.head(x.view(x.size(0), -1))
return x
@register_model
def visformer_tiny(pretrained=False, **kwargs):
model = Visformer(
img_size=224, init_channels=16, embed_dim=192, depth=(7, 4, 4), num_heads=3, mlp_ratio=4., group=8,
attn_stage='011', spatial_conv='100', norm_layer=nn.BatchNorm2d, conv_init=True,
embed_norm=nn.BatchNorm2d, **kwargs)
return model
@register_model
def visformer_small(pretrained=False, **kwargs):
model = Visformer(
img_size=224, init_channels=32, embed_dim=384, depth=(7, 4, 4), num_heads=6, mlp_ratio=4., group=8,
attn_stage='011', spatial_conv='100', norm_layer=nn.BatchNorm2d, conv_init=True,
embed_norm=nn.BatchNorm2d, **kwargs)
return model
@register_model
def visformer_net1(pretrained=False, **kwargs):
model = Visformer(
init_channels=None, embed_dim=384, depth=(0, 12, 0), num_heads=6, mlp_ratio=4., attn_stage='111',
spatial_conv='000', vit_stem=True, conv_init=True, **kwargs)
return model
@register_model
def visformer_net2(pretrained=False, **kwargs):
model = Visformer(
init_channels=32, embed_dim=384, depth=(0, 12, 0), num_heads=6, mlp_ratio=4., attn_stage='111',
spatial_conv='000', vit_stem=False, conv_init=True, **kwargs)
return model
@register_model
def visformer_net3(pretrained=False, **kwargs):
model = Visformer(
init_channels=32, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4., attn_stage='111',
spatial_conv='000', vit_stem=False, conv_init=True, **kwargs)
return model
@register_model
def visformer_net4(pretrained=False, **kwargs):
model = Visformer(init_channels=32, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4., attn_stage='111',
spatial_conv='000', vit_stem=False, conv_init=True, **kwargs)
return model
@register_model
def visformer_net5(pretrained=False, **kwargs):
model = Visformer(
init_channels=32, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4., group=1, attn_stage='111',
spatial_conv='111', vit_stem=False, conv_init=True, **kwargs)
return model
@register_model
def visformer_net6(pretrained=False, **kwargs):
model = Visformer(
init_channels=32, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4., group=1, attn_stage='111',
pos_embed=False, spatial_conv='111', conv_init=True, **kwargs)
return model
@register_model
def visformer_net7(pretrained=False, **kwargs):
model = Visformer(
init_channels=32, embed_dim=384, depth=(6, 7, 7), num_heads=6, group=1, attn_stage='000',
pos_embed=False, spatial_conv='111', conv_init=True, **kwargs)
return model
Loading…
Cancel
Save