diff --git a/timm/models/layers/trace_utils.py b/timm/models/layers/trace_utils.py index ae7fd7c9..83970729 100644 --- a/timm/models/layers/trace_utils.py +++ b/timm/models/layers/trace_utils.py @@ -1,18 +1,8 @@ -import torch try: - from torch.overrides import has_torch_function, handle_torch_function + from torch import _assert except ImportError: - from torch._overrides import has_torch_function, handle_torch_function - - -def _assert(condition, message): - r"""A wrapper around Python's assert which is symbolically traceable. - This is based on _assert method in torch.__init__.py but brought here to avoid reliance - on internal torch fn and allow compatibility with PyTorch < 1.8. - """ - if type(condition) is not torch.Tensor and has_torch_function((condition,)): - return handle_torch_function(_assert, (condition,), condition, message) - assert condition, message + def _assert(condition: bool, message: str): + assert condition, message def _float_to_int(x: float) -> int: