@ -953,7 +953,7 @@ class BasicBlock(nn.Module):
self . act = nn . Identity ( ) if linear_out else layers . act ( inplace = True )
def init_weights ( self , zero_init_last : bool = False ) :
if zero_init_last and self . shortcut is not None :
if zero_init_last and self . shortcut is not None and self . conv2_kxk . bn . weight is not None :
nn . init . zeros_ ( self . conv2_kxk . bn . weight )
for attn in ( self . attn , self . attn_last ) :
if hasattr ( attn , ' reset_parameters ' ) :
@ -1002,7 +1002,7 @@ class BottleneckBlock(nn.Module):
self . act = nn . Identity ( ) if linear_out else layers . act ( inplace = True )
def init_weights ( self , zero_init_last : bool = False ) :
if zero_init_last and self . shortcut is not None :
if zero_init_last and self . shortcut is not None and self . conv3_1x1 . bn . weigh is not None :
nn . init . zeros_ ( self . conv3_1x1 . bn . weight )
for attn in ( self . attn , self . attn_last ) :
if hasattr ( attn , ' reset_parameters ' ) :
@ -1055,7 +1055,7 @@ class DarkBlock(nn.Module):
self . act = nn . Identity ( ) if linear_out else layers . act ( inplace = True )
def init_weights ( self , zero_init_last : bool = False ) :
if zero_init_last and self . shortcut is not None :
if zero_init_last and self . shortcut is not None and self . conv2_kxk . bn . weight is not None :
nn . init . zeros_ ( self . conv2_kxk . bn . weight )
for attn in ( self . attn , self . attn_last ) :
if hasattr ( attn , ' reset_parameters ' ) :
@ -1516,8 +1516,10 @@ def _init_weights(module, name='', zero_init_last=False):
if module . bias is not None :
nn . init . zeros_ ( module . bias )
elif isinstance ( module , nn . BatchNorm2d ) :
nn . init . ones_ ( module . weight )
nn . init . zeros_ ( module . bias )
if module . weight is not None :
nn . init . ones_ ( module . weight )
if module . bias is not None :
nn . init . zeros_ ( module . bias )
elif hasattr ( module , ' init_weights ' ) :
module . init_weights ( zero_init_last = zero_init_last )