wip - pre-rebase

pull/800/head
Alexander Soare 3 years ago
parent e051dce354
commit b25ff96768

@ -4,10 +4,12 @@ import platform
import os
import fnmatch
from torchvision.models.feature_extraction import create_feature_extractor, get_graph_node_names, NodePathTracer
import timm
from timm import list_models, create_model, set_scriptable, has_model_default_key, is_model_default_key, \
get_model_default_value
from timm.models.fx_features import NodePathTracer
from timm.models.fx_features import _leaf_modules, _autowrap_functions
if hasattr(torch._C, '_jit_set_profiling_executor'):
# legacy executor is too slow to compile large models for unit tests
@ -312,12 +314,14 @@ def test_model_forward_fx(model_name, batch_size):
if max(input_size) > MAX_FWD_SIZE:
pytest.skip("Fixed input size model > limit.")
tracer = NodePathTracer()
graph = tracer.trace(model)
model = torch.fx.GraphModule(model, graph)
train_nodes, eval_nodes = get_graph_node_names(
model, tracer_kwargs={'leaf_modules': list(_leaf_modules), 'autowrap_functions': list(_autowrap_functions)})
model = create_feature_extractor(
model, train_return_nodes=[train_nodes[-1]], eval_return_nodes=[eval_nodes[-1]],
tracer_kwargs={'leaf_modules': list(_leaf_modules), 'autowrap_functions': list(_autowrap_functions)})
inputs = torch.randn((batch_size, *input_size))
outputs = model(inputs)
outputs = model(inputs)[eval_nodes[-1]]
assert outputs.shape[0] == batch_size
assert not torch.isnan(outputs).any(), 'Output included NaNs'
@ -336,12 +340,30 @@ def test_model_backward_fx(model_name, batch_size):
model.train()
num_params = sum([x.numel() for x in model.parameters()])
tracer = NodePathTracer()
input_size = _get_input_size(model=model, target=TARGET_FWD_SIZE)
if max(input_size) > MAX_FWD_SIZE:
pytest.skip("Fixed input size model > limit.")
# This block of code does a bit of juggling to handle any case where there are multiple outputs in train mode
# If so, we need to return all of them in order to check all grads
# So we trace once and look at the graph, and get the indices of the nodes that lead into the original fx output
# node. Then we use those indices to select from train_nodes returned by torchvision get_graph_node_names
tracer = NodePathTracer(leaf_modules=list(_leaf_modules), autowrap_functions=list(_autowrap_functions))
graph = tracer.trace(model)
model = torch.fx.GraphModule(model, graph)
graph_nodes = list(reversed(graph.nodes))
output_node_names = [n.name for n in graph_nodes[0]._input_nodes.keys()]
graph_node_names = [n.name for n in graph_nodes]
output_node_indices = [-graph_node_names.index(node_name) for node_name in output_node_names]
train_nodes, eval_nodes = get_graph_node_names(
model, tracer_kwargs={'leaf_modules': list(_leaf_modules), 'autowrap_functions': list(_autowrap_functions)})
train_return_nodes = [train_nodes[ix] for ix in output_node_indices]
model = create_feature_extractor(
model, train_return_nodes=train_return_nodes, eval_return_nodes=[eval_nodes[-1]],
tracer_kwargs={'leaf_modules': list(_leaf_modules), 'autowrap_functions': list(_autowrap_functions)})
inputs = torch.randn((batch_size, *input_size))
outputs = model(inputs)
outputs = tuple(model(inputs).values())
if isinstance(outputs, tuple):
outputs = torch.cat(outputs)
outputs.mean().backward()
@ -354,9 +376,14 @@ def test_model_backward_fx(model_name, batch_size):
assert not torch.isnan(outputs).any(), 'Output included NaNs'
EXCLUDE_FX_JIT_FILTERS = [
'beit_*' # reason: model is scripted after fx tracing, but beit has torch.jit.is_scripting() control flow
]
@pytest.mark.timeout(120)
@pytest.mark.parametrize(
'model_name', list_models(exclude_filters=EXCLUDE_FILTERS + EXCLUDE_JIT_FILTERS, name_matches_cfg=True))
'model_name', list_models(
exclude_filters=EXCLUDE_FILTERS + EXCLUDE_JIT_FILTERS + EXCLUDE_FX_JIT_FILTERS, name_matches_cfg=True))
@pytest.mark.parametrize('batch_size', [1])
def test_model_forward_fx_torchscript(model_name, batch_size):
"""Symbolically trace each model, script it, and run single forward pass"""
@ -368,12 +395,18 @@ def test_model_forward_fx_torchscript(model_name, batch_size):
model = create_model(model_name, pretrained=False)
model.eval()
tracer = NodePathTracer()
graph = tracer.trace(model)
model = torch.fx.GraphModule(model, graph)
input_size = _get_input_size(model=model, target=TARGET_FWD_SIZE)
if max(input_size) > MAX_FWD_SIZE:
pytest.skip("Fixed input size model > limit.")
train_nodes, eval_nodes = get_graph_node_names(
model, tracer_kwargs={'leaf_modules': list(_leaf_modules), 'autowrap_functions': list(_autowrap_functions)})
model = create_feature_extractor(
model, train_return_nodes=[train_nodes[-1]], eval_return_nodes=[eval_nodes[-1]],
tracer_kwargs={'leaf_modules': list(_leaf_modules), 'autowrap_functions': list(_autowrap_functions)})
model = torch.jit.script(model)
outputs = model(torch.randn((batch_size, *input_size)))
outputs = model(torch.randn((batch_size, *input_size)))[train_nodes[-1]]
assert outputs.shape[0] == batch_size
assert not torch.isnan(outputs).any(), 'Output included NaNs'
assert not torch.isnan(outputs).any(), 'Output included NaNs'

