From ab3ac3f25b1df54e45cc91daecb73bf2e6c30825 Mon Sep 17 00:00:00 2001 From: Alexander Soare Date: Thu, 12 Aug 2021 15:31:02 +0100 Subject: [PATCH 01/26] Add FX based FeatureGraphNet capability --- timm/models/fx_features.py | 291 +++++++++++++++++++++++++++++++++++++ timm/models/fx_helpers.py | 17 +++ timm/models/helpers.py | 3 + 3 files changed, 311 insertions(+) create mode 100644 timm/models/fx_features.py create mode 100644 timm/models/fx_helpers.py diff --git a/timm/models/fx_features.py b/timm/models/fx_features.py new file mode 100644 index 00000000..e00b080f --- /dev/null +++ b/timm/models/fx_features.py @@ -0,0 +1,291 @@ +""" PyTorch FX Based Feature Extraction Helpers +An extension/alternative to timm.models.features making use of PyTorch FX. Here, the idea is to: + 1. Symbolically trace a model producing a graph based intermediate representation (PyTorch FX functionality with + some custom tweaks) + 2. Identify desired feature extraction nodes and reconfigure them as output nodes while deleting all unecessary + nodes. (custom - inspired by https://github.com/pytorch/vision/pull/3597) + 3. Write the resulting graph into a GraphModule (PyTorch FX functionality) +Copyright 2021 Alexander Soare +""" +from typing import Callable, Dict +import math +from collections import OrderedDict +from pprint import pprint +from inspect import ismethod +import re +import warnings + +import torch +from torch import nn +from torch import fx +import torch.nn.functional as F + +from .features import _get_feature_info +from .fx_helpers import fx_and, fx_float_to_int + +# Layers we went to treat as leaf modules for FeatureGraphNet +from .layers import Conv2dSame, ScaledStdConv2dSame, BatchNormAct2d, BlurPool2d, CondConv2d, StdConv2dSame +from .layers import GatherExcite, DropPath +from .layers.non_local_attn import BilinearAttnTransform +from .layers.pool2d_same import MaxPool2dSame, AvgPool2dSame + + +# These modules will not be traced through. +_leaf_modules = { + Conv2dSame, ScaledStdConv2dSame, BatchNormAct2d, BlurPool2d, CondConv2d, StdConv2dSame, GatherExcite, DropPath, + BilinearAttnTransform, MaxPool2dSame, AvgPool2dSame +} + +try: + from .layers import InplaceAbn + _leaf_modules.add(InplaceAbn) +except ImportError: + pass + + +def register_leaf_module(module: nn.Module): + """ + Any module not under timm.models.layers should get this decorator if we don't want to trace through it. + """ + _leaf_modules.add(module) + return module + + +# These functions will not be traced through +_autowrap_functions=(fx_float_to_int, fx_and) + + +class TimmTracer(fx.Tracer): + """ + Temporary bridge from torch.fx.Tracer to include any general workarounds required to make FX work for us + """ + def __init__(self, autowrap_modules=(math, ), autowrap_functions=(), enable_cpatching=False): + super().__init__(autowrap_modules=autowrap_modules, enable_cpatching=enable_cpatching) + # FIXME: This is a workaround pending on a PyTorch PR https://github.com/pytorch/pytorch/pull/62106 + self._autowrap_function_ids.update(set([id(f) for f in autowrap_functions])) + + def create_node(self, kind, target, args, kwargs, name=None, type_expr=None): + # FIXME: This is a workaround pending on a PyTorch PR https://github.com/pytorch/pytorch/pull/62095 + if target == F.pad: + kwargs['value'] = float(kwargs['value']) + return super().create_node(kind, target, args, kwargs, name=name, type_expr=type_expr) + + +class LeafNodeTracer(TimmTracer): + """ + Account for desired leaf nodes according to _leaf_modules and _autowrap functions + """ + def __init__(self): + super().__init__(autowrap_functions=_autowrap_functions) + + def is_leaf_module(self, m: nn.Module, module_qualname: str) -> bool: + if isinstance(m, tuple(_leaf_modules)): + return True + return super().is_leaf_module(m, module_qualname) + + +# Taken from https://github.com/pytorch/examples/blob/master/fx/module_tracer.py with modifications for storing +# qualified names for all Nodes, not just top-level Modules +class NodePathTracer(LeafNodeTracer): + """ + NodePathTracer is an FX tracer that, for each operation, also records the qualified name of the Node from which the + operation originated. A qualified name here is a `.` seperated path walking the hierarchy from top level module + down to leaf operation or leaf module. The name of the top level module is not included as part of the qualified + name. For example, if we trace a module who's forward method applies a ReLU module, the qualified name for that + node will simply be 'relu'. + """ + def __init__(self): + super().__init__() + # Track the qualified name of the Node being traced + self.current_module_qualname : str = '' + # A map from FX Node to the qualified name + self.node_to_qualname = OrderedDict() + + def call_module(self, m: torch.nn.Module, forward: Callable, args, kwargs): + """ + Override of Tracer.call_module (see https://pytorch.org/docs/stable/fx.html#torch.fx.Tracer.call_module). + This override: + 1) Stores away the qualified name of the caller for restoration later + 2) Installs the qualified name of the caller in `current_module_qualname` for retrieval by `create_proxy` + 3) Once a leaf module is reached, calls `create_proxy` + 4) Restores the caller's qualified name into current_module_qualname + """ + old_qualname = self.current_module_qualname + try: + module_qualname = self.path_of_module(m) + self.current_module_qualname = module_qualname + if not self.is_leaf_module(m, module_qualname): + out = forward(*args, **kwargs) + return out + return self.create_proxy('call_module', module_qualname, args, kwargs) + finally: + self.current_module_qualname = old_qualname + + def create_proxy(self, kind: str, target: fx.node.Target, args, kwargs, name=None, type_expr=None): + """ + Override of `Tracer.create_proxy`. This override intercepts the recording + of every operation and stores away the current traced module's qualified + name in `node_to_qualname` + """ + proxy = super().create_proxy(kind, target, args, kwargs, name, type_expr) + self.node_to_qualname[proxy.node] = self._get_node_qualname( + self.current_module_qualname, proxy.node) + return proxy + + def _get_node_qualname(self, module_qualname: str, node: fx.node.Node): + node_qualname = module_qualname + if node.op == 'call_module': + # Node terminates in a leaf module so the module_qualname is a complete description of the node + # Just need to check if this module has appeared before. If so add postfix counter starting from _1 for the + # first reappearance (this follows the way that repeated leaf ops are enumerated by PyTorch FX) + for existing_qualname in reversed(self.node_to_qualname.values()): + # Check to see if existing_qualname is of the form {node_qualname} or {node_qualname}_{int} + if re.match(rf'{node_qualname}(_[0-9]+)?$', existing_qualname) is not None: + postfix = existing_qualname.replace(node_qualname, '') + if len(postfix): + # existing_qualname is of the form {node_qualname}_{int} + next_index = int(postfix[1:]) + 1 + else: + # existing_qualname is of the form {node_qualname} + next_index = 1 + node_qualname += f'_{next_index}' + break + else: + # Node terminates in non- leaf module so the node name needs to be appended + if len(node_qualname) > 0: # only append '.' if we are deeper than the top level module + node_qualname += '.' + node_qualname += str(node) + return node_qualname + + +def print_graph_node_qualified_names(model: nn.Module): + """ + Dev utility to prints nodes in order of execution. Useful for choosing `nodes` for a FeatureGraphNet design. + This is useful for two reasons: + 1. Not all submodules are traced through. Some are treated as leaf modules. See `LeafNodeTracer` + 2. Leaf ops that occur more than once in the graph get a `_{counter}` postfix. + + WARNING: Changes to the operations in the original module might not change the module's overall behaviour, but they + may result in changes to the postfixes for the names of repeated ops, thereby breaking feature extraction. + """ + tracer = NodePathTracer() + tracer.trace(model) + pprint(list(tracer.node_to_qualname.values())) + + +def get_intermediate_nodes(model: nn.Module, return_nodes: Dict[str, str]) -> nn.Module: + """ + Creates a new FX-based module that returns intermediate nodes from a given model. This is achieved by re-writing + the computation graph of the model via FX to return the desired nodes as outputs. All unused nodes are removed, + together with their corresponding parameters. + Args: + model (nn.Module): model on which we will extract the features + return_nodes (Dict[name, new_name]): a dict containing the names (or partial names - see note below) of the + nodes for which the activations will be returned as the keys. The values of the dict are the names + of the returned activations (which the user can specify). + A note on node specification: A node is specified as a `.` seperated path walking the hierarchy from top + level module down to leaf operation or leaf module. For instance `blocks.5.3.bn1`. Nevertheless, the keys + in this dict need not be fully specified. One could provide `blocks.5` as a key, and the last node with + that prefix will be selected. + While designing a feature extractor one can use the `print_graph_node_qualified_names` utility as a guide + to which nodes are available. + Acknowledgement: Starter code from https://github.com/pytorch/vision/pull/3597 + """ + return_nodes = {str(k): str(v) for k, v in return_nodes.items()} + + # Instantiate our NodePathTracer and use that to trace the model + tracer = NodePathTracer() + graph = tracer.trace(model) + + name = model.__class__.__name__ if isinstance(model, nn.Module) else model.__name__ + graph_module = fx.GraphModule(tracer.root, graph, name) + + available_nodes = [f'{v}.{k}' for k, v in tracer.node_to_qualname.items()] + # FIXME We don't know if we should expect this to happen + assert len(set(available_nodes)) == len(available_nodes), \ + "There are duplicate nodes! Please raise an issue https://github.com/rwightman/pytorch-image-models/issues" + # Check that all outputs in return_nodes are present in the model + for query in return_nodes.keys(): + if not any([m.startswith(query) for m in available_nodes]): + raise ValueError(f"return_node: {query} is not present in model") + + # Remove existing output nodes + orig_output_node = None + for n in reversed(graph_module.graph.nodes): + if n.op == "output": + orig_output_node = n + assert orig_output_node + # And remove it + graph_module.graph.erase_node(orig_output_node) + # Find nodes corresponding to return_nodes and make them into output_nodes + nodes = [n for n in graph_module.graph.nodes] + output_nodes = OrderedDict() + for n in reversed(nodes): + if 'tensor_constant' in str(n): + # NOTE Without this control flow we would get a None value for + # `module_qualname = tracer.node_to_qualname.get(n)`. On the other hand, we can safely assume that we'll + # never need to get this as an interesting intermediate node. + continue + module_qualname = tracer.node_to_qualname.get(n) + for query in return_nodes: + depth = query.count('.') + if '.'.join(module_qualname.split('.')[:depth+1]) == query: + output_nodes[return_nodes[query]] = n + return_nodes.pop(query) + break + output_nodes = OrderedDict(reversed(list(output_nodes.items()))) + + # And add them in the end of the graph + with graph_module.graph.inserting_after(nodes[-1]): + graph_module.graph.output(output_nodes) + + # Remove unused modules / parameters + graph_module.graph.eliminate_dead_code() + graph_module.recompile() + graph_module = fx.GraphModule(graph_module, graph_module.graph, name) + return graph_module + + +class FeatureGraphNet(nn.Module): + """ + Take the provided model and transform it into a graph module. This class wraps the resulting graph module while + also keeping the original model's non-parameter properties for reference. The original model is discarded. + + WARNING: Changes to the operations in the original module might not change the module's overall behaviour, but they + may result in changes to the postfixes for the names of repeated ops, thereby breaking feature extraction. + + TODO: FIX THIS + WARNING: This puts the input model into eval mode prior to tracing. This means that any control flow dependent on + the model being in train mode will be lost. + """ + def __init__(self, model, out_indices, out_map=None): + super().__init__() + model.eval() + self.feature_info = _get_feature_info(model, out_indices) + if out_map is not None: + assert len(out_map) == len(out_indices) + # NOTE the feature_info key is innapropriately named 'module' because prior to FX only modules could be + # provided. Recall that here, we may also provide nodes referring to individual ops + return_nodes = {info['module']: out_map[i] if out_map is not None else info['module'] + for i, info in enumerate(self.feature_info) if i in out_indices} + self.graph_module = get_intermediate_nodes(model, return_nodes) + # Keep non-parameter model properties for reference + for attr_str in model.__dir__(): + attr = getattr(model, attr_str) + if (not attr_str.startswith('_') and attr_str not in self.__dir__() and not ismethod(attr) + and not isinstance(attr, (nn.Module, nn.Parameter))): + setattr(self, attr_str, attr) + + def forward(self, x): + return list(self.graph_module(x).values()) + + def train(self, mode=True): + """ + NOTE: This also covers `self.eval()` as that just does self.train(False) + """ + if mode: + warnings.warn( + "Setting a FeatureGraphNet to training mode won't necessarily have the desired effect. Control " + "flow depending on `self.training` will follow the `False` path. See FeatureGraphNet doc-string " + "for more details.") + super().train(mode) \ No newline at end of file diff --git a/timm/models/fx_helpers.py b/timm/models/fx_helpers.py new file mode 100644 index 00000000..1955d5b1 --- /dev/null +++ b/timm/models/fx_helpers.py @@ -0,0 +1,17 @@ + + +def fx_and(a: bool, b: bool) -> bool: + """ + Symbolic tracing helper to substitute for normal usage of `* and *` within `torch._assert`. + Hint: Symbolic tracing does not support control flow but since an `assert` is either a dead-end or not, this hack + is okay. + """ + return (a and b) + + +def fx_float_to_int(x: float) -> int: + """ + Symbolic tracing helper to substitute for inbuilt `int`. + Hint: Inbuilt `int` can't accept an argument of type `Proxy` + """ + return int(x) \ No newline at end of file diff --git a/timm/models/helpers.py b/timm/models/helpers.py index bd97cf20..4cb571f4 100644 --- a/timm/models/helpers.py +++ b/timm/models/helpers.py @@ -14,6 +14,7 @@ import torch.nn as nn from .features import FeatureListNet, FeatureDictNet, FeatureHookNet +from .fx_features import FeatureGraphNet from .hub import has_hf_hub, download_cached_file, load_state_dict_from_hf, load_state_dict_from_url from .layers import Conv2dSame, Linear @@ -477,6 +478,8 @@ def build_model_with_cfg( feature_cls = feature_cls.lower() if 'hook' in feature_cls: feature_cls = FeatureHookNet + elif feature_cls == 'fx': + feature_cls = FeatureGraphNet else: assert False, f'Unknown feature class {feature_cls}' model = feature_cls(model, **feature_cfg) From bc3d4eb403b9f40c23ed5bedd8875bba911e3cdc Mon Sep 17 00:00:00 2001 From: Alexander Soare Date: Sun, 7 Nov 2021 15:04:19 +0000 Subject: [PATCH 02/26] wip -rebase --- timm/models/cait.py | 8 +- timm/models/coat.py | 10 +- timm/models/convit.py | 10 +- timm/models/layers/bottleneck_attn.py | 9 +- timm/models/layers/evo_norm.py | 4 +- timm/models/layers/global_context.py | 3 +- timm/models/layers/halo_attn.py | 7 +- timm/models/layers/lambda_layer.py | 4 +- timm/models/layers/non_local_attn.py | 5 +- timm/models/layers/patch_embed.py | 4 + timm/models/layers/selective_kernel.py | 2 +- timm/models/layers/swin_attn.py | 183 +++++++++++++++++++++++++ timm/models/levit.py | 8 +- timm/models/nest.py | 16 ++- timm/models/nfnet.py | 2 + timm/models/rexnet.py | 3 +- timm/models/swin_transformer.py | 14 +- timm/models/tnt.py | 14 +- timm/models/twins.py | 10 +- timm/models/vgg.py | 2 + timm/models/visformer.py | 4 +- timm/models/vision_transformer.py | 4 +- timm/models/xcit.py | 6 +- 23 files changed, 269 insertions(+), 63 deletions(-) create mode 100644 timm/models/layers/swin_attn.py diff --git a/timm/models/cait.py b/timm/models/cait.py index 69b4ba06..b6a18ce3 100644 --- a/timm/models/cait.py +++ b/timm/models/cait.py @@ -95,11 +95,11 @@ class ClassAttn(nn.Module): q = q * self.scale v = self.v(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) - attn = (q @ k.transpose(-2, -1)) + attn = torch.matmul(q, k.transpose(-2, -1)) attn = attn.softmax(dim=-1) attn = self.attn_drop(attn) - x_cls = (attn @ v).transpose(1, 2).reshape(B, 1, C) + x_cls = torch.matmul(attn, v).transpose(1, 2).reshape(B, 1, C) x_cls = self.proj(x_cls) x_cls = self.proj_drop(x_cls) @@ -158,7 +158,7 @@ class TalkingHeadAttn(nn.Module): qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) q, k, v = qkv[0] * self.scale, qkv[1], qkv[2] - attn = (q @ k.transpose(-2, -1)) + attn = torch.matmul(q, k.transpose(-2, -1)) attn = self.proj_l(attn.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) @@ -167,7 +167,7 @@ class TalkingHeadAttn(nn.Module): attn = self.proj_w(attn.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) attn = self.attn_drop(attn) - x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = torch.matmul(attn, v).transpose(1, 2).reshape(B, N, C) x = self.proj(x) x = self.proj_drop(x) return x diff --git a/timm/models/coat.py b/timm/models/coat.py index f071715a..69b1bd9f 100644 --- a/timm/models/coat.py +++ b/timm/models/coat.py @@ -105,7 +105,7 @@ class ConvRelPosEnc(nn.Module): def forward(self, q, v, size: Tuple[int, int]): B, h, N, Ch = q.shape H, W = size - assert N == 1 + H * W + torch._assert(N == 1 + H * W, '') # Convolutional relative position encoding. q_img = q[:, :, 1:, :] # [B, h, H*W, Ch] @@ -149,8 +149,8 @@ class FactorAtt_ConvRelPosEnc(nn.Module): # Factorized attention. k_softmax = k.softmax(dim=2) - factor_att = k_softmax.transpose(-1, -2) @ v - factor_att = q @ factor_att + factor_att = torch.matmul(k_softmax.transpose(-1, -2), v) + factor_att = torch.matmul(q, factor_att) # Convolutional relative position encoding. crpe = self.crpe(q, v, size=size) # [B, h, N, Ch] @@ -177,7 +177,7 @@ class ConvPosEnc(nn.Module): def forward(self, x, size: Tuple[int, int]): B, N, C = x.shape H, W = size - assert N == 1 + H * W + torch._assert(N == 1 + H * W, '') # Extract CLS token and image tokens. cls_token, img_tokens = x[:, :1], x[:, 1:] # [B, 1, C], [B, H*W, C] @@ -275,7 +275,7 @@ class ParallelBlock(nn.Module): """ Feature map interpolation. """ B, N, C = x.shape H, W = size - assert N == 1 + H * W + torch._assert(N == 1 + H * W, '') cls_token = x[:, :1, :] img_tokens = x[:, 1:, :] diff --git a/timm/models/convit.py b/timm/models/convit.py index f58249ec..603548f9 100644 --- a/timm/models/convit.py +++ b/timm/models/convit.py @@ -30,6 +30,7 @@ from .helpers import build_model_with_cfg from .layers import DropPath, to_2tuple, trunc_normal_, PatchEmbed, Mlp from .registry import register_model from .vision_transformer_hybrid import HybridEmbed +from .fx_features import register_leaf_module import torch import torch.nn as nn @@ -56,6 +57,7 @@ default_cfgs = { } +@register_leaf_module # FX can't symbolically trace control flow in forward method class GPSA(nn.Module): def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0., locality_strength=1.): @@ -82,7 +84,7 @@ class GPSA(nn.Module): self.rel_indices = self.get_rel_indices(N) attn = self.get_attention(x) v = self.v(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) - x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = torch.matmul(attn, v).transpose(1, 2).reshape(B, N, C) x = self.proj(x) x = self.proj_drop(x) return x @@ -93,7 +95,7 @@ class GPSA(nn.Module): q, k = qk[0], qk[1] pos_score = self.rel_indices.expand(B, -1, -1, -1) pos_score = self.pos_proj(pos_score).permute(0, 3, 1, 2) - patch_score = (q @ k.transpose(-2, -1)) * self.scale + patch_score = torch.matmul(q, k.transpose(-2, -1)) * self.scale patch_score = patch_score.softmax(dim=-1) pos_score = pos_score.softmax(dim=-1) @@ -178,11 +180,11 @@ class MHSA(nn.Module): qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) q, k, v = qkv[0], qkv[1], qkv[2] - attn = (q @ k.transpose(-2, -1)) * self.scale + attn = torch.matmul(q, k.transpose(-2, -1)) * self.scale attn = attn.softmax(dim=-1) attn = self.attn_drop(attn) - x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = torch.matmul(attn, v).transpose(1, 2).reshape(B, N, C) x = self.proj(x) x = self.proj_drop(x) return x diff --git a/timm/models/layers/bottleneck_attn.py b/timm/models/layers/bottleneck_attn.py index f55fd989..305f9de3 100644 --- a/timm/models/layers/bottleneck_attn.py +++ b/timm/models/layers/bottleneck_attn.py @@ -22,6 +22,7 @@ import torch.nn.functional as F from .helpers import to_2tuple, make_divisible from .weight_init import trunc_normal_ +from timm.models.fx_helpers import fx_and def rel_logits_1d(q, rel_k, permute_mask: List[int]): @@ -36,7 +37,7 @@ def rel_logits_1d(q, rel_k, permute_mask: List[int]): permute_mask: permute output dim according to this """ B, H, W, dim = q.shape - x = (q @ rel_k.transpose(-1, -2)) + x = torch.matmul(q, rel_k.transpose(-1, -2)) x = x.reshape(-1, W, 2 * W -1) # pad to shift from relative to absolute indexing @@ -133,8 +134,8 @@ class BottleneckAttn(nn.Module): def forward(self, x): B, C, H, W = x.shape - assert H == self.pos_embed.height - assert W == self.pos_embed.width + torch._assert(H == self.pos_embed.height, '') + torch._assert(W == self.pos_embed.width, '') x = self.qkv(x) # B, (2 * dim_head_qk + dim_head_v) * num_heads, H, W @@ -154,5 +155,3 @@ class BottleneckAttn(nn.Module): out = (attn @ v).transpose(-1, -2).reshape(B, self.dim_out_v, H, W) # B, dim_out, H, W out = self.pool(out) return out - - diff --git a/timm/models/layers/evo_norm.py b/timm/models/layers/evo_norm.py index 9023afd0..02aa0a0c 100644 --- a/timm/models/layers/evo_norm.py +++ b/timm/models/layers/evo_norm.py @@ -72,9 +72,9 @@ class EvoNormSample2d(nn.Module): nn.init.ones_(self.v) def forward(self, x): - assert x.dim() == 4, 'expected 4D input' + torch._assert(x.dim() == 4, 'expected 4D input') B, C, H, W = x.shape - assert C % self.groups == 0 + torch._assert(C % self.groups == 0, '') if self.apply_act: n = x * (x * self.v).sigmoid() x = x.reshape(B, self.groups, -1) diff --git a/timm/models/layers/global_context.py b/timm/models/layers/global_context.py index de7fb5c1..a0bb8a43 100644 --- a/timm/models/layers/global_context.py +++ b/timm/models/layers/global_context.py @@ -7,6 +7,7 @@ Official code consulted as reference: https://github.com/xvjiarui/GCNet Hacked together by / Copyright 2021 Ross Wightman """ +import torch from torch import nn as nn import torch.nn.functional as F @@ -52,7 +53,7 @@ class GlobalContext(nn.Module): if self.conv_attn is not None: attn = self.conv_attn(x).reshape(B, 1, H * W) # (B, 1, H * W) attn = F.softmax(attn, dim=-1).unsqueeze(3) # (B, 1, H * W, 1) - context = x.reshape(B, C, H * W).unsqueeze(1) @ attn + context = torch.matmul(x.reshape(B, C, H * W).unsqueeze(1), attn) context = context.view(B, C, 1, 1) else: context = x.mean(dim=(2, 3), keepdim=True) diff --git a/timm/models/layers/halo_attn.py b/timm/models/layers/halo_attn.py index 4149e812..0bd611b1 100644 --- a/timm/models/layers/halo_attn.py +++ b/timm/models/layers/halo_attn.py @@ -24,6 +24,7 @@ import torch.nn.functional as F from .helpers import make_divisible from .weight_init import trunc_normal_ +from timm.models.fx_helpers import fx_and def rel_logits_1d(q, rel_k, permute_mask: List[int]): @@ -41,7 +42,7 @@ def rel_logits_1d(q, rel_k, permute_mask: List[int]): rel_size = rel_k.shape[0] win_size = (rel_size + 1) // 2 - x = (q @ rel_k.transpose(-1, -2)) + x = torch.matmul(q, rel_k.transpose(-1, -2)) x = x.reshape(-1, W, rel_size) # pad to shift from relative to absolute indexing @@ -167,8 +168,8 @@ class HaloAttn(nn.Module): def forward(self, x): B, C, H, W = x.shape - assert H % self.block_size == 0 - assert W % self.block_size == 0 + torch._assert(H % self.block_size == 0, '') + torch._assert(W % self.block_size == 0, '') num_h_blocks = H // self.block_size num_w_blocks = W // self.block_size num_blocks = num_h_blocks * num_w_blocks diff --git a/timm/models/layers/lambda_layer.py b/timm/models/layers/lambda_layer.py index e50b43c8..058426b6 100644 --- a/timm/models/layers/lambda_layer.py +++ b/timm/models/layers/lambda_layer.py @@ -116,8 +116,8 @@ class LambdaLayer(nn.Module): v = self.norm_v(v).reshape(B, self.dim_v, M).transpose(-1, -2) # B, M, V k = F.softmax(k.reshape(B, self.dim_qk, M), dim=-1) # B, K, M - content_lam = k @ v # B, K, V - content_out = q @ content_lam.unsqueeze(1) # B, num_heads, M, V + content_lam = torch.matmul(k, v) # B, K, V + content_out = torch.matmul(q, content_lam.unsqueeze(1)) # B, num_heads, M, V if self.pos_emb is None: position_lam = self.conv_lambda(v.reshape(B, 1, H, W, self.dim_v)) # B, H, W, V, K diff --git a/timm/models/layers/non_local_attn.py b/timm/models/layers/non_local_attn.py index a537d60e..517e28a8 100644 --- a/timm/models/layers/non_local_attn.py +++ b/timm/models/layers/non_local_attn.py @@ -10,6 +10,7 @@ from torch.nn import functional as F from .conv_bn_act import ConvBnAct from .helpers import make_divisible +from timm.models.fx_helpers import fx_and class NonLocalAttn(nn.Module): @@ -83,7 +84,7 @@ class BilinearAttnTransform(nn.Module): def resize_mat(self, x, t: int): B, C, block_size, block_size1 = x.shape - assert block_size == block_size1 + torch._assert(block_size == block_size1, '') if t <= 1: return x x = x.view(B * C, -1, 1, 1) @@ -95,7 +96,7 @@ class BilinearAttnTransform(nn.Module): return x def forward(self, x): - assert x.shape[-1] % self.block_size == 0 and x.shape[-2] % self.block_size == 0 + torch._assert(fx_and(x.shape[-1] % self.block_size == 0, x.shape[-2] % self.block_size == 0), '') B, C, H, W = x.shape out = self.conv1(x) rp = F.adaptive_max_pool2d(out, (self.block_size, 1)) diff --git a/timm/models/layers/patch_embed.py b/timm/models/layers/patch_embed.py index 6a7facef..157bc250 100644 --- a/timm/models/layers/patch_embed.py +++ b/timm/models/layers/patch_embed.py @@ -9,7 +9,11 @@ Hacked together by / Copyright 2020 Ross Wightman from torch import nn as nn from .helpers import to_2tuple +<<<<<<< HEAD from .trace_utils import _assert +======= +from timm.models.fx_helpers import fx_and +>>>>>>> Make all models FX traceable class PatchEmbed(nn.Module): diff --git a/timm/models/layers/selective_kernel.py b/timm/models/layers/selective_kernel.py index f28b8d2e..69aca86b 100644 --- a/timm/models/layers/selective_kernel.py +++ b/timm/models/layers/selective_kernel.py @@ -34,7 +34,7 @@ class SelectiveKernelAttn(nn.Module): self.fc_select = nn.Conv2d(attn_channels, channels * num_paths, kernel_size=1, bias=False) def forward(self, x): - assert x.shape[1] == self.num_paths + torch._assert(x.shape[1] == self.num_paths, '') x = x.sum(1).mean((2, 3), keepdim=True) x = self.fc_reduce(x) x = self.bn(x) diff --git a/timm/models/layers/swin_attn.py b/timm/models/layers/swin_attn.py new file mode 100644 index 00000000..2a3731f3 --- /dev/null +++ b/timm/models/layers/swin_attn.py @@ -0,0 +1,183 @@ +""" Shifted Window Attn + +This is a WIP experiment to apply windowed attention from the Swin Transformer +to a stand-alone module for use as an attn block in conv nets. + +Based on original swin window code at https://github.com/microsoft/Swin-Transformer +Swin Transformer paper: https://arxiv.org/pdf/2103.14030.pdf +""" +from typing import Optional + +import torch +import torch.nn as nn + +from .drop import DropPath +from .helpers import to_2tuple +from .weight_init import trunc_normal_ +from timm.models.fx_helpers import fx_float_to_int + + +def window_partition(x, win_size: int): + """ + Args: + x: (B, H, W, C) + win_size (int): window size + + Returns: + windows: (num_windows*B, window_size, window_size, C) + """ + B, H, W, C = x.shape + x = x.view(B, H // win_size, win_size, W // win_size, win_size, C) + windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, win_size, win_size, C) + return windows + + +def window_reverse(windows, win_size: int, H: int, W: int): + """ + Args: + windows: (num_windows*B, window_size, window_size, C) + win_size (int): Window size + H (int): Height of image + W (int): Width of image + + Returns: + x: (B, H, W, C) + """ + B = fx_float_to_int(windows.shape[0] / (H * W / win_size / win_size)) + x = windows.view(B, H // win_size, W // win_size, win_size, win_size, -1) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) + return x + + +class WindowAttention(nn.Module): + r""" Window based multi-head self attention (W-MSA) module with relative position bias. + It supports both of shifted and non-shifted window. + + Args: + dim (int): Number of input channels. + win_size (int): The height and width of the window. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 + """ + + def __init__( + self, dim, dim_out=None, feat_size=None, stride=1, win_size=8, shift_size=None, num_heads=8, + qkv_bias=True, attn_drop=0.): + + super().__init__() + self.dim_out = dim_out or dim + self.feat_size = to_2tuple(feat_size) + self.win_size = win_size + self.shift_size = shift_size or win_size // 2 + if min(self.feat_size) <= win_size: + # if window size is larger than input resolution, we don't partition windows + self.shift_size = 0 + self.win_size = min(self.feat_size) + assert 0 <= self.shift_size < self.win_size, "shift_size must in 0-window_size" + self.num_heads = num_heads + head_dim = self.dim_out // num_heads + self.scale = head_dim ** -0.5 + + if self.shift_size > 0: + # calculate attention mask for SW-MSA + H, W = self.feat_size + img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1 + h_slices = ( + slice(0, -self.win_size), + slice(-self.win_size, -self.shift_size), + slice(-self.shift_size, None)) + w_slices = ( + slice(0, -self.win_size), + slice(-self.win_size, -self.shift_size), + slice(-self.shift_size, None)) + cnt = 0 + for h in h_slices: + for w in w_slices: + img_mask[:, h, w, :] = cnt + cnt += 1 + mask_windows = window_partition(img_mask, self.win_size) # num_win, window_size, window_size, 1 + mask_windows = mask_windows.view(-1, self.win_size * self.win_size) + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) + else: + attn_mask = None + self.register_buffer("attn_mask", attn_mask) + + # define a parameter table of relative position bias + self.relative_position_bias_table = nn.Parameter( + # 2 * Wh - 1 * 2 * Ww - 1, nH + torch.zeros((2 * self.win_size - 1) * (2 * self.win_size - 1), num_heads)) + trunc_normal_(self.relative_position_bias_table, std=.02) + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(self.win_size) + coords_w = torch.arange(self.win_size) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww + coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 + relative_coords[:, :, 0] += self.win_size - 1 # shift to start from 0 + relative_coords[:, :, 1] += self.win_size - 1 + relative_coords[:, :, 0] *= 2 * self.win_size - 1 + relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww + self.register_buffer("relative_position_index", relative_position_index) + + self.qkv = nn.Linear(dim, self.dim_out * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.softmax = nn.Softmax(dim=-1) + self.pool = nn.AvgPool2d(2, 2) if stride == 2 else nn.Identity() + + def reset_parameters(self): + trunc_normal_(self.qkv.weight, std=self.qkv.weight.shape[1] ** -0.5) + trunc_normal_(self.relative_position_bias_table, std=.02) + + def forward(self, x): + B, C, H, W = x.shape + x = x.permute(0, 2, 3, 1) + + # cyclic shift + if self.shift_size > 0: + shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) + else: + shifted_x = x + + # partition windows + win_size_sq = self.win_size * self.win_size + x_windows = window_partition(shifted_x, self.win_size) # num_win * B, window_size, window_size, C + x_windows = x_windows.view(-1, win_size_sq, C) # num_win * B, window_size*window_size, C + BW, N, _ = x_windows.shape + + qkv = self.qkv(x_windows) + qkv = qkv.reshape(BW, N, 3, self.num_heads, self.dim_out // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] + q = q * self.scale + attn = torch.matmul(q, k.transpose(-2, -1)) + + relative_position_bias = self.relative_position_bias_table[ + self.relative_position_index.view(-1)].view(win_size_sq, win_size_sq, -1) + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh * Ww, Wh * Ww + attn = attn + relative_position_bias.unsqueeze(0) + if self.attn_mask is not None: + num_win = self.attn_mask.shape[0] + attn = attn.view(B, num_win, self.num_heads, N, N) + self.attn_mask.unsqueeze(1).unsqueeze(0) + attn = attn.view(-1, self.num_heads, N, N) + attn = self.softmax(attn) + attn = self.attn_drop(attn) + + x = torch.matmul(attn, v).transpose(1, 2).reshape(BW, N, self.dim_out) + + # merge windows + x = x.view(-1, self.win_size, self.win_size, self.dim_out) + shifted_x = window_reverse(x, self.win_size, H, W) # B H' W' C + + # reverse cyclic shift + if self.shift_size > 0: + x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) + else: + x = shifted_x + x = x.view(B, H, W, self.dim_out).permute(0, 3, 1, 2) + x = self.pool(x) + return x + + diff --git a/timm/models/levit.py b/timm/models/levit.py index 9987e4ba..c4377bb1 100644 --- a/timm/models/levit.py +++ b/timm/models/levit.py @@ -293,10 +293,10 @@ class Attention(nn.Module): k = k.permute(0, 2, 1, 3) v = v.permute(0, 2, 1, 3) - attn = q @ k.transpose(-2, -1) * self.scale + self.get_attention_biases(x.device) + attn = torch.matmul(q, k.transpose(-2, -1)) * self.scale + self.get_attention_biases(x.device) attn = attn.softmax(dim=-1) - x = (attn @ v).transpose(1, 2).reshape(B, N, self.dh) + x = torch.matmul(attn, v).transpose(1, 2).reshape(B, N, self.dh) x = self.proj(x) return x @@ -387,10 +387,10 @@ class AttentionSubsample(nn.Module): v = v.permute(0, 2, 1, 3) # BHNC q = self.q(x).view(B, self.resolution_2, self.num_heads, self.key_dim).permute(0, 2, 1, 3) - attn = q @ k.transpose(-2, -1) * self.scale + self.get_attention_biases(x.device) + attn = torch.matmul(q, k.transpose(-2, -1)) * self.scale + self.get_attention_biases(x.device) attn = attn.softmax(dim=-1) - x = (attn @ v).transpose(1, 2).reshape(B, -1, self.dh) + x = torch.matmul(attn, v).transpose(1, 2).reshape(B, -1, self.dh) x = self.proj(x) return x diff --git a/timm/models/nest.py b/timm/models/nest.py index 9a477bf9..73f14da5 100644 --- a/timm/models/nest.py +++ b/timm/models/nest.py @@ -26,10 +26,12 @@ from torch import nn from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from .helpers import build_model_with_cfg, named_apply +from .fx_helpers import fx_float_to_int from .layers import PatchEmbed, Mlp, DropPath, create_classifier, trunc_normal_ from .layers import create_conv2d, create_pool2d, to_ntuple from .registry import register_model + _logger = logging.getLogger(__name__) @@ -83,12 +85,12 @@ class Attention(nn.Module): qkv = self.qkv(x).reshape(B, T, N, 3, self.num_heads, C // self.num_heads).permute(3, 0, 4, 1, 2, 5) q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple) - attn = (q @ k.transpose(-2, -1)) * self.scale # (B, H, T, N, N) + attn = torch.matmul(q, k.transpose(-2, -1)) * self.scale # (B, H, T, N, N) attn = attn.softmax(dim=-1) attn = self.attn_drop(attn) # (B, H, T, N, C'), permute -> (B, T, N, C', H) - x = (attn @ v).permute(0, 2, 3, 4, 1).reshape(B, T, N, C) + x = torch.matmul(attn, v).permute(0, 2, 3, 4, 1).reshape(B, T, N, C) x = self.proj(x) x = self.proj_drop(x) return x # (B, T, N, C) @@ -128,8 +130,8 @@ class ConvPool(nn.Module): """ x is expected to have shape (B, C, H, W) """ - assert x.shape[-2] % 2 == 0, 'BlockAggregation requires even input spatial dims' - assert x.shape[-1] % 2 == 0, 'BlockAggregation requires even input spatial dims' + torch._assert(x.shape[-2] % 2 == 0, 'BlockAggregation requires even input spatial dims') + torch._assert(x.shape[-1] % 2 == 0, 'BlockAggregation requires even input spatial dims') x = self.conv(x) # Layer norm done over channel dim only x = self.norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) @@ -144,8 +146,8 @@ def blockify(x, block_size: int): block_size (int): edge length of a single square block in units of H, W """ B, H, W, C = x.shape - assert H % block_size == 0, '`block_size` must divide input height evenly' - assert W % block_size == 0, '`block_size` must divide input width evenly' + torch._assert(H % block_size == 0, '`block_size` must divide input height evenly') + torch._assert(W % block_size == 0, '`block_size` must divide input width evenly') grid_height = H // block_size grid_width = W // block_size x = x.reshape(B, grid_height, block_size, grid_width, block_size, C) @@ -160,7 +162,7 @@ def deblockify(x, block_size: int): block_size (int): edge length of a single square block in units of desired H, W """ B, T, _, C = x.shape - grid_size = int(math.sqrt(T)) + grid_size = fx_float_to_int(math.sqrt(T)) height = width = grid_size * block_size x = x.reshape(B, grid_size, grid_size, block_size, block_size, C) x = x.transpose(2, 3).reshape(B, height, width, C) diff --git a/timm/models/nfnet.py b/timm/models/nfnet.py index 4e0f2b21..ec86dbb8 100644 --- a/timm/models/nfnet.py +++ b/timm/models/nfnet.py @@ -27,6 +27,7 @@ import torch.nn as nn from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from .helpers import build_model_with_cfg +from timm.models.fx_features import register_leaf_module from .registry import register_model from .layers import ClassifierHead, DropPath, AvgPool2dSame, ScaledStdConv2d, ScaledStdConv2dSame,\ get_act_layer, get_act_fn, get_attn, make_divisible @@ -318,6 +319,7 @@ class DownsampleAvg(nn.Module): return self.conv(self.pool(x)) +@register_leaf_module # FX feature extraction was giving different valued features. Perhaps to do with control flow? class NormFreeBlock(nn.Module): """Normalization-Free pre-activation block. """ diff --git a/timm/models/rexnet.py b/timm/models/rexnet.py index 279780be..f27ce5d8 100644 --- a/timm/models/rexnet.py +++ b/timm/models/rexnet.py @@ -10,6 +10,7 @@ Changes for timm, feature extraction, and rounded channel variant hacked togethe Copyright 2020 Ross Wightman """ +import torch import torch.nn as nn from functools import partial from math import ceil @@ -92,7 +93,7 @@ class LinearBottleneck(nn.Module): if self.use_shortcut: if self.drop_path is not None: x = self.drop_path(x) - x[:, 0:self.in_channels] += shortcut + x = torch.cat([x[:, 0:self.in_channels] + shortcut, x[:, self.in_channels:]], dim=1) return x diff --git a/timm/models/swin_transformer.py b/timm/models/swin_transformer.py index 822aeef8..53c7bcd5 100644 --- a/timm/models/swin_transformer.py +++ b/timm/models/swin_transformer.py @@ -22,10 +22,12 @@ import torch.utils.checkpoint as checkpoint from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from .helpers import build_model_with_cfg, overlay_external_default_cfg +from .fx_helpers import fx_float_to_int from .layers import PatchEmbed, Mlp, DropPath, to_2tuple, trunc_normal_ from .registry import register_model from .vision_transformer import checkpoint_filter_fn, _init_vit_weights + _logger = logging.getLogger(__name__) @@ -111,7 +113,7 @@ def window_reverse(windows, window_size: int, H: int, W: int): Returns: x: (B, H, W, C) """ - B = int(windows.shape[0] / (H * W / window_size / window_size)) + B = fx_float_to_int(windows.shape[0] / (H * W / window_size / window_size)) x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) return x @@ -175,7 +177,7 @@ class WindowAttention(nn.Module): q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple) q = q * self.scale - attn = (q @ k.transpose(-2, -1)) + attn = torch.matmul(q, k.transpose(-2, -1)) relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH @@ -192,7 +194,7 @@ class WindowAttention(nn.Module): attn = self.attn_drop(attn) - x = (attn @ v).transpose(1, 2).reshape(B_, N, C) + x = torch.matmul(attn, v).transpose(1, 2).reshape(B_, N, C) x = self.proj(x) x = self.proj_drop(x) return x @@ -270,7 +272,7 @@ class SwinTransformerBlock(nn.Module): def forward(self, x): H, W = self.input_resolution B, L, C = x.shape - assert L == H * W, "input feature has wrong size" + torch._assert(L == H * W, "input feature has wrong size") shortcut = x x = self.norm1(x) @@ -329,8 +331,8 @@ class PatchMerging(nn.Module): """ H, W = self.input_resolution B, L, C = x.shape - assert L == H * W, "input feature has wrong size" - assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even." + torch._assert(L == H * W, "input feature has wrong size") + torch._assert(H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even.") x = x.view(B, H, W, C) diff --git a/timm/models/tnt.py b/timm/models/tnt.py index 9829653c..f9510487 100644 --- a/timm/models/tnt.py +++ b/timm/models/tnt.py @@ -9,10 +9,10 @@ https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/cv/TNT import math import torch import torch.nn as nn -from functools import partial from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.models.helpers import build_model_with_cfg +from timm.models.fx_helpers import fx_and from timm.models.layers import Mlp, DropPath, trunc_normal_ from timm.models.layers.helpers import to_2tuple from timm.models.registry import register_model @@ -64,11 +64,11 @@ class Attention(nn.Module): q, k = qk.unbind(0) # make torchscript happy (cannot use tensor as tuple) v = self.v(x).reshape(B, N, self.num_heads, -1).permute(0, 2, 1, 3) - attn = (q @ k.transpose(-2, -1)) * self.scale + attn = torch.matmul(q, k.transpose(-2, -1)) * self.scale attn = attn.softmax(dim=-1) attn = self.attn_drop(attn) - x = (attn @ v).transpose(1, 2).reshape(B, N, -1) + x = torch.matmul(attn, v).transpose(1, 2).reshape(B, N, -1) x = self.proj(x) x = self.proj_drop(x) return x @@ -109,7 +109,9 @@ class Block(nn.Module): pixel_embed = pixel_embed + self.drop_path(self.mlp_in(self.norm_mlp_in(pixel_embed))) # outer B, N, C = patch_embed.size() - patch_embed[:, 1:] = patch_embed[:, 1:] + self.proj(self.norm1_proj(pixel_embed).reshape(B, N - 1, -1)) + patch_embed = torch.cat( + [patch_embed[:, 0:1], patch_embed[:, 1:] + self.proj(self.norm1_proj(pixel_embed).reshape(B, N - 1, -1))], + dim=1) patch_embed = patch_embed + self.drop_path(self.attn_out(self.norm_out(patch_embed))) patch_embed = patch_embed + self.drop_path(self.mlp(self.norm_mlp(patch_embed))) return pixel_embed, patch_embed @@ -136,8 +138,8 @@ class PixelEmbed(nn.Module): def forward(self, x, pixel_pos): B, C, H, W = x.shape - assert H == self.img_size[0] and W == self.img_size[1], \ - f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." + torch._assert(fx_and(H == self.img_size[0], W == self.img_size[1]), + f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]}).") x = self.proj(x) x = self.unfold(x) x = x.transpose(1, 2).reshape(B * self.num_patches, self.in_dim, self.new_patch_size[0], self.new_patch_size[1]) diff --git a/timm/models/twins.py b/timm/models/twins.py index 4aed09d9..7b5afafb 100644 --- a/timm/models/twins.py +++ b/timm/models/twins.py @@ -22,6 +22,7 @@ from functools import partial from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from .layers import Mlp, DropPath, to_2tuple, trunc_normal_ +from .fx_features import register_leaf_module from .registry import register_model from .vision_transformer import Attention from .helpers import build_model_with_cfg, overlay_external_default_cfg @@ -62,6 +63,7 @@ default_cfgs = { Size_ = Tuple[int, int] +@register_leaf_module # FX can't symbolically trace control flow in forward method class LocallyGroupedAttn(nn.Module): """ LSA: self attention within a group """ @@ -98,10 +100,10 @@ class LocallyGroupedAttn(nn.Module): qkv = self.qkv(x).reshape( B, _h * _w, self.ws * self.ws, 3, self.num_heads, C // self.num_heads).permute(3, 0, 1, 4, 2, 5) q, k, v = qkv[0], qkv[1], qkv[2] - attn = (q @ k.transpose(-2, -1)) * self.scale + attn = torch.matmul(q, k.transpose(-2, -1)) * self.scale attn = attn.softmax(dim=-1) attn = self.attn_drop(attn) - attn = (attn @ v).transpose(2, 3).reshape(B, _h, _w, self.ws, self.ws, C) + attn = torch.matmul(attn, v).transpose(2, 3).reshape(B, _h, _w, self.ws, self.ws, C) x = attn.transpose(2, 3).reshape(B, _h * self.ws, _w * self.ws, C) if pad_r > 0 or pad_b > 0: x = x[:, :H, :W, :].contiguous() @@ -183,11 +185,11 @@ class GlobalSubSampleAttn(nn.Module): kv = self.kv(x).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) k, v = kv[0], kv[1] - attn = (q @ k.transpose(-2, -1)) * self.scale + attn = torch.matmul(q, k.transpose(-2, -1)) * self.scale attn = attn.softmax(dim=-1) attn = self.attn_drop(attn) - x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = torch.matmul(attn, v).transpose(1, 2).reshape(B, N, C) x = self.proj(x) x = self.proj_drop(x) diff --git a/timm/models/vgg.py b/timm/models/vgg.py index 8bea03e7..aee41b25 100644 --- a/timm/models/vgg.py +++ b/timm/models/vgg.py @@ -12,6 +12,7 @@ from typing import Union, List, Dict, Any, cast from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from .helpers import build_model_with_cfg +from .fx_features import register_leaf_module from .layers import ClassifierHead, ConvBnAct from .registry import register_model @@ -52,6 +53,7 @@ cfgs: Dict[str, List[Union[str, int]]] = { } +@register_leaf_module # FX can't symbolically trace control flow in forward method class ConvMlp(nn.Module): def __init__(self, in_features=512, out_features=4096, kernel_size=7, mlp_ratio=1.0, diff --git a/timm/models/visformer.py b/timm/models/visformer.py index 6e832cd0..6ed43102 100644 --- a/timm/models/visformer.py +++ b/timm/models/visformer.py @@ -100,10 +100,10 @@ class Attention(nn.Module): x = self.qkv(x).reshape(B, 3, self.num_heads, self.head_dim, -1).permute(1, 0, 2, 4, 3) q, k, v = x[0], x[1], x[2] - attn = (q @ k.transpose(-2, -1)) * self.scale + attn = torch.matmul(q, k.transpose(-2, -1)) * self.scale attn = attn.softmax(dim=-1) attn = self.attn_drop(attn) - x = attn @ v + x = torch.matmul(attn, v) x = x.permute(0, 1, 3, 2).reshape(B, -1, H, W) x = self.proj(x) diff --git a/timm/models/vision_transformer.py b/timm/models/vision_transformer.py index 94ae2666..fb939624 100644 --- a/timm/models/vision_transformer.py +++ b/timm/models/vision_transformer.py @@ -192,11 +192,11 @@ class Attention(nn.Module): qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple) - attn = (q @ k.transpose(-2, -1)) * self.scale + attn = torch.matmul(q, k.transpose(-2, -1)) * self.scale attn = attn.softmax(dim=-1) attn = self.attn_drop(attn) - x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = torch.matmul(attn, v).transpose(1, 2).reshape(B, N, C) x = self.proj(x) x = self.proj_drop(x) return x diff --git a/timm/models/xcit.py b/timm/models/xcit.py index 2942ed8a..9be29046 100644 --- a/timm/models/xcit.py +++ b/timm/models/xcit.py @@ -21,6 +21,7 @@ from .vision_transformer import _cfg, Mlp from .registry import register_model from .layers import DropPath, trunc_normal_, to_2tuple from .cait import ClassAttn +from .fx_features import register_leaf_module def _cfg(url='', **kwargs): @@ -97,6 +98,7 @@ default_cfgs = { } +@register_leaf_module # FX can't symbolically trace torch.arange in forward method class PositionalEncodingFourier(nn.Module): """ Positional encoding relying on a fourier kernel matching the one used in the "Attention is all of Need" paper. @@ -272,12 +274,12 @@ class XCA(nn.Module): # Paper section 3.2 l2-Normalization and temperature scaling q = torch.nn.functional.normalize(q, dim=-1) k = torch.nn.functional.normalize(k, dim=-1) - attn = (q @ k.transpose(-2, -1)) * self.temperature + attn = torch.matmul(q, k.transpose(-2, -1)) * self.temperature attn = attn.softmax(dim=-1) attn = self.attn_drop(attn) # (B, H, C', N), permute -> (B, N, H, C') - x = (attn @ v).permute(0, 3, 1, 2).reshape(B, N, C) + x = torch.matmul(attn, v).permute(0, 3, 1, 2).reshape(B, N, C) x = self.proj(x) x = self.proj_drop(x) return x From a6c24b936ba91a02686fb10cf7fbbe50216226a4 Mon Sep 17 00:00:00 2001 From: Alexander Soare Date: Thu, 12 Aug 2021 15:31:36 +0100 Subject: [PATCH 03/26] Tests to enforce all models FX traceable --- tests/test_models.py | 80 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 80 insertions(+) diff --git a/tests/test_models.py b/tests/test_models.py index c0d0e901..91fd3543 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -7,6 +7,7 @@ import fnmatch import timm from timm import list_models, create_model, set_scriptable, has_model_default_key, is_model_default_key, \ get_model_default_value +from timm.models.fx_features import NodePathTracer if hasattr(torch._C, '_jit_set_profiling_executor'): # legacy executor is too slow to compile large models for unit tests @@ -297,3 +298,82 @@ def test_model_forward_features(model_name, batch_size): assert e == o.shape[1] assert o.shape[0] == batch_size assert not torch.isnan(o).any() + + +@pytest.mark.timeout(120) +@pytest.mark.parametrize('model_name', list_models(exclude_filters=EXCLUDE_FILTERS)) +@pytest.mark.parametrize('batch_size', [1]) +def test_model_forward_fx(model_name, batch_size): + """Symbolically trace each model and run single forward pass through the resulting GraphModule""" + model = create_model(model_name, pretrained=False) + model.eval() + + input_size = _get_input_size(model=model, target=TARGET_FWD_SIZE) + if max(input_size) > MAX_FWD_SIZE: + pytest.skip("Fixed input size model > limit.") + + tracer = NodePathTracer() + graph = tracer.trace(model) + model = torch.fx.GraphModule(model, graph) + + inputs = torch.randn((batch_size, *input_size)) + outputs = model(inputs) + + assert outputs.shape[0] == batch_size + assert not torch.isnan(outputs).any(), 'Output included NaNs' + + +@pytest.mark.timeout(120) +@pytest.mark.parametrize('model_name', list_models(exclude_filters=EXCLUDE_FILTERS, name_matches_cfg=True)) +@pytest.mark.parametrize('batch_size', [2]) +def test_model_backward_fx(model_name, batch_size): + """Symbolically trace each model and run single backward pass through the resulting GraphModule""" + input_size = _get_input_size(model_name=model_name, target=TARGET_BWD_SIZE) + if max(input_size) > MAX_BWD_SIZE: + pytest.skip("Fixed input size model > limit.") + + model = create_model(model_name, pretrained=False, num_classes=42) + model.train() + num_params = sum([x.numel() for x in model.parameters()]) + + tracer = NodePathTracer() + graph = tracer.trace(model) + model = torch.fx.GraphModule(model, graph) + + inputs = torch.randn((batch_size, *input_size)) + outputs = model(inputs) + if isinstance(outputs, tuple): + outputs = torch.cat(outputs) + outputs.mean().backward() + for n, x in model.named_parameters(): + assert x.grad is not None, f'No gradient for {n}' + num_grad = sum([x.grad.numel() for x in model.parameters() if x.grad is not None]) + + assert outputs.shape[-1] == 42 + assert num_params == num_grad, 'Some parameters are missing gradients' + assert not torch.isnan(outputs).any(), 'Output included NaNs' + + +@pytest.mark.timeout(120) +@pytest.mark.parametrize( + 'model_name', list_models(exclude_filters=EXCLUDE_FILTERS + EXCLUDE_JIT_FILTERS, name_matches_cfg=True)) +@pytest.mark.parametrize('batch_size', [1]) +def test_model_forward_fx_torchscript(model_name, batch_size): + """Symbolically trace each model, script it, and run single forward pass""" + input_size = _get_input_size(model_name=model_name, target=TARGET_JIT_SIZE) + if max(input_size) > MAX_JIT_SIZE: + pytest.skip("Fixed input size model > limit.") + + with set_scriptable(True): + model = create_model(model_name, pretrained=False) + model.eval() + + tracer = NodePathTracer() + graph = tracer.trace(model) + model = torch.fx.GraphModule(model, graph) + + model = torch.jit.script(model) + outputs = model(torch.randn((batch_size, *input_size))) + + assert outputs.shape[0] == batch_size + assert not torch.isnan(outputs).any(), 'Output included NaNs' \ No newline at end of file From 02c3a75a45deea8b8728da14e0d0b5106e06d98b Mon Sep 17 00:00:00 2001 From: Alexander Soare Date: Sat, 28 Aug 2021 17:54:22 +0100 Subject: [PATCH 04/26] wip - make it possible to use fx graph in train and eval mode --- timm/models/fx_features.py | 460 ++++++++++++++++++++++++++----------- 1 file changed, 329 insertions(+), 131 deletions(-) diff --git a/timm/models/fx_features.py b/timm/models/fx_features.py index e00b080f..9a76e041 100644 --- a/timm/models/fx_features.py +++ b/timm/models/fx_features.py @@ -7,18 +7,21 @@ An extension/alternative to timm.models.features making use of PyTorch FX. Here, 3. Write the resulting graph into a GraphModule (PyTorch FX functionality) Copyright 2021 Alexander Soare """ -from typing import Callable, Dict +from typing import Callable, Dict, Union, List, Optional import math from collections import OrderedDict from pprint import pprint from inspect import ismethod import re import warnings +from copy import deepcopy +from itertools import chain import torch from torch import nn from torch import fx import torch.nn.functional as F +from torch.fx.graph_module import _copy_attr from .features import _get_feature_info from .fx_helpers import fx_and, fx_float_to_int @@ -84,29 +87,48 @@ class LeafNodeTracer(TimmTracer): return super().is_leaf_module(m, module_qualname) +def _is_subseq(x, y): + """Check if y is a subseqence of x + https://stackoverflow.com/a/24017747/4391249 + """ + iter_x = iter(x) + return all(any(x_item == y_item for x_item in iter_x) for y_item in y) + + # Taken from https://github.com/pytorch/examples/blob/master/fx/module_tracer.py with modifications for storing # qualified names for all Nodes, not just top-level Modules class NodePathTracer(LeafNodeTracer): """ - NodePathTracer is an FX tracer that, for each operation, also records the qualified name of the Node from which the - operation originated. A qualified name here is a `.` seperated path walking the hierarchy from top level module - down to leaf operation or leaf module. The name of the top level module is not included as part of the qualified - name. For example, if we trace a module who's forward method applies a ReLU module, the qualified name for that - node will simply be 'relu'. + NodePathTracer is an FX tracer that, for each operation, also records the + qualified name of the Node from which the operation originated. A + qualified name here is a `.` seperated path walking the hierarchy from top + level module down to leaf operation or leaf module. The name of the top + level module is not included as part of the qualified name. For example, + if we trace a module who's forward method applies a ReLU module, the + qualified name for that node will simply be 'relu'. + + Some notes on the specifics: + - Nodes are recorded to `self.node_to_qualname` which is a dictionary + mapping a given Node object to its qualified name. + - Nodes are recorded in the order which they are executed during + tracing. + - When a duplicate qualified name is encountered, a suffix of the form + _{int} is added. The counter starts from 1. """ - def __init__(self): - super().__init__() + def __init__(self, *args, **kwargs): + super(NodePathTracer, self).__init__(*args, **kwargs) # Track the qualified name of the Node being traced - self.current_module_qualname : str = '' + self.current_module_qualname = '' # A map from FX Node to the qualified name self.node_to_qualname = OrderedDict() def call_module(self, m: torch.nn.Module, forward: Callable, args, kwargs): """ - Override of Tracer.call_module (see https://pytorch.org/docs/stable/fx.html#torch.fx.Tracer.call_module). + Override of `fx.Tracer.call_module` This override: 1) Stores away the qualified name of the caller for restoration later - 2) Installs the qualified name of the caller in `current_module_qualname` for retrieval by `create_proxy` + 2) Adds the qualified name of the caller to + `current_module_qualname` for retrieval by `create_proxy` 3) Once a leaf module is reached, calls `create_proxy` 4) Restores the caller's qualified name into current_module_qualname """ @@ -121,7 +143,8 @@ class NodePathTracer(LeafNodeTracer): finally: self.current_module_qualname = old_qualname - def create_proxy(self, kind: str, target: fx.node.Target, args, kwargs, name=None, type_expr=None): + def create_proxy(self, kind: str, target: fx.node.Target, args, kwargs, + name=None, type_expr=None) -> fx.proxy.Proxy: """ Override of `Tracer.create_proxy`. This override intercepts the recording of every operation and stores away the current traced module's qualified @@ -132,160 +155,335 @@ class NodePathTracer(LeafNodeTracer): self.current_module_qualname, proxy.node) return proxy - def _get_node_qualname(self, module_qualname: str, node: fx.node.Node): + def _get_node_qualname( + self, module_qualname: str, node: fx.node.Node) -> str: node_qualname = module_qualname if node.op == 'call_module': - # Node terminates in a leaf module so the module_qualname is a complete description of the node - # Just need to check if this module has appeared before. If so add postfix counter starting from _1 for the - # first reappearance (this follows the way that repeated leaf ops are enumerated by PyTorch FX) + # Node terminates in a leaf module so the module_qualname is a + # complete description of the node for existing_qualname in reversed(self.node_to_qualname.values()): - # Check to see if existing_qualname is of the form {node_qualname} or {node_qualname}_{int} - if re.match(rf'{node_qualname}(_[0-9]+)?$', existing_qualname) is not None: + # Check to see if existing_qualname is of the form + # {node_qualname} or {node_qualname}_{int} + if re.match(rf'{node_qualname}(_[0-9]+)?$', + existing_qualname) is not None: postfix = existing_qualname.replace(node_qualname, '') if len(postfix): - # existing_qualname is of the form {node_qualname}_{int} + # Existing_qualname is of the form {node_qualname}_{int} next_index = int(postfix[1:]) + 1 else: # existing_qualname is of the form {node_qualname} next_index = 1 node_qualname += f'_{next_index}' - break + break else: - # Node terminates in non- leaf module so the node name needs to be appended - if len(node_qualname) > 0: # only append '.' if we are deeper than the top level module + # Node terminates in non- leaf module so the node name needs to be + # appended + if len(node_qualname) > 0: + # Only append '.' if we are deeper than the top level module node_qualname += '.' node_qualname += str(node) return node_qualname -def print_graph_node_qualified_names(model: nn.Module): +def _warn_graph_differences( + train_tracer: NodePathTracer, eval_tracer: NodePathTracer): """ - Dev utility to prints nodes in order of execution. Useful for choosing `nodes` for a FeatureGraphNet design. - This is useful for two reasons: - 1. Not all submodules are traced through. Some are treated as leaf modules. See `LeafNodeTracer` - 2. Leaf ops that occur more than once in the graph get a `_{counter}` postfix. - - WARNING: Changes to the operations in the original module might not change the module's overall behaviour, but they - may result in changes to the postfixes for the names of repeated ops, thereby breaking feature extraction. + Utility function for warning the user if there are differences between + the train graph and the eval graph. + """ + train_nodes = list(train_tracer.node_to_qualname.values()) + eval_nodes = list(eval_tracer.node_to_qualname.values()) + + if len(train_nodes) == len(eval_nodes) and [ + t == e for t, e in zip(train_nodes, eval_nodes)]: + return + + suggestion_msg = ( + "When choosing nodes for feature extraction, you may need to specify " + "output nodes for train and eval mode separately") + + if _is_subseq(train_nodes, eval_nodes): + msg = ("NOTE: The nodes obtained by tracing the model in eval mode " + "are a subsequence of those obtained in train mode. ") + elif _is_subseq(eval_nodes, train_nodes): + msg = ("NOTE: The nodes obtained by tracing the model in train mode " + "are a subsequence of those obtained in eval mode. ") + else: + msg = ("The nodes obtained by tracing the model in train mode " + "are different to those obtained in eval mode. ") + warnings.warn(msg + suggestion_msg) + + +def print_graph_node_qualified_names( + model: nn.Module, tracer_kwargs: Dict = {}): """ - tracer = NodePathTracer() - tracer.trace(model) - pprint(list(tracer.node_to_qualname.values())) + Dev utility to prints nodes in order of execution. Useful for choosing + nodes for a FeatureGraphNet design. There are two reasons that qualified + node names can't easily be read directly from the code for a model: + 1. Not all submodules are traced through. Modules from `torch.nn` all + fall within this category. + 2. Node qualified names that occur more than once in the graph get a + `_{counter}` postfix. + The model will be traced twice: once in train mode, and once in eval mode. + If there are discrepancies between the graphs produced, both sets will + be printed and the user will be warned. + + Args: + model (nn.Module): model on which we will extract the features + tracer_kwargs (Dict): a dictionary of keywork arguments for + `NodePathTracer` (which passes them onto it's parent class + `torch.fx.Tracer`). + """ + train_tracer = NodePathTracer(**tracer_kwargs) + train_tracer.trace(model.train()) + eval_tracer = NodePathTracer(**tracer_kwargs) + eval_tracer.trace(model.eval()) + train_nodes = list(train_tracer.node_to_qualname.values()) + eval_nodes = list(eval_tracer.node_to_qualname.values()) + if len(train_nodes) == len(eval_nodes) and [ + t == e for t, e in zip(train_nodes, eval_nodes)]: + # Nodes are aligned in train vs eval mode + pprint(list(train_tracer.node_to_qualname.values())) + return + print("Nodes from train mode:") + pprint(list(train_tracer.node_to_qualname.values())) + print() + print("Nodes from eval mode:") + pprint(list(eval_tracer.node_to_qualname.values())) + print() + _warn_graph_differences(train_tracer, eval_tracer) + + +class DualGraphModule(fx.GraphModule): + """ + A derivative of `fx.GraphModule`. Differs in the following ways: + - Requires a train and eval version of the underlying graph + - Copies submodules according to the nodes of both train and eval graphs. + - Calling train(mode) switches between train graph and eval graph. + """ + def __init__(self, + root: torch.nn.Module, + train_graph: fx.Graph, + eval_graph: fx.Graph, + class_name: str = 'GraphModule'): + """ + Args: + root (torch.nn.Module): module from which the copied module + hierarchy is built + train_graph (Graph): the graph that should be used in train mode + eval_graph (Graph): the graph that should be used in eval mode + """ + super(fx.GraphModule, self).__init__() + + self.__class__.__name__ = class_name + + self.train_graph = train_graph + self.eval_graph = eval_graph + + # Copy all get_attr and call_module ops (indicated by BOTH train and + # eval graphs) + for node in chain(iter(train_graph.nodes), iter(eval_graph.nodes)): + if node.op in ['get_attr', 'call_module']: + assert isinstance(node.target, str) + _copy_attr(root, self, node.target) + + # eval mode by default + self.eval() + self.graph = eval_graph + + # (borrowed from fx.GraphModule): + # Store the Tracer class responsible for creating a Graph separately as part of the + # GraphModule state, except when the Tracer is defined in a local namespace. + # Locally defined Tracers are not pickleable. This is needed because torch.package will + # serialize a GraphModule without retaining the Graph, and needs to use the correct Tracer + # to re-create the Graph during deserialization. + # TODO uncomment this when https://github.com/pytorch/pytorch/pull/63121 is available + # assert self.eval_graph._tracer_cls == self.train_graph._tracer_cls, \ + # "Train mode and eval mode should use the same tracer class" + # self._tracer_cls = None + # if self.graph._tracer_cls and '' not in self.graph._tracer_cls.__qualname__: + # self._tracer_cls = self.graph._tracer_cls + + def train(self, mode=True): + """ + Swap out the graph depending on the training mode. + NOTE this should be safe when calling model.eval() because that just + calls this with mode == False. + """ + if mode: + self.graph = self.train_graph + else: + self.graph = self.eval_graph + return super().train(mode=mode) -def get_intermediate_nodes(model: nn.Module, return_nodes: Dict[str, str]) -> nn.Module: +def build_feature_graph_net( + model: nn.Module, + return_nodes: Union[List[str], Dict[str, str]], + train_return_nodes: Optional[Union[List[str], Dict[str, str]]] = None, + eval_return_nodes: Optional[Union[List[str], Dict[str, str]]] = None, + tracer_kwargs: Dict = {}) -> fx.GraphModule: """ - Creates a new FX-based module that returns intermediate nodes from a given model. This is achieved by re-writing - the computation graph of the model via FX to return the desired nodes as outputs. All unused nodes are removed, - together with their corresponding parameters. + Creates a new graph module that returns intermediate nodes from a given + model as dictionary with user specified keys as strings, and the requested + outputs as values. This is achieved by re-writing the computation graph of + the model via FX to return the desired nodes as outputs. All unused nodes + are removed, together with their corresponding parameters. + + A note on node specification: A node qualified name is specified as a `.` + seperated path walking the hierarchy from top level module down to leaf + operation or leaf module. For instance `blocks.5.3.bn1`. The keys of the + `return_nodes` argument should point to either a node's qualified name, + or some truncated version of it. For example, one could provide `blocks.5` + as a key, and the last node with that prefix will be selected. + `print_graph_node_qualified_names` is a useful helper function for getting + a list of qualified names of a model. + + An attempt is made to keep all non-parametric properties of the original + model, but existing properties of the constructed `GraphModule` are not + overwritten. + Args: model (nn.Module): model on which we will extract the features - return_nodes (Dict[name, new_name]): a dict containing the names (or partial names - see note below) of the - nodes for which the activations will be returned as the keys. The values of the dict are the names - of the returned activations (which the user can specify). - A note on node specification: A node is specified as a `.` seperated path walking the hierarchy from top - level module down to leaf operation or leaf module. For instance `blocks.5.3.bn1`. Nevertheless, the keys - in this dict need not be fully specified. One could provide `blocks.5` as a key, and the last node with - that prefix will be selected. - While designing a feature extractor one can use the `print_graph_node_qualified_names` utility as a guide - to which nodes are available. - Acknowledgement: Starter code from https://github.com/pytorch/vision/pull/3597 + return_nodes (Union[List[name], Dict[name, new_name]])): either a list + or a dict containing the names (or partial names - see note above) + of the nodes for which the activations will be returned. If it is + a `Dict`, the keys are the qualified node names, and the values + are the user-specified keys for the graph module's returned + dictionary. If it is a `List`, it is treated as a `Dict` mapping + node specification strings directly to output names. + tracer_kwargs (Dict): a dictionary of keywork arguments for + `NodePathTracer` (which passes them onto it's parent class + `torch.fx.Tracer`). + + Examples:: + + >>> model = torchvision.models.resnet18() + >>> # extract layer1 and layer3, giving as names `feat1` and feat2` + >>> graph_module = torchvision.models._utils.build_feature_graph_net(m, + >>> {'layer1': 'feat1', 'layer3': 'feat2'}) + >>> out = graph_module(torch.rand(1, 3, 224, 224)) + >>> print([(k, v.shape) for k, v in out.items()]) + >>> [('feat1', torch.Size([1, 64, 56, 56])), + >>> ('feat2', torch.Size([1, 256, 14, 14]))] + """ + is_training = model.training + + if isinstance(return_nodes, list): + return_nodes = {n: n for n in return_nodes} return_nodes = {str(k): str(v) for k, v in return_nodes.items()} - # Instantiate our NodePathTracer and use that to trace the model - tracer = NodePathTracer() - graph = tracer.trace(model) - - name = model.__class__.__name__ if isinstance(model, nn.Module) else model.__name__ - graph_module = fx.GraphModule(tracer.root, graph, name) - - available_nodes = [f'{v}.{k}' for k, v in tracer.node_to_qualname.items()] - # FIXME We don't know if we should expect this to happen - assert len(set(available_nodes)) == len(available_nodes), \ - "There are duplicate nodes! Please raise an issue https://github.com/rwightman/pytorch-image-models/issues" - # Check that all outputs in return_nodes are present in the model - for query in return_nodes.keys(): - if not any([m.startswith(query) for m in available_nodes]): - raise ValueError(f"return_node: {query} is not present in model") - - # Remove existing output nodes - orig_output_node = None - for n in reversed(graph_module.graph.nodes): - if n.op == "output": - orig_output_node = n - assert orig_output_node - # And remove it - graph_module.graph.erase_node(orig_output_node) - # Find nodes corresponding to return_nodes and make them into output_nodes - nodes = [n for n in graph_module.graph.nodes] - output_nodes = OrderedDict() - for n in reversed(nodes): - if 'tensor_constant' in str(n): - # NOTE Without this control flow we would get a None value for - # `module_qualname = tracer.node_to_qualname.get(n)`. On the other hand, we can safely assume that we'll - # never need to get this as an interesting intermediate node. - continue - module_qualname = tracer.node_to_qualname.get(n) - for query in return_nodes: - depth = query.count('.') - if '.'.join(module_qualname.split('.')[:depth+1]) == query: - output_nodes[return_nodes[query]] = n - return_nodes.pop(query) - break - output_nodes = OrderedDict(reversed(list(output_nodes.items()))) - - # And add them in the end of the graph - with graph_module.graph.inserting_after(nodes[-1]): - graph_module.graph.output(output_nodes) - - # Remove unused modules / parameters - graph_module.graph.eliminate_dead_code() - graph_module.recompile() - graph_module = fx.GraphModule(graph_module, graph_module.graph, name) + assert not ((train_return_nodes is None) ^ (eval_return_nodes is None)), \ + ("If any of `train_return_nodes` and `eval_return_nodes` are " + "specified, then both should be specified") + + if train_return_nodes is None: + train_return_nodes = deepcopy(return_nodes) + eval_return_nodes = deepcopy(return_nodes) + + # Repeat the tracing and graph rewriting for train and eval mode + tracers = {} + graphs = {} + return_nodes = { + 'train': train_return_nodes, + 'eval': eval_return_nodes + } + for mode in ['train', 'eval']: + if mode == 'train': + model.train() + elif mode == 'eval': + model.eval() + + # Instantiate our NodePathTracer and use that to trace the model + tracer = NodePathTracer(**tracer_kwargs) + graph = tracer.trace(model) + + name = model.__class__.__name__ if isinstance( + model, nn.Module) else model.__name__ + graph_module = fx.GraphModule(tracer.root, graph, name) + + available_nodes = [f'{v}.{k}' for k, v in tracer.node_to_qualname.items()] + # FIXME We don't know if we should expect this to happen + assert len(set(available_nodes)) == len(available_nodes), \ + "There are duplicate nodes! Please raise an issue https://github.com/pytorch/vision/issues" + # Check that all outputs in return_nodes are present in the model + for query in return_nodes[mode].keys(): + if not any([m.startswith(query) for m in available_nodes]): + raise ValueError(f"return_node: {query} is not present in model") + + # Remove existing output nodes (train mode) + orig_output_nodes = [] + for n in reversed(graph_module.graph.nodes): + if n.op == "output": + orig_output_nodes.append(n) + assert len(orig_output_nodes) + for n in orig_output_nodes: + graph_module.graph.erase_node(n) + + # Find nodes corresponding to return_nodes and make them into output_nodes + nodes = [n for n in graph_module.graph.nodes] + output_nodes = OrderedDict() + for n in reversed(nodes): + if 'tensor_constant' in str(n): + # NOTE Without this control flow we would get a None value for + # `module_qualname = tracer.node_to_qualname.get(n)`. + # On the other hand, we can safely assume that we'll never need to + # get this as an interesting intermediate node. + continue + module_qualname = tracer.node_to_qualname.get(n) + for query in return_nodes[mode]: + depth = query.count('.') + if '.'.join(module_qualname.split('.')[:depth + 1]) == query: + output_nodes[return_nodes[mode][query]] = n + return_nodes[mode].pop(query) + break + output_nodes = OrderedDict(reversed(list(output_nodes.items()))) + + # And add them in the end of the graph + with graph_module.graph.inserting_after(nodes[-1]): + graph_module.graph.output(output_nodes) + + # Remove unused modules / parameters + graph_module.graph.eliminate_dead_code() + graph_module.recompile() + + # Keep track of the tracer and graph so we can choose the main one + tracers[mode] = tracer + graphs[mode] = graph + + # Warn user if there are any discrepancies between the graphs of the + # train and eval modes + _warn_graph_differences(tracers['train'], tracers['eval']) + + # Build the final graph module + graph_module = DualGraphModule( + model, graphs['train'], graphs['eval'], class_name=name) + + # Keep non-parameter model properties for reference + for attr_str in model.__dir__(): + attr = getattr(model, attr_str) + if (not attr_str.startswith('_') + and attr_str not in graph_module.__dir__() + and not ismethod(attr) + and not isinstance(attr, (nn.Module, nn.Parameter))): + setattr(graph_module, attr_str, attr) + + # Restore original training mode + graph_module.train(is_training) + return graph_module class FeatureGraphNet(nn.Module): - """ - Take the provided model and transform it into a graph module. This class wraps the resulting graph module while - also keeping the original model's non-parameter properties for reference. The original model is discarded. - - WARNING: Changes to the operations in the original module might not change the module's overall behaviour, but they - may result in changes to the postfixes for the names of repeated ops, thereby breaking feature extraction. - - TODO: FIX THIS - WARNING: This puts the input model into eval mode prior to tracing. This means that any control flow dependent on - the model being in train mode will be lost. - """ def __init__(self, model, out_indices, out_map=None): super().__init__() - model.eval() self.feature_info = _get_feature_info(model, out_indices) if out_map is not None: assert len(out_map) == len(out_indices) - # NOTE the feature_info key is innapropriately named 'module' because prior to FX only modules could be - # provided. Recall that here, we may also provide nodes referring to individual ops return_nodes = {info['module']: out_map[i] if out_map is not None else info['module'] for i, info in enumerate(self.feature_info) if i in out_indices} - self.graph_module = get_intermediate_nodes(model, return_nodes) - # Keep non-parameter model properties for reference - for attr_str in model.__dir__(): - attr = getattr(model, attr_str) - if (not attr_str.startswith('_') and attr_str not in self.__dir__() and not ismethod(attr) - and not isinstance(attr, (nn.Module, nn.Parameter))): - setattr(self, attr_str, attr) - + self.graph_module = build_feature_graph_net(model, return_nodes) + def forward(self, x): - return list(self.graph_module(x).values()) - - def train(self, mode=True): - """ - NOTE: This also covers `self.eval()` as that just does self.train(False) - """ - if mode: - warnings.warn( - "Setting a FeatureGraphNet to training mode won't necessarily have the desired effect. Control " - "flow depending on `self.training` will follow the `False` path. See FeatureGraphNet doc-string " - "for more details.") - super().train(mode) \ No newline at end of file + return list(self.graph_module(x).values()) \ No newline at end of file From 0149ec30d7d39ae8ed65b95bd6ed601055a97f4c Mon Sep 17 00:00:00 2001 From: Alexander Soare Date: Sun, 7 Nov 2021 14:54:42 +0000 Subject: [PATCH 05/26] wip - attempting to rebase --- timm/models/fx_features.py | 4 ++-- timm/models/fx_helpers.py | 10 ---------- timm/models/layers/bottleneck_attn.py | 1 - timm/models/layers/halo_attn.py | 2 +- timm/models/layers/non_local_attn.py | 4 ++-- timm/models/layers/patch_embed.py | 4 ---- timm/models/tnt.py | 5 +++-- 7 files changed, 8 insertions(+), 22 deletions(-) diff --git a/timm/models/fx_features.py b/timm/models/fx_features.py index 9a76e041..310cc465 100644 --- a/timm/models/fx_features.py +++ b/timm/models/fx_features.py @@ -24,7 +24,7 @@ import torch.nn.functional as F from torch.fx.graph_module import _copy_attr from .features import _get_feature_info -from .fx_helpers import fx_and, fx_float_to_int +from .fx_helpers import fx_float_to_int # Layers we went to treat as leaf modules for FeatureGraphNet from .layers import Conv2dSame, ScaledStdConv2dSame, BatchNormAct2d, BlurPool2d, CondConv2d, StdConv2dSame @@ -55,7 +55,7 @@ def register_leaf_module(module: nn.Module): # These functions will not be traced through -_autowrap_functions=(fx_float_to_int, fx_and) +_autowrap_functions=(fx_float_to_int,) class TimmTracer(fx.Tracer): diff --git a/timm/models/fx_helpers.py b/timm/models/fx_helpers.py index 1955d5b1..878ba381 100644 --- a/timm/models/fx_helpers.py +++ b/timm/models/fx_helpers.py @@ -1,14 +1,4 @@ - -def fx_and(a: bool, b: bool) -> bool: - """ - Symbolic tracing helper to substitute for normal usage of `* and *` within `torch._assert`. - Hint: Symbolic tracing does not support control flow but since an `assert` is either a dead-end or not, this hack - is okay. - """ - return (a and b) - - def fx_float_to_int(x: float) -> int: """ Symbolic tracing helper to substitute for inbuilt `int`. diff --git a/timm/models/layers/bottleneck_attn.py b/timm/models/layers/bottleneck_attn.py index 305f9de3..c56c5821 100644 --- a/timm/models/layers/bottleneck_attn.py +++ b/timm/models/layers/bottleneck_attn.py @@ -22,7 +22,6 @@ import torch.nn.functional as F from .helpers import to_2tuple, make_divisible from .weight_init import trunc_normal_ -from timm.models.fx_helpers import fx_and def rel_logits_1d(q, rel_k, permute_mask: List[int]): diff --git a/timm/models/layers/halo_attn.py b/timm/models/layers/halo_attn.py index 0bd611b1..babfcb06 100644 --- a/timm/models/layers/halo_attn.py +++ b/timm/models/layers/halo_attn.py @@ -24,7 +24,7 @@ import torch.nn.functional as F from .helpers import make_divisible from .weight_init import trunc_normal_ -from timm.models.fx_helpers import fx_and +from timm.models.fx_helpers import def rel_logits_1d(q, rel_k, permute_mask: List[int]): diff --git a/timm/models/layers/non_local_attn.py b/timm/models/layers/non_local_attn.py index 517e28a8..f933ece2 100644 --- a/timm/models/layers/non_local_attn.py +++ b/timm/models/layers/non_local_attn.py @@ -10,7 +10,6 @@ from torch.nn import functional as F from .conv_bn_act import ConvBnAct from .helpers import make_divisible -from timm.models.fx_helpers import fx_and class NonLocalAttn(nn.Module): @@ -96,7 +95,8 @@ class BilinearAttnTransform(nn.Module): return x def forward(self, x): - torch._assert(fx_and(x.shape[-1] % self.block_size == 0, x.shape[-2] % self.block_size == 0), '') + torch._assert(x.shape[-1] % self.block_size == 0, '') + torch._assert(x.shape[-2] % self.block_size == 0, '') B, C, H, W = x.shape out = self.conv1(x) rp = F.adaptive_max_pool2d(out, (self.block_size, 1)) diff --git a/timm/models/layers/patch_embed.py b/timm/models/layers/patch_embed.py index 157bc250..6a7facef 100644 --- a/timm/models/layers/patch_embed.py +++ b/timm/models/layers/patch_embed.py @@ -9,11 +9,7 @@ Hacked together by / Copyright 2020 Ross Wightman from torch import nn as nn from .helpers import to_2tuple -<<<<<<< HEAD from .trace_utils import _assert -======= -from timm.models.fx_helpers import fx_and ->>>>>>> Make all models FX traceable class PatchEmbed(nn.Module): diff --git a/timm/models/tnt.py b/timm/models/tnt.py index f9510487..92108fe5 100644 --- a/timm/models/tnt.py +++ b/timm/models/tnt.py @@ -12,7 +12,6 @@ import torch.nn as nn from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.models.helpers import build_model_with_cfg -from timm.models.fx_helpers import fx_and from timm.models.layers import Mlp, DropPath, trunc_normal_ from timm.models.layers.helpers import to_2tuple from timm.models.registry import register_model @@ -138,7 +137,9 @@ class PixelEmbed(nn.Module): def forward(self, x, pixel_pos): B, C, H, W = x.shape - torch._assert(fx_and(H == self.img_size[0], W == self.img_size[1]), + torch._assert(H == self.img_size[0], + f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]}).") + torch._assert(W == self.img_size[1], f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]}).") x = self.proj(x) x = self.unfold(x) From cf4561ca729c69bd9e5c600ca2dc40b7510c50b2 Mon Sep 17 00:00:00 2001 From: Alexander Soare Date: Thu, 12 Aug 2021 15:31:02 +0100 Subject: [PATCH 06/26] Add FX based FeatureGraphNet capability --- timm/models/fx_features.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/timm/models/fx_features.py b/timm/models/fx_features.py index 310cc465..3d253444 100644 --- a/timm/models/fx_features.py +++ b/timm/models/fx_features.py @@ -486,4 +486,4 @@ class FeatureGraphNet(nn.Module): self.graph_module = build_feature_graph_net(model, return_nodes) def forward(self, x): - return list(self.graph_module(x).values()) \ No newline at end of file + return list(self.graph_module(x).values()) From e051dce35451451b8ac7eee7b8abab38325a26b0 Mon Sep 17 00:00:00 2001 From: Alexander Soare Date: Thu, 12 Aug 2021 15:31:24 +0100 Subject: [PATCH 07/26] Make all models FX traceable --- timm/models/layers/bottleneck_attn.py | 1 + timm/models/layers/halo_attn.py | 1 - timm/models/layers/non_local_attn.py | 1 + timm/models/tnt.py | 1 + 4 files changed, 3 insertions(+), 1 deletion(-) diff --git a/timm/models/layers/bottleneck_attn.py b/timm/models/layers/bottleneck_attn.py index c56c5821..305f9de3 100644 --- a/timm/models/layers/bottleneck_attn.py +++ b/timm/models/layers/bottleneck_attn.py @@ -22,6 +22,7 @@ import torch.nn.functional as F from .helpers import to_2tuple, make_divisible from .weight_init import trunc_normal_ +from timm.models.fx_helpers import fx_and def rel_logits_1d(q, rel_k, permute_mask: List[int]): diff --git a/timm/models/layers/halo_attn.py b/timm/models/layers/halo_attn.py index babfcb06..ec93474f 100644 --- a/timm/models/layers/halo_attn.py +++ b/timm/models/layers/halo_attn.py @@ -24,7 +24,6 @@ import torch.nn.functional as F from .helpers import make_divisible from .weight_init import trunc_normal_ -from timm.models.fx_helpers import def rel_logits_1d(q, rel_k, permute_mask: List[int]): diff --git a/timm/models/layers/non_local_attn.py b/timm/models/layers/non_local_attn.py index f933ece2..5f83005c 100644 --- a/timm/models/layers/non_local_attn.py +++ b/timm/models/layers/non_local_attn.py @@ -10,6 +10,7 @@ from torch.nn import functional as F from .conv_bn_act import ConvBnAct from .helpers import make_divisible +from timm.models.fx_helpers import fx_and class NonLocalAttn(nn.Module): diff --git a/timm/models/tnt.py b/timm/models/tnt.py index 92108fe5..298808c3 100644 --- a/timm/models/tnt.py +++ b/timm/models/tnt.py @@ -12,6 +12,7 @@ import torch.nn as nn from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.models.helpers import build_model_with_cfg +from timm.models.fx_helpers import fx_and from timm.models.layers import Mlp, DropPath, trunc_normal_ from timm.models.layers.helpers import to_2tuple from timm.models.registry import register_model From b25ff9676848a25df5c87f489bcece89f216e749 Mon Sep 17 00:00:00 2001 From: Alexander Soare Date: Fri, 12 Nov 2021 20:42:45 +0000 Subject: [PATCH 08/26] wip - pre-rebase --- tests/test_models.py | 61 +++- timm/models/cait.py | 8 +- timm/models/coat.py | 11 +- timm/models/convit.py | 10 +- timm/models/crossvit.py | 40 ++- timm/models/fx_features.py | 473 ++----------------------- timm/models/fx_helpers.py | 7 - timm/models/layers/bottleneck_attn.py | 8 +- timm/models/layers/evo_norm.py | 6 +- timm/models/layers/global_context.py | 3 +- timm/models/layers/halo_attn.py | 9 +- timm/models/layers/lambda_layer.py | 4 +- timm/models/layers/non_local_attn.py | 8 +- timm/models/layers/selective_kernel.py | 3 +- timm/models/layers/swin_attn.py | 183 ---------- timm/models/levit.py | 8 +- timm/models/nest.py | 19 +- timm/models/nfnet.py | 2 - timm/models/swin_transformer.py | 16 +- timm/models/tnt.py | 10 +- timm/models/twins.py | 12 +- timm/models/vgg.py | 4 +- timm/models/visformer.py | 4 +- timm/models/vision_transformer.py | 4 +- timm/models/xcit.py | 6 +- 25 files changed, 185 insertions(+), 734 deletions(-) delete mode 100644 timm/models/fx_helpers.py delete mode 100644 timm/models/layers/swin_attn.py diff --git a/tests/test_models.py b/tests/test_models.py index 91fd3543..f7233ef3 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -4,10 +4,12 @@ import platform import os import fnmatch +from torchvision.models.feature_extraction import create_feature_extractor, get_graph_node_names, NodePathTracer + import timm from timm import list_models, create_model, set_scriptable, has_model_default_key, is_model_default_key, \ get_model_default_value -from timm.models.fx_features import NodePathTracer +from timm.models.fx_features import _leaf_modules, _autowrap_functions if hasattr(torch._C, '_jit_set_profiling_executor'): # legacy executor is too slow to compile large models for unit tests @@ -312,12 +314,14 @@ def test_model_forward_fx(model_name, batch_size): if max(input_size) > MAX_FWD_SIZE: pytest.skip("Fixed input size model > limit.") - tracer = NodePathTracer() - graph = tracer.trace(model) - model = torch.fx.GraphModule(model, graph) + train_nodes, eval_nodes = get_graph_node_names( + model, tracer_kwargs={'leaf_modules': list(_leaf_modules), 'autowrap_functions': list(_autowrap_functions)}) + model = create_feature_extractor( + model, train_return_nodes=[train_nodes[-1]], eval_return_nodes=[eval_nodes[-1]], + tracer_kwargs={'leaf_modules': list(_leaf_modules), 'autowrap_functions': list(_autowrap_functions)}) inputs = torch.randn((batch_size, *input_size)) - outputs = model(inputs) + outputs = model(inputs)[eval_nodes[-1]] assert outputs.shape[0] == batch_size assert not torch.isnan(outputs).any(), 'Output included NaNs' @@ -336,12 +340,30 @@ def test_model_backward_fx(model_name, batch_size): model.train() num_params = sum([x.numel() for x in model.parameters()]) - tracer = NodePathTracer() + input_size = _get_input_size(model=model, target=TARGET_FWD_SIZE) + if max(input_size) > MAX_FWD_SIZE: + pytest.skip("Fixed input size model > limit.") + + # This block of code does a bit of juggling to handle any case where there are multiple outputs in train mode + # If so, we need to return all of them in order to check all grads + # So we trace once and look at the graph, and get the indices of the nodes that lead into the original fx output + # node. Then we use those indices to select from train_nodes returned by torchvision get_graph_node_names + tracer = NodePathTracer(leaf_modules=list(_leaf_modules), autowrap_functions=list(_autowrap_functions)) graph = tracer.trace(model) - model = torch.fx.GraphModule(model, graph) + graph_nodes = list(reversed(graph.nodes)) + output_node_names = [n.name for n in graph_nodes[0]._input_nodes.keys()] + graph_node_names = [n.name for n in graph_nodes] + output_node_indices = [-graph_node_names.index(node_name) for node_name in output_node_names] + train_nodes, eval_nodes = get_graph_node_names( + model, tracer_kwargs={'leaf_modules': list(_leaf_modules), 'autowrap_functions': list(_autowrap_functions)}) + train_return_nodes = [train_nodes[ix] for ix in output_node_indices] + + model = create_feature_extractor( + model, train_return_nodes=train_return_nodes, eval_return_nodes=[eval_nodes[-1]], + tracer_kwargs={'leaf_modules': list(_leaf_modules), 'autowrap_functions': list(_autowrap_functions)}) inputs = torch.randn((batch_size, *input_size)) - outputs = model(inputs) + outputs = tuple(model(inputs).values()) if isinstance(outputs, tuple): outputs = torch.cat(outputs) outputs.mean().backward() @@ -354,9 +376,14 @@ def test_model_backward_fx(model_name, batch_size): assert not torch.isnan(outputs).any(), 'Output included NaNs' +EXCLUDE_FX_JIT_FILTERS = [ + 'beit_*' # reason: model is scripted after fx tracing, but beit has torch.jit.is_scripting() control flow +] + @pytest.mark.timeout(120) @pytest.mark.parametrize( - 'model_name', list_models(exclude_filters=EXCLUDE_FILTERS + EXCLUDE_JIT_FILTERS, name_matches_cfg=True)) + 'model_name', list_models( + exclude_filters=EXCLUDE_FILTERS + EXCLUDE_JIT_FILTERS + EXCLUDE_FX_JIT_FILTERS, name_matches_cfg=True)) @pytest.mark.parametrize('batch_size', [1]) def test_model_forward_fx_torchscript(model_name, batch_size): """Symbolically trace each model, script it, and run single forward pass""" @@ -368,12 +395,18 @@ def test_model_forward_fx_torchscript(model_name, batch_size): model = create_model(model_name, pretrained=False) model.eval() - tracer = NodePathTracer() - graph = tracer.trace(model) - model = torch.fx.GraphModule(model, graph) + input_size = _get_input_size(model=model, target=TARGET_FWD_SIZE) + if max(input_size) > MAX_FWD_SIZE: + pytest.skip("Fixed input size model > limit.") + + train_nodes, eval_nodes = get_graph_node_names( + model, tracer_kwargs={'leaf_modules': list(_leaf_modules), 'autowrap_functions': list(_autowrap_functions)}) + model = create_feature_extractor( + model, train_return_nodes=[train_nodes[-1]], eval_return_nodes=[eval_nodes[-1]], + tracer_kwargs={'leaf_modules': list(_leaf_modules), 'autowrap_functions': list(_autowrap_functions)}) model = torch.jit.script(model) - outputs = model(torch.randn((batch_size, *input_size))) + outputs = model(torch.randn((batch_size, *input_size)))[train_nodes[-1]] assert outputs.shape[0] == batch_size - assert not torch.isnan(outputs).any(), 'Output included NaNs' \ No newline at end of file + assert not torch.isnan(outputs).any(), 'Output included NaNs' diff --git a/timm/models/cait.py b/timm/models/cait.py index b6a18ce3..69b4ba06 100644 --- a/timm/models/cait.py +++ b/timm/models/cait.py @@ -95,11 +95,11 @@ class ClassAttn(nn.Module): q = q * self.scale v = self.v(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) - attn = torch.matmul(q, k.transpose(-2, -1)) + attn = (q @ k.transpose(-2, -1)) attn = attn.softmax(dim=-1) attn = self.attn_drop(attn) - x_cls = torch.matmul(attn, v).transpose(1, 2).reshape(B, 1, C) + x_cls = (attn @ v).transpose(1, 2).reshape(B, 1, C) x_cls = self.proj(x_cls) x_cls = self.proj_drop(x_cls) @@ -158,7 +158,7 @@ class TalkingHeadAttn(nn.Module): qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) q, k, v = qkv[0] * self.scale, qkv[1], qkv[2] - attn = torch.matmul(q, k.transpose(-2, -1)) + attn = (q @ k.transpose(-2, -1)) attn = self.proj_l(attn.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) @@ -167,7 +167,7 @@ class TalkingHeadAttn(nn.Module): attn = self.proj_w(attn.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) attn = self.attn_drop(attn) - x = torch.matmul(attn, v).transpose(1, 2).reshape(B, N, C) + x = (attn @ v).transpose(1, 2).reshape(B, N, C) x = self.proj(x) x = self.proj_drop(x) return x diff --git a/timm/models/coat.py b/timm/models/coat.py index 69b1bd9f..dca655ea 100644 --- a/timm/models/coat.py +++ b/timm/models/coat.py @@ -19,6 +19,7 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from .helpers import build_model_with_cfg, overlay_external_default_cfg from .layers import PatchEmbed, Mlp, DropPath, to_2tuple, trunc_normal_ from .registry import register_model +from .layers.trace_utils import _assert __all__ = [ @@ -105,7 +106,7 @@ class ConvRelPosEnc(nn.Module): def forward(self, q, v, size: Tuple[int, int]): B, h, N, Ch = q.shape H, W = size - torch._assert(N == 1 + H * W, '') + _assert(N == 1 + H * W, '') # Convolutional relative position encoding. q_img = q[:, :, 1:, :] # [B, h, H*W, Ch] @@ -149,8 +150,8 @@ class FactorAtt_ConvRelPosEnc(nn.Module): # Factorized attention. k_softmax = k.softmax(dim=2) - factor_att = torch.matmul(k_softmax.transpose(-1, -2), v) - factor_att = torch.matmul(q, factor_att) + factor_att = k_softmax.transpose(-1, -2) @ v + factor_att = q @ factor_att # Convolutional relative position encoding. crpe = self.crpe(q, v, size=size) # [B, h, N, Ch] @@ -177,7 +178,7 @@ class ConvPosEnc(nn.Module): def forward(self, x, size: Tuple[int, int]): B, N, C = x.shape H, W = size - torch._assert(N == 1 + H * W, '') + _assert(N == 1 + H * W, '') # Extract CLS token and image tokens. cls_token, img_tokens = x[:, :1], x[:, 1:] # [B, 1, C], [B, H*W, C] @@ -275,7 +276,7 @@ class ParallelBlock(nn.Module): """ Feature map interpolation. """ B, N, C = x.shape H, W = size - torch._assert(N == 1 + H * W, '') + _assert(N == 1 + H * W, '') cls_token = x[:, :1, :] img_tokens = x[:, 1:, :] diff --git a/timm/models/convit.py b/timm/models/convit.py index 603548f9..d2f69b68 100644 --- a/timm/models/convit.py +++ b/timm/models/convit.py @@ -57,7 +57,7 @@ default_cfgs = { } -@register_leaf_module # FX can't symbolically trace control flow in forward method +@register_leaf_module # reason: FX can't symbolically trace control flow in forward method class GPSA(nn.Module): def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0., locality_strength=1.): @@ -84,7 +84,7 @@ class GPSA(nn.Module): self.rel_indices = self.get_rel_indices(N) attn = self.get_attention(x) v = self.v(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) - x = torch.matmul(attn, v).transpose(1, 2).reshape(B, N, C) + x = (attn @ v).transpose(1, 2).reshape(B, N, C) x = self.proj(x) x = self.proj_drop(x) return x @@ -95,7 +95,7 @@ class GPSA(nn.Module): q, k = qk[0], qk[1] pos_score = self.rel_indices.expand(B, -1, -1, -1) pos_score = self.pos_proj(pos_score).permute(0, 3, 1, 2) - patch_score = torch.matmul(q, k.transpose(-2, -1)) * self.scale + patch_score = (q @ k.transpose(-2, -1)) * self.scale patch_score = patch_score.softmax(dim=-1) pos_score = pos_score.softmax(dim=-1) @@ -180,11 +180,11 @@ class MHSA(nn.Module): qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) q, k, v = qkv[0], qkv[1], qkv[2] - attn = torch.matmul(q, k.transpose(-2, -1)) * self.scale + attn = (q @ k.transpose(-2, -1)) * self.scale attn = attn.softmax(dim=-1) attn = self.attn_drop(attn) - x = torch.matmul(attn, v).transpose(1, 2).reshape(B, N, C) + x = (attn @ v).transpose(1, 2).reshape(B, N, C) x = self.proj(x) x = self.proj_drop(x) return x diff --git a/timm/models/crossvit.py b/timm/models/crossvit.py index 6e0160f9..3ba5b4c7 100644 --- a/timm/models/crossvit.py +++ b/timm/models/crossvit.py @@ -22,6 +22,7 @@ NOTE: model names have been renamed from originals to represent actual input res Modifed from Timm. https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py """ +from typing import Tuple import torch import torch.nn as nn @@ -31,8 +32,9 @@ from functools import partial from typing import List from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from .fx_features import register_autowrap_function from .helpers import build_model_with_cfg -from .layers import DropPath, to_2tuple, trunc_normal_ +from .layers import DropPath, to_2tuple, trunc_normal_, _assert from .registry import register_model from .vision_transformer import Mlp, Block @@ -116,8 +118,10 @@ class PatchEmbed(nn.Module): def forward(self, x): B, C, H, W = x.shape # FIXME look at relaxing size constraints - assert H == self.img_size[0] and W == self.img_size[1], \ - f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." + _assert(H == self.img_size[0], + f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]}).") + _assert(W == self.img_size[1], + f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]}).") x = self.proj(x).flatten(2).transpose(1, 2) return x @@ -255,6 +259,27 @@ def _compute_num_patches(img_size, patches): return [i[0] // p * i[1] // p for i, p in zip(img_size, patches)] +@register_autowrap_function +def scale_image(x, ss: Tuple[int, int], crop_scale: bool = False): # annotations for torchscript + """ + Pulled out of CrossViT.forward_features to bury conditional logic in a leaf node for FX tracing. + Args: + x (Tensor): input image + ss (tuple[int, int]): height and width to scale to + crop_scale (bool): whether to crop instead of interpolate to achieve the desired scale. Defaults to False + Returns: + Tensor: the "scaled" image batch tensor + """ + H, W = x.shape[-2:] + if H != ss[0] or W != ss[1]: + if crop_scale and ss[0] <= H and ss[1] <= W: + cu, cl = int(round((H - ss[0]) / 2.)), int(round((W - ss[1]) / 2.)) + x = x[:, :, cu:cu + ss[0], cl:cl + ss[1]] + else: + x = torch.nn.functional.interpolate(x, size=ss, mode='bicubic', align_corners=False) + return x + + class CrossViT(nn.Module): """ Vision Transformer with support for patch or hybrid CNN input stage """ @@ -342,17 +367,12 @@ class CrossViT(nn.Module): range(self.num_branches)]) def forward_features(self, x): - B, C, H, W = x.shape + B = x.shape[0] xs = [] for i, patch_embed in enumerate(self.patch_embed): x_ = x ss = self.img_size_scaled[i] - if H != ss[0] or W != ss[1]: - if self.crop_scale and ss[0] <= H and ss[1] <= W: - cu, cl = int(round((H - ss[0]) / 2.)), int(round((W - ss[1]) / 2.)) - x_ = x_[:, :, cu:cu + ss[0], cl:cl + ss[1]] - else: - x_ = torch.nn.functional.interpolate(x_, size=ss, mode='bicubic', align_corners=False) + x_ = scale_image(x_, ss, self.crop_scale) x_ = patch_embed(x_) cls_tokens = self.cls_token_0 if i == 0 else self.cls_token_1 # hard-coded for torch jit script cls_tokens = cls_tokens.expand(B, -1, -1) diff --git a/timm/models/fx_features.py b/timm/models/fx_features.py index 3d253444..c8f296d4 100644 --- a/timm/models/fx_features.py +++ b/timm/models/fx_features.py @@ -1,42 +1,31 @@ """ PyTorch FX Based Feature Extraction Helpers -An extension/alternative to timm.models.features making use of PyTorch FX. Here, the idea is to: - 1. Symbolically trace a model producing a graph based intermediate representation (PyTorch FX functionality with - some custom tweaks) - 2. Identify desired feature extraction nodes and reconfigure them as output nodes while deleting all unecessary - nodes. (custom - inspired by https://github.com/pytorch/vision/pull/3597) - 3. Write the resulting graph into a GraphModule (PyTorch FX functionality) -Copyright 2021 Alexander Soare +Using https://pytorch.org/vision/stable/feature_extraction.html """ -from typing import Callable, Dict, Union, List, Optional -import math -from collections import OrderedDict -from pprint import pprint -from inspect import ismethod -import re -import warnings -from copy import deepcopy -from itertools import chain - -import torch +from typing import Callable from torch import nn -from torch import fx -import torch.nn.functional as F -from torch.fx.graph_module import _copy_attr from .features import _get_feature_info -from .fx_helpers import fx_float_to_int -# Layers we went to treat as leaf modules for FeatureGraphNet -from .layers import Conv2dSame, ScaledStdConv2dSame, BatchNormAct2d, BlurPool2d, CondConv2d, StdConv2dSame -from .layers import GatherExcite, DropPath +try: + from torchvision.models.feature_extraction import create_feature_extractor +except ImportError: + pass + +# Layers we went to treat as leaf modules +from .layers import Conv2dSame, ScaledStdConv2dSame, BatchNormAct2d, BlurPool2d, CondConv2d, StdConv2dSame, DropPath from .layers.non_local_attn import BilinearAttnTransform from .layers.pool2d_same import MaxPool2dSame, AvgPool2dSame - -# These modules will not be traced through. +# NOTE: By default, any modules from timm.models.layers that we want to treat as leaf modules go here +# BUT modules from timm.models should use the registration mechanism below _leaf_modules = { - Conv2dSame, ScaledStdConv2dSame, BatchNormAct2d, BlurPool2d, CondConv2d, StdConv2dSame, GatherExcite, DropPath, - BilinearAttnTransform, MaxPool2dSame, AvgPool2dSame + BatchNormAct2d, # reason: flow control for jit scripting + BilinearAttnTransform, # reason: flow control t <= 1 + BlurPool2d, # reason: TypeError: F.conv2d received Proxy in groups=x.shape[1] + # Reason: get_same_padding has a max which raises a control flow error + Conv2dSame, MaxPool2dSame, ScaledStdConv2dSame, StdConv2dSame, AvgPool2dSame, + CondConv2d, # reason: TypeError: F.conv2d received Proxy in groups=self.groups * B (because B = x.shape[0]) + DropPath, # reason: TypeError: rand recieved Proxy in `size` argument } try: @@ -54,425 +43,16 @@ def register_leaf_module(module: nn.Module): return module -# These functions will not be traced through -_autowrap_functions=(fx_float_to_int,) - - -class TimmTracer(fx.Tracer): - """ - Temporary bridge from torch.fx.Tracer to include any general workarounds required to make FX work for us - """ - def __init__(self, autowrap_modules=(math, ), autowrap_functions=(), enable_cpatching=False): - super().__init__(autowrap_modules=autowrap_modules, enable_cpatching=enable_cpatching) - # FIXME: This is a workaround pending on a PyTorch PR https://github.com/pytorch/pytorch/pull/62106 - self._autowrap_function_ids.update(set([id(f) for f in autowrap_functions])) - - def create_node(self, kind, target, args, kwargs, name=None, type_expr=None): - # FIXME: This is a workaround pending on a PyTorch PR https://github.com/pytorch/pytorch/pull/62095 - if target == F.pad: - kwargs['value'] = float(kwargs['value']) - return super().create_node(kind, target, args, kwargs, name=name, type_expr=type_expr) - - -class LeafNodeTracer(TimmTracer): - """ - Account for desired leaf nodes according to _leaf_modules and _autowrap functions - """ - def __init__(self): - super().__init__(autowrap_functions=_autowrap_functions) - - def is_leaf_module(self, m: nn.Module, module_qualname: str) -> bool: - if isinstance(m, tuple(_leaf_modules)): - return True - return super().is_leaf_module(m, module_qualname) - - -def _is_subseq(x, y): - """Check if y is a subseqence of x - https://stackoverflow.com/a/24017747/4391249 - """ - iter_x = iter(x) - return all(any(x_item == y_item for x_item in iter_x) for y_item in y) - - -# Taken from https://github.com/pytorch/examples/blob/master/fx/module_tracer.py with modifications for storing -# qualified names for all Nodes, not just top-level Modules -class NodePathTracer(LeafNodeTracer): - """ - NodePathTracer is an FX tracer that, for each operation, also records the - qualified name of the Node from which the operation originated. A - qualified name here is a `.` seperated path walking the hierarchy from top - level module down to leaf operation or leaf module. The name of the top - level module is not included as part of the qualified name. For example, - if we trace a module who's forward method applies a ReLU module, the - qualified name for that node will simply be 'relu'. - - Some notes on the specifics: - - Nodes are recorded to `self.node_to_qualname` which is a dictionary - mapping a given Node object to its qualified name. - - Nodes are recorded in the order which they are executed during - tracing. - - When a duplicate qualified name is encountered, a suffix of the form - _{int} is added. The counter starts from 1. - """ - def __init__(self, *args, **kwargs): - super(NodePathTracer, self).__init__(*args, **kwargs) - # Track the qualified name of the Node being traced - self.current_module_qualname = '' - # A map from FX Node to the qualified name - self.node_to_qualname = OrderedDict() - - def call_module(self, m: torch.nn.Module, forward: Callable, args, kwargs): - """ - Override of `fx.Tracer.call_module` - This override: - 1) Stores away the qualified name of the caller for restoration later - 2) Adds the qualified name of the caller to - `current_module_qualname` for retrieval by `create_proxy` - 3) Once a leaf module is reached, calls `create_proxy` - 4) Restores the caller's qualified name into current_module_qualname - """ - old_qualname = self.current_module_qualname - try: - module_qualname = self.path_of_module(m) - self.current_module_qualname = module_qualname - if not self.is_leaf_module(m, module_qualname): - out = forward(*args, **kwargs) - return out - return self.create_proxy('call_module', module_qualname, args, kwargs) - finally: - self.current_module_qualname = old_qualname - - def create_proxy(self, kind: str, target: fx.node.Target, args, kwargs, - name=None, type_expr=None) -> fx.proxy.Proxy: - """ - Override of `Tracer.create_proxy`. This override intercepts the recording - of every operation and stores away the current traced module's qualified - name in `node_to_qualname` - """ - proxy = super().create_proxy(kind, target, args, kwargs, name, type_expr) - self.node_to_qualname[proxy.node] = self._get_node_qualname( - self.current_module_qualname, proxy.node) - return proxy - - def _get_node_qualname( - self, module_qualname: str, node: fx.node.Node) -> str: - node_qualname = module_qualname - if node.op == 'call_module': - # Node terminates in a leaf module so the module_qualname is a - # complete description of the node - for existing_qualname in reversed(self.node_to_qualname.values()): - # Check to see if existing_qualname is of the form - # {node_qualname} or {node_qualname}_{int} - if re.match(rf'{node_qualname}(_[0-9]+)?$', - existing_qualname) is not None: - postfix = existing_qualname.replace(node_qualname, '') - if len(postfix): - # Existing_qualname is of the form {node_qualname}_{int} - next_index = int(postfix[1:]) + 1 - else: - # existing_qualname is of the form {node_qualname} - next_index = 1 - node_qualname += f'_{next_index}' - break - else: - # Node terminates in non- leaf module so the node name needs to be - # appended - if len(node_qualname) > 0: - # Only append '.' if we are deeper than the top level module - node_qualname += '.' - node_qualname += str(node) - return node_qualname +# Functions we want to autowrap (treat them as leaves) +_autowrap_functions = set() -def _warn_graph_differences( - train_tracer: NodePathTracer, eval_tracer: NodePathTracer): +def register_autowrap_function(func: Callable): """ - Utility function for warning the user if there are differences between - the train graph and the eval graph. + Decorator for functions which ought not to be traced through """ - train_nodes = list(train_tracer.node_to_qualname.values()) - eval_nodes = list(eval_tracer.node_to_qualname.values()) - - if len(train_nodes) == len(eval_nodes) and [ - t == e for t, e in zip(train_nodes, eval_nodes)]: - return - - suggestion_msg = ( - "When choosing nodes for feature extraction, you may need to specify " - "output nodes for train and eval mode separately") - - if _is_subseq(train_nodes, eval_nodes): - msg = ("NOTE: The nodes obtained by tracing the model in eval mode " - "are a subsequence of those obtained in train mode. ") - elif _is_subseq(eval_nodes, train_nodes): - msg = ("NOTE: The nodes obtained by tracing the model in train mode " - "are a subsequence of those obtained in eval mode. ") - else: - msg = ("The nodes obtained by tracing the model in train mode " - "are different to those obtained in eval mode. ") - warnings.warn(msg + suggestion_msg) - - -def print_graph_node_qualified_names( - model: nn.Module, tracer_kwargs: Dict = {}): - """ - Dev utility to prints nodes in order of execution. Useful for choosing - nodes for a FeatureGraphNet design. There are two reasons that qualified - node names can't easily be read directly from the code for a model: - 1. Not all submodules are traced through. Modules from `torch.nn` all - fall within this category. - 2. Node qualified names that occur more than once in the graph get a - `_{counter}` postfix. - The model will be traced twice: once in train mode, and once in eval mode. - If there are discrepancies between the graphs produced, both sets will - be printed and the user will be warned. - - Args: - model (nn.Module): model on which we will extract the features - tracer_kwargs (Dict): a dictionary of keywork arguments for - `NodePathTracer` (which passes them onto it's parent class - `torch.fx.Tracer`). - """ - train_tracer = NodePathTracer(**tracer_kwargs) - train_tracer.trace(model.train()) - eval_tracer = NodePathTracer(**tracer_kwargs) - eval_tracer.trace(model.eval()) - train_nodes = list(train_tracer.node_to_qualname.values()) - eval_nodes = list(eval_tracer.node_to_qualname.values()) - if len(train_nodes) == len(eval_nodes) and [ - t == e for t, e in zip(train_nodes, eval_nodes)]: - # Nodes are aligned in train vs eval mode - pprint(list(train_tracer.node_to_qualname.values())) - return - print("Nodes from train mode:") - pprint(list(train_tracer.node_to_qualname.values())) - print() - print("Nodes from eval mode:") - pprint(list(eval_tracer.node_to_qualname.values())) - print() - _warn_graph_differences(train_tracer, eval_tracer) - - -class DualGraphModule(fx.GraphModule): - """ - A derivative of `fx.GraphModule`. Differs in the following ways: - - Requires a train and eval version of the underlying graph - - Copies submodules according to the nodes of both train and eval graphs. - - Calling train(mode) switches between train graph and eval graph. - """ - def __init__(self, - root: torch.nn.Module, - train_graph: fx.Graph, - eval_graph: fx.Graph, - class_name: str = 'GraphModule'): - """ - Args: - root (torch.nn.Module): module from which the copied module - hierarchy is built - train_graph (Graph): the graph that should be used in train mode - eval_graph (Graph): the graph that should be used in eval mode - """ - super(fx.GraphModule, self).__init__() - - self.__class__.__name__ = class_name - - self.train_graph = train_graph - self.eval_graph = eval_graph - - # Copy all get_attr and call_module ops (indicated by BOTH train and - # eval graphs) - for node in chain(iter(train_graph.nodes), iter(eval_graph.nodes)): - if node.op in ['get_attr', 'call_module']: - assert isinstance(node.target, str) - _copy_attr(root, self, node.target) - - # eval mode by default - self.eval() - self.graph = eval_graph - - # (borrowed from fx.GraphModule): - # Store the Tracer class responsible for creating a Graph separately as part of the - # GraphModule state, except when the Tracer is defined in a local namespace. - # Locally defined Tracers are not pickleable. This is needed because torch.package will - # serialize a GraphModule without retaining the Graph, and needs to use the correct Tracer - # to re-create the Graph during deserialization. - # TODO uncomment this when https://github.com/pytorch/pytorch/pull/63121 is available - # assert self.eval_graph._tracer_cls == self.train_graph._tracer_cls, \ - # "Train mode and eval mode should use the same tracer class" - # self._tracer_cls = None - # if self.graph._tracer_cls and '' not in self.graph._tracer_cls.__qualname__: - # self._tracer_cls = self.graph._tracer_cls - - def train(self, mode=True): - """ - Swap out the graph depending on the training mode. - NOTE this should be safe when calling model.eval() because that just - calls this with mode == False. - """ - if mode: - self.graph = self.train_graph - else: - self.graph = self.eval_graph - return super().train(mode=mode) - - -def build_feature_graph_net( - model: nn.Module, - return_nodes: Union[List[str], Dict[str, str]], - train_return_nodes: Optional[Union[List[str], Dict[str, str]]] = None, - eval_return_nodes: Optional[Union[List[str], Dict[str, str]]] = None, - tracer_kwargs: Dict = {}) -> fx.GraphModule: - """ - Creates a new graph module that returns intermediate nodes from a given - model as dictionary with user specified keys as strings, and the requested - outputs as values. This is achieved by re-writing the computation graph of - the model via FX to return the desired nodes as outputs. All unused nodes - are removed, together with their corresponding parameters. - - A note on node specification: A node qualified name is specified as a `.` - seperated path walking the hierarchy from top level module down to leaf - operation or leaf module. For instance `blocks.5.3.bn1`. The keys of the - `return_nodes` argument should point to either a node's qualified name, - or some truncated version of it. For example, one could provide `blocks.5` - as a key, and the last node with that prefix will be selected. - `print_graph_node_qualified_names` is a useful helper function for getting - a list of qualified names of a model. - - An attempt is made to keep all non-parametric properties of the original - model, but existing properties of the constructed `GraphModule` are not - overwritten. - - Args: - model (nn.Module): model on which we will extract the features - return_nodes (Union[List[name], Dict[name, new_name]])): either a list - or a dict containing the names (or partial names - see note above) - of the nodes for which the activations will be returned. If it is - a `Dict`, the keys are the qualified node names, and the values - are the user-specified keys for the graph module's returned - dictionary. If it is a `List`, it is treated as a `Dict` mapping - node specification strings directly to output names. - tracer_kwargs (Dict): a dictionary of keywork arguments for - `NodePathTracer` (which passes them onto it's parent class - `torch.fx.Tracer`). - - Examples:: - - >>> model = torchvision.models.resnet18() - >>> # extract layer1 and layer3, giving as names `feat1` and feat2` - >>> graph_module = torchvision.models._utils.build_feature_graph_net(m, - >>> {'layer1': 'feat1', 'layer3': 'feat2'}) - >>> out = graph_module(torch.rand(1, 3, 224, 224)) - >>> print([(k, v.shape) for k, v in out.items()]) - >>> [('feat1', torch.Size([1, 64, 56, 56])), - >>> ('feat2', torch.Size([1, 256, 14, 14]))] - - """ - is_training = model.training - - if isinstance(return_nodes, list): - return_nodes = {n: n for n in return_nodes} - return_nodes = {str(k): str(v) for k, v in return_nodes.items()} - - assert not ((train_return_nodes is None) ^ (eval_return_nodes is None)), \ - ("If any of `train_return_nodes` and `eval_return_nodes` are " - "specified, then both should be specified") - - if train_return_nodes is None: - train_return_nodes = deepcopy(return_nodes) - eval_return_nodes = deepcopy(return_nodes) - - # Repeat the tracing and graph rewriting for train and eval mode - tracers = {} - graphs = {} - return_nodes = { - 'train': train_return_nodes, - 'eval': eval_return_nodes - } - for mode in ['train', 'eval']: - if mode == 'train': - model.train() - elif mode == 'eval': - model.eval() - - # Instantiate our NodePathTracer and use that to trace the model - tracer = NodePathTracer(**tracer_kwargs) - graph = tracer.trace(model) - - name = model.__class__.__name__ if isinstance( - model, nn.Module) else model.__name__ - graph_module = fx.GraphModule(tracer.root, graph, name) - - available_nodes = [f'{v}.{k}' for k, v in tracer.node_to_qualname.items()] - # FIXME We don't know if we should expect this to happen - assert len(set(available_nodes)) == len(available_nodes), \ - "There are duplicate nodes! Please raise an issue https://github.com/pytorch/vision/issues" - # Check that all outputs in return_nodes are present in the model - for query in return_nodes[mode].keys(): - if not any([m.startswith(query) for m in available_nodes]): - raise ValueError(f"return_node: {query} is not present in model") - - # Remove existing output nodes (train mode) - orig_output_nodes = [] - for n in reversed(graph_module.graph.nodes): - if n.op == "output": - orig_output_nodes.append(n) - assert len(orig_output_nodes) - for n in orig_output_nodes: - graph_module.graph.erase_node(n) - - # Find nodes corresponding to return_nodes and make them into output_nodes - nodes = [n for n in graph_module.graph.nodes] - output_nodes = OrderedDict() - for n in reversed(nodes): - if 'tensor_constant' in str(n): - # NOTE Without this control flow we would get a None value for - # `module_qualname = tracer.node_to_qualname.get(n)`. - # On the other hand, we can safely assume that we'll never need to - # get this as an interesting intermediate node. - continue - module_qualname = tracer.node_to_qualname.get(n) - for query in return_nodes[mode]: - depth = query.count('.') - if '.'.join(module_qualname.split('.')[:depth + 1]) == query: - output_nodes[return_nodes[mode][query]] = n - return_nodes[mode].pop(query) - break - output_nodes = OrderedDict(reversed(list(output_nodes.items()))) - - # And add them in the end of the graph - with graph_module.graph.inserting_after(nodes[-1]): - graph_module.graph.output(output_nodes) - - # Remove unused modules / parameters - graph_module.graph.eliminate_dead_code() - graph_module.recompile() - - # Keep track of the tracer and graph so we can choose the main one - tracers[mode] = tracer - graphs[mode] = graph - - # Warn user if there are any discrepancies between the graphs of the - # train and eval modes - _warn_graph_differences(tracers['train'], tracers['eval']) - - # Build the final graph module - graph_module = DualGraphModule( - model, graphs['train'], graphs['eval'], class_name=name) - - # Keep non-parameter model properties for reference - for attr_str in model.__dir__(): - attr = getattr(model, attr_str) - if (not attr_str.startswith('_') - and attr_str not in graph_module.__dir__() - and not ismethod(attr) - and not isinstance(attr, (nn.Module, nn.Parameter))): - setattr(graph_module, attr_str, attr) - - # Restore original training mode - graph_module.train(is_training) - - return graph_module + _autowrap_functions.add(func) + return func class FeatureGraphNet(nn.Module): @@ -483,7 +63,10 @@ class FeatureGraphNet(nn.Module): assert len(out_map) == len(out_indices) return_nodes = {info['module']: out_map[i] if out_map is not None else info['module'] for i, info in enumerate(self.feature_info) if i in out_indices} - self.graph_module = build_feature_graph_net(model, return_nodes) + self.graph_module = create_feature_extractor( + model, return_nodes, + tracer_kwargs={'leaf_modules': list(_leaf_modules), 'autowrap_functions': list(_autowrap_functions)}) def forward(self, x): return list(self.graph_module(x).values()) + \ No newline at end of file diff --git a/timm/models/fx_helpers.py b/timm/models/fx_helpers.py deleted file mode 100644 index 878ba381..00000000 --- a/timm/models/fx_helpers.py +++ /dev/null @@ -1,7 +0,0 @@ - -def fx_float_to_int(x: float) -> int: - """ - Symbolic tracing helper to substitute for inbuilt `int`. - Hint: Inbuilt `int` can't accept an argument of type `Proxy` - """ - return int(x) \ No newline at end of file diff --git a/timm/models/layers/bottleneck_attn.py b/timm/models/layers/bottleneck_attn.py index 305f9de3..c3db464e 100644 --- a/timm/models/layers/bottleneck_attn.py +++ b/timm/models/layers/bottleneck_attn.py @@ -22,7 +22,7 @@ import torch.nn.functional as F from .helpers import to_2tuple, make_divisible from .weight_init import trunc_normal_ -from timm.models.fx_helpers import fx_and +from .trace_utils import _assert def rel_logits_1d(q, rel_k, permute_mask: List[int]): @@ -37,7 +37,7 @@ def rel_logits_1d(q, rel_k, permute_mask: List[int]): permute_mask: permute output dim according to this """ B, H, W, dim = q.shape - x = torch.matmul(q, rel_k.transpose(-1, -2)) + x = (q @ rel_k.transpose(-1, -2)) x = x.reshape(-1, W, 2 * W -1) # pad to shift from relative to absolute indexing @@ -134,8 +134,8 @@ class BottleneckAttn(nn.Module): def forward(self, x): B, C, H, W = x.shape - torch._assert(H == self.pos_embed.height, '') - torch._assert(W == self.pos_embed.width, '') + _assert(H == self.pos_embed.height, '') + _assert(W == self.pos_embed.width, '') x = self.qkv(x) # B, (2 * dim_head_qk + dim_head_v) * num_heads, H, W diff --git a/timm/models/layers/evo_norm.py b/timm/models/layers/evo_norm.py index 02aa0a0c..ecc5fb61 100644 --- a/timm/models/layers/evo_norm.py +++ b/timm/models/layers/evo_norm.py @@ -12,6 +12,8 @@ Hacked together by / Copyright 2020 Ross Wightman import torch import torch.nn as nn +from .trace_utils import _assert + class EvoNormBatch2d(nn.Module): def __init__(self, num_features, apply_act=True, momentum=0.1, eps=1e-5, drop_block=None): @@ -72,9 +74,9 @@ class EvoNormSample2d(nn.Module): nn.init.ones_(self.v) def forward(self, x): - torch._assert(x.dim() == 4, 'expected 4D input') + _assert(x.dim() == 4, 'expected 4D input') B, C, H, W = x.shape - torch._assert(C % self.groups == 0, '') + _assert(C % self.groups == 0, '') if self.apply_act: n = x * (x * self.v).sigmoid() x = x.reshape(B, self.groups, -1) diff --git a/timm/models/layers/global_context.py b/timm/models/layers/global_context.py index a0bb8a43..de7fb5c1 100644 --- a/timm/models/layers/global_context.py +++ b/timm/models/layers/global_context.py @@ -7,7 +7,6 @@ Official code consulted as reference: https://github.com/xvjiarui/GCNet Hacked together by / Copyright 2021 Ross Wightman """ -import torch from torch import nn as nn import torch.nn.functional as F @@ -53,7 +52,7 @@ class GlobalContext(nn.Module): if self.conv_attn is not None: attn = self.conv_attn(x).reshape(B, 1, H * W) # (B, 1, H * W) attn = F.softmax(attn, dim=-1).unsqueeze(3) # (B, 1, H * W, 1) - context = torch.matmul(x.reshape(B, C, H * W).unsqueeze(1), attn) + context = x.reshape(B, C, H * W).unsqueeze(1) @ attn context = context.view(B, C, 1, 1) else: context = x.mean(dim=(2, 3), keepdim=True) diff --git a/timm/models/layers/halo_attn.py b/timm/models/layers/halo_attn.py index ec93474f..f2ac64f8 100644 --- a/timm/models/layers/halo_attn.py +++ b/timm/models/layers/halo_attn.py @@ -16,7 +16,7 @@ The attention mechanism works but it's slow as implemented. Hacked together by / Copyright 2021 Ross Wightman """ -from typing import Tuple, List +from typing import List import torch from torch import nn @@ -24,6 +24,7 @@ import torch.nn.functional as F from .helpers import make_divisible from .weight_init import trunc_normal_ +from .trace_utils import _assert def rel_logits_1d(q, rel_k, permute_mask: List[int]): @@ -41,7 +42,7 @@ def rel_logits_1d(q, rel_k, permute_mask: List[int]): rel_size = rel_k.shape[0] win_size = (rel_size + 1) // 2 - x = torch.matmul(q, rel_k.transpose(-1, -2)) + x = (q @ rel_k.transpose(-1, -2)) x = x.reshape(-1, W, rel_size) # pad to shift from relative to absolute indexing @@ -167,8 +168,8 @@ class HaloAttn(nn.Module): def forward(self, x): B, C, H, W = x.shape - torch._assert(H % self.block_size == 0, '') - torch._assert(W % self.block_size == 0, '') + _assert(H % self.block_size == 0, '') + _assert(W % self.block_size == 0, '') num_h_blocks = H // self.block_size num_w_blocks = W // self.block_size num_blocks = num_h_blocks * num_w_blocks diff --git a/timm/models/layers/lambda_layer.py b/timm/models/layers/lambda_layer.py index 058426b6..e50b43c8 100644 --- a/timm/models/layers/lambda_layer.py +++ b/timm/models/layers/lambda_layer.py @@ -116,8 +116,8 @@ class LambdaLayer(nn.Module): v = self.norm_v(v).reshape(B, self.dim_v, M).transpose(-1, -2) # B, M, V k = F.softmax(k.reshape(B, self.dim_qk, M), dim=-1) # B, K, M - content_lam = torch.matmul(k, v) # B, K, V - content_out = torch.matmul(q, content_lam.unsqueeze(1)) # B, num_heads, M, V + content_lam = k @ v # B, K, V + content_out = q @ content_lam.unsqueeze(1) # B, num_heads, M, V if self.pos_emb is None: position_lam = self.conv_lambda(v.reshape(B, 1, H, W, self.dim_v)) # B, H, W, V, K diff --git a/timm/models/layers/non_local_attn.py b/timm/models/layers/non_local_attn.py index 5f83005c..881fa36d 100644 --- a/timm/models/layers/non_local_attn.py +++ b/timm/models/layers/non_local_attn.py @@ -10,7 +10,7 @@ from torch.nn import functional as F from .conv_bn_act import ConvBnAct from .helpers import make_divisible -from timm.models.fx_helpers import fx_and +from .trace_utils import _assert class NonLocalAttn(nn.Module): @@ -84,7 +84,7 @@ class BilinearAttnTransform(nn.Module): def resize_mat(self, x, t: int): B, C, block_size, block_size1 = x.shape - torch._assert(block_size == block_size1, '') + _assert(block_size == block_size1, '') if t <= 1: return x x = x.view(B * C, -1, 1, 1) @@ -96,8 +96,8 @@ class BilinearAttnTransform(nn.Module): return x def forward(self, x): - torch._assert(x.shape[-1] % self.block_size == 0, '') - torch._assert(x.shape[-2] % self.block_size == 0, '') + _assert(x.shape[-1] % self.block_size == 0, '') + _assert(x.shape[-2] % self.block_size == 0, '') B, C, H, W = x.shape out = self.conv1(x) rp = F.adaptive_max_pool2d(out, (self.block_size, 1)) diff --git a/timm/models/layers/selective_kernel.py b/timm/models/layers/selective_kernel.py index 69aca86b..1aeb9294 100644 --- a/timm/models/layers/selective_kernel.py +++ b/timm/models/layers/selective_kernel.py @@ -9,6 +9,7 @@ from torch import nn as nn from .conv_bn_act import ConvBnAct from .helpers import make_divisible +from .trace_utils import _assert def _kernel_valid(k): @@ -34,7 +35,7 @@ class SelectiveKernelAttn(nn.Module): self.fc_select = nn.Conv2d(attn_channels, channels * num_paths, kernel_size=1, bias=False) def forward(self, x): - torch._assert(x.shape[1] == self.num_paths, '') + _assert(x.shape[1] == self.num_paths, '') x = x.sum(1).mean((2, 3), keepdim=True) x = self.fc_reduce(x) x = self.bn(x) diff --git a/timm/models/layers/swin_attn.py b/timm/models/layers/swin_attn.py deleted file mode 100644 index 2a3731f3..00000000 --- a/timm/models/layers/swin_attn.py +++ /dev/null @@ -1,183 +0,0 @@ -""" Shifted Window Attn - -This is a WIP experiment to apply windowed attention from the Swin Transformer -to a stand-alone module for use as an attn block in conv nets. - -Based on original swin window code at https://github.com/microsoft/Swin-Transformer -Swin Transformer paper: https://arxiv.org/pdf/2103.14030.pdf -""" -from typing import Optional - -import torch -import torch.nn as nn - -from .drop import DropPath -from .helpers import to_2tuple -from .weight_init import trunc_normal_ -from timm.models.fx_helpers import fx_float_to_int - - -def window_partition(x, win_size: int): - """ - Args: - x: (B, H, W, C) - win_size (int): window size - - Returns: - windows: (num_windows*B, window_size, window_size, C) - """ - B, H, W, C = x.shape - x = x.view(B, H // win_size, win_size, W // win_size, win_size, C) - windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, win_size, win_size, C) - return windows - - -def window_reverse(windows, win_size: int, H: int, W: int): - """ - Args: - windows: (num_windows*B, window_size, window_size, C) - win_size (int): Window size - H (int): Height of image - W (int): Width of image - - Returns: - x: (B, H, W, C) - """ - B = fx_float_to_int(windows.shape[0] / (H * W / win_size / win_size)) - x = windows.view(B, H // win_size, W // win_size, win_size, win_size, -1) - x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) - return x - - -class WindowAttention(nn.Module): - r""" Window based multi-head self attention (W-MSA) module with relative position bias. - It supports both of shifted and non-shifted window. - - Args: - dim (int): Number of input channels. - win_size (int): The height and width of the window. - num_heads (int): Number of attention heads. - qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True - attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 - """ - - def __init__( - self, dim, dim_out=None, feat_size=None, stride=1, win_size=8, shift_size=None, num_heads=8, - qkv_bias=True, attn_drop=0.): - - super().__init__() - self.dim_out = dim_out or dim - self.feat_size = to_2tuple(feat_size) - self.win_size = win_size - self.shift_size = shift_size or win_size // 2 - if min(self.feat_size) <= win_size: - # if window size is larger than input resolution, we don't partition windows - self.shift_size = 0 - self.win_size = min(self.feat_size) - assert 0 <= self.shift_size < self.win_size, "shift_size must in 0-window_size" - self.num_heads = num_heads - head_dim = self.dim_out // num_heads - self.scale = head_dim ** -0.5 - - if self.shift_size > 0: - # calculate attention mask for SW-MSA - H, W = self.feat_size - img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1 - h_slices = ( - slice(0, -self.win_size), - slice(-self.win_size, -self.shift_size), - slice(-self.shift_size, None)) - w_slices = ( - slice(0, -self.win_size), - slice(-self.win_size, -self.shift_size), - slice(-self.shift_size, None)) - cnt = 0 - for h in h_slices: - for w in w_slices: - img_mask[:, h, w, :] = cnt - cnt += 1 - mask_windows = window_partition(img_mask, self.win_size) # num_win, window_size, window_size, 1 - mask_windows = mask_windows.view(-1, self.win_size * self.win_size) - attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) - attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) - else: - attn_mask = None - self.register_buffer("attn_mask", attn_mask) - - # define a parameter table of relative position bias - self.relative_position_bias_table = nn.Parameter( - # 2 * Wh - 1 * 2 * Ww - 1, nH - torch.zeros((2 * self.win_size - 1) * (2 * self.win_size - 1), num_heads)) - trunc_normal_(self.relative_position_bias_table, std=.02) - - # get pair-wise relative position index for each token inside the window - coords_h = torch.arange(self.win_size) - coords_w = torch.arange(self.win_size) - coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww - coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww - relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww - relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 - relative_coords[:, :, 0] += self.win_size - 1 # shift to start from 0 - relative_coords[:, :, 1] += self.win_size - 1 - relative_coords[:, :, 0] *= 2 * self.win_size - 1 - relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww - self.register_buffer("relative_position_index", relative_position_index) - - self.qkv = nn.Linear(dim, self.dim_out * 3, bias=qkv_bias) - self.attn_drop = nn.Dropout(attn_drop) - self.softmax = nn.Softmax(dim=-1) - self.pool = nn.AvgPool2d(2, 2) if stride == 2 else nn.Identity() - - def reset_parameters(self): - trunc_normal_(self.qkv.weight, std=self.qkv.weight.shape[1] ** -0.5) - trunc_normal_(self.relative_position_bias_table, std=.02) - - def forward(self, x): - B, C, H, W = x.shape - x = x.permute(0, 2, 3, 1) - - # cyclic shift - if self.shift_size > 0: - shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) - else: - shifted_x = x - - # partition windows - win_size_sq = self.win_size * self.win_size - x_windows = window_partition(shifted_x, self.win_size) # num_win * B, window_size, window_size, C - x_windows = x_windows.view(-1, win_size_sq, C) # num_win * B, window_size*window_size, C - BW, N, _ = x_windows.shape - - qkv = self.qkv(x_windows) - qkv = qkv.reshape(BW, N, 3, self.num_heads, self.dim_out // self.num_heads).permute(2, 0, 3, 1, 4) - q, k, v = qkv[0], qkv[1], qkv[2] - q = q * self.scale - attn = torch.matmul(q, k.transpose(-2, -1)) - - relative_position_bias = self.relative_position_bias_table[ - self.relative_position_index.view(-1)].view(win_size_sq, win_size_sq, -1) - relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh * Ww, Wh * Ww - attn = attn + relative_position_bias.unsqueeze(0) - if self.attn_mask is not None: - num_win = self.attn_mask.shape[0] - attn = attn.view(B, num_win, self.num_heads, N, N) + self.attn_mask.unsqueeze(1).unsqueeze(0) - attn = attn.view(-1, self.num_heads, N, N) - attn = self.softmax(attn) - attn = self.attn_drop(attn) - - x = torch.matmul(attn, v).transpose(1, 2).reshape(BW, N, self.dim_out) - - # merge windows - x = x.view(-1, self.win_size, self.win_size, self.dim_out) - shifted_x = window_reverse(x, self.win_size, H, W) # B H' W' C - - # reverse cyclic shift - if self.shift_size > 0: - x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) - else: - x = shifted_x - x = x.view(B, H, W, self.dim_out).permute(0, 3, 1, 2) - x = self.pool(x) - return x - - diff --git a/timm/models/levit.py b/timm/models/levit.py index c4377bb1..9987e4ba 100644 --- a/timm/models/levit.py +++ b/timm/models/levit.py @@ -293,10 +293,10 @@ class Attention(nn.Module): k = k.permute(0, 2, 1, 3) v = v.permute(0, 2, 1, 3) - attn = torch.matmul(q, k.transpose(-2, -1)) * self.scale + self.get_attention_biases(x.device) + attn = q @ k.transpose(-2, -1) * self.scale + self.get_attention_biases(x.device) attn = attn.softmax(dim=-1) - x = torch.matmul(attn, v).transpose(1, 2).reshape(B, N, self.dh) + x = (attn @ v).transpose(1, 2).reshape(B, N, self.dh) x = self.proj(x) return x @@ -387,10 +387,10 @@ class AttentionSubsample(nn.Module): v = v.permute(0, 2, 1, 3) # BHNC q = self.q(x).view(B, self.resolution_2, self.num_heads, self.key_dim).permute(0, 2, 1, 3) - attn = torch.matmul(q, k.transpose(-2, -1)) * self.scale + self.get_attention_biases(x.device) + attn = q @ k.transpose(-2, -1) * self.scale + self.get_attention_biases(x.device) attn = attn.softmax(dim=-1) - x = torch.matmul(attn, v).transpose(1, 2).reshape(B, -1, self.dh) + x = (attn @ v).transpose(1, 2).reshape(B, -1, self.dh) x = self.proj(x) return x diff --git a/timm/models/nest.py b/timm/models/nest.py index 73f14da5..c5951aea 100644 --- a/timm/models/nest.py +++ b/timm/models/nest.py @@ -25,13 +25,13 @@ import torch.nn.functional as F from torch import nn from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from .fx_features import register_autowrap_function from .helpers import build_model_with_cfg, named_apply -from .fx_helpers import fx_float_to_int from .layers import PatchEmbed, Mlp, DropPath, create_classifier, trunc_normal_ +from .layers.trace_utils import _assert from .layers import create_conv2d, create_pool2d, to_ntuple from .registry import register_model - _logger = logging.getLogger(__name__) @@ -85,12 +85,12 @@ class Attention(nn.Module): qkv = self.qkv(x).reshape(B, T, N, 3, self.num_heads, C // self.num_heads).permute(3, 0, 4, 1, 2, 5) q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple) - attn = torch.matmul(q, k.transpose(-2, -1)) * self.scale # (B, H, T, N, N) + attn = (q @ k.transpose(-2, -1)) * self.scale # (B, H, T, N, N) attn = attn.softmax(dim=-1) attn = self.attn_drop(attn) # (B, H, T, N, C'), permute -> (B, T, N, C', H) - x = torch.matmul(attn, v).permute(0, 2, 3, 4, 1).reshape(B, T, N, C) + x = (attn @ v).permute(0, 2, 3, 4, 1).reshape(B, T, N, C) x = self.proj(x) x = self.proj_drop(x) return x # (B, T, N, C) @@ -130,8 +130,8 @@ class ConvPool(nn.Module): """ x is expected to have shape (B, C, H, W) """ - torch._assert(x.shape[-2] % 2 == 0, 'BlockAggregation requires even input spatial dims') - torch._assert(x.shape[-1] % 2 == 0, 'BlockAggregation requires even input spatial dims') + _assert(x.shape[-2] % 2 == 0, 'BlockAggregation requires even input spatial dims') + _assert(x.shape[-1] % 2 == 0, 'BlockAggregation requires even input spatial dims') x = self.conv(x) # Layer norm done over channel dim only x = self.norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) @@ -146,8 +146,8 @@ def blockify(x, block_size: int): block_size (int): edge length of a single square block in units of H, W """ B, H, W, C = x.shape - torch._assert(H % block_size == 0, '`block_size` must divide input height evenly') - torch._assert(W % block_size == 0, '`block_size` must divide input width evenly') + _assert(H % block_size == 0, '`block_size` must divide input height evenly') + _assert(W % block_size == 0, '`block_size` must divide input width evenly') grid_height = H // block_size grid_width = W // block_size x = x.reshape(B, grid_height, block_size, grid_width, block_size, C) @@ -155,6 +155,7 @@ def blockify(x, block_size: int): return x # (B, T, N, C) +@register_autowrap_function # reason: int receives Proxy def deblockify(x, block_size: int): """blocks to image Args: @@ -162,7 +163,7 @@ def deblockify(x, block_size: int): block_size (int): edge length of a single square block in units of desired H, W """ B, T, _, C = x.shape - grid_size = fx_float_to_int(math.sqrt(T)) + grid_size = int(math.sqrt(T)) height = width = grid_size * block_size x = x.reshape(B, grid_size, grid_size, block_size, block_size, C) x = x.transpose(2, 3).reshape(B, height, width, C) diff --git a/timm/models/nfnet.py b/timm/models/nfnet.py index ec86dbb8..4e0f2b21 100644 --- a/timm/models/nfnet.py +++ b/timm/models/nfnet.py @@ -27,7 +27,6 @@ import torch.nn as nn from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from .helpers import build_model_with_cfg -from timm.models.fx_features import register_leaf_module from .registry import register_model from .layers import ClassifierHead, DropPath, AvgPool2dSame, ScaledStdConv2d, ScaledStdConv2dSame,\ get_act_layer, get_act_fn, get_attn, make_divisible @@ -319,7 +318,6 @@ class DownsampleAvg(nn.Module): return self.conv(self.pool(x)) -@register_leaf_module # FX feature extraction was giving different valued features. Perhaps to do with control flow? class NormFreeBlock(nn.Module): """Normalization-Free pre-activation block. """ diff --git a/timm/models/swin_transformer.py b/timm/models/swin_transformer.py index 53c7bcd5..d5dd5513 100644 --- a/timm/models/swin_transformer.py +++ b/timm/models/swin_transformer.py @@ -21,9 +21,10 @@ import torch.nn as nn import torch.utils.checkpoint as checkpoint from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from .fx_features import register_autowrap_function from .helpers import build_model_with_cfg, overlay_external_default_cfg -from .fx_helpers import fx_float_to_int from .layers import PatchEmbed, Mlp, DropPath, to_2tuple, trunc_normal_ +from .layers.trace_utils import _assert from .registry import register_model from .vision_transformer import checkpoint_filter_fn, _init_vit_weights @@ -102,6 +103,7 @@ def window_partition(x, window_size: int): return windows +@register_autowrap_function # reason: int argument is a Proxy def window_reverse(windows, window_size: int, H: int, W: int): """ Args: @@ -113,7 +115,7 @@ def window_reverse(windows, window_size: int, H: int, W: int): Returns: x: (B, H, W, C) """ - B = fx_float_to_int(windows.shape[0] / (H * W / window_size / window_size)) + B = int(windows.shape[0] / (H * W / window_size / window_size)) x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) return x @@ -177,7 +179,7 @@ class WindowAttention(nn.Module): q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple) q = q * self.scale - attn = torch.matmul(q, k.transpose(-2, -1)) + attn = (q @ k.transpose(-2, -1)) relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH @@ -194,7 +196,7 @@ class WindowAttention(nn.Module): attn = self.attn_drop(attn) - x = torch.matmul(attn, v).transpose(1, 2).reshape(B_, N, C) + x = (attn @ v).transpose(1, 2).reshape(B_, N, C) x = self.proj(x) x = self.proj_drop(x) return x @@ -272,7 +274,7 @@ class SwinTransformerBlock(nn.Module): def forward(self, x): H, W = self.input_resolution B, L, C = x.shape - torch._assert(L == H * W, "input feature has wrong size") + _assert(L == H * W, "input feature has wrong size") shortcut = x x = self.norm1(x) @@ -331,8 +333,8 @@ class PatchMerging(nn.Module): """ H, W = self.input_resolution B, L, C = x.shape - torch._assert(L == H * W, "input feature has wrong size") - torch._assert(H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even.") + _assert(L == H * W, "input feature has wrong size") + _assert(H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even.") x = x.view(B, H, W, C) diff --git a/timm/models/tnt.py b/timm/models/tnt.py index 298808c3..1ad481f6 100644 --- a/timm/models/tnt.py +++ b/timm/models/tnt.py @@ -12,9 +12,9 @@ import torch.nn as nn from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.models.helpers import build_model_with_cfg -from timm.models.fx_helpers import fx_and from timm.models.layers import Mlp, DropPath, trunc_normal_ from timm.models.layers.helpers import to_2tuple +from timm.models.layers.trace_utils import _assert from timm.models.registry import register_model from timm.models.vision_transformer import resize_pos_embed @@ -64,11 +64,11 @@ class Attention(nn.Module): q, k = qk.unbind(0) # make torchscript happy (cannot use tensor as tuple) v = self.v(x).reshape(B, N, self.num_heads, -1).permute(0, 2, 1, 3) - attn = torch.matmul(q, k.transpose(-2, -1)) * self.scale + attn = (q @ k.transpose(-2, -1)) * self.scale attn = attn.softmax(dim=-1) attn = self.attn_drop(attn) - x = torch.matmul(attn, v).transpose(1, 2).reshape(B, N, -1) + x = (attn @ v).transpose(1, 2).reshape(B, N, -1) x = self.proj(x) x = self.proj_drop(x) return x @@ -138,9 +138,9 @@ class PixelEmbed(nn.Module): def forward(self, x, pixel_pos): B, C, H, W = x.shape - torch._assert(H == self.img_size[0], + _assert(H == self.img_size[0], f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]}).") - torch._assert(W == self.img_size[1], + _assert(W == self.img_size[1], f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]}).") x = self.proj(x) x = self.unfold(x) diff --git a/timm/models/twins.py b/timm/models/twins.py index 7b5afafb..9ae70d32 100644 --- a/timm/models/twins.py +++ b/timm/models/twins.py @@ -25,7 +25,7 @@ from .layers import Mlp, DropPath, to_2tuple, trunc_normal_ from .fx_features import register_leaf_module from .registry import register_model from .vision_transformer import Attention -from .helpers import build_model_with_cfg, overlay_external_default_cfg +from .helpers import build_model_with_cfg def _cfg(url='', **kwargs): @@ -63,7 +63,7 @@ default_cfgs = { Size_ = Tuple[int, int] -@register_leaf_module # FX can't symbolically trace control flow in forward method +@register_leaf_module # reason: FX can't symbolically trace control flow in forward method class LocallyGroupedAttn(nn.Module): """ LSA: self attention within a group """ @@ -100,10 +100,10 @@ class LocallyGroupedAttn(nn.Module): qkv = self.qkv(x).reshape( B, _h * _w, self.ws * self.ws, 3, self.num_heads, C // self.num_heads).permute(3, 0, 1, 4, 2, 5) q, k, v = qkv[0], qkv[1], qkv[2] - attn = torch.matmul(q, k.transpose(-2, -1)) * self.scale + attn = (q @ k.transpose(-2, -1)) * self.scale attn = attn.softmax(dim=-1) attn = self.attn_drop(attn) - attn = torch.matmul(attn, v).transpose(2, 3).reshape(B, _h, _w, self.ws, self.ws, C) + attn = (attn @ v).transpose(2, 3).reshape(B, _h, _w, self.ws, self.ws, C) x = attn.transpose(2, 3).reshape(B, _h * self.ws, _w * self.ws, C) if pad_r > 0 or pad_b > 0: x = x[:, :H, :W, :].contiguous() @@ -185,11 +185,11 @@ class GlobalSubSampleAttn(nn.Module): kv = self.kv(x).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) k, v = kv[0], kv[1] - attn = torch.matmul(q, k.transpose(-2, -1)) * self.scale + attn = (q @ k.transpose(-2, -1)) * self.scale attn = attn.softmax(dim=-1) attn = self.attn_drop(attn) - x = torch.matmul(attn, v).transpose(1, 2).reshape(B, N, C) + x = (attn @ v).transpose(1, 2).reshape(B, N, C) x = self.proj(x) x = self.proj_drop(x) diff --git a/timm/models/vgg.py b/timm/models/vgg.py index aee41b25..0f62ac4e 100644 --- a/timm/models/vgg.py +++ b/timm/models/vgg.py @@ -13,7 +13,7 @@ from typing import Union, List, Dict, Any, cast from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from .helpers import build_model_with_cfg from .fx_features import register_leaf_module -from .layers import ClassifierHead, ConvBnAct +from .layers import ClassifierHead from .registry import register_model __all__ = [ @@ -53,7 +53,7 @@ cfgs: Dict[str, List[Union[str, int]]] = { } -@register_leaf_module # FX can't symbolically trace control flow in forward method +@register_leaf_module # reason: FX can't symbolically trace control flow in forward method class ConvMlp(nn.Module): def __init__(self, in_features=512, out_features=4096, kernel_size=7, mlp_ratio=1.0, diff --git a/timm/models/visformer.py b/timm/models/visformer.py index 6ed43102..6e832cd0 100644 --- a/timm/models/visformer.py +++ b/timm/models/visformer.py @@ -100,10 +100,10 @@ class Attention(nn.Module): x = self.qkv(x).reshape(B, 3, self.num_heads, self.head_dim, -1).permute(1, 0, 2, 4, 3) q, k, v = x[0], x[1], x[2] - attn = torch.matmul(q, k.transpose(-2, -1)) * self.scale + attn = (q @ k.transpose(-2, -1)) * self.scale attn = attn.softmax(dim=-1) attn = self.attn_drop(attn) - x = torch.matmul(attn, v) + x = attn @ v x = x.permute(0, 1, 3, 2).reshape(B, -1, H, W) x = self.proj(x) diff --git a/timm/models/vision_transformer.py b/timm/models/vision_transformer.py index fb939624..94ae2666 100644 --- a/timm/models/vision_transformer.py +++ b/timm/models/vision_transformer.py @@ -192,11 +192,11 @@ class Attention(nn.Module): qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple) - attn = torch.matmul(q, k.transpose(-2, -1)) * self.scale + attn = (q @ k.transpose(-2, -1)) * self.scale attn = attn.softmax(dim=-1) attn = self.attn_drop(attn) - x = torch.matmul(attn, v).transpose(1, 2).reshape(B, N, C) + x = (attn @ v).transpose(1, 2).reshape(B, N, C) x = self.proj(x) x = self.proj_drop(x) return x diff --git a/timm/models/xcit.py b/timm/models/xcit.py index 9be29046..f5dd0683 100644 --- a/timm/models/xcit.py +++ b/timm/models/xcit.py @@ -98,7 +98,7 @@ default_cfgs = { } -@register_leaf_module # FX can't symbolically trace torch.arange in forward method +@register_leaf_module # reason: FX can't symbolically trace torch.arange in forward method class PositionalEncodingFourier(nn.Module): """ Positional encoding relying on a fourier kernel matching the one used in the "Attention is all of Need" paper. @@ -274,12 +274,12 @@ class XCA(nn.Module): # Paper section 3.2 l2-Normalization and temperature scaling q = torch.nn.functional.normalize(q, dim=-1) k = torch.nn.functional.normalize(k, dim=-1) - attn = torch.matmul(q, k.transpose(-2, -1)) * self.temperature + attn = (q @ k.transpose(-2, -1)) * self.temperature attn = attn.softmax(dim=-1) attn = self.attn_drop(attn) # (B, H, C', N), permute -> (B, N, H, C') - x = torch.matmul(attn, v).permute(0, 3, 1, 2).reshape(B, N, C) + x = (attn @ v).permute(0, 3, 1, 2).reshape(B, N, C) x = self.proj(x) x = self.proj_drop(x) return x From d2994016e952e9b11b9d8a32020f7565a3c163b5 Mon Sep 17 00:00:00 2001 From: Alexander Soare Date: Fri, 12 Nov 2021 21:16:53 +0000 Subject: [PATCH 09/26] Add try/except guards --- tests/test_models.py | 15 ++++++++++++++- timm/models/fx_features.py | 6 ++++-- 2 files changed, 18 insertions(+), 3 deletions(-) diff --git a/tests/test_models.py b/tests/test_models.py index f7233ef3..e513dcaf 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -4,7 +4,11 @@ import platform import os import fnmatch -from torchvision.models.feature_extraction import create_feature_extractor, get_graph_node_names, NodePathTracer +try: + from torchvision.models.feature_extraction import create_feature_extractor, get_graph_node_names, NodePathTracer + has_fx_feature_extraction = True +except ImportError: + has_fx_feature_extraction = False import timm from timm import list_models, create_model, set_scriptable, has_model_default_key, is_model_default_key, \ @@ -307,6 +311,9 @@ def test_model_forward_features(model_name, batch_size): @pytest.mark.parametrize('batch_size', [1]) def test_model_forward_fx(model_name, batch_size): """Symbolically trace each model and run single forward pass through the resulting GraphModule""" + if not has_fx_feature_extraction: + pytest.skip("Can't test FX because Torch >= 1.10 and Torchvision >= 0.11 are required") + model = create_model(model_name, pretrained=False) model.eval() @@ -332,6 +339,9 @@ def test_model_forward_fx(model_name, batch_size): @pytest.mark.parametrize('batch_size', [2]) def test_model_backward_fx(model_name, batch_size): """Symbolically trace each model and run single backward pass through the resulting GraphModule""" + if not has_fx_feature_extraction: + pytest.skip("Can't test FX because Torch >= 1.10 and Torchvision >= 0.11 are required") + input_size = _get_input_size(model_name=model_name, target=TARGET_BWD_SIZE) if max(input_size) > MAX_BWD_SIZE: pytest.skip("Fixed input size model > limit.") @@ -387,6 +397,9 @@ EXCLUDE_FX_JIT_FILTERS = [ @pytest.mark.parametrize('batch_size', [1]) def test_model_forward_fx_torchscript(model_name, batch_size): """Symbolically trace each model, script it, and run single forward pass""" + if not has_fx_feature_extraction: + pytest.skip("Can't test FX because Torch >= 1.10 and Torchvision >= 0.11 are required") + input_size = _get_input_size(model_name=model_name, target=TARGET_JIT_SIZE) if max(input_size) > MAX_JIT_SIZE: pytest.skip("Fixed input size model > limit.") diff --git a/timm/models/fx_features.py b/timm/models/fx_features.py index c8f296d4..a582cf9b 100644 --- a/timm/models/fx_features.py +++ b/timm/models/fx_features.py @@ -8,8 +8,9 @@ from .features import _get_feature_info try: from torchvision.models.feature_extraction import create_feature_extractor + has_fx_feature_extraction = True except ImportError: - pass + has_fx_feature_extraction = False # Layers we went to treat as leaf modules from .layers import Conv2dSame, ScaledStdConv2dSame, BatchNormAct2d, BlurPool2d, CondConv2d, StdConv2dSame, DropPath @@ -58,6 +59,7 @@ def register_autowrap_function(func: Callable): class FeatureGraphNet(nn.Module): def __init__(self, model, out_indices, out_map=None): super().__init__() + assert has_fx_feature_extraction, 'Please update to PyTorch 1.10+, torchvision 0.11+ for FX feature extraction' self.feature_info = _get_feature_info(model, out_indices) if out_map is not None: assert len(out_map) == len(out_indices) @@ -66,7 +68,7 @@ class FeatureGraphNet(nn.Module): self.graph_module = create_feature_extractor( model, return_nodes, tracer_kwargs={'leaf_modules': list(_leaf_modules), 'autowrap_functions': list(_autowrap_functions)}) - + def forward(self, x): return list(self.graph_module(x).values()) \ No newline at end of file From 0262a0e8e16c31a4e8157a44dafbe5f6b8c21495 Mon Sep 17 00:00:00 2001 From: Alexander Soare Date: Sat, 13 Nov 2021 00:06:33 +0000 Subject: [PATCH 10/26] fx ready for review --- tests/test_models.py | 37 ++++++++++++++++++++++++++++++------- timm/models/nfnet.py | 2 ++ 2 files changed, 32 insertions(+), 7 deletions(-) diff --git a/tests/test_models.py b/tests/test_models.py index e513dcaf..93152d9a 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -310,7 +310,10 @@ def test_model_forward_features(model_name, batch_size): @pytest.mark.parametrize('model_name', list_models(exclude_filters=EXCLUDE_FILTERS)) @pytest.mark.parametrize('batch_size', [1]) def test_model_forward_fx(model_name, batch_size): - """Symbolically trace each model and run single forward pass through the resulting GraphModule""" + """ + Symbolically trace each model and run single forward pass through the resulting GraphModule + Also check that the output of a forward pass through the GraphModule is the same as that from the original Module + """ if not has_fx_feature_extraction: pytest.skip("Can't test FX because Torch >= 1.10 and Torchvision >= 0.11 are required") @@ -321,15 +324,32 @@ def test_model_forward_fx(model_name, batch_size): if max(input_size) > MAX_FWD_SIZE: pytest.skip("Fixed input size model > limit.") + # This block of code does a bit of juggling to handle any case where there are multiple outputs in train mode + # So we trace once and look at the graph, and get the indices of the nodes that lead into the original fx output + # node. Then we use those indices to select from train_nodes returned by torchvision get_graph_node_names + tracer = NodePathTracer(leaf_modules=list(_leaf_modules), autowrap_functions=list(_autowrap_functions)) + graph = tracer.trace(model) + graph_nodes = list(reversed(graph.nodes)) + output_node_names = [n.name for n in graph_nodes[0]._input_nodes.keys()] + graph_node_names = [n.name for n in graph_nodes] + output_node_indices = [-graph_node_names.index(node_name) for node_name in output_node_names] train_nodes, eval_nodes = get_graph_node_names( model, tracer_kwargs={'leaf_modules': list(_leaf_modules), 'autowrap_functions': list(_autowrap_functions)}) - model = create_feature_extractor( - model, train_return_nodes=[train_nodes[-1]], eval_return_nodes=[eval_nodes[-1]], + eval_return_nodes = [eval_nodes[ix] for ix in output_node_indices] + + fx_model = create_feature_extractor( + model, train_return_nodes=[train_nodes[-1]], eval_return_nodes=eval_return_nodes, tracer_kwargs={'leaf_modules': list(_leaf_modules), 'autowrap_functions': list(_autowrap_functions)}) inputs = torch.randn((batch_size, *input_size)) - outputs = model(inputs)[eval_nodes[-1]] + outputs = model(inputs) + if isinstance(outputs, tuple): + outputs = torch.cat(outputs) + fx_outputs = tuple(fx_model(inputs).values()) + if isinstance(fx_outputs, tuple): + fx_outputs = torch.cat(fx_outputs) + assert torch.all(fx_outputs == outputs) assert outputs.shape[0] == batch_size assert not torch.isnan(outputs).any(), 'Output included NaNs' @@ -348,6 +368,7 @@ def test_model_backward_fx(model_name, batch_size): model = create_model(model_name, pretrained=False, num_classes=42) model.train() + num_params = sum([x.numel() for x in model.parameters()]) input_size = _get_input_size(model=model, target=TARGET_FWD_SIZE) @@ -355,7 +376,6 @@ def test_model_backward_fx(model_name, batch_size): pytest.skip("Fixed input size model > limit.") # This block of code does a bit of juggling to handle any case where there are multiple outputs in train mode - # If so, we need to return all of them in order to check all grads # So we trace once and look at the graph, and get the indices of the nodes that lead into the original fx output # node. Then we use those indices to select from train_nodes returned by torchvision get_graph_node_names tracer = NodePathTracer(leaf_modules=list(_leaf_modules), autowrap_functions=list(_autowrap_functions)) @@ -385,9 +405,12 @@ def test_model_backward_fx(model_name, batch_size): assert num_params == num_grad, 'Some parameters are missing gradients' assert not torch.isnan(outputs).any(), 'Output included NaNs' - +# reason: model is scripted after fx tracing, but beit has torch.jit.is_scripting() control flow EXCLUDE_FX_JIT_FILTERS = [ - 'beit_*' # reason: model is scripted after fx tracing, but beit has torch.jit.is_scripting() control flow + 'beit_*', + 'deit_*_distilled_patch16_224', + 'levit*', + 'pit_*_distilled_224', ] @pytest.mark.timeout(120) diff --git a/timm/models/nfnet.py b/timm/models/nfnet.py index 4e0f2b21..1d6cbb38 100644 --- a/timm/models/nfnet.py +++ b/timm/models/nfnet.py @@ -26,6 +26,7 @@ import torch import torch.nn as nn from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from .fx_features import register_leaf_module from .helpers import build_model_with_cfg from .registry import register_model from .layers import ClassifierHead, DropPath, AvgPool2dSame, ScaledStdConv2d, ScaledStdConv2dSame,\ @@ -318,6 +319,7 @@ class DownsampleAvg(nn.Module): return self.conv(self.pool(x)) +@register_leaf_module # reason: mul_ causes FX to drop a relevant node. https://github.com/pytorch/pytorch/issues/68301 class NormFreeBlock(nn.Module): """Normalization-Free pre-activation block. """ From 5220711d8727d87ac21312443df64a99750fe046 Mon Sep 17 00:00:00 2001 From: Martins Bruveris Date: Sun, 14 Nov 2021 11:01:48 +0000 Subject: [PATCH 11/26] Added B/8 models to ViT. --- timm/models/vision_transformer.py | 28 ++++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/timm/models/vision_transformer.py b/timm/models/vision_transformer.py index 94ae2666..6e568abf 100644 --- a/timm/models/vision_transformer.py +++ b/timm/models/vision_transformer.py @@ -88,6 +88,9 @@ default_cfgs = { url='https://storage.googleapis.com/vit_models/augreg/' 'B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_384.npz', input_size=(3, 384, 384), crop_pct=1.0), + 'vit_base_patch8_224': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/' + 'B_8-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_224.npz'), 'vit_large_patch32_224': _cfg( url='', # no official model weights for this combo, only for in21k ), @@ -118,6 +121,9 @@ default_cfgs = { 'vit_base_patch16_224_in21k': _cfg( url='https://storage.googleapis.com/vit_models/augreg/B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0.npz', num_classes=21843), + 'vit_base_patch8_224_in21k': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/B_8-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0.npz', + num_classes=21843), 'vit_large_patch32_224_in21k': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_patch32_224_in21k-9046d2e7.pth', num_classes=21843), @@ -640,6 +646,16 @@ def vit_base_patch16_384(pretrained=False, **kwargs): return model +@register_model +def vit_base_patch8_224(pretrained=False, **kwargs): + """ ViT-Base (ViT-B/8) from original paper (https://arxiv.org/abs/2010.11929). + ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer. + """ + model_kwargs = dict(patch_size=8, embed_dim=768, depth=12, num_heads=12, **kwargs) + model = _create_vision_transformer('vit_base_patch8_224', pretrained=pretrained, **model_kwargs) + return model + + @register_model def vit_large_patch32_224(pretrained=False, **kwargs): """ ViT-Large model (ViT-L/32) from original paper (https://arxiv.org/abs/2010.11929). No pretrained weights. @@ -756,6 +772,18 @@ def vit_base_patch16_224_in21k(pretrained=False, **kwargs): return model +@register_model +def vit_base_patch8_224_in21k(pretrained=False, **kwargs): + """ ViT-Base model (ViT-B/8) from original paper (https://arxiv.org/abs/2010.11929). + ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer. + NOTE: this model has valid 21k classifier head and no representation (pre-logits) layer + """ + model_kwargs = dict( + patch_size=8, embed_dim=768, depth=12, num_heads=12, **kwargs) + model = _create_vision_transformer('vit_base_patch8_224_in21k', pretrained=pretrained, **model_kwargs) + return model + + @register_model def vit_large_patch32_224_in21k(pretrained=False, **kwargs): """ ViT-Large model (ViT-L/32) from original paper (https://arxiv.org/abs/2010.11929). From cfa414cad2662502fd6b71bffd720fbd8cfba395 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Sun, 14 Nov 2021 12:52:19 -0800 Subject: [PATCH 12/26] Matching two bits_and_tpu changes for TFDs wrapper * change 'samples' -> 'examples' for tfds wrapper to match tfds naming * add class_to_idx for image classification datasets in tfds wrapper --- timm/data/parsers/parser_tfds.py | 103 ++++++++++++++++++++----------- 1 file changed, 66 insertions(+), 37 deletions(-) diff --git a/timm/data/parsers/parser_tfds.py b/timm/data/parsers/parser_tfds.py index 67db6891..ee5893c4 100644 --- a/timm/data/parsers/parser_tfds.py +++ b/timm/data/parsers/parser_tfds.py @@ -30,38 +30,47 @@ from .parser import Parser MAX_TP_SIZE = 8 # maximum TF threadpool size, only doing jpeg decodes and queuing activities -SHUFFLE_SIZE = 16384 # samples to shuffle in DS queue -PREFETCH_SIZE = 2048 # samples to prefetch +SHUFFLE_SIZE = 8192 # examples to shuffle in DS queue +PREFETCH_SIZE = 2048 # examples to prefetch -def even_split_indices(split, n, num_samples): - partitions = [round(i * num_samples / n) for i in range(n + 1)] - return [f"{split}[{partitions[i]}:{partitions[i+1]}]" for i in range(n)] +def even_split_indices(split, n, num_examples): + partitions = [round(i * num_examples / n) for i in range(n + 1)] + return [f"{split}[{partitions[i]}:{partitions[i + 1]}]" for i in range(n)] + + +def get_class_labels(info): + if 'label' not in info.features: + return {} + class_label = info.features['label'] + class_to_idx = {n: class_label.str2int(n) for n in class_label.names} + return class_to_idx class ParserTfds(Parser): """ Wrap Tensorflow Datasets for use in PyTorch There several things to be aware of: - * To prevent excessive samples being dropped per epoch w/ distributed training or multiplicity of + * To prevent excessive examples being dropped per epoch w/ distributed training or multiplicity of dataloader workers, the train iterator wraps to avoid returning partial batches that trigger drop_last https://github.com/pytorch/pytorch/issues/33413 * With PyTorch IterableDatasets, each worker in each replica operates in isolation, the final batch from each worker could be a different size. For training this is worked around by option above, for - validation extra samples are inserted iff distributed mode is enabled so that the batches being reduced + validation extra examples are inserted iff distributed mode is enabled so that the batches being reduced across replicas are of same size. This will slightly alter the results, distributed validation will not be 100% correct. This is similar to common handling in DistributedSampler for normal Datasets but a bit worse - since there are up to N * J extra samples with IterableDatasets. + since there are up to N * J extra examples with IterableDatasets. * The sharding (splitting of dataset into TFRecord) files imposes limitations on the number of replicas and dataloader workers you can use. For really small datasets that only contain a few shards you may have to train non-distributed w/ 1-2 dataloader workers. This is likely not a huge concern as the benefit of distributed training or fast dataloading should be much less for small datasets. - * This wrapper is currently configured to return individual, decompressed image samples from the TFDS + * This wrapper is currently configured to return individual, decompressed image examples from the TFDS dataset. The augmentation (transforms) and batching is still done in PyTorch. It would be possible to specify TF augmentation fn and return augmented batches w/ some modifications to other downstream components. """ + def __init__( self, root, @@ -72,6 +81,10 @@ class ParserTfds(Parser): download=False, repeats=0, seed=42, + input_name='image', + input_image='RGB', + target_name='label', + target_image='', prefetch_size=None, shuffle_size=None, max_threadpool_size=None @@ -83,10 +96,14 @@ class ParserTfds(Parser): name: tfds dataset name (eg `imagenet2012`) split: tfds dataset split (can use all TFDS split strings eg `train[:10%]`) is_training: training mode, shuffle enabled, dataset len rounded by batch_size - batch_size: batch_size to use to unsure total samples % batch_size == 0 in training across all dis nodes + batch_size: batch_size to use to unsure total examples % batch_size == 0 in training across all dis nodes download: download and build TFDS dataset if set, otherwise must use tfds CLI repeats: iterate through (repeat) the dataset this many times per iteration (once if 0 or 1) seed: common seed for shard shuffle across all distributed/worker instances + input_name: name of Feature to return as data (input) + input_image: image mode if input is an image (currently PIL mode string) + target_name: name of Feature to return as target (label) + target_image: image mode if target is an image (currently PIL mode string) prefetch_size: override default tf.data prefetch buffer size shuffle_size: override default tf.data shuffle buffer size max_threadpool_size: override default threadpool size for tf.data @@ -96,22 +113,29 @@ class ParserTfds(Parser): self.split = split self.is_training = is_training if self.is_training: - assert batch_size is not None,\ + assert batch_size is not None, \ "Must specify batch_size in training mode for reasonable behaviour w/ TFDS wrapper" self.batch_size = batch_size self.repeats = repeats self.common_seed = seed # a seed that's fixed across all worker / distributed instances + + # performance settings self.prefetch_size = prefetch_size or PREFETCH_SIZE self.shuffle_size = shuffle_size or SHUFFLE_SIZE self.max_threadpool_size = max_threadpool_size or MAX_TP_SIZE # TFDS builder and split information + self.input_name = input_name # FIXME support tuples / lists of inputs and targets and full range of Feature + self.input_image = input_image + self.target_name = target_name + self.target_image = target_image self.builder = tfds.builder(name, data_dir=root) # NOTE: the tfds command line app can be used download & prepare datasets if you don't enable download flag if download: self.builder.download_and_prepare() + self.class_to_idx = get_class_labels(self.builder.info) if self.target_name == 'label' else {} self.split_info = self.builder.info.splits[split] - self.num_samples = self.split_info.num_examples + self.num_examples = self.split_info.num_examples # Distributed world state self.dist_rank = 0 @@ -154,21 +178,21 @@ class ParserTfds(Parser): InputContext will assign subset of underlying TFRecord files to each 'pipeline' if used. My understanding is that using split, the underling TFRecord files will shuffle (shuffle_files=True) between the splits each iteration, but that understanding could be wrong. - + I am currently using a mix of InputContext shard assignment and fine-grained sub-splits for distributing the data across workers. For training InputContext is used to assign shards to nodes unless num_shards in dataset < total number of workers. Otherwise sub-split API is used for datasets without enough shards or - for validation where we can't drop samples and need to avoid minimize uneven splits to avoid padding. + for validation where we can't drop examples and need to avoid minimize uneven splits to avoid padding. """ should_subsplit = self.global_num_workers > 1 and ( self.split_info.num_shards < self.global_num_workers or not self.is_training) if should_subsplit: - # split the dataset w/o using sharding for more even samples / worker, can result in less optimal + # split the dataset w/o using sharding for more even examples / worker, can result in less optimal # read patterns for distributed training (overlap across shards) so better to use InputContext there if has_buggy_even_splits: # my even_split workaround doesn't work on subsplits, upgrade tfds! if not isinstance(self.split_info, tfds.core.splits.SubSplitInfo): - subsplits = even_split_indices(self.split, self.global_num_workers, self.num_samples) + subsplits = even_split_indices(self.split, self.global_num_workers, self.num_examples) self.subsplit = subsplits[global_worker_id] else: subsplits = tfds.even_splits(self.split, self.global_num_workers) @@ -199,8 +223,8 @@ class ParserTfds(Parser): # see warnings at https://pytorch.org/docs/stable/data.html#multi-process-data-loading ds = ds.repeat() # allow wrap around and break iteration manually if self.is_training: - ds = ds.shuffle(min(self.num_samples, self.shuffle_size) // self.global_num_workers, seed=self.worker_seed) - ds = ds.prefetch(min(self.num_samples // self.global_num_workers, self.prefetch_size)) + ds = ds.shuffle(min(self.num_examples, self.shuffle_size) // self.global_num_workers, seed=self.worker_seed) + ds = ds.prefetch(min(self.num_examples // self.global_num_workers, self.prefetch_size)) self.ds = tfds.as_numpy(ds) def __iter__(self): @@ -209,44 +233,49 @@ class ParserTfds(Parser): # Compute a rounded up sample count that is used to: # 1. make batches even cross workers & replicas in distributed validation. - # This adds extra samples and will slightly alter validation results. + # This adds extra examples and will slightly alter validation results. # 2. determine loop ending condition in training w/ repeat enabled so that only full batch_size # batches are produced (underlying tfds iter wraps around) - target_sample_count = math.ceil(max(1, self.repeats) * self.num_samples / self.global_num_workers) + target_example_count = math.ceil(max(1, self.repeats) * self.num_examples / self.global_num_workers) if self.is_training: # round up to nearest batch_size per worker-replica - target_sample_count = math.ceil(target_sample_count / self.batch_size) * self.batch_size + target_example_count = math.ceil(target_example_count / self.batch_size) * self.batch_size # Iterate until exhausted or sample count hits target when training (ds.repeat enabled) - sample_count = 0 - for sample in self.ds: - img = Image.fromarray(sample['image'], mode='RGB') - yield img, sample['label'] - sample_count += 1 - if self.is_training and sample_count >= target_sample_count: + example_count = 0 + for example in self.ds: + input_data = example[self.input_name] + if self.input_image: + input_data = Image.fromarray(input_data, mode=self.input_image) + target_data = example[self.target_name] + if self.target_image: + target_data = Image.fromarray(target_data, mode=self.target_image) + yield input_data, target_data + example_count += 1 + if self.is_training and example_count >= target_example_count: # Need to break out of loop when repeat() is enabled for training w/ oversampling - # this results in extra samples per epoch but seems more desirable than dropping + # this results in extra examples per epoch but seems more desirable than dropping # up to N*J batches per epoch (where N = num distributed processes, and J = num worker processes) break - # Pad across distributed nodes (make counts equal by adding samples) + # Pad across distributed nodes (make counts equal by adding examples) if not self.is_training and self.dist_num_replicas > 1 and self.subsplit is not None and \ - 0 < sample_count < target_sample_count: + 0 < example_count < target_example_count: # Validation batch padding only done for distributed training where results are reduced across nodes. # For single process case, it won't matter if workers return different batch sizes. # If using input_context or % based splits, sample count can vary significantly across workers and this # approach should not be used (hence disabled if self.subsplit isn't set). - while sample_count < target_sample_count: - yield img, sample['label'] # yield prev sample again - sample_count += 1 + while example_count < target_example_count: + yield input_data, target_data # yield prev sample again + example_count += 1 def __len__(self): - # this is just an estimate and does not factor in extra samples added to pad batches based on + # this is just an estimate and does not factor in extra examples added to pad batches based on # complete worker & replica info (not available until init in dataloader). - return math.ceil(max(1, self.repeats) * self.num_samples / self.dist_num_replicas) + return math.ceil(max(1, self.repeats) * self.num_examples / self.dist_num_replicas) def _filename(self, index, basename=False, absolute=False): - assert False, "Not supported" # no random access to samples + assert False, "Not supported" # no random access to examples def filenames(self, basename=False, absolute=False): """ Return all filenames in dataset, overrides base""" @@ -254,7 +283,7 @@ class ParserTfds(Parser): self._lazy_init() names = [] for sample in self.ds: - if len(names) > self.num_samples: + if len(names) > self.num_examples: break # safety for ds.repeat() case if 'file_name' in sample: name = sample['file_name'] From 9b2daf2a35184854d20325d6b0f69d99e607cc7d Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Sun, 14 Nov 2021 13:17:27 -0800 Subject: [PATCH 13/26] Add ResNeXt-50 weights 81.1 top-1 @ 224, 82 @ 288 with A1 'high aug' recipe --- timm/models/resnet.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/timm/models/resnet.py b/timm/models/resnet.py index 1c7cbba2..998a739e 100644 --- a/timm/models/resnet.py +++ b/timm/models/resnet.py @@ -91,8 +91,8 @@ default_cfgs = { # ResNeXt 'resnext50_32x4d': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnext50_32x4d_ra-d733960d.pth', - interpolation='bicubic'), + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/resnext50_32x4d_a1h-0146ab0a.pth', + interpolation='bicubic', crop_pct=0.95), 'resnext50d_32x4d': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnext50d_32x4d-103e99f8.pth', interpolation='bicubic', From 65d827c7a6739b20dcd4c57216f20adc521a6b2a Mon Sep 17 00:00:00 2001 From: Alexander Soare Date: Mon, 15 Nov 2021 21:03:21 +0000 Subject: [PATCH 14/26] rename notrace registration and standardize trace_utils imports --- timm/models/coat.py | 2 +- timm/models/convit.py | 4 ++-- timm/models/crossvit.py | 4 ++-- timm/models/fx_features.py | 4 ++-- timm/models/nest.py | 6 +++--- timm/models/nfnet.py | 4 ++-- timm/models/swin_transformer.py | 6 +++--- timm/models/tnt.py | 2 +- timm/models/twins.py | 4 ++-- timm/models/vgg.py | 4 ++-- timm/models/xcit.py | 4 ++-- 11 files changed, 22 insertions(+), 22 deletions(-) diff --git a/timm/models/coat.py b/timm/models/coat.py index dca655ea..18ff8ab9 100644 --- a/timm/models/coat.py +++ b/timm/models/coat.py @@ -19,7 +19,7 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from .helpers import build_model_with_cfg, overlay_external_default_cfg from .layers import PatchEmbed, Mlp, DropPath, to_2tuple, trunc_normal_ from .registry import register_model -from .layers.trace_utils import _assert +from .layers import _assert __all__ = [ diff --git a/timm/models/convit.py b/timm/models/convit.py index d2f69b68..6ef1da72 100644 --- a/timm/models/convit.py +++ b/timm/models/convit.py @@ -30,7 +30,7 @@ from .helpers import build_model_with_cfg from .layers import DropPath, to_2tuple, trunc_normal_, PatchEmbed, Mlp from .registry import register_model from .vision_transformer_hybrid import HybridEmbed -from .fx_features import register_leaf_module +from .fx_features import register_notrace_module import torch import torch.nn as nn @@ -57,7 +57,7 @@ default_cfgs = { } -@register_leaf_module # reason: FX can't symbolically trace control flow in forward method +@register_notrace_module # reason: FX can't symbolically trace control flow in forward method class GPSA(nn.Module): def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0., locality_strength=1.): diff --git a/timm/models/crossvit.py b/timm/models/crossvit.py index 3ba5b4c7..ddc4f64c 100644 --- a/timm/models/crossvit.py +++ b/timm/models/crossvit.py @@ -32,7 +32,7 @@ from functools import partial from typing import List from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from .fx_features import register_autowrap_function +from .fx_features import register_notrace_function from .helpers import build_model_with_cfg from .layers import DropPath, to_2tuple, trunc_normal_, _assert from .registry import register_model @@ -259,7 +259,7 @@ def _compute_num_patches(img_size, patches): return [i[0] // p * i[1] // p for i, p in zip(img_size, patches)] -@register_autowrap_function +@register_notrace_function def scale_image(x, ss: Tuple[int, int], crop_scale: bool = False): # annotations for torchscript """ Pulled out of CrossViT.forward_features to bury conditional logic in a leaf node for FX tracing. diff --git a/timm/models/fx_features.py b/timm/models/fx_features.py index a582cf9b..2e01586b 100644 --- a/timm/models/fx_features.py +++ b/timm/models/fx_features.py @@ -36,7 +36,7 @@ except ImportError: pass -def register_leaf_module(module: nn.Module): +def register_notrace_module(module: nn.Module): """ Any module not under timm.models.layers should get this decorator if we don't want to trace through it. """ @@ -48,7 +48,7 @@ def register_leaf_module(module: nn.Module): _autowrap_functions = set() -def register_autowrap_function(func: Callable): +def register_notrace_function(func: Callable): """ Decorator for functions which ought not to be traced through """ diff --git a/timm/models/nest.py b/timm/models/nest.py index c5951aea..22cf6099 100644 --- a/timm/models/nest.py +++ b/timm/models/nest.py @@ -25,10 +25,10 @@ import torch.nn.functional as F from torch import nn from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from .fx_features import register_autowrap_function +from .fx_features import register_notrace_function from .helpers import build_model_with_cfg, named_apply from .layers import PatchEmbed, Mlp, DropPath, create_classifier, trunc_normal_ -from .layers.trace_utils import _assert +from .layers import _assert from .layers import create_conv2d, create_pool2d, to_ntuple from .registry import register_model @@ -155,7 +155,7 @@ def blockify(x, block_size: int): return x # (B, T, N, C) -@register_autowrap_function # reason: int receives Proxy +@register_notrace_function # reason: int receives Proxy def deblockify(x, block_size: int): """blocks to image Args: diff --git a/timm/models/nfnet.py b/timm/models/nfnet.py index 1d6cbb38..973cbd66 100644 --- a/timm/models/nfnet.py +++ b/timm/models/nfnet.py @@ -26,7 +26,7 @@ import torch import torch.nn as nn from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from .fx_features import register_leaf_module +from .fx_features import register_notrace_module from .helpers import build_model_with_cfg from .registry import register_model from .layers import ClassifierHead, DropPath, AvgPool2dSame, ScaledStdConv2d, ScaledStdConv2dSame,\ @@ -319,7 +319,7 @@ class DownsampleAvg(nn.Module): return self.conv(self.pool(x)) -@register_leaf_module # reason: mul_ causes FX to drop a relevant node. https://github.com/pytorch/pytorch/issues/68301 +@register_notrace_module # reason: mul_ causes FX to drop a relevant node. https://github.com/pytorch/pytorch/issues/68301 class NormFreeBlock(nn.Module): """Normalization-Free pre-activation block. """ diff --git a/timm/models/swin_transformer.py b/timm/models/swin_transformer.py index d5dd5513..92057902 100644 --- a/timm/models/swin_transformer.py +++ b/timm/models/swin_transformer.py @@ -21,10 +21,10 @@ import torch.nn as nn import torch.utils.checkpoint as checkpoint from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from .fx_features import register_autowrap_function +from .fx_features import register_notrace_function from .helpers import build_model_with_cfg, overlay_external_default_cfg from .layers import PatchEmbed, Mlp, DropPath, to_2tuple, trunc_normal_ -from .layers.trace_utils import _assert +from .layers import _assert from .registry import register_model from .vision_transformer import checkpoint_filter_fn, _init_vit_weights @@ -103,7 +103,7 @@ def window_partition(x, window_size: int): return windows -@register_autowrap_function # reason: int argument is a Proxy +@register_notrace_function # reason: int argument is a Proxy def window_reverse(windows, window_size: int, H: int, W: int): """ Args: diff --git a/timm/models/tnt.py b/timm/models/tnt.py index 1ad481f6..d52f9ce6 100644 --- a/timm/models/tnt.py +++ b/timm/models/tnt.py @@ -14,7 +14,7 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.models.helpers import build_model_with_cfg from timm.models.layers import Mlp, DropPath, trunc_normal_ from timm.models.layers.helpers import to_2tuple -from timm.models.layers.trace_utils import _assert +from timm.models.layers import _assert from timm.models.registry import register_model from timm.models.vision_transformer import resize_pos_embed diff --git a/timm/models/twins.py b/timm/models/twins.py index 9ae70d32..67a939d4 100644 --- a/timm/models/twins.py +++ b/timm/models/twins.py @@ -22,7 +22,7 @@ from functools import partial from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from .layers import Mlp, DropPath, to_2tuple, trunc_normal_ -from .fx_features import register_leaf_module +from .fx_features import register_notrace_module from .registry import register_model from .vision_transformer import Attention from .helpers import build_model_with_cfg @@ -63,7 +63,7 @@ default_cfgs = { Size_ = Tuple[int, int] -@register_leaf_module # reason: FX can't symbolically trace control flow in forward method +@register_notrace_module # reason: FX can't symbolically trace control flow in forward method class LocallyGroupedAttn(nn.Module): """ LSA: self attention within a group """ diff --git a/timm/models/vgg.py b/timm/models/vgg.py index 0f62ac4e..11f6d0ea 100644 --- a/timm/models/vgg.py +++ b/timm/models/vgg.py @@ -12,7 +12,7 @@ from typing import Union, List, Dict, Any, cast from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from .helpers import build_model_with_cfg -from .fx_features import register_leaf_module +from .fx_features import register_notrace_module from .layers import ClassifierHead from .registry import register_model @@ -53,7 +53,7 @@ cfgs: Dict[str, List[Union[str, int]]] = { } -@register_leaf_module # reason: FX can't symbolically trace control flow in forward method +@register_notrace_module # reason: FX can't symbolically trace control flow in forward method class ConvMlp(nn.Module): def __init__(self, in_features=512, out_features=4096, kernel_size=7, mlp_ratio=1.0, diff --git a/timm/models/xcit.py b/timm/models/xcit.py index f5dd0683..ac5e802c 100644 --- a/timm/models/xcit.py +++ b/timm/models/xcit.py @@ -21,7 +21,7 @@ from .vision_transformer import _cfg, Mlp from .registry import register_model from .layers import DropPath, trunc_normal_, to_2tuple from .cait import ClassAttn -from .fx_features import register_leaf_module +from .fx_features import register_notrace_module def _cfg(url='', **kwargs): @@ -98,7 +98,7 @@ default_cfgs = { } -@register_leaf_module # reason: FX can't symbolically trace torch.arange in forward method +@register_notrace_module # reason: FX can't symbolically trace torch.arange in forward method class PositionalEncodingFourier(nn.Module): """ Positional encoding relying on a fourier kernel matching the one used in the "Attention is all of Need" paper. From 78b36bf46c21be8557b8b84ff9261637d234fe47 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Thu, 18 Nov 2021 14:59:51 -0800 Subject: [PATCH 15/26] Places365 doesn't exist in some still used torchvision version --- timm/data/dataset_factory.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/timm/data/dataset_factory.py b/timm/data/dataset_factory.py index 03b03cf5..e86bcc29 100644 --- a/timm/data/dataset_factory.py +++ b/timm/data/dataset_factory.py @@ -1,7 +1,11 @@ import os -from torchvision.datasets import CIFAR100, CIFAR10, MNIST, QMNIST, KMNIST, FashionMNIST,\ - Places365, ImageNet, ImageFolder +from torchvision.datasets import CIFAR100, CIFAR10, MNIST, QMNIST, KMNIST, FashionMNIST, ImageNet, ImageFolder +try: + from torchvision.datasets import Places365 + has_places365 = True +except ImportError: + has_places365 = False try: from torchvision.datasets import INaturalist has_inaturalist = True @@ -104,6 +108,7 @@ def create_dataset( split = '2021_valid' ds = INaturalist(version=split, target_type=target_type, **torch_kwargs) elif name == 'places365': + assert has_places365, 'Please update to a newer PyTorch and torchvision for Places365 dataset.' if split in _TRAIN_SYNONYM: split = 'train-standard' elif split in _EVAL_SYNONYM: From 1076a65df1fa8fe0612adbe21284f4722b585dac Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Thu, 18 Nov 2021 19:47:07 -0800 Subject: [PATCH 16/26] Minor post FX merge cleanup --- tests/test_models.py | 6 +++--- timm/models/fx_features.py | 1 - 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/tests/test_models.py b/tests/test_models.py index 93152d9a..7a3f143e 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -315,7 +315,7 @@ def test_model_forward_fx(model_name, batch_size): Also check that the output of a forward pass through the GraphModule is the same as that from the original Module """ if not has_fx_feature_extraction: - pytest.skip("Can't test FX because Torch >= 1.10 and Torchvision >= 0.11 are required") + pytest.skip("Can't test FX. Torch >= 1.10 and Torchvision >= 0.11 are required.") model = create_model(model_name, pretrained=False) model.eval() @@ -360,7 +360,7 @@ def test_model_forward_fx(model_name, batch_size): def test_model_backward_fx(model_name, batch_size): """Symbolically trace each model and run single backward pass through the resulting GraphModule""" if not has_fx_feature_extraction: - pytest.skip("Can't test FX because Torch >= 1.10 and Torchvision >= 0.11 are required") + pytest.skip("Can't test FX. Torch >= 1.10 and Torchvision >= 0.11 are required.") input_size = _get_input_size(model_name=model_name, target=TARGET_BWD_SIZE) if max(input_size) > MAX_BWD_SIZE: @@ -421,7 +421,7 @@ EXCLUDE_FX_JIT_FILTERS = [ def test_model_forward_fx_torchscript(model_name, batch_size): """Symbolically trace each model, script it, and run single forward pass""" if not has_fx_feature_extraction: - pytest.skip("Can't test FX because Torch >= 1.10 and Torchvision >= 0.11 are required") + pytest.skip("Can't test FX. Torch >= 1.10 and Torchvision >= 0.11 are required.") input_size = _get_input_size(model_name=model_name, target=TARGET_JIT_SIZE) if max(input_size) > MAX_JIT_SIZE: diff --git a/timm/models/fx_features.py b/timm/models/fx_features.py index 2e01586b..5a25ee3e 100644 --- a/timm/models/fx_features.py +++ b/timm/models/fx_features.py @@ -71,4 +71,3 @@ class FeatureGraphNet(nn.Module): def forward(self, x): return list(self.graph_module(x).values()) - \ No newline at end of file From f2006b24370338643e30eba72c3b7a124ee4b5b3 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Thu, 18 Nov 2021 21:25:00 -0800 Subject: [PATCH 17/26] Cleanup qkv_bias cat in beit model so it can be traced --- tests/test_models.py | 1 - timm/models/beit.py | 10 +++------- 2 files changed, 3 insertions(+), 8 deletions(-) diff --git a/tests/test_models.py b/tests/test_models.py index 7a3f143e..39e2dcdc 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -407,7 +407,6 @@ def test_model_backward_fx(model_name, batch_size): # reason: model is scripted after fx tracing, but beit has torch.jit.is_scripting() control flow EXCLUDE_FX_JIT_FILTERS = [ - 'beit_*', 'deit_*_distilled_patch16_224', 'levit*', 'pit_*_distilled_224', diff --git a/timm/models/beit.py b/timm/models/beit.py index 199c2a4b..f644b657 100644 --- a/timm/models/beit.py +++ b/timm/models/beit.py @@ -86,9 +86,11 @@ class Attention(nn.Module): self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False) if qkv_bias: self.q_bias = nn.Parameter(torch.zeros(all_head_dim)) + self.register_buffer('k_bias', torch.zeros(all_head_dim), persistent=False) self.v_bias = nn.Parameter(torch.zeros(all_head_dim)) else: self.q_bias = None + self.k_bias = None self.v_bias = None if window_size: @@ -127,13 +129,7 @@ class Attention(nn.Module): def forward(self, x, rel_pos_bias: Optional[torch.Tensor] = None): B, N, C = x.shape - qkv_bias = None - if self.q_bias is not None: - if torch.jit.is_scripting(): - # FIXME requires_grad breaks w/ torchscript - qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias), self.v_bias)) - else: - qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias)) + qkv_bias = torch.cat((self.q_bias, self.k_bias, self.v_bias)) if self.q_bias is not None else None qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias) qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple) From 9d6aad44f8fd32e89e5cca503efe3ada5071cc2a Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Thu, 18 Nov 2021 21:38:04 -0800 Subject: [PATCH 18/26] Update tests to run Python 3.9, PyTorch 1.10, torchvision 0.11.1 --- .github/workflows/tests.yml | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 1136f306..9e0a4aac 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -16,9 +16,9 @@ jobs: strategy: matrix: os: [ubuntu-latest, macOS-latest] - python: ['3.8'] - torch: ['1.9.0'] - torchvision: ['0.10.0'] + python: ['3.9'] + torch: ['1.10.0'] + torchvision: ['0.11.1'] runs-on: ${{ matrix.os }} steps: @@ -30,7 +30,7 @@ jobs: - name: Install testing dependencies run: | python -m pip install --upgrade pip - pip install pytest pytest-timeout + pip install pytest pytest-timeout expecttest - name: Install torch on mac if: startsWith(matrix.os, 'macOS') run: pip install --no-cache-dir torch==${{ matrix.torch }} torchvision==${{ matrix.torchvision }} From bdd3dff0ca0841e19e36d14e106691e43d35d996 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Fri, 19 Nov 2021 08:39:48 -0800 Subject: [PATCH 19/26] beit_large models killing GitHub actions test, filter out --- tests/test_models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_models.py b/tests/test_models.py index 39e2dcdc..f4520720 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -33,7 +33,7 @@ if 'GITHUB_ACTIONS' in os.environ: # and 'Linux' in platform.system(): EXCLUDE_FILTERS = [ '*efficientnet_l2*', '*resnext101_32x48d', '*in21k', '*152x4_bitm', '*101x3_bitm', '*50x3_bitm', '*nfnet_f3*', '*nfnet_f4*', '*nfnet_f5*', '*nfnet_f6*', '*nfnet_f7*', '*efficientnetv2_xl*', - '*resnetrs350*', '*resnetrs420*', 'xcit_large_24_p8*'] + '*resnetrs350*', '*resnetrs420*', 'xcit_large_24_p8*', 'beit_large*'] else: EXCLUDE_FILTERS = [] From 9b3519545d6bf901047dccd24832793c95919cd4 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Fri, 19 Nov 2021 14:24:12 -0800 Subject: [PATCH 20/26] Attempt to reduce memory footprint of FX tests for GitHub actions runs --- tests/test_models.py | 107 ++++++++++++++++++------------------------- 1 file changed, 45 insertions(+), 62 deletions(-) diff --git a/tests/test_models.py b/tests/test_models.py index f4520720..1750d540 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -33,7 +33,7 @@ if 'GITHUB_ACTIONS' in os.environ: # and 'Linux' in platform.system(): EXCLUDE_FILTERS = [ '*efficientnet_l2*', '*resnext101_32x48d', '*in21k', '*152x4_bitm', '*101x3_bitm', '*50x3_bitm', '*nfnet_f3*', '*nfnet_f4*', '*nfnet_f5*', '*nfnet_f6*', '*nfnet_f7*', '*efficientnetv2_xl*', - '*resnetrs350*', '*resnetrs420*', 'xcit_large_24_p8*', 'beit_large*'] + '*resnetrs350*', '*resnetrs420*', 'xcit_large_24_p8*'] else: EXCLUDE_FILTERS = [] @@ -45,6 +45,10 @@ TARGET_JIT_SIZE = 128 MAX_JIT_SIZE = 320 TARGET_FFEAT_SIZE = 96 MAX_FFEAT_SIZE = 256 +TARGET_FWD_FX_SIZE = 128 +MAX_FWD_FX_SIZE = 224 +TARGET_BWD_FX_SIZE = 128 +MAX_BWD_FX_SIZE = 224 def _get_input_size(model=None, model_name='', target=None): @@ -306,6 +310,30 @@ def test_model_forward_features(model_name, batch_size): assert not torch.isnan(o).any() +def _create_fx_model(model, train=False): + # This block of code does a bit of juggling to handle any case where there are multiple outputs in train mode + # So we trace once and look at the graph, and get the indices of the nodes that lead into the original fx output + # node. Then we use those indices to select from train_nodes returned by torchvision get_graph_node_names + train_nodes, eval_nodes = get_graph_node_names( + model, tracer_kwargs={'leaf_modules': list(_leaf_modules), 'autowrap_functions': list(_autowrap_functions)}) + + eval_return_nodes = [eval_nodes[-1]] + train_return_nodes = [train_nodes[-1]] + if train: + tracer = NodePathTracer(leaf_modules=list(_leaf_modules), autowrap_functions=list(_autowrap_functions)) + graph = tracer.trace(model) + graph_nodes = list(reversed(graph.nodes)) + output_node_names = [n.name for n in graph_nodes[0]._input_nodes.keys()] + graph_node_names = [n.name for n in graph_nodes] + output_node_indices = [-graph_node_names.index(node_name) for node_name in output_node_names] + train_return_nodes = [train_nodes[ix] for ix in output_node_indices] + + fx_model = create_feature_extractor( + model, train_return_nodes=train_return_nodes, eval_return_nodes=eval_return_nodes, + tracer_kwargs={'leaf_modules': list(_leaf_modules), 'autowrap_functions': list(_autowrap_functions)}) + return fx_model + + @pytest.mark.timeout(120) @pytest.mark.parametrize('model_name', list_models(exclude_filters=EXCLUDE_FILTERS)) @pytest.mark.parametrize('batch_size', [1]) @@ -320,39 +348,23 @@ def test_model_forward_fx(model_name, batch_size): model = create_model(model_name, pretrained=False) model.eval() - input_size = _get_input_size(model=model, target=TARGET_FWD_SIZE) - if max(input_size) > MAX_FWD_SIZE: + input_size = _get_input_size(model=model, target=TARGET_FWD_FX_SIZE) + if max(input_size) > MAX_FWD_FX_SIZE: pytest.skip("Fixed input size model > limit.") - - # This block of code does a bit of juggling to handle any case where there are multiple outputs in train mode - # So we trace once and look at the graph, and get the indices of the nodes that lead into the original fx output - # node. Then we use those indices to select from train_nodes returned by torchvision get_graph_node_names - tracer = NodePathTracer(leaf_modules=list(_leaf_modules), autowrap_functions=list(_autowrap_functions)) - graph = tracer.trace(model) - graph_nodes = list(reversed(graph.nodes)) - output_node_names = [n.name for n in graph_nodes[0]._input_nodes.keys()] - graph_node_names = [n.name for n in graph_nodes] - output_node_indices = [-graph_node_names.index(node_name) for node_name in output_node_names] - train_nodes, eval_nodes = get_graph_node_names( - model, tracer_kwargs={'leaf_modules': list(_leaf_modules), 'autowrap_functions': list(_autowrap_functions)}) - eval_return_nodes = [eval_nodes[ix] for ix in output_node_indices] - - fx_model = create_feature_extractor( - model, train_return_nodes=[train_nodes[-1]], eval_return_nodes=eval_return_nodes, - tracer_kwargs={'leaf_modules': list(_leaf_modules), 'autowrap_functions': list(_autowrap_functions)}) - inputs = torch.randn((batch_size, *input_size)) outputs = model(inputs) if isinstance(outputs, tuple): outputs = torch.cat(outputs) - fx_outputs = tuple(fx_model(inputs).values()) + + model = _create_fx_model(model) + fx_outputs = tuple(model(inputs).values()) if isinstance(fx_outputs, tuple): fx_outputs = torch.cat(fx_outputs) assert torch.all(fx_outputs == outputs) assert outputs.shape[0] == batch_size assert not torch.isnan(outputs).any(), 'Output included NaNs' - + @pytest.mark.timeout(120) @pytest.mark.parametrize('model_name', list_models(exclude_filters=EXCLUDE_FILTERS, name_matches_cfg=True)) @@ -362,38 +374,16 @@ def test_model_backward_fx(model_name, batch_size): if not has_fx_feature_extraction: pytest.skip("Can't test FX. Torch >= 1.10 and Torchvision >= 0.11 are required.") - input_size = _get_input_size(model_name=model_name, target=TARGET_BWD_SIZE) - if max(input_size) > MAX_BWD_SIZE: + input_size = _get_input_size(model_name=model_name, target=TARGET_BWD_FX_SIZE) + if max(input_size) > MAX_BWD_FX_SIZE: pytest.skip("Fixed input size model > limit.") model = create_model(model_name, pretrained=False, num_classes=42) - model.train() - num_params = sum([x.numel() for x in model.parameters()]) + model.train() - input_size = _get_input_size(model=model, target=TARGET_FWD_SIZE) - if max(input_size) > MAX_FWD_SIZE: - pytest.skip("Fixed input size model > limit.") - - # This block of code does a bit of juggling to handle any case where there are multiple outputs in train mode - # So we trace once and look at the graph, and get the indices of the nodes that lead into the original fx output - # node. Then we use those indices to select from train_nodes returned by torchvision get_graph_node_names - tracer = NodePathTracer(leaf_modules=list(_leaf_modules), autowrap_functions=list(_autowrap_functions)) - graph = tracer.trace(model) - graph_nodes = list(reversed(graph.nodes)) - output_node_names = [n.name for n in graph_nodes[0]._input_nodes.keys()] - graph_node_names = [n.name for n in graph_nodes] - output_node_indices = [-graph_node_names.index(node_name) for node_name in output_node_names] - train_nodes, eval_nodes = get_graph_node_names( - model, tracer_kwargs={'leaf_modules': list(_leaf_modules), 'autowrap_functions': list(_autowrap_functions)}) - train_return_nodes = [train_nodes[ix] for ix in output_node_indices] - - model = create_feature_extractor( - model, train_return_nodes=train_return_nodes, eval_return_nodes=[eval_nodes[-1]], - tracer_kwargs={'leaf_modules': list(_leaf_modules), 'autowrap_functions': list(_autowrap_functions)}) - - inputs = torch.randn((batch_size, *input_size)) - outputs = tuple(model(inputs).values()) + model = _create_fx_model(model, train=True) + outputs = tuple(model(torch.randn((batch_size, *input_size))).values()) if isinstance(outputs, tuple): outputs = torch.cat(outputs) outputs.mean().backward() @@ -412,6 +402,7 @@ EXCLUDE_FX_JIT_FILTERS = [ 'pit_*_distilled_224', ] + @pytest.mark.timeout(120) @pytest.mark.parametrize( 'model_name', list_models( @@ -430,18 +421,10 @@ def test_model_forward_fx_torchscript(model_name, batch_size): model = create_model(model_name, pretrained=False) model.eval() - input_size = _get_input_size(model=model, target=TARGET_FWD_SIZE) - if max(input_size) > MAX_FWD_SIZE: - pytest.skip("Fixed input size model > limit.") - - train_nodes, eval_nodes = get_graph_node_names( - model, tracer_kwargs={'leaf_modules': list(_leaf_modules), 'autowrap_functions': list(_autowrap_functions)}) - model = create_feature_extractor( - model, train_return_nodes=[train_nodes[-1]], eval_return_nodes=[eval_nodes[-1]], - tracer_kwargs={'leaf_modules': list(_leaf_modules), 'autowrap_functions': list(_autowrap_functions)}) - - model = torch.jit.script(model) - outputs = model(torch.randn((batch_size, *input_size)))[train_nodes[-1]] + model = torch.jit.script(_create_fx_model(model)) + outputs = tuple(model(torch.randn((batch_size, *input_size))).values()) + if isinstance(outputs, tuple): + outputs = torch.cat(outputs) assert outputs.shape[0] == batch_size assert not torch.isnan(outputs).any(), 'Output included NaNs' From c976a410d9295224f23cf413466fd3387bf651c1 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Fri, 19 Nov 2021 14:24:43 -0800 Subject: [PATCH 21/26] Add ResNet-50 w/ GN (resnet50_gn) and SEBotNet-33-TS (sebotnet33ts_256) model defs and weights. Update halonet50ts weights w/ slightly better variant in1k val, more robust to test sets. --- timm/models/byoanet.py | 28 +++++++++++++++++++++++++++- timm/models/layers/norm.py | 2 +- timm/models/resnet.py | 15 ++++++++++++++- 3 files changed, 42 insertions(+), 3 deletions(-) diff --git a/timm/models/byoanet.py b/timm/models/byoanet.py index b05cd91a..7fc7f82e 100644 --- a/timm/models/byoanet.py +++ b/timm/models/byoanet.py @@ -36,6 +36,9 @@ default_cfgs = { 'botnet26t_256': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/botnet26t_c1_256-167a0e9f.pth', fixed_input_size=True, input_size=(3, 256, 256), pool_size=(8, 8)), + 'sebotnet33ts_256': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/sebotnet33ts_a1h2_256-957e3c3e.pth', + fixed_input_size=True, input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=0.94), 'botnet50ts_256': _cfg( url='', fixed_input_size=True, input_size=(3, 256, 256), pool_size=(8, 8)), @@ -51,7 +54,7 @@ default_cfgs = { url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/sehalonet33ts_256-87e053f9.pth', input_size=(3, 256, 256), pool_size=(8, 8), min_input_size=(3, 256, 256), crop_pct=0.94), 'halonet50ts': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/halonet50ts_a1h_256-c6d7ff15.pth', + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/halonet50ts_a1h2_256-f3a3daee.pth', input_size=(3, 256, 256), pool_size=(8, 8), min_input_size=(3, 256, 256), crop_pct=0.94), 'eca_halonext26ts': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/eca_halonext26ts_c_256-06906299.pth', @@ -97,6 +100,22 @@ model_cfgs = dict( self_attn_layer='bottleneck', self_attn_kwargs=dict() ), + sebotnet33ts=ByoModelCfg( + blocks=( + ByoBlockCfg(type='bottle', d=2, c=256, s=1, gs=0, br=0.25), + interleave_blocks(types=('bottle', 'self_attn'), every=[2], d=3, c=512, s=2, gs=0, br=0.25), + interleave_blocks(types=('bottle', 'self_attn'), every=[2], d=3, c=1024, s=2, gs=0, br=0.25), + ByoBlockCfg('self_attn', d=2, c=1536, s=2, gs=0, br=0.333), + ), + stem_chs=64, + stem_type='tiered', + stem_pool='', + act_layer='silu', + num_features=1280, + attn_layer='se', + self_attn_layer='bottleneck', + self_attn_kwargs=dict() + ), botnet50ts=ByoModelCfg( blocks=( ByoBlockCfg(type='bottle', d=3, c=256, s=1, gs=0, br=0.25), @@ -322,6 +341,13 @@ def botnet26t_256(pretrained=False, **kwargs): return _create_byoanet('botnet26t_256', 'botnet26t', pretrained=pretrained, **kwargs) +@register_model +def sebotnet33ts_256(pretrained=False, **kwargs): + """ Bottleneck Transformer w/ a ResNet33-t backbone, SE attn for non Halo blocks, SiLU, + """ + return _create_byoanet('sebotnet33ts_256', 'sebotnet33ts', pretrained=pretrained, **kwargs) + + @register_model def botnet50ts_256(pretrained=False, **kwargs): """ Bottleneck Transformer w/ ResNet50-T backbone, silu act. diff --git a/timm/models/layers/norm.py b/timm/models/layers/norm.py index aace107b..85297420 100644 --- a/timm/models/layers/norm.py +++ b/timm/models/layers/norm.py @@ -6,7 +6,7 @@ import torch.nn.functional as F class GroupNorm(nn.GroupNorm): - def __init__(self, num_channels, num_groups, eps=1e-5, affine=True): + def __init__(self, num_channels, num_groups=32, eps=1e-5, affine=True): # NOTE num_channels is swapped to first arg for consistency in swapping norm layers with BN super().__init__(num_groups, num_channels, eps=eps, affine=affine) diff --git a/timm/models/resnet.py b/timm/models/resnet.py index 998a739e..bbcae9a3 100644 --- a/timm/models/resnet.py +++ b/timm/models/resnet.py @@ -15,7 +15,7 @@ import torch.nn.functional as F from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from .helpers import build_model_with_cfg -from .layers import DropBlock2d, DropPath, AvgPool2dSame, BlurPool2d, create_attn, get_attn, create_classifier +from .layers import DropBlock2d, DropPath, AvgPool2dSame, BlurPool2d, GroupNorm, create_attn, get_attn, create_classifier from .registry import register_model __all__ = ['ResNet', 'BasicBlock', 'Bottleneck'] # model_registry will add each entrypoint fn to this @@ -89,6 +89,11 @@ default_cfgs = { interpolation='bicubic'), 'wide_resnet101_2': _cfg(url='https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth'), + # ResNets w/ alternative norm layers + 'resnet50_gn': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/resnet50_gn_a1h2-8fe6c4d0.pth', + crop_pct=0.94, interpolation='bicubic'), + # ResNeXt 'resnext50_32x4d': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/resnext50_32x4d_a1h-0146ab0a.pth', @@ -881,6 +886,14 @@ def wide_resnet101_2(pretrained=False, **kwargs): return _create_resnet('wide_resnet101_2', pretrained, **model_args) +@register_model +def resnet50_gn(pretrained=False, **kwargs): + """Constructs a ResNet-50 model w/ GroupNorm + """ + model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], **kwargs) + return _create_resnet('resnet50_gn', pretrained, norm_layer=GroupNorm, **model_args) + + @register_model def resnext50_32x4d(pretrained=False, **kwargs): """Constructs a ResNeXt50-32x4d model. From 3819bef93e55f7f0ec4b0e2718624f3782559dd0 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Fri, 19 Nov 2021 17:35:41 -0800 Subject: [PATCH 22/26] Add FX test exclusion since it uses more ram and barfs on GitHub actions. Will take a few iterations to include needed models :( --- tests/test_models.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/tests/test_models.py b/tests/test_models.py index 1750d540..5fde43da 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -334,8 +334,14 @@ def _create_fx_model(model, train=False): return fx_model +EXCLUDE_FX_FILTERS = [] +# not enough memory to run fx on more models than other tests +if 'GITHUB_ACTIONS' in os.environ: + EXCLUDE_FX_FILTERS += ['beit_large*', 'swin_large*'] + + @pytest.mark.timeout(120) -@pytest.mark.parametrize('model_name', list_models(exclude_filters=EXCLUDE_FILTERS)) +@pytest.mark.parametrize('model_name', list_models(exclude_filters=EXCLUDE_FILTERS + EXCLUDE_FX_FILTERS)) @pytest.mark.parametrize('batch_size', [1]) def test_model_forward_fx(model_name, batch_size): """ @@ -367,7 +373,8 @@ def test_model_forward_fx(model_name, batch_size): @pytest.mark.timeout(120) -@pytest.mark.parametrize('model_name', list_models(exclude_filters=EXCLUDE_FILTERS, name_matches_cfg=True)) +@pytest.mark.parametrize('model_name', list_models( + exclude_filters=EXCLUDE_FILTERS + EXCLUDE_FX_FILTERS, name_matches_cfg=True)) @pytest.mark.parametrize('batch_size', [2]) def test_model_backward_fx(model_name, batch_size): """Symbolically trace each model and run single backward pass through the resulting GraphModule""" @@ -400,7 +407,7 @@ EXCLUDE_FX_JIT_FILTERS = [ 'deit_*_distilled_patch16_224', 'levit*', 'pit_*_distilled_224', -] +] + EXCLUDE_FX_FILTERS @pytest.mark.timeout(120) From af607b75cc7aa91eaece8a4478939e33a17cd1e9 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Fri, 19 Nov 2021 17:37:00 -0800 Subject: [PATCH 23/26] Prep a set of ResNetV2 models with GroupNorm, EvoNormB0, EvoNormS0 for BN free model experiments on TPU and IPU --- timm/models/layers/evo_norm.py | 2 +- timm/models/layers/norm_act.py | 2 +- timm/models/resnetv2.py | 47 ++++++++++++++++++++++------------ 3 files changed, 33 insertions(+), 18 deletions(-) diff --git a/timm/models/layers/evo_norm.py b/timm/models/layers/evo_norm.py index ecc5fb61..50367f9b 100644 --- a/timm/models/layers/evo_norm.py +++ b/timm/models/layers/evo_norm.py @@ -55,7 +55,7 @@ class EvoNormBatch2d(nn.Module): class EvoNormSample2d(nn.Module): - def __init__(self, num_features, apply_act=True, groups=8, eps=1e-5, drop_block=None): + def __init__(self, num_features, apply_act=True, groups=32, eps=1e-5, drop_block=None): super(EvoNormSample2d, self).__init__() self.apply_act = apply_act # apply activation (non-linearity) self.groups = groups diff --git a/timm/models/layers/norm_act.py b/timm/models/layers/norm_act.py index 02cabe88..2e15181f 100644 --- a/timm/models/layers/norm_act.py +++ b/timm/models/layers/norm_act.py @@ -68,7 +68,7 @@ class BatchNormAct2d(nn.BatchNorm2d): class GroupNormAct(nn.GroupNorm): # NOTE num_channel and num_groups order flipped for easier layer swaps / binding of fixed args - def __init__(self, num_channels, num_groups, eps=1e-5, affine=True, + def __init__(self, num_channels, num_groups=32, eps=1e-5, affine=True, apply_act=True, act_layer=nn.ReLU, inplace=True, drop_block=None): super(GroupNormAct, self).__init__(num_groups, num_channels, eps=eps, affine=affine) if isinstance(act_layer, str): diff --git a/timm/models/resnetv2.py b/timm/models/resnetv2.py index 43940cc3..e38eaf5e 100644 --- a/timm/models/resnetv2.py +++ b/timm/models/resnetv2.py @@ -120,6 +120,13 @@ default_cfgs = { interpolation='bicubic'), 'resnetv2_152d': _cfg( interpolation='bicubic', first_conv='stem.conv1'), + + 'resnetv2_50d_gn': _cfg( + interpolation='bicubic', first_conv='stem.conv1'), + 'resnetv2_50d_evob': _cfg( + interpolation='bicubic', first_conv='stem.conv1'), + 'resnetv2_50d_evos': _cfg( + interpolation='bicubic', first_conv='stem.conv1'), } @@ -639,19 +646,27 @@ def resnetv2_152d(pretrained=False, **kwargs): stem_type='deep', avg_down=True, **kwargs) -# @register_model -# def resnetv2_50ebd(pretrained=False, **kwargs): -# # FIXME for testing w/ TPU + PyTorch XLA -# return _create_resnetv2( -# 'resnetv2_50d', pretrained=pretrained, -# layers=[3, 4, 6, 3], conv_layer=create_conv2d, norm_layer=EvoNormBatch2d, -# stem_type='deep', avg_down=True, **kwargs) -# -# -# @register_model -# def resnetv2_50esd(pretrained=False, **kwargs): -# # FIXME for testing w/ TPU + PyTorch XLA -# return _create_resnetv2( -# 'resnetv2_50d', pretrained=pretrained, -# layers=[3, 4, 6, 3], conv_layer=create_conv2d, norm_layer=EvoNormSample2d, -# stem_type='deep', avg_down=True, **kwargs) +# Experimental configs (may change / be removed) + +@register_model +def resnetv2_50d_gn(pretrained=False, **kwargs): + return _create_resnetv2( + 'resnetv2_50d_gn', pretrained=pretrained, + layers=[3, 4, 6, 3], conv_layer=create_conv2d, norm_layer=GroupNormAct, + stem_type='deep', avg_down=True, **kwargs) + + +@register_model +def resnetv2_50d_evob(pretrained=False, **kwargs): + return _create_resnetv2( + 'resnetv2_50d_evob', pretrained=pretrained, + layers=[3, 4, 6, 3], conv_layer=create_conv2d, norm_layer=EvoNormBatch2d, + stem_type='deep', avg_down=True, **kwargs) + + +@register_model +def resnetv2_50d_evos(pretrained=False, **kwargs): + return _create_resnetv2( + 'resnetv2_50d_evos', pretrained=pretrained, + layers=[3, 4, 6, 3], conv_layer=create_conv2d, norm_layer=EvoNormSample2d, + stem_type='deep', avg_down=True, **kwargs) From 93cc08fdc5a3f6716c183150b8370621788a13f0 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Sat, 20 Nov 2021 15:50:51 -0800 Subject: [PATCH 24/26] Make evonorm variables 1d to match other PyTorch norm layers, will break weight compat for any existing use (likely minimal, easy to fix). --- timm/models/layers/evo_norm.py | 37 ++++++++++++++++------------------ 1 file changed, 17 insertions(+), 20 deletions(-) diff --git a/timm/models/layers/evo_norm.py b/timm/models/layers/evo_norm.py index 50367f9b..8c08e49f 100644 --- a/timm/models/layers/evo_norm.py +++ b/timm/models/layers/evo_norm.py @@ -21,12 +21,10 @@ class EvoNormBatch2d(nn.Module): self.apply_act = apply_act # apply activation (non-linearity) self.momentum = momentum self.eps = eps - param_shape = (1, num_features, 1, 1) - self.weight = nn.Parameter(torch.ones(param_shape), requires_grad=True) - self.bias = nn.Parameter(torch.zeros(param_shape), requires_grad=True) - if apply_act: - self.v = nn.Parameter(torch.ones(param_shape), requires_grad=True) - self.register_buffer('running_var', torch.ones(1, num_features, 1, 1)) + self.weight = nn.Parameter(torch.ones(num_features), requires_grad=True) + self.bias = nn.Parameter(torch.zeros(num_features), requires_grad=True) + self.v = nn.Parameter(torch.ones(num_features), requires_grad=True) if apply_act else None + self.register_buffer('running_var', torch.ones(num_features)) self.reset_parameters() def reset_parameters(self): @@ -38,20 +36,21 @@ class EvoNormBatch2d(nn.Module): def forward(self, x): assert x.dim() == 4, 'expected 4D input' x_type = x.dtype + running_var = self.running_var.view(1, -1, 1, 1) if self.training: var = x.var(dim=(0, 2, 3), unbiased=False, keepdim=True) n = x.numel() / x.shape[1] - self.running_var.copy_( - var.detach() * self.momentum * (n / (n - 1)) + self.running_var * (1 - self.momentum)) + running_var = var.detach() * self.momentum * (n / (n - 1)) + running_var * (1 - self.momentum) + self.running_var.copy_(running_var.view(self.running_var.shape)) else: - var = self.running_var + var = running_var - if self.apply_act: - v = self.v.to(dtype=x_type) + if self.v is not None: + v = self.v.to(dtype=x_type).reshape(1, -1, 1, 1) d = x * v + (x.var(dim=(2, 3), unbiased=False, keepdim=True) + self.eps).sqrt().to(dtype=x_type) d = d.max((var + self.eps).sqrt().to(dtype=x_type)) x = x / d - return x * self.weight + self.bias + return x * self.weight.view(1, -1, 1, 1) + self.bias.view(1, -1, 1, 1) class EvoNormSample2d(nn.Module): @@ -60,11 +59,9 @@ class EvoNormSample2d(nn.Module): self.apply_act = apply_act # apply activation (non-linearity) self.groups = groups self.eps = eps - param_shape = (1, num_features, 1, 1) - self.weight = nn.Parameter(torch.ones(param_shape), requires_grad=True) - self.bias = nn.Parameter(torch.zeros(param_shape), requires_grad=True) - if apply_act: - self.v = nn.Parameter(torch.ones(param_shape), requires_grad=True) + self.weight = nn.Parameter(torch.ones(num_features), requires_grad=True) + self.bias = nn.Parameter(torch.zeros(num_features), requires_grad=True) + self.v = nn.Parameter(torch.ones(num_features), requires_grad=True) if apply_act else None self.reset_parameters() def reset_parameters(self): @@ -77,9 +74,9 @@ class EvoNormSample2d(nn.Module): _assert(x.dim() == 4, 'expected 4D input') B, C, H, W = x.shape _assert(C % self.groups == 0, '') - if self.apply_act: - n = x * (x * self.v).sigmoid() + if self.v is not None: + n = x * (x * self.v.view(1, -1, 1, 1)).sigmoid() x = x.reshape(B, self.groups, -1) x = n.reshape(B, self.groups, -1) / (x.var(dim=-1, unbiased=False, keepdim=True) + self.eps).sqrt() x = x.reshape(B, C, H, W) - return x * self.weight + self.bias + return x * self.weight.view(1, -1, 1, 1) + self.bias.view(1, -1, 1, 1) From 05092e2fbeafa034b363094876c54eec04c342ad Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Sat, 20 Nov 2021 15:51:48 -0800 Subject: [PATCH 25/26] Add more models to FX filter --- tests/test_models.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/tests/test_models.py b/tests/test_models.py index 5fde43da..f55247ee 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -337,7 +337,12 @@ def _create_fx_model(model, train=False): EXCLUDE_FX_FILTERS = [] # not enough memory to run fx on more models than other tests if 'GITHUB_ACTIONS' in os.environ: - EXCLUDE_FX_FILTERS += ['beit_large*', 'swin_large*'] + EXCLUDE_FX_FILTERS += [ + 'beit_large*', + 'swin_large*', + '*resnext101_32x32d', + 'resnetv2_152x2*', + ] @pytest.mark.timeout(120) From 734b2244fe5962d55f0a49492be1f17bf549f9f9 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Sat, 20 Nov 2021 15:52:04 -0800 Subject: [PATCH 26/26] Add RegNetZ-D8 (83.5 @ 256, 84 @ 320) and RegNetZ-E8 (84.5 @ 256, 85 @ 320) weights. Update names of existing RegZ models to include group size. --- timm/models/byobnet.py | 87 ++++++++++++++++++++++++++++++++++-------- 1 file changed, 72 insertions(+), 15 deletions(-) diff --git a/timm/models/byobnet.py b/timm/models/byobnet.py index d7253bdf..fa57943a 100644 --- a/timm/models/byobnet.py +++ b/timm/models/byobnet.py @@ -35,7 +35,7 @@ import torch.nn as nn from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from .helpers import build_model_with_cfg, named_apply from .layers import ClassifierHead, ConvBnAct, BatchNormAct2d, DropPath, AvgPool2dSame, \ - create_conv2d, get_act_layer, convert_norm_act, get_attn, make_divisible, to_2tuple + create_conv2d, get_act_layer, convert_norm_act, get_attn, make_divisible, to_2tuple, EvoNormSample2d from .registry import register_model __all__ = ['ByobNet', 'ByoModelCfg', 'ByoBlockCfg', 'create_byob_stem', 'create_block'] @@ -136,20 +136,26 @@ default_cfgs = { url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/gcresnext50ts_256-3e0f515e.pth'), # experimental models, likely to change ot be removed - 'regnetz_b': _cfgr( + 'regnetz_b16': _cfgr( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/regnetz_b_raa-677d9606.pth', mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), input_size=(3, 224, 224), pool_size=(7, 7), test_input_size=(3, 288, 288), first_conv='stem.conv', crop_pct=0.94), - 'regnetz_c': _cfgr( + 'regnetz_c16': _cfgr( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/regnetz_c_rab2_256-a54bf36a.pth', mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), test_input_size=(3, 320, 320), first_conv='stem.conv', crop_pct=0.94), - 'regnetz_d': _cfgr( + 'regnetz_d32': _cfgr( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/regnetz_d_rab_256-b8073a89.pth', mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), test_input_size=(3, 320, 320), crop_pct=0.95), 'regnetz_d8': _cfgr( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/regnetz_d8_bh-afc03c55.pth', + mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), test_input_size=(3, 320, 320), crop_pct=1.0), + 'regnetz_e8': _cfgr( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/regnetz_e8_bh-aace8e6e.pth', + mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), test_input_size=(3, 320, 320), crop_pct=1.0), + 'regnetz_d8_evob': _cfgr( url='', mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), test_input_size=(3, 320, 320), crop_pct=0.95), - 'regnetz_e8': _cfgr( + 'regnetz_d8_evos': _cfgr( url='', mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), test_input_size=(3, 320, 320), crop_pct=0.95), } @@ -506,7 +512,7 @@ model_cfgs = dict( ), # experimental models, closer to a RegNetZ than a ResNet. Similar to EfficientNets but w/ groups instead of DW - regnetz_b=ByoModelCfg( + regnetz_b16=ByoModelCfg( blocks=( ByoBlockCfg(type='bottle', d=2, c=48, s=2, gs=16, br=3), ByoBlockCfg(type='bottle', d=6, c=96, s=2, gs=16, br=3), @@ -522,7 +528,7 @@ model_cfgs = dict( attn_kwargs=dict(rd_ratio=0.25), block_kwargs=dict(bottle_in=True, linear_out=True), ), - regnetz_c=ByoModelCfg( + regnetz_c16=ByoModelCfg( blocks=( ByoBlockCfg(type='bottle', d=2, c=48, s=2, gs=16, br=4), ByoBlockCfg(type='bottle', d=6, c=96, s=2, gs=16, br=4), @@ -538,7 +544,7 @@ model_cfgs = dict( attn_kwargs=dict(rd_ratio=0.25), block_kwargs=dict(bottle_in=True, linear_out=True), ), - regnetz_d=ByoModelCfg( + regnetz_d32=ByoModelCfg( blocks=( ByoBlockCfg(type='bottle', d=3, c=64, s=1, gs=32, br=4), ByoBlockCfg(type='bottle', d=6, c=128, s=2, gs=32, br=4), @@ -589,8 +595,45 @@ model_cfgs = dict( attn_kwargs=dict(rd_ratio=0.25), block_kwargs=dict(bottle_in=True, linear_out=True), ), -) + # experimental EvoNorm configs + regnetz_d8_evob=ByoModelCfg( + blocks=( + ByoBlockCfg(type='bottle', d=3, c=64, s=1, gs=8, br=4), + ByoBlockCfg(type='bottle', d=6, c=128, s=2, gs=8, br=4), + ByoBlockCfg(type='bottle', d=12, c=256, s=2, gs=8, br=4), + ByoBlockCfg(type='bottle', d=3, c=384, s=2, gs=8, br=4), + ), + stem_chs=64, + stem_type='tiered', + stem_pool='', + downsample='', + num_features=1792, + act_layer='silu', + norm_layer='evonormbatch', + attn_layer='se', + attn_kwargs=dict(rd_ratio=0.25), + block_kwargs=dict(bottle_in=True, linear_out=True), + ), + regnetz_d8_evos=ByoModelCfg( + blocks=( + ByoBlockCfg(type='bottle', d=3, c=64, s=1, gs=8, br=4), + ByoBlockCfg(type='bottle', d=6, c=128, s=2, gs=8, br=4), + ByoBlockCfg(type='bottle', d=12, c=256, s=2, gs=8, br=4), + ByoBlockCfg(type='bottle', d=3, c=384, s=2, gs=8, br=4), + ), + stem_chs=64, + stem_type='deep', + stem_pool='', + downsample='', + num_features=1792, + act_layer='silu', + norm_layer=partial(EvoNormSample2d, groups=32), + attn_layer='se', + attn_kwargs=dict(rd_ratio=0.25), + block_kwargs=dict(bottle_in=True, linear_out=True), + ), +) @register_model def gernet_l(pretrained=False, **kwargs): @@ -779,24 +822,24 @@ def gcresnext50ts(pretrained=False, **kwargs): @register_model -def regnetz_b(pretrained=False, **kwargs): +def regnetz_b16(pretrained=False, **kwargs): """ """ - return _create_byobnet('regnetz_b', pretrained=pretrained, **kwargs) + return _create_byobnet('regnetz_b16', pretrained=pretrained, **kwargs) @register_model -def regnetz_c(pretrained=False, **kwargs): +def regnetz_c16(pretrained=False, **kwargs): """ """ - return _create_byobnet('regnetz_c', pretrained=pretrained, **kwargs) + return _create_byobnet('regnetz_c16', pretrained=pretrained, **kwargs) @register_model -def regnetz_d(pretrained=False, **kwargs): +def regnetz_d32(pretrained=False, **kwargs): """ """ - return _create_byobnet('regnetz_d', pretrained=pretrained, **kwargs) + return _create_byobnet('regnetz_d32', pretrained=pretrained, **kwargs) @register_model @@ -813,6 +856,20 @@ def regnetz_e8(pretrained=False, **kwargs): return _create_byobnet('regnetz_e8', pretrained=pretrained, **kwargs) +@register_model +def regnetz_d8_evob(pretrained=False, **kwargs): + """ + """ + return _create_byobnet('regnetz_d8_evob', pretrained=pretrained, **kwargs) + + +@register_model +def regnetz_d8_evos(pretrained=False, **kwargs): + """ + """ + return _create_byobnet('regnetz_d8_evos', pretrained=pretrained, **kwargs) + + def expand_blocks_cfg(stage_blocks_cfg: Union[ByoBlockCfg, Sequence[ByoBlockCfg]]) -> List[ByoBlockCfg]: if not isinstance(stage_blocks_cfg, Sequence): stage_blocks_cfg = (stage_blocks_cfg,)