diff --git a/timm/models/__init__.py b/timm/models/__init__.py index cdba50e5..3ed8bdb3 100644 --- a/timm/models/__init__.py +++ b/timm/models/__init__.py @@ -1,3 +1,4 @@ +from .byoanet import * from .byobnet import * from .cspnet import * from .densenet import * @@ -39,5 +40,4 @@ from .helpers import load_checkpoint, resume_checkpoint, model_parameters from .layers import TestTimePoolHead, apply_test_time_pool from .layers import convert_splitbn_model from .layers import is_scriptable, is_exportable, set_scriptable, set_exportable, is_no_jit, set_no_jit -from .registry import * - +from .registry import register_model, model_entrypoint, list_models, is_model, list_modules, is_model_in_modules diff --git a/timm/models/byoanet.py b/timm/models/byoanet.py new file mode 100644 index 00000000..935b6309 --- /dev/null +++ b/timm/models/byoanet.py @@ -0,0 +1,427 @@ +""" Bring-Your-Own-Attention Network + +A flexible network w/ dataclass based config for stacking NN blocks including +self-attention (or similar) layers. + +Currently used to implement experimential variants of: + * Bottleneck Transformers + * Lambda ResNets + * HaloNets + +Consider all of the models here a WIP and likely to change. + +Hacked together by / copyright Ross Wightman, 2021. +""" +import math +from dataclasses import dataclass, field +from collections import OrderedDict +from typing import Tuple, List, Optional, Union, Any, Callable +from functools import partial + +import torch +import torch.nn as nn + +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from .byobnet import BlocksCfg, ByobCfg, create_byob_stem, create_byob_stages, create_downsample,\ + reduce_feat_size, register_block, num_groups, LayerFn, _init_weights +from .helpers import build_model_with_cfg +from .layers import ClassifierHead, ConvBnAct, DropPath, get_act_layer, convert_norm_act, get_attn, get_self_attn,\ + make_divisible, to_2tuple +from .registry import register_model + +__all__ = ['ByoaNet'] + + +def _cfg(url='', **kwargs): + return { + 'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7), + 'crop_pct': 0.875, 'interpolation': 'bilinear', + 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, + 'first_conv': 'stem.conv', 'classifier': 'head.fc', + **kwargs + } + + +default_cfgs = { + # GPU-Efficient (ResNet) weights + 'botnet50t_224': _cfg(url=''), + 'botnet50t_c4c5_224': _cfg(url=''), + + 'halonet_h1': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8)), + 'halonet_h1_c4c5': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8)), + 'halonet26t': _cfg(url=''), + 'halonet50t': _cfg(url=''), + + 'lambda_resnet26t': _cfg(url=''), + 'lambda_resnet50t': _cfg(url=''), +} + + +@dataclass +class ByoaBlocksCfg(BlocksCfg): + # FIXME allow overriding self_attn layer or args per block/stage, + pass + + +@dataclass +class ByoaCfg(ByobCfg): + blocks: Tuple[Union[ByoaBlocksCfg, Tuple[ByoaBlocksCfg, ...]], ...] = None + self_attn_layer: Optional[str] = None + self_attn_fixed_size: bool = False + self_attn_kwargs: dict = field(default_factory=lambda: dict()) + + +def interleave_attn( + types : Tuple[str, str], every: Union[int, List[int]], d, first: bool = False, **kwargs +) -> Tuple[ByoaBlocksCfg]: + """ interleave attn blocks + """ + assert len(types) == 2 + if isinstance(every, int): + every = list(range(0 if first else every, d, every)) + if not every: + every = [d - 1] + set(every) + blocks = [] + for i in range(d): + block_type = types[1] if i in every else types[0] + blocks += [ByoaBlocksCfg(type=block_type, d=1, **kwargs)] + return tuple(blocks) + + +model_cfgs = dict( + + botnet50t=ByoaCfg( + blocks=( + ByoaBlocksCfg(type='bottle', d=3, c=256, s=2, gs=0, br=0.25), + ByoaBlocksCfg(type='bottle', d=4, c=512, s=2, gs=0, br=0.25), + ByoaBlocksCfg(type='bottle', d=6, c=1024, s=2, gs=0, br=0.25), + ByoaBlocksCfg(type='self_attn', d=3, c=2048, s=1, gs=0, br=0.25), + ), + stem_chs=64, + stem_type='tiered', + stem_pool='', + num_features=0, + self_attn_layer='bottleneck', + self_attn_fixed_size=True, + self_attn_kwargs=dict() + ), + botnet50t_c4c5=ByoaCfg( + blocks=( + ByoaBlocksCfg(type='bottle', d=3, c=256, s=1, gs=0, br=0.25), + ByoaBlocksCfg(type='bottle', d=4, c=512, s=2, gs=0, br=0.25), + ( + ByoaBlocksCfg(type='self_attn', d=1, c=1024, s=2, gs=0, br=0.25), + ByoaBlocksCfg(type='bottle', d=5, c=1024, s=1, gs=0, br=0.25), + ), + ( + ByoaBlocksCfg(type='self_attn', d=1, c=2048, s=2, gs=0, br=0.25), + ByoaBlocksCfg(type='bottle', d=2, c=2048, s=1, gs=0, br=0.25), + ) + ), + stem_chs=64, + stem_type='tiered', + stem_pool='maxpool', + num_features=0, + self_attn_layer='bottleneck', + self_attn_fixed_size=True, + self_attn_kwargs=dict() + ), + + halonet_h1=ByoaCfg( + blocks=( + ByoaBlocksCfg(type='self_attn', d=3, c=64, s=1, gs=0, br=1.0), + ByoaBlocksCfg(type='self_attn', d=3, c=128, s=2, gs=0, br=1.0), + ByoaBlocksCfg(type='self_attn', d=10, c=256, s=2, gs=0, br=1.0), + ByoaBlocksCfg(type='self_attn', d=3, c=512, s=2, gs=0, br=1.0), + ), + stem_chs=64, + stem_type='7x7', + stem_pool='maxpool', + num_features=0, + self_attn_layer='halo', + self_attn_kwargs=dict(block_size=8, halo_size=3), + ), + halonet_h1_c4c5=ByoaCfg( + blocks=( + ByoaBlocksCfg(type='bottle', d=3, c=64, s=1, gs=0, br=1.0), + ByoaBlocksCfg(type='bottle', d=3, c=128, s=2, gs=0, br=1.0), + ByoaBlocksCfg(type='self_attn', d=10, c=256, s=2, gs=0, br=1.0), + ByoaBlocksCfg(type='self_attn', d=3, c=512, s=2, gs=0, br=1.0), + ), + stem_chs=64, + stem_type='tiered', + stem_pool='maxpool', + num_features=0, + self_attn_layer='halo', + self_attn_kwargs=dict(block_size=8, halo_size=3), + ), + halonet26t=ByoaCfg( + blocks=( + ByoaBlocksCfg(type='bottle', d=2, c=256, s=1, gs=0, br=0.25), + ByoaBlocksCfg(type='bottle', d=2, c=512, s=2, gs=0, br=0.25), + ByoaBlocksCfg(type='bottle', d=2, c=1024, s=2, gs=0, br=0.25), + ByoaBlocksCfg(type='self_attn', d=2, c=2048, s=2, gs=0, br=0.25), + ), + stem_chs=64, + stem_type='tiered', + stem_pool='maxpool', + num_features=0, + self_attn_layer='halo', + self_attn_kwargs=dict(block_size=7, halo_size=2) + ), + halonet50t=ByoaCfg( + blocks=( + ByoaBlocksCfg(type='bottle', d=3, c=256, s=1, gs=0, br=0.25), + ByoaBlocksCfg(type='bottle', d=4, c=512, s=2, gs=0, br=0.25), + ByoaBlocksCfg(type='bottle', d=6, c=1024, s=2, gs=0, br=0.25), + ByoaBlocksCfg(type='self_attn', d=3, c=2048, s=2, gs=0, br=0.25), + ), + stem_chs=64, + stem_type='tiered', + stem_pool='maxpool', + num_features=0, + self_attn_layer='halo', + self_attn_kwargs=dict(block_size=7, halo_size=2) + ), + + lambda_resnet26t=ByoaCfg( + blocks=( + ByoaBlocksCfg(type='bottle', d=2, c=256, s=1, gs=0, br=0.25), + ByoaBlocksCfg(type='bottle', d=2, c=512, s=2, gs=0, br=0.25), + interleave_attn(types=('bottle', 'self_attn'), every=1, d=2, c=1024, s=2, gs=0, br=0.25), + ByoaBlocksCfg(type='self_attn', d=2, c=2048, s=2, gs=0, br=0.25), + ), + stem_chs=64, + stem_type='tiered', + stem_pool='maxpool', + num_features=0, + self_attn_layer='lambda', + self_attn_kwargs=dict() + ), + lambda_resnet50t=ByoaCfg( + blocks=( + ByoaBlocksCfg(type='bottle', d=3, c=256, s=1, gs=0, br=0.25), + ByoaBlocksCfg(type='bottle', d=4, c=512, s=2, gs=0, br=0.25), + interleave_attn(types=('bottle', 'self_attn'), every=3, d=6, c=1024, s=2, gs=0, br=0.25), + ByoaBlocksCfg(type='self_attn', d=3, c=2048, s=2, gs=0, br=0.25), + ), + stem_chs=64, + stem_type='tiered', + stem_pool='maxpool', + num_features=0, + self_attn_layer='lambda', + self_attn_kwargs=dict() + ), +) + + +@dataclass +class ByoaLayerFn(LayerFn): + self_attn: Optional[Callable] = None + + +class SelfAttnBlock(nn.Module): + """ ResNet-like Bottleneck Block - 1x1 - optional kxk - self attn - 1x1 + """ + + def __init__(self, in_chs, out_chs, kernel_size=3, stride=1, dilation=(1, 1), bottle_ratio=1., group_size=None, + downsample='avg', extra_conv=False, linear_out=False, post_attn_na=True, feat_size=None, + layers: ByoaLayerFn = None, drop_block=None, drop_path_rate=0.): + super(SelfAttnBlock, self).__init__() + assert layers is not None + mid_chs = make_divisible(out_chs * bottle_ratio) + groups = num_groups(group_size, mid_chs) + + if in_chs != out_chs or stride != 1 or dilation[0] != dilation[1]: + self.shortcut = create_downsample( + downsample, in_chs=in_chs, out_chs=out_chs, stride=stride, dilation=dilation[0], + apply_act=False, layers=layers) + else: + self.shortcut = nn.Identity() + + self.conv1_1x1 = layers.conv_norm_act(in_chs, mid_chs, 1) + if extra_conv: + self.conv2_kxk = layers.conv_norm_act( + mid_chs, mid_chs, kernel_size, stride=stride, dilation=dilation[0], + groups=groups, drop_block=drop_block) + stride = 1 # striding done via conv if enabled + else: + self.conv2_kxk = nn.Identity() + opt_kwargs = {} if feat_size is None else dict(feat_size=feat_size) + # FIXME need to dilate self attn to have dilated network support, moop moop + self.self_attn = layers.self_attn(mid_chs, stride=stride, **opt_kwargs) + self.post_attn = layers.norm_act(mid_chs) if post_attn_na else nn.Identity() + self.conv3_1x1 = layers.conv_norm_act(mid_chs, out_chs, 1, apply_act=False) + self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity() + self.act = nn.Identity() if linear_out else layers.act(inplace=True) + + def init_weights(self, zero_init_last_bn=False): + if zero_init_last_bn: + nn.init.zeros_(self.conv3_1x1.bn.weight) + + def forward(self, x): + shortcut = self.shortcut(x) + + x = self.conv1_1x1(x) + x = self.conv2_kxk(x) + x = self.self_attn(x) + x = self.post_attn(x) + x = self.conv3_1x1(x) + x = self.drop_path(x) + + x = self.act(x + shortcut) + return x + +register_block('self_attn', SelfAttnBlock) + + +def _byoa_block_args(block_kwargs, block_cfg: ByoaBlocksCfg, model_cfg: ByoaCfg, feat_size=None): + if block_cfg.type == 'self_attn' and model_cfg.self_attn_fixed_size: + assert feat_size is not None + block_kwargs['feat_size'] = feat_size + return block_kwargs + + +def get_layer_fns(cfg: ByoaCfg): + act = get_act_layer(cfg.act_layer) + norm_act = convert_norm_act(norm_layer=cfg.norm_layer, act_layer=act) + conv_norm_act = partial(ConvBnAct, norm_layer=cfg.norm_layer, act_layer=act) + attn = partial(get_attn(cfg.attn_layer), **cfg.attn_kwargs) if cfg.attn_layer else None + self_attn = partial(get_self_attn(cfg.self_attn_layer), **cfg.self_attn_kwargs) if cfg.self_attn_layer else None + layer_fn = ByoaLayerFn( + conv_norm_act=conv_norm_act, norm_act=norm_act, act=act, attn=attn, self_attn=self_attn) + return layer_fn + + +class ByoaNet(nn.Module): + """ 'Bring-your-own-attention' Net + + A ResNet inspired backbone that supports interleaving traditional residual blocks with + 'Self Attention' bottleneck blocks that replace the bottleneck kxk conv w/ a self-attention + or similar module. + + FIXME This class network definition is almost the same as ByobNet, I'd like to merge them but + torchscript limitations prevent sensible inheritance overrides. + """ + def __init__(self, cfg: ByoaCfg, num_classes=1000, in_chans=3, output_stride=32, global_pool='avg', + zero_init_last_bn=True, img_size=None, drop_rate=0., drop_path_rate=0.): + super().__init__() + self.num_classes = num_classes + self.drop_rate = drop_rate + layers = get_layer_fns(cfg) + feat_size = to_2tuple(img_size) if img_size is not None else None + + self.feature_info = [] + stem_chs = int(round((cfg.stem_chs or cfg.blocks[0].c) * cfg.width_factor)) + self.stem, stem_feat = create_byob_stem(in_chans, stem_chs, cfg.stem_type, cfg.stem_pool, layers=layers) + self.feature_info.extend(stem_feat[:-1]) + feat_size = reduce_feat_size(feat_size, stride=stem_feat[-1]['reduction']) + + self.stages, stage_feat = create_byob_stages( + cfg, drop_path_rate, output_stride, stem_feat[-1], + feat_size=feat_size, layers=layers, extra_args_fn=_byoa_block_args) + self.feature_info.extend(stage_feat[:-1]) + + prev_chs = stage_feat[-1]['num_chs'] + if cfg.num_features: + self.num_features = int(round(cfg.width_factor * cfg.num_features)) + self.final_conv = layers.conv_norm_act(prev_chs, self.num_features, 1) + else: + self.num_features = prev_chs + self.final_conv = nn.Identity() + self.feature_info += [ + dict(num_chs=self.num_features, reduction=stage_feat[-1]['reduction'], module='final_conv')] + + self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate) + + for n, m in self.named_modules(): + _init_weights(m, n) + for m in self.modules(): + # call each block's weight init for block-specific overrides to init above + if hasattr(m, 'init_weights'): + m.init_weights(zero_init_last_bn=zero_init_last_bn) + + def get_classifier(self): + return self.head.fc + + def reset_classifier(self, num_classes, global_pool='avg'): + self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate) + + def forward_features(self, x): + x = self.stem(x) + x = self.stages(x) + x = self.final_conv(x) + return x + + def forward(self, x): + x = self.forward_features(x) + x = self.head(x) + return x + + +def _create_byoanet(variant, cfg_variant=None, pretrained=False, **kwargs): + return build_model_with_cfg( + ByoaNet, variant, pretrained, + default_cfg=default_cfgs[variant], + model_cfg=model_cfgs[variant] if not cfg_variant else model_cfgs[cfg_variant], + feature_cfg=dict(flatten_sequential=True), + **kwargs) + + +@register_model +def botnet50t_224(pretrained=False, **kwargs): + """ + """ + kwargs.setdefault('img_size', 224) + return _create_byoanet('botnet50t_224', 'botnet50t', pretrained=pretrained, **kwargs) + + +@register_model +def botnet50t_c4c5_224(pretrained=False, **kwargs): + """ + """ + kwargs.setdefault('img_size', 224) + return _create_byoanet('botnet50t_c4c5_224', 'botnet50t_c4c5', pretrained=pretrained, **kwargs) + + +@register_model +def halonet_h1(pretrained=False, **kwargs): + """ + """ + return _create_byoanet('halonet_h1', pretrained=pretrained, **kwargs) + + +@register_model +def halonet_h1_c4c5(pretrained=False, **kwargs): + """ + """ + return _create_byoanet('halonet_h1_c4c5', pretrained=pretrained, **kwargs) + + +@register_model +def halonet26t(pretrained=False, **kwargs): + """ + """ + return _create_byoanet('halonet26t', pretrained=pretrained, **kwargs) + + +@register_model +def halonet50t(pretrained=False, **kwargs): + """ + """ + return _create_byoanet('halonet50t', pretrained=pretrained, **kwargs) + + +@register_model +def lambda_resnet26t(pretrained=False, **kwargs): + """ + """ + return _create_byoanet('lambda_resnet26t', pretrained=pretrained, **kwargs) + + +@register_model +def lambda_resnet50t(pretrained=False, **kwargs): + """ + """ + return _create_byoanet('lambda_resnet50t', pretrained=pretrained, **kwargs) diff --git a/timm/models/byobnet.py b/timm/models/byobnet.py index c5ccb70b..15c718f5 100644 --- a/timm/models/byobnet.py +++ b/timm/models/byobnet.py @@ -25,9 +25,9 @@ above nets that include attention. Hacked together by / copyright Ross Wightman, 2021. """ import math -from dataclasses import dataclass, field +from dataclasses import dataclass, field, replace from collections import OrderedDict -from typing import Tuple, Dict, Optional, Union, Any, Callable +from typing import Tuple, List, Optional, Union, Any, Callable, Sequence from functools import partial import torch @@ -35,11 +35,11 @@ import torch.nn as nn from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from .helpers import build_model_with_cfg -from .layers import ClassifierHead, ConvBnAct, DropPath, AvgPool2dSame, \ - create_conv2d, get_act_layer, get_attn, convert_norm_act, make_divisible +from .layers import ClassifierHead, ConvBnAct, BatchNormAct2d, DropPath, AvgPool2dSame, \ + create_conv2d, get_act_layer, convert_norm_act, get_attn, make_divisible from .registry import register_model -__all__ = ['ByobNet', 'ByobCfg', 'BlocksCfg'] +__all__ = ['ByobNet', 'ByobCfg', 'BlocksCfg', 'create_byob_stem', 'create_block'] def _cfg(url='', **kwargs): @@ -98,20 +98,22 @@ class BlocksCfg: s: int = 2 # stride of stage (first block) gs: Optional[Union[int, Callable]] = None # group-size of blocks in stage, conv is depthwise if gs == 1 br: float = 1. # bottleneck-ratio of blocks in stage + no_attn: bool = True # disable channel attn (ie SE) when layer is set for model @dataclass class ByobCfg: - blocks: Tuple[BlocksCfg, ...] + blocks: Tuple[Union[BlocksCfg, Tuple[BlocksCfg, ...]], ...] downsample: str = 'conv1x1' stem_type: str = '3x3' + stem_pool: str = '' stem_chs: int = 32 width_factor: float = 1.0 num_features: int = 0 # num out_channels for final conv, no final 1x1 conv if 0 zero_init_last_bn: bool = True act_layer: str = 'relu' - norm_layer: nn.Module = nn.BatchNorm2d + norm_layer: str = 'batchnorm' attn_layer: Optional[str] = None attn_kwargs: dict = field(default_factory=lambda: dict()) @@ -201,17 +203,29 @@ model_cfgs = dict( stem_type='rep', stem_chs=64, ), -) - -def _na_args(cfg: dict): - return dict( - norm_layer=cfg.get('norm_layer', nn.BatchNorm2d), - act_layer=cfg.get('act_layer', nn.ReLU)) + resnet52q=ByobCfg( + blocks=( + BlocksCfg(type='bottle', d=2, c=256, s=1, gs=32, br=0.25), + BlocksCfg(type='bottle', d=4, c=512, s=2, gs=32, br=0.25), + BlocksCfg(type='bottle', d=6, c=1536, s=2, gs=32, br=0.25), + BlocksCfg(type='bottle', d=4, c=1536, s=2, gs=1, br=1.0), + ), + stem_chs=128, + stem_type='quad', + num_features=2048, + act_layer='silu', + ), +) -def _ex_tuple(cfg: dict, *names): - return tuple([cfg.get(n, None) for n in names]) +def expand_blocks_cfg(stage_blocks_cfg: Union[BlocksCfg, Sequence[BlocksCfg]]) -> List[BlocksCfg]: + if not isinstance(stage_blocks_cfg, Sequence): + stage_blocks_cfg = (stage_blocks_cfg,) + block_cfgs = [] + for i, cfg in enumerate(stage_blocks_cfg): + block_cfgs += [replace(cfg, d=1) for _ in range(cfg.d)] + return block_cfgs def num_groups(group_size, channels): @@ -223,27 +237,36 @@ def num_groups(group_size, channels): return channels // group_size +@dataclass +class LayerFn: + conv_norm_act: Callable = ConvBnAct + norm_act: Callable = BatchNormAct2d + act: Callable = nn.ReLU + attn: Optional[Callable] = None + + class DownsampleAvg(nn.Module): - def __init__(self, in_chs, out_chs, stride=1, dilation=1, apply_act=False, norm_layer=None, act_layer=None): + def __init__(self, in_chs, out_chs, stride=1, dilation=1, apply_act=False, layers: LayerFn = None): """ AvgPool Downsampling as in 'D' ResNet variants.""" super(DownsampleAvg, self).__init__() + layers = layers or LayerFn() avg_stride = stride if dilation == 1 else 1 if stride > 1 or dilation > 1: avg_pool_fn = AvgPool2dSame if avg_stride == 1 and dilation > 1 else nn.AvgPool2d self.pool = avg_pool_fn(2, avg_stride, ceil_mode=True, count_include_pad=False) else: self.pool = nn.Identity() - self.conv = ConvBnAct(in_chs, out_chs, 1, apply_act=apply_act, norm_layer=norm_layer, act_layer=act_layer) + self.conv = layers.conv_norm_act(in_chs, out_chs, 1, apply_act=apply_act) def forward(self, x): return self.conv(self.pool(x)) -def create_downsample(type, **kwargs): - if type == 'avg': +def create_downsample(downsample_type, layers: LayerFn, **kwargs): + if downsample_type == 'avg': return DownsampleAvg(**kwargs) else: - return ConvBnAct(kwargs.pop('in_chs'), kwargs.pop('out_chs'), kernel_size=1, **kwargs) + return layers.conv_norm_act(kwargs.pop('in_chs'), kwargs.pop('out_chs'), kernel_size=1, **kwargs) class BasicBlock(nn.Module): @@ -252,28 +275,25 @@ class BasicBlock(nn.Module): def __init__( self, in_chs, out_chs, kernel_size=3, stride=1, dilation=(1, 1), group_size=None, bottle_ratio=1.0, - downsample='avg', linear_out=False, layer_cfg=None, drop_block=None, drop_path_rate=0.): + downsample='avg', linear_out=False, layers: LayerFn = None, drop_block=None, drop_path_rate=0.): super(BasicBlock, self).__init__() - layer_cfg = layer_cfg or {} - act_layer, attn_layer = _ex_tuple(layer_cfg, 'act_layer', 'attn_layer') - layer_args = _na_args(layer_cfg) + layers = layers or LayerFn() mid_chs = make_divisible(out_chs * bottle_ratio) groups = num_groups(group_size, mid_chs) if in_chs != out_chs or stride != 1 or dilation[0] != dilation[1]: self.shortcut = create_downsample( downsample, in_chs=in_chs, out_chs=out_chs, stride=stride, dilation=dilation[0], - apply_act=False, **layer_args) + apply_act=False, layers=layers) else: self.shortcut = nn.Identity() - self.conv1_kxk = ConvBnAct(in_chs, mid_chs, kernel_size, stride=stride, dilation=dilation[0], **layer_args) - self.conv2_kxk = ConvBnAct( - mid_chs, out_chs, kernel_size, dilation=dilation[1], groups=groups, - drop_block=drop_block, apply_act=False, **layer_args) - self.attn = nn.Identity() if attn_layer is None else attn_layer(out_chs) + self.conv1_kxk = layers.conv_norm_act(in_chs, mid_chs, kernel_size, stride=stride, dilation=dilation[0]) + self.conv2_kxk = layers.conv_norm_act( + mid_chs, out_chs, kernel_size, dilation=dilation[1], groups=groups, drop_block=drop_block, apply_act=False) + self.attn = nn.Identity() if layers.attn is None else layers.attn(out_chs) self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity() - self.act = nn.Identity() if linear_out else act_layer(inplace=True) + self.act = nn.Identity() if linear_out else layers.act(inplace=True) def init_weights(self, zero_init_last_bn=False): if zero_init_last_bn: @@ -297,29 +317,27 @@ class BottleneckBlock(nn.Module): """ def __init__(self, in_chs, out_chs, kernel_size=3, stride=1, dilation=(1, 1), bottle_ratio=1., group_size=None, - downsample='avg', linear_out=False, layer_cfg=None, drop_block=None, drop_path_rate=0.): + downsample='avg', linear_out=False, layers : LayerFn = None, drop_block=None, drop_path_rate=0.): super(BottleneckBlock, self).__init__() - layer_cfg = layer_cfg or {} - act_layer, attn_layer = _ex_tuple(layer_cfg, 'act_layer', 'attn_layer') - layer_args = _na_args(layer_cfg) + layers = layers or LayerFn() mid_chs = make_divisible(out_chs * bottle_ratio) groups = num_groups(group_size, mid_chs) if in_chs != out_chs or stride != 1 or dilation[0] != dilation[1]: self.shortcut = create_downsample( downsample, in_chs=in_chs, out_chs=out_chs, stride=stride, dilation=dilation[0], - apply_act=False, **layer_args) + apply_act=False, layers=layers) else: self.shortcut = nn.Identity() - self.conv1_1x1 = ConvBnAct(in_chs, mid_chs, 1, **layer_args) - self.conv2_kxk = ConvBnAct( + self.conv1_1x1 = layers.conv_norm_act(in_chs, mid_chs, 1) + self.conv2_kxk = layers.conv_norm_act( mid_chs, mid_chs, kernel_size, stride=stride, dilation=dilation[0], - groups=groups, drop_block=drop_block, **layer_args) - self.attn = nn.Identity() if attn_layer is None else attn_layer(mid_chs) - self.conv3_1x1 = ConvBnAct(mid_chs, out_chs, 1, apply_act=False, **layer_args) + groups=groups, drop_block=drop_block) + self.attn = nn.Identity() if layers.attn is None else layers.attn(mid_chs) + self.conv3_1x1 = layers.conv_norm_act(mid_chs, out_chs, 1, apply_act=False) self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity() - self.act = nn.Identity() if linear_out else act_layer(inplace=True) + self.act = nn.Identity() if linear_out else layers.act(inplace=True) def init_weights(self, zero_init_last_bn=False): if zero_init_last_bn: @@ -350,28 +368,26 @@ class DarkBlock(nn.Module): """ def __init__(self, in_chs, out_chs, kernel_size=3, stride=1, dilation=(1, 1), bottle_ratio=1.0, group_size=None, - downsample='avg', linear_out=False, layer_cfg=None, drop_block=None, drop_path_rate=0.): + downsample='avg', linear_out=False, layers: LayerFn = None, drop_block=None, drop_path_rate=0.): super(DarkBlock, self).__init__() - layer_cfg = layer_cfg or {} - act_layer, attn_layer = _ex_tuple(layer_cfg, 'act_layer', 'attn_layer') - layer_args = _na_args(layer_cfg) + layers = layers or LayerFn() mid_chs = make_divisible(out_chs * bottle_ratio) groups = num_groups(group_size, mid_chs) if in_chs != out_chs or stride != 1 or dilation[0] != dilation[1]: self.shortcut = create_downsample( downsample, in_chs=in_chs, out_chs=out_chs, stride=stride, dilation=dilation[0], - apply_act=False, **layer_args) + apply_act=False, layers=layers) else: self.shortcut = nn.Identity() - self.conv1_1x1 = ConvBnAct(in_chs, mid_chs, 1, **layer_args) - self.conv2_kxk = ConvBnAct( + self.conv1_1x1 = layers.conv_norm_act(in_chs, mid_chs, 1) + self.conv2_kxk = layers.conv_norm_act( mid_chs, out_chs, kernel_size, stride=stride, dilation=dilation[0], - groups=groups, drop_block=drop_block, apply_act=False, **layer_args) - self.attn = nn.Identity() if attn_layer is None else attn_layer(out_chs) + groups=groups, drop_block=drop_block, apply_act=False) + self.attn = nn.Identity() if layers.attn is None else layers.attn(out_chs) self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity() - self.act = nn.Identity() if linear_out else act_layer(inplace=True) + self.act = nn.Identity() if linear_out else layers.act(inplace=True) def init_weights(self, zero_init_last_bn=False): if zero_init_last_bn: @@ -399,28 +415,26 @@ class EdgeBlock(nn.Module): """ def __init__(self, in_chs, out_chs, kernel_size=3, stride=1, dilation=(1, 1), bottle_ratio=1.0, group_size=None, - downsample='avg', linear_out=False, layer_cfg=None, drop_block=None, drop_path_rate=0.): + downsample='avg', linear_out=False, layers: LayerFn = None, drop_block=None, drop_path_rate=0.): super(EdgeBlock, self).__init__() - layer_cfg = layer_cfg or {} - act_layer, attn_layer = _ex_tuple(layer_cfg, 'act_layer', 'attn_layer') - layer_args = _na_args(layer_cfg) + layers = layers or LayerFn() mid_chs = make_divisible(out_chs * bottle_ratio) groups = num_groups(group_size, mid_chs) if in_chs != out_chs or stride != 1 or dilation[0] != dilation[1]: self.shortcut = create_downsample( downsample, in_chs=in_chs, out_chs=out_chs, stride=stride, dilation=dilation[0], - apply_act=False, **layer_args) + apply_act=False, layers=layers) else: self.shortcut = nn.Identity() - self.conv1_kxk = ConvBnAct( + self.conv1_kxk = layers.conv_norm_act( in_chs, mid_chs, kernel_size, stride=stride, dilation=dilation[0], - groups=groups, drop_block=drop_block, **layer_args) - self.attn = nn.Identity() if attn_layer is None else attn_layer(out_chs) - self.conv2_1x1 = ConvBnAct(mid_chs, out_chs, 1, apply_act=False, **layer_args) + groups=groups, drop_block=drop_block) + self.attn = nn.Identity() if layers.attn is None else layers.attn(out_chs) + self.conv2_1x1 = layers.conv_norm_act(mid_chs, out_chs, 1, apply_act=False) self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity() - self.act = nn.Identity() if linear_out else act_layer(inplace=True) + self.act = nn.Identity() if linear_out else layers.act(inplace=True) def init_weights(self, zero_init_last_bn=False): if zero_init_last_bn: @@ -446,23 +460,20 @@ class RepVggBlock(nn.Module): """ def __init__(self, in_chs, out_chs, kernel_size=3, stride=1, dilation=(1, 1), bottle_ratio=1.0, group_size=None, - downsample='', layer_cfg=None, drop_block=None, drop_path_rate=0.): + downsample='', layers : LayerFn = None, drop_block=None, drop_path_rate=0.): super(RepVggBlock, self).__init__() - layer_cfg = layer_cfg or {} - act_layer, norm_layer, attn_layer = _ex_tuple(layer_cfg, 'act_layer', 'norm_layer', 'attn_layer') - norm_layer = convert_norm_act(norm_layer=norm_layer, act_layer=act_layer) - layer_args = _na_args(layer_cfg) + layers = layers or LayerFn() groups = num_groups(group_size, in_chs) use_ident = in_chs == out_chs and stride == 1 and dilation[0] == dilation[1] - self.identity = norm_layer(out_chs, apply_act=False) if use_ident else None - self.conv_kxk = ConvBnAct( + self.identity = layers.norm_act(out_chs, apply_act=False) if use_ident else None + self.conv_kxk = layers.conv_norm_act( in_chs, out_chs, kernel_size, stride=stride, dilation=dilation[0], - groups=groups, drop_block=drop_block, apply_act=False, **layer_args) - self.conv_1x1 = ConvBnAct(in_chs, out_chs, 1, stride=stride, groups=groups, apply_act=False, **layer_args) - self.attn = nn.Identity() if attn_layer is None else attn_layer(out_chs) + groups=groups, drop_block=drop_block, apply_act=False) + self.conv_1x1 = layers.conv_norm_act(in_chs, out_chs, 1, stride=stride, groups=groups, apply_act=False) + self.attn = nn.Identity() if layers.attn is None else layers.attn(out_chs) self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. and use_ident else nn.Identity() - self.act = act_layer(inplace=True) + self.act = layers.act(inplace=True) def init_weights(self, zero_init_last_bn=False): # NOTE this init overrides that base model init with specific changes for the block type @@ -504,33 +515,200 @@ def create_block(block: Union[str, nn.Module], **kwargs): return _block_registry[block](**kwargs) -def create_stem(in_chs, out_chs, stem_type='', layer_cfg=None): - layer_cfg = layer_cfg or {} - layer_args = _na_args(layer_cfg) - assert stem_type in ('', 'deep', 'deep_tiered', '3x3', '7x7', 'rep') - if 'deep' in stem_type: - # 3 deep 3x3 conv stack - stem = OrderedDict() - stem_chs = (out_chs // 2, out_chs // 2) - if 'tiered' in stem_type: - stem_chs = (3 * stem_chs[0] // 4, stem_chs[1]) - norm_layer, act_layer = _ex_tuple(layer_args, 'norm_layer', 'act_layer') - stem['conv1'] = create_conv2d(in_chs, stem_chs[0], kernel_size=3, stride=2) - stem['conv2'] = create_conv2d(stem_chs[0], stem_chs[1], kernel_size=3, stride=1) - stem['conv3'] = create_conv2d(stem_chs[1], out_chs, kernel_size=3, stride=1) - norm_act_layer = convert_norm_act(norm_layer=norm_layer, act_layer=act_layer) - stem['na'] = norm_act_layer(out_chs) - stem = nn.Sequential(stem) +# class Stem(nn.Module): +# +# def __init__(self, in_chs, out_chs, kernel_size=3, stride=4, pool='maxpool', +# num_rep=3, num_act=None, chs_decay=0.5, layers: LayerFn = None): +# super().__init__() +# assert stride in (2, 4) +# if pool: +# assert stride == 4 +# layers = layers or LayerFn() +# +# if isinstance(out_chs, (list, tuple)): +# num_rep = len(out_chs) +# stem_chs = out_chs +# else: +# stem_chs = [round(out_chs * chs_decay ** i) for i in range(num_rep)][::-1] +# +# self.stride = stride +# stem_strides = [2] + [1] * (num_rep - 1) +# if stride == 4 and not pool: +# # set last conv in stack to be strided if stride == 4 and no pooling layer +# stem_strides[-1] = 2 +# +# num_act = num_rep if num_act is None else num_act +# # if num_act < num_rep, first convs in stack won't have bn + act +# stem_norm_acts = [False] * (num_rep - num_act) + [True] * num_act +# prev_chs = in_chs +# convs = [] +# for i, (ch, s, na) in enumerate(zip(stem_chs, stem_strides, stem_norm_acts)): +# layer_fn = layers.conv_norm_act if na else create_conv2d +# convs.append(layer_fn(prev_chs, ch, kernel_size=kernel_size, stride=s)) +# prev_chs = ch +# self.conv = nn.Sequential(*convs) if len(convs) > 1 else convs[0] +# +# if not pool: +# self.pool = nn.Identity() +# elif 'max' in pool.lower(): +# self.pool = nn.MaxPool2d(3, 2, 1) if pool else nn.Identity() +# else: +# assert False, "Unknown pooling type" +# +# def forward(self, x): +# x = self.conv(x) +# x = self.pool(x) +# return x + + +class Stem(nn.Sequential): + + def __init__(self, in_chs, out_chs, kernel_size=3, stride=4, pool='maxpool', + num_rep=3, num_act=None, chs_decay=0.5, layers: LayerFn = None): + super().__init__() + assert stride in (2, 4) + layers = layers or LayerFn() + + if isinstance(out_chs, (list, tuple)): + num_rep = len(out_chs) + stem_chs = out_chs + else: + stem_chs = [round(out_chs * chs_decay ** i) for i in range(num_rep)][::-1] + + self.stride = stride + self.feature_info = [] # track intermediate features + prev_feat = '' + stem_strides = [2] + [1] * (num_rep - 1) + if stride == 4 and not pool: + # set last conv in stack to be strided if stride == 4 and no pooling layer + stem_strides[-1] = 2 + + num_act = num_rep if num_act is None else num_act + # if num_act < num_rep, first convs in stack won't have bn + act + stem_norm_acts = [False] * (num_rep - num_act) + [True] * num_act + prev_chs = in_chs + curr_stride = 1 + for i, (ch, s, na) in enumerate(zip(stem_chs, stem_strides, stem_norm_acts)): + layer_fn = layers.conv_norm_act if na else create_conv2d + conv_name = f'conv{i + 1}' + if i > 0 and s > 1: + self.feature_info.append(dict(num_chs=ch, reduction=curr_stride, module=prev_feat)) + self.add_module(conv_name, layer_fn(prev_chs, ch, kernel_size=kernel_size, stride=s)) + prev_chs = ch + curr_stride *= s + prev_feat = conv_name + + if 'max' in pool.lower(): + self.feature_info.append(dict(num_chs=prev_chs, reduction=curr_stride, module=prev_feat)) + self.add_module('pool', nn.MaxPool2d(3, 2, 1)) + curr_stride *= 2 + prev_feat = 'pool' + + self.feature_info.append(dict(num_chs=prev_chs, reduction=curr_stride, module=prev_feat)) + assert curr_stride == stride + + +def create_byob_stem(in_chs, out_chs, stem_type='', pool_type='', feat_prefix='stem', layers: LayerFn = None): + layers = layers or LayerFn() + assert stem_type in ('', 'quad', 'tiered', 'deep', 'rep', '7x7', '3x3') + if 'quad' in stem_type: + # based on NFNet stem, stack of 4 3x3 convs + num_act = 2 if 'quad2' in stem_type else None + stem = Stem(in_chs, out_chs, num_rep=4, num_act=num_act, pool=pool_type, layers=layers) + elif 'tiered' in stem_type: + # 3x3 stack of 3 convs as in my ResNet-T + stem = Stem(in_chs, (3 * out_chs // 8, out_chs // 2, out_chs), pool=pool_type, layers=layers) + elif 'deep' in stem_type: + # 3x3 stack of 3 convs as in ResNet-D + stem = Stem(in_chs, out_chs, num_rep=3, chs_decay=1.0, pool=pool_type, layers=layers) + elif 'rep' in stem_type: + stem = RepVggBlock(in_chs, out_chs, stride=2, layers=layers) elif '7x7' in stem_type: # 7x7 stem conv as in ResNet - stem = ConvBnAct(in_chs, out_chs, 7, stride=2, **layer_args) - elif 'rep' in stem_type: - stem = RepVggBlock(in_chs, out_chs, stride=2, layer_cfg=layer_cfg) + if pool_type: + stem = Stem(in_chs, out_chs, 7, num_rep=1, pool=pool_type, layers=layers) + else: + stem = layers.conv_norm_act(in_chs, out_chs, 7, stride=2) else: - # 3x3 stem conv as in RegNet - stem = ConvBnAct(in_chs, out_chs, 3, stride=2, **layer_args) + # 3x3 stem conv as in RegNet is the default + if pool_type: + stem = Stem(in_chs, out_chs, 3, num_rep=1, pool=pool_type, layers=layers) + else: + stem = layers.conv_norm_act(in_chs, out_chs, 3, stride=2) - return stem + if isinstance(stem, Stem): + feature_info = [dict(f, module='.'.join([feat_prefix, f['module']])) for f in stem.feature_info] + else: + feature_info = [dict(num_chs=out_chs, reduction=2, module=feat_prefix)] + return stem, feature_info + + +def reduce_feat_size(feat_size, stride=2): + return None if feat_size is None else tuple([s // stride for s in feat_size]) + + +def create_byob_stages( + cfg, drop_path_rate, output_stride, stem_feat, + feat_size=None, layers=None, extra_args_fn=None): + layers = layers or LayerFn() + feature_info = [] + block_cfgs = [expand_blocks_cfg(s) for s in cfg.blocks] + depths = [sum([bc.d for bc in stage_bcs]) for stage_bcs in block_cfgs] + dpr = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depths)).split(depths)] + dilation = 1 + net_stride = stem_feat['reduction'] + prev_chs = stem_feat['num_chs'] + prev_feat = stem_feat + stages = [] + for stage_idx, stage_block_cfgs in enumerate(block_cfgs): + stride = stage_block_cfgs[0].s + if stride != 1 and prev_feat: + feature_info.append(prev_feat) + if net_stride >= output_stride and stride > 1: + dilation *= stride + stride = 1 + net_stride *= stride + first_dilation = 1 if dilation in (1, 2) else 2 + + blocks = [] + for block_idx, block_cfg in enumerate(stage_block_cfgs): + out_chs = make_divisible(block_cfg.c * cfg.width_factor) + group_size = block_cfg.gs + if isinstance(group_size, Callable): + group_size = group_size(out_chs, block_idx) + block_kwargs = dict( # Blocks used in this model must accept these arguments + in_chs=prev_chs, + out_chs=out_chs, + stride=stride if block_idx == 0 else 1, + dilation=(first_dilation, dilation), + group_size=group_size, + bottle_ratio=block_cfg.br, + downsample=cfg.downsample, + drop_path_rate=dpr[stage_idx][block_idx], + layers=layers, + ) + if extra_args_fn is not None: + extra_args_fn(block_kwargs, block_cfg=block_cfg, model_cfg=cfg, feat_size=feat_size) + blocks += [create_block(block_cfg.type, **block_kwargs)] + first_dilation = dilation + prev_chs = out_chs + if stride > 1 and block_idx == 0: + feat_size = reduce_feat_size(feat_size, stride) + + stages += [nn.Sequential(*blocks)] + prev_feat = dict(num_chs=prev_chs, reduction=net_stride, module=f'stages.{stage_idx}') + + feature_info.append(prev_feat) + return nn.Sequential(*stages), feature_info + + +def get_layer_fns(cfg: ByobCfg): + act = get_act_layer(cfg.act_layer) + norm_act = convert_norm_act(norm_layer=cfg.norm_layer, act_layer=act) + conv_norm_act = partial(ConvBnAct, norm_layer=cfg.norm_layer, act_layer=act) + attn = partial(get_attn(cfg.attn_layer), **cfg.attn_kwargs) if cfg.attn_layer else None + layer_fn = LayerFn(conv_norm_act=conv_norm_act, norm_act=norm_act, act=act, attn=attn) + return layer_fn class ByobNet(nn.Module): @@ -546,79 +724,30 @@ class ByobNet(nn.Module): super().__init__() self.num_classes = num_classes self.drop_rate = drop_rate - norm_layer = cfg.norm_layer - act_layer = get_act_layer(cfg.act_layer) - attn_layer = partial(get_attn(cfg.attn_layer), **cfg.attn_kwargs) if cfg.attn_layer else None - layer_cfg = dict(norm_layer=norm_layer, act_layer=act_layer, attn_layer=attn_layer) + layers = get_layer_fns(cfg) + self.feature_info = [] stem_chs = int(round((cfg.stem_chs or cfg.blocks[0].c) * cfg.width_factor)) - self.stem = create_stem(in_chans, stem_chs, cfg.stem_type, layer_cfg=layer_cfg) + self.stem, stem_feat = create_byob_stem(in_chans, stem_chs, cfg.stem_type, cfg.stem_pool, layers=layers) + self.feature_info.extend(stem_feat[:-1]) - self.feature_info = [] - depths = [bc.d for bc in cfg.blocks] - dpr = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depths)).split(depths)] - prev_name = 'stem' - prev_chs = stem_chs - net_stride = 2 - dilation = 1 - stages = [] - for stage_idx, block_cfg in enumerate(cfg.blocks): - stride = block_cfg.s - if stride != 1: - self.feature_info.append(dict(num_chs=prev_chs, reduction=net_stride, module=prev_name)) - if net_stride >= output_stride and stride > 1: - dilation *= stride - stride = 1 - net_stride *= stride - first_dilation = 1 if dilation in (1, 2) else 2 - - blocks = [] - for block_idx in range(block_cfg.d): - out_chs = make_divisible(block_cfg.c * cfg.width_factor) - group_size = block_cfg.gs - if isinstance(group_size, Callable): - group_size = group_size(out_chs, block_idx) - block_kwargs = dict( # Blocks used in this model must accept these arguments - in_chs=prev_chs, - out_chs=out_chs, - stride=stride if block_idx == 0 else 1, - dilation=(first_dilation, dilation), - group_size=group_size, - bottle_ratio=block_cfg.br, - downsample=cfg.downsample, - drop_path_rate=dpr[stage_idx][block_idx], - layer_cfg=layer_cfg, - ) - blocks += [create_block(block_cfg.type, **block_kwargs)] - first_dilation = dilation - prev_chs = out_chs - stages += [nn.Sequential(*blocks)] - prev_name = f'stages.{stage_idx}' - self.stages = nn.Sequential(*stages) + self.stages, stage_feat = create_byob_stages(cfg, drop_path_rate, output_stride, stem_feat[-1], layers=layers) + self.feature_info.extend(stage_feat[:-1]) + prev_chs = stage_feat[-1]['num_chs'] if cfg.num_features: self.num_features = int(round(cfg.width_factor * cfg.num_features)) - self.final_conv = ConvBnAct(prev_chs, self.num_features, 1, **_na_args(layer_cfg)) + self.final_conv = layers.conv_norm_act(prev_chs, self.num_features, 1) else: self.num_features = prev_chs self.final_conv = nn.Identity() - self.feature_info += [dict(num_chs=self.num_features, reduction=net_stride, module='final_conv')] + self.feature_info += [ + dict(num_chs=self.num_features, reduction=stage_feat[-1]['reduction'], module='final_conv')] self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate) for n, m in self.named_modules(): - if isinstance(m, nn.Conv2d): - fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels - fan_out //= m.groups - m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) - if m.bias is not None: - m.bias.data.zero_() - elif isinstance(m, nn.Linear): - nn.init.normal_(m.weight, mean=0.0, std=0.01) - nn.init.zeros_(m.bias) - elif isinstance(m, nn.BatchNorm2d): - nn.init.ones_(m.weight) - nn.init.zeros_(m.bias) + _init_weights(m, n) for m in self.modules(): # call each block's weight init for block-specific overrides to init above if hasattr(m, 'init_weights'): @@ -642,6 +771,22 @@ class ByobNet(nn.Module): return x +def _init_weights(m, n=''): + if isinstance(m, nn.Conv2d): + fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + fan_out //= m.groups + m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) + if m.bias is not None: + m.bias.data.zero_() + elif isinstance(m, nn.Linear): + nn.init.normal_(m.weight, mean=0.0, std=0.01) + if m.bias is not None: + nn.init.zeros_(m.bias) + elif isinstance(m, nn.BatchNorm2d): + nn.init.ones_(m.weight) + nn.init.zeros_(m.bias) + + def _create_byobnet(variant, pretrained=False, **kwargs): return build_model_with_cfg( ByobNet, variant, pretrained, diff --git a/timm/models/layers/__init__.py b/timm/models/layers/__init__.py index 89fb859c..ac0b6b41 100644 --- a/timm/models/layers/__init__.py +++ b/timm/models/layers/__init__.py @@ -13,6 +13,7 @@ from .create_act import create_act_layer, get_act_layer, get_act_fn from .create_attn import get_attn, create_attn from .create_conv2d import create_conv2d from .create_norm_act import get_norm_act_layer, create_norm_act, convert_norm_act +from .create_self_attn import get_self_attn, create_self_attn from .drop import DropBlock2d, DropPath, drop_block_2d, drop_path from .eca import EcaModule, CecaModule from .evo_norm import EvoNormBatch2d, EvoNormSample2d @@ -20,6 +21,7 @@ from .helpers import to_ntuple, to_2tuple, to_3tuple, to_4tuple, make_divisible from .inplace_abn import InplaceAbn from .linear import Linear from .mixed_conv2d import MixedConv2d +from .norm import GroupNorm from .norm_act import BatchNormAct2d, GroupNormAct from .padding import get_padding, get_same_padding, pad_same from .pool2d_same import AvgPool2dSame, create_pool2d diff --git a/timm/models/layers/bottleneck_attn.py b/timm/models/layers/bottleneck_attn.py new file mode 100644 index 00000000..0bb0e27b --- /dev/null +++ b/timm/models/layers/bottleneck_attn.py @@ -0,0 +1,120 @@ +""" Bottleneck Self Attention (Bottleneck Transformers) + +Paper: `Bottleneck Transformers for Visual Recognition` - https://arxiv.org/abs/2101.11605 + +@misc{2101.11605, +Author = {Aravind Srinivas and Tsung-Yi Lin and Niki Parmar and Jonathon Shlens and Pieter Abbeel and Ashish Vaswani}, +Title = {Bottleneck Transformers for Visual Recognition}, +Year = {2021}, +} + +Based on ref gist at: https://gist.github.com/aravindsrinivas/56359b79f0ce4449bcb04ab4b56a57a2 + +This impl is a WIP but given that it is based on the ref gist likely not too far off. + +Hacked together by / Copyright 2021 Ross Wightman +""" +from typing import List + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .helpers import to_2tuple + + +def rel_logits_1d(q, rel_k, permute_mask: List[int]): + """ Compute relative logits along one dimension + + As per: https://gist.github.com/aravindsrinivas/56359b79f0ce4449bcb04ab4b56a57a2 + Originally from: `Attention Augmented Convolutional Networks` - https://arxiv.org/abs/1904.09925 + + Args: + q: (batch, heads, height, width, dim) + rel_k: (2 * width - 1, dim) + permute_mask: permute output dim according to this + """ + B, H, W, dim = q.shape + x = (q @ rel_k.transpose(-1, -2)) + x = x.reshape(-1, W, 2 * W -1) + + # pad to shift from relative to absolute indexing + x_pad = F.pad(x, [0, 1]).flatten(1) + x_pad = F.pad(x_pad, [0, W - 1]) + + # reshape and slice out the padded elements + x_pad = x_pad.reshape(-1, W + 1, 2 * W - 1) + x = x_pad[:, :W, W - 1:] + + # reshape and tile + x = x.reshape(B, H, 1, W, W).expand(-1, -1, H, -1, -1) + return x.permute(permute_mask) + + +class PosEmbedRel(nn.Module): + """ Relative Position Embedding + As per: https://gist.github.com/aravindsrinivas/56359b79f0ce4449bcb04ab4b56a57a2 + Originally from: `Attention Augmented Convolutional Networks` - https://arxiv.org/abs/1904.09925 + """ + def __init__(self, feat_size, dim_head, scale): + super().__init__() + self.height, self.width = to_2tuple(feat_size) + self.dim_head = dim_head + self.scale = scale + self.height_rel = nn.Parameter(torch.randn(self.height * 2 - 1, dim_head) * self.scale) + self.width_rel = nn.Parameter(torch.randn(self.width * 2 - 1, dim_head) * self.scale) + + def forward(self, q): + B, num_heads, HW, _ = q.shape + + # relative logits in width dimension. + q = q.reshape(B * num_heads, self.height, self.width, -1) + rel_logits_w = rel_logits_1d(q, self.width_rel, permute_mask=(0, 1, 3, 2, 4)) + + # relative logits in height dimension. + q = q.transpose(1, 2) + rel_logits_h = rel_logits_1d(q, self.height_rel, permute_mask=(0, 3, 1, 4, 2)) + + rel_logits = rel_logits_h + rel_logits_w + rel_logits = rel_logits.reshape(B, num_heads, HW, HW) + return rel_logits + + +class BottleneckAttn(nn.Module): + """ Bottleneck Attention + Paper: `Bottleneck Transformers for Visual Recognition` - https://arxiv.org/abs/2101.11605 + """ + def __init__(self, dim, dim_out=None, feat_size=None, stride=1, num_heads=4, qkv_bias=False): + super().__init__() + assert feat_size is not None, 'A concrete feature size matching expected input (H, W) is required' + dim_out = dim_out or dim + assert dim_out % num_heads == 0 + self.num_heads = num_heads + self.dim_out = dim_out + self.dim_head = dim_out // num_heads + self.scale = self.dim_head ** -0.5 + + self.qkv = nn.Conv2d(dim, self.dim_out * 3, 1, bias=qkv_bias) + + # NOTE I'm only supporting relative pos embedding for now + self.pos_embed = PosEmbedRel(feat_size, dim_head=self.dim_head, scale=self.scale) + + self.pool = nn.AvgPool2d(2, 2) if stride == 2 else nn.Identity() + + def forward(self, x): + B, C, H, W = x.shape + assert H == self.pos_embed.height and W == self.pos_embed.width + + x = self.qkv(x) # B, 3 * num_heads * dim_head, H, W + x = x.reshape(B, -1, self.dim_head, H * W).transpose(-1, -2) + q, k, v = torch.split(x, self.num_heads, dim=1) + + attn_logits = (q @ k.transpose(-1, -2)) * self.scale + attn_logits = attn_logits + self.pos_embed(q) # B, num_heads, H * W, H * W + + attn_out = attn_logits.softmax(dim = -1) + attn_out = (attn_out @ v).transpose(1, 2).reshape(B, self.dim_out, H, W) # B, dim_out, H, W + attn_out = self.pool(attn_out) + return attn_out + + diff --git a/timm/models/layers/create_self_attn.py b/timm/models/layers/create_self_attn.py new file mode 100644 index 00000000..8c0984c8 --- /dev/null +++ b/timm/models/layers/create_self_attn.py @@ -0,0 +1,17 @@ +from .bottleneck_attn import BottleneckAttn +from .halo_attn import HaloAttn +from .lambda_layer import LambdaLayer + + +def get_self_attn(attn_type): + if attn_type == 'bottleneck': + return BottleneckAttn + elif attn_type == 'halo': + return HaloAttn + elif attn_type == 'lambda': + return LambdaLayer + + +def create_self_attn(attn_type, dim, stride=1, **kwargs): + attn_fn = get_self_attn(attn_type) + return attn_fn(dim, stride=stride, **kwargs) diff --git a/timm/models/layers/halo_attn.py b/timm/models/layers/halo_attn.py new file mode 100644 index 00000000..bd5d1b45 --- /dev/null +++ b/timm/models/layers/halo_attn.py @@ -0,0 +1,157 @@ +""" Halo Self Attention + +Paper: `Scaling Local Self-Attention for Parameter Efficient Visual Backbones` + - https://arxiv.org/abs/2103.12731 + +@misc{2103.12731, +Author = {Ashish Vaswani and Prajit Ramachandran and Aravind Srinivas and Niki Parmar and Blake Hechtman and + Jonathon Shlens}, +Title = {Scaling Local Self-Attention for Parameter Efficient Visual Backbones}, +Year = {2021}, +} + +Status: +This impl is a WIP, there is no official ref impl and some details in paper weren't clear to me. + +Trying to match the 'H1' variant in the paper, my parameter counts are 2M less and the model +is extremely slow. Something isn't right. However, the models do appear to train and experimental +variants with attn in C4 and/or C5 stages are tolerable speed. + +Hacked together by / Copyright 2021 Ross Wightman +""" +from typing import Tuple, List + +import torch +from torch import nn +import torch.nn.functional as F + + +def rel_logits_1d(q, rel_k, permute_mask: List[int]): + """ Compute relative logits along one dimension + + As per: https://gist.github.com/aravindsrinivas/56359b79f0ce4449bcb04ab4b56a57a2 + Originally from: `Attention Augmented Convolutional Networks` - https://arxiv.org/abs/1904.09925 + + Args: + q: (batch, height, width, dim) + rel_k: (2 * window - 1, dim) + permute_mask: permute output dim according to this + """ + B, H, W, dim = q.shape + rel_size = rel_k.shape[0] + win_size = (rel_size + 1) // 2 + + x = (q @ rel_k.transpose(-1, -2)) + x = x.reshape(-1, W, rel_size) + + # pad to shift from relative to absolute indexing + x_pad = F.pad(x, [0, 1]).flatten(1) + x_pad = F.pad(x_pad, [0, rel_size - W]) + + # reshape and slice out the padded elements + x_pad = x_pad.reshape(-1, W + 1, rel_size) + x = x_pad[:, :W, win_size - 1:] + + # reshape and tile + x = x.reshape(B, H, 1, W, win_size).expand(-1, -1, win_size, -1, -1) + return x.permute(permute_mask) + + +class PosEmbedRel(nn.Module): + """ Relative Position Embedding + As per: https://gist.github.com/aravindsrinivas/56359b79f0ce4449bcb04ab4b56a57a2 + Originally from: `Attention Augmented Convolutional Networks` - https://arxiv.org/abs/1904.09925 + + """ + def __init__(self, block_size, win_size, dim_head, scale): + """ + Args: + block_size (int): block size + win_size (int): neighbourhood window size + dim_head (int): attention head dim + scale (float): scale factor (for init) + """ + super().__init__() + self.block_size = block_size + self.dim_head = dim_head + self.scale = scale + self.height_rel = nn.Parameter(torch.randn(win_size * 2 - 1, dim_head) * self.scale) + self.width_rel = nn.Parameter(torch.randn(win_size * 2 - 1, dim_head) * self.scale) + + def forward(self, q): + B, BB, HW, _ = q.shape + + # relative logits in width dimension. + q = q.reshape(-1, self.block_size, self.block_size, self.dim_head) + rel_logits_w = rel_logits_1d(q, self.width_rel, permute_mask=(0, 1, 3, 2, 4)) + + # relative logits in height dimension. + q = q.transpose(1, 2) + rel_logits_h = rel_logits_1d(q, self.height_rel, permute_mask=(0, 3, 1, 4, 2)) + + rel_logits = rel_logits_h + rel_logits_w + rel_logits = rel_logits.reshape(B, BB, HW, -1) + return rel_logits + + +class HaloAttn(nn.Module): + """ Halo Attention + + Paper: `Scaling Local Self-Attention for Parameter Efficient Visual Backbones` + - https://arxiv.org/abs/2103.12731 + """ + def __init__( + self, dim, dim_out=None, stride=1, num_heads=8, dim_head=16, block_size=8, halo_size=3, qkv_bias=False): + super().__init__() + dim_out = dim_out or dim + assert dim_out % num_heads == 0 + self.stride = stride + self.num_heads = num_heads + self.dim_head = dim_head + self.dim_qk = num_heads * dim_head + self.dim_v = dim_out + self.block_size = block_size + self.halo_size = halo_size + self.win_size = block_size + halo_size * 2 # neighbourhood window size + self.scale = self.dim_head ** -0.5 + + # FIXME not clear if this stride behaviour is what the paper intended, not really clear + # Also, the paper mentions using a 3D conv for dealing with the blocking/gather, and leaving + # data in unfolded block form. I haven't wrapped my head around how that'd look. + self.q = nn.Conv2d(dim, self.dim_qk, 1, stride=self.stride, bias=qkv_bias) + self.kv = nn.Conv2d(dim, self.dim_qk + self.dim_v, 1, bias=qkv_bias) + + self.pos_embed = PosEmbedRel( + block_size=block_size // self.stride, win_size=self.win_size, dim_head=self.dim_head, scale=self.scale) + + def forward(self, x): + B, C, H, W = x.shape + assert H % self.block_size == 0 and W % self.block_size == 0 + num_h_blocks = H // self.block_size + num_w_blocks = W // self.block_size + num_blocks = num_h_blocks * num_w_blocks + + q = self.q(x) + q = F.unfold(q, kernel_size=self.block_size // self.stride, stride=self.block_size // self.stride) + # B, num_heads * dim_head * block_size ** 2, num_blocks + q = q.reshape(B * self.num_heads, self.dim_head, -1, num_blocks).transpose(1, 3) + # B * num_heads, num_blocks, block_size ** 2, dim_head + + kv = self.kv(x) + # FIXME I 'think' this unfold does what I want it to, but I should investigate + k = F.unfold(kv, kernel_size=self.win_size, stride=self.block_size, padding=self.halo_size) + k = k.reshape( + B * self.num_heads, self.dim_head + (self.dim_v // self.num_heads), -1, num_blocks).transpose(1, 3) + k, v = torch.split(k, [self.dim_head, self.dim_v // self.num_heads], dim=-1) + + attn_logits = (q @ k.transpose(-1, -2)) * self.scale # FIXME should usual attn scale be applied? + attn_logits = attn_logits + self.pos_embed(q) # B * num_heads, block_size ** 2, win_size ** 2 + + attn_out = attn_logits.softmax(dim=-1) + attn_out = (attn_out @ v).transpose(1, 3) # B * num_heads, dim_v // num_heads, block_size ** 2, num_blocks + attn_out = F.fold( + attn_out.reshape(B, -1, num_blocks), + (H // self.stride, W // self.stride), + kernel_size=self.block_size // self.stride, stride=self.block_size // self.stride) + # B, dim_out, H // stride, W // stride + return attn_out diff --git a/timm/models/layers/lambda_layer.py b/timm/models/layers/lambda_layer.py new file mode 100644 index 00000000..bdaebb5d --- /dev/null +++ b/timm/models/layers/lambda_layer.py @@ -0,0 +1,78 @@ +""" Lambda Layer + +Paper: `LambdaNetworks: Modeling Long-Range Interactions Without Attention` + - https://arxiv.org/abs/2102.08602 + +@misc{2102.08602, +Author = {Irwan Bello}, +Title = {LambdaNetworks: Modeling Long-Range Interactions Without Attention}, +Year = {2021}, +} + +Status: +This impl is a WIP. Code snippets in the paper were used as reference but +good chance some details are missing/wrong. + +I've only implemented local lambda conv based pos embeddings. + +For a PyTorch impl that includes other embedding options checkout +https://github.com/lucidrains/lambda-networks + +Hacked together by / Copyright 2021 Ross Wightman +""" +import torch +from torch import nn +import torch.nn.functional as F + + + +class LambdaLayer(nn.Module): + """Lambda Layer w/ lambda conv position embedding + + Paper: `LambdaNetworks: Modeling Long-Range Interactions Without Attention` + - https://arxiv.org/abs/2102.08602 + """ + def __init__( + self, + dim, dim_out=None, stride=1, num_heads=4, dim_head=16, r=5, qkv_bias=False): + super().__init__() + self.dim_out = dim_out or dim + self.dim_k = dim_head # query depth 'k' + self.num_heads = num_heads + assert self.dim_out % num_heads == 0, ' should be divided by num_heads' + self.dim_v = self.dim_out // num_heads # value depth 'v' + self.r = r # relative position neighbourhood (lambda conv kernel size) + + self.qkv = nn.Conv2d( + dim, + num_heads * dim_head + dim_head + self.dim_v, + kernel_size=1, bias=qkv_bias) + self.norm_q = nn.BatchNorm2d(num_heads * dim_head) + self.norm_v = nn.BatchNorm2d(self.dim_v) + + # NOTE currently only supporting the local lambda convolutions for positional + self.conv_lambda = nn.Conv3d(1, dim_head, (r, r, 1), padding=(r // 2, r // 2, 0)) + + self.pool = nn.AvgPool2d(2, 2) if stride == 2 else nn.Identity() + + def forward(self, x): + B, C, H, W = x.shape + M = H * W + + qkv = self.qkv(x) + q, k, v = torch.split(qkv, [ + self.num_heads * self.dim_k, self.dim_k, self.dim_v], dim=1) + q = self.norm_q(q).reshape(B, self.num_heads, self.dim_k, M).transpose(-1, -2) # B, num_heads, M, K + v = self.norm_v(v).reshape(B, self.dim_v, M).transpose(-1, -2) # B, M, V + k = F.softmax(k.reshape(B, self.dim_k, M), dim=-1) # B, K, M + + content_lam = k @ v # B, K, V + content_out = q @ content_lam.unsqueeze(1) # B, num_heads, M, V + + position_lam = self.conv_lambda(v.reshape(B, 1, H, W, self.dim_v)) # B, H, W, V, K + position_lam = position_lam.reshape(B, 1, self.dim_k, H * W, self.dim_v).transpose(2, 3) # B, 1, M, K, V + position_out = (q.unsqueeze(-2) @ position_lam).squeeze(-2) # B, num_heads, M, V + + out = (content_out + position_out).transpose(3, 1).reshape(B, C, H, W) # B, C (num_heads * V), H, W + out = self.pool(out) + return out