From 43d7df3e57cf88ba0517be4d442b9306a2776d0b Mon Sep 17 00:00:00 2001 From: Michael Monashev Date: Fri, 31 Dec 2021 12:02:39 +0300 Subject: [PATCH] Fix BatchNorm initialisation This line inside gernet_s config `norm_layer=partial(BatchNormAct2d, affine=False)` cause many bugs like this: `AttributeError: 'NoneType' object has no attribute 'zero_'` --- timm/models/byobnet.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/timm/models/byobnet.py b/timm/models/byobnet.py index fa57943a..b1ebf0a0 100644 --- a/timm/models/byobnet.py +++ b/timm/models/byobnet.py @@ -953,7 +953,7 @@ class BasicBlock(nn.Module): self.act = nn.Identity() if linear_out else layers.act(inplace=True) def init_weights(self, zero_init_last: bool = False): - if zero_init_last and self.shortcut is not None: + if zero_init_last and self.shortcut is not None and self.conv2_kxk.bn.weight is not None: nn.init.zeros_(self.conv2_kxk.bn.weight) for attn in (self.attn, self.attn_last): if hasattr(attn, 'reset_parameters'): @@ -1002,7 +1002,7 @@ class BottleneckBlock(nn.Module): self.act = nn.Identity() if linear_out else layers.act(inplace=True) def init_weights(self, zero_init_last: bool = False): - if zero_init_last and self.shortcut is not None: + if zero_init_last and self.shortcut is not None and self.conv3_1x1.bn.weigh is not None: nn.init.zeros_(self.conv3_1x1.bn.weight) for attn in (self.attn, self.attn_last): if hasattr(attn, 'reset_parameters'): @@ -1055,7 +1055,7 @@ class DarkBlock(nn.Module): self.act = nn.Identity() if linear_out else layers.act(inplace=True) def init_weights(self, zero_init_last: bool = False): - if zero_init_last and self.shortcut is not None: + if zero_init_last and self.shortcut is not None and self.conv2_kxk.bn.weight is not None: nn.init.zeros_(self.conv2_kxk.bn.weight) for attn in (self.attn, self.attn_last): if hasattr(attn, 'reset_parameters'): @@ -1516,8 +1516,10 @@ def _init_weights(module, name='', zero_init_last=False): if module.bias is not None: nn.init.zeros_(module.bias) elif isinstance(module, nn.BatchNorm2d): - nn.init.ones_(module.weight) - nn.init.zeros_(module.bias) + if module.weight is not None: + nn.init.ones_(module.weight) + if module.bias is not None: + nn.init.zeros_(module.bias) elif hasattr(module, 'init_weights'): module.init_weights(zero_init_last=zero_init_last)