|
|
|
@ -1,3 +1,11 @@
|
|
|
|
|
""" 'Fast' Normalization Functions
|
|
|
|
|
|
|
|
|
|
For GroupNorm and LayerNorm these functions bypass typical AMP upcast to float32.
|
|
|
|
|
|
|
|
|
|
Additionally, for LayerNorm, the APEX fused LN is used if available (which also does not upcast)
|
|
|
|
|
|
|
|
|
|
Hacked together by / Copyright 2022 Ross Wightman
|
|
|
|
|
"""
|
|
|
|
|
from typing import List, Optional
|
|
|
|
|
|
|
|
|
|
import torch
|
|
|
|
@ -37,6 +45,7 @@ def fast_group_norm(
|
|
|
|
|
if torch.is_autocast_enabled():
|
|
|
|
|
# normally native AMP casts GN inputs to float32
|
|
|
|
|
# here we use the low precision autocast dtype
|
|
|
|
|
# FIXME what to do re CPU autocast?
|
|
|
|
|
dt = torch.get_autocast_gpu_dtype()
|
|
|
|
|
x, weight, bias = x.to(dt), weight.to(dt), bias.to(dt)
|
|
|
|
|
|
|
|
|
@ -62,6 +71,7 @@ def fast_layer_norm(
|
|
|
|
|
# normally native AMP casts LN inputs to float32
|
|
|
|
|
# apex LN does not, this is behaving like Apex
|
|
|
|
|
dt = torch.get_autocast_gpu_dtype()
|
|
|
|
|
# FIXME what to do re CPU autocast?
|
|
|
|
|
x, weight, bias = x.to(dt), weight.to(dt), bias.to(dt)
|
|
|
|
|
|
|
|
|
|
with torch.cuda.amp.autocast(enabled=False):
|
|
|
|
|