A bit more ResNet cleanup.

* add inplace=True back
* minor comment improvements
* few clarity changes
pull/19/head
Ross Wightman 5 years ago
parent 33436fafad
commit 3d9be78fc6

@ -79,20 +79,17 @@ class SEModule(nn.Module):
#self.avg_pool = nn.AdaptiveAvgPool2d(1) #self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.fc1 = nn.Conv2d( self.fc1 = nn.Conv2d(
channels, reduction_channels, kernel_size=1, padding=0, bias=True) channels, reduction_channels, kernel_size=1, padding=0, bias=True)
self.relu = nn.ReLU() self.relu = nn.ReLU(inplace=True)
self.fc2 = nn.Conv2d( self.fc2 = nn.Conv2d(
reduction_channels, channels, kernel_size=1, padding=0, bias=True) reduction_channels, channels, kernel_size=1, padding=0, bias=True)
self.sigmoid = nn.Sigmoid()
def forward(self, x): def forward(self, x):
module_input = x #x_se = self.avg_pool(x)
#x = self.avg_pool(x) x_se = x.view(x.size(0), x.size(1), -1).mean(-1).view(x.size(0), x.size(1), 1, 1)
x = x.view(x.size(0), x.size(1), -1).mean(-1).view(x.size(0), x.size(1), 1, 1) x_se = self.fc1(x_se)
x = self.fc1(x) x_se = self.relu(x_se)
x = self.relu(x) x_se = self.fc2(x_se)
x = self.fc2(x) return x * x_se.sigmoid()
x = self.sigmoid(x)
return module_input * x
class BasicBlock(nn.Module): class BasicBlock(nn.Module):
@ -112,7 +109,7 @@ class BasicBlock(nn.Module):
inplanes, first_planes, kernel_size=3, stride=stride, padding=dilation, inplanes, first_planes, kernel_size=3, stride=stride, padding=dilation,
dilation=dilation, bias=False) dilation=dilation, bias=False)
self.bn1 = norm_layer(first_planes) self.bn1 = norm_layer(first_planes)
self.relu = nn.ReLU() self.relu = nn.ReLU(inplace=True)
self.conv2 = nn.Conv2d( self.conv2 = nn.Conv2d(
first_planes, outplanes, kernel_size=3, padding=previous_dilation, first_planes, outplanes, kernel_size=3, padding=previous_dilation,
dilation=previous_dilation, bias=False) dilation=previous_dilation, bias=False)
@ -164,7 +161,7 @@ class Bottleneck(nn.Module):
self.conv3 = nn.Conv2d(width, outplanes, kernel_size=1, bias=False) self.conv3 = nn.Conv2d(width, outplanes, kernel_size=1, bias=False)
self.bn3 = norm_layer(outplanes) self.bn3 = norm_layer(outplanes)
self.se = SEModule(outplanes, planes // 4) if use_se else None self.se = SEModule(outplanes, planes // 4) if use_se else None
self.relu = nn.ReLU() self.relu = nn.ReLU(inplace=True)
self.downsample = downsample self.downsample = downsample
self.stride = stride self.stride = stride
self.dilation = dilation self.dilation = dilation
@ -203,13 +200,14 @@ class ResNet(nn.Module):
* have conv-bn-act ordering * have conv-bn-act ordering
This ResNet impl supports a number of stem and downsample options based on the v1c, v1d, v1e, and v1s This ResNet impl supports a number of stem and downsample options based on the v1c, v1d, v1e, and v1s
variants included in the MXNet Gluon ResNetV1b model variants included in the MXNet Gluon ResNetV1b model. The C and D variants are also discussed in the
'Bag of Tricks' paper: https://arxiv.org/pdf/1812.01187. The B variant is equivalent to torchvision default.
ResNet variants: ResNet variants:
* normal - 7x7 stem, stem_width = 64, same as torchvision ResNet, NVIDIA ResNet 'v1.5', Gluon v1b * normal, b - 7x7 stem, stem_width = 64, same as torchvision ResNet, NVIDIA ResNet 'v1.5', Gluon v1b
* c - 3 layer deep 3x3 stem, stem_width = 32 * c - 3 layer deep 3x3 stem, stem_width = 32
* d - 3 layer deep 3x3 stem, stem_width = 32, average pool in downsample * d - 3 layer deep 3x3 stem, stem_width = 32, average pool in downsample
* e - 3 layer deep 3x3 stem, stem_width = 64, average pool in downsample *no pretrained weights available * e - 3 layer deep 3x3 stem, stem_width = 64, average pool in downsample
* s - 3 layer deep 3x3 stem, stem_width = 64 * s - 3 layer deep 3x3 stem, stem_width = 64
ResNeXt ResNeXt
@ -275,31 +273,25 @@ class ResNet(nn.Module):
self.conv1 = nn.Sequential(*[ self.conv1 = nn.Sequential(*[
nn.Conv2d(in_chans, stem_width, 3, stride=2, padding=1, bias=False), nn.Conv2d(in_chans, stem_width, 3, stride=2, padding=1, bias=False),
norm_layer(stem_width), norm_layer(stem_width),
nn.ReLU(), nn.ReLU(inplace=True),
nn.Conv2d(stem_width, stem_width, 3, stride=1, padding=1, bias=False), nn.Conv2d(stem_width, stem_width, 3, stride=1, padding=1, bias=False),
norm_layer(stem_width), norm_layer(stem_width),
nn.ReLU(), nn.ReLU(inplace=True),
nn.Conv2d(stem_width, self.inplanes, 3, stride=1, padding=1, bias=False)]) nn.Conv2d(stem_width, self.inplanes, 3, stride=1, padding=1, bias=False)])
else: else:
self.conv1 = nn.Conv2d(in_chans, stem_width, kernel_size=7, stride=2, padding=3, bias=False) self.conv1 = nn.Conv2d(in_chans, stem_width, kernel_size=7, stride=2, padding=3, bias=False)
self.bn1 = norm_layer(self.inplanes) self.bn1 = norm_layer(self.inplanes)
self.relu = nn.ReLU() self.relu = nn.ReLU(inplace=True)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
stride_3_4 = 1 if self.dilated else 2 stride_3_4 = 1 if self.dilated else 2
dilation_3 = 2 if self.dilated else 1 dilation_3 = 2 if self.dilated else 1
dilation_4 = 4 if self.dilated else 1 dilation_4 = 4 if self.dilated else 1
self.layer1 = self._make_layer( largs = dict(use_se=use_se, reduce_first=block_reduce_first, norm_layer=norm_layer,
block, 64, layers[0], stride=1, reduce_first=block_reduce_first, avg_down=avg_down, down_kernel_size=down_kernel_size)
use_se=use_se, avg_down=avg_down, down_kernel_size=1, norm_layer=norm_layer) self.layer1 = self._make_layer(block, 64, layers[0], stride=1, **largs)
self.layer2 = self._make_layer( self.layer2 = self._make_layer(block, 128, layers[1], stride=2, **largs)
block, 128, layers[1], stride=2, reduce_first=block_reduce_first, self.layer3 = self._make_layer(block, 256, layers[2], stride=stride_3_4, dilation=dilation_3, **largs)
use_se=use_se, avg_down=avg_down, down_kernel_size=down_kernel_size, norm_layer=norm_layer) self.layer4 = self._make_layer(block, 512, layers[3], stride=stride_3_4, dilation=dilation_4, **largs)
self.layer3 = self._make_layer(
block, 256, layers[2], stride=stride_3_4, dilation=dilation_3, reduce_first=block_reduce_first,
use_se=use_se, avg_down=avg_down, down_kernel_size=down_kernel_size, norm_layer=norm_layer)
self.layer4 = self._make_layer(
block, 512, layers[3], stride=stride_3_4, dilation=dilation_4, reduce_first=block_reduce_first,
use_se=use_se, avg_down=avg_down, down_kernel_size=down_kernel_size, norm_layer=norm_layer)
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
self.num_features = 512 * block.expansion self.num_features = 512 * block.expansion
self.fc = nn.Linear(self.num_features * self.global_pool.feat_mult(), num_classes) self.fc = nn.Linear(self.num_features * self.global_pool.feat_mult(), num_classes)
@ -314,6 +306,7 @@ class ResNet(nn.Module):
def _make_layer(self, block, planes, blocks, stride=1, dilation=1, reduce_first=1, def _make_layer(self, block, planes, blocks, stride=1, dilation=1, reduce_first=1,
use_se=False, avg_down=False, down_kernel_size=1, norm_layer=nn.BatchNorm2d): use_se=False, avg_down=False, down_kernel_size=1, norm_layer=nn.BatchNorm2d):
downsample = None downsample = None
down_kernel_size = 1 if stride == 1 and dilation == 1 else down_kernel_size
if stride != 1 or self.inplanes != planes * block.expansion: if stride != 1 or self.inplanes != planes * block.expansion:
downsample_padding = _get_padding(down_kernel_size, stride) downsample_padding = _get_padding(down_kernel_size, stride)
downsample_layers = [] downsample_layers = []

Loading…
Cancel
Save