Merge pull request #1363 from Jasha10/patch-1

Update type hint for `register_notrace_module`
pull/1606/head
Ross Wightman 2 years ago committed by GitHub
commit 45c447fc15
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -1,7 +1,7 @@
""" PyTorch FX Based Feature Extraction Helpers
Using https://pytorch.org/vision/stable/feature_extraction.html
"""
from typing import Callable, List, Dict, Union
from typing import Callable, List, Dict, Union, Type
import torch
from torch import nn
@ -35,7 +35,7 @@ except ImportError:
pass
def register_notrace_module(module: nn.Module):
def register_notrace_module(module: Type[nn.Module]):
"""
Any module not under timm.models.layers should get this decorator if we don't want to trace through it.
"""

Loading…
Cancel
Save