@ -95,11 +95,11 @@ class ClassAttn(nn.Module):
q = q * self.scale
v = self.v(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
attn = torch.matmul(q, k.transpose(-2, -1))
attn = (q @ k.transpose(-2, -1))
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x_cls = torch.matmul(attn, v).transpose(1, 2).reshape(B, 1, C)
x_cls = (attn @ v).transpose(1, 2).reshape(B, 1, C)
x_cls = self.proj(x_cls)
x_cls = self.proj_drop(x_cls)
@ -158,7 +158,7 @@ class TalkingHeadAttn(nn.Module):
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0] * self.scale, qkv[1], qkv[2]
attn = torch.matmul(q, k.transpose(-2, -1))
attn = (q @ k.transpose(-2, -1))
attn = self.proj_l(attn.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
@ -167,7 +167,7 @@ class TalkingHeadAttn(nn.Module):
attn = self.proj_w(attn.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
attn = self.attn_drop(attn)
x = torch.matmul(attn, v).transpose(1, 2).reshape(B, N, C)
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x

@ -19,6 +19,7 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .helpers import build_model_with_cfg, overlay_external_default_cfg
from .layers import PatchEmbed, Mlp, DropPath, to_2tuple, trunc_normal_
from .registry import register_model
from .layers.trace_utils import _assert
__all__ = [
@ -105,7 +106,7 @@ class ConvRelPosEnc(nn.Module):
def forward(self, q, v, size: Tuple[int, int]):
B, h, N, Ch = q.shape
H, W = size
torch._assert(N == 1 + H * W, '')
_assert(N == 1 + H * W, '')
# Convolutional relative position encoding.
q_img = q[:, :, 1:, :] # [B, h, H*W, Ch]
@ -149,8 +150,8 @@ class FactorAtt_ConvRelPosEnc(nn.Module):
# Factorized attention.
k_softmax = k.softmax(dim=2)
factor_att = torch.matmul(k_softmax.transpose(-1, -2), v)
factor_att = torch.matmul(q, factor_att)
factor_att = k_softmax.transpose(-1, -2) @ v
factor_att = q @ factor_att
# Convolutional relative position encoding.
crpe = self.crpe(q, v, size=size) # [B, h, N, Ch]
@ -177,7 +178,7 @@ class ConvPosEnc(nn.Module):
def forward(self, x, size: Tuple[int, int]):
B, N, C = x.shape
H, W = size
torch._assert(N == 1 + H * W, '')
_assert(N == 1 + H * W, '')
# Extract CLS token and image tokens.
cls_token, img_tokens = x[:, :1], x[:, 1:] # [B, 1, C], [B, H*W, C]
@ -275,7 +276,7 @@ class ParallelBlock(nn.Module):
""" Feature map interpolation. """
B, N, C = x.shape
H, W = size
torch._assert(N == 1 + H * W, '')
_assert(N == 1 + H * W, '')
cls_token = x[:, :1, :]
img_tokens = x[:, 1:, :]

@ -57,7 +57,7 @@ default_cfgs = {
}
@register_leaf_module # FX can't symbolically trace control flow in forward method
@register_leaf_module # reason: FX can't symbolically trace control flow in forward method
class GPSA(nn.Module):
def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.,
locality_strength=1.):
@ -84,7 +84,7 @@ class GPSA(nn.Module):
self.rel_indices = self.get_rel_indices(N)
attn = self.get_attention(x)
v = self.v(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
x = torch.matmul(attn, v).transpose(1, 2).reshape(B, N, C)
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
@ -95,7 +95,7 @@ class GPSA(nn.Module):
q, k = qk[0], qk[1]
pos_score = self.rel_indices.expand(B, -1, -1, -1)
pos_score = self.pos_proj(pos_score).permute(0, 3, 1, 2)
patch_score = torch.matmul(q, k.transpose(-2, -1)) * self.scale
patch_score = (q @ k.transpose(-2, -1)) * self.scale
patch_score = patch_score.softmax(dim=-1)
pos_score = pos_score.softmax(dim=-1)
@ -180,11 +180,11 @@ class MHSA(nn.Module):
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2]
attn = torch.matmul(q, k.transpose(-2, -1)) * self.scale
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = torch.matmul(attn, v).transpose(1, 2).reshape(B, N, C)
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x

@ -22,6 +22,7 @@ NOTE: model names have been renamed from originals to represent actual input res
Modifed from Timm. https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
"""
from typing import Tuple
import torch
import torch.nn as nn
@ -31,8 +32,9 @@ from functools import partial
from typing import List
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .fx_features import register_autowrap_function
from .helpers import build_model_with_cfg
from .layers import DropPath, to_2tuple, trunc_normal_
from .layers import DropPath, to_2tuple, trunc_normal_, _assert
from .registry import register_model
from .vision_transformer import Mlp, Block
@ -116,8 +118,10 @@ class PatchEmbed(nn.Module):
def forward(self, x):
B, C, H, W = x.shape
# FIXME look at relaxing size constraints
assert H == self.img_size[0] and W == self.img_size[1], \
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
_assert(H == self.img_size[0],
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]}).")
_assert(W == self.img_size[1],
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]}).")
x = self.proj(x).flatten(2).transpose(1, 2)
return x
@ -255,6 +259,27 @@ def _compute_num_patches(img_size, patches):
return [i[0] // p * i[1] // p for i, p in zip(img_size, patches)]
@register_autowrap_function
def scale_image(x, ss: Tuple[int, int], crop_scale: bool = False): # annotations for torchscript
"""
Pulled out of CrossViT.forward_features to bury conditional logic in a leaf node for FX tracing.
Args:
x (Tensor): input image
ss (tuple[int, int]): height and width to scale to
crop_scale (bool): whether to crop instead of interpolate to achieve the desired scale. Defaults to False
Returns:
Tensor: the "scaled" image batch tensor
"""
H, W = x.shape[-2:]
if H != ss[0] or W != ss[1]:
if crop_scale and ss[0] <= H and ss[1] <= W:
cu, cl = int(round((H - ss[0]) / 2.)), int(round((W - ss[1]) / 2.))
x = x[:, :, cu:cu + ss[0], cl:cl + ss[1]]
else:
x = torch.nn.functional.interpolate(x, size=ss, mode='bicubic', align_corners=False)
return x
class CrossViT(nn.Module):
""" Vision Transformer with support for patch or hybrid CNN input stage
"""
@ -342,17 +367,12 @@ class CrossViT(nn.Module):
range(self.num_branches)])
def forward_features(self, x):
B, C, H, W = x.shape
B = x.shape[0]
xs = []
for i, patch_embed in enumerate(self.patch_embed):
x_ = x
ss = self.img_size_scaled[i]
if H != ss[0] or W != ss[1]:
if self.crop_scale and ss[0] <= H and ss[1] <= W:
cu, cl = int(round((H - ss[0]) / 2.)), int(round((W - ss[1]) / 2.))
x_ = x_[:, :, cu:cu + ss[0], cl:cl + ss[1]]
else:
x_ = torch.nn.functional.interpolate(x_, size=ss, mode='bicubic', align_corners=False)
x_ = scale_image(x_, ss, self.crop_scale)
x_ = patch_embed(x_)
cls_tokens = self.cls_token_0 if i == 0 else self.cls_token_1 # hard-coded for torch jit script
cls_tokens = cls_tokens.expand(B, -1, -1)

@ -1,42 +1,31 @@
""" PyTorch FX Based Feature Extraction Helpers
An extension/alternative to timm.models.features making use of PyTorch FX. Here, the idea is to:
1. Symbolically trace a model producing a graph based intermediate representation (PyTorch FX functionality with
some custom tweaks)
2. Identify desired feature extraction nodes and reconfigure them as output nodes while deleting all unecessary
nodes. (custom - inspired by https://github.com/pytorch/vision/pull/3597)
3. Write the resulting graph into a GraphModule (PyTorch FX functionality)
Copyright 2021 Alexander Soare
Using https://pytorch.org/vision/stable/feature_extraction.html
"""
from typing import Callable, Dict, Union, List, Optional
import math
from collections import OrderedDict
from pprint import pprint
from inspect import ismethod
import re
import warnings
from copy import deepcopy
from itertools import chain
import torch
from typing import Callable
from torch import nn
from torch import fx
import torch.nn.functional as F
from torch.fx.graph_module import _copy_attr
from .features import _get_feature_info
from .fx_helpers import fx_float_to_int
# Layers we went to treat as leaf modules for FeatureGraphNet
from .layers import Conv2dSame, ScaledStdConv2dSame, BatchNormAct2d, BlurPool2d, CondConv2d, StdConv2dSame
from .layers import GatherExcite, DropPath
try:
from torchvision.models.feature_extraction import create_feature_extractor
except ImportError:
pass
# Layers we went to treat as leaf modules
from .layers import Conv2dSame, ScaledStdConv2dSame, BatchNormAct2d, BlurPool2d, CondConv2d, StdConv2dSame, DropPath
from .layers.non_local_attn import BilinearAttnTransform
from .layers.pool2d_same import MaxPool2dSame, AvgPool2dSame
# These modules will not be traced through.
# NOTE: By default, any modules from timm.models.layers that we want to treat as leaf modules go here
# BUT modules from timm.models should use the registration mechanism below
_leaf_modules = {
Conv2dSame, ScaledStdConv2dSame, BatchNormAct2d, BlurPool2d, CondConv2d, StdConv2dSame, GatherExcite, DropPath,
BilinearAttnTransform, MaxPool2dSame, AvgPool2dSame
BatchNormAct2d, # reason: flow control for jit scripting
BilinearAttnTransform, # reason: flow control t <= 1
BlurPool2d, # reason: TypeError: F.conv2d received Proxy in groups=x.shape[1]
# Reason: get_same_padding has a max which raises a control flow error
Conv2dSame, MaxPool2dSame, ScaledStdConv2dSame, StdConv2dSame, AvgPool2dSame,
CondConv2d, # reason: TypeError: F.conv2d received Proxy in groups=self.groups * B (because B = x.shape[0])
DropPath, # reason: TypeError: rand recieved Proxy in `size` argument
}
try:
@ -54,425 +43,16 @@ def register_leaf_module(module: nn.Module):
return module
# These functions will not be traced through
_autowrap_functions=(fx_float_to_int,)
class TimmTracer(fx.Tracer):
"""
Temporary bridge from torch.fx.Tracer to include any general workarounds required to make FX work for us
"""
def __init__(self, autowrap_modules=(math, ), autowrap_functions=(), enable_cpatching=False):
super().__init__(autowrap_modules=autowrap_modules, enable_cpatching=enable_cpatching)
# FIXME: This is a workaround pending on a PyTorch PR https://github.com/pytorch/pytorch/pull/62106
self._autowrap_function_ids.update(set([id(f) for f in autowrap_functions]))
def create_node(self, kind, target, args, kwargs, name=None, type_expr=None):
# FIXME: This is a workaround pending on a PyTorch PR https://github.com/pytorch/pytorch/pull/62095
if target == F.pad:
kwargs['value'] = float(kwargs['value'])
return super().create_node(kind, target, args, kwargs, name=name, type_expr=type_expr)
class LeafNodeTracer(TimmTracer):
"""
Account for desired leaf nodes according to _leaf_modules and _autowrap functions
"""
def __init__(self):
super().__init__(autowrap_functions=_autowrap_functions)
def is_leaf_module(self, m: nn.Module, module_qualname: str) -> bool:
if isinstance(m, tuple(_leaf_modules)):
return True
return super().is_leaf_module(m, module_qualname)
def _is_subseq(x, y):
"""Check if y is a subseqence of x
https://stackoverflow.com/a/24017747/4391249
"""
iter_x = iter(x)
return all(any(x_item == y_item for x_item in iter_x) for y_item in y)
# Taken from https://github.com/pytorch/examples/blob/master/fx/module_tracer.py with modifications for storing
# qualified names for all Nodes, not just top-level Modules
class NodePathTracer(LeafNodeTracer):
"""
NodePathTracer is an FX tracer that, for each operation, also records the
qualified name of the Node from which the operation originated. A
qualified name here is a `.` seperated path walking the hierarchy from top
level module down to leaf operation or leaf module. The name of the top
level module is not included as part of the qualified name. For example,
if we trace a module who's forward method applies a ReLU module, the
qualified name for that node will simply be 'relu'.
Some notes on the specifics:
- Nodes are recorded to `self.node_to_qualname` which is a dictionary
mapping a given Node object to its qualified name.
- Nodes are recorded in the order which they are executed during
tracing.
- When a duplicate qualified name is encountered, a suffix of the form
_{int} is added. The counter starts from 1.
"""
def __init__(self, *args, **kwargs):
super(NodePathTracer, self).__init__(*args, **kwargs)
# Track the qualified name of the Node being traced
self.current_module_qualname = ''
# A map from FX Node to the qualified name
self.node_to_qualname = OrderedDict()
def call_module(self, m: torch.nn.Module, forward: Callable, args, kwargs):
"""
Override of `fx.Tracer.call_module`
This override:
1) Stores away the qualified name of the caller for restoration later
2) Adds the qualified name of the caller to
`current_module_qualname` for retrieval by `create_proxy`
3) Once a leaf module is reached, calls `create_proxy`
4) Restores the caller's qualified name into current_module_qualname
"""
old_qualname = self.current_module_qualname
try:
module_qualname = self.path_of_module(m)
self.current_module_qualname = module_qualname
if not self.is_leaf_module(m, module_qualname):
out = forward(*args, **kwargs)
return out
return self.create_proxy('call_module', module_qualname, args, kwargs)
finally:
self.current_module_qualname = old_qualname
def create_proxy(self, kind: str, target: fx.node.Target, args, kwargs,
name=None, type_expr=None) -> fx.proxy.Proxy:
"""
Override of `Tracer.create_proxy`. This override intercepts the recording
of every operation and stores away the current traced module's qualified
name in `node_to_qualname`
"""
proxy = super().create_proxy(kind, target, args, kwargs, name, type_expr)
self.node_to_qualname[proxy.node] = self._get_node_qualname(
self.current_module_qualname, proxy.node)
return proxy
def _get_node_qualname(
self, module_qualname: str, node: fx.node.Node) -> str:
node_qualname = module_qualname
if node.op == 'call_module':
# Node terminates in a leaf module so the module_qualname is a
# complete description of the node
for existing_qualname in reversed(self.node_to_qualname.values()):
# Check to see if existing_qualname is of the form
# {node_qualname} or {node_qualname}_{int}
if re.match(rf'{node_qualname}(_[0-9]+)?$',
existing_qualname) is not None:
postfix = existing_qualname.replace(node_qualname, '')
if len(postfix):
# Existing_qualname is of the form {node_qualname}_{int}
next_index = int(postfix[1:]) + 1
else:
# existing_qualname is of the form {node_qualname}
next_index = 1
node_qualname += f'_{next_index}'
break
else:
# Node terminates in non- leaf module so the node name needs to be
# appended
if len(node_qualname) > 0:
# Only append '.' if we are deeper than the top level module
node_qualname += '.'
node_qualname += str(node)
return node_qualname
# Functions we want to autowrap (treat them as leaves)
_autowrap_functions = set()
def _warn_graph_differences(
train_tracer: NodePathTracer, eval_tracer: NodePathTracer):
def register_autowrap_function(func: Callable):
"""
Utility function for warning the user if there are differences between
the train graph and the eval graph.
Decorator for functions which ought not to be traced through
"""
train_nodes = list(train_tracer.node_to_qualname.values())
eval_nodes = list(eval_tracer.node_to_qualname.values())
if len(train_nodes) == len(eval_nodes) and [
t == e for t, e in zip(train_nodes, eval_nodes)]:
return
suggestion_msg = (
"When choosing nodes for feature extraction, you may need to specify "
"output nodes for train and eval mode separately")
if _is_subseq(train_nodes, eval_nodes):
msg = ("NOTE: The nodes obtained by tracing the model in eval mode "
"are a subsequence of those obtained in train mode. ")
elif _is_subseq(eval_nodes, train_nodes):
msg = ("NOTE: The nodes obtained by tracing the model in train mode "
"are a subsequence of those obtained in eval mode. ")
else:
msg = ("The nodes obtained by tracing the model in train mode "
"are different to those obtained in eval mode. ")
warnings.warn(msg + suggestion_msg)
def print_graph_node_qualified_names(
model: nn.Module, tracer_kwargs: Dict = {}):
"""
Dev utility to prints nodes in order of execution. Useful for choosing
nodes for a FeatureGraphNet design. There are two reasons that qualified
node names can't easily be read directly from the code for a model:
1. Not all submodules are traced through. Modules from `torch.nn` all
fall within this category.
2. Node qualified names that occur more than once in the graph get a
`_{counter}` postfix.
The model will be traced twice: once in train mode, and once in eval mode.
If there are discrepancies between the graphs produced, both sets will
be printed and the user will be warned.
Args:
model (nn.Module): model on which we will extract the features
tracer_kwargs (Dict): a dictionary of keywork arguments for
`NodePathTracer` (which passes them onto it's parent class
`torch.fx.Tracer`).
"""
train_tracer = NodePathTracer(**tracer_kwargs)
train_tracer.trace(model.train())
eval_tracer = NodePathTracer(**tracer_kwargs)
eval_tracer.trace(model.eval())
train_nodes = list(train_tracer.node_to_qualname.values())
eval_nodes = list(eval_tracer.node_to_qualname.values())
if len(train_nodes) == len(eval_nodes) and [
t == e for t, e in zip(train_nodes, eval_nodes)]:
# Nodes are aligned in train vs eval mode
pprint(list(train_tracer.node_to_qualname.values()))
return
print("Nodes from train mode:")
pprint(list(train_tracer.node_to_qualname.values()))
print()
print("Nodes from eval mode:")
pprint(list(eval_tracer.node_to_qualname.values()))
print()
_warn_graph_differences(train_tracer, eval_tracer)
class DualGraphModule(fx.GraphModule):
"""
A derivative of `fx.GraphModule`. Differs in the following ways:
- Requires a train and eval version of the underlying graph
- Copies submodules according to the nodes of both train and eval graphs.
- Calling train(mode) switches between train graph and eval graph.
"""
def __init__(self,
root: torch.nn.Module,
train_graph: fx.Graph,
eval_graph: fx.Graph,
class_name: str = 'GraphModule'):
"""
Args:
root (torch.nn.Module): module from which the copied module
hierarchy is built
train_graph (Graph): the graph that should be used in train mode
eval_graph (Graph): the graph that should be used in eval mode
"""
super(fx.GraphModule, self).__init__()
self.__class__.__name__ = class_name
self.train_graph = train_graph
self.eval_graph = eval_graph
# Copy all get_attr and call_module ops (indicated by BOTH train and
# eval graphs)
for node in chain(iter(train_graph.nodes), iter(eval_graph.nodes)):
if node.op in ['get_attr', 'call_module']:
assert isinstance(node.target, str)
_copy_attr(root, self, node.target)
# eval mode by default
self.eval()
self.graph = eval_graph
# (borrowed from fx.GraphModule):
# Store the Tracer class responsible for creating a Graph separately as part of the
# GraphModule state, except when the Tracer is defined in a local namespace.
# Locally defined Tracers are not pickleable. This is needed because torch.package will
# serialize a GraphModule without retaining the Graph, and needs to use the correct Tracer
# to re-create the Graph during deserialization.
# TODO uncomment this when https://github.com/pytorch/pytorch/pull/63121 is available
# assert self.eval_graph._tracer_cls == self.train_graph._tracer_cls, \
# "Train mode and eval mode should use the same tracer class"
# self._tracer_cls = None
# if self.graph._tracer_cls and '<locals>' not in self.graph._tracer_cls.__qualname__:
# self._tracer_cls = self.graph._tracer_cls
def train(self, mode=True):
"""
Swap out the graph depending on the training mode.
NOTE this should be safe when calling model.eval() because that just
calls this with mode == False.
"""
if mode:
self.graph = self.train_graph
else:
self.graph = self.eval_graph
return super().train(mode=mode)
def build_feature_graph_net(
model: nn.Module,
return_nodes: Union[List[str], Dict[str, str]],
train_return_nodes: Optional[Union[List[str], Dict[str, str]]] = None,
eval_return_nodes: Optional[Union[List[str], Dict[str, str]]] = None,
tracer_kwargs: Dict = {}) -> fx.GraphModule:
"""
Creates a new graph module that returns intermediate nodes from a given
model as dictionary with user specified keys as strings, and the requested
outputs as values. This is achieved by re-writing the computation graph of
the model via FX to return the desired nodes as outputs. All unused nodes
are removed, together with their corresponding parameters.
A note on node specification: A node qualified name is specified as a `.`
seperated path walking the hierarchy from top level module down to leaf
operation or leaf module. For instance `blocks.5.3.bn1`. The keys of the
`return_nodes` argument should point to either a node's qualified name,
or some truncated version of it. For example, one could provide `blocks.5`
as a key, and the last node with that prefix will be selected.
`print_graph_node_qualified_names` is a useful helper function for getting
a list of qualified names of a model.
An attempt is made to keep all non-parametric properties of the original
model, but existing properties of the constructed `GraphModule` are not
overwritten.
Args:
model (nn.Module): model on which we will extract the features
return_nodes (Union[List[name], Dict[name, new_name]])): either a list
or a dict containing the names (or partial names - see note above)
of the nodes for which the activations will be returned. If it is
a `Dict`, the keys are the qualified node names, and the values
are the user-specified keys for the graph module's returned
dictionary. If it is a `List`, it is treated as a `Dict` mapping
node specification strings directly to output names.
tracer_kwargs (Dict): a dictionary of keywork arguments for
`NodePathTracer` (which passes them onto it's parent class
`torch.fx.Tracer`).
Examples::
>>> model = torchvision.models.resnet18()
>>> # extract layer1 and layer3, giving as names `feat1` and feat2`
>>> graph_module = torchvision.models._utils.build_feature_graph_net(m,
>>> {'layer1': 'feat1', 'layer3': 'feat2'})
>>> out = graph_module(torch.rand(1, 3, 224, 224))
>>> print([(k, v.shape) for k, v in out.items()])
>>> [('feat1', torch.Size([1, 64, 56, 56])),
>>> ('feat2', torch.Size([1, 256, 14, 14]))]
"""
is_training = model.training
if isinstance(return_nodes, list):
return_nodes = {n: n for n in return_nodes}
return_nodes = {str(k): str(v) for k, v in return_nodes.items()}
assert not ((train_return_nodes is None) ^ (eval_return_nodes is None)), \
("If any of `train_return_nodes` and `eval_return_nodes` are "
"specified, then both should be specified")
if train_return_nodes is None:
train_return_nodes = deepcopy(return_nodes)
eval_return_nodes = deepcopy(return_nodes)
# Repeat the tracing and graph rewriting for train and eval mode
tracers = {}
graphs = {}
return_nodes = {
'train': train_return_nodes,
'eval': eval_return_nodes
}
for mode in ['train', 'eval']:
if mode == 'train':
model.train()
elif mode == 'eval':
model.eval()
# Instantiate our NodePathTracer and use that to trace the model
tracer = NodePathTracer(**tracer_kwargs)
graph = tracer.trace(model)
name = model.__class__.__name__ if isinstance(
model, nn.Module) else model.__name__
graph_module = fx.GraphModule(tracer.root, graph, name)
available_nodes = [f'{v}.{k}' for k, v in tracer.node_to_qualname.items()]
# FIXME We don't know if we should expect this to happen
assert len(set(available_nodes)) == len(available_nodes), \
"There are duplicate nodes! Please raise an issue https://github.com/pytorch/vision/issues"
# Check that all outputs in return_nodes are present in the model
for query in return_nodes[mode].keys():
if not any([m.startswith(query) for m in available_nodes]):
raise ValueError(f"return_node: {query} is not present in model")
# Remove existing output nodes (train mode)
orig_output_nodes = []
for n in reversed(graph_module.graph.nodes):
if n.op == "output":
orig_output_nodes.append(n)
assert len(orig_output_nodes)
for n in orig_output_nodes:
graph_module.graph.erase_node(n)
# Find nodes corresponding to return_nodes and make them into output_nodes
nodes = [n for n in graph_module.graph.nodes]
output_nodes = OrderedDict()
for n in reversed(nodes):
if 'tensor_constant' in str(n):
# NOTE Without this control flow we would get a None value for
# `module_qualname = tracer.node_to_qualname.get(n)`.
# On the other hand, we can safely assume that we'll never need to
# get this as an interesting intermediate node.
continue
module_qualname = tracer.node_to_qualname.get(n)
for query in return_nodes[mode]:
depth = query.count('.')
if '.'.join(module_qualname.split('.')[:depth + 1]) == query:
output_nodes[return_nodes[mode][query]] = n
return_nodes[mode].pop(query)
break
output_nodes = OrderedDict(reversed(list(output_nodes.items())))
# And add them in the end of the graph
with graph_module.graph.inserting_after(nodes[-1]):
graph_module.graph.output(output_nodes)
# Remove unused modules / parameters
graph_module.graph.eliminate_dead_code()
graph_module.recompile()
# Keep track of the tracer and graph so we can choose the main one
tracers[mode] = tracer
graphs[mode] = graph
# Warn user if there are any discrepancies between the graphs of the
# train and eval modes
_warn_graph_differences(tracers['train'], tracers['eval'])
# Build the final graph module
graph_module = DualGraphModule(
model, graphs['train'], graphs['eval'], class_name=name)
# Keep non-parameter model properties for reference
for attr_str in model.__dir__():
attr = getattr(model, attr_str)
if (not attr_str.startswith('_')
and attr_str not in graph_module.__dir__()
and not ismethod(attr)
and not isinstance(attr, (nn.Module, nn.Parameter))):
setattr(graph_module, attr_str, attr)
# Restore original training mode
graph_module.train(is_training)
return graph_module
_autowrap_functions.add(func)
return func
class FeatureGraphNet(nn.Module):
@ -483,7 +63,10 @@ class FeatureGraphNet(nn.Module):
assert len(out_map) == len(out_indices)
return_nodes = {info['module']: out_map[i] if out_map is not None else info['module']
for i, info in enumerate(self.feature_info) if i in out_indices}
self.graph_module = build_feature_graph_net(model, return_nodes)
self.graph_module = create_feature_extractor(
model, return_nodes,
tracer_kwargs={'leaf_modules': list(_leaf_modules), 'autowrap_functions': list(_autowrap_functions)})
def forward(self, x):
return list(self.graph_module(x).values())

@ -1,7 +0,0 @@
def fx_float_to_int(x: float) -> int:
"""
Symbolic tracing helper to substitute for inbuilt `int`.
Hint: Inbuilt `int` can't accept an argument of type `Proxy`
"""
return int(x)

@ -22,7 +22,7 @@ import torch.nn.functional as F
from .helpers import to_2tuple, make_divisible
from .weight_init import trunc_normal_
from timm.models.fx_helpers import fx_and
from .trace_utils import _assert
def rel_logits_1d(q, rel_k, permute_mask: List[int]):
@ -37,7 +37,7 @@ def rel_logits_1d(q, rel_k, permute_mask: List[int]):
permute_mask: permute output dim according to this
"""
B, H, W, dim = q.shape
x = torch.matmul(q, rel_k.transpose(-1, -2))
x = (q @ rel_k.transpose(-1, -2))
x = x.reshape(-1, W, 2 * W -1)
# pad to shift from relative to absolute indexing
@ -134,8 +134,8 @@ class BottleneckAttn(nn.Module):
def forward(self, x):
B, C, H, W = x.shape
torch._assert(H == self.pos_embed.height, '')
torch._assert(W == self.pos_embed.width, '')
_assert(H == self.pos_embed.height, '')
_assert(W == self.pos_embed.width, '')
x = self.qkv(x) # B, (2 * dim_head_qk + dim_head_v) * num_heads, H, W

