Better fix for #954 that doesn't break torchscript, pull torch._assert into timm namespace when it exists

more_datasets
Ross Wightman 3 years ago
parent 4f0f9cb348
commit 2ddef942b9

@ -1,18 +1,8 @@
import torch
try:
from torch.overrides import has_torch_function, handle_torch_function
from torch import _assert
except ImportError:
from torch._overrides import has_torch_function, handle_torch_function
def _assert(condition, message):
r"""A wrapper around Python's assert which is symbolically traceable.
This is based on _assert method in torch.__init__.py but brought here to avoid reliance
on internal torch fn and allow compatibility with PyTorch < 1.8.
"""
if type(condition) is not torch.Tensor and has_torch_function((condition,)):
return handle_torch_function(_assert, (condition,), condition, message)
assert condition, message
def _assert(condition: bool, message: str):
assert condition, message
def _float_to_int(x: float) -> int:

Loading…
Cancel
Save