diff --git a/timm/models/layers/__init__.py b/timm/models/layers/__init__.py index b1a64db3..b1f452ff 100644 --- a/timm/models/layers/__init__.py +++ b/timm/models/layers/__init__.py @@ -39,4 +39,4 @@ from .split_batchnorm import SplitBatchNorm2d, convert_splitbn_model from .std_conv import StdConv2d, StdConv2dSame, ScaledStdConv2d, ScaledStdConv2dSame from .test_time_pool import TestTimePoolHead, apply_test_time_pool from .trace_utils import _assert, _float_to_int -from .weight_init import trunc_normal_, variance_scaling_, lecun_normal_ +from .weight_init import trunc_normal_, trunc_normal_tf_, variance_scaling_, lecun_normal_ diff --git a/timm/models/layers/norm.py b/timm/models/layers/norm.py index 85297420..345f67bc 100644 --- a/timm/models/layers/norm.py +++ b/timm/models/layers/norm.py @@ -16,8 +16,8 @@ class GroupNorm(nn.GroupNorm): class LayerNorm2d(nn.LayerNorm): """ LayerNorm for channels of '2D' spatial BCHW tensors """ - def __init__(self, num_channels): - super().__init__(num_channels) + def __init__(self, num_channels, eps=1e-6): + super().__init__(num_channels, eps=eps) def forward(self, x: torch.Tensor) -> torch.Tensor: return F.layer_norm( diff --git a/timm/models/layers/weight_init.py b/timm/models/layers/weight_init.py index 305a2fd0..4a160931 100644 --- a/timm/models/layers/weight_init.py +++ b/timm/models/layers/weight_init.py @@ -49,6 +49,11 @@ def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.): with values outside :math:`[a, b]` redrawn until they are within the bounds. The method used for generating the random values works best when :math:`a \leq \text{mean} \leq b`. + + NOTE: this impl is similar to the PyTorch trunc_normal_, the bounds [a, b] are + applied while sampling the normal with mean/std applied, therefore a, b args + should be adjusted to match the range of mean, std args. + Args: tensor: an n-dimensional `torch.Tensor` mean: the mean of the normal distribution @@ -62,6 +67,35 @@ def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.): return _no_grad_trunc_normal_(tensor, mean, std, a, b) +def trunc_normal_tf_(tensor, mean=0., std=1., a=-2., b=2.): + # type: (Tensor, float, float, float, float) -> Tensor + r"""Fills the input Tensor with values drawn from a truncated + normal distribution. The values are effectively drawn from the + normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)` + with values outside :math:`[a, b]` redrawn until they are within + the bounds. The method used for generating the random values works + best when :math:`a \leq \text{mean} \leq b`. + + NOTE: this 'tf' variant behaves closer to Tensorflow / JAX impl where the + bounds [a, b] are applied when sampling the normal distribution with mean=0, std=1.0 + and the result is subsquently scaled and shifted by the mean and std args. + + Args: + tensor: an n-dimensional `torch.Tensor` + mean: the mean of the normal distribution + std: the standard deviation of the normal distribution + a: the minimum cutoff value + b: the maximum cutoff value + Examples: + >>> w = torch.empty(3, 5) + >>> nn.init.trunc_normal_(w) + """ + _no_grad_trunc_normal_(tensor, 0, 1.0, a, b) + with torch.no_grad(): + tensor.mul_(std).add_(mean) + return tensor + + def variance_scaling_(tensor, scale=1.0, mode='fan_in', distribution='normal'): fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor) if mode == 'fan_in': @@ -75,7 +109,7 @@ def variance_scaling_(tensor, scale=1.0, mode='fan_in', distribution='normal'): if distribution == "truncated_normal": # constant is stddev of standard normal truncated to (-2, 2) - trunc_normal_(tensor, std=math.sqrt(variance) / .87962566103423978) + trunc_normal_tf_(tensor, std=math.sqrt(variance) / .87962566103423978) elif distribution == "normal": tensor.normal_(std=math.sqrt(variance)) elif distribution == "uniform":