@ -12,6 +12,8 @@ Hacked together by / Copyright 2020 Ross Wightman
import torch
import torch.nn as nn
from .trace_utils import _assert
class EvoNormBatch2d(nn.Module):
def __init__(self, num_features, apply_act=True, momentum=0.1, eps=1e-5, drop_block=None):
@ -72,9 +74,9 @@ class EvoNormSample2d(nn.Module):
nn.init.ones_(self.v)
def forward(self, x):
torch._assert(x.dim() == 4, 'expected 4D input')
_assert(x.dim() == 4, 'expected 4D input')
B, C, H, W = x.shape
torch._assert(C % self.groups == 0, '')
_assert(C % self.groups == 0, '')
if self.apply_act:
n = x * (x * self.v).sigmoid()
x = x.reshape(B, self.groups, -1)

@ -7,7 +7,6 @@ Official code consulted as reference: https://github.com/xvjiarui/GCNet
Hacked together by / Copyright 2021 Ross Wightman
"""
import torch
from torch import nn as nn
import torch.nn.functional as F
@ -53,7 +52,7 @@ class GlobalContext(nn.Module):
if self.conv_attn is not None:
attn = self.conv_attn(x).reshape(B, 1, H * W) # (B, 1, H * W)
attn = F.softmax(attn, dim=-1).unsqueeze(3) # (B, 1, H * W, 1)
context = torch.matmul(x.reshape(B, C, H * W).unsqueeze(1), attn)
context = x.reshape(B, C, H * W).unsqueeze(1) @ attn
context = context.view(B, C, 1, 1)
else:
context = x.mean(dim=(2, 3), keepdim=True)

@ -16,7 +16,7 @@ The attention mechanism works but it's slow as implemented.
Hacked together by / Copyright 2021 Ross Wightman
"""
from typing import Tuple, List
from typing import List
import torch
from torch import nn
@ -24,6 +24,7 @@ import torch.nn.functional as F
from .helpers import make_divisible
from .weight_init import trunc_normal_
from .trace_utils import _assert
def rel_logits_1d(q, rel_k, permute_mask: List[int]):
@ -41,7 +42,7 @@ def rel_logits_1d(q, rel_k, permute_mask: List[int]):
rel_size = rel_k.shape[0]
win_size = (rel_size + 1) // 2
x = torch.matmul(q, rel_k.transpose(-1, -2))
x = (q @ rel_k.transpose(-1, -2))
x = x.reshape(-1, W, rel_size)
# pad to shift from relative to absolute indexing
@ -167,8 +168,8 @@ class HaloAttn(nn.Module):
def forward(self, x):
B, C, H, W = x.shape
torch._assert(H % self.block_size == 0, '')
torch._assert(W % self.block_size == 0, '')
_assert(H % self.block_size == 0, '')
_assert(W % self.block_size == 0, '')
num_h_blocks = H // self.block_size
num_w_blocks = W // self.block_size
num_blocks = num_h_blocks * num_w_blocks

