From ab3ac3f25b1df54e45cc91daecb73bf2e6c30825 Mon Sep 17 00:00:00 2001 From: Alexander Soare Date: Thu, 12 Aug 2021 15:31:02 +0100 Subject: [PATCH 01/13] Add FX based FeatureGraphNet capability --- timm/models/fx_features.py | 291 +++++++++++++++++++++++++++++++++++++ timm/models/fx_helpers.py | 17 +++ timm/models/helpers.py | 3 + 3 files changed, 311 insertions(+) create mode 100644 timm/models/fx_features.py create mode 100644 timm/models/fx_helpers.py diff --git a/timm/models/fx_features.py b/timm/models/fx_features.py new file mode 100644 index 00000000..e00b080f --- /dev/null +++ b/timm/models/fx_features.py @@ -0,0 +1,291 @@ +""" PyTorch FX Based Feature Extraction Helpers +An extension/alternative to timm.models.features making use of PyTorch FX. Here, the idea is to: + 1. Symbolically trace a model producing a graph based intermediate representation (PyTorch FX functionality with + some custom tweaks) + 2. Identify desired feature extraction nodes and reconfigure them as output nodes while deleting all unecessary + nodes. (custom - inspired by https://github.com/pytorch/vision/pull/3597) + 3. Write the resulting graph into a GraphModule (PyTorch FX functionality) +Copyright 2021 Alexander Soare +""" +from typing import Callable, Dict +import math +from collections import OrderedDict +from pprint import pprint +from inspect import ismethod +import re +import warnings + +import torch +from torch import nn +from torch import fx +import torch.nn.functional as F + +from .features import _get_feature_info +from .fx_helpers import fx_and, fx_float_to_int + +# Layers we went to treat as leaf modules for FeatureGraphNet +from .layers import Conv2dSame, ScaledStdConv2dSame, BatchNormAct2d, BlurPool2d, CondConv2d, StdConv2dSame +from .layers import GatherExcite, DropPath +from .layers.non_local_attn import BilinearAttnTransform +from .layers.pool2d_same import MaxPool2dSame, AvgPool2dSame + + +# These modules will not be traced through. +_leaf_modules = { + Conv2dSame, ScaledStdConv2dSame, BatchNormAct2d, BlurPool2d, CondConv2d, StdConv2dSame, GatherExcite, DropPath, + BilinearAttnTransform, MaxPool2dSame, AvgPool2dSame +} + +try: + from .layers import InplaceAbn + _leaf_modules.add(InplaceAbn) +except ImportError: + pass + + +def register_leaf_module(module: nn.Module): + """ + Any module not under timm.models.layers should get this decorator if we don't want to trace through it. + """ + _leaf_modules.add(module) + return module + + +# These functions will not be traced through +_autowrap_functions=(fx_float_to_int, fx_and) + + +class TimmTracer(fx.Tracer): + """ + Temporary bridge from torch.fx.Tracer to include any general workarounds required to make FX work for us + """ + def __init__(self, autowrap_modules=(math, ), autowrap_functions=(), enable_cpatching=False): + super().__init__(autowrap_modules=autowrap_modules, enable_cpatching=enable_cpatching) + # FIXME: This is a workaround pending on a PyTorch PR https://github.com/pytorch/pytorch/pull/62106 + self._autowrap_function_ids.update(set([id(f) for f in autowrap_functions])) + + def create_node(self, kind, target, args, kwargs, name=None, type_expr=None): + # FIXME: This is a workaround pending on a PyTorch PR https://github.com/pytorch/pytorch/pull/62095 + if target == F.pad: + kwargs['value'] = float(kwargs['value']) + return super().create_node(kind, target, args, kwargs, name=name, type_expr=type_expr) + + +class LeafNodeTracer(TimmTracer): + """ + Account for desired leaf nodes according to _leaf_modules and _autowrap functions + """ + def __init__(self): + super().__init__(autowrap_functions=_autowrap_functions) + + def is_leaf_module(self, m: nn.Module, module_qualname: str) -> bool: + if isinstance(m, tuple(_leaf_modules)): + return True + return super().is_leaf_module(m, module_qualname) + + +# Taken from https://github.com/pytorch/examples/blob/master/fx/module_tracer.py with modifications for storing +# qualified names for all Nodes, not just top-level Modules +class NodePathTracer(LeafNodeTracer): + """ + NodePathTracer is an FX tracer that, for each operation, also records the qualified name of the Node from which the + operation originated. A qualified name here is a `.` seperated path walking the hierarchy from top level module + down to leaf operation or leaf module. The name of the top level module is not included as part of the qualified + name. For example, if we trace a module who's forward method applies a ReLU module, the qualified name for that + node will simply be 'relu'. + """ + def __init__(self): + super().__init__() + # Track the qualified name of the Node being traced + self.current_module_qualname : str = '' + # A map from FX Node to the qualified name + self.node_to_qualname = OrderedDict() + + def call_module(self, m: torch.nn.Module, forward: Callable, args, kwargs): + """ + Override of Tracer.call_module (see https://pytorch.org/docs/stable/fx.html#torch.fx.Tracer.call_module). + This override: + 1) Stores away the qualified name of the caller for restoration later + 2) Installs the qualified name of the caller in `current_module_qualname` for retrieval by `create_proxy` + 3) Once a leaf module is reached, calls `create_proxy` + 4) Restores the caller's qualified name into current_module_qualname + """ + old_qualname = self.current_module_qualname + try: + module_qualname = self.path_of_module(m) + self.current_module_qualname = module_qualname + if not self.is_leaf_module(m, module_qualname): + out = forward(*args, **kwargs) + return out + return self.create_proxy('call_module', module_qualname, args, kwargs) + finally: + self.current_module_qualname = old_qualname + + def create_proxy(self, kind: str, target: fx.node.Target, args, kwargs, name=None, type_expr=None): + """ + Override of `Tracer.create_proxy`. This override intercepts the recording + of every operation and stores away the current traced module's qualified + name in `node_to_qualname` + """ + proxy = super().create_proxy(kind, target, args, kwargs, name, type_expr) + self.node_to_qualname[proxy.node] = self._get_node_qualname( + self.current_module_qualname, proxy.node) + return proxy + + def _get_node_qualname(self, module_qualname: str, node: fx.node.Node): + node_qualname = module_qualname + if node.op == 'call_module': + # Node terminates in a leaf module so the module_qualname is a complete description of the node + # Just need to check if this module has appeared before. If so add postfix counter starting from _1 for the + # first reappearance (this follows the way that repeated leaf ops are enumerated by PyTorch FX) + for existing_qualname in reversed(self.node_to_qualname.values()): + # Check to see if existing_qualname is of the form {node_qualname} or {node_qualname}_{int} + if re.match(rf'{node_qualname}(_[0-9]+)?$', existing_qualname) is not None: + postfix = existing_qualname.replace(node_qualname, '') + if len(postfix): + # existing_qualname is of the form {node_qualname}_{int} + next_index = int(postfix[1:]) + 1 + else: + # existing_qualname is of the form {node_qualname} + next_index = 1 + node_qualname += f'_{next_index}' + break + else: + # Node terminates in non- leaf module so the node name needs to be appended + if len(node_qualname) > 0: # only append '.' if we are deeper than the top level module + node_qualname += '.' + node_qualname += str(node) + return node_qualname + + +def print_graph_node_qualified_names(model: nn.Module): + """ + Dev utility to prints nodes in order of execution. Useful for choosing `nodes` for a FeatureGraphNet design. + This is useful for two reasons: + 1. Not all submodules are traced through. Some are treated as leaf modules. See `LeafNodeTracer` + 2. Leaf ops that occur more than once in the graph get a `_{counter}` postfix. + + WARNING: Changes to the operations in the original module might not change the module's overall behaviour, but they + may result in changes to the postfixes for the names of repeated ops, thereby breaking feature extraction. + """ + tracer = NodePathTracer() + tracer.trace(model) + pprint(list(tracer.node_to_qualname.values())) + + +def get_intermediate_nodes(model: nn.Module, return_nodes: Dict[str, str]) -> nn.Module: + """ + Creates a new FX-based module that returns intermediate nodes from a given model. This is achieved by re-writing + the computation graph of the model via FX to return the desired nodes as outputs. All unused nodes are removed, + together with their corresponding parameters. + Args: + model (nn.Module): model on which we will extract the features + return_nodes (Dict[name, new_name]): a dict containing the names (or partial names - see note below) of the + nodes for which the activations will be returned as the keys. The values of the dict are the names + of the returned activations (which the user can specify). + A note on node specification: A node is specified as a `.` seperated path walking the hierarchy from top + level module down to leaf operation or leaf module. For instance `blocks.5.3.bn1`. Nevertheless, the keys + in this dict need not be fully specified. One could provide `blocks.5` as a key, and the last node with + that prefix will be selected. + While designing a feature extractor one can use the `print_graph_node_qualified_names` utility as a guide + to which nodes are available. + Acknowledgement: Starter code from https://github.com/pytorch/vision/pull/3597 + """ + return_nodes = {str(k): str(v) for k, v in return_nodes.items()} + + # Instantiate our NodePathTracer and use that to trace the model + tracer = NodePathTracer() + graph = tracer.trace(model) + + name = model.__class__.__name__ if isinstance(model, nn.Module) else model.__name__ + graph_module = fx.GraphModule(tracer.root, graph, name) + + available_nodes = [f'{v}.{k}' for k, v in tracer.node_to_qualname.items()] + # FIXME We don't know if we should expect this to happen + assert len(set(available_nodes)) == len(available_nodes), \ + "There are duplicate nodes! Please raise an issue https://github.com/rwightman/pytorch-image-models/issues" + # Check that all outputs in return_nodes are present in the model + for query in return_nodes.keys(): + if not any([m.startswith(query) for m in available_nodes]): + raise ValueError(f"return_node: {query} is not present in model") + + # Remove existing output nodes + orig_output_node = None + for n in reversed(graph_module.graph.nodes): + if n.op == "output": + orig_output_node = n + assert orig_output_node + # And remove it + graph_module.graph.erase_node(orig_output_node) + # Find nodes corresponding to return_nodes and make them into output_nodes + nodes = [n for n in graph_module.graph.nodes] + output_nodes = OrderedDict() + for n in reversed(nodes): + if 'tensor_constant' in str(n): + # NOTE Without this control flow we would get a None value for + # `module_qualname = tracer.node_to_qualname.get(n)`. On the other hand, we can safely assume that we'll + # never need to get this as an interesting intermediate node. + continue + module_qualname = tracer.node_to_qualname.get(n) + for query in return_nodes: + depth = query.count('.') + if '.'.join(module_qualname.split('.')[:depth+1]) == query: + output_nodes[return_nodes[query]] = n + return_nodes.pop(query) + break + output_nodes = OrderedDict(reversed(list(output_nodes.items()))) + + # And add them in the end of the graph + with graph_module.graph.inserting_after(nodes[-1]): + graph_module.graph.output(output_nodes) + + # Remove unused modules / parameters + graph_module.graph.eliminate_dead_code() + graph_module.recompile() + graph_module = fx.GraphModule(graph_module, graph_module.graph, name) + return graph_module + + +class FeatureGraphNet(nn.Module): + """ + Take the provided model and transform it into a graph module. This class wraps the resulting graph module while + also keeping the original model's non-parameter properties for reference. The original model is discarded. + + WARNING: Changes to the operations in the original module might not change the module's overall behaviour, but they + may result in changes to the postfixes for the names of repeated ops, thereby breaking feature extraction. + + TODO: FIX THIS + WARNING: This puts the input model into eval mode prior to tracing. This means that any control flow dependent on + the model being in train mode will be lost. + """ + def __init__(self, model, out_indices, out_map=None): + super().__init__() + model.eval() + self.feature_info = _get_feature_info(model, out_indices) + if out_map is not None: + assert len(out_map) == len(out_indices) + # NOTE the feature_info key is innapropriately named 'module' because prior to FX only modules could be + # provided. Recall that here, we may also provide nodes referring to individual ops + return_nodes = {info['module']: out_map[i] if out_map is not None else info['module'] + for i, info in enumerate(self.feature_info) if i in out_indices} + self.graph_module = get_intermediate_nodes(model, return_nodes) + # Keep non-parameter model properties for reference + for attr_str in model.__dir__(): + attr = getattr(model, attr_str) + if (not attr_str.startswith('_') and attr_str not in self.__dir__() and not ismethod(attr) + and not isinstance(attr, (nn.Module, nn.Parameter))): + setattr(self, attr_str, attr) + + def forward(self, x): + return list(self.graph_module(x).values()) + + def train(self, mode=True): + """ + NOTE: This also covers `self.eval()` as that just does self.train(False) + """ + if mode: + warnings.warn( + "Setting a FeatureGraphNet to training mode won't necessarily have the desired effect. Control " + "flow depending on `self.training` will follow the `False` path. See FeatureGraphNet doc-string " + "for more details.") + super().train(mode) \ No newline at end of file diff --git a/timm/models/fx_helpers.py b/timm/models/fx_helpers.py new file mode 100644 index 00000000..1955d5b1 --- /dev/null +++ b/timm/models/fx_helpers.py @@ -0,0 +1,17 @@ + + +def fx_and(a: bool, b: bool) -> bool: + """ + Symbolic tracing helper to substitute for normal usage of `* and *` within `torch._assert`. + Hint: Symbolic tracing does not support control flow but since an `assert` is either a dead-end or not, this hack + is okay. + """ + return (a and b) + + +def fx_float_to_int(x: float) -> int: + """ + Symbolic tracing helper to substitute for inbuilt `int`. + Hint: Inbuilt `int` can't accept an argument of type `Proxy` + """ + return int(x) \ No newline at end of file diff --git a/timm/models/helpers.py b/timm/models/helpers.py index bd97cf20..4cb571f4 100644 --- a/timm/models/helpers.py +++ b/timm/models/helpers.py @@ -14,6 +14,7 @@ import torch.nn as nn from .features import FeatureListNet, FeatureDictNet, FeatureHookNet +from .fx_features import FeatureGraphNet from .hub import has_hf_hub, download_cached_file, load_state_dict_from_hf, load_state_dict_from_url from .layers import Conv2dSame, Linear @@ -477,6 +478,8 @@ def build_model_with_cfg( feature_cls = feature_cls.lower() if 'hook' in feature_cls: feature_cls = FeatureHookNet + elif feature_cls == 'fx': + feature_cls = FeatureGraphNet else: assert False, f'Unknown feature class {feature_cls}' model = feature_cls(model, **feature_cfg) From bc3d4eb403b9f40c23ed5bedd8875bba911e3cdc Mon Sep 17 00:00:00 2001 From: Alexander Soare Date: Sun, 7 Nov 2021 15:04:19 +0000 Subject: [PATCH 02/13] wip -rebase --- timm/models/cait.py | 8 +- timm/models/coat.py | 10 +- timm/models/convit.py | 10 +- timm/models/layers/bottleneck_attn.py | 9 +- timm/models/layers/evo_norm.py | 4 +- timm/models/layers/global_context.py | 3 +- timm/models/layers/halo_attn.py | 7 +- timm/models/layers/lambda_layer.py | 4 +- timm/models/layers/non_local_attn.py | 5 +- timm/models/layers/patch_embed.py | 4 + timm/models/layers/selective_kernel.py | 2 +- timm/models/layers/swin_attn.py | 183 +++++++++++++++++++++++++ timm/models/levit.py | 8 +- timm/models/nest.py | 16 ++- timm/models/nfnet.py | 2 + timm/models/rexnet.py | 3 +- timm/models/swin_transformer.py | 14 +- timm/models/tnt.py | 14 +- timm/models/twins.py | 10 +- timm/models/vgg.py | 2 + timm/models/visformer.py | 4 +- timm/models/vision_transformer.py | 4 +- timm/models/xcit.py | 6 +- 23 files changed, 269 insertions(+), 63 deletions(-) create mode 100644 timm/models/layers/swin_attn.py diff --git a/timm/models/cait.py b/timm/models/cait.py index 69b4ba06..b6a18ce3 100644 --- a/timm/models/cait.py +++ b/timm/models/cait.py @@ -95,11 +95,11 @@ class ClassAttn(nn.Module): q = q * self.scale v = self.v(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) - attn = (q @ k.transpose(-2, -1)) + attn = torch.matmul(q, k.transpose(-2, -1)) attn = attn.softmax(dim=-1) attn = self.attn_drop(attn) - x_cls = (attn @ v).transpose(1, 2).reshape(B, 1, C) + x_cls = torch.matmul(attn, v).transpose(1, 2).reshape(B, 1, C) x_cls = self.proj(x_cls) x_cls = self.proj_drop(x_cls) @@ -158,7 +158,7 @@ class TalkingHeadAttn(nn.Module): qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) q, k, v = qkv[0] * self.scale, qkv[1], qkv[2] - attn = (q @ k.transpose(-2, -1)) + attn = torch.matmul(q, k.transpose(-2, -1)) attn = self.proj_l(attn.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) @@ -167,7 +167,7 @@ class TalkingHeadAttn(nn.Module): attn = self.proj_w(attn.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) attn = self.attn_drop(attn) - x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = torch.matmul(attn, v).transpose(1, 2).reshape(B, N, C) x = self.proj(x) x = self.proj_drop(x) return x diff --git a/timm/models/coat.py b/timm/models/coat.py index f071715a..69b1bd9f 100644 --- a/timm/models/coat.py +++ b/timm/models/coat.py @@ -105,7 +105,7 @@ class ConvRelPosEnc(nn.Module): def forward(self, q, v, size: Tuple[int, int]): B, h, N, Ch = q.shape H, W = size - assert N == 1 + H * W + torch._assert(N == 1 + H * W, '') # Convolutional relative position encoding. q_img = q[:, :, 1:, :] # [B, h, H*W, Ch] @@ -149,8 +149,8 @@ class FactorAtt_ConvRelPosEnc(nn.Module): # Factorized attention. k_softmax = k.softmax(dim=2) - factor_att = k_softmax.transpose(-1, -2) @ v - factor_att = q @ factor_att + factor_att = torch.matmul(k_softmax.transpose(-1, -2), v) + factor_att = torch.matmul(q, factor_att) # Convolutional relative position encoding. crpe = self.crpe(q, v, size=size) # [B, h, N, Ch] @@ -177,7 +177,7 @@ class ConvPosEnc(nn.Module): def forward(self, x, size: Tuple[int, int]): B, N, C = x.shape H, W = size - assert N == 1 + H * W + torch._assert(N == 1 + H * W, '') # Extract CLS token and image tokens. cls_token, img_tokens = x[:, :1], x[:, 1:] # [B, 1, C], [B, H*W, C] @@ -275,7 +275,7 @@ class ParallelBlock(nn.Module): """ Feature map interpolation. """ B, N, C = x.shape H, W = size - assert N == 1 + H * W + torch._assert(N == 1 + H * W, '') cls_token = x[:, :1, :] img_tokens = x[:, 1:, :] diff --git a/timm/models/convit.py b/timm/models/convit.py index f58249ec..603548f9 100644 --- a/timm/models/convit.py +++ b/timm/models/convit.py @@ -30,6 +30,7 @@ from .helpers import build_model_with_cfg from .layers import DropPath, to_2tuple, trunc_normal_, PatchEmbed, Mlp from .registry import register_model from .vision_transformer_hybrid import HybridEmbed +from .fx_features import register_leaf_module import torch import torch.nn as nn @@ -56,6 +57,7 @@ default_cfgs = { } +@register_leaf_module # FX can't symbolically trace control flow in forward method class GPSA(nn.Module): def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0., locality_strength=1.): @@ -82,7 +84,7 @@ class GPSA(nn.Module): self.rel_indices = self.get_rel_indices(N) attn = self.get_attention(x) v = self.v(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) - x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = torch.matmul(attn, v).transpose(1, 2).reshape(B, N, C) x = self.proj(x) x = self.proj_drop(x) return x @@ -93,7 +95,7 @@ class GPSA(nn.Module): q, k = qk[0], qk[1] pos_score = self.rel_indices.expand(B, -1, -1, -1) pos_score = self.pos_proj(pos_score).permute(0, 3, 1, 2) - patch_score = (q @ k.transpose(-2, -1)) * self.scale + patch_score = torch.matmul(q, k.transpose(-2, -1)) * self.scale patch_score = patch_score.softmax(dim=-1) pos_score = pos_score.softmax(dim=-1) @@ -178,11 +180,11 @@ class MHSA(nn.Module): qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) q, k, v = qkv[0], qkv[1], qkv[2] - attn = (q @ k.transpose(-2, -1)) * self.scale + attn = torch.matmul(q, k.transpose(-2, -1)) * self.scale attn = attn.softmax(dim=-1) attn = self.attn_drop(attn) - x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = torch.matmul(attn, v).transpose(1, 2).reshape(B, N, C) x = self.proj(x) x = self.proj_drop(x) return x diff --git a/timm/models/layers/bottleneck_attn.py b/timm/models/layers/bottleneck_attn.py index f55fd989..305f9de3 100644 --- a/timm/models/layers/bottleneck_attn.py +++ b/timm/models/layers/bottleneck_attn.py @@ -22,6 +22,7 @@ import torch.nn.functional as F from .helpers import to_2tuple, make_divisible from .weight_init import trunc_normal_ +from timm.models.fx_helpers import fx_and def rel_logits_1d(q, rel_k, permute_mask: List[int]): @@ -36,7 +37,7 @@ def rel_logits_1d(q, rel_k, permute_mask: List[int]): permute_mask: permute output dim according to this """ B, H, W, dim = q.shape - x = (q @ rel_k.transpose(-1, -2)) + x = torch.matmul(q, rel_k.transpose(-1, -2)) x = x.reshape(-1, W, 2 * W -1) # pad to shift from relative to absolute indexing @@ -133,8 +134,8 @@ class BottleneckAttn(nn.Module): def forward(self, x): B, C, H, W = x.shape - assert H == self.pos_embed.height - assert W == self.pos_embed.width + torch._assert(H == self.pos_embed.height, '') + torch._assert(W == self.pos_embed.width, '') x = self.qkv(x) # B, (2 * dim_head_qk + dim_head_v) * num_heads, H, W @@ -154,5 +155,3 @@ class BottleneckAttn(nn.Module): out = (attn @ v).transpose(-1, -2).reshape(B, self.dim_out_v, H, W) # B, dim_out, H, W out = self.pool(out) return out - - diff --git a/timm/models/layers/evo_norm.py b/timm/models/layers/evo_norm.py index 9023afd0..02aa0a0c 100644 --- a/timm/models/layers/evo_norm.py +++ b/timm/models/layers/evo_norm.py @@ -72,9 +72,9 @@ class EvoNormSample2d(nn.Module): nn.init.ones_(self.v) def forward(self, x): - assert x.dim() == 4, 'expected 4D input' + torch._assert(x.dim() == 4, 'expected 4D input') B, C, H, W = x.shape - assert C % self.groups == 0 + torch._assert(C % self.groups == 0, '') if self.apply_act: n = x * (x * self.v).sigmoid() x = x.reshape(B, self.groups, -1) diff --git a/timm/models/layers/global_context.py b/timm/models/layers/global_context.py index de7fb5c1..a0bb8a43 100644 --- a/timm/models/layers/global_context.py +++ b/timm/models/layers/global_context.py @@ -7,6 +7,7 @@ Official code consulted as reference: https://github.com/xvjiarui/GCNet Hacked together by / Copyright 2021 Ross Wightman """ +import torch from torch import nn as nn import torch.nn.functional as F @@ -52,7 +53,7 @@ class GlobalContext(nn.Module): if self.conv_attn is not None: attn = self.conv_attn(x).reshape(B, 1, H * W) # (B, 1, H * W) attn = F.softmax(attn, dim=-1).unsqueeze(3) # (B, 1, H * W, 1) - context = x.reshape(B, C, H * W).unsqueeze(1) @ attn + context = torch.matmul(x.reshape(B, C, H * W).unsqueeze(1), attn) context = context.view(B, C, 1, 1) else: context = x.mean(dim=(2, 3), keepdim=True) diff --git a/timm/models/layers/halo_attn.py b/timm/models/layers/halo_attn.py index 4149e812..0bd611b1 100644 --- a/timm/models/layers/halo_attn.py +++ b/timm/models/layers/halo_attn.py @@ -24,6 +24,7 @@ import torch.nn.functional as F from .helpers import make_divisible from .weight_init import trunc_normal_ +from timm.models.fx_helpers import fx_and def rel_logits_1d(q, rel_k, permute_mask: List[int]): @@ -41,7 +42,7 @@ def rel_logits_1d(q, rel_k, permute_mask: List[int]): rel_size = rel_k.shape[0] win_size = (rel_size + 1) // 2 - x = (q @ rel_k.transpose(-1, -2)) + x = torch.matmul(q, rel_k.transpose(-1, -2)) x = x.reshape(-1, W, rel_size) # pad to shift from relative to absolute indexing @@ -167,8 +168,8 @@ class HaloAttn(nn.Module): def forward(self, x): B, C, H, W = x.shape - assert H % self.block_size == 0 - assert W % self.block_size == 0 + torch._assert(H % self.block_size == 0, '') + torch._assert(W % self.block_size == 0, '') num_h_blocks = H // self.block_size num_w_blocks = W // self.block_size num_blocks = num_h_blocks * num_w_blocks diff --git a/timm/models/layers/lambda_layer.py b/timm/models/layers/lambda_layer.py index e50b43c8..058426b6 100644 --- a/timm/models/layers/lambda_layer.py +++ b/timm/models/layers/lambda_layer.py @@ -116,8 +116,8 @@ class LambdaLayer(nn.Module): v = self.norm_v(v).reshape(B, self.dim_v, M).transpose(-1, -2) # B, M, V k = F.softmax(k.reshape(B, self.dim_qk, M), dim=-1) # B, K, M - content_lam = k @ v # B, K, V - content_out = q @ content_lam.unsqueeze(1) # B, num_heads, M, V + content_lam = torch.matmul(k, v) # B, K, V + content_out = torch.matmul(q, content_lam.unsqueeze(1)) # B, num_heads, M, V if self.pos_emb is None: position_lam = self.conv_lambda(v.reshape(B, 1, H, W, self.dim_v)) # B, H, W, V, K diff --git a/timm/models/layers/non_local_attn.py b/timm/models/layers/non_local_attn.py index a537d60e..517e28a8 100644 --- a/timm/models/layers/non_local_attn.py +++ b/timm/models/layers/non_local_attn.py @@ -10,6 +10,7 @@ from torch.nn import functional as F from .conv_bn_act import ConvBnAct from .helpers import make_divisible +from timm.models.fx_helpers import fx_and class NonLocalAttn(nn.Module): @@ -83,7 +84,7 @@ class BilinearAttnTransform(nn.Module): def resize_mat(self, x, t: int): B, C, block_size, block_size1 = x.shape - assert block_size == block_size1 + torch._assert(block_size == block_size1, '') if t <= 1: return x x = x.view(B * C, -1, 1, 1) @@ -95,7 +96,7 @@ class BilinearAttnTransform(nn.Module): return x def forward(self, x): - assert x.shape[-1] % self.block_size == 0 and x.shape[-2] % self.block_size == 0 + torch._assert(fx_and(x.shape[-1] % self.block_size == 0, x.shape[-2] % self.block_size == 0), '') B, C, H, W = x.shape out = self.conv1(x) rp = F.adaptive_max_pool2d(out, (self.block_size, 1)) diff --git a/timm/models/layers/patch_embed.py b/timm/models/layers/patch_embed.py index 6a7facef..157bc250 100644 --- a/timm/models/layers/patch_embed.py +++ b/timm/models/layers/patch_embed.py @@ -9,7 +9,11 @@ Hacked together by / Copyright 2020 Ross Wightman from torch import nn as nn from .helpers import to_2tuple +<<<<<<< HEAD from .trace_utils import _assert +======= +from timm.models.fx_helpers import fx_and +>>>>>>> Make all models FX traceable class PatchEmbed(nn.Module): diff --git a/timm/models/layers/selective_kernel.py b/timm/models/layers/selective_kernel.py index f28b8d2e..69aca86b 100644 --- a/timm/models/layers/selective_kernel.py +++ b/timm/models/layers/selective_kernel.py @@ -34,7 +34,7 @@ class SelectiveKernelAttn(nn.Module): self.fc_select = nn.Conv2d(attn_channels, channels * num_paths, kernel_size=1, bias=False) def forward(self, x): - assert x.shape[1] == self.num_paths + torch._assert(x.shape[1] == self.num_paths, '') x = x.sum(1).mean((2, 3), keepdim=True) x = self.fc_reduce(x) x = self.bn(x) diff --git a/timm/models/layers/swin_attn.py b/timm/models/layers/swin_attn.py new file mode 100644 index 00000000..2a3731f3 --- /dev/null +++ b/timm/models/layers/swin_attn.py @@ -0,0 +1,183 @@ +""" Shifted Window Attn + +This is a WIP experiment to apply windowed attention from the Swin Transformer +to a stand-alone module for use as an attn block in conv nets. + +Based on original swin window code at https://github.com/microsoft/Swin-Transformer +Swin Transformer paper: https://arxiv.org/pdf/2103.14030.pdf +""" +from typing import Optional + +import torch +import torch.nn as nn + +from .drop import DropPath +from .helpers import to_2tuple +from .weight_init import trunc_normal_ +from timm.models.fx_helpers import fx_float_to_int + + +def window_partition(x, win_size: int): + """ + Args: + x: (B, H, W, C) + win_size (int): window size + + Returns: + windows: (num_windows*B, window_size, window_size, C) + """ + B, H, W, C = x.shape + x = x.view(B, H // win_size, win_size, W // win_size, win_size, C) + windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, win_size, win_size, C) + return windows + + +def window_reverse(windows, win_size: int, H: int, W: int): + """ + Args: + windows: (num_windows*B, window_size, window_size, C) + win_size (int): Window size + H (int): Height of image + W (int): Width of image + + Returns: + x: (B, H, W, C) + """ + B = fx_float_to_int(windows.shape[0] / (H * W / win_size / win_size)) + x = windows.view(B, H // win_size, W // win_size, win_size, win_size, -1) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) + return x + + +class WindowAttention(nn.Module): + r""" Window based multi-head self attention (W-MSA) module with relative position bias. + It supports both of shifted and non-shifted window. + + Args: + dim (int): Number of input channels. + win_size (int): The height and width of the window. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 + """ + + def __init__( + self, dim, dim_out=None, feat_size=None, stride=1, win_size=8, shift_size=None, num_heads=8, + qkv_bias=True, attn_drop=0.): + + super().__init__() + self.dim_out = dim_out or dim + self.feat_size = to_2tuple(feat_size) + self.win_size = win_size + self.shift_size = shift_size or win_size // 2 + if min(self.feat_size) <= win_size: + # if window size is larger than input resolution, we don't partition windows + self.shift_size = 0 + self.win_size = min(self.feat_size) + assert 0 <= self.shift_size < self.win_size, "shift_size must in 0-window_size" + self.num_heads = num_heads + head_dim = self.dim_out // num_heads + self.scale = head_dim ** -0.5 + + if self.shift_size > 0: + # calculate attention mask for SW-MSA + H, W = self.feat_size + img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1 + h_slices = ( + slice(0, -self.win_size), + slice(-self.win_size, -self.shift_size), + slice(-self.shift_size, None)) + w_slices = ( + slice(0, -self.win_size), + slice(-self.win_size, -self.shift_size), + slice(-self.shift_size, None)) + cnt = 0 + for h in h_slices: + for w in w_slices: + img_mask[:, h, w, :] = cnt + cnt += 1 + mask_windows = window_partition(img_mask, self.win_size) # num_win, window_size, window_size, 1 + mask_windows = mask_windows.view(-1, self.win_size * self.win_size) + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) + else: + attn_mask = None + self.register_buffer("attn_mask", attn_mask) + + # define a parameter table of relative position bias + self.relative_position_bias_table = nn.Parameter( + # 2 * Wh - 1 * 2 * Ww - 1, nH + torch.zeros((2 * self.win_size - 1) * (2 * self.win_size - 1), num_heads)) + trunc_normal_(self.relative_position_bias_table, std=.02) + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(self.win_size) + coords_w = torch.arange(self.win_size) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww + coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 + relative_coords[:, :, 0] += self.win_size - 1 # shift to start from 0 + relative_coords[:, :, 1] += self.win_size - 1 + relative_coords[:, :, 0] *= 2 * self.win_size - 1 + relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww + self.register_buffer("relative_position_index", relative_position_index) + + self.qkv = nn.Linear(dim, self.dim_out * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.softmax = nn.Softmax(dim=-1) + self.pool = nn.AvgPool2d(2, 2) if stride == 2 else nn.Identity() + + def reset_parameters(self): + trunc_normal_(self.qkv.weight, std=self.qkv.weight.shape[1] ** -0.5) + trunc_normal_(self.relative_position_bias_table, std=.02) + + def forward(self, x): + B, C, H, W = x.shape + x = x.permute(0, 2, 3, 1) + + # cyclic shift + if self.shift_size > 0: + shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) + else: + shifted_x = x + + # partition windows + win_size_sq = self.win_size * self.win_size + x_windows = window_partition(shifted_x, self.win_size) # num_win * B, window_size, window_size, C + x_windows = x_windows.view(-1, win_size_sq, C) # num_win * B, window_size*window_size, C + BW, N, _ = x_windows.shape + + qkv = self.qkv(x_windows) + qkv = qkv.reshape(BW, N, 3, self.num_heads, self.dim_out // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] + q = q * self.scale + attn = torch.matmul(q, k.transpose(-2, -1)) + + relative_position_bias = self.relative_position_bias_table[ + self.relative_position_index.view(-1)].view(win_size_sq, win_size_sq, -1) + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh * Ww, Wh * Ww + attn = attn + relative_position_bias.unsqueeze(0) + if self.attn_mask is not None: + num_win = self.attn_mask.shape[0] + attn = attn.view(B, num_win, self.num_heads, N, N) + self.attn_mask.unsqueeze(1).unsqueeze(0) + attn = attn.view(-1, self.num_heads, N, N) + attn = self.softmax(attn) + attn = self.attn_drop(attn) + + x = torch.matmul(attn, v).transpose(1, 2).reshape(BW, N, self.dim_out) + + # merge windows + x = x.view(-1, self.win_size, self.win_size, self.dim_out) + shifted_x = window_reverse(x, self.win_size, H, W) # B H' W' C + + # reverse cyclic shift + if self.shift_size > 0: + x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) + else: + x = shifted_x + x = x.view(B, H, W, self.dim_out).permute(0, 3, 1, 2) + x = self.pool(x) + return x + + diff --git a/timm/models/levit.py b/timm/models/levit.py index 9987e4ba..c4377bb1 100644 --- a/timm/models/levit.py +++ b/timm/models/levit.py @@ -293,10 +293,10 @@ class Attention(nn.Module): k = k.permute(0, 2, 1, 3) v = v.permute(0, 2, 1, 3) - attn = q @ k.transpose(-2, -1) * self.scale + self.get_attention_biases(x.device) + attn = torch.matmul(q, k.transpose(-2, -1)) * self.scale + self.get_attention_biases(x.device) attn = attn.softmax(dim=-1) - x = (attn @ v).transpose(1, 2).reshape(B, N, self.dh) + x = torch.matmul(attn, v).transpose(1, 2).reshape(B, N, self.dh) x = self.proj(x) return x @@ -387,10 +387,10 @@ class AttentionSubsample(nn.Module): v = v.permute(0, 2, 1, 3) # BHNC q = self.q(x).view(B, self.resolution_2, self.num_heads, self.key_dim).permute(0, 2, 1, 3) - attn = q @ k.transpose(-2, -1) * self.scale + self.get_attention_biases(x.device) + attn = torch.matmul(q, k.transpose(-2, -1)) * self.scale + self.get_attention_biases(x.device) attn = attn.softmax(dim=-1) - x = (attn @ v).transpose(1, 2).reshape(B, -1, self.dh) + x = torch.matmul(attn, v).transpose(1, 2).reshape(B, -1, self.dh) x = self.proj(x) return x diff --git a/timm/models/nest.py b/timm/models/nest.py index 9a477bf9..73f14da5 100644 --- a/timm/models/nest.py +++ b/timm/models/nest.py @@ -26,10 +26,12 @@ from torch import nn from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from .helpers import build_model_with_cfg, named_apply +from .fx_helpers import fx_float_to_int from .layers import PatchEmbed, Mlp, DropPath, create_classifier, trunc_normal_ from .layers import create_conv2d, create_pool2d, to_ntuple from .registry import register_model + _logger = logging.getLogger(__name__) @@ -83,12 +85,12 @@ class Attention(nn.Module): qkv = self.qkv(x).reshape(B, T, N, 3, self.num_heads, C // self.num_heads).permute(3, 0, 4, 1, 2, 5) q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple) - attn = (q @ k.transpose(-2, -1)) * self.scale # (B, H, T, N, N) + attn = torch.matmul(q, k.transpose(-2, -1)) * self.scale # (B, H, T, N, N) attn = attn.softmax(dim=-1) attn = self.attn_drop(attn) # (B, H, T, N, C'), permute -> (B, T, N, C', H) - x = (attn @ v).permute(0, 2, 3, 4, 1).reshape(B, T, N, C) + x = torch.matmul(attn, v).permute(0, 2, 3, 4, 1).reshape(B, T, N, C) x = self.proj(x) x = self.proj_drop(x) return x # (B, T, N, C) @@ -128,8 +130,8 @@ class ConvPool(nn.Module): """ x is expected to have shape (B, C, H, W) """ - assert x.shape[-2] % 2 == 0, 'BlockAggregation requires even input spatial dims' - assert x.shape[-1] % 2 == 0, 'BlockAggregation requires even input spatial dims' + torch._assert(x.shape[-2] % 2 == 0, 'BlockAggregation requires even input spatial dims') + torch._assert(x.shape[-1] % 2 == 0, 'BlockAggregation requires even input spatial dims') x = self.conv(x) # Layer norm done over channel dim only x = self.norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) @@ -144,8 +146,8 @@ def blockify(x, block_size: int): block_size (int): edge length of a single square block in units of H, W """ B, H, W, C = x.shape - assert H % block_size == 0, '`block_size` must divide input height evenly' - assert W % block_size == 0, '`block_size` must divide input width evenly' + torch._assert(H % block_size == 0, '`block_size` must divide input height evenly') + torch._assert(W % block_size == 0, '`block_size` must divide input width evenly') grid_height = H // block_size grid_width = W // block_size x = x.reshape(B, grid_height, block_size, grid_width, block_size, C) @@ -160,7 +162,7 @@ def deblockify(x, block_size: int): block_size (int): edge length of a single square block in units of desired H, W """ B, T, _, C = x.shape - grid_size = int(math.sqrt(T)) + grid_size = fx_float_to_int(math.sqrt(T)) height = width = grid_size * block_size x = x.reshape(B, grid_size, grid_size, block_size, block_size, C) x = x.transpose(2, 3).reshape(B, height, width, C) diff --git a/timm/models/nfnet.py b/timm/models/nfnet.py index 4e0f2b21..ec86dbb8 100644 --- a/timm/models/nfnet.py +++ b/timm/models/nfnet.py @@ -27,6 +27,7 @@ import torch.nn as nn from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from .helpers import build_model_with_cfg +from timm.models.fx_features import register_leaf_module from .registry import register_model from .layers import ClassifierHead, DropPath, AvgPool2dSame, ScaledStdConv2d, ScaledStdConv2dSame,\ get_act_layer, get_act_fn, get_attn, make_divisible @@ -318,6 +319,7 @@ class DownsampleAvg(nn.Module): return self.conv(self.pool(x)) +@register_leaf_module # FX feature extraction was giving different valued features. Perhaps to do with control flow? class NormFreeBlock(nn.Module): """Normalization-Free pre-activation block. """ diff --git a/timm/models/rexnet.py b/timm/models/rexnet.py index 279780be..f27ce5d8 100644 --- a/timm/models/rexnet.py +++ b/timm/models/rexnet.py @@ -10,6 +10,7 @@ Changes for timm, feature extraction, and rounded channel variant hacked togethe Copyright 2020 Ross Wightman """ +import torch import torch.nn as nn from functools import partial from math import ceil @@ -92,7 +93,7 @@ class LinearBottleneck(nn.Module): if self.use_shortcut: if self.drop_path is not None: x = self.drop_path(x) - x[:, 0:self.in_channels] += shortcut + x = torch.cat([x[:, 0:self.in_channels] + shortcut, x[:, self.in_channels:]], dim=1) return x diff --git a/timm/models/swin_transformer.py b/timm/models/swin_transformer.py index 822aeef8..53c7bcd5 100644 --- a/timm/models/swin_transformer.py +++ b/timm/models/swin_transformer.py @@ -22,10 +22,12 @@ import torch.utils.checkpoint as checkpoint from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from .helpers import build_model_with_cfg, overlay_external_default_cfg +from .fx_helpers import fx_float_to_int from .layers import PatchEmbed, Mlp, DropPath, to_2tuple, trunc_normal_ from .registry import register_model from .vision_transformer import checkpoint_filter_fn, _init_vit_weights + _logger = logging.getLogger(__name__) @@ -111,7 +113,7 @@ def window_reverse(windows, window_size: int, H: int, W: int): Returns: x: (B, H, W, C) """ - B = int(windows.shape[0] / (H * W / window_size / window_size)) + B = fx_float_to_int(windows.shape[0] / (H * W / window_size / window_size)) x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) return x @@ -175,7 +177,7 @@ class WindowAttention(nn.Module): q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple) q = q * self.scale - attn = (q @ k.transpose(-2, -1)) + attn = torch.matmul(q, k.transpose(-2, -1)) relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH @@ -192,7 +194,7 @@ class WindowAttention(nn.Module): attn = self.attn_drop(attn) - x = (attn @ v).transpose(1, 2).reshape(B_, N, C) + x = torch.matmul(attn, v).transpose(1, 2).reshape(B_, N, C) x = self.proj(x) x = self.proj_drop(x) return x @@ -270,7 +272,7 @@ class SwinTransformerBlock(nn.Module): def forward(self, x): H, W = self.input_resolution B, L, C = x.shape - assert L == H * W, "input feature has wrong size" + torch._assert(L == H * W, "input feature has wrong size") shortcut = x x = self.norm1(x) @@ -329,8 +331,8 @@ class PatchMerging(nn.Module): """ H, W = self.input_resolution B, L, C = x.shape - assert L == H * W, "input feature has wrong size" - assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even." + torch._assert(L == H * W, "input feature has wrong size") + torch._assert(H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even.") x = x.view(B, H, W, C) diff --git a/timm/models/tnt.py b/timm/models/tnt.py index 9829653c..f9510487 100644 --- a/timm/models/tnt.py +++ b/timm/models/tnt.py @@ -9,10 +9,10 @@ https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/cv/TNT import math import torch import torch.nn as nn -from functools import partial from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.models.helpers import build_model_with_cfg +from timm.models.fx_helpers import fx_and from timm.models.layers import Mlp, DropPath, trunc_normal_ from timm.models.layers.helpers import to_2tuple from timm.models.registry import register_model @@ -64,11 +64,11 @@ class Attention(nn.Module): q, k = qk.unbind(0) # make torchscript happy (cannot use tensor as tuple) v = self.v(x).reshape(B, N, self.num_heads, -1).permute(0, 2, 1, 3) - attn = (q @ k.transpose(-2, -1)) * self.scale + attn = torch.matmul(q, k.transpose(-2, -1)) * self.scale attn = attn.softmax(dim=-1) attn = self.attn_drop(attn) - x = (attn @ v).transpose(1, 2).reshape(B, N, -1) + x = torch.matmul(attn, v).transpose(1, 2).reshape(B, N, -1) x = self.proj(x) x = self.proj_drop(x) return x @@ -109,7 +109,9 @@ class Block(nn.Module): pixel_embed = pixel_embed + self.drop_path(self.mlp_in(self.norm_mlp_in(pixel_embed))) # outer B, N, C = patch_embed.size() - patch_embed[:, 1:] = patch_embed[:, 1:] + self.proj(self.norm1_proj(pixel_embed).reshape(B, N - 1, -1)) + patch_embed = torch.cat( + [patch_embed[:, 0:1], patch_embed[:, 1:] + self.proj(self.norm1_proj(pixel_embed).reshape(B, N - 1, -1))], + dim=1) patch_embed = patch_embed + self.drop_path(self.attn_out(self.norm_out(patch_embed))) patch_embed = patch_embed + self.drop_path(self.mlp(self.norm_mlp(patch_embed))) return pixel_embed, patch_embed @@ -136,8 +138,8 @@ class PixelEmbed(nn.Module): def forward(self, x, pixel_pos): B, C, H, W = x.shape - assert H == self.img_size[0] and W == self.img_size[1], \ - f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." + torch._assert(fx_and(H == self.img_size[0], W == self.img_size[1]), + f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]}).") x = self.proj(x) x = self.unfold(x) x = x.transpose(1, 2).reshape(B * self.num_patches, self.in_dim, self.new_patch_size[0], self.new_patch_size[1]) diff --git a/timm/models/twins.py b/timm/models/twins.py index 4aed09d9..7b5afafb 100644 --- a/timm/models/twins.py +++ b/timm/models/twins.py @@ -22,6 +22,7 @@ from functools import partial from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from .layers import Mlp, DropPath, to_2tuple, trunc_normal_ +from .fx_features import register_leaf_module from .registry import register_model from .vision_transformer import Attention from .helpers import build_model_with_cfg, overlay_external_default_cfg @@ -62,6 +63,7 @@ default_cfgs = { Size_ = Tuple[int, int] +@register_leaf_module # FX can't symbolically trace control flow in forward method class LocallyGroupedAttn(nn.Module): """ LSA: self attention within a group """ @@ -98,10 +100,10 @@ class LocallyGroupedAttn(nn.Module): qkv = self.qkv(x).reshape( B, _h * _w, self.ws * self.ws, 3, self.num_heads, C // self.num_heads).permute(3, 0, 1, 4, 2, 5) q, k, v = qkv[0], qkv[1], qkv[2] - attn = (q @ k.transpose(-2, -1)) * self.scale + attn = torch.matmul(q, k.transpose(-2, -1)) * self.scale attn = attn.softmax(dim=-1) attn = self.attn_drop(attn) - attn = (attn @ v).transpose(2, 3).reshape(B, _h, _w, self.ws, self.ws, C) + attn = torch.matmul(attn, v).transpose(2, 3).reshape(B, _h, _w, self.ws, self.ws, C) x = attn.transpose(2, 3).reshape(B, _h * self.ws, _w * self.ws, C) if pad_r > 0 or pad_b > 0: x = x[:, :H, :W, :].contiguous() @@ -183,11 +185,11 @@ class GlobalSubSampleAttn(nn.Module): kv = self.kv(x).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) k, v = kv[0], kv[1] - attn = (q @ k.transpose(-2, -1)) * self.scale + attn = torch.matmul(q, k.transpose(-2, -1)) * self.scale attn = attn.softmax(dim=-1) attn = self.attn_drop(attn) - x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = torch.matmul(attn, v).transpose(1, 2).reshape(B, N, C) x = self.proj(x) x = self.proj_drop(x) diff --git a/timm/models/vgg.py b/timm/models/vgg.py index 8bea03e7..aee41b25 100644 --- a/timm/models/vgg.py +++ b/timm/models/vgg.py @@ -12,6 +12,7 @@ from typing import Union, List, Dict, Any, cast from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from .helpers import build_model_with_cfg +from .fx_features import register_leaf_module from .layers import ClassifierHead, ConvBnAct from .registry import register_model @@ -52,6 +53,7 @@ cfgs: Dict[str, List[Union[str, int]]] = { } +@register_leaf_module # FX can't symbolically trace control flow in forward method class ConvMlp(nn.Module): def __init__(self, in_features=512, out_features=4096, kernel_size=7, mlp_ratio=1.0, diff --git a/timm/models/visformer.py b/timm/models/visformer.py index 6e832cd0..6ed43102 100644 --- a/timm/models/visformer.py +++ b/timm/models/visformer.py @@ -100,10 +100,10 @@ class Attention(nn.Module): x = self.qkv(x).reshape(B, 3, self.num_heads, self.head_dim, -1).permute(1, 0, 2, 4, 3) q, k, v = x[0], x[1], x[2] - attn = (q @ k.transpose(-2, -1)) * self.scale + attn = torch.matmul(q, k.transpose(-2, -1)) * self.scale attn = attn.softmax(dim=-1) attn = self.attn_drop(attn) - x = attn @ v + x = torch.matmul(attn, v) x = x.permute(0, 1, 3, 2).reshape(B, -1, H, W) x = self.proj(x) diff --git a/timm/models/vision_transformer.py b/timm/models/vision_transformer.py index 94ae2666..fb939624 100644 --- a/timm/models/vision_transformer.py +++ b/timm/models/vision_transformer.py @@ -192,11 +192,11 @@ class Attention(nn.Module): qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple) - attn = (q @ k.transpose(-2, -1)) * self.scale + attn = torch.matmul(q, k.transpose(-2, -1)) * self.scale attn = attn.softmax(dim=-1) attn = self.attn_drop(attn) - x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = torch.matmul(attn, v).transpose(1, 2).reshape(B, N, C) x = self.proj(x) x = self.proj_drop(x) return x diff --git a/timm/models/xcit.py b/timm/models/xcit.py index 2942ed8a..9be29046 100644 --- a/timm/models/xcit.py +++ b/timm/models/xcit.py @@ -21,6 +21,7 @@ from .vision_transformer import _cfg, Mlp from .registry import register_model from .layers import DropPath, trunc_normal_, to_2tuple from .cait import ClassAttn +from .fx_features import register_leaf_module def _cfg(url='', **kwargs): @@ -97,6 +98,7 @@ default_cfgs = { } +@register_leaf_module # FX can't symbolically trace torch.arange in forward method class PositionalEncodingFourier(nn.Module): """ Positional encoding relying on a fourier kernel matching the one used in the "Attention is all of Need" paper. @@ -272,12 +274,12 @@ class XCA(nn.Module): # Paper section 3.2 l2-Normalization and temperature scaling q = torch.nn.functional.normalize(q, dim=-1) k = torch.nn.functional.normalize(k, dim=-1) - attn = (q @ k.transpose(-2, -1)) * self.temperature + attn = torch.matmul(q, k.transpose(-2, -1)) * self.temperature attn = attn.softmax(dim=-1) attn = self.attn_drop(attn) # (B, H, C', N), permute -> (B, N, H, C') - x = (attn @ v).permute(0, 3, 1, 2).reshape(B, N, C) + x = torch.matmul(attn, v).permute(0, 3, 1, 2).reshape(B, N, C) x = self.proj(x) x = self.proj_drop(x) return x From a6c24b936ba91a02686fb10cf7fbbe50216226a4 Mon Sep 17 00:00:00 2001 From: Alexander Soare Date: Thu, 12 Aug 2021 15:31:36 +0100 Subject: [PATCH 03/13] Tests to enforce all models FX traceable --- tests/test_models.py | 80 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 80 insertions(+) diff --git a/tests/test_models.py b/tests/test_models.py index c0d0e901..91fd3543 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -7,6 +7,7 @@ import fnmatch import timm from timm import list_models, create_model, set_scriptable, has_model_default_key, is_model_default_key, \ get_model_default_value +from timm.models.fx_features import NodePathTracer if hasattr(torch._C, '_jit_set_profiling_executor'): # legacy executor is too slow to compile large models for unit tests @@ -297,3 +298,82 @@ def test_model_forward_features(model_name, batch_size): assert e == o.shape[1] assert o.shape[0] == batch_size assert not torch.isnan(o).any() + + +@pytest.mark.timeout(120) +@pytest.mark.parametrize('model_name', list_models(exclude_filters=EXCLUDE_FILTERS)) +@pytest.mark.parametrize('batch_size', [1]) +def test_model_forward_fx(model_name, batch_size): + """Symbolically trace each model and run single forward pass through the resulting GraphModule""" + model = create_model(model_name, pretrained=False) + model.eval() + + input_size = _get_input_size(model=model, target=TARGET_FWD_SIZE) + if max(input_size) > MAX_FWD_SIZE: + pytest.skip("Fixed input size model > limit.") + + tracer = NodePathTracer() + graph = tracer.trace(model) + model = torch.fx.GraphModule(model, graph) + + inputs = torch.randn((batch_size, *input_size)) + outputs = model(inputs) + + assert outputs.shape[0] == batch_size + assert not torch.isnan(outputs).any(), 'Output included NaNs' + + +@pytest.mark.timeout(120) +@pytest.mark.parametrize('model_name', list_models(exclude_filters=EXCLUDE_FILTERS, name_matches_cfg=True)) +@pytest.mark.parametrize('batch_size', [2]) +def test_model_backward_fx(model_name, batch_size): + """Symbolically trace each model and run single backward pass through the resulting GraphModule""" + input_size = _get_input_size(model_name=model_name, target=TARGET_BWD_SIZE) + if max(input_size) > MAX_BWD_SIZE: + pytest.skip("Fixed input size model > limit.") + + model = create_model(model_name, pretrained=False, num_classes=42) + model.train() + num_params = sum([x.numel() for x in model.parameters()]) + + tracer = NodePathTracer() + graph = tracer.trace(model) + model = torch.fx.GraphModule(model, graph) + + inputs = torch.randn((batch_size, *input_size)) + outputs = model(inputs) + if isinstance(outputs, tuple): + outputs = torch.cat(outputs) + outputs.mean().backward() + for n, x in model.named_parameters(): + assert x.grad is not None, f'No gradient for {n}' + num_grad = sum([x.grad.numel() for x in model.parameters() if x.grad is not None]) + + assert outputs.shape[-1] == 42 + assert num_params == num_grad, 'Some parameters are missing gradients' + assert not torch.isnan(outputs).any(), 'Output included NaNs' + + +@pytest.mark.timeout(120) +@pytest.mark.parametrize( + 'model_name', list_models(exclude_filters=EXCLUDE_FILTERS + EXCLUDE_JIT_FILTERS, name_matches_cfg=True)) +@pytest.mark.parametrize('batch_size', [1]) +def test_model_forward_fx_torchscript(model_name, batch_size): + """Symbolically trace each model, script it, and run single forward pass""" + input_size = _get_input_size(model_name=model_name, target=TARGET_JIT_SIZE) + if max(input_size) > MAX_JIT_SIZE: + pytest.skip("Fixed input size model > limit.") + + with set_scriptable(True): + model = create_model(model_name, pretrained=False) + model.eval() + + tracer = NodePathTracer() + graph = tracer.trace(model) + model = torch.fx.GraphModule(model, graph) + + model = torch.jit.script(model) + outputs = model(torch.randn((batch_size, *input_size))) + + assert outputs.shape[0] == batch_size + assert not torch.isnan(outputs).any(), 'Output included NaNs' \ No newline at end of file From 02c3a75a45deea8b8728da14e0d0b5106e06d98b Mon Sep 17 00:00:00 2001 From: Alexander Soare Date: Sat, 28 Aug 2021 17:54:22 +0100 Subject: [PATCH 04/13] wip - make it possible to use fx graph in train and eval mode --- timm/models/fx_features.py | 460 ++++++++++++++++++++++++++----------- 1 file changed, 329 insertions(+), 131 deletions(-) diff --git a/timm/models/fx_features.py b/timm/models/fx_features.py index e00b080f..9a76e041 100644 --- a/timm/models/fx_features.py +++ b/timm/models/fx_features.py @@ -7,18 +7,21 @@ An extension/alternative to timm.models.features making use of PyTorch FX. Here, 3. Write the resulting graph into a GraphModule (PyTorch FX functionality) Copyright 2021 Alexander Soare """ -from typing import Callable, Dict +from typing import Callable, Dict, Union, List, Optional import math from collections import OrderedDict from pprint import pprint from inspect import ismethod import re import warnings +from copy import deepcopy +from itertools import chain import torch from torch import nn from torch import fx import torch.nn.functional as F +from torch.fx.graph_module import _copy_attr from .features import _get_feature_info from .fx_helpers import fx_and, fx_float_to_int @@ -84,29 +87,48 @@ class LeafNodeTracer(TimmTracer): return super().is_leaf_module(m, module_qualname) +def _is_subseq(x, y): + """Check if y is a subseqence of x + https://stackoverflow.com/a/24017747/4391249 + """ + iter_x = iter(x) + return all(any(x_item == y_item for x_item in iter_x) for y_item in y) + + # Taken from https://github.com/pytorch/examples/blob/master/fx/module_tracer.py with modifications for storing # qualified names for all Nodes, not just top-level Modules class NodePathTracer(LeafNodeTracer): """ - NodePathTracer is an FX tracer that, for each operation, also records the qualified name of the Node from which the - operation originated. A qualified name here is a `.` seperated path walking the hierarchy from top level module - down to leaf operation or leaf module. The name of the top level module is not included as part of the qualified - name. For example, if we trace a module who's forward method applies a ReLU module, the qualified name for that - node will simply be 'relu'. + NodePathTracer is an FX tracer that, for each operation, also records the + qualified name of the Node from which the operation originated. A + qualified name here is a `.` seperated path walking the hierarchy from top + level module down to leaf operation or leaf module. The name of the top + level module is not included as part of the qualified name. For example, + if we trace a module who's forward method applies a ReLU module, the + qualified name for that node will simply be 'relu'. + + Some notes on the specifics: + - Nodes are recorded to `self.node_to_qualname` which is a dictionary + mapping a given Node object to its qualified name. + - Nodes are recorded in the order which they are executed during + tracing. + - When a duplicate qualified name is encountered, a suffix of the form + _{int} is added. The counter starts from 1. """ - def __init__(self): - super().__init__() + def __init__(self, *args, **kwargs): + super(NodePathTracer, self).__init__(*args, **kwargs) # Track the qualified name of the Node being traced - self.current_module_qualname : str = '' + self.current_module_qualname = '' # A map from FX Node to the qualified name self.node_to_qualname = OrderedDict() def call_module(self, m: torch.nn.Module, forward: Callable, args, kwargs): """ - Override of Tracer.call_module (see https://pytorch.org/docs/stable/fx.html#torch.fx.Tracer.call_module). + Override of `fx.Tracer.call_module` This override: 1) Stores away the qualified name of the caller for restoration later - 2) Installs the qualified name of the caller in `current_module_qualname` for retrieval by `create_proxy` + 2) Adds the qualified name of the caller to + `current_module_qualname` for retrieval by `create_proxy` 3) Once a leaf module is reached, calls `create_proxy` 4) Restores the caller's qualified name into current_module_qualname """ @@ -121,7 +143,8 @@ class NodePathTracer(LeafNodeTracer): finally: self.current_module_qualname = old_qualname - def create_proxy(self, kind: str, target: fx.node.Target, args, kwargs, name=None, type_expr=None): + def create_proxy(self, kind: str, target: fx.node.Target, args, kwargs, + name=None, type_expr=None) -> fx.proxy.Proxy: """ Override of `Tracer.create_proxy`. This override intercepts the recording of every operation and stores away the current traced module's qualified @@ -132,160 +155,335 @@ class NodePathTracer(LeafNodeTracer): self.current_module_qualname, proxy.node) return proxy - def _get_node_qualname(self, module_qualname: str, node: fx.node.Node): + def _get_node_qualname( + self, module_qualname: str, node: fx.node.Node) -> str: node_qualname = module_qualname if node.op == 'call_module': - # Node terminates in a leaf module so the module_qualname is a complete description of the node - # Just need to check if this module has appeared before. If so add postfix counter starting from _1 for the - # first reappearance (this follows the way that repeated leaf ops are enumerated by PyTorch FX) + # Node terminates in a leaf module so the module_qualname is a + # complete description of the node for existing_qualname in reversed(self.node_to_qualname.values()): - # Check to see if existing_qualname is of the form {node_qualname} or {node_qualname}_{int} - if re.match(rf'{node_qualname}(_[0-9]+)?$', existing_qualname) is not None: + # Check to see if existing_qualname is of the form + # {node_qualname} or {node_qualname}_{int} + if re.match(rf'{node_qualname}(_[0-9]+)?$', + existing_qualname) is not None: postfix = existing_qualname.replace(node_qualname, '') if len(postfix): - # existing_qualname is of the form {node_qualname}_{int} + # Existing_qualname is of the form {node_qualname}_{int} next_index = int(postfix[1:]) + 1 else: # existing_qualname is of the form {node_qualname} next_index = 1 node_qualname += f'_{next_index}' - break + break else: - # Node terminates in non- leaf module so the node name needs to be appended - if len(node_qualname) > 0: # only append '.' if we are deeper than the top level module + # Node terminates in non- leaf module so the node name needs to be + # appended + if len(node_qualname) > 0: + # Only append '.' if we are deeper than the top level module node_qualname += '.' node_qualname += str(node) return node_qualname -def print_graph_node_qualified_names(model: nn.Module): +def _warn_graph_differences( + train_tracer: NodePathTracer, eval_tracer: NodePathTracer): """ - Dev utility to prints nodes in order of execution. Useful for choosing `nodes` for a FeatureGraphNet design. - This is useful for two reasons: - 1. Not all submodules are traced through. Some are treated as leaf modules. See `LeafNodeTracer` - 2. Leaf ops that occur more than once in the graph get a `_{counter}` postfix. - - WARNING: Changes to the operations in the original module might not change the module's overall behaviour, but they - may result in changes to the postfixes for the names of repeated ops, thereby breaking feature extraction. + Utility function for warning the user if there are differences between + the train graph and the eval graph. + """ + train_nodes = list(train_tracer.node_to_qualname.values()) + eval_nodes = list(eval_tracer.node_to_qualname.values()) + + if len(train_nodes) == len(eval_nodes) and [ + t == e for t, e in zip(train_nodes, eval_nodes)]: + return + + suggestion_msg = ( + "When choosing nodes for feature extraction, you may need to specify " + "output nodes for train and eval mode separately") + + if _is_subseq(train_nodes, eval_nodes): + msg = ("NOTE: The nodes obtained by tracing the model in eval mode " + "are a subsequence of those obtained in train mode. ") + elif _is_subseq(eval_nodes, train_nodes): + msg = ("NOTE: The nodes obtained by tracing the model in train mode " + "are a subsequence of those obtained in eval mode. ") + else: + msg = ("The nodes obtained by tracing the model in train mode " + "are different to those obtained in eval mode. ") + warnings.warn(msg + suggestion_msg) + + +def print_graph_node_qualified_names( + model: nn.Module, tracer_kwargs: Dict = {}): """ - tracer = NodePathTracer() - tracer.trace(model) - pprint(list(tracer.node_to_qualname.values())) + Dev utility to prints nodes in order of execution. Useful for choosing + nodes for a FeatureGraphNet design. There are two reasons that qualified + node names can't easily be read directly from the code for a model: + 1. Not all submodules are traced through. Modules from `torch.nn` all + fall within this category. + 2. Node qualified names that occur more than once in the graph get a + `_{counter}` postfix. + The model will be traced twice: once in train mode, and once in eval mode. + If there are discrepancies between the graphs produced, both sets will + be printed and the user will be warned. + + Args: + model (nn.Module): model on which we will extract the features + tracer_kwargs (Dict): a dictionary of keywork arguments for + `NodePathTracer` (which passes them onto it's parent class + `torch.fx.Tracer`). + """ + train_tracer = NodePathTracer(**tracer_kwargs) + train_tracer.trace(model.train()) + eval_tracer = NodePathTracer(**tracer_kwargs) + eval_tracer.trace(model.eval()) + train_nodes = list(train_tracer.node_to_qualname.values()) + eval_nodes = list(eval_tracer.node_to_qualname.values()) + if len(train_nodes) == len(eval_nodes) and [ + t == e for t, e in zip(train_nodes, eval_nodes)]: + # Nodes are aligned in train vs eval mode + pprint(list(train_tracer.node_to_qualname.values())) + return + print("Nodes from train mode:") + pprint(list(train_tracer.node_to_qualname.values())) + print() + print("Nodes from eval mode:") + pprint(list(eval_tracer.node_to_qualname.values())) + print() + _warn_graph_differences(train_tracer, eval_tracer) + + +class DualGraphModule(fx.GraphModule): + """ + A derivative of `fx.GraphModule`. Differs in the following ways: + - Requires a train and eval version of the underlying graph + - Copies submodules according to the nodes of both train and eval graphs. + - Calling train(mode) switches between train graph and eval graph. + """ + def __init__(self, + root: torch.nn.Module, + train_graph: fx.Graph, + eval_graph: fx.Graph, + class_name: str = 'GraphModule'): + """ + Args: + root (torch.nn.Module): module from which the copied module + hierarchy is built + train_graph (Graph): the graph that should be used in train mode + eval_graph (Graph): the graph that should be used in eval mode + """ + super(fx.GraphModule, self).__init__() + + self.__class__.__name__ = class_name + + self.train_graph = train_graph + self.eval_graph = eval_graph + + # Copy all get_attr and call_module ops (indicated by BOTH train and + # eval graphs) + for node in chain(iter(train_graph.nodes), iter(eval_graph.nodes)): + if node.op in ['get_attr', 'call_module']: + assert isinstance(node.target, str) + _copy_attr(root, self, node.target) + + # eval mode by default + self.eval() + self.graph = eval_graph + + # (borrowed from fx.GraphModule): + # Store the Tracer class responsible for creating a Graph separately as part of the + # GraphModule state, except when the Tracer is defined in a local namespace. + # Locally defined Tracers are not pickleable. This is needed because torch.package will + # serialize a GraphModule without retaining the Graph, and needs to use the correct Tracer + # to re-create the Graph during deserialization. + # TODO uncomment this when https://github.com/pytorch/pytorch/pull/63121 is available + # assert self.eval_graph._tracer_cls == self.train_graph._tracer_cls, \ + # "Train mode and eval mode should use the same tracer class" + # self._tracer_cls = None + # if self.graph._tracer_cls and '' not in self.graph._tracer_cls.__qualname__: + # self._tracer_cls = self.graph._tracer_cls + + def train(self, mode=True): + """ + Swap out the graph depending on the training mode. + NOTE this should be safe when calling model.eval() because that just + calls this with mode == False. + """ + if mode: + self.graph = self.train_graph + else: + self.graph = self.eval_graph + return super().train(mode=mode) -def get_intermediate_nodes(model: nn.Module, return_nodes: Dict[str, str]) -> nn.Module: +def build_feature_graph_net( + model: nn.Module, + return_nodes: Union[List[str], Dict[str, str]], + train_return_nodes: Optional[Union[List[str], Dict[str, str]]] = None, + eval_return_nodes: Optional[Union[List[str], Dict[str, str]]] = None, + tracer_kwargs: Dict = {}) -> fx.GraphModule: """ - Creates a new FX-based module that returns intermediate nodes from a given model. This is achieved by re-writing - the computation graph of the model via FX to return the desired nodes as outputs. All unused nodes are removed, - together with their corresponding parameters. + Creates a new graph module that returns intermediate nodes from a given + model as dictionary with user specified keys as strings, and the requested + outputs as values. This is achieved by re-writing the computation graph of + the model via FX to return the desired nodes as outputs. All unused nodes + are removed, together with their corresponding parameters. + + A note on node specification: A node qualified name is specified as a `.` + seperated path walking the hierarchy from top level module down to leaf + operation or leaf module. For instance `blocks.5.3.bn1`. The keys of the + `return_nodes` argument should point to either a node's qualified name, + or some truncated version of it. For example, one could provide `blocks.5` + as a key, and the last node with that prefix will be selected. + `print_graph_node_qualified_names` is a useful helper function for getting + a list of qualified names of a model. + + An attempt is made to keep all non-parametric properties of the original + model, but existing properties of the constructed `GraphModule` are not + overwritten. + Args: model (nn.Module): model on which we will extract the features - return_nodes (Dict[name, new_name]): a dict containing the names (or partial names - see note below) of the - nodes for which the activations will be returned as the keys. The values of the dict are the names - of the returned activations (which the user can specify). - A note on node specification: A node is specified as a `.` seperated path walking the hierarchy from top - level module down to leaf operation or leaf module. For instance `blocks.5.3.bn1`. Nevertheless, the keys - in this dict need not be fully specified. One could provide `blocks.5` as a key, and the last node with - that prefix will be selected. - While designing a feature extractor one can use the `print_graph_node_qualified_names` utility as a guide - to which nodes are available. - Acknowledgement: Starter code from https://github.com/pytorch/vision/pull/3597 + return_nodes (Union[List[name], Dict[name, new_name]])): either a list + or a dict containing the names (or partial names - see note above) + of the nodes for which the activations will be returned. If it is + a `Dict`, the keys are the qualified node names, and the values + are the user-specified keys for the graph module's returned + dictionary. If it is a `List`, it is treated as a `Dict` mapping + node specification strings directly to output names. + tracer_kwargs (Dict): a dictionary of keywork arguments for + `NodePathTracer` (which passes them onto it's parent class + `torch.fx.Tracer`). + + Examples:: + + >>> model = torchvision.models.resnet18() + >>> # extract layer1 and layer3, giving as names `feat1` and feat2` + >>> graph_module = torchvision.models._utils.build_feature_graph_net(m, + >>> {'layer1': 'feat1', 'layer3': 'feat2'}) + >>> out = graph_module(torch.rand(1, 3, 224, 224)) + >>> print([(k, v.shape) for k, v in out.items()]) + >>> [('feat1', torch.Size([1, 64, 56, 56])), + >>> ('feat2', torch.Size([1, 256, 14, 14]))] + """ + is_training = model.training + + if isinstance(return_nodes, list): + return_nodes = {n: n for n in return_nodes} return_nodes = {str(k): str(v) for k, v in return_nodes.items()} - # Instantiate our NodePathTracer and use that to trace the model - tracer = NodePathTracer() - graph = tracer.trace(model) - - name = model.__class__.__name__ if isinstance(model, nn.Module) else model.__name__ - graph_module = fx.GraphModule(tracer.root, graph, name) - - available_nodes = [f'{v}.{k}' for k, v in tracer.node_to_qualname.items()] - # FIXME We don't know if we should expect this to happen - assert len(set(available_nodes)) == len(available_nodes), \ - "There are duplicate nodes! Please raise an issue https://github.com/rwightman/pytorch-image-models/issues" - # Check that all outputs in return_nodes are present in the model - for query in return_nodes.keys(): - if not any([m.startswith(query) for m in available_nodes]): - raise ValueError(f"return_node: {query} is not present in model") - - # Remove existing output nodes - orig_output_node = None - for n in reversed(graph_module.graph.nodes): - if n.op == "output": - orig_output_node = n - assert orig_output_node - # And remove it - graph_module.graph.erase_node(orig_output_node) - # Find nodes corresponding to return_nodes and make them into output_nodes - nodes = [n for n in graph_module.graph.nodes] - output_nodes = OrderedDict() - for n in reversed(nodes): - if 'tensor_constant' in str(n): - # NOTE Without this control flow we would get a None value for - # `module_qualname = tracer.node_to_qualname.get(n)`. On the other hand, we can safely assume that we'll - # never need to get this as an interesting intermediate node. - continue - module_qualname = tracer.node_to_qualname.get(n) - for query in return_nodes: - depth = query.count('.') - if '.'.join(module_qualname.split('.')[:depth+1]) == query: - output_nodes[return_nodes[query]] = n - return_nodes.pop(query) - break - output_nodes = OrderedDict(reversed(list(output_nodes.items()))) - - # And add them in the end of the graph - with graph_module.graph.inserting_after(nodes[-1]): - graph_module.graph.output(output_nodes) - - # Remove unused modules / parameters - graph_module.graph.eliminate_dead_code() - graph_module.recompile() - graph_module = fx.GraphModule(graph_module, graph_module.graph, name) + assert not ((train_return_nodes is None) ^ (eval_return_nodes is None)), \ + ("If any of `train_return_nodes` and `eval_return_nodes` are " + "specified, then both should be specified") + + if train_return_nodes is None: + train_return_nodes = deepcopy(return_nodes) + eval_return_nodes = deepcopy(return_nodes) + + # Repeat the tracing and graph rewriting for train and eval mode + tracers = {} + graphs = {} + return_nodes = { + 'train': train_return_nodes, + 'eval': eval_return_nodes + } + for mode in ['train', 'eval']: + if mode == 'train': + model.train() + elif mode == 'eval': + model.eval() + + # Instantiate our NodePathTracer and use that to trace the model + tracer = NodePathTracer(**tracer_kwargs) + graph = tracer.trace(model) + + name = model.__class__.__name__ if isinstance( + model, nn.Module) else model.__name__ + graph_module = fx.GraphModule(tracer.root, graph, name) + + available_nodes = [f'{v}.{k}' for k, v in tracer.node_to_qualname.items()] + # FIXME We don't know if we should expect this to happen + assert len(set(available_nodes)) == len(available_nodes), \ + "There are duplicate nodes! Please raise an issue https://github.com/pytorch/vision/issues" + # Check that all outputs in return_nodes are present in the model + for query in return_nodes[mode].keys(): + if not any([m.startswith(query) for m in available_nodes]): + raise ValueError(f"return_node: {query} is not present in model") + + # Remove existing output nodes (train mode) + orig_output_nodes = [] + for n in reversed(graph_module.graph.nodes): + if n.op == "output": + orig_output_nodes.append(n) + assert len(orig_output_nodes) + for n in orig_output_nodes: + graph_module.graph.erase_node(n) + + # Find nodes corresponding to return_nodes and make them into output_nodes + nodes = [n for n in graph_module.graph.nodes] + output_nodes = OrderedDict() + for n in reversed(nodes): + if 'tensor_constant' in str(n): + # NOTE Without this control flow we would get a None value for + # `module_qualname = tracer.node_to_qualname.get(n)`. + # On the other hand, we can safely assume that we'll never need to + # get this as an interesting intermediate node. + continue + module_qualname = tracer.node_to_qualname.get(n) + for query in return_nodes[mode]: + depth = query.count('.') + if '.'.join(module_qualname.split('.')[:depth + 1]) == query: + output_nodes[return_nodes[mode][query]] = n + return_nodes[mode].pop(query) + break + output_nodes = OrderedDict(reversed(list(output_nodes.items()))) + + # And add them in the end of the graph + with graph_module.graph.inserting_after(nodes[-1]): + graph_module.graph.output(output_nodes) + + # Remove unused modules / parameters + graph_module.graph.eliminate_dead_code() + graph_module.recompile() + + # Keep track of the tracer and graph so we can choose the main one + tracers[mode] = tracer + graphs[mode] = graph + + # Warn user if there are any discrepancies between the graphs of the + # train and eval modes + _warn_graph_differences(tracers['train'], tracers['eval']) + + # Build the final graph module + graph_module = DualGraphModule( + model, graphs['train'], graphs['eval'], class_name=name) + + # Keep non-parameter model properties for reference + for attr_str in model.__dir__(): + attr = getattr(model, attr_str) + if (not attr_str.startswith('_') + and attr_str not in graph_module.__dir__() + and not ismethod(attr) + and not isinstance(attr, (nn.Module, nn.Parameter))): + setattr(graph_module, attr_str, attr) + + # Restore original training mode + graph_module.train(is_training) + return graph_module class FeatureGraphNet(nn.Module): - """ - Take the provided model and transform it into a graph module. This class wraps the resulting graph module while - also keeping the original model's non-parameter properties for reference. The original model is discarded. - - WARNING: Changes to the operations in the original module might not change the module's overall behaviour, but they - may result in changes to the postfixes for the names of repeated ops, thereby breaking feature extraction. - - TODO: FIX THIS - WARNING: This puts the input model into eval mode prior to tracing. This means that any control flow dependent on - the model being in train mode will be lost. - """ def __init__(self, model, out_indices, out_map=None): super().__init__() - model.eval() self.feature_info = _get_feature_info(model, out_indices) if out_map is not None: assert len(out_map) == len(out_indices) - # NOTE the feature_info key is innapropriately named 'module' because prior to FX only modules could be - # provided. Recall that here, we may also provide nodes referring to individual ops return_nodes = {info['module']: out_map[i] if out_map is not None else info['module'] for i, info in enumerate(self.feature_info) if i in out_indices} - self.graph_module = get_intermediate_nodes(model, return_nodes) - # Keep non-parameter model properties for reference - for attr_str in model.__dir__(): - attr = getattr(model, attr_str) - if (not attr_str.startswith('_') and attr_str not in self.__dir__() and not ismethod(attr) - and not isinstance(attr, (nn.Module, nn.Parameter))): - setattr(self, attr_str, attr) - + self.graph_module = build_feature_graph_net(model, return_nodes) + def forward(self, x): - return list(self.graph_module(x).values()) - - def train(self, mode=True): - """ - NOTE: This also covers `self.eval()` as that just does self.train(False) - """ - if mode: - warnings.warn( - "Setting a FeatureGraphNet to training mode won't necessarily have the desired effect. Control " - "flow depending on `self.training` will follow the `False` path. See FeatureGraphNet doc-string " - "for more details.") - super().train(mode) \ No newline at end of file + return list(self.graph_module(x).values()) \ No newline at end of file From 0149ec30d7d39ae8ed65b95bd6ed601055a97f4c Mon Sep 17 00:00:00 2001 From: Alexander Soare Date: Sun, 7 Nov 2021 14:54:42 +0000 Subject: [PATCH 05/13] wip - attempting to rebase --- timm/models/fx_features.py | 4 ++-- timm/models/fx_helpers.py | 10 ---------- timm/models/layers/bottleneck_attn.py | 1 - timm/models/layers/halo_attn.py | 2 +- timm/models/layers/non_local_attn.py | 4 ++-- timm/models/layers/patch_embed.py | 4 ---- timm/models/tnt.py | 5 +++-- 7 files changed, 8 insertions(+), 22 deletions(-) diff --git a/timm/models/fx_features.py b/timm/models/fx_features.py index 9a76e041..310cc465 100644 --- a/timm/models/fx_features.py +++ b/timm/models/fx_features.py @@ -24,7 +24,7 @@ import torch.nn.functional as F from torch.fx.graph_module import _copy_attr from .features import _get_feature_info -from .fx_helpers import fx_and, fx_float_to_int +from .fx_helpers import fx_float_to_int # Layers we went to treat as leaf modules for FeatureGraphNet from .layers import Conv2dSame, ScaledStdConv2dSame, BatchNormAct2d, BlurPool2d, CondConv2d, StdConv2dSame @@ -55,7 +55,7 @@ def register_leaf_module(module: nn.Module): # These functions will not be traced through -_autowrap_functions=(fx_float_to_int, fx_and) +_autowrap_functions=(fx_float_to_int,) class TimmTracer(fx.Tracer): diff --git a/timm/models/fx_helpers.py b/timm/models/fx_helpers.py index 1955d5b1..878ba381 100644 --- a/timm/models/fx_helpers.py +++ b/timm/models/fx_helpers.py @@ -1,14 +1,4 @@ - -def fx_and(a: bool, b: bool) -> bool: - """ - Symbolic tracing helper to substitute for normal usage of `* and *` within `torch._assert`. - Hint: Symbolic tracing does not support control flow but since an `assert` is either a dead-end or not, this hack - is okay. - """ - return (a and b) - - def fx_float_to_int(x: float) -> int: """ Symbolic tracing helper to substitute for inbuilt `int`. diff --git a/timm/models/layers/bottleneck_attn.py b/timm/models/layers/bottleneck_attn.py index 305f9de3..c56c5821 100644 --- a/timm/models/layers/bottleneck_attn.py +++ b/timm/models/layers/bottleneck_attn.py @@ -22,7 +22,6 @@ import torch.nn.functional as F from .helpers import to_2tuple, make_divisible from .weight_init import trunc_normal_ -from timm.models.fx_helpers import fx_and def rel_logits_1d(q, rel_k, permute_mask: List[int]): diff --git a/timm/models/layers/halo_attn.py b/timm/models/layers/halo_attn.py index 0bd611b1..babfcb06 100644 --- a/timm/models/layers/halo_attn.py +++ b/timm/models/layers/halo_attn.py @@ -24,7 +24,7 @@ import torch.nn.functional as F from .helpers import make_divisible from .weight_init import trunc_normal_ -from timm.models.fx_helpers import fx_and +from timm.models.fx_helpers import def rel_logits_1d(q, rel_k, permute_mask: List[int]): diff --git a/timm/models/layers/non_local_attn.py b/timm/models/layers/non_local_attn.py index 517e28a8..f933ece2 100644 --- a/timm/models/layers/non_local_attn.py +++ b/timm/models/layers/non_local_attn.py @@ -10,7 +10,6 @@ from torch.nn import functional as F from .conv_bn_act import ConvBnAct from .helpers import make_divisible -from timm.models.fx_helpers import fx_and class NonLocalAttn(nn.Module): @@ -96,7 +95,8 @@ class BilinearAttnTransform(nn.Module): return x def forward(self, x): - torch._assert(fx_and(x.shape[-1] % self.block_size == 0, x.shape[-2] % self.block_size == 0), '') + torch._assert(x.shape[-1] % self.block_size == 0, '') + torch._assert(x.shape[-2] % self.block_size == 0, '') B, C, H, W = x.shape out = self.conv1(x) rp = F.adaptive_max_pool2d(out, (self.block_size, 1)) diff --git a/timm/models/layers/patch_embed.py b/timm/models/layers/patch_embed.py index 157bc250..6a7facef 100644 --- a/timm/models/layers/patch_embed.py +++ b/timm/models/layers/patch_embed.py @@ -9,11 +9,7 @@ Hacked together by / Copyright 2020 Ross Wightman from torch import nn as nn from .helpers import to_2tuple -<<<<<<< HEAD from .trace_utils import _assert -======= -from timm.models.fx_helpers import fx_and ->>>>>>> Make all models FX traceable class PatchEmbed(nn.Module): diff --git a/timm/models/tnt.py b/timm/models/tnt.py index f9510487..92108fe5 100644 --- a/timm/models/tnt.py +++ b/timm/models/tnt.py @@ -12,7 +12,6 @@ import torch.nn as nn from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.models.helpers import build_model_with_cfg -from timm.models.fx_helpers import fx_and from timm.models.layers import Mlp, DropPath, trunc_normal_ from timm.models.layers.helpers import to_2tuple from timm.models.registry import register_model @@ -138,7 +137,9 @@ class PixelEmbed(nn.Module): def forward(self, x, pixel_pos): B, C, H, W = x.shape - torch._assert(fx_and(H == self.img_size[0], W == self.img_size[1]), + torch._assert(H == self.img_size[0], + f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]}).") + torch._assert(W == self.img_size[1], f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]}).") x = self.proj(x) x = self.unfold(x) From cf4561ca729c69bd9e5c600ca2dc40b7510c50b2 Mon Sep 17 00:00:00 2001 From: Alexander Soare Date: Thu, 12 Aug 2021 15:31:02 +0100 Subject: [PATCH 06/13] Add FX based FeatureGraphNet capability --- timm/models/fx_features.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/timm/models/fx_features.py b/timm/models/fx_features.py index 310cc465..3d253444 100644 --- a/timm/models/fx_features.py +++ b/timm/models/fx_features.py @@ -486,4 +486,4 @@ class FeatureGraphNet(nn.Module): self.graph_module = build_feature_graph_net(model, return_nodes) def forward(self, x): - return list(self.graph_module(x).values()) \ No newline at end of file + return list(self.graph_module(x).values()) From e051dce35451451b8ac7eee7b8abab38325a26b0 Mon Sep 17 00:00:00 2001 From: Alexander Soare Date: Thu, 12 Aug 2021 15:31:24 +0100 Subject: [PATCH 07/13] Make all models FX traceable --- timm/models/layers/bottleneck_attn.py | 1 + timm/models/layers/halo_attn.py | 1 - timm/models/layers/non_local_attn.py | 1 + timm/models/tnt.py | 1 + 4 files changed, 3 insertions(+), 1 deletion(-) diff --git a/timm/models/layers/bottleneck_attn.py b/timm/models/layers/bottleneck_attn.py index c56c5821..305f9de3 100644 --- a/timm/models/layers/bottleneck_attn.py +++ b/timm/models/layers/bottleneck_attn.py @@ -22,6 +22,7 @@ import torch.nn.functional as F from .helpers import to_2tuple, make_divisible from .weight_init import trunc_normal_ +from timm.models.fx_helpers import fx_and def rel_logits_1d(q, rel_k, permute_mask: List[int]): diff --git a/timm/models/layers/halo_attn.py b/timm/models/layers/halo_attn.py index babfcb06..ec93474f 100644 --- a/timm/models/layers/halo_attn.py +++ b/timm/models/layers/halo_attn.py @@ -24,7 +24,6 @@ import torch.nn.functional as F from .helpers import make_divisible from .weight_init import trunc_normal_ -from timm.models.fx_helpers import def rel_logits_1d(q, rel_k, permute_mask: List[int]): diff --git a/timm/models/layers/non_local_attn.py b/timm/models/layers/non_local_attn.py index f933ece2..5f83005c 100644 --- a/timm/models/layers/non_local_attn.py +++ b/timm/models/layers/non_local_attn.py @@ -10,6 +10,7 @@ from torch.nn import functional as F from .conv_bn_act import ConvBnAct from .helpers import make_divisible +from timm.models.fx_helpers import fx_and class NonLocalAttn(nn.Module): diff --git a/timm/models/tnt.py b/timm/models/tnt.py index 92108fe5..298808c3 100644 --- a/timm/models/tnt.py +++ b/timm/models/tnt.py @@ -12,6 +12,7 @@ import torch.nn as nn from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.models.helpers import build_model_with_cfg +from timm.models.fx_helpers import fx_and from timm.models.layers import Mlp, DropPath, trunc_normal_ from timm.models.layers.helpers import to_2tuple from timm.models.registry import register_model From b25ff9676848a25df5c87f489bcece89f216e749 Mon Sep 17 00:00:00 2001 From: Alexander Soare Date: Fri, 12 Nov 2021 20:42:45 +0000 Subject: [PATCH 08/13] wip - pre-rebase --- tests/test_models.py | 61 +++- timm/models/cait.py | 8 +- timm/models/coat.py | 11 +- timm/models/convit.py | 10 +- timm/models/crossvit.py | 40 ++- timm/models/fx_features.py | 473 ++----------------------- timm/models/fx_helpers.py | 7 - timm/models/layers/bottleneck_attn.py | 8 +- timm/models/layers/evo_norm.py | 6 +- timm/models/layers/global_context.py | 3 +- timm/models/layers/halo_attn.py | 9 +- timm/models/layers/lambda_layer.py | 4 +- timm/models/layers/non_local_attn.py | 8 +- timm/models/layers/selective_kernel.py | 3 +- timm/models/layers/swin_attn.py | 183 ---------- timm/models/levit.py | 8 +- timm/models/nest.py | 19 +- timm/models/nfnet.py | 2 - timm/models/swin_transformer.py | 16 +- timm/models/tnt.py | 10 +- timm/models/twins.py | 12 +- timm/models/vgg.py | 4 +- timm/models/visformer.py | 4 +- timm/models/vision_transformer.py | 4 +- timm/models/xcit.py | 6 +- 25 files changed, 185 insertions(+), 734 deletions(-) delete mode 100644 timm/models/fx_helpers.py delete mode 100644 timm/models/layers/swin_attn.py diff --git a/tests/test_models.py b/tests/test_models.py index 91fd3543..f7233ef3 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -4,10 +4,12 @@ import platform import os import fnmatch +from torchvision.models.feature_extraction import create_feature_extractor, get_graph_node_names, NodePathTracer + import timm from timm import list_models, create_model, set_scriptable, has_model_default_key, is_model_default_key, \ get_model_default_value -from timm.models.fx_features import NodePathTracer +from timm.models.fx_features import _leaf_modules, _autowrap_functions if hasattr(torch._C, '_jit_set_profiling_executor'): # legacy executor is too slow to compile large models for unit tests @@ -312,12 +314,14 @@ def test_model_forward_fx(model_name, batch_size): if max(input_size) > MAX_FWD_SIZE: pytest.skip("Fixed input size model > limit.") - tracer = NodePathTracer() - graph = tracer.trace(model) - model = torch.fx.GraphModule(model, graph) + train_nodes, eval_nodes = get_graph_node_names( + model, tracer_kwargs={'leaf_modules': list(_leaf_modules), 'autowrap_functions': list(_autowrap_functions)}) + model = create_feature_extractor( + model, train_return_nodes=[train_nodes[-1]], eval_return_nodes=[eval_nodes[-1]], + tracer_kwargs={'leaf_modules': list(_leaf_modules), 'autowrap_functions': list(_autowrap_functions)}) inputs = torch.randn((batch_size, *input_size)) - outputs = model(inputs) + outputs = model(inputs)[eval_nodes[-1]] assert outputs.shape[0] == batch_size assert not torch.isnan(outputs).any(), 'Output included NaNs' @@ -336,12 +340,30 @@ def test_model_backward_fx(model_name, batch_size): model.train() num_params = sum([x.numel() for x in model.parameters()]) - tracer = NodePathTracer() + input_size = _get_input_size(model=model, target=TARGET_FWD_SIZE) + if max(input_size) > MAX_FWD_SIZE: + pytest.skip("Fixed input size model > limit.") + + # This block of code does a bit of juggling to handle any case where there are multiple outputs in train mode + # If so, we need to return all of them in order to check all grads + # So we trace once and look at the graph, and get the indices of the nodes that lead into the original fx output + # node. Then we use those indices to select from train_nodes returned by torchvision get_graph_node_names + tracer = NodePathTracer(leaf_modules=list(_leaf_modules), autowrap_functions=list(_autowrap_functions)) graph = tracer.trace(model) - model = torch.fx.GraphModule(model, graph) + graph_nodes = list(reversed(graph.nodes)) + output_node_names = [n.name for n in graph_nodes[0]._input_nodes.keys()] + graph_node_names = [n.name for n in graph_nodes] + output_node_indices = [-graph_node_names.index(node_name) for node_name in output_node_names] + train_nodes, eval_nodes = get_graph_node_names( + model, tracer_kwargs={'leaf_modules': list(_leaf_modules), 'autowrap_functions': list(_autowrap_functions)}) + train_return_nodes = [train_nodes[ix] for ix in output_node_indices] + + model = create_feature_extractor( + model, train_return_nodes=train_return_nodes, eval_return_nodes=[eval_nodes[-1]], + tracer_kwargs={'leaf_modules': list(_leaf_modules), 'autowrap_functions': list(_autowrap_functions)}) inputs = torch.randn((batch_size, *input_size)) - outputs = model(inputs) + outputs = tuple(model(inputs).values()) if isinstance(outputs, tuple): outputs = torch.cat(outputs) outputs.mean().backward() @@ -354,9 +376,14 @@ def test_model_backward_fx(model_name, batch_size): assert not torch.isnan(outputs).any(), 'Output included NaNs' +EXCLUDE_FX_JIT_FILTERS = [ + 'beit_*' # reason: model is scripted after fx tracing, but beit has torch.jit.is_scripting() control flow +] + @pytest.mark.timeout(120) @pytest.mark.parametrize( - 'model_name', list_models(exclude_filters=EXCLUDE_FILTERS + EXCLUDE_JIT_FILTERS, name_matches_cfg=True)) + 'model_name', list_models( + exclude_filters=EXCLUDE_FILTERS + EXCLUDE_JIT_FILTERS + EXCLUDE_FX_JIT_FILTERS, name_matches_cfg=True)) @pytest.mark.parametrize('batch_size', [1]) def test_model_forward_fx_torchscript(model_name, batch_size): """Symbolically trace each model, script it, and run single forward pass""" @@ -368,12 +395,18 @@ def test_model_forward_fx_torchscript(model_name, batch_size): model = create_model(model_name, pretrained=False) model.eval() - tracer = NodePathTracer() - graph = tracer.trace(model) - model = torch.fx.GraphModule(model, graph) + input_size = _get_input_size(model=model, target=TARGET_FWD_SIZE) + if max(input_size) > MAX_FWD_SIZE: + pytest.skip("Fixed input size model > limit.") + + train_nodes, eval_nodes = get_graph_node_names( + model, tracer_kwargs={'leaf_modules': list(_leaf_modules), 'autowrap_functions': list(_autowrap_functions)}) + model = create_feature_extractor( + model, train_return_nodes=[train_nodes[-1]], eval_return_nodes=[eval_nodes[-1]], + tracer_kwargs={'leaf_modules': list(_leaf_modules), 'autowrap_functions': list(_autowrap_functions)}) model = torch.jit.script(model) - outputs = model(torch.randn((batch_size, *input_size))) + outputs = model(torch.randn((batch_size, *input_size)))[train_nodes[-1]] assert outputs.shape[0] == batch_size - assert not torch.isnan(outputs).any(), 'Output included NaNs' \ No newline at end of file + assert not torch.isnan(outputs).any(), 'Output included NaNs' diff --git a/timm/models/cait.py b/timm/models/cait.py index b6a18ce3..69b4ba06 100644 --- a/timm/models/cait.py +++ b/timm/models/cait.py @@ -95,11 +95,11 @@ class ClassAttn(nn.Module): q = q * self.scale v = self.v(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) - attn = torch.matmul(q, k.transpose(-2, -1)) + attn = (q @ k.transpose(-2, -1)) attn = attn.softmax(dim=-1) attn = self.attn_drop(attn) - x_cls = torch.matmul(attn, v).transpose(1, 2).reshape(B, 1, C) + x_cls = (attn @ v).transpose(1, 2).reshape(B, 1, C) x_cls = self.proj(x_cls) x_cls = self.proj_drop(x_cls) @@ -158,7 +158,7 @@ class TalkingHeadAttn(nn.Module): qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) q, k, v = qkv[0] * self.scale, qkv[1], qkv[2] - attn = torch.matmul(q, k.transpose(-2, -1)) + attn = (q @ k.transpose(-2, -1)) attn = self.proj_l(attn.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) @@ -167,7 +167,7 @@ class TalkingHeadAttn(nn.Module): attn = self.proj_w(attn.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) attn = self.attn_drop(attn) - x = torch.matmul(attn, v).transpose(1, 2).reshape(B, N, C) + x = (attn @ v).transpose(1, 2).reshape(B, N, C) x = self.proj(x) x = self.proj_drop(x) return x diff --git a/timm/models/coat.py b/timm/models/coat.py index 69b1bd9f..dca655ea 100644 --- a/timm/models/coat.py +++ b/timm/models/coat.py @@ -19,6 +19,7 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from .helpers import build_model_with_cfg, overlay_external_default_cfg from .layers import PatchEmbed, Mlp, DropPath, to_2tuple, trunc_normal_ from .registry import register_model +from .layers.trace_utils import _assert __all__ = [ @@ -105,7 +106,7 @@ class ConvRelPosEnc(nn.Module): def forward(self, q, v, size: Tuple[int, int]): B, h, N, Ch = q.shape H, W = size - torch._assert(N == 1 + H * W, '') + _assert(N == 1 + H * W, '') # Convolutional relative position encoding. q_img = q[:, :, 1:, :] # [B, h, H*W, Ch] @@ -149,8 +150,8 @@ class FactorAtt_ConvRelPosEnc(nn.Module): # Factorized attention. k_softmax = k.softmax(dim=2) - factor_att = torch.matmul(k_softmax.transpose(-1, -2), v) - factor_att = torch.matmul(q, factor_att) + factor_att = k_softmax.transpose(-1, -2) @ v + factor_att = q @ factor_att # Convolutional relative position encoding. crpe = self.crpe(q, v, size=size) # [B, h, N, Ch] @@ -177,7 +178,7 @@ class ConvPosEnc(nn.Module): def forward(self, x, size: Tuple[int, int]): B, N, C = x.shape H, W = size - torch._assert(N == 1 + H * W, '') + _assert(N == 1 + H * W, '') # Extract CLS token and image tokens. cls_token, img_tokens = x[:, :1], x[:, 1:] # [B, 1, C], [B, H*W, C] @@ -275,7 +276,7 @@ class ParallelBlock(nn.Module): """ Feature map interpolation. """ B, N, C = x.shape H, W = size - torch._assert(N == 1 + H * W, '') + _assert(N == 1 + H * W, '') cls_token = x[:, :1, :] img_tokens = x[:, 1:, :] diff --git a/timm/models/convit.py b/timm/models/convit.py index 603548f9..d2f69b68 100644 --- a/timm/models/convit.py +++ b/timm/models/convit.py @@ -57,7 +57,7 @@ default_cfgs = { } -@register_leaf_module # FX can't symbolically trace control flow in forward method +@register_leaf_module # reason: FX can't symbolically trace control flow in forward method class GPSA(nn.Module): def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0., locality_strength=1.): @@ -84,7 +84,7 @@ class GPSA(nn.Module): self.rel_indices = self.get_rel_indices(N) attn = self.get_attention(x) v = self.v(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) - x = torch.matmul(attn, v).transpose(1, 2).reshape(B, N, C) + x = (attn @ v).transpose(1, 2).reshape(B, N, C) x = self.proj(x) x = self.proj_drop(x) return x @@ -95,7 +95,7 @@ class GPSA(nn.Module): q, k = qk[0], qk[1] pos_score = self.rel_indices.expand(B, -1, -1, -1) pos_score = self.pos_proj(pos_score).permute(0, 3, 1, 2) - patch_score = torch.matmul(q, k.transpose(-2, -1)) * self.scale + patch_score = (q @ k.transpose(-2, -1)) * self.scale patch_score = patch_score.softmax(dim=-1) pos_score = pos_score.softmax(dim=-1) @@ -180,11 +180,11 @@ class MHSA(nn.Module): qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) q, k, v = qkv[0], qkv[1], qkv[2] - attn = torch.matmul(q, k.transpose(-2, -1)) * self.scale + attn = (q @ k.transpose(-2, -1)) * self.scale attn = attn.softmax(dim=-1) attn = self.attn_drop(attn) - x = torch.matmul(attn, v).transpose(1, 2).reshape(B, N, C) + x = (attn @ v).transpose(1, 2).reshape(B, N, C) x = self.proj(x) x = self.proj_drop(x) return x diff --git a/timm/models/crossvit.py b/timm/models/crossvit.py index 6e0160f9..3ba5b4c7 100644 --- a/timm/models/crossvit.py +++ b/timm/models/crossvit.py @@ -22,6 +22,7 @@ NOTE: model names have been renamed from originals to represent actual input res Modifed from Timm. https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py """ +from typing import Tuple import torch import torch.nn as nn @@ -31,8 +32,9 @@ from functools import partial from typing import List from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from .fx_features import register_autowrap_function from .helpers import build_model_with_cfg -from .layers import DropPath, to_2tuple, trunc_normal_ +from .layers import DropPath, to_2tuple, trunc_normal_, _assert from .registry import register_model from .vision_transformer import Mlp, Block @@ -116,8 +118,10 @@ class PatchEmbed(nn.Module): def forward(self, x): B, C, H, W = x.shape # FIXME look at relaxing size constraints - assert H == self.img_size[0] and W == self.img_size[1], \ - f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." + _assert(H == self.img_size[0], + f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]}).") + _assert(W == self.img_size[1], + f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]}).") x = self.proj(x).flatten(2).transpose(1, 2) return x @@ -255,6 +259,27 @@ def _compute_num_patches(img_size, patches): return [i[0] // p * i[1] // p for i, p in zip(img_size, patches)] +@register_autowrap_function +def scale_image(x, ss: Tuple[int, int], crop_scale: bool = False): # annotations for torchscript + """ + Pulled out of CrossViT.forward_features to bury conditional logic in a leaf node for FX tracing. + Args: + x (Tensor): input image + ss (tuple[int, int]): height and width to scale to + crop_scale (bool): whether to crop instead of interpolate to achieve the desired scale. Defaults to False + Returns: + Tensor: the "scaled" image batch tensor + """ + H, W = x.shape[-2:] + if H != ss[0] or W != ss[1]: + if crop_scale and ss[0] <= H and ss[1] <= W: + cu, cl = int(round((H - ss[0]) / 2.)), int(round((W - ss[1]) / 2.)) + x = x[:, :, cu:cu + ss[0], cl:cl + ss[1]] + else: + x = torch.nn.functional.interpolate(x, size=ss, mode='bicubic', align_corners=False) + return x + + class CrossViT(nn.Module): """ Vision Transformer with support for patch or hybrid CNN input stage """ @@ -342,17 +367,12 @@ class CrossViT(nn.Module): range(self.num_branches)]) def forward_features(self, x): - B, C, H, W = x.shape + B = x.shape[0] xs = [] for i, patch_embed in enumerate(self.patch_embed): x_ = x ss = self.img_size_scaled[i] - if H != ss[0] or W != ss[1]: - if self.crop_scale and ss[0] <= H and ss[1] <= W: - cu, cl = int(round((H - ss[0]) / 2.)), int(round((W - ss[1]) / 2.)) - x_ = x_[:, :, cu:cu + ss[0], cl:cl + ss[1]] - else: - x_ = torch.nn.functional.interpolate(x_, size=ss, mode='bicubic', align_corners=False) + x_ = scale_image(x_, ss, self.crop_scale) x_ = patch_embed(x_) cls_tokens = self.cls_token_0 if i == 0 else self.cls_token_1 # hard-coded for torch jit script cls_tokens = cls_tokens.expand(B, -1, -1) diff --git a/timm/models/fx_features.py b/timm/models/fx_features.py index 3d253444..c8f296d4 100644 --- a/timm/models/fx_features.py +++ b/timm/models/fx_features.py @@ -1,42 +1,31 @@ """ PyTorch FX Based Feature Extraction Helpers -An extension/alternative to timm.models.features making use of PyTorch FX. Here, the idea is to: - 1. Symbolically trace a model producing a graph based intermediate representation (PyTorch FX functionality with - some custom tweaks) - 2. Identify desired feature extraction nodes and reconfigure them as output nodes while deleting all unecessary - nodes. (custom - inspired by https://github.com/pytorch/vision/pull/3597) - 3. Write the resulting graph into a GraphModule (PyTorch FX functionality) -Copyright 2021 Alexander Soare +Using https://pytorch.org/vision/stable/feature_extraction.html """ -from typing import Callable, Dict, Union, List, Optional -import math -from collections import OrderedDict -from pprint import pprint -from inspect import ismethod -import re -import warnings -from copy import deepcopy -from itertools import chain - -import torch +from typing import Callable from torch import nn -from torch import fx -import torch.nn.functional as F -from torch.fx.graph_module import _copy_attr from .features import _get_feature_info -from .fx_helpers import fx_float_to_int -# Layers we went to treat as leaf modules for FeatureGraphNet -from .layers import Conv2dSame, ScaledStdConv2dSame, BatchNormAct2d, BlurPool2d, CondConv2d, StdConv2dSame -from .layers import GatherExcite, DropPath +try: + from torchvision.models.feature_extraction import create_feature_extractor +except ImportError: + pass + +# Layers we went to treat as leaf modules +from .layers import Conv2dSame, ScaledStdConv2dSame, BatchNormAct2d, BlurPool2d, CondConv2d, StdConv2dSame, DropPath from .layers.non_local_attn import BilinearAttnTransform from .layers.pool2d_same import MaxPool2dSame, AvgPool2dSame - -# These modules will not be traced through. +# NOTE: By default, any modules from timm.models.layers that we want to treat as leaf modules go here +# BUT modules from timm.models should use the registration mechanism below _leaf_modules = { - Conv2dSame, ScaledStdConv2dSame, BatchNormAct2d, BlurPool2d, CondConv2d, StdConv2dSame, GatherExcite, DropPath, - BilinearAttnTransform, MaxPool2dSame, AvgPool2dSame + BatchNormAct2d, # reason: flow control for jit scripting + BilinearAttnTransform, # reason: flow control t <= 1 + BlurPool2d, # reason: TypeError: F.conv2d received Proxy in groups=x.shape[1] + # Reason: get_same_padding has a max which raises a control flow error + Conv2dSame, MaxPool2dSame, ScaledStdConv2dSame, StdConv2dSame, AvgPool2dSame, + CondConv2d, # reason: TypeError: F.conv2d received Proxy in groups=self.groups * B (because B = x.shape[0]) + DropPath, # reason: TypeError: rand recieved Proxy in `size` argument } try: @@ -54,425 +43,16 @@ def register_leaf_module(module: nn.Module): return module -# These functions will not be traced through -_autowrap_functions=(fx_float_to_int,) - - -class TimmTracer(fx.Tracer): - """ - Temporary bridge from torch.fx.Tracer to include any general workarounds required to make FX work for us - """ - def __init__(self, autowrap_modules=(math, ), autowrap_functions=(), enable_cpatching=False): - super().__init__(autowrap_modules=autowrap_modules, enable_cpatching=enable_cpatching) - # FIXME: This is a workaround pending on a PyTorch PR https://github.com/pytorch/pytorch/pull/62106 - self._autowrap_function_ids.update(set([id(f) for f in autowrap_functions])) - - def create_node(self, kind, target, args, kwargs, name=None, type_expr=None): - # FIXME: This is a workaround pending on a PyTorch PR https://github.com/pytorch/pytorch/pull/62095 - if target == F.pad: - kwargs['value'] = float(kwargs['value']) - return super().create_node(kind, target, args, kwargs, name=name, type_expr=type_expr) - - -class LeafNodeTracer(TimmTracer): - """ - Account for desired leaf nodes according to _leaf_modules and _autowrap functions - """ - def __init__(self): - super().__init__(autowrap_functions=_autowrap_functions) - - def is_leaf_module(self, m: nn.Module, module_qualname: str) -> bool: - if isinstance(m, tuple(_leaf_modules)): - return True - return super().is_leaf_module(m, module_qualname) - - -def _is_subseq(x, y): - """Check if y is a subseqence of x - https://stackoverflow.com/a/24017747/4391249 - """ - iter_x = iter(x) - return all(any(x_item == y_item for x_item in iter_x) for y_item in y) - - -# Taken from https://github.com/pytorch/examples/blob/master/fx/module_tracer.py with modifications for storing -# qualified names for all Nodes, not just top-level Modules -class NodePathTracer(LeafNodeTracer): - """ - NodePathTracer is an FX tracer that, for each operation, also records the - qualified name of the Node from which the operation originated. A - qualified name here is a `.` seperated path walking the hierarchy from top - level module down to leaf operation or leaf module. The name of the top - level module is not included as part of the qualified name. For example, - if we trace a module who's forward method applies a ReLU module, the - qualified name for that node will simply be 'relu'. - - Some notes on the specifics: - - Nodes are recorded to `self.node_to_qualname` which is a dictionary - mapping a given Node object to its qualified name. - - Nodes are recorded in the order which they are executed during - tracing. - - When a duplicate qualified name is encountered, a suffix of the form - _{int} is added. The counter starts from 1. - """ - def __init__(self, *args, **kwargs): - super(NodePathTracer, self).__init__(*args, **kwargs) - # Track the qualified name of the Node being traced - self.current_module_qualname = '' - # A map from FX Node to the qualified name - self.node_to_qualname = OrderedDict() - - def call_module(self, m: torch.nn.Module, forward: Callable, args, kwargs): - """ - Override of `fx.Tracer.call_module` - This override: - 1) Stores away the qualified name of the caller for restoration later - 2) Adds the qualified name of the caller to - `current_module_qualname` for retrieval by `create_proxy` - 3) Once a leaf module is reached, calls `create_proxy` - 4) Restores the caller's qualified name into current_module_qualname - """ - old_qualname = self.current_module_qualname - try: - module_qualname = self.path_of_module(m) - self.current_module_qualname = module_qualname - if not self.is_leaf_module(m, module_qualname): - out = forward(*args, **kwargs) - return out - return self.create_proxy('call_module', module_qualname, args, kwargs) - finally: - self.current_module_qualname = old_qualname - - def create_proxy(self, kind: str, target: fx.node.Target, args, kwargs, - name=None, type_expr=None) -> fx.proxy.Proxy: - """ - Override of `Tracer.create_proxy`. This override intercepts the recording - of every operation and stores away the current traced module's qualified - name in `node_to_qualname` - """ - proxy = super().create_proxy(kind, target, args, kwargs, name, type_expr) - self.node_to_qualname[proxy.node] = self._get_node_qualname( - self.current_module_qualname, proxy.node) - return proxy - - def _get_node_qualname( - self, module_qualname: str, node: fx.node.Node) -> str: - node_qualname = module_qualname - if node.op == 'call_module': - # Node terminates in a leaf module so the module_qualname is a - # complete description of the node - for existing_qualname in reversed(self.node_to_qualname.values()): - # Check to see if existing_qualname is of the form - # {node_qualname} or {node_qualname}_{int} - if re.match(rf'{node_qualname}(_[0-9]+)?$', - existing_qualname) is not None: - postfix = existing_qualname.replace(node_qualname, '') - if len(postfix): - # Existing_qualname is of the form {node_qualname}_{int} - next_index = int(postfix[1:]) + 1 - else: - # existing_qualname is of the form {node_qualname} - next_index = 1 - node_qualname += f'_{next_index}' - break - else: - # Node terminates in non- leaf module so the node name needs to be - # appended - if len(node_qualname) > 0: - # Only append '.' if we are deeper than the top level module - node_qualname += '.' - node_qualname += str(node) - return node_qualname +# Functions we want to autowrap (treat them as leaves) +_autowrap_functions = set() -def _warn_graph_differences( - train_tracer: NodePathTracer, eval_tracer: NodePathTracer): +def register_autowrap_function(func: Callable): """ - Utility function for warning the user if there are differences between - the train graph and the eval graph. + Decorator for functions which ought not to be traced through """ - train_nodes = list(train_tracer.node_to_qualname.values()) - eval_nodes = list(eval_tracer.node_to_qualname.values()) - - if len(train_nodes) == len(eval_nodes) and [ - t == e for t, e in zip(train_nodes, eval_nodes)]: - return - - suggestion_msg = ( - "When choosing nodes for feature extraction, you may need to specify " - "output nodes for train and eval mode separately") - - if _is_subseq(train_nodes, eval_nodes): - msg = ("NOTE: The nodes obtained by tracing the model in eval mode " - "are a subsequence of those obtained in train mode. ") - elif _is_subseq(eval_nodes, train_nodes): - msg = ("NOTE: The nodes obtained by tracing the model in train mode " - "are a subsequence of those obtained in eval mode. ") - else: - msg = ("The nodes obtained by tracing the model in train mode " - "are different to those obtained in eval mode. ") - warnings.warn(msg + suggestion_msg) - - -def print_graph_node_qualified_names( - model: nn.Module, tracer_kwargs: Dict = {}): - """ - Dev utility to prints nodes in order of execution. Useful for choosing - nodes for a FeatureGraphNet design. There are two reasons that qualified - node names can't easily be read directly from the code for a model: - 1. Not all submodules are traced through. Modules from `torch.nn` all - fall within this category. - 2. Node qualified names that occur more than once in the graph get a - `_{counter}` postfix. - The model will be traced twice: once in train mode, and once in eval mode. - If there are discrepancies between the graphs produced, both sets will - be printed and the user will be warned. - - Args: - model (nn.Module): model on which we will extract the features - tracer_kwargs (Dict): a dictionary of keywork arguments for - `NodePathTracer` (which passes them onto it's parent class - `torch.fx.Tracer`). - """ - train_tracer = NodePathTracer(**tracer_kwargs) - train_tracer.trace(model.train()) - eval_tracer = NodePathTracer(**tracer_kwargs) - eval_tracer.trace(model.eval()) - train_nodes = list(train_tracer.node_to_qualname.values()) - eval_nodes = list(eval_tracer.node_to_qualname.values()) - if len(train_nodes) == len(eval_nodes) and [ - t == e for t, e in zip(train_nodes, eval_nodes)]: - # Nodes are aligned in train vs eval mode - pprint(list(train_tracer.node_to_qualname.values())) - return - print("Nodes from train mode:") - pprint(list(train_tracer.node_to_qualname.values())) - print() - print("Nodes from eval mode:") - pprint(list(eval_tracer.node_to_qualname.values())) - print() - _warn_graph_differences(train_tracer, eval_tracer) - - -class DualGraphModule(fx.GraphModule): - """ - A derivative of `fx.GraphModule`. Differs in the following ways: - - Requires a train and eval version of the underlying graph - - Copies submodules according to the nodes of both train and eval graphs. - - Calling train(mode) switches between train graph and eval graph. - """ - def __init__(self, - root: torch.nn.Module, - train_graph: fx.Graph, - eval_graph: fx.Graph, - class_name: str = 'GraphModule'): - """ - Args: - root (torch.nn.Module): module from which the copied module - hierarchy is built - train_graph (Graph): the graph that should be used in train mode - eval_graph (Graph): the graph that should be used in eval mode - """ - super(fx.GraphModule, self).__init__() - - self.__class__.__name__ = class_name - - self.train_graph = train_graph - self.eval_graph = eval_graph - - # Copy all get_attr and call_module ops (indicated by BOTH train and - # eval graphs) - for node in chain(iter(train_graph.nodes), iter(eval_graph.nodes)): - if node.op in ['get_attr', 'call_module']: - assert isinstance(node.target, str) - _copy_attr(root, self, node.target) - - # eval mode by default - self.eval() - self.graph = eval_graph - - # (borrowed from fx.GraphModule): - # Store the Tracer class responsible for creating a Graph separately as part of the - # GraphModule state, except when the Tracer is defined in a local namespace. - # Locally defined Tracers are not pickleable. This is needed because torch.package will - # serialize a GraphModule without retaining the Graph, and needs to use the correct Tracer - # to re-create the Graph during deserialization. - # TODO uncomment this when https://github.com/pytorch/pytorch/pull/63121 is available - # assert self.eval_graph._tracer_cls == self.train_graph._tracer_cls, \ - # "Train mode and eval mode should use the same tracer class" - # self._tracer_cls = None - # if self.graph._tracer_cls and '' not in self.graph._tracer_cls.__qualname__: - # self._tracer_cls = self.graph._tracer_cls - - def train(self, mode=True): - """ - Swap out the graph depending on the training mode. - NOTE this should be safe when calling model.eval() because that just - calls this with mode == False. - """ - if mode: - self.graph = self.train_graph - else: - self.graph = self.eval_graph - return super().train(mode=mode) - - -def build_feature_graph_net( - model: nn.Module, - return_nodes: Union[List[str], Dict[str, str]], - train_return_nodes: Optional[Union[List[str], Dict[str, str]]] = None, - eval_return_nodes: Optional[Union[List[str], Dict[str, str]]] = None, - tracer_kwargs: Dict = {}) -> fx.GraphModule: - """ - Creates a new graph module that returns intermediate nodes from a given - model as dictionary with user specified keys as strings, and the requested - outputs as values. This is achieved by re-writing the computation graph of - the model via FX to return the desired nodes as outputs. All unused nodes - are removed, together with their corresponding parameters. - - A note on node specification: A node qualified name is specified as a `.` - seperated path walking the hierarchy from top level module down to leaf - operation or leaf module. For instance `blocks.5.3.bn1`. The keys of the - `return_nodes` argument should point to either a node's qualified name, - or some truncated version of it. For example, one could provide `blocks.5` - as a key, and the last node with that prefix will be selected. - `print_graph_node_qualified_names` is a useful helper function for getting - a list of qualified names of a model. - - An attempt is made to keep all non-parametric properties of the original - model, but existing properties of the constructed `GraphModule` are not - overwritten. - - Args: - model (nn.Module): model on which we will extract the features - return_nodes (Union[List[name], Dict[name, new_name]])): either a list - or a dict containing the names (or partial names - see note above) - of the nodes for which the activations will be returned. If it is - a `Dict`, the keys are the qualified node names, and the values - are the user-specified keys for the graph module's returned - dictionary. If it is a `List`, it is treated as a `Dict` mapping - node specification strings directly to output names. - tracer_kwargs (Dict): a dictionary of keywork arguments for - `NodePathTracer` (which passes them onto it's parent class - `torch.fx.Tracer`). - - Examples:: - - >>> model = torchvision.models.resnet18() - >>> # extract layer1 and layer3, giving as names `feat1` and feat2` - >>> graph_module = torchvision.models._utils.build_feature_graph_net(m, - >>> {'layer1': 'feat1', 'layer3': 'feat2'}) - >>> out = graph_module(torch.rand(1, 3, 224, 224)) - >>> print([(k, v.shape) for k, v in out.items()]) - >>> [('feat1', torch.Size([1, 64, 56, 56])), - >>> ('feat2', torch.Size([1, 256, 14, 14]))] - - """ - is_training = model.training - - if isinstance(return_nodes, list): - return_nodes = {n: n for n in return_nodes} - return_nodes = {str(k): str(v) for k, v in return_nodes.items()} - - assert not ((train_return_nodes is None) ^ (eval_return_nodes is None)), \ - ("If any of `train_return_nodes` and `eval_return_nodes` are " - "specified, then both should be specified") - - if train_return_nodes is None: - train_return_nodes = deepcopy(return_nodes) - eval_return_nodes = deepcopy(return_nodes) - - # Repeat the tracing and graph rewriting for train and eval mode - tracers = {} - graphs = {} - return_nodes = { - 'train': train_return_nodes, - 'eval': eval_return_nodes - } - for mode in ['train', 'eval']: - if mode == 'train': - model.train() - elif mode == 'eval': - model.eval() - - # Instantiate our NodePathTracer and use that to trace the model - tracer = NodePathTracer(**tracer_kwargs) - graph = tracer.trace(model) - - name = model.__class__.__name__ if isinstance( - model, nn.Module) else model.__name__ - graph_module = fx.GraphModule(tracer.root, graph, name) - - available_nodes = [f'{v}.{k}' for k, v in tracer.node_to_qualname.items()] - # FIXME We don't know if we should expect this to happen - assert len(set(available_nodes)) == len(available_nodes), \ - "There are duplicate nodes! Please raise an issue https://github.com/pytorch/vision/issues" - # Check that all outputs in return_nodes are present in the model - for query in return_nodes[mode].keys(): - if not any([m.startswith(query) for m in available_nodes]): - raise ValueError(f"return_node: {query} is not present in model") - - # Remove existing output nodes (train mode) - orig_output_nodes = [] - for n in reversed(graph_module.graph.nodes): - if n.op == "output": - orig_output_nodes.append(n) - assert len(orig_output_nodes) - for n in orig_output_nodes: - graph_module.graph.erase_node(n) - - # Find nodes corresponding to return_nodes and make them into output_nodes - nodes = [n for n in graph_module.graph.nodes] - output_nodes = OrderedDict() - for n in reversed(nodes): - if 'tensor_constant' in str(n): - # NOTE Without this control flow we would get a None value for - # `module_qualname = tracer.node_to_qualname.get(n)`. - # On the other hand, we can safely assume that we'll never need to - # get this as an interesting intermediate node. - continue - module_qualname = tracer.node_to_qualname.get(n) - for query in return_nodes[mode]: - depth = query.count('.') - if '.'.join(module_qualname.split('.')[:depth + 1]) == query: - output_nodes[return_nodes[mode][query]] = n - return_nodes[mode].pop(query) - break - output_nodes = OrderedDict(reversed(list(output_nodes.items()))) - - # And add them in the end of the graph - with graph_module.graph.inserting_after(nodes[-1]): - graph_module.graph.output(output_nodes) - - # Remove unused modules / parameters - graph_module.graph.eliminate_dead_code() - graph_module.recompile() - - # Keep track of the tracer and graph so we can choose the main one - tracers[mode] = tracer - graphs[mode] = graph - - # Warn user if there are any discrepancies between the graphs of the - # train and eval modes - _warn_graph_differences(tracers['train'], tracers['eval']) - - # Build the final graph module - graph_module = DualGraphModule( - model, graphs['train'], graphs['eval'], class_name=name) - - # Keep non-parameter model properties for reference - for attr_str in model.__dir__(): - attr = getattr(model, attr_str) - if (not attr_str.startswith('_') - and attr_str not in graph_module.__dir__() - and not ismethod(attr) - and not isinstance(attr, (nn.Module, nn.Parameter))): - setattr(graph_module, attr_str, attr) - - # Restore original training mode - graph_module.train(is_training) - - return graph_module + _autowrap_functions.add(func) + return func class FeatureGraphNet(nn.Module): @@ -483,7 +63,10 @@ class FeatureGraphNet(nn.Module): assert len(out_map) == len(out_indices) return_nodes = {info['module']: out_map[i] if out_map is not None else info['module'] for i, info in enumerate(self.feature_info) if i in out_indices} - self.graph_module = build_feature_graph_net(model, return_nodes) + self.graph_module = create_feature_extractor( + model, return_nodes, + tracer_kwargs={'leaf_modules': list(_leaf_modules), 'autowrap_functions': list(_autowrap_functions)}) def forward(self, x): return list(self.graph_module(x).values()) + \ No newline at end of file diff --git a/timm/models/fx_helpers.py b/timm/models/fx_helpers.py deleted file mode 100644 index 878ba381..00000000 --- a/timm/models/fx_helpers.py +++ /dev/null @@ -1,7 +0,0 @@ - -def fx_float_to_int(x: float) -> int: - """ - Symbolic tracing helper to substitute for inbuilt `int`. - Hint: Inbuilt `int` can't accept an argument of type `Proxy` - """ - return int(x) \ No newline at end of file diff --git a/timm/models/layers/bottleneck_attn.py b/timm/models/layers/bottleneck_attn.py index 305f9de3..c3db464e 100644 --- a/timm/models/layers/bottleneck_attn.py +++ b/timm/models/layers/bottleneck_attn.py @@ -22,7 +22,7 @@ import torch.nn.functional as F from .helpers import to_2tuple, make_divisible from .weight_init import trunc_normal_ -from timm.models.fx_helpers import fx_and +from .trace_utils import _assert def rel_logits_1d(q, rel_k, permute_mask: List[int]): @@ -37,7 +37,7 @@ def rel_logits_1d(q, rel_k, permute_mask: List[int]): permute_mask: permute output dim according to this """ B, H, W, dim = q.shape - x = torch.matmul(q, rel_k.transpose(-1, -2)) + x = (q @ rel_k.transpose(-1, -2)) x = x.reshape(-1, W, 2 * W -1) # pad to shift from relative to absolute indexing @@ -134,8 +134,8 @@ class BottleneckAttn(nn.Module): def forward(self, x): B, C, H, W = x.shape - torch._assert(H == self.pos_embed.height, '') - torch._assert(W == self.pos_embed.width, '') + _assert(H == self.pos_embed.height, '') + _assert(W == self.pos_embed.width, '') x = self.qkv(x) # B, (2 * dim_head_qk + dim_head_v) * num_heads, H, W diff --git a/timm/models/layers/evo_norm.py b/timm/models/layers/evo_norm.py index 02aa0a0c..ecc5fb61 100644 --- a/timm/models/layers/evo_norm.py +++ b/timm/models/layers/evo_norm.py @@ -12,6 +12,8 @@ Hacked together by / Copyright 2020 Ross Wightman import torch import torch.nn as nn +from .trace_utils import _assert + class EvoNormBatch2d(nn.Module): def __init__(self, num_features, apply_act=True, momentum=0.1, eps=1e-5, drop_block=None): @@ -72,9 +74,9 @@ class EvoNormSample2d(nn.Module): nn.init.ones_(self.v) def forward(self, x): - torch._assert(x.dim() == 4, 'expected 4D input') + _assert(x.dim() == 4, 'expected 4D input') B, C, H, W = x.shape - torch._assert(C % self.groups == 0, '') + _assert(C % self.groups == 0, '') if self.apply_act: n = x * (x * self.v).sigmoid() x = x.reshape(B, self.groups, -1) diff --git a/timm/models/layers/global_context.py b/timm/models/layers/global_context.py index a0bb8a43..de7fb5c1 100644 --- a/timm/models/layers/global_context.py +++ b/timm/models/layers/global_context.py @@ -7,7 +7,6 @@ Official code consulted as reference: https://github.com/xvjiarui/GCNet Hacked together by / Copyright 2021 Ross Wightman """ -import torch from torch import nn as nn import torch.nn.functional as F @@ -53,7 +52,7 @@ class GlobalContext(nn.Module): if self.conv_attn is not None: attn = self.conv_attn(x).reshape(B, 1, H * W) # (B, 1, H * W) attn = F.softmax(attn, dim=-1).unsqueeze(3) # (B, 1, H * W, 1) - context = torch.matmul(x.reshape(B, C, H * W).unsqueeze(1), attn) + context = x.reshape(B, C, H * W).unsqueeze(1) @ attn context = context.view(B, C, 1, 1) else: context = x.mean(dim=(2, 3), keepdim=True) diff --git a/timm/models/layers/halo_attn.py b/timm/models/layers/halo_attn.py index ec93474f..f2ac64f8 100644 --- a/timm/models/layers/halo_attn.py +++ b/timm/models/layers/halo_attn.py @@ -16,7 +16,7 @@ The attention mechanism works but it's slow as implemented. Hacked together by / Copyright 2021 Ross Wightman """ -from typing import Tuple, List +from typing import List import torch from torch import nn @@ -24,6 +24,7 @@ import torch.nn.functional as F from .helpers import make_divisible from .weight_init import trunc_normal_ +from .trace_utils import _assert def rel_logits_1d(q, rel_k, permute_mask: List[int]): @@ -41,7 +42,7 @@ def rel_logits_1d(q, rel_k, permute_mask: List[int]): rel_size = rel_k.shape[0] win_size = (rel_size + 1) // 2 - x = torch.matmul(q, rel_k.transpose(-1, -2)) + x = (q @ rel_k.transpose(-1, -2)) x = x.reshape(-1, W, rel_size) # pad to shift from relative to absolute indexing @@ -167,8 +168,8 @@ class HaloAttn(nn.Module): def forward(self, x): B, C, H, W = x.shape - torch._assert(H % self.block_size == 0, '') - torch._assert(W % self.block_size == 0, '') + _assert(H % self.block_size == 0, '') + _assert(W % self.block_size == 0, '') num_h_blocks = H // self.block_size num_w_blocks = W // self.block_size num_blocks = num_h_blocks * num_w_blocks diff --git a/timm/models/layers/lambda_layer.py b/timm/models/layers/lambda_layer.py index 058426b6..e50b43c8 100644 --- a/timm/models/layers/lambda_layer.py +++ b/timm/models/layers/lambda_layer.py @@ -116,8 +116,8 @@ class LambdaLayer(nn.Module): v = self.norm_v(v).reshape(B, self.dim_v, M).transpose(-1, -2) # B, M, V k = F.softmax(k.reshape(B, self.dim_qk, M), dim=-1) # B, K, M - content_lam = torch.matmul(k, v) # B, K, V - content_out = torch.matmul(q, content_lam.unsqueeze(1)) # B, num_heads, M, V + content_lam = k @ v # B, K, V + content_out = q @ content_lam.unsqueeze(1) # B, num_heads, M, V if self.pos_emb is None: position_lam = self.conv_lambda(v.reshape(B, 1, H, W, self.dim_v)) # B, H, W, V, K diff --git a/timm/models/layers/non_local_attn.py b/timm/models/layers/non_local_attn.py index 5f83005c..881fa36d 100644 --- a/timm/models/layers/non_local_attn.py +++ b/timm/models/layers/non_local_attn.py @@ -10,7 +10,7 @@ from torch.nn import functional as F from .conv_bn_act import ConvBnAct from .helpers import make_divisible -from timm.models.fx_helpers import fx_and +from .trace_utils import _assert class NonLocalAttn(nn.Module): @@ -84,7 +84,7 @@ class BilinearAttnTransform(nn.Module): def resize_mat(self, x, t: int): B, C, block_size, block_size1 = x.shape - torch._assert(block_size == block_size1, '') + _assert(block_size == block_size1, '') if t <= 1: return x x = x.view(B * C, -1, 1, 1) @@ -96,8 +96,8 @@ class BilinearAttnTransform(nn.Module): return x def forward(self, x): - torch._assert(x.shape[-1] % self.block_size == 0, '') - torch._assert(x.shape[-2] % self.block_size == 0, '') + _assert(x.shape[-1] % self.block_size == 0, '') + _assert(x.shape[-2] % self.block_size == 0, '') B, C, H, W = x.shape out = self.conv1(x) rp = F.adaptive_max_pool2d(out, (self.block_size, 1)) diff --git a/timm/models/layers/selective_kernel.py b/timm/models/layers/selective_kernel.py index 69aca86b..1aeb9294 100644 --- a/timm/models/layers/selective_kernel.py +++ b/timm/models/layers/selective_kernel.py @@ -9,6 +9,7 @@ from torch import nn as nn from .conv_bn_act import ConvBnAct from .helpers import make_divisible +from .trace_utils import _assert def _kernel_valid(k): @@ -34,7 +35,7 @@ class SelectiveKernelAttn(nn.Module): self.fc_select = nn.Conv2d(attn_channels, channels * num_paths, kernel_size=1, bias=False) def forward(self, x): - torch._assert(x.shape[1] == self.num_paths, '') + _assert(x.shape[1] == self.num_paths, '') x = x.sum(1).mean((2, 3), keepdim=True) x = self.fc_reduce(x) x = self.bn(x) diff --git a/timm/models/layers/swin_attn.py b/timm/models/layers/swin_attn.py deleted file mode 100644 index 2a3731f3..00000000 --- a/timm/models/layers/swin_attn.py +++ /dev/null @@ -1,183 +0,0 @@ -""" Shifted Window Attn - -This is a WIP experiment to apply windowed attention from the Swin Transformer -to a stand-alone module for use as an attn block in conv nets. - -Based on original swin window code at https://github.com/microsoft/Swin-Transformer -Swin Transformer paper: https://arxiv.org/pdf/2103.14030.pdf -""" -from typing import Optional - -import torch -import torch.nn as nn - -from .drop import DropPath -from .helpers import to_2tuple -from .weight_init import trunc_normal_ -from timm.models.fx_helpers import fx_float_to_int - - -def window_partition(x, win_size: int): - """ - Args: - x: (B, H, W, C) - win_size (int): window size - - Returns: - windows: (num_windows*B, window_size, window_size, C) - """ - B, H, W, C = x.shape - x = x.view(B, H // win_size, win_size, W // win_size, win_size, C) - windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, win_size, win_size, C) - return windows - - -def window_reverse(windows, win_size: int, H: int, W: int): - """ - Args: - windows: (num_windows*B, window_size, window_size, C) - win_size (int): Window size - H (int): Height of image - W (int): Width of image - - Returns: - x: (B, H, W, C) - """ - B = fx_float_to_int(windows.shape[0] / (H * W / win_size / win_size)) - x = windows.view(B, H // win_size, W // win_size, win_size, win_size, -1) - x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) - return x - - -class WindowAttention(nn.Module): - r""" Window based multi-head self attention (W-MSA) module with relative position bias. - It supports both of shifted and non-shifted window. - - Args: - dim (int): Number of input channels. - win_size (int): The height and width of the window. - num_heads (int): Number of attention heads. - qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True - attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 - """ - - def __init__( - self, dim, dim_out=None, feat_size=None, stride=1, win_size=8, shift_size=None, num_heads=8, - qkv_bias=True, attn_drop=0.): - - super().__init__() - self.dim_out = dim_out or dim - self.feat_size = to_2tuple(feat_size) - self.win_size = win_size - self.shift_size = shift_size or win_size // 2 - if min(self.feat_size) <= win_size: - # if window size is larger than input resolution, we don't partition windows - self.shift_size = 0 - self.win_size = min(self.feat_size) - assert 0 <= self.shift_size < self.win_size, "shift_size must in 0-window_size" - self.num_heads = num_heads - head_dim = self.dim_out // num_heads - self.scale = head_dim ** -0.5 - - if self.shift_size > 0: - # calculate attention mask for SW-MSA - H, W = self.feat_size - img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1 - h_slices = ( - slice(0, -self.win_size), - slice(-self.win_size, -self.shift_size), - slice(-self.shift_size, None)) - w_slices = ( - slice(0, -self.win_size), - slice(-self.win_size, -self.shift_size), - slice(-self.shift_size, None)) - cnt = 0 - for h in h_slices: - for w in w_slices: - img_mask[:, h, w, :] = cnt - cnt += 1 - mask_windows = window_partition(img_mask, self.win_size) # num_win, window_size, window_size, 1 - mask_windows = mask_windows.view(-1, self.win_size * self.win_size) - attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) - attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) - else: - attn_mask = None - self.register_buffer("attn_mask", attn_mask) - - # define a parameter table of relative position bias - self.relative_position_bias_table = nn.Parameter( - # 2 * Wh - 1 * 2 * Ww - 1, nH - torch.zeros((2 * self.win_size - 1) * (2 * self.win_size - 1), num_heads)) - trunc_normal_(self.relative_position_bias_table, std=.02) - - # get pair-wise relative position index for each token inside the window - coords_h = torch.arange(self.win_size) - coords_w = torch.arange(self.win_size) - coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww - coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww - relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww - relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 - relative_coords[:, :, 0] += self.win_size - 1 # shift to start from 0 - relative_coords[:, :, 1] += self.win_size - 1 - relative_coords[:, :, 0] *= 2 * self.win_size - 1 - relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww - self.register_buffer("relative_position_index", relative_position_index) - - self.qkv = nn.Linear(dim, self.dim_out * 3, bias=qkv_bias) - self.attn_drop = nn.Dropout(attn_drop) - self.softmax = nn.Softmax(dim=-1) - self.pool = nn.AvgPool2d(2, 2) if stride == 2 else nn.Identity() - - def reset_parameters(self): - trunc_normal_(self.qkv.weight, std=self.qkv.weight.shape[1] ** -0.5) - trunc_normal_(self.relative_position_bias_table, std=.02) - - def forward(self, x): - B, C, H, W = x.shape - x = x.permute(0, 2, 3, 1) - - # cyclic shift - if self.shift_size > 0: - shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) - else: - shifted_x = x - - # partition windows - win_size_sq = self.win_size * self.win_size - x_windows = window_partition(shifted_x, self.win_size) # num_win * B, window_size, window_size, C - x_windows = x_windows.view(-1, win_size_sq, C) # num_win * B, window_size*window_size, C - BW, N, _ = x_windows.shape - - qkv = self.qkv(x_windows) - qkv = qkv.reshape(BW, N, 3, self.num_heads, self.dim_out // self.num_heads).permute(2, 0, 3, 1, 4) - q, k, v = qkv[0], qkv[1], qkv[2] - q = q * self.scale - attn = torch.matmul(q, k.transpose(-2, -1)) - - relative_position_bias = self.relative_position_bias_table[ - self.relative_position_index.view(-1)].view(win_size_sq, win_size_sq, -1) - relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh * Ww, Wh * Ww - attn = attn + relative_position_bias.unsqueeze(0) - if self.attn_mask is not None: - num_win = self.attn_mask.shape[0] - attn = attn.view(B, num_win, self.num_heads, N, N) + self.attn_mask.unsqueeze(1).unsqueeze(0) - attn = attn.view(-1, self.num_heads, N, N) - attn = self.softmax(attn) - attn = self.attn_drop(attn) - - x = torch.matmul(attn, v).transpose(1, 2).reshape(BW, N, self.dim_out) - - # merge windows - x = x.view(-1, self.win_size, self.win_size, self.dim_out) - shifted_x = window_reverse(x, self.win_size, H, W) # B H' W' C - - # reverse cyclic shift - if self.shift_size > 0: - x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) - else: - x = shifted_x - x = x.view(B, H, W, self.dim_out).permute(0, 3, 1, 2) - x = self.pool(x) - return x - - diff --git a/timm/models/levit.py b/timm/models/levit.py index c4377bb1..9987e4ba 100644 --- a/timm/models/levit.py +++ b/timm/models/levit.py @@ -293,10 +293,10 @@ class Attention(nn.Module): k = k.permute(0, 2, 1, 3) v = v.permute(0, 2, 1, 3) - attn = torch.matmul(q, k.transpose(-2, -1)) * self.scale + self.get_attention_biases(x.device) + attn = q @ k.transpose(-2, -1) * self.scale + self.get_attention_biases(x.device) attn = attn.softmax(dim=-1) - x = torch.matmul(attn, v).transpose(1, 2).reshape(B, N, self.dh) + x = (attn @ v).transpose(1, 2).reshape(B, N, self.dh) x = self.proj(x) return x @@ -387,10 +387,10 @@ class AttentionSubsample(nn.Module): v = v.permute(0, 2, 1, 3) # BHNC q = self.q(x).view(B, self.resolution_2, self.num_heads, self.key_dim).permute(0, 2, 1, 3) - attn = torch.matmul(q, k.transpose(-2, -1)) * self.scale + self.get_attention_biases(x.device) + attn = q @ k.transpose(-2, -1) * self.scale + self.get_attention_biases(x.device) attn = attn.softmax(dim=-1) - x = torch.matmul(attn, v).transpose(1, 2).reshape(B, -1, self.dh) + x = (attn @ v).transpose(1, 2).reshape(B, -1, self.dh) x = self.proj(x) return x diff --git a/timm/models/nest.py b/timm/models/nest.py index 73f14da5..c5951aea 100644 --- a/timm/models/nest.py +++ b/timm/models/nest.py @@ -25,13 +25,13 @@ import torch.nn.functional as F from torch import nn from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from .fx_features import register_autowrap_function from .helpers import build_model_with_cfg, named_apply -from .fx_helpers import fx_float_to_int from .layers import PatchEmbed, Mlp, DropPath, create_classifier, trunc_normal_ +from .layers.trace_utils import _assert from .layers import create_conv2d, create_pool2d, to_ntuple from .registry import register_model - _logger = logging.getLogger(__name__) @@ -85,12 +85,12 @@ class Attention(nn.Module): qkv = self.qkv(x).reshape(B, T, N, 3, self.num_heads, C // self.num_heads).permute(3, 0, 4, 1, 2, 5) q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple) - attn = torch.matmul(q, k.transpose(-2, -1)) * self.scale # (B, H, T, N, N) + attn = (q @ k.transpose(-2, -1)) * self.scale # (B, H, T, N, N) attn = attn.softmax(dim=-1) attn = self.attn_drop(attn) # (B, H, T, N, C'), permute -> (B, T, N, C', H) - x = torch.matmul(attn, v).permute(0, 2, 3, 4, 1).reshape(B, T, N, C) + x = (attn @ v).permute(0, 2, 3, 4, 1).reshape(B, T, N, C) x = self.proj(x) x = self.proj_drop(x) return x # (B, T, N, C) @@ -130,8 +130,8 @@ class ConvPool(nn.Module): """ x is expected to have shape (B, C, H, W) """ - torch._assert(x.shape[-2] % 2 == 0, 'BlockAggregation requires even input spatial dims') - torch._assert(x.shape[-1] % 2 == 0, 'BlockAggregation requires even input spatial dims') + _assert(x.shape[-2] % 2 == 0, 'BlockAggregation requires even input spatial dims') + _assert(x.shape[-1] % 2 == 0, 'BlockAggregation requires even input spatial dims') x = self.conv(x) # Layer norm done over channel dim only x = self.norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) @@ -146,8 +146,8 @@ def blockify(x, block_size: int): block_size (int): edge length of a single square block in units of H, W """ B, H, W, C = x.shape - torch._assert(H % block_size == 0, '`block_size` must divide input height evenly') - torch._assert(W % block_size == 0, '`block_size` must divide input width evenly') + _assert(H % block_size == 0, '`block_size` must divide input height evenly') + _assert(W % block_size == 0, '`block_size` must divide input width evenly') grid_height = H // block_size grid_width = W // block_size x = x.reshape(B, grid_height, block_size, grid_width, block_size, C) @@ -155,6 +155,7 @@ def blockify(x, block_size: int): return x # (B, T, N, C) +@register_autowrap_function # reason: int receives Proxy def deblockify(x, block_size: int): """blocks to image Args: @@ -162,7 +163,7 @@ def deblockify(x, block_size: int): block_size (int): edge length of a single square block in units of desired H, W """ B, T, _, C = x.shape - grid_size = fx_float_to_int(math.sqrt(T)) + grid_size = int(math.sqrt(T)) height = width = grid_size * block_size x = x.reshape(B, grid_size, grid_size, block_size, block_size, C) x = x.transpose(2, 3).reshape(B, height, width, C) diff --git a/timm/models/nfnet.py b/timm/models/nfnet.py index ec86dbb8..4e0f2b21 100644 --- a/timm/models/nfnet.py +++ b/timm/models/nfnet.py @@ -27,7 +27,6 @@ import torch.nn as nn from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from .helpers import build_model_with_cfg -from timm.models.fx_features import register_leaf_module from .registry import register_model from .layers import ClassifierHead, DropPath, AvgPool2dSame, ScaledStdConv2d, ScaledStdConv2dSame,\ get_act_layer, get_act_fn, get_attn, make_divisible @@ -319,7 +318,6 @@ class DownsampleAvg(nn.Module): return self.conv(self.pool(x)) -@register_leaf_module # FX feature extraction was giving different valued features. Perhaps to do with control flow? class NormFreeBlock(nn.Module): """Normalization-Free pre-activation block. """ diff --git a/timm/models/swin_transformer.py b/timm/models/swin_transformer.py index 53c7bcd5..d5dd5513 100644 --- a/timm/models/swin_transformer.py +++ b/timm/models/swin_transformer.py @@ -21,9 +21,10 @@ import torch.nn as nn import torch.utils.checkpoint as checkpoint from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from .fx_features import register_autowrap_function from .helpers import build_model_with_cfg, overlay_external_default_cfg -from .fx_helpers import fx_float_to_int from .layers import PatchEmbed, Mlp, DropPath, to_2tuple, trunc_normal_ +from .layers.trace_utils import _assert from .registry import register_model from .vision_transformer import checkpoint_filter_fn, _init_vit_weights @@ -102,6 +103,7 @@ def window_partition(x, window_size: int): return windows +@register_autowrap_function # reason: int argument is a Proxy def window_reverse(windows, window_size: int, H: int, W: int): """ Args: @@ -113,7 +115,7 @@ def window_reverse(windows, window_size: int, H: int, W: int): Returns: x: (B, H, W, C) """ - B = fx_float_to_int(windows.shape[0] / (H * W / window_size / window_size)) + B = int(windows.shape[0] / (H * W / window_size / window_size)) x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) return x @@ -177,7 +179,7 @@ class WindowAttention(nn.Module): q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple) q = q * self.scale - attn = torch.matmul(q, k.transpose(-2, -1)) + attn = (q @ k.transpose(-2, -1)) relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH @@ -194,7 +196,7 @@ class WindowAttention(nn.Module): attn = self.attn_drop(attn) - x = torch.matmul(attn, v).transpose(1, 2).reshape(B_, N, C) + x = (attn @ v).transpose(1, 2).reshape(B_, N, C) x = self.proj(x) x = self.proj_drop(x) return x @@ -272,7 +274,7 @@ class SwinTransformerBlock(nn.Module): def forward(self, x): H, W = self.input_resolution B, L, C = x.shape - torch._assert(L == H * W, "input feature has wrong size") + _assert(L == H * W, "input feature has wrong size") shortcut = x x = self.norm1(x) @@ -331,8 +333,8 @@ class PatchMerging(nn.Module): """ H, W = self.input_resolution B, L, C = x.shape - torch._assert(L == H * W, "input feature has wrong size") - torch._assert(H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even.") + _assert(L == H * W, "input feature has wrong size") + _assert(H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even.") x = x.view(B, H, W, C) diff --git a/timm/models/tnt.py b/timm/models/tnt.py index 298808c3..1ad481f6 100644 --- a/timm/models/tnt.py +++ b/timm/models/tnt.py @@ -12,9 +12,9 @@ import torch.nn as nn from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.models.helpers import build_model_with_cfg -from timm.models.fx_helpers import fx_and from timm.models.layers import Mlp, DropPath, trunc_normal_ from timm.models.layers.helpers import to_2tuple +from timm.models.layers.trace_utils import _assert from timm.models.registry import register_model from timm.models.vision_transformer import resize_pos_embed @@ -64,11 +64,11 @@ class Attention(nn.Module): q, k = qk.unbind(0) # make torchscript happy (cannot use tensor as tuple) v = self.v(x).reshape(B, N, self.num_heads, -1).permute(0, 2, 1, 3) - attn = torch.matmul(q, k.transpose(-2, -1)) * self.scale + attn = (q @ k.transpose(-2, -1)) * self.scale attn = attn.softmax(dim=-1) attn = self.attn_drop(attn) - x = torch.matmul(attn, v).transpose(1, 2).reshape(B, N, -1) + x = (attn @ v).transpose(1, 2).reshape(B, N, -1) x = self.proj(x) x = self.proj_drop(x) return x @@ -138,9 +138,9 @@ class PixelEmbed(nn.Module): def forward(self, x, pixel_pos): B, C, H, W = x.shape - torch._assert(H == self.img_size[0], + _assert(H == self.img_size[0], f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]}).") - torch._assert(W == self.img_size[1], + _assert(W == self.img_size[1], f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]}).") x = self.proj(x) x = self.unfold(x) diff --git a/timm/models/twins.py b/timm/models/twins.py index 7b5afafb..9ae70d32 100644 --- a/timm/models/twins.py +++ b/timm/models/twins.py @@ -25,7 +25,7 @@ from .layers import Mlp, DropPath, to_2tuple, trunc_normal_ from .fx_features import register_leaf_module from .registry import register_model from .vision_transformer import Attention -from .helpers import build_model_with_cfg, overlay_external_default_cfg +from .helpers import build_model_with_cfg def _cfg(url='', **kwargs): @@ -63,7 +63,7 @@ default_cfgs = { Size_ = Tuple[int, int] -@register_leaf_module # FX can't symbolically trace control flow in forward method +@register_leaf_module # reason: FX can't symbolically trace control flow in forward method class LocallyGroupedAttn(nn.Module): """ LSA: self attention within a group """ @@ -100,10 +100,10 @@ class LocallyGroupedAttn(nn.Module): qkv = self.qkv(x).reshape( B, _h * _w, self.ws * self.ws, 3, self.num_heads, C // self.num_heads).permute(3, 0, 1, 4, 2, 5) q, k, v = qkv[0], qkv[1], qkv[2] - attn = torch.matmul(q, k.transpose(-2, -1)) * self.scale + attn = (q @ k.transpose(-2, -1)) * self.scale attn = attn.softmax(dim=-1) attn = self.attn_drop(attn) - attn = torch.matmul(attn, v).transpose(2, 3).reshape(B, _h, _w, self.ws, self.ws, C) + attn = (attn @ v).transpose(2, 3).reshape(B, _h, _w, self.ws, self.ws, C) x = attn.transpose(2, 3).reshape(B, _h * self.ws, _w * self.ws, C) if pad_r > 0 or pad_b > 0: x = x[:, :H, :W, :].contiguous() @@ -185,11 +185,11 @@ class GlobalSubSampleAttn(nn.Module): kv = self.kv(x).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) k, v = kv[0], kv[1] - attn = torch.matmul(q, k.transpose(-2, -1)) * self.scale + attn = (q @ k.transpose(-2, -1)) * self.scale attn = attn.softmax(dim=-1) attn = self.attn_drop(attn) - x = torch.matmul(attn, v).transpose(1, 2).reshape(B, N, C) + x = (attn @ v).transpose(1, 2).reshape(B, N, C) x = self.proj(x) x = self.proj_drop(x) diff --git a/timm/models/vgg.py b/timm/models/vgg.py index aee41b25..0f62ac4e 100644 --- a/timm/models/vgg.py +++ b/timm/models/vgg.py @@ -13,7 +13,7 @@ from typing import Union, List, Dict, Any, cast from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from .helpers import build_model_with_cfg from .fx_features import register_leaf_module -from .layers import ClassifierHead, ConvBnAct +from .layers import ClassifierHead from .registry import register_model __all__ = [ @@ -53,7 +53,7 @@ cfgs: Dict[str, List[Union[str, int]]] = { } -@register_leaf_module # FX can't symbolically trace control flow in forward method +@register_leaf_module # reason: FX can't symbolically trace control flow in forward method class ConvMlp(nn.Module): def __init__(self, in_features=512, out_features=4096, kernel_size=7, mlp_ratio=1.0, diff --git a/timm/models/visformer.py b/timm/models/visformer.py index 6ed43102..6e832cd0 100644 --- a/timm/models/visformer.py +++ b/timm/models/visformer.py @@ -100,10 +100,10 @@ class Attention(nn.Module): x = self.qkv(x).reshape(B, 3, self.num_heads, self.head_dim, -1).permute(1, 0, 2, 4, 3) q, k, v = x[0], x[1], x[2] - attn = torch.matmul(q, k.transpose(-2, -1)) * self.scale + attn = (q @ k.transpose(-2, -1)) * self.scale attn = attn.softmax(dim=-1) attn = self.attn_drop(attn) - x = torch.matmul(attn, v) + x = attn @ v x = x.permute(0, 1, 3, 2).reshape(B, -1, H, W) x = self.proj(x) diff --git a/timm/models/vision_transformer.py b/timm/models/vision_transformer.py index fb939624..94ae2666 100644 --- a/timm/models/vision_transformer.py +++ b/timm/models/vision_transformer.py @@ -192,11 +192,11 @@ class Attention(nn.Module): qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple) - attn = torch.matmul(q, k.transpose(-2, -1)) * self.scale + attn = (q @ k.transpose(-2, -1)) * self.scale attn = attn.softmax(dim=-1) attn = self.attn_drop(attn) - x = torch.matmul(attn, v).transpose(1, 2).reshape(B, N, C) + x = (attn @ v).transpose(1, 2).reshape(B, N, C) x = self.proj(x) x = self.proj_drop(x) return x diff --git a/timm/models/xcit.py b/timm/models/xcit.py index 9be29046..f5dd0683 100644 --- a/timm/models/xcit.py +++ b/timm/models/xcit.py @@ -98,7 +98,7 @@ default_cfgs = { } -@register_leaf_module # FX can't symbolically trace torch.arange in forward method +@register_leaf_module # reason: FX can't symbolically trace torch.arange in forward method class PositionalEncodingFourier(nn.Module): """ Positional encoding relying on a fourier kernel matching the one used in the "Attention is all of Need" paper. @@ -274,12 +274,12 @@ class XCA(nn.Module): # Paper section 3.2 l2-Normalization and temperature scaling q = torch.nn.functional.normalize(q, dim=-1) k = torch.nn.functional.normalize(k, dim=-1) - attn = torch.matmul(q, k.transpose(-2, -1)) * self.temperature + attn = (q @ k.transpose(-2, -1)) * self.temperature attn = attn.softmax(dim=-1) attn = self.attn_drop(attn) # (B, H, C', N), permute -> (B, N, H, C') - x = torch.matmul(attn, v).permute(0, 3, 1, 2).reshape(B, N, C) + x = (attn @ v).permute(0, 3, 1, 2).reshape(B, N, C) x = self.proj(x) x = self.proj_drop(x) return x From d2994016e952e9b11b9d8a32020f7565a3c163b5 Mon Sep 17 00:00:00 2001 From: Alexander Soare Date: Fri, 12 Nov 2021 21:16:53 +0000 Subject: [PATCH 09/13] Add try/except guards --- tests/test_models.py | 15 ++++++++++++++- timm/models/fx_features.py | 6 ++++-- 2 files changed, 18 insertions(+), 3 deletions(-) diff --git a/tests/test_models.py b/tests/test_models.py index f7233ef3..e513dcaf 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -4,7 +4,11 @@ import platform import os import fnmatch -from torchvision.models.feature_extraction import create_feature_extractor, get_graph_node_names, NodePathTracer +try: + from torchvision.models.feature_extraction import create_feature_extractor, get_graph_node_names, NodePathTracer + has_fx_feature_extraction = True +except ImportError: + has_fx_feature_extraction = False import timm from timm import list_models, create_model, set_scriptable, has_model_default_key, is_model_default_key, \ @@ -307,6 +311,9 @@ def test_model_forward_features(model_name, batch_size): @pytest.mark.parametrize('batch_size', [1]) def test_model_forward_fx(model_name, batch_size): """Symbolically trace each model and run single forward pass through the resulting GraphModule""" + if not has_fx_feature_extraction: + pytest.skip("Can't test FX because Torch >= 1.10 and Torchvision >= 0.11 are required") + model = create_model(model_name, pretrained=False) model.eval() @@ -332,6 +339,9 @@ def test_model_forward_fx(model_name, batch_size): @pytest.mark.parametrize('batch_size', [2]) def test_model_backward_fx(model_name, batch_size): """Symbolically trace each model and run single backward pass through the resulting GraphModule""" + if not has_fx_feature_extraction: + pytest.skip("Can't test FX because Torch >= 1.10 and Torchvision >= 0.11 are required") + input_size = _get_input_size(model_name=model_name, target=TARGET_BWD_SIZE) if max(input_size) > MAX_BWD_SIZE: pytest.skip("Fixed input size model > limit.") @@ -387,6 +397,9 @@ EXCLUDE_FX_JIT_FILTERS = [ @pytest.mark.parametrize('batch_size', [1]) def test_model_forward_fx_torchscript(model_name, batch_size): """Symbolically trace each model, script it, and run single forward pass""" + if not has_fx_feature_extraction: + pytest.skip("Can't test FX because Torch >= 1.10 and Torchvision >= 0.11 are required") + input_size = _get_input_size(model_name=model_name, target=TARGET_JIT_SIZE) if max(input_size) > MAX_JIT_SIZE: pytest.skip("Fixed input size model > limit.") diff --git a/timm/models/fx_features.py b/timm/models/fx_features.py index c8f296d4..a582cf9b 100644 --- a/timm/models/fx_features.py +++ b/timm/models/fx_features.py @@ -8,8 +8,9 @@ from .features import _get_feature_info try: from torchvision.models.feature_extraction import create_feature_extractor + has_fx_feature_extraction = True except ImportError: - pass + has_fx_feature_extraction = False # Layers we went to treat as leaf modules from .layers import Conv2dSame, ScaledStdConv2dSame, BatchNormAct2d, BlurPool2d, CondConv2d, StdConv2dSame, DropPath @@ -58,6 +59,7 @@ def register_autowrap_function(func: Callable): class FeatureGraphNet(nn.Module): def __init__(self, model, out_indices, out_map=None): super().__init__() + assert has_fx_feature_extraction, 'Please update to PyTorch 1.10+, torchvision 0.11+ for FX feature extraction' self.feature_info = _get_feature_info(model, out_indices) if out_map is not None: assert len(out_map) == len(out_indices) @@ -66,7 +68,7 @@ class FeatureGraphNet(nn.Module): self.graph_module = create_feature_extractor( model, return_nodes, tracer_kwargs={'leaf_modules': list(_leaf_modules), 'autowrap_functions': list(_autowrap_functions)}) - + def forward(self, x): return list(self.graph_module(x).values()) \ No newline at end of file From 0262a0e8e16c31a4e8157a44dafbe5f6b8c21495 Mon Sep 17 00:00:00 2001 From: Alexander Soare Date: Sat, 13 Nov 2021 00:06:33 +0000 Subject: [PATCH 10/13] fx ready for review --- tests/test_models.py | 37 ++++++++++++++++++++++++++++++------- timm/models/nfnet.py | 2 ++ 2 files changed, 32 insertions(+), 7 deletions(-) diff --git a/tests/test_models.py b/tests/test_models.py index e513dcaf..93152d9a 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -310,7 +310,10 @@ def test_model_forward_features(model_name, batch_size): @pytest.mark.parametrize('model_name', list_models(exclude_filters=EXCLUDE_FILTERS)) @pytest.mark.parametrize('batch_size', [1]) def test_model_forward_fx(model_name, batch_size): - """Symbolically trace each model and run single forward pass through the resulting GraphModule""" + """ + Symbolically trace each model and run single forward pass through the resulting GraphModule + Also check that the output of a forward pass through the GraphModule is the same as that from the original Module + """ if not has_fx_feature_extraction: pytest.skip("Can't test FX because Torch >= 1.10 and Torchvision >= 0.11 are required") @@ -321,15 +324,32 @@ def test_model_forward_fx(model_name, batch_size): if max(input_size) > MAX_FWD_SIZE: pytest.skip("Fixed input size model > limit.") + # This block of code does a bit of juggling to handle any case where there are multiple outputs in train mode + # So we trace once and look at the graph, and get the indices of the nodes that lead into the original fx output + # node. Then we use those indices to select from train_nodes returned by torchvision get_graph_node_names + tracer = NodePathTracer(leaf_modules=list(_leaf_modules), autowrap_functions=list(_autowrap_functions)) + graph = tracer.trace(model) + graph_nodes = list(reversed(graph.nodes)) + output_node_names = [n.name for n in graph_nodes[0]._input_nodes.keys()] + graph_node_names = [n.name for n in graph_nodes] + output_node_indices = [-graph_node_names.index(node_name) for node_name in output_node_names] train_nodes, eval_nodes = get_graph_node_names( model, tracer_kwargs={'leaf_modules': list(_leaf_modules), 'autowrap_functions': list(_autowrap_functions)}) - model = create_feature_extractor( - model, train_return_nodes=[train_nodes[-1]], eval_return_nodes=[eval_nodes[-1]], + eval_return_nodes = [eval_nodes[ix] for ix in output_node_indices] + + fx_model = create_feature_extractor( + model, train_return_nodes=[train_nodes[-1]], eval_return_nodes=eval_return_nodes, tracer_kwargs={'leaf_modules': list(_leaf_modules), 'autowrap_functions': list(_autowrap_functions)}) inputs = torch.randn((batch_size, *input_size)) - outputs = model(inputs)[eval_nodes[-1]] + outputs = model(inputs) + if isinstance(outputs, tuple): + outputs = torch.cat(outputs) + fx_outputs = tuple(fx_model(inputs).values()) + if isinstance(fx_outputs, tuple): + fx_outputs = torch.cat(fx_outputs) + assert torch.all(fx_outputs == outputs) assert outputs.shape[0] == batch_size assert not torch.isnan(outputs).any(), 'Output included NaNs' @@ -348,6 +368,7 @@ def test_model_backward_fx(model_name, batch_size): model = create_model(model_name, pretrained=False, num_classes=42) model.train() + num_params = sum([x.numel() for x in model.parameters()]) input_size = _get_input_size(model=model, target=TARGET_FWD_SIZE) @@ -355,7 +376,6 @@ def test_model_backward_fx(model_name, batch_size): pytest.skip("Fixed input size model > limit.") # This block of code does a bit of juggling to handle any case where there are multiple outputs in train mode - # If so, we need to return all of them in order to check all grads # So we trace once and look at the graph, and get the indices of the nodes that lead into the original fx output # node. Then we use those indices to select from train_nodes returned by torchvision get_graph_node_names tracer = NodePathTracer(leaf_modules=list(_leaf_modules), autowrap_functions=list(_autowrap_functions)) @@ -385,9 +405,12 @@ def test_model_backward_fx(model_name, batch_size): assert num_params == num_grad, 'Some parameters are missing gradients' assert not torch.isnan(outputs).any(), 'Output included NaNs' - +# reason: model is scripted after fx tracing, but beit has torch.jit.is_scripting() control flow EXCLUDE_FX_JIT_FILTERS = [ - 'beit_*' # reason: model is scripted after fx tracing, but beit has torch.jit.is_scripting() control flow + 'beit_*', + 'deit_*_distilled_patch16_224', + 'levit*', + 'pit_*_distilled_224', ] @pytest.mark.timeout(120) diff --git a/timm/models/nfnet.py b/timm/models/nfnet.py index 4e0f2b21..1d6cbb38 100644 --- a/timm/models/nfnet.py +++ b/timm/models/nfnet.py @@ -26,6 +26,7 @@ import torch import torch.nn as nn from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from .fx_features import register_leaf_module from .helpers import build_model_with_cfg from .registry import register_model from .layers import ClassifierHead, DropPath, AvgPool2dSame, ScaledStdConv2d, ScaledStdConv2dSame,\ @@ -318,6 +319,7 @@ class DownsampleAvg(nn.Module): return self.conv(self.pool(x)) +@register_leaf_module # reason: mul_ causes FX to drop a relevant node. https://github.com/pytorch/pytorch/issues/68301 class NormFreeBlock(nn.Module): """Normalization-Free pre-activation block. """ From 65d827c7a6739b20dcd4c57216f20adc521a6b2a Mon Sep 17 00:00:00 2001 From: Alexander Soare Date: Mon, 15 Nov 2021 21:03:21 +0000 Subject: [PATCH 11/13] rename notrace registration and standardize trace_utils imports --- timm/models/coat.py | 2 +- timm/models/convit.py | 4 ++-- timm/models/crossvit.py | 4 ++-- timm/models/fx_features.py | 4 ++-- timm/models/nest.py | 6 +++--- timm/models/nfnet.py | 4 ++-- timm/models/swin_transformer.py | 6 +++--- timm/models/tnt.py | 2 +- timm/models/twins.py | 4 ++-- timm/models/vgg.py | 4 ++-- timm/models/xcit.py | 4 ++-- 11 files changed, 22 insertions(+), 22 deletions(-) diff --git a/timm/models/coat.py b/timm/models/coat.py index dca655ea..18ff8ab9 100644 --- a/timm/models/coat.py +++ b/timm/models/coat.py @@ -19,7 +19,7 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from .helpers import build_model_with_cfg, overlay_external_default_cfg from .layers import PatchEmbed, Mlp, DropPath, to_2tuple, trunc_normal_ from .registry import register_model -from .layers.trace_utils import _assert +from .layers import _assert __all__ = [ diff --git a/timm/models/convit.py b/timm/models/convit.py index d2f69b68..6ef1da72 100644 --- a/timm/models/convit.py +++ b/timm/models/convit.py @@ -30,7 +30,7 @@ from .helpers import build_model_with_cfg from .layers import DropPath, to_2tuple, trunc_normal_, PatchEmbed, Mlp from .registry import register_model from .vision_transformer_hybrid import HybridEmbed -from .fx_features import register_leaf_module +from .fx_features import register_notrace_module import torch import torch.nn as nn @@ -57,7 +57,7 @@ default_cfgs = { } -@register_leaf_module # reason: FX can't symbolically trace control flow in forward method +@register_notrace_module # reason: FX can't symbolically trace control flow in forward method class GPSA(nn.Module): def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0., locality_strength=1.): diff --git a/timm/models/crossvit.py b/timm/models/crossvit.py index 3ba5b4c7..ddc4f64c 100644 --- a/timm/models/crossvit.py +++ b/timm/models/crossvit.py @@ -32,7 +32,7 @@ from functools import partial from typing import List from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from .fx_features import register_autowrap_function +from .fx_features import register_notrace_function from .helpers import build_model_with_cfg from .layers import DropPath, to_2tuple, trunc_normal_, _assert from .registry import register_model @@ -259,7 +259,7 @@ def _compute_num_patches(img_size, patches): return [i[0] // p * i[1] // p for i, p in zip(img_size, patches)] -@register_autowrap_function +@register_notrace_function def scale_image(x, ss: Tuple[int, int], crop_scale: bool = False): # annotations for torchscript """ Pulled out of CrossViT.forward_features to bury conditional logic in a leaf node for FX tracing. diff --git a/timm/models/fx_features.py b/timm/models/fx_features.py index a582cf9b..2e01586b 100644 --- a/timm/models/fx_features.py +++ b/timm/models/fx_features.py @@ -36,7 +36,7 @@ except ImportError: pass -def register_leaf_module(module: nn.Module): +def register_notrace_module(module: nn.Module): """ Any module not under timm.models.layers should get this decorator if we don't want to trace through it. """ @@ -48,7 +48,7 @@ def register_leaf_module(module: nn.Module): _autowrap_functions = set() -def register_autowrap_function(func: Callable): +def register_notrace_function(func: Callable): """ Decorator for functions which ought not to be traced through """ diff --git a/timm/models/nest.py b/timm/models/nest.py index c5951aea..22cf6099 100644 --- a/timm/models/nest.py +++ b/timm/models/nest.py @@ -25,10 +25,10 @@ import torch.nn.functional as F from torch import nn from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from .fx_features import register_autowrap_function +from .fx_features import register_notrace_function from .helpers import build_model_with_cfg, named_apply from .layers import PatchEmbed, Mlp, DropPath, create_classifier, trunc_normal_ -from .layers.trace_utils import _assert +from .layers import _assert from .layers import create_conv2d, create_pool2d, to_ntuple from .registry import register_model @@ -155,7 +155,7 @@ def blockify(x, block_size: int): return x # (B, T, N, C) -@register_autowrap_function # reason: int receives Proxy +@register_notrace_function # reason: int receives Proxy def deblockify(x, block_size: int): """blocks to image Args: diff --git a/timm/models/nfnet.py b/timm/models/nfnet.py index 1d6cbb38..973cbd66 100644 --- a/timm/models/nfnet.py +++ b/timm/models/nfnet.py @@ -26,7 +26,7 @@ import torch import torch.nn as nn from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from .fx_features import register_leaf_module +from .fx_features import register_notrace_module from .helpers import build_model_with_cfg from .registry import register_model from .layers import ClassifierHead, DropPath, AvgPool2dSame, ScaledStdConv2d, ScaledStdConv2dSame,\ @@ -319,7 +319,7 @@ class DownsampleAvg(nn.Module): return self.conv(self.pool(x)) -@register_leaf_module # reason: mul_ causes FX to drop a relevant node. https://github.com/pytorch/pytorch/issues/68301 +@register_notrace_module # reason: mul_ causes FX to drop a relevant node. https://github.com/pytorch/pytorch/issues/68301 class NormFreeBlock(nn.Module): """Normalization-Free pre-activation block. """ diff --git a/timm/models/swin_transformer.py b/timm/models/swin_transformer.py index d5dd5513..92057902 100644 --- a/timm/models/swin_transformer.py +++ b/timm/models/swin_transformer.py @@ -21,10 +21,10 @@ import torch.nn as nn import torch.utils.checkpoint as checkpoint from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from .fx_features import register_autowrap_function +from .fx_features import register_notrace_function from .helpers import build_model_with_cfg, overlay_external_default_cfg from .layers import PatchEmbed, Mlp, DropPath, to_2tuple, trunc_normal_ -from .layers.trace_utils import _assert +from .layers import _assert from .registry import register_model from .vision_transformer import checkpoint_filter_fn, _init_vit_weights @@ -103,7 +103,7 @@ def window_partition(x, window_size: int): return windows -@register_autowrap_function # reason: int argument is a Proxy +@register_notrace_function # reason: int argument is a Proxy def window_reverse(windows, window_size: int, H: int, W: int): """ Args: diff --git a/timm/models/tnt.py b/timm/models/tnt.py index 1ad481f6..d52f9ce6 100644 --- a/timm/models/tnt.py +++ b/timm/models/tnt.py @@ -14,7 +14,7 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.models.helpers import build_model_with_cfg from timm.models.layers import Mlp, DropPath, trunc_normal_ from timm.models.layers.helpers import to_2tuple -from timm.models.layers.trace_utils import _assert +from timm.models.layers import _assert from timm.models.registry import register_model from timm.models.vision_transformer import resize_pos_embed diff --git a/timm/models/twins.py b/timm/models/twins.py index 9ae70d32..67a939d4 100644 --- a/timm/models/twins.py +++ b/timm/models/twins.py @@ -22,7 +22,7 @@ from functools import partial from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from .layers import Mlp, DropPath, to_2tuple, trunc_normal_ -from .fx_features import register_leaf_module +from .fx_features import register_notrace_module from .registry import register_model from .vision_transformer import Attention from .helpers import build_model_with_cfg @@ -63,7 +63,7 @@ default_cfgs = { Size_ = Tuple[int, int] -@register_leaf_module # reason: FX can't symbolically trace control flow in forward method +@register_notrace_module # reason: FX can't symbolically trace control flow in forward method class LocallyGroupedAttn(nn.Module): """ LSA: self attention within a group """ diff --git a/timm/models/vgg.py b/timm/models/vgg.py index 0f62ac4e..11f6d0ea 100644 --- a/timm/models/vgg.py +++ b/timm/models/vgg.py @@ -12,7 +12,7 @@ from typing import Union, List, Dict, Any, cast from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from .helpers import build_model_with_cfg -from .fx_features import register_leaf_module +from .fx_features import register_notrace_module from .layers import ClassifierHead from .registry import register_model @@ -53,7 +53,7 @@ cfgs: Dict[str, List[Union[str, int]]] = { } -@register_leaf_module # reason: FX can't symbolically trace control flow in forward method +@register_notrace_module # reason: FX can't symbolically trace control flow in forward method class ConvMlp(nn.Module): def __init__(self, in_features=512, out_features=4096, kernel_size=7, mlp_ratio=1.0, diff --git a/timm/models/xcit.py b/timm/models/xcit.py index f5dd0683..ac5e802c 100644 --- a/timm/models/xcit.py +++ b/timm/models/xcit.py @@ -21,7 +21,7 @@ from .vision_transformer import _cfg, Mlp from .registry import register_model from .layers import DropPath, trunc_normal_, to_2tuple from .cait import ClassAttn -from .fx_features import register_leaf_module +from .fx_features import register_notrace_module def _cfg(url='', **kwargs): @@ -98,7 +98,7 @@ default_cfgs = { } -@register_leaf_module # reason: FX can't symbolically trace torch.arange in forward method +@register_notrace_module # reason: FX can't symbolically trace torch.arange in forward method class PositionalEncodingFourier(nn.Module): """ Positional encoding relying on a fourier kernel matching the one used in the "Attention is all of Need" paper. From 1076a65df1fa8fe0612adbe21284f4722b585dac Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Thu, 18 Nov 2021 19:47:07 -0800 Subject: [PATCH 12/13] Minor post FX merge cleanup --- tests/test_models.py | 6 +++--- timm/models/fx_features.py | 1 - 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/tests/test_models.py b/tests/test_models.py index 93152d9a..7a3f143e 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -315,7 +315,7 @@ def test_model_forward_fx(model_name, batch_size): Also check that the output of a forward pass through the GraphModule is the same as that from the original Module """ if not has_fx_feature_extraction: - pytest.skip("Can't test FX because Torch >= 1.10 and Torchvision >= 0.11 are required") + pytest.skip("Can't test FX. Torch >= 1.10 and Torchvision >= 0.11 are required.") model = create_model(model_name, pretrained=False) model.eval() @@ -360,7 +360,7 @@ def test_model_forward_fx(model_name, batch_size): def test_model_backward_fx(model_name, batch_size): """Symbolically trace each model and run single backward pass through the resulting GraphModule""" if not has_fx_feature_extraction: - pytest.skip("Can't test FX because Torch >= 1.10 and Torchvision >= 0.11 are required") + pytest.skip("Can't test FX. Torch >= 1.10 and Torchvision >= 0.11 are required.") input_size = _get_input_size(model_name=model_name, target=TARGET_BWD_SIZE) if max(input_size) > MAX_BWD_SIZE: @@ -421,7 +421,7 @@ EXCLUDE_FX_JIT_FILTERS = [ def test_model_forward_fx_torchscript(model_name, batch_size): """Symbolically trace each model, script it, and run single forward pass""" if not has_fx_feature_extraction: - pytest.skip("Can't test FX because Torch >= 1.10 and Torchvision >= 0.11 are required") + pytest.skip("Can't test FX. Torch >= 1.10 and Torchvision >= 0.11 are required.") input_size = _get_input_size(model_name=model_name, target=TARGET_JIT_SIZE) if max(input_size) > MAX_JIT_SIZE: diff --git a/timm/models/fx_features.py b/timm/models/fx_features.py index 2e01586b..5a25ee3e 100644 --- a/timm/models/fx_features.py +++ b/timm/models/fx_features.py @@ -71,4 +71,3 @@ class FeatureGraphNet(nn.Module): def forward(self, x): return list(self.graph_module(x).values()) - \ No newline at end of file From f2006b24370338643e30eba72c3b7a124ee4b5b3 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Thu, 18 Nov 2021 21:25:00 -0800 Subject: [PATCH 13/13] Cleanup qkv_bias cat in beit model so it can be traced --- tests/test_models.py | 1 - timm/models/beit.py | 10 +++------- 2 files changed, 3 insertions(+), 8 deletions(-) diff --git a/tests/test_models.py b/tests/test_models.py index 7a3f143e..39e2dcdc 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -407,7 +407,6 @@ def test_model_backward_fx(model_name, batch_size): # reason: model is scripted after fx tracing, but beit has torch.jit.is_scripting() control flow EXCLUDE_FX_JIT_FILTERS = [ - 'beit_*', 'deit_*_distilled_patch16_224', 'levit*', 'pit_*_distilled_224', diff --git a/timm/models/beit.py b/timm/models/beit.py index 199c2a4b..f644b657 100644 --- a/timm/models/beit.py +++ b/timm/models/beit.py @@ -86,9 +86,11 @@ class Attention(nn.Module): self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False) if qkv_bias: self.q_bias = nn.Parameter(torch.zeros(all_head_dim)) + self.register_buffer('k_bias', torch.zeros(all_head_dim), persistent=False) self.v_bias = nn.Parameter(torch.zeros(all_head_dim)) else: self.q_bias = None + self.k_bias = None self.v_bias = None if window_size: @@ -127,13 +129,7 @@ class Attention(nn.Module): def forward(self, x, rel_pos_bias: Optional[torch.Tensor] = None): B, N, C = x.shape - qkv_bias = None - if self.q_bias is not None: - if torch.jit.is_scripting(): - # FIXME requires_grad breaks w/ torchscript - qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias), self.v_bias)) - else: - qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias)) + qkv_bias = torch.cat((self.q_bias, self.k_bias, self.v_bias)) if self.q_bias is not None else None qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias) qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)