fix test errors

pull/1578/head
iamhankai 3 years ago
parent df75883615
commit 7e71058c88

@ -21,7 +21,7 @@ from .fast_norm import is_fast_norm, set_fast_norm, fast_group_norm, fast_layer_
from .filter_response_norm import FilterResponseNormTlu2d, FilterResponseNormAct2d from .filter_response_norm import FilterResponseNormTlu2d, FilterResponseNormAct2d
from .gather_excite import GatherExcite from .gather_excite import GatherExcite
from .global_context import GlobalContext from .global_context import GlobalContext
from .gnn_layers import Grapher from .gnn_layers import DyGraphConv2d
from .helpers import to_ntuple, to_2tuple, to_3tuple, to_4tuple, make_divisible, extend_tuple from .helpers import to_ntuple, to_2tuple, to_3tuple, to_4tuple, make_divisible, extend_tuple
from .inplace_abn import InplaceAbn from .inplace_abn import InplaceAbn
from .linear import Linear from .linear import Linear

@ -4,9 +4,7 @@ import numpy as np
import torch import torch
from torch import nn from torch import nn
import torch.nn.functional as F import torch.nn.functional as F
from timm.models.fx_features import register_notrace_module
from .drop import DropPath from .drop import DropPath
from .pos_embed import build_sincos2d_pos_embed
def pairwise_distance(x, y): def pairwise_distance(x, y):
@ -215,64 +213,3 @@ class DyGraphConv2d(GraphConv2d):
edge_index = self.dilated_knn_graph(x, y, relative_pos) edge_index = self.dilated_knn_graph(x, y, relative_pos)
x = super(DyGraphConv2d, self).forward(x, edge_index, y) x = super(DyGraphConv2d, self).forward(x, edge_index, y)
return x.reshape(B, -1, H, W).contiguous() return x.reshape(B, -1, H, W).contiguous()
def get_2d_relative_pos_embed(embed_dim, grid_size):
"""
relative position embedding
References: https://arxiv.org/abs/2009.13658
"""
pos_embed = build_sincos2d_pos_embed([grid_size, grid_size], embed_dim)
relative_pos = 2 * torch.matmul(pos_embed, pos_embed.transpose(0, 1)) / pos_embed.shape[1]
return relative_pos
@register_notrace_module # reason: FX can't symbolically trace control flow in forward method
class Grapher(nn.Module):
"""
Grapher module with graph convolution and fc layers
"""
def __init__(self, in_channels, kernel_size=9, dilation=1, conv='mr', act_layer=nn.GELU, norm=None,
bias=True, stochastic=False, epsilon=0.0, r=1, n=196, drop_path=0.0, relative_pos=False):
super(Grapher, self).__init__()
self.channels = in_channels
self.n = n
self.r = r
self.fc1 = nn.Sequential(
nn.Conv2d(in_channels, in_channels, 1, stride=1, padding=0),
nn.BatchNorm2d(in_channels),
)
self.graph_conv = DyGraphConv2d(in_channels, in_channels * 2, kernel_size, dilation, conv,
act_layer, norm, bias, stochastic, epsilon, r)
self.fc2 = nn.Sequential(
nn.Conv2d(in_channels * 2, in_channels, 1, stride=1, padding=0),
nn.BatchNorm2d(in_channels),
)
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
if relative_pos:
relative_pos_tensor = get_2d_relative_pos_embed(in_channels,
int(n**0.5)).unsqueeze(0).unsqueeze(1)
relative_pos_tensor = F.interpolate(
relative_pos_tensor, size=(n, n//(r*r)), mode='bicubic', align_corners=False)
# self.relative_pos = nn.Parameter(-relative_pos_tensor.squeeze(1))
self.register_buffer('relative_pos', -relative_pos_tensor.squeeze(1))
else:
self.relative_pos = None
def _get_relative_pos(self, relative_pos, H, W):
if relative_pos is None or H * W == self.n:
return relative_pos
else:
N = H * W
N_reduced = N // (self.r * self.r)
return F.interpolate(relative_pos.unsqueeze(0), size=(N, N_reduced), mode="bicubic").squeeze(0)
def forward(self, x):
_tmp = x
x = self.fc1(x)
B, C, H, W = x.shape
relative_pos = self._get_relative_pos(self.relative_pos, H, W)
x = self.graph_conv(x, relative_pos)
x = self.fc2(x)
x = self.drop_path(x) + _tmp
return x

@ -11,9 +11,10 @@ from torch.nn import Sequential as Seq
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .helpers import load_pretrained, build_model_with_cfg from .helpers import load_pretrained, build_model_with_cfg
from .layers import DropPath, Grapher from .layers import DropPath, DyGraphConv2d
from .layers.pos_embed import build_sincos2d_pos_embed
from .registry import register_model from .registry import register_model
from .fx_features import register_notrace_function from .fx_features import register_notrace_function, register_notrace_module
def _cfg(url='', **kwargs): def _cfg(url='', **kwargs):
@ -48,6 +49,68 @@ default_cfgs = {
} }
def get_2d_relative_pos_embed(embed_dim, grid_size):
"""
relative position embedding
References: https://arxiv.org/abs/2009.13658
"""
pos_embed = build_sincos2d_pos_embed([grid_size, grid_size], embed_dim)
relative_pos = 2 * torch.matmul(pos_embed, pos_embed.transpose(0, 1)) / pos_embed.shape[1]
return relative_pos
@register_notrace_module # reason: FX can't symbolically trace control flow in forward method
class Grapher(nn.Module):
"""
Grapher module with graph convolution and fc layers
"""
def __init__(self, in_channels, kernel_size=9, dilation=1, conv='mr', act_layer=nn.GELU, norm=None,
bias=True, stochastic=False, epsilon=0.0, r=1, n=196, drop_path=0.0, relative_pos=False):
super(Grapher, self).__init__()
self.channels = in_channels
self.n = n
self.r = r
self.fc1 = nn.Sequential(
nn.Conv2d(in_channels, in_channels, 1, stride=1, padding=0),
nn.BatchNorm2d(in_channels),
)
self.graph_conv = DyGraphConv2d(in_channels, in_channels * 2, kernel_size, dilation, conv,
act_layer, norm, bias, stochastic, epsilon, r)
self.fc2 = nn.Sequential(
nn.Conv2d(in_channels * 2, in_channels, 1, stride=1, padding=0),
nn.BatchNorm2d(in_channels),
)
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
if relative_pos:
relative_pos_tensor = get_2d_relative_pos_embed(in_channels,
int(n**0.5)).unsqueeze(0).unsqueeze(1)
relative_pos_tensor = F.interpolate(
relative_pos_tensor, size=(n, n//(r*r)), mode='bicubic', align_corners=False)
self.register_buffer('relative_pos', -relative_pos_tensor.squeeze(1))
else:
self.relative_pos = None
@register_notrace_function # reason: int argument is a Proxy
def _get_relative_pos(self, relative_pos, H, W):
if relative_pos is None or H * W == self.n:
return relative_pos
else:
N = H * W
N_reduced = N // (self.r * self.r)
return F.interpolate(relative_pos.unsqueeze(0), size=(N, N_reduced), mode="bicubic").squeeze(0)
def forward(self, x):
_tmp = x
x = self.fc1(x)
B, C, H, W = x.shape
relative_pos = self._get_relative_pos(self.relative_pos, H, W)
x = self.graph_conv(x, relative_pos)
x = self.fc2(x)
x = self.drop_path(x) + _tmp
return x
class FFN(nn.Module): class FFN(nn.Module):
def __init__(self, in_features, hidden_features=None, out_features=None, def __init__(self, in_features, hidden_features=None, out_features=None,
act_layer=nn.GELU, drop_path=0.0): act_layer=nn.GELU, drop_path=0.0):

Loading…
Cancel
Save