from typing import List, Optional import torch from torch.nn import functional as F try: from apex.normalization.fused_layer_norm import fused_layer_norm_affine has_apex = True except ImportError: has_apex = False # fast (ie lower precision LN) can be disabled with this flag if issues crop up _USE_FAST_NORM = False # defaulting to False for now def is_fast_norm(): return _USE_FAST_NORM def set_fast_norm(enable=True): global _USE_FAST_NORM _USE_FAST_NORM = enable def fast_group_norm( x: torch.Tensor, num_groups: int, weight: Optional[torch.Tensor] = None, bias: Optional[torch.Tensor] = None, eps: float = 1e-5 ) -> torch.Tensor: if torch.jit.is_scripting(): # currently cannot use is_autocast_enabled within torchscript return F.group_norm(x, num_groups, weight, bias, eps) if torch.is_autocast_enabled(): # normally native AMP casts GN inputs to float32 # here we use the low precision autocast dtype dt = torch.get_autocast_gpu_dtype() x, weight, bias = x.to(dt), weight.to(dt), bias.to(dt) with torch.cuda.amp.autocast(enabled=False): return F.group_norm(x, num_groups, weight, bias, eps) def fast_layer_norm( x: torch.Tensor, normalized_shape: List[int], weight: Optional[torch.Tensor] = None, bias: Optional[torch.Tensor] = None, eps: float = 1e-5 ) -> torch.Tensor: if torch.jit.is_scripting(): # currently cannot use is_autocast_enabled within torchscript return F.layer_norm(x, normalized_shape, weight, bias, eps) if has_apex: return fused_layer_norm_affine(x, weight, bias, normalized_shape, eps) if torch.is_autocast_enabled(): # normally native AMP casts LN inputs to float32 # apex LN does not, this is behaving like Apex dt = torch.get_autocast_gpu_dtype() x, weight, bias = x.to(dt), weight.to(dt), bias.to(dt) with torch.cuda.amp.autocast(enabled=False): return F.layer_norm(x, normalized_shape, weight, bias, eps)