Update resnext init

pull/1/head
Ross Wightman 6 years ago
parent 2295cf56c2
commit 321435e6b4

@ -80,11 +80,10 @@ class ResNeXt(nn.Module):
for m in self.modules(): for m in self.modules():
if isinstance(m, nn.Conv2d): if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
m.weight.data.normal_(0, math.sqrt(2. / n))
elif isinstance(m, nn.BatchNorm2d): elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1) nn.init.constant_(m.weight, 1.)
m.bias.data.zero_() nn.init.constant_(m.bias, 0.)
def _make_layer(self, block, planes, blocks, stride=1): def _make_layer(self, block, planes, blocks, stride=1):
downsample = None downsample = None

Loading…
Cancel
Save