|
|
|
@ -26,6 +26,7 @@ import torch.nn as nn
|
|
|
|
|
import torch.nn.functional as F
|
|
|
|
|
import torch.hub
|
|
|
|
|
from functools import partial
|
|
|
|
|
from typing import List
|
|
|
|
|
|
|
|
|
|
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
|
|
|
|
from .helpers import build_model_with_cfg
|
|
|
|
@ -135,23 +136,16 @@ class CrossAttention(nn.Module):
|
|
|
|
|
class CrossAttentionBlock(nn.Module):
|
|
|
|
|
|
|
|
|
|
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
|
|
|
|
|
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, has_mlp=True):
|
|
|
|
|
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
|
|
|
|
|
super().__init__()
|
|
|
|
|
self.norm1 = norm_layer(dim)
|
|
|
|
|
self.attn = CrossAttention(
|
|
|
|
|
dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
|
|
|
|
|
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
|
|
|
|
|
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
|
|
|
|
self.has_mlp = has_mlp
|
|
|
|
|
if has_mlp:
|
|
|
|
|
self.norm2 = norm_layer(dim)
|
|
|
|
|
mlp_hidden_dim = int(dim * mlp_ratio)
|
|
|
|
|
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
|
|
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
|
x = x[:, 0:1, ...] + self.drop_path(self.attn(self.norm1(x)))
|
|
|
|
|
if self.has_mlp:
|
|
|
|
|
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
|
|
|
|
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
@ -192,14 +186,12 @@ class MultiScaleBlock(nn.Module):
|
|
|
|
|
nh = num_heads[d_]
|
|
|
|
|
if depth[-1] == 0: # backward capability:
|
|
|
|
|
self.fusion.append(CrossAttentionBlock(dim=dim[d_], num_heads=nh, mlp_ratio=mlp_ratio[d], qkv_bias=qkv_bias, qk_scale=qk_scale,
|
|
|
|
|
drop=drop, attn_drop=attn_drop, drop_path=drop_path[-1], norm_layer=norm_layer,
|
|
|
|
|
has_mlp=False))
|
|
|
|
|
drop=drop, attn_drop=attn_drop, drop_path=drop_path[-1], norm_layer=norm_layer))
|
|
|
|
|
else:
|
|
|
|
|
tmp = []
|
|
|
|
|
for _ in range(depth[-1]):
|
|
|
|
|
tmp.append(CrossAttentionBlock(dim=dim[d_], num_heads=nh, mlp_ratio=mlp_ratio[d], qkv_bias=qkv_bias, qk_scale=qk_scale,
|
|
|
|
|
drop=drop, attn_drop=attn_drop, drop_path=drop_path[-1], norm_layer=norm_layer,
|
|
|
|
|
has_mlp=False))
|
|
|
|
|
drop=drop, attn_drop=attn_drop, drop_path=drop_path[-1], norm_layer=norm_layer))
|
|
|
|
|
self.fusion.append(nn.Sequential(*tmp))
|
|
|
|
|
|
|
|
|
|
self.revert_projs = nn.ModuleList()
|
|
|
|
@ -210,16 +202,23 @@ class MultiScaleBlock(nn.Module):
|
|
|
|
|
tmp = [norm_layer(dim[(d+1) % num_branches]), act_layer(), nn.Linear(dim[(d+1) % num_branches], dim[d])]
|
|
|
|
|
self.revert_projs.append(nn.Sequential(*tmp))
|
|
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
|
outs_b = [block(x_) for x_, block in zip(x, self.blocks)]
|
|
|
|
|
def forward(self, x: List[torch.Tensor]) -> List[torch.Tensor]:
|
|
|
|
|
|
|
|
|
|
outs_b = []
|
|
|
|
|
for i, block in enumerate(self.blocks):
|
|
|
|
|
outs_b.append(block(x[i]))
|
|
|
|
|
|
|
|
|
|
# only take the cls token out
|
|
|
|
|
proj_cls_token = [proj(x[:, 0:1]) for x, proj in zip(outs_b, self.projs)]
|
|
|
|
|
proj_cls_token = torch.jit.annotate(List[torch.Tensor], [])
|
|
|
|
|
for i, proj in enumerate(self.projs):
|
|
|
|
|
proj_cls_token.append(proj(outs_b[i][:, 0:1, ...]))
|
|
|
|
|
|
|
|
|
|
# cross attention
|
|
|
|
|
outs = []
|
|
|
|
|
for i in range(self.num_branches):
|
|
|
|
|
for i, (fusion, revert_proj) in enumerate(zip(self.fusion, self.revert_projs)):
|
|
|
|
|
tmp = torch.cat((proj_cls_token[i], outs_b[(i + 1) % self.num_branches][:, 1:, ...]), dim=1)
|
|
|
|
|
tmp = self.fusion[i](tmp)
|
|
|
|
|
reverted_proj_cls_token = self.revert_projs[i](tmp[:, 0:1, ...])
|
|
|
|
|
tmp = fusion(tmp)
|
|
|
|
|
reverted_proj_cls_token = revert_proj(tmp[:, 0:1, ...])
|
|
|
|
|
tmp = torch.cat((reverted_proj_cls_token, outs_b[i][:, 1:, ...]), dim=1)
|
|
|
|
|
outs.append(tmp)
|
|
|
|
|
return outs
|
|
|
|
@ -246,11 +245,15 @@ class CrossViT(nn.Module):
|
|
|
|
|
self.num_branches = len(patch_size)
|
|
|
|
|
|
|
|
|
|
self.patch_embed = nn.ModuleList()
|
|
|
|
|
self.pos_embed = nn.ParameterList([nn.Parameter(torch.zeros(1, 1 + num_patches[i], embed_dim[i])) for i in range(self.num_branches)])
|
|
|
|
|
|
|
|
|
|
# hard-coded for torch jit script
|
|
|
|
|
for i in range(self.num_branches):
|
|
|
|
|
setattr(self, f'pos_embed_{i}', nn.Parameter(torch.zeros(1, 1 + num_patches[i], embed_dim[i])))
|
|
|
|
|
setattr(self, f'cls_token_{i}', nn.Parameter(torch.zeros(1, 1, embed_dim[i])))
|
|
|
|
|
|
|
|
|
|
for im_s, p, d in zip(img_size, patch_size, embed_dim):
|
|
|
|
|
self.patch_embed.append(PatchEmbed(img_size=im_s, patch_size=p, in_chans=in_chans, embed_dim=d, multi_conv=multi_conv))
|
|
|
|
|
|
|
|
|
|
self.cls_token = nn.ParameterList([nn.Parameter(torch.zeros(1, 1, embed_dim[i])) for i in range(self.num_branches)])
|
|
|
|
|
self.pos_drop = nn.Dropout(p=drop_rate)
|
|
|
|
|
|
|
|
|
|
total_depth = sum([sum(x[-2:]) for x in depth])
|
|
|
|
@ -270,9 +273,10 @@ class CrossViT(nn.Module):
|
|
|
|
|
self.head = nn.ModuleList([nn.Linear(embed_dim[i], num_classes) if num_classes > 0 else nn.Identity() for i in range(self.num_branches)])
|
|
|
|
|
|
|
|
|
|
for i in range(self.num_branches):
|
|
|
|
|
if self.pos_embed[i].requires_grad:
|
|
|
|
|
trunc_normal_(self.pos_embed[i], std=.02)
|
|
|
|
|
trunc_normal_(self.cls_token[i], std=.02)
|
|
|
|
|
if hasattr(self, f'pos_embed_{i}'):
|
|
|
|
|
# if self.pos_embed[i].requires_grad:
|
|
|
|
|
trunc_normal_(getattr(self, f'pos_embed_{i}'), std=.02)
|
|
|
|
|
trunc_normal_(getattr(self, f'cls_token_{i}'), std=.02)
|
|
|
|
|
|
|
|
|
|
self.apply(self._init_weights)
|
|
|
|
|
|
|
|
|
@ -302,27 +306,29 @@ class CrossViT(nn.Module):
|
|
|
|
|
def forward_features(self, x):
|
|
|
|
|
B, C, H, W = x.shape
|
|
|
|
|
xs = []
|
|
|
|
|
for i in range(self.num_branches):
|
|
|
|
|
for i, patch_embed in enumerate(self.patch_embed):
|
|
|
|
|
x_ = torch.nn.functional.interpolate(x, size=(self.img_size[i], self.img_size[i]), mode='bicubic') if H != self.img_size[i] else x
|
|
|
|
|
tmp = self.patch_embed[i](x_)
|
|
|
|
|
cls_tokens = self.cls_token[i].expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
|
|
|
|
|
tmp = patch_embed(x_)
|
|
|
|
|
cls_tokens = self.cls_token_0 if i == 0 else self.cls_token_1 # hard-coded for torch jit script
|
|
|
|
|
cls_tokens = cls_tokens.expand(B, -1, -1)
|
|
|
|
|
tmp = torch.cat((cls_tokens, tmp), dim=1)
|
|
|
|
|
tmp = tmp + self.pos_embed[i]
|
|
|
|
|
pos_embed = self.pos_embed_0 if i == 0 else self.pos_embed_1 # hard-coded for torch jit script
|
|
|
|
|
tmp = tmp + pos_embed
|
|
|
|
|
tmp = self.pos_drop(tmp)
|
|
|
|
|
xs.append(tmp)
|
|
|
|
|
|
|
|
|
|
for blk in self.blocks:
|
|
|
|
|
for i, blk in enumerate(self.blocks):
|
|
|
|
|
xs = blk(xs)
|
|
|
|
|
|
|
|
|
|
# NOTE: was before branch token section, move to here to assure all branch token are before layer norm
|
|
|
|
|
xs = [self.norm[i](x) for i, x in enumerate(xs)]
|
|
|
|
|
xs = [norm(xs[i]) for i, norm in enumerate(self.norm)]
|
|
|
|
|
out = [x[:, 0] for x in xs]
|
|
|
|
|
|
|
|
|
|
return out
|
|
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
|
xs = self.forward_features(x)
|
|
|
|
|
ce_logits = [self.head[i](x) for i, x in enumerate(xs)]
|
|
|
|
|
ce_logits = [head(xs[i]) for i, head in enumerate(self.head)]
|
|
|
|
|
ce_logits = torch.mean(torch.stack(ce_logits, dim=0), dim=0)
|
|
|
|
|
return ce_logits
|
|
|
|
|
|
|
|
|
|