|
|
|
@ -25,9 +25,9 @@ above nets that include attention.
|
|
|
|
|
Hacked together by / copyright Ross Wightman, 2021.
|
|
|
|
|
"""
|
|
|
|
|
import math
|
|
|
|
|
from dataclasses import dataclass, field
|
|
|
|
|
from dataclasses import dataclass, field, replace
|
|
|
|
|
from collections import OrderedDict
|
|
|
|
|
from typing import Tuple, Dict, Optional, Union, Any, Callable
|
|
|
|
|
from typing import Tuple, List, Optional, Union, Any, Callable, Sequence
|
|
|
|
|
from functools import partial
|
|
|
|
|
|
|
|
|
|
import torch
|
|
|
|
@ -35,11 +35,11 @@ import torch.nn as nn
|
|
|
|
|
|
|
|
|
|
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
|
|
|
|
from .helpers import build_model_with_cfg
|
|
|
|
|
from .layers import ClassifierHead, ConvBnAct, DropPath, AvgPool2dSame, \
|
|
|
|
|
create_conv2d, get_act_layer, get_attn, convert_norm_act, make_divisible
|
|
|
|
|
from .layers import ClassifierHead, ConvBnAct, BatchNormAct2d, DropPath, AvgPool2dSame, \
|
|
|
|
|
create_conv2d, get_act_layer, convert_norm_act, get_attn, make_divisible
|
|
|
|
|
from .registry import register_model
|
|
|
|
|
|
|
|
|
|
__all__ = ['ByobNet', 'ByobCfg', 'BlocksCfg']
|
|
|
|
|
__all__ = ['ByobNet', 'ByobCfg', 'BlocksCfg', 'create_byob_stem', 'create_block']
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _cfg(url='', **kwargs):
|
|
|
|
@ -98,20 +98,22 @@ class BlocksCfg:
|
|
|
|
|
s: int = 2 # stride of stage (first block)
|
|
|
|
|
gs: Optional[Union[int, Callable]] = None # group-size of blocks in stage, conv is depthwise if gs == 1
|
|
|
|
|
br: float = 1. # bottleneck-ratio of blocks in stage
|
|
|
|
|
no_attn: bool = True # disable channel attn (ie SE) when layer is set for model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
|
|
class ByobCfg:
|
|
|
|
|
blocks: Tuple[BlocksCfg, ...]
|
|
|
|
|
blocks: Tuple[Union[BlocksCfg, Tuple[BlocksCfg, ...]], ...]
|
|
|
|
|
downsample: str = 'conv1x1'
|
|
|
|
|
stem_type: str = '3x3'
|
|
|
|
|
stem_pool: str = ''
|
|
|
|
|
stem_chs: int = 32
|
|
|
|
|
width_factor: float = 1.0
|
|
|
|
|
num_features: int = 0 # num out_channels for final conv, no final 1x1 conv if 0
|
|
|
|
|
zero_init_last_bn: bool = True
|
|
|
|
|
|
|
|
|
|
act_layer: str = 'relu'
|
|
|
|
|
norm_layer: nn.Module = nn.BatchNorm2d
|
|
|
|
|
norm_layer: str = 'batchnorm'
|
|
|
|
|
attn_layer: Optional[str] = None
|
|
|
|
|
attn_kwargs: dict = field(default_factory=lambda: dict())
|
|
|
|
|
|
|
|
|
@ -201,17 +203,29 @@ model_cfgs = dict(
|
|
|
|
|
stem_type='rep',
|
|
|
|
|
stem_chs=64,
|
|
|
|
|
),
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _na_args(cfg: dict):
|
|
|
|
|
return dict(
|
|
|
|
|
norm_layer=cfg.get('norm_layer', nn.BatchNorm2d),
|
|
|
|
|
act_layer=cfg.get('act_layer', nn.ReLU))
|
|
|
|
|
resnet52q=ByobCfg(
|
|
|
|
|
blocks=(
|
|
|
|
|
BlocksCfg(type='bottle', d=2, c=256, s=1, gs=32, br=0.25),
|
|
|
|
|
BlocksCfg(type='bottle', d=4, c=512, s=2, gs=32, br=0.25),
|
|
|
|
|
BlocksCfg(type='bottle', d=6, c=1536, s=2, gs=32, br=0.25),
|
|
|
|
|
BlocksCfg(type='bottle', d=4, c=1536, s=2, gs=1, br=1.0),
|
|
|
|
|
),
|
|
|
|
|
stem_chs=128,
|
|
|
|
|
stem_type='quad',
|
|
|
|
|
num_features=2048,
|
|
|
|
|
act_layer='silu',
|
|
|
|
|
),
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _ex_tuple(cfg: dict, *names):
|
|
|
|
|
return tuple([cfg.get(n, None) for n in names])
|
|
|
|
|
def expand_blocks_cfg(stage_blocks_cfg: Union[BlocksCfg, Sequence[BlocksCfg]]) -> List[BlocksCfg]:
|
|
|
|
|
if not isinstance(stage_blocks_cfg, Sequence):
|
|
|
|
|
stage_blocks_cfg = (stage_blocks_cfg,)
|
|
|
|
|
block_cfgs = []
|
|
|
|
|
for i, cfg in enumerate(stage_blocks_cfg):
|
|
|
|
|
block_cfgs += [replace(cfg, d=1) for _ in range(cfg.d)]
|
|
|
|
|
return block_cfgs
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def num_groups(group_size, channels):
|
|
|
|
@ -223,27 +237,36 @@ def num_groups(group_size, channels):
|
|
|
|
|
return channels // group_size
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
|
|
class LayerFn:
|
|
|
|
|
conv_norm_act: Callable = ConvBnAct
|
|
|
|
|
norm_act: Callable = BatchNormAct2d
|
|
|
|
|
act: Callable = nn.ReLU
|
|
|
|
|
attn: Optional[Callable] = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class DownsampleAvg(nn.Module):
|
|
|
|
|
def __init__(self, in_chs, out_chs, stride=1, dilation=1, apply_act=False, norm_layer=None, act_layer=None):
|
|
|
|
|
def __init__(self, in_chs, out_chs, stride=1, dilation=1, apply_act=False, layers: LayerFn = None):
|
|
|
|
|
""" AvgPool Downsampling as in 'D' ResNet variants."""
|
|
|
|
|
super(DownsampleAvg, self).__init__()
|
|
|
|
|
layers = layers or LayerFn()
|
|
|
|
|
avg_stride = stride if dilation == 1 else 1
|
|
|
|
|
if stride > 1 or dilation > 1:
|
|
|
|
|
avg_pool_fn = AvgPool2dSame if avg_stride == 1 and dilation > 1 else nn.AvgPool2d
|
|
|
|
|
self.pool = avg_pool_fn(2, avg_stride, ceil_mode=True, count_include_pad=False)
|
|
|
|
|
else:
|
|
|
|
|
self.pool = nn.Identity()
|
|
|
|
|
self.conv = ConvBnAct(in_chs, out_chs, 1, apply_act=apply_act, norm_layer=norm_layer, act_layer=act_layer)
|
|
|
|
|
self.conv = layers.conv_norm_act(in_chs, out_chs, 1, apply_act=apply_act)
|
|
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
|
return self.conv(self.pool(x))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def create_downsample(type, **kwargs):
|
|
|
|
|
if type == 'avg':
|
|
|
|
|
def create_downsample(downsample_type, layers: LayerFn, **kwargs):
|
|
|
|
|
if downsample_type == 'avg':
|
|
|
|
|
return DownsampleAvg(**kwargs)
|
|
|
|
|
else:
|
|
|
|
|
return ConvBnAct(kwargs.pop('in_chs'), kwargs.pop('out_chs'), kernel_size=1, **kwargs)
|
|
|
|
|
return layers.conv_norm_act(kwargs.pop('in_chs'), kwargs.pop('out_chs'), kernel_size=1, **kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class BasicBlock(nn.Module):
|
|
|
|
@ -252,28 +275,25 @@ class BasicBlock(nn.Module):
|
|
|
|
|
|
|
|
|
|
def __init__(
|
|
|
|
|
self, in_chs, out_chs, kernel_size=3, stride=1, dilation=(1, 1), group_size=None, bottle_ratio=1.0,
|
|
|
|
|
downsample='avg', linear_out=False, layer_cfg=None, drop_block=None, drop_path_rate=0.):
|
|
|
|
|
downsample='avg', linear_out=False, layers: LayerFn = None, drop_block=None, drop_path_rate=0.):
|
|
|
|
|
super(BasicBlock, self).__init__()
|
|
|
|
|
layer_cfg = layer_cfg or {}
|
|
|
|
|
act_layer, attn_layer = _ex_tuple(layer_cfg, 'act_layer', 'attn_layer')
|
|
|
|
|
layer_args = _na_args(layer_cfg)
|
|
|
|
|
layers = layers or LayerFn()
|
|
|
|
|
mid_chs = make_divisible(out_chs * bottle_ratio)
|
|
|
|
|
groups = num_groups(group_size, mid_chs)
|
|
|
|
|
|
|
|
|
|
if in_chs != out_chs or stride != 1 or dilation[0] != dilation[1]:
|
|
|
|
|
self.shortcut = create_downsample(
|
|
|
|
|
downsample, in_chs=in_chs, out_chs=out_chs, stride=stride, dilation=dilation[0],
|
|
|
|
|
apply_act=False, **layer_args)
|
|
|
|
|
apply_act=False, layers=layers)
|
|
|
|
|
else:
|
|
|
|
|
self.shortcut = nn.Identity()
|
|
|
|
|
|
|
|
|
|
self.conv1_kxk = ConvBnAct(in_chs, mid_chs, kernel_size, stride=stride, dilation=dilation[0], **layer_args)
|
|
|
|
|
self.conv2_kxk = ConvBnAct(
|
|
|
|
|
mid_chs, out_chs, kernel_size, dilation=dilation[1], groups=groups,
|
|
|
|
|
drop_block=drop_block, apply_act=False, **layer_args)
|
|
|
|
|
self.attn = nn.Identity() if attn_layer is None else attn_layer(out_chs)
|
|
|
|
|
self.conv1_kxk = layers.conv_norm_act(in_chs, mid_chs, kernel_size, stride=stride, dilation=dilation[0])
|
|
|
|
|
self.conv2_kxk = layers.conv_norm_act(
|
|
|
|
|
mid_chs, out_chs, kernel_size, dilation=dilation[1], groups=groups, drop_block=drop_block, apply_act=False)
|
|
|
|
|
self.attn = nn.Identity() if layers.attn is None else layers.attn(out_chs)
|
|
|
|
|
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity()
|
|
|
|
|
self.act = nn.Identity() if linear_out else act_layer(inplace=True)
|
|
|
|
|
self.act = nn.Identity() if linear_out else layers.act(inplace=True)
|
|
|
|
|
|
|
|
|
|
def init_weights(self, zero_init_last_bn=False):
|
|
|
|
|
if zero_init_last_bn:
|
|
|
|
@ -297,29 +317,27 @@ class BottleneckBlock(nn.Module):
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def __init__(self, in_chs, out_chs, kernel_size=3, stride=1, dilation=(1, 1), bottle_ratio=1., group_size=None,
|
|
|
|
|
downsample='avg', linear_out=False, layer_cfg=None, drop_block=None, drop_path_rate=0.):
|
|
|
|
|
downsample='avg', linear_out=False, layers : LayerFn = None, drop_block=None, drop_path_rate=0.):
|
|
|
|
|
super(BottleneckBlock, self).__init__()
|
|
|
|
|
layer_cfg = layer_cfg or {}
|
|
|
|
|
act_layer, attn_layer = _ex_tuple(layer_cfg, 'act_layer', 'attn_layer')
|
|
|
|
|
layer_args = _na_args(layer_cfg)
|
|
|
|
|
layers = layers or LayerFn()
|
|
|
|
|
mid_chs = make_divisible(out_chs * bottle_ratio)
|
|
|
|
|
groups = num_groups(group_size, mid_chs)
|
|
|
|
|
|
|
|
|
|
if in_chs != out_chs or stride != 1 or dilation[0] != dilation[1]:
|
|
|
|
|
self.shortcut = create_downsample(
|
|
|
|
|
downsample, in_chs=in_chs, out_chs=out_chs, stride=stride, dilation=dilation[0],
|
|
|
|
|
apply_act=False, **layer_args)
|
|
|
|
|
apply_act=False, layers=layers)
|
|
|
|
|
else:
|
|
|
|
|
self.shortcut = nn.Identity()
|
|
|
|
|
|
|
|
|
|
self.conv1_1x1 = ConvBnAct(in_chs, mid_chs, 1, **layer_args)
|
|
|
|
|
self.conv2_kxk = ConvBnAct(
|
|
|
|
|
self.conv1_1x1 = layers.conv_norm_act(in_chs, mid_chs, 1)
|
|
|
|
|
self.conv2_kxk = layers.conv_norm_act(
|
|
|
|
|
mid_chs, mid_chs, kernel_size, stride=stride, dilation=dilation[0],
|
|
|
|
|
groups=groups, drop_block=drop_block, **layer_args)
|
|
|
|
|
self.attn = nn.Identity() if attn_layer is None else attn_layer(mid_chs)
|
|
|
|
|
self.conv3_1x1 = ConvBnAct(mid_chs, out_chs, 1, apply_act=False, **layer_args)
|
|
|
|
|
groups=groups, drop_block=drop_block)
|
|
|
|
|
self.attn = nn.Identity() if layers.attn is None else layers.attn(mid_chs)
|
|
|
|
|
self.conv3_1x1 = layers.conv_norm_act(mid_chs, out_chs, 1, apply_act=False)
|
|
|
|
|
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity()
|
|
|
|
|
self.act = nn.Identity() if linear_out else act_layer(inplace=True)
|
|
|
|
|
self.act = nn.Identity() if linear_out else layers.act(inplace=True)
|
|
|
|
|
|
|
|
|
|
def init_weights(self, zero_init_last_bn=False):
|
|
|
|
|
if zero_init_last_bn:
|
|
|
|
@ -350,28 +368,26 @@ class DarkBlock(nn.Module):
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def __init__(self, in_chs, out_chs, kernel_size=3, stride=1, dilation=(1, 1), bottle_ratio=1.0, group_size=None,
|
|
|
|
|
downsample='avg', linear_out=False, layer_cfg=None, drop_block=None, drop_path_rate=0.):
|
|
|
|
|
downsample='avg', linear_out=False, layers: LayerFn = None, drop_block=None, drop_path_rate=0.):
|
|
|
|
|
super(DarkBlock, self).__init__()
|
|
|
|
|
layer_cfg = layer_cfg or {}
|
|
|
|
|
act_layer, attn_layer = _ex_tuple(layer_cfg, 'act_layer', 'attn_layer')
|
|
|
|
|
layer_args = _na_args(layer_cfg)
|
|
|
|
|
layers = layers or LayerFn()
|
|
|
|
|
mid_chs = make_divisible(out_chs * bottle_ratio)
|
|
|
|
|
groups = num_groups(group_size, mid_chs)
|
|
|
|
|
|
|
|
|
|
if in_chs != out_chs or stride != 1 or dilation[0] != dilation[1]:
|
|
|
|
|
self.shortcut = create_downsample(
|
|
|
|
|
downsample, in_chs=in_chs, out_chs=out_chs, stride=stride, dilation=dilation[0],
|
|
|
|
|
apply_act=False, **layer_args)
|
|
|
|
|
apply_act=False, layers=layers)
|
|
|
|
|
else:
|
|
|
|
|
self.shortcut = nn.Identity()
|
|
|
|
|
|
|
|
|
|
self.conv1_1x1 = ConvBnAct(in_chs, mid_chs, 1, **layer_args)
|
|
|
|
|
self.conv2_kxk = ConvBnAct(
|
|
|
|
|
self.conv1_1x1 = layers.conv_norm_act(in_chs, mid_chs, 1)
|
|
|
|
|
self.conv2_kxk = layers.conv_norm_act(
|
|
|
|
|
mid_chs, out_chs, kernel_size, stride=stride, dilation=dilation[0],
|
|
|
|
|
groups=groups, drop_block=drop_block, apply_act=False, **layer_args)
|
|
|
|
|
self.attn = nn.Identity() if attn_layer is None else attn_layer(out_chs)
|
|
|
|
|
groups=groups, drop_block=drop_block, apply_act=False)
|
|
|
|
|
self.attn = nn.Identity() if layers.attn is None else layers.attn(out_chs)
|
|
|
|
|
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity()
|
|
|
|
|
self.act = nn.Identity() if linear_out else act_layer(inplace=True)
|
|
|
|
|
self.act = nn.Identity() if linear_out else layers.act(inplace=True)
|
|
|
|
|
|
|
|
|
|
def init_weights(self, zero_init_last_bn=False):
|
|
|
|
|
if zero_init_last_bn:
|
|
|
|
@ -399,28 +415,26 @@ class EdgeBlock(nn.Module):
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def __init__(self, in_chs, out_chs, kernel_size=3, stride=1, dilation=(1, 1), bottle_ratio=1.0, group_size=None,
|
|
|
|
|
downsample='avg', linear_out=False, layer_cfg=None, drop_block=None, drop_path_rate=0.):
|
|
|
|
|
downsample='avg', linear_out=False, layers: LayerFn = None, drop_block=None, drop_path_rate=0.):
|
|
|
|
|
super(EdgeBlock, self).__init__()
|
|
|
|
|
layer_cfg = layer_cfg or {}
|
|
|
|
|
act_layer, attn_layer = _ex_tuple(layer_cfg, 'act_layer', 'attn_layer')
|
|
|
|
|
layer_args = _na_args(layer_cfg)
|
|
|
|
|
layers = layers or LayerFn()
|
|
|
|
|
mid_chs = make_divisible(out_chs * bottle_ratio)
|
|
|
|
|
groups = num_groups(group_size, mid_chs)
|
|
|
|
|
|
|
|
|
|
if in_chs != out_chs or stride != 1 or dilation[0] != dilation[1]:
|
|
|
|
|
self.shortcut = create_downsample(
|
|
|
|
|
downsample, in_chs=in_chs, out_chs=out_chs, stride=stride, dilation=dilation[0],
|
|
|
|
|
apply_act=False, **layer_args)
|
|
|
|
|
apply_act=False, layers=layers)
|
|
|
|
|
else:
|
|
|
|
|
self.shortcut = nn.Identity()
|
|
|
|
|
|
|
|
|
|
self.conv1_kxk = ConvBnAct(
|
|
|
|
|
self.conv1_kxk = layers.conv_norm_act(
|
|
|
|
|
in_chs, mid_chs, kernel_size, stride=stride, dilation=dilation[0],
|
|
|
|
|
groups=groups, drop_block=drop_block, **layer_args)
|
|
|
|
|
self.attn = nn.Identity() if attn_layer is None else attn_layer(out_chs)
|
|
|
|
|
self.conv2_1x1 = ConvBnAct(mid_chs, out_chs, 1, apply_act=False, **layer_args)
|
|
|
|
|
groups=groups, drop_block=drop_block)
|
|
|
|
|
self.attn = nn.Identity() if layers.attn is None else layers.attn(out_chs)
|
|
|
|
|
self.conv2_1x1 = layers.conv_norm_act(mid_chs, out_chs, 1, apply_act=False)
|
|
|
|
|
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity()
|
|
|
|
|
self.act = nn.Identity() if linear_out else act_layer(inplace=True)
|
|
|
|
|
self.act = nn.Identity() if linear_out else layers.act(inplace=True)
|
|
|
|
|
|
|
|
|
|
def init_weights(self, zero_init_last_bn=False):
|
|
|
|
|
if zero_init_last_bn:
|
|
|
|
@ -446,23 +460,20 @@ class RepVggBlock(nn.Module):
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def __init__(self, in_chs, out_chs, kernel_size=3, stride=1, dilation=(1, 1), bottle_ratio=1.0, group_size=None,
|
|
|
|
|
downsample='', layer_cfg=None, drop_block=None, drop_path_rate=0.):
|
|
|
|
|
downsample='', layers : LayerFn = None, drop_block=None, drop_path_rate=0.):
|
|
|
|
|
super(RepVggBlock, self).__init__()
|
|
|
|
|
layer_cfg = layer_cfg or {}
|
|
|
|
|
act_layer, norm_layer, attn_layer = _ex_tuple(layer_cfg, 'act_layer', 'norm_layer', 'attn_layer')
|
|
|
|
|
norm_layer = convert_norm_act(norm_layer=norm_layer, act_layer=act_layer)
|
|
|
|
|
layer_args = _na_args(layer_cfg)
|
|
|
|
|
layers = layers or LayerFn()
|
|
|
|
|
groups = num_groups(group_size, in_chs)
|
|
|
|
|
|
|
|
|
|
use_ident = in_chs == out_chs and stride == 1 and dilation[0] == dilation[1]
|
|
|
|
|
self.identity = norm_layer(out_chs, apply_act=False) if use_ident else None
|
|
|
|
|
self.conv_kxk = ConvBnAct(
|
|
|
|
|
self.identity = layers.norm_act(out_chs, apply_act=False) if use_ident else None
|
|
|
|
|
self.conv_kxk = layers.conv_norm_act(
|
|
|
|
|
in_chs, out_chs, kernel_size, stride=stride, dilation=dilation[0],
|
|
|
|
|
groups=groups, drop_block=drop_block, apply_act=False, **layer_args)
|
|
|
|
|
self.conv_1x1 = ConvBnAct(in_chs, out_chs, 1, stride=stride, groups=groups, apply_act=False, **layer_args)
|
|
|
|
|
self.attn = nn.Identity() if attn_layer is None else attn_layer(out_chs)
|
|
|
|
|
groups=groups, drop_block=drop_block, apply_act=False)
|
|
|
|
|
self.conv_1x1 = layers.conv_norm_act(in_chs, out_chs, 1, stride=stride, groups=groups, apply_act=False)
|
|
|
|
|
self.attn = nn.Identity() if layers.attn is None else layers.attn(out_chs)
|
|
|
|
|
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. and use_ident else nn.Identity()
|
|
|
|
|
self.act = act_layer(inplace=True)
|
|
|
|
|
self.act = layers.act(inplace=True)
|
|
|
|
|
|
|
|
|
|
def init_weights(self, zero_init_last_bn=False):
|
|
|
|
|
# NOTE this init overrides that base model init with specific changes for the block type
|
|
|
|
@ -504,33 +515,200 @@ def create_block(block: Union[str, nn.Module], **kwargs):
|
|
|
|
|
return _block_registry[block](**kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def create_stem(in_chs, out_chs, stem_type='', layer_cfg=None):
|
|
|
|
|
layer_cfg = layer_cfg or {}
|
|
|
|
|
layer_args = _na_args(layer_cfg)
|
|
|
|
|
assert stem_type in ('', 'deep', 'deep_tiered', '3x3', '7x7', 'rep')
|
|
|
|
|
if 'deep' in stem_type:
|
|
|
|
|
# 3 deep 3x3 conv stack
|
|
|
|
|
stem = OrderedDict()
|
|
|
|
|
stem_chs = (out_chs // 2, out_chs // 2)
|
|
|
|
|
if 'tiered' in stem_type:
|
|
|
|
|
stem_chs = (3 * stem_chs[0] // 4, stem_chs[1])
|
|
|
|
|
norm_layer, act_layer = _ex_tuple(layer_args, 'norm_layer', 'act_layer')
|
|
|
|
|
stem['conv1'] = create_conv2d(in_chs, stem_chs[0], kernel_size=3, stride=2)
|
|
|
|
|
stem['conv2'] = create_conv2d(stem_chs[0], stem_chs[1], kernel_size=3, stride=1)
|
|
|
|
|
stem['conv3'] = create_conv2d(stem_chs[1], out_chs, kernel_size=3, stride=1)
|
|
|
|
|
norm_act_layer = convert_norm_act(norm_layer=norm_layer, act_layer=act_layer)
|
|
|
|
|
stem['na'] = norm_act_layer(out_chs)
|
|
|
|
|
stem = nn.Sequential(stem)
|
|
|
|
|
# class Stem(nn.Module):
|
|
|
|
|
#
|
|
|
|
|
# def __init__(self, in_chs, out_chs, kernel_size=3, stride=4, pool='maxpool',
|
|
|
|
|
# num_rep=3, num_act=None, chs_decay=0.5, layers: LayerFn = None):
|
|
|
|
|
# super().__init__()
|
|
|
|
|
# assert stride in (2, 4)
|
|
|
|
|
# if pool:
|
|
|
|
|
# assert stride == 4
|
|
|
|
|
# layers = layers or LayerFn()
|
|
|
|
|
#
|
|
|
|
|
# if isinstance(out_chs, (list, tuple)):
|
|
|
|
|
# num_rep = len(out_chs)
|
|
|
|
|
# stem_chs = out_chs
|
|
|
|
|
# else:
|
|
|
|
|
# stem_chs = [round(out_chs * chs_decay ** i) for i in range(num_rep)][::-1]
|
|
|
|
|
#
|
|
|
|
|
# self.stride = stride
|
|
|
|
|
# stem_strides = [2] + [1] * (num_rep - 1)
|
|
|
|
|
# if stride == 4 and not pool:
|
|
|
|
|
# # set last conv in stack to be strided if stride == 4 and no pooling layer
|
|
|
|
|
# stem_strides[-1] = 2
|
|
|
|
|
#
|
|
|
|
|
# num_act = num_rep if num_act is None else num_act
|
|
|
|
|
# # if num_act < num_rep, first convs in stack won't have bn + act
|
|
|
|
|
# stem_norm_acts = [False] * (num_rep - num_act) + [True] * num_act
|
|
|
|
|
# prev_chs = in_chs
|
|
|
|
|
# convs = []
|
|
|
|
|
# for i, (ch, s, na) in enumerate(zip(stem_chs, stem_strides, stem_norm_acts)):
|
|
|
|
|
# layer_fn = layers.conv_norm_act if na else create_conv2d
|
|
|
|
|
# convs.append(layer_fn(prev_chs, ch, kernel_size=kernel_size, stride=s))
|
|
|
|
|
# prev_chs = ch
|
|
|
|
|
# self.conv = nn.Sequential(*convs) if len(convs) > 1 else convs[0]
|
|
|
|
|
#
|
|
|
|
|
# if not pool:
|
|
|
|
|
# self.pool = nn.Identity()
|
|
|
|
|
# elif 'max' in pool.lower():
|
|
|
|
|
# self.pool = nn.MaxPool2d(3, 2, 1) if pool else nn.Identity()
|
|
|
|
|
# else:
|
|
|
|
|
# assert False, "Unknown pooling type"
|
|
|
|
|
#
|
|
|
|
|
# def forward(self, x):
|
|
|
|
|
# x = self.conv(x)
|
|
|
|
|
# x = self.pool(x)
|
|
|
|
|
# return x
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Stem(nn.Sequential):
|
|
|
|
|
|
|
|
|
|
def __init__(self, in_chs, out_chs, kernel_size=3, stride=4, pool='maxpool',
|
|
|
|
|
num_rep=3, num_act=None, chs_decay=0.5, layers: LayerFn = None):
|
|
|
|
|
super().__init__()
|
|
|
|
|
assert stride in (2, 4)
|
|
|
|
|
layers = layers or LayerFn()
|
|
|
|
|
|
|
|
|
|
if isinstance(out_chs, (list, tuple)):
|
|
|
|
|
num_rep = len(out_chs)
|
|
|
|
|
stem_chs = out_chs
|
|
|
|
|
else:
|
|
|
|
|
stem_chs = [round(out_chs * chs_decay ** i) for i in range(num_rep)][::-1]
|
|
|
|
|
|
|
|
|
|
self.stride = stride
|
|
|
|
|
self.feature_info = [] # track intermediate features
|
|
|
|
|
prev_feat = ''
|
|
|
|
|
stem_strides = [2] + [1] * (num_rep - 1)
|
|
|
|
|
if stride == 4 and not pool:
|
|
|
|
|
# set last conv in stack to be strided if stride == 4 and no pooling layer
|
|
|
|
|
stem_strides[-1] = 2
|
|
|
|
|
|
|
|
|
|
num_act = num_rep if num_act is None else num_act
|
|
|
|
|
# if num_act < num_rep, first convs in stack won't have bn + act
|
|
|
|
|
stem_norm_acts = [False] * (num_rep - num_act) + [True] * num_act
|
|
|
|
|
prev_chs = in_chs
|
|
|
|
|
curr_stride = 1
|
|
|
|
|
for i, (ch, s, na) in enumerate(zip(stem_chs, stem_strides, stem_norm_acts)):
|
|
|
|
|
layer_fn = layers.conv_norm_act if na else create_conv2d
|
|
|
|
|
conv_name = f'conv{i + 1}'
|
|
|
|
|
if i > 0 and s > 1:
|
|
|
|
|
self.feature_info.append(dict(num_chs=ch, reduction=curr_stride, module=prev_feat))
|
|
|
|
|
self.add_module(conv_name, layer_fn(prev_chs, ch, kernel_size=kernel_size, stride=s))
|
|
|
|
|
prev_chs = ch
|
|
|
|
|
curr_stride *= s
|
|
|
|
|
prev_feat = conv_name
|
|
|
|
|
|
|
|
|
|
if 'max' in pool.lower():
|
|
|
|
|
self.feature_info.append(dict(num_chs=prev_chs, reduction=curr_stride, module=prev_feat))
|
|
|
|
|
self.add_module('pool', nn.MaxPool2d(3, 2, 1))
|
|
|
|
|
curr_stride *= 2
|
|
|
|
|
prev_feat = 'pool'
|
|
|
|
|
|
|
|
|
|
self.feature_info.append(dict(num_chs=prev_chs, reduction=curr_stride, module=prev_feat))
|
|
|
|
|
assert curr_stride == stride
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def create_byob_stem(in_chs, out_chs, stem_type='', pool_type='', feat_prefix='stem', layers: LayerFn = None):
|
|
|
|
|
layers = layers or LayerFn()
|
|
|
|
|
assert stem_type in ('', 'quad', 'tiered', 'deep', 'rep', '7x7', '3x3')
|
|
|
|
|
if 'quad' in stem_type:
|
|
|
|
|
# based on NFNet stem, stack of 4 3x3 convs
|
|
|
|
|
num_act = 2 if 'quad2' in stem_type else None
|
|
|
|
|
stem = Stem(in_chs, out_chs, num_rep=4, num_act=num_act, pool=pool_type, layers=layers)
|
|
|
|
|
elif 'tiered' in stem_type:
|
|
|
|
|
# 3x3 stack of 3 convs as in my ResNet-T
|
|
|
|
|
stem = Stem(in_chs, (3 * out_chs // 8, out_chs // 2, out_chs), pool=pool_type, layers=layers)
|
|
|
|
|
elif 'deep' in stem_type:
|
|
|
|
|
# 3x3 stack of 3 convs as in ResNet-D
|
|
|
|
|
stem = Stem(in_chs, out_chs, num_rep=3, chs_decay=1.0, pool=pool_type, layers=layers)
|
|
|
|
|
elif 'rep' in stem_type:
|
|
|
|
|
stem = RepVggBlock(in_chs, out_chs, stride=2, layers=layers)
|
|
|
|
|
elif '7x7' in stem_type:
|
|
|
|
|
# 7x7 stem conv as in ResNet
|
|
|
|
|
stem = ConvBnAct(in_chs, out_chs, 7, stride=2, **layer_args)
|
|
|
|
|
elif 'rep' in stem_type:
|
|
|
|
|
stem = RepVggBlock(in_chs, out_chs, stride=2, layer_cfg=layer_cfg)
|
|
|
|
|
if pool_type:
|
|
|
|
|
stem = Stem(in_chs, out_chs, 7, num_rep=1, pool=pool_type, layers=layers)
|
|
|
|
|
else:
|
|
|
|
|
stem = layers.conv_norm_act(in_chs, out_chs, 7, stride=2)
|
|
|
|
|
else:
|
|
|
|
|
# 3x3 stem conv as in RegNet
|
|
|
|
|
stem = ConvBnAct(in_chs, out_chs, 3, stride=2, **layer_args)
|
|
|
|
|
# 3x3 stem conv as in RegNet is the default
|
|
|
|
|
if pool_type:
|
|
|
|
|
stem = Stem(in_chs, out_chs, 3, num_rep=1, pool=pool_type, layers=layers)
|
|
|
|
|
else:
|
|
|
|
|
stem = layers.conv_norm_act(in_chs, out_chs, 3, stride=2)
|
|
|
|
|
|
|
|
|
|
return stem
|
|
|
|
|
if isinstance(stem, Stem):
|
|
|
|
|
feature_info = [dict(f, module='.'.join([feat_prefix, f['module']])) for f in stem.feature_info]
|
|
|
|
|
else:
|
|
|
|
|
feature_info = [dict(num_chs=out_chs, reduction=2, module=feat_prefix)]
|
|
|
|
|
return stem, feature_info
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def reduce_feat_size(feat_size, stride=2):
|
|
|
|
|
return None if feat_size is None else tuple([s // stride for s in feat_size])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def create_byob_stages(
|
|
|
|
|
cfg, drop_path_rate, output_stride, stem_feat,
|
|
|
|
|
feat_size=None, layers=None, extra_args_fn=None):
|
|
|
|
|
layers = layers or LayerFn()
|
|
|
|
|
feature_info = []
|
|
|
|
|
block_cfgs = [expand_blocks_cfg(s) for s in cfg.blocks]
|
|
|
|
|
depths = [sum([bc.d for bc in stage_bcs]) for stage_bcs in block_cfgs]
|
|
|
|
|
dpr = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depths)).split(depths)]
|
|
|
|
|
dilation = 1
|
|
|
|
|
net_stride = stem_feat['reduction']
|
|
|
|
|
prev_chs = stem_feat['num_chs']
|
|
|
|
|
prev_feat = stem_feat
|
|
|
|
|
stages = []
|
|
|
|
|
for stage_idx, stage_block_cfgs in enumerate(block_cfgs):
|
|
|
|
|
stride = stage_block_cfgs[0].s
|
|
|
|
|
if stride != 1 and prev_feat:
|
|
|
|
|
feature_info.append(prev_feat)
|
|
|
|
|
if net_stride >= output_stride and stride > 1:
|
|
|
|
|
dilation *= stride
|
|
|
|
|
stride = 1
|
|
|
|
|
net_stride *= stride
|
|
|
|
|
first_dilation = 1 if dilation in (1, 2) else 2
|
|
|
|
|
|
|
|
|
|
blocks = []
|
|
|
|
|
for block_idx, block_cfg in enumerate(stage_block_cfgs):
|
|
|
|
|
out_chs = make_divisible(block_cfg.c * cfg.width_factor)
|
|
|
|
|
group_size = block_cfg.gs
|
|
|
|
|
if isinstance(group_size, Callable):
|
|
|
|
|
group_size = group_size(out_chs, block_idx)
|
|
|
|
|
block_kwargs = dict( # Blocks used in this model must accept these arguments
|
|
|
|
|
in_chs=prev_chs,
|
|
|
|
|
out_chs=out_chs,
|
|
|
|
|
stride=stride if block_idx == 0 else 1,
|
|
|
|
|
dilation=(first_dilation, dilation),
|
|
|
|
|
group_size=group_size,
|
|
|
|
|
bottle_ratio=block_cfg.br,
|
|
|
|
|
downsample=cfg.downsample,
|
|
|
|
|
drop_path_rate=dpr[stage_idx][block_idx],
|
|
|
|
|
layers=layers,
|
|
|
|
|
)
|
|
|
|
|
if extra_args_fn is not None:
|
|
|
|
|
extra_args_fn(block_kwargs, block_cfg=block_cfg, model_cfg=cfg, feat_size=feat_size)
|
|
|
|
|
blocks += [create_block(block_cfg.type, **block_kwargs)]
|
|
|
|
|
first_dilation = dilation
|
|
|
|
|
prev_chs = out_chs
|
|
|
|
|
if stride > 1 and block_idx == 0:
|
|
|
|
|
feat_size = reduce_feat_size(feat_size, stride)
|
|
|
|
|
|
|
|
|
|
stages += [nn.Sequential(*blocks)]
|
|
|
|
|
prev_feat = dict(num_chs=prev_chs, reduction=net_stride, module=f'stages.{stage_idx}')
|
|
|
|
|
|
|
|
|
|
feature_info.append(prev_feat)
|
|
|
|
|
return nn.Sequential(*stages), feature_info
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_layer_fns(cfg: ByobCfg):
|
|
|
|
|
act = get_act_layer(cfg.act_layer)
|
|
|
|
|
norm_act = convert_norm_act(norm_layer=cfg.norm_layer, act_layer=act)
|
|
|
|
|
conv_norm_act = partial(ConvBnAct, norm_layer=cfg.norm_layer, act_layer=act)
|
|
|
|
|
attn = partial(get_attn(cfg.attn_layer), **cfg.attn_kwargs) if cfg.attn_layer else None
|
|
|
|
|
layer_fn = LayerFn(conv_norm_act=conv_norm_act, norm_act=norm_act, act=act, attn=attn)
|
|
|
|
|
return layer_fn
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ByobNet(nn.Module):
|
|
|
|
@ -546,79 +724,30 @@ class ByobNet(nn.Module):
|
|
|
|
|
super().__init__()
|
|
|
|
|
self.num_classes = num_classes
|
|
|
|
|
self.drop_rate = drop_rate
|
|
|
|
|
norm_layer = cfg.norm_layer
|
|
|
|
|
act_layer = get_act_layer(cfg.act_layer)
|
|
|
|
|
attn_layer = partial(get_attn(cfg.attn_layer), **cfg.attn_kwargs) if cfg.attn_layer else None
|
|
|
|
|
layer_cfg = dict(norm_layer=norm_layer, act_layer=act_layer, attn_layer=attn_layer)
|
|
|
|
|
layers = get_layer_fns(cfg)
|
|
|
|
|
|
|
|
|
|
self.feature_info = []
|
|
|
|
|
stem_chs = int(round((cfg.stem_chs or cfg.blocks[0].c) * cfg.width_factor))
|
|
|
|
|
self.stem = create_stem(in_chans, stem_chs, cfg.stem_type, layer_cfg=layer_cfg)
|
|
|
|
|
self.stem, stem_feat = create_byob_stem(in_chans, stem_chs, cfg.stem_type, cfg.stem_pool, layers=layers)
|
|
|
|
|
self.feature_info.extend(stem_feat[:-1])
|
|
|
|
|
|
|
|
|
|
self.feature_info = []
|
|
|
|
|
depths = [bc.d for bc in cfg.blocks]
|
|
|
|
|
dpr = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depths)).split(depths)]
|
|
|
|
|
prev_name = 'stem'
|
|
|
|
|
prev_chs = stem_chs
|
|
|
|
|
net_stride = 2
|
|
|
|
|
dilation = 1
|
|
|
|
|
stages = []
|
|
|
|
|
for stage_idx, block_cfg in enumerate(cfg.blocks):
|
|
|
|
|
stride = block_cfg.s
|
|
|
|
|
if stride != 1:
|
|
|
|
|
self.feature_info.append(dict(num_chs=prev_chs, reduction=net_stride, module=prev_name))
|
|
|
|
|
if net_stride >= output_stride and stride > 1:
|
|
|
|
|
dilation *= stride
|
|
|
|
|
stride = 1
|
|
|
|
|
net_stride *= stride
|
|
|
|
|
first_dilation = 1 if dilation in (1, 2) else 2
|
|
|
|
|
|
|
|
|
|
blocks = []
|
|
|
|
|
for block_idx in range(block_cfg.d):
|
|
|
|
|
out_chs = make_divisible(block_cfg.c * cfg.width_factor)
|
|
|
|
|
group_size = block_cfg.gs
|
|
|
|
|
if isinstance(group_size, Callable):
|
|
|
|
|
group_size = group_size(out_chs, block_idx)
|
|
|
|
|
block_kwargs = dict( # Blocks used in this model must accept these arguments
|
|
|
|
|
in_chs=prev_chs,
|
|
|
|
|
out_chs=out_chs,
|
|
|
|
|
stride=stride if block_idx == 0 else 1,
|
|
|
|
|
dilation=(first_dilation, dilation),
|
|
|
|
|
group_size=group_size,
|
|
|
|
|
bottle_ratio=block_cfg.br,
|
|
|
|
|
downsample=cfg.downsample,
|
|
|
|
|
drop_path_rate=dpr[stage_idx][block_idx],
|
|
|
|
|
layer_cfg=layer_cfg,
|
|
|
|
|
)
|
|
|
|
|
blocks += [create_block(block_cfg.type, **block_kwargs)]
|
|
|
|
|
first_dilation = dilation
|
|
|
|
|
prev_chs = out_chs
|
|
|
|
|
stages += [nn.Sequential(*blocks)]
|
|
|
|
|
prev_name = f'stages.{stage_idx}'
|
|
|
|
|
self.stages = nn.Sequential(*stages)
|
|
|
|
|
self.stages, stage_feat = create_byob_stages(cfg, drop_path_rate, output_stride, stem_feat[-1], layers=layers)
|
|
|
|
|
self.feature_info.extend(stage_feat[:-1])
|
|
|
|
|
|
|
|
|
|
prev_chs = stage_feat[-1]['num_chs']
|
|
|
|
|
if cfg.num_features:
|
|
|
|
|
self.num_features = int(round(cfg.width_factor * cfg.num_features))
|
|
|
|
|
self.final_conv = ConvBnAct(prev_chs, self.num_features, 1, **_na_args(layer_cfg))
|
|
|
|
|
self.final_conv = layers.conv_norm_act(prev_chs, self.num_features, 1)
|
|
|
|
|
else:
|
|
|
|
|
self.num_features = prev_chs
|
|
|
|
|
self.final_conv = nn.Identity()
|
|
|
|
|
self.feature_info += [dict(num_chs=self.num_features, reduction=net_stride, module='final_conv')]
|
|
|
|
|
self.feature_info += [
|
|
|
|
|
dict(num_chs=self.num_features, reduction=stage_feat[-1]['reduction'], module='final_conv')]
|
|
|
|
|
|
|
|
|
|
self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate)
|
|
|
|
|
|
|
|
|
|
for n, m in self.named_modules():
|
|
|
|
|
if isinstance(m, nn.Conv2d):
|
|
|
|
|
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
|
|
|
|
fan_out //= m.groups
|
|
|
|
|
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
|
|
|
|
|
if m.bias is not None:
|
|
|
|
|
m.bias.data.zero_()
|
|
|
|
|
elif isinstance(m, nn.Linear):
|
|
|
|
|
nn.init.normal_(m.weight, mean=0.0, std=0.01)
|
|
|
|
|
nn.init.zeros_(m.bias)
|
|
|
|
|
elif isinstance(m, nn.BatchNorm2d):
|
|
|
|
|
nn.init.ones_(m.weight)
|
|
|
|
|
nn.init.zeros_(m.bias)
|
|
|
|
|
_init_weights(m, n)
|
|
|
|
|
for m in self.modules():
|
|
|
|
|
# call each block's weight init for block-specific overrides to init above
|
|
|
|
|
if hasattr(m, 'init_weights'):
|
|
|
|
@ -642,6 +771,22 @@ class ByobNet(nn.Module):
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _init_weights(m, n=''):
|
|
|
|
|
if isinstance(m, nn.Conv2d):
|
|
|
|
|
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
|
|
|
|
fan_out //= m.groups
|
|
|
|
|
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
|
|
|
|
|
if m.bias is not None:
|
|
|
|
|
m.bias.data.zero_()
|
|
|
|
|
elif isinstance(m, nn.Linear):
|
|
|
|
|
nn.init.normal_(m.weight, mean=0.0, std=0.01)
|
|
|
|
|
if m.bias is not None:
|
|
|
|
|
nn.init.zeros_(m.bias)
|
|
|
|
|
elif isinstance(m, nn.BatchNorm2d):
|
|
|
|
|
nn.init.ones_(m.weight)
|
|
|
|
|
nn.init.zeros_(m.bias)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _create_byobnet(variant, pretrained=False, **kwargs):
|
|
|
|
|
return build_model_with_cfg(
|
|
|
|
|
ByobNet, variant, pretrained,
|
|
|
|
|