From bf2ca6bdf474ab0f27b4fa7cdca9348e60058f20 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Thu, 1 Apr 2021 18:11:51 -0700 Subject: [PATCH] Merge jax and original weight init --- timm/models/vision_transformer.py | 64 +++++++++++++------------------ 1 file changed, 26 insertions(+), 38 deletions(-) diff --git a/timm/models/vision_transformer.py b/timm/models/vision_transformer.py index cd73cc11..81f8ae9f 100644 --- a/timm/models/vision_transformer.py +++ b/timm/models/vision_transformer.py @@ -289,40 +289,19 @@ class VisionTransformer(nn.Module): assert weight_init in ('jax', 'jax_nlhb', 'nlhb', '') head_bias = -math.log(self.num_classes) if 'nlhb' in weight_init else 0. trunc_normal_(self.pos_embed, std=.02) + if self.dist_token is not None: + trunc_normal_(self.dist_token, std=.02) if weight_init.startswith('jax'): # leave cls token as zeros to match jax impl for n, m in self.named_modules(): - _init_weights_jax(m, n, head_bias=head_bias) + _init_vit_weights(m, n, head_bias=head_bias, jax_impl=True) else: trunc_normal_(self.cls_token, std=.02) - if self.dist_token is not None: - trunc_normal_(self.dist_token, std=.02) - for n, m in self.named_modules(): - self._init_weights(m, n, head_bias=head_bias) - - def _init_weights(self, m, n: str = '', head_bias: float = 0., init_conv=False): - # This impl does not exactly match the official JAX version. - # When called w/o n, head_bias, init_conv args it will behave exactly the same - # as my original init for compatibility with downstream use cases (ie DeiT). - if isinstance(m, nn.Linear): - if n.startswith('head'): - nn.init.zeros_(m.weight) - nn.init.constant_(m.bias, head_bias) - elif n.startswith('pre_logits'): - lecun_normal_(m.weight) - nn.init.zeros_(m.bias) - else: - trunc_normal_(m.weight, std=.02) - if m.bias is not None: - nn.init.zeros_(m.bias) - elif init_conv and isinstance(m, nn.Conv2d): - # NOTE conv was left to pytorch default init originally - lecun_normal_(m.weight) - if m.bias is not None: - nn.init.zeros_(m.bias) - elif isinstance(m, nn.LayerNorm): - nn.init.zeros_(m.bias) - nn.init.ones_(m.weight) + self.apply(_init_vit_weights) + + def _init_weights(self, m): + # this fn left here for compat with downstream users + _init_vit_weights(m) @torch.jit.ignore def no_weight_decay(self): @@ -369,9 +348,12 @@ class VisionTransformer(nn.Module): return x -def _init_weights_jax(m: nn.Module, n: str, head_bias: float = 0.): - # A weight init scheme closer to the official JAX impl than my original init - # NOTE: requires module name so cannot be used via module.apply() +def _init_vit_weights(m, n: str = '', head_bias: float = 0., jax_impl: bool = False): + """ ViT weight initialization + * When called without n, head_bias, jax_impl args it will behave exactly the same + as my original init for compatibility with prev hparam / downstream use cases (ie DeiT). + * When called w/ valid n (module name) and jax_impl=True, will (hopefully) match JAX impl + """ if isinstance(m, nn.Linear): if n.startswith('head'): nn.init.zeros_(m.weight) @@ -380,13 +362,19 @@ def _init_weights_jax(m: nn.Module, n: str, head_bias: float = 0.): lecun_normal_(m.weight) nn.init.zeros_(m.bias) else: - nn.init.xavier_uniform_(m.weight) - if m.bias is not None: - if 'mlp' in n: - nn.init.normal_(m.bias, 0, 1e-6) - else: + if jax_impl: + nn.init.xavier_uniform_(m.weight) + if m.bias is not None: + if 'mlp' in n: + nn.init.normal_(m.bias, std=1e-6) + else: + nn.init.zeros_(m.bias) + else: + trunc_normal_(m.weight, std=.02) + if m.bias is not None: nn.init.zeros_(m.bias) - elif isinstance(m, nn.Conv2d): + elif jax_impl and isinstance(m, nn.Conv2d): + # NOTE conv was left to pytorch default in my original init lecun_normal_(m.weight) if m.bias is not None: nn.init.zeros_(m.bias)