diff --git a/tests/test_models.py b/tests/test_models.py index ced2fd76..1e1de498 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -15,7 +15,7 @@ if hasattr(torch._C, '_jit_set_profiling_executor'): torch._C._jit_set_profiling_mode(False) # transformer models don't support many of the spatial / feature based model functionalities -NON_STD_FILTERS = ['vit_*', 'tnt_*', 'pit_*', 'swin_*', 'coat_*', 'cait_*', 'mixer_*'] +NON_STD_FILTERS = ['vit_*', 'tnt_*', 'pit_*', 'swin_*', 'coat_*', 'cait_*', 'mixer_*', 'levit*', 'visformer*'] NUM_NON_STD = len(NON_STD_FILTERS) # exclude models that cause specific test failures diff --git a/timm/models/__init__.py b/timm/models/__init__.py index 46ea155f..821012e2 100644 --- a/timm/models/__init__.py +++ b/timm/models/__init__.py @@ -15,6 +15,8 @@ from .hrnet import * from .inception_resnet_v2 import * from .inception_v3 import * from .inception_v4 import * +from .levitc import * +from .levit import * from .mlp_mixer import * from .mobilenetv3 import * from .nasnet import * @@ -34,6 +36,7 @@ from .swin_transformer import * from .tnt import * from .tresnet import * from .vgg import * +from .visformer import * from .vision_transformer import * from .vision_transformer_hybrid import * from .vovnet import * diff --git a/timm/models/layers/patch_embed.py b/timm/models/layers/patch_embed.py index b06f9982..42997fb8 100644 --- a/timm/models/layers/patch_embed.py +++ b/timm/models/layers/patch_embed.py @@ -15,7 +15,7 @@ from .helpers import to_2tuple class PatchEmbed(nn.Module): """ 2D Image to Patch Embedding """ - def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, norm_layer=None): + def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, norm_layer=None, flatten=True): super().__init__() img_size = to_2tuple(img_size) patch_size = to_2tuple(patch_size) @@ -23,6 +23,7 @@ class PatchEmbed(nn.Module): self.patch_size = patch_size self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1]) self.num_patches = self.grid_size[0] * self.grid_size[1] + self.flatten = flatten self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() @@ -31,6 +32,8 @@ class PatchEmbed(nn.Module): B, C, H, W = x.shape assert H == self.img_size[0] and W == self.img_size[1], \ f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." - x = self.proj(x).flatten(2).transpose(1, 2) + x = self.proj(x) + if self.flatten: + x = x.flatten(2).transpose(1, 2) # BCHW -> BNC x = self.norm(x) return x diff --git a/timm/models/levit.py b/timm/models/levit.py new file mode 100644 index 00000000..997b44d7 --- /dev/null +++ b/timm/models/levit.py @@ -0,0 +1,440 @@ +# Copyright (c) 2015-present, Facebook, Inc. +# All rights reserved. + +# Modified from +# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py +# Copyright 2020 Ross Wightman, Apache-2.0 License +import itertools + +import torch + +from timm.data import IMAGENET_DEFAULT_STD, IMAGENET_DEFAULT_MEAN +from .vision_transformer import trunc_normal_ +from .registry import register_model + + +def _cfg(url='', **kwargs): + return { + 'url': url, + 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, + 'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True, + 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, + 'first_conv': 'patch_embed.proj', 'classifier': 'head', + **kwargs + } + + +specification = { + 'levit_128s': { + 'C': '128_256_384', 'D': 16, 'N': '4_6_8', 'X': '2_3_4', 'drop_path': 0, + 'weights': 'https://dl.fbaipublicfiles.com/LeViT/LeViT-128S-96703c44.pth'}, + 'levit_128': { + 'C': '128_256_384', 'D': 16, 'N': '4_8_12', 'X': '4_4_4', 'drop_path': 0, + 'weights': 'https://dl.fbaipublicfiles.com/LeViT/LeViT-128-b88c2750.pth'}, + 'levit_192': { + 'C': '192_288_384', 'D': 32, 'N': '3_5_6', 'X': '4_4_4', 'drop_path': 0, + 'weights': 'https://dl.fbaipublicfiles.com/LeViT/LeViT-192-92712e41.pth'}, + 'levit_256': { + 'C': '256_384_512', 'D': 32, 'N': '4_6_8', 'X': '4_4_4', 'drop_path': 0, + 'weights': 'https://dl.fbaipublicfiles.com/LeViT/LeViT-256-13b5763e.pth'}, + 'levit_384': { + 'C': '384_512_768', 'D': 32, 'N': '6_9_12', 'X': '4_4_4', 'drop_path': 0.1, + 'weights': 'https://dl.fbaipublicfiles.com/LeViT/LeViT-384-9bdaf2e2.pth'}, +} + +__all__ = ['Levit'] + + +@register_model +def levit_128s(num_classes=1000, distillation=True, pretrained=False, fuse=False, **kwargs): + return model_factory(**specification['levit_128s'], num_classes=num_classes, + distillation=distillation, pretrained=pretrained, fuse=fuse) + + +@register_model +def levit_128(num_classes=1000, distillation=True, pretrained=False, fuse=False, **kwargs): + return model_factory(**specification['levit_128'], num_classes=num_classes, + distillation=distillation, pretrained=pretrained, fuse=fuse) + + +@register_model +def levit_192(num_classes=1000, distillation=True, pretrained=False, fuse=False, **kwargs): + return model_factory(**specification['levit_192'], num_classes=num_classes, + distillation=distillation, pretrained=pretrained, fuse=fuse) + + +@register_model +def levit_256(num_classes=1000, distillation=True, pretrained=False, fuse=False, **kwargs): + return model_factory(**specification['levit_256'], num_classes=num_classes, + distillation=distillation, pretrained=pretrained, fuse=fuse) + + +@register_model +def levit_384(num_classes=1000, distillation=True, pretrained=False, fuse=False, **kwargs): + return model_factory(**specification['levit_384'], num_classes=num_classes, + distillation=distillation, pretrained=pretrained, fuse=fuse) + + +class ConvNorm(torch.nn.Sequential): + def __init__( + self, a, b, ks=1, stride=1, pad=0, dilation=1, groups=1, bn_weight_init=1, resolution=-10000): + super().__init__() + self.add_module('c', torch.nn.Conv2d(a, b, ks, stride, pad, dilation, groups, bias=False)) + bn = torch.nn.BatchNorm2d(b) + torch.nn.init.constant_(bn.weight, bn_weight_init) + torch.nn.init.constant_(bn.bias, 0) + self.add_module('bn', bn) + + @torch.no_grad() + def fuse(self): + c, bn = self._modules.values() + w = bn.weight / (bn.running_var + bn.eps) ** 0.5 + w = c.weight * w[:, None, None, None] + b = bn.bias - bn.running_mean * bn.weight / (bn.running_var + bn.eps) ** 0.5 + m = torch.nn.Conv2d( + w.size(1), w.size(0), w.shape[2:], stride=self.c.stride, + padding=self.c.padding, dilation=self.c.dilation, groups=self.c.groups) + m.weight.data.copy_(w) + m.bias.data.copy_(b) + return m + + +class LinearNorm(torch.nn.Sequential): + def __init__(self, a, b, bn_weight_init=1, resolution=-100000): + super().__init__() + self.add_module('c', torch.nn.Linear(a, b, bias=False)) + bn = torch.nn.BatchNorm1d(b) + torch.nn.init.constant_(bn.weight, bn_weight_init) + torch.nn.init.constant_(bn.bias, 0) + self.add_module('bn', bn) + + @torch.no_grad() + def fuse(self): + l, bn = self._modules.values() + w = bn.weight / (bn.running_var + bn.eps) ** 0.5 + w = l.weight * w[:, None] + b = bn.bias - bn.running_mean * bn.weight / (bn.running_var + bn.eps) ** 0.5 + m = torch.nn.Linear(w.size(1), w.size(0)) + m.weight.data.copy_(w) + m.bias.data.copy_(b) + return m + + def forward(self, x): + l, bn = self._modules.values() + x = l(x) + return bn(x.flatten(0, 1)).reshape_as(x) + + +class NormLinear(torch.nn.Sequential): + def __init__(self, a, b, bias=True, std=0.02): + super().__init__() + self.add_module('bn', torch.nn.BatchNorm1d(a)) + l = torch.nn.Linear(a, b, bias=bias) + trunc_normal_(l.weight, std=std) + if bias: + torch.nn.init.constant_(l.bias, 0) + self.add_module('l', l) + + @torch.no_grad() + def fuse(self): + bn, l = self._modules.values() + w = bn.weight / (bn.running_var + bn.eps) ** 0.5 + b = bn.bias - self.bn.running_mean * self.bn.weight / (bn.running_var + bn.eps) ** 0.5 + w = l.weight * w[None, :] + if l.bias is None: + b = b @ self.l.weight.T + else: + b = (l.weight @ b[:, None]).view(-1) + self.l.bias + m = torch.nn.Linear(w.size(1), w.size(0)) + m.weight.data.copy_(w) + m.bias.data.copy_(b) + return m + + +def b16(n, activation, resolution=224): + return torch.nn.Sequential( + ConvNorm(3, n // 8, 3, 2, 1, resolution=resolution), + activation(), + ConvNorm(n // 8, n // 4, 3, 2, 1, resolution=resolution // 2), + activation(), + ConvNorm(n // 4, n // 2, 3, 2, 1, resolution=resolution // 4), + activation(), + ConvNorm(n // 2, n, 3, 2, 1, resolution=resolution // 8)) + + +class Residual(torch.nn.Module): + def __init__(self, m, drop): + super().__init__() + self.m = m + self.drop = drop + + def forward(self, x): + if self.training and self.drop > 0: + return x + self.m(x) * torch.rand( + x.size(0), 1, 1, device=x.device).ge_(self.drop).div(1 - self.drop).detach() + else: + return x + self.m(x) + + +class Attention(torch.nn.Module): + def __init__( + self, dim, key_dim, num_heads=8, attn_ratio=4, act_layer=None, resolution=14): + super().__init__() + self.num_heads = num_heads + self.scale = key_dim ** -0.5 + self.key_dim = key_dim + self.nh_kd = nh_kd = key_dim * num_heads + self.d = int(attn_ratio * key_dim) + self.dh = int(attn_ratio * key_dim) * num_heads + self.attn_ratio = attn_ratio + h = self.dh + nh_kd * 2 + self.qkv = LinearNorm(dim, h, resolution=resolution) + self.proj = torch.nn.Sequential( + act_layer(), + LinearNorm(self.dh, dim, bn_weight_init=0, resolution=resolution)) + + points = list(itertools.product(range(resolution), range(resolution))) + N = len(points) + attention_offsets = {} + idxs = [] + for p1 in points: + for p2 in points: + offset = (abs(p1[0] - p2[0]), abs(p1[1] - p2[1])) + if offset not in attention_offsets: + attention_offsets[offset] = len(attention_offsets) + idxs.append(attention_offsets[offset]) + self.attention_biases = torch.nn.Parameter(torch.zeros(num_heads, len(attention_offsets))) + self.register_buffer('attention_bias_idxs', torch.LongTensor(idxs).view(N, N)) + + @torch.no_grad() + def train(self, mode=True): + super().train(mode) + if mode and hasattr(self, 'ab'): + del self.ab + else: + self.ab = self.attention_biases[:, self.attention_bias_idxs] + + def forward(self, x): # x (B,N,C) + B, N, C = x.shape + qkv = self.qkv(x) + q, k, v = qkv.view(B, N, self.num_heads, -1).split([self.key_dim, self.key_dim, self.d], dim=3) + q = q.permute(0, 2, 1, 3) + k = k.permute(0, 2, 1, 3) + v = v.permute(0, 2, 1, 3) + + ab = self.attention_biases[:, self.attention_bias_idxs] if self.training else self.ab + attn = q @ k.transpose(-2, -1) * self.scale + ab + + attn = attn.softmax(dim=-1) + x = (attn @ v).transpose(1, 2).reshape(B, N, self.dh) + x = self.proj(x) + return x + + +class Subsample(torch.nn.Module): + def __init__(self, stride, resolution): + super().__init__() + self.stride = stride + self.resolution = resolution + + def forward(self, x): + B, N, C = x.shape + x = x.view(B, self.resolution, self.resolution, C)[:, ::self.stride, ::self.stride] + return x.reshape(B, -1, C) + + +class AttentionSubsample(torch.nn.Module): + def __init__(self, in_dim, out_dim, key_dim, num_heads=8, + attn_ratio=2, act_layer=None, stride=2, resolution=14, resolution_=7): + super().__init__() + self.num_heads = num_heads + self.scale = key_dim ** -0.5 + self.key_dim = key_dim + self.nh_kd = nh_kd = key_dim * num_heads + self.d = int(attn_ratio * key_dim) + self.dh = int(attn_ratio * key_dim) * self.num_heads + self.attn_ratio = attn_ratio + self.resolution_ = resolution_ + self.resolution_2 = resolution_ ** 2 + h = self.dh + nh_kd + self.kv = LinearNorm(in_dim, h, resolution=resolution) + + self.q = torch.nn.Sequential( + Subsample(stride, resolution), + LinearNorm(in_dim, nh_kd, resolution=resolution_)) + self.proj = torch.nn.Sequential( + act_layer(), + LinearNorm(self.dh, out_dim, resolution=resolution_)) + + self.stride = stride + self.resolution = resolution + points = list(itertools.product(range(resolution), range(resolution))) + points_ = list(itertools.product(range(resolution_), range(resolution_))) + N = len(points) + N_ = len(points_) + attention_offsets = {} + idxs = [] + for p1 in points_: + for p2 in points: + size = 1 + offset = ( + abs(p1[0] * stride - p2[0] + (size - 1) / 2), + abs(p1[1] * stride - p2[1] + (size - 1) / 2)) + if offset not in attention_offsets: + attention_offsets[offset] = len(attention_offsets) + idxs.append(attention_offsets[offset]) + self.attention_biases = torch.nn.Parameter(torch.zeros(num_heads, len(attention_offsets))) + self.register_buffer('attention_bias_idxs', torch.LongTensor(idxs).view(N_, N)) + + + @torch.no_grad() + def train(self, mode=True): + super().train(mode) + if mode and hasattr(self, 'ab'): + del self.ab + else: + self.ab = self.attention_biases[:, self.attention_bias_idxs] + + def forward(self, x): + B, N, C = x.shape + k, v = self.kv(x).view(B, N, self.num_heads, -1).split([self.key_dim, self.d], dim=3) + k = k.permute(0, 2, 1, 3) # BHNC + v = v.permute(0, 2, 1, 3) # BHNC + q = self.q(x).view(B, self.resolution_2, self.num_heads, self.key_dim).permute(0, 2, 1, 3) + + ab = self.attention_biases[:, self.attention_bias_idxs] if self.training else self.ab + attn = q @ k.transpose(-2, -1) * self.scale + ab + attn = attn.softmax(dim=-1) + + x = (attn @ v).transpose(1, 2).reshape(B, -1, self.dh) + x = self.proj(x) + return x + + +class Levit(torch.nn.Module): + """ Vision Transformer with support for patch or hybrid CNN input stage + """ + + def __init__( + self, + img_size=224, + patch_size=16, + in_chans=3, + num_classes=1000, + embed_dim=[192], + key_dim=[64], + depth=[12], + num_heads=[3], + attn_ratio=[2], + mlp_ratio=[2], + hybrid_backbone=None, + down_ops=[], + attn_act_layer=torch.nn.Hardswish, + mlp_act_layer=torch.nn.Hardswish, + distillation=True, + drop_path=0): + super().__init__() + global FLOPS_COUNTER + + self.num_classes = num_classes + self.num_features = embed_dim[-1] + self.embed_dim = embed_dim + self.distillation = distillation + + self.patch_embed = hybrid_backbone + + self.blocks = [] + down_ops.append(['']) + resolution = img_size // patch_size + for i, (ed, kd, dpth, nh, ar, mr, do) in enumerate( + zip(embed_dim, key_dim, depth, num_heads, attn_ratio, mlp_ratio, down_ops)): + for _ in range(dpth): + self.blocks.append( + Residual( + Attention(ed, kd, nh, attn_ratio=ar, act_layer=attn_act_layer, resolution=resolution), + drop_path)) + if mr > 0: + h = int(ed * mr) + self.blocks.append( + Residual(torch.nn.Sequential( + LinearNorm(ed, h, resolution=resolution), + mlp_act_layer(), + LinearNorm(h, ed, bn_weight_init=0, resolution=resolution), + ), drop_path)) + if do[0] == 'Subsample': + # ('Subsample',key_dim, num_heads, attn_ratio, mlp_ratio, stride) + resolution_ = (resolution - 1) // do[5] + 1 + self.blocks.append( + AttentionSubsample( + *embed_dim[i:i + 2], key_dim=do[1], num_heads=do[2], + attn_ratio=do[3], act_layer=attn_act_layer, stride=do[5], + resolution=resolution, resolution_=resolution_)) + resolution = resolution_ + if do[4] > 0: # mlp_ratio + h = int(embed_dim[i + 1] * do[4]) + self.blocks.append( + Residual(torch.nn.Sequential( + LinearNorm(embed_dim[i + 1], h, resolution=resolution), + mlp_act_layer(), + LinearNorm(h, embed_dim[i + 1], bn_weight_init=0, resolution=resolution), + ), drop_path)) + self.blocks = torch.nn.Sequential(*self.blocks) + + # Classifier head + self.head = NormLinear(embed_dim[-1], num_classes) if num_classes > 0 else torch.nn.Identity() + if distillation: + self.head_dist = NormLinear(embed_dim[-1], num_classes) if num_classes > 0 else torch.nn.Identity() + else: + self.head_dist = None + + @torch.jit.ignore + def no_weight_decay(self): + return {x for x in self.state_dict().keys() if 'attention_biases' in x} + + def forward(self, x): + x = self.patch_embed(x) + x = x.flatten(2).transpose(1, 2) + x = self.blocks(x) + x = x.mean(1) + if self.distillation: + x = self.head(x), self.head_dist(x) + if not self.training: + x = (x[0] + x[1]) / 2 + else: + x = self.head(x) + return x + + +def model_factory(C, D, X, N, drop_path, weights, num_classes, distillation, pretrained, fuse): + embed_dim = [int(x) for x in C.split('_')] + num_heads = [int(x) for x in N.split('_')] + depth = [int(x) for x in X.split('_')] + act = torch.nn.Hardswish + model = Levit( + patch_size=16, + embed_dim=embed_dim, + num_heads=num_heads, + key_dim=[D] * 3, + depth=depth, + attn_ratio=[2, 2, 2], + mlp_ratio=[2, 2, 2], + down_ops=[ + # ('Subsample',key_dim, num_heads, attn_ratio, mlp_ratio, stride) + ['Subsample', D, embed_dim[0] // D, 4, 2, 2], + ['Subsample', D, embed_dim[1] // D, 4, 2, 2], + ], + attn_act_layer=act, + mlp_act_layer=act, + hybrid_backbone=b16(embed_dim[0], activation=act), + num_classes=num_classes, + drop_path=drop_path, + distillation=distillation + ) + model.default_cfg = _cfg() + if pretrained: + checkpoint = torch.hub.load_state_dict_from_url(weights, map_location='cpu') + model.load_state_dict(checkpoint['model']) + #if fuse: + # utils.replace_batchnorm(model) + + return model diff --git a/timm/models/levitc.py b/timm/models/levitc.py new file mode 100644 index 00000000..1a422953 --- /dev/null +++ b/timm/models/levitc.py @@ -0,0 +1,400 @@ +# Copyright (c) 2015-present, Facebook, Inc. +# All rights reserved. + +# Modified from +# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py +# Copyright 2020 Ross Wightman, Apache-2.0 License +import itertools + +import torch + +from timm.data import IMAGENET_DEFAULT_STD, IMAGENET_DEFAULT_MEAN +from .vision_transformer import trunc_normal_ +from .registry import register_model + + +def _cfg(url='', **kwargs): + return { + 'url': url, + 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, + 'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True, + 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, + 'first_conv': 'patch_embed.proj', 'classifier': 'head', + **kwargs + } + + +specification = { + 'levit_c_128s': { + 'C': '128_256_384', 'D': 16, 'N': '4_6_8', 'X': '2_3_4', 'drop_path': 0, + 'weights': 'https://dl.fbaipublicfiles.com/LeViT/LeViT-128S-96703c44.pth'}, + 'levit_c_128': { + 'C': '128_256_384', 'D': 16, 'N': '4_8_12', 'X': '4_4_4', 'drop_path': 0, + 'weights': 'https://dl.fbaipublicfiles.com/LeViT/LeViT-128-b88c2750.pth'}, + 'levit_c_192': { + 'C': '192_288_384', 'D': 32, 'N': '3_5_6', 'X': '4_4_4', 'drop_path': 0, + 'weights': 'https://dl.fbaipublicfiles.com/LeViT/LeViT-192-92712e41.pth'}, + 'levit_c_256': { + 'C': '256_384_512', 'D': 32, 'N': '4_6_8', 'X': '4_4_4', 'drop_path': 0, + 'weights': 'https://dl.fbaipublicfiles.com/LeViT/LeViT-256-13b5763e.pth'}, + 'levit_c_384': { + 'C': '384_512_768', 'D': 32, 'N': '6_9_12', 'X': '4_4_4', 'drop_path': 0.1, + 'weights': 'https://dl.fbaipublicfiles.com/LeViT/LeViT-384-9bdaf2e2.pth'}, +} + +__all__ = ['Levit'] + + +@register_model +def levit_c_128s(num_classes=1000, distillation=True, pretrained=False, fuse=False, **kwargs): + return model_factory(**specification['levit_c_128s'], num_classes=num_classes, + distillation=distillation, pretrained=pretrained, fuse=fuse) + + +@register_model +def levit_c_128(num_classes=1000, distillation=True, pretrained=False, fuse=False, **kwargs): + return model_factory(**specification['levit_c_128'], num_classes=num_classes, + distillation=distillation, pretrained=pretrained, fuse=fuse) + + +@register_model +def levit_c_192(num_classes=1000, distillation=True, pretrained=False, fuse=False, **kwargs): + return model_factory(**specification['levit_c_192'], num_classes=num_classes, + distillation=distillation, pretrained=pretrained, fuse=fuse) + + +@register_model +def levit_c_256(num_classes=1000, distillation=True, pretrained=False, fuse=False, **kwargs): + return model_factory(**specification['levit_c_256'], num_classes=num_classes, + distillation=distillation, pretrained=pretrained, fuse=fuse) + + +@register_model +def levit_c_384(num_classes=1000, distillation=True, pretrained=False, fuse=False, **kwargs): + return model_factory(**specification['levit_c_384'], num_classes=num_classes, + distillation=distillation, pretrained=pretrained, fuse=fuse) + + +class ConvNorm(torch.nn.Sequential): + def __init__( + self, a, b, ks=1, stride=1, pad=0, dilation=1, groups=1, bn_weight_init=1, resolution=-10000): + super().__init__() + self.add_module('c', torch.nn.Conv2d(a, b, ks, stride, pad, dilation, groups, bias=False)) + bn = torch.nn.BatchNorm2d(b) + torch.nn.init.constant_(bn.weight, bn_weight_init) + torch.nn.init.constant_(bn.bias, 0) + self.add_module('bn', bn) + + @torch.no_grad() + def fuse(self): + c, bn = self._modules.values() + w = bn.weight / (bn.running_var + bn.eps) ** 0.5 + w = c.weight * w[:, None, None, None] + b = bn.bias - bn.running_mean * bn.weight / \ + (bn.running_var + bn.eps) ** 0.5 + m = torch.nn.Conv2d( + w.size(1), w.size(0), w.shape[2:], stride=self.c.stride, + padding=self.c.padding, dilation=self.c.dilation, groups=self.c.groups) + m.weight.data.copy_(w) + m.bias.data.copy_(b) + return m + + +class NormLinear(torch.nn.Sequential): + def __init__(self, a, b, bias=True, std=0.02): + super().__init__() + self.add_module('bn', torch.nn.BatchNorm1d(a)) + l = torch.nn.Linear(a, b, bias=bias) + trunc_normal_(l.weight, std=std) + if bias: + torch.nn.init.constant_(l.bias, 0) + self.add_module('l', l) + + @torch.no_grad() + def fuse(self): + bn, l = self._modules.values() + w = bn.weight / (bn.running_var + bn.eps) ** 0.5 + b = bn.bias - self.bn.running_mean * \ + self.bn.weight / (bn.running_var + bn.eps) ** 0.5 + w = l.weight * w[None, :] + if l.bias is None: + b = b @ self.l.weight.T + else: + b = (l.weight @ b[:, None]).view(-1) + self.l.bias + m = torch.nn.Linear(w.size(1), w.size(0)) + m.weight.data.copy_(w) + m.bias.data.copy_(b) + return m + + +def b16(n, activation, resolution=224): + return torch.nn.Sequential( + ConvNorm(3, n // 8, 3, 2, 1, resolution=resolution), + activation(), + ConvNorm(n // 8, n // 4, 3, 2, 1, resolution=resolution // 2), + activation(), + ConvNorm(n // 4, n // 2, 3, 2, 1, resolution=resolution // 4), + activation(), + ConvNorm(n // 2, n, 3, 2, 1, resolution=resolution // 8)) + + +class Residual(torch.nn.Module): + def __init__(self, m, drop): + super().__init__() + self.m = m + self.drop = drop + + def forward(self, x): + if self.training and self.drop > 0: + return x + self.m(x) * torch.rand( + x.size(0), 1, 1, device=x.device).ge_(self.drop).div(1 - self.drop).detach() + else: + return x + self.m(x) + + +class Attention(torch.nn.Module): + def __init__(self, dim, key_dim, num_heads=8, + attn_ratio=4, act_layer=None, resolution=14): + super().__init__() + self.num_heads = num_heads + self.scale = key_dim ** -0.5 + self.key_dim = key_dim + self.nh_kd = nh_kd = key_dim * num_heads + self.d = int(attn_ratio * key_dim) + self.dh = int(attn_ratio * key_dim) * num_heads + self.attn_ratio = attn_ratio + h = self.dh + nh_kd * 2 + self.qkv = ConvNorm(dim, h, resolution=resolution) + self.proj = torch.nn.Sequential( + act_layer(), + ConvNorm(self.dh, dim, bn_weight_init=0, resolution=resolution)) + + points = list(itertools.product(range(resolution), range(resolution))) + N = len(points) + attention_offsets = {} + idxs = [] + for p1 in points: + for p2 in points: + offset = (abs(p1[0] - p2[0]), abs(p1[1] - p2[1])) + if offset not in attention_offsets: + attention_offsets[offset] = len(attention_offsets) + idxs.append(attention_offsets[offset]) + self.attention_biases = torch.nn.Parameter(torch.zeros(num_heads, len(attention_offsets))) + self.register_buffer('attention_bias_idxs', torch.LongTensor(idxs).view(N, N)) + self.ab = None + + @torch.no_grad() + def train(self, mode=True): + super().train(mode) + if mode and self.ab is not None: + self.ab = None + else: + self.ab = self.attention_biases[:, self.attention_bias_idxs] + + def forward(self, x): # x (B,C,H,W) + B, C, H, W = x.shape + q, k, v = self.qkv(x).view(B, self.num_heads, -1, H * W).split([self.key_dim, self.key_dim, self.d], dim=2) + ab = self.attention_biases[:, self.attention_bias_idxs] if self.training else self.ab + attn = (q.transpose(-2, -1) @ k) * self.scale + ab + attn = attn.softmax(dim=-1) + x = (v @ attn.transpose(-2, -1)).view(B, -1, H, W) + x = self.proj(x) + return x + + +class AttentionSubsample(torch.nn.Module): + def __init__( + self, in_dim, out_dim, key_dim, num_heads=8, attn_ratio=2, + act_layer=None, stride=2, resolution=14, resolution_=7): + super().__init__() + self.num_heads = num_heads + self.scale = key_dim ** -0.5 + self.key_dim = key_dim + self.nh_kd = nh_kd = key_dim * num_heads + self.d = int(attn_ratio * key_dim) + self.dh = int(attn_ratio * key_dim) * self.num_heads + self.attn_ratio = attn_ratio + self.resolution_ = resolution_ + self.resolution_2 = resolution_ ** 2 + h = self.dh + nh_kd + self.kv = ConvNorm(in_dim, h, resolution=resolution) + self.q = torch.nn.Sequential( + torch.nn.AvgPool2d(1, stride, 0), + ConvNorm(in_dim, nh_kd, resolution=resolution_)) + self.proj = torch.nn.Sequential( + act_layer(), + ConvNorm(self.d * num_heads, out_dim, resolution=resolution_)) + + self.stride = stride + self.resolution = resolution + points = list(itertools.product(range(resolution), range(resolution))) + points_ = list(itertools.product(range(resolution_), range(resolution_))) + N = len(points) + N_ = len(points_) + attention_offsets = {} + idxs = [] + for p1 in points_: + for p2 in points: + size = 1 + offset = ( + abs(p1[0] * stride - p2[0] + (size - 1) / 2), + abs(p1[1] * stride - p2[1] + (size - 1) / 2)) + if offset not in attention_offsets: + attention_offsets[offset] = len(attention_offsets) + idxs.append(attention_offsets[offset]) + self.attention_biases = torch.nn.Parameter(torch.zeros(num_heads, len(attention_offsets))) + self.register_buffer('attention_bias_idxs', torch.LongTensor(idxs).view(N_, N)) + self.ab = None + + @torch.no_grad() + def train(self, mode=True): + super().train(mode) + if mode and self.ab is not None: + self.ab = None + else: + self.ab = self.attention_biases[:, self.attention_bias_idxs] + + def forward(self, x): + B, C, H, W = x.shape + k, v = self.kv(x).view(B, self.num_heads, -1, H * W).split([self.key_dim, self.d], dim=2) + q = self.q(x).view(B, self.num_heads, self.key_dim, self.resolution_2) + ab = self.attention_biases[:, self.attention_bias_idxs] if self.training else self.ab + attn = (q.transpose(-2, -1) @ k) * self.scale + ab + attn = attn.softmax(dim=-1) + + x = (v @ attn.transpose(-2, -1)).reshape(B, -1, self.resolution_, self.resolution_) + x = self.proj(x) + return x + + +class Levit(torch.nn.Module): + """ Vision Transformer with support for patch or hybrid CNN input stage + """ + + def __init__( + self, + img_size=224, + patch_size=16, + in_chans=3, + num_classes=1000, + embed_dim=[192], + key_dim=[64], + depth=[12], + num_heads=[3], + attn_ratio=[2], + mlp_ratio=[2], + hybrid_backbone=None, + down_ops=[], + attn_act_layer=torch.nn.Hardswish, + mlp_act_layer=torch.nn.Hardswish, + distillation=True, + drop_path=0): + super().__init__() + self.num_classes = num_classes + self.num_features = embed_dim[-1] + self.embed_dim = embed_dim + self.distillation = distillation + + self.patch_embed = hybrid_backbone + + self.blocks = [] + down_ops.append(['']) + resolution = img_size // patch_size + for i, (ed, kd, dpth, nh, ar, mr, do) in enumerate( + zip(embed_dim, key_dim, depth, num_heads, attn_ratio, mlp_ratio, down_ops)): + for _ in range(dpth): + self.blocks.append( + Residual( + Attention(ed, kd, nh, attn_ratio=ar, act_layer=attn_act_layer, resolution=resolution), + drop_path)) + if mr > 0: + h = int(ed * mr) + self.blocks.append( + Residual(torch.nn.Sequential( + ConvNorm(ed, h, resolution=resolution), + mlp_act_layer(), + ConvNorm(h, ed, bn_weight_init=0, resolution=resolution), + ), drop_path)) + if do[0] == 'Subsample': + # ('Subsample',key_dim, num_heads, attn_ratio, mlp_ratio, stride) + resolution_ = (resolution - 1) // do[5] + 1 + self.blocks.append( + AttentionSubsample( + *embed_dim[i:i + 2], key_dim=do[1], num_heads=do[2], attn_ratio=do[3], + act_layer=attn_act_layer, stride=do[5], + resolution=resolution, resolution_=resolution_)) + resolution = resolution_ + if do[4] > 0: # mlp_ratio + h = int(embed_dim[i + 1] * do[4]) + self.blocks.append( + Residual(torch.nn.Sequential( + ConvNorm(embed_dim[i + 1], h, resolution=resolution), + mlp_act_layer(), + ConvNorm(h, embed_dim[i + 1], bn_weight_init=0, resolution=resolution), + ), drop_path)) + self.blocks = torch.nn.Sequential(*self.blocks) + + # Classifier head + self.head = NormLinear( + embed_dim[-1], num_classes) if num_classes > 0 else torch.nn.Identity() + if distillation: + self.head_dist = NormLinear( + embed_dim[-1], num_classes) if num_classes > 0 else torch.nn.Identity() + + @torch.jit.ignore + def no_weight_decay(self): + return {x for x in self.state_dict().keys() if 'attention_biases' in x} + + def forward(self, x): + x = self.patch_embed(x) + x = self.blocks(x) + x = torch.nn.functional.adaptive_avg_pool2d(x, 1).flatten(1) + if self.distillation: + x = self.head(x), self.head_dist(x) + if not self.training: + x = (x[0] + x[1]) / 2 + else: + x = self.head(x) + return x + + +def model_factory(C, D, X, N, drop_path, weights, num_classes, distillation, pretrained, fuse): + embed_dim = [int(x) for x in C.split('_')] + num_heads = [int(x) for x in N.split('_')] + depth = [int(x) for x in X.split('_')] + act = torch.nn.Hardswish + model = Levit( + patch_size=16, + embed_dim=embed_dim, + num_heads=num_heads, + key_dim=[D] * 3, + depth=depth, + attn_ratio=[2, 2, 2], + mlp_ratio=[2, 2, 2], + down_ops=[ + # ('Subsample',key_dim, num_heads, attn_ratio, mlp_ratio, stride) + ['Subsample', D, embed_dim[0] // D, 4, 2, 2], + ['Subsample', D, embed_dim[1] // D, 4, 2, 2], + ], + attn_act_layer=act, + mlp_act_layer=act, + hybrid_backbone=b16(embed_dim[0], activation=act), + num_classes=num_classes, + drop_path=drop_path, + distillation=distillation + ) + model.default_cfg = _cfg() + if pretrained: + checkpoint = torch.hub.load_state_dict_from_url( + weights, map_location='cpu') + d = checkpoint['model'] + D = model.state_dict() + for k in d.keys(): + if D[k].shape != d[k].shape: + d[k] = d[k][:, :, None, None] + model.load_state_dict(d) + #if fuse: + # utils.replace_batchnorm(model) + + return model + diff --git a/timm/models/visformer.py b/timm/models/visformer.py new file mode 100644 index 00000000..0d213ad5 --- /dev/null +++ b/timm/models/visformer.py @@ -0,0 +1,377 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from .helpers import build_model_with_cfg, overlay_external_default_cfg +from .layers import to_2tuple, trunc_normal_, DropPath, PatchEmbed +from .registry import register_model + + +__all__ = ['Visformer'] + + +def _cfg(url='', **kwargs): + return { + 'url': url, + 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, + 'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True, + 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, + 'first_conv': 'patch_embed.proj', 'classifier': 'head', + **kwargs + } + + +class LayerNormBHWC(nn.LayerNorm): + def __init__(self, dim): + super().__init__(dim) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return F.layer_norm( + x.permute(0, 2, 3, 1), self.normalized_shape, self.weight, self.bias, self.eps).permute(0, 3, 1, 2) + + +class SpatialMlp(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, + act_layer=nn.GELU, drop=0., group=8, spatial_conv=False): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.in_features = in_features + self.out_features = out_features + self.spatial_conv = spatial_conv + if self.spatial_conv: + if group < 2: # net setting + hidden_features = in_features * 5 // 6 + else: + hidden_features = in_features * 2 + self.hidden_features = hidden_features + self.group = group + self.drop = nn.Dropout(drop) + self.conv1 = nn.Conv2d(in_features, hidden_features, 1, stride=1, padding=0, bias=False) + self.act1 = act_layer() + if self.spatial_conv: + self.conv2 = nn.Conv2d( + hidden_features, hidden_features, 3, stride=1, padding=1, groups=self.group, bias=False) + self.act2 = act_layer() + else: + self.conv2 = None + self.act2 = None + self.conv3 = nn.Conv2d(hidden_features, out_features, 1, stride=1, padding=0, bias=False) + + def forward(self, x): + x = self.conv1(x) + x = self.act1(x) + x = self.drop(x) + if self.conv2 is not None: + x = self.conv2(x) + x = self.act2(x) + x = self.conv3(x) + x = self.drop(x) + return x + + +class Attention(nn.Module): + def __init__(self, dim, num_heads=8, head_dim_ratio=1., attn_drop=0., proj_drop=0.): + super().__init__() + self.dim = dim + self.num_heads = num_heads + head_dim = round(dim // num_heads * head_dim_ratio) + self.head_dim = head_dim + self.scale = head_dim ** -0.5 + self.qkv = nn.Conv2d(dim, head_dim * num_heads * 3, 1, stride=1, padding=0, bias=False) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Conv2d(self.head_dim * self.num_heads, dim, 1, stride=1, padding=0, bias=False) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x): + B, C, H, W = x.shape + x = self.qkv(x).reshape(B, 3, self.num_heads, self.head_dim, -1).permute(1, 0, 2, 4, 3) + q, k, v = x[0], x[1], x[2] + + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + x = attn @ v + + x = x.permute(0, 1, 3, 2).reshape(B, -1, H, W) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class Block(nn.Module): + def __init__(self, dim, num_heads, head_dim_ratio=1., mlp_ratio=4., + drop=0., attn_drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=LayerNormBHWC, + group=8, attn_disabled=False, spatial_conv=False): + super().__init__() + self.spatial_conv = spatial_conv + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + if attn_disabled: + self.norm1 = None + self.attn = None + else: + self.norm1 = norm_layer(dim) + self.attn = Attention( + dim, num_heads=num_heads, head_dim_ratio=head_dim_ratio, attn_drop=attn_drop, proj_drop=drop) + + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = SpatialMlp( + in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop, + group=group, spatial_conv=spatial_conv) # new setting + + def forward(self, x): + if self.attn is not None: + x = x + self.drop_path(self.attn(self.norm1(x))) + x = x + self.drop_path(self.mlp(self.norm2(x))) + return x + + +class Visformer(nn.Module): + def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, init_channels=32, embed_dim=384, + depth=12, num_heads=6, mlp_ratio=4., drop_rate=0., attn_drop_rate=0., drop_path_rate=0., + norm_layer=LayerNormBHWC, attn_stage='111', pos_embed=True, spatial_conv='111', + vit_stem=False, group=8, pool=True, conv_init=False, embed_norm=None): + super().__init__() + self.num_classes = num_classes + self.num_features = self.embed_dim = embed_dim + self.init_channels = init_channels + self.img_size = img_size + self.vit_stem = vit_stem + self.pool = pool + self.conv_init = conv_init + if isinstance(depth, (list, tuple)): + self.stage_num1, self.stage_num2, self.stage_num3 = depth + depth = sum(depth) + else: + self.stage_num1 = self.stage_num3 = depth // 3 + self.stage_num2 = depth - self.stage_num1 - self.stage_num3 + self.pos_embed = pos_embed + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] + + # stage 1 + if self.vit_stem: + self.stem = None + self.patch_embed1 = PatchEmbed( + img_size=img_size, patch_size=patch_size, in_chans=in_chans, + embed_dim=embed_dim, norm_layer=embed_norm, flatten=False) + img_size //= 16 + else: + if self.init_channels is None: + self.stem = None + self.patch_embed1 = PatchEmbed( + img_size=img_size, patch_size=patch_size // 2, in_chans=in_chans, + embed_dim=embed_dim // 2, norm_layer=embed_norm, flatten=False) + img_size //= 8 + else: + self.stem = nn.Sequential( + nn.Conv2d(3, self.init_channels, 7, stride=2, padding=3, bias=False), + nn.BatchNorm2d(self.init_channels), + nn.ReLU(inplace=True) + ) + img_size //= 2 + self.patch_embed1 = PatchEmbed( + img_size=img_size, patch_size=patch_size // 4, in_chans=self.init_channels, + embed_dim=embed_dim // 2, norm_layer=embed_norm, flatten=False) + img_size //= 4 + + if self.pos_embed: + if self.vit_stem: + self.pos_embed1 = nn.Parameter(torch.zeros(1, embed_dim, img_size, img_size)) + else: + self.pos_embed1 = nn.Parameter(torch.zeros(1, embed_dim//2, img_size, img_size)) + self.pos_drop = nn.Dropout(p=drop_rate) + self.stage1 = nn.ModuleList([ + Block( + dim=embed_dim//2, num_heads=num_heads, head_dim_ratio=0.5, mlp_ratio=mlp_ratio, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, + group=group, attn_disabled=(attn_stage[0] == '0'), spatial_conv=(spatial_conv[0] == '1') + ) + for i in range(self.stage_num1) + ]) + + #stage2 + if not self.vit_stem: + self.patch_embed2 = PatchEmbed( + img_size=img_size, patch_size=patch_size // 8, in_chans=embed_dim // 2, + embed_dim=embed_dim, norm_layer=embed_norm, flatten=False) + img_size //= 2 + if self.pos_embed: + self.pos_embed2 = nn.Parameter(torch.zeros(1, embed_dim, img_size, img_size)) + self.stage2 = nn.ModuleList([ + Block( + dim=embed_dim, num_heads=num_heads, head_dim_ratio=1.0, mlp_ratio=mlp_ratio, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, + group=group, attn_disabled=(attn_stage[1] == '0'), spatial_conv=(spatial_conv[1] == '1') + ) + for i in range(self.stage_num1, self.stage_num1+self.stage_num2) + ]) + + # stage 3 + if not self.vit_stem: + self.patch_embed3 = PatchEmbed( + img_size=img_size, patch_size=patch_size // 8, in_chans=embed_dim, + embed_dim=embed_dim * 2, norm_layer=embed_norm, flatten=False) + img_size //= 2 + if self.pos_embed: + self.pos_embed3 = nn.Parameter(torch.zeros(1, embed_dim*2, img_size, img_size)) + self.stage3 = nn.ModuleList([ + Block( + dim=embed_dim*2, num_heads=num_heads, head_dim_ratio=1.0, mlp_ratio=mlp_ratio, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, + group=group, attn_disabled=(attn_stage[2] == '0'), spatial_conv=(spatial_conv[2] == '1') + ) + for i in range(self.stage_num1+self.stage_num2, depth) + ]) + + # head + if self.pool: + self.global_pooling = nn.AdaptiveAvgPool2d(1) + head_dim = embed_dim if self.vit_stem else embed_dim * 2 + self.norm = norm_layer(head_dim) + self.head = nn.Linear(head_dim, num_classes) + + # weights init + if self.pos_embed: + trunc_normal_(self.pos_embed1, std=0.02) + if not self.vit_stem: + trunc_normal_(self.pos_embed2, std=0.02) + trunc_normal_(self.pos_embed3, std=0.02) + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=0.02) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + elif isinstance(m, nn.BatchNorm2d): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + elif isinstance(m, nn.Conv2d): + if self.conv_init: + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + else: + trunc_normal_(m.weight, std=0.02) + if m.bias is not None: + nn.init.constant_(m.bias, 0.) + + def forward(self, x): + if self.stem is not None: + x = self.stem(x) + + # stage 1 + x = self.patch_embed1(x) + if self.pos_embed: + x = x + self.pos_embed1 + x = self.pos_drop(x) + for b in self.stage1: + x = b(x) + + # stage 2 + if not self.vit_stem: + x = self.patch_embed2(x) + if self.pos_embed: + x = x + self.pos_embed2 + x = self.pos_drop(x) + for b in self.stage2: + x = b(x) + + # stage3 + if not self.vit_stem: + x = self.patch_embed3(x) + if self.pos_embed: + x = x + self.pos_embed3 + x = self.pos_drop(x) + for b in self.stage3: + x = b(x) + + # head + x = self.norm(x) + if self.pool: + x = self.global_pooling(x) + else: + x = x[:, :, 0, 0] + + x = self.head(x.view(x.size(0), -1)) + return x + + +@register_model +def visformer_tiny(pretrained=False, **kwargs): + model = Visformer( + img_size=224, init_channels=16, embed_dim=192, depth=(7, 4, 4), num_heads=3, mlp_ratio=4., group=8, + attn_stage='011', spatial_conv='100', norm_layer=nn.BatchNorm2d, conv_init=True, + embed_norm=nn.BatchNorm2d, **kwargs) + return model + + +@register_model +def visformer_small(pretrained=False, **kwargs): + model = Visformer( + img_size=224, init_channels=32, embed_dim=384, depth=(7, 4, 4), num_heads=6, mlp_ratio=4., group=8, + attn_stage='011', spatial_conv='100', norm_layer=nn.BatchNorm2d, conv_init=True, + embed_norm=nn.BatchNorm2d, **kwargs) + return model + + +@register_model +def visformer_net1(pretrained=False, **kwargs): + model = Visformer( + init_channels=None, embed_dim=384, depth=(0, 12, 0), num_heads=6, mlp_ratio=4., attn_stage='111', + spatial_conv='000', vit_stem=True, conv_init=True, **kwargs) + return model + + +@register_model +def visformer_net2(pretrained=False, **kwargs): + model = Visformer( + init_channels=32, embed_dim=384, depth=(0, 12, 0), num_heads=6, mlp_ratio=4., attn_stage='111', + spatial_conv='000', vit_stem=False, conv_init=True, **kwargs) + return model + + +@register_model +def visformer_net3(pretrained=False, **kwargs): + model = Visformer( + init_channels=32, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4., attn_stage='111', + spatial_conv='000', vit_stem=False, conv_init=True, **kwargs) + return model + + +@register_model +def visformer_net4(pretrained=False, **kwargs): + model = Visformer(init_channels=32, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4., attn_stage='111', + spatial_conv='000', vit_stem=False, conv_init=True, **kwargs) + return model + + +@register_model +def visformer_net5(pretrained=False, **kwargs): + model = Visformer( + init_channels=32, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4., group=1, attn_stage='111', + spatial_conv='111', vit_stem=False, conv_init=True, **kwargs) + return model + + +@register_model +def visformer_net6(pretrained=False, **kwargs): + model = Visformer( + init_channels=32, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4., group=1, attn_stage='111', + pos_embed=False, spatial_conv='111', conv_init=True, **kwargs) + return model + + +@register_model +def visformer_net7(pretrained=False, **kwargs): + model = Visformer( + init_channels=32, embed_dim=384, depth=(6, 7, 7), num_heads=6, group=1, attn_stage='000', + pos_embed=False, spatial_conv='111', conv_init=True, **kwargs) + return model + + + +