|
|
@ -1,7 +1,7 @@
|
|
|
|
""" PyTorch FX Based Feature Extraction Helpers
|
|
|
|
""" PyTorch FX Based Feature Extraction Helpers
|
|
|
|
Using https://pytorch.org/vision/stable/feature_extraction.html
|
|
|
|
Using https://pytorch.org/vision/stable/feature_extraction.html
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
from typing import Callable, List, Dict, Union
|
|
|
|
from typing import Callable, List, Dict, Union, Type
|
|
|
|
|
|
|
|
|
|
|
|
import torch
|
|
|
|
import torch
|
|
|
|
from torch import nn
|
|
|
|
from torch import nn
|
|
|
@ -35,7 +35,7 @@ except ImportError:
|
|
|
|
pass
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def register_notrace_module(module: nn.Module):
|
|
|
|
def register_notrace_module(module: Type[nn.Module]):
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
Any module not under timm.models.layers should get this decorator if we don't want to trace through it.
|
|
|
|
Any module not under timm.models.layers should get this decorator if we don't want to trace through it.
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|