|
|
|
@ -147,9 +147,10 @@ class PoolingVisionTransformer(nn.Module):
|
|
|
|
|
A PyTorch implement of 'Rethinking Spatial Dimensions of Vision Transformers'
|
|
|
|
|
- https://arxiv.org/abs/2103.16302
|
|
|
|
|
"""
|
|
|
|
|
def __init__(self, img_size, patch_size, stride, base_dims, depth, heads,
|
|
|
|
|
mlp_ratio, num_classes=1000, in_chans=3, distilled=False, global_pool='token',
|
|
|
|
|
attn_drop_rate=.0, drop_rate=.0, drop_path_rate=.0):
|
|
|
|
|
def __init__(
|
|
|
|
|
self, img_size, patch_size, stride, base_dims, depth, heads,
|
|
|
|
|
mlp_ratio, num_classes=1000, in_chans=3, global_pool='token',
|
|
|
|
|
distilled=False, attn_drop_rate=.0, drop_rate=.0, drop_path_rate=.0):
|
|
|
|
|
super(PoolingVisionTransformer, self).__init__()
|
|
|
|
|
assert global_pool in ('token',)
|
|
|
|
|
|
|
|
|
@ -193,6 +194,7 @@ class PoolingVisionTransformer(nn.Module):
|
|
|
|
|
self.head_dist = None
|
|
|
|
|
if distilled:
|
|
|
|
|
self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if num_classes > 0 else nn.Identity()
|
|
|
|
|
self.distilled_training = False # must set this True to train w/ distillation token
|
|
|
|
|
|
|
|
|
|
trunc_normal_(self.pos_embed, std=.02)
|
|
|
|
|
trunc_normal_(self.cls_token, std=.02)
|
|
|
|
@ -207,6 +209,10 @@ class PoolingVisionTransformer(nn.Module):
|
|
|
|
|
def no_weight_decay(self):
|
|
|
|
|
return {'pos_embed', 'cls_token'}
|
|
|
|
|
|
|
|
|
|
@torch.jit.ignore
|
|
|
|
|
def set_distilled_training(self, enable=True):
|
|
|
|
|
self.distilled_training = enable
|
|
|
|
|
|
|
|
|
|
@torch.jit.ignore
|
|
|
|
|
def set_grad_checkpointing(self, enable=True):
|
|
|
|
|
assert not enable, 'gradient checkpointing not supported'
|
|
|
|
@ -231,16 +237,30 @@ class PoolingVisionTransformer(nn.Module):
|
|
|
|
|
cls_tokens = self.norm(cls_tokens)
|
|
|
|
|
return cls_tokens
|
|
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
|
x = self.forward_features(x)
|
|
|
|
|
def forward_head(self, x, pre_logits: bool = False) -> torch.Tensor:
|
|
|
|
|
if self.head_dist is not None:
|
|
|
|
|
x, x_dist = self.head(x[:, 0]), self.head_dist(x[:, 1]) # x must be a tuple
|
|
|
|
|
if self.training and not torch.jit.is_scripting():
|
|
|
|
|
assert self.global_pool == 'token'
|
|
|
|
|
x, x_dist = x[:, 0], x[:, 1]
|
|
|
|
|
if not pre_logits:
|
|
|
|
|
x = self.head(x)
|
|
|
|
|
x_dist = self.head_dist(x_dist)
|
|
|
|
|
if self.distilled_training and self.training and not torch.jit.is_scripting():
|
|
|
|
|
# only return separate classification predictions when training in distilled mode
|
|
|
|
|
return x, x_dist
|
|
|
|
|
else:
|
|
|
|
|
# during standard train / finetune, inference average the classifier predictions
|
|
|
|
|
return (x + x_dist) / 2
|
|
|
|
|
else:
|
|
|
|
|
return self.head(x[:, 0])
|
|
|
|
|
if self.global_pool == 'token':
|
|
|
|
|
x = x[:, 0]
|
|
|
|
|
if not pre_logits:
|
|
|
|
|
x = self.head(x)
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
|
x = self.forward_features(x)
|
|
|
|
|
x = self.forward_head(x)
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def checkpoint_filter_fn(state_dict, model):
|
|
|
|
|