|
|
@ -49,6 +49,11 @@ def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
|
|
|
|
with values outside :math:`[a, b]` redrawn until they are within
|
|
|
|
with values outside :math:`[a, b]` redrawn until they are within
|
|
|
|
the bounds. The method used for generating the random values works
|
|
|
|
the bounds. The method used for generating the random values works
|
|
|
|
best when :math:`a \leq \text{mean} \leq b`.
|
|
|
|
best when :math:`a \leq \text{mean} \leq b`.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
NOTE: this impl is similar to the PyTorch trunc_normal_, the bounds [a, b] are
|
|
|
|
|
|
|
|
applied while sampling the normal with mean/std applied, therefore a, b args
|
|
|
|
|
|
|
|
should be adjusted to match the range of mean, std args.
|
|
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
Args:
|
|
|
|
tensor: an n-dimensional `torch.Tensor`
|
|
|
|
tensor: an n-dimensional `torch.Tensor`
|
|
|
|
mean: the mean of the normal distribution
|
|
|
|
mean: the mean of the normal distribution
|
|
|
@ -62,6 +67,35 @@ def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
|
|
|
|
return _no_grad_trunc_normal_(tensor, mean, std, a, b)
|
|
|
|
return _no_grad_trunc_normal_(tensor, mean, std, a, b)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def trunc_normal_tf_(tensor, mean=0., std=1., a=-2., b=2.):
|
|
|
|
|
|
|
|
# type: (Tensor, float, float, float, float) -> Tensor
|
|
|
|
|
|
|
|
r"""Fills the input Tensor with values drawn from a truncated
|
|
|
|
|
|
|
|
normal distribution. The values are effectively drawn from the
|
|
|
|
|
|
|
|
normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
|
|
|
|
|
|
|
|
with values outside :math:`[a, b]` redrawn until they are within
|
|
|
|
|
|
|
|
the bounds. The method used for generating the random values works
|
|
|
|
|
|
|
|
best when :math:`a \leq \text{mean} \leq b`.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
NOTE: this 'tf' variant behaves closer to Tensorflow / JAX impl where the
|
|
|
|
|
|
|
|
bounds [a, b] are applied when sampling the normal distribution with mean=0, std=1.0
|
|
|
|
|
|
|
|
and the result is subsquently scaled and shifted by the mean and std args.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
|
|
|
tensor: an n-dimensional `torch.Tensor`
|
|
|
|
|
|
|
|
mean: the mean of the normal distribution
|
|
|
|
|
|
|
|
std: the standard deviation of the normal distribution
|
|
|
|
|
|
|
|
a: the minimum cutoff value
|
|
|
|
|
|
|
|
b: the maximum cutoff value
|
|
|
|
|
|
|
|
Examples:
|
|
|
|
|
|
|
|
>>> w = torch.empty(3, 5)
|
|
|
|
|
|
|
|
>>> nn.init.trunc_normal_(w)
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
_no_grad_trunc_normal_(tensor, 0, 1.0, a, b)
|
|
|
|
|
|
|
|
with torch.no_grad():
|
|
|
|
|
|
|
|
tensor.mul_(std).add_(mean)
|
|
|
|
|
|
|
|
return tensor
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def variance_scaling_(tensor, scale=1.0, mode='fan_in', distribution='normal'):
|
|
|
|
def variance_scaling_(tensor, scale=1.0, mode='fan_in', distribution='normal'):
|
|
|
|
fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
|
|
|
|
fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
|
|
|
|
if mode == 'fan_in':
|
|
|
|
if mode == 'fan_in':
|
|
|
@ -75,7 +109,7 @@ def variance_scaling_(tensor, scale=1.0, mode='fan_in', distribution='normal'):
|
|
|
|
|
|
|
|
|
|
|
|
if distribution == "truncated_normal":
|
|
|
|
if distribution == "truncated_normal":
|
|
|
|
# constant is stddev of standard normal truncated to (-2, 2)
|
|
|
|
# constant is stddev of standard normal truncated to (-2, 2)
|
|
|
|
trunc_normal_(tensor, std=math.sqrt(variance) / .87962566103423978)
|
|
|
|
trunc_normal_tf_(tensor, std=math.sqrt(variance) / .87962566103423978)
|
|
|
|
elif distribution == "normal":
|
|
|
|
elif distribution == "normal":
|
|
|
|
tensor.normal_(std=math.sqrt(variance))
|
|
|
|
tensor.normal_(std=math.sqrt(variance))
|
|
|
|
elif distribution == "uniform":
|
|
|
|
elif distribution == "uniform":
|
|
|
|