Fix pit implementation to be clsoer to deit/levit re distillation head handling

pull/1014/head
Ross Wightman 2 years ago
parent 0862e6ebae
commit 5f47518f27

@ -72,7 +72,7 @@ class VisionTransformerDistilled(VisionTransformer):
self.dist_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim))
self.pos_embed = nn.Parameter(torch.zeros(1, self.patch_embed.num_patches + self.num_tokens, self.embed_dim))
self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if self.num_classes > 0 else nn.Identity()
self.distilled_training = False
self.distilled_training = False # must set this True to train w/ distillation token
self.init_weights(weight_init)

@ -539,7 +539,7 @@ class LevitDistilled(Levit):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.head_dist = NormLinear(self.num_features, self.num_classes) if self.num_classes > 0 else nn.Identity()
self.distilled_training = False
self.distilled_training = False # must set this True to train w/ distillation token
@torch.jit.ignore
def get_classifier(self):

@ -147,9 +147,10 @@ class PoolingVisionTransformer(nn.Module):
A PyTorch implement of 'Rethinking Spatial Dimensions of Vision Transformers'
- https://arxiv.org/abs/2103.16302
"""
def __init__(self, img_size, patch_size, stride, base_dims, depth, heads,
mlp_ratio, num_classes=1000, in_chans=3, distilled=False, global_pool='token',
attn_drop_rate=.0, drop_rate=.0, drop_path_rate=.0):
def __init__(
self, img_size, patch_size, stride, base_dims, depth, heads,
mlp_ratio, num_classes=1000, in_chans=3, global_pool='token',
distilled=False, attn_drop_rate=.0, drop_rate=.0, drop_path_rate=.0):
super(PoolingVisionTransformer, self).__init__()
assert global_pool in ('token',)
@ -193,6 +194,7 @@ class PoolingVisionTransformer(nn.Module):
self.head_dist = None
if distilled:
self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if num_classes > 0 else nn.Identity()
self.distilled_training = False # must set this True to train w/ distillation token
trunc_normal_(self.pos_embed, std=.02)
trunc_normal_(self.cls_token, std=.02)
@ -207,6 +209,10 @@ class PoolingVisionTransformer(nn.Module):
def no_weight_decay(self):
return {'pos_embed', 'cls_token'}
@torch.jit.ignore
def set_distilled_training(self, enable=True):
self.distilled_training = enable
@torch.jit.ignore
def set_grad_checkpointing(self, enable=True):
assert not enable, 'gradient checkpointing not supported'
@ -231,16 +237,30 @@ class PoolingVisionTransformer(nn.Module):
cls_tokens = self.norm(cls_tokens)
return cls_tokens
def forward(self, x):
x = self.forward_features(x)
def forward_head(self, x, pre_logits: bool = False) -> torch.Tensor:
if self.head_dist is not None:
x, x_dist = self.head(x[:, 0]), self.head_dist(x[:, 1]) # x must be a tuple
if self.training and not torch.jit.is_scripting():
assert self.global_pool == 'token'
x, x_dist = x[:, 0], x[:, 1]
if not pre_logits:
x = self.head(x)
x_dist = self.head_dist(x_dist)
if self.distilled_training and self.training and not torch.jit.is_scripting():
# only return separate classification predictions when training in distilled mode
return x, x_dist
else:
# during standard train / finetune, inference average the classifier predictions
return (x + x_dist) / 2
else:
return self.head(x[:, 0])
if self.global_pool == 'token':
x = x[:, 0]
if not pre_logits:
x = self.head(x)
return x
def forward(self, x):
x = self.forward_features(x)
x = self.forward_head(x)
return x
def checkpoint_filter_fn(state_dict, model):

Loading…
Cancel
Save