# # For licensing see accompanying LICENSE.md file. # Copyright (C) 2022 Apple Inc. All Rights Reserved. # import torch import torch.nn as nn # Reference: https://github.com/apple/ml-ane-transformers/blob/main/ane_transformers/reference/layer_norm.py class LayerNormANE(nn.Module): """ LayerNorm optimized for Apple Neural Engine (ANE) execution Note: This layer only supports normalization over the final dim. It expects `num_channels` as an argument and not `normalized_shape` which is used by `torch.nn.LayerNorm`. """ def __init__(self, num_channels, clip_mag=None, eps=1e-5, elementwise_affine=True): """ Args: num_channels: Number of channels (C) where the expected input data format is BC1S. S stands for sequence length. clip_mag: Optional float value to use for clamping the input range before layer norm is applied. If specified, helps reduce risk of overflow. eps: Small value to avoid dividing by zero elementwise_affine: If true, adds learnable channel-wise shift (bias) and scale (weight) parameters """ super().__init__() # Principle 1: Picking the Right Data Format (machinelearning.apple.com/research/apple-neural-engine) self.expected_rank = len("BC1S") self.num_channels = num_channels self.eps = eps self.clip_mag = clip_mag self.elementwise_affine = elementwise_affine if self.elementwise_affine: self.weight = nn.Parameter(torch.Tensor(num_channels)) self.bias = nn.Parameter(torch.Tensor(num_channels)) self._reset_parameters() def _reset_parameters(self): if self.elementwise_affine: nn.init.ones_(self.weight) nn.init.zeros_(self.bias) def forward(self, inputs): input_rank = len(inputs.size()) # Principle 1: Picking the Right Data Format (machinelearning.apple.com/research/apple-neural-engine) # Migrate the data format from BSC to BC1S (most conducive to ANE) if input_rank == 3 and inputs.size(2) == self.num_channels: inputs = inputs.transpose(1, 2).unsqueeze(2) input_rank = len(inputs.size()) assert input_rank == self.expected_rank assert inputs.size(1) == self.num_channels if self.clip_mag is not None: inputs.clamp_(-self.clip_mag, self.clip_mag) channels_mean = inputs.mean(dim=1, keepdims=True) zero_mean = inputs - channels_mean zero_mean_sq = zero_mean * zero_mean denom = (zero_mean_sq.mean(dim=1, keepdims=True) + self.eps).rsqrt() out = zero_mean * denom if self.elementwise_affine: out = (out + self.bias.view(1, self.num_channels, 1, 1) ) * self.weight.view(1, self.num_channels, 1, 1) return out