# Copyright (c) 2015-present, Facebook, Inc. # All rights reserved. # Modified from # https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py # Copyright 2020 Ross Wightman, Apache-2.0 License import itertools import torch from timm.data import IMAGENET_DEFAULT_STD, IMAGENET_DEFAULT_MEAN from .vision_transformer import trunc_normal_ from .registry import register_model def _cfg(url='', **kwargs): return { 'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, 'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True, 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, 'first_conv': 'patch_embed.proj', 'classifier': 'head', **kwargs } specification = { 'levit_128s': { 'C': '128_256_384', 'D': 16, 'N': '4_6_8', 'X': '2_3_4', 'drop_path': 0, 'weights': 'https://dl.fbaipublicfiles.com/LeViT/LeViT-128S-96703c44.pth'}, 'levit_128': { 'C': '128_256_384', 'D': 16, 'N': '4_8_12', 'X': '4_4_4', 'drop_path': 0, 'weights': 'https://dl.fbaipublicfiles.com/LeViT/LeViT-128-b88c2750.pth'}, 'levit_192': { 'C': '192_288_384', 'D': 32, 'N': '3_5_6', 'X': '4_4_4', 'drop_path': 0, 'weights': 'https://dl.fbaipublicfiles.com/LeViT/LeViT-192-92712e41.pth'}, 'levit_256': { 'C': '256_384_512', 'D': 32, 'N': '4_6_8', 'X': '4_4_4', 'drop_path': 0, 'weights': 'https://dl.fbaipublicfiles.com/LeViT/LeViT-256-13b5763e.pth'}, 'levit_384': { 'C': '384_512_768', 'D': 32, 'N': '6_9_12', 'X': '4_4_4', 'drop_path': 0.1, 'weights': 'https://dl.fbaipublicfiles.com/LeViT/LeViT-384-9bdaf2e2.pth'}, } __all__ = ['Levit'] @register_model def levit_128s(num_classes=1000, distillation=True, pretrained=False, fuse=False, **kwargs): return model_factory(**specification['levit_128s'], num_classes=num_classes, distillation=distillation, pretrained=pretrained, fuse=fuse) @register_model def levit_128(num_classes=1000, distillation=True, pretrained=False, fuse=False, **kwargs): return model_factory(**specification['levit_128'], num_classes=num_classes, distillation=distillation, pretrained=pretrained, fuse=fuse) @register_model def levit_192(num_classes=1000, distillation=True, pretrained=False, fuse=False, **kwargs): return model_factory(**specification['levit_192'], num_classes=num_classes, distillation=distillation, pretrained=pretrained, fuse=fuse) @register_model def levit_256(num_classes=1000, distillation=True, pretrained=False, fuse=False, **kwargs): return model_factory(**specification['levit_256'], num_classes=num_classes, distillation=distillation, pretrained=pretrained, fuse=fuse) @register_model def levit_384(num_classes=1000, distillation=True, pretrained=False, fuse=False, **kwargs): return model_factory(**specification['levit_384'], num_classes=num_classes, distillation=distillation, pretrained=pretrained, fuse=fuse) class ConvNorm(torch.nn.Sequential): def __init__( self, a, b, ks=1, stride=1, pad=0, dilation=1, groups=1, bn_weight_init=1, resolution=-10000): super().__init__() self.add_module('c', torch.nn.Conv2d(a, b, ks, stride, pad, dilation, groups, bias=False)) bn = torch.nn.BatchNorm2d(b) torch.nn.init.constant_(bn.weight, bn_weight_init) torch.nn.init.constant_(bn.bias, 0) self.add_module('bn', bn) @torch.no_grad() def fuse(self): c, bn = self._modules.values() w = bn.weight / (bn.running_var + bn.eps) ** 0.5 w = c.weight * w[:, None, None, None] b = bn.bias - bn.running_mean * bn.weight / (bn.running_var + bn.eps) ** 0.5 m = torch.nn.Conv2d( w.size(1), w.size(0), w.shape[2:], stride=self.c.stride, padding=self.c.padding, dilation=self.c.dilation, groups=self.c.groups) m.weight.data.copy_(w) m.bias.data.copy_(b) return m class LinearNorm(torch.nn.Sequential): def __init__(self, a, b, bn_weight_init=1, resolution=-100000): super().__init__() self.add_module('c', torch.nn.Linear(a, b, bias=False)) bn = torch.nn.BatchNorm1d(b) torch.nn.init.constant_(bn.weight, bn_weight_init) torch.nn.init.constant_(bn.bias, 0) self.add_module('bn', bn) @torch.no_grad() def fuse(self): l, bn = self._modules.values() w = bn.weight / (bn.running_var + bn.eps) ** 0.5 w = l.weight * w[:, None] b = bn.bias - bn.running_mean * bn.weight / (bn.running_var + bn.eps) ** 0.5 m = torch.nn.Linear(w.size(1), w.size(0)) m.weight.data.copy_(w) m.bias.data.copy_(b) return m def forward(self, x): l, bn = self._modules.values() x = l(x) return bn(x.flatten(0, 1)).reshape_as(x) class NormLinear(torch.nn.Sequential): def __init__(self, a, b, bias=True, std=0.02): super().__init__() self.add_module('bn', torch.nn.BatchNorm1d(a)) l = torch.nn.Linear(a, b, bias=bias) trunc_normal_(l.weight, std=std) if bias: torch.nn.init.constant_(l.bias, 0) self.add_module('l', l) @torch.no_grad() def fuse(self): bn, l = self._modules.values() w = bn.weight / (bn.running_var + bn.eps) ** 0.5 b = bn.bias - self.bn.running_mean * self.bn.weight / (bn.running_var + bn.eps) ** 0.5 w = l.weight * w[None, :] if l.bias is None: b = b @ self.l.weight.T else: b = (l.weight @ b[:, None]).view(-1) + self.l.bias m = torch.nn.Linear(w.size(1), w.size(0)) m.weight.data.copy_(w) m.bias.data.copy_(b) return m def b16(n, activation, resolution=224): return torch.nn.Sequential( ConvNorm(3, n // 8, 3, 2, 1, resolution=resolution), activation(), ConvNorm(n // 8, n // 4, 3, 2, 1, resolution=resolution // 2), activation(), ConvNorm(n // 4, n // 2, 3, 2, 1, resolution=resolution // 4), activation(), ConvNorm(n // 2, n, 3, 2, 1, resolution=resolution // 8)) class Residual(torch.nn.Module): def __init__(self, m, drop): super().__init__() self.m = m self.drop = drop def forward(self, x): if self.training and self.drop > 0: return x + self.m(x) * torch.rand( x.size(0), 1, 1, device=x.device).ge_(self.drop).div(1 - self.drop).detach() else: return x + self.m(x) class Attention(torch.nn.Module): def __init__( self, dim, key_dim, num_heads=8, attn_ratio=4, act_layer=None, resolution=14): super().__init__() self.num_heads = num_heads self.scale = key_dim ** -0.5 self.key_dim = key_dim self.nh_kd = nh_kd = key_dim * num_heads self.d = int(attn_ratio * key_dim) self.dh = int(attn_ratio * key_dim) * num_heads self.attn_ratio = attn_ratio h = self.dh + nh_kd * 2 self.qkv = LinearNorm(dim, h, resolution=resolution) self.proj = torch.nn.Sequential( act_layer(), LinearNorm(self.dh, dim, bn_weight_init=0, resolution=resolution)) points = list(itertools.product(range(resolution), range(resolution))) N = len(points) attention_offsets = {} idxs = [] for p1 in points: for p2 in points: offset = (abs(p1[0] - p2[0]), abs(p1[1] - p2[1])) if offset not in attention_offsets: attention_offsets[offset] = len(attention_offsets) idxs.append(attention_offsets[offset]) self.attention_biases = torch.nn.Parameter(torch.zeros(num_heads, len(attention_offsets))) self.register_buffer('attention_bias_idxs', torch.LongTensor(idxs).view(N, N)) @torch.no_grad() def train(self, mode=True): super().train(mode) if mode and hasattr(self, 'ab'): del self.ab else: self.ab = self.attention_biases[:, self.attention_bias_idxs] def forward(self, x): # x (B,N,C) B, N, C = x.shape qkv = self.qkv(x) q, k, v = qkv.view(B, N, self.num_heads, -1).split([self.key_dim, self.key_dim, self.d], dim=3) q = q.permute(0, 2, 1, 3) k = k.permute(0, 2, 1, 3) v = v.permute(0, 2, 1, 3) ab = self.attention_biases[:, self.attention_bias_idxs] if self.training else self.ab attn = q @ k.transpose(-2, -1) * self.scale + ab attn = attn.softmax(dim=-1) x = (attn @ v).transpose(1, 2).reshape(B, N, self.dh) x = self.proj(x) return x class Subsample(torch.nn.Module): def __init__(self, stride, resolution): super().__init__() self.stride = stride self.resolution = resolution def forward(self, x): B, N, C = x.shape x = x.view(B, self.resolution, self.resolution, C)[:, ::self.stride, ::self.stride] return x.reshape(B, -1, C) class AttentionSubsample(torch.nn.Module): def __init__(self, in_dim, out_dim, key_dim, num_heads=8, attn_ratio=2, act_layer=None, stride=2, resolution=14, resolution_=7): super().__init__() self.num_heads = num_heads self.scale = key_dim ** -0.5 self.key_dim = key_dim self.nh_kd = nh_kd = key_dim * num_heads self.d = int(attn_ratio * key_dim) self.dh = int(attn_ratio * key_dim) * self.num_heads self.attn_ratio = attn_ratio self.resolution_ = resolution_ self.resolution_2 = resolution_ ** 2 h = self.dh + nh_kd self.kv = LinearNorm(in_dim, h, resolution=resolution) self.q = torch.nn.Sequential( Subsample(stride, resolution), LinearNorm(in_dim, nh_kd, resolution=resolution_)) self.proj = torch.nn.Sequential( act_layer(), LinearNorm(self.dh, out_dim, resolution=resolution_)) self.stride = stride self.resolution = resolution points = list(itertools.product(range(resolution), range(resolution))) points_ = list(itertools.product(range(resolution_), range(resolution_))) N = len(points) N_ = len(points_) attention_offsets = {} idxs = [] for p1 in points_: for p2 in points: size = 1 offset = ( abs(p1[0] * stride - p2[0] + (size - 1) / 2), abs(p1[1] * stride - p2[1] + (size - 1) / 2)) if offset not in attention_offsets: attention_offsets[offset] = len(attention_offsets) idxs.append(attention_offsets[offset]) self.attention_biases = torch.nn.Parameter(torch.zeros(num_heads, len(attention_offsets))) self.register_buffer('attention_bias_idxs', torch.LongTensor(idxs).view(N_, N)) @torch.no_grad() def train(self, mode=True): super().train(mode) if mode and hasattr(self, 'ab'): del self.ab else: self.ab = self.attention_biases[:, self.attention_bias_idxs] def forward(self, x): B, N, C = x.shape k, v = self.kv(x).view(B, N, self.num_heads, -1).split([self.key_dim, self.d], dim=3) k = k.permute(0, 2, 1, 3) # BHNC v = v.permute(0, 2, 1, 3) # BHNC q = self.q(x).view(B, self.resolution_2, self.num_heads, self.key_dim).permute(0, 2, 1, 3) ab = self.attention_biases[:, self.attention_bias_idxs] if self.training else self.ab attn = q @ k.transpose(-2, -1) * self.scale + ab attn = attn.softmax(dim=-1) x = (attn @ v).transpose(1, 2).reshape(B, -1, self.dh) x = self.proj(x) return x class Levit(torch.nn.Module): """ Vision Transformer with support for patch or hybrid CNN input stage """ def __init__( self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=[192], key_dim=[64], depth=[12], num_heads=[3], attn_ratio=[2], mlp_ratio=[2], hybrid_backbone=None, down_ops=[], attn_act_layer=torch.nn.Hardswish, mlp_act_layer=torch.nn.Hardswish, distillation=True, drop_path=0): super().__init__() global FLOPS_COUNTER self.num_classes = num_classes self.num_features = embed_dim[-1] self.embed_dim = embed_dim self.distillation = distillation self.patch_embed = hybrid_backbone self.blocks = [] down_ops.append(['']) resolution = img_size // patch_size for i, (ed, kd, dpth, nh, ar, mr, do) in enumerate( zip(embed_dim, key_dim, depth, num_heads, attn_ratio, mlp_ratio, down_ops)): for _ in range(dpth): self.blocks.append( Residual( Attention(ed, kd, nh, attn_ratio=ar, act_layer=attn_act_layer, resolution=resolution), drop_path)) if mr > 0: h = int(ed * mr) self.blocks.append( Residual(torch.nn.Sequential( LinearNorm(ed, h, resolution=resolution), mlp_act_layer(), LinearNorm(h, ed, bn_weight_init=0, resolution=resolution), ), drop_path)) if do[0] == 'Subsample': # ('Subsample',key_dim, num_heads, attn_ratio, mlp_ratio, stride) resolution_ = (resolution - 1) // do[5] + 1 self.blocks.append( AttentionSubsample( *embed_dim[i:i + 2], key_dim=do[1], num_heads=do[2], attn_ratio=do[3], act_layer=attn_act_layer, stride=do[5], resolution=resolution, resolution_=resolution_)) resolution = resolution_ if do[4] > 0: # mlp_ratio h = int(embed_dim[i + 1] * do[4]) self.blocks.append( Residual(torch.nn.Sequential( LinearNorm(embed_dim[i + 1], h, resolution=resolution), mlp_act_layer(), LinearNorm(h, embed_dim[i + 1], bn_weight_init=0, resolution=resolution), ), drop_path)) self.blocks = torch.nn.Sequential(*self.blocks) # Classifier head self.head = NormLinear(embed_dim[-1], num_classes) if num_classes > 0 else torch.nn.Identity() if distillation: self.head_dist = NormLinear(embed_dim[-1], num_classes) if num_classes > 0 else torch.nn.Identity() else: self.head_dist = None @torch.jit.ignore def no_weight_decay(self): return {x for x in self.state_dict().keys() if 'attention_biases' in x} def forward(self, x): x = self.patch_embed(x) x = x.flatten(2).transpose(1, 2) x = self.blocks(x) x = x.mean(1) if self.distillation: x = self.head(x), self.head_dist(x) if not self.training: x = (x[0] + x[1]) / 2 else: x = self.head(x) return x def model_factory(C, D, X, N, drop_path, weights, num_classes, distillation, pretrained, fuse): embed_dim = [int(x) for x in C.split('_')] num_heads = [int(x) for x in N.split('_')] depth = [int(x) for x in X.split('_')] act = torch.nn.Hardswish model = Levit( patch_size=16, embed_dim=embed_dim, num_heads=num_heads, key_dim=[D] * 3, depth=depth, attn_ratio=[2, 2, 2], mlp_ratio=[2, 2, 2], down_ops=[ # ('Subsample',key_dim, num_heads, attn_ratio, mlp_ratio, stride) ['Subsample', D, embed_dim[0] // D, 4, 2, 2], ['Subsample', D, embed_dim[1] // D, 4, 2, 2], ], attn_act_layer=act, mlp_act_layer=act, hybrid_backbone=b16(embed_dim[0], activation=act), num_classes=num_classes, drop_path=drop_path, distillation=distillation ) model.default_cfg = _cfg() if pretrained: checkpoint = torch.hub.load_state_dict_from_url(weights, map_location='cpu') model.load_state_dict(checkpoint['model']) #if fuse: # utils.replace_batchnorm(model) return model