Cleanup weight init for byob/byoanet and related

pull/821/head
Ross Wightman 3 years ago
parent 8642401e88
commit 5bd04714e4

@ -33,7 +33,7 @@ import torch
import torch.nn as nn
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .helpers import build_model_with_cfg
from .helpers import build_model_with_cfg, named_apply
from .layers import ClassifierHead, ConvBnAct, BatchNormAct2d, DropPath, AvgPool2dSame, \
create_conv2d, get_act_layer, convert_norm_act, get_attn, make_divisible, to_2tuple
from .registry import register_model
@ -166,7 +166,7 @@ class ByoModelCfg:
stem_chs: int = 32
width_factor: float = 1.0
num_features: int = 0 # num out_channels for final conv, no final 1x1 conv if 0
zero_init_last_bn: bool = True
zero_init_last: bool = True # zero init last weight (usually bn) in residual path
fixed_input_size: bool = False # model constrained to a fixed-input size / img_size must be provided on creation
act_layer: str = 'relu'
@ -757,8 +757,8 @@ class BasicBlock(nn.Module):
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity()
self.act = nn.Identity() if linear_out else layers.act(inplace=True)
def init_weights(self, zero_init_last_bn: bool = False):
if zero_init_last_bn:
def init_weights(self, zero_init_last: bool = False):
if zero_init_last:
nn.init.zeros_(self.conv2_kxk.bn.weight)
for attn in (self.attn, self.attn_last):
if hasattr(attn, 'reset_parameters'):
@ -814,8 +814,8 @@ class BottleneckBlock(nn.Module):
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity()
self.act = nn.Identity() if linear_out else layers.act(inplace=True)
def init_weights(self, zero_init_last_bn: bool = False):
if zero_init_last_bn:
def init_weights(self, zero_init_last: bool = False):
if zero_init_last:
nn.init.zeros_(self.conv3_1x1.bn.weight)
for attn in (self.attn, self.attn_last):
if hasattr(attn, 'reset_parameters'):
@ -871,8 +871,8 @@ class DarkBlock(nn.Module):
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity()
self.act = nn.Identity() if linear_out else layers.act(inplace=True)
def init_weights(self, zero_init_last_bn: bool = False):
if zero_init_last_bn:
def init_weights(self, zero_init_last: bool = False):
if zero_init_last:
nn.init.zeros_(self.conv2_kxk.bn.weight)
for attn in (self.attn, self.attn_last):
if hasattr(attn, 'reset_parameters'):
@ -924,8 +924,8 @@ class EdgeBlock(nn.Module):
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity()
self.act = nn.Identity() if linear_out else layers.act(inplace=True)
def init_weights(self, zero_init_last_bn: bool = False):
if zero_init_last_bn:
def init_weights(self, zero_init_last: bool = False):
if zero_init_last:
nn.init.zeros_(self.conv2_1x1.bn.weight)
for attn in (self.attn, self.attn_last):
if hasattr(attn, 'reset_parameters'):
@ -967,7 +967,7 @@ class RepVggBlock(nn.Module):
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. and use_ident else nn.Identity()
self.act = layers.act(inplace=True)
def init_weights(self, zero_init_last_bn: bool = False):
def init_weights(self, zero_init_last: bool = False):
# NOTE this init overrides that base model init with specific changes for the block type
for m in self.modules():
if isinstance(m, nn.BatchNorm2d):
@ -1024,8 +1024,8 @@ class SelfAttnBlock(nn.Module):
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity()
self.act = nn.Identity() if linear_out else layers.act(inplace=True)
def init_weights(self, zero_init_last_bn: bool = False):
if zero_init_last_bn:
def init_weights(self, zero_init_last: bool = False):
if zero_init_last:
nn.init.zeros_(self.conv3_1x1.bn.weight)
if hasattr(self.self_attn, 'reset_parameters'):
self.self_attn.reset_parameters()
@ -1278,7 +1278,7 @@ class ByobNet(nn.Module):
Current assumption is that both stem and blocks are in conv-bn-act order (w/ block ending in act).
"""
def __init__(self, cfg: ByoModelCfg, num_classes=1000, in_chans=3, global_pool='avg', output_stride=32,
zero_init_last_bn=True, img_size=None, drop_rate=0., drop_path_rate=0.):
zero_init_last=True, img_size=None, drop_rate=0., drop_path_rate=0.):
super().__init__()
self.num_classes = num_classes
self.drop_rate = drop_rate
@ -1309,12 +1309,8 @@ class ByobNet(nn.Module):
self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate)
for n, m in self.named_modules():
_init_weights(m, n)
for m in self.modules():
# call each block's weight init for block-specific overrides to init above
if hasattr(m, 'init_weights'):
m.init_weights(zero_init_last_bn=zero_init_last_bn)
# init weights
named_apply(partial(_init_weights, zero_init_last=zero_init_last), self)
def get_classifier(self):
return self.head.fc
@ -1334,20 +1330,22 @@ class ByobNet(nn.Module):
return x
def _init_weights(m, n=''):
if isinstance(m, nn.Conv2d):
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
fan_out //= m.groups
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
if m.bias is not None:
m.bias.data.zero_()
elif isinstance(m, nn.Linear):
nn.init.normal_(m.weight, mean=0.0, std=0.01)
if m.bias is not None:
nn.init.zeros_(m.bias)
elif isinstance(m, nn.BatchNorm2d):
nn.init.ones_(m.weight)
nn.init.zeros_(m.bias)
def _init_weights(module, name='', zero_init_last=False):
if isinstance(module, nn.Conv2d):
fan_out = module.kernel_size[0] * module.kernel_size[1] * module.out_channels
fan_out //= module.groups
module.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Linear):
nn.init.normal_(module.weight, mean=0.0, std=0.01)
if module.bias is not None:
nn.init.zeros_(module.bias)
elif isinstance(module, nn.BatchNorm2d):
nn.init.ones_(module.weight)
nn.init.zeros_(module.bias)
elif hasattr(module, 'init_weights'):
module.init_weights(zero_init_last=zero_init_last)
def _create_byobnet(variant, pretrained=False, **kwargs):

@ -102,6 +102,8 @@ class BottleneckAttn(nn.Module):
self.pool = nn.AvgPool2d(2, 2) if stride == 2 else nn.Identity()
self.reset_parameters()
def reset_parameters(self):
trunc_normal_(self.qkv.weight, std=self.qkv.weight.shape[1] ** -0.5)
trunc_normal_(self.pos_embed.height_rel, std=self.scale)

@ -123,6 +123,8 @@ class HaloAttn(nn.Module):
self.pos_embed = PosEmbedRel(
block_size=block_size // self.stride, win_size=self.win_size, dim_head=self.dim_head, scale=self.scale)
self.reset_parameters()
def reset_parameters(self):
std = self.q.weight.shape[1] ** -0.5 # fan-in
trunc_normal_(self.q.weight, std=std)

@ -57,6 +57,8 @@ class LambdaLayer(nn.Module):
self.pool = nn.AvgPool2d(2, 2) if stride == 2 else nn.Identity()
self.reset_parameters()
def reset_parameters(self):
trunc_normal_(self.qkv.weight, std=self.dim ** -0.5)
trunc_normal_(self.conv_lambda.weight, std=self.dim_k ** -0.5)

Loading…
Cancel
Save