You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
216 lines
7.8 KiB
216 lines
7.8 KiB
# Layers for GNN model
|
|
# Reference: https://github.com/lightaime/deep_gcns_torch
|
|
import numpy as np
|
|
import torch
|
|
from torch import nn
|
|
import torch.nn.functional as F
|
|
from .drop import DropPath
|
|
|
|
|
|
def pairwise_distance(x, y):
|
|
"""
|
|
Compute pairwise distance of a point cloud
|
|
"""
|
|
with torch.no_grad():
|
|
xy_inner = -2*torch.matmul(x, y.transpose(2, 1))
|
|
x_square = torch.sum(torch.mul(x, x), dim=-1, keepdim=True)
|
|
y_square = torch.sum(torch.mul(y, y), dim=-1, keepdim=True)
|
|
return x_square + xy_inner + y_square.transpose(2, 1)
|
|
|
|
|
|
def dense_knn_matrix(x, y, k=16, relative_pos=None):
|
|
"""Get KNN based on the pairwise distance
|
|
"""
|
|
with torch.no_grad():
|
|
x = x.transpose(2, 1).squeeze(-1)
|
|
y = y.transpose(2, 1).squeeze(-1)
|
|
batch_size, n_points, n_dims = x.shape
|
|
dist = pairwise_distance(x.detach(), y.detach())
|
|
if relative_pos is not None:
|
|
dist += relative_pos
|
|
_, nn_idx = torch.topk(-dist, k=k)
|
|
center_idx = torch.arange(0, n_points, device=x.device).repeat(batch_size, k, 1).transpose(2, 1)
|
|
return torch.stack((nn_idx, center_idx), dim=0)
|
|
|
|
|
|
class DenseDilated(nn.Module):
|
|
"""
|
|
Find dilated neighbor from neighbor list
|
|
"""
|
|
def __init__(self, k=9, dilation=1, stochastic=False, epsilon=0.0):
|
|
super(DenseDilated, self).__init__()
|
|
self.dilation = dilation
|
|
self.stochastic = stochastic
|
|
self.epsilon = epsilon
|
|
self.k = k
|
|
|
|
def forward(self, edge_index):
|
|
if self.stochastic:
|
|
if torch.rand(1) < self.epsilon and self.training:
|
|
num = self.k * self.dilation
|
|
randnum = torch.randperm(num)[:self.k]
|
|
edge_index = edge_index[:, :, :, randnum]
|
|
else:
|
|
edge_index = edge_index[:, :, :, ::self.dilation]
|
|
else:
|
|
edge_index = edge_index[:, :, :, ::self.dilation]
|
|
return edge_index
|
|
|
|
|
|
class DenseDilatedKnnGraph(nn.Module):
|
|
"""
|
|
Find the neighbors' indices based on dilated knn
|
|
"""
|
|
def __init__(self, k=9, dilation=1, stochastic=False, epsilon=0.0):
|
|
super(DenseDilatedKnnGraph, self).__init__()
|
|
self.dilation = dilation
|
|
self.k = k
|
|
self._dilated = DenseDilated(k, dilation, stochastic, epsilon)
|
|
|
|
def forward(self, x, y=None, relative_pos=None):
|
|
x = F.normalize(x, p=2.0, dim=1)
|
|
if y is not None:
|
|
y = F.normalize(y, p=2.0, dim=1)
|
|
edge_index = dense_knn_matrix(x, y, self.k * self.dilation, relative_pos)
|
|
else:
|
|
edge_index = dense_knn_matrix(x, x, self.k * self.dilation, relative_pos)
|
|
return self._dilated(edge_index)
|
|
|
|
|
|
def batched_index_select(x, idx):
|
|
# fetches neighbors features from a given neighbor idx
|
|
batch_size, num_dims, num_vertices_reduced = x.shape[:3]
|
|
_, num_vertices, k = idx.shape
|
|
idx_base = torch.arange(0, batch_size, device=idx.device).view(-1, 1, 1) * num_vertices_reduced
|
|
idx = idx + idx_base
|
|
idx = idx.contiguous().view(-1)
|
|
|
|
x = x.transpose(2, 1)
|
|
feature = x.contiguous().view(batch_size * num_vertices_reduced, -1)[idx, :]
|
|
feature = feature.view(batch_size, num_vertices, k, num_dims).permute(0, 3, 1, 2).contiguous()
|
|
return feature
|
|
|
|
|
|
def norm_layer(norm, nc):
|
|
# normalization layer 2d
|
|
norm = norm.lower()
|
|
if norm == 'batch':
|
|
layer = nn.BatchNorm2d(nc, affine=True)
|
|
elif norm == 'instance':
|
|
layer = nn.InstanceNorm2d(nc, affine=False)
|
|
else:
|
|
raise NotImplementedError('normalization layer [%s] is not found' % norm)
|
|
return layer
|
|
|
|
|
|
class MRConv2d(nn.Module):
|
|
"""
|
|
Max-Relative Graph Convolution (Paper: https://arxiv.org/abs/1904.03751) for dense data type
|
|
"""
|
|
def __init__(self, in_channels, out_channels, act_layer=nn.GELU, norm=None, bias=True):
|
|
super(MRConv2d, self).__init__()
|
|
# self.nn = BasicConv([in_channels*2, out_channels], act_layer, norm, bias)
|
|
self.nn = nn.Sequential(
|
|
nn.Conv2d(in_channels*2, out_channels, 1, bias=bias, groups=4),
|
|
norm_layer(norm, out_channels),
|
|
act_layer(),
|
|
)
|
|
|
|
self.init_weights()
|
|
|
|
def init_weights(self):
|
|
for m in self.modules():
|
|
if isinstance(m, nn.Conv2d):
|
|
nn.init.kaiming_normal_(m.weight)
|
|
if m.bias is not None:
|
|
nn.init.zeros_(m.bias)
|
|
elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.InstanceNorm2d):
|
|
m.weight.data.fill_(1)
|
|
m.bias.data.zero_()
|
|
|
|
def forward(self, x, edge_index, y=None):
|
|
x_i = batched_index_select(x, edge_index[1])
|
|
if y is not None:
|
|
x_j = batched_index_select(y, edge_index[0])
|
|
else:
|
|
x_j = batched_index_select(x, edge_index[0])
|
|
x_j, _ = torch.max(x_j - x_i, -1, keepdim=True)
|
|
b, c, n, _ = x.shape
|
|
x = torch.cat([x.unsqueeze(2), x_j.unsqueeze(2)], dim=2).reshape(b, 2 * c, n, _)
|
|
return self.nn(x)
|
|
|
|
|
|
class EdgeConv2d(nn.Module):
|
|
"""
|
|
Edge convolution layer (with activation, batch normalization) for dense data type
|
|
"""
|
|
def __init__(self, in_channels, out_channels, act_layer=nn.GELU, norm=None, bias=True):
|
|
super(EdgeConv2d, self).__init__()
|
|
self.nn = nn.Sequential(
|
|
nn.Conv2d(in_channels*2, out_channels, 1, bias=bias, groups=4),
|
|
norm_layer(norm, out_channels),
|
|
act_layer(),
|
|
)
|
|
|
|
self.init_weights()
|
|
|
|
def init_weights(self):
|
|
for m in self.modules():
|
|
if isinstance(m, nn.Conv2d):
|
|
nn.init.kaiming_normal_(m.weight)
|
|
if m.bias is not None:
|
|
nn.init.zeros_(m.bias)
|
|
elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.InstanceNorm2d):
|
|
m.weight.data.fill_(1)
|
|
m.bias.data.zero_()
|
|
|
|
def forward(self, x, edge_index, y=None):
|
|
x_i = batched_index_select(x, edge_index[1])
|
|
if y is not None:
|
|
x_j = batched_index_select(y, edge_index[0])
|
|
else:
|
|
x_j = batched_index_select(x, edge_index[0])
|
|
max_value, _ = torch.max(self.nn(torch.cat([x_i, x_j - x_i], dim=1)), -1, keepdim=True)
|
|
return max_value
|
|
|
|
|
|
class GraphConv2d(nn.Module):
|
|
"""
|
|
Static graph convolution layer
|
|
"""
|
|
def __init__(self, in_channels, out_channels, conv='mr', act_layer=nn.GELU, norm=None, bias=True):
|
|
super(GraphConv2d, self).__init__()
|
|
if conv == 'edge':
|
|
self.gconv = EdgeConv2d(in_channels, out_channels, act_layer, norm, bias)
|
|
elif conv == 'mr':
|
|
self.gconv = MRConv2d(in_channels, out_channels, act_layer, norm, bias)
|
|
else:
|
|
raise NotImplementedError('conv:{} is not supported'.format(conv))
|
|
|
|
def forward(self, x, edge_index, y=None):
|
|
return self.gconv(x, edge_index, y)
|
|
|
|
|
|
class DyGraphConv2d(GraphConv2d):
|
|
"""
|
|
Dynamic graph convolution layer
|
|
"""
|
|
def __init__(self, in_channels, out_channels, kernel_size=9, dilation=1, conv='mr', act_layer=nn.GELU,
|
|
norm=None, bias=True, stochastic=False, epsilon=0.0, r=1):
|
|
super(DyGraphConv2d, self).__init__(in_channels, out_channels, conv, act_layer, norm, bias)
|
|
self.k = kernel_size
|
|
self.d = dilation
|
|
self.r = r
|
|
self.dilated_knn_graph = DenseDilatedKnnGraph(kernel_size, dilation, stochastic, epsilon)
|
|
|
|
def forward(self, x, relative_pos=None):
|
|
B, C, H, W = x.shape
|
|
y = None
|
|
if self.r > 1:
|
|
y = F.avg_pool2d(x, self.r, self.r)
|
|
y = y.reshape(B, C, -1, 1).contiguous()
|
|
x = x.reshape(B, C, -1, 1).contiguous()
|
|
edge_index = self.dilated_knn_graph(x, y, relative_pos)
|
|
x = super(DyGraphConv2d, self).forward(x, edge_index, y)
|
|
return x.reshape(B, -1, H, W).contiguous()
|