Merge branch 'fx-feature-extract-new' of https://github.com/alexander-soare/pytorch-image-models into alexander-soare-fx-feature-extract-new

pull/989/head
Ross Wightman 3 years ago
commit 32c9937dec

@ -4,9 +4,16 @@ import platform
import os import os
import fnmatch import fnmatch
try:
from torchvision.models.feature_extraction import create_feature_extractor, get_graph_node_names, NodePathTracer
has_fx_feature_extraction = True
except ImportError:
has_fx_feature_extraction = False
import timm import timm
from timm import list_models, create_model, set_scriptable, has_model_default_key, is_model_default_key, \ from timm import list_models, create_model, set_scriptable, has_model_default_key, is_model_default_key, \
get_model_default_value get_model_default_value
from timm.models.fx_features import _leaf_modules, _autowrap_functions
if hasattr(torch._C, '_jit_set_profiling_executor'): if hasattr(torch._C, '_jit_set_profiling_executor'):
# legacy executor is too slow to compile large models for unit tests # legacy executor is too slow to compile large models for unit tests
@ -297,3 +304,145 @@ def test_model_forward_features(model_name, batch_size):
assert e == o.shape[1] assert e == o.shape[1]
assert o.shape[0] == batch_size assert o.shape[0] == batch_size
assert not torch.isnan(o).any() assert not torch.isnan(o).any()
@pytest.mark.timeout(120)
@pytest.mark.parametrize('model_name', list_models(exclude_filters=EXCLUDE_FILTERS))
@pytest.mark.parametrize('batch_size', [1])
def test_model_forward_fx(model_name, batch_size):
"""
Symbolically trace each model and run single forward pass through the resulting GraphModule
Also check that the output of a forward pass through the GraphModule is the same as that from the original Module
"""
if not has_fx_feature_extraction:
pytest.skip("Can't test FX because Torch >= 1.10 and Torchvision >= 0.11 are required")
model = create_model(model_name, pretrained=False)
model.eval()
input_size = _get_input_size(model=model, target=TARGET_FWD_SIZE)
if max(input_size) > MAX_FWD_SIZE:
pytest.skip("Fixed input size model > limit.")
# This block of code does a bit of juggling to handle any case where there are multiple outputs in train mode
# So we trace once and look at the graph, and get the indices of the nodes that lead into the original fx output
# node. Then we use those indices to select from train_nodes returned by torchvision get_graph_node_names
tracer = NodePathTracer(leaf_modules=list(_leaf_modules), autowrap_functions=list(_autowrap_functions))
graph = tracer.trace(model)
graph_nodes = list(reversed(graph.nodes))
output_node_names = [n.name for n in graph_nodes[0]._input_nodes.keys()]
graph_node_names = [n.name for n in graph_nodes]
output_node_indices = [-graph_node_names.index(node_name) for node_name in output_node_names]
train_nodes, eval_nodes = get_graph_node_names(
model, tracer_kwargs={'leaf_modules': list(_leaf_modules), 'autowrap_functions': list(_autowrap_functions)})
eval_return_nodes = [eval_nodes[ix] for ix in output_node_indices]
fx_model = create_feature_extractor(
model, train_return_nodes=[train_nodes[-1]], eval_return_nodes=eval_return_nodes,
tracer_kwargs={'leaf_modules': list(_leaf_modules), 'autowrap_functions': list(_autowrap_functions)})
inputs = torch.randn((batch_size, *input_size))
outputs = model(inputs)
if isinstance(outputs, tuple):
outputs = torch.cat(outputs)
fx_outputs = tuple(fx_model(inputs).values())
if isinstance(fx_outputs, tuple):
fx_outputs = torch.cat(fx_outputs)
assert torch.all(fx_outputs == outputs)
assert outputs.shape[0] == batch_size
assert not torch.isnan(outputs).any(), 'Output included NaNs'
@pytest.mark.timeout(120)
@pytest.mark.parametrize('model_name', list_models(exclude_filters=EXCLUDE_FILTERS, name_matches_cfg=True))
@pytest.mark.parametrize('batch_size', [2])
def test_model_backward_fx(model_name, batch_size):
"""Symbolically trace each model and run single backward pass through the resulting GraphModule"""
if not has_fx_feature_extraction:
pytest.skip("Can't test FX because Torch >= 1.10 and Torchvision >= 0.11 are required")
input_size = _get_input_size(model_name=model_name, target=TARGET_BWD_SIZE)
if max(input_size) > MAX_BWD_SIZE:
pytest.skip("Fixed input size model > limit.")
model = create_model(model_name, pretrained=False, num_classes=42)
model.train()
num_params = sum([x.numel() for x in model.parameters()])
input_size = _get_input_size(model=model, target=TARGET_FWD_SIZE)
if max(input_size) > MAX_FWD_SIZE:
pytest.skip("Fixed input size model > limit.")
# This block of code does a bit of juggling to handle any case where there are multiple outputs in train mode
# So we trace once and look at the graph, and get the indices of the nodes that lead into the original fx output
# node. Then we use those indices to select from train_nodes returned by torchvision get_graph_node_names
tracer = NodePathTracer(leaf_modules=list(_leaf_modules), autowrap_functions=list(_autowrap_functions))
graph = tracer.trace(model)
graph_nodes = list(reversed(graph.nodes))
output_node_names = [n.name for n in graph_nodes[0]._input_nodes.keys()]
graph_node_names = [n.name for n in graph_nodes]
output_node_indices = [-graph_node_names.index(node_name) for node_name in output_node_names]
train_nodes, eval_nodes = get_graph_node_names(
model, tracer_kwargs={'leaf_modules': list(_leaf_modules), 'autowrap_functions': list(_autowrap_functions)})
train_return_nodes = [train_nodes[ix] for ix in output_node_indices]
model = create_feature_extractor(
model, train_return_nodes=train_return_nodes, eval_return_nodes=[eval_nodes[-1]],
tracer_kwargs={'leaf_modules': list(_leaf_modules), 'autowrap_functions': list(_autowrap_functions)})
inputs = torch.randn((batch_size, *input_size))
outputs = tuple(model(inputs).values())
if isinstance(outputs, tuple):
outputs = torch.cat(outputs)
outputs.mean().backward()
for n, x in model.named_parameters():
assert x.grad is not None, f'No gradient for {n}'
num_grad = sum([x.grad.numel() for x in model.parameters() if x.grad is not None])
assert outputs.shape[-1] == 42
assert num_params == num_grad, 'Some parameters are missing gradients'
assert not torch.isnan(outputs).any(), 'Output included NaNs'
# reason: model is scripted after fx tracing, but beit has torch.jit.is_scripting() control flow
EXCLUDE_FX_JIT_FILTERS = [
'beit_*',
'deit_*_distilled_patch16_224',
'levit*',
'pit_*_distilled_224',
]
@pytest.mark.timeout(120)
@pytest.mark.parametrize(
'model_name', list_models(
exclude_filters=EXCLUDE_FILTERS + EXCLUDE_JIT_FILTERS + EXCLUDE_FX_JIT_FILTERS, name_matches_cfg=True))
@pytest.mark.parametrize('batch_size', [1])
def test_model_forward_fx_torchscript(model_name, batch_size):
"""Symbolically trace each model, script it, and run single forward pass"""
if not has_fx_feature_extraction:
pytest.skip("Can't test FX because Torch >= 1.10 and Torchvision >= 0.11 are required")
input_size = _get_input_size(model_name=model_name, target=TARGET_JIT_SIZE)
if max(input_size) > MAX_JIT_SIZE:
pytest.skip("Fixed input size model > limit.")
with set_scriptable(True):
model = create_model(model_name, pretrained=False)
model.eval()
input_size = _get_input_size(model=model, target=TARGET_FWD_SIZE)
if max(input_size) > MAX_FWD_SIZE:
pytest.skip("Fixed input size model > limit.")
train_nodes, eval_nodes = get_graph_node_names(
model, tracer_kwargs={'leaf_modules': list(_leaf_modules), 'autowrap_functions': list(_autowrap_functions)})
model = create_feature_extractor(
model, train_return_nodes=[train_nodes[-1]], eval_return_nodes=[eval_nodes[-1]],
tracer_kwargs={'leaf_modules': list(_leaf_modules), 'autowrap_functions': list(_autowrap_functions)})
model = torch.jit.script(model)
outputs = model(torch.randn((batch_size, *input_size)))[train_nodes[-1]]
assert outputs.shape[0] == batch_size
assert not torch.isnan(outputs).any(), 'Output included NaNs'

