From 4de57ccf0123650bf759960d9ac64dca6263da7c Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Thu, 18 Mar 2021 15:35:22 -0700 Subject: [PATCH] Add weight init scheme that's closer to JAX impl --- timm/models/vision_transformer.py | 74 +++++++++++++++++++++++++------ 1 file changed, 61 insertions(+), 13 deletions(-) diff --git a/timm/models/vision_transformer.py b/timm/models/vision_transformer.py index aed295ec..5fb5c7c7 100644 --- a/timm/models/vision_transformer.py +++ b/timm/models/vision_transformer.py @@ -373,7 +373,7 @@ class VisionTransformer(nn.Module): def __init__(self, img_size=224, patch_size=None, in_chans=3, num_classes=1000, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4., qkv_bias=True, qk_scale=None, representation_size=None, drop_rate=0., attn_drop_rate=0., drop_path_rate=0., hybrid_backbone=None, norm_layer=None, - act_layer=None): + act_layer=None, weight_init=''): """ Args: img_size (int, tuple): input image size @@ -434,17 +434,13 @@ class VisionTransformer(nn.Module): self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() trunc_normal_(self.pos_embed, std=.02) - trunc_normal_(self.cls_token, std=.02) - self.apply(self._init_weights) - - def _init_weights(self, m): - if isinstance(m, nn.Linear): - trunc_normal_(m.weight, std=.02) - if isinstance(m, nn.Linear) and m.bias is not None: - nn.init.constant_(m.bias, 0) - elif isinstance(m, nn.LayerNorm): - nn.init.constant_(m.bias, 0) - nn.init.constant_(m.weight, 1.0) + if weight_init != 'jax': # leave as zeros to match JAX impl + trunc_normal_(self.cls_token, std=.02) + for n, m in self.named_modules(): + if weight_init == 'jax': + _init_weights_jax(m, n) + else: + _init_weights_original(m, n) @torch.jit.ignore def no_weight_decay(self): @@ -479,6 +475,58 @@ class VisionTransformer(nn.Module): return x +def _init_weights_original(m: nn.Module, n: str = ''): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.zeros_(m.bias) + nn.init.ones_(m.weight) + + +def _init_weights_jax(m: nn.Module, n: str): + """ Weight init scheme closer to the official JAX impl than my original init""" + + def _fan_in(tensor): + dimensions = tensor.dim() + if dimensions < 2: + raise ValueError("Fan in and fan out can not be computed for tensor with fewer than 2 dimensions") + + num_input_fmaps = tensor.size(1) + receptive_field_size = 1 + if tensor.dim() > 2: + receptive_field_size = tensor[0][0].numel() + fan_in = num_input_fmaps * receptive_field_size + return fan_in + + def _lecun_normal(w): + stddev = (1.0 / _fan_in(w)) ** 0.5 / .87962566103423978 + trunc_normal_(w, 0, stddev) + + if isinstance(m, nn.Linear): + if 'head' in n: + nn.init.zeros_(m.weight) + nn.init.zeros_(m.bias) + elif 'pre_logits' in n: + _lecun_normal(m.weight) + nn.init.zeros_(m.bias) + else: + nn.init.xavier_uniform_(m.weight) + if m.bias is not None: + if 'mlp' in n: + nn.init.normal_(m.bias, 0, 1e-6) + else: + nn.init.zeros_(m.bias) + elif isinstance(m, nn.Conv2d): + _lecun_normal(m.weight) + if m.bias is not None: + nn.init.zeros_(m.bias) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0.) + nn.init.constant_(m.weight, 1.) + + class DistilledVisionTransformer(VisionTransformer): """ Vision Transformer with distillation token. @@ -496,7 +544,7 @@ class DistilledVisionTransformer(VisionTransformer): trunc_normal_(self.dist_token, std=.02) trunc_normal_(self.pos_embed, std=.02) - self.head_dist.apply(self._init_weights) + self.head_dist.apply(_init_weights_original) def forward_features(self, x): B = x.shape[0]