Update type hint for `register_notrace_module`

register_notrace_module is used to decorate types (i.e. subclasses of nn.Module).
It is not called on module instances.
pull/1363/head
Jasha10 2 years ago committed by GitHub
parent d7b55a9429
commit 56c3a84db3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -1,7 +1,7 @@
""" PyTorch FX Based Feature Extraction Helpers
Using https://pytorch.org/vision/stable/feature_extraction.html
"""
from typing import Callable, List, Dict, Union
from typing import Callable, List, Dict, Union, Type
import torch
from torch import nn
@ -35,7 +35,7 @@ except ImportError:
pass
def register_notrace_module(module: nn.Module):
def register_notrace_module(module: Type[nn.Module]):
"""
Any module not under timm.models.layers should get this decorator if we don't want to trace through it.
"""

Loading…
Cancel
Save