Fix incorrect name of shortcut/identity paths in many residual nets. Inherited from naming in old old torchvision, long fixed there.

pull/612/head
Ross Wightman 4 years ago
parent 0d87650fea
commit d5473c17f7

@ -62,9 +62,9 @@ class DlaBasic(nn.Module):
self.bn2 = nn.BatchNorm2d(planes)
self.stride = stride
def forward(self, x, residual=None):
if residual is None:
residual = x
def forward(self, x, shortcut=None):
if shortcut is None:
shortcut = x
out = self.conv1(x)
out = self.bn1(out)
@ -73,7 +73,7 @@ class DlaBasic(nn.Module):
out = self.conv2(out)
out = self.bn2(out)
out += residual
out += shortcut
out = self.relu(out)
return out
@ -99,9 +99,9 @@ class DlaBottleneck(nn.Module):
self.bn3 = nn.BatchNorm2d(outplanes)
self.relu = nn.ReLU(inplace=True)
def forward(self, x, residual=None):
if residual is None:
residual = x
def forward(self, x, shortcut=None):
if shortcut is None:
shortcut = x
out = self.conv1(x)
out = self.bn1(out)
@ -114,7 +114,7 @@ class DlaBottleneck(nn.Module):
out = self.conv3(out)
out = self.bn3(out)
out += residual
out += shortcut
out = self.relu(out)
return out
@ -154,9 +154,9 @@ class DlaBottle2neck(nn.Module):
self.bn3 = nn.BatchNorm2d(outplanes)
self.relu = nn.ReLU(inplace=True)
def forward(self, x, residual=None):
if residual is None:
residual = x
def forward(self, x, shortcut=None):
if shortcut is None:
shortcut = x
out = self.conv1(x)
out = self.bn1(out)
@ -177,26 +177,26 @@ class DlaBottle2neck(nn.Module):
out = self.conv3(out)
out = self.bn3(out)
out += residual
out += shortcut
out = self.relu(out)
return out
class DlaRoot(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, residual):
def __init__(self, in_channels, out_channels, kernel_size, shortcut):
super(DlaRoot, self).__init__()
self.conv = nn.Conv2d(
in_channels, out_channels, 1, stride=1, bias=False, padding=(kernel_size - 1) // 2)
self.bn = nn.BatchNorm2d(out_channels)
self.relu = nn.ReLU(inplace=True)
self.residual = residual
self.shortcut = shortcut
def forward(self, *x):
children = x
x = self.conv(torch.cat(x, 1))
x = self.bn(x)
if self.residual:
if self.shortcut:
x += children[0]
x = self.relu(x)
@ -206,7 +206,7 @@ class DlaRoot(nn.Module):
class DlaTree(nn.Module):
def __init__(self, levels, block, in_channels, out_channels, stride=1,
dilation=1, cardinality=1, base_width=64,
level_root=False, root_dim=0, root_kernel_size=1, root_residual=False):
level_root=False, root_dim=0, root_kernel_size=1, root_shortcut=False):
super(DlaTree, self).__init__()
if root_dim == 0:
root_dim = 2 * out_channels
@ -226,24 +226,24 @@ class DlaTree(nn.Module):
nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, bias=False),
nn.BatchNorm2d(out_channels))
else:
cargs.update(dict(root_kernel_size=root_kernel_size, root_residual=root_residual))
cargs.update(dict(root_kernel_size=root_kernel_size, root_shortcut=root_shortcut))
self.tree1 = DlaTree(
levels - 1, block, in_channels, out_channels, stride, root_dim=0, **cargs)
self.tree2 = DlaTree(
levels - 1, block, out_channels, out_channels, root_dim=root_dim + out_channels, **cargs)
if levels == 1:
self.root = DlaRoot(root_dim, out_channels, root_kernel_size, root_residual)
self.root = DlaRoot(root_dim, out_channels, root_kernel_size, root_shortcut)
self.level_root = level_root
self.root_dim = root_dim
self.levels = levels
def forward(self, x, residual=None, children=None):
def forward(self, x, shortcut=None, children=None):
children = [] if children is None else children
bottom = self.downsample(x)
residual = self.project(bottom)
shortcut = self.project(bottom)
if self.level_root:
children.append(bottom)
x1 = self.tree1(x, residual)
x1 = self.tree1(x, shortcut)
if self.levels == 1:
x2 = self.tree2(x1)
x = self.root(x2, x1, *children)
@ -255,7 +255,7 @@ class DlaTree(nn.Module):
class DLA(nn.Module):
def __init__(self, levels, channels, output_stride=32, num_classes=1000, in_chans=3,
cardinality=1, base_width=64, block=DlaBottle2neck, residual_root=False,
cardinality=1, base_width=64, block=DlaBottle2neck, shortcut_root=False,
drop_rate=0.0, global_pool='avg'):
super(DLA, self).__init__()
self.channels = channels
@ -271,7 +271,7 @@ class DLA(nn.Module):
nn.ReLU(inplace=True))
self.level0 = self._make_conv_level(channels[0], channels[0], levels[0])
self.level1 = self._make_conv_level(channels[0], channels[1], levels[1], stride=2)
cargs = dict(cardinality=cardinality, base_width=base_width, root_residual=residual_root)
cargs = dict(cardinality=cardinality, base_width=base_width, root_shortcut=shortcut_root)
self.level2 = DlaTree(levels[2], block, channels[1], channels[2], 2, level_root=False, **cargs)
self.level3 = DlaTree(levels[3], block, channels[2], channels[3], 2, level_root=True, **cargs)
self.level4 = DlaTree(levels[4], block, channels[3], channels[4], 2, level_root=True, **cargs)
@ -413,7 +413,7 @@ def dla60x(pretrained=False, **kwargs): # DLA-X-60
def dla102(pretrained=False, **kwargs): # DLA-102
model_kwargs = dict(
levels=[1, 1, 1, 3, 4, 1], channels=[16, 32, 128, 256, 512, 1024],
block=DlaBottleneck, residual_root=True, **kwargs)
block=DlaBottleneck, shortcut_root=True, **kwargs)
return _create_dla('dla102', pretrained, **model_kwargs)
@ -421,7 +421,7 @@ def dla102(pretrained=False, **kwargs): # DLA-102
def dla102x(pretrained=False, **kwargs): # DLA-X-102
model_kwargs = dict(
levels=[1, 1, 1, 3, 4, 1], channels=[16, 32, 128, 256, 512, 1024],
block=DlaBottleneck, cardinality=32, base_width=4, residual_root=True, **kwargs)
block=DlaBottleneck, cardinality=32, base_width=4, shortcut_root=True, **kwargs)
return _create_dla('dla102x', pretrained, **model_kwargs)
@ -429,7 +429,7 @@ def dla102x(pretrained=False, **kwargs): # DLA-X-102
def dla102x2(pretrained=False, **kwargs): # DLA-X-102 64
model_kwargs = dict(
levels=[1, 1, 1, 3, 4, 1], channels=[16, 32, 128, 256, 512, 1024],
block=DlaBottleneck, cardinality=64, base_width=4, residual_root=True, **kwargs)
block=DlaBottleneck, cardinality=64, base_width=4, shortcut_root=True, **kwargs)
return _create_dla('dla102x2', pretrained, **model_kwargs)
@ -437,5 +437,5 @@ def dla102x2(pretrained=False, **kwargs): # DLA-X-102 64
def dla169(pretrained=False, **kwargs): # DLA-169
model_kwargs = dict(
levels=[1, 1, 2, 3, 5, 1], channels=[16, 32, 128, 256, 512, 1024],
block=DlaBottleneck, residual_root=True, **kwargs)
block=DlaBottleneck, shortcut_root=True, **kwargs)
return _create_dla('dla169', pretrained, **model_kwargs)

