From bb50b69a57229a3ee30bbd460539c9a45e508532 Mon Sep 17 00:00:00 2001 From: Richard Chen Date: Wed, 8 Sep 2021 11:20:59 -0400 Subject: [PATCH] fix for torch script --- timm/models/crossvit.py | 66 ++++++++++++++++++++++------------------- 1 file changed, 36 insertions(+), 30 deletions(-) diff --git a/timm/models/crossvit.py b/timm/models/crossvit.py index 6543fe35..0873fdcc 100644 --- a/timm/models/crossvit.py +++ b/timm/models/crossvit.py @@ -26,6 +26,7 @@ import torch.nn as nn import torch.nn.functional as F import torch.hub from functools import partial +from typing import List from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from .helpers import build_model_with_cfg @@ -135,23 +136,16 @@ class CrossAttention(nn.Module): class CrossAttentionBlock(nn.Module): def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., - drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, has_mlp=True): + drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): super().__init__() self.norm1 = norm_layer(dim) self.attn = CrossAttention( dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() - self.has_mlp = has_mlp - if has_mlp: - self.norm2 = norm_layer(dim) - mlp_hidden_dim = int(dim * mlp_ratio) - self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) def forward(self, x): x = x[:, 0:1, ...] + self.drop_path(self.attn(self.norm1(x))) - if self.has_mlp: - x = x + self.drop_path(self.mlp(self.norm2(x))) return x @@ -192,14 +186,12 @@ class MultiScaleBlock(nn.Module): nh = num_heads[d_] if depth[-1] == 0: # backward capability: self.fusion.append(CrossAttentionBlock(dim=dim[d_], num_heads=nh, mlp_ratio=mlp_ratio[d], qkv_bias=qkv_bias, qk_scale=qk_scale, - drop=drop, attn_drop=attn_drop, drop_path=drop_path[-1], norm_layer=norm_layer, - has_mlp=False)) + drop=drop, attn_drop=attn_drop, drop_path=drop_path[-1], norm_layer=norm_layer)) else: tmp = [] for _ in range(depth[-1]): tmp.append(CrossAttentionBlock(dim=dim[d_], num_heads=nh, mlp_ratio=mlp_ratio[d], qkv_bias=qkv_bias, qk_scale=qk_scale, - drop=drop, attn_drop=attn_drop, drop_path=drop_path[-1], norm_layer=norm_layer, - has_mlp=False)) + drop=drop, attn_drop=attn_drop, drop_path=drop_path[-1], norm_layer=norm_layer)) self.fusion.append(nn.Sequential(*tmp)) self.revert_projs = nn.ModuleList() @@ -210,16 +202,23 @@ class MultiScaleBlock(nn.Module): tmp = [norm_layer(dim[(d+1) % num_branches]), act_layer(), nn.Linear(dim[(d+1) % num_branches], dim[d])] self.revert_projs.append(nn.Sequential(*tmp)) - def forward(self, x): - outs_b = [block(x_) for x_, block in zip(x, self.blocks)] + def forward(self, x: List[torch.Tensor]) -> List[torch.Tensor]: + + outs_b = [] + for i, block in enumerate(self.blocks): + outs_b.append(block(x[i])) + # only take the cls token out - proj_cls_token = [proj(x[:, 0:1]) for x, proj in zip(outs_b, self.projs)] + proj_cls_token = torch.jit.annotate(List[torch.Tensor], []) + for i, proj in enumerate(self.projs): + proj_cls_token.append(proj(outs_b[i][:, 0:1, ...])) + # cross attention outs = [] - for i in range(self.num_branches): + for i, (fusion, revert_proj) in enumerate(zip(self.fusion, self.revert_projs)): tmp = torch.cat((proj_cls_token[i], outs_b[(i + 1) % self.num_branches][:, 1:, ...]), dim=1) - tmp = self.fusion[i](tmp) - reverted_proj_cls_token = self.revert_projs[i](tmp[:, 0:1, ...]) + tmp = fusion(tmp) + reverted_proj_cls_token = revert_proj(tmp[:, 0:1, ...]) tmp = torch.cat((reverted_proj_cls_token, outs_b[i][:, 1:, ...]), dim=1) outs.append(tmp) return outs @@ -246,11 +245,15 @@ class CrossViT(nn.Module): self.num_branches = len(patch_size) self.patch_embed = nn.ModuleList() - self.pos_embed = nn.ParameterList([nn.Parameter(torch.zeros(1, 1 + num_patches[i], embed_dim[i])) for i in range(self.num_branches)]) + + # hard-coded for torch jit script + for i in range(self.num_branches): + setattr(self, f'pos_embed_{i}', nn.Parameter(torch.zeros(1, 1 + num_patches[i], embed_dim[i]))) + setattr(self, f'cls_token_{i}', nn.Parameter(torch.zeros(1, 1, embed_dim[i]))) + for im_s, p, d in zip(img_size, patch_size, embed_dim): self.patch_embed.append(PatchEmbed(img_size=im_s, patch_size=p, in_chans=in_chans, embed_dim=d, multi_conv=multi_conv)) - self.cls_token = nn.ParameterList([nn.Parameter(torch.zeros(1, 1, embed_dim[i])) for i in range(self.num_branches)]) self.pos_drop = nn.Dropout(p=drop_rate) total_depth = sum([sum(x[-2:]) for x in depth]) @@ -270,9 +273,10 @@ class CrossViT(nn.Module): self.head = nn.ModuleList([nn.Linear(embed_dim[i], num_classes) if num_classes > 0 else nn.Identity() for i in range(self.num_branches)]) for i in range(self.num_branches): - if self.pos_embed[i].requires_grad: - trunc_normal_(self.pos_embed[i], std=.02) - trunc_normal_(self.cls_token[i], std=.02) + if hasattr(self, f'pos_embed_{i}'): + # if self.pos_embed[i].requires_grad: + trunc_normal_(getattr(self, f'pos_embed_{i}'), std=.02) + trunc_normal_(getattr(self, f'cls_token_{i}'), std=.02) self.apply(self._init_weights) @@ -302,27 +306,29 @@ class CrossViT(nn.Module): def forward_features(self, x): B, C, H, W = x.shape xs = [] - for i in range(self.num_branches): + for i, patch_embed in enumerate(self.patch_embed): x_ = torch.nn.functional.interpolate(x, size=(self.img_size[i], self.img_size[i]), mode='bicubic') if H != self.img_size[i] else x - tmp = self.patch_embed[i](x_) - cls_tokens = self.cls_token[i].expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks + tmp = patch_embed(x_) + cls_tokens = self.cls_token_0 if i == 0 else self.cls_token_1 # hard-coded for torch jit script + cls_tokens = cls_tokens.expand(B, -1, -1) tmp = torch.cat((cls_tokens, tmp), dim=1) - tmp = tmp + self.pos_embed[i] + pos_embed = self.pos_embed_0 if i == 0 else self.pos_embed_1 # hard-coded for torch jit script + tmp = tmp + pos_embed tmp = self.pos_drop(tmp) xs.append(tmp) - for blk in self.blocks: + for i, blk in enumerate(self.blocks): xs = blk(xs) # NOTE: was before branch token section, move to here to assure all branch token are before layer norm - xs = [self.norm[i](x) for i, x in enumerate(xs)] + xs = [norm(xs[i]) for i, norm in enumerate(self.norm)] out = [x[:, 0] for x in xs] return out def forward(self, x): xs = self.forward_features(x) - ce_logits = [self.head[i](x) for i, x in enumerate(xs)] + ce_logits = [head(xs[i]) for i, head in enumerate(self.head)] ce_logits = torch.mean(torch.stack(ce_logits, dim=0), dim=0) return ce_logits