Cleanup experimental vit weight init a bit

pull/450/head
Ross Wightman 4 years ago
parent f42f1df26c
commit cf5fec5047

@ -31,4 +31,4 @@ from .split_attn import SplitAttnConv2d
from .split_batchnorm import SplitBatchNorm2d, convert_splitbn_model from .split_batchnorm import SplitBatchNorm2d, convert_splitbn_model
from .std_conv import StdConv2d, StdConv2dSame, ScaledStdConv2d, ScaledStdConv2dSame from .std_conv import StdConv2d, StdConv2dSame, ScaledStdConv2d, ScaledStdConv2dSame
from .test_time_pool import TestTimePoolHead, apply_test_time_pool from .test_time_pool import TestTimePoolHead, apply_test_time_pool
from .weight_init import trunc_normal_ from .weight_init import trunc_normal_, variance_scaling_, lecun_normal_

@ -2,6 +2,8 @@ import torch
import math import math
import warnings import warnings
from torch.nn.init import _calculate_fan_in_and_fan_out
def _no_grad_trunc_normal_(tensor, mean, std, a, b): def _no_grad_trunc_normal_(tensor, mean, std, a, b):
# Cut & paste from PyTorch official master until it's in a few official releases - RW # Cut & paste from PyTorch official master until it's in a few official releases - RW
@ -58,3 +60,30 @@ def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
>>> nn.init.trunc_normal_(w) >>> nn.init.trunc_normal_(w)
""" """
return _no_grad_trunc_normal_(tensor, mean, std, a, b) return _no_grad_trunc_normal_(tensor, mean, std, a, b)
def variance_scaling_(tensor, scale=1.0, mode='fan_in', distribution='normal'):
fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
if mode == 'fan_in':
denom = fan_in
elif mode == 'fan_out':
denom = fan_out
elif mode == 'fan_avg':
denom = (fan_in + fan_out) / 2
variance = scale / denom
if distribution == "truncated_normal":
# constant is stddev of standard normal truncated to (-2, 2)
trunc_normal_(tensor, std=math.sqrt(variance) / .87962566103423978)
elif distribution == "normal":
tensor.normal_(std=math.sqrt(variance))
elif distribution == "uniform":
bound = math.sqrt(3 * variance)
tensor.uniform_(-bound, bound)
else:
raise ValueError(f"invalid distribution {distribution}")
def lecun_normal_(tensor):
variance_scaling_(tensor, mode='fan_in', distribution='truncated_normal')

