From 7ab9d4555c50ff07d61e78e128fbd667816041fd Mon Sep 17 00:00:00 2001 From: Richard Chen Date: Wed, 1 Sep 2021 17:13:12 -0400 Subject: [PATCH 1/4] add crossvit --- tests/test_models.py | 2 +- timm/models/__init__.py | 1 + timm/models/crossvit.py | 443 ++++++++++++++++++++++++++++++++++++++++ 3 files changed, 445 insertions(+), 1 deletion(-) create mode 100644 timm/models/crossvit.py diff --git a/tests/test_models.py b/tests/test_models.py index 50838c8a..d06f306b 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -17,7 +17,7 @@ if hasattr(torch._C, '_jit_set_profiling_executor'): # transformer models don't support many of the spatial / feature based model functionalities NON_STD_FILTERS = [ 'vit_*', 'tnt_*', 'pit_*', 'swin_*', 'coat_*', 'cait_*', '*mixer_*', 'gmlp_*', 'resmlp_*', 'twins_*', - 'convit_*', 'levit*', 'visformer*', 'deit*', 'jx_nest_*', 'nest_*', 'xcit_*'] + 'convit_*', 'levit*', 'visformer*', 'deit*', 'jx_nest_*', 'nest_*', 'xcit_*', 'crossvit_*'] NUM_NON_STD = len(NON_STD_FILTERS) # exclude models that cause specific test failures diff --git a/timm/models/__init__.py b/timm/models/__init__.py index 56c812d4..843e9ae0 100644 --- a/timm/models/__init__.py +++ b/timm/models/__init__.py @@ -1,6 +1,7 @@ from .byoanet import * from .byobnet import * from .cait import * +from .crossvit import * from .coat import * from .convit import * from .cspnet import * diff --git a/timm/models/crossvit.py b/timm/models/crossvit.py new file mode 100644 index 00000000..6543fe35 --- /dev/null +++ b/timm/models/crossvit.py @@ -0,0 +1,443 @@ +""" CrossViT Model + +@inproceedings{ + chen2021crossvit, + title={{CrossViT: Cross-Attention Multi-Scale Vision Transformer for Image Classification}}, + author={Chun-Fu (Richard) Chen and Quanfu Fan and Rameswar Panda}, + booktitle={International Conference on Computer Vision (ICCV)}, + year={2021} +} + +Paper link: https://arxiv.org/abs/2103.14899 +Original code: https://github.com/IBM/CrossViT/blob/main/models/crossvit.py +""" + +# Copyright IBM All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + + +""" +Modifed from Timm. https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py + +""" + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.hub +from functools import partial + +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from .helpers import build_model_with_cfg +from .layers import DropPath, to_2tuple, trunc_normal_ +from .registry import register_model +from .vision_transformer import Mlp, Block + + +def _cfg(url='', **kwargs): + return { + 'url': url, + 'num_classes': 1000, 'input_size': (3, 240, 240), 'pool_size': None, + 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, 'fixed_input_size': True, + # 'first_conv': 'patch_embed.proj', + 'classifier': 'head', + **kwargs + } + + +default_cfgs = { + 'crossvit_15_224': _cfg(url='https://github.com/IBM/CrossViT/releases/download/weights-0.1/crossvit_15_224.pth'), + 'crossvit_15_dagger_224': _cfg(url='https://github.com/IBM/CrossViT/releases/download/weights-0.1/crossvit_15_dagger_224.pth'), + 'crossvit_15_dagger_384': _cfg(url='https://github.com/IBM/CrossViT/releases/download/weights-0.1/crossvit_15_dagger_384.pth'), + 'crossvit_18_224': _cfg(url='https://github.com/IBM/CrossViT/releases/download/weights-0.1/crossvit_18_224.pth'), + 'crossvit_18_dagger_224': _cfg(url='https://github.com/IBM/CrossViT/releases/download/weights-0.1/crossvit_18_dagger_224.pth'), + 'crossvit_18_dagger_384': _cfg(url='https://github.com/IBM/CrossViT/releases/download/weights-0.1/crossvit_18_dagger_384.pth'), + 'crossvit_9_224': _cfg(url='https://github.com/IBM/CrossViT/releases/download/weights-0.1/crossvit_9_224.pth'), + 'crossvit_9_dagger_224': _cfg(url='https://github.com/IBM/CrossViT/releases/download/weights-0.1/crossvit_9_dagger_224.pth'), + 'crossvit_base_224': _cfg(url='https://github.com/IBM/CrossViT/releases/download/weights-0.1/crossvit_base_224.pth'), + 'crossvit_small_224': _cfg(url='https://github.com/IBM/CrossViT/releases/download/weights-0.1/crossvit_small_224.pth'), + 'crossvit_tiny_224': _cfg(url='https://github.com/IBM/CrossViT/releases/download/weights-0.1/crossvit_tiny_224.pth'), +} + + +class PatchEmbed(nn.Module): + """ Image to Patch Embedding + """ + def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, multi_conv=False): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0]) + self.img_size = img_size + self.patch_size = patch_size + self.num_patches = num_patches + if multi_conv: + if patch_size[0] == 12: + self.proj = nn.Sequential( + nn.Conv2d(in_chans, embed_dim // 4, kernel_size=7, stride=4, padding=3), + nn.ReLU(inplace=True), + nn.Conv2d(embed_dim // 4, embed_dim // 2, kernel_size=3, stride=3, padding=0), + nn.ReLU(inplace=True), + nn.Conv2d(embed_dim // 2, embed_dim, kernel_size=3, stride=1, padding=1), + ) + elif patch_size[0] == 16: + self.proj = nn.Sequential( + nn.Conv2d(in_chans, embed_dim // 4, kernel_size=7, stride=4, padding=3), + nn.ReLU(inplace=True), + nn.Conv2d(embed_dim // 4, embed_dim // 2, kernel_size=3, stride=2, padding=1), + nn.ReLU(inplace=True), + nn.Conv2d(embed_dim // 2, embed_dim, kernel_size=3, stride=2, padding=1), + ) + else: + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) + + def forward(self, x): + B, C, H, W = x.shape + # FIXME look at relaxing size constraints + assert H == self.img_size[0] and W == self.img_size[1], \ + f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." + x = self.proj(x).flatten(2).transpose(1, 2) + return x + + +class CrossAttention(nn.Module): + def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights + self.scale = qk_scale or head_dim ** -0.5 + + self.wq = nn.Linear(dim, dim, bias=qkv_bias) + self.wk = nn.Linear(dim, dim, bias=qkv_bias) + self.wv = nn.Linear(dim, dim, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x): + + B, N, C = x.shape + q = self.wq(x[:, 0:1, ...]).reshape(B, 1, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) # B1C -> B1H(C/H) -> BH1(C/H) + k = self.wk(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) # BNC -> BNH(C/H) -> BHN(C/H) + v = self.wv(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) # BNC -> BNH(C/H) -> BHN(C/H) + + attn = (q @ k.transpose(-2, -1)) * self.scale # BH1(C/H) @ BH(C/H)N -> BH1N + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, 1, C) # (BH1N @ BHN(C/H)) -> BH1(C/H) -> B1H(C/H) -> B1C + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class CrossAttentionBlock(nn.Module): + + def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, has_mlp=True): + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = CrossAttention( + dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) + # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.has_mlp = has_mlp + if has_mlp: + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + def forward(self, x): + x = x[:, 0:1, ...] + self.drop_path(self.attn(self.norm1(x))) + if self.has_mlp: + x = x + self.drop_path(self.mlp(self.norm2(x))) + + return x + + +class MultiScaleBlock(nn.Module): + + def __init__(self, dim, patches, depth, num_heads, mlp_ratio, qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): + super().__init__() + + num_branches = len(dim) + self.num_branches = num_branches + # different branch could have different embedding size, the first one is the base + self.blocks = nn.ModuleList() + for d in range(num_branches): + tmp = [] + for i in range(depth[d]): + tmp.append( + Block(dim=dim[d], num_heads=num_heads[d], mlp_ratio=mlp_ratio[d], qkv_bias=qkv_bias, + drop=drop, attn_drop=attn_drop, drop_path=drop_path[i], norm_layer=norm_layer)) + if len(tmp) != 0: + self.blocks.append(nn.Sequential(*tmp)) + + if len(self.blocks) == 0: + self.blocks = None + + self.projs = nn.ModuleList() + for d in range(num_branches): + if dim[d] == dim[(d+1) % num_branches] and False: + tmp = [nn.Identity()] + else: + tmp = [norm_layer(dim[d]), act_layer(), nn.Linear(dim[d], dim[(d+1) % num_branches])] + self.projs.append(nn.Sequential(*tmp)) + + self.fusion = nn.ModuleList() + for d in range(num_branches): + d_ = (d+1) % num_branches + nh = num_heads[d_] + if depth[-1] == 0: # backward capability: + self.fusion.append(CrossAttentionBlock(dim=dim[d_], num_heads=nh, mlp_ratio=mlp_ratio[d], qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop, attn_drop=attn_drop, drop_path=drop_path[-1], norm_layer=norm_layer, + has_mlp=False)) + else: + tmp = [] + for _ in range(depth[-1]): + tmp.append(CrossAttentionBlock(dim=dim[d_], num_heads=nh, mlp_ratio=mlp_ratio[d], qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop, attn_drop=attn_drop, drop_path=drop_path[-1], norm_layer=norm_layer, + has_mlp=False)) + self.fusion.append(nn.Sequential(*tmp)) + + self.revert_projs = nn.ModuleList() + for d in range(num_branches): + if dim[(d+1) % num_branches] == dim[d] and False: + tmp = [nn.Identity()] + else: + tmp = [norm_layer(dim[(d+1) % num_branches]), act_layer(), nn.Linear(dim[(d+1) % num_branches], dim[d])] + self.revert_projs.append(nn.Sequential(*tmp)) + + def forward(self, x): + outs_b = [block(x_) for x_, block in zip(x, self.blocks)] + # only take the cls token out + proj_cls_token = [proj(x[:, 0:1]) for x, proj in zip(outs_b, self.projs)] + # cross attention + outs = [] + for i in range(self.num_branches): + tmp = torch.cat((proj_cls_token[i], outs_b[(i + 1) % self.num_branches][:, 1:, ...]), dim=1) + tmp = self.fusion[i](tmp) + reverted_proj_cls_token = self.revert_projs[i](tmp[:, 0:1, ...]) + tmp = torch.cat((reverted_proj_cls_token, outs_b[i][:, 1:, ...]), dim=1) + outs.append(tmp) + return outs + + +def _compute_num_patches(img_size, patches): + return [i // p * i // p for i, p in zip(img_size,patches)] + + +class CrossViT(nn.Module): + """ Vision Transformer with support for patch or hybrid CNN input stage + """ + def __init__(self, img_size=(224, 224), patch_size=(8, 16), in_chans=3, num_classes=1000, embed_dim=(192, 384), depth=([1, 3, 1], [1, 3, 1], [1, 3, 1]), + num_heads=(6, 12), mlp_ratio=(2., 2., 4.), qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0., + drop_path_rate=0., norm_layer=nn.LayerNorm, multi_conv=False): + super().__init__() + + self.num_classes = num_classes + if not isinstance(img_size, list): + img_size = to_2tuple(img_size) + self.img_size = img_size + + num_patches = _compute_num_patches(img_size, patch_size) + self.num_branches = len(patch_size) + + self.patch_embed = nn.ModuleList() + self.pos_embed = nn.ParameterList([nn.Parameter(torch.zeros(1, 1 + num_patches[i], embed_dim[i])) for i in range(self.num_branches)]) + for im_s, p, d in zip(img_size, patch_size, embed_dim): + self.patch_embed.append(PatchEmbed(img_size=im_s, patch_size=p, in_chans=in_chans, embed_dim=d, multi_conv=multi_conv)) + + self.cls_token = nn.ParameterList([nn.Parameter(torch.zeros(1, 1, embed_dim[i])) for i in range(self.num_branches)]) + self.pos_drop = nn.Dropout(p=drop_rate) + + total_depth = sum([sum(x[-2:]) for x in depth]) + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, total_depth)] # stochastic depth decay rule + dpr_ptr = 0 + self.blocks = nn.ModuleList() + for idx, block_cfg in enumerate(depth): + curr_depth = max(block_cfg[:-1]) + block_cfg[-1] + dpr_ = dpr[dpr_ptr:dpr_ptr + curr_depth] + blk = MultiScaleBlock(embed_dim, num_patches, block_cfg, num_heads=num_heads, mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr_, + norm_layer=norm_layer) + dpr_ptr += curr_depth + self.blocks.append(blk) + + self.norm = nn.ModuleList([norm_layer(embed_dim[i]) for i in range(self.num_branches)]) + self.head = nn.ModuleList([nn.Linear(embed_dim[i], num_classes) if num_classes > 0 else nn.Identity() for i in range(self.num_branches)]) + + for i in range(self.num_branches): + if self.pos_embed[i].requires_grad: + trunc_normal_(self.pos_embed[i], std=.02) + trunc_normal_(self.cls_token[i], std=.02) + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + @torch.jit.ignore + def no_weight_decay(self): + out = {'cls_token'} + if self.pos_embed[0].requires_grad: + out.add('pos_embed') + return out + + def get_classifier(self): + return self.head + + def reset_classifier(self, num_classes, global_pool=''): + self.num_classes = num_classes + self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() + + def forward_features(self, x): + B, C, H, W = x.shape + xs = [] + for i in range(self.num_branches): + x_ = torch.nn.functional.interpolate(x, size=(self.img_size[i], self.img_size[i]), mode='bicubic') if H != self.img_size[i] else x + tmp = self.patch_embed[i](x_) + cls_tokens = self.cls_token[i].expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks + tmp = torch.cat((cls_tokens, tmp), dim=1) + tmp = tmp + self.pos_embed[i] + tmp = self.pos_drop(tmp) + xs.append(tmp) + + for blk in self.blocks: + xs = blk(xs) + + # NOTE: was before branch token section, move to here to assure all branch token are before layer norm + xs = [self.norm[i](x) for i, x in enumerate(xs)] + out = [x[:, 0] for x in xs] + + return out + + def forward(self, x): + xs = self.forward_features(x) + ce_logits = [self.head[i](x) for i, x in enumerate(xs)] + ce_logits = torch.mean(torch.stack(ce_logits, dim=0), dim=0) + return ce_logits + + +def _create_crossvit(variant, pretrained=False, **kwargs): + if kwargs.get('features_only', None): + raise RuntimeError('features_only not implemented for Vision Transformer models.') + + return build_model_with_cfg( + CrossViT, variant, pretrained, + default_cfg=default_cfgs[variant], + **kwargs) + + +@register_model +def crossvit_tiny_224(pretrained=False, **kwargs): + model_args = dict( + img_size=[240, 224], patch_size=[12, 16], embed_dim=[96, 192], depth=[[1, 4, 0], [1, 4, 0], [1, 4, 0]], + num_heads=[3, 3], mlp_ratio=[4, 4, 1], qkv_bias=True, + norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) + model = _create_crossvit(variant='crossvit_tiny_224', pretrained=pretrained, **model_args) + return model + + +@register_model +def crossvit_small_224(pretrained=False, **kwargs): + model_args = dict(img_size=[240, 224], + patch_size=[12, 16], embed_dim=[192, 384], depth=[[1, 4, 0], [1, 4, 0], [1, 4, 0]], + num_heads=[6, 6], mlp_ratio=[4, 4, 1], qkv_bias=True, + norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) + model = _create_crossvit(variant='crossvit_small_224', pretrained=pretrained, **model_args) + return model + + +@register_model +def crossvit_base_224(pretrained=False, **kwargs): + model_args = dict(img_size=[240, 224], + patch_size=[12, 16], embed_dim=[384, 768], depth=[[1, 4, 0], [1, 4, 0], [1, 4, 0]], + num_heads=[12, 12], mlp_ratio=[4, 4, 1], qkv_bias=True, + norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) + model = _create_crossvit(variant='crossvit_base_224', pretrained=pretrained, **model_args) + return model + + +@register_model +def crossvit_9_224(pretrained=False, **kwargs): + model_args = dict(img_size=[240, 224], + patch_size=[12, 16], embed_dim=[128, 256], depth=[[1, 3, 0], [1, 3, 0], [1, 3, 0]], + num_heads=[4, 4], mlp_ratio=[3, 3, 1], qkv_bias=True, + norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) + model = _create_crossvit(variant='crossvit_9_224', pretrained=pretrained, **model_args) + return model + + +@register_model +def crossvit_15_224(pretrained=False, **kwargs): + model_args = dict(img_size=[240, 224], + patch_size=[12, 16], embed_dim=[192, 384], depth=[[1, 5, 0], [1, 5, 0], [1, 5, 0]], + num_heads=[6, 6], mlp_ratio=[3, 3, 1], qkv_bias=True, + norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) + model = _create_crossvit(variant='crossvit_15_224', pretrained=pretrained, **model_args) + return model + + +@register_model +def crossvit_18_224(pretrained=False, **kwargs): + model_args = dict(img_size=[240, 224], + patch_size=[12, 16], embed_dim=[224, 448], depth=[[1, 6, 0], [1, 6, 0], [1, 6, 0]], + num_heads=[7, 7], mlp_ratio=[3, 3, 1], qkv_bias=True, + norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) + model = _create_crossvit(variant='crossvit_18_224', pretrained=pretrained, **model_args) + return model + + +@register_model +def crossvit_9_dagger_224(pretrained=False, **kwargs): + model_args = dict(img_size=[240, 224], + patch_size=[12, 16], embed_dim=[128, 256], depth=[[1, 3, 0], [1, 3, 0], [1, 3, 0]], + num_heads=[4, 4], mlp_ratio=[3, 3, 1], qkv_bias=True, + norm_layer=partial(nn.LayerNorm, eps=1e-6), multi_conv=True, **kwargs) + model = _create_crossvit(variant='crossvit_9_dagger_224', pretrained=pretrained, **model_args) + return model + +@register_model +def crossvit_15_dagger_224(pretrained=False, **kwargs): + model_args = dict(img_size=[240, 224], + patch_size=[12, 16], embed_dim=[192, 384], depth=[[1, 5, 0], [1, 5, 0], [1, 5, 0]], + num_heads=[6, 6], mlp_ratio=[3, 3, 1], qkv_bias=True, + norm_layer=partial(nn.LayerNorm, eps=1e-6), multi_conv=True, **kwargs) + model = _create_crossvit(variant='crossvit_15_dagger_224', pretrained=pretrained, **model_args) + return model + +@register_model +def crossvit_15_dagger_384(pretrained=False, **kwargs): + model_args = dict(img_size=[408, 384], + patch_size=[12, 16], embed_dim=[192, 384], depth=[[1, 5, 0], [1, 5, 0], [1, 5, 0]], + num_heads=[6, 6], mlp_ratio=[3, 3, 1], qkv_bias=True, + norm_layer=partial(nn.LayerNorm, eps=1e-6), multi_conv=True, **kwargs) + model = _create_crossvit(variant='crossvit_15_dagger_384', pretrained=pretrained, **model_args) + return model + +@register_model +def crossvit_18_dagger_224(pretrained=False, **kwargs): + model_args = dict(img_size=[240, 224], + patch_size=[12, 16], embed_dim=[224, 448], depth=[[1, 6, 0], [1, 6, 0], [1, 6, 0]], + num_heads=[7, 7], mlp_ratio=[3, 3, 1], qkv_bias=True, + norm_layer=partial(nn.LayerNorm, eps=1e-6), multi_conv=True, **kwargs) + model = _create_crossvit(variant='crossvit_18_dagger_224', pretrained=pretrained, **model_args) + return model + +@register_model +def crossvit_18_dagger_384(pretrained=False, **kwargs): + model_args = dict(img_size=[408, 384], + patch_size=[12, 16], embed_dim=[224, 448], depth=[[1, 6, 0], [1, 6, 0], [1, 6, 0]], + num_heads=[7, 7], mlp_ratio=[3, 3, 1], qkv_bias=True, + norm_layer=partial(nn.LayerNorm, eps=1e-6), multi_conv=True, **kwargs) + model = _create_crossvit(variant='crossvit_18_dagger_384', pretrained=pretrained, **model_args) + return model From bb50b69a57229a3ee30bbd460539c9a45e508532 Mon Sep 17 00:00:00 2001 From: Richard Chen Date: Wed, 8 Sep 2021 11:20:59 -0400 Subject: [PATCH 2/4] fix for torch script --- timm/models/crossvit.py | 66 ++++++++++++++++++++++------------------- 1 file changed, 36 insertions(+), 30 deletions(-) diff --git a/timm/models/crossvit.py b/timm/models/crossvit.py index 6543fe35..0873fdcc 100644 --- a/timm/models/crossvit.py +++ b/timm/models/crossvit.py @@ -26,6 +26,7 @@ import torch.nn as nn import torch.nn.functional as F import torch.hub from functools import partial +from typing import List from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from .helpers import build_model_with_cfg @@ -135,23 +136,16 @@ class CrossAttention(nn.Module): class CrossAttentionBlock(nn.Module): def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., - drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, has_mlp=True): + drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): super().__init__() self.norm1 = norm_layer(dim) self.attn = CrossAttention( dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() - self.has_mlp = has_mlp - if has_mlp: - self.norm2 = norm_layer(dim) - mlp_hidden_dim = int(dim * mlp_ratio) - self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) def forward(self, x): x = x[:, 0:1, ...] + self.drop_path(self.attn(self.norm1(x))) - if self.has_mlp: - x = x + self.drop_path(self.mlp(self.norm2(x))) return x @@ -192,14 +186,12 @@ class MultiScaleBlock(nn.Module): nh = num_heads[d_] if depth[-1] == 0: # backward capability: self.fusion.append(CrossAttentionBlock(dim=dim[d_], num_heads=nh, mlp_ratio=mlp_ratio[d], qkv_bias=qkv_bias, qk_scale=qk_scale, - drop=drop, attn_drop=attn_drop, drop_path=drop_path[-1], norm_layer=norm_layer, - has_mlp=False)) + drop=drop, attn_drop=attn_drop, drop_path=drop_path[-1], norm_layer=norm_layer)) else: tmp = [] for _ in range(depth[-1]): tmp.append(CrossAttentionBlock(dim=dim[d_], num_heads=nh, mlp_ratio=mlp_ratio[d], qkv_bias=qkv_bias, qk_scale=qk_scale, - drop=drop, attn_drop=attn_drop, drop_path=drop_path[-1], norm_layer=norm_layer, - has_mlp=False)) + drop=drop, attn_drop=attn_drop, drop_path=drop_path[-1], norm_layer=norm_layer)) self.fusion.append(nn.Sequential(*tmp)) self.revert_projs = nn.ModuleList() @@ -210,16 +202,23 @@ class MultiScaleBlock(nn.Module): tmp = [norm_layer(dim[(d+1) % num_branches]), act_layer(), nn.Linear(dim[(d+1) % num_branches], dim[d])] self.revert_projs.append(nn.Sequential(*tmp)) - def forward(self, x): - outs_b = [block(x_) for x_, block in zip(x, self.blocks)] + def forward(self, x: List[torch.Tensor]) -> List[torch.Tensor]: + + outs_b = [] + for i, block in enumerate(self.blocks): + outs_b.append(block(x[i])) + # only take the cls token out - proj_cls_token = [proj(x[:, 0:1]) for x, proj in zip(outs_b, self.projs)] + proj_cls_token = torch.jit.annotate(List[torch.Tensor], []) + for i, proj in enumerate(self.projs): + proj_cls_token.append(proj(outs_b[i][:, 0:1, ...])) + # cross attention outs = [] - for i in range(self.num_branches): + for i, (fusion, revert_proj) in enumerate(zip(self.fusion, self.revert_projs)): tmp = torch.cat((proj_cls_token[i], outs_b[(i + 1) % self.num_branches][:, 1:, ...]), dim=1) - tmp = self.fusion[i](tmp) - reverted_proj_cls_token = self.revert_projs[i](tmp[:, 0:1, ...]) + tmp = fusion(tmp) + reverted_proj_cls_token = revert_proj(tmp[:, 0:1, ...]) tmp = torch.cat((reverted_proj_cls_token, outs_b[i][:, 1:, ...]), dim=1) outs.append(tmp) return outs @@ -246,11 +245,15 @@ class CrossViT(nn.Module): self.num_branches = len(patch_size) self.patch_embed = nn.ModuleList() - self.pos_embed = nn.ParameterList([nn.Parameter(torch.zeros(1, 1 + num_patches[i], embed_dim[i])) for i in range(self.num_branches)]) + + # hard-coded for torch jit script + for i in range(self.num_branches): + setattr(self, f'pos_embed_{i}', nn.Parameter(torch.zeros(1, 1 + num_patches[i], embed_dim[i]))) + setattr(self, f'cls_token_{i}', nn.Parameter(torch.zeros(1, 1, embed_dim[i]))) + for im_s, p, d in zip(img_size, patch_size, embed_dim): self.patch_embed.append(PatchEmbed(img_size=im_s, patch_size=p, in_chans=in_chans, embed_dim=d, multi_conv=multi_conv)) - self.cls_token = nn.ParameterList([nn.Parameter(torch.zeros(1, 1, embed_dim[i])) for i in range(self.num_branches)]) self.pos_drop = nn.Dropout(p=drop_rate) total_depth = sum([sum(x[-2:]) for x in depth]) @@ -270,9 +273,10 @@ class CrossViT(nn.Module): self.head = nn.ModuleList([nn.Linear(embed_dim[i], num_classes) if num_classes > 0 else nn.Identity() for i in range(self.num_branches)]) for i in range(self.num_branches): - if self.pos_embed[i].requires_grad: - trunc_normal_(self.pos_embed[i], std=.02) - trunc_normal_(self.cls_token[i], std=.02) + if hasattr(self, f'pos_embed_{i}'): + # if self.pos_embed[i].requires_grad: + trunc_normal_(getattr(self, f'pos_embed_{i}'), std=.02) + trunc_normal_(getattr(self, f'cls_token_{i}'), std=.02) self.apply(self._init_weights) @@ -302,27 +306,29 @@ class CrossViT(nn.Module): def forward_features(self, x): B, C, H, W = x.shape xs = [] - for i in range(self.num_branches): + for i, patch_embed in enumerate(self.patch_embed): x_ = torch.nn.functional.interpolate(x, size=(self.img_size[i], self.img_size[i]), mode='bicubic') if H != self.img_size[i] else x - tmp = self.patch_embed[i](x_) - cls_tokens = self.cls_token[i].expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks + tmp = patch_embed(x_) + cls_tokens = self.cls_token_0 if i == 0 else self.cls_token_1 # hard-coded for torch jit script + cls_tokens = cls_tokens.expand(B, -1, -1) tmp = torch.cat((cls_tokens, tmp), dim=1) - tmp = tmp + self.pos_embed[i] + pos_embed = self.pos_embed_0 if i == 0 else self.pos_embed_1 # hard-coded for torch jit script + tmp = tmp + pos_embed tmp = self.pos_drop(tmp) xs.append(tmp) - for blk in self.blocks: + for i, blk in enumerate(self.blocks): xs = blk(xs) # NOTE: was before branch token section, move to here to assure all branch token are before layer norm - xs = [self.norm[i](x) for i, x in enumerate(xs)] + xs = [norm(xs[i]) for i, norm in enumerate(self.norm)] out = [x[:, 0] for x in xs] return out def forward(self, x): xs = self.forward_features(x) - ce_logits = [self.head[i](x) for i, x in enumerate(xs)] + ce_logits = [head(xs[i]) for i, head in enumerate(self.head)] ce_logits = torch.mean(torch.stack(ce_logits, dim=0), dim=0) return ce_logits From 3718c5a5bd4fd814b13a73c4cb3116d91337b7ff Mon Sep 17 00:00:00 2001 From: Richard Chen Date: Wed, 8 Sep 2021 11:53:05 -0400 Subject: [PATCH 3/4] fix loading pretrained model --- timm/models/crossvit.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/timm/models/crossvit.py b/timm/models/crossvit.py index 0873fdcc..f9296b74 100644 --- a/timm/models/crossvit.py +++ b/timm/models/crossvit.py @@ -337,11 +337,23 @@ def _create_crossvit(variant, pretrained=False, **kwargs): if kwargs.get('features_only', None): raise RuntimeError('features_only not implemented for Vision Transformer models.') + def pretrained_filter_fn(state_dict): + new_state_dict = {} + for key in state_dict.keys(): + if 'pos_embed' in key or 'cls_token' in key: + new_key = key.replace(".", "_") + else: + new_key = key + new_state_dict[new_key] = state_dict[key] + return new_state_dict + return build_model_with_cfg( CrossViT, variant, pretrained, default_cfg=default_cfgs[variant], + pretrained_filter_fn=pretrained_filter_fn, **kwargs) + @register_model def crossvit_tiny_224(pretrained=False, **kwargs): From 9fe5798beea786574fb52bf8a71174861f93e600 Mon Sep 17 00:00:00 2001 From: Richard Chen Date: Wed, 8 Sep 2021 21:58:17 -0400 Subject: [PATCH 4/4] fix bug for reset classifier and fix for validating the dimension --- timm/models/crossvit.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/timm/models/crossvit.py b/timm/models/crossvit.py index f9296b74..ff529064 100644 --- a/timm/models/crossvit.py +++ b/timm/models/crossvit.py @@ -243,7 +243,8 @@ class CrossViT(nn.Module): num_patches = _compute_num_patches(img_size, patch_size) self.num_branches = len(patch_size) - + self.embed_dim = embed_dim + self.num_features = embed_dim[0] # to pass the tests self.patch_embed = nn.ModuleList() # hard-coded for torch jit script @@ -274,7 +275,6 @@ class CrossViT(nn.Module): for i in range(self.num_branches): if hasattr(self, f'pos_embed_{i}'): - # if self.pos_embed[i].requires_grad: trunc_normal_(getattr(self, f'pos_embed_{i}'), std=.02) trunc_normal_(getattr(self, f'cls_token_{i}'), std=.02) @@ -301,7 +301,7 @@ class CrossViT(nn.Module): def reset_classifier(self, num_classes, global_pool=''): self.num_classes = num_classes - self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() + self.head = nn.ModuleList([nn.Linear(self.embed_dim[i], num_classes) if num_classes > 0 else nn.Identity() for i in range(self.num_branches)]) def forward_features(self, x): B, C, H, W = x.shape