Fix #262, num_classes arg mixup. Make vision_transformers a bit closer to other models wrt get/reset classfier/forward_features. Fix torchscript for ViT.

pull/268/head
Ross Wightman 4 years ago
parent da1b90e5c9
commit f944242cb0

@ -107,7 +107,8 @@ class Attention(nn.Module):
def forward(self, x): def forward(self, x):
B, N, C = x.shape B, N, C = x.shape
q, k, v = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
attn = (q @ k.transpose(-2, -1)) * self.scale attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1) attn = attn.softmax(dim=-1)
@ -204,6 +205,9 @@ class VisionTransformer(nn.Module):
num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0., num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0.,
drop_path_rate=0., hybrid_backbone=None, norm_layer=nn.LayerNorm): drop_path_rate=0., hybrid_backbone=None, norm_layer=nn.LayerNorm):
super().__init__() super().__init__()
self.num_classes = num_classes
self.embed_dim = embed_dim
if hybrid_backbone is not None: if hybrid_backbone is not None:
self.patch_embed = HybridEmbed( self.patch_embed = HybridEmbed(
hybrid_backbone, img_size=img_size, in_chans=in_chans, embed_dim=embed_dim) hybrid_backbone, img_size=img_size, in_chans=in_chans, embed_dim=embed_dim)
@ -229,7 +233,7 @@ class VisionTransformer(nn.Module):
#self.repr_act = nn.Tanh() #self.repr_act = nn.Tanh()
# Classifier head # Classifier head
self.head = nn.Linear(embed_dim, num_classes) self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
trunc_normal_(self.pos_embed, std=.02) trunc_normal_(self.pos_embed, std=.02)
trunc_normal_(self.cls_token, std=.02) trunc_normal_(self.cls_token, std=.02)
@ -244,11 +248,18 @@ class VisionTransformer(nn.Module):
nn.init.constant_(m.bias, 0) nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0) nn.init.constant_(m.weight, 1.0)
@property @torch.jit.ignore
def no_weight_decay(self): def no_weight_decay(self):
return {'pos_embed', 'cls_token'} return {'pos_embed', 'cls_token'}
def forward(self, x): def get_classifier(self):
return self.head
def reset_classifier(self, num_classes, global_pool=''):
self.num_classes = num_classes
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
def forward_features(self, x):
B = x.shape[0] B = x.shape[0]
x = self.patch_embed(x) x = self.patch_embed(x)
@ -261,7 +272,11 @@ class VisionTransformer(nn.Module):
x = blk(x) x = blk(x)
x = self.norm(x) x = self.norm(x)
x = self.head(x[:, 0]) return x[:, 0]
def forward(self, x):
x = self.forward_features(x)
x = self.head(x)
return x return x
@ -284,7 +299,7 @@ def vit_small_patch16_224(pretrained=False, **kwargs):
model.default_cfg = default_cfgs['vit_small_patch16_224'] model.default_cfg = default_cfgs['vit_small_patch16_224']
if pretrained: if pretrained:
load_pretrained( load_pretrained(
model, num_classes=kwargs.get('num_classes', 0), in_chans=kwargs.get('in_chans', 3), filter_fn=_conv_filter) model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3), filter_fn=_conv_filter)
return model return model
@ -297,7 +312,7 @@ def vit_base_patch16_224(pretrained=False, **kwargs):
model.default_cfg = default_cfgs['vit_base_patch16_224'] model.default_cfg = default_cfgs['vit_base_patch16_224']
if pretrained: if pretrained:
load_pretrained( load_pretrained(
model, num_classes=kwargs.get('num_classes', 0), in_chans=kwargs.get('in_chans', 3), filter_fn=_conv_filter) model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3), filter_fn=_conv_filter)
return model return model
@ -308,8 +323,7 @@ def vit_base_patch16_384(pretrained=False, **kwargs):
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
model.default_cfg = default_cfgs['vit_base_patch16_384'] model.default_cfg = default_cfgs['vit_base_patch16_384']
if pretrained: if pretrained:
load_pretrained( load_pretrained(model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3))
model, num_classes=kwargs.get('num_classes', 0), in_chans=kwargs.get('in_chans', 3))
return model return model
@ -320,8 +334,7 @@ def vit_base_patch32_384(pretrained=False, **kwargs):
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
model.default_cfg = default_cfgs['vit_base_patch32_384'] model.default_cfg = default_cfgs['vit_base_patch32_384']
if pretrained: if pretrained:
load_pretrained( load_pretrained(model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3))
model, num_classes=kwargs.get('num_classes', 0), in_chans=kwargs.get('in_chans', 3))
return model return model
@ -339,8 +352,7 @@ def vit_large_patch16_384(pretrained=False, **kwargs):
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
model.default_cfg = default_cfgs['vit_large_patch16_384'] model.default_cfg = default_cfgs['vit_large_patch16_384']
if pretrained: if pretrained:
load_pretrained( load_pretrained(model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3))
model, num_classes=kwargs.get('num_classes', 0), in_chans=kwargs.get('in_chans', 3))
return model return model
@ -351,8 +363,7 @@ def vit_large_patch32_384(pretrained=False, **kwargs):
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
model.default_cfg = default_cfgs['vit_large_patch32_384'] model.default_cfg = default_cfgs['vit_large_patch32_384']
if pretrained: if pretrained:
load_pretrained( load_pretrained(model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3))
model, num_classes=kwargs.get('num_classes', 0), in_chans=kwargs.get('in_chans', 3))
return model return model

@ -43,7 +43,7 @@ def create_optimizer(args, model, filter_bias_and_bn=True):
if weight_decay and filter_bias_and_bn: if weight_decay and filter_bias_and_bn:
skip = {} skip = {}
if hasattr(model, 'no_weight_decay'): if hasattr(model, 'no_weight_decay'):
skip = model.no_weight_decay skip = model.no_weight_decay()
parameters = add_weight_decay(model, weight_decay, skip) parameters = add_weight_decay(model, weight_decay, skip)
weight_decay = 0. weight_decay = 0.
else: else:

Loading…
Cancel
Save