You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
79 lines
2.4 KiB
79 lines
2.4 KiB
""" 'Fast' Normalization Functions
|
|
|
|
For GroupNorm and LayerNorm these functions bypass typical AMP upcast to float32.
|
|
|
|
Additionally, for LayerNorm, the APEX fused LN is used if available (which also does not upcast)
|
|
|
|
Hacked together by / Copyright 2022 Ross Wightman
|
|
"""
|
|
from typing import List, Optional
|
|
|
|
import torch
|
|
from torch.nn import functional as F
|
|
|
|
try:
|
|
from apex.normalization.fused_layer_norm import fused_layer_norm_affine
|
|
has_apex = True
|
|
except ImportError:
|
|
has_apex = False
|
|
|
|
|
|
# fast (ie lower precision LN) can be disabled with this flag if issues crop up
|
|
_USE_FAST_NORM = False # defaulting to False for now
|
|
|
|
|
|
def is_fast_norm():
|
|
return _USE_FAST_NORM
|
|
|
|
|
|
def set_fast_norm(enable=True):
|
|
global _USE_FAST_NORM
|
|
_USE_FAST_NORM = enable
|
|
|
|
|
|
def fast_group_norm(
|
|
x: torch.Tensor,
|
|
num_groups: int,
|
|
weight: Optional[torch.Tensor] = None,
|
|
bias: Optional[torch.Tensor] = None,
|
|
eps: float = 1e-5
|
|
) -> torch.Tensor:
|
|
if torch.jit.is_scripting():
|
|
# currently cannot use is_autocast_enabled within torchscript
|
|
return F.group_norm(x, num_groups, weight, bias, eps)
|
|
|
|
if torch.is_autocast_enabled():
|
|
# normally native AMP casts GN inputs to float32
|
|
# here we use the low precision autocast dtype
|
|
# FIXME what to do re CPU autocast?
|
|
dt = torch.get_autocast_gpu_dtype()
|
|
x, weight, bias = x.to(dt), weight.to(dt), bias.to(dt)
|
|
|
|
with torch.cuda.amp.autocast(enabled=False):
|
|
return F.group_norm(x, num_groups, weight, bias, eps)
|
|
|
|
|
|
def fast_layer_norm(
|
|
x: torch.Tensor,
|
|
normalized_shape: List[int],
|
|
weight: Optional[torch.Tensor] = None,
|
|
bias: Optional[torch.Tensor] = None,
|
|
eps: float = 1e-5
|
|
) -> torch.Tensor:
|
|
if torch.jit.is_scripting():
|
|
# currently cannot use is_autocast_enabled within torchscript
|
|
return F.layer_norm(x, normalized_shape, weight, bias, eps)
|
|
|
|
if has_apex:
|
|
return fused_layer_norm_affine(x, weight, bias, normalized_shape, eps)
|
|
|
|
if torch.is_autocast_enabled():
|
|
# normally native AMP casts LN inputs to float32
|
|
# apex LN does not, this is behaving like Apex
|
|
dt = torch.get_autocast_gpu_dtype()
|
|
# FIXME what to do re CPU autocast?
|
|
x, weight, bias = x.to(dt), weight.to(dt), bias.to(dt)
|
|
|
|
with torch.cuda.amp.autocast(enabled=False):
|
|
return F.layer_norm(x, normalized_shape, weight, bias, eps)
|