parent
e8ddc6865c
commit
cb725a85a2
@ -0,0 +1,274 @@
|
|||||||
|
# Layers for GNN model
|
||||||
|
# Reference: https://github.com/lightaime/deep_gcns_torch
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from .drop import DropPath
|
||||||
|
from .pos_embed import build_sincos2d_pos_embed
|
||||||
|
|
||||||
|
|
||||||
|
def pairwise_distance(x, y):
|
||||||
|
"""
|
||||||
|
Compute pairwise distance of a point cloud
|
||||||
|
"""
|
||||||
|
with torch.no_grad():
|
||||||
|
xy_inner = -2*torch.matmul(x, y.transpose(2, 1))
|
||||||
|
x_square = torch.sum(torch.mul(x, x), dim=-1, keepdim=True)
|
||||||
|
y_square = torch.sum(torch.mul(y, y), dim=-1, keepdim=True)
|
||||||
|
return x_square + xy_inner + y_square.transpose(2, 1)
|
||||||
|
|
||||||
|
|
||||||
|
def dense_knn_matrix(x, y, k=16, relative_pos=None):
|
||||||
|
"""Get KNN based on the pairwise distance
|
||||||
|
"""
|
||||||
|
with torch.no_grad():
|
||||||
|
x = x.transpose(2, 1).squeeze(-1)
|
||||||
|
y = y.transpose(2, 1).squeeze(-1)
|
||||||
|
batch_size, n_points, n_dims = x.shape
|
||||||
|
dist = pairwise_distance(x.detach(), y.detach())
|
||||||
|
if relative_pos is not None:
|
||||||
|
dist += relative_pos
|
||||||
|
_, nn_idx = torch.topk(-dist, k=k)
|
||||||
|
center_idx = torch.arange(0, n_points, device=x.device).repeat(batch_size, k, 1).transpose(2, 1)
|
||||||
|
return torch.stack((nn_idx, center_idx), dim=0)
|
||||||
|
|
||||||
|
|
||||||
|
class DenseDilated(nn.Module):
|
||||||
|
"""
|
||||||
|
Find dilated neighbor from neighbor list
|
||||||
|
"""
|
||||||
|
def __init__(self, k=9, dilation=1, stochastic=False, epsilon=0.0):
|
||||||
|
super(DenseDilated, self).__init__()
|
||||||
|
self.dilation = dilation
|
||||||
|
self.stochastic = stochastic
|
||||||
|
self.epsilon = epsilon
|
||||||
|
self.k = k
|
||||||
|
|
||||||
|
def forward(self, edge_index):
|
||||||
|
if self.stochastic:
|
||||||
|
if torch.rand(1) < self.epsilon and self.training:
|
||||||
|
num = self.k * self.dilation
|
||||||
|
randnum = torch.randperm(num)[:self.k]
|
||||||
|
edge_index = edge_index[:, :, :, randnum]
|
||||||
|
else:
|
||||||
|
edge_index = edge_index[:, :, :, ::self.dilation]
|
||||||
|
else:
|
||||||
|
edge_index = edge_index[:, :, :, ::self.dilation]
|
||||||
|
return edge_index
|
||||||
|
|
||||||
|
|
||||||
|
class DenseDilatedKnnGraph(nn.Module):
|
||||||
|
"""
|
||||||
|
Find the neighbors' indices based on dilated knn
|
||||||
|
"""
|
||||||
|
def __init__(self, k=9, dilation=1, stochastic=False, epsilon=0.0):
|
||||||
|
super(DenseDilatedKnnGraph, self).__init__()
|
||||||
|
self.dilation = dilation
|
||||||
|
self.k = k
|
||||||
|
self._dilated = DenseDilated(k, dilation, stochastic, epsilon)
|
||||||
|
|
||||||
|
def forward(self, x, y=None, relative_pos=None):
|
||||||
|
x = F.normalize(x, p=2.0, dim=1)
|
||||||
|
if y is not None:
|
||||||
|
y = F.normalize(y, p=2.0, dim=1)
|
||||||
|
edge_index = dense_knn_matrix(x, y, self.k * self.dilation, relative_pos)
|
||||||
|
else:
|
||||||
|
edge_index = dense_knn_matrix(x, x, self.k * self.dilation, relative_pos)
|
||||||
|
return self._dilated(edge_index)
|
||||||
|
|
||||||
|
|
||||||
|
def batched_index_select(x, idx):
|
||||||
|
# fetches neighbors features from a given neighbor idx
|
||||||
|
batch_size, num_dims, num_vertices_reduced = x.shape[:3]
|
||||||
|
_, num_vertices, k = idx.shape
|
||||||
|
idx_base = torch.arange(0, batch_size, device=idx.device).view(-1, 1, 1) * num_vertices_reduced
|
||||||
|
idx = idx + idx_base
|
||||||
|
idx = idx.contiguous().view(-1)
|
||||||
|
|
||||||
|
x = x.transpose(2, 1)
|
||||||
|
feature = x.contiguous().view(batch_size * num_vertices_reduced, -1)[idx, :]
|
||||||
|
feature = feature.view(batch_size, num_vertices, k, num_dims).permute(0, 3, 1, 2).contiguous()
|
||||||
|
return feature
|
||||||
|
|
||||||
|
|
||||||
|
def norm_layer(norm, nc):
|
||||||
|
# normalization layer 2d
|
||||||
|
norm = norm.lower()
|
||||||
|
if norm == 'batch':
|
||||||
|
layer = nn.BatchNorm2d(nc, affine=True)
|
||||||
|
elif norm == 'instance':
|
||||||
|
layer = nn.InstanceNorm2d(nc, affine=False)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError('normalization layer [%s] is not found' % norm)
|
||||||
|
return layer
|
||||||
|
|
||||||
|
|
||||||
|
class MRConv2d(nn.Module):
|
||||||
|
"""
|
||||||
|
Max-Relative Graph Convolution (Paper: https://arxiv.org/abs/1904.03751) for dense data type
|
||||||
|
"""
|
||||||
|
def __init__(self, in_channels, out_channels, act_layer=nn.GELU, norm=None, bias=True):
|
||||||
|
super(MRConv2d, self).__init__()
|
||||||
|
# self.nn = BasicConv([in_channels*2, out_channels], act_layer, norm, bias)
|
||||||
|
self.nn = nn.Sequential(
|
||||||
|
nn.Conv2d(in_channels*2, out_channels, 1, bias=bias, groups=4),
|
||||||
|
norm_layer(norm, out_channels),
|
||||||
|
act_layer(),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.init_weights()
|
||||||
|
|
||||||
|
def init_weights(self):
|
||||||
|
for m in self.modules():
|
||||||
|
if isinstance(m, nn.Conv2d):
|
||||||
|
nn.init.kaiming_normal_(m.weight)
|
||||||
|
if m.bias is not None:
|
||||||
|
nn.init.zeros_(m.bias)
|
||||||
|
elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.InstanceNorm2d):
|
||||||
|
m.weight.data.fill_(1)
|
||||||
|
m.bias.data.zero_()
|
||||||
|
|
||||||
|
def forward(self, x, edge_index, y=None):
|
||||||
|
x_i = batched_index_select(x, edge_index[1])
|
||||||
|
if y is not None:
|
||||||
|
x_j = batched_index_select(y, edge_index[0])
|
||||||
|
else:
|
||||||
|
x_j = batched_index_select(x, edge_index[0])
|
||||||
|
x_j, _ = torch.max(x_j - x_i, -1, keepdim=True)
|
||||||
|
b, c, n, _ = x.shape
|
||||||
|
x = torch.cat([x.unsqueeze(2), x_j.unsqueeze(2)], dim=2).reshape(b, 2 * c, n, _)
|
||||||
|
return self.nn(x)
|
||||||
|
|
||||||
|
|
||||||
|
class EdgeConv2d(nn.Module):
|
||||||
|
"""
|
||||||
|
Edge convolution layer (with activation, batch normalization) for dense data type
|
||||||
|
"""
|
||||||
|
def __init__(self, in_channels, out_channels, act_layer=nn.GELU, norm=None, bias=True):
|
||||||
|
super(EdgeConv2d, self).__init__()
|
||||||
|
self.nn = nn.Sequential(
|
||||||
|
nn.Conv2d(in_channels*2, out_channels, 1, bias=bias, groups=4),
|
||||||
|
norm_layer(norm, out_channels),
|
||||||
|
act_layer(),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.init_weights()
|
||||||
|
|
||||||
|
def init_weights(self):
|
||||||
|
for m in self.modules():
|
||||||
|
if isinstance(m, nn.Conv2d):
|
||||||
|
nn.init.kaiming_normal_(m.weight)
|
||||||
|
if m.bias is not None:
|
||||||
|
nn.init.zeros_(m.bias)
|
||||||
|
elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.InstanceNorm2d):
|
||||||
|
m.weight.data.fill_(1)
|
||||||
|
m.bias.data.zero_()
|
||||||
|
|
||||||
|
def forward(self, x, edge_index, y=None):
|
||||||
|
x_i = batched_index_select(x, edge_index[1])
|
||||||
|
if y is not None:
|
||||||
|
x_j = batched_index_select(y, edge_index[0])
|
||||||
|
else:
|
||||||
|
x_j = batched_index_select(x, edge_index[0])
|
||||||
|
max_value, _ = torch.max(self.nn(torch.cat([x_i, x_j - x_i], dim=1)), -1, keepdim=True)
|
||||||
|
return max_value
|
||||||
|
|
||||||
|
|
||||||
|
class GraphConv2d(nn.Module):
|
||||||
|
"""
|
||||||
|
Static graph convolution layer
|
||||||
|
"""
|
||||||
|
def __init__(self, in_channels, out_channels, conv='mr', act_layer=nn.GELU, norm=None, bias=True):
|
||||||
|
super(GraphConv2d, self).__init__()
|
||||||
|
if conv == 'edge':
|
||||||
|
self.gconv = EdgeConv2d(in_channels, out_channels, act_layer, norm, bias)
|
||||||
|
elif conv == 'mr':
|
||||||
|
self.gconv = MRConv2d(in_channels, out_channels, act_layer, norm, bias)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError('conv:{} is not supported'.format(conv))
|
||||||
|
|
||||||
|
def forward(self, x, edge_index, y=None):
|
||||||
|
return self.gconv(x, edge_index, y)
|
||||||
|
|
||||||
|
|
||||||
|
class DyGraphConv2d(GraphConv2d):
|
||||||
|
"""
|
||||||
|
Dynamic graph convolution layer
|
||||||
|
"""
|
||||||
|
def __init__(self, in_channels, out_channels, kernel_size=9, dilation=1, conv='mr', act_layer=nn.GELU,
|
||||||
|
norm=None, bias=True, stochastic=False, epsilon=0.0, r=1):
|
||||||
|
super(DyGraphConv2d, self).__init__(in_channels, out_channels, conv, act_layer, norm, bias)
|
||||||
|
self.k = kernel_size
|
||||||
|
self.d = dilation
|
||||||
|
self.r = r
|
||||||
|
self.dilated_knn_graph = DenseDilatedKnnGraph(kernel_size, dilation, stochastic, epsilon)
|
||||||
|
|
||||||
|
def forward(self, x, relative_pos=None):
|
||||||
|
B, C, H, W = x.shape
|
||||||
|
y = None
|
||||||
|
if self.r > 1:
|
||||||
|
y = F.avg_pool2d(x, self.r, self.r)
|
||||||
|
y = y.reshape(B, C, -1, 1).contiguous()
|
||||||
|
x = x.reshape(B, C, -1, 1).contiguous()
|
||||||
|
edge_index = self.dilated_knn_graph(x, y, relative_pos)
|
||||||
|
x = super(DyGraphConv2d, self).forward(x, edge_index, y)
|
||||||
|
return x.reshape(B, -1, H, W).contiguous()
|
||||||
|
|
||||||
|
|
||||||
|
def get_2d_relative_pos_embed(embed_dim, grid_size):
|
||||||
|
"""
|
||||||
|
relative position embedding
|
||||||
|
References: https://arxiv.org/abs/2009.13658
|
||||||
|
"""
|
||||||
|
pos_embed = build_sincos2d_pos_embed([grid_size, grid_size], embed_dim)
|
||||||
|
relative_pos = 2 * torch.matmul(pos_embed, pos_embed.transpose(0, 1)) / pos_embed.shape[1]
|
||||||
|
return relative_pos
|
||||||
|
|
||||||
|
|
||||||
|
class Grapher(nn.Module):
|
||||||
|
"""
|
||||||
|
Grapher module with graph convolution and fc layers
|
||||||
|
"""
|
||||||
|
def __init__(self, in_channels, kernel_size=9, dilation=1, conv='mr', act_layer=nn.GELU, norm=None,
|
||||||
|
bias=True, stochastic=False, epsilon=0.0, r=1, n=196, drop_path=0.0, relative_pos=False):
|
||||||
|
super(Grapher, self).__init__()
|
||||||
|
self.channels = in_channels
|
||||||
|
self.n = n
|
||||||
|
self.r = r
|
||||||
|
self.fc1 = nn.Sequential(
|
||||||
|
nn.Conv2d(in_channels, in_channels, 1, stride=1, padding=0),
|
||||||
|
nn.BatchNorm2d(in_channels),
|
||||||
|
)
|
||||||
|
self.graph_conv = DyGraphConv2d(in_channels, in_channels * 2, kernel_size, dilation, conv,
|
||||||
|
act_layer, norm, bias, stochastic, epsilon, r)
|
||||||
|
self.fc2 = nn.Sequential(
|
||||||
|
nn.Conv2d(in_channels * 2, in_channels, 1, stride=1, padding=0),
|
||||||
|
nn.BatchNorm2d(in_channels),
|
||||||
|
)
|
||||||
|
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
||||||
|
self.relative_pos = None
|
||||||
|
if relative_pos:
|
||||||
|
relative_pos_tensor = get_2d_relative_pos_embed(in_channels,
|
||||||
|
int(n**0.5)).unsqueeze(0).unsqueeze(1)
|
||||||
|
relative_pos_tensor = F.interpolate(
|
||||||
|
relative_pos_tensor, size=(n, n//(r*r)), mode='bicubic', align_corners=False)
|
||||||
|
self.relative_pos = nn.Parameter(-relative_pos_tensor.squeeze(1), requires_grad=False)
|
||||||
|
|
||||||
|
def _get_relative_pos(self, relative_pos, H, W):
|
||||||
|
if relative_pos is None or H * W == self.n:
|
||||||
|
return relative_pos
|
||||||
|
else:
|
||||||
|
N = H * W
|
||||||
|
N_reduced = N // (self.r * self.r)
|
||||||
|
return F.interpolate(relative_pos.unsqueeze(0), size=(N, N_reduced), mode="bicubic").squeeze(0)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
_tmp = x
|
||||||
|
x = self.fc1(x)
|
||||||
|
B, C, H, W = x.shape
|
||||||
|
relative_pos = self._get_relative_pos(self.relative_pos, H, W)
|
||||||
|
x = self.graph_conv(x, relative_pos)
|
||||||
|
x = self.fc2(x)
|
||||||
|
x = self.drop_path(x) + _tmp
|
||||||
|
return x
|
@ -0,0 +1,280 @@
|
|||||||
|
"""
|
||||||
|
An implementation of ViG Model as defined in:
|
||||||
|
Vision GNN: An Image is Worth Graph of Nodes.
|
||||||
|
https://arxiv.org/abs/2206.00272
|
||||||
|
"""
|
||||||
|
import math
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from torch.nn import Sequential as Seq
|
||||||
|
|
||||||
|
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||||
|
from .helpers import load_pretrained, build_model_with_cfg
|
||||||
|
from .layers import DropPath, Grapher
|
||||||
|
from .registry import register_model
|
||||||
|
|
||||||
|
|
||||||
|
def _cfg(url='', **kwargs):
|
||||||
|
return {
|
||||||
|
'url': url,
|
||||||
|
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
|
||||||
|
'crop_pct': .9, 'interpolation': 'bicubic',
|
||||||
|
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
|
||||||
|
'first_conv': 'patch_embed.proj', 'classifier': 'head',
|
||||||
|
**kwargs
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
default_cfgs = {
|
||||||
|
'pvig_ti_224_gelu': _cfg(
|
||||||
|
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
|
||||||
|
url='https://github.com/huawei-noah/Efficient-AI-Backbones/releases/download/pyramid-vig/pvig_ti_78.5.pth.tar',
|
||||||
|
),
|
||||||
|
'pvig_s_224_gelu': _cfg(
|
||||||
|
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
|
||||||
|
url='https://github.com/huawei-noah/Efficient-AI-Backbones/releases/download/pyramid-vig/pvig_s_82.1.pth.tar',
|
||||||
|
),
|
||||||
|
'pvig_m_224_gelu': _cfg(
|
||||||
|
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
|
||||||
|
url='https://github.com/huawei-noah/Efficient-AI-Backbones/releases/download/pyramid-vig/pvig_m_83.1.pth.tar',
|
||||||
|
),
|
||||||
|
'pvig_b_224_gelu': _cfg(
|
||||||
|
crop_pct=0.95, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
|
||||||
|
url='https://github.com/huawei-noah/Efficient-AI-Backbones/releases/download/pyramid-vig/pvig_b_83.66.pth.tar',
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class FFN(nn.Module):
|
||||||
|
def __init__(self, in_features, hidden_features=None, out_features=None,
|
||||||
|
act_layer=nn.GELU, drop_path=0.0):
|
||||||
|
super().__init__()
|
||||||
|
out_features = out_features or in_features
|
||||||
|
hidden_features = hidden_features or in_features
|
||||||
|
self.fc1 = nn.Sequential(
|
||||||
|
nn.Conv2d(in_features, hidden_features, 1, stride=1, padding=0),
|
||||||
|
nn.BatchNorm2d(hidden_features),
|
||||||
|
)
|
||||||
|
self.act = act_layer()
|
||||||
|
self.fc2 = nn.Sequential(
|
||||||
|
nn.Conv2d(hidden_features, out_features, 1, stride=1, padding=0),
|
||||||
|
nn.BatchNorm2d(out_features),
|
||||||
|
)
|
||||||
|
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
shortcut = x
|
||||||
|
x = self.fc1(x)
|
||||||
|
x = self.act(x)
|
||||||
|
x = self.fc2(x)
|
||||||
|
x = self.drop_path(x) + shortcut
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class Stem(nn.Module):
|
||||||
|
""" Image to Visual Embedding
|
||||||
|
Overlap: https://arxiv.org/pdf/2106.13797.pdf
|
||||||
|
"""
|
||||||
|
def __init__(self, img_size=224, in_dim=3, out_dim=768, act_layer=nn.GELU):
|
||||||
|
super().__init__()
|
||||||
|
self.convs = nn.Sequential(
|
||||||
|
nn.Conv2d(in_dim, out_dim//2, 3, stride=2, padding=1),
|
||||||
|
nn.BatchNorm2d(out_dim//2),
|
||||||
|
act_layer(),
|
||||||
|
nn.Conv2d(out_dim//2, out_dim, 3, stride=2, padding=1),
|
||||||
|
nn.BatchNorm2d(out_dim),
|
||||||
|
act_layer(),
|
||||||
|
nn.Conv2d(out_dim, out_dim, 3, stride=1, padding=1),
|
||||||
|
nn.BatchNorm2d(out_dim),
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.convs(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class Downsample(nn.Module):
|
||||||
|
""" Convolution-based downsample
|
||||||
|
"""
|
||||||
|
def __init__(self, in_dim=3, out_dim=768):
|
||||||
|
super().__init__()
|
||||||
|
self.conv = nn.Sequential(
|
||||||
|
nn.Conv2d(in_dim, out_dim, 3, stride=2, padding=1),
|
||||||
|
nn.BatchNorm2d(out_dim),
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.conv(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class DeepGCN(torch.nn.Module):
|
||||||
|
def __init__(self, opt, num_classes=1000, in_chans=3):
|
||||||
|
super(DeepGCN, self).__init__()
|
||||||
|
self.num_classes = num_classes
|
||||||
|
self.in_chans = in_chans
|
||||||
|
print(opt)
|
||||||
|
k = opt.k
|
||||||
|
act_layer = nn.GELU
|
||||||
|
norm = opt.norm
|
||||||
|
bias = opt.bias
|
||||||
|
epsilon = opt.epsilon
|
||||||
|
stochastic = opt.use_stochastic
|
||||||
|
conv = opt.conv
|
||||||
|
drop_path = opt.drop_path
|
||||||
|
|
||||||
|
blocks = opt.blocks
|
||||||
|
self.n_blocks = sum(blocks)
|
||||||
|
channels = opt.channels
|
||||||
|
reduce_ratios = [4, 2, 1, 1]
|
||||||
|
dpr = [x.item() for x in torch.linspace(0, drop_path, self.n_blocks)] # stochastic depth decay
|
||||||
|
num_knn = [int(x.item()) for x in torch.linspace(k, k, self.n_blocks)] # number of knn's k
|
||||||
|
max_dilation = 49 // max(num_knn)
|
||||||
|
|
||||||
|
self.stem = Stem(in_dim=in_chans, out_dim=channels[0], act_layer=act_layer)
|
||||||
|
self.pos_embed = nn.Parameter(torch.zeros(1, channels[0], 224//4, 224//4))
|
||||||
|
HW = 224 // 4 * 224 // 4
|
||||||
|
|
||||||
|
self.backbone = nn.ModuleList([])
|
||||||
|
idx = 0
|
||||||
|
for i in range(len(blocks)):
|
||||||
|
if i > 0:
|
||||||
|
self.backbone.append(Downsample(channels[i-1], channels[i]))
|
||||||
|
HW = HW // 4
|
||||||
|
for j in range(blocks[i]):
|
||||||
|
self.backbone += [
|
||||||
|
Seq(Grapher(channels[i], num_knn[idx], min(idx // 4 + 1, max_dilation), conv, act_layer,
|
||||||
|
norm, bias, stochastic, epsilon, reduce_ratios[i], n=HW, drop_path=dpr[idx],
|
||||||
|
relative_pos=True),
|
||||||
|
FFN(channels[i], channels[i] * 4, act_layer=act_layer, drop_path=dpr[idx])
|
||||||
|
)]
|
||||||
|
idx += 1
|
||||||
|
self.backbone = Seq(*self.backbone)
|
||||||
|
|
||||||
|
self.prediction = Seq(nn.Conv2d(channels[-1], 1024, 1, bias=True),
|
||||||
|
nn.BatchNorm2d(1024),
|
||||||
|
act_layer(),
|
||||||
|
nn.Dropout(opt.dropout),
|
||||||
|
nn.Conv2d(1024, num_classes, 1, bias=True))
|
||||||
|
self.model_init()
|
||||||
|
|
||||||
|
def model_init(self):
|
||||||
|
for m in self.modules():
|
||||||
|
if isinstance(m, torch.nn.Conv2d):
|
||||||
|
torch.nn.init.kaiming_normal_(m.weight)
|
||||||
|
m.weight.requires_grad = True
|
||||||
|
if m.bias is not None:
|
||||||
|
m.bias.data.zero_()
|
||||||
|
m.bias.requires_grad = True
|
||||||
|
|
||||||
|
def forward(self, inputs):
|
||||||
|
x = self.stem(inputs) + self.pos_embed
|
||||||
|
B, C, H, W = x.shape
|
||||||
|
for i in range(len(self.backbone)):
|
||||||
|
x = self.backbone[i](x)
|
||||||
|
|
||||||
|
x = F.adaptive_avg_pool2d(x, 1)
|
||||||
|
return self.prediction(x).squeeze(-1).squeeze(-1)
|
||||||
|
|
||||||
|
|
||||||
|
def _create_pvig(variant, opt, pretrained=False, **kwargs):
|
||||||
|
"""
|
||||||
|
Constructs a GhostNet model
|
||||||
|
"""
|
||||||
|
model_kwargs = dict(
|
||||||
|
opt=opt,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
return build_model_with_cfg(
|
||||||
|
DeepGCN, variant, pretrained,
|
||||||
|
feature_cfg=dict(flatten_sequential=True),
|
||||||
|
**model_kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
@register_model
|
||||||
|
def pvig_ti_224_gelu(pretrained=False, **kwargs):
|
||||||
|
class OptInit:
|
||||||
|
def __init__(self, drop_path_rate=0.0, **kwargs):
|
||||||
|
self.k = 9 # neighbor num (default:9)
|
||||||
|
self.conv = 'mr' # graph conv layer {edge, mr}
|
||||||
|
self.norm = 'batch' # batch or instance normalization {batch, instance}
|
||||||
|
self.bias = True # bias of conv layer True or False
|
||||||
|
self.dropout = 0.0 # dropout rate
|
||||||
|
self.use_dilation = True # use dilated knn or not
|
||||||
|
self.epsilon = 0.2 # stochastic epsilon for gcn
|
||||||
|
self.use_stochastic = False # stochastic for gcn, True or False
|
||||||
|
self.drop_path = drop_path_rate
|
||||||
|
self.blocks = [2, 2, 6, 2] # number of basic blocks in the backbone
|
||||||
|
self.channels = [48, 96, 240, 384] # number of channels of deep features
|
||||||
|
|
||||||
|
opt = OptInit(**kwargs)
|
||||||
|
model = _create_pvig('pvig_ti_224_gelu', opt, pretrained)
|
||||||
|
model.default_cfg = default_cfgs['pvig_ti_224_gelu']
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
@register_model
|
||||||
|
def pvig_s_224_gelu(pretrained=False, **kwargs):
|
||||||
|
class OptInit:
|
||||||
|
def __init__(self, drop_path_rate=0.0, **kwargs):
|
||||||
|
self.k = 9 # neighbor num (default:9)
|
||||||
|
self.conv = 'mr' # graph conv layer {edge, mr}
|
||||||
|
self.norm = 'batch' # batch or instance normalization {batch, instance}
|
||||||
|
self.bias = True # bias of conv layer True or False
|
||||||
|
self.dropout = 0.0 # dropout rate
|
||||||
|
self.use_dilation = True # use dilated knn or not
|
||||||
|
self.epsilon = 0.2 # stochastic epsilon for gcn
|
||||||
|
self.use_stochastic = False # stochastic for gcn, True or False
|
||||||
|
self.drop_path = drop_path_rate
|
||||||
|
self.blocks = [2, 2, 6, 2] # number of basic blocks in the backbone
|
||||||
|
self.channels = [80, 160, 400, 640] # number of channels of deep features
|
||||||
|
|
||||||
|
opt = OptInit(**kwargs)
|
||||||
|
model = _create_pvig('pvig_s_224_gelu', opt, pretrained)
|
||||||
|
model.default_cfg = default_cfgs['pvig_s_224_gelu']
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
@register_model
|
||||||
|
def pvig_m_224_gelu(pretrained=False, **kwargs):
|
||||||
|
class OptInit:
|
||||||
|
def __init__(self, drop_path_rate=0.0, **kwargs):
|
||||||
|
self.k = 9 # neighbor num (default:9)
|
||||||
|
self.conv = 'mr' # graph conv layer {edge, mr}
|
||||||
|
self.norm = 'batch' # batch or instance normalization {batch, instance}
|
||||||
|
self.bias = True # bias of conv layer True or False
|
||||||
|
self.dropout = 0.0 # dropout rate
|
||||||
|
self.use_dilation = True # use dilated knn or not
|
||||||
|
self.epsilon = 0.2 # stochastic epsilon for gcn
|
||||||
|
self.use_stochastic = False # stochastic for gcn, True or False
|
||||||
|
self.drop_path = drop_path_rate
|
||||||
|
self.blocks = [2,2,16,2] # number of basic blocks in the backbone
|
||||||
|
self.channels = [96, 192, 384, 768] # number of channels of deep features
|
||||||
|
|
||||||
|
opt = OptInit(**kwargs)
|
||||||
|
model = _create_pvig('pvig_m_224_gelu', opt, pretrained)
|
||||||
|
model.default_cfg = default_cfgs['pvig_m_224_gelu']
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
@register_model
|
||||||
|
def pvig_b_224_gelu(pretrained=False, **kwargs):
|
||||||
|
class OptInit:
|
||||||
|
def __init__(self, drop_path_rate=0.0, **kwargs):
|
||||||
|
self.k = 9 # neighbor num (default:9)
|
||||||
|
self.conv = 'mr' # graph conv layer {edge, mr}
|
||||||
|
self.norm = 'batch' # batch or instance normalization {batch, instance}
|
||||||
|
self.bias = True # bias of conv layer True or False
|
||||||
|
self.dropout = 0.0 # dropout rate
|
||||||
|
self.use_dilation = True # use dilated knn or not
|
||||||
|
self.epsilon = 0.2 # stochastic epsilon for gcn
|
||||||
|
self.use_stochastic = False # stochastic for gcn, True or False
|
||||||
|
self.drop_path = drop_path_rate
|
||||||
|
self.blocks = [2,2,18,2] # number of basic blocks in the backbone
|
||||||
|
self.channels = [128, 256, 512, 1024] # number of channels of deep features
|
||||||
|
|
||||||
|
opt = OptInit(**kwargs)
|
||||||
|
model = _create_pvig('pvig_b_224_gelu', opt, pretrained)
|
||||||
|
model.default_cfg = default_cfgs['pvig_b_224_gelu']
|
||||||
|
return model
|
Loading…
Reference in new issue