# Layers for GNN model # Reference: https://github.com/lightaime/deep_gcns_torch import numpy as np import torch from torch import nn import torch.nn.functional as F from .drop import DropPath from .pos_embed import build_sincos2d_pos_embed def pairwise_distance(x, y): """ Compute pairwise distance of a point cloud """ with torch.no_grad(): xy_inner = -2*torch.matmul(x, y.transpose(2, 1)) x_square = torch.sum(torch.mul(x, x), dim=-1, keepdim=True) y_square = torch.sum(torch.mul(y, y), dim=-1, keepdim=True) return x_square + xy_inner + y_square.transpose(2, 1) def dense_knn_matrix(x, y, k=16, relative_pos=None): """Get KNN based on the pairwise distance """ with torch.no_grad(): x = x.transpose(2, 1).squeeze(-1) y = y.transpose(2, 1).squeeze(-1) batch_size, n_points, n_dims = x.shape dist = pairwise_distance(x.detach(), y.detach()) if relative_pos is not None: dist += relative_pos _, nn_idx = torch.topk(-dist, k=k) center_idx = torch.arange(0, n_points, device=x.device).repeat(batch_size, k, 1).transpose(2, 1) return torch.stack((nn_idx, center_idx), dim=0) class DenseDilated(nn.Module): """ Find dilated neighbor from neighbor list """ def __init__(self, k=9, dilation=1, stochastic=False, epsilon=0.0): super(DenseDilated, self).__init__() self.dilation = dilation self.stochastic = stochastic self.epsilon = epsilon self.k = k def forward(self, edge_index): if self.stochastic: if torch.rand(1) < self.epsilon and self.training: num = self.k * self.dilation randnum = torch.randperm(num)[:self.k] edge_index = edge_index[:, :, :, randnum] else: edge_index = edge_index[:, :, :, ::self.dilation] else: edge_index = edge_index[:, :, :, ::self.dilation] return edge_index class DenseDilatedKnnGraph(nn.Module): """ Find the neighbors' indices based on dilated knn """ def __init__(self, k=9, dilation=1, stochastic=False, epsilon=0.0): super(DenseDilatedKnnGraph, self).__init__() self.dilation = dilation self.k = k self._dilated = DenseDilated(k, dilation, stochastic, epsilon) def forward(self, x, y=None, relative_pos=None): x = F.normalize(x, p=2.0, dim=1) if y is not None: y = F.normalize(y, p=2.0, dim=1) edge_index = dense_knn_matrix(x, y, self.k * self.dilation, relative_pos) else: edge_index = dense_knn_matrix(x, x, self.k * self.dilation, relative_pos) return self._dilated(edge_index) def batched_index_select(x, idx): # fetches neighbors features from a given neighbor idx batch_size, num_dims, num_vertices_reduced = x.shape[:3] _, num_vertices, k = idx.shape idx_base = torch.arange(0, batch_size, device=idx.device).view(-1, 1, 1) * num_vertices_reduced idx = idx + idx_base idx = idx.contiguous().view(-1) x = x.transpose(2, 1) feature = x.contiguous().view(batch_size * num_vertices_reduced, -1)[idx, :] feature = feature.view(batch_size, num_vertices, k, num_dims).permute(0, 3, 1, 2).contiguous() return feature def norm_layer(norm, nc): # normalization layer 2d norm = norm.lower() if norm == 'batch': layer = nn.BatchNorm2d(nc, affine=True) elif norm == 'instance': layer = nn.InstanceNorm2d(nc, affine=False) else: raise NotImplementedError('normalization layer [%s] is not found' % norm) return layer class MRConv2d(nn.Module): """ Max-Relative Graph Convolution (Paper: https://arxiv.org/abs/1904.03751) for dense data type """ def __init__(self, in_channels, out_channels, act_layer=nn.GELU, norm=None, bias=True): super(MRConv2d, self).__init__() # self.nn = BasicConv([in_channels*2, out_channels], act_layer, norm, bias) self.nn = nn.Sequential( nn.Conv2d(in_channels*2, out_channels, 1, bias=bias, groups=4), norm_layer(norm, out_channels), act_layer(), ) self.init_weights() def init_weights(self): for m in self.modules(): if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight) if m.bias is not None: nn.init.zeros_(m.bias) elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.InstanceNorm2d): m.weight.data.fill_(1) m.bias.data.zero_() def forward(self, x, edge_index, y=None): x_i = batched_index_select(x, edge_index[1]) if y is not None: x_j = batched_index_select(y, edge_index[0]) else: x_j = batched_index_select(x, edge_index[0]) x_j, _ = torch.max(x_j - x_i, -1, keepdim=True) b, c, n, _ = x.shape x = torch.cat([x.unsqueeze(2), x_j.unsqueeze(2)], dim=2).reshape(b, 2 * c, n, _) return self.nn(x) class EdgeConv2d(nn.Module): """ Edge convolution layer (with activation, batch normalization) for dense data type """ def __init__(self, in_channels, out_channels, act_layer=nn.GELU, norm=None, bias=True): super(EdgeConv2d, self).__init__() self.nn = nn.Sequential( nn.Conv2d(in_channels*2, out_channels, 1, bias=bias, groups=4), norm_layer(norm, out_channels), act_layer(), ) self.init_weights() def init_weights(self): for m in self.modules(): if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight) if m.bias is not None: nn.init.zeros_(m.bias) elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.InstanceNorm2d): m.weight.data.fill_(1) m.bias.data.zero_() def forward(self, x, edge_index, y=None): x_i = batched_index_select(x, edge_index[1]) if y is not None: x_j = batched_index_select(y, edge_index[0]) else: x_j = batched_index_select(x, edge_index[0]) max_value, _ = torch.max(self.nn(torch.cat([x_i, x_j - x_i], dim=1)), -1, keepdim=True) return max_value class GraphConv2d(nn.Module): """ Static graph convolution layer """ def __init__(self, in_channels, out_channels, conv='mr', act_layer=nn.GELU, norm=None, bias=True): super(GraphConv2d, self).__init__() if conv == 'edge': self.gconv = EdgeConv2d(in_channels, out_channels, act_layer, norm, bias) elif conv == 'mr': self.gconv = MRConv2d(in_channels, out_channels, act_layer, norm, bias) else: raise NotImplementedError('conv:{} is not supported'.format(conv)) def forward(self, x, edge_index, y=None): return self.gconv(x, edge_index, y) class DyGraphConv2d(GraphConv2d): """ Dynamic graph convolution layer """ def __init__(self, in_channels, out_channels, kernel_size=9, dilation=1, conv='mr', act_layer=nn.GELU, norm=None, bias=True, stochastic=False, epsilon=0.0, r=1): super(DyGraphConv2d, self).__init__(in_channels, out_channels, conv, act_layer, norm, bias) self.k = kernel_size self.d = dilation self.r = r self.dilated_knn_graph = DenseDilatedKnnGraph(kernel_size, dilation, stochastic, epsilon) def forward(self, x, relative_pos=None): B, C, H, W = x.shape y = None if self.r > 1: y = F.avg_pool2d(x, self.r, self.r) y = y.reshape(B, C, -1, 1).contiguous() x = x.reshape(B, C, -1, 1).contiguous() edge_index = self.dilated_knn_graph(x, y, relative_pos) x = super(DyGraphConv2d, self).forward(x, edge_index, y) return x.reshape(B, -1, H, W).contiguous() def get_2d_relative_pos_embed(embed_dim, grid_size): """ relative position embedding References: https://arxiv.org/abs/2009.13658 """ pos_embed = build_sincos2d_pos_embed([grid_size, grid_size], embed_dim) relative_pos = 2 * torch.matmul(pos_embed, pos_embed.transpose(0, 1)) / pos_embed.shape[1] return relative_pos class Grapher(nn.Module): """ Grapher module with graph convolution and fc layers """ def __init__(self, in_channels, kernel_size=9, dilation=1, conv='mr', act_layer=nn.GELU, norm=None, bias=True, stochastic=False, epsilon=0.0, r=1, n=196, drop_path=0.0, relative_pos=False): super(Grapher, self).__init__() self.channels = in_channels self.n = n self.r = r self.fc1 = nn.Sequential( nn.Conv2d(in_channels, in_channels, 1, stride=1, padding=0), nn.BatchNorm2d(in_channels), ) self.graph_conv = DyGraphConv2d(in_channels, in_channels * 2, kernel_size, dilation, conv, act_layer, norm, bias, stochastic, epsilon, r) self.fc2 = nn.Sequential( nn.Conv2d(in_channels * 2, in_channels, 1, stride=1, padding=0), nn.BatchNorm2d(in_channels), ) self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() if relative_pos: relative_pos_tensor = get_2d_relative_pos_embed(in_channels, int(n**0.5)).unsqueeze(0).unsqueeze(1) relative_pos_tensor = F.interpolate( relative_pos_tensor, size=(n, n//(r*r)), mode='bicubic', align_corners=False) # self.relative_pos = nn.Parameter(-relative_pos_tensor.squeeze(1)) self.register_buffer('relative_pos', -relative_pos_tensor.squeeze(1)) else: self.relative_pos = None def _get_relative_pos(self, relative_pos, H, W): if relative_pos is None or H * W == self.n: return relative_pos else: N = H * W N_reduced = N // (self.r * self.r) return F.interpolate(relative_pos.unsqueeze(0), size=(N, N_reduced), mode="bicubic").squeeze(0) def forward(self, x): _tmp = x x = self.fc1(x) B, C, H, W = x.shape relative_pos = self._get_relative_pos(self.relative_pos, H, W) x = self.graph_conv(x, relative_pos) x = self.fc2(x) x = self.drop_path(x) + _tmp return x