You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
pytorch-image-models/timm/models/levit.py

441 lines
16 KiB

# Copyright (c) 2015-present, Facebook, Inc.
# All rights reserved.
# Modified from
# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
# Copyright 2020 Ross Wightman, Apache-2.0 License
import itertools
import torch
from timm.data import IMAGENET_DEFAULT_STD, IMAGENET_DEFAULT_MEAN
from .vision_transformer import trunc_normal_
from .registry import register_model
def _cfg(url='', **kwargs):
return {
'url': url,
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True,
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
'first_conv': 'patch_embed.proj', 'classifier': 'head',
**kwargs
}
specification = {
'levit_128s': {
'C': '128_256_384', 'D': 16, 'N': '4_6_8', 'X': '2_3_4', 'drop_path': 0,
'weights': 'https://dl.fbaipublicfiles.com/LeViT/LeViT-128S-96703c44.pth'},
'levit_128': {
'C': '128_256_384', 'D': 16, 'N': '4_8_12', 'X': '4_4_4', 'drop_path': 0,
'weights': 'https://dl.fbaipublicfiles.com/LeViT/LeViT-128-b88c2750.pth'},
'levit_192': {
'C': '192_288_384', 'D': 32, 'N': '3_5_6', 'X': '4_4_4', 'drop_path': 0,
'weights': 'https://dl.fbaipublicfiles.com/LeViT/LeViT-192-92712e41.pth'},
'levit_256': {
'C': '256_384_512', 'D': 32, 'N': '4_6_8', 'X': '4_4_4', 'drop_path': 0,
'weights': 'https://dl.fbaipublicfiles.com/LeViT/LeViT-256-13b5763e.pth'},
'levit_384': {
'C': '384_512_768', 'D': 32, 'N': '6_9_12', 'X': '4_4_4', 'drop_path': 0.1,
'weights': 'https://dl.fbaipublicfiles.com/LeViT/LeViT-384-9bdaf2e2.pth'},
}
__all__ = ['Levit']
@register_model
def levit_128s(num_classes=1000, distillation=True, pretrained=False, fuse=False, **kwargs):
return model_factory(**specification['levit_128s'], num_classes=num_classes,
distillation=distillation, pretrained=pretrained, fuse=fuse)
@register_model
def levit_128(num_classes=1000, distillation=True, pretrained=False, fuse=False, **kwargs):
return model_factory(**specification['levit_128'], num_classes=num_classes,
distillation=distillation, pretrained=pretrained, fuse=fuse)
@register_model
def levit_192(num_classes=1000, distillation=True, pretrained=False, fuse=False, **kwargs):
return model_factory(**specification['levit_192'], num_classes=num_classes,
distillation=distillation, pretrained=pretrained, fuse=fuse)
@register_model
def levit_256(num_classes=1000, distillation=True, pretrained=False, fuse=False, **kwargs):
return model_factory(**specification['levit_256'], num_classes=num_classes,
distillation=distillation, pretrained=pretrained, fuse=fuse)
@register_model
def levit_384(num_classes=1000, distillation=True, pretrained=False, fuse=False, **kwargs):
return model_factory(**specification['levit_384'], num_classes=num_classes,
distillation=distillation, pretrained=pretrained, fuse=fuse)
class ConvNorm(torch.nn.Sequential):
def __init__(
self, a, b, ks=1, stride=1, pad=0, dilation=1, groups=1, bn_weight_init=1, resolution=-10000):
super().__init__()
self.add_module('c', torch.nn.Conv2d(a, b, ks, stride, pad, dilation, groups, bias=False))
bn = torch.nn.BatchNorm2d(b)
torch.nn.init.constant_(bn.weight, bn_weight_init)
torch.nn.init.constant_(bn.bias, 0)
self.add_module('bn', bn)
@torch.no_grad()
def fuse(self):
c, bn = self._modules.values()
w = bn.weight / (bn.running_var + bn.eps) ** 0.5
w = c.weight * w[:, None, None, None]
b = bn.bias - bn.running_mean * bn.weight / (bn.running_var + bn.eps) ** 0.5
m = torch.nn.Conv2d(
w.size(1), w.size(0), w.shape[2:], stride=self.c.stride,
padding=self.c.padding, dilation=self.c.dilation, groups=self.c.groups)
m.weight.data.copy_(w)
m.bias.data.copy_(b)
return m
class LinearNorm(torch.nn.Sequential):
def __init__(self, a, b, bn_weight_init=1, resolution=-100000):
super().__init__()
self.add_module('c', torch.nn.Linear(a, b, bias=False))
bn = torch.nn.BatchNorm1d(b)
torch.nn.init.constant_(bn.weight, bn_weight_init)
torch.nn.init.constant_(bn.bias, 0)
self.add_module('bn', bn)
@torch.no_grad()
def fuse(self):
l, bn = self._modules.values()
w = bn.weight / (bn.running_var + bn.eps) ** 0.5
w = l.weight * w[:, None]
b = bn.bias - bn.running_mean * bn.weight / (bn.running_var + bn.eps) ** 0.5
m = torch.nn.Linear(w.size(1), w.size(0))
m.weight.data.copy_(w)
m.bias.data.copy_(b)
return m
def forward(self, x):
l, bn = self._modules.values()
x = l(x)
return bn(x.flatten(0, 1)).reshape_as(x)
class NormLinear(torch.nn.Sequential):
def __init__(self, a, b, bias=True, std=0.02):
super().__init__()
self.add_module('bn', torch.nn.BatchNorm1d(a))
l = torch.nn.Linear(a, b, bias=bias)
trunc_normal_(l.weight, std=std)
if bias:
torch.nn.init.constant_(l.bias, 0)
self.add_module('l', l)
@torch.no_grad()
def fuse(self):
bn, l = self._modules.values()
w = bn.weight / (bn.running_var + bn.eps) ** 0.5
b = bn.bias - self.bn.running_mean * self.bn.weight / (bn.running_var + bn.eps) ** 0.5
w = l.weight * w[None, :]
if l.bias is None:
b = b @ self.l.weight.T
else:
b = (l.weight @ b[:, None]).view(-1) + self.l.bias
m = torch.nn.Linear(w.size(1), w.size(0))
m.weight.data.copy_(w)
m.bias.data.copy_(b)
return m
def b16(n, activation, resolution=224):
return torch.nn.Sequential(
ConvNorm(3, n // 8, 3, 2, 1, resolution=resolution),
activation(),
ConvNorm(n // 8, n // 4, 3, 2, 1, resolution=resolution // 2),
activation(),
ConvNorm(n // 4, n // 2, 3, 2, 1, resolution=resolution // 4),
activation(),
ConvNorm(n // 2, n, 3, 2, 1, resolution=resolution // 8))
class Residual(torch.nn.Module):
def __init__(self, m, drop):
super().__init__()
self.m = m
self.drop = drop
def forward(self, x):
if self.training and self.drop > 0:
return x + self.m(x) * torch.rand(
x.size(0), 1, 1, device=x.device).ge_(self.drop).div(1 - self.drop).detach()
else:
return x + self.m(x)
class Attention(torch.nn.Module):
def __init__(
self, dim, key_dim, num_heads=8, attn_ratio=4, act_layer=None, resolution=14):
super().__init__()
self.num_heads = num_heads
self.scale = key_dim ** -0.5
self.key_dim = key_dim
self.nh_kd = nh_kd = key_dim * num_heads
self.d = int(attn_ratio * key_dim)
self.dh = int(attn_ratio * key_dim) * num_heads
self.attn_ratio = attn_ratio
h = self.dh + nh_kd * 2
self.qkv = LinearNorm(dim, h, resolution=resolution)
self.proj = torch.nn.Sequential(
act_layer(),
LinearNorm(self.dh, dim, bn_weight_init=0, resolution=resolution))
points = list(itertools.product(range(resolution), range(resolution)))
N = len(points)
attention_offsets = {}
idxs = []
for p1 in points:
for p2 in points:
offset = (abs(p1[0] - p2[0]), abs(p1[1] - p2[1]))
if offset not in attention_offsets:
attention_offsets[offset] = len(attention_offsets)
idxs.append(attention_offsets[offset])
self.attention_biases = torch.nn.Parameter(torch.zeros(num_heads, len(attention_offsets)))
self.register_buffer('attention_bias_idxs', torch.LongTensor(idxs).view(N, N))
@torch.no_grad()
def train(self, mode=True):
super().train(mode)
if mode and hasattr(self, 'ab'):
del self.ab
else:
self.ab = self.attention_biases[:, self.attention_bias_idxs]
def forward(self, x): # x (B,N,C)
B, N, C = x.shape
qkv = self.qkv(x)
q, k, v = qkv.view(B, N, self.num_heads, -1).split([self.key_dim, self.key_dim, self.d], dim=3)
q = q.permute(0, 2, 1, 3)
k = k.permute(0, 2, 1, 3)
v = v.permute(0, 2, 1, 3)
ab = self.attention_biases[:, self.attention_bias_idxs] if self.training else self.ab
attn = q @ k.transpose(-2, -1) * self.scale + ab
attn = attn.softmax(dim=-1)
x = (attn @ v).transpose(1, 2).reshape(B, N, self.dh)
x = self.proj(x)
return x
class Subsample(torch.nn.Module):
def __init__(self, stride, resolution):
super().__init__()
self.stride = stride
self.resolution = resolution
def forward(self, x):
B, N, C = x.shape
x = x.view(B, self.resolution, self.resolution, C)[:, ::self.stride, ::self.stride]
return x.reshape(B, -1, C)
class AttentionSubsample(torch.nn.Module):
def __init__(self, in_dim, out_dim, key_dim, num_heads=8,
attn_ratio=2, act_layer=None, stride=2, resolution=14, resolution_=7):
super().__init__()
self.num_heads = num_heads
self.scale = key_dim ** -0.5
self.key_dim = key_dim
self.nh_kd = nh_kd = key_dim * num_heads
self.d = int(attn_ratio * key_dim)
self.dh = int(attn_ratio * key_dim) * self.num_heads
self.attn_ratio = attn_ratio
self.resolution_ = resolution_
self.resolution_2 = resolution_ ** 2
h = self.dh + nh_kd
self.kv = LinearNorm(in_dim, h, resolution=resolution)
self.q = torch.nn.Sequential(
Subsample(stride, resolution),
LinearNorm(in_dim, nh_kd, resolution=resolution_))
self.proj = torch.nn.Sequential(
act_layer(),
LinearNorm(self.dh, out_dim, resolution=resolution_))
self.stride = stride
self.resolution = resolution
points = list(itertools.product(range(resolution), range(resolution)))
points_ = list(itertools.product(range(resolution_), range(resolution_)))
N = len(points)
N_ = len(points_)
attention_offsets = {}
idxs = []
for p1 in points_:
for p2 in points:
size = 1
offset = (
abs(p1[0] * stride - p2[0] + (size - 1) / 2),
abs(p1[1] * stride - p2[1] + (size - 1) / 2))
if offset not in attention_offsets:
attention_offsets[offset] = len(attention_offsets)
idxs.append(attention_offsets[offset])
self.attention_biases = torch.nn.Parameter(torch.zeros(num_heads, len(attention_offsets)))
self.register_buffer('attention_bias_idxs', torch.LongTensor(idxs).view(N_, N))
@torch.no_grad()
def train(self, mode=True):
super().train(mode)
if mode and hasattr(self, 'ab'):
del self.ab
else:
self.ab = self.attention_biases[:, self.attention_bias_idxs]
def forward(self, x):
B, N, C = x.shape
k, v = self.kv(x).view(B, N, self.num_heads, -1).split([self.key_dim, self.d], dim=3)
k = k.permute(0, 2, 1, 3) # BHNC
v = v.permute(0, 2, 1, 3) # BHNC
q = self.q(x).view(B, self.resolution_2, self.num_heads, self.key_dim).permute(0, 2, 1, 3)
ab = self.attention_biases[:, self.attention_bias_idxs] if self.training else self.ab
attn = q @ k.transpose(-2, -1) * self.scale + ab
attn = attn.softmax(dim=-1)
x = (attn @ v).transpose(1, 2).reshape(B, -1, self.dh)
x = self.proj(x)
return x
class Levit(torch.nn.Module):
""" Vision Transformer with support for patch or hybrid CNN input stage
"""
def __init__(
self,
img_size=224,
patch_size=16,
in_chans=3,
num_classes=1000,
embed_dim=[192],
key_dim=[64],
depth=[12],
num_heads=[3],
attn_ratio=[2],
mlp_ratio=[2],
hybrid_backbone=None,
down_ops=[],
attn_act_layer=torch.nn.Hardswish,
mlp_act_layer=torch.nn.Hardswish,
distillation=True,
drop_path=0):
super().__init__()
global FLOPS_COUNTER
self.num_classes = num_classes
self.num_features = embed_dim[-1]
self.embed_dim = embed_dim
self.distillation = distillation
self.patch_embed = hybrid_backbone
self.blocks = []
down_ops.append([''])
resolution = img_size // patch_size
for i, (ed, kd, dpth, nh, ar, mr, do) in enumerate(
zip(embed_dim, key_dim, depth, num_heads, attn_ratio, mlp_ratio, down_ops)):
for _ in range(dpth):
self.blocks.append(
Residual(
Attention(ed, kd, nh, attn_ratio=ar, act_layer=attn_act_layer, resolution=resolution),
drop_path))
if mr > 0:
h = int(ed * mr)
self.blocks.append(
Residual(torch.nn.Sequential(
LinearNorm(ed, h, resolution=resolution),
mlp_act_layer(),
LinearNorm(h, ed, bn_weight_init=0, resolution=resolution),
), drop_path))
if do[0] == 'Subsample':
# ('Subsample',key_dim, num_heads, attn_ratio, mlp_ratio, stride)
resolution_ = (resolution - 1) // do[5] + 1
self.blocks.append(
AttentionSubsample(
*embed_dim[i:i + 2], key_dim=do[1], num_heads=do[2],
attn_ratio=do[3], act_layer=attn_act_layer, stride=do[5],
resolution=resolution, resolution_=resolution_))
resolution = resolution_
if do[4] > 0: # mlp_ratio
h = int(embed_dim[i + 1] * do[4])
self.blocks.append(
Residual(torch.nn.Sequential(
LinearNorm(embed_dim[i + 1], h, resolution=resolution),
mlp_act_layer(),
LinearNorm(h, embed_dim[i + 1], bn_weight_init=0, resolution=resolution),
), drop_path))
self.blocks = torch.nn.Sequential(*self.blocks)
# Classifier head
self.head = NormLinear(embed_dim[-1], num_classes) if num_classes > 0 else torch.nn.Identity()
if distillation:
self.head_dist = NormLinear(embed_dim[-1], num_classes) if num_classes > 0 else torch.nn.Identity()
else:
self.head_dist = None
@torch.jit.ignore
def no_weight_decay(self):
return {x for x in self.state_dict().keys() if 'attention_biases' in x}
def forward(self, x):
x = self.patch_embed(x)
x = x.flatten(2).transpose(1, 2)
x = self.blocks(x)
x = x.mean(1)
if self.distillation:
x = self.head(x), self.head_dist(x)
if not self.training:
x = (x[0] + x[1]) / 2
else:
x = self.head(x)
return x
def model_factory(C, D, X, N, drop_path, weights, num_classes, distillation, pretrained, fuse):
embed_dim = [int(x) for x in C.split('_')]
num_heads = [int(x) for x in N.split('_')]
depth = [int(x) for x in X.split('_')]
act = torch.nn.Hardswish
model = Levit(
patch_size=16,
embed_dim=embed_dim,
num_heads=num_heads,
key_dim=[D] * 3,
depth=depth,
attn_ratio=[2, 2, 2],
mlp_ratio=[2, 2, 2],
down_ops=[
# ('Subsample',key_dim, num_heads, attn_ratio, mlp_ratio, stride)
['Subsample', D, embed_dim[0] // D, 4, 2, 2],
['Subsample', D, embed_dim[1] // D, 4, 2, 2],
],
attn_act_layer=act,
mlp_act_layer=act,
hybrid_backbone=b16(embed_dim[0], activation=act),
num_classes=num_classes,
drop_path=drop_path,
distillation=distillation
)
model.default_cfg = _cfg()
if pretrained:
checkpoint = torch.hub.load_state_dict_from_url(weights, map_location='cpu')
model.load_state_dict(checkpoint['model'])
#if fuse:
# utils.replace_batchnorm(model)
return model