diff --git a/timm/models/layers/fast_norm.py b/timm/models/layers/fast_norm.py index 9a34a15e..fb35e47d 100644 --- a/timm/models/layers/fast_norm.py +++ b/timm/models/layers/fast_norm.py @@ -1,3 +1,11 @@ +""" 'Fast' Normalization Functions + +For GroupNorm and LayerNorm these functions bypass typical AMP upcast to float32. + +Additionally, for LayerNorm, the APEX fused LN is used if available (which also does not upcast) + +Hacked together by / Copyright 2022 Ross Wightman +""" from typing import List, Optional import torch @@ -37,6 +45,7 @@ def fast_group_norm( if torch.is_autocast_enabled(): # normally native AMP casts GN inputs to float32 # here we use the low precision autocast dtype + # FIXME what to do re CPU autocast? dt = torch.get_autocast_gpu_dtype() x, weight, bias = x.to(dt), weight.to(dt), bias.to(dt) @@ -62,6 +71,7 @@ def fast_layer_norm( # normally native AMP casts LN inputs to float32 # apex LN does not, this is behaving like Apex dt = torch.get_autocast_gpu_dtype() + # FIXME what to do re CPU autocast? x, weight, bias = x.to(dt), weight.to(dt), bias.to(dt) with torch.cuda.amp.autocast(enabled=False): diff --git a/timm/models/layers/norm.py b/timm/models/layers/norm.py index 2ff8fc08..42445a49 100644 --- a/timm/models/layers/norm.py +++ b/timm/models/layers/norm.py @@ -1,4 +1,8 @@ """ Normalization layers and wrappers + +Norm layer definitions that support fast norm and consistent channel arg order (always first arg). + +Hacked together by / Copyright 2022 Ross Wightman """ import torch diff --git a/timm/models/layers/norm_act.py b/timm/models/layers/norm_act.py index dc077160..ff075fbc 100644 --- a/timm/models/layers/norm_act.py +++ b/timm/models/layers/norm_act.py @@ -1,4 +1,16 @@ """ Normalization + Activation Layers + +Provides Norm+Act fns for standard PyTorch norm layers such as +* BatchNorm +* GroupNorm +* LayerNorm + +This allows swapping with alternative layers that are natively both norm + act such as +* EvoNorm (evo_norm.py) +* FilterResponseNorm (filter_response_norm.py) +* InplaceABN (inplace_abn.py) + +Hacked together by / Copyright 2022 Ross Wightman """ from typing import Union, List, Optional, Any