rename notrace registration and standardize trace_utils imports

pull/800/head
Alexander Soare 3 years ago
parent 0262a0e8e1
commit 65d827c7a6

@ -19,7 +19,7 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .helpers import build_model_with_cfg, overlay_external_default_cfg from .helpers import build_model_with_cfg, overlay_external_default_cfg
from .layers import PatchEmbed, Mlp, DropPath, to_2tuple, trunc_normal_ from .layers import PatchEmbed, Mlp, DropPath, to_2tuple, trunc_normal_
from .registry import register_model from .registry import register_model
from .layers.trace_utils import _assert from .layers import _assert
__all__ = [ __all__ = [

@ -30,7 +30,7 @@ from .helpers import build_model_with_cfg
from .layers import DropPath, to_2tuple, trunc_normal_, PatchEmbed, Mlp from .layers import DropPath, to_2tuple, trunc_normal_, PatchEmbed, Mlp
from .registry import register_model from .registry import register_model
from .vision_transformer_hybrid import HybridEmbed from .vision_transformer_hybrid import HybridEmbed
from .fx_features import register_leaf_module from .fx_features import register_notrace_module
import torch import torch
import torch.nn as nn import torch.nn as nn
@ -57,7 +57,7 @@ default_cfgs = {
} }
@register_leaf_module # reason: FX can't symbolically trace control flow in forward method @register_notrace_module # reason: FX can't symbolically trace control flow in forward method
class GPSA(nn.Module): class GPSA(nn.Module):
def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0., def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.,
locality_strength=1.): locality_strength=1.):

