diff --git a/timm/models/fx_features.py b/timm/models/fx_features.py index 9a76e041..310cc465 100644 --- a/timm/models/fx_features.py +++ b/timm/models/fx_features.py @@ -24,7 +24,7 @@ import torch.nn.functional as F from torch.fx.graph_module import _copy_attr from .features import _get_feature_info -from .fx_helpers import fx_and, fx_float_to_int +from .fx_helpers import fx_float_to_int # Layers we went to treat as leaf modules for FeatureGraphNet from .layers import Conv2dSame, ScaledStdConv2dSame, BatchNormAct2d, BlurPool2d, CondConv2d, StdConv2dSame @@ -55,7 +55,7 @@ def register_leaf_module(module: nn.Module): # These functions will not be traced through -_autowrap_functions=(fx_float_to_int, fx_and) +_autowrap_functions=(fx_float_to_int,) class TimmTracer(fx.Tracer): diff --git a/timm/models/fx_helpers.py b/timm/models/fx_helpers.py index 1955d5b1..878ba381 100644 --- a/timm/models/fx_helpers.py +++ b/timm/models/fx_helpers.py @@ -1,14 +1,4 @@ - -def fx_and(a: bool, b: bool) -> bool: - """ - Symbolic tracing helper to substitute for normal usage of `* and *` within `torch._assert`. - Hint: Symbolic tracing does not support control flow but since an `assert` is either a dead-end or not, this hack - is okay. - """ - return (a and b) - - def fx_float_to_int(x: float) -> int: """ Symbolic tracing helper to substitute for inbuilt `int`. diff --git a/timm/models/layers/bottleneck_attn.py b/timm/models/layers/bottleneck_attn.py index 305f9de3..c56c5821 100644 --- a/timm/models/layers/bottleneck_attn.py +++ b/timm/models/layers/bottleneck_attn.py @@ -22,7 +22,6 @@ import torch.nn.functional as F from .helpers import to_2tuple, make_divisible from .weight_init import trunc_normal_ -from timm.models.fx_helpers import fx_and def rel_logits_1d(q, rel_k, permute_mask: List[int]): diff --git a/timm/models/layers/halo_attn.py b/timm/models/layers/halo_attn.py index 0bd611b1..babfcb06 100644 --- a/timm/models/layers/halo_attn.py +++ b/timm/models/layers/halo_attn.py @@ -24,7 +24,7 @@ import torch.nn.functional as F from .helpers import make_divisible from .weight_init import trunc_normal_ -from timm.models.fx_helpers import fx_and +from timm.models.fx_helpers import def rel_logits_1d(q, rel_k, permute_mask: List[int]): diff --git a/timm/models/layers/non_local_attn.py b/timm/models/layers/non_local_attn.py index 517e28a8..f933ece2 100644 --- a/timm/models/layers/non_local_attn.py +++ b/timm/models/layers/non_local_attn.py @@ -10,7 +10,6 @@ from torch.nn import functional as F from .conv_bn_act import ConvBnAct from .helpers import make_divisible -from timm.models.fx_helpers import fx_and class NonLocalAttn(nn.Module): @@ -96,7 +95,8 @@ class BilinearAttnTransform(nn.Module): return x def forward(self, x): - torch._assert(fx_and(x.shape[-1] % self.block_size == 0, x.shape[-2] % self.block_size == 0), '') + torch._assert(x.shape[-1] % self.block_size == 0, '') + torch._assert(x.shape[-2] % self.block_size == 0, '') B, C, H, W = x.shape out = self.conv1(x) rp = F.adaptive_max_pool2d(out, (self.block_size, 1)) diff --git a/timm/models/layers/patch_embed.py b/timm/models/layers/patch_embed.py index 157bc250..6a7facef 100644 --- a/timm/models/layers/patch_embed.py +++ b/timm/models/layers/patch_embed.py @@ -9,11 +9,7 @@ Hacked together by / Copyright 2020 Ross Wightman from torch import nn as nn from .helpers import to_2tuple -<<<<<<< HEAD from .trace_utils import _assert -======= -from timm.models.fx_helpers import fx_and ->>>>>>> Make all models FX traceable class PatchEmbed(nn.Module): diff --git a/timm/models/tnt.py b/timm/models/tnt.py index f9510487..92108fe5 100644 --- a/timm/models/tnt.py +++ b/timm/models/tnt.py @@ -12,7 +12,6 @@ import torch.nn as nn from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.models.helpers import build_model_with_cfg -from timm.models.fx_helpers import fx_and from timm.models.layers import Mlp, DropPath, trunc_normal_ from timm.models.layers.helpers import to_2tuple from timm.models.registry import register_model @@ -138,7 +137,9 @@ class PixelEmbed(nn.Module): def forward(self, x, pixel_pos): B, C, H, W = x.shape - torch._assert(fx_and(H == self.img_size[0], W == self.img_size[1]), + torch._assert(H == self.img_size[0], + f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]}).") + torch._assert(W == self.img_size[1], f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]}).") x = self.proj(x) x = self.unfold(x)