Merge branch 'fx-feature-extract-new' of https://github.com/alexander-soare/pytorch-image-models into alexander-soare-fx-feature-extract-new

pull/989/head
Ross Wightman 3 years ago
commit 32c9937dec

@ -4,9 +4,16 @@ import platform
import os
import fnmatch
try:
from torchvision.models.feature_extraction import create_feature_extractor, get_graph_node_names, NodePathTracer
has_fx_feature_extraction = True
except ImportError:
has_fx_feature_extraction = False
import timm
from timm import list_models, create_model, set_scriptable, has_model_default_key, is_model_default_key, \
get_model_default_value
from timm.models.fx_features import _leaf_modules, _autowrap_functions
if hasattr(torch._C, '_jit_set_profiling_executor'):
# legacy executor is too slow to compile large models for unit tests
@ -297,3 +304,145 @@ def test_model_forward_features(model_name, batch_size):
assert e == o.shape[1]
assert o.shape[0] == batch_size
assert not torch.isnan(o).any()
@pytest.mark.timeout(120)
@pytest.mark.parametrize('model_name', list_models(exclude_filters=EXCLUDE_FILTERS))
@pytest.mark.parametrize('batch_size', [1])
def test_model_forward_fx(model_name, batch_size):
"""
Symbolically trace each model and run single forward pass through the resulting GraphModule
Also check that the output of a forward pass through the GraphModule is the same as that from the original Module
"""
if not has_fx_feature_extraction:
pytest.skip("Can't test FX because Torch >= 1.10 and Torchvision >= 0.11 are required")
model = create_model(model_name, pretrained=False)
model.eval()
input_size = _get_input_size(model=model, target=TARGET_FWD_SIZE)
if max(input_size) > MAX_FWD_SIZE:
pytest.skip("Fixed input size model > limit.")
# This block of code does a bit of juggling to handle any case where there are multiple outputs in train mode
# So we trace once and look at the graph, and get the indices of the nodes that lead into the original fx output
# node. Then we use those indices to select from train_nodes returned by torchvision get_graph_node_names
tracer = NodePathTracer(leaf_modules=list(_leaf_modules), autowrap_functions=list(_autowrap_functions))
graph = tracer.trace(model)
graph_nodes = list(reversed(graph.nodes))
output_node_names = [n.name for n in graph_nodes[0]._input_nodes.keys()]
graph_node_names = [n.name for n in graph_nodes]
output_node_indices = [-graph_node_names.index(node_name) for node_name in output_node_names]
train_nodes, eval_nodes = get_graph_node_names(
model, tracer_kwargs={'leaf_modules': list(_leaf_modules), 'autowrap_functions': list(_autowrap_functions)})
eval_return_nodes = [eval_nodes[ix] for ix in output_node_indices]
fx_model = create_feature_extractor(
model, train_return_nodes=[train_nodes[-1]], eval_return_nodes=eval_return_nodes,
tracer_kwargs={'leaf_modules': list(_leaf_modules), 'autowrap_functions': list(_autowrap_functions)})
inputs = torch.randn((batch_size, *input_size))
outputs = model(inputs)
if isinstance(outputs, tuple):
outputs = torch.cat(outputs)
fx_outputs = tuple(fx_model(inputs).values())
if isinstance(fx_outputs, tuple):
fx_outputs = torch.cat(fx_outputs)
assert torch.all(fx_outputs == outputs)
assert outputs.shape[0] == batch_size
assert not torch.isnan(outputs).any(), 'Output included NaNs'
@pytest.mark.timeout(120)
@pytest.mark.parametrize('model_name', list_models(exclude_filters=EXCLUDE_FILTERS, name_matches_cfg=True))
@pytest.mark.parametrize('batch_size', [2])
def test_model_backward_fx(model_name, batch_size):
"""Symbolically trace each model and run single backward pass through the resulting GraphModule"""
if not has_fx_feature_extraction:
pytest.skip("Can't test FX because Torch >= 1.10 and Torchvision >= 0.11 are required")
input_size = _get_input_size(model_name=model_name, target=TARGET_BWD_SIZE)
if max(input_size) > MAX_BWD_SIZE:
pytest.skip("Fixed input size model > limit.")
model = create_model(model_name, pretrained=False, num_classes=42)
model.train()
num_params = sum([x.numel() for x in model.parameters()])
input_size = _get_input_size(model=model, target=TARGET_FWD_SIZE)
if max(input_size) > MAX_FWD_SIZE:
pytest.skip("Fixed input size model > limit.")
# This block of code does a bit of juggling to handle any case where there are multiple outputs in train mode
# So we trace once and look at the graph, and get the indices of the nodes that lead into the original fx output
# node. Then we use those indices to select from train_nodes returned by torchvision get_graph_node_names
tracer = NodePathTracer(leaf_modules=list(_leaf_modules), autowrap_functions=list(_autowrap_functions))
graph = tracer.trace(model)
graph_nodes = list(reversed(graph.nodes))
output_node_names = [n.name for n in graph_nodes[0]._input_nodes.keys()]
graph_node_names = [n.name for n in graph_nodes]
output_node_indices = [-graph_node_names.index(node_name) for node_name in output_node_names]
train_nodes, eval_nodes = get_graph_node_names(
model, tracer_kwargs={'leaf_modules': list(_leaf_modules), 'autowrap_functions': list(_autowrap_functions)})
train_return_nodes = [train_nodes[ix] for ix in output_node_indices]
model = create_feature_extractor(
model, train_return_nodes=train_return_nodes, eval_return_nodes=[eval_nodes[-1]],
tracer_kwargs={'leaf_modules': list(_leaf_modules), 'autowrap_functions': list(_autowrap_functions)})
inputs = torch.randn((batch_size, *input_size))
outputs = tuple(model(inputs).values())
if isinstance(outputs, tuple):
outputs = torch.cat(outputs)
outputs.mean().backward()
for n, x in model.named_parameters():
assert x.grad is not None, f'No gradient for {n}'
num_grad = sum([x.grad.numel() for x in model.parameters() if x.grad is not None])
assert outputs.shape[-1] == 42
assert num_params == num_grad, 'Some parameters are missing gradients'
assert not torch.isnan(outputs).any(), 'Output included NaNs'
# reason: model is scripted after fx tracing, but beit has torch.jit.is_scripting() control flow
EXCLUDE_FX_JIT_FILTERS = [
'beit_*',
'deit_*_distilled_patch16_224',
'levit*',
'pit_*_distilled_224',
]
@pytest.mark.timeout(120)
@pytest.mark.parametrize(
'model_name', list_models(
exclude_filters=EXCLUDE_FILTERS + EXCLUDE_JIT_FILTERS + EXCLUDE_FX_JIT_FILTERS, name_matches_cfg=True))
@pytest.mark.parametrize('batch_size', [1])
def test_model_forward_fx_torchscript(model_name, batch_size):
"""Symbolically trace each model, script it, and run single forward pass"""
if not has_fx_feature_extraction:
pytest.skip("Can't test FX because Torch >= 1.10 and Torchvision >= 0.11 are required")
input_size = _get_input_size(model_name=model_name, target=TARGET_JIT_SIZE)
if max(input_size) > MAX_JIT_SIZE:
pytest.skip("Fixed input size model > limit.")
with set_scriptable(True):
model = create_model(model_name, pretrained=False)
model.eval()
input_size = _get_input_size(model=model, target=TARGET_FWD_SIZE)
if max(input_size) > MAX_FWD_SIZE:
pytest.skip("Fixed input size model > limit.")
train_nodes, eval_nodes = get_graph_node_names(
model, tracer_kwargs={'leaf_modules': list(_leaf_modules), 'autowrap_functions': list(_autowrap_functions)})
model = create_feature_extractor(
model, train_return_nodes=[train_nodes[-1]], eval_return_nodes=[eval_nodes[-1]],
tracer_kwargs={'leaf_modules': list(_leaf_modules), 'autowrap_functions': list(_autowrap_functions)})
model = torch.jit.script(model)
outputs = model(torch.randn((batch_size, *input_size)))[train_nodes[-1]]
assert outputs.shape[0] == batch_size
assert not torch.isnan(outputs).any(), 'Output included NaNs'

