From f670d98cb8ec70ed6e03b4be60a18faf4dc913b5 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Thu, 24 Mar 2022 21:40:34 -0700 Subject: [PATCH] Make a few more layers symbolically traceable (remove from FX leaf modules) * remove dtype kwarg from .to() calls in EvoNorm as it messed up script + trace combo * BatchNormAct2d always uses custom forward (cut & paste from original) instead of super().forward. Fixes #1176 * BlurPool groups==channels, no need to use input.dim[1] --- timm/models/fx_features.py | 10 +--------- timm/models/layers/blur_pool.py | 2 +- timm/models/layers/drop.py | 16 ++++++++-------- timm/models/layers/evo_norm.py | 30 +++++++++++++++--------------- timm/models/layers/norm_act.py | 21 ++++++--------------- 5 files changed, 31 insertions(+), 48 deletions(-) diff --git a/timm/models/fx_features.py b/timm/models/fx_features.py index cbb51980..a9c05b0a 100644 --- a/timm/models/fx_features.py +++ b/timm/models/fx_features.py @@ -15,25 +15,17 @@ except ImportError: has_fx_feature_extraction = False # Layers we went to treat as leaf modules -from .layers import Conv2dSame, ScaledStdConv2dSame, BatchNormAct2d, BlurPool2d, CondConv2d, StdConv2dSame, DropPath -from .layers import EvoNorm2dB0, EvoNorm2dB1, EvoNorm2dB2 -from .layers import EvoNorm2dS0, EvoNorm2dS0a, EvoNorm2dS1, EvoNorm2dS1a, EvoNorm2dS2, EvoNorm2dS2a +from .layers import Conv2dSame, ScaledStdConv2dSame, CondConv2d, StdConv2dSame from .layers.non_local_attn import BilinearAttnTransform from .layers.pool2d_same import MaxPool2dSame, AvgPool2dSame # NOTE: By default, any modules from timm.models.layers that we want to treat as leaf modules go here # BUT modules from timm.models should use the registration mechanism below _leaf_modules = { - BatchNormAct2d, # reason: flow control for jit scripting BilinearAttnTransform, # reason: flow control t <= 1 - BlurPool2d, # reason: TypeError: F.conv2d received Proxy in groups=x.shape[1] # Reason: get_same_padding has a max which raises a control flow error Conv2dSame, MaxPool2dSame, ScaledStdConv2dSame, StdConv2dSame, AvgPool2dSame, CondConv2d, # reason: TypeError: F.conv2d received Proxy in groups=self.groups * B (because B = x.shape[0]) - DropPath, # reason: TypeError: rand recieved Proxy in `size` argument - EvoNorm2dB0, EvoNorm2dB1, EvoNorm2dB2, # to(dtype) use that causes tracing failure (on scripted models only?) - EvoNorm2dS0, EvoNorm2dS0a, EvoNorm2dS1, EvoNorm2dS1a, EvoNorm2dS2, EvoNorm2dS2a, - } try: diff --git a/timm/models/layers/blur_pool.py b/timm/models/layers/blur_pool.py index ca4ce756..e73d8863 100644 --- a/timm/models/layers/blur_pool.py +++ b/timm/models/layers/blur_pool.py @@ -39,4 +39,4 @@ class BlurPool2d(nn.Module): def forward(self, x: torch.Tensor) -> torch.Tensor: x = F.pad(x, self.padding, 'reflect') - return F.conv2d(x, self.filt, stride=self.stride, groups=x.shape[1]) + return F.conv2d(x, self.filt, stride=self.stride, groups=self.channels) diff --git a/timm/models/layers/drop.py b/timm/models/layers/drop.py index 14945efc..ae065277 100644 --- a/timm/models/layers/drop.py +++ b/timm/models/layers/drop.py @@ -107,13 +107,13 @@ class DropBlock2d(nn.Module): def __init__( self, - drop_prob=0.1, - block_size=7, - gamma_scale=1.0, - with_noise=False, - inplace=False, - batchwise=False, - fast=True): + drop_prob: float = 0.1, + block_size: int = 7, + gamma_scale: float = 1.0, + with_noise: bool = False, + inplace: bool = False, + batchwise: bool = False, + fast: bool = True): super(DropBlock2d, self).__init__() self.drop_prob = drop_prob self.gamma_scale = gamma_scale @@ -157,7 +157,7 @@ def drop_path(x, drop_prob: float = 0., training: bool = False, scale_by_keep: b class DropPath(nn.Module): """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). """ - def __init__(self, drop_prob=None, scale_by_keep=True): + def __init__(self, drop_prob: float = 0., scale_by_keep: bool = True): super(DropPath, self).__init__() self.drop_prob = drop_prob self.scale_by_keep = scale_by_keep diff --git a/timm/models/layers/evo_norm.py b/timm/models/layers/evo_norm.py index 42636236..b643302c 100644 --- a/timm/models/layers/evo_norm.py +++ b/timm/models/layers/evo_norm.py @@ -92,7 +92,7 @@ def group_rms(x, groups: int = 32, eps: float = 1e-5): _assert(C % groups == 0, '') x_dtype = x.dtype x = x.reshape(B, groups, C // groups, H, W) - rms = x.float().square().mean(dim=(2, 3, 4), keepdim=True).add(eps).sqrt_().to(dtype=x_dtype) + rms = x.float().square().mean(dim=(2, 3, 4), keepdim=True).add(eps).sqrt_().to(x_dtype) return rms.expand(x.shape).reshape(B, C, H, W) @@ -160,14 +160,14 @@ class EvoNorm2dB1(nn.Module): n = x.numel() / x.shape[1] self.running_var.copy_( self.running_var * (1 - self.momentum) + - var.detach().to(dtype=self.running_var.dtype) * self.momentum * (n / (n - 1))) + var.detach().to(self.running_var.dtype) * self.momentum * (n / (n - 1))) else: var = self.running_var - var = var.to(dtype=x_dtype).view(v_shape) + var = var.to(x_dtype).view(v_shape) left = var.add(self.eps).sqrt_() right = (x + 1) * instance_rms(x, self.eps) x = x / left.max(right) - return x * self.weight.view(v_shape).to(dtype=x_dtype) + self.bias.view(v_shape).to(dtype=x_dtype) + return x * self.weight.view(v_shape).to(x_dtype) + self.bias.view(v_shape).to(x_dtype) class EvoNorm2dB2(nn.Module): @@ -195,14 +195,14 @@ class EvoNorm2dB2(nn.Module): n = x.numel() / x.shape[1] self.running_var.copy_( self.running_var * (1 - self.momentum) + - var.detach().to(dtype=self.running_var.dtype) * self.momentum * (n / (n - 1))) + var.detach().to(self.running_var.dtype) * self.momentum * (n / (n - 1))) else: var = self.running_var - var = var.to(dtype=x_dtype).view(v_shape) + var = var.to(x_dtype).view(v_shape) left = var.add(self.eps).sqrt_() right = instance_rms(x, self.eps) - x x = x / left.max(right) - return x * self.weight.view(v_shape).to(dtype=x_dtype) + self.bias.view(v_shape).to(dtype=x_dtype) + return x * self.weight.view(v_shape).to(x_dtype) + self.bias.view(v_shape).to(x_dtype) class EvoNorm2dS0(nn.Module): @@ -231,9 +231,9 @@ class EvoNorm2dS0(nn.Module): x_dtype = x.dtype v_shape = (1, -1, 1, 1) if self.v is not None: - v = self.v.view(v_shape).to(dtype=x_dtype) + v = self.v.view(v_shape).to(x_dtype) x = x * (x * v).sigmoid() / group_std(x, self.groups, self.eps) - return x * self.weight.view(v_shape).to(dtype=x_dtype) + self.bias.view(v_shape).to(dtype=x_dtype) + return x * self.weight.view(v_shape).to(x_dtype) + self.bias.view(v_shape).to(x_dtype) class EvoNorm2dS0a(EvoNorm2dS0): @@ -247,10 +247,10 @@ class EvoNorm2dS0a(EvoNorm2dS0): v_shape = (1, -1, 1, 1) d = group_std(x, self.groups, self.eps) if self.v is not None: - v = self.v.view(v_shape).to(dtype=x_dtype) + v = self.v.view(v_shape).to(x_dtype) x = x * (x * v).sigmoid() x = x / d - return x * self.weight.view(v_shape).to(dtype=x_dtype) + self.bias.view(v_shape).to(dtype=x_dtype) + return x * self.weight.view(v_shape).to(x_dtype) + self.bias.view(v_shape).to(x_dtype) class EvoNorm2dS1(nn.Module): @@ -284,7 +284,7 @@ class EvoNorm2dS1(nn.Module): v_shape = (1, -1, 1, 1) if self.apply_act: x = self.act(x) / group_std(x, self.groups, self.eps) - return x * self.weight.view(v_shape).to(dtype=x_dtype) + self.bias.view(v_shape).to(dtype=x_dtype) + return x * self.weight.view(v_shape).to(x_dtype) + self.bias.view(v_shape).to(x_dtype) class EvoNorm2dS1a(EvoNorm2dS1): @@ -299,7 +299,7 @@ class EvoNorm2dS1a(EvoNorm2dS1): x_dtype = x.dtype v_shape = (1, -1, 1, 1) x = self.act(x) / group_std(x, self.groups, self.eps) - return x * self.weight.view(v_shape).to(dtype=x_dtype) + self.bias.view(v_shape).to(dtype=x_dtype) + return x * self.weight.view(v_shape).to(x_dtype) + self.bias.view(v_shape).to(x_dtype) class EvoNorm2dS2(nn.Module): @@ -332,7 +332,7 @@ class EvoNorm2dS2(nn.Module): v_shape = (1, -1, 1, 1) if self.apply_act: x = self.act(x) / group_rms(x, self.groups, self.eps) - return x * self.weight.view(v_shape).to(dtype=x_dtype) + self.bias.view(v_shape).to(dtype=x_dtype) + return x * self.weight.view(v_shape).to(x_dtype) + self.bias.view(v_shape).to(x_dtype) class EvoNorm2dS2a(EvoNorm2dS2): @@ -347,4 +347,4 @@ class EvoNorm2dS2a(EvoNorm2dS2): x_dtype = x.dtype v_shape = (1, -1, 1, 1) x = self.act(x) / group_rms(x, self.groups, self.eps) - return x * self.weight.view(v_shape).to(dtype=x_dtype) + self.bias.view(v_shape).to(dtype=x_dtype) + return x * self.weight.view(v_shape).to(x_dtype) + self.bias.view(v_shape).to(x_dtype) diff --git a/timm/models/layers/norm_act.py b/timm/models/layers/norm_act.py index ae3f75c6..34c4fd64 100644 --- a/timm/models/layers/norm_act.py +++ b/timm/models/layers/norm_act.py @@ -6,6 +6,7 @@ import torch from torch import nn as nn from torch.nn import functional as F +from .trace_utils import _assert from .create_act import get_act_layer @@ -29,9 +30,10 @@ class BatchNormAct2d(nn.BatchNorm2d): else: self.act = nn.Identity() - def _forward_jit(self, x): - """ A cut & paste of the contents of the PyTorch BatchNorm2d forward function - """ + def forward(self, x): + # cut & paste of torch.nn.BatchNorm2d.forward impl to avoid issues with torchscript and tracing + _assert(x.ndim == 4, f'expected 4D input (got {x.ndim}D input)') + # exponential_average_factor is set to self.momentum # (when it is available) only so that it gets updated # in ONNX graph when this node is exported to ONNX. @@ -63,7 +65,7 @@ class BatchNormAct2d(nn.BatchNorm2d): passed when the update should occur (i.e. in training mode when they are tracked), or when buffer stats are used for normalization (i.e. in eval mode when buffers are not None). """ - return F.batch_norm( + x = F.batch_norm( x, # If buffers are not to be tracked, ensure that they won't be updated self.running_mean if not self.training or self.track_running_stats else None, @@ -74,17 +76,6 @@ class BatchNormAct2d(nn.BatchNorm2d): exponential_average_factor, self.eps, ) - - @torch.jit.ignore - def _forward_python(self, x): - return super(BatchNormAct2d, self).forward(x) - - def forward(self, x): - # FIXME cannot call parent forward() and maintain jit.script compatibility? - if torch.jit.is_scripting(): - x = self._forward_jit(x) - else: - x = self._forward_python(x) x = self.drop(x) x = self.act(x) return x