Merge jax and original weight init

pull/533/head
Ross Wightman 3 years ago
parent acbd698c83
commit bf2ca6bdf4

@ -289,40 +289,19 @@ class VisionTransformer(nn.Module):
assert weight_init in ('jax', 'jax_nlhb', 'nlhb', '') assert weight_init in ('jax', 'jax_nlhb', 'nlhb', '')
head_bias = -math.log(self.num_classes) if 'nlhb' in weight_init else 0. head_bias = -math.log(self.num_classes) if 'nlhb' in weight_init else 0.
trunc_normal_(self.pos_embed, std=.02) trunc_normal_(self.pos_embed, std=.02)
if self.dist_token is not None:
trunc_normal_(self.dist_token, std=.02)
if weight_init.startswith('jax'): if weight_init.startswith('jax'):
# leave cls token as zeros to match jax impl # leave cls token as zeros to match jax impl
for n, m in self.named_modules(): for n, m in self.named_modules():
_init_weights_jax(m, n, head_bias=head_bias) _init_vit_weights(m, n, head_bias=head_bias, jax_impl=True)
else: else:
trunc_normal_(self.cls_token, std=.02) trunc_normal_(self.cls_token, std=.02)
if self.dist_token is not None: self.apply(_init_vit_weights)
trunc_normal_(self.dist_token, std=.02)
for n, m in self.named_modules():
self._init_weights(m, n, head_bias=head_bias)
def _init_weights(self, m, n: str = '', head_bias: float = 0., init_conv=False): def _init_weights(self, m):
# This impl does not exactly match the official JAX version. # this fn left here for compat with downstream users
# When called w/o n, head_bias, init_conv args it will behave exactly the same _init_vit_weights(m)
# as my original init for compatibility with downstream use cases (ie DeiT).
if isinstance(m, nn.Linear):
if n.startswith('head'):
nn.init.zeros_(m.weight)
nn.init.constant_(m.bias, head_bias)
elif n.startswith('pre_logits'):
lecun_normal_(m.weight)
nn.init.zeros_(m.bias)
else:
trunc_normal_(m.weight, std=.02)
if m.bias is not None:
nn.init.zeros_(m.bias)
elif init_conv and isinstance(m, nn.Conv2d):
# NOTE conv was left to pytorch default init originally
lecun_normal_(m.weight)
if m.bias is not None:
nn.init.zeros_(m.bias)
elif isinstance(m, nn.LayerNorm):
nn.init.zeros_(m.bias)
nn.init.ones_(m.weight)
@torch.jit.ignore @torch.jit.ignore
def no_weight_decay(self): def no_weight_decay(self):
@ -369,9 +348,12 @@ class VisionTransformer(nn.Module):
return x return x
def _init_weights_jax(m: nn.Module, n: str, head_bias: float = 0.): def _init_vit_weights(m, n: str = '', head_bias: float = 0., jax_impl: bool = False):
# A weight init scheme closer to the official JAX impl than my original init """ ViT weight initialization
# NOTE: requires module name so cannot be used via module.apply() * When called without n, head_bias, jax_impl args it will behave exactly the same
as my original init for compatibility with prev hparam / downstream use cases (ie DeiT).
* When called w/ valid n (module name) and jax_impl=True, will (hopefully) match JAX impl
"""
if isinstance(m, nn.Linear): if isinstance(m, nn.Linear):
if n.startswith('head'): if n.startswith('head'):
nn.init.zeros_(m.weight) nn.init.zeros_(m.weight)
@ -380,13 +362,19 @@ def _init_weights_jax(m: nn.Module, n: str, head_bias: float = 0.):
lecun_normal_(m.weight) lecun_normal_(m.weight)
nn.init.zeros_(m.bias) nn.init.zeros_(m.bias)
else: else:
if jax_impl:
nn.init.xavier_uniform_(m.weight) nn.init.xavier_uniform_(m.weight)
if m.bias is not None: if m.bias is not None:
if 'mlp' in n: if 'mlp' in n:
nn.init.normal_(m.bias, 0, 1e-6) nn.init.normal_(m.bias, std=1e-6)
else:
nn.init.zeros_(m.bias)
else: else:
trunc_normal_(m.weight, std=.02)
if m.bias is not None:
nn.init.zeros_(m.bias) nn.init.zeros_(m.bias)
elif isinstance(m, nn.Conv2d): elif jax_impl and isinstance(m, nn.Conv2d):
# NOTE conv was left to pytorch default in my original init
lecun_normal_(m.weight) lecun_normal_(m.weight)
if m.bias is not None: if m.bias is not None:
nn.init.zeros_(m.bias) nn.init.zeros_(m.bias)

Loading…
Cancel
Save