@ -19,6 +19,7 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .helpers import build_model_with_cfg, overlay_external_default_cfg
from .layers import PatchEmbed, Mlp, DropPath, to_2tuple, trunc_normal_
from .registry import register_model
from .layers import _assert
__all__ = [
@ -105,7 +106,7 @@ class ConvRelPosEnc(nn.Module):
def forward(self, q, v, size: Tuple[int, int]):
B, h, N, Ch = q.shape
H, W = size
assert N == 1 + H * W
_assert(N == 1 + H * W, '')
# Convolutional relative position encoding.
q_img = q[:, :, 1:, :] # [B, h, H*W, Ch]
@ -177,7 +178,7 @@ class ConvPosEnc(nn.Module):
def forward(self, x, size: Tuple[int, int]):
B, N, C = x.shape
H, W = size
assert N == 1 + H * W
_assert(N == 1 + H * W, '')
# Extract CLS token and image tokens.
cls_token, img_tokens = x[:, :1], x[:, 1:] # [B, 1, C], [B, H*W, C]
@ -275,7 +276,7 @@ class ParallelBlock(nn.Module):
""" Feature map interpolation. """
B, N, C = x.shape
H, W = size
assert N == 1 + H * W
_assert(N == 1 + H * W, '')
cls_token = x[:, :1, :]
img_tokens = x[:, 1:, :]

@ -30,6 +30,7 @@ from .helpers import build_model_with_cfg
from .layers import DropPath, to_2tuple, trunc_normal_, PatchEmbed, Mlp
from .registry import register_model
from .vision_transformer_hybrid import HybridEmbed
from .fx_features import register_notrace_module
import torch
import torch.nn as nn
@ -56,6 +57,7 @@ default_cfgs = {
}
@register_notrace_module # reason: FX can't symbolically trace control flow in forward method
class GPSA(nn.Module):
def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.,
locality_strength=1.):

