|
|
@ -953,7 +953,7 @@ class BasicBlock(nn.Module):
|
|
|
|
self.act = nn.Identity() if linear_out else layers.act(inplace=True)
|
|
|
|
self.act = nn.Identity() if linear_out else layers.act(inplace=True)
|
|
|
|
|
|
|
|
|
|
|
|
def init_weights(self, zero_init_last: bool = False):
|
|
|
|
def init_weights(self, zero_init_last: bool = False):
|
|
|
|
if zero_init_last and self.shortcut is not None:
|
|
|
|
if zero_init_last and self.shortcut is not None and self.conv2_kxk.bn.weight is not None:
|
|
|
|
nn.init.zeros_(self.conv2_kxk.bn.weight)
|
|
|
|
nn.init.zeros_(self.conv2_kxk.bn.weight)
|
|
|
|
for attn in (self.attn, self.attn_last):
|
|
|
|
for attn in (self.attn, self.attn_last):
|
|
|
|
if hasattr(attn, 'reset_parameters'):
|
|
|
|
if hasattr(attn, 'reset_parameters'):
|
|
|
@ -1002,7 +1002,7 @@ class BottleneckBlock(nn.Module):
|
|
|
|
self.act = nn.Identity() if linear_out else layers.act(inplace=True)
|
|
|
|
self.act = nn.Identity() if linear_out else layers.act(inplace=True)
|
|
|
|
|
|
|
|
|
|
|
|
def init_weights(self, zero_init_last: bool = False):
|
|
|
|
def init_weights(self, zero_init_last: bool = False):
|
|
|
|
if zero_init_last and self.shortcut is not None:
|
|
|
|
if zero_init_last and self.shortcut is not None and self.conv3_1x1.bn.weight is not None:
|
|
|
|
nn.init.zeros_(self.conv3_1x1.bn.weight)
|
|
|
|
nn.init.zeros_(self.conv3_1x1.bn.weight)
|
|
|
|
for attn in (self.attn, self.attn_last):
|
|
|
|
for attn in (self.attn, self.attn_last):
|
|
|
|
if hasattr(attn, 'reset_parameters'):
|
|
|
|
if hasattr(attn, 'reset_parameters'):
|
|
|
@ -1055,7 +1055,7 @@ class DarkBlock(nn.Module):
|
|
|
|
self.act = nn.Identity() if linear_out else layers.act(inplace=True)
|
|
|
|
self.act = nn.Identity() if linear_out else layers.act(inplace=True)
|
|
|
|
|
|
|
|
|
|
|
|
def init_weights(self, zero_init_last: bool = False):
|
|
|
|
def init_weights(self, zero_init_last: bool = False):
|
|
|
|
if zero_init_last and self.shortcut is not None:
|
|
|
|
if zero_init_last and self.shortcut is not None and self.conv2_kxk.bn.weight is not None:
|
|
|
|
nn.init.zeros_(self.conv2_kxk.bn.weight)
|
|
|
|
nn.init.zeros_(self.conv2_kxk.bn.weight)
|
|
|
|
for attn in (self.attn, self.attn_last):
|
|
|
|
for attn in (self.attn, self.attn_last):
|
|
|
|
if hasattr(attn, 'reset_parameters'):
|
|
|
|
if hasattr(attn, 'reset_parameters'):
|
|
|
@ -1516,8 +1516,10 @@ def _init_weights(module, name='', zero_init_last=False):
|
|
|
|
if module.bias is not None:
|
|
|
|
if module.bias is not None:
|
|
|
|
nn.init.zeros_(module.bias)
|
|
|
|
nn.init.zeros_(module.bias)
|
|
|
|
elif isinstance(module, nn.BatchNorm2d):
|
|
|
|
elif isinstance(module, nn.BatchNorm2d):
|
|
|
|
nn.init.ones_(module.weight)
|
|
|
|
if module.weight is not None:
|
|
|
|
nn.init.zeros_(module.bias)
|
|
|
|
nn.init.ones_(module.weight)
|
|
|
|
|
|
|
|
if module.bias is not None:
|
|
|
|
|
|
|
|
nn.init.zeros_(module.bias)
|
|
|
|
elif hasattr(module, 'init_weights'):
|
|
|
|
elif hasattr(module, 'init_weights'):
|
|
|
|
module.init_weights(zero_init_last=zero_init_last)
|
|
|
|
module.init_weights(zero_init_last=zero_init_last)
|
|
|
|
|
|
|
|
|
|
|
|