diff --git a/timm/models/davit.py b/timm/models/davit.py index 0fe89d5c..7de3c2bb 100644 --- a/timm/models/davit.py +++ b/timm/models/davit.py @@ -23,7 +23,7 @@ from torch import Tensor import torch.utils.checkpoint as checkpoint from .features import FeatureInfo -from .fx_features import register_notrace_function +from .fx_features import register_notrace_function, register_notrace_module from .helpers import build_model_with_cfg, pretrained_cfg_for_features from .layers import DropPath, to_2tuple, trunc_normal_, SelectAdaptivePool2d, ClassifierHead, Mlp from .pretrained import generate_default_cfgs