wip - make it possible to use fx graph in train and eval mode

pull/800/head
Alexander Soare 3 years ago
parent a6c24b936b
commit 02c3a75a45

@ -7,18 +7,21 @@ An extension/alternative to timm.models.features making use of PyTorch FX. Here,
3. Write the resulting graph into a GraphModule (PyTorch FX functionality) 3. Write the resulting graph into a GraphModule (PyTorch FX functionality)
Copyright 2021 Alexander Soare Copyright 2021 Alexander Soare
""" """
from typing import Callable, Dict from typing import Callable, Dict, Union, List, Optional
import math import math
from collections import OrderedDict from collections import OrderedDict
from pprint import pprint from pprint import pprint
from inspect import ismethod from inspect import ismethod
import re import re
import warnings import warnings
from copy import deepcopy
from itertools import chain
import torch import torch
from torch import nn from torch import nn
from torch import fx from torch import fx
import torch.nn.functional as F import torch.nn.functional as F
from torch.fx.graph_module import _copy_attr
from .features import _get_feature_info from .features import _get_feature_info
from .fx_helpers import fx_and, fx_float_to_int from .fx_helpers import fx_and, fx_float_to_int
@ -84,29 +87,48 @@ class LeafNodeTracer(TimmTracer):
return super().is_leaf_module(m, module_qualname) return super().is_leaf_module(m, module_qualname)
def _is_subseq(x, y):
"""Check if y is a subseqence of x
https://stackoverflow.com/a/24017747/4391249
"""
iter_x = iter(x)
return all(any(x_item == y_item for x_item in iter_x) for y_item in y)
# Taken from https://github.com/pytorch/examples/blob/master/fx/module_tracer.py with modifications for storing # Taken from https://github.com/pytorch/examples/blob/master/fx/module_tracer.py with modifications for storing
# qualified names for all Nodes, not just top-level Modules # qualified names for all Nodes, not just top-level Modules
class NodePathTracer(LeafNodeTracer): class NodePathTracer(LeafNodeTracer):
""" """
NodePathTracer is an FX tracer that, for each operation, also records the qualified name of the Node from which the NodePathTracer is an FX tracer that, for each operation, also records the
operation originated. A qualified name here is a `.` seperated path walking the hierarchy from top level module qualified name of the Node from which the operation originated. A
down to leaf operation or leaf module. The name of the top level module is not included as part of the qualified qualified name here is a `.` seperated path walking the hierarchy from top
name. For example, if we trace a module who's forward method applies a ReLU module, the qualified name for that level module down to leaf operation or leaf module. The name of the top
node will simply be 'relu'. level module is not included as part of the qualified name. For example,
if we trace a module who's forward method applies a ReLU module, the
qualified name for that node will simply be 'relu'.
Some notes on the specifics:
- Nodes are recorded to `self.node_to_qualname` which is a dictionary
mapping a given Node object to its qualified name.
- Nodes are recorded in the order which they are executed during
tracing.
- When a duplicate qualified name is encountered, a suffix of the form
_{int} is added. The counter starts from 1.
""" """
def __init__(self): def __init__(self, *args, **kwargs):
super().__init__() super(NodePathTracer, self).__init__(*args, **kwargs)
# Track the qualified name of the Node being traced # Track the qualified name of the Node being traced
self.current_module_qualname : str = '' self.current_module_qualname = ''
# A map from FX Node to the qualified name # A map from FX Node to the qualified name
self.node_to_qualname = OrderedDict() self.node_to_qualname = OrderedDict()
def call_module(self, m: torch.nn.Module, forward: Callable, args, kwargs): def call_module(self, m: torch.nn.Module, forward: Callable, args, kwargs):
""" """
Override of Tracer.call_module (see https://pytorch.org/docs/stable/fx.html#torch.fx.Tracer.call_module). Override of `fx.Tracer.call_module`
This override: This override:
1) Stores away the qualified name of the caller for restoration later 1) Stores away the qualified name of the caller for restoration later
2) Installs the qualified name of the caller in `current_module_qualname` for retrieval by `create_proxy` 2) Adds the qualified name of the caller to
`current_module_qualname` for retrieval by `create_proxy`
3) Once a leaf module is reached, calls `create_proxy` 3) Once a leaf module is reached, calls `create_proxy`
4) Restores the caller's qualified name into current_module_qualname 4) Restores the caller's qualified name into current_module_qualname
""" """
@ -121,7 +143,8 @@ class NodePathTracer(LeafNodeTracer):
finally: finally:
self.current_module_qualname = old_qualname self.current_module_qualname = old_qualname
def create_proxy(self, kind: str, target: fx.node.Target, args, kwargs, name=None, type_expr=None): def create_proxy(self, kind: str, target: fx.node.Target, args, kwargs,
name=None, type_expr=None) -> fx.proxy.Proxy:
""" """
Override of `Tracer.create_proxy`. This override intercepts the recording Override of `Tracer.create_proxy`. This override intercepts the recording
of every operation and stores away the current traced module's qualified of every operation and stores away the current traced module's qualified
@ -132,160 +155,335 @@ class NodePathTracer(LeafNodeTracer):
self.current_module_qualname, proxy.node) self.current_module_qualname, proxy.node)
return proxy return proxy
def _get_node_qualname(self, module_qualname: str, node: fx.node.Node): def _get_node_qualname(
self, module_qualname: str, node: fx.node.Node) -> str:
node_qualname = module_qualname node_qualname = module_qualname
if node.op == 'call_module': if node.op == 'call_module':
# Node terminates in a leaf module so the module_qualname is a complete description of the node # Node terminates in a leaf module so the module_qualname is a
# Just need to check if this module has appeared before. If so add postfix counter starting from _1 for the # complete description of the node
# first reappearance (this follows the way that repeated leaf ops are enumerated by PyTorch FX)
for existing_qualname in reversed(self.node_to_qualname.values()): for existing_qualname in reversed(self.node_to_qualname.values()):
# Check to see if existing_qualname is of the form {node_qualname} or {node_qualname}_{int} # Check to see if existing_qualname is of the form
if re.match(rf'{node_qualname}(_[0-9]+)?$', existing_qualname) is not None: # {node_qualname} or {node_qualname}_{int}
if re.match(rf'{node_qualname}(_[0-9]+)?$',
existing_qualname) is not None:
postfix = existing_qualname.replace(node_qualname, '') postfix = existing_qualname.replace(node_qualname, '')
if len(postfix): if len(postfix):
# existing_qualname is of the form {node_qualname}_{int} # Existing_qualname is of the form {node_qualname}_{int}
next_index = int(postfix[1:]) + 1 next_index = int(postfix[1:]) + 1
else: else:
# existing_qualname is of the form {node_qualname} # existing_qualname is of the form {node_qualname}
next_index = 1 next_index = 1
node_qualname += f'_{next_index}' node_qualname += f'_{next_index}'
break break
else: else:
# Node terminates in non- leaf module so the node name needs to be appended # Node terminates in non- leaf module so the node name needs to be
if len(node_qualname) > 0: # only append '.' if we are deeper than the top level module # appended
if len(node_qualname) > 0:
# Only append '.' if we are deeper than the top level module
node_qualname += '.' node_qualname += '.'
node_qualname += str(node) node_qualname += str(node)
return node_qualname return node_qualname
def print_graph_node_qualified_names(model: nn.Module): def _warn_graph_differences(
train_tracer: NodePathTracer, eval_tracer: NodePathTracer):
""" """
Dev utility to prints nodes in order of execution. Useful for choosing `nodes` for a FeatureGraphNet design. Utility function for warning the user if there are differences between
This is useful for two reasons: the train graph and the eval graph.
1. Not all submodules are traced through. Some are treated as leaf modules. See `LeafNodeTracer` """
2. Leaf ops that occur more than once in the graph get a `_{counter}` postfix. train_nodes = list(train_tracer.node_to_qualname.values())
eval_nodes = list(eval_tracer.node_to_qualname.values())
WARNING: Changes to the operations in the original module might not change the module's overall behaviour, but they
may result in changes to the postfixes for the names of repeated ops, thereby breaking feature extraction. if len(train_nodes) == len(eval_nodes) and [
t == e for t, e in zip(train_nodes, eval_nodes)]:
return
suggestion_msg = (
"When choosing nodes for feature extraction, you may need to specify "
"output nodes for train and eval mode separately")
if _is_subseq(train_nodes, eval_nodes):
msg = ("NOTE: The nodes obtained by tracing the model in eval mode "
"are a subsequence of those obtained in train mode. ")
elif _is_subseq(eval_nodes, train_nodes):
msg = ("NOTE: The nodes obtained by tracing the model in train mode "
"are a subsequence of those obtained in eval mode. ")
else:
msg = ("The nodes obtained by tracing the model in train mode "
"are different to those obtained in eval mode. ")
warnings.warn(msg + suggestion_msg)
def print_graph_node_qualified_names(
model: nn.Module, tracer_kwargs: Dict = {}):
""" """
tracer = NodePathTracer() Dev utility to prints nodes in order of execution. Useful for choosing
tracer.trace(model) nodes for a FeatureGraphNet design. There are two reasons that qualified
pprint(list(tracer.node_to_qualname.values())) node names can't easily be read directly from the code for a model:
1. Not all submodules are traced through. Modules from `torch.nn` all
fall within this category.
2. Node qualified names that occur more than once in the graph get a
`_{counter}` postfix.
The model will be traced twice: once in train mode, and once in eval mode.
If there are discrepancies between the graphs produced, both sets will
be printed and the user will be warned.
Args:
model (nn.Module): model on which we will extract the features
tracer_kwargs (Dict): a dictionary of keywork arguments for
`NodePathTracer` (which passes them onto it's parent class
`torch.fx.Tracer`).
"""
train_tracer = NodePathTracer(**tracer_kwargs)
train_tracer.trace(model.train())
eval_tracer = NodePathTracer(**tracer_kwargs)
eval_tracer.trace(model.eval())
train_nodes = list(train_tracer.node_to_qualname.values())
eval_nodes = list(eval_tracer.node_to_qualname.values())
if len(train_nodes) == len(eval_nodes) and [
t == e for t, e in zip(train_nodes, eval_nodes)]:
# Nodes are aligned in train vs eval mode
pprint(list(train_tracer.node_to_qualname.values()))
return
print("Nodes from train mode:")
pprint(list(train_tracer.node_to_qualname.values()))
print()
print("Nodes from eval mode:")
pprint(list(eval_tracer.node_to_qualname.values()))
print()
_warn_graph_differences(train_tracer, eval_tracer)
class DualGraphModule(fx.GraphModule):
"""
A derivative of `fx.GraphModule`. Differs in the following ways:
- Requires a train and eval version of the underlying graph
- Copies submodules according to the nodes of both train and eval graphs.
- Calling train(mode) switches between train graph and eval graph.
"""
def __init__(self,
root: torch.nn.Module,
train_graph: fx.Graph,
eval_graph: fx.Graph,
class_name: str = 'GraphModule'):
"""
Args:
root (torch.nn.Module): module from which the copied module
hierarchy is built
train_graph (Graph): the graph that should be used in train mode
eval_graph (Graph): the graph that should be used in eval mode
"""
super(fx.GraphModule, self).__init__()
self.__class__.__name__ = class_name
self.train_graph = train_graph
self.eval_graph = eval_graph
# Copy all get_attr and call_module ops (indicated by BOTH train and
# eval graphs)
for node in chain(iter(train_graph.nodes), iter(eval_graph.nodes)):
if node.op in ['get_attr', 'call_module']:
assert isinstance(node.target, str)
_copy_attr(root, self, node.target)
# eval mode by default
self.eval()
self.graph = eval_graph
# (borrowed from fx.GraphModule):
# Store the Tracer class responsible for creating a Graph separately as part of the
# GraphModule state, except when the Tracer is defined in a local namespace.
# Locally defined Tracers are not pickleable. This is needed because torch.package will
# serialize a GraphModule without retaining the Graph, and needs to use the correct Tracer
# to re-create the Graph during deserialization.
# TODO uncomment this when https://github.com/pytorch/pytorch/pull/63121 is available
# assert self.eval_graph._tracer_cls == self.train_graph._tracer_cls, \
# "Train mode and eval mode should use the same tracer class"
# self._tracer_cls = None
# if self.graph._tracer_cls and '<locals>' not in self.graph._tracer_cls.__qualname__:
# self._tracer_cls = self.graph._tracer_cls
def train(self, mode=True):
"""
Swap out the graph depending on the training mode.
NOTE this should be safe when calling model.eval() because that just
calls this with mode == False.
"""
if mode:
self.graph = self.train_graph
else:
self.graph = self.eval_graph
return super().train(mode=mode)
def get_intermediate_nodes(model: nn.Module, return_nodes: Dict[str, str]) -> nn.Module: def build_feature_graph_net(
model: nn.Module,
return_nodes: Union[List[str], Dict[str, str]],
train_return_nodes: Optional[Union[List[str], Dict[str, str]]] = None,
eval_return_nodes: Optional[Union[List[str], Dict[str, str]]] = None,
tracer_kwargs: Dict = {}) -> fx.GraphModule:
""" """
Creates a new FX-based module that returns intermediate nodes from a given model. This is achieved by re-writing Creates a new graph module that returns intermediate nodes from a given
the computation graph of the model via FX to return the desired nodes as outputs. All unused nodes are removed, model as dictionary with user specified keys as strings, and the requested
together with their corresponding parameters. outputs as values. This is achieved by re-writing the computation graph of
the model via FX to return the desired nodes as outputs. All unused nodes
are removed, together with their corresponding parameters.
A note on node specification: A node qualified name is specified as a `.`
seperated path walking the hierarchy from top level module down to leaf
operation or leaf module. For instance `blocks.5.3.bn1`. The keys of the
`return_nodes` argument should point to either a node's qualified name,
or some truncated version of it. For example, one could provide `blocks.5`
as a key, and the last node with that prefix will be selected.
`print_graph_node_qualified_names` is a useful helper function for getting
a list of qualified names of a model.
An attempt is made to keep all non-parametric properties of the original
model, but existing properties of the constructed `GraphModule` are not
overwritten.
Args: Args:
model (nn.Module): model on which we will extract the features model (nn.Module): model on which we will extract the features
return_nodes (Dict[name, new_name]): a dict containing the names (or partial names - see note below) of the return_nodes (Union[List[name], Dict[name, new_name]])): either a list
nodes for which the activations will be returned as the keys. The values of the dict are the names or a dict containing the names (or partial names - see note above)
of the returned activations (which the user can specify). of the nodes for which the activations will be returned. If it is
A note on node specification: A node is specified as a `.` seperated path walking the hierarchy from top a `Dict`, the keys are the qualified node names, and the values
level module down to leaf operation or leaf module. For instance `blocks.5.3.bn1`. Nevertheless, the keys are the user-specified keys for the graph module's returned
in this dict need not be fully specified. One could provide `blocks.5` as a key, and the last node with dictionary. If it is a `List`, it is treated as a `Dict` mapping
that prefix will be selected. node specification strings directly to output names.
While designing a feature extractor one can use the `print_graph_node_qualified_names` utility as a guide tracer_kwargs (Dict): a dictionary of keywork arguments for
to which nodes are available. `NodePathTracer` (which passes them onto it's parent class
Acknowledgement: Starter code from https://github.com/pytorch/vision/pull/3597 `torch.fx.Tracer`).
Examples::
>>> model = torchvision.models.resnet18()
>>> # extract layer1 and layer3, giving as names `feat1` and feat2`
>>> graph_module = torchvision.models._utils.build_feature_graph_net(m,
>>> {'layer1': 'feat1', 'layer3': 'feat2'})
>>> out = graph_module(torch.rand(1, 3, 224, 224))
>>> print([(k, v.shape) for k, v in out.items()])
>>> [('feat1', torch.Size([1, 64, 56, 56])),
>>> ('feat2', torch.Size([1, 256, 14, 14]))]
""" """
is_training = model.training
if isinstance(return_nodes, list):
return_nodes = {n: n for n in return_nodes}
return_nodes = {str(k): str(v) for k, v in return_nodes.items()} return_nodes = {str(k): str(v) for k, v in return_nodes.items()}
# Instantiate our NodePathTracer and use that to trace the model assert not ((train_return_nodes is None) ^ (eval_return_nodes is None)), \
tracer = NodePathTracer() ("If any of `train_return_nodes` and `eval_return_nodes` are "
graph = tracer.trace(model) "specified, then both should be specified")
name = model.__class__.__name__ if isinstance(model, nn.Module) else model.__name__ if train_return_nodes is None:
graph_module = fx.GraphModule(tracer.root, graph, name) train_return_nodes = deepcopy(return_nodes)
eval_return_nodes = deepcopy(return_nodes)
available_nodes = [f'{v}.{k}' for k, v in tracer.node_to_qualname.items()]
# FIXME We don't know if we should expect this to happen # Repeat the tracing and graph rewriting for train and eval mode
assert len(set(available_nodes)) == len(available_nodes), \ tracers = {}
"There are duplicate nodes! Please raise an issue https://github.com/rwightman/pytorch-image-models/issues" graphs = {}
# Check that all outputs in return_nodes are present in the model return_nodes = {
for query in return_nodes.keys(): 'train': train_return_nodes,
if not any([m.startswith(query) for m in available_nodes]): 'eval': eval_return_nodes
raise ValueError(f"return_node: {query} is not present in model") }
for mode in ['train', 'eval']:
# Remove existing output nodes if mode == 'train':
orig_output_node = None model.train()
for n in reversed(graph_module.graph.nodes): elif mode == 'eval':
if n.op == "output": model.eval()
orig_output_node = n
assert orig_output_node # Instantiate our NodePathTracer and use that to trace the model
# And remove it tracer = NodePathTracer(**tracer_kwargs)
graph_module.graph.erase_node(orig_output_node) graph = tracer.trace(model)
# Find nodes corresponding to return_nodes and make them into output_nodes
nodes = [n for n in graph_module.graph.nodes] name = model.__class__.__name__ if isinstance(
output_nodes = OrderedDict() model, nn.Module) else model.__name__
for n in reversed(nodes): graph_module = fx.GraphModule(tracer.root, graph, name)
if 'tensor_constant' in str(n):
# NOTE Without this control flow we would get a None value for available_nodes = [f'{v}.{k}' for k, v in tracer.node_to_qualname.items()]
# `module_qualname = tracer.node_to_qualname.get(n)`. On the other hand, we can safely assume that we'll # FIXME We don't know if we should expect this to happen
# never need to get this as an interesting intermediate node. assert len(set(available_nodes)) == len(available_nodes), \
continue "There are duplicate nodes! Please raise an issue https://github.com/pytorch/vision/issues"
module_qualname = tracer.node_to_qualname.get(n) # Check that all outputs in return_nodes are present in the model
for query in return_nodes: for query in return_nodes[mode].keys():
depth = query.count('.') if not any([m.startswith(query) for m in available_nodes]):
if '.'.join(module_qualname.split('.')[:depth+1]) == query: raise ValueError(f"return_node: {query} is not present in model")
output_nodes[return_nodes[query]] = n
return_nodes.pop(query) # Remove existing output nodes (train mode)
break orig_output_nodes = []
output_nodes = OrderedDict(reversed(list(output_nodes.items()))) for n in reversed(graph_module.graph.nodes):
if n.op == "output":
# And add them in the end of the graph orig_output_nodes.append(n)
with graph_module.graph.inserting_after(nodes[-1]): assert len(orig_output_nodes)
graph_module.graph.output(output_nodes) for n in orig_output_nodes:
graph_module.graph.erase_node(n)
# Remove unused modules / parameters
graph_module.graph.eliminate_dead_code() # Find nodes corresponding to return_nodes and make them into output_nodes
graph_module.recompile() nodes = [n for n in graph_module.graph.nodes]
graph_module = fx.GraphModule(graph_module, graph_module.graph, name) output_nodes = OrderedDict()
for n in reversed(nodes):
if 'tensor_constant' in str(n):
# NOTE Without this control flow we would get a None value for
# `module_qualname = tracer.node_to_qualname.get(n)`.
# On the other hand, we can safely assume that we'll never need to
# get this as an interesting intermediate node.
continue
module_qualname = tracer.node_to_qualname.get(n)
for query in return_nodes[mode]:
depth = query.count('.')
if '.'.join(module_qualname.split('.')[:depth + 1]) == query:
output_nodes[return_nodes[mode][query]] = n
return_nodes[mode].pop(query)
break
output_nodes = OrderedDict(reversed(list(output_nodes.items())))
# And add them in the end of the graph
with graph_module.graph.inserting_after(nodes[-1]):
graph_module.graph.output(output_nodes)
# Remove unused modules / parameters
graph_module.graph.eliminate_dead_code()
graph_module.recompile()
# Keep track of the tracer and graph so we can choose the main one
tracers[mode] = tracer
graphs[mode] = graph
# Warn user if there are any discrepancies between the graphs of the
# train and eval modes
_warn_graph_differences(tracers['train'], tracers['eval'])
# Build the final graph module
graph_module = DualGraphModule(
model, graphs['train'], graphs['eval'], class_name=name)
# Keep non-parameter model properties for reference
for attr_str in model.__dir__():
attr = getattr(model, attr_str)
if (not attr_str.startswith('_')
and attr_str not in graph_module.__dir__()
and not ismethod(attr)
and not isinstance(attr, (nn.Module, nn.Parameter))):
setattr(graph_module, attr_str, attr)
# Restore original training mode
graph_module.train(is_training)
return graph_module return graph_module
class FeatureGraphNet(nn.Module): class FeatureGraphNet(nn.Module):
"""
Take the provided model and transform it into a graph module. This class wraps the resulting graph module while
also keeping the original model's non-parameter properties for reference. The original model is discarded.
WARNING: Changes to the operations in the original module might not change the module's overall behaviour, but they
may result in changes to the postfixes for the names of repeated ops, thereby breaking feature extraction.
TODO: FIX THIS
WARNING: This puts the input model into eval mode prior to tracing. This means that any control flow dependent on
the model being in train mode will be lost.
"""
def __init__(self, model, out_indices, out_map=None): def __init__(self, model, out_indices, out_map=None):
super().__init__() super().__init__()
model.eval()
self.feature_info = _get_feature_info(model, out_indices) self.feature_info = _get_feature_info(model, out_indices)
if out_map is not None: if out_map is not None:
assert len(out_map) == len(out_indices) assert len(out_map) == len(out_indices)
# NOTE the feature_info key is innapropriately named 'module' because prior to FX only modules could be
# provided. Recall that here, we may also provide nodes referring to individual ops
return_nodes = {info['module']: out_map[i] if out_map is not None else info['module'] return_nodes = {info['module']: out_map[i] if out_map is not None else info['module']
for i, info in enumerate(self.feature_info) if i in out_indices} for i, info in enumerate(self.feature_info) if i in out_indices}
self.graph_module = get_intermediate_nodes(model, return_nodes) self.graph_module = build_feature_graph_net(model, return_nodes)
# Keep non-parameter model properties for reference
for attr_str in model.__dir__():
attr = getattr(model, attr_str)
if (not attr_str.startswith('_') and attr_str not in self.__dir__() and not ismethod(attr)
and not isinstance(attr, (nn.Module, nn.Parameter))):
setattr(self, attr_str, attr)
def forward(self, x): def forward(self, x):
return list(self.graph_module(x).values()) return list(self.graph_module(x).values())
def train(self, mode=True):
"""
NOTE: This also covers `self.eval()` as that just does self.train(False)
"""
if mode:
warnings.warn(
"Setting a FeatureGraphNet to training mode won't necessarily have the desired effect. Control "
"flow depending on `self.training` will follow the `False` path. See FeatureGraphNet doc-string "
"for more details.")
super().train(mode)
Loading…
Cancel
Save