diff --git a/timm/models/layers/__init__.py b/timm/models/layers/__init__.py index f8d8d8c0..89fb859c 100644 --- a/timm/models/layers/__init__.py +++ b/timm/models/layers/__init__.py @@ -31,4 +31,4 @@ from .split_attn import SplitAttnConv2d from .split_batchnorm import SplitBatchNorm2d, convert_splitbn_model from .std_conv import StdConv2d, StdConv2dSame, ScaledStdConv2d, ScaledStdConv2dSame from .test_time_pool import TestTimePoolHead, apply_test_time_pool -from .weight_init import trunc_normal_ +from .weight_init import trunc_normal_, variance_scaling_, lecun_normal_ diff --git a/timm/models/layers/weight_init.py b/timm/models/layers/weight_init.py index d731029f..305a2fd0 100644 --- a/timm/models/layers/weight_init.py +++ b/timm/models/layers/weight_init.py @@ -2,6 +2,8 @@ import torch import math import warnings +from torch.nn.init import _calculate_fan_in_and_fan_out + def _no_grad_trunc_normal_(tensor, mean, std, a, b): # Cut & paste from PyTorch official master until it's in a few official releases - RW @@ -58,3 +60,30 @@ def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.): >>> nn.init.trunc_normal_(w) """ return _no_grad_trunc_normal_(tensor, mean, std, a, b) + + +def variance_scaling_(tensor, scale=1.0, mode='fan_in', distribution='normal'): + fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor) + if mode == 'fan_in': + denom = fan_in + elif mode == 'fan_out': + denom = fan_out + elif mode == 'fan_avg': + denom = (fan_in + fan_out) / 2 + + variance = scale / denom + + if distribution == "truncated_normal": + # constant is stddev of standard normal truncated to (-2, 2) + trunc_normal_(tensor, std=math.sqrt(variance) / .87962566103423978) + elif distribution == "normal": + tensor.normal_(std=math.sqrt(variance)) + elif distribution == "uniform": + bound = math.sqrt(3 * variance) + tensor.uniform_(-bound, bound) + else: + raise ValueError(f"invalid distribution {distribution}") + + +def lecun_normal_(tensor): + variance_scaling_(tensor, mode='fan_in', distribution='truncated_normal') diff --git a/timm/models/vision_transformer.py b/timm/models/vision_transformer.py index 42943fab..45c1eddb 100644 --- a/timm/models/vision_transformer.py +++ b/timm/models/vision_transformer.py @@ -28,7 +28,7 @@ import torch.nn.functional as F from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from .helpers import load_pretrained -from .layers import StdConv2dSame, StdConv2d, DropPath, to_2tuple, trunc_normal_ +from .layers import StdConv2dSame, StdConv2d, DropPath, to_2tuple, trunc_normal_, lecun_normal_ from .resnet import resnet26d, resnet50d from .resnetv2 import ResNetV2, create_resnetv2_stem from .registry import register_model @@ -373,7 +373,7 @@ class VisionTransformer(nn.Module): def __init__(self, img_size=224, patch_size=None, in_chans=3, num_classes=1000, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4., qkv_bias=True, qk_scale=None, representation_size=None, drop_rate=0., attn_drop_rate=0., drop_path_rate=0., hybrid_backbone=None, norm_layer=None, - act_layer=None, weight_init=''): + act_layer=None, weight_init='new_nlhb'): """ Args: img_size (int, tuple): input image size @@ -433,14 +433,20 @@ class VisionTransformer(nn.Module): # Classifier head self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() + self._init_weights(weight_init) + + def _init_weights(self, weight_init: str): trunc_normal_(self.pos_embed, std=.02) - if weight_init != 'jax': # leave as zeros to match JAX impl + if weight_init.startswith('jax'): + init_fn = _init_weights_jax + # leave cls token as zeros to match jax impl + else: trunc_normal_(self.cls_token, std=.02) + init_fn = _init_weights_new if weight_init.startswith('new') else _init_weights_old + hb = -math.log(self.num_classes) if 'nlhb' in weight_init else 0. + init_fn = partial(init_fn, head_bias=hb) for n, m in self.named_modules(): - if weight_init == 'jax': - _init_weights_jax(m, n) - else: - _init_weights_original(m, n) + init_fn(m, n) @torch.jit.ignore def no_weight_decay(self): @@ -475,41 +481,42 @@ class VisionTransformer(nn.Module): return x -def _init_weights_original(m: nn.Module, n: str = ''): - if isinstance(m, (nn.Conv2d, nn.Linear)): +def _init_weights_old(m: nn.Module, n: str = '', head_bias: float = 0.): + if isinstance(m, nn.Linear): trunc_normal_(m.weight, std=.02) - if isinstance(m, nn.Linear) and m.bias is not None: - nn.init.constant_(m.bias, 0) + if m.bias is not None: + if 'head' in n: + nn.init.constant_(m.bias, head_bias) + else: + nn.init.zeros_(m.bias) elif isinstance(m, nn.LayerNorm): nn.init.zeros_(m.bias) nn.init.ones_(m.weight) -def _init_weights_jax(m: nn.Module, n: str): - """ Weight init scheme closer to the official JAX impl than my original init""" - - def _fan_in(tensor): - dimensions = tensor.dim() - if dimensions < 2: - raise ValueError("Fan in and fan out can not be computed for tensor with fewer than 2 dimensions") +def _init_weights_new(m: nn.Module, n: str = '', head_bias: float = 0.): + if isinstance(m, (nn.Conv2d, nn.Linear)): + #trunc_normal_(m.weight, std=.02) + lecun_normal_(m.weight) + if m.bias is not None: + if 'head' in n: + nn.init.constant_(m.bias, head_bias) + else: + nn.init.zeros_(m.bias) + elif isinstance(m, nn.LayerNorm): + nn.init.zeros_(m.bias) + nn.init.ones_(m.weight) - num_input_fmaps = tensor.size(1) - receptive_field_size = 1 - if tensor.dim() > 2: - receptive_field_size = tensor[0][0].numel() - fan_in = num_input_fmaps * receptive_field_size - return fan_in - def _lecun_normal(w): - stddev = (1.0 / _fan_in(w)) ** 0.5 / .87962566103423978 - trunc_normal_(w, 0, stddev) +def _init_weights_jax(m: nn.Module, n: str, head_bias: float = 0.): + """ Attempt at weight init scheme closer to the official JAX impl than my original init""" if isinstance(m, nn.Linear): if 'head' in n: nn.init.zeros_(m.weight) - nn.init.zeros_(m.bias) + nn.init.constant_(m.bias, head_bias) elif 'pre_logits' in n: - _lecun_normal(m.weight) + lecun_normal_(m.weight) nn.init.zeros_(m.bias) else: nn.init.xavier_uniform_(m.weight) @@ -519,7 +526,7 @@ def _init_weights_jax(m: nn.Module, n: str): else: nn.init.zeros_(m.bias) elif isinstance(m, nn.Conv2d): - _lecun_normal(m.weight) + lecun_normal_(m.weight) if m.bias is not None: nn.init.zeros_(m.bias) elif isinstance(m, nn.LayerNorm): @@ -544,7 +551,7 @@ class DistilledVisionTransformer(VisionTransformer): trunc_normal_(self.dist_token, std=.02) trunc_normal_(self.pos_embed, std=.02) - self.head_dist.apply(_init_weights_original) + self.head_dist.apply(_init_weights_new) def forward_features(self, x): B = x.shape[0]