From d5473c17f77d608ee150ef09b0a7c8d590f77aee Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Tue, 4 May 2021 21:27:15 -0700 Subject: [PATCH] Fix incorrect name of shortcut/identity paths in many residual nets. Inherited from naming in old old torchvision, long fixed there. --- timm/models/dla.py | 54 +++++++++++++++--------------- timm/models/efficientnet_blocks.py | 16 ++++----- timm/models/ghostnet.py | 4 +-- timm/models/res2net.py | 6 ++-- timm/models/resnest.py | 6 ++-- timm/models/resnet.py | 12 +++---- timm/models/senet.py | 12 +++---- timm/models/sknet.py | 12 +++---- timm/models/tresnet.py | 12 +++---- 9 files changed, 67 insertions(+), 67 deletions(-) diff --git a/timm/models/dla.py b/timm/models/dla.py index 64ad61d6..f0f25b0b 100644 --- a/timm/models/dla.py +++ b/timm/models/dla.py @@ -62,9 +62,9 @@ class DlaBasic(nn.Module): self.bn2 = nn.BatchNorm2d(planes) self.stride = stride - def forward(self, x, residual=None): - if residual is None: - residual = x + def forward(self, x, shortcut=None): + if shortcut is None: + shortcut = x out = self.conv1(x) out = self.bn1(out) @@ -73,7 +73,7 @@ class DlaBasic(nn.Module): out = self.conv2(out) out = self.bn2(out) - out += residual + out += shortcut out = self.relu(out) return out @@ -99,9 +99,9 @@ class DlaBottleneck(nn.Module): self.bn3 = nn.BatchNorm2d(outplanes) self.relu = nn.ReLU(inplace=True) - def forward(self, x, residual=None): - if residual is None: - residual = x + def forward(self, x, shortcut=None): + if shortcut is None: + shortcut = x out = self.conv1(x) out = self.bn1(out) @@ -114,7 +114,7 @@ class DlaBottleneck(nn.Module): out = self.conv3(out) out = self.bn3(out) - out += residual + out += shortcut out = self.relu(out) return out @@ -154,9 +154,9 @@ class DlaBottle2neck(nn.Module): self.bn3 = nn.BatchNorm2d(outplanes) self.relu = nn.ReLU(inplace=True) - def forward(self, x, residual=None): - if residual is None: - residual = x + def forward(self, x, shortcut=None): + if shortcut is None: + shortcut = x out = self.conv1(x) out = self.bn1(out) @@ -177,26 +177,26 @@ class DlaBottle2neck(nn.Module): out = self.conv3(out) out = self.bn3(out) - out += residual + out += shortcut out = self.relu(out) return out class DlaRoot(nn.Module): - def __init__(self, in_channels, out_channels, kernel_size, residual): + def __init__(self, in_channels, out_channels, kernel_size, shortcut): super(DlaRoot, self).__init__() self.conv = nn.Conv2d( in_channels, out_channels, 1, stride=1, bias=False, padding=(kernel_size - 1) // 2) self.bn = nn.BatchNorm2d(out_channels) self.relu = nn.ReLU(inplace=True) - self.residual = residual + self.shortcut = shortcut def forward(self, *x): children = x x = self.conv(torch.cat(x, 1)) x = self.bn(x) - if self.residual: + if self.shortcut: x += children[0] x = self.relu(x) @@ -206,7 +206,7 @@ class DlaRoot(nn.Module): class DlaTree(nn.Module): def __init__(self, levels, block, in_channels, out_channels, stride=1, dilation=1, cardinality=1, base_width=64, - level_root=False, root_dim=0, root_kernel_size=1, root_residual=False): + level_root=False, root_dim=0, root_kernel_size=1, root_shortcut=False): super(DlaTree, self).__init__() if root_dim == 0: root_dim = 2 * out_channels @@ -226,24 +226,24 @@ class DlaTree(nn.Module): nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, bias=False), nn.BatchNorm2d(out_channels)) else: - cargs.update(dict(root_kernel_size=root_kernel_size, root_residual=root_residual)) + cargs.update(dict(root_kernel_size=root_kernel_size, root_shortcut=root_shortcut)) self.tree1 = DlaTree( levels - 1, block, in_channels, out_channels, stride, root_dim=0, **cargs) self.tree2 = DlaTree( levels - 1, block, out_channels, out_channels, root_dim=root_dim + out_channels, **cargs) if levels == 1: - self.root = DlaRoot(root_dim, out_channels, root_kernel_size, root_residual) + self.root = DlaRoot(root_dim, out_channels, root_kernel_size, root_shortcut) self.level_root = level_root self.root_dim = root_dim self.levels = levels - def forward(self, x, residual=None, children=None): + def forward(self, x, shortcut=None, children=None): children = [] if children is None else children bottom = self.downsample(x) - residual = self.project(bottom) + shortcut = self.project(bottom) if self.level_root: children.append(bottom) - x1 = self.tree1(x, residual) + x1 = self.tree1(x, shortcut) if self.levels == 1: x2 = self.tree2(x1) x = self.root(x2, x1, *children) @@ -255,7 +255,7 @@ class DlaTree(nn.Module): class DLA(nn.Module): def __init__(self, levels, channels, output_stride=32, num_classes=1000, in_chans=3, - cardinality=1, base_width=64, block=DlaBottle2neck, residual_root=False, + cardinality=1, base_width=64, block=DlaBottle2neck, shortcut_root=False, drop_rate=0.0, global_pool='avg'): super(DLA, self).__init__() self.channels = channels @@ -271,7 +271,7 @@ class DLA(nn.Module): nn.ReLU(inplace=True)) self.level0 = self._make_conv_level(channels[0], channels[0], levels[0]) self.level1 = self._make_conv_level(channels[0], channels[1], levels[1], stride=2) - cargs = dict(cardinality=cardinality, base_width=base_width, root_residual=residual_root) + cargs = dict(cardinality=cardinality, base_width=base_width, root_shortcut=shortcut_root) self.level2 = DlaTree(levels[2], block, channels[1], channels[2], 2, level_root=False, **cargs) self.level3 = DlaTree(levels[3], block, channels[2], channels[3], 2, level_root=True, **cargs) self.level4 = DlaTree(levels[4], block, channels[3], channels[4], 2, level_root=True, **cargs) @@ -413,7 +413,7 @@ def dla60x(pretrained=False, **kwargs): # DLA-X-60 def dla102(pretrained=False, **kwargs): # DLA-102 model_kwargs = dict( levels=[1, 1, 1, 3, 4, 1], channels=[16, 32, 128, 256, 512, 1024], - block=DlaBottleneck, residual_root=True, **kwargs) + block=DlaBottleneck, shortcut_root=True, **kwargs) return _create_dla('dla102', pretrained, **model_kwargs) @@ -421,7 +421,7 @@ def dla102(pretrained=False, **kwargs): # DLA-102 def dla102x(pretrained=False, **kwargs): # DLA-X-102 model_kwargs = dict( levels=[1, 1, 1, 3, 4, 1], channels=[16, 32, 128, 256, 512, 1024], - block=DlaBottleneck, cardinality=32, base_width=4, residual_root=True, **kwargs) + block=DlaBottleneck, cardinality=32, base_width=4, shortcut_root=True, **kwargs) return _create_dla('dla102x', pretrained, **model_kwargs) @@ -429,7 +429,7 @@ def dla102x(pretrained=False, **kwargs): # DLA-X-102 def dla102x2(pretrained=False, **kwargs): # DLA-X-102 64 model_kwargs = dict( levels=[1, 1, 1, 3, 4, 1], channels=[16, 32, 128, 256, 512, 1024], - block=DlaBottleneck, cardinality=64, base_width=4, residual_root=True, **kwargs) + block=DlaBottleneck, cardinality=64, base_width=4, shortcut_root=True, **kwargs) return _create_dla('dla102x2', pretrained, **model_kwargs) @@ -437,5 +437,5 @@ def dla102x2(pretrained=False, **kwargs): # DLA-X-102 64 def dla169(pretrained=False, **kwargs): # DLA-169 model_kwargs = dict( levels=[1, 1, 2, 3, 5, 1], channels=[16, 32, 128, 256, 512, 1024], - block=DlaBottleneck, residual_root=True, **kwargs) + block=DlaBottleneck, shortcut_root=True, **kwargs) return _create_dla('dla169', pretrained, **model_kwargs) diff --git a/timm/models/efficientnet_blocks.py b/timm/models/efficientnet_blocks.py index 114533cf..040785f6 100644 --- a/timm/models/efficientnet_blocks.py +++ b/timm/models/efficientnet_blocks.py @@ -184,7 +184,7 @@ class DepthwiseSeparableConv(nn.Module): return info def forward(self, x): - residual = x + shortcut = x x = self.conv_dw(x) x = self.bn1(x) @@ -200,7 +200,7 @@ class DepthwiseSeparableConv(nn.Module): if self.has_residual: if self.drop_path_rate > 0.: x = drop_path(x, self.drop_path_rate, self.training) - x += residual + x += shortcut return x @@ -258,7 +258,7 @@ class InvertedResidual(nn.Module): return info def forward(self, x): - residual = x + shortcut = x # Point-wise expansion x = self.conv_pw(x) @@ -281,7 +281,7 @@ class InvertedResidual(nn.Module): if self.has_residual: if self.drop_path_rate > 0.: x = drop_path(x, self.drop_path_rate, self.training) - x += residual + x += shortcut return x @@ -308,7 +308,7 @@ class CondConvResidual(InvertedResidual): self.routing_fn = nn.Linear(in_chs, self.num_experts) def forward(self, x): - residual = x + shortcut = x # CondConv routing pooled_inputs = F.adaptive_avg_pool2d(x, 1).flatten(1) @@ -335,7 +335,7 @@ class CondConvResidual(InvertedResidual): if self.has_residual: if self.drop_path_rate > 0.: x = drop_path(x, self.drop_path_rate, self.training) - x += residual + x += shortcut return x @@ -390,7 +390,7 @@ class EdgeResidual(nn.Module): return info def forward(self, x): - residual = x + shortcut = x # Expansion convolution x = self.conv_exp(x) @@ -408,6 +408,6 @@ class EdgeResidual(nn.Module): if self.has_residual: if self.drop_path_rate > 0.: x = drop_path(x, self.drop_path_rate, self.training) - x += residual + x += shortcut return x diff --git a/timm/models/ghostnet.py b/timm/models/ghostnet.py index 76761d1c..358fb4c7 100644 --- a/timm/models/ghostnet.py +++ b/timm/models/ghostnet.py @@ -112,7 +112,7 @@ class GhostBottleneck(nn.Module): def forward(self, x): - residual = x + shortcut = x # 1st ghost bottleneck x = self.ghost1(x) @@ -129,7 +129,7 @@ class GhostBottleneck(nn.Module): # 2nd ghost bottleneck x = self.ghost2(x) - x += self.shortcut(residual) + x += self.shortcut(shortcut) return x diff --git a/timm/models/res2net.py b/timm/models/res2net.py index 977d872f..282baba3 100644 --- a/timm/models/res2net.py +++ b/timm/models/res2net.py @@ -91,7 +91,7 @@ class Bottle2neck(nn.Module): nn.init.zeros_(self.bn3.weight) def forward(self, x): - residual = x + shortcut = x out = self.conv1(x) out = self.bn1(out) @@ -124,9 +124,9 @@ class Bottle2neck(nn.Module): out = self.se(out) if self.downsample is not None: - residual = self.downsample(x) + shortcut = self.downsample(x) - out += residual + out += shortcut out = self.relu(out) return out diff --git a/timm/models/resnest.py b/timm/models/resnest.py index 154e250c..ac3b2559 100644 --- a/timm/models/resnest.py +++ b/timm/models/resnest.py @@ -105,7 +105,7 @@ class ResNestBottleneck(nn.Module): nn.init.zeros_(self.bn3.weight) def forward(self, x): - residual = x + shortcut = x out = self.conv1(x) out = self.bn1(out) @@ -132,9 +132,9 @@ class ResNestBottleneck(nn.Module): out = self.drop_block(out) if self.downsample is not None: - residual = self.downsample(x) + shortcut = self.downsample(x) - out += residual + out += shortcut out = self.act3(out) return out diff --git a/timm/models/resnet.py b/timm/models/resnet.py index 7fd47057..491d9acb 100644 --- a/timm/models/resnet.py +++ b/timm/models/resnet.py @@ -315,7 +315,7 @@ class BasicBlock(nn.Module): nn.init.zeros_(self.bn2.weight) def forward(self, x): - residual = x + shortcut = x x = self.conv1(x) x = self.bn1(x) @@ -337,8 +337,8 @@ class BasicBlock(nn.Module): x = self.drop_path(x) if self.downsample is not None: - residual = self.downsample(residual) - x += residual + shortcut = self.downsample(shortcut) + x += shortcut x = self.act2(x) return x @@ -385,7 +385,7 @@ class Bottleneck(nn.Module): nn.init.zeros_(self.bn3.weight) def forward(self, x): - residual = x + shortcut = x x = self.conv1(x) x = self.bn1(x) @@ -413,8 +413,8 @@ class Bottleneck(nn.Module): x = self.drop_path(x) if self.downsample is not None: - residual = self.downsample(residual) - x += residual + shortcut = self.downsample(shortcut) + x += shortcut x = self.act3(x) return x diff --git a/timm/models/senet.py b/timm/models/senet.py index 8227a453..3d0ba7b3 100644 --- a/timm/models/senet.py +++ b/timm/models/senet.py @@ -92,7 +92,7 @@ class Bottleneck(nn.Module): """ def forward(self, x): - residual = x + shortcut = x out = self.conv1(x) out = self.bn1(out) @@ -106,9 +106,9 @@ class Bottleneck(nn.Module): out = self.bn3(out) if self.downsample is not None: - residual = self.downsample(x) + shortcut = self.downsample(x) - out = self.se_module(out) + residual + out = self.se_module(out) + shortcut out = self.relu(out) return out @@ -204,7 +204,7 @@ class SEResNetBlock(nn.Module): self.stride = stride def forward(self, x): - residual = x + shortcut = x out = self.conv1(x) out = self.bn1(out) @@ -215,9 +215,9 @@ class SEResNetBlock(nn.Module): out = self.relu(out) if self.downsample is not None: - residual = self.downsample(x) + shortcut = self.downsample(x) - out = self.se_module(out) + residual + out = self.se_module(out) + shortcut out = self.relu(out) return out diff --git a/timm/models/sknet.py b/timm/models/sknet.py index bd9dd393..eb7ad8c3 100644 --- a/timm/models/sknet.py +++ b/timm/models/sknet.py @@ -76,7 +76,7 @@ class SelectiveKernelBasic(nn.Module): nn.init.zeros_(self.conv2.bn.weight) def forward(self, x): - residual = x + shortcut = x x = self.conv1(x) x = self.conv2(x) if self.se is not None: @@ -84,8 +84,8 @@ class SelectiveKernelBasic(nn.Module): if self.drop_path is not None: x = self.drop_path(x) if self.downsample is not None: - residual = self.downsample(residual) - x += residual + shortcut = self.downsample(shortcut) + x += shortcut x = self.act(x) return x @@ -124,7 +124,7 @@ class SelectiveKernelBottleneck(nn.Module): nn.init.zeros_(self.conv3.bn.weight) def forward(self, x): - residual = x + shortcut = x x = self.conv1(x) x = self.conv2(x) x = self.conv3(x) @@ -133,8 +133,8 @@ class SelectiveKernelBottleneck(nn.Module): if self.drop_path is not None: x = self.drop_path(x) if self.downsample is not None: - residual = self.downsample(residual) - x += residual + shortcut = self.downsample(shortcut) + x += shortcut x = self.act(x) return x diff --git a/timm/models/tresnet.py b/timm/models/tresnet.py index cec51cf4..9fb34c20 100644 --- a/timm/models/tresnet.py +++ b/timm/models/tresnet.py @@ -89,9 +89,9 @@ class BasicBlock(nn.Module): def forward(self, x): if self.downsample is not None: - residual = self.downsample(x) + shortcut = self.downsample(x) else: - residual = x + shortcut = x out = self.conv1(x) out = self.conv2(out) @@ -99,7 +99,7 @@ class BasicBlock(nn.Module): if self.se is not None: out = self.se(out) - out += residual + out += shortcut out = self.relu(out) return out @@ -136,9 +136,9 @@ class Bottleneck(nn.Module): def forward(self, x): if self.downsample is not None: - residual = self.downsample(x) + shortcut = self.downsample(x) else: - residual = x + shortcut = x out = self.conv1(x) out = self.conv2(out) @@ -146,7 +146,7 @@ class Bottleneck(nn.Module): out = self.se(out) out = self.conv3(out) - out = out + residual # no inplace + out = out + shortcut # no inplace out = self.relu(out) return out