rename notrace registration and standardize trace_utils imports

pull/800/head
Alexander Soare 3 years ago
parent 0262a0e8e1
commit 65d827c7a6

@ -19,7 +19,7 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .helpers import build_model_with_cfg, overlay_external_default_cfg
from .layers import PatchEmbed, Mlp, DropPath, to_2tuple, trunc_normal_
from .registry import register_model
from .layers.trace_utils import _assert
from .layers import _assert
__all__ = [

@ -30,7 +30,7 @@ from .helpers import build_model_with_cfg
from .layers import DropPath, to_2tuple, trunc_normal_, PatchEmbed, Mlp
from .registry import register_model
from .vision_transformer_hybrid import HybridEmbed
from .fx_features import register_leaf_module
from .fx_features import register_notrace_module
import torch
import torch.nn as nn
@ -57,7 +57,7 @@ default_cfgs = {
}
@register_leaf_module # reason: FX can't symbolically trace control flow in forward method
@register_notrace_module # reason: FX can't symbolically trace control flow in forward method
class GPSA(nn.Module):
def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.,
locality_strength=1.):

@ -32,7 +32,7 @@ from functools import partial
from typing import List
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .fx_features import register_autowrap_function
from .fx_features import register_notrace_function
from .helpers import build_model_with_cfg
from .layers import DropPath, to_2tuple, trunc_normal_, _assert
from .registry import register_model
@ -259,7 +259,7 @@ def _compute_num_patches(img_size, patches):
return [i[0] // p * i[1] // p for i, p in zip(img_size, patches)]
@register_autowrap_function
@register_notrace_function
def scale_image(x, ss: Tuple[int, int], crop_scale: bool = False): # annotations for torchscript
"""
Pulled out of CrossViT.forward_features to bury conditional logic in a leaf node for FX tracing.

@ -36,7 +36,7 @@ except ImportError:
pass
def register_leaf_module(module: nn.Module):
def register_notrace_module(module: nn.Module):
"""
Any module not under timm.models.layers should get this decorator if we don't want to trace through it.
"""
@ -48,7 +48,7 @@ def register_leaf_module(module: nn.Module):
_autowrap_functions = set()
def register_autowrap_function(func: Callable):
def register_notrace_function(func: Callable):
"""
Decorator for functions which ought not to be traced through
"""

@ -25,10 +25,10 @@ import torch.nn.functional as F
from torch import nn
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .fx_features import register_autowrap_function
from .fx_features import register_notrace_function
from .helpers import build_model_with_cfg, named_apply
from .layers import PatchEmbed, Mlp, DropPath, create_classifier, trunc_normal_
from .layers.trace_utils import _assert
from .layers import _assert
from .layers import create_conv2d, create_pool2d, to_ntuple
from .registry import register_model
@ -155,7 +155,7 @@ def blockify(x, block_size: int):
return x # (B, T, N, C)
@register_autowrap_function # reason: int receives Proxy
@register_notrace_function # reason: int receives Proxy
def deblockify(x, block_size: int):
"""blocks to image
Args:

@ -26,7 +26,7 @@ import torch
import torch.nn as nn
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .fx_features import register_leaf_module
from .fx_features import register_notrace_module
from .helpers import build_model_with_cfg
from .registry import register_model
from .layers import ClassifierHead, DropPath, AvgPool2dSame, ScaledStdConv2d, ScaledStdConv2dSame,\
@ -319,7 +319,7 @@ class DownsampleAvg(nn.Module):
return self.conv(self.pool(x))
@register_leaf_module # reason: mul_ causes FX to drop a relevant node. https://github.com/pytorch/pytorch/issues/68301
@register_notrace_module # reason: mul_ causes FX to drop a relevant node. https://github.com/pytorch/pytorch/issues/68301
class NormFreeBlock(nn.Module):
"""Normalization-Free pre-activation block.
"""

@ -21,10 +21,10 @@ import torch.nn as nn
import torch.utils.checkpoint as checkpoint
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .fx_features import register_autowrap_function
from .fx_features import register_notrace_function
from .helpers import build_model_with_cfg, overlay_external_default_cfg
from .layers import PatchEmbed, Mlp, DropPath, to_2tuple, trunc_normal_
from .layers.trace_utils import _assert
from .layers import _assert
from .registry import register_model
from .vision_transformer import checkpoint_filter_fn, _init_vit_weights
@ -103,7 +103,7 @@ def window_partition(x, window_size: int):
return windows
@register_autowrap_function # reason: int argument is a Proxy
@register_notrace_function # reason: int argument is a Proxy
def window_reverse(windows, window_size: int, H: int, W: int):
"""
Args:

@ -14,7 +14,7 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.models.helpers import build_model_with_cfg
from timm.models.layers import Mlp, DropPath, trunc_normal_
from timm.models.layers.helpers import to_2tuple
from timm.models.layers.trace_utils import _assert
from timm.models.layers import _assert
from timm.models.registry import register_model
from timm.models.vision_transformer import resize_pos_embed

@ -22,7 +22,7 @@ from functools import partial
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .layers import Mlp, DropPath, to_2tuple, trunc_normal_
from .fx_features import register_leaf_module
from .fx_features import register_notrace_module
from .registry import register_model
from .vision_transformer import Attention
from .helpers import build_model_with_cfg
@ -63,7 +63,7 @@ default_cfgs = {
Size_ = Tuple[int, int]
@register_leaf_module # reason: FX can't symbolically trace control flow in forward method
@register_notrace_module # reason: FX can't symbolically trace control flow in forward method
class LocallyGroupedAttn(nn.Module):
""" LSA: self attention within a group
"""

@ -12,7 +12,7 @@ from typing import Union, List, Dict, Any, cast
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .helpers import build_model_with_cfg
from .fx_features import register_leaf_module
from .fx_features import register_notrace_module
from .layers import ClassifierHead
from .registry import register_model
@ -53,7 +53,7 @@ cfgs: Dict[str, List[Union[str, int]]] = {
}
@register_leaf_module # reason: FX can't symbolically trace control flow in forward method
@register_notrace_module # reason: FX can't symbolically trace control flow in forward method
class ConvMlp(nn.Module):
def __init__(self, in_features=512, out_features=4096, kernel_size=7, mlp_ratio=1.0,

@ -21,7 +21,7 @@ from .vision_transformer import _cfg, Mlp
from .registry import register_model
from .layers import DropPath, trunc_normal_, to_2tuple
from .cait import ClassAttn
from .fx_features import register_leaf_module
from .fx_features import register_notrace_module
def _cfg(url='', **kwargs):
@ -98,7 +98,7 @@ default_cfgs = {
}
@register_leaf_module # reason: FX can't symbolically trace torch.arange in forward method
@register_notrace_module # reason: FX can't symbolically trace torch.arange in forward method
class PositionalEncodingFourier(nn.Module):
"""
Positional encoding relying on a fourier kernel matching the one used in the "Attention is all of Need" paper.

Loading…
Cancel
Save