@ -116,8 +116,8 @@ class LambdaLayer(nn.Module):
v = self.norm_v(v).reshape(B, self.dim_v, M).transpose(-1, -2) # B, M, V
k = F.softmax(k.reshape(B, self.dim_qk, M), dim=-1) # B, K, M
content_lam = torch.matmul(k, v) # B, K, V
content_out = torch.matmul(q, content_lam.unsqueeze(1)) # B, num_heads, M, V
content_lam = k @ v # B, K, V
content_out = q @ content_lam.unsqueeze(1) # B, num_heads, M, V
if self.pos_emb is None:
position_lam = self.conv_lambda(v.reshape(B, 1, H, W, self.dim_v)) # B, H, W, V, K

@ -10,7 +10,7 @@ from torch.nn import functional as F
from .conv_bn_act import ConvBnAct
from .helpers import make_divisible
from timm.models.fx_helpers import fx_and
from .trace_utils import _assert
class NonLocalAttn(nn.Module):
@ -84,7 +84,7 @@ class BilinearAttnTransform(nn.Module):
def resize_mat(self, x, t: int):
B, C, block_size, block_size1 = x.shape
torch._assert(block_size == block_size1, '')
_assert(block_size == block_size1, '')
if t <= 1:
return x
x = x.view(B * C, -1, 1, 1)
@ -96,8 +96,8 @@ class BilinearAttnTransform(nn.Module):
return x
def forward(self, x):
torch._assert(x.shape[-1] % self.block_size == 0, '')
torch._assert(x.shape[-2] % self.block_size == 0, '')
_assert(x.shape[-1] % self.block_size == 0, '')
_assert(x.shape[-2] % self.block_size == 0, '')
B, C, H, W = x.shape
out = self.conv1(x)
rp = F.adaptive_max_pool2d(out, (self.block_size, 1))