@ -184,7 +184,7 @@ class DepthwiseSeparableConv(nn.Module):
return info
def forward(self, x):
residual = x
shortcut = x
x = self.conv_dw(x)
x = self.bn1(x)
@ -200,7 +200,7 @@ class DepthwiseSeparableConv(nn.Module):
if self.has_residual:
if self.drop_path_rate > 0.:
x = drop_path(x, self.drop_path_rate, self.training)
x += residual
x += shortcut
return x
@ -258,7 +258,7 @@ class InvertedResidual(nn.Module):
return info
def forward(self, x):
residual = x
shortcut = x
# Point-wise expansion
x = self.conv_pw(x)
@ -281,7 +281,7 @@ class InvertedResidual(nn.Module):
if self.has_residual:
if self.drop_path_rate > 0.:
x = drop_path(x, self.drop_path_rate, self.training)
x += residual
x += shortcut
return x
@ -308,7 +308,7 @@ class CondConvResidual(InvertedResidual):
self.routing_fn = nn.Linear(in_chs, self.num_experts)
def forward(self, x):
residual = x
shortcut = x
# CondConv routing
pooled_inputs = F.adaptive_avg_pool2d(x, 1).flatten(1)
@ -335,7 +335,7 @@ class CondConvResidual(InvertedResidual):
if self.has_residual:
if self.drop_path_rate > 0.:
x = drop_path(x, self.drop_path_rate, self.training)
x += residual
x += shortcut
return x
@ -390,7 +390,7 @@ class EdgeResidual(nn.Module):
return info
def forward(self, x):
residual = x
shortcut = x
# Expansion convolution
x = self.conv_exp(x)
@ -408,6 +408,6 @@ class EdgeResidual(nn.Module):
if self.has_residual:
if self.drop_path_rate > 0.:
x = drop_path(x, self.drop_path_rate, self.training)
x += residual
x += shortcut
return x

@ -112,7 +112,7 @@ class GhostBottleneck(nn.Module):
def forward(self, x):
residual = x
shortcut = x
# 1st ghost bottleneck
x = self.ghost1(x)
@ -129,7 +129,7 @@ class GhostBottleneck(nn.Module):
# 2nd ghost bottleneck
x = self.ghost2(x)
x += self.shortcut(residual)
x += self.shortcut(shortcut)
return x

