Fix #954 by bringing traceable _assert into timm to allow compat w/ PyTorch < 1.8

more_datasets
Ross Wightman 3 years ago
parent a41de1f666
commit 4f0f9cb348

@ -36,4 +36,5 @@ from .split_attn import SplitAttn
from .split_batchnorm import SplitBatchNorm2d, convert_splitbn_model from .split_batchnorm import SplitBatchNorm2d, convert_splitbn_model
from .std_conv import StdConv2d, StdConv2dSame, ScaledStdConv2d, ScaledStdConv2dSame from .std_conv import StdConv2d, StdConv2dSame, ScaledStdConv2d, ScaledStdConv2dSame
from .test_time_pool import TestTimePoolHead, apply_test_time_pool from .test_time_pool import TestTimePoolHead, apply_test_time_pool
from .trace_utils import _assert, _float_to_int
from .weight_init import trunc_normal_, variance_scaling_, lecun_normal_ from .weight_init import trunc_normal_, variance_scaling_, lecun_normal_

@ -6,10 +6,10 @@ Based on the impl in https://github.com/google-research/vision_transformer
Hacked together by / Copyright 2020 Ross Wightman Hacked together by / Copyright 2020 Ross Wightman
""" """
import torch
from torch import nn as nn from torch import nn as nn
from .helpers import to_2tuple from .helpers import to_2tuple
from .trace_utils import _assert
class PatchEmbed(nn.Module): class PatchEmbed(nn.Module):
@ -30,8 +30,8 @@ class PatchEmbed(nn.Module):
def forward(self, x): def forward(self, x):
B, C, H, W = x.shape B, C, H, W = x.shape
torch._assert(H == self.img_size[0], f"Input image height ({H}) doesn't match model ({self.img_size[0]}).") _assert(H == self.img_size[0], f"Input image height ({H}) doesn't match model ({self.img_size[0]}).")
torch._assert(W == self.img_size[1], f"Input image width ({W}) doesn't match model ({self.img_size[1]}).") _assert(W == self.img_size[1], f"Input image width ({W}) doesn't match model ({self.img_size[1]}).")
x = self.proj(x) x = self.proj(x)
if self.flatten: if self.flatten:
x = x.flatten(2).transpose(1, 2) # BCHW -> BNC x = x.flatten(2).transpose(1, 2) # BCHW -> BNC

@ -0,0 +1,23 @@
import torch
try:
from torch.overrides import has_torch_function, handle_torch_function
except ImportError:
from torch._overrides import has_torch_function, handle_torch_function
def _assert(condition, message):
r"""A wrapper around Python's assert which is symbolically traceable.
This is based on _assert method in torch.__init__.py but brought here to avoid reliance
on internal torch fn and allow compatibility with PyTorch < 1.8.
"""
if type(condition) is not torch.Tensor and has_torch_function((condition,)):
return handle_torch_function(_assert, (condition,), condition, message)
assert condition, message
def _float_to_int(x: float) -> int:
"""
Symbolic tracing helper to substitute for inbuilt `int`.
Hint: Inbuilt `int` can't accept an argument of type `Proxy`
"""
return int(x)
Loading…
Cancel
Save