@ -19,6 +19,7 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .helpers import build_model_with_cfg, overlay_external_default_cfg from .helpers import build_model_with_cfg, overlay_external_default_cfg
from .layers import PatchEmbed, Mlp, DropPath, to_2tuple, trunc_normal_ from .layers import PatchEmbed, Mlp, DropPath, to_2tuple, trunc_normal_
from .registry import register_model from .registry import register_model
from .layers import _assert
__all__ = [ __all__ = [
@ -105,7 +106,7 @@ class ConvRelPosEnc(nn.Module):
def forward(self, q, v, size: Tuple[int, int]): def forward(self, q, v, size: Tuple[int, int]):
B, h, N, Ch = q.shape B, h, N, Ch = q.shape
H, W = size H, W = size
assert N == 1 + H * W _assert(N == 1 + H * W, '')
# Convolutional relative position encoding. # Convolutional relative position encoding.
q_img = q[:, :, 1:, :] # [B, h, H*W, Ch] q_img = q[:, :, 1:, :] # [B, h, H*W, Ch]
@ -177,7 +178,7 @@ class ConvPosEnc(nn.Module):
def forward(self, x, size: Tuple[int, int]): def forward(self, x, size: Tuple[int, int]):
B, N, C = x.shape B, N, C = x.shape
H, W = size H, W = size
assert N == 1 + H * W _assert(N == 1 + H * W, '')
# Extract CLS token and image tokens. # Extract CLS token and image tokens.
cls_token, img_tokens = x[:, :1], x[:, 1:] # [B, 1, C], [B, H*W, C] cls_token, img_tokens = x[:, :1], x[:, 1:] # [B, 1, C], [B, H*W, C]
@ -275,7 +276,7 @@ class ParallelBlock(nn.Module):
""" Feature map interpolation. """ """ Feature map interpolation. """
B, N, C = x.shape B, N, C = x.shape
H, W = size H, W = size
assert N == 1 + H * W _assert(N == 1 + H * W, '')
cls_token = x[:, :1, :] cls_token = x[:, :1, :]
img_tokens = x[:, 1:, :] img_tokens = x[:, 1:, :]

@ -30,6 +30,7 @@ from .helpers import build_model_with_cfg
from .layers import DropPath, to_2tuple, trunc_normal_, PatchEmbed, Mlp from .layers import DropPath, to_2tuple, trunc_normal_, PatchEmbed, Mlp
from .registry import register_model from .registry import register_model
from .vision_transformer_hybrid import HybridEmbed from .vision_transformer_hybrid import HybridEmbed
from .fx_features import register_notrace_module
import torch import torch
import torch.nn as nn import torch.nn as nn
@ -56,6 +57,7 @@ default_cfgs = {
} }
@register_notrace_module # reason: FX can't symbolically trace control flow in forward method
class GPSA(nn.Module): class GPSA(nn.Module):
def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0., def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.,
locality_strength=1.): locality_strength=1.):

