From 4f0f9cb348eef16198309a679f3af5fd2132ac73 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Tue, 2 Nov 2021 09:21:40 -0700 Subject: [PATCH] Fix #954 by bringing traceable _assert into timm to allow compat w/ PyTorch < 1.8 --- timm/models/layers/__init__.py | 1 + timm/models/layers/patch_embed.py | 6 +++--- timm/models/layers/trace_utils.py | 23 +++++++++++++++++++++++ 3 files changed, 27 insertions(+), 3 deletions(-) create mode 100644 timm/models/layers/trace_utils.py diff --git a/timm/models/layers/__init__.py b/timm/models/layers/__init__.py index e9a5f18f..4831af9a 100644 --- a/timm/models/layers/__init__.py +++ b/timm/models/layers/__init__.py @@ -36,4 +36,5 @@ from .split_attn import SplitAttn from .split_batchnorm import SplitBatchNorm2d, convert_splitbn_model from .std_conv import StdConv2d, StdConv2dSame, ScaledStdConv2d, ScaledStdConv2dSame from .test_time_pool import TestTimePoolHead, apply_test_time_pool +from .trace_utils import _assert, _float_to_int from .weight_init import trunc_normal_, variance_scaling_, lecun_normal_ diff --git a/timm/models/layers/patch_embed.py b/timm/models/layers/patch_embed.py index 41528efa..6a7facef 100644 --- a/timm/models/layers/patch_embed.py +++ b/timm/models/layers/patch_embed.py @@ -6,10 +6,10 @@ Based on the impl in https://github.com/google-research/vision_transformer Hacked together by / Copyright 2020 Ross Wightman """ -import torch from torch import nn as nn from .helpers import to_2tuple +from .trace_utils import _assert class PatchEmbed(nn.Module): @@ -30,8 +30,8 @@ class PatchEmbed(nn.Module): def forward(self, x): B, C, H, W = x.shape - torch._assert(H == self.img_size[0], f"Input image height ({H}) doesn't match model ({self.img_size[0]}).") - torch._assert(W == self.img_size[1], f"Input image width ({W}) doesn't match model ({self.img_size[1]}).") + _assert(H == self.img_size[0], f"Input image height ({H}) doesn't match model ({self.img_size[0]}).") + _assert(W == self.img_size[1], f"Input image width ({W}) doesn't match model ({self.img_size[1]}).") x = self.proj(x) if self.flatten: x = x.flatten(2).transpose(1, 2) # BCHW -> BNC diff --git a/timm/models/layers/trace_utils.py b/timm/models/layers/trace_utils.py new file mode 100644 index 00000000..ae7fd7c9 --- /dev/null +++ b/timm/models/layers/trace_utils.py @@ -0,0 +1,23 @@ +import torch +try: + from torch.overrides import has_torch_function, handle_torch_function +except ImportError: + from torch._overrides import has_torch_function, handle_torch_function + + +def _assert(condition, message): + r"""A wrapper around Python's assert which is symbolically traceable. + This is based on _assert method in torch.__init__.py but brought here to avoid reliance + on internal torch fn and allow compatibility with PyTorch < 1.8. + """ + if type(condition) is not torch.Tensor and has_torch_function((condition,)): + return handle_torch_function(_assert, (condition,), condition, message) + assert condition, message + + +def _float_to_int(x: float) -> int: + """ + Symbolic tracing helper to substitute for inbuilt `int`. + Hint: Inbuilt `int` can't accept an argument of type `Proxy` + """ + return int(x)