From 02c3a75a45deea8b8728da14e0d0b5106e06d98b Mon Sep 17 00:00:00 2001 From: Alexander Soare Date: Sat, 28 Aug 2021 17:54:22 +0100 Subject: [PATCH] wip - make it possible to use fx graph in train and eval mode --- timm/models/fx_features.py | 460 ++++++++++++++++++++++++++----------- 1 file changed, 329 insertions(+), 131 deletions(-) diff --git a/timm/models/fx_features.py b/timm/models/fx_features.py index e00b080f..9a76e041 100644 --- a/timm/models/fx_features.py +++ b/timm/models/fx_features.py @@ -7,18 +7,21 @@ An extension/alternative to timm.models.features making use of PyTorch FX. Here, 3. Write the resulting graph into a GraphModule (PyTorch FX functionality) Copyright 2021 Alexander Soare """ -from typing import Callable, Dict +from typing import Callable, Dict, Union, List, Optional import math from collections import OrderedDict from pprint import pprint from inspect import ismethod import re import warnings +from copy import deepcopy +from itertools import chain import torch from torch import nn from torch import fx import torch.nn.functional as F +from torch.fx.graph_module import _copy_attr from .features import _get_feature_info from .fx_helpers import fx_and, fx_float_to_int @@ -84,29 +87,48 @@ class LeafNodeTracer(TimmTracer): return super().is_leaf_module(m, module_qualname) +def _is_subseq(x, y): + """Check if y is a subseqence of x + https://stackoverflow.com/a/24017747/4391249 + """ + iter_x = iter(x) + return all(any(x_item == y_item for x_item in iter_x) for y_item in y) + + # Taken from https://github.com/pytorch/examples/blob/master/fx/module_tracer.py with modifications for storing # qualified names for all Nodes, not just top-level Modules class NodePathTracer(LeafNodeTracer): """ - NodePathTracer is an FX tracer that, for each operation, also records the qualified name of the Node from which the - operation originated. A qualified name here is a `.` seperated path walking the hierarchy from top level module - down to leaf operation or leaf module. The name of the top level module is not included as part of the qualified - name. For example, if we trace a module who's forward method applies a ReLU module, the qualified name for that - node will simply be 'relu'. + NodePathTracer is an FX tracer that, for each operation, also records the + qualified name of the Node from which the operation originated. A + qualified name here is a `.` seperated path walking the hierarchy from top + level module down to leaf operation or leaf module. The name of the top + level module is not included as part of the qualified name. For example, + if we trace a module who's forward method applies a ReLU module, the + qualified name for that node will simply be 'relu'. + + Some notes on the specifics: + - Nodes are recorded to `self.node_to_qualname` which is a dictionary + mapping a given Node object to its qualified name. + - Nodes are recorded in the order which they are executed during + tracing. + - When a duplicate qualified name is encountered, a suffix of the form + _{int} is added. The counter starts from 1. """ - def __init__(self): - super().__init__() + def __init__(self, *args, **kwargs): + super(NodePathTracer, self).__init__(*args, **kwargs) # Track the qualified name of the Node being traced - self.current_module_qualname : str = '' + self.current_module_qualname = '' # A map from FX Node to the qualified name self.node_to_qualname = OrderedDict() def call_module(self, m: torch.nn.Module, forward: Callable, args, kwargs): """ - Override of Tracer.call_module (see https://pytorch.org/docs/stable/fx.html#torch.fx.Tracer.call_module). + Override of `fx.Tracer.call_module` This override: 1) Stores away the qualified name of the caller for restoration later - 2) Installs the qualified name of the caller in `current_module_qualname` for retrieval by `create_proxy` + 2) Adds the qualified name of the caller to + `current_module_qualname` for retrieval by `create_proxy` 3) Once a leaf module is reached, calls `create_proxy` 4) Restores the caller's qualified name into current_module_qualname """ @@ -121,7 +143,8 @@ class NodePathTracer(LeafNodeTracer): finally: self.current_module_qualname = old_qualname - def create_proxy(self, kind: str, target: fx.node.Target, args, kwargs, name=None, type_expr=None): + def create_proxy(self, kind: str, target: fx.node.Target, args, kwargs, + name=None, type_expr=None) -> fx.proxy.Proxy: """ Override of `Tracer.create_proxy`. This override intercepts the recording of every operation and stores away the current traced module's qualified @@ -132,160 +155,335 @@ class NodePathTracer(LeafNodeTracer): self.current_module_qualname, proxy.node) return proxy - def _get_node_qualname(self, module_qualname: str, node: fx.node.Node): + def _get_node_qualname( + self, module_qualname: str, node: fx.node.Node) -> str: node_qualname = module_qualname if node.op == 'call_module': - # Node terminates in a leaf module so the module_qualname is a complete description of the node - # Just need to check if this module has appeared before. If so add postfix counter starting from _1 for the - # first reappearance (this follows the way that repeated leaf ops are enumerated by PyTorch FX) + # Node terminates in a leaf module so the module_qualname is a + # complete description of the node for existing_qualname in reversed(self.node_to_qualname.values()): - # Check to see if existing_qualname is of the form {node_qualname} or {node_qualname}_{int} - if re.match(rf'{node_qualname}(_[0-9]+)?$', existing_qualname) is not None: + # Check to see if existing_qualname is of the form + # {node_qualname} or {node_qualname}_{int} + if re.match(rf'{node_qualname}(_[0-9]+)?$', + existing_qualname) is not None: postfix = existing_qualname.replace(node_qualname, '') if len(postfix): - # existing_qualname is of the form {node_qualname}_{int} + # Existing_qualname is of the form {node_qualname}_{int} next_index = int(postfix[1:]) + 1 else: # existing_qualname is of the form {node_qualname} next_index = 1 node_qualname += f'_{next_index}' - break + break else: - # Node terminates in non- leaf module so the node name needs to be appended - if len(node_qualname) > 0: # only append '.' if we are deeper than the top level module + # Node terminates in non- leaf module so the node name needs to be + # appended + if len(node_qualname) > 0: + # Only append '.' if we are deeper than the top level module node_qualname += '.' node_qualname += str(node) return node_qualname -def print_graph_node_qualified_names(model: nn.Module): +def _warn_graph_differences( + train_tracer: NodePathTracer, eval_tracer: NodePathTracer): """ - Dev utility to prints nodes in order of execution. Useful for choosing `nodes` for a FeatureGraphNet design. - This is useful for two reasons: - 1. Not all submodules are traced through. Some are treated as leaf modules. See `LeafNodeTracer` - 2. Leaf ops that occur more than once in the graph get a `_{counter}` postfix. - - WARNING: Changes to the operations in the original module might not change the module's overall behaviour, but they - may result in changes to the postfixes for the names of repeated ops, thereby breaking feature extraction. + Utility function for warning the user if there are differences between + the train graph and the eval graph. + """ + train_nodes = list(train_tracer.node_to_qualname.values()) + eval_nodes = list(eval_tracer.node_to_qualname.values()) + + if len(train_nodes) == len(eval_nodes) and [ + t == e for t, e in zip(train_nodes, eval_nodes)]: + return + + suggestion_msg = ( + "When choosing nodes for feature extraction, you may need to specify " + "output nodes for train and eval mode separately") + + if _is_subseq(train_nodes, eval_nodes): + msg = ("NOTE: The nodes obtained by tracing the model in eval mode " + "are a subsequence of those obtained in train mode. ") + elif _is_subseq(eval_nodes, train_nodes): + msg = ("NOTE: The nodes obtained by tracing the model in train mode " + "are a subsequence of those obtained in eval mode. ") + else: + msg = ("The nodes obtained by tracing the model in train mode " + "are different to those obtained in eval mode. ") + warnings.warn(msg + suggestion_msg) + + +def print_graph_node_qualified_names( + model: nn.Module, tracer_kwargs: Dict = {}): """ - tracer = NodePathTracer() - tracer.trace(model) - pprint(list(tracer.node_to_qualname.values())) + Dev utility to prints nodes in order of execution. Useful for choosing + nodes for a FeatureGraphNet design. There are two reasons that qualified + node names can't easily be read directly from the code for a model: + 1. Not all submodules are traced through. Modules from `torch.nn` all + fall within this category. + 2. Node qualified names that occur more than once in the graph get a + `_{counter}` postfix. + The model will be traced twice: once in train mode, and once in eval mode. + If there are discrepancies between the graphs produced, both sets will + be printed and the user will be warned. + + Args: + model (nn.Module): model on which we will extract the features + tracer_kwargs (Dict): a dictionary of keywork arguments for + `NodePathTracer` (which passes them onto it's parent class + `torch.fx.Tracer`). + """ + train_tracer = NodePathTracer(**tracer_kwargs) + train_tracer.trace(model.train()) + eval_tracer = NodePathTracer(**tracer_kwargs) + eval_tracer.trace(model.eval()) + train_nodes = list(train_tracer.node_to_qualname.values()) + eval_nodes = list(eval_tracer.node_to_qualname.values()) + if len(train_nodes) == len(eval_nodes) and [ + t == e for t, e in zip(train_nodes, eval_nodes)]: + # Nodes are aligned in train vs eval mode + pprint(list(train_tracer.node_to_qualname.values())) + return + print("Nodes from train mode:") + pprint(list(train_tracer.node_to_qualname.values())) + print() + print("Nodes from eval mode:") + pprint(list(eval_tracer.node_to_qualname.values())) + print() + _warn_graph_differences(train_tracer, eval_tracer) + + +class DualGraphModule(fx.GraphModule): + """ + A derivative of `fx.GraphModule`. Differs in the following ways: + - Requires a train and eval version of the underlying graph + - Copies submodules according to the nodes of both train and eval graphs. + - Calling train(mode) switches between train graph and eval graph. + """ + def __init__(self, + root: torch.nn.Module, + train_graph: fx.Graph, + eval_graph: fx.Graph, + class_name: str = 'GraphModule'): + """ + Args: + root (torch.nn.Module): module from which the copied module + hierarchy is built + train_graph (Graph): the graph that should be used in train mode + eval_graph (Graph): the graph that should be used in eval mode + """ + super(fx.GraphModule, self).__init__() + + self.__class__.__name__ = class_name + + self.train_graph = train_graph + self.eval_graph = eval_graph + + # Copy all get_attr and call_module ops (indicated by BOTH train and + # eval graphs) + for node in chain(iter(train_graph.nodes), iter(eval_graph.nodes)): + if node.op in ['get_attr', 'call_module']: + assert isinstance(node.target, str) + _copy_attr(root, self, node.target) + + # eval mode by default + self.eval() + self.graph = eval_graph + + # (borrowed from fx.GraphModule): + # Store the Tracer class responsible for creating a Graph separately as part of the + # GraphModule state, except when the Tracer is defined in a local namespace. + # Locally defined Tracers are not pickleable. This is needed because torch.package will + # serialize a GraphModule without retaining the Graph, and needs to use the correct Tracer + # to re-create the Graph during deserialization. + # TODO uncomment this when https://github.com/pytorch/pytorch/pull/63121 is available + # assert self.eval_graph._tracer_cls == self.train_graph._tracer_cls, \ + # "Train mode and eval mode should use the same tracer class" + # self._tracer_cls = None + # if self.graph._tracer_cls and '' not in self.graph._tracer_cls.__qualname__: + # self._tracer_cls = self.graph._tracer_cls + + def train(self, mode=True): + """ + Swap out the graph depending on the training mode. + NOTE this should be safe when calling model.eval() because that just + calls this with mode == False. + """ + if mode: + self.graph = self.train_graph + else: + self.graph = self.eval_graph + return super().train(mode=mode) -def get_intermediate_nodes(model: nn.Module, return_nodes: Dict[str, str]) -> nn.Module: +def build_feature_graph_net( + model: nn.Module, + return_nodes: Union[List[str], Dict[str, str]], + train_return_nodes: Optional[Union[List[str], Dict[str, str]]] = None, + eval_return_nodes: Optional[Union[List[str], Dict[str, str]]] = None, + tracer_kwargs: Dict = {}) -> fx.GraphModule: """ - Creates a new FX-based module that returns intermediate nodes from a given model. This is achieved by re-writing - the computation graph of the model via FX to return the desired nodes as outputs. All unused nodes are removed, - together with their corresponding parameters. + Creates a new graph module that returns intermediate nodes from a given + model as dictionary with user specified keys as strings, and the requested + outputs as values. This is achieved by re-writing the computation graph of + the model via FX to return the desired nodes as outputs. All unused nodes + are removed, together with their corresponding parameters. + + A note on node specification: A node qualified name is specified as a `.` + seperated path walking the hierarchy from top level module down to leaf + operation or leaf module. For instance `blocks.5.3.bn1`. The keys of the + `return_nodes` argument should point to either a node's qualified name, + or some truncated version of it. For example, one could provide `blocks.5` + as a key, and the last node with that prefix will be selected. + `print_graph_node_qualified_names` is a useful helper function for getting + a list of qualified names of a model. + + An attempt is made to keep all non-parametric properties of the original + model, but existing properties of the constructed `GraphModule` are not + overwritten. + Args: model (nn.Module): model on which we will extract the features - return_nodes (Dict[name, new_name]): a dict containing the names (or partial names - see note below) of the - nodes for which the activations will be returned as the keys. The values of the dict are the names - of the returned activations (which the user can specify). - A note on node specification: A node is specified as a `.` seperated path walking the hierarchy from top - level module down to leaf operation or leaf module. For instance `blocks.5.3.bn1`. Nevertheless, the keys - in this dict need not be fully specified. One could provide `blocks.5` as a key, and the last node with - that prefix will be selected. - While designing a feature extractor one can use the `print_graph_node_qualified_names` utility as a guide - to which nodes are available. - Acknowledgement: Starter code from https://github.com/pytorch/vision/pull/3597 + return_nodes (Union[List[name], Dict[name, new_name]])): either a list + or a dict containing the names (or partial names - see note above) + of the nodes for which the activations will be returned. If it is + a `Dict`, the keys are the qualified node names, and the values + are the user-specified keys for the graph module's returned + dictionary. If it is a `List`, it is treated as a `Dict` mapping + node specification strings directly to output names. + tracer_kwargs (Dict): a dictionary of keywork arguments for + `NodePathTracer` (which passes them onto it's parent class + `torch.fx.Tracer`). + + Examples:: + + >>> model = torchvision.models.resnet18() + >>> # extract layer1 and layer3, giving as names `feat1` and feat2` + >>> graph_module = torchvision.models._utils.build_feature_graph_net(m, + >>> {'layer1': 'feat1', 'layer3': 'feat2'}) + >>> out = graph_module(torch.rand(1, 3, 224, 224)) + >>> print([(k, v.shape) for k, v in out.items()]) + >>> [('feat1', torch.Size([1, 64, 56, 56])), + >>> ('feat2', torch.Size([1, 256, 14, 14]))] + """ + is_training = model.training + + if isinstance(return_nodes, list): + return_nodes = {n: n for n in return_nodes} return_nodes = {str(k): str(v) for k, v in return_nodes.items()} - # Instantiate our NodePathTracer and use that to trace the model - tracer = NodePathTracer() - graph = tracer.trace(model) - - name = model.__class__.__name__ if isinstance(model, nn.Module) else model.__name__ - graph_module = fx.GraphModule(tracer.root, graph, name) - - available_nodes = [f'{v}.{k}' for k, v in tracer.node_to_qualname.items()] - # FIXME We don't know if we should expect this to happen - assert len(set(available_nodes)) == len(available_nodes), \ - "There are duplicate nodes! Please raise an issue https://github.com/rwightman/pytorch-image-models/issues" - # Check that all outputs in return_nodes are present in the model - for query in return_nodes.keys(): - if not any([m.startswith(query) for m in available_nodes]): - raise ValueError(f"return_node: {query} is not present in model") - - # Remove existing output nodes - orig_output_node = None - for n in reversed(graph_module.graph.nodes): - if n.op == "output": - orig_output_node = n - assert orig_output_node - # And remove it - graph_module.graph.erase_node(orig_output_node) - # Find nodes corresponding to return_nodes and make them into output_nodes - nodes = [n for n in graph_module.graph.nodes] - output_nodes = OrderedDict() - for n in reversed(nodes): - if 'tensor_constant' in str(n): - # NOTE Without this control flow we would get a None value for - # `module_qualname = tracer.node_to_qualname.get(n)`. On the other hand, we can safely assume that we'll - # never need to get this as an interesting intermediate node. - continue - module_qualname = tracer.node_to_qualname.get(n) - for query in return_nodes: - depth = query.count('.') - if '.'.join(module_qualname.split('.')[:depth+1]) == query: - output_nodes[return_nodes[query]] = n - return_nodes.pop(query) - break - output_nodes = OrderedDict(reversed(list(output_nodes.items()))) - - # And add them in the end of the graph - with graph_module.graph.inserting_after(nodes[-1]): - graph_module.graph.output(output_nodes) - - # Remove unused modules / parameters - graph_module.graph.eliminate_dead_code() - graph_module.recompile() - graph_module = fx.GraphModule(graph_module, graph_module.graph, name) + assert not ((train_return_nodes is None) ^ (eval_return_nodes is None)), \ + ("If any of `train_return_nodes` and `eval_return_nodes` are " + "specified, then both should be specified") + + if train_return_nodes is None: + train_return_nodes = deepcopy(return_nodes) + eval_return_nodes = deepcopy(return_nodes) + + # Repeat the tracing and graph rewriting for train and eval mode + tracers = {} + graphs = {} + return_nodes = { + 'train': train_return_nodes, + 'eval': eval_return_nodes + } + for mode in ['train', 'eval']: + if mode == 'train': + model.train() + elif mode == 'eval': + model.eval() + + # Instantiate our NodePathTracer and use that to trace the model + tracer = NodePathTracer(**tracer_kwargs) + graph = tracer.trace(model) + + name = model.__class__.__name__ if isinstance( + model, nn.Module) else model.__name__ + graph_module = fx.GraphModule(tracer.root, graph, name) + + available_nodes = [f'{v}.{k}' for k, v in tracer.node_to_qualname.items()] + # FIXME We don't know if we should expect this to happen + assert len(set(available_nodes)) == len(available_nodes), \ + "There are duplicate nodes! Please raise an issue https://github.com/pytorch/vision/issues" + # Check that all outputs in return_nodes are present in the model + for query in return_nodes[mode].keys(): + if not any([m.startswith(query) for m in available_nodes]): + raise ValueError(f"return_node: {query} is not present in model") + + # Remove existing output nodes (train mode) + orig_output_nodes = [] + for n in reversed(graph_module.graph.nodes): + if n.op == "output": + orig_output_nodes.append(n) + assert len(orig_output_nodes) + for n in orig_output_nodes: + graph_module.graph.erase_node(n) + + # Find nodes corresponding to return_nodes and make them into output_nodes + nodes = [n for n in graph_module.graph.nodes] + output_nodes = OrderedDict() + for n in reversed(nodes): + if 'tensor_constant' in str(n): + # NOTE Without this control flow we would get a None value for + # `module_qualname = tracer.node_to_qualname.get(n)`. + # On the other hand, we can safely assume that we'll never need to + # get this as an interesting intermediate node. + continue + module_qualname = tracer.node_to_qualname.get(n) + for query in return_nodes[mode]: + depth = query.count('.') + if '.'.join(module_qualname.split('.')[:depth + 1]) == query: + output_nodes[return_nodes[mode][query]] = n + return_nodes[mode].pop(query) + break + output_nodes = OrderedDict(reversed(list(output_nodes.items()))) + + # And add them in the end of the graph + with graph_module.graph.inserting_after(nodes[-1]): + graph_module.graph.output(output_nodes) + + # Remove unused modules / parameters + graph_module.graph.eliminate_dead_code() + graph_module.recompile() + + # Keep track of the tracer and graph so we can choose the main one + tracers[mode] = tracer + graphs[mode] = graph + + # Warn user if there are any discrepancies between the graphs of the + # train and eval modes + _warn_graph_differences(tracers['train'], tracers['eval']) + + # Build the final graph module + graph_module = DualGraphModule( + model, graphs['train'], graphs['eval'], class_name=name) + + # Keep non-parameter model properties for reference + for attr_str in model.__dir__(): + attr = getattr(model, attr_str) + if (not attr_str.startswith('_') + and attr_str not in graph_module.__dir__() + and not ismethod(attr) + and not isinstance(attr, (nn.Module, nn.Parameter))): + setattr(graph_module, attr_str, attr) + + # Restore original training mode + graph_module.train(is_training) + return graph_module class FeatureGraphNet(nn.Module): - """ - Take the provided model and transform it into a graph module. This class wraps the resulting graph module while - also keeping the original model's non-parameter properties for reference. The original model is discarded. - - WARNING: Changes to the operations in the original module might not change the module's overall behaviour, but they - may result in changes to the postfixes for the names of repeated ops, thereby breaking feature extraction. - - TODO: FIX THIS - WARNING: This puts the input model into eval mode prior to tracing. This means that any control flow dependent on - the model being in train mode will be lost. - """ def __init__(self, model, out_indices, out_map=None): super().__init__() - model.eval() self.feature_info = _get_feature_info(model, out_indices) if out_map is not None: assert len(out_map) == len(out_indices) - # NOTE the feature_info key is innapropriately named 'module' because prior to FX only modules could be - # provided. Recall that here, we may also provide nodes referring to individual ops return_nodes = {info['module']: out_map[i] if out_map is not None else info['module'] for i, info in enumerate(self.feature_info) if i in out_indices} - self.graph_module = get_intermediate_nodes(model, return_nodes) - # Keep non-parameter model properties for reference - for attr_str in model.__dir__(): - attr = getattr(model, attr_str) - if (not attr_str.startswith('_') and attr_str not in self.__dir__() and not ismethod(attr) - and not isinstance(attr, (nn.Module, nn.Parameter))): - setattr(self, attr_str, attr) - + self.graph_module = build_feature_graph_net(model, return_nodes) + def forward(self, x): - return list(self.graph_module(x).values()) - - def train(self, mode=True): - """ - NOTE: This also covers `self.eval()` as that just does self.train(False) - """ - if mode: - warnings.warn( - "Setting a FeatureGraphNet to training mode won't necessarily have the desired effect. Control " - "flow depending on `self.training` will follow the `False` path. See FeatureGraphNet doc-string " - "for more details.") - super().train(mode) \ No newline at end of file + return list(self.graph_module(x).values()) \ No newline at end of file