Merge jax and original weight init

pull/533/head
Ross Wightman 3 years ago
parent acbd698c83
commit bf2ca6bdf4

@ -289,40 +289,19 @@ class VisionTransformer(nn.Module):
assert weight_init in ('jax', 'jax_nlhb', 'nlhb', '')
head_bias = -math.log(self.num_classes) if 'nlhb' in weight_init else 0.
trunc_normal_(self.pos_embed, std=.02)
if self.dist_token is not None:
trunc_normal_(self.dist_token, std=.02)
if weight_init.startswith('jax'):
# leave cls token as zeros to match jax impl
for n, m in self.named_modules():
_init_weights_jax(m, n, head_bias=head_bias)
_init_vit_weights(m, n, head_bias=head_bias, jax_impl=True)
else:
trunc_normal_(self.cls_token, std=.02)
if self.dist_token is not None:
trunc_normal_(self.dist_token, std=.02)
for n, m in self.named_modules():
self._init_weights(m, n, head_bias=head_bias)
def _init_weights(self, m, n: str = '', head_bias: float = 0., init_conv=False):
# This impl does not exactly match the official JAX version.
# When called w/o n, head_bias, init_conv args it will behave exactly the same
# as my original init for compatibility with downstream use cases (ie DeiT).
if isinstance(m, nn.Linear):
if n.startswith('head'):
nn.init.zeros_(m.weight)
nn.init.constant_(m.bias, head_bias)
elif n.startswith('pre_logits'):
lecun_normal_(m.weight)
nn.init.zeros_(m.bias)
else:
trunc_normal_(m.weight, std=.02)
if m.bias is not None:
nn.init.zeros_(m.bias)
elif init_conv and isinstance(m, nn.Conv2d):
# NOTE conv was left to pytorch default init originally
lecun_normal_(m.weight)
if m.bias is not None:
nn.init.zeros_(m.bias)
elif isinstance(m, nn.LayerNorm):
nn.init.zeros_(m.bias)
nn.init.ones_(m.weight)
self.apply(_init_vit_weights)
def _init_weights(self, m):
# this fn left here for compat with downstream users
_init_vit_weights(m)
@torch.jit.ignore
def no_weight_decay(self):
@ -369,9 +348,12 @@ class VisionTransformer(nn.Module):
return x
def _init_weights_jax(m: nn.Module, n: str, head_bias: float = 0.):
# A weight init scheme closer to the official JAX impl than my original init
# NOTE: requires module name so cannot be used via module.apply()
def _init_vit_weights(m, n: str = '', head_bias: float = 0., jax_impl: bool = False):
""" ViT weight initialization
* When called without n, head_bias, jax_impl args it will behave exactly the same
as my original init for compatibility with prev hparam / downstream use cases (ie DeiT).
* When called w/ valid n (module name) and jax_impl=True, will (hopefully) match JAX impl
"""
if isinstance(m, nn.Linear):
if n.startswith('head'):
nn.init.zeros_(m.weight)
@ -380,13 +362,19 @@ def _init_weights_jax(m: nn.Module, n: str, head_bias: float = 0.):
lecun_normal_(m.weight)
nn.init.zeros_(m.bias)
else:
nn.init.xavier_uniform_(m.weight)
if m.bias is not None:
if 'mlp' in n:
nn.init.normal_(m.bias, 0, 1e-6)
else:
if jax_impl:
nn.init.xavier_uniform_(m.weight)
if m.bias is not None:
if 'mlp' in n:
nn.init.normal_(m.bias, std=1e-6)
else:
nn.init.zeros_(m.bias)
else:
trunc_normal_(m.weight, std=.02)
if m.bias is not None:
nn.init.zeros_(m.bias)
elif isinstance(m, nn.Conv2d):
elif jax_impl and isinstance(m, nn.Conv2d):
# NOTE conv was left to pytorch default in my original init
lecun_normal_(m.weight)
if m.bias is not None:
nn.init.zeros_(m.bias)

Loading…
Cancel
Save