You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

81 lines
2.9 KiB

#
# For licensing see accompanying LICENSE.md file.
# Copyright (C) 2022 Apple Inc. All Rights Reserved.
#
import torch
import torch.nn as nn
# Reference: https://github.com/apple/ml-ane-transformers/blob/main/ane_transformers/reference/layer_norm.py
class LayerNormANE(nn.Module):
""" LayerNorm optimized for Apple Neural Engine (ANE) execution
Note: This layer only supports normalization over the final dim. It expects `num_channels`
as an argument and not `normalized_shape` which is used by `torch.nn.LayerNorm`.
"""
def __init__(self,
num_channels,
clip_mag=None,
eps=1e-5,
elementwise_affine=True):
"""
Args:
num_channels: Number of channels (C) where the expected input data format is BC1S. S stands for sequence length.
clip_mag: Optional float value to use for clamping the input range before layer norm is applied.
If specified, helps reduce risk of overflow.
eps: Small value to avoid dividing by zero
elementwise_affine: If true, adds learnable channel-wise shift (bias) and scale (weight) parameters
"""
super().__init__()
# Principle 1: Picking the Right Data Format (machinelearning.apple.com/research/apple-neural-engine)
self.expected_rank = len("BC1S")
self.num_channels = num_channels
self.eps = eps
self.clip_mag = clip_mag
self.elementwise_affine = elementwise_affine
if self.elementwise_affine:
self.weight = nn.Parameter(torch.Tensor(num_channels))
self.bias = nn.Parameter(torch.Tensor(num_channels))
self._reset_parameters()
def _reset_parameters(self):
if self.elementwise_affine:
nn.init.ones_(self.weight)
nn.init.zeros_(self.bias)
def forward(self, inputs):
input_rank = len(inputs.size())
# Principle 1: Picking the Right Data Format (machinelearning.apple.com/research/apple-neural-engine)
# Migrate the data format from BSC to BC1S (most conducive to ANE)
if input_rank == 3 and inputs.size(2) == self.num_channels:
inputs = inputs.transpose(1, 2).unsqueeze(2)
input_rank = len(inputs.size())
assert input_rank == self.expected_rank
assert inputs.size(1) == self.num_channels
if self.clip_mag is not None:
inputs.clamp_(-self.clip_mag, self.clip_mag)
channels_mean = inputs.mean(dim=1, keepdims=True)
zero_mean = inputs - channels_mean
zero_mean_sq = zero_mean * zero_mean
denom = (zero_mean_sq.mean(dim=1, keepdims=True) + self.eps).rsqrt()
out = zero_mean * denom
if self.elementwise_affine:
out = (out + self.bias.view(1, self.num_channels, 1, 1)
) * self.weight.view(1, self.num_channels, 1, 1)
return out