|
|
|
@ -289,40 +289,19 @@ class VisionTransformer(nn.Module):
|
|
|
|
|
assert weight_init in ('jax', 'jax_nlhb', 'nlhb', '')
|
|
|
|
|
head_bias = -math.log(self.num_classes) if 'nlhb' in weight_init else 0.
|
|
|
|
|
trunc_normal_(self.pos_embed, std=.02)
|
|
|
|
|
if self.dist_token is not None:
|
|
|
|
|
trunc_normal_(self.dist_token, std=.02)
|
|
|
|
|
if weight_init.startswith('jax'):
|
|
|
|
|
# leave cls token as zeros to match jax impl
|
|
|
|
|
for n, m in self.named_modules():
|
|
|
|
|
_init_weights_jax(m, n, head_bias=head_bias)
|
|
|
|
|
_init_vit_weights(m, n, head_bias=head_bias, jax_impl=True)
|
|
|
|
|
else:
|
|
|
|
|
trunc_normal_(self.cls_token, std=.02)
|
|
|
|
|
if self.dist_token is not None:
|
|
|
|
|
trunc_normal_(self.dist_token, std=.02)
|
|
|
|
|
for n, m in self.named_modules():
|
|
|
|
|
self._init_weights(m, n, head_bias=head_bias)
|
|
|
|
|
|
|
|
|
|
def _init_weights(self, m, n: str = '', head_bias: float = 0., init_conv=False):
|
|
|
|
|
# This impl does not exactly match the official JAX version.
|
|
|
|
|
# When called w/o n, head_bias, init_conv args it will behave exactly the same
|
|
|
|
|
# as my original init for compatibility with downstream use cases (ie DeiT).
|
|
|
|
|
if isinstance(m, nn.Linear):
|
|
|
|
|
if n.startswith('head'):
|
|
|
|
|
nn.init.zeros_(m.weight)
|
|
|
|
|
nn.init.constant_(m.bias, head_bias)
|
|
|
|
|
elif n.startswith('pre_logits'):
|
|
|
|
|
lecun_normal_(m.weight)
|
|
|
|
|
nn.init.zeros_(m.bias)
|
|
|
|
|
else:
|
|
|
|
|
trunc_normal_(m.weight, std=.02)
|
|
|
|
|
if m.bias is not None:
|
|
|
|
|
nn.init.zeros_(m.bias)
|
|
|
|
|
elif init_conv and isinstance(m, nn.Conv2d):
|
|
|
|
|
# NOTE conv was left to pytorch default init originally
|
|
|
|
|
lecun_normal_(m.weight)
|
|
|
|
|
if m.bias is not None:
|
|
|
|
|
nn.init.zeros_(m.bias)
|
|
|
|
|
elif isinstance(m, nn.LayerNorm):
|
|
|
|
|
nn.init.zeros_(m.bias)
|
|
|
|
|
nn.init.ones_(m.weight)
|
|
|
|
|
self.apply(_init_vit_weights)
|
|
|
|
|
|
|
|
|
|
def _init_weights(self, m):
|
|
|
|
|
# this fn left here for compat with downstream users
|
|
|
|
|
_init_vit_weights(m)
|
|
|
|
|
|
|
|
|
|
@torch.jit.ignore
|
|
|
|
|
def no_weight_decay(self):
|
|
|
|
@ -369,9 +348,12 @@ class VisionTransformer(nn.Module):
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _init_weights_jax(m: nn.Module, n: str, head_bias: float = 0.):
|
|
|
|
|
# A weight init scheme closer to the official JAX impl than my original init
|
|
|
|
|
# NOTE: requires module name so cannot be used via module.apply()
|
|
|
|
|
def _init_vit_weights(m, n: str = '', head_bias: float = 0., jax_impl: bool = False):
|
|
|
|
|
""" ViT weight initialization
|
|
|
|
|
* When called without n, head_bias, jax_impl args it will behave exactly the same
|
|
|
|
|
as my original init for compatibility with prev hparam / downstream use cases (ie DeiT).
|
|
|
|
|
* When called w/ valid n (module name) and jax_impl=True, will (hopefully) match JAX impl
|
|
|
|
|
"""
|
|
|
|
|
if isinstance(m, nn.Linear):
|
|
|
|
|
if n.startswith('head'):
|
|
|
|
|
nn.init.zeros_(m.weight)
|
|
|
|
@ -380,13 +362,19 @@ def _init_weights_jax(m: nn.Module, n: str, head_bias: float = 0.):
|
|
|
|
|
lecun_normal_(m.weight)
|
|
|
|
|
nn.init.zeros_(m.bias)
|
|
|
|
|
else:
|
|
|
|
|
nn.init.xavier_uniform_(m.weight)
|
|
|
|
|
if m.bias is not None:
|
|
|
|
|
if 'mlp' in n:
|
|
|
|
|
nn.init.normal_(m.bias, 0, 1e-6)
|
|
|
|
|
else:
|
|
|
|
|
if jax_impl:
|
|
|
|
|
nn.init.xavier_uniform_(m.weight)
|
|
|
|
|
if m.bias is not None:
|
|
|
|
|
if 'mlp' in n:
|
|
|
|
|
nn.init.normal_(m.bias, std=1e-6)
|
|
|
|
|
else:
|
|
|
|
|
nn.init.zeros_(m.bias)
|
|
|
|
|
else:
|
|
|
|
|
trunc_normal_(m.weight, std=.02)
|
|
|
|
|
if m.bias is not None:
|
|
|
|
|
nn.init.zeros_(m.bias)
|
|
|
|
|
elif isinstance(m, nn.Conv2d):
|
|
|
|
|
elif jax_impl and isinstance(m, nn.Conv2d):
|
|
|
|
|
# NOTE conv was left to pytorch default in my original init
|
|
|
|
|
lecun_normal_(m.weight)
|
|
|
|
|
if m.bias is not None:
|
|
|
|
|
nn.init.zeros_(m.bias)
|
|
|
|
|