Make a few more layers symbolically traceable (remove from FX leaf modules)

* remove dtype kwarg from .to() calls in EvoNorm as it messed up script + trace combo
* BatchNormAct2d always uses custom forward (cut & paste from original) instead of super().forward. Fixes #1176
* BlurPool groups==channels, no need to use input.dim[1]
pull/1194/head
Ross Wightman 2 years ago
parent a9ecb880e5
commit f670d98cb8

@ -15,25 +15,17 @@ except ImportError:
has_fx_feature_extraction = False
# Layers we went to treat as leaf modules
from .layers import Conv2dSame, ScaledStdConv2dSame, BatchNormAct2d, BlurPool2d, CondConv2d, StdConv2dSame, DropPath
from .layers import EvoNorm2dB0, EvoNorm2dB1, EvoNorm2dB2
from .layers import EvoNorm2dS0, EvoNorm2dS0a, EvoNorm2dS1, EvoNorm2dS1a, EvoNorm2dS2, EvoNorm2dS2a
from .layers import Conv2dSame, ScaledStdConv2dSame, CondConv2d, StdConv2dSame
from .layers.non_local_attn import BilinearAttnTransform
from .layers.pool2d_same import MaxPool2dSame, AvgPool2dSame
# NOTE: By default, any modules from timm.models.layers that we want to treat as leaf modules go here
# BUT modules from timm.models should use the registration mechanism below
_leaf_modules = {
BatchNormAct2d, # reason: flow control for jit scripting
BilinearAttnTransform, # reason: flow control t <= 1
BlurPool2d, # reason: TypeError: F.conv2d received Proxy in groups=x.shape[1]
# Reason: get_same_padding has a max which raises a control flow error
Conv2dSame, MaxPool2dSame, ScaledStdConv2dSame, StdConv2dSame, AvgPool2dSame,
CondConv2d, # reason: TypeError: F.conv2d received Proxy in groups=self.groups * B (because B = x.shape[0])
DropPath, # reason: TypeError: rand recieved Proxy in `size` argument
EvoNorm2dB0, EvoNorm2dB1, EvoNorm2dB2, # to(dtype) use that causes tracing failure (on scripted models only?)
EvoNorm2dS0, EvoNorm2dS0a, EvoNorm2dS1, EvoNorm2dS1a, EvoNorm2dS2, EvoNorm2dS2a,
}
try:

@ -39,4 +39,4 @@ class BlurPool2d(nn.Module):
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = F.pad(x, self.padding, 'reflect')
return F.conv2d(x, self.filt, stride=self.stride, groups=x.shape[1])
return F.conv2d(x, self.filt, stride=self.stride, groups=self.channels)