@ -22,6 +22,7 @@ NOTE: model names have been renamed from originals to represent actual input res
Modifed from Timm. https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py Modifed from Timm. https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
""" """
from typing import Tuple
import torch import torch
import torch.nn as nn import torch.nn as nn
@ -31,8 +32,9 @@ from functools import partial
from typing import List from typing import List
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .fx_features import register_notrace_function
from .helpers import build_model_with_cfg from .helpers import build_model_with_cfg
from .layers import DropPath, to_2tuple, trunc_normal_ from .layers import DropPath, to_2tuple, trunc_normal_, _assert
from .registry import register_model from .registry import register_model
from .vision_transformer import Mlp, Block from .vision_transformer import Mlp, Block
@ -116,8 +118,10 @@ class PatchEmbed(nn.Module):
def forward(self, x): def forward(self, x):
B, C, H, W = x.shape B, C, H, W = x.shape
# FIXME look at relaxing size constraints # FIXME look at relaxing size constraints
assert H == self.img_size[0] and W == self.img_size[1], \ _assert(H == self.img_size[0],
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]}).")
_assert(W == self.img_size[1],
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]}).")
x = self.proj(x).flatten(2).transpose(1, 2) x = self.proj(x).flatten(2).transpose(1, 2)
return x return x
@ -255,6 +259,27 @@ def _compute_num_patches(img_size, patches):
return [i[0] // p * i[1] // p for i, p in zip(img_size, patches)] return [i[0] // p * i[1] // p for i, p in zip(img_size, patches)]
@register_notrace_function
def scale_image(x, ss: Tuple[int, int], crop_scale: bool = False): # annotations for torchscript
"""
Pulled out of CrossViT.forward_features to bury conditional logic in a leaf node for FX tracing.
Args:
x (Tensor): input image
ss (tuple[int, int]): height and width to scale to
crop_scale (bool): whether to crop instead of interpolate to achieve the desired scale. Defaults to False
Returns:
Tensor: the "scaled" image batch tensor
"""
H, W = x.shape[-2:]
if H != ss[0] or W != ss[1]:
if crop_scale and ss[0] <= H and ss[1] <= W:
cu, cl = int(round((H - ss[0]) / 2.)), int(round((W - ss[1]) / 2.))
x = x[:, :, cu:cu + ss[0], cl:cl + ss[1]]
else:
x = torch.nn.functional.interpolate(x, size=ss, mode='bicubic', align_corners=False)
return x
class CrossViT(nn.Module): class CrossViT(nn.Module):
""" Vision Transformer with support for patch or hybrid CNN input stage """ Vision Transformer with support for patch or hybrid CNN input stage
""" """
@ -342,17 +367,12 @@ class CrossViT(nn.Module):
range(self.num_branches)]) range(self.num_branches)])
def forward_features(self, x): def forward_features(self, x):
B, C, H, W = x.shape B = x.shape[0]
xs = [] xs = []
for i, patch_embed in enumerate(self.patch_embed): for i, patch_embed in enumerate(self.patch_embed):
x_ = x x_ = x
ss = self.img_size_scaled[i] ss = self.img_size_scaled[i]
if H != ss[0] or W != ss[1]: x_ = scale_image(x_, ss, self.crop_scale)
if self.crop_scale and ss[0] <= H and ss[1] <= W:
cu, cl = int(round((H - ss[0]) / 2.)), int(round((W - ss[1]) / 2.))
x_ = x_[:, :, cu:cu + ss[0], cl:cl + ss[1]]
else:
x_ = torch.nn.functional.interpolate(x_, size=ss, mode='bicubic', align_corners=False)
x_ = patch_embed(x_) x_ = patch_embed(x_)
cls_tokens = self.cls_token_0 if i == 0 else self.cls_token_1 # hard-coded for torch jit script cls_tokens = self.cls_token_0 if i == 0 else self.cls_token_1 # hard-coded for torch jit script
cls_tokens = cls_tokens.expand(B, -1, -1) cls_tokens = cls_tokens.expand(B, -1, -1)

