diff --git a/timm/models/byobnet.py b/timm/models/byobnet.py index 99350d7c..cc293530 100644 --- a/timm/models/byobnet.py +++ b/timm/models/byobnet.py @@ -33,7 +33,7 @@ import torch import torch.nn as nn from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from .helpers import build_model_with_cfg +from .helpers import build_model_with_cfg, named_apply from .layers import ClassifierHead, ConvBnAct, BatchNormAct2d, DropPath, AvgPool2dSame, \ create_conv2d, get_act_layer, convert_norm_act, get_attn, make_divisible, to_2tuple from .registry import register_model @@ -166,7 +166,7 @@ class ByoModelCfg: stem_chs: int = 32 width_factor: float = 1.0 num_features: int = 0 # num out_channels for final conv, no final 1x1 conv if 0 - zero_init_last_bn: bool = True + zero_init_last: bool = True # zero init last weight (usually bn) in residual path fixed_input_size: bool = False # model constrained to a fixed-input size / img_size must be provided on creation act_layer: str = 'relu' @@ -757,8 +757,8 @@ class BasicBlock(nn.Module): self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity() self.act = nn.Identity() if linear_out else layers.act(inplace=True) - def init_weights(self, zero_init_last_bn: bool = False): - if zero_init_last_bn: + def init_weights(self, zero_init_last: bool = False): + if zero_init_last: nn.init.zeros_(self.conv2_kxk.bn.weight) for attn in (self.attn, self.attn_last): if hasattr(attn, 'reset_parameters'): @@ -814,8 +814,8 @@ class BottleneckBlock(nn.Module): self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity() self.act = nn.Identity() if linear_out else layers.act(inplace=True) - def init_weights(self, zero_init_last_bn: bool = False): - if zero_init_last_bn: + def init_weights(self, zero_init_last: bool = False): + if zero_init_last: nn.init.zeros_(self.conv3_1x1.bn.weight) for attn in (self.attn, self.attn_last): if hasattr(attn, 'reset_parameters'): @@ -871,8 +871,8 @@ class DarkBlock(nn.Module): self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity() self.act = nn.Identity() if linear_out else layers.act(inplace=True) - def init_weights(self, zero_init_last_bn: bool = False): - if zero_init_last_bn: + def init_weights(self, zero_init_last: bool = False): + if zero_init_last: nn.init.zeros_(self.conv2_kxk.bn.weight) for attn in (self.attn, self.attn_last): if hasattr(attn, 'reset_parameters'): @@ -924,8 +924,8 @@ class EdgeBlock(nn.Module): self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity() self.act = nn.Identity() if linear_out else layers.act(inplace=True) - def init_weights(self, zero_init_last_bn: bool = False): - if zero_init_last_bn: + def init_weights(self, zero_init_last: bool = False): + if zero_init_last: nn.init.zeros_(self.conv2_1x1.bn.weight) for attn in (self.attn, self.attn_last): if hasattr(attn, 'reset_parameters'): @@ -967,7 +967,7 @@ class RepVggBlock(nn.Module): self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. and use_ident else nn.Identity() self.act = layers.act(inplace=True) - def init_weights(self, zero_init_last_bn: bool = False): + def init_weights(self, zero_init_last: bool = False): # NOTE this init overrides that base model init with specific changes for the block type for m in self.modules(): if isinstance(m, nn.BatchNorm2d): @@ -1024,8 +1024,8 @@ class SelfAttnBlock(nn.Module): self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity() self.act = nn.Identity() if linear_out else layers.act(inplace=True) - def init_weights(self, zero_init_last_bn: bool = False): - if zero_init_last_bn: + def init_weights(self, zero_init_last: bool = False): + if zero_init_last: nn.init.zeros_(self.conv3_1x1.bn.weight) if hasattr(self.self_attn, 'reset_parameters'): self.self_attn.reset_parameters() @@ -1278,7 +1278,7 @@ class ByobNet(nn.Module): Current assumption is that both stem and blocks are in conv-bn-act order (w/ block ending in act). """ def __init__(self, cfg: ByoModelCfg, num_classes=1000, in_chans=3, global_pool='avg', output_stride=32, - zero_init_last_bn=True, img_size=None, drop_rate=0., drop_path_rate=0.): + zero_init_last=True, img_size=None, drop_rate=0., drop_path_rate=0.): super().__init__() self.num_classes = num_classes self.drop_rate = drop_rate @@ -1309,12 +1309,8 @@ class ByobNet(nn.Module): self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate) - for n, m in self.named_modules(): - _init_weights(m, n) - for m in self.modules(): - # call each block's weight init for block-specific overrides to init above - if hasattr(m, 'init_weights'): - m.init_weights(zero_init_last_bn=zero_init_last_bn) + # init weights + named_apply(partial(_init_weights, zero_init_last=zero_init_last), self) def get_classifier(self): return self.head.fc @@ -1334,20 +1330,22 @@ class ByobNet(nn.Module): return x -def _init_weights(m, n=''): - if isinstance(m, nn.Conv2d): - fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels - fan_out //= m.groups - m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) - if m.bias is not None: - m.bias.data.zero_() - elif isinstance(m, nn.Linear): - nn.init.normal_(m.weight, mean=0.0, std=0.01) - if m.bias is not None: - nn.init.zeros_(m.bias) - elif isinstance(m, nn.BatchNorm2d): - nn.init.ones_(m.weight) - nn.init.zeros_(m.bias) +def _init_weights(module, name='', zero_init_last=False): + if isinstance(module, nn.Conv2d): + fan_out = module.kernel_size[0] * module.kernel_size[1] * module.out_channels + fan_out //= module.groups + module.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Linear): + nn.init.normal_(module.weight, mean=0.0, std=0.01) + if module.bias is not None: + nn.init.zeros_(module.bias) + elif isinstance(module, nn.BatchNorm2d): + nn.init.ones_(module.weight) + nn.init.zeros_(module.bias) + elif hasattr(module, 'init_weights'): + module.init_weights(zero_init_last=zero_init_last) def _create_byobnet(variant, pretrained=False, **kwargs): diff --git a/timm/models/layers/bottleneck_attn.py b/timm/models/layers/bottleneck_attn.py index feb7decc..c0c619cc 100644 --- a/timm/models/layers/bottleneck_attn.py +++ b/timm/models/layers/bottleneck_attn.py @@ -102,6 +102,8 @@ class BottleneckAttn(nn.Module): self.pool = nn.AvgPool2d(2, 2) if stride == 2 else nn.Identity() + self.reset_parameters() + def reset_parameters(self): trunc_normal_(self.qkv.weight, std=self.qkv.weight.shape[1] ** -0.5) trunc_normal_(self.pos_embed.height_rel, std=self.scale) diff --git a/timm/models/layers/halo_attn.py b/timm/models/layers/halo_attn.py index 337acae8..d298fc0b 100644 --- a/timm/models/layers/halo_attn.py +++ b/timm/models/layers/halo_attn.py @@ -123,6 +123,8 @@ class HaloAttn(nn.Module): self.pos_embed = PosEmbedRel( block_size=block_size // self.stride, win_size=self.win_size, dim_head=self.dim_head, scale=self.scale) + self.reset_parameters() + def reset_parameters(self): std = self.q.weight.shape[1] ** -0.5 # fan-in trunc_normal_(self.q.weight, std=std) diff --git a/timm/models/layers/lambda_layer.py b/timm/models/layers/lambda_layer.py index 2d1027a1..d298c1aa 100644 --- a/timm/models/layers/lambda_layer.py +++ b/timm/models/layers/lambda_layer.py @@ -57,6 +57,8 @@ class LambdaLayer(nn.Module): self.pool = nn.AvgPool2d(2, 2) if stride == 2 else nn.Identity() + self.reset_parameters() + def reset_parameters(self): trunc_normal_(self.qkv.weight, std=self.dim ** -0.5) trunc_normal_(self.conv_lambda.weight, std=self.dim_k ** -0.5)