@ -32,7 +32,7 @@ from functools import partial
from typing import List from typing import List
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .fx_features import register_autowrap_function from .fx_features import register_notrace_function
from .helpers import build_model_with_cfg from .helpers import build_model_with_cfg
from .layers import DropPath, to_2tuple, trunc_normal_, _assert from .layers import DropPath, to_2tuple, trunc_normal_, _assert
from .registry import register_model from .registry import register_model
@ -259,7 +259,7 @@ def _compute_num_patches(img_size, patches):
return [i[0] // p * i[1] // p for i, p in zip(img_size, patches)] return [i[0] // p * i[1] // p for i, p in zip(img_size, patches)]
@register_autowrap_function @register_notrace_function
def scale_image(x, ss: Tuple[int, int], crop_scale: bool = False): # annotations for torchscript def scale_image(x, ss: Tuple[int, int], crop_scale: bool = False): # annotations for torchscript
""" """
Pulled out of CrossViT.forward_features to bury conditional logic in a leaf node for FX tracing. Pulled out of CrossViT.forward_features to bury conditional logic in a leaf node for FX tracing.

@ -36,7 +36,7 @@ except ImportError:
pass pass
def register_leaf_module(module: nn.Module): def register_notrace_module(module: nn.Module):
""" """
Any module not under timm.models.layers should get this decorator if we don't want to trace through it. Any module not under timm.models.layers should get this decorator if we don't want to trace through it.
""" """
@ -48,7 +48,7 @@ def register_leaf_module(module: nn.Module):
_autowrap_functions = set() _autowrap_functions = set()
def register_autowrap_function(func: Callable): def register_notrace_function(func: Callable):
""" """
Decorator for functions which ought not to be traced through Decorator for functions which ought not to be traced through
""" """

@ -25,10 +25,10 @@ import torch.nn.functional as F
from torch import nn from torch import nn
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .fx_features import register_autowrap_function from .fx_features import register_notrace_function
from .helpers import build_model_with_cfg, named_apply from .helpers import build_model_with_cfg, named_apply
from .layers import PatchEmbed, Mlp, DropPath, create_classifier, trunc_normal_ from .layers import PatchEmbed, Mlp, DropPath, create_classifier, trunc_normal_
from .layers.trace_utils import _assert from .layers import _assert
from .layers import create_conv2d, create_pool2d, to_ntuple from .layers import create_conv2d, create_pool2d, to_ntuple
from .registry import register_model from .registry import register_model
@ -155,7 +155,7 @@ def blockify(x, block_size: int):
return x # (B, T, N, C) return x # (B, T, N, C)
@register_autowrap_function # reason: int receives Proxy @register_notrace_function # reason: int receives Proxy
def deblockify(x, block_size: int): def deblockify(x, block_size: int):
"""blocks to image """blocks to image
Args: Args:

@ -26,7 +26,7 @@ import torch
import torch.nn as nn import torch.nn as nn
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .fx_features import register_leaf_module from .fx_features import register_notrace_module
from .helpers import build_model_with_cfg from .helpers import build_model_with_cfg
from .registry import register_model from .registry import register_model
from .layers import ClassifierHead, DropPath, AvgPool2dSame, ScaledStdConv2d, ScaledStdConv2dSame,\ from .layers import ClassifierHead, DropPath, AvgPool2dSame, ScaledStdConv2d, ScaledStdConv2dSame,\
@ -319,7 +319,7 @@ class DownsampleAvg(nn.Module):
return self.conv(self.pool(x)) return self.conv(self.pool(x))
@register_leaf_module # reason: mul_ causes FX to drop a relevant node. https://github.com/pytorch/pytorch/issues/68301 @register_notrace_module # reason: mul_ causes FX to drop a relevant node. https://github.com/pytorch/pytorch/issues/68301
class NormFreeBlock(nn.Module): class NormFreeBlock(nn.Module):
"""Normalization-Free pre-activation block. """Normalization-Free pre-activation block.
""" """

@ -21,10 +21,10 @@ import torch.nn as nn
import torch.utils.checkpoint as checkpoint import torch.utils.checkpoint as checkpoint
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .fx_features import register_autowrap_function from .fx_features import register_notrace_function
from .helpers import build_model_with_cfg, overlay_external_default_cfg from .helpers import build_model_with_cfg, overlay_external_default_cfg
from .layers import PatchEmbed, Mlp, DropPath, to_2tuple, trunc_normal_ from .layers import PatchEmbed, Mlp, DropPath, to_2tuple, trunc_normal_
from .layers.trace_utils import _assert from .layers import _assert
from .registry import register_model from .registry import register_model
from .vision_transformer import checkpoint_filter_fn, _init_vit_weights from .vision_transformer import checkpoint_filter_fn, _init_vit_weights
@ -103,7 +103,7 @@ def window_partition(x, window_size: int):
return windows return windows
@register_autowrap_function # reason: int argument is a Proxy @register_notrace_function # reason: int argument is a Proxy
def window_reverse(windows, window_size: int, H: int, W: int): def window_reverse(windows, window_size: int, H: int, W: int):
""" """
Args: Args:

@ -14,7 +14,7 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.models.helpers import build_model_with_cfg from timm.models.helpers import build_model_with_cfg
from timm.models.layers import Mlp, DropPath, trunc_normal_ from timm.models.layers import Mlp, DropPath, trunc_normal_
from timm.models.layers.helpers import to_2tuple from timm.models.layers.helpers import to_2tuple
from timm.models.layers.trace_utils import _assert from timm.models.layers import _assert
from timm.models.registry import register_model from timm.models.registry import register_model
from timm.models.vision_transformer import resize_pos_embed from timm.models.vision_transformer import resize_pos_embed

@ -22,7 +22,7 @@ from functools import partial
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .layers import Mlp, DropPath, to_2tuple, trunc_normal_ from .layers import Mlp, DropPath, to_2tuple, trunc_normal_
from .fx_features import register_leaf_module from .fx_features import register_notrace_module
from .registry import register_model from .registry import register_model
from .vision_transformer import Attention from .vision_transformer import Attention
from .helpers import build_model_with_cfg from .helpers import build_model_with_cfg
@ -63,7 +63,7 @@ default_cfgs = {
Size_ = Tuple[int, int] Size_ = Tuple[int, int]
@register_leaf_module # reason: FX can't symbolically trace control flow in forward method @register_notrace_module # reason: FX can't symbolically trace control flow in forward method
class LocallyGroupedAttn(nn.Module): class LocallyGroupedAttn(nn.Module):
""" LSA: self attention within a group """ LSA: self attention within a group
""" """

@ -12,7 +12,7 @@ from typing import Union, List, Dict, Any, cast
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .helpers import build_model_with_cfg from .helpers import build_model_with_cfg
from .fx_features import register_leaf_module from .fx_features import register_notrace_module
from .layers import ClassifierHead from .layers import ClassifierHead
from .registry import register_model from .registry import register_model
@ -53,7 +53,7 @@ cfgs: Dict[str, List[Union[str, int]]] = {
} }
@register_leaf_module # reason: FX can't symbolically trace control flow in forward method @register_notrace_module # reason: FX can't symbolically trace control flow in forward method
class ConvMlp(nn.Module): class ConvMlp(nn.Module):
def __init__(self, in_features=512, out_features=4096, kernel_size=7, mlp_ratio=1.0, def __init__(self, in_features=512, out_features=4096, kernel_size=7, mlp_ratio=1.0,

@ -21,7 +21,7 @@ from .vision_transformer import _cfg, Mlp
from .registry import register_model from .registry import register_model
from .layers import DropPath, trunc_normal_, to_2tuple from .layers import DropPath, trunc_normal_, to_2tuple
from .cait import ClassAttn from .cait import ClassAttn
from .fx_features import register_leaf_module from .fx_features import register_notrace_module
def _cfg(url='', **kwargs): def _cfg(url='', **kwargs):
@ -98,7 +98,7 @@ default_cfgs = {
} }
@register_leaf_module # reason: FX can't symbolically trace torch.arange in forward method @register_notrace_module # reason: FX can't symbolically trace torch.arange in forward method
class PositionalEncodingFourier(nn.Module): class PositionalEncodingFourier(nn.Module):
""" """
Positional encoding relying on a fourier kernel matching the one used in the "Attention is all of Need" paper. Positional encoding relying on a fourier kernel matching the one used in the "Attention is all of Need" paper.

Loading…
Cancel
Save