@ -22,6 +22,7 @@ NOTE: model names have been renamed from originals to represent actual input res
Modifed from Timm. https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
"""
from typing import Tuple
import torch
import torch.nn as nn
@ -31,8 +32,9 @@ from functools import partial
from typing import List
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .fx_features import register_notrace_function
from .helpers import build_model_with_cfg
from .layers import DropPath, to_2tuple, trunc_normal_
from .layers import DropPath, to_2tuple, trunc_normal_, _assert
from .registry import register_model
from .vision_transformer import Mlp, Block
@ -116,8 +118,10 @@ class PatchEmbed(nn.Module):
def forward(self, x):
B, C, H, W = x.shape
# FIXME look at relaxing size constraints
assert H == self.img_size[0] and W == self.img_size[1], \
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
_assert(H == self.img_size[0],
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]}).")
_assert(W == self.img_size[1],
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]}).")
x = self.proj(x).flatten(2).transpose(1, 2)
return x
@ -255,6 +259,27 @@ def _compute_num_patches(img_size, patches):
return [i[0] // p * i[1] // p for i, p in zip(img_size, patches)]
@register_notrace_function
def scale_image(x, ss: Tuple[int, int], crop_scale: bool = False): # annotations for torchscript
"""
Pulled out of CrossViT.forward_features to bury conditional logic in a leaf node for FX tracing.
Args:
x (Tensor): input image
ss (tuple[int, int]): height and width to scale to
crop_scale (bool): whether to crop instead of interpolate to achieve the desired scale. Defaults to False
Returns:
Tensor: the "scaled" image batch tensor
"""
H, W = x.shape[-2:]
if H != ss[0] or W != ss[1]:
if crop_scale and ss[0] <= H and ss[1] <= W:
cu, cl = int(round((H - ss[0]) / 2.)), int(round((W - ss[1]) / 2.))
x = x[:, :, cu:cu + ss[0], cl:cl + ss[1]]
else:
x = torch.nn.functional.interpolate(x, size=ss, mode='bicubic', align_corners=False)
return x
class CrossViT(nn.Module):
""" Vision Transformer with support for patch or hybrid CNN input stage
"""
@ -342,17 +367,12 @@ class CrossViT(nn.Module):
range(self.num_branches)])
def forward_features(self, x):
B, C, H, W = x.shape
B = x.shape[0]
xs = []
for i, patch_embed in enumerate(self.patch_embed):
x_ = x
ss = self.img_size_scaled[i]
if H != ss[0] or W != ss[1]:
if self.crop_scale and ss[0] <= H and ss[1] <= W:
cu, cl = int(round((H - ss[0]) / 2.)), int(round((W - ss[1]) / 2.))
x_ = x_[:, :, cu:cu + ss[0], cl:cl + ss[1]]
else:
x_ = torch.nn.functional.interpolate(x_, size=ss, mode='bicubic', align_corners=False)
x_ = scale_image(x_, ss, self.crop_scale)
x_ = patch_embed(x_)
cls_tokens = self.cls_token_0 if i == 0 else self.cls_token_1 # hard-coded for torch jit script
cls_tokens = cls_tokens.expand(B, -1, -1)

@ -0,0 +1,74 @@
""" PyTorch FX Based Feature Extraction Helpers
Using https://pytorch.org/vision/stable/feature_extraction.html
"""
from typing import Callable
from torch import nn
from .features import _get_feature_info
try:
from torchvision.models.feature_extraction import create_feature_extractor
has_fx_feature_extraction = True
except ImportError:
has_fx_feature_extraction = False
# Layers we went to treat as leaf modules
from .layers import Conv2dSame, ScaledStdConv2dSame, BatchNormAct2d, BlurPool2d, CondConv2d, StdConv2dSame, DropPath
from .layers.non_local_attn import BilinearAttnTransform
from .layers.pool2d_same import MaxPool2dSame, AvgPool2dSame
# NOTE: By default, any modules from timm.models.layers that we want to treat as leaf modules go here
# BUT modules from timm.models should use the registration mechanism below
_leaf_modules = {
BatchNormAct2d, # reason: flow control for jit scripting
BilinearAttnTransform, # reason: flow control t <= 1
BlurPool2d, # reason: TypeError: F.conv2d received Proxy in groups=x.shape[1]
# Reason: get_same_padding has a max which raises a control flow error
Conv2dSame, MaxPool2dSame, ScaledStdConv2dSame, StdConv2dSame, AvgPool2dSame,
CondConv2d, # reason: TypeError: F.conv2d received Proxy in groups=self.groups * B (because B = x.shape[0])
DropPath, # reason: TypeError: rand recieved Proxy in `size` argument
}
try:
from .layers import InplaceAbn
_leaf_modules.add(InplaceAbn)
except ImportError:
pass
def register_notrace_module(module: nn.Module):
"""
Any module not under timm.models.layers should get this decorator if we don't want to trace through it.
"""
_leaf_modules.add(module)
return module
# Functions we want to autowrap (treat them as leaves)
_autowrap_functions = set()
def register_notrace_function(func: Callable):
"""
Decorator for functions which ought not to be traced through
"""
_autowrap_functions.add(func)
return func
class FeatureGraphNet(nn.Module):
def __init__(self, model, out_indices, out_map=None):
super().__init__()
assert has_fx_feature_extraction, 'Please update to PyTorch 1.10+, torchvision 0.11+ for FX feature extraction'
self.feature_info = _get_feature_info(model, out_indices)
if out_map is not None:
assert len(out_map) == len(out_indices)
return_nodes = {info['module']: out_map[i] if out_map is not None else info['module']
for i, info in enumerate(self.feature_info) if i in out_indices}
self.graph_module = create_feature_extractor(
model, return_nodes,
tracer_kwargs={'leaf_modules': list(_leaf_modules), 'autowrap_functions': list(_autowrap_functions)})
def forward(self, x):
return list(self.graph_module(x).values())

@ -14,6 +14,7 @@ import torch.nn as nn
from .features import FeatureListNet, FeatureDictNet, FeatureHookNet
from .fx_features import FeatureGraphNet
from .hub import has_hf_hub, download_cached_file, load_state_dict_from_hf, load_state_dict_from_url
from .layers import Conv2dSame, Linear
@ -477,6 +478,8 @@ def build_model_with_cfg(
feature_cls = feature_cls.lower()
if 'hook' in feature_cls:
feature_cls = FeatureHookNet
elif feature_cls == 'fx':
feature_cls = FeatureGraphNet
else:
assert False, f'Unknown feature class {feature_cls}'
model = feature_cls(model, **feature_cfg)

@ -22,6 +22,7 @@ import torch.nn.functional as F
from .helpers import to_2tuple, make_divisible
from .weight_init import trunc_normal_
from .trace_utils import _assert
def rel_logits_1d(q, rel_k, permute_mask: List[int]):
@ -133,8 +134,8 @@ class BottleneckAttn(nn.Module):
def forward(self, x):
B, C, H, W = x.shape
assert H == self.pos_embed.height
assert W == self.pos_embed.width
_assert(H == self.pos_embed.height, '')
_assert(W == self.pos_embed.width, '')
x = self.qkv(x) # B, (2 * dim_head_qk + dim_head_v) * num_heads, H, W
@ -154,5 +155,3 @@ class BottleneckAttn(nn.Module):
out = (attn @ v).transpose(-1, -2).reshape(B, self.dim_out_v, H, W) # B, dim_out, H, W
out = self.pool(out)
return out

@ -12,6 +12,8 @@ Hacked together by / Copyright 2020 Ross Wightman
import torch
import torch.nn as nn
from .trace_utils import _assert
class EvoNormBatch2d(nn.Module):
def __init__(self, num_features, apply_act=True, momentum=0.1, eps=1e-5, drop_block=None):
@ -72,9 +74,9 @@ class EvoNormSample2d(nn.Module):
nn.init.ones_(self.v)
def forward(self, x):
assert x.dim() == 4, 'expected 4D input'
_assert(x.dim() == 4, 'expected 4D input')
B, C, H, W = x.shape
assert C % self.groups == 0
_assert(C % self.groups == 0, '')
if self.apply_act:
n = x * (x * self.v).sigmoid()
x = x.reshape(B, self.groups, -1)

@ -16,7 +16,7 @@ The attention mechanism works but it's slow as implemented.
Hacked together by / Copyright 2021 Ross Wightman
"""
from typing import Tuple, List
from typing import List
import torch
from torch import nn
@ -24,6 +24,7 @@ import torch.nn.functional as F
from .helpers import make_divisible
from .weight_init import trunc_normal_
from .trace_utils import _assert
def rel_logits_1d(q, rel_k, permute_mask: List[int]):
@ -167,8 +168,8 @@ class HaloAttn(nn.Module):
def forward(self, x):
B, C, H, W = x.shape
assert H % self.block_size == 0
assert W % self.block_size == 0
_assert(H % self.block_size == 0, '')
_assert(W % self.block_size == 0, '')
num_h_blocks = H // self.block_size
num_w_blocks = W // self.block_size
num_blocks = num_h_blocks * num_w_blocks

