From 7bbbd5ef1b2ad03ace04982152a8ca395fed4f43 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Wed, 8 Dec 2021 14:05:12 -0800 Subject: [PATCH] EvoNorm and GroupNormAct options for debugging TPU / XLA concerns --- timm/models/layers/evo_norm.py | 48 +++++++++++++++++++++------------- timm/models/layers/norm_act.py | 30 ++++++++++++++++++++- 2 files changed, 59 insertions(+), 19 deletions(-) diff --git a/timm/models/layers/evo_norm.py b/timm/models/layers/evo_norm.py index d42c502c..5032a527 100644 --- a/timm/models/layers/evo_norm.py +++ b/timm/models/layers/evo_norm.py @@ -23,6 +23,7 @@ GPU, similar train speeds for EvoNormS variants and BatchNorm. Hacked together by / Copyright 2020 Ross Wightman """ +from typing import Sequence, Union import torch import torch.nn as nn @@ -38,36 +39,47 @@ def instance_std(x, eps: float = 1e-5): def instance_rms(x, eps: float = 1e-5): - rms = x.square().float().mean(dim=(2, 3), keepdim=True).add(eps).sqrt().to(dtype=x.dtype) + rms = x.float().square().mean(dim=(2, 3), keepdim=True).add(eps).sqrt().to(x.dtype) return rms.expand(x.shape) +def manual_var(x, dim: Union[int, Sequence[int]], diff_sqm: bool = False): + xm = x.mean(dim=dim, keepdim=True) + if diff_sqm: + # difference of squared mean and mean squared, faster on TPU can be less stable + var = (x.square().mean(dim=(2, 3, 4), keepdim=True) - xm.square()).clamp(0) + else: + var = (x - xm).square().mean(dim=(2, 3, 4), keepdim=True) + return var + + def group_std(x, groups: int = 32, eps: float = 1e-5, flatten: bool = False): B, C, H, W = x.shape x_dtype = x.dtype _assert(C % groups == 0, '') - # x = x.reshape(B, groups, -1) # FIXME simpler shape causing TPU / XLA issues - # std = x.float().var(dim=2, unbiased=False, keepdim=True).add(eps).sqrt() - x = x.reshape(B, groups, C // groups, H, W) - std = x.float().var(dim=(2, 3, 4), unbiased=False, keepdim=True).add(eps).sqrt() - return std.expand(x.shape).reshape(B, C, H, W).to(x_dtype) + torch.var() + if flatten: + x = x.reshape(B, groups, -1) # FIXME simpler shape causing TPU / XLA issues + std = x.float().var(dim=2, unbiased=False, keepdim=True).add(eps).sqrt().to(x_dtype) + else: + x = x.reshape(B, groups, C // groups, H, W) + std = x.float().var(dim=(2, 3, 4), unbiased=False, keepdim=True).add(eps).sqrt().to(x_dtype) + return std.expand(x.shape).reshape(B, C, H, W) -def group_std_tpu(x, groups: int = 32, eps: float = 1e-5, diff_sqm: bool = False): +def group_std_tpu(x, groups: int = 32, eps: float = 1e-5, diff_sqm: bool = False, flatten: bool = False): # This is a workaround for some stability / odd behaviour of .var and .std # running on PyTorch XLA w/ TPUs. These manual var impl are producing much better results B, C, H, W = x.shape _assert(C % groups == 0, '') - x_dtype = x.dtype - x = x.float().reshape(B, groups, C // groups, H, W) - xm = x.mean(dim=(2, 3, 4), keepdim=True) - if diff_sqm: - # difference of squared mean and mean squared, faster on TPU - var = (x.square().mean(dim=(2, 3, 4), keepdim=True) - xm.square()).clamp(0) + if flatten: + x = x.reshape(B, groups, -1) # FIXME simpler shape causing TPU / XLA issues + var = manual_var(x, dim=-1, diff_sqm=diff_sqm) else: - var = (x - xm).square().mean(dim=(2, 3, 4), keepdim=True) - return var.add(eps).sqrt().expand(x.shape).reshape(B, C, H, W).to(x_dtype) -# group_std = group_std_tpu # temporary, for TPU / PT XLA + x = x.reshape(B, groups, C // groups, H, W) + var = manual_var(x, dim=(2, 3, 4), diff_sqm=diff_sqm) + return var.add(eps).sqrt().expand(x.shape).reshape(B, C, H, W) +#group_std = group_std_tpu # FIXME TPU temporary def group_rms(x, groups: int = 32, eps: float = 1e-5): @@ -75,8 +87,8 @@ def group_rms(x, groups: int = 32, eps: float = 1e-5): _assert(C % groups == 0, '') x_dtype = x.dtype x = x.reshape(B, groups, C // groups, H, W) - sqm = x.square().mean(dim=(2, 3, 4), keepdim=True).add(eps).sqrt_().to(dtype=x_dtype) - return sqm.expand(x.shape).reshape(B, C, H, W) + rms = x.float().square().mean(dim=(2, 3, 4), keepdim=True).add(eps).sqrt_().to(dtype=x_dtype) + return rms.expand(x.shape).reshape(B, C, H, W) class EvoNorm2dB0(nn.Module): diff --git a/timm/models/layers/norm_act.py b/timm/models/layers/norm_act.py index 2e15181f..40bd57ef 100644 --- a/timm/models/layers/norm_act.py +++ b/timm/models/layers/norm_act.py @@ -66,6 +66,31 @@ class BatchNormAct2d(nn.BatchNorm2d): return x +def group_norm_tpu(x, w, b, groups: int = 32, eps: float = 1e-5, diff_sqm: bool = False, flatten: bool = False): + # This is a workaround for some odd behaviour running on PyTorch XLA w/ TPUs. + x_shape = x.shape + x_dtype = x.dtype + if flatten: + norm_shape = (x_shape[0], groups, -1) + reduce_dim = -1 + else: + norm_shape = (x_shape[0], groups, x_shape[1] // groups) + x_shape[2:] + reduce_dim = tuple(range(2, x.ndim + 1)) + affine_shape = (1, -1) + (1,) * (x.ndim - 2) + x = x.reshape(norm_shape) + # x = x.to(torch.float32) # for testing w/ AMP + xm = x.mean(dim=reduce_dim, keepdim=True) + if diff_sqm: + # difference of squared mean and mean squared, faster on TPU + var = (x.square().mean(dim=reduce_dim, keepdim=True) - xm.square()).clamp(0) + else: + var = (x - xm).square().mean(dim=reduce_dim, keepdim=True) + x = (x - xm.expand(norm_shape)) / var.add(eps).sqrt().expand(norm_shape) + x = x.reshape(x_shape) * w.view(affine_shape) + b.view(affine_shape) + # x = x.to(x_dtype) # for testing w/ AMP + return x + + class GroupNormAct(nn.GroupNorm): # NOTE num_channel and num_groups order flipped for easier layer swaps / binding of fixed args def __init__(self, num_channels, num_groups=32, eps=1e-5, affine=True, @@ -80,6 +105,9 @@ class GroupNormAct(nn.GroupNorm): self.act = nn.Identity() def forward(self, x): - x = F.group_norm(x, self.num_groups, self.weight, self.bias, self.eps) + if False: # FIXME TPU temporary while resolving some performance issues + x = group_norm_tpu(x, self.weight, self.bias, self.num_groups, self.eps) + else: + x = F.group_norm(x, self.num_groups, self.weight, self.bias, self.eps) x = self.act(x) return x