fix for torch script

pull/841/head
Richard Chen 3 years ago
parent 7ab9d4555c
commit bb50b69a57

@ -26,6 +26,7 @@ import torch.nn as nn
import torch.nn.functional as F
import torch.hub
from functools import partial
from typing import List
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .helpers import build_model_with_cfg
@ -135,23 +136,16 @@ class CrossAttention(nn.Module):
class CrossAttentionBlock(nn.Module):
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, has_mlp=True):
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
super().__init__()
self.norm1 = norm_layer(dim)
self.attn = CrossAttention(
dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.has_mlp = has_mlp
if has_mlp:
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
def forward(self, x):
x = x[:, 0:1, ...] + self.drop_path(self.attn(self.norm1(x)))
if self.has_mlp:
x = x + self.drop_path(self.mlp(self.norm2(x)))
return x
@ -192,14 +186,12 @@ class MultiScaleBlock(nn.Module):
nh = num_heads[d_]
if depth[-1] == 0: # backward capability:
self.fusion.append(CrossAttentionBlock(dim=dim[d_], num_heads=nh, mlp_ratio=mlp_ratio[d], qkv_bias=qkv_bias, qk_scale=qk_scale,
drop=drop, attn_drop=attn_drop, drop_path=drop_path[-1], norm_layer=norm_layer,
has_mlp=False))
drop=drop, attn_drop=attn_drop, drop_path=drop_path[-1], norm_layer=norm_layer))
else:
tmp = []
for _ in range(depth[-1]):
tmp.append(CrossAttentionBlock(dim=dim[d_], num_heads=nh, mlp_ratio=mlp_ratio[d], qkv_bias=qkv_bias, qk_scale=qk_scale,
drop=drop, attn_drop=attn_drop, drop_path=drop_path[-1], norm_layer=norm_layer,
has_mlp=False))
drop=drop, attn_drop=attn_drop, drop_path=drop_path[-1], norm_layer=norm_layer))
self.fusion.append(nn.Sequential(*tmp))
self.revert_projs = nn.ModuleList()
@ -210,16 +202,23 @@ class MultiScaleBlock(nn.Module):
tmp = [norm_layer(dim[(d+1) % num_branches]), act_layer(), nn.Linear(dim[(d+1) % num_branches], dim[d])]
self.revert_projs.append(nn.Sequential(*tmp))
def forward(self, x):
outs_b = [block(x_) for x_, block in zip(x, self.blocks)]
def forward(self, x: List[torch.Tensor]) -> List[torch.Tensor]:
outs_b = []
for i, block in enumerate(self.blocks):
outs_b.append(block(x[i]))
# only take the cls token out
proj_cls_token = [proj(x[:, 0:1]) for x, proj in zip(outs_b, self.projs)]
proj_cls_token = torch.jit.annotate(List[torch.Tensor], [])
for i, proj in enumerate(self.projs):
proj_cls_token.append(proj(outs_b[i][:, 0:1, ...]))
# cross attention
outs = []
for i in range(self.num_branches):
for i, (fusion, revert_proj) in enumerate(zip(self.fusion, self.revert_projs)):
tmp = torch.cat((proj_cls_token[i], outs_b[(i + 1) % self.num_branches][:, 1:, ...]), dim=1)
tmp = self.fusion[i](tmp)
reverted_proj_cls_token = self.revert_projs[i](tmp[:, 0:1, ...])
tmp = fusion(tmp)
reverted_proj_cls_token = revert_proj(tmp[:, 0:1, ...])
tmp = torch.cat((reverted_proj_cls_token, outs_b[i][:, 1:, ...]), dim=1)
outs.append(tmp)
return outs
@ -246,11 +245,15 @@ class CrossViT(nn.Module):
self.num_branches = len(patch_size)
self.patch_embed = nn.ModuleList()
self.pos_embed = nn.ParameterList([nn.Parameter(torch.zeros(1, 1 + num_patches[i], embed_dim[i])) for i in range(self.num_branches)])
# hard-coded for torch jit script
for i in range(self.num_branches):
setattr(self, f'pos_embed_{i}', nn.Parameter(torch.zeros(1, 1 + num_patches[i], embed_dim[i])))
setattr(self, f'cls_token_{i}', nn.Parameter(torch.zeros(1, 1, embed_dim[i])))
for im_s, p, d in zip(img_size, patch_size, embed_dim):
self.patch_embed.append(PatchEmbed(img_size=im_s, patch_size=p, in_chans=in_chans, embed_dim=d, multi_conv=multi_conv))
self.cls_token = nn.ParameterList([nn.Parameter(torch.zeros(1, 1, embed_dim[i])) for i in range(self.num_branches)])
self.pos_drop = nn.Dropout(p=drop_rate)
total_depth = sum([sum(x[-2:]) for x in depth])
@ -270,9 +273,10 @@ class CrossViT(nn.Module):
self.head = nn.ModuleList([nn.Linear(embed_dim[i], num_classes) if num_classes > 0 else nn.Identity() for i in range(self.num_branches)])
for i in range(self.num_branches):
if self.pos_embed[i].requires_grad:
trunc_normal_(self.pos_embed[i], std=.02)
trunc_normal_(self.cls_token[i], std=.02)
if hasattr(self, f'pos_embed_{i}'):
# if self.pos_embed[i].requires_grad:
trunc_normal_(getattr(self, f'pos_embed_{i}'), std=.02)
trunc_normal_(getattr(self, f'cls_token_{i}'), std=.02)
self.apply(self._init_weights)
@ -302,27 +306,29 @@ class CrossViT(nn.Module):
def forward_features(self, x):
B, C, H, W = x.shape
xs = []
for i in range(self.num_branches):
for i, patch_embed in enumerate(self.patch_embed):
x_ = torch.nn.functional.interpolate(x, size=(self.img_size[i], self.img_size[i]), mode='bicubic') if H != self.img_size[i] else x
tmp = self.patch_embed[i](x_)
cls_tokens = self.cls_token[i].expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
tmp = patch_embed(x_)
cls_tokens = self.cls_token_0 if i == 0 else self.cls_token_1 # hard-coded for torch jit script
cls_tokens = cls_tokens.expand(B, -1, -1)
tmp = torch.cat((cls_tokens, tmp), dim=1)
tmp = tmp + self.pos_embed[i]
pos_embed = self.pos_embed_0 if i == 0 else self.pos_embed_1 # hard-coded for torch jit script
tmp = tmp + pos_embed
tmp = self.pos_drop(tmp)
xs.append(tmp)
for blk in self.blocks:
for i, blk in enumerate(self.blocks):
xs = blk(xs)
# NOTE: was before branch token section, move to here to assure all branch token are before layer norm
xs = [self.norm[i](x) for i, x in enumerate(xs)]
xs = [norm(xs[i]) for i, norm in enumerate(self.norm)]
out = [x[:, 0] for x in xs]
return out
def forward(self, x):
xs = self.forward_features(x)
ce_logits = [self.head[i](x) for i, x in enumerate(xs)]
ce_logits = [head(xs[i]) for i, head in enumerate(self.head)]
ce_logits = torch.mean(torch.stack(ce_logits, dim=0), dim=0)
return ce_logits

Loading…
Cancel
Save