From e051dce35451451b8ac7eee7b8abab38325a26b0 Mon Sep 17 00:00:00 2001 From: Alexander Soare Date: Thu, 12 Aug 2021 15:31:24 +0100 Subject: [PATCH] Make all models FX traceable --- timm/models/layers/bottleneck_attn.py | 1 + timm/models/layers/halo_attn.py | 1 - timm/models/layers/non_local_attn.py | 1 + timm/models/tnt.py | 1 + 4 files changed, 3 insertions(+), 1 deletion(-) diff --git a/timm/models/layers/bottleneck_attn.py b/timm/models/layers/bottleneck_attn.py index c56c5821..305f9de3 100644 --- a/timm/models/layers/bottleneck_attn.py +++ b/timm/models/layers/bottleneck_attn.py @@ -22,6 +22,7 @@ import torch.nn.functional as F from .helpers import to_2tuple, make_divisible from .weight_init import trunc_normal_ +from timm.models.fx_helpers import fx_and def rel_logits_1d(q, rel_k, permute_mask: List[int]): diff --git a/timm/models/layers/halo_attn.py b/timm/models/layers/halo_attn.py index babfcb06..ec93474f 100644 --- a/timm/models/layers/halo_attn.py +++ b/timm/models/layers/halo_attn.py @@ -24,7 +24,6 @@ import torch.nn.functional as F from .helpers import make_divisible from .weight_init import trunc_normal_ -from timm.models.fx_helpers import def rel_logits_1d(q, rel_k, permute_mask: List[int]): diff --git a/timm/models/layers/non_local_attn.py b/timm/models/layers/non_local_attn.py index f933ece2..5f83005c 100644 --- a/timm/models/layers/non_local_attn.py +++ b/timm/models/layers/non_local_attn.py @@ -10,6 +10,7 @@ from torch.nn import functional as F from .conv_bn_act import ConvBnAct from .helpers import make_divisible +from timm.models.fx_helpers import fx_and class NonLocalAttn(nn.Module): diff --git a/timm/models/tnt.py b/timm/models/tnt.py index 92108fe5..298808c3 100644 --- a/timm/models/tnt.py +++ b/timm/models/tnt.py @@ -12,6 +12,7 @@ import torch.nn as nn from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.models.helpers import build_model_with_cfg +from timm.models.fx_helpers import fx_and from timm.models.layers import Mlp, DropPath, trunc_normal_ from timm.models.layers.helpers import to_2tuple from timm.models.registry import register_model