|
|
|
@ -62,9 +62,9 @@ class DlaBasic(nn.Module):
|
|
|
|
|
self.bn2 = nn.BatchNorm2d(planes)
|
|
|
|
|
self.stride = stride
|
|
|
|
|
|
|
|
|
|
def forward(self, x, residual=None):
|
|
|
|
|
if residual is None:
|
|
|
|
|
residual = x
|
|
|
|
|
def forward(self, x, shortcut=None):
|
|
|
|
|
if shortcut is None:
|
|
|
|
|
shortcut = x
|
|
|
|
|
|
|
|
|
|
out = self.conv1(x)
|
|
|
|
|
out = self.bn1(out)
|
|
|
|
@ -73,7 +73,7 @@ class DlaBasic(nn.Module):
|
|
|
|
|
out = self.conv2(out)
|
|
|
|
|
out = self.bn2(out)
|
|
|
|
|
|
|
|
|
|
out += residual
|
|
|
|
|
out += shortcut
|
|
|
|
|
out = self.relu(out)
|
|
|
|
|
|
|
|
|
|
return out
|
|
|
|
@ -99,9 +99,9 @@ class DlaBottleneck(nn.Module):
|
|
|
|
|
self.bn3 = nn.BatchNorm2d(outplanes)
|
|
|
|
|
self.relu = nn.ReLU(inplace=True)
|
|
|
|
|
|
|
|
|
|
def forward(self, x, residual=None):
|
|
|
|
|
if residual is None:
|
|
|
|
|
residual = x
|
|
|
|
|
def forward(self, x, shortcut=None):
|
|
|
|
|
if shortcut is None:
|
|
|
|
|
shortcut = x
|
|
|
|
|
|
|
|
|
|
out = self.conv1(x)
|
|
|
|
|
out = self.bn1(out)
|
|
|
|
@ -114,7 +114,7 @@ class DlaBottleneck(nn.Module):
|
|
|
|
|
out = self.conv3(out)
|
|
|
|
|
out = self.bn3(out)
|
|
|
|
|
|
|
|
|
|
out += residual
|
|
|
|
|
out += shortcut
|
|
|
|
|
out = self.relu(out)
|
|
|
|
|
|
|
|
|
|
return out
|
|
|
|
@ -154,9 +154,9 @@ class DlaBottle2neck(nn.Module):
|
|
|
|
|
self.bn3 = nn.BatchNorm2d(outplanes)
|
|
|
|
|
self.relu = nn.ReLU(inplace=True)
|
|
|
|
|
|
|
|
|
|
def forward(self, x, residual=None):
|
|
|
|
|
if residual is None:
|
|
|
|
|
residual = x
|
|
|
|
|
def forward(self, x, shortcut=None):
|
|
|
|
|
if shortcut is None:
|
|
|
|
|
shortcut = x
|
|
|
|
|
|
|
|
|
|
out = self.conv1(x)
|
|
|
|
|
out = self.bn1(out)
|
|
|
|
@ -177,26 +177,26 @@ class DlaBottle2neck(nn.Module):
|
|
|
|
|
out = self.conv3(out)
|
|
|
|
|
out = self.bn3(out)
|
|
|
|
|
|
|
|
|
|
out += residual
|
|
|
|
|
out += shortcut
|
|
|
|
|
out = self.relu(out)
|
|
|
|
|
|
|
|
|
|
return out
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class DlaRoot(nn.Module):
|
|
|
|
|
def __init__(self, in_channels, out_channels, kernel_size, residual):
|
|
|
|
|
def __init__(self, in_channels, out_channels, kernel_size, shortcut):
|
|
|
|
|
super(DlaRoot, self).__init__()
|
|
|
|
|
self.conv = nn.Conv2d(
|
|
|
|
|
in_channels, out_channels, 1, stride=1, bias=False, padding=(kernel_size - 1) // 2)
|
|
|
|
|
self.bn = nn.BatchNorm2d(out_channels)
|
|
|
|
|
self.relu = nn.ReLU(inplace=True)
|
|
|
|
|
self.residual = residual
|
|
|
|
|
self.shortcut = shortcut
|
|
|
|
|
|
|
|
|
|
def forward(self, *x):
|
|
|
|
|
children = x
|
|
|
|
|
x = self.conv(torch.cat(x, 1))
|
|
|
|
|
x = self.bn(x)
|
|
|
|
|
if self.residual:
|
|
|
|
|
if self.shortcut:
|
|
|
|
|
x += children[0]
|
|
|
|
|
x = self.relu(x)
|
|
|
|
|
|
|
|
|
@ -206,7 +206,7 @@ class DlaRoot(nn.Module):
|
|
|
|
|
class DlaTree(nn.Module):
|
|
|
|
|
def __init__(self, levels, block, in_channels, out_channels, stride=1,
|
|
|
|
|
dilation=1, cardinality=1, base_width=64,
|
|
|
|
|
level_root=False, root_dim=0, root_kernel_size=1, root_residual=False):
|
|
|
|
|
level_root=False, root_dim=0, root_kernel_size=1, root_shortcut=False):
|
|
|
|
|
super(DlaTree, self).__init__()
|
|
|
|
|
if root_dim == 0:
|
|
|
|
|
root_dim = 2 * out_channels
|
|
|
|
@ -226,24 +226,24 @@ class DlaTree(nn.Module):
|
|
|
|
|
nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, bias=False),
|
|
|
|
|
nn.BatchNorm2d(out_channels))
|
|
|
|
|
else:
|
|
|
|
|
cargs.update(dict(root_kernel_size=root_kernel_size, root_residual=root_residual))
|
|
|
|
|
cargs.update(dict(root_kernel_size=root_kernel_size, root_shortcut=root_shortcut))
|
|
|
|
|
self.tree1 = DlaTree(
|
|
|
|
|
levels - 1, block, in_channels, out_channels, stride, root_dim=0, **cargs)
|
|
|
|
|
self.tree2 = DlaTree(
|
|
|
|
|
levels - 1, block, out_channels, out_channels, root_dim=root_dim + out_channels, **cargs)
|
|
|
|
|
if levels == 1:
|
|
|
|
|
self.root = DlaRoot(root_dim, out_channels, root_kernel_size, root_residual)
|
|
|
|
|
self.root = DlaRoot(root_dim, out_channels, root_kernel_size, root_shortcut)
|
|
|
|
|
self.level_root = level_root
|
|
|
|
|
self.root_dim = root_dim
|
|
|
|
|
self.levels = levels
|
|
|
|
|
|
|
|
|
|
def forward(self, x, residual=None, children=None):
|
|
|
|
|
def forward(self, x, shortcut=None, children=None):
|
|
|
|
|
children = [] if children is None else children
|
|
|
|
|
bottom = self.downsample(x)
|
|
|
|
|
residual = self.project(bottom)
|
|
|
|
|
shortcut = self.project(bottom)
|
|
|
|
|
if self.level_root:
|
|
|
|
|
children.append(bottom)
|
|
|
|
|
x1 = self.tree1(x, residual)
|
|
|
|
|
x1 = self.tree1(x, shortcut)
|
|
|
|
|
if self.levels == 1:
|
|
|
|
|
x2 = self.tree2(x1)
|
|
|
|
|
x = self.root(x2, x1, *children)
|
|
|
|
@ -255,7 +255,7 @@ class DlaTree(nn.Module):
|
|
|
|
|
|
|
|
|
|
class DLA(nn.Module):
|
|
|
|
|
def __init__(self, levels, channels, output_stride=32, num_classes=1000, in_chans=3,
|
|
|
|
|
cardinality=1, base_width=64, block=DlaBottle2neck, residual_root=False,
|
|
|
|
|
cardinality=1, base_width=64, block=DlaBottle2neck, shortcut_root=False,
|
|
|
|
|
drop_rate=0.0, global_pool='avg'):
|
|
|
|
|
super(DLA, self).__init__()
|
|
|
|
|
self.channels = channels
|
|
|
|
@ -271,7 +271,7 @@ class DLA(nn.Module):
|
|
|
|
|
nn.ReLU(inplace=True))
|
|
|
|
|
self.level0 = self._make_conv_level(channels[0], channels[0], levels[0])
|
|
|
|
|
self.level1 = self._make_conv_level(channels[0], channels[1], levels[1], stride=2)
|
|
|
|
|
cargs = dict(cardinality=cardinality, base_width=base_width, root_residual=residual_root)
|
|
|
|
|
cargs = dict(cardinality=cardinality, base_width=base_width, root_shortcut=shortcut_root)
|
|
|
|
|
self.level2 = DlaTree(levels[2], block, channels[1], channels[2], 2, level_root=False, **cargs)
|
|
|
|
|
self.level3 = DlaTree(levels[3], block, channels[2], channels[3], 2, level_root=True, **cargs)
|
|
|
|
|
self.level4 = DlaTree(levels[4], block, channels[3], channels[4], 2, level_root=True, **cargs)
|
|
|
|
@ -413,7 +413,7 @@ def dla60x(pretrained=False, **kwargs): # DLA-X-60
|
|
|
|
|
def dla102(pretrained=False, **kwargs): # DLA-102
|
|
|
|
|
model_kwargs = dict(
|
|
|
|
|
levels=[1, 1, 1, 3, 4, 1], channels=[16, 32, 128, 256, 512, 1024],
|
|
|
|
|
block=DlaBottleneck, residual_root=True, **kwargs)
|
|
|
|
|
block=DlaBottleneck, shortcut_root=True, **kwargs)
|
|
|
|
|
return _create_dla('dla102', pretrained, **model_kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -421,7 +421,7 @@ def dla102(pretrained=False, **kwargs): # DLA-102
|
|
|
|
|
def dla102x(pretrained=False, **kwargs): # DLA-X-102
|
|
|
|
|
model_kwargs = dict(
|
|
|
|
|
levels=[1, 1, 1, 3, 4, 1], channels=[16, 32, 128, 256, 512, 1024],
|
|
|
|
|
block=DlaBottleneck, cardinality=32, base_width=4, residual_root=True, **kwargs)
|
|
|
|
|
block=DlaBottleneck, cardinality=32, base_width=4, shortcut_root=True, **kwargs)
|
|
|
|
|
return _create_dla('dla102x', pretrained, **model_kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -429,7 +429,7 @@ def dla102x(pretrained=False, **kwargs): # DLA-X-102
|
|
|
|
|
def dla102x2(pretrained=False, **kwargs): # DLA-X-102 64
|
|
|
|
|
model_kwargs = dict(
|
|
|
|
|
levels=[1, 1, 1, 3, 4, 1], channels=[16, 32, 128, 256, 512, 1024],
|
|
|
|
|
block=DlaBottleneck, cardinality=64, base_width=4, residual_root=True, **kwargs)
|
|
|
|
|
block=DlaBottleneck, cardinality=64, base_width=4, shortcut_root=True, **kwargs)
|
|
|
|
|
return _create_dla('dla102x2', pretrained, **model_kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -437,5 +437,5 @@ def dla102x2(pretrained=False, **kwargs): # DLA-X-102 64
|
|
|
|
|
def dla169(pretrained=False, **kwargs): # DLA-169
|
|
|
|
|
model_kwargs = dict(
|
|
|
|
|
levels=[1, 1, 2, 3, 5, 1], channels=[16, 32, 128, 256, 512, 1024],
|
|
|
|
|
block=DlaBottleneck, residual_root=True, **kwargs)
|
|
|
|
|
block=DlaBottleneck, shortcut_root=True, **kwargs)
|
|
|
|
|
return _create_dla('dla169', pretrained, **model_kwargs)
|
|
|
|
|