|
|
|
@ -1,6 +1,24 @@
|
|
|
|
|
"""These modules are adapted from those of timm, see
|
|
|
|
|
https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
|
|
|
|
|
""" ConViT Model
|
|
|
|
|
|
|
|
|
|
@article{d2021convit,
|
|
|
|
|
title={ConViT: Improving Vision Transformers with Soft Convolutional Inductive Biases},
|
|
|
|
|
author={d'Ascoli, St{\'e}phane and Touvron, Hugo and Leavitt, Matthew and Morcos, Ari and Biroli, Giulio and Sagun, Levent},
|
|
|
|
|
journal={arXiv preprint arXiv:2103.10697},
|
|
|
|
|
year={2021}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
Paper link: https://arxiv.org/abs/2103.10697
|
|
|
|
|
Original code: https://github.com/facebookresearch/convit, original copyright below
|
|
|
|
|
"""
|
|
|
|
|
# Copyright (c) 2015-present, Facebook, Inc.
|
|
|
|
|
# All rights reserved.
|
|
|
|
|
#
|
|
|
|
|
# This source code is licensed under the CC-by-NC license found in the
|
|
|
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
|
#
|
|
|
|
|
'''These modules are adapted from those of timm, see
|
|
|
|
|
https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
|
|
|
|
|
'''
|
|
|
|
|
|
|
|
|
|
import torch
|
|
|
|
|
import torch.nn as nn
|
|
|
|
@ -9,8 +27,9 @@ import torch.nn.functional as F
|
|
|
|
|
|
|
|
|
|
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
|
|
|
|
from .helpers import build_model_with_cfg
|
|
|
|
|
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
|
|
|
|
|
from timm.models.registry import register_model
|
|
|
|
|
from .layers import DropPath, to_2tuple, trunc_normal_, PatchEmbed, Mlp
|
|
|
|
|
from .registry import register_model
|
|
|
|
|
from .vision_transformer_hybrid import HybridEmbed
|
|
|
|
|
|
|
|
|
|
import torch
|
|
|
|
|
import torch.nn as nn
|
|
|
|
@ -29,7 +48,7 @@ def _cfg(url='', **kwargs):
|
|
|
|
|
default_cfgs = {
|
|
|
|
|
# ConViT
|
|
|
|
|
'convit_tiny': _cfg(
|
|
|
|
|
url="https://dl.fbaipublicfiles.com/convit/convit_tiny.pth"),
|
|
|
|
|
url="https://dl.fbaipublicfiles.com/convit/convit_tiny.pth"),
|
|
|
|
|
'convit_small': _cfg(
|
|
|
|
|
url="https://dl.fbaipublicfiles.com/convit/convit_small.pth"),
|
|
|
|
|
'convit_base': _cfg(
|
|
|
|
@ -37,71 +56,31 @@ default_cfgs = {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Mlp(nn.Module):
|
|
|
|
|
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
|
|
|
|
|
super().__init__()
|
|
|
|
|
out_features = out_features or in_features
|
|
|
|
|
hidden_features = hidden_features or in_features
|
|
|
|
|
self.fc1 = nn.Linear(in_features, hidden_features)
|
|
|
|
|
self.act = act_layer()
|
|
|
|
|
self.fc2 = nn.Linear(hidden_features, out_features)
|
|
|
|
|
self.drop = nn.Dropout(drop)
|
|
|
|
|
self.apply(self._init_weights)
|
|
|
|
|
|
|
|
|
|
def _init_weights(self, m):
|
|
|
|
|
if isinstance(m, nn.Linear):
|
|
|
|
|
trunc_normal_(m.weight, std=.02)
|
|
|
|
|
if isinstance(m, nn.Linear) and m.bias is not None:
|
|
|
|
|
nn.init.constant_(m.bias, 0)
|
|
|
|
|
elif isinstance(m, nn.LayerNorm):
|
|
|
|
|
nn.init.constant_(m.bias, 0)
|
|
|
|
|
nn.init.constant_(m.weight, 1.0)
|
|
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
|
x = self.fc1(x)
|
|
|
|
|
x = self.act(x)
|
|
|
|
|
x = self.drop(x)
|
|
|
|
|
x = self.fc2(x)
|
|
|
|
|
x = self.drop(x)
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class GPSA(nn.Module):
|
|
|
|
|
def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.,
|
|
|
|
|
locality_strength=1., use_local_init=True):
|
|
|
|
|
locality_strength=1.):
|
|
|
|
|
super().__init__()
|
|
|
|
|
self.num_heads = num_heads
|
|
|
|
|
self.dim = dim
|
|
|
|
|
head_dim = dim // num_heads
|
|
|
|
|
self.scale = qk_scale or head_dim ** -0.5
|
|
|
|
|
self.locality_strength = locality_strength
|
|
|
|
|
|
|
|
|
|
self.qk = nn.Linear(dim, dim * 2, bias=qkv_bias)
|
|
|
|
|
self.v = nn.Linear(dim, dim, bias=qkv_bias)
|
|
|
|
|
|
|
|
|
|
self.qk = nn.Linear(dim, dim * 2, bias=qkv_bias)
|
|
|
|
|
self.v = nn.Linear(dim, dim, bias=qkv_bias)
|
|
|
|
|
|
|
|
|
|
self.attn_drop = nn.Dropout(attn_drop)
|
|
|
|
|
self.proj = nn.Linear(dim, dim)
|
|
|
|
|
self.pos_proj = nn.Linear(3, num_heads)
|
|
|
|
|
self.proj_drop = nn.Dropout(proj_drop)
|
|
|
|
|
self.locality_strength = locality_strength
|
|
|
|
|
self.gating_param = nn.Parameter(torch.ones(self.num_heads))
|
|
|
|
|
self.apply(self._init_weights)
|
|
|
|
|
if use_local_init:
|
|
|
|
|
self.local_init(locality_strength=locality_strength)
|
|
|
|
|
self.rel_indices: torch.Tensor = torch.zeros(1, 1, 1, 3) # silly torchscript hack, won't work with None
|
|
|
|
|
|
|
|
|
|
def _init_weights(self, m):
|
|
|
|
|
if isinstance(m, nn.Linear):
|
|
|
|
|
trunc_normal_(m.weight, std=.02)
|
|
|
|
|
if isinstance(m, nn.Linear) and m.bias is not None:
|
|
|
|
|
nn.init.constant_(m.bias, 0)
|
|
|
|
|
elif isinstance(m, nn.LayerNorm):
|
|
|
|
|
nn.init.constant_(m.bias, 0)
|
|
|
|
|
nn.init.constant_(m.weight, 1.0)
|
|
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
|
B, N, C = x.shape
|
|
|
|
|
if not hasattr(self, 'rel_indices') or self.rel_indices.size(1)!=N:
|
|
|
|
|
self.get_rel_indices(N)
|
|
|
|
|
|
|
|
|
|
if self.rel_indices is None or self.rel_indices.shape[1] != N:
|
|
|
|
|
self.rel_indices = self.get_rel_indices(N)
|
|
|
|
|
attn = self.get_attention(x)
|
|
|
|
|
v = self.v(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
|
|
|
|
|
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
|
|
|
|
@ -110,61 +89,58 @@ class GPSA(nn.Module):
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
def get_attention(self, x):
|
|
|
|
|
B, N, C = x.shape
|
|
|
|
|
B, N, C = x.shape
|
|
|
|
|
qk = self.qk(x).reshape(B, N, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
|
|
|
|
q, k = qk[0], qk[1]
|
|
|
|
|
pos_score = self.rel_indices.expand(B, -1, -1,-1)
|
|
|
|
|
pos_score = self.pos_proj(pos_score).permute(0,3,1,2)
|
|
|
|
|
pos_score = self.rel_indices.expand(B, -1, -1, -1)
|
|
|
|
|
pos_score = self.pos_proj(pos_score).permute(0, 3, 1, 2)
|
|
|
|
|
patch_score = (q @ k.transpose(-2, -1)) * self.scale
|
|
|
|
|
patch_score = patch_score.softmax(dim=-1)
|
|
|
|
|
pos_score = pos_score.softmax(dim=-1)
|
|
|
|
|
|
|
|
|
|
gating = self.gating_param.view(1,-1,1,1)
|
|
|
|
|
attn = (1.-torch.sigmoid(gating)) * patch_score + torch.sigmoid(gating) * pos_score
|
|
|
|
|
gating = self.gating_param.view(1, -1, 1, 1)
|
|
|
|
|
attn = (1. - torch.sigmoid(gating)) * patch_score + torch.sigmoid(gating) * pos_score
|
|
|
|
|
attn /= attn.sum(dim=-1).unsqueeze(-1)
|
|
|
|
|
attn = self.attn_drop(attn)
|
|
|
|
|
return attn
|
|
|
|
|
|
|
|
|
|
def get_attention_map(self, x, return_map = False):
|
|
|
|
|
|
|
|
|
|
attn_map = self.get_attention(x).mean(0) # average over batch
|
|
|
|
|
distances = self.rel_indices.squeeze()[:,:,-1]**.5
|
|
|
|
|
dist = torch.einsum('nm,hnm->h', (distances, attn_map))
|
|
|
|
|
dist /= distances.size(0)
|
|
|
|
|
def get_attention_map(self, x, return_map=False):
|
|
|
|
|
attn_map = self.get_attention(x).mean(0) # average over batch
|
|
|
|
|
distances = self.rel_indices.squeeze()[:, :, -1] ** .5
|
|
|
|
|
dist = torch.einsum('nm,hnm->h', (distances, attn_map)) / distances.size(0)
|
|
|
|
|
if return_map:
|
|
|
|
|
return dist, attn_map
|
|
|
|
|
else:
|
|
|
|
|
return dist
|
|
|
|
|
|
|
|
|
|
def local_init(self, locality_strength=1.):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def local_init(self):
|
|
|
|
|
self.v.weight.data.copy_(torch.eye(self.dim))
|
|
|
|
|
locality_distance = 1 #max(1,1/locality_strength**.5)
|
|
|
|
|
|
|
|
|
|
kernel_size = int(self.num_heads**.5)
|
|
|
|
|
center = (kernel_size-1)/2 if kernel_size%2==0 else kernel_size//2
|
|
|
|
|
locality_distance = 1 # max(1,1/locality_strength**.5)
|
|
|
|
|
|
|
|
|
|
kernel_size = int(self.num_heads ** .5)
|
|
|
|
|
center = (kernel_size - 1) / 2 if kernel_size % 2 == 0 else kernel_size // 2
|
|
|
|
|
for h1 in range(kernel_size):
|
|
|
|
|
for h2 in range(kernel_size):
|
|
|
|
|
position = h1+kernel_size*h2
|
|
|
|
|
self.pos_proj.weight.data[position,2] = -1
|
|
|
|
|
self.pos_proj.weight.data[position,1] = 2*(h1-center)*locality_distance
|
|
|
|
|
self.pos_proj.weight.data[position,0] = 2*(h2-center)*locality_distance
|
|
|
|
|
self.pos_proj.weight.data *= locality_strength
|
|
|
|
|
|
|
|
|
|
def get_rel_indices(self, num_patches):
|
|
|
|
|
img_size = int(num_patches**.5)
|
|
|
|
|
rel_indices = torch.zeros(1, num_patches, num_patches, 3)
|
|
|
|
|
ind = torch.arange(img_size).view(1,-1) - torch.arange(img_size).view(-1, 1)
|
|
|
|
|
indx = ind.repeat(img_size,img_size)
|
|
|
|
|
indy = ind.repeat_interleave(img_size,dim=0).repeat_interleave(img_size,dim=1)
|
|
|
|
|
indd = indx**2 + indy**2
|
|
|
|
|
rel_indices[:,:,:,2] = indd.unsqueeze(0)
|
|
|
|
|
rel_indices[:,:,:,1] = indy.unsqueeze(0)
|
|
|
|
|
rel_indices[:,:,:,0] = indx.unsqueeze(0)
|
|
|
|
|
position = h1 + kernel_size * h2
|
|
|
|
|
self.pos_proj.weight.data[position, 2] = -1
|
|
|
|
|
self.pos_proj.weight.data[position, 1] = 2 * (h1 - center) * locality_distance
|
|
|
|
|
self.pos_proj.weight.data[position, 0] = 2 * (h2 - center) * locality_distance
|
|
|
|
|
self.pos_proj.weight.data *= self.locality_strength
|
|
|
|
|
|
|
|
|
|
def get_rel_indices(self, num_patches: int) -> torch.Tensor:
|
|
|
|
|
img_size = int(num_patches ** .5)
|
|
|
|
|
rel_indices = torch.zeros(1, num_patches, num_patches, 3)
|
|
|
|
|
ind = torch.arange(img_size).view(1, -1) - torch.arange(img_size).view(-1, 1)
|
|
|
|
|
indx = ind.repeat(img_size, img_size)
|
|
|
|
|
indy = ind.repeat_interleave(img_size, dim=0).repeat_interleave(img_size, dim=1)
|
|
|
|
|
indd = indx ** 2 + indy ** 2
|
|
|
|
|
rel_indices[:, :, :, 2] = indd.unsqueeze(0)
|
|
|
|
|
rel_indices[:, :, :, 1] = indy.unsqueeze(0)
|
|
|
|
|
rel_indices[:, :, :, 0] = indx.unsqueeze(0)
|
|
|
|
|
device = self.qk.weight.device
|
|
|
|
|
self.rel_indices = rel_indices.to(device)
|
|
|
|
|
return rel_indices.to(device)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class MHSA(nn.Module):
|
|
|
|
|
def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
|
|
|
|
|
super().__init__()
|
|
|
|
@ -176,41 +152,28 @@ class MHSA(nn.Module):
|
|
|
|
|
self.attn_drop = nn.Dropout(attn_drop)
|
|
|
|
|
self.proj = nn.Linear(dim, dim)
|
|
|
|
|
self.proj_drop = nn.Dropout(proj_drop)
|
|
|
|
|
self.apply(self._init_weights)
|
|
|
|
|
|
|
|
|
|
def _init_weights(self, m):
|
|
|
|
|
if isinstance(m, nn.Linear):
|
|
|
|
|
trunc_normal_(m.weight, std=.02)
|
|
|
|
|
if isinstance(m, nn.Linear) and m.bias is not None:
|
|
|
|
|
nn.init.constant_(m.bias, 0)
|
|
|
|
|
elif isinstance(m, nn.LayerNorm):
|
|
|
|
|
nn.init.constant_(m.bias, 0)
|
|
|
|
|
nn.init.constant_(m.weight, 1.0)
|
|
|
|
|
|
|
|
|
|
def get_attention_map(self, x, return_map = False):
|
|
|
|
|
def get_attention_map(self, x, return_map=False):
|
|
|
|
|
B, N, C = x.shape
|
|
|
|
|
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
|
|
|
|
q, k, v = qkv[0], qkv[1], qkv[2]
|
|
|
|
|
attn_map = (q @ k.transpose(-2, -1)) * self.scale
|
|
|
|
|
attn_map = attn_map.softmax(dim=-1).mean(0)
|
|
|
|
|
|
|
|
|
|
img_size = int(N**.5)
|
|
|
|
|
ind = torch.arange(img_size).view(1,-1) - torch.arange(img_size).view(-1, 1)
|
|
|
|
|
indx = ind.repeat(img_size,img_size)
|
|
|
|
|
indy = ind.repeat_interleave(img_size,dim=0).repeat_interleave(img_size,dim=1)
|
|
|
|
|
indd = indx**2 + indy**2
|
|
|
|
|
distances = indd**.5
|
|
|
|
|
img_size = int(N ** .5)
|
|
|
|
|
ind = torch.arange(img_size).view(1, -1) - torch.arange(img_size).view(-1, 1)
|
|
|
|
|
indx = ind.repeat(img_size, img_size)
|
|
|
|
|
indy = ind.repeat_interleave(img_size, dim=0).repeat_interleave(img_size, dim=1)
|
|
|
|
|
indd = indx ** 2 + indy ** 2
|
|
|
|
|
distances = indd ** .5
|
|
|
|
|
distances = distances.to('cuda')
|
|
|
|
|
|
|
|
|
|
dist = torch.einsum('nm,hnm->h', (distances, attn_map))
|
|
|
|
|
dist /= N
|
|
|
|
|
|
|
|
|
|
dist = torch.einsum('nm,hnm->h', (distances, attn_map)) / N
|
|
|
|
|
if return_map:
|
|
|
|
|
return dist, attn_map
|
|
|
|
|
else:
|
|
|
|
|
return dist
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
|
B, N, C = x.shape
|
|
|
|
|
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
|
|
|
@ -228,15 +191,19 @@ class MHSA(nn.Module):
|
|
|
|
|
|
|
|
|
|
class Block(nn.Module):
|
|
|
|
|
|
|
|
|
|
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
|
|
|
|
|
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
|
|
|
|
|
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, use_gpsa=True, **kwargs):
|
|
|
|
|
super().__init__()
|
|
|
|
|
self.norm1 = norm_layer(dim)
|
|
|
|
|
self.use_gpsa = use_gpsa
|
|
|
|
|
if self.use_gpsa:
|
|
|
|
|
self.attn = GPSA(dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, **kwargs)
|
|
|
|
|
self.attn = GPSA(
|
|
|
|
|
dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop,
|
|
|
|
|
proj_drop=drop, **kwargs)
|
|
|
|
|
else:
|
|
|
|
|
self.attn = MHSA(dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, **kwargs)
|
|
|
|
|
self.attn = MHSA(
|
|
|
|
|
dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop,
|
|
|
|
|
proj_drop=drop, **kwargs)
|
|
|
|
|
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
|
|
|
|
self.norm2 = norm_layer(dim)
|
|
|
|
|
mlp_hidden_dim = int(dim * mlp_ratio)
|
|
|
|
@ -246,75 +213,12 @@ class Block(nn.Module):
|
|
|
|
|
x = x + self.drop_path(self.attn(self.norm1(x)))
|
|
|
|
|
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class PatchEmbed(nn.Module):
|
|
|
|
|
""" Image to Patch Embedding, from timm
|
|
|
|
|
"""
|
|
|
|
|
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
|
|
|
|
|
super().__init__()
|
|
|
|
|
img_size = to_2tuple(img_size)
|
|
|
|
|
patch_size = to_2tuple(patch_size)
|
|
|
|
|
num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
|
|
|
|
|
self.img_size = img_size
|
|
|
|
|
self.patch_size = patch_size
|
|
|
|
|
self.num_patches = num_patches
|
|
|
|
|
|
|
|
|
|
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
|
|
|
|
|
self.apply(self._init_weights)
|
|
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
|
B, C, H, W = x.shape
|
|
|
|
|
assert H == self.img_size[0] and W == self.img_size[1], \
|
|
|
|
|
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
|
|
|
|
|
x = self.proj(x).flatten(2).transpose(1, 2)
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
def _init_weights(self, m):
|
|
|
|
|
if isinstance(m, nn.Linear):
|
|
|
|
|
trunc_normal_(m.weight, std=.02)
|
|
|
|
|
if isinstance(m, nn.Linear) and m.bias is not None:
|
|
|
|
|
nn.init.constant_(m.bias, 0)
|
|
|
|
|
elif isinstance(m, nn.LayerNorm):
|
|
|
|
|
nn.init.constant_(m.bias, 0)
|
|
|
|
|
nn.init.constant_(m.weight, 1.0)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class HybridEmbed(nn.Module):
|
|
|
|
|
""" CNN Feature Map Embedding, from timm
|
|
|
|
|
"""
|
|
|
|
|
def __init__(self, backbone, img_size=224, feature_size=None, in_chans=3, embed_dim=768):
|
|
|
|
|
super().__init__()
|
|
|
|
|
assert isinstance(backbone, nn.Module)
|
|
|
|
|
img_size = to_2tuple(img_size)
|
|
|
|
|
self.img_size = img_size
|
|
|
|
|
self.backbone = backbone
|
|
|
|
|
if feature_size is None:
|
|
|
|
|
with torch.no_grad():
|
|
|
|
|
training = backbone.training
|
|
|
|
|
if training:
|
|
|
|
|
backbone.eval()
|
|
|
|
|
o = self.backbone(torch.zeros(1, in_chans, img_size[0], img_size[1]))[-1]
|
|
|
|
|
feature_size = o.shape[-2:]
|
|
|
|
|
feature_dim = o.shape[1]
|
|
|
|
|
backbone.train(training)
|
|
|
|
|
else:
|
|
|
|
|
feature_size = to_2tuple(feature_size)
|
|
|
|
|
feature_dim = self.backbone.feature_info.channels()[-1]
|
|
|
|
|
self.num_patches = feature_size[0] * feature_size[1]
|
|
|
|
|
self.proj = nn.Linear(feature_dim, embed_dim)
|
|
|
|
|
self.apply(self._init_weights)
|
|
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
|
x = self.backbone(x)[-1]
|
|
|
|
|
x = x.flatten(2).transpose(1, 2)
|
|
|
|
|
x = self.proj(x)
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ConViT(nn.Module):
|
|
|
|
|
""" Vision Transformer with support for patch or hybrid CNN input stage
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,
|
|
|
|
|
num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0.,
|
|
|
|
|
drop_path_rate=0., hybrid_backbone=None, norm_layer=nn.LayerNorm, global_pool=None,
|
|
|
|
@ -335,7 +239,7 @@ class ConViT(nn.Module):
|
|
|
|
|
img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
|
|
|
|
|
num_patches = self.patch_embed.num_patches
|
|
|
|
|
self.num_patches = num_patches
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
|
|
|
|
self.pos_drop = nn.Dropout(p=drop_rate)
|
|
|
|
|
|
|
|
|
@ -350,7 +254,7 @@ class ConViT(nn.Module):
|
|
|
|
|
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
|
|
|
|
|
use_gpsa=True,
|
|
|
|
|
locality_strength=locality_strength)
|
|
|
|
|
if i<local_up_to_layer else
|
|
|
|
|
if i < local_up_to_layer else
|
|
|
|
|
Block(
|
|
|
|
|
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
|
|
|
|
|
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
|
|
|
|
@ -363,7 +267,10 @@ class ConViT(nn.Module):
|
|
|
|
|
self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
|
|
|
|
|
|
|
|
|
|
trunc_normal_(self.cls_token, std=.02)
|
|
|
|
|
self.head.apply(self._init_weights)
|
|
|
|
|
self.apply(self._init_weights)
|
|
|
|
|
for n, m in self.named_modules():
|
|
|
|
|
if hasattr(m, 'local_init'):
|
|
|
|
|
m.local_init()
|
|
|
|
|
|
|
|
|
|
def _init_weights(self, m):
|
|
|
|
|
if isinstance(m, nn.Linear):
|
|
|
|
@ -395,8 +302,8 @@ class ConViT(nn.Module):
|
|
|
|
|
x = x + self.pos_embed
|
|
|
|
|
x = self.pos_drop(x)
|
|
|
|
|
|
|
|
|
|
for u,blk in enumerate(self.blocks):
|
|
|
|
|
if u == self.local_up_to_layer :
|
|
|
|
|
for u, blk in enumerate(self.blocks):
|
|
|
|
|
if u == self.local_up_to_layer:
|
|
|
|
|
x = torch.cat((cls_tokens, x), dim=1)
|
|
|
|
|
x = blk(x)
|
|
|
|
|
|
|
|
|
@ -415,30 +322,29 @@ def _create_convit(variant, pretrained=False, **kwargs):
|
|
|
|
|
default_cfg=default_cfgs[variant],
|
|
|
|
|
**kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
def convit_tiny(pretrained=False, **kwargs):
|
|
|
|
|
model_args = dict(
|
|
|
|
|
local_up_to_layer=10, locality_strength=1.0, embed_dim=48,
|
|
|
|
|
local_up_to_layer=10, locality_strength=1.0, embed_dim=48,
|
|
|
|
|
num_heads=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
|
|
|
|
|
model = _create_convit(
|
|
|
|
|
variant='convit_tiny', pretrained=pretrained, **model_args)
|
|
|
|
|
model = _create_convit(variant='convit_tiny', pretrained=pretrained, **model_args)
|
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
def convit_small(pretrained=False, **kwargs):
|
|
|
|
|
model_args = dict(
|
|
|
|
|
local_up_to_layer=10, locality_strength=1.0, embed_dim=48,
|
|
|
|
|
local_up_to_layer=10, locality_strength=1.0, embed_dim=48,
|
|
|
|
|
num_heads=9, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
|
|
|
|
|
model = _create_convit(
|
|
|
|
|
variant='convit_small', pretrained=pretrained, **model_args)
|
|
|
|
|
model = _create_convit(variant='convit_small', pretrained=pretrained, **model_args)
|
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
def convit_base(pretrained=False, **kwargs):
|
|
|
|
|
model_args = dict(
|
|
|
|
|
local_up_to_layer=10, locality_strength=1.0, embed_dim=48,
|
|
|
|
|
local_up_to_layer=10, locality_strength=1.0, embed_dim=48,
|
|
|
|
|
num_heads=16, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
|
|
|
|
|
model = _create_convit(
|
|
|
|
|
variant='convit_base', pretrained=pretrained, **model_args)
|
|
|
|
|
model = _create_convit(variant='convit_base', pretrained=pretrained, **model_args)
|
|
|
|
|
return model
|
|
|
|
|