@ -10,6 +10,7 @@ from torch.nn import functional as F
from .conv_bn_act import ConvBnAct
from .helpers import make_divisible
from .trace_utils import _assert
class NonLocalAttn(nn.Module):
@ -83,7 +84,7 @@ class BilinearAttnTransform(nn.Module):
def resize_mat(self, x, t: int):
B, C, block_size, block_size1 = x.shape
assert block_size == block_size1
_assert(block_size == block_size1, '')
if t <= 1:
return x
x = x.view(B * C, -1, 1, 1)
@ -95,7 +96,8 @@ class BilinearAttnTransform(nn.Module):
return x
def forward(self, x):
assert x.shape[-1] % self.block_size == 0 and x.shape[-2] % self.block_size == 0
_assert(x.shape[-1] % self.block_size == 0, '')
_assert(x.shape[-2] % self.block_size == 0, '')
B, C, H, W = x.shape
out = self.conv1(x)
rp = F.adaptive_max_pool2d(out, (self.block_size, 1))

@ -9,6 +9,7 @@ from torch import nn as nn
from .conv_bn_act import ConvBnAct
from .helpers import make_divisible
from .trace_utils import _assert
def _kernel_valid(k):
@ -34,7 +35,7 @@ class SelectiveKernelAttn(nn.Module):
self.fc_select = nn.Conv2d(attn_channels, channels * num_paths, kernel_size=1, bias=False)
def forward(self, x):
assert x.shape[1] == self.num_paths
_assert(x.shape[1] == self.num_paths, '')
x = x.sum(1).mean((2, 3), keepdim=True)
x = self.fc_reduce(x)
x = self.bn(x)