@ -9,6 +9,7 @@ from torch import nn as nn
from .conv_bn_act import ConvBnAct
from .helpers import make_divisible
from .trace_utils import _assert
def _kernel_valid(k):
@ -34,7 +35,7 @@ class SelectiveKernelAttn(nn.Module):
self.fc_select = nn.Conv2d(attn_channels, channels * num_paths, kernel_size=1, bias=False)
def forward(self, x):
torch._assert(x.shape[1] == self.num_paths, '')
_assert(x.shape[1] == self.num_paths, '')
x = x.sum(1).mean((2, 3), keepdim=True)
x = self.fc_reduce(x)
x = self.bn(x)

@ -1,183 +0,0 @@
""" Shifted Window Attn
This is a WIP experiment to apply windowed attention from the Swin Transformer
to a stand-alone module for use as an attn block in conv nets.
Based on original swin window code at https://github.com/microsoft/Swin-Transformer
Swin Transformer paper: https://arxiv.org/pdf/2103.14030.pdf
"""
from typing import Optional
import torch
import torch.nn as nn
from .drop import DropPath
from .helpers import to_2tuple
from .weight_init import trunc_normal_
from timm.models.fx_helpers import fx_float_to_int
def window_partition(x, win_size: int):
"""
Args:
x: (B, H, W, C)
win_size (int): window size
Returns:
windows: (num_windows*B, window_size, window_size, C)
"""
B, H, W, C = x.shape
x = x.view(B, H // win_size, win_size, W // win_size, win_size, C)
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, win_size, win_size, C)
return windows
def window_reverse(windows, win_size: int, H: int, W: int):
"""
Args:
windows: (num_windows*B, window_size, window_size, C)
win_size (int): Window size
H (int): Height of image
W (int): Width of image
Returns:
x: (B, H, W, C)
"""
B = fx_float_to_int(windows.shape[0] / (H * W / win_size / win_size))
x = windows.view(B, H // win_size, W // win_size, win_size, win_size, -1)
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
return x
class WindowAttention(nn.Module):
r""" Window based multi-head self attention (W-MSA) module with relative position bias.
It supports both of shifted and non-shifted window.
Args:
dim (int): Number of input channels.
win_size (int): The height and width of the window.
num_heads (int): Number of attention heads.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
"""
def __init__(
self, dim, dim_out=None, feat_size=None, stride=1, win_size=8, shift_size=None, num_heads=8,
qkv_bias=True, attn_drop=0.):
super().__init__()
self.dim_out = dim_out or dim
self.feat_size = to_2tuple(feat_size)
self.win_size = win_size
self.shift_size = shift_size or win_size // 2
if min(self.feat_size) <= win_size:
# if window size is larger than input resolution, we don't partition windows
self.shift_size = 0
self.win_size = min(self.feat_size)
assert 0 <= self.shift_size < self.win_size, "shift_size must in 0-window_size"
self.num_heads = num_heads
head_dim = self.dim_out // num_heads
self.scale = head_dim ** -0.5
if self.shift_size > 0:
# calculate attention mask for SW-MSA
H, W = self.feat_size
img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
h_slices = (
slice(0, -self.win_size),
slice(-self.win_size, -self.shift_size),
slice(-self.shift_size, None))
w_slices = (
slice(0, -self.win_size),
slice(-self.win_size, -self.shift_size),
slice(-self.shift_size, None))
cnt = 0
for h in h_slices:
for w in w_slices:
img_mask[:, h, w, :] = cnt
cnt += 1
mask_windows = window_partition(img_mask, self.win_size) # num_win, window_size, window_size, 1
mask_windows = mask_windows.view(-1, self.win_size * self.win_size)
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
else:
attn_mask = None
self.register_buffer("attn_mask", attn_mask)
# define a parameter table of relative position bias
self.relative_position_bias_table = nn.Parameter(
# 2 * Wh - 1 * 2 * Ww - 1, nH
torch.zeros((2 * self.win_size - 1) * (2 * self.win_size - 1), num_heads))
trunc_normal_(self.relative_position_bias_table, std=.02)
# get pair-wise relative position index for each token inside the window
coords_h = torch.arange(self.win_size)
coords_w = torch.arange(self.win_size)
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
relative_coords[:, :, 0] += self.win_size - 1 # shift to start from 0
relative_coords[:, :, 1] += self.win_size - 1
relative_coords[:, :, 0] *= 2 * self.win_size - 1
relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
self.register_buffer("relative_position_index", relative_position_index)
self.qkv = nn.Linear(dim, self.dim_out * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.softmax = nn.Softmax(dim=-1)
self.pool = nn.AvgPool2d(2, 2) if stride == 2 else nn.Identity()
def reset_parameters(self):
trunc_normal_(self.qkv.weight, std=self.qkv.weight.shape[1] ** -0.5)
trunc_normal_(self.relative_position_bias_table, std=.02)
def forward(self, x):
B, C, H, W = x.shape
x = x.permute(0, 2, 3, 1)
# cyclic shift
if self.shift_size > 0:
shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
else:
shifted_x = x
# partition windows
win_size_sq = self.win_size * self.win_size
x_windows = window_partition(shifted_x, self.win_size) # num_win * B, window_size, window_size, C
x_windows = x_windows.view(-1, win_size_sq, C) # num_win * B, window_size*window_size, C
BW, N, _ = x_windows.shape
qkv = self.qkv(x_windows)
qkv = qkv.reshape(BW, N, 3, self.num_heads, self.dim_out // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2]
q = q * self.scale
attn = torch.matmul(q, k.transpose(-2, -1))
relative_position_bias = self.relative_position_bias_table[
self.relative_position_index.view(-1)].view(win_size_sq, win_size_sq, -1)
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh * Ww, Wh * Ww
attn = attn + relative_position_bias.unsqueeze(0)
if self.attn_mask is not None:
num_win = self.attn_mask.shape[0]
attn = attn.view(B, num_win, self.num_heads, N, N) + self.attn_mask.unsqueeze(1).unsqueeze(0)
attn = attn.view(-1, self.num_heads, N, N)
attn = self.softmax(attn)
attn = self.attn_drop(attn)
x = torch.matmul(attn, v).transpose(1, 2).reshape(BW, N, self.dim_out)
# merge windows
x = x.view(-1, self.win_size, self.win_size, self.dim_out)
shifted_x = window_reverse(x, self.win_size, H, W) # B H' W' C
# reverse cyclic shift
if self.shift_size > 0:
x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
else:
x = shifted_x
x = x.view(B, H, W, self.dim_out).permute(0, 3, 1, 2)
x = self.pool(x)
return x

@ -293,10 +293,10 @@ class Attention(nn.Module):
k = k.permute(0, 2, 1, 3)
v = v.permute(0, 2, 1, 3)
attn = torch.matmul(q, k.transpose(-2, -1)) * self.scale + self.get_attention_biases(x.device)
attn = q @ k.transpose(-2, -1) * self.scale + self.get_attention_biases(x.device)
attn = attn.softmax(dim=-1)
x = torch.matmul(attn, v).transpose(1, 2).reshape(B, N, self.dh)
x = (attn @ v).transpose(1, 2).reshape(B, N, self.dh)
x = self.proj(x)
return x
@ -387,10 +387,10 @@ class AttentionSubsample(nn.Module):
v = v.permute(0, 2, 1, 3) # BHNC
q = self.q(x).view(B, self.resolution_2, self.num_heads, self.key_dim).permute(0, 2, 1, 3)
attn = torch.matmul(q, k.transpose(-2, -1)) * self.scale + self.get_attention_biases(x.device)
attn = q @ k.transpose(-2, -1) * self.scale + self.get_attention_biases(x.device)
attn = attn.softmax(dim=-1)
x = torch.matmul(attn, v).transpose(1, 2).reshape(B, -1, self.dh)
x = (attn @ v).transpose(1, 2).reshape(B, -1, self.dh)
x = self.proj(x)
return x

@ -25,13 +25,13 @@ import torch.nn.functional as F
from torch import nn
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .fx_features import register_autowrap_function
from .helpers import build_model_with_cfg, named_apply
from .fx_helpers import fx_float_to_int
from .layers import PatchEmbed, Mlp, DropPath, create_classifier, trunc_normal_
from .layers.trace_utils import _assert
from .layers import create_conv2d, create_pool2d, to_ntuple
from .registry import register_model
_logger = logging.getLogger(__name__)
@ -85,12 +85,12 @@ class Attention(nn.Module):
qkv = self.qkv(x).reshape(B, T, N, 3, self.num_heads, C // self.num_heads).permute(3, 0, 4, 1, 2, 5)
q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
attn = torch.matmul(q, k.transpose(-2, -1)) * self.scale # (B, H, T, N, N)
attn = (q @ k.transpose(-2, -1)) * self.scale # (B, H, T, N, N)
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
# (B, H, T, N, C'), permute -> (B, T, N, C', H)
x = torch.matmul(attn, v).permute(0, 2, 3, 4, 1).reshape(B, T, N, C)
x = (attn @ v).permute(0, 2, 3, 4, 1).reshape(B, T, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x # (B, T, N, C)
@ -130,8 +130,8 @@ class ConvPool(nn.Module):
"""
x is expected to have shape (B, C, H, W)
"""
torch._assert(x.shape[-2] % 2 == 0, 'BlockAggregation requires even input spatial dims')
torch._assert(x.shape[-1] % 2 == 0, 'BlockAggregation requires even input spatial dims')
_assert(x.shape[-2] % 2 == 0, 'BlockAggregation requires even input spatial dims')
_assert(x.shape[-1] % 2 == 0, 'BlockAggregation requires even input spatial dims')
x = self.conv(x)
# Layer norm done over channel dim only
x = self.norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
@ -146,8 +146,8 @@ def blockify(x, block_size: int):
block_size (int): edge length of a single square block in units of H, W
"""
B, H, W, C = x.shape
torch._assert(H % block_size == 0, '`block_size` must divide input height evenly')
torch._assert(W % block_size == 0, '`block_size` must divide input width evenly')
_assert(H % block_size == 0, '`block_size` must divide input height evenly')
_assert(W % block_size == 0, '`block_size` must divide input width evenly')
grid_height = H // block_size
grid_width = W // block_size
x = x.reshape(B, grid_height, block_size, grid_width, block_size, C)
@ -155,6 +155,7 @@ def blockify(x, block_size: int):
return x # (B, T, N, C)
@register_autowrap_function # reason: int receives Proxy
def deblockify(x, block_size: int):
"""blocks to image
Args:
@ -162,7 +163,7 @@ def deblockify(x, block_size: int):
block_size (int): edge length of a single square block in units of desired H, W
"""
B, T, _, C = x.shape
grid_size = fx_float_to_int(math.sqrt(T))
grid_size = int(math.sqrt(T))
height = width = grid_size * block_size
x = x.reshape(B, grid_size, grid_size, block_size, block_size, C)
x = x.transpose(2, 3).reshape(B, height, width, C)

@ -27,7 +27,6 @@ import torch.nn as nn
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .helpers import build_model_with_cfg
from timm.models.fx_features import register_leaf_module
from .registry import register_model
from .layers import ClassifierHead, DropPath, AvgPool2dSame, ScaledStdConv2d, ScaledStdConv2dSame,\
get_act_layer, get_act_fn, get_attn, make_divisible
@ -319,7 +318,6 @@ class DownsampleAvg(nn.Module):
return self.conv(self.pool(x))
@register_leaf_module # FX feature extraction was giving different valued features. Perhaps to do with control flow?
class NormFreeBlock(nn.Module):
"""Normalization-Free pre-activation block.
"""

@ -21,9 +21,10 @@ import torch.nn as nn
import torch.utils.checkpoint as checkpoint
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .fx_features import register_autowrap_function
from .helpers import build_model_with_cfg, overlay_external_default_cfg
from .fx_helpers import fx_float_to_int
from .layers import PatchEmbed, Mlp, DropPath, to_2tuple, trunc_normal_
from .layers.trace_utils import _assert
from .registry import register_model
from .vision_transformer import checkpoint_filter_fn, _init_vit_weights
@ -102,6 +103,7 @@ def window_partition(x, window_size: int):
return windows
@register_autowrap_function # reason: int argument is a Proxy
def window_reverse(windows, window_size: int, H: int, W: int):
"""
Args:
@ -113,7 +115,7 @@ def window_reverse(windows, window_size: int, H: int, W: int):
Returns:
x: (B, H, W, C)
"""
B = fx_float_to_int(windows.shape[0] / (H * W / window_size / window_size))
B = int(windows.shape[0] / (H * W / window_size / window_size))
x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
return x
@ -177,7 +179,7 @@ class WindowAttention(nn.Module):
q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
q = q * self.scale
attn = torch.matmul(q, k.transpose(-2, -1))
attn = (q @ k.transpose(-2, -1))
relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
@ -194,7 +196,7 @@ class WindowAttention(nn.Module):
attn = self.attn_drop(attn)
x = torch.matmul(attn, v).transpose(1, 2).reshape(B_, N, C)
x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
@ -272,7 +274,7 @@ class SwinTransformerBlock(nn.Module):
def forward(self, x):
H, W = self.input_resolution
B, L, C = x.shape
torch._assert(L == H * W, "input feature has wrong size")
_assert(L == H * W, "input feature has wrong size")
shortcut = x
x = self.norm1(x)
@ -331,8 +333,8 @@ class PatchMerging(nn.Module):
"""
H, W = self.input_resolution
B, L, C = x.shape
torch._assert(L == H * W, "input feature has wrong size")
torch._assert(H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even.")
_assert(L == H * W, "input feature has wrong size")
_assert(H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even.")
x = x.view(B, H, W, C)

@ -12,9 +12,9 @@ import torch.nn as nn
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.models.helpers import build_model_with_cfg
from timm.models.fx_helpers import fx_and
from timm.models.layers import Mlp, DropPath, trunc_normal_
from timm.models.layers.helpers import to_2tuple
from timm.models.layers.trace_utils import _assert
from timm.models.registry import register_model
from timm.models.vision_transformer import resize_pos_embed
@ -64,11 +64,11 @@ class Attention(nn.Module):
q, k = qk.unbind(0) # make torchscript happy (cannot use tensor as tuple)
v = self.v(x).reshape(B, N, self.num_heads, -1).permute(0, 2, 1, 3)
attn = torch.matmul(q, k.transpose(-2, -1)) * self.scale
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = torch.matmul(attn, v).transpose(1, 2).reshape(B, N, -1)
x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
x = self.proj(x)
x = self.proj_drop(x)
return x
@ -138,9 +138,9 @@ class PixelEmbed(nn.Module):
def forward(self, x, pixel_pos):
B, C, H, W = x.shape
torch._assert(H == self.img_size[0],
_assert(H == self.img_size[0],
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]}).")
torch._assert(W == self.img_size[1],
_assert(W == self.img_size[1],
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]}).")
x = self.proj(x)
x = self.unfold(x)

@ -25,7 +25,7 @@ from .layers import Mlp, DropPath, to_2tuple, trunc_normal_
from .fx_features import register_leaf_module
from .registry import register_model
from .vision_transformer import Attention
from .helpers import build_model_with_cfg, overlay_external_default_cfg
from .helpers import build_model_with_cfg
def _cfg(url='', **kwargs):
@ -63,7 +63,7 @@ default_cfgs = {
Size_ = Tuple[int, int]
@register_leaf_module # FX can't symbolically trace control flow in forward method
@register_leaf_module # reason: FX can't symbolically trace control flow in forward method
class LocallyGroupedAttn(nn.Module):
""" LSA: self attention within a group
"""
@ -100,10 +100,10 @@ class LocallyGroupedAttn(nn.Module):
qkv = self.qkv(x).reshape(
B, _h * _w, self.ws * self.ws, 3, self.num_heads, C // self.num_heads).permute(3, 0, 1, 4, 2, 5)
q, k, v = qkv[0], qkv[1], qkv[2]
attn = torch.matmul(q, k.transpose(-2, -1)) * self.scale
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
attn = torch.matmul(attn, v).transpose(2, 3).reshape(B, _h, _w, self.ws, self.ws, C)
attn = (attn @ v).transpose(2, 3).reshape(B, _h, _w, self.ws, self.ws, C)
x = attn.transpose(2, 3).reshape(B, _h * self.ws, _w * self.ws, C)
if pad_r > 0 or pad_b > 0:
x = x[:, :H, :W, :].contiguous()
@ -185,11 +185,11 @@ class GlobalSubSampleAttn(nn.Module):
kv = self.kv(x).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
k, v = kv[0], kv[1]
attn = torch.matmul(q, k.transpose(-2, -1)) * self.scale
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = torch.matmul(attn, v).transpose(1, 2).reshape(B, N, C)
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)

@ -13,7 +13,7 @@ from typing import Union, List, Dict, Any, cast
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .helpers import build_model_with_cfg
from .fx_features import register_leaf_module
from .layers import ClassifierHead, ConvBnAct
from .layers import ClassifierHead
from .registry import register_model
__all__ = [
@ -53,7 +53,7 @@ cfgs: Dict[str, List[Union[str, int]]] = {
}
@register_leaf_module # FX can't symbolically trace control flow in forward method
@register_leaf_module # reason: FX can't symbolically trace control flow in forward method
class ConvMlp(nn.Module):
def __init__(self, in_features=512, out_features=4096, kernel_size=7, mlp_ratio=1.0,

@ -100,10 +100,10 @@ class Attention(nn.Module):
x = self.qkv(x).reshape(B, 3, self.num_heads, self.head_dim, -1).permute(1, 0, 2, 4, 3)
q, k, v = x[0], x[1], x[2]
attn = torch.matmul(q, k.transpose(-2, -1)) * self.scale
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = torch.matmul(attn, v)
x = attn @ v
x = x.permute(0, 1, 3, 2).reshape(B, -1, H, W)
x = self.proj(x)

@ -192,11 +192,11 @@ class Attention(nn.Module):
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
attn = torch.matmul(q, k.transpose(-2, -1)) * self.scale
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = torch.matmul(attn, v).transpose(1, 2).reshape(B, N, C)
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x

@ -98,7 +98,7 @@ default_cfgs = {
}
@register_leaf_module # FX can't symbolically trace torch.arange in forward method
@register_leaf_module # reason: FX can't symbolically trace torch.arange in forward method
class PositionalEncodingFourier(nn.Module):
"""
Positional encoding relying on a fourier kernel matching the one used in the "Attention is all of Need" paper.
@ -274,12 +274,12 @@ class XCA(nn.Module):
# Paper section 3.2 l2-Normalization and temperature scaling
q = torch.nn.functional.normalize(q, dim=-1)
k = torch.nn.functional.normalize(k, dim=-1)
attn = torch.matmul(q, k.transpose(-2, -1)) * self.temperature
attn = (q @ k.transpose(-2, -1)) * self.temperature
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
# (B, H, C', N), permute -> (B, N, H, C')
x = torch.matmul(attn, v).permute(0, 3, 1, 2).reshape(B, N, C)
x = (attn @ v).permute(0, 3, 1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x

Loading…
Cancel
Save