diff --git a/timm/models/layers/gnn_layers.py b/timm/models/layers/gnn_layers.py index ca453e73..a26220e4 100755 --- a/timm/models/layers/gnn_layers.py +++ b/timm/models/layers/gnn_layers.py @@ -6,6 +6,7 @@ from torch import nn import torch.nn.functional as F from .drop import DropPath from .pos_embed import build_sincos2d_pos_embed +from .fx_features import register_notrace_module def pairwise_distance(x, y): @@ -226,6 +227,7 @@ def get_2d_relative_pos_embed(embed_dim, grid_size): return relative_pos +@register_notrace_module # reason: FX can't symbolically trace control flow in forward method class Grapher(nn.Module): """ Grapher module with graph convolution and fc layers