|
|
|
@ -373,7 +373,7 @@ class VisionTransformer(nn.Module):
|
|
|
|
|
def __init__(self, img_size=224, patch_size=None, in_chans=3, num_classes=1000, embed_dim=768, depth=12,
|
|
|
|
|
num_heads=12, mlp_ratio=4., qkv_bias=True, qk_scale=None, representation_size=None,
|
|
|
|
|
drop_rate=0., attn_drop_rate=0., drop_path_rate=0., hybrid_backbone=None, norm_layer=None,
|
|
|
|
|
act_layer=None):
|
|
|
|
|
act_layer=None, weight_init=''):
|
|
|
|
|
"""
|
|
|
|
|
Args:
|
|
|
|
|
img_size (int, tuple): input image size
|
|
|
|
@ -434,17 +434,13 @@ class VisionTransformer(nn.Module):
|
|
|
|
|
self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
|
|
|
|
|
|
|
|
|
|
trunc_normal_(self.pos_embed, std=.02)
|
|
|
|
|
trunc_normal_(self.cls_token, std=.02)
|
|
|
|
|
self.apply(self._init_weights)
|
|
|
|
|
|
|
|
|
|
def _init_weights(self, m):
|
|
|
|
|
if isinstance(m, nn.Linear):
|
|
|
|
|
trunc_normal_(m.weight, std=.02)
|
|
|
|
|
if isinstance(m, nn.Linear) and m.bias is not None:
|
|
|
|
|
nn.init.constant_(m.bias, 0)
|
|
|
|
|
elif isinstance(m, nn.LayerNorm):
|
|
|
|
|
nn.init.constant_(m.bias, 0)
|
|
|
|
|
nn.init.constant_(m.weight, 1.0)
|
|
|
|
|
if weight_init != 'jax': # leave as zeros to match JAX impl
|
|
|
|
|
trunc_normal_(self.cls_token, std=.02)
|
|
|
|
|
for n, m in self.named_modules():
|
|
|
|
|
if weight_init == 'jax':
|
|
|
|
|
_init_weights_jax(m, n)
|
|
|
|
|
else:
|
|
|
|
|
_init_weights_original(m, n)
|
|
|
|
|
|
|
|
|
|
@torch.jit.ignore
|
|
|
|
|
def no_weight_decay(self):
|
|
|
|
@ -479,6 +475,58 @@ class VisionTransformer(nn.Module):
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _init_weights_original(m: nn.Module, n: str = ''):
|
|
|
|
|
if isinstance(m, nn.Linear):
|
|
|
|
|
trunc_normal_(m.weight, std=.02)
|
|
|
|
|
if isinstance(m, nn.Linear) and m.bias is not None:
|
|
|
|
|
nn.init.constant_(m.bias, 0)
|
|
|
|
|
elif isinstance(m, nn.LayerNorm):
|
|
|
|
|
nn.init.zeros_(m.bias)
|
|
|
|
|
nn.init.ones_(m.weight)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _init_weights_jax(m: nn.Module, n: str):
|
|
|
|
|
""" Weight init scheme closer to the official JAX impl than my original init"""
|
|
|
|
|
|
|
|
|
|
def _fan_in(tensor):
|
|
|
|
|
dimensions = tensor.dim()
|
|
|
|
|
if dimensions < 2:
|
|
|
|
|
raise ValueError("Fan in and fan out can not be computed for tensor with fewer than 2 dimensions")
|
|
|
|
|
|
|
|
|
|
num_input_fmaps = tensor.size(1)
|
|
|
|
|
receptive_field_size = 1
|
|
|
|
|
if tensor.dim() > 2:
|
|
|
|
|
receptive_field_size = tensor[0][0].numel()
|
|
|
|
|
fan_in = num_input_fmaps * receptive_field_size
|
|
|
|
|
return fan_in
|
|
|
|
|
|
|
|
|
|
def _lecun_normal(w):
|
|
|
|
|
stddev = (1.0 / _fan_in(w)) ** 0.5 / .87962566103423978
|
|
|
|
|
trunc_normal_(w, 0, stddev)
|
|
|
|
|
|
|
|
|
|
if isinstance(m, nn.Linear):
|
|
|
|
|
if 'head' in n:
|
|
|
|
|
nn.init.zeros_(m.weight)
|
|
|
|
|
nn.init.zeros_(m.bias)
|
|
|
|
|
elif 'pre_logits' in n:
|
|
|
|
|
_lecun_normal(m.weight)
|
|
|
|
|
nn.init.zeros_(m.bias)
|
|
|
|
|
else:
|
|
|
|
|
nn.init.xavier_uniform_(m.weight)
|
|
|
|
|
if m.bias is not None:
|
|
|
|
|
if 'mlp' in n:
|
|
|
|
|
nn.init.normal_(m.bias, 0, 1e-6)
|
|
|
|
|
else:
|
|
|
|
|
nn.init.zeros_(m.bias)
|
|
|
|
|
elif isinstance(m, nn.Conv2d):
|
|
|
|
|
_lecun_normal(m.weight)
|
|
|
|
|
if m.bias is not None:
|
|
|
|
|
nn.init.zeros_(m.bias)
|
|
|
|
|
elif isinstance(m, nn.LayerNorm):
|
|
|
|
|
nn.init.constant_(m.bias, 0.)
|
|
|
|
|
nn.init.constant_(m.weight, 1.)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class DistilledVisionTransformer(VisionTransformer):
|
|
|
|
|
""" Vision Transformer with distillation token.
|
|
|
|
|
|
|
|
|
@ -496,7 +544,7 @@ class DistilledVisionTransformer(VisionTransformer):
|
|
|
|
|
|
|
|
|
|
trunc_normal_(self.dist_token, std=.02)
|
|
|
|
|
trunc_normal_(self.pos_embed, std=.02)
|
|
|
|
|
self.head_dist.apply(self._init_weights)
|
|
|
|
|
self.head_dist.apply(_init_weights_original)
|
|
|
|
|
|
|
|
|
|
def forward_features(self, x):
|
|
|
|
|
B = x.shape[0]
|
|
|
|
|