diff --git a/timm/models/fx_features.py b/timm/models/fx_features.py index a9c05b0a..b09381b7 100644 --- a/timm/models/fx_features.py +++ b/timm/models/fx_features.py @@ -1,7 +1,7 @@ """ PyTorch FX Based Feature Extraction Helpers Using https://pytorch.org/vision/stable/feature_extraction.html """ -from typing import Callable, List, Dict, Union +from typing import Callable, List, Dict, Union, Type import torch from torch import nn @@ -35,7 +35,7 @@ except ImportError: pass -def register_notrace_module(module: nn.Module): +def register_notrace_module(module: Type[nn.Module]): """ Any module not under timm.models.layers should get this decorator if we don't want to trace through it. """