@ -107,13 +107,13 @@ class DropBlock2d(nn.Module):
def __init__(
self,
drop_prob=0.1,
block_size=7,
gamma_scale=1.0,
with_noise=False,
inplace=False,
batchwise=False,
fast=True):
drop_prob: float = 0.1,
block_size: int = 7,
gamma_scale: float = 1.0,
with_noise: bool = False,
inplace: bool = False,
batchwise: bool = False,
fast: bool = True):
super(DropBlock2d, self).__init__()
self.drop_prob = drop_prob
self.gamma_scale = gamma_scale
@ -157,7 +157,7 @@ def drop_path(x, drop_prob: float = 0., training: bool = False, scale_by_keep: b
class DropPath(nn.Module):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
"""
def __init__(self, drop_prob=None, scale_by_keep=True):
def __init__(self, drop_prob: float = 0., scale_by_keep: bool = True):
super(DropPath, self).__init__()
self.drop_prob = drop_prob
self.scale_by_keep = scale_by_keep

@ -92,7 +92,7 @@ def group_rms(x, groups: int = 32, eps: float = 1e-5):
_assert(C % groups == 0, '')
x_dtype = x.dtype
x = x.reshape(B, groups, C // groups, H, W)
rms = x.float().square().mean(dim=(2, 3, 4), keepdim=True).add(eps).sqrt_().to(dtype=x_dtype)
rms = x.float().square().mean(dim=(2, 3, 4), keepdim=True).add(eps).sqrt_().to(x_dtype)
return rms.expand(x.shape).reshape(B, C, H, W)
@ -160,14 +160,14 @@ class EvoNorm2dB1(nn.Module):
n = x.numel() / x.shape[1]
self.running_var.copy_(
self.running_var * (1 - self.momentum) +
var.detach().to(dtype=self.running_var.dtype) * self.momentum * (n / (n - 1)))
var.detach().to(self.running_var.dtype) * self.momentum * (n / (n - 1)))
else:
var = self.running_var
var = var.to(dtype=x_dtype).view(v_shape)
var = var.to(x_dtype).view(v_shape)
left = var.add(self.eps).sqrt_()
right = (x + 1) * instance_rms(x, self.eps)
x = x / left.max(right)
return x * self.weight.view(v_shape).to(dtype=x_dtype) + self.bias.view(v_shape).to(dtype=x_dtype)
return x * self.weight.view(v_shape).to(x_dtype) + self.bias.view(v_shape).to(x_dtype)
class EvoNorm2dB2(nn.Module):
@ -195,14 +195,14 @@ class EvoNorm2dB2(nn.Module):
n = x.numel() / x.shape[1]
self.running_var.copy_(
self.running_var * (1 - self.momentum) +
var.detach().to(dtype=self.running_var.dtype) * self.momentum * (n / (n - 1)))
var.detach().to(self.running_var.dtype) * self.momentum * (n / (n - 1)))
else:
var = self.running_var
var = var.to(dtype=x_dtype).view(v_shape)
var = var.to(x_dtype).view(v_shape)
left = var.add(self.eps).sqrt_()
right = instance_rms(x, self.eps) - x
x = x / left.max(right)
return x * self.weight.view(v_shape).to(dtype=x_dtype) + self.bias.view(v_shape).to(dtype=x_dtype)
return x * self.weight.view(v_shape).to(x_dtype) + self.bias.view(v_shape).to(x_dtype)
class EvoNorm2dS0(nn.Module):
@ -231,9 +231,9 @@ class EvoNorm2dS0(nn.Module):
x_dtype = x.dtype
v_shape = (1, -1, 1, 1)
if self.v is not None:
v = self.v.view(v_shape).to(dtype=x_dtype)
v = self.v.view(v_shape).to(x_dtype)
x = x * (x * v).sigmoid() / group_std(x, self.groups, self.eps)
return x * self.weight.view(v_shape).to(dtype=x_dtype) + self.bias.view(v_shape).to(dtype=x_dtype)
return x * self.weight.view(v_shape).to(x_dtype) + self.bias.view(v_shape).to(x_dtype)
class EvoNorm2dS0a(EvoNorm2dS0):
@ -247,10 +247,10 @@ class EvoNorm2dS0a(EvoNorm2dS0):
v_shape = (1, -1, 1, 1)
d = group_std(x, self.groups, self.eps)
if self.v is not None:
v = self.v.view(v_shape).to(dtype=x_dtype)
v = self.v.view(v_shape).to(x_dtype)
x = x * (x * v).sigmoid()
x = x / d
return x * self.weight.view(v_shape).to(dtype=x_dtype) + self.bias.view(v_shape).to(dtype=x_dtype)
return x * self.weight.view(v_shape).to(x_dtype) + self.bias.view(v_shape).to(x_dtype)
class EvoNorm2dS1(nn.Module):
@ -284,7 +284,7 @@ class EvoNorm2dS1(nn.Module):
v_shape = (1, -1, 1, 1)
if self.apply_act:
x = self.act(x) / group_std(x, self.groups, self.eps)
return x * self.weight.view(v_shape).to(dtype=x_dtype) + self.bias.view(v_shape).to(dtype=x_dtype)
return x * self.weight.view(v_shape).to(x_dtype) + self.bias.view(v_shape).to(x_dtype)
class EvoNorm2dS1a(EvoNorm2dS1):
@ -299,7 +299,7 @@ class EvoNorm2dS1a(EvoNorm2dS1):
x_dtype = x.dtype
v_shape = (1, -1, 1, 1)
x = self.act(x) / group_std(x, self.groups, self.eps)
return x * self.weight.view(v_shape).to(dtype=x_dtype) + self.bias.view(v_shape).to(dtype=x_dtype)
return x * self.weight.view(v_shape).to(x_dtype) + self.bias.view(v_shape).to(x_dtype)
class EvoNorm2dS2(nn.Module):
@ -332,7 +332,7 @@ class EvoNorm2dS2(nn.Module):
v_shape = (1, -1, 1, 1)
if self.apply_act:
x = self.act(x) / group_rms(x, self.groups, self.eps)
return x * self.weight.view(v_shape).to(dtype=x_dtype) + self.bias.view(v_shape).to(dtype=x_dtype)
return x * self.weight.view(v_shape).to(x_dtype) + self.bias.view(v_shape).to(x_dtype)
class EvoNorm2dS2a(EvoNorm2dS2):
@ -347,4 +347,4 @@ class EvoNorm2dS2a(EvoNorm2dS2):
x_dtype = x.dtype
v_shape = (1, -1, 1, 1)
x = self.act(x) / group_rms(x, self.groups, self.eps)
return x * self.weight.view(v_shape).to(dtype=x_dtype) + self.bias.view(v_shape).to(dtype=x_dtype)
return x * self.weight.view(v_shape).to(x_dtype) + self.bias.view(v_shape).to(x_dtype)

@ -6,6 +6,7 @@ import torch
from torch import nn as nn
from torch.nn import functional as F
from .trace_utils import _assert
from .create_act import get_act_layer
@ -29,9 +30,10 @@ class BatchNormAct2d(nn.BatchNorm2d):
else:
self.act = nn.Identity()
def _forward_jit(self, x):
""" A cut & paste of the contents of the PyTorch BatchNorm2d forward function
"""
def forward(self, x):
# cut & paste of torch.nn.BatchNorm2d.forward impl to avoid issues with torchscript and tracing
_assert(x.ndim == 4, f'expected 4D input (got {x.ndim}D input)')
# exponential_average_factor is set to self.momentum
# (when it is available) only so that it gets updated
# in ONNX graph when this node is exported to ONNX.
@ -63,7 +65,7 @@ class BatchNormAct2d(nn.BatchNorm2d):
passed when the update should occur (i.e. in training mode when they are tracked), or when buffer stats are
used for normalization (i.e. in eval mode when buffers are not None).
"""
return F.batch_norm(
x = F.batch_norm(
x,
# If buffers are not to be tracked, ensure that they won't be updated
self.running_mean if not self.training or self.track_running_stats else None,
@ -74,17 +76,6 @@ class BatchNormAct2d(nn.BatchNorm2d):
exponential_average_factor,
self.eps,
)
@torch.jit.ignore
def _forward_python(self, x):
return super(BatchNormAct2d, self).forward(x)
def forward(self, x):
# FIXME cannot call parent forward() and maintain jit.script compatibility?
if torch.jit.is_scripting():
x = self._forward_jit(x)
else:
x = self._forward_python(x)
x = self.drop(x)
x = self.act(x)
return x

Loading…
Cancel
Save