|
|
|
@ -6,6 +6,7 @@ from torch import nn
|
|
|
|
|
import torch.nn.functional as F
|
|
|
|
|
from .drop import DropPath
|
|
|
|
|
from .pos_embed import build_sincos2d_pos_embed
|
|
|
|
|
from .fx_features import register_notrace_module
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def pairwise_distance(x, y):
|
|
|
|
@ -226,6 +227,7 @@ def get_2d_relative_pos_embed(embed_dim, grid_size):
|
|
|
|
|
return relative_pos
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_notrace_module # reason: FX can't symbolically trace control flow in forward method
|
|
|
|
|
class Grapher(nn.Module):
|
|
|
|
|
"""
|
|
|
|
|
Grapher module with graph convolution and fc layers
|
|
|
|
|