Better fix for #954 that doesn't break torchscript, pull torch._assert into timm namespace when it exists

more_datasets
Ross Wightman 3 years ago
parent 4f0f9cb348
commit 2ddef942b9

@ -1,17 +1,7 @@
import torch
try: try:
from torch.overrides import has_torch_function, handle_torch_function from torch import _assert
except ImportError: except ImportError:
from torch._overrides import has_torch_function, handle_torch_function def _assert(condition: bool, message: str):
def _assert(condition, message):
r"""A wrapper around Python's assert which is symbolically traceable.
This is based on _assert method in torch.__init__.py but brought here to avoid reliance
on internal torch fn and allow compatibility with PyTorch < 1.8.
"""
if type(condition) is not torch.Tensor and has_torch_function((condition,)):
return handle_torch_function(_assert, (condition,), condition, message)
assert condition, message assert condition, message

Loading…
Cancel
Save