You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
pytorch-image-models/timm/models/layers/trace_utils.py

24 lines
854 B

import torch
try:
from torch.overrides import has_torch_function, handle_torch_function
except ImportError:
from torch._overrides import has_torch_function, handle_torch_function
def _assert(condition, message):
r"""A wrapper around Python's assert which is symbolically traceable.
This is based on _assert method in torch.__init__.py but brought here to avoid reliance
on internal torch fn and allow compatibility with PyTorch < 1.8.
"""
if type(condition) is not torch.Tensor and has_torch_function((condition,)):
return handle_torch_function(_assert, (condition,), condition, message)
assert condition, message
def _float_to_int(x: float) -> int:
"""
Symbolic tracing helper to substitute for inbuilt `int`.
Hint: Inbuilt `int` can't accept an argument of type `Proxy`
"""
return int(x)