@ -91,7 +91,7 @@ class Bottle2neck(nn.Module):
nn.init.zeros_(self.bn3.weight)
def forward(self, x):
residual = x
shortcut = x
out = self.conv1(x)
out = self.bn1(out)
@ -124,9 +124,9 @@ class Bottle2neck(nn.Module):
out = self.se(out)
if self.downsample is not None:
residual = self.downsample(x)
shortcut = self.downsample(x)
out += residual
out += shortcut
out = self.relu(out)
return out

@ -105,7 +105,7 @@ class ResNestBottleneck(nn.Module):
nn.init.zeros_(self.bn3.weight)
def forward(self, x):
residual = x
shortcut = x
out = self.conv1(x)
out = self.bn1(out)
@ -132,9 +132,9 @@ class ResNestBottleneck(nn.Module):
out = self.drop_block(out)
if self.downsample is not None:
residual = self.downsample(x)
shortcut = self.downsample(x)
out += residual
out += shortcut
out = self.act3(out)
return out

@ -315,7 +315,7 @@ class BasicBlock(nn.Module):
nn.init.zeros_(self.bn2.weight)
def forward(self, x):
residual = x
shortcut = x
x = self.conv1(x)
x = self.bn1(x)
@ -337,8 +337,8 @@ class BasicBlock(nn.Module):
x = self.drop_path(x)
if self.downsample is not None:
residual = self.downsample(residual)
x += residual
shortcut = self.downsample(shortcut)
x += shortcut
x = self.act2(x)
return x
@ -385,7 +385,7 @@ class Bottleneck(nn.Module):
nn.init.zeros_(self.bn3.weight)
def forward(self, x):
residual = x
shortcut = x
x = self.conv1(x)
x = self.bn1(x)
@ -413,8 +413,8 @@ class Bottleneck(nn.Module):
x = self.drop_path(x)
if self.downsample is not None:
residual = self.downsample(residual)
x += residual
shortcut = self.downsample(shortcut)
x += shortcut
x = self.act3(x)
return x

@ -92,7 +92,7 @@ class Bottleneck(nn.Module):
"""
def forward(self, x):
residual = x
shortcut = x
out = self.conv1(x)
out = self.bn1(out)
@ -106,9 +106,9 @@ class Bottleneck(nn.Module):
out = self.bn3(out)
if self.downsample is not None:
residual = self.downsample(x)
shortcut = self.downsample(x)
out = self.se_module(out) + residual
out = self.se_module(out) + shortcut
out = self.relu(out)
return out
@ -204,7 +204,7 @@ class SEResNetBlock(nn.Module):
self.stride = stride
def forward(self, x):
residual = x
shortcut = x
out = self.conv1(x)
out = self.bn1(out)
@ -215,9 +215,9 @@ class SEResNetBlock(nn.Module):
out = self.relu(out)
if self.downsample is not None:
residual = self.downsample(x)
shortcut = self.downsample(x)
out = self.se_module(out) + residual
out = self.se_module(out) + shortcut
out = self.relu(out)
return out

@ -76,7 +76,7 @@ class SelectiveKernelBasic(nn.Module):
nn.init.zeros_(self.conv2.bn.weight)
def forward(self, x):
residual = x
shortcut = x
x = self.conv1(x)
x = self.conv2(x)
if self.se is not None:
@ -84,8 +84,8 @@ class SelectiveKernelBasic(nn.Module):
if self.drop_path is not None:
x = self.drop_path(x)
if self.downsample is not None:
residual = self.downsample(residual)
x += residual
shortcut = self.downsample(shortcut)
x += shortcut
x = self.act(x)
return x
@ -124,7 +124,7 @@ class SelectiveKernelBottleneck(nn.Module):
nn.init.zeros_(self.conv3.bn.weight)
def forward(self, x):
residual = x
shortcut = x
x = self.conv1(x)
x = self.conv2(x)
x = self.conv3(x)
@ -133,8 +133,8 @@ class SelectiveKernelBottleneck(nn.Module):
if self.drop_path is not None:
x = self.drop_path(x)
if self.downsample is not None:
residual = self.downsample(residual)
x += residual
shortcut = self.downsample(shortcut)
x += shortcut
x = self.act(x)
return x

@ -89,9 +89,9 @@ class BasicBlock(nn.Module):
def forward(self, x):
if self.downsample is not None:
residual = self.downsample(x)
shortcut = self.downsample(x)
else:
residual = x
shortcut = x
out = self.conv1(x)
out = self.conv2(out)
@ -99,7 +99,7 @@ class BasicBlock(nn.Module):
if self.se is not None:
out = self.se(out)
out += residual
out += shortcut
out = self.relu(out)
return out
@ -136,9 +136,9 @@ class Bottleneck(nn.Module):
def forward(self, x):
if self.downsample is not None:
residual = self.downsample(x)
shortcut = self.downsample(x)
else:
residual = x
shortcut = x
out = self.conv1(x)
out = self.conv2(out)
@ -146,7 +146,7 @@ class Bottleneck(nn.Module):
out = self.se(out)
out = self.conv3(out)
out = out + residual # no inplace
out = out + shortcut # no inplace
out = self.relu(out)
return out

Loading…
Cancel
Save