|
|
|
@ -11,9 +11,10 @@ from torch.nn import Sequential as Seq
|
|
|
|
|
|
|
|
|
|
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
|
|
|
|
from .helpers import load_pretrained, build_model_with_cfg
|
|
|
|
|
from .layers import DropPath, Grapher
|
|
|
|
|
from .layers import DropPath, DyGraphConv2d
|
|
|
|
|
from .layers.pos_embed import build_sincos2d_pos_embed
|
|
|
|
|
from .registry import register_model
|
|
|
|
|
from .fx_features import register_notrace_function
|
|
|
|
|
from .fx_features import register_notrace_function, register_notrace_module
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _cfg(url='', **kwargs):
|
|
|
|
@ -48,6 +49,68 @@ default_cfgs = {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_2d_relative_pos_embed(embed_dim, grid_size):
|
|
|
|
|
"""
|
|
|
|
|
relative position embedding
|
|
|
|
|
References: https://arxiv.org/abs/2009.13658
|
|
|
|
|
"""
|
|
|
|
|
pos_embed = build_sincos2d_pos_embed([grid_size, grid_size], embed_dim)
|
|
|
|
|
relative_pos = 2 * torch.matmul(pos_embed, pos_embed.transpose(0, 1)) / pos_embed.shape[1]
|
|
|
|
|
return relative_pos
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_notrace_module # reason: FX can't symbolically trace control flow in forward method
|
|
|
|
|
class Grapher(nn.Module):
|
|
|
|
|
"""
|
|
|
|
|
Grapher module with graph convolution and fc layers
|
|
|
|
|
"""
|
|
|
|
|
def __init__(self, in_channels, kernel_size=9, dilation=1, conv='mr', act_layer=nn.GELU, norm=None,
|
|
|
|
|
bias=True, stochastic=False, epsilon=0.0, r=1, n=196, drop_path=0.0, relative_pos=False):
|
|
|
|
|
super(Grapher, self).__init__()
|
|
|
|
|
self.channels = in_channels
|
|
|
|
|
self.n = n
|
|
|
|
|
self.r = r
|
|
|
|
|
self.fc1 = nn.Sequential(
|
|
|
|
|
nn.Conv2d(in_channels, in_channels, 1, stride=1, padding=0),
|
|
|
|
|
nn.BatchNorm2d(in_channels),
|
|
|
|
|
)
|
|
|
|
|
self.graph_conv = DyGraphConv2d(in_channels, in_channels * 2, kernel_size, dilation, conv,
|
|
|
|
|
act_layer, norm, bias, stochastic, epsilon, r)
|
|
|
|
|
self.fc2 = nn.Sequential(
|
|
|
|
|
nn.Conv2d(in_channels * 2, in_channels, 1, stride=1, padding=0),
|
|
|
|
|
nn.BatchNorm2d(in_channels),
|
|
|
|
|
)
|
|
|
|
|
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
|
|
|
|
if relative_pos:
|
|
|
|
|
relative_pos_tensor = get_2d_relative_pos_embed(in_channels,
|
|
|
|
|
int(n**0.5)).unsqueeze(0).unsqueeze(1)
|
|
|
|
|
relative_pos_tensor = F.interpolate(
|
|
|
|
|
relative_pos_tensor, size=(n, n//(r*r)), mode='bicubic', align_corners=False)
|
|
|
|
|
self.register_buffer('relative_pos', -relative_pos_tensor.squeeze(1))
|
|
|
|
|
else:
|
|
|
|
|
self.relative_pos = None
|
|
|
|
|
|
|
|
|
|
@register_notrace_function # reason: int argument is a Proxy
|
|
|
|
|
def _get_relative_pos(self, relative_pos, H, W):
|
|
|
|
|
if relative_pos is None or H * W == self.n:
|
|
|
|
|
return relative_pos
|
|
|
|
|
else:
|
|
|
|
|
N = H * W
|
|
|
|
|
N_reduced = N // (self.r * self.r)
|
|
|
|
|
return F.interpolate(relative_pos.unsqueeze(0), size=(N, N_reduced), mode="bicubic").squeeze(0)
|
|
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
|
_tmp = x
|
|
|
|
|
x = self.fc1(x)
|
|
|
|
|
B, C, H, W = x.shape
|
|
|
|
|
relative_pos = self._get_relative_pos(self.relative_pos, H, W)
|
|
|
|
|
x = self.graph_conv(x, relative_pos)
|
|
|
|
|
x = self.fc2(x)
|
|
|
|
|
x = self.drop_path(x) + _tmp
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class FFN(nn.Module):
|
|
|
|
|
def __init__(self, in_features, hidden_features=None, out_features=None,
|
|
|
|
|
act_layer=nn.GELU, drop_path=0.0):
|
|
|
|
|