pull/1063/merge
Michael Monashev 4 years ago committed by GitHub
commit 15bf46e9e3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -953,7 +953,7 @@ class BasicBlock(nn.Module):
self.act = nn.Identity() if linear_out else layers.act(inplace=True) self.act = nn.Identity() if linear_out else layers.act(inplace=True)
def init_weights(self, zero_init_last: bool = False): def init_weights(self, zero_init_last: bool = False):
if zero_init_last and self.shortcut is not None: if zero_init_last and self.shortcut is not None and self.conv2_kxk.bn.weight is not None:
nn.init.zeros_(self.conv2_kxk.bn.weight) nn.init.zeros_(self.conv2_kxk.bn.weight)
for attn in (self.attn, self.attn_last): for attn in (self.attn, self.attn_last):
if hasattr(attn, 'reset_parameters'): if hasattr(attn, 'reset_parameters'):
@ -1002,7 +1002,7 @@ class BottleneckBlock(nn.Module):
self.act = nn.Identity() if linear_out else layers.act(inplace=True) self.act = nn.Identity() if linear_out else layers.act(inplace=True)
def init_weights(self, zero_init_last: bool = False): def init_weights(self, zero_init_last: bool = False):
if zero_init_last and self.shortcut is not None: if zero_init_last and self.shortcut is not None and self.conv3_1x1.bn.weight is not None:
nn.init.zeros_(self.conv3_1x1.bn.weight) nn.init.zeros_(self.conv3_1x1.bn.weight)
for attn in (self.attn, self.attn_last): for attn in (self.attn, self.attn_last):
if hasattr(attn, 'reset_parameters'): if hasattr(attn, 'reset_parameters'):
@ -1055,7 +1055,7 @@ class DarkBlock(nn.Module):
self.act = nn.Identity() if linear_out else layers.act(inplace=True) self.act = nn.Identity() if linear_out else layers.act(inplace=True)
def init_weights(self, zero_init_last: bool = False): def init_weights(self, zero_init_last: bool = False):
if zero_init_last and self.shortcut is not None: if zero_init_last and self.shortcut is not None and self.conv2_kxk.bn.weight is not None:
nn.init.zeros_(self.conv2_kxk.bn.weight) nn.init.zeros_(self.conv2_kxk.bn.weight)
for attn in (self.attn, self.attn_last): for attn in (self.attn, self.attn_last):
if hasattr(attn, 'reset_parameters'): if hasattr(attn, 'reset_parameters'):
@ -1516,8 +1516,10 @@ def _init_weights(module, name='', zero_init_last=False):
if module.bias is not None: if module.bias is not None:
nn.init.zeros_(module.bias) nn.init.zeros_(module.bias)
elif isinstance(module, nn.BatchNorm2d): elif isinstance(module, nn.BatchNorm2d):
nn.init.ones_(module.weight) if module.weight is not None:
nn.init.zeros_(module.bias) nn.init.ones_(module.weight)
if module.bias is not None:
nn.init.zeros_(module.bias)
elif hasattr(module, 'init_weights'): elif hasattr(module, 'init_weights'):
module.init_weights(zero_init_last=zero_init_last) module.init_weights(zero_init_last=zero_init_last)

Loading…
Cancel
Save