You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
pytorch-image-models/timm/models/fx_features.py

489 lines
21 KiB

""" PyTorch FX Based Feature Extraction Helpers
An extension/alternative to timm.models.features making use of PyTorch FX. Here, the idea is to:
1. Symbolically trace a model producing a graph based intermediate representation (PyTorch FX functionality with
some custom tweaks)
2. Identify desired feature extraction nodes and reconfigure them as output nodes while deleting all unecessary
nodes. (custom - inspired by https://github.com/pytorch/vision/pull/3597)
3. Write the resulting graph into a GraphModule (PyTorch FX functionality)
Copyright 2021 Alexander Soare
"""
from typing import Callable, Dict, Union, List, Optional
import math
from collections import OrderedDict
from pprint import pprint
from inspect import ismethod
import re
import warnings
from copy import deepcopy
from itertools import chain
import torch
from torch import nn
from torch import fx
import torch.nn.functional as F
from torch.fx.graph_module import _copy_attr
from .features import _get_feature_info
from .fx_helpers import fx_float_to_int
# Layers we went to treat as leaf modules for FeatureGraphNet
from .layers import Conv2dSame, ScaledStdConv2dSame, BatchNormAct2d, BlurPool2d, CondConv2d, StdConv2dSame
from .layers import GatherExcite, DropPath
from .layers.non_local_attn import BilinearAttnTransform
from .layers.pool2d_same import MaxPool2dSame, AvgPool2dSame
# These modules will not be traced through.
_leaf_modules = {
Conv2dSame, ScaledStdConv2dSame, BatchNormAct2d, BlurPool2d, CondConv2d, StdConv2dSame, GatherExcite, DropPath,
BilinearAttnTransform, MaxPool2dSame, AvgPool2dSame
}
try:
from .layers import InplaceAbn
_leaf_modules.add(InplaceAbn)
except ImportError:
pass
def register_leaf_module(module: nn.Module):
"""
Any module not under timm.models.layers should get this decorator if we don't want to trace through it.
"""
_leaf_modules.add(module)
return module
# These functions will not be traced through
_autowrap_functions=(fx_float_to_int,)
class TimmTracer(fx.Tracer):
"""
Temporary bridge from torch.fx.Tracer to include any general workarounds required to make FX work for us
"""
def __init__(self, autowrap_modules=(math, ), autowrap_functions=(), enable_cpatching=False):
super().__init__(autowrap_modules=autowrap_modules, enable_cpatching=enable_cpatching)
# FIXME: This is a workaround pending on a PyTorch PR https://github.com/pytorch/pytorch/pull/62106
self._autowrap_function_ids.update(set([id(f) for f in autowrap_functions]))
def create_node(self, kind, target, args, kwargs, name=None, type_expr=None):
# FIXME: This is a workaround pending on a PyTorch PR https://github.com/pytorch/pytorch/pull/62095
if target == F.pad:
kwargs['value'] = float(kwargs['value'])
return super().create_node(kind, target, args, kwargs, name=name, type_expr=type_expr)
class LeafNodeTracer(TimmTracer):
"""
Account for desired leaf nodes according to _leaf_modules and _autowrap functions
"""
def __init__(self):
super().__init__(autowrap_functions=_autowrap_functions)
def is_leaf_module(self, m: nn.Module, module_qualname: str) -> bool:
if isinstance(m, tuple(_leaf_modules)):
return True
return super().is_leaf_module(m, module_qualname)
def _is_subseq(x, y):
"""Check if y is a subseqence of x
https://stackoverflow.com/a/24017747/4391249
"""
iter_x = iter(x)
return all(any(x_item == y_item for x_item in iter_x) for y_item in y)
# Taken from https://github.com/pytorch/examples/blob/master/fx/module_tracer.py with modifications for storing
# qualified names for all Nodes, not just top-level Modules
class NodePathTracer(LeafNodeTracer):
"""
NodePathTracer is an FX tracer that, for each operation, also records the
qualified name of the Node from which the operation originated. A
qualified name here is a `.` seperated path walking the hierarchy from top
level module down to leaf operation or leaf module. The name of the top
level module is not included as part of the qualified name. For example,
if we trace a module who's forward method applies a ReLU module, the
qualified name for that node will simply be 'relu'.
Some notes on the specifics:
- Nodes are recorded to `self.node_to_qualname` which is a dictionary
mapping a given Node object to its qualified name.
- Nodes are recorded in the order which they are executed during
tracing.
- When a duplicate qualified name is encountered, a suffix of the form
_{int} is added. The counter starts from 1.
"""
def __init__(self, *args, **kwargs):
super(NodePathTracer, self).__init__(*args, **kwargs)
# Track the qualified name of the Node being traced
self.current_module_qualname = ''
# A map from FX Node to the qualified name
self.node_to_qualname = OrderedDict()
def call_module(self, m: torch.nn.Module, forward: Callable, args, kwargs):
"""
Override of `fx.Tracer.call_module`
This override:
1) Stores away the qualified name of the caller for restoration later
2) Adds the qualified name of the caller to
`current_module_qualname` for retrieval by `create_proxy`
3) Once a leaf module is reached, calls `create_proxy`
4) Restores the caller's qualified name into current_module_qualname
"""
old_qualname = self.current_module_qualname
try:
module_qualname = self.path_of_module(m)
self.current_module_qualname = module_qualname
if not self.is_leaf_module(m, module_qualname):
out = forward(*args, **kwargs)
return out
return self.create_proxy('call_module', module_qualname, args, kwargs)
finally:
self.current_module_qualname = old_qualname
def create_proxy(self, kind: str, target: fx.node.Target, args, kwargs,
name=None, type_expr=None) -> fx.proxy.Proxy:
"""
Override of `Tracer.create_proxy`. This override intercepts the recording
of every operation and stores away the current traced module's qualified
name in `node_to_qualname`
"""
proxy = super().create_proxy(kind, target, args, kwargs, name, type_expr)
self.node_to_qualname[proxy.node] = self._get_node_qualname(
self.current_module_qualname, proxy.node)
return proxy
def _get_node_qualname(
self, module_qualname: str, node: fx.node.Node) -> str:
node_qualname = module_qualname
if node.op == 'call_module':
# Node terminates in a leaf module so the module_qualname is a
# complete description of the node
for existing_qualname in reversed(self.node_to_qualname.values()):
# Check to see if existing_qualname is of the form
# {node_qualname} or {node_qualname}_{int}
if re.match(rf'{node_qualname}(_[0-9]+)?$',
existing_qualname) is not None:
postfix = existing_qualname.replace(node_qualname, '')
if len(postfix):
# Existing_qualname is of the form {node_qualname}_{int}
next_index = int(postfix[1:]) + 1
else:
# existing_qualname is of the form {node_qualname}
next_index = 1
node_qualname += f'_{next_index}'
break
else:
# Node terminates in non- leaf module so the node name needs to be
# appended
if len(node_qualname) > 0:
# Only append '.' if we are deeper than the top level module
node_qualname += '.'
node_qualname += str(node)
return node_qualname
def _warn_graph_differences(
train_tracer: NodePathTracer, eval_tracer: NodePathTracer):
"""
Utility function for warning the user if there are differences between
the train graph and the eval graph.
"""
train_nodes = list(train_tracer.node_to_qualname.values())
eval_nodes = list(eval_tracer.node_to_qualname.values())
if len(train_nodes) == len(eval_nodes) and [
t == e for t, e in zip(train_nodes, eval_nodes)]:
return
suggestion_msg = (
"When choosing nodes for feature extraction, you may need to specify "
"output nodes for train and eval mode separately")
if _is_subseq(train_nodes, eval_nodes):
msg = ("NOTE: The nodes obtained by tracing the model in eval mode "
"are a subsequence of those obtained in train mode. ")
elif _is_subseq(eval_nodes, train_nodes):
msg = ("NOTE: The nodes obtained by tracing the model in train mode "
"are a subsequence of those obtained in eval mode. ")
else:
msg = ("The nodes obtained by tracing the model in train mode "
"are different to those obtained in eval mode. ")
warnings.warn(msg + suggestion_msg)
def print_graph_node_qualified_names(
model: nn.Module, tracer_kwargs: Dict = {}):
"""
Dev utility to prints nodes in order of execution. Useful for choosing
nodes for a FeatureGraphNet design. There are two reasons that qualified
node names can't easily be read directly from the code for a model:
1. Not all submodules are traced through. Modules from `torch.nn` all
fall within this category.
2. Node qualified names that occur more than once in the graph get a
`_{counter}` postfix.
The model will be traced twice: once in train mode, and once in eval mode.
If there are discrepancies between the graphs produced, both sets will
be printed and the user will be warned.
Args:
model (nn.Module): model on which we will extract the features
tracer_kwargs (Dict): a dictionary of keywork arguments for
`NodePathTracer` (which passes them onto it's parent class
`torch.fx.Tracer`).
"""
train_tracer = NodePathTracer(**tracer_kwargs)
train_tracer.trace(model.train())
eval_tracer = NodePathTracer(**tracer_kwargs)
eval_tracer.trace(model.eval())
train_nodes = list(train_tracer.node_to_qualname.values())
eval_nodes = list(eval_tracer.node_to_qualname.values())
if len(train_nodes) == len(eval_nodes) and [
t == e for t, e in zip(train_nodes, eval_nodes)]:
# Nodes are aligned in train vs eval mode
pprint(list(train_tracer.node_to_qualname.values()))
return
print("Nodes from train mode:")
pprint(list(train_tracer.node_to_qualname.values()))
print()
print("Nodes from eval mode:")
pprint(list(eval_tracer.node_to_qualname.values()))
print()
_warn_graph_differences(train_tracer, eval_tracer)
class DualGraphModule(fx.GraphModule):
"""
A derivative of `fx.GraphModule`. Differs in the following ways:
- Requires a train and eval version of the underlying graph
- Copies submodules according to the nodes of both train and eval graphs.
- Calling train(mode) switches between train graph and eval graph.
"""
def __init__(self,
root: torch.nn.Module,
train_graph: fx.Graph,
eval_graph: fx.Graph,
class_name: str = 'GraphModule'):
"""
Args:
root (torch.nn.Module): module from which the copied module
hierarchy is built
train_graph (Graph): the graph that should be used in train mode
eval_graph (Graph): the graph that should be used in eval mode
"""
super(fx.GraphModule, self).__init__()
self.__class__.__name__ = class_name
self.train_graph = train_graph
self.eval_graph = eval_graph
# Copy all get_attr and call_module ops (indicated by BOTH train and
# eval graphs)
for node in chain(iter(train_graph.nodes), iter(eval_graph.nodes)):
if node.op in ['get_attr', 'call_module']:
assert isinstance(node.target, str)
_copy_attr(root, self, node.target)
# eval mode by default
self.eval()
self.graph = eval_graph
# (borrowed from fx.GraphModule):
# Store the Tracer class responsible for creating a Graph separately as part of the
# GraphModule state, except when the Tracer is defined in a local namespace.
# Locally defined Tracers are not pickleable. This is needed because torch.package will
# serialize a GraphModule without retaining the Graph, and needs to use the correct Tracer
# to re-create the Graph during deserialization.
# TODO uncomment this when https://github.com/pytorch/pytorch/pull/63121 is available
# assert self.eval_graph._tracer_cls == self.train_graph._tracer_cls, \
# "Train mode and eval mode should use the same tracer class"
# self._tracer_cls = None
# if self.graph._tracer_cls and '<locals>' not in self.graph._tracer_cls.__qualname__:
# self._tracer_cls = self.graph._tracer_cls
def train(self, mode=True):
"""
Swap out the graph depending on the training mode.
NOTE this should be safe when calling model.eval() because that just
calls this with mode == False.
"""
if mode:
self.graph = self.train_graph
else:
self.graph = self.eval_graph
return super().train(mode=mode)
def build_feature_graph_net(
model: nn.Module,
return_nodes: Union[List[str], Dict[str, str]],
train_return_nodes: Optional[Union[List[str], Dict[str, str]]] = None,
eval_return_nodes: Optional[Union[List[str], Dict[str, str]]] = None,
tracer_kwargs: Dict = {}) -> fx.GraphModule:
"""
Creates a new graph module that returns intermediate nodes from a given
model as dictionary with user specified keys as strings, and the requested
outputs as values. This is achieved by re-writing the computation graph of
the model via FX to return the desired nodes as outputs. All unused nodes
are removed, together with their corresponding parameters.
A note on node specification: A node qualified name is specified as a `.`
seperated path walking the hierarchy from top level module down to leaf
operation or leaf module. For instance `blocks.5.3.bn1`. The keys of the
`return_nodes` argument should point to either a node's qualified name,
or some truncated version of it. For example, one could provide `blocks.5`
as a key, and the last node with that prefix will be selected.
`print_graph_node_qualified_names` is a useful helper function for getting
a list of qualified names of a model.
An attempt is made to keep all non-parametric properties of the original
model, but existing properties of the constructed `GraphModule` are not
overwritten.
Args:
model (nn.Module): model on which we will extract the features
return_nodes (Union[List[name], Dict[name, new_name]])): either a list
or a dict containing the names (or partial names - see note above)
of the nodes for which the activations will be returned. If it is
a `Dict`, the keys are the qualified node names, and the values
are the user-specified keys for the graph module's returned
dictionary. If it is a `List`, it is treated as a `Dict` mapping
node specification strings directly to output names.
tracer_kwargs (Dict): a dictionary of keywork arguments for
`NodePathTracer` (which passes them onto it's parent class
`torch.fx.Tracer`).
Examples::
>>> model = torchvision.models.resnet18()
>>> # extract layer1 and layer3, giving as names `feat1` and feat2`
>>> graph_module = torchvision.models._utils.build_feature_graph_net(m,
>>> {'layer1': 'feat1', 'layer3': 'feat2'})
>>> out = graph_module(torch.rand(1, 3, 224, 224))
>>> print([(k, v.shape) for k, v in out.items()])
>>> [('feat1', torch.Size([1, 64, 56, 56])),
>>> ('feat2', torch.Size([1, 256, 14, 14]))]
"""
is_training = model.training
if isinstance(return_nodes, list):
return_nodes = {n: n for n in return_nodes}
return_nodes = {str(k): str(v) for k, v in return_nodes.items()}
assert not ((train_return_nodes is None) ^ (eval_return_nodes is None)), \
("If any of `train_return_nodes` and `eval_return_nodes` are "
"specified, then both should be specified")
if train_return_nodes is None:
train_return_nodes = deepcopy(return_nodes)
eval_return_nodes = deepcopy(return_nodes)
# Repeat the tracing and graph rewriting for train and eval mode
tracers = {}
graphs = {}
return_nodes = {
'train': train_return_nodes,
'eval': eval_return_nodes
}
for mode in ['train', 'eval']:
if mode == 'train':
model.train()
elif mode == 'eval':
model.eval()
# Instantiate our NodePathTracer and use that to trace the model
tracer = NodePathTracer(**tracer_kwargs)
graph = tracer.trace(model)
name = model.__class__.__name__ if isinstance(
model, nn.Module) else model.__name__
graph_module = fx.GraphModule(tracer.root, graph, name)
available_nodes = [f'{v}.{k}' for k, v in tracer.node_to_qualname.items()]
# FIXME We don't know if we should expect this to happen
assert len(set(available_nodes)) == len(available_nodes), \
"There are duplicate nodes! Please raise an issue https://github.com/pytorch/vision/issues"
# Check that all outputs in return_nodes are present in the model
for query in return_nodes[mode].keys():
if not any([m.startswith(query) for m in available_nodes]):
raise ValueError(f"return_node: {query} is not present in model")
# Remove existing output nodes (train mode)
orig_output_nodes = []
for n in reversed(graph_module.graph.nodes):
if n.op == "output":
orig_output_nodes.append(n)
assert len(orig_output_nodes)
for n in orig_output_nodes:
graph_module.graph.erase_node(n)
# Find nodes corresponding to return_nodes and make them into output_nodes
nodes = [n for n in graph_module.graph.nodes]
output_nodes = OrderedDict()
for n in reversed(nodes):
if 'tensor_constant' in str(n):
# NOTE Without this control flow we would get a None value for
# `module_qualname = tracer.node_to_qualname.get(n)`.
# On the other hand, we can safely assume that we'll never need to
# get this as an interesting intermediate node.
continue
module_qualname = tracer.node_to_qualname.get(n)
for query in return_nodes[mode]:
depth = query.count('.')
if '.'.join(module_qualname.split('.')[:depth + 1]) == query:
output_nodes[return_nodes[mode][query]] = n
return_nodes[mode].pop(query)
break
output_nodes = OrderedDict(reversed(list(output_nodes.items())))
# And add them in the end of the graph
with graph_module.graph.inserting_after(nodes[-1]):
graph_module.graph.output(output_nodes)
# Remove unused modules / parameters
graph_module.graph.eliminate_dead_code()
graph_module.recompile()
# Keep track of the tracer and graph so we can choose the main one
tracers[mode] = tracer
graphs[mode] = graph
# Warn user if there are any discrepancies between the graphs of the
# train and eval modes
_warn_graph_differences(tracers['train'], tracers['eval'])
# Build the final graph module
graph_module = DualGraphModule(
model, graphs['train'], graphs['eval'], class_name=name)
# Keep non-parameter model properties for reference
for attr_str in model.__dir__():
attr = getattr(model, attr_str)
if (not attr_str.startswith('_')
and attr_str not in graph_module.__dir__()
and not ismethod(attr)
and not isinstance(attr, (nn.Module, nn.Parameter))):
setattr(graph_module, attr_str, attr)
# Restore original training mode
graph_module.train(is_training)
return graph_module
class FeatureGraphNet(nn.Module):
def __init__(self, model, out_indices, out_map=None):
super().__init__()
self.feature_info = _get_feature_info(model, out_indices)
if out_map is not None:
assert len(out_map) == len(out_indices)
return_nodes = {info['module']: out_map[i] if out_map is not None else info['module']
for i, info in enumerate(self.feature_info) if i in out_indices}
self.graph_module = build_feature_graph_net(model, return_nodes)
def forward(self, x):
return list(self.graph_module(x).values())