diff --git a/timm/models/vision_transformer.py b/timm/models/vision_transformer.py index 6cefda28..042efc05 100644 --- a/timm/models/vision_transformer.py +++ b/timm/models/vision_transformer.py @@ -107,7 +107,8 @@ class Attention(nn.Module): def forward(self, x): B, N, C = x.shape - q, k, v = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) attn = (q @ k.transpose(-2, -1)) * self.scale attn = attn.softmax(dim=-1) @@ -204,6 +205,9 @@ class VisionTransformer(nn.Module): num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0., drop_path_rate=0., hybrid_backbone=None, norm_layer=nn.LayerNorm): super().__init__() + self.num_classes = num_classes + self.embed_dim = embed_dim + if hybrid_backbone is not None: self.patch_embed = HybridEmbed( hybrid_backbone, img_size=img_size, in_chans=in_chans, embed_dim=embed_dim) @@ -229,7 +233,7 @@ class VisionTransformer(nn.Module): #self.repr_act = nn.Tanh() # Classifier head - self.head = nn.Linear(embed_dim, num_classes) + self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity() trunc_normal_(self.pos_embed, std=.02) trunc_normal_(self.cls_token, std=.02) @@ -244,11 +248,18 @@ class VisionTransformer(nn.Module): nn.init.constant_(m.bias, 0) nn.init.constant_(m.weight, 1.0) - @property + @torch.jit.ignore def no_weight_decay(self): return {'pos_embed', 'cls_token'} - def forward(self, x): + def get_classifier(self): + return self.head + + def reset_classifier(self, num_classes, global_pool=''): + self.num_classes = num_classes + self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() + + def forward_features(self, x): B = x.shape[0] x = self.patch_embed(x) @@ -261,7 +272,11 @@ class VisionTransformer(nn.Module): x = blk(x) x = self.norm(x) - x = self.head(x[:, 0]) + return x[:, 0] + + def forward(self, x): + x = self.forward_features(x) + x = self.head(x) return x @@ -284,7 +299,7 @@ def vit_small_patch16_224(pretrained=False, **kwargs): model.default_cfg = default_cfgs['vit_small_patch16_224'] if pretrained: load_pretrained( - model, num_classes=kwargs.get('num_classes', 0), in_chans=kwargs.get('in_chans', 3), filter_fn=_conv_filter) + model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3), filter_fn=_conv_filter) return model @@ -297,7 +312,7 @@ def vit_base_patch16_224(pretrained=False, **kwargs): model.default_cfg = default_cfgs['vit_base_patch16_224'] if pretrained: load_pretrained( - model, num_classes=kwargs.get('num_classes', 0), in_chans=kwargs.get('in_chans', 3), filter_fn=_conv_filter) + model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3), filter_fn=_conv_filter) return model @@ -308,8 +323,7 @@ def vit_base_patch16_384(pretrained=False, **kwargs): norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) model.default_cfg = default_cfgs['vit_base_patch16_384'] if pretrained: - load_pretrained( - model, num_classes=kwargs.get('num_classes', 0), in_chans=kwargs.get('in_chans', 3)) + load_pretrained(model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3)) return model @@ -320,8 +334,7 @@ def vit_base_patch32_384(pretrained=False, **kwargs): norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) model.default_cfg = default_cfgs['vit_base_patch32_384'] if pretrained: - load_pretrained( - model, num_classes=kwargs.get('num_classes', 0), in_chans=kwargs.get('in_chans', 3)) + load_pretrained(model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3)) return model @@ -339,8 +352,7 @@ def vit_large_patch16_384(pretrained=False, **kwargs): norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) model.default_cfg = default_cfgs['vit_large_patch16_384'] if pretrained: - load_pretrained( - model, num_classes=kwargs.get('num_classes', 0), in_chans=kwargs.get('in_chans', 3)) + load_pretrained(model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3)) return model @@ -351,8 +363,7 @@ def vit_large_patch32_384(pretrained=False, **kwargs): norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) model.default_cfg = default_cfgs['vit_large_patch32_384'] if pretrained: - load_pretrained( - model, num_classes=kwargs.get('num_classes', 0), in_chans=kwargs.get('in_chans', 3)) + load_pretrained(model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3)) return model diff --git a/timm/optim/optim_factory.py b/timm/optim/optim_factory.py index d3592e80..c4a43a2e 100644 --- a/timm/optim/optim_factory.py +++ b/timm/optim/optim_factory.py @@ -43,7 +43,7 @@ def create_optimizer(args, model, filter_bias_and_bn=True): if weight_decay and filter_bias_and_bn: skip = {} if hasattr(model, 'no_weight_decay'): - skip = model.no_weight_decay + skip = model.no_weight_decay() parameters = add_weight_decay(model, weight_decay, skip) weight_decay = 0. else: