EvoNorm and GroupNormAct options for debugging TPU / XLA concerns

pull/1239/head
Ross Wightman 2 years ago
parent ff0f709c20
commit 7bbbd5ef1b

@ -23,6 +23,7 @@ GPU, similar train speeds for EvoNormS variants and BatchNorm.
Hacked together by / Copyright 2020 Ross Wightman
"""
from typing import Sequence, Union
import torch
import torch.nn as nn
@ -38,36 +39,47 @@ def instance_std(x, eps: float = 1e-5):
def instance_rms(x, eps: float = 1e-5):
rms = x.square().float().mean(dim=(2, 3), keepdim=True).add(eps).sqrt().to(dtype=x.dtype)
rms = x.float().square().mean(dim=(2, 3), keepdim=True).add(eps).sqrt().to(x.dtype)
return rms.expand(x.shape)
def manual_var(x, dim: Union[int, Sequence[int]], diff_sqm: bool = False):
xm = x.mean(dim=dim, keepdim=True)
if diff_sqm:
# difference of squared mean and mean squared, faster on TPU can be less stable
var = (x.square().mean(dim=(2, 3, 4), keepdim=True) - xm.square()).clamp(0)
else:
var = (x - xm).square().mean(dim=(2, 3, 4), keepdim=True)
return var
def group_std(x, groups: int = 32, eps: float = 1e-5, flatten: bool = False):
B, C, H, W = x.shape
x_dtype = x.dtype
_assert(C % groups == 0, '')
# x = x.reshape(B, groups, -1) # FIXME simpler shape causing TPU / XLA issues
# std = x.float().var(dim=2, unbiased=False, keepdim=True).add(eps).sqrt()
x = x.reshape(B, groups, C // groups, H, W)
std = x.float().var(dim=(2, 3, 4), unbiased=False, keepdim=True).add(eps).sqrt()
return std.expand(x.shape).reshape(B, C, H, W).to(x_dtype)
torch.var()
if flatten:
x = x.reshape(B, groups, -1) # FIXME simpler shape causing TPU / XLA issues
std = x.float().var(dim=2, unbiased=False, keepdim=True).add(eps).sqrt().to(x_dtype)
else:
x = x.reshape(B, groups, C // groups, H, W)
std = x.float().var(dim=(2, 3, 4), unbiased=False, keepdim=True).add(eps).sqrt().to(x_dtype)
return std.expand(x.shape).reshape(B, C, H, W)
def group_std_tpu(x, groups: int = 32, eps: float = 1e-5, diff_sqm: bool = False):
def group_std_tpu(x, groups: int = 32, eps: float = 1e-5, diff_sqm: bool = False, flatten: bool = False):
# This is a workaround for some stability / odd behaviour of .var and .std
# running on PyTorch XLA w/ TPUs. These manual var impl are producing much better results
B, C, H, W = x.shape
_assert(C % groups == 0, '')
x_dtype = x.dtype
x = x.float().reshape(B, groups, C // groups, H, W)
xm = x.mean(dim=(2, 3, 4), keepdim=True)
if diff_sqm:
# difference of squared mean and mean squared, faster on TPU
var = (x.square().mean(dim=(2, 3, 4), keepdim=True) - xm.square()).clamp(0)
if flatten:
x = x.reshape(B, groups, -1) # FIXME simpler shape causing TPU / XLA issues
var = manual_var(x, dim=-1, diff_sqm=diff_sqm)
else:
var = (x - xm).square().mean(dim=(2, 3, 4), keepdim=True)
return var.add(eps).sqrt().expand(x.shape).reshape(B, C, H, W).to(x_dtype)
# group_std = group_std_tpu # temporary, for TPU / PT XLA
x = x.reshape(B, groups, C // groups, H, W)
var = manual_var(x, dim=(2, 3, 4), diff_sqm=diff_sqm)
return var.add(eps).sqrt().expand(x.shape).reshape(B, C, H, W)
#group_std = group_std_tpu # FIXME TPU temporary
def group_rms(x, groups: int = 32, eps: float = 1e-5):
@ -75,8 +87,8 @@ def group_rms(x, groups: int = 32, eps: float = 1e-5):
_assert(C % groups == 0, '')
x_dtype = x.dtype
x = x.reshape(B, groups, C // groups, H, W)
sqm = x.square().mean(dim=(2, 3, 4), keepdim=True).add(eps).sqrt_().to(dtype=x_dtype)
return sqm.expand(x.shape).reshape(B, C, H, W)
rms = x.float().square().mean(dim=(2, 3, 4), keepdim=True).add(eps).sqrt_().to(dtype=x_dtype)
return rms.expand(x.shape).reshape(B, C, H, W)
class EvoNorm2dB0(nn.Module):

@ -66,6 +66,31 @@ class BatchNormAct2d(nn.BatchNorm2d):
return x
def group_norm_tpu(x, w, b, groups: int = 32, eps: float = 1e-5, diff_sqm: bool = False, flatten: bool = False):
# This is a workaround for some odd behaviour running on PyTorch XLA w/ TPUs.
x_shape = x.shape
x_dtype = x.dtype
if flatten:
norm_shape = (x_shape[0], groups, -1)
reduce_dim = -1
else:
norm_shape = (x_shape[0], groups, x_shape[1] // groups) + x_shape[2:]
reduce_dim = tuple(range(2, x.ndim + 1))
affine_shape = (1, -1) + (1,) * (x.ndim - 2)
x = x.reshape(norm_shape)
# x = x.to(torch.float32) # for testing w/ AMP
xm = x.mean(dim=reduce_dim, keepdim=True)
if diff_sqm:
# difference of squared mean and mean squared, faster on TPU
var = (x.square().mean(dim=reduce_dim, keepdim=True) - xm.square()).clamp(0)
else:
var = (x - xm).square().mean(dim=reduce_dim, keepdim=True)
x = (x - xm.expand(norm_shape)) / var.add(eps).sqrt().expand(norm_shape)
x = x.reshape(x_shape) * w.view(affine_shape) + b.view(affine_shape)
# x = x.to(x_dtype) # for testing w/ AMP
return x
class GroupNormAct(nn.GroupNorm):
# NOTE num_channel and num_groups order flipped for easier layer swaps / binding of fixed args
def __init__(self, num_channels, num_groups=32, eps=1e-5, affine=True,
@ -80,6 +105,9 @@ class GroupNormAct(nn.GroupNorm):
self.act = nn.Identity()
def forward(self, x):
x = F.group_norm(x, self.num_groups, self.weight, self.bias, self.eps)
if False: # FIXME TPU temporary while resolving some performance issues
x = group_norm_tpu(x, self.weight, self.bias, self.num_groups, self.eps)
else:
x = F.group_norm(x, self.num_groups, self.weight, self.bias, self.eps)
x = self.act(x)
return x

Loading…
Cancel
Save