@ -25,8 +25,10 @@ import torch.nn.functional as F
from torch import nn
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .fx_features import register_notrace_function
from .helpers import build_model_with_cfg, named_apply
from .layers import PatchEmbed, Mlp, DropPath, create_classifier, trunc_normal_
from .layers import _assert
from .layers import create_conv2d, create_pool2d, to_ntuple
from .registry import register_model
@ -128,8 +130,8 @@ class ConvPool(nn.Module):
"""
x is expected to have shape (B, C, H, W)
"""
assert x.shape[-2] % 2 == 0, 'BlockAggregation requires even input spatial dims'
assert x.shape[-1] % 2 == 0, 'BlockAggregation requires even input spatial dims'
_assert(x.shape[-2] % 2 == 0, 'BlockAggregation requires even input spatial dims')
_assert(x.shape[-1] % 2 == 0, 'BlockAggregation requires even input spatial dims')
x = self.conv(x)
# Layer norm done over channel dim only
x = self.norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
@ -144,8 +146,8 @@ def blockify(x, block_size: int):
block_size (int): edge length of a single square block in units of H, W
"""
B, H, W, C = x.shape
assert H % block_size == 0, '`block_size` must divide input height evenly'
assert W % block_size == 0, '`block_size` must divide input width evenly'
_assert(H % block_size == 0, '`block_size` must divide input height evenly')
_assert(W % block_size == 0, '`block_size` must divide input width evenly')
grid_height = H // block_size
grid_width = W // block_size
x = x.reshape(B, grid_height, block_size, grid_width, block_size, C)
@ -153,6 +155,7 @@ def blockify(x, block_size: int):
return x # (B, T, N, C)
@register_notrace_function # reason: int receives Proxy
def deblockify(x, block_size: int):
"""blocks to image
Args:

