diff --git a/timm/models/vision_gnn.py b/timm/models/vision_gnn.py index fa8b0d72..85a01a57 100755 --- a/timm/models/vision_gnn.py +++ b/timm/models/vision_gnn.py @@ -13,6 +13,7 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from .helpers import load_pretrained, build_model_with_cfg from .layers import DropPath, Grapher from .registry import register_model +from .fx_features import register_notrace_function def _cfg(url='', **kwargs): @@ -172,6 +173,7 @@ class DeepGCN(torch.nn.Module): m.bias.data.zero_() m.bias.requires_grad = True + @register_notrace_function # reason: int argument is a Proxy def _get_pos_embed(self, pos_embed, H, W): if pos_embed is None or (H == pos_embed.size(-2) and W == pos_embed.size(-1)): return pos_embed