@ -0,0 +1,74 @@
""" PyTorch FX Based Feature Extraction Helpers
Using https://pytorch.org/vision/stable/feature_extraction.html
"""
from typing import Callable
from torch import nn
from .features import _get_feature_info
try:
from torchvision.models.feature_extraction import create_feature_extractor
has_fx_feature_extraction = True
except ImportError:
has_fx_feature_extraction = False
# Layers we went to treat as leaf modules
from .layers import Conv2dSame, ScaledStdConv2dSame, BatchNormAct2d, BlurPool2d, CondConv2d, StdConv2dSame, DropPath
from .layers.non_local_attn import BilinearAttnTransform
from .layers.pool2d_same import MaxPool2dSame, AvgPool2dSame
# NOTE: By default, any modules from timm.models.layers that we want to treat as leaf modules go here
# BUT modules from timm.models should use the registration mechanism below
_leaf_modules = {
BatchNormAct2d, # reason: flow control for jit scripting
BilinearAttnTransform, # reason: flow control t <= 1
BlurPool2d, # reason: TypeError: F.conv2d received Proxy in groups=x.shape[1]
# Reason: get_same_padding has a max which raises a control flow error
Conv2dSame, MaxPool2dSame, ScaledStdConv2dSame, StdConv2dSame, AvgPool2dSame,
CondConv2d, # reason: TypeError: F.conv2d received Proxy in groups=self.groups * B (because B = x.shape[0])
DropPath, # reason: TypeError: rand recieved Proxy in `size` argument
}
try:
from .layers import InplaceAbn
_leaf_modules.add(InplaceAbn)
except ImportError:
pass
def register_notrace_module(module: nn.Module):
"""
Any module not under timm.models.layers should get this decorator if we don't want to trace through it.
"""
_leaf_modules.add(module)
return module
# Functions we want to autowrap (treat them as leaves)
_autowrap_functions = set()
def register_notrace_function(func: Callable):
"""
Decorator for functions which ought not to be traced through
"""
_autowrap_functions.add(func)
return func
class FeatureGraphNet(nn.Module):
def __init__(self, model, out_indices, out_map=None):
super().__init__()
assert has_fx_feature_extraction, 'Please update to PyTorch 1.10+, torchvision 0.11+ for FX feature extraction'
self.feature_info = _get_feature_info(model, out_indices)
if out_map is not None:
assert len(out_map) == len(out_indices)
return_nodes = {info['module']: out_map[i] if out_map is not None else info['module']
for i, info in enumerate(self.feature_info) if i in out_indices}
self.graph_module = create_feature_extractor(
model, return_nodes,
tracer_kwargs={'leaf_modules': list(_leaf_modules), 'autowrap_functions': list(_autowrap_functions)})
def forward(self, x):
return list(self.graph_module(x).values())

