""" PyTorch FX Based Feature Extraction Helpers Using https://pytorch.org/vision/stable/feature_extraction.html """ from typing import Callable, List, Dict, Union import torch from torch import nn from .features import _get_feature_info try: from torchvision.models.feature_extraction import create_feature_extractor as _create_feature_extractor has_fx_feature_extraction = True except ImportError: has_fx_feature_extraction = False # Layers we went to treat as leaf modules from .layers import Conv2dSame, ScaledStdConv2dSame, BatchNormAct2d, BlurPool2d, CondConv2d, StdConv2dSame, DropPath from .layers import EvoNorm2dB0, EvoNorm2dB1, EvoNorm2dB2 from .layers import EvoNorm2dS0, EvoNorm2dS0a, EvoNorm2dS1, EvoNorm2dS1a, EvoNorm2dS2, EvoNorm2dS2a from .layers.non_local_attn import BilinearAttnTransform from .layers.pool2d_same import MaxPool2dSame, AvgPool2dSame # NOTE: By default, any modules from timm.models.layers that we want to treat as leaf modules go here # BUT modules from timm.models should use the registration mechanism below _leaf_modules = { BatchNormAct2d, # reason: flow control for jit scripting BilinearAttnTransform, # reason: flow control t <= 1 BlurPool2d, # reason: TypeError: F.conv2d received Proxy in groups=x.shape[1] # Reason: get_same_padding has a max which raises a control flow error Conv2dSame, MaxPool2dSame, ScaledStdConv2dSame, StdConv2dSame, AvgPool2dSame, CondConv2d, # reason: TypeError: F.conv2d received Proxy in groups=self.groups * B (because B = x.shape[0]) DropPath, # reason: TypeError: rand recieved Proxy in `size` argument EvoNorm2dB0, EvoNorm2dB1, EvoNorm2dB2, # to(dtype) use that causes tracing failure (on scripted models only?) EvoNorm2dS0, EvoNorm2dS0a, EvoNorm2dS1, EvoNorm2dS1a, EvoNorm2dS2, EvoNorm2dS2a, } try: from .layers import InplaceAbn _leaf_modules.add(InplaceAbn) except ImportError: pass def register_notrace_module(module: nn.Module): """ Any module not under timm.models.layers should get this decorator if we don't want to trace through it. """ _leaf_modules.add(module) return module # Functions we want to autowrap (treat them as leaves) _autowrap_functions = set() def register_notrace_function(func: Callable): """ Decorator for functions which ought not to be traced through """ _autowrap_functions.add(func) return func def create_feature_extractor(model: nn.Module, return_nodes: Union[Dict[str, str], List[str]]): assert has_fx_feature_extraction, 'Please update to PyTorch 1.10+, torchvision 0.11+ for FX feature extraction' return _create_feature_extractor( model, return_nodes, tracer_kwargs={'leaf_modules': list(_leaf_modules), 'autowrap_functions': list(_autowrap_functions)} ) class FeatureGraphNet(nn.Module): """ A FX Graph based feature extractor that works with the model feature_info metadata """ def __init__(self, model, out_indices, out_map=None): super().__init__() assert has_fx_feature_extraction, 'Please update to PyTorch 1.10+, torchvision 0.11+ for FX feature extraction' self.feature_info = _get_feature_info(model, out_indices) if out_map is not None: assert len(out_map) == len(out_indices) return_nodes = { info['module']: out_map[i] if out_map is not None else info['module'] for i, info in enumerate(self.feature_info) if i in out_indices} self.graph_module = create_feature_extractor(model, return_nodes) def forward(self, x): return list(self.graph_module(x).values()) class FeatureExtractNet(nn.Module): """ A standalone feature extraction wrapper that maps dict -> list or single tensor NOTE: * one can use feature_extractor directly if dictionary output is desired * unlike FeatureGraphNet, this is intended to be used standalone and not with model feature_info metadata for builtin feature extraction mode * feature_extractor can be used directly if dictionary output is acceptable Args: model: model to extract features from return_nodes: node names to return features from (dict or list) squeeze_out: if only one output, and output in list format, flatten to single tensor """ def __init__(self, model, return_nodes: Union[Dict[str, str], List[str]], squeeze_out: bool = True): super().__init__() self.squeeze_out = squeeze_out self.graph_module = create_feature_extractor(model, return_nodes) def forward(self, x) -> Union[List[torch.Tensor], torch.Tensor]: out = list(self.graph_module(x).values()) if self.squeeze_out and len(out) == 1: return out[0] return out