diff --git a/timm/models/layers/__init__.py b/timm/models/layers/__init__.py index f398502b..db5033be 100644 --- a/timm/models/layers/__init__.py +++ b/timm/models/layers/__init__.py @@ -21,7 +21,7 @@ from .fast_norm import is_fast_norm, set_fast_norm, fast_group_norm, fast_layer_ from .filter_response_norm import FilterResponseNormTlu2d, FilterResponseNormAct2d from .gather_excite import GatherExcite from .global_context import GlobalContext -from .gnn_layers import Grapher +from .gnn_layers import DyGraphConv2d from .helpers import to_ntuple, to_2tuple, to_3tuple, to_4tuple, make_divisible, extend_tuple from .inplace_abn import InplaceAbn from .linear import Linear diff --git a/timm/models/layers/gnn_layers.py b/timm/models/layers/gnn_layers.py index 4275cb3c..b2645271 100755 --- a/timm/models/layers/gnn_layers.py +++ b/timm/models/layers/gnn_layers.py @@ -4,9 +4,7 @@ import numpy as np import torch from torch import nn import torch.nn.functional as F -from timm.models.fx_features import register_notrace_module from .drop import DropPath -from .pos_embed import build_sincos2d_pos_embed def pairwise_distance(x, y): @@ -215,64 +213,3 @@ class DyGraphConv2d(GraphConv2d): edge_index = self.dilated_knn_graph(x, y, relative_pos) x = super(DyGraphConv2d, self).forward(x, edge_index, y) return x.reshape(B, -1, H, W).contiguous() - - -def get_2d_relative_pos_embed(embed_dim, grid_size): - """ - relative position embedding - References: https://arxiv.org/abs/2009.13658 - """ - pos_embed = build_sincos2d_pos_embed([grid_size, grid_size], embed_dim) - relative_pos = 2 * torch.matmul(pos_embed, pos_embed.transpose(0, 1)) / pos_embed.shape[1] - return relative_pos - - -@register_notrace_module # reason: FX can't symbolically trace control flow in forward method -class Grapher(nn.Module): - """ - Grapher module with graph convolution and fc layers - """ - def __init__(self, in_channels, kernel_size=9, dilation=1, conv='mr', act_layer=nn.GELU, norm=None, - bias=True, stochastic=False, epsilon=0.0, r=1, n=196, drop_path=0.0, relative_pos=False): - super(Grapher, self).__init__() - self.channels = in_channels - self.n = n - self.r = r - self.fc1 = nn.Sequential( - nn.Conv2d(in_channels, in_channels, 1, stride=1, padding=0), - nn.BatchNorm2d(in_channels), - ) - self.graph_conv = DyGraphConv2d(in_channels, in_channels * 2, kernel_size, dilation, conv, - act_layer, norm, bias, stochastic, epsilon, r) - self.fc2 = nn.Sequential( - nn.Conv2d(in_channels * 2, in_channels, 1, stride=1, padding=0), - nn.BatchNorm2d(in_channels), - ) - self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() - if relative_pos: - relative_pos_tensor = get_2d_relative_pos_embed(in_channels, - int(n**0.5)).unsqueeze(0).unsqueeze(1) - relative_pos_tensor = F.interpolate( - relative_pos_tensor, size=(n, n//(r*r)), mode='bicubic', align_corners=False) - # self.relative_pos = nn.Parameter(-relative_pos_tensor.squeeze(1)) - self.register_buffer('relative_pos', -relative_pos_tensor.squeeze(1)) - else: - self.relative_pos = None - - def _get_relative_pos(self, relative_pos, H, W): - if relative_pos is None or H * W == self.n: - return relative_pos - else: - N = H * W - N_reduced = N // (self.r * self.r) - return F.interpolate(relative_pos.unsqueeze(0), size=(N, N_reduced), mode="bicubic").squeeze(0) - - def forward(self, x): - _tmp = x - x = self.fc1(x) - B, C, H, W = x.shape - relative_pos = self._get_relative_pos(self.relative_pos, H, W) - x = self.graph_conv(x, relative_pos) - x = self.fc2(x) - x = self.drop_path(x) + _tmp - return x diff --git a/timm/models/vision_gnn.py b/timm/models/vision_gnn.py index 85a01a57..50f9baf4 100755 --- a/timm/models/vision_gnn.py +++ b/timm/models/vision_gnn.py @@ -11,9 +11,10 @@ from torch.nn import Sequential as Seq from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from .helpers import load_pretrained, build_model_with_cfg -from .layers import DropPath, Grapher +from .layers import DropPath, DyGraphConv2d +from .layers.pos_embed import build_sincos2d_pos_embed from .registry import register_model -from .fx_features import register_notrace_function +from .fx_features import register_notrace_function, register_notrace_module def _cfg(url='', **kwargs): @@ -48,6 +49,68 @@ default_cfgs = { } + +def get_2d_relative_pos_embed(embed_dim, grid_size): + """ + relative position embedding + References: https://arxiv.org/abs/2009.13658 + """ + pos_embed = build_sincos2d_pos_embed([grid_size, grid_size], embed_dim) + relative_pos = 2 * torch.matmul(pos_embed, pos_embed.transpose(0, 1)) / pos_embed.shape[1] + return relative_pos + + +@register_notrace_module # reason: FX can't symbolically trace control flow in forward method +class Grapher(nn.Module): + """ + Grapher module with graph convolution and fc layers + """ + def __init__(self, in_channels, kernel_size=9, dilation=1, conv='mr', act_layer=nn.GELU, norm=None, + bias=True, stochastic=False, epsilon=0.0, r=1, n=196, drop_path=0.0, relative_pos=False): + super(Grapher, self).__init__() + self.channels = in_channels + self.n = n + self.r = r + self.fc1 = nn.Sequential( + nn.Conv2d(in_channels, in_channels, 1, stride=1, padding=0), + nn.BatchNorm2d(in_channels), + ) + self.graph_conv = DyGraphConv2d(in_channels, in_channels * 2, kernel_size, dilation, conv, + act_layer, norm, bias, stochastic, epsilon, r) + self.fc2 = nn.Sequential( + nn.Conv2d(in_channels * 2, in_channels, 1, stride=1, padding=0), + nn.BatchNorm2d(in_channels), + ) + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + if relative_pos: + relative_pos_tensor = get_2d_relative_pos_embed(in_channels, + int(n**0.5)).unsqueeze(0).unsqueeze(1) + relative_pos_tensor = F.interpolate( + relative_pos_tensor, size=(n, n//(r*r)), mode='bicubic', align_corners=False) + self.register_buffer('relative_pos', -relative_pos_tensor.squeeze(1)) + else: + self.relative_pos = None + + @register_notrace_function # reason: int argument is a Proxy + def _get_relative_pos(self, relative_pos, H, W): + if relative_pos is None or H * W == self.n: + return relative_pos + else: + N = H * W + N_reduced = N // (self.r * self.r) + return F.interpolate(relative_pos.unsqueeze(0), size=(N, N_reduced), mode="bicubic").squeeze(0) + + def forward(self, x): + _tmp = x + x = self.fc1(x) + B, C, H, W = x.shape + relative_pos = self._get_relative_pos(self.relative_pos, H, W) + x = self.graph_conv(x, relative_pos) + x = self.fc2(x) + x = self.drop_path(x) + _tmp + return x + + class FFN(nn.Module): def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop_path=0.0):