From 321435e6b43b2e365f82b67c0120660fb54d415d Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Sun, 10 Mar 2019 14:24:39 -0700 Subject: [PATCH] Update resnext init --- models/resnext.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/models/resnext.py b/models/resnext.py index ee4ea51f..aa435de6 100644 --- a/models/resnext.py +++ b/models/resnext.py @@ -80,11 +80,10 @@ class ResNeXt(nn.Module): for m in self.modules(): if isinstance(m, nn.Conv2d): - n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels - m.weight.data.normal_(0, math.sqrt(2. / n)) + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') elif isinstance(m, nn.BatchNorm2d): - m.weight.data.fill_(1) - m.bias.data.zero_() + nn.init.constant_(m.weight, 1.) + nn.init.constant_(m.bias, 0.) def _make_layer(self, block, planes, blocks, stride=1): downsample = None