@ -28,7 +28,7 @@ import torch.nn.functional as F
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .helpers import load_pretrained from .helpers import load_pretrained
from .layers import StdConv2dSame, StdConv2d, DropPath, to_2tuple, trunc_normal_ from .layers import StdConv2dSame, StdConv2d, DropPath, to_2tuple, trunc_normal_, lecun_normal_
from .resnet import resnet26d, resnet50d from .resnet import resnet26d, resnet50d
from .resnetv2 import ResNetV2, create_resnetv2_stem from .resnetv2 import ResNetV2, create_resnetv2_stem
from .registry import register_model from .registry import register_model
@ -373,7 +373,7 @@ class VisionTransformer(nn.Module):
def __init__(self, img_size=224, patch_size=None, in_chans=3, num_classes=1000, embed_dim=768, depth=12, def __init__(self, img_size=224, patch_size=None, in_chans=3, num_classes=1000, embed_dim=768, depth=12,
num_heads=12, mlp_ratio=4., qkv_bias=True, qk_scale=None, representation_size=None, num_heads=12, mlp_ratio=4., qkv_bias=True, qk_scale=None, representation_size=None,
drop_rate=0., attn_drop_rate=0., drop_path_rate=0., hybrid_backbone=None, norm_layer=None, drop_rate=0., attn_drop_rate=0., drop_path_rate=0., hybrid_backbone=None, norm_layer=None,
act_layer=None, weight_init=''): act_layer=None, weight_init='new_nlhb'):
""" """
Args: Args:
img_size (int, tuple): input image size img_size (int, tuple): input image size
@ -433,14 +433,20 @@ class VisionTransformer(nn.Module):
# Classifier head # Classifier head
self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
self._init_weights(weight_init)
def _init_weights(self, weight_init: str):
trunc_normal_(self.pos_embed, std=.02) trunc_normal_(self.pos_embed, std=.02)
if weight_init != 'jax': # leave as zeros to match JAX impl if weight_init.startswith('jax'):
init_fn = _init_weights_jax
# leave cls token as zeros to match jax impl
else:
trunc_normal_(self.cls_token, std=.02) trunc_normal_(self.cls_token, std=.02)
init_fn = _init_weights_new if weight_init.startswith('new') else _init_weights_old
hb = -math.log(self.num_classes) if 'nlhb' in weight_init else 0.
init_fn = partial(init_fn, head_bias=hb)
for n, m in self.named_modules(): for n, m in self.named_modules():
if weight_init == 'jax': init_fn(m, n)
_init_weights_jax(m, n)
else:
_init_weights_original(m, n)
@torch.jit.ignore @torch.jit.ignore
def no_weight_decay(self): def no_weight_decay(self):
@ -475,41 +481,42 @@ class VisionTransformer(nn.Module):
return x return x
def _init_weights_original(m: nn.Module, n: str = ''): def _init_weights_old(m: nn.Module, n: str = '', head_bias: float = 0.):
if isinstance(m, (nn.Conv2d, nn.Linear)): if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02) trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None: if m.bias is not None:
nn.init.constant_(m.bias, 0) if 'head' in n:
nn.init.constant_(m.bias, head_bias)
else:
nn.init.zeros_(m.bias)
elif isinstance(m, nn.LayerNorm): elif isinstance(m, nn.LayerNorm):
nn.init.zeros_(m.bias) nn.init.zeros_(m.bias)
nn.init.ones_(m.weight) nn.init.ones_(m.weight)
def _init_weights_jax(m: nn.Module, n: str): def _init_weights_new(m: nn.Module, n: str = '', head_bias: float = 0.):
""" Weight init scheme closer to the official JAX impl than my original init""" if isinstance(m, (nn.Conv2d, nn.Linear)):
#trunc_normal_(m.weight, std=.02)
def _fan_in(tensor): lecun_normal_(m.weight)
dimensions = tensor.dim() if m.bias is not None:
if dimensions < 2: if 'head' in n:
raise ValueError("Fan in and fan out can not be computed for tensor with fewer than 2 dimensions") nn.init.constant_(m.bias, head_bias)
else:
nn.init.zeros_(m.bias)
elif isinstance(m, nn.LayerNorm):
nn.init.zeros_(m.bias)
nn.init.ones_(m.weight)
num_input_fmaps = tensor.size(1)
receptive_field_size = 1
if tensor.dim() > 2:
receptive_field_size = tensor[0][0].numel()
fan_in = num_input_fmaps * receptive_field_size
return fan_in
def _lecun_normal(w): def _init_weights_jax(m: nn.Module, n: str, head_bias: float = 0.):
stddev = (1.0 / _fan_in(w)) ** 0.5 / .87962566103423978 """ Attempt at weight init scheme closer to the official JAX impl than my original init"""
trunc_normal_(w, 0, stddev)
if isinstance(m, nn.Linear): if isinstance(m, nn.Linear):
if 'head' in n: if 'head' in n:
nn.init.zeros_(m.weight) nn.init.zeros_(m.weight)
nn.init.zeros_(m.bias) nn.init.constant_(m.bias, head_bias)
elif 'pre_logits' in n: elif 'pre_logits' in n:
_lecun_normal(m.weight) lecun_normal_(m.weight)
nn.init.zeros_(m.bias) nn.init.zeros_(m.bias)
else: else:
nn.init.xavier_uniform_(m.weight) nn.init.xavier_uniform_(m.weight)
@ -519,7 +526,7 @@ def _init_weights_jax(m: nn.Module, n: str):
else: else:
nn.init.zeros_(m.bias) nn.init.zeros_(m.bias)
elif isinstance(m, nn.Conv2d): elif isinstance(m, nn.Conv2d):
_lecun_normal(m.weight) lecun_normal_(m.weight)
if m.bias is not None: if m.bias is not None:
nn.init.zeros_(m.bias) nn.init.zeros_(m.bias)
elif isinstance(m, nn.LayerNorm): elif isinstance(m, nn.LayerNorm):
@ -544,7 +551,7 @@ class DistilledVisionTransformer(VisionTransformer):
trunc_normal_(self.dist_token, std=.02) trunc_normal_(self.dist_token, std=.02)
trunc_normal_(self.pos_embed, std=.02) trunc_normal_(self.pos_embed, std=.02)
self.head_dist.apply(_init_weights_original) self.head_dist.apply(_init_weights_new)
def forward_features(self, x): def forward_features(self, x):
B = x.shape[0] B = x.shape[0]

Loading…
Cancel
Save