@ -26,6 +26,7 @@ import torch
import torch.nn as nn
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .fx_features import register_notrace_module
from .helpers import build_model_with_cfg
from .registry import register_model
from .layers import ClassifierHead, DropPath, AvgPool2dSame, ScaledStdConv2d, ScaledStdConv2dSame,\
@ -318,6 +319,7 @@ class DownsampleAvg(nn.Module):
return self.conv(self.pool(x))
@register_notrace_module # reason: mul_ causes FX to drop a relevant node. https://github.com/pytorch/pytorch/issues/68301
class NormFreeBlock(nn.Module):
"""Normalization-Free pre-activation block.
"""

@ -10,6 +10,7 @@ Changes for timm, feature extraction, and rounded channel variant hacked togethe
Copyright 2020 Ross Wightman
"""
import torch
import torch.nn as nn
from functools import partial
from math import ceil
@ -92,7 +93,7 @@ class LinearBottleneck(nn.Module):
if self.use_shortcut:
if self.drop_path is not None:
x = self.drop_path(x)
x[:, 0:self.in_channels] += shortcut
x = torch.cat([x[:, 0:self.in_channels] + shortcut, x[:, self.in_channels:]], dim=1)
return x

@ -21,11 +21,14 @@ import torch.nn as nn
import torch.utils.checkpoint as checkpoint
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .fx_features import register_notrace_function
from .helpers import build_model_with_cfg, overlay_external_default_cfg
from .layers import PatchEmbed, Mlp, DropPath, to_2tuple, trunc_normal_
from .layers import _assert
from .registry import register_model
from .vision_transformer import checkpoint_filter_fn, _init_vit_weights
_logger = logging.getLogger(__name__)
@ -100,6 +103,7 @@ def window_partition(x, window_size: int):
return windows
@register_notrace_function # reason: int argument is a Proxy
def window_reverse(windows, window_size: int, H: int, W: int):
"""
Args:
@ -270,7 +274,7 @@ class SwinTransformerBlock(nn.Module):
def forward(self, x):
H, W = self.input_resolution
B, L, C = x.shape
assert L == H * W, "input feature has wrong size"
_assert(L == H * W, "input feature has wrong size")
shortcut = x
x = self.norm1(x)
@ -329,8 +333,8 @@ class PatchMerging(nn.Module):
"""
H, W = self.input_resolution
B, L, C = x.shape
assert L == H * W, "input feature has wrong size"
assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."
_assert(L == H * W, "input feature has wrong size")
_assert(H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even.")
x = x.view(B, H, W, C)

@ -9,12 +9,12 @@ https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/cv/TNT
import math
import torch
import torch.nn as nn
from functools import partial
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.models.helpers import build_model_with_cfg
from timm.models.layers import Mlp, DropPath, trunc_normal_
from timm.models.layers.helpers import to_2tuple
from timm.models.layers import _assert
from timm.models.registry import register_model
from timm.models.vision_transformer import resize_pos_embed
@ -109,7 +109,9 @@ class Block(nn.Module):
pixel_embed = pixel_embed + self.drop_path(self.mlp_in(self.norm_mlp_in(pixel_embed)))
# outer
B, N, C = patch_embed.size()
patch_embed[:, 1:] = patch_embed[:, 1:] + self.proj(self.norm1_proj(pixel_embed).reshape(B, N - 1, -1))
patch_embed = torch.cat(
[patch_embed[:, 0:1], patch_embed[:, 1:] + self.proj(self.norm1_proj(pixel_embed).reshape(B, N - 1, -1))],
dim=1)
patch_embed = patch_embed + self.drop_path(self.attn_out(self.norm_out(patch_embed)))
patch_embed = patch_embed + self.drop_path(self.mlp(self.norm_mlp(patch_embed)))
return pixel_embed, patch_embed
@ -136,8 +138,10 @@ class PixelEmbed(nn.Module):
def forward(self, x, pixel_pos):
B, C, H, W = x.shape
assert H == self.img_size[0] and W == self.img_size[1], \
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
_assert(H == self.img_size[0],
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]}).")
_assert(W == self.img_size[1],
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]}).")
x = self.proj(x)
x = self.unfold(x)
x = x.transpose(1, 2).reshape(B * self.num_patches, self.in_dim, self.new_patch_size[0], self.new_patch_size[1])

@ -22,9 +22,10 @@ from functools import partial
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .layers import Mlp, DropPath, to_2tuple, trunc_normal_
from .fx_features import register_notrace_module
from .registry import register_model
from .vision_transformer import Attention
from .helpers import build_model_with_cfg, overlay_external_default_cfg
from .helpers import build_model_with_cfg
def _cfg(url='', **kwargs):
@ -62,6 +63,7 @@ default_cfgs = {
Size_ = Tuple[int, int]
@register_notrace_module # reason: FX can't symbolically trace control flow in forward method
class LocallyGroupedAttn(nn.Module):
""" LSA: self attention within a group
"""

@ -12,7 +12,8 @@ from typing import Union, List, Dict, Any, cast
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .helpers import build_model_with_cfg
from .layers import ClassifierHead, ConvBnAct
from .fx_features import register_notrace_module
from .layers import ClassifierHead
from .registry import register_model
__all__ = [
@ -52,6 +53,7 @@ cfgs: Dict[str, List[Union[str, int]]] = {
}
@register_notrace_module # reason: FX can't symbolically trace control flow in forward method
class ConvMlp(nn.Module):
def __init__(self, in_features=512, out_features=4096, kernel_size=7, mlp_ratio=1.0,

@ -21,6 +21,7 @@ from .vision_transformer import _cfg, Mlp
from .registry import register_model
from .layers import DropPath, trunc_normal_, to_2tuple
from .cait import ClassAttn
from .fx_features import register_notrace_module
def _cfg(url='', **kwargs):
@ -97,6 +98,7 @@ default_cfgs = {
}
@register_notrace_module # reason: FX can't symbolically trace torch.arange in forward method
class PositionalEncodingFourier(nn.Module):
"""
Positional encoding relying on a fourier kernel matching the one used in the "Attention is all of Need" paper.

Loading…
Cancel
Save