@ -25,9 +25,9 @@ above nets that include attention.
Hacked together by / copyright Ross Wightman, 2021.
import math
from dataclasses import dataclass, field
from dataclasses import dataclass, field, replace
from collections import OrderedDict
from typing import Tuple, Dict, Optional, Union, Any, Callable
from typing import Tuple, List, Optional, Union, Any, Callable, Sequence
from functools import partial
import torch
@ -35,11 +35,11 @@ import torch.nn as nn
from .helpers import build_model_with_cfg
from .layers import ClassifierHead, ConvBnAct, DropPath, AvgPool2dSame, \
create_conv2d, get_act_layer, get_attn, convert_norm_act, make_divisible
from .layers import ClassifierHead, ConvBnAct, BatchNormAct2d, DropPath, AvgPool2dSame, \
create_conv2d, get_act_layer, convert_norm_act, get_attn, make_divisible
from .registry import register_model
__all__ = ['ByobNet', 'ByobCfg', 'BlocksCfg']
__all__ = ['ByobNet', 'ByobCfg', 'BlocksCfg', 'create_byob_stem', 'create_block']
def _cfg(url='', **kwargs):
@ -98,20 +98,22 @@ class BlocksCfg:
s: int = 2 # stride of stage (first block)
gs: Optional[Union[int, Callable]] = None # group-size of blocks in stage, conv is depthwise if gs == 1
br: float = 1. # bottleneck-ratio of blocks in stage
no_attn: bool = True # disable channel attn (ie SE) when layer is set for model
class ByobCfg:
blocks: Tuple[BlocksCfg, ...]
blocks: Tuple[Union[BlocksCfg, Tuple[BlocksCfg, ...]], ...]
downsample: str = 'conv1x1'
stem_type: str = '3x3'
stem_pool: str = ''
stem_chs: int = 32
width_factor: float = 1.0
num_features: int = 0 # num out_channels for final conv, no final 1x1 conv if 0
zero_init_last_bn: bool = True
act_layer: str = 'relu'
norm_layer: nn.Module = nn.BatchNorm2d
norm_layer: str = 'batchnorm'
attn_layer: Optional[str] = None
attn_kwargs: dict = field(default_factory=lambda: dict())
@ -201,17 +203,29 @@ model_cfgs = dict(
def _na_args(cfg: dict):
return dict(
norm_layer=cfg.get('norm_layer', nn.BatchNorm2d),
act_layer=cfg.get('act_layer', nn.ReLU))
BlocksCfg(type='bottle', d=2, c=256, s=1, gs=32, br=0.25),
BlocksCfg(type='bottle', d=4, c=512, s=2, gs=32, br=0.25),
BlocksCfg(type='bottle', d=6, c=1536, s=2, gs=32, br=0.25),
BlocksCfg(type='bottle', d=4, c=1536, s=2, gs=1, br=1.0),
def _ex_tuple(cfg: dict, *names):
return tuple([cfg.get(n, None) for n in names])
def expand_blocks_cfg(stage_blocks_cfg: Union[BlocksCfg, Sequence[BlocksCfg]]) -> List[BlocksCfg]:
if not isinstance(stage_blocks_cfg, Sequence):
stage_blocks_cfg = (stage_blocks_cfg,)
block_cfgs = []
for i, cfg in enumerate(stage_blocks_cfg):
block_cfgs += [replace(cfg, d=1) for _ in range(cfg.d)]
return block_cfgs
def num_groups(group_size, channels):
@ -223,27 +237,36 @@ def num_groups(group_size, channels):
return channels // group_size
class LayerFn:
conv_norm_act: Callable = ConvBnAct
norm_act: Callable = BatchNormAct2d
act: Callable = nn.ReLU
attn: Optional[Callable] = None
class DownsampleAvg(nn.Module):
def __init__(self, in_chs, out_chs, stride=1, dilation=1, apply_act=False, norm_layer=None, act_layer=None):
def __init__(self, in_chs, out_chs, stride=1, dilation=1, apply_act=False, layers: LayerFn = None):
""" AvgPool Downsampling as in 'D' ResNet variants."""
super(DownsampleAvg, self).__init__()
layers = layers or LayerFn()
avg_stride = stride if dilation == 1 else 1
if stride > 1 or dilation > 1:
avg_pool_fn = AvgPool2dSame if avg_stride == 1 and dilation > 1 else nn.AvgPool2d
self.pool = avg_pool_fn(2, avg_stride, ceil_mode=True, count_include_pad=False)
self.pool = nn.Identity()
self.conv = ConvBnAct(in_chs, out_chs, 1, apply_act=apply_act, norm_layer=norm_layer, act_layer=act_layer)
self.conv = layers.conv_norm_act(in_chs, out_chs, 1, apply_act=apply_act)
def forward(self, x):
return self.conv(self.pool(x))
def create_downsample(type, **kwargs):
if type == 'avg':
def create_downsample(downsample_type, layers: LayerFn, **kwargs):
if downsample_type == 'avg':
return DownsampleAvg(**kwargs)
return ConvBnAct(kwargs.pop('in_chs'), kwargs.pop('out_chs'), kernel_size=1, **kwargs)
return layers.conv_norm_act(kwargs.pop('in_chs'), kwargs.pop('out_chs'), kernel_size=1, **kwargs)
class BasicBlock(nn.Module):
@ -252,28 +275,25 @@ class BasicBlock(nn.Module):
def __init__(
self, in_chs, out_chs, kernel_size=3, stride=1, dilation=(1, 1), group_size=None, bottle_ratio=1.0,
downsample='avg', linear_out=False, layer_cfg=None, drop_block=None, drop_path_rate=0.):
downsample='avg', linear_out=False, layers: LayerFn = None, drop_block=None, drop_path_rate=0.):
super(BasicBlock, self).__init__()
layer_cfg = layer_cfg or {}
act_layer, attn_layer = _ex_tuple(layer_cfg, 'act_layer', 'attn_layer')
layer_args = _na_args(layer_cfg)
layers = layers or LayerFn()
mid_chs = make_divisible(out_chs * bottle_ratio)
groups = num_groups(group_size, mid_chs)
if in_chs != out_chs or stride != 1 or dilation[0] != dilation[1]:
self.shortcut = create_downsample(
downsample, in_chs=in_chs, out_chs=out_chs, stride=stride, dilation=dilation[0],
apply_act=False, **layer_args)
apply_act=False, layers=layers)
self.shortcut = nn.Identity()
self.conv1_kxk = ConvBnAct(in_chs, mid_chs, kernel_size, stride=stride, dilation=dilation[0], **layer_args)
self.conv2_kxk = ConvBnAct(
mid_chs, out_chs, kernel_size, dilation=dilation[1], groups=groups,
drop_block=drop_block, apply_act=False, **layer_args)
self.attn = nn.Identity() if attn_layer is None else attn_layer(out_chs)
self.conv1_kxk = layers.conv_norm_act(in_chs, mid_chs, kernel_size, stride=stride, dilation=dilation[0])
self.conv2_kxk = layers.conv_norm_act(
mid_chs, out_chs, kernel_size, dilation=dilation[1], groups=groups, drop_block=drop_block, apply_act=False)
self.attn = nn.Identity() if layers.attn is None else layers.attn(out_chs)
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity()
self.act = nn.Identity() if linear_out else act_layer(inplace=True)
self.act = nn.Identity() if linear_out else layers.act(inplace=True)
def init_weights(self, zero_init_last_bn=False):
if zero_init_last_bn:
@ -297,29 +317,27 @@ class BottleneckBlock(nn.Module):
def __init__(self, in_chs, out_chs, kernel_size=3, stride=1, dilation=(1, 1), bottle_ratio=1., group_size=None,
downsample='avg', linear_out=False, layer_cfg=None, drop_block=None, drop_path_rate=0.):
downsample='avg', linear_out=False, layers : LayerFn = None, drop_block=None, drop_path_rate=0.):
super(BottleneckBlock, self).__init__()
layer_cfg = layer_cfg or {}
act_layer, attn_layer = _ex_tuple(layer_cfg, 'act_layer', 'attn_layer')
layer_args = _na_args(layer_cfg)
layers = layers or LayerFn()
mid_chs = make_divisible(out_chs * bottle_ratio)
groups = num_groups(group_size, mid_chs)
if in_chs != out_chs or stride != 1 or dilation[0] != dilation[1]:
self.shortcut = create_downsample(
downsample, in_chs=in_chs, out_chs=out_chs, stride=stride, dilation=dilation[0],
apply_act=False, **layer_args)
apply_act=False, layers=layers)
self.shortcut = nn.Identity()
self.conv1_1x1 = ConvBnAct(in_chs, mid_chs, 1, **layer_args)
self.conv2_kxk = ConvBnAct(
self.conv1_1x1 = layers.conv_norm_act(in_chs, mid_chs, 1)
self.conv2_kxk = layers.conv_norm_act(
mid_chs, mid_chs, kernel_size, stride=stride, dilation=dilation[0],
groups=groups, drop_block=drop_block, **layer_args)
self.attn = nn.Identity() if attn_layer is None else attn_layer(mid_chs)
self.conv3_1x1 = ConvBnAct(mid_chs, out_chs, 1, apply_act=False, **layer_args)
groups=groups, drop_block=drop_block)
self.attn = nn.Identity() if layers.attn is None else layers.attn(mid_chs)
self.conv3_1x1 = layers.conv_norm_act(mid_chs, out_chs, 1, apply_act=False)
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity()
self.act = nn.Identity() if linear_out else act_layer(inplace=True)
self.act = nn.Identity() if linear_out else layers.act(inplace=True)
def init_weights(self, zero_init_last_bn=False):
if zero_init_last_bn:
@ -350,28 +368,26 @@ class DarkBlock(nn.Module):
def __init__(self, in_chs, out_chs, kernel_size=3, stride=1, dilation=(1, 1), bottle_ratio=1.0, group_size=None,
downsample='avg', linear_out=False, layer_cfg=None, drop_block=None, drop_path_rate=0.):
downsample='avg', linear_out=False, layers: LayerFn = None, drop_block=None, drop_path_rate=0.):
super(DarkBlock, self).__init__()
layer_cfg = layer_cfg or {}
act_layer, attn_layer = _ex_tuple(layer_cfg, 'act_layer', 'attn_layer')
layer_args = _na_args(layer_cfg)
layers = layers or LayerFn()
mid_chs = make_divisible(out_chs * bottle_ratio)
groups = num_groups(group_size, mid_chs)
if in_chs != out_chs or stride != 1 or dilation[0] != dilation[1]:
self.shortcut = create_downsample(
downsample, in_chs=in_chs, out_chs=out_chs, stride=stride, dilation=dilation[0],
apply_act=False, **layer_args)
apply_act=False, layers=layers)
self.shortcut = nn.Identity()
self.conv1_1x1 = ConvBnAct(in_chs, mid_chs, 1, **layer_args)
self.conv2_kxk = ConvBnAct(
self.conv1_1x1 = layers.conv_norm_act(in_chs, mid_chs, 1)
self.conv2_kxk = layers.conv_norm_act(
mid_chs, out_chs, kernel_size, stride=stride, dilation=dilation[0],
groups=groups, drop_block=drop_block, apply_act=False, **layer_args)
self.attn = nn.Identity() if attn_layer is None else attn_layer(out_chs)
groups=groups, drop_block=drop_block, apply_act=False)
self.attn = nn.Identity() if layers.attn is None else layers.attn(out_chs)
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity()
self.act = nn.Identity() if linear_out else act_layer(inplace=True)
self.act = nn.Identity() if linear_out else layers.act(inplace=True)
def init_weights(self, zero_init_last_bn=False):
if zero_init_last_bn:
@ -399,28 +415,26 @@ class EdgeBlock(nn.Module):
def __init__(self, in_chs, out_chs, kernel_size=3, stride=1, dilation=(1, 1), bottle_ratio=1.0, group_size=None,
downsample='avg', linear_out=False, layer_cfg=None, drop_block=None, drop_path_rate=0.):
downsample='avg', linear_out=False, layers: LayerFn = None, drop_block=None, drop_path_rate=0.):
super(EdgeBlock, self).__init__()
layer_cfg = layer_cfg or {}
act_layer, attn_layer = _ex_tuple(layer_cfg, 'act_layer', 'attn_layer')
layer_args = _na_args(layer_cfg)
layers = layers or LayerFn()
mid_chs = make_divisible(out_chs * bottle_ratio)
groups = num_groups(group_size, mid_chs)
if in_chs != out_chs or stride != 1 or dilation[0] != dilation[1]:
self.shortcut = create_downsample(
downsample, in_chs=in_chs, out_chs=out_chs, stride=stride, dilation=dilation[0],
apply_act=False, **layer_args)
apply_act=False, layers=layers)
self.shortcut = nn.Identity()
self.conv1_kxk = ConvBnAct(
self.conv1_kxk = layers.conv_norm_act(
in_chs, mid_chs, kernel_size, stride=stride, dilation=dilation[0],
groups=groups, drop_block=drop_block, **layer_args)
self.attn = nn.Identity() if attn_layer is None else attn_layer(out_chs)
self.conv2_1x1 = ConvBnAct(mid_chs, out_chs, 1, apply_act=False, **layer_args)
groups=groups, drop_block=drop_block)
self.attn = nn.Identity() if layers.attn is None else layers.attn(out_chs)
self.conv2_1x1 = layers.conv_norm_act(mid_chs, out_chs, 1, apply_act=False)
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity()
self.act = nn.Identity() if linear_out else act_layer(inplace=True)
self.act = nn.Identity() if linear_out else layers.act(inplace=True)
def init_weights(self, zero_init_last_bn=False):
if zero_init_last_bn:
@ -446,23 +460,20 @@ class RepVggBlock(nn.Module):
def __init__(self, in_chs, out_chs, kernel_size=3, stride=1, dilation=(1, 1), bottle_ratio=1.0, group_size=None,
downsample='', layer_cfg=None, drop_block=None, drop_path_rate=0.):
downsample='', layers : LayerFn = None, drop_block=None, drop_path_rate=0.):
super(RepVggBlock, self).__init__()
layer_cfg = layer_cfg or {}
act_layer, norm_layer, attn_layer = _ex_tuple(layer_cfg, 'act_layer', 'norm_layer', 'attn_layer')
norm_layer = convert_norm_act(norm_layer=norm_layer, act_layer=act_layer)
layer_args = _na_args(layer_cfg)
layers = layers or LayerFn()
groups = num_groups(group_size, in_chs)
use_ident = in_chs == out_chs and stride == 1 and dilation[0] == dilation[1]
self.identity = norm_layer(out_chs, apply_act=False) if use_ident else None
self.conv_kxk = ConvBnAct(
self.identity = layers.norm_act(out_chs, apply_act=False) if use_ident else None
self.conv_kxk = layers.conv_norm_act(
in_chs, out_chs, kernel_size, stride=stride, dilation=dilation[0],
groups=groups, drop_block=drop_block, apply_act=False, **layer_args)
self.conv_1x1 = ConvBnAct(in_chs, out_chs, 1, stride=stride, groups=groups, apply_act=False, **layer_args)
self.attn = nn.Identity() if attn_layer is None else attn_layer(out_chs)
groups=groups, drop_block=drop_block, apply_act=False)
self.conv_1x1 = layers.conv_norm_act(in_chs, out_chs, 1, stride=stride, groups=groups, apply_act=False)
self.attn = nn.Identity() if layers.attn is None else layers.attn(out_chs)
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. and use_ident else nn.Identity()
self.act = act_layer(inplace=True)
self.act = layers.act(inplace=True)
def init_weights(self, zero_init_last_bn=False):
# NOTE this init overrides that base model init with specific changes for the block type
@ -504,33 +515,200 @@ def create_block(block: Union[str, nn.Module], **kwargs):
return _block_registry[block](**kwargs)
def create_stem(in_chs, out_chs, stem_type='', layer_cfg=None):
layer_cfg = layer_cfg or {}
layer_args = _na_args(layer_cfg)
assert stem_type in ('', 'deep', 'deep_tiered', '3x3', '7x7', 'rep')
if 'deep' in stem_type:
# 3 deep 3x3 conv stack
stem = OrderedDict()
stem_chs = (out_chs // 2, out_chs // 2)
if 'tiered' in stem_type:
stem_chs = (3 * stem_chs[0] // 4, stem_chs[1])
norm_layer, act_layer = _ex_tuple(layer_args, 'norm_layer', 'act_layer')
stem['conv1'] = create_conv2d(in_chs, stem_chs[0], kernel_size=3, stride=2)
stem['conv2'] = create_conv2d(stem_chs[0], stem_chs[1], kernel_size=3, stride=1)
stem['conv3'] = create_conv2d(stem_chs[1], out_chs, kernel_size=3, stride=1)
norm_act_layer = convert_norm_act(norm_layer=norm_layer, act_layer=act_layer)
stem['na'] = norm_act_layer(out_chs)
stem = nn.Sequential(stem)
# class Stem(nn.Module):
# def __init__(self, in_chs, out_chs, kernel_size=3, stride=4, pool='maxpool',
# num_rep=3, num_act=None, chs_decay=0.5, layers: LayerFn = None):
# super().__init__()
# assert stride in (2, 4)
# if pool:
# assert stride == 4
# layers = layers or LayerFn()
# if isinstance(out_chs, (list, tuple)):
# num_rep = len(out_chs)
# stem_chs = out_chs
# else:
# stem_chs = [round(out_chs * chs_decay ** i) for i in range(num_rep)][::-1]
# self.stride = stride
# stem_strides = [2] + [1] * (num_rep - 1)
# if stride == 4 and not pool:
# # set last conv in stack to be strided if stride == 4 and no pooling layer
# stem_strides[-1] = 2
# num_act = num_rep if num_act is None else num_act
# # if num_act < num_rep, first convs in stack won't have bn + act
# stem_norm_acts = [False] * (num_rep - num_act) + [True] * num_act
# prev_chs = in_chs
# convs = []
# for i, (ch, s, na) in enumerate(zip(stem_chs, stem_strides, stem_norm_acts)):
# layer_fn = layers.conv_norm_act if na else create_conv2d
# convs.append(layer_fn(prev_chs, ch, kernel_size=kernel_size, stride=s))
# prev_chs = ch
# self.conv = nn.Sequential(*convs) if len(convs) > 1 else convs[0]
# if not pool:
# self.pool = nn.Identity()
# elif 'max' in pool.lower():
# self.pool = nn.MaxPool2d(3, 2, 1) if pool else nn.Identity()
# else:
# assert False, "Unknown pooling type"
# def forward(self, x):
# x = self.conv(x)
# x = self.pool(x)
# return x
class Stem(nn.Sequential):
def __init__(self, in_chs, out_chs, kernel_size=3, stride=4, pool='maxpool',
num_rep=3, num_act=None, chs_decay=0.5, layers: LayerFn = None):
assert stride in (2, 4)
layers = layers or LayerFn()
if isinstance(out_chs, (list, tuple)):
num_rep = len(out_chs)
stem_chs = out_chs
stem_chs = [round(out_chs * chs_decay ** i) for i in range(num_rep)][::-1]
self.stride = stride
self.feature_info = [] # track intermediate features
prev_feat = ''
stem_strides = [2] + [1] * (num_rep - 1)
if stride == 4 and not pool:
# set last conv in stack to be strided if stride == 4 and no pooling layer
stem_strides[-1] = 2
num_act = num_rep if num_act is None else num_act
# if num_act < num_rep, first convs in stack won't have bn + act
stem_norm_acts = [False] * (num_rep - num_act) + [True] * num_act
prev_chs = in_chs
curr_stride = 1
for i, (ch, s, na) in enumerate(zip(stem_chs, stem_strides, stem_norm_acts)):
layer_fn = layers.conv_norm_act if na else create_conv2d
conv_name = f'conv{i + 1}'
if i > 0 and s > 1:
self.feature_info.append(dict(num_chs=ch, reduction=curr_stride, module=prev_feat))
self.add_module(conv_name, layer_fn(prev_chs, ch, kernel_size=kernel_size, stride=s))
prev_chs = ch
curr_stride *= s
prev_feat = conv_name
if 'max' in pool.lower():
self.feature_info.append(dict(num_chs=prev_chs, reduction=curr_stride, module=prev_feat))
self.add_module('pool', nn.MaxPool2d(3, 2, 1))
curr_stride *= 2
prev_feat = 'pool'
self.feature_info.append(dict(num_chs=prev_chs, reduction=curr_stride, module=prev_feat))
assert curr_stride == stride
def create_byob_stem(in_chs, out_chs, stem_type='', pool_type='', feat_prefix='stem', layers: LayerFn = None):
layers = layers or LayerFn()
assert stem_type in ('', 'quad', 'tiered', 'deep', 'rep', '7x7', '3x3')
if 'quad' in stem_type:
# based on NFNet stem, stack of 4 3x3 convs
num_act = 2 if 'quad2' in stem_type else None
stem = Stem(in_chs, out_chs, num_rep=4, num_act=num_act, pool=pool_type, layers=layers)
elif 'tiered' in stem_type:
# 3x3 stack of 3 convs as in my ResNet-T
stem = Stem(in_chs, (3 * out_chs // 8, out_chs // 2, out_chs), pool=pool_type, layers=layers)
elif 'deep' in stem_type:
# 3x3 stack of 3 convs as in ResNet-D
stem = Stem(in_chs, out_chs, num_rep=3, chs_decay=1.0, pool=pool_type, layers=layers)
elif 'rep' in stem_type:
stem = RepVggBlock(in_chs, out_chs, stride=2, layers=layers)
elif '7x7' in stem_type:
# 7x7 stem conv as in ResNet
stem = ConvBnAct(in_chs, out_chs, 7, stride=2, **layer_args)
elif 'rep' in stem_type:
stem = RepVggBlock(in_chs, out_chs, stride=2, layer_cfg=layer_cfg)
if pool_type:
stem = Stem(in_chs, out_chs, 7, num_rep=1, pool=pool_type, layers=layers)
stem = layers.conv_norm_act(in_chs, out_chs, 7, stride=2)
# 3x3 stem conv as in RegNet
stem = ConvBnAct(in_chs, out_chs, 3, stride=2, **layer_args)
# 3x3 stem conv as in RegNet is the default
if pool_type:
stem = Stem(in_chs, out_chs, 3, num_rep=1, pool=pool_type, layers=layers)
stem = layers.conv_norm_act(in_chs, out_chs, 3, stride=2)
return stem
if isinstance(stem, Stem):
feature_info = [dict(f, module='.'.join([feat_prefix, f['module']])) for f in stem.feature_info]
feature_info = [dict(num_chs=out_chs, reduction=2, module=feat_prefix)]
return stem, feature_info
def reduce_feat_size(feat_size, stride=2):
return None if feat_size is None else tuple([s // stride for s in feat_size])
def create_byob_stages(
cfg, drop_path_rate, output_stride, stem_feat,
feat_size=None, layers=None, extra_args_fn=None):
layers = layers or LayerFn()
feature_info = []
block_cfgs = [expand_blocks_cfg(s) for s in cfg.blocks]
depths = [sum([bc.d for bc in stage_bcs]) for stage_bcs in block_cfgs]
dpr = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depths)).split(depths)]
dilation = 1
net_stride = stem_feat['reduction']
prev_chs = stem_feat['num_chs']
prev_feat = stem_feat
stages = []
for stage_idx, stage_block_cfgs in enumerate(block_cfgs):
stride = stage_block_cfgs[0].s
if stride != 1 and prev_feat:
if net_stride >= output_stride and stride > 1:
dilation *= stride
stride = 1
net_stride *= stride
first_dilation = 1 if dilation in (1, 2) else 2
blocks = []
for block_idx, block_cfg in enumerate(stage_block_cfgs):
out_chs = make_divisible(block_cfg.c * cfg.width_factor)
group_size = block_cfg.gs
if isinstance(group_size, Callable):
group_size = group_size(out_chs, block_idx)
block_kwargs = dict( # Blocks used in this model must accept these arguments
stride=stride if block_idx == 0 else 1,
dilation=(first_dilation, dilation),
if extra_args_fn is not None:
extra_args_fn(block_kwargs, block_cfg=block_cfg, model_cfg=cfg, feat_size=feat_size)
blocks += [create_block(block_cfg.type, **block_kwargs)]
first_dilation = dilation
prev_chs = out_chs
if stride > 1 and block_idx == 0:
feat_size = reduce_feat_size(feat_size, stride)
stages += [nn.Sequential(*blocks)]
prev_feat = dict(num_chs=prev_chs, reduction=net_stride, module=f'stages.{stage_idx}')
return nn.Sequential(*stages), feature_info
def get_layer_fns(cfg: ByobCfg):
act = get_act_layer(cfg.act_layer)
norm_act = convert_norm_act(norm_layer=cfg.norm_layer, act_layer=act)
conv_norm_act = partial(ConvBnAct, norm_layer=cfg.norm_layer, act_layer=act)
attn = partial(get_attn(cfg.attn_layer), **cfg.attn_kwargs) if cfg.attn_layer else None
layer_fn = LayerFn(conv_norm_act=conv_norm_act, norm_act=norm_act, act=act, attn=attn)
return layer_fn
class ByobNet(nn.Module):
@ -546,79 +724,30 @@ class ByobNet(nn.Module):
self.num_classes = num_classes
self.drop_rate = drop_rate
norm_layer = cfg.norm_layer
act_layer = get_act_layer(cfg.act_layer)
attn_layer = partial(get_attn(cfg.attn_layer), **cfg.attn_kwargs) if cfg.attn_layer else None
layer_cfg = dict(norm_layer=norm_layer, act_layer=act_layer, attn_layer=attn_layer)
layers = get_layer_fns(cfg)
self.feature_info = []
stem_chs = int(round((cfg.stem_chs or cfg.blocks[0].c) * cfg.width_factor))
self.stem = create_stem(in_chans, stem_chs, cfg.stem_type, layer_cfg=layer_cfg)
self.stem, stem_feat = create_byob_stem(in_chans, stem_chs, cfg.stem_type, cfg.stem_pool, layers=layers)
self.feature_info = []
depths = [bc.d for bc in cfg.blocks]
dpr = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depths)).split(depths)]
prev_name = 'stem'
prev_chs = stem_chs
net_stride = 2
dilation = 1
stages = []
for stage_idx, block_cfg in enumerate(cfg.blocks):
stride = block_cfg.s
if stride != 1:
self.feature_info.append(dict(num_chs=prev_chs, reduction=net_stride, module=prev_name))
if net_stride >= output_stride and stride > 1:
dilation *= stride
stride = 1
net_stride *= stride
first_dilation = 1 if dilation in (1, 2) else 2
blocks = []
for block_idx in range(block_cfg.d):
out_chs = make_divisible(block_cfg.c * cfg.width_factor)
group_size = block_cfg.gs
if isinstance(group_size, Callable):
group_size = group_size(out_chs, block_idx)
block_kwargs = dict( # Blocks used in this model must accept these arguments
stride=stride if block_idx == 0 else 1,
dilation=(first_dilation, dilation),
blocks += [create_block(block_cfg.type, **block_kwargs)]
first_dilation = dilation
prev_chs = out_chs
stages += [nn.Sequential(*blocks)]
prev_name = f'stages.{stage_idx}'
self.stages = nn.Sequential(*stages)
self.stages, stage_feat = create_byob_stages(cfg, drop_path_rate, output_stride, stem_feat[-1], layers=layers)
prev_chs = stage_feat[-1]['num_chs']
if cfg.num_features:
self.num_features = int(round(cfg.width_factor * cfg.num_features))
self.final_conv = ConvBnAct(prev_chs, self.num_features, 1, **_na_args(layer_cfg))
self.final_conv = layers.conv_norm_act(prev_chs, self.num_features, 1)
self.num_features = prev_chs
self.final_conv = nn.Identity()
self.feature_info += [dict(num_chs=self.num_features, reduction=net_stride, module='final_conv')]
self.feature_info += [
dict(num_chs=self.num_features, reduction=stage_feat[-1]['reduction'], module='final_conv')]
self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate)
for n, m in self.named_modules():
if isinstance(m, nn.Conv2d):
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
fan_out //= m.groups
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
if m.bias is not None:
elif isinstance(m, nn.Linear):
nn.init.normal_(m.weight, mean=0.0, std=0.01)
elif isinstance(m, nn.BatchNorm2d):
_init_weights(m, n)
for m in self.modules():
# call each block's weight init for block-specific overrides to init above
if hasattr(m, 'init_weights'):
@ -642,6 +771,22 @@ class ByobNet(nn.Module):
return x
def _init_weights(m, n=''):
if isinstance(m, nn.Conv2d):
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
fan_out //= m.groups
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
if m.bias is not None:
elif isinstance(m, nn.Linear):
nn.init.normal_(m.weight, mean=0.0, std=0.01)
if m.bias is not None:
elif isinstance(m, nn.BatchNorm2d):
def _create_byobnet(variant, pretrained=False, **kwargs):
return build_model_with_cfg(
ByobNet, variant, pretrained,