diff --git a/timm/models/byobnet.py b/timm/models/byobnet.py index fa57943a..7a1d77ec 100644 --- a/timm/models/byobnet.py +++ b/timm/models/byobnet.py @@ -953,7 +953,7 @@ class BasicBlock(nn.Module): self.act = nn.Identity() if linear_out else layers.act(inplace=True) def init_weights(self, zero_init_last: bool = False): - if zero_init_last and self.shortcut is not None: + if zero_init_last and self.shortcut is not None and self.conv2_kxk.bn.weight is not None: nn.init.zeros_(self.conv2_kxk.bn.weight) for attn in (self.attn, self.attn_last): if hasattr(attn, 'reset_parameters'): @@ -1002,7 +1002,7 @@ class BottleneckBlock(nn.Module): self.act = nn.Identity() if linear_out else layers.act(inplace=True) def init_weights(self, zero_init_last: bool = False): - if zero_init_last and self.shortcut is not None: + if zero_init_last and self.shortcut is not None and self.conv3_1x1.bn.weight is not None: nn.init.zeros_(self.conv3_1x1.bn.weight) for attn in (self.attn, self.attn_last): if hasattr(attn, 'reset_parameters'): @@ -1055,7 +1055,7 @@ class DarkBlock(nn.Module): self.act = nn.Identity() if linear_out else layers.act(inplace=True) def init_weights(self, zero_init_last: bool = False): - if zero_init_last and self.shortcut is not None: + if zero_init_last and self.shortcut is not None and self.conv2_kxk.bn.weight is not None: nn.init.zeros_(self.conv2_kxk.bn.weight) for attn in (self.attn, self.attn_last): if hasattr(attn, 'reset_parameters'): @@ -1516,8 +1516,10 @@ def _init_weights(module, name='', zero_init_last=False): if module.bias is not None: nn.init.zeros_(module.bias) elif isinstance(module, nn.BatchNorm2d): - nn.init.ones_(module.weight) - nn.init.zeros_(module.bias) + if module.weight is not None: + nn.init.ones_(module.weight) + if module.bias is not None: + nn.init.zeros_(module.bias) elif hasattr(module, 'init_weights'): module.init_weights(zero_init_last=zero_init_last)