|
|
@ -33,7 +33,7 @@ import torch
|
|
|
|
import torch.nn as nn
|
|
|
|
import torch.nn as nn
|
|
|
|
|
|
|
|
|
|
|
|
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
|
|
|
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
|
|
|
from .helpers import build_model_with_cfg
|
|
|
|
from .helpers import build_model_with_cfg, named_apply
|
|
|
|
from .layers import ClassifierHead, ConvBnAct, BatchNormAct2d, DropPath, AvgPool2dSame, \
|
|
|
|
from .layers import ClassifierHead, ConvBnAct, BatchNormAct2d, DropPath, AvgPool2dSame, \
|
|
|
|
create_conv2d, get_act_layer, convert_norm_act, get_attn, make_divisible, to_2tuple
|
|
|
|
create_conv2d, get_act_layer, convert_norm_act, get_attn, make_divisible, to_2tuple
|
|
|
|
from .registry import register_model
|
|
|
|
from .registry import register_model
|
|
|
@ -166,7 +166,7 @@ class ByoModelCfg:
|
|
|
|
stem_chs: int = 32
|
|
|
|
stem_chs: int = 32
|
|
|
|
width_factor: float = 1.0
|
|
|
|
width_factor: float = 1.0
|
|
|
|
num_features: int = 0 # num out_channels for final conv, no final 1x1 conv if 0
|
|
|
|
num_features: int = 0 # num out_channels for final conv, no final 1x1 conv if 0
|
|
|
|
zero_init_last_bn: bool = True
|
|
|
|
zero_init_last: bool = True # zero init last weight (usually bn) in residual path
|
|
|
|
fixed_input_size: bool = False # model constrained to a fixed-input size / img_size must be provided on creation
|
|
|
|
fixed_input_size: bool = False # model constrained to a fixed-input size / img_size must be provided on creation
|
|
|
|
|
|
|
|
|
|
|
|
act_layer: str = 'relu'
|
|
|
|
act_layer: str = 'relu'
|
|
|
@ -757,8 +757,8 @@ class BasicBlock(nn.Module):
|
|
|
|
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity()
|
|
|
|
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity()
|
|
|
|
self.act = nn.Identity() if linear_out else layers.act(inplace=True)
|
|
|
|
self.act = nn.Identity() if linear_out else layers.act(inplace=True)
|
|
|
|
|
|
|
|
|
|
|
|
def init_weights(self, zero_init_last_bn: bool = False):
|
|
|
|
def init_weights(self, zero_init_last: bool = False):
|
|
|
|
if zero_init_last_bn:
|
|
|
|
if zero_init_last:
|
|
|
|
nn.init.zeros_(self.conv2_kxk.bn.weight)
|
|
|
|
nn.init.zeros_(self.conv2_kxk.bn.weight)
|
|
|
|
for attn in (self.attn, self.attn_last):
|
|
|
|
for attn in (self.attn, self.attn_last):
|
|
|
|
if hasattr(attn, 'reset_parameters'):
|
|
|
|
if hasattr(attn, 'reset_parameters'):
|
|
|
@ -814,8 +814,8 @@ class BottleneckBlock(nn.Module):
|
|
|
|
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity()
|
|
|
|
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity()
|
|
|
|
self.act = nn.Identity() if linear_out else layers.act(inplace=True)
|
|
|
|
self.act = nn.Identity() if linear_out else layers.act(inplace=True)
|
|
|
|
|
|
|
|
|
|
|
|
def init_weights(self, zero_init_last_bn: bool = False):
|
|
|
|
def init_weights(self, zero_init_last: bool = False):
|
|
|
|
if zero_init_last_bn:
|
|
|
|
if zero_init_last:
|
|
|
|
nn.init.zeros_(self.conv3_1x1.bn.weight)
|
|
|
|
nn.init.zeros_(self.conv3_1x1.bn.weight)
|
|
|
|
for attn in (self.attn, self.attn_last):
|
|
|
|
for attn in (self.attn, self.attn_last):
|
|
|
|
if hasattr(attn, 'reset_parameters'):
|
|
|
|
if hasattr(attn, 'reset_parameters'):
|
|
|
@ -871,8 +871,8 @@ class DarkBlock(nn.Module):
|
|
|
|
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity()
|
|
|
|
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity()
|
|
|
|
self.act = nn.Identity() if linear_out else layers.act(inplace=True)
|
|
|
|
self.act = nn.Identity() if linear_out else layers.act(inplace=True)
|
|
|
|
|
|
|
|
|
|
|
|
def init_weights(self, zero_init_last_bn: bool = False):
|
|
|
|
def init_weights(self, zero_init_last: bool = False):
|
|
|
|
if zero_init_last_bn:
|
|
|
|
if zero_init_last:
|
|
|
|
nn.init.zeros_(self.conv2_kxk.bn.weight)
|
|
|
|
nn.init.zeros_(self.conv2_kxk.bn.weight)
|
|
|
|
for attn in (self.attn, self.attn_last):
|
|
|
|
for attn in (self.attn, self.attn_last):
|
|
|
|
if hasattr(attn, 'reset_parameters'):
|
|
|
|
if hasattr(attn, 'reset_parameters'):
|
|
|
@ -924,8 +924,8 @@ class EdgeBlock(nn.Module):
|
|
|
|
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity()
|
|
|
|
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity()
|
|
|
|
self.act = nn.Identity() if linear_out else layers.act(inplace=True)
|
|
|
|
self.act = nn.Identity() if linear_out else layers.act(inplace=True)
|
|
|
|
|
|
|
|
|
|
|
|
def init_weights(self, zero_init_last_bn: bool = False):
|
|
|
|
def init_weights(self, zero_init_last: bool = False):
|
|
|
|
if zero_init_last_bn:
|
|
|
|
if zero_init_last:
|
|
|
|
nn.init.zeros_(self.conv2_1x1.bn.weight)
|
|
|
|
nn.init.zeros_(self.conv2_1x1.bn.weight)
|
|
|
|
for attn in (self.attn, self.attn_last):
|
|
|
|
for attn in (self.attn, self.attn_last):
|
|
|
|
if hasattr(attn, 'reset_parameters'):
|
|
|
|
if hasattr(attn, 'reset_parameters'):
|
|
|
@ -967,7 +967,7 @@ class RepVggBlock(nn.Module):
|
|
|
|
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. and use_ident else nn.Identity()
|
|
|
|
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. and use_ident else nn.Identity()
|
|
|
|
self.act = layers.act(inplace=True)
|
|
|
|
self.act = layers.act(inplace=True)
|
|
|
|
|
|
|
|
|
|
|
|
def init_weights(self, zero_init_last_bn: bool = False):
|
|
|
|
def init_weights(self, zero_init_last: bool = False):
|
|
|
|
# NOTE this init overrides that base model init with specific changes for the block type
|
|
|
|
# NOTE this init overrides that base model init with specific changes for the block type
|
|
|
|
for m in self.modules():
|
|
|
|
for m in self.modules():
|
|
|
|
if isinstance(m, nn.BatchNorm2d):
|
|
|
|
if isinstance(m, nn.BatchNorm2d):
|
|
|
@ -1024,8 +1024,8 @@ class SelfAttnBlock(nn.Module):
|
|
|
|
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity()
|
|
|
|
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity()
|
|
|
|
self.act = nn.Identity() if linear_out else layers.act(inplace=True)
|
|
|
|
self.act = nn.Identity() if linear_out else layers.act(inplace=True)
|
|
|
|
|
|
|
|
|
|
|
|
def init_weights(self, zero_init_last_bn: bool = False):
|
|
|
|
def init_weights(self, zero_init_last: bool = False):
|
|
|
|
if zero_init_last_bn:
|
|
|
|
if zero_init_last:
|
|
|
|
nn.init.zeros_(self.conv3_1x1.bn.weight)
|
|
|
|
nn.init.zeros_(self.conv3_1x1.bn.weight)
|
|
|
|
if hasattr(self.self_attn, 'reset_parameters'):
|
|
|
|
if hasattr(self.self_attn, 'reset_parameters'):
|
|
|
|
self.self_attn.reset_parameters()
|
|
|
|
self.self_attn.reset_parameters()
|
|
|
@ -1278,7 +1278,7 @@ class ByobNet(nn.Module):
|
|
|
|
Current assumption is that both stem and blocks are in conv-bn-act order (w/ block ending in act).
|
|
|
|
Current assumption is that both stem and blocks are in conv-bn-act order (w/ block ending in act).
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
def __init__(self, cfg: ByoModelCfg, num_classes=1000, in_chans=3, global_pool='avg', output_stride=32,
|
|
|
|
def __init__(self, cfg: ByoModelCfg, num_classes=1000, in_chans=3, global_pool='avg', output_stride=32,
|
|
|
|
zero_init_last_bn=True, img_size=None, drop_rate=0., drop_path_rate=0.):
|
|
|
|
zero_init_last=True, img_size=None, drop_rate=0., drop_path_rate=0.):
|
|
|
|
super().__init__()
|
|
|
|
super().__init__()
|
|
|
|
self.num_classes = num_classes
|
|
|
|
self.num_classes = num_classes
|
|
|
|
self.drop_rate = drop_rate
|
|
|
|
self.drop_rate = drop_rate
|
|
|
@ -1309,12 +1309,8 @@ class ByobNet(nn.Module):
|
|
|
|
|
|
|
|
|
|
|
|
self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate)
|
|
|
|
self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate)
|
|
|
|
|
|
|
|
|
|
|
|
for n, m in self.named_modules():
|
|
|
|
# init weights
|
|
|
|
_init_weights(m, n)
|
|
|
|
named_apply(partial(_init_weights, zero_init_last=zero_init_last), self)
|
|
|
|
for m in self.modules():
|
|
|
|
|
|
|
|
# call each block's weight init for block-specific overrides to init above
|
|
|
|
|
|
|
|
if hasattr(m, 'init_weights'):
|
|
|
|
|
|
|
|
m.init_weights(zero_init_last_bn=zero_init_last_bn)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_classifier(self):
|
|
|
|
def get_classifier(self):
|
|
|
|
return self.head.fc
|
|
|
|
return self.head.fc
|
|
|
@ -1334,20 +1330,22 @@ class ByobNet(nn.Module):
|
|
|
|
return x
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _init_weights(m, n=''):
|
|
|
|
def _init_weights(module, name='', zero_init_last=False):
|
|
|
|
if isinstance(m, nn.Conv2d):
|
|
|
|
if isinstance(module, nn.Conv2d):
|
|
|
|
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
|
|
|
fan_out = module.kernel_size[0] * module.kernel_size[1] * module.out_channels
|
|
|
|
fan_out //= m.groups
|
|
|
|
fan_out //= module.groups
|
|
|
|
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
|
|
|
|
module.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
|
|
|
|
if m.bias is not None:
|
|
|
|
if module.bias is not None:
|
|
|
|
m.bias.data.zero_()
|
|
|
|
module.bias.data.zero_()
|
|
|
|
elif isinstance(m, nn.Linear):
|
|
|
|
elif isinstance(module, nn.Linear):
|
|
|
|
nn.init.normal_(m.weight, mean=0.0, std=0.01)
|
|
|
|
nn.init.normal_(module.weight, mean=0.0, std=0.01)
|
|
|
|
if m.bias is not None:
|
|
|
|
if module.bias is not None:
|
|
|
|
nn.init.zeros_(m.bias)
|
|
|
|
nn.init.zeros_(module.bias)
|
|
|
|
elif isinstance(m, nn.BatchNorm2d):
|
|
|
|
elif isinstance(module, nn.BatchNorm2d):
|
|
|
|
nn.init.ones_(m.weight)
|
|
|
|
nn.init.ones_(module.weight)
|
|
|
|
nn.init.zeros_(m.bias)
|
|
|
|
nn.init.zeros_(module.bias)
|
|
|
|
|
|
|
|
elif hasattr(module, 'init_weights'):
|
|
|
|
|
|
|
|
module.init_weights(zero_init_last=zero_init_last)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _create_byobnet(variant, pretrained=False, **kwargs):
|
|
|
|
def _create_byobnet(variant, pretrained=False, **kwargs):
|
|
|
|