Merge pull request #637 from rwightman/levit_visformer_rednet
LeVit, Visformer, RedNet/Involution models and layerspull/669/head
commit
07d952c7a7
@ -0,0 +1,50 @@
|
||||
""" PyTorch Involution Layer
|
||||
|
||||
Official impl: https://github.com/d-li14/involution/blob/main/cls/mmcls/models/utils/involution_naive.py
|
||||
Paper: `Involution: Inverting the Inherence of Convolution for Visual Recognition` - https://arxiv.org/abs/2103.06255
|
||||
"""
|
||||
import torch.nn as nn
|
||||
from .conv_bn_act import ConvBnAct
|
||||
from .create_conv2d import create_conv2d
|
||||
|
||||
|
||||
class Involution(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
group_size=16,
|
||||
reduction_ratio=4,
|
||||
norm_layer=nn.BatchNorm2d,
|
||||
act_layer=nn.ReLU,
|
||||
):
|
||||
super(Involution, self).__init__()
|
||||
self.kernel_size = kernel_size
|
||||
self.stride = stride
|
||||
self.channels = channels
|
||||
self.group_size = group_size
|
||||
self.groups = self.channels // self.group_size
|
||||
self.conv1 = ConvBnAct(
|
||||
in_channels=channels,
|
||||
out_channels=channels // reduction_ratio,
|
||||
kernel_size=1,
|
||||
norm_layer=norm_layer,
|
||||
act_layer=act_layer)
|
||||
self.conv2 = self.conv = create_conv2d(
|
||||
in_channels=channels // reduction_ratio,
|
||||
out_channels=kernel_size**2 * self.groups,
|
||||
kernel_size=1,
|
||||
stride=1)
|
||||
self.avgpool = nn.AvgPool2d(stride, stride) if stride == 2 else nn.Identity()
|
||||
self.unfold = nn.Unfold(kernel_size, 1, (kernel_size-1)//2, stride)
|
||||
|
||||
def forward(self, x):
|
||||
weight = self.conv2(self.conv1(self.avgpool(x)))
|
||||
B, C, H, W = weight.shape
|
||||
KK = int(self.kernel_size ** 2)
|
||||
weight = weight.view(B, self.groups, KK, H, W).unsqueeze(2)
|
||||
out = self.unfold(x).view(B, self.groups, self.group_size, KK, H, W)
|
||||
out = (weight * out).sum(dim=3).view(B, self.channels, H, W)
|
||||
return out
|
@ -0,0 +1,568 @@
|
||||
""" LeViT
|
||||
|
||||
Paper: `LeViT: a Vision Transformer in ConvNet's Clothing for Faster Inference`
|
||||
- https://arxiv.org/abs/2104.01136
|
||||
|
||||
@article{graham2021levit,
|
||||
title={LeViT: a Vision Transformer in ConvNet's Clothing for Faster Inference},
|
||||
author={Benjamin Graham and Alaaeldin El-Nouby and Hugo Touvron and Pierre Stock and Armand Joulin and Herv\'e J\'egou and Matthijs Douze},
|
||||
journal={arXiv preprint arXiv:22104.01136},
|
||||
year={2021}
|
||||
}
|
||||
|
||||
Adapted from official impl at https://github.com/facebookresearch/LeViT, original copyright bellow.
|
||||
|
||||
This version combines both conv/linear models and fixes torchscript compatibility.
|
||||
|
||||
Modifications by/coyright Copyright 2021 Ross Wightman
|
||||
"""
|
||||
|
||||
# Copyright (c) 2015-present, Facebook, Inc.
|
||||
# All rights reserved.
|
||||
|
||||
# Modified from
|
||||
# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
|
||||
# Copyright 2020 Ross Wightman, Apache-2.0 License
|
||||
import itertools
|
||||
from copy import deepcopy
|
||||
from functools import partial
|
||||
from typing import Dict
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from timm.data import IMAGENET_DEFAULT_STD, IMAGENET_DEFAULT_MEAN
|
||||
from .helpers import build_model_with_cfg, overlay_external_default_cfg
|
||||
from .layers import to_ntuple
|
||||
from .vision_transformer import trunc_normal_
|
||||
from .registry import register_model
|
||||
|
||||
|
||||
def _cfg(url='', **kwargs):
|
||||
return {
|
||||
'url': url,
|
||||
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
|
||||
'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True,
|
||||
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
|
||||
'first_conv': 'patch_embed.0.c', 'classifier': ('head.l', 'head_dist.l'),
|
||||
**kwargs
|
||||
}
|
||||
|
||||
|
||||
default_cfgs = dict(
|
||||
levit_128s=_cfg(
|
||||
url='https://dl.fbaipublicfiles.com/LeViT/LeViT-128S-96703c44.pth'
|
||||
),
|
||||
levit_128=_cfg(
|
||||
url='https://dl.fbaipublicfiles.com/LeViT/LeViT-128-b88c2750.pth'
|
||||
),
|
||||
levit_192=_cfg(
|
||||
url='https://dl.fbaipublicfiles.com/LeViT/LeViT-192-92712e41.pth'
|
||||
),
|
||||
levit_256=_cfg(
|
||||
url='https://dl.fbaipublicfiles.com/LeViT/LeViT-256-13b5763e.pth'
|
||||
),
|
||||
levit_384=_cfg(
|
||||
url='https://dl.fbaipublicfiles.com/LeViT/LeViT-384-9bdaf2e2.pth'
|
||||
),
|
||||
)
|
||||
|
||||
model_cfgs = dict(
|
||||
levit_128s=dict(
|
||||
embed_dim=(128, 256, 384), key_dim=16, num_heads=(4, 6, 8), depth=(2, 3, 4)),
|
||||
levit_128=dict(
|
||||
embed_dim=(128, 256, 384), key_dim=16, num_heads=(4, 8, 12), depth=(4, 4, 4)),
|
||||
levit_192=dict(
|
||||
embed_dim=(192, 288, 384), key_dim=32, num_heads=(3, 5, 6), depth=(4, 4, 4)),
|
||||
levit_256=dict(
|
||||
embed_dim=(256, 384, 512), key_dim=32, num_heads=(4, 6, 8), depth=(4, 4, 4)),
|
||||
levit_384=dict(
|
||||
embed_dim=(384, 512, 768), key_dim=32, num_heads=(6, 9, 12), depth=(4, 4, 4)),
|
||||
)
|
||||
|
||||
__all__ = ['Levit']
|
||||
|
||||
|
||||
@register_model
|
||||
def levit_128s(pretrained=False, fuse=False,distillation=True, use_conv=False, **kwargs):
|
||||
return create_levit(
|
||||
'levit_128s', pretrained=pretrained, fuse=fuse, distillation=distillation, use_conv=use_conv, **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def levit_128(pretrained=False, fuse=False, distillation=True, use_conv=False, **kwargs):
|
||||
return create_levit(
|
||||
'levit_128', pretrained=pretrained, fuse=fuse, distillation=distillation, use_conv=use_conv, **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def levit_192(pretrained=False, fuse=False, distillation=True, use_conv=False, **kwargs):
|
||||
return create_levit(
|
||||
'levit_192', pretrained=pretrained, fuse=fuse, distillation=distillation, use_conv=use_conv, **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def levit_256(pretrained=False, fuse=False, distillation=True, use_conv=False, **kwargs):
|
||||
return create_levit(
|
||||
'levit_256', pretrained=pretrained, fuse=fuse, distillation=distillation, use_conv=use_conv, **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def levit_384(pretrained=False, fuse=False, distillation=True, use_conv=False, **kwargs):
|
||||
return create_levit(
|
||||
'levit_384', pretrained=pretrained, fuse=fuse, distillation=distillation, use_conv=use_conv, **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def levit_c_128s(pretrained=False, fuse=False, distillation=True, use_conv=True,**kwargs):
|
||||
return create_levit(
|
||||
'levit_128s', pretrained=pretrained, fuse=fuse, distillation=distillation, use_conv=use_conv, **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def levit_c_128(pretrained=False, fuse=False,distillation=True, use_conv=True, **kwargs):
|
||||
return create_levit(
|
||||
'levit_128', pretrained=pretrained, fuse=fuse, distillation=distillation, use_conv=use_conv, **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def levit_c_192(pretrained=False, fuse=False, distillation=True, use_conv=True, **kwargs):
|
||||
return create_levit(
|
||||
'levit_192', pretrained=pretrained, fuse=fuse, distillation=distillation, use_conv=use_conv, **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def levit_c_256(pretrained=False, fuse=False, distillation=True, use_conv=True, **kwargs):
|
||||
return create_levit(
|
||||
'levit_256', pretrained=pretrained, fuse=fuse, distillation=distillation, use_conv=use_conv, **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def levit_c_384(pretrained=False, fuse=False, distillation=True, use_conv=True, **kwargs):
|
||||
return create_levit(
|
||||
'levit_384', pretrained=pretrained, fuse=fuse, distillation=distillation, use_conv=use_conv, **kwargs)
|
||||
|
||||
|
||||
class ConvNorm(nn.Sequential):
|
||||
def __init__(
|
||||
self, a, b, ks=1, stride=1, pad=0, dilation=1, groups=1, bn_weight_init=1, resolution=-10000):
|
||||
super().__init__()
|
||||
self.add_module('c', nn.Conv2d(a, b, ks, stride, pad, dilation, groups, bias=False))
|
||||
bn = nn.BatchNorm2d(b)
|
||||
nn.init.constant_(bn.weight, bn_weight_init)
|
||||
nn.init.constant_(bn.bias, 0)
|
||||
self.add_module('bn', bn)
|
||||
|
||||
@torch.no_grad()
|
||||
def fuse(self):
|
||||
c, bn = self._modules.values()
|
||||
w = bn.weight / (bn.running_var + bn.eps) ** 0.5
|
||||
w = c.weight * w[:, None, None, None]
|
||||
b = bn.bias - bn.running_mean * bn.weight / (bn.running_var + bn.eps) ** 0.5
|
||||
m = nn.Conv2d(
|
||||
w.size(1), w.size(0), w.shape[2:], stride=self.c.stride,
|
||||
padding=self.c.padding, dilation=self.c.dilation, groups=self.c.groups)
|
||||
m.weight.data.copy_(w)
|
||||
m.bias.data.copy_(b)
|
||||
return m
|
||||
|
||||
|
||||
class LinearNorm(nn.Sequential):
|
||||
def __init__(self, a, b, bn_weight_init=1, resolution=-100000):
|
||||
super().__init__()
|
||||
self.add_module('c', nn.Linear(a, b, bias=False))
|
||||
bn = nn.BatchNorm1d(b)
|
||||
nn.init.constant_(bn.weight, bn_weight_init)
|
||||
nn.init.constant_(bn.bias, 0)
|
||||
self.add_module('bn', bn)
|
||||
|
||||
@torch.no_grad()
|
||||
def fuse(self):
|
||||
l, bn = self._modules.values()
|
||||
w = bn.weight / (bn.running_var + bn.eps) ** 0.5
|
||||
w = l.weight * w[:, None]
|
||||
b = bn.bias - bn.running_mean * bn.weight / (bn.running_var + bn.eps) ** 0.5
|
||||
m = nn.Linear(w.size(1), w.size(0))
|
||||
m.weight.data.copy_(w)
|
||||
m.bias.data.copy_(b)
|
||||
return m
|
||||
|
||||
def forward(self, x):
|
||||
x = self.c(x)
|
||||
return self.bn(x.flatten(0, 1)).reshape_as(x)
|
||||
|
||||
|
||||
class NormLinear(nn.Sequential):
|
||||
def __init__(self, a, b, bias=True, std=0.02):
|
||||
super().__init__()
|
||||
self.add_module('bn', nn.BatchNorm1d(a))
|
||||
l = nn.Linear(a, b, bias=bias)
|
||||
trunc_normal_(l.weight, std=std)
|
||||
if bias:
|
||||
nn.init.constant_(l.bias, 0)
|
||||
self.add_module('l', l)
|
||||
|
||||
@torch.no_grad()
|
||||
def fuse(self):
|
||||
bn, l = self._modules.values()
|
||||
w = bn.weight / (bn.running_var + bn.eps) ** 0.5
|
||||
b = bn.bias - self.bn.running_mean * self.bn.weight / (bn.running_var + bn.eps) ** 0.5
|
||||
w = l.weight * w[None, :]
|
||||
if l.bias is None:
|
||||
b = b @ self.l.weight.T
|
||||
else:
|
||||
b = (l.weight @ b[:, None]).view(-1) + self.l.bias
|
||||
m = nn.Linear(w.size(1), w.size(0))
|
||||
m.weight.data.copy_(w)
|
||||
m.bias.data.copy_(b)
|
||||
return m
|
||||
|
||||
|
||||
def stem_b16(in_chs, out_chs, activation, resolution=224):
|
||||
return nn.Sequential(
|
||||
ConvNorm(in_chs, out_chs // 8, 3, 2, 1, resolution=resolution),
|
||||
activation(),
|
||||
ConvNorm(out_chs // 8, out_chs // 4, 3, 2, 1, resolution=resolution // 2),
|
||||
activation(),
|
||||
ConvNorm(out_chs // 4, out_chs // 2, 3, 2, 1, resolution=resolution // 4),
|
||||
activation(),
|
||||
ConvNorm(out_chs // 2, out_chs, 3, 2, 1, resolution=resolution // 8))
|
||||
|
||||
|
||||
class Residual(nn.Module):
|
||||
def __init__(self, m, drop):
|
||||
super().__init__()
|
||||
self.m = m
|
||||
self.drop = drop
|
||||
|
||||
def forward(self, x):
|
||||
if self.training and self.drop > 0:
|
||||
return x + self.m(x) * torch.rand(
|
||||
x.size(0), 1, 1, device=x.device).ge_(self.drop).div(1 - self.drop).detach()
|
||||
else:
|
||||
return x + self.m(x)
|
||||
|
||||
|
||||
class Subsample(nn.Module):
|
||||
def __init__(self, stride, resolution):
|
||||
super().__init__()
|
||||
self.stride = stride
|
||||
self.resolution = resolution
|
||||
|
||||
def forward(self, x):
|
||||
B, N, C = x.shape
|
||||
x = x.view(B, self.resolution, self.resolution, C)[:, ::self.stride, ::self.stride]
|
||||
return x.reshape(B, -1, C)
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
ab: Dict[str, torch.Tensor]
|
||||
|
||||
def __init__(
|
||||
self, dim, key_dim, num_heads=8, attn_ratio=4, act_layer=None, resolution=14, use_conv=False):
|
||||
super().__init__()
|
||||
|
||||
self.num_heads = num_heads
|
||||
self.scale = key_dim ** -0.5
|
||||
self.key_dim = key_dim
|
||||
self.nh_kd = nh_kd = key_dim * num_heads
|
||||
self.d = int(attn_ratio * key_dim)
|
||||
self.dh = int(attn_ratio * key_dim) * num_heads
|
||||
self.attn_ratio = attn_ratio
|
||||
self.use_conv = use_conv
|
||||
ln_layer = ConvNorm if self.use_conv else LinearNorm
|
||||
h = self.dh + nh_kd * 2
|
||||
self.qkv = ln_layer(dim, h, resolution=resolution)
|
||||
self.proj = nn.Sequential(
|
||||
act_layer(),
|
||||
ln_layer(self.dh, dim, bn_weight_init=0, resolution=resolution))
|
||||
|
||||
points = list(itertools.product(range(resolution), range(resolution)))
|
||||
N = len(points)
|
||||
attention_offsets = {}
|
||||
idxs = []
|
||||
for p1 in points:
|
||||
for p2 in points:
|
||||
offset = (abs(p1[0] - p2[0]), abs(p1[1] - p2[1]))
|
||||
if offset not in attention_offsets:
|
||||
attention_offsets[offset] = len(attention_offsets)
|
||||
idxs.append(attention_offsets[offset])
|
||||
self.attention_biases = nn.Parameter(torch.zeros(num_heads, len(attention_offsets)))
|
||||
self.register_buffer('attention_bias_idxs', torch.LongTensor(idxs).view(N, N))
|
||||
self.ab = {}
|
||||
|
||||
@torch.no_grad()
|
||||
def train(self, mode=True):
|
||||
super().train(mode)
|
||||
if mode and self.ab:
|
||||
self.ab = {} # clear ab cache
|
||||
|
||||
def get_attention_biases(self, device: torch.device) -> torch.Tensor:
|
||||
if self.training:
|
||||
return self.attention_biases[:, self.attention_bias_idxs]
|
||||
else:
|
||||
device_key = str(device)
|
||||
if device_key not in self.ab:
|
||||
self.ab[device_key] = self.attention_biases[:, self.attention_bias_idxs]
|
||||
return self.ab[device_key]
|
||||
|
||||
def forward(self, x): # x (B,C,H,W)
|
||||
if self.use_conv:
|
||||
B, C, H, W = x.shape
|
||||
q, k, v = self.qkv(x).view(B, self.num_heads, -1, H * W).split([self.key_dim, self.key_dim, self.d], dim=2)
|
||||
|
||||
attn = (q.transpose(-2, -1) @ k) * self.scale + self.get_attention_biases(x.device)
|
||||
attn = attn.softmax(dim=-1)
|
||||
|
||||
x = (v @ attn.transpose(-2, -1)).view(B, -1, H, W)
|
||||
else:
|
||||
B, N, C = x.shape
|
||||
qkv = self.qkv(x)
|
||||
q, k, v = qkv.view(B, N, self.num_heads, -1).split([self.key_dim, self.key_dim, self.d], dim=3)
|
||||
q = q.permute(0, 2, 1, 3)
|
||||
k = k.permute(0, 2, 1, 3)
|
||||
v = v.permute(0, 2, 1, 3)
|
||||
|
||||
attn = q @ k.transpose(-2, -1) * self.scale + self.get_attention_biases(x.device)
|
||||
attn = attn.softmax(dim=-1)
|
||||
|
||||
x = (attn @ v).transpose(1, 2).reshape(B, N, self.dh)
|
||||
x = self.proj(x)
|
||||
return x
|
||||
|
||||
|
||||
class AttentionSubsample(nn.Module):
|
||||
ab: Dict[str, torch.Tensor]
|
||||
|
||||
def __init__(
|
||||
self, in_dim, out_dim, key_dim, num_heads=8, attn_ratio=2,
|
||||
act_layer=None, stride=2, resolution=14, resolution_=7, use_conv=False):
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
self.scale = key_dim ** -0.5
|
||||
self.key_dim = key_dim
|
||||
self.nh_kd = nh_kd = key_dim * num_heads
|
||||
self.d = int(attn_ratio * key_dim)
|
||||
self.dh = self.d * self.num_heads
|
||||
self.attn_ratio = attn_ratio
|
||||
self.resolution_ = resolution_
|
||||
self.resolution_2 = resolution_ ** 2
|
||||
self.use_conv = use_conv
|
||||
if self.use_conv:
|
||||
ln_layer = ConvNorm
|
||||
sub_layer = partial(nn.AvgPool2d, kernel_size=1, padding=0)
|
||||
else:
|
||||
ln_layer = LinearNorm
|
||||
sub_layer = partial(Subsample, resolution=resolution)
|
||||
|
||||
h = self.dh + nh_kd
|
||||
self.kv = ln_layer(in_dim, h, resolution=resolution)
|
||||
self.q = nn.Sequential(
|
||||
sub_layer(stride=stride),
|
||||
ln_layer(in_dim, nh_kd, resolution=resolution_))
|
||||
self.proj = nn.Sequential(
|
||||
act_layer(),
|
||||
ln_layer(self.dh, out_dim, resolution=resolution_))
|
||||
|
||||
self.stride = stride
|
||||
self.resolution = resolution
|
||||
points = list(itertools.product(range(resolution), range(resolution)))
|
||||
points_ = list(itertools.product(range(resolution_), range(resolution_)))
|
||||
N = len(points)
|
||||
N_ = len(points_)
|
||||
attention_offsets = {}
|
||||
idxs = []
|
||||
for p1 in points_:
|
||||
for p2 in points:
|
||||
size = 1
|
||||
offset = (
|
||||
abs(p1[0] * stride - p2[0] + (size - 1) / 2),
|
||||
abs(p1[1] * stride - p2[1] + (size - 1) / 2))
|
||||
if offset not in attention_offsets:
|
||||
attention_offsets[offset] = len(attention_offsets)
|
||||
idxs.append(attention_offsets[offset])
|
||||
self.attention_biases = nn.Parameter(torch.zeros(num_heads, len(attention_offsets)))
|
||||
self.register_buffer('attention_bias_idxs', torch.LongTensor(idxs).view(N_, N))
|
||||
self.ab = {} # per-device attention_biases cache
|
||||
|
||||
@torch.no_grad()
|
||||
def train(self, mode=True):
|
||||
super().train(mode)
|
||||
if mode and self.ab:
|
||||
self.ab = {} # clear ab cache
|
||||
|
||||
def get_attention_biases(self, device: torch.device) -> torch.Tensor:
|
||||
if self.training:
|
||||
return self.attention_biases[:, self.attention_bias_idxs]
|
||||
else:
|
||||
device_key = str(device)
|
||||
if device_key not in self.ab:
|
||||
self.ab[device_key] = self.attention_biases[:, self.attention_bias_idxs]
|
||||
return self.ab[device_key]
|
||||
|
||||
def forward(self, x):
|
||||
if self.use_conv:
|
||||
B, C, H, W = x.shape
|
||||
k, v = self.kv(x).view(B, self.num_heads, -1, H * W).split([self.key_dim, self.d], dim=2)
|
||||
q = self.q(x).view(B, self.num_heads, self.key_dim, self.resolution_2)
|
||||
|
||||
attn = (q.transpose(-2, -1) @ k) * self.scale + self.get_attention_biases(x.device)
|
||||
attn = attn.softmax(dim=-1)
|
||||
|
||||
x = (v @ attn.transpose(-2, -1)).reshape(B, -1, self.resolution_, self.resolution_)
|
||||
else:
|
||||
B, N, C = x.shape
|
||||
k, v = self.kv(x).view(B, N, self.num_heads, -1).split([self.key_dim, self.d], dim=3)
|
||||
k = k.permute(0, 2, 1, 3) # BHNC
|
||||
v = v.permute(0, 2, 1, 3) # BHNC
|
||||
q = self.q(x).view(B, self.resolution_2, self.num_heads, self.key_dim).permute(0, 2, 1, 3)
|
||||
|
||||
attn = q @ k.transpose(-2, -1) * self.scale + self.get_attention_biases(x.device)
|
||||
attn = attn.softmax(dim=-1)
|
||||
|
||||
x = (attn @ v).transpose(1, 2).reshape(B, -1, self.dh)
|
||||
x = self.proj(x)
|
||||
return x
|
||||
|
||||
|
||||
class Levit(nn.Module):
|
||||
""" Vision Transformer with support for patch or hybrid CNN input stage
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
img_size=224,
|
||||
patch_size=16,
|
||||
in_chans=3,
|
||||
num_classes=1000,
|
||||
embed_dim=(192,),
|
||||
key_dim=64,
|
||||
depth=(12,),
|
||||
num_heads=(3,),
|
||||
attn_ratio=2,
|
||||
mlp_ratio=2,
|
||||
hybrid_backbone=None,
|
||||
down_ops=None,
|
||||
act_layer=nn.Hardswish,
|
||||
attn_act_layer=nn.Hardswish,
|
||||
distillation=True,
|
||||
use_conv=False,
|
||||
drop_path=0):
|
||||
super().__init__()
|
||||
if isinstance(img_size, tuple):
|
||||
# FIXME origin impl passes single img/res dim through whole hierarchy,
|
||||
# not sure this model will be used enough to spend time fixing it.
|
||||
assert img_size[0] == img_size[1]
|
||||
img_size = img_size[0]
|
||||
self.num_classes = num_classes
|
||||
self.num_features = embed_dim[-1]
|
||||
self.embed_dim = embed_dim
|
||||
N = len(embed_dim)
|
||||
assert len(depth) == len(num_heads) == N
|
||||
key_dim = to_ntuple(N)(key_dim)
|
||||
attn_ratio = to_ntuple(N)(attn_ratio)
|
||||
mlp_ratio = to_ntuple(N)(mlp_ratio)
|
||||
down_ops = down_ops or (
|
||||
# ('Subsample',key_dim, num_heads, attn_ratio, mlp_ratio, stride)
|
||||
('Subsample', key_dim[0], embed_dim[0] // key_dim[0], 4, 2, 2),
|
||||
('Subsample', key_dim[0], embed_dim[1] // key_dim[1], 4, 2, 2),
|
||||
('',)
|
||||
)
|
||||
self.distillation = distillation
|
||||
self.use_conv = use_conv
|
||||
ln_layer = ConvNorm if self.use_conv else LinearNorm
|
||||
|
||||
self.patch_embed = hybrid_backbone or stem_b16(in_chans, embed_dim[0], activation=act_layer)
|
||||
|
||||
self.blocks = []
|
||||
resolution = img_size // patch_size
|
||||
for i, (ed, kd, dpth, nh, ar, mr, do) in enumerate(
|
||||
zip(embed_dim, key_dim, depth, num_heads, attn_ratio, mlp_ratio, down_ops)):
|
||||
for _ in range(dpth):
|
||||
self.blocks.append(
|
||||
Residual(
|
||||
Attention(
|
||||
ed, kd, nh, attn_ratio=ar, act_layer=attn_act_layer,
|
||||
resolution=resolution, use_conv=use_conv),
|
||||
drop_path))
|
||||
if mr > 0:
|
||||
h = int(ed * mr)
|
||||
self.blocks.append(
|
||||
Residual(nn.Sequential(
|
||||
ln_layer(ed, h, resolution=resolution),
|
||||
act_layer(),
|
||||
ln_layer(h, ed, bn_weight_init=0, resolution=resolution),
|
||||
), drop_path))
|
||||
if do[0] == 'Subsample':
|
||||
# ('Subsample',key_dim, num_heads, attn_ratio, mlp_ratio, stride)
|
||||
resolution_ = (resolution - 1) // do[5] + 1
|
||||
self.blocks.append(
|
||||
AttentionSubsample(
|
||||
*embed_dim[i:i + 2], key_dim=do[1], num_heads=do[2],
|
||||
attn_ratio=do[3], act_layer=attn_act_layer, stride=do[5],
|
||||
resolution=resolution, resolution_=resolution_, use_conv=use_conv))
|
||||
resolution = resolution_
|
||||
if do[4] > 0: # mlp_ratio
|
||||
h = int(embed_dim[i + 1] * do[4])
|
||||
self.blocks.append(
|
||||
Residual(nn.Sequential(
|
||||
ln_layer(embed_dim[i + 1], h, resolution=resolution),
|
||||
act_layer(),
|
||||
ln_layer(h, embed_dim[i + 1], bn_weight_init=0, resolution=resolution),
|
||||
), drop_path))
|
||||
self.blocks = nn.Sequential(*self.blocks)
|
||||
|
||||
# Classifier head
|
||||
self.head = NormLinear(embed_dim[-1], num_classes) if num_classes > 0 else nn.Identity()
|
||||
if distillation:
|
||||
self.head_dist = NormLinear(embed_dim[-1], num_classes) if num_classes > 0 else nn.Identity()
|
||||
else:
|
||||
self.head_dist = None
|
||||
|
||||
@torch.jit.ignore
|
||||
def no_weight_decay(self):
|
||||
return {x for x in self.state_dict().keys() if 'attention_biases' in x}
|
||||
|
||||
def forward(self, x):
|
||||
x = self.patch_embed(x)
|
||||
if not self.use_conv:
|
||||
x = x.flatten(2).transpose(1, 2)
|
||||
x = self.blocks(x)
|
||||
x = x.mean((-2, -1)) if self.use_conv else x.mean(1)
|
||||
if self.head_dist is not None:
|
||||
x, x_dist = self.head(x), self.head_dist(x)
|
||||
if self.training and not torch.jit.is_scripting():
|
||||
return x, x_dist
|
||||
else:
|
||||
# during inference, return the average of both classifier predictions
|
||||
return (x + x_dist) / 2
|
||||
else:
|
||||
x = self.head(x)
|
||||
return x
|
||||
|
||||
|
||||
def checkpoint_filter_fn(state_dict, model):
|
||||
if 'model' in state_dict:
|
||||
# For deit models
|
||||
state_dict = state_dict['model']
|
||||
D = model.state_dict()
|
||||
for k in state_dict.keys():
|
||||
if D[k].ndim == 4 and state_dict[k].ndim == 2:
|
||||
state_dict[k] = state_dict[k][:, :, None, None]
|
||||
return state_dict
|
||||
|
||||
|
||||
def create_levit(variant, pretrained=False, default_cfg=None, fuse=False, **kwargs):
|
||||
if kwargs.get('features_only', None):
|
||||
raise RuntimeError('features_only not implemented for Vision Transformer models.')
|
||||
|
||||
model_cfg = dict(**model_cfgs[variant], **kwargs)
|
||||
model = build_model_with_cfg(
|
||||
Levit, variant, pretrained,
|
||||
default_cfg=default_cfgs[variant],
|
||||
pretrained_filter_fn=checkpoint_filter_fn,
|
||||
**model_cfg)
|
||||
#if fuse:
|
||||
# utils.replace_batchnorm(model)
|
||||
return model
|
||||
|
@ -0,0 +1,414 @@
|
||||
""" Visformer
|
||||
|
||||
Paper: Visformer: The Vision-friendly Transformer - https://arxiv.org/abs/2104.12533
|
||||
|
||||
From original at https://github.com/danczs/Visformer
|
||||
|
||||
"""
|
||||
from copy import deepcopy
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from .helpers import build_model_with_cfg, overlay_external_default_cfg
|
||||
from .layers import to_2tuple, trunc_normal_, DropPath, PatchEmbed
|
||||
from .registry import register_model
|
||||
|
||||
|
||||
__all__ = ['Visformer']
|
||||
|
||||
|
||||
def _cfg(url='', **kwargs):
|
||||
return {
|
||||
'url': url,
|
||||
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
|
||||
'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True,
|
||||
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
|
||||
'first_conv': 'stem.0', 'classifier': 'head',
|
||||
**kwargs
|
||||
}
|
||||
|
||||
|
||||
default_cfgs = dict(
|
||||
visformer_tiny=_cfg(),
|
||||
visformer_small=_cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vt3p-weights/visformer_small-839e1f5b.pth'
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class LayerNormBHWC(nn.LayerNorm):
|
||||
def __init__(self, dim):
|
||||
super().__init__(dim)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return F.layer_norm(
|
||||
x.permute(0, 2, 3, 1), self.normalized_shape, self.weight, self.bias, self.eps).permute(0, 3, 1, 2)
|
||||
|
||||
|
||||
class SpatialMlp(nn.Module):
|
||||
def __init__(self, in_features, hidden_features=None, out_features=None,
|
||||
act_layer=nn.GELU, drop=0., group=8, spatial_conv=False):
|
||||
super().__init__()
|
||||
out_features = out_features or in_features
|
||||
hidden_features = hidden_features or in_features
|
||||
self.in_features = in_features
|
||||
self.out_features = out_features
|
||||
self.spatial_conv = spatial_conv
|
||||
if self.spatial_conv:
|
||||
if group < 2: # net setting
|
||||
hidden_features = in_features * 5 // 6
|
||||
else:
|
||||
hidden_features = in_features * 2
|
||||
self.hidden_features = hidden_features
|
||||
self.group = group
|
||||
self.drop = nn.Dropout(drop)
|
||||
self.conv1 = nn.Conv2d(in_features, hidden_features, 1, stride=1, padding=0, bias=False)
|
||||
self.act1 = act_layer()
|
||||
if self.spatial_conv:
|
||||
self.conv2 = nn.Conv2d(
|
||||
hidden_features, hidden_features, 3, stride=1, padding=1, groups=self.group, bias=False)
|
||||
self.act2 = act_layer()
|
||||
else:
|
||||
self.conv2 = None
|
||||
self.act2 = None
|
||||
self.conv3 = nn.Conv2d(hidden_features, out_features, 1, stride=1, padding=0, bias=False)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv1(x)
|
||||
x = self.act1(x)
|
||||
x = self.drop(x)
|
||||
if self.conv2 is not None:
|
||||
x = self.conv2(x)
|
||||
x = self.act2(x)
|
||||
x = self.conv3(x)
|
||||
x = self.drop(x)
|
||||
return x
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, dim, num_heads=8, head_dim_ratio=1., attn_drop=0., proj_drop=0.):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.num_heads = num_heads
|
||||
head_dim = round(dim // num_heads * head_dim_ratio)
|
||||
self.head_dim = head_dim
|
||||
self.scale = head_dim ** -0.5
|
||||
self.qkv = nn.Conv2d(dim, head_dim * num_heads * 3, 1, stride=1, padding=0, bias=False)
|
||||
self.attn_drop = nn.Dropout(attn_drop)
|
||||
self.proj = nn.Conv2d(self.head_dim * self.num_heads, dim, 1, stride=1, padding=0, bias=False)
|
||||
self.proj_drop = nn.Dropout(proj_drop)
|
||||
|
||||
def forward(self, x):
|
||||
B, C, H, W = x.shape
|
||||
x = self.qkv(x).reshape(B, 3, self.num_heads, self.head_dim, -1).permute(1, 0, 2, 4, 3)
|
||||
q, k, v = x[0], x[1], x[2]
|
||||
|
||||
attn = (q @ k.transpose(-2, -1)) * self.scale
|
||||
attn = attn.softmax(dim=-1)
|
||||
attn = self.attn_drop(attn)
|
||||
x = attn @ v
|
||||
|
||||
x = x.permute(0, 1, 3, 2).reshape(B, -1, H, W)
|
||||
x = self.proj(x)
|
||||
x = self.proj_drop(x)
|
||||
return x
|
||||
|
||||
|
||||
class Block(nn.Module):
|
||||
def __init__(self, dim, num_heads, head_dim_ratio=1., mlp_ratio=4.,
|
||||
drop=0., attn_drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=LayerNormBHWC,
|
||||
group=8, attn_disabled=False, spatial_conv=False):
|
||||
super().__init__()
|
||||
self.spatial_conv = spatial_conv
|
||||
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
||||
if attn_disabled:
|
||||
self.norm1 = None
|
||||
self.attn = None
|
||||
else:
|
||||
self.norm1 = norm_layer(dim)
|
||||
self.attn = Attention(
|
||||
dim, num_heads=num_heads, head_dim_ratio=head_dim_ratio, attn_drop=attn_drop, proj_drop=drop)
|
||||
|
||||
self.norm2 = norm_layer(dim)
|
||||
mlp_hidden_dim = int(dim * mlp_ratio)
|
||||
self.mlp = SpatialMlp(
|
||||
in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop,
|
||||
group=group, spatial_conv=spatial_conv) # new setting
|
||||
|
||||
def forward(self, x):
|
||||
if self.attn is not None:
|
||||
x = x + self.drop_path(self.attn(self.norm1(x)))
|
||||
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
||||
return x
|
||||
|
||||
|
||||
class Visformer(nn.Module):
|
||||
def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, init_channels=32, embed_dim=384,
|
||||
depth=12, num_heads=6, mlp_ratio=4., drop_rate=0., attn_drop_rate=0., drop_path_rate=0.,
|
||||
norm_layer=LayerNormBHWC, attn_stage='111', pos_embed=True, spatial_conv='111',
|
||||
vit_stem=False, group=8, pool=True, conv_init=False, embed_norm=None):
|
||||
super().__init__()
|
||||
self.num_classes = num_classes
|
||||
self.num_features = self.embed_dim = embed_dim
|
||||
self.init_channels = init_channels
|
||||
self.img_size = img_size
|
||||
self.vit_stem = vit_stem
|
||||
self.pool = pool
|
||||
self.conv_init = conv_init
|
||||
if isinstance(depth, (list, tuple)):
|
||||
self.stage_num1, self.stage_num2, self.stage_num3 = depth
|
||||
depth = sum(depth)
|
||||
else:
|
||||
self.stage_num1 = self.stage_num3 = depth // 3
|
||||
self.stage_num2 = depth - self.stage_num1 - self.stage_num3
|
||||
self.pos_embed = pos_embed
|
||||
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]
|
||||
|
||||
# stage 1
|
||||
if self.vit_stem:
|
||||
self.stem = None
|
||||
self.patch_embed1 = PatchEmbed(
|
||||
img_size=img_size, patch_size=patch_size, in_chans=in_chans,
|
||||
embed_dim=embed_dim, norm_layer=embed_norm, flatten=False)
|
||||
img_size //= 16
|
||||
else:
|
||||
if self.init_channels is None:
|
||||
self.stem = None
|
||||
self.patch_embed1 = PatchEmbed(
|
||||
img_size=img_size, patch_size=patch_size // 2, in_chans=in_chans,
|
||||
embed_dim=embed_dim // 2, norm_layer=embed_norm, flatten=False)
|
||||
img_size //= 8
|
||||
else:
|
||||
self.stem = nn.Sequential(
|
||||
nn.Conv2d(in_chans, self.init_channels, 7, stride=2, padding=3, bias=False),
|
||||
nn.BatchNorm2d(self.init_channels),
|
||||
nn.ReLU(inplace=True)
|
||||
)
|
||||
img_size //= 2
|
||||
self.patch_embed1 = PatchEmbed(
|
||||
img_size=img_size, patch_size=patch_size // 4, in_chans=self.init_channels,
|
||||
embed_dim=embed_dim // 2, norm_layer=embed_norm, flatten=False)
|
||||
img_size //= 4
|
||||
|
||||
if self.pos_embed:
|
||||
if self.vit_stem:
|
||||
self.pos_embed1 = nn.Parameter(torch.zeros(1, embed_dim, img_size, img_size))
|
||||
else:
|
||||
self.pos_embed1 = nn.Parameter(torch.zeros(1, embed_dim//2, img_size, img_size))
|
||||
self.pos_drop = nn.Dropout(p=drop_rate)
|
||||
self.stage1 = nn.ModuleList([
|
||||
Block(
|
||||
dim=embed_dim//2, num_heads=num_heads, head_dim_ratio=0.5, mlp_ratio=mlp_ratio,
|
||||
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
|
||||
group=group, attn_disabled=(attn_stage[0] == '0'), spatial_conv=(spatial_conv[0] == '1')
|
||||
)
|
||||
for i in range(self.stage_num1)
|
||||
])
|
||||
|
||||
#stage2
|
||||
if not self.vit_stem:
|
||||
self.patch_embed2 = PatchEmbed(
|
||||
img_size=img_size, patch_size=patch_size // 8, in_chans=embed_dim // 2,
|
||||
embed_dim=embed_dim, norm_layer=embed_norm, flatten=False)
|
||||
img_size //= 2
|
||||
if self.pos_embed:
|
||||
self.pos_embed2 = nn.Parameter(torch.zeros(1, embed_dim, img_size, img_size))
|
||||
self.stage2 = nn.ModuleList([
|
||||
Block(
|
||||
dim=embed_dim, num_heads=num_heads, head_dim_ratio=1.0, mlp_ratio=mlp_ratio,
|
||||
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
|
||||
group=group, attn_disabled=(attn_stage[1] == '0'), spatial_conv=(spatial_conv[1] == '1')
|
||||
)
|
||||
for i in range(self.stage_num1, self.stage_num1+self.stage_num2)
|
||||
])
|
||||
|
||||
# stage 3
|
||||
if not self.vit_stem:
|
||||
self.patch_embed3 = PatchEmbed(
|
||||
img_size=img_size, patch_size=patch_size // 8, in_chans=embed_dim,
|
||||
embed_dim=embed_dim * 2, norm_layer=embed_norm, flatten=False)
|
||||
img_size //= 2
|
||||
if self.pos_embed:
|
||||
self.pos_embed3 = nn.Parameter(torch.zeros(1, embed_dim*2, img_size, img_size))
|
||||
self.stage3 = nn.ModuleList([
|
||||
Block(
|
||||
dim=embed_dim*2, num_heads=num_heads, head_dim_ratio=1.0, mlp_ratio=mlp_ratio,
|
||||
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
|
||||
group=group, attn_disabled=(attn_stage[2] == '0'), spatial_conv=(spatial_conv[2] == '1')
|
||||
)
|
||||
for i in range(self.stage_num1+self.stage_num2, depth)
|
||||
])
|
||||
|
||||
# head
|
||||
if self.pool:
|
||||
self.global_pooling = nn.AdaptiveAvgPool2d(1)
|
||||
head_dim = embed_dim if self.vit_stem else embed_dim * 2
|
||||
self.norm = norm_layer(head_dim)
|
||||
self.head = nn.Linear(head_dim, num_classes)
|
||||
|
||||
# weights init
|
||||
if self.pos_embed:
|
||||
trunc_normal_(self.pos_embed1, std=0.02)
|
||||
if not self.vit_stem:
|
||||
trunc_normal_(self.pos_embed2, std=0.02)
|
||||
trunc_normal_(self.pos_embed3, std=0.02)
|
||||
self.apply(self._init_weights)
|
||||
|
||||
def _init_weights(self, m):
|
||||
if isinstance(m, nn.Linear):
|
||||
trunc_normal_(m.weight, std=0.02)
|
||||
if m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.LayerNorm):
|
||||
nn.init.constant_(m.bias, 0)
|
||||
nn.init.constant_(m.weight, 1.0)
|
||||
elif isinstance(m, nn.BatchNorm2d):
|
||||
nn.init.constant_(m.bias, 0)
|
||||
nn.init.constant_(m.weight, 1.0)
|
||||
elif isinstance(m, nn.Conv2d):
|
||||
if self.conv_init:
|
||||
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
||||
else:
|
||||
trunc_normal_(m.weight, std=0.02)
|
||||
if m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0.)
|
||||
|
||||
def forward(self, x):
|
||||
if self.stem is not None:
|
||||
x = self.stem(x)
|
||||
|
||||
# stage 1
|
||||
x = self.patch_embed1(x)
|
||||
if self.pos_embed:
|
||||
x = x + self.pos_embed1
|
||||
x = self.pos_drop(x)
|
||||
for b in self.stage1:
|
||||
x = b(x)
|
||||
|
||||
# stage 2
|
||||
if not self.vit_stem:
|
||||
x = self.patch_embed2(x)
|
||||
if self.pos_embed:
|
||||
x = x + self.pos_embed2
|
||||
x = self.pos_drop(x)
|
||||
for b in self.stage2:
|
||||
x = b(x)
|
||||
|
||||
# stage3
|
||||
if not self.vit_stem:
|
||||
x = self.patch_embed3(x)
|
||||
if self.pos_embed:
|
||||
x = x + self.pos_embed3
|
||||
x = self.pos_drop(x)
|
||||
for b in self.stage3:
|
||||
x = b(x)
|
||||
|
||||
# head
|
||||
x = self.norm(x)
|
||||
if self.pool:
|
||||
x = self.global_pooling(x)
|
||||
else:
|
||||
x = x[:, :, 0, 0]
|
||||
|
||||
x = self.head(x.view(x.size(0), -1))
|
||||
return x
|
||||
|
||||
|
||||
def _create_visformer(variant, pretrained=False, default_cfg=None, **kwargs):
|
||||
if kwargs.get('features_only', None):
|
||||
raise RuntimeError('features_only not implemented for Vision Transformer models.')
|
||||
model = build_model_with_cfg(
|
||||
Visformer, variant, pretrained,
|
||||
default_cfg=default_cfgs[variant],
|
||||
**kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def visformer_tiny(pretrained=False, **kwargs):
|
||||
model_cfg = dict(
|
||||
img_size=224, init_channels=16, embed_dim=192, depth=(7, 4, 4), num_heads=3, mlp_ratio=4., group=8,
|
||||
attn_stage='011', spatial_conv='100', norm_layer=nn.BatchNorm2d, conv_init=True,
|
||||
embed_norm=nn.BatchNorm2d, **kwargs)
|
||||
model = _create_visformer('visformer_tiny', pretrained=pretrained, **model_cfg)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def visformer_small(pretrained=False, **kwargs):
|
||||
model_cfg = dict(
|
||||
img_size=224, init_channels=32, embed_dim=384, depth=(7, 4, 4), num_heads=6, mlp_ratio=4., group=8,
|
||||
attn_stage='011', spatial_conv='100', norm_layer=nn.BatchNorm2d, conv_init=True,
|
||||
embed_norm=nn.BatchNorm2d, **kwargs)
|
||||
model = _create_visformer('visformer_small', pretrained=pretrained, **model_cfg)
|
||||
return model
|
||||
|
||||
|
||||
# @register_model
|
||||
# def visformer_net1(pretrained=False, **kwargs):
|
||||
# model = Visformer(
|
||||
# init_channels=None, embed_dim=384, depth=(0, 12, 0), num_heads=6, mlp_ratio=4., attn_stage='111',
|
||||
# spatial_conv='000', vit_stem=True, conv_init=True, **kwargs)
|
||||
# model.default_cfg = _cfg()
|
||||
# return model
|
||||
#
|
||||
#
|
||||
# @register_model
|
||||
# def visformer_net2(pretrained=False, **kwargs):
|
||||
# model = Visformer(
|
||||
# init_channels=32, embed_dim=384, depth=(0, 12, 0), num_heads=6, mlp_ratio=4., attn_stage='111',
|
||||
# spatial_conv='000', vit_stem=False, conv_init=True, **kwargs)
|
||||
# model.default_cfg = _cfg()
|
||||
# return model
|
||||
#
|
||||
#
|
||||
# @register_model
|
||||
# def visformer_net3(pretrained=False, **kwargs):
|
||||
# model = Visformer(
|
||||
# init_channels=32, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4., attn_stage='111',
|
||||
# spatial_conv='000', vit_stem=False, conv_init=True, **kwargs)
|
||||
# model.default_cfg = _cfg()
|
||||
# return model
|
||||
#
|
||||
#
|
||||
# @register_model
|
||||
# def visformer_net4(pretrained=False, **kwargs):
|
||||
# model = Visformer(
|
||||
# init_channels=32, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4., attn_stage='111',
|
||||
# spatial_conv='000', vit_stem=False, conv_init=True, **kwargs)
|
||||
# model.default_cfg = _cfg()
|
||||
# return model
|
||||
#
|
||||
#
|
||||
# @register_model
|
||||
# def visformer_net5(pretrained=False, **kwargs):
|
||||
# model = Visformer(
|
||||
# init_channels=32, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4., group=1, attn_stage='111',
|
||||
# spatial_conv='111', vit_stem=False, conv_init=True, **kwargs)
|
||||
# model.default_cfg = _cfg()
|
||||
# return model
|
||||
#
|
||||
#
|
||||
# @register_model
|
||||
# def visformer_net6(pretrained=False, **kwargs):
|
||||
# model = Visformer(
|
||||
# init_channels=32, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4., group=1, attn_stage='111',
|
||||
# pos_embed=False, spatial_conv='111', conv_init=True, **kwargs)
|
||||
# model.default_cfg = _cfg()
|
||||
# return model
|
||||
#
|
||||
#
|
||||
# @register_model
|
||||
# def visformer_net7(pretrained=False, **kwargs):
|
||||
# model = Visformer(
|
||||
# init_channels=32, embed_dim=384, depth=(6, 7, 7), num_heads=6, group=1, attn_stage='000',
|
||||
# pos_embed=False, spatial_conv='111', conv_init=True, **kwargs)
|
||||
# model.default_cfg = _cfg()
|
||||
# return model
|
||||
|
||||
|
||||
|
||||
|
@ -1 +1 @@
|
||||
__version__ = '0.4.9'
|
||||
__version__ = '0.4.10'
|
||||
|
Loading…
Reference in new issue