@ -14,6 +14,7 @@ import torch.nn as nn
from .features import FeatureListNet, FeatureDictNet, FeatureHookNet from .features import FeatureListNet, FeatureDictNet, FeatureHookNet
from .fx_features import FeatureGraphNet
from .hub import has_hf_hub, download_cached_file, load_state_dict_from_hf, load_state_dict_from_url from .hub import has_hf_hub, download_cached_file, load_state_dict_from_hf, load_state_dict_from_url
from .layers import Conv2dSame, Linear from .layers import Conv2dSame, Linear
@ -477,6 +478,8 @@ def build_model_with_cfg(
feature_cls = feature_cls.lower() feature_cls = feature_cls.lower()
if 'hook' in feature_cls: if 'hook' in feature_cls:
feature_cls = FeatureHookNet feature_cls = FeatureHookNet
elif feature_cls == 'fx':
feature_cls = FeatureGraphNet
else: else:
assert False, f'Unknown feature class {feature_cls}' assert False, f'Unknown feature class {feature_cls}'
model = feature_cls(model, **feature_cfg) model = feature_cls(model, **feature_cfg)

@ -22,6 +22,7 @@ import torch.nn.functional as F
from .helpers import to_2tuple, make_divisible from .helpers import to_2tuple, make_divisible
from .weight_init import trunc_normal_ from .weight_init import trunc_normal_
from .trace_utils import _assert
def rel_logits_1d(q, rel_k, permute_mask: List[int]): def rel_logits_1d(q, rel_k, permute_mask: List[int]):
@ -133,8 +134,8 @@ class BottleneckAttn(nn.Module):
def forward(self, x): def forward(self, x):
B, C, H, W = x.shape B, C, H, W = x.shape
assert H == self.pos_embed.height _assert(H == self.pos_embed.height, '')
assert W == self.pos_embed.width _assert(W == self.pos_embed.width, '')
x = self.qkv(x) # B, (2 * dim_head_qk + dim_head_v) * num_heads, H, W x = self.qkv(x) # B, (2 * dim_head_qk + dim_head_v) * num_heads, H, W
@ -154,5 +155,3 @@ class BottleneckAttn(nn.Module):
out = (attn @ v).transpose(-1, -2).reshape(B, self.dim_out_v, H, W) # B, dim_out, H, W out = (attn @ v).transpose(-1, -2).reshape(B, self.dim_out_v, H, W) # B, dim_out, H, W
out = self.pool(out) out = self.pool(out)
return out return out

@ -12,6 +12,8 @@ Hacked together by / Copyright 2020 Ross Wightman
import torch import torch
import torch.nn as nn import torch.nn as nn
from .trace_utils import _assert
class EvoNormBatch2d(nn.Module): class EvoNormBatch2d(nn.Module):
def __init__(self, num_features, apply_act=True, momentum=0.1, eps=1e-5, drop_block=None): def __init__(self, num_features, apply_act=True, momentum=0.1, eps=1e-5, drop_block=None):
@ -72,9 +74,9 @@ class EvoNormSample2d(nn.Module):
nn.init.ones_(self.v) nn.init.ones_(self.v)
def forward(self, x): def forward(self, x):
assert x.dim() == 4, 'expected 4D input' _assert(x.dim() == 4, 'expected 4D input')
B, C, H, W = x.shape B, C, H, W = x.shape
assert C % self.groups == 0 _assert(C % self.groups == 0, '')
if self.apply_act: if self.apply_act:
n = x * (x * self.v).sigmoid() n = x * (x * self.v).sigmoid()
x = x.reshape(B, self.groups, -1) x = x.reshape(B, self.groups, -1)

@ -16,7 +16,7 @@ The attention mechanism works but it's slow as implemented.
Hacked together by / Copyright 2021 Ross Wightman Hacked together by / Copyright 2021 Ross Wightman
""" """
from typing import Tuple, List from typing import List
import torch import torch
from torch import nn from torch import nn
@ -24,6 +24,7 @@ import torch.nn.functional as F
from .helpers import make_divisible from .helpers import make_divisible
from .weight_init import trunc_normal_ from .weight_init import trunc_normal_
from .trace_utils import _assert
def rel_logits_1d(q, rel_k, permute_mask: List[int]): def rel_logits_1d(q, rel_k, permute_mask: List[int]):
@ -167,8 +168,8 @@ class HaloAttn(nn.Module):
def forward(self, x): def forward(self, x):
B, C, H, W = x.shape B, C, H, W = x.shape
assert H % self.block_size == 0 _assert(H % self.block_size == 0, '')
assert W % self.block_size == 0 _assert(W % self.block_size == 0, '')
num_h_blocks = H // self.block_size num_h_blocks = H // self.block_size
num_w_blocks = W // self.block_size num_w_blocks = W // self.block_size
num_blocks = num_h_blocks * num_w_blocks num_blocks = num_h_blocks * num_w_blocks

@ -10,6 +10,7 @@ from torch.nn import functional as F
from .conv_bn_act import ConvBnAct from .conv_bn_act import ConvBnAct
from .helpers import make_divisible from .helpers import make_divisible
from .trace_utils import _assert
class NonLocalAttn(nn.Module): class NonLocalAttn(nn.Module):
@ -83,7 +84,7 @@ class BilinearAttnTransform(nn.Module):
def resize_mat(self, x, t: int): def resize_mat(self, x, t: int):
B, C, block_size, block_size1 = x.shape B, C, block_size, block_size1 = x.shape
assert block_size == block_size1 _assert(block_size == block_size1, '')
if t <= 1: if t <= 1:
return x return x
x = x.view(B * C, -1, 1, 1) x = x.view(B * C, -1, 1, 1)
@ -95,7 +96,8 @@ class BilinearAttnTransform(nn.Module):
return x return x
def forward(self, x): def forward(self, x):
assert x.shape[-1] % self.block_size == 0 and x.shape[-2] % self.block_size == 0 _assert(x.shape[-1] % self.block_size == 0, '')
_assert(x.shape[-2] % self.block_size == 0, '')
B, C, H, W = x.shape B, C, H, W = x.shape
out = self.conv1(x) out = self.conv1(x)
rp = F.adaptive_max_pool2d(out, (self.block_size, 1)) rp = F.adaptive_max_pool2d(out, (self.block_size, 1))

@ -9,6 +9,7 @@ from torch import nn as nn
from .conv_bn_act import ConvBnAct from .conv_bn_act import ConvBnAct
from .helpers import make_divisible from .helpers import make_divisible
from .trace_utils import _assert
def _kernel_valid(k): def _kernel_valid(k):
@ -34,7 +35,7 @@ class SelectiveKernelAttn(nn.Module):
self.fc_select = nn.Conv2d(attn_channels, channels * num_paths, kernel_size=1, bias=False) self.fc_select = nn.Conv2d(attn_channels, channels * num_paths, kernel_size=1, bias=False)
def forward(self, x): def forward(self, x):
assert x.shape[1] == self.num_paths _assert(x.shape[1] == self.num_paths, '')
x = x.sum(1).mean((2, 3), keepdim=True) x = x.sum(1).mean((2, 3), keepdim=True)
x = self.fc_reduce(x) x = self.fc_reduce(x)
x = self.bn(x) x = self.bn(x)

@ -25,8 +25,10 @@ import torch.nn.functional as F
from torch import nn from torch import nn
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .fx_features import register_notrace_function
from .helpers import build_model_with_cfg, named_apply from .helpers import build_model_with_cfg, named_apply
from .layers import PatchEmbed, Mlp, DropPath, create_classifier, trunc_normal_ from .layers import PatchEmbed, Mlp, DropPath, create_classifier, trunc_normal_
from .layers import _assert
from .layers import create_conv2d, create_pool2d, to_ntuple from .layers import create_conv2d, create_pool2d, to_ntuple
from .registry import register_model from .registry import register_model
@ -128,8 +130,8 @@ class ConvPool(nn.Module):
""" """
x is expected to have shape (B, C, H, W) x is expected to have shape (B, C, H, W)
""" """
assert x.shape[-2] % 2 == 0, 'BlockAggregation requires even input spatial dims' _assert(x.shape[-2] % 2 == 0, 'BlockAggregation requires even input spatial dims')
assert x.shape[-1] % 2 == 0, 'BlockAggregation requires even input spatial dims' _assert(x.shape[-1] % 2 == 0, 'BlockAggregation requires even input spatial dims')
x = self.conv(x) x = self.conv(x)
# Layer norm done over channel dim only # Layer norm done over channel dim only
x = self.norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) x = self.norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
@ -144,8 +146,8 @@ def blockify(x, block_size: int):
block_size (int): edge length of a single square block in units of H, W block_size (int): edge length of a single square block in units of H, W
""" """
B, H, W, C = x.shape B, H, W, C = x.shape
assert H % block_size == 0, '`block_size` must divide input height evenly' _assert(H % block_size == 0, '`block_size` must divide input height evenly')
assert W % block_size == 0, '`block_size` must divide input width evenly' _assert(W % block_size == 0, '`block_size` must divide input width evenly')
grid_height = H // block_size grid_height = H // block_size
grid_width = W // block_size grid_width = W // block_size
x = x.reshape(B, grid_height, block_size, grid_width, block_size, C) x = x.reshape(B, grid_height, block_size, grid_width, block_size, C)
@ -153,6 +155,7 @@ def blockify(x, block_size: int):
return x # (B, T, N, C) return x # (B, T, N, C)
@register_notrace_function # reason: int receives Proxy
def deblockify(x, block_size: int): def deblockify(x, block_size: int):
"""blocks to image """blocks to image
Args: Args:

@ -26,6 +26,7 @@ import torch
import torch.nn as nn import torch.nn as nn
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .fx_features import register_notrace_module
from .helpers import build_model_with_cfg from .helpers import build_model_with_cfg
from .registry import register_model from .registry import register_model
from .layers import ClassifierHead, DropPath, AvgPool2dSame, ScaledStdConv2d, ScaledStdConv2dSame,\ from .layers import ClassifierHead, DropPath, AvgPool2dSame, ScaledStdConv2d, ScaledStdConv2dSame,\
@ -318,6 +319,7 @@ class DownsampleAvg(nn.Module):
return self.conv(self.pool(x)) return self.conv(self.pool(x))
@register_notrace_module # reason: mul_ causes FX to drop a relevant node. https://github.com/pytorch/pytorch/issues/68301
class NormFreeBlock(nn.Module): class NormFreeBlock(nn.Module):
"""Normalization-Free pre-activation block. """Normalization-Free pre-activation block.
""" """

@ -10,6 +10,7 @@ Changes for timm, feature extraction, and rounded channel variant hacked togethe
Copyright 2020 Ross Wightman Copyright 2020 Ross Wightman
""" """
import torch
import torch.nn as nn import torch.nn as nn
from functools import partial from functools import partial
from math import ceil from math import ceil
@ -92,7 +93,7 @@ class LinearBottleneck(nn.Module):
if self.use_shortcut: if self.use_shortcut:
if self.drop_path is not None: if self.drop_path is not None:
x = self.drop_path(x) x = self.drop_path(x)
x[:, 0:self.in_channels] += shortcut x = torch.cat([x[:, 0:self.in_channels] + shortcut, x[:, self.in_channels:]], dim=1)
return x return x

@ -21,11 +21,14 @@ import torch.nn as nn
import torch.utils.checkpoint as checkpoint import torch.utils.checkpoint as checkpoint
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .fx_features import register_notrace_function
from .helpers import build_model_with_cfg, overlay_external_default_cfg from .helpers import build_model_with_cfg, overlay_external_default_cfg
from .layers import PatchEmbed, Mlp, DropPath, to_2tuple, trunc_normal_ from .layers import PatchEmbed, Mlp, DropPath, to_2tuple, trunc_normal_
from .layers import _assert
from .registry import register_model from .registry import register_model
from .vision_transformer import checkpoint_filter_fn, _init_vit_weights from .vision_transformer import checkpoint_filter_fn, _init_vit_weights
_logger = logging.getLogger(__name__) _logger = logging.getLogger(__name__)
@ -100,6 +103,7 @@ def window_partition(x, window_size: int):
return windows return windows
@register_notrace_function # reason: int argument is a Proxy
def window_reverse(windows, window_size: int, H: int, W: int): def window_reverse(windows, window_size: int, H: int, W: int):
""" """
Args: Args:
@ -270,7 +274,7 @@ class SwinTransformerBlock(nn.Module):
def forward(self, x): def forward(self, x):
H, W = self.input_resolution H, W = self.input_resolution
B, L, C = x.shape B, L, C = x.shape
assert L == H * W, "input feature has wrong size" _assert(L == H * W, "input feature has wrong size")
shortcut = x shortcut = x
x = self.norm1(x) x = self.norm1(x)
@ -329,8 +333,8 @@ class PatchMerging(nn.Module):
""" """
H, W = self.input_resolution H, W = self.input_resolution
B, L, C = x.shape B, L, C = x.shape
assert L == H * W, "input feature has wrong size" _assert(L == H * W, "input feature has wrong size")
assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even." _assert(H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even.")
x = x.view(B, H, W, C) x = x.view(B, H, W, C)

@ -9,12 +9,12 @@ https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/cv/TNT
import math import math
import torch import torch
import torch.nn as nn import torch.nn as nn
from functools import partial
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.models.helpers import build_model_with_cfg from timm.models.helpers import build_model_with_cfg
from timm.models.layers import Mlp, DropPath, trunc_normal_ from timm.models.layers import Mlp, DropPath, trunc_normal_
from timm.models.layers.helpers import to_2tuple from timm.models.layers.helpers import to_2tuple
from timm.models.layers import _assert
from timm.models.registry import register_model from timm.models.registry import register_model
from timm.models.vision_transformer import resize_pos_embed from timm.models.vision_transformer import resize_pos_embed
@ -109,7 +109,9 @@ class Block(nn.Module):
pixel_embed = pixel_embed + self.drop_path(self.mlp_in(self.norm_mlp_in(pixel_embed))) pixel_embed = pixel_embed + self.drop_path(self.mlp_in(self.norm_mlp_in(pixel_embed)))
# outer # outer
B, N, C = patch_embed.size() B, N, C = patch_embed.size()
patch_embed[:, 1:] = patch_embed[:, 1:] + self.proj(self.norm1_proj(pixel_embed).reshape(B, N - 1, -1)) patch_embed = torch.cat(
[patch_embed[:, 0:1], patch_embed[:, 1:] + self.proj(self.norm1_proj(pixel_embed).reshape(B, N - 1, -1))],
dim=1)
patch_embed = patch_embed + self.drop_path(self.attn_out(self.norm_out(patch_embed))) patch_embed = patch_embed + self.drop_path(self.attn_out(self.norm_out(patch_embed)))
patch_embed = patch_embed + self.drop_path(self.mlp(self.norm_mlp(patch_embed))) patch_embed = patch_embed + self.drop_path(self.mlp(self.norm_mlp(patch_embed)))
return pixel_embed, patch_embed return pixel_embed, patch_embed
@ -136,8 +138,10 @@ class PixelEmbed(nn.Module):
def forward(self, x, pixel_pos): def forward(self, x, pixel_pos):
B, C, H, W = x.shape B, C, H, W = x.shape
assert H == self.img_size[0] and W == self.img_size[1], \ _assert(H == self.img_size[0],
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]}).")
_assert(W == self.img_size[1],
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]}).")
x = self.proj(x) x = self.proj(x)
x = self.unfold(x) x = self.unfold(x)
x = x.transpose(1, 2).reshape(B * self.num_patches, self.in_dim, self.new_patch_size[0], self.new_patch_size[1]) x = x.transpose(1, 2).reshape(B * self.num_patches, self.in_dim, self.new_patch_size[0], self.new_patch_size[1])

@ -22,9 +22,10 @@ from functools import partial
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .layers import Mlp, DropPath, to_2tuple, trunc_normal_ from .layers import Mlp, DropPath, to_2tuple, trunc_normal_
from .fx_features import register_notrace_module
from .registry import register_model from .registry import register_model
from .vision_transformer import Attention from .vision_transformer import Attention
from .helpers import build_model_with_cfg, overlay_external_default_cfg from .helpers import build_model_with_cfg
def _cfg(url='', **kwargs): def _cfg(url='', **kwargs):
@ -62,6 +63,7 @@ default_cfgs = {
Size_ = Tuple[int, int] Size_ = Tuple[int, int]
@register_notrace_module # reason: FX can't symbolically trace control flow in forward method
class LocallyGroupedAttn(nn.Module): class LocallyGroupedAttn(nn.Module):
""" LSA: self attention within a group """ LSA: self attention within a group
""" """

@ -12,7 +12,8 @@ from typing import Union, List, Dict, Any, cast
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .helpers import build_model_with_cfg from .helpers import build_model_with_cfg
from .layers import ClassifierHead, ConvBnAct from .fx_features import register_notrace_module
from .layers import ClassifierHead
from .registry import register_model from .registry import register_model
__all__ = [ __all__ = [
@ -52,6 +53,7 @@ cfgs: Dict[str, List[Union[str, int]]] = {
} }
@register_notrace_module # reason: FX can't symbolically trace control flow in forward method
class ConvMlp(nn.Module): class ConvMlp(nn.Module):
def __init__(self, in_features=512, out_features=4096, kernel_size=7, mlp_ratio=1.0, def __init__(self, in_features=512, out_features=4096, kernel_size=7, mlp_ratio=1.0,

@ -21,6 +21,7 @@ from .vision_transformer import _cfg, Mlp
from .registry import register_model from .registry import register_model
from .layers import DropPath, trunc_normal_, to_2tuple from .layers import DropPath, trunc_normal_, to_2tuple
from .cait import ClassAttn from .cait import ClassAttn
from .fx_features import register_notrace_module
def _cfg(url='', **kwargs): def _cfg(url='', **kwargs):
@ -97,6 +98,7 @@ default_cfgs = {
} }
@register_notrace_module # reason: FX can't symbolically trace torch.arange in forward method
class PositionalEncodingFourier(nn.Module): class PositionalEncodingFourier(nn.Module):
""" """
Positional encoding relying on a fourier kernel matching the one used in the "Attention is all of Need" paper. Positional encoding relying on a fourier kernel matching the one used in the "Attention is all of Need" paper.

Loading…
Cancel
Save