diff --git a/timm/models/crossvit.py b/timm/models/crossvit.py index 12eebdc5..6e0160f9 100644 --- a/timm/models/crossvit.py +++ b/timm/models/crossvit.py @@ -40,7 +40,7 @@ from .vision_transformer import Mlp, Block def _cfg(url='', **kwargs): return { 'url': url, - 'num_classes': 1000, 'input_size': (3, 240, 240), 'pool_size': None, + 'num_classes': 1000, 'input_size': (3, 240, 240), 'pool_size': None, 'crop_pct': 0.875, 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, 'fixed_input_size': True, 'first_conv': ('patch_embed.0.proj', 'patch_embed.1.proj'), 'classifier': ('head.0', 'head.1'), @@ -56,7 +56,7 @@ default_cfgs = { ), 'crossvit_15_dagger_408': _cfg( url='https://github.com/IBM/CrossViT/releases/download/weights-0.1/crossvit_15_dagger_384.pth', - input_size=(3, 408, 408), first_conv=('patch_embed.0.proj.0', 'patch_embed.1.proj.0'), + input_size=(3, 408, 408), first_conv=('patch_embed.0.proj.0', 'patch_embed.1.proj.0'), crop_pct=1.0, ), 'crossvit_18_240': _cfg(url='https://github.com/IBM/CrossViT/releases/download/weights-0.1/crossvit_18_224.pth'), 'crossvit_18_dagger_240': _cfg( @@ -65,7 +65,7 @@ default_cfgs = { ), 'crossvit_18_dagger_408': _cfg( url='https://github.com/IBM/CrossViT/releases/download/weights-0.1/crossvit_18_dagger_384.pth', - input_size=(3, 408, 408), first_conv=('patch_embed.0.proj.0', 'patch_embed.1.proj.0'), + input_size=(3, 408, 408), first_conv=('patch_embed.0.proj.0', 'patch_embed.1.proj.0'), crop_pct=1.0, ), 'crossvit_9_240': _cfg(url='https://github.com/IBM/CrossViT/releases/download/weights-0.1/crossvit_9_224.pth'), 'crossvit_9_dagger_240': _cfg( @@ -263,7 +263,7 @@ class CrossViT(nn.Module): self, img_size=224, img_scale=(1.0, 1.0), patch_size=(8, 16), in_chans=3, num_classes=1000, embed_dim=(192, 384), depth=((1, 3, 1), (1, 3, 1), (1, 3, 1)), num_heads=(6, 12), mlp_ratio=(2., 2., 4.), qkv_bias=True, drop_rate=0., attn_drop_rate=0., drop_path_rate=0., - norm_layer=partial(nn.LayerNorm, eps=1e-6), multi_conv=False + norm_layer=partial(nn.LayerNorm, eps=1e-6), multi_conv=False, crop_scale=False, ): super().__init__() @@ -271,6 +271,7 @@ class CrossViT(nn.Module): self.img_size = to_2tuple(img_size) img_scale = to_2tuple(img_scale) self.img_size_scaled = [tuple([int(sj * si) for sj in self.img_size]) for si in img_scale] + self.crop_scale = crop_scale # crop instead of interpolate for scale num_patches = _compute_num_patches(self.img_size_scaled, patch_size) self.num_branches = len(patch_size) self.embed_dim = embed_dim @@ -307,8 +308,7 @@ class CrossViT(nn.Module): for i in range(self.num_branches)]) for i in range(self.num_branches): - if hasattr(self, f'pos_embed_{i}'): - trunc_normal_(getattr(self, f'pos_embed_{i}'), std=.02) + trunc_normal_(getattr(self, f'pos_embed_{i}'), std=.02) trunc_normal_(getattr(self, f'cls_token_{i}'), std=.02) self.apply(self._init_weights) @@ -324,9 +324,12 @@ class CrossViT(nn.Module): @torch.jit.ignore def no_weight_decay(self): - out = {'cls_token'} - if self.pos_embed[0].requires_grad: - out.add('pos_embed') + out = set() + for i in range(self.num_branches): + out.add(f'cls_token_{i}') + pe = getattr(self, f'pos_embed_{i}', None) + if pe is not None and pe.requires_grad: + out.add(f'pos_embed_{i}') return out def get_classifier(self): @@ -342,23 +345,29 @@ class CrossViT(nn.Module): B, C, H, W = x.shape xs = [] for i, patch_embed in enumerate(self.patch_embed): + x_ = x ss = self.img_size_scaled[i] - x_ = torch.nn.functional.interpolate(x, size=ss, mode='bicubic', align_corners=False) if H != ss[0] else x - tmp = patch_embed(x_) + if H != ss[0] or W != ss[1]: + if self.crop_scale and ss[0] <= H and ss[1] <= W: + cu, cl = int(round((H - ss[0]) / 2.)), int(round((W - ss[1]) / 2.)) + x_ = x_[:, :, cu:cu + ss[0], cl:cl + ss[1]] + else: + x_ = torch.nn.functional.interpolate(x_, size=ss, mode='bicubic', align_corners=False) + x_ = patch_embed(x_) cls_tokens = self.cls_token_0 if i == 0 else self.cls_token_1 # hard-coded for torch jit script cls_tokens = cls_tokens.expand(B, -1, -1) - tmp = torch.cat((cls_tokens, tmp), dim=1) + x_ = torch.cat((cls_tokens, x_), dim=1) pos_embed = self.pos_embed_0 if i == 0 else self.pos_embed_1 # hard-coded for torch jit script - tmp = tmp + pos_embed - tmp = self.pos_drop(tmp) - xs.append(tmp) + x_ = x_ + pos_embed + x_ = self.pos_drop(x_) + xs.append(x_) for i, blk in enumerate(self.blocks): xs = blk(xs) # NOTE: was before branch token section, move to here to assure all branch token are before layer norm xs = [norm(xs[i]) for i, norm in enumerate(self.norm)] - return [x[:, 0] for x in xs] + return [xo[:, 0] for xo in xs] def forward(self, x): xs = self.forward_features(x)