Merge pull request #556 from rwightman/byoanet-self_attn

ByoaNet - Self Attn Networks - Bottleneck Transformers, Lambda ResNet, HaloNet
pull/571/head
Ross Wightman 4 years ago committed by GitHub
commit ce6585f533
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -5,7 +5,8 @@ import os
import fnmatch import fnmatch
import timm import timm
from timm import list_models, create_model, set_scriptable from timm import list_models, create_model, set_scriptable, has_model_default_key, is_model_default_key, \
get_model_default_value
if hasattr(torch._C, '_jit_set_profiling_executor'): if hasattr(torch._C, '_jit_set_profiling_executor'):
# legacy executor is too slow to compile large models for unit tests # legacy executor is too slow to compile large models for unit tests
@ -60,9 +61,15 @@ def test_model_backward(model_name, batch_size):
model.eval() model.eval()
input_size = model.default_cfg['input_size'] input_size = model.default_cfg['input_size']
if any([x > MAX_BWD_SIZE for x in input_size]): if not is_model_default_key(model_name, 'fixed_input_size'):
# cap backward test at 128 * 128 to keep resource usage down min_input_size = get_model_default_value(model_name, 'min_input_size')
input_size = tuple([min(x, MAX_BWD_SIZE) for x in input_size]) if min_input_size is not None:
input_size = min_input_size
else:
if any([x > MAX_BWD_SIZE for x in input_size]):
# cap backward test at 128 * 128 to keep resource usage down
input_size = tuple([min(x, MAX_BWD_SIZE) for x in input_size])
inputs = torch.randn((batch_size, *input_size)) inputs = torch.randn((batch_size, *input_size))
outputs = model(inputs) outputs = model(inputs)
outputs.mean().backward() outputs.mean().backward()
@ -155,7 +162,14 @@ def test_model_forward_torchscript(model_name, batch_size):
with set_scriptable(True): with set_scriptable(True):
model = create_model(model_name, pretrained=False) model = create_model(model_name, pretrained=False)
model.eval() model.eval()
input_size = (3, 128, 128) # jit compile is already a bit slow and we've tested normal res already...
if has_model_default_key(model_name, 'fixed_input_size'):
input_size = get_model_default_value(model_name, 'input_size')
elif has_model_default_key(model_name, 'min_input_size'):
input_size = get_model_default_value(model_name, 'min_input_size')
else:
input_size = (3, 128, 128) # jit compile is already a bit slow and we've tested normal res already...
model = torch.jit.script(model) model = torch.jit.script(model)
outputs = model(torch.randn((batch_size, *input_size))) outputs = model(torch.randn((batch_size, *input_size)))
@ -180,7 +194,14 @@ def test_model_forward_features(model_name, batch_size):
model.eval() model.eval()
expected_channels = model.feature_info.channels() expected_channels = model.feature_info.channels()
assert len(expected_channels) >= 4 # all models here should have at least 4 feature levels by default, some 5 or 6 assert len(expected_channels) >= 4 # all models here should have at least 4 feature levels by default, some 5 or 6
input_size = (3, 96, 96) # jit compile is already a bit slow and we've tested normal res already...
if has_model_default_key(model_name, 'fixed_input_size'):
input_size = get_model_default_value(model_name, 'input_size')
elif has_model_default_key(model_name, 'min_input_size'):
input_size = get_model_default_value(model_name, 'min_input_size')
else:
input_size = (3, 96, 96) # jit compile is already a bit slow and we've tested normal res already...
outputs = model(torch.randn((batch_size, *input_size))) outputs = model(torch.randn((batch_size, *input_size)))
assert len(expected_channels) == len(outputs) assert len(expected_channels) == len(outputs)
for e, o in zip(expected_channels, outputs): for e, o in zip(expected_channels, outputs):

@ -1,3 +1,4 @@
from .version import __version__ from .version import __version__
from .models import create_model, list_models, is_model, list_modules, model_entrypoint, \ from .models import create_model, list_models, is_model, list_modules, model_entrypoint, \
is_scriptable, is_exportable, set_scriptable, set_exportable is_scriptable, is_exportable, set_scriptable, set_exportable, has_model_default_key, is_model_default_key, \
get_model_default_value, is_model_pretrained

@ -1,3 +1,4 @@
from .byoanet import *
from .byobnet import * from .byobnet import *
from .cspnet import * from .cspnet import *
from .densenet import * from .densenet import *
@ -39,5 +40,5 @@ from .helpers import load_checkpoint, resume_checkpoint, model_parameters
from .layers import TestTimePoolHead, apply_test_time_pool from .layers import TestTimePoolHead, apply_test_time_pool
from .layers import convert_splitbn_model from .layers import convert_splitbn_model
from .layers import is_scriptable, is_exportable, set_scriptable, set_exportable, is_no_jit, set_no_jit from .layers import is_scriptable, is_exportable, set_scriptable, set_exportable, is_no_jit, set_no_jit
from .registry import * from .registry import register_model, model_entrypoint, list_models, is_model, list_modules, is_model_in_modules,\
has_model_default_key, is_model_default_key, get_model_default_value, is_model_pretrained

@ -0,0 +1,430 @@
""" Bring-Your-Own-Attention Network
A flexible network w/ dataclass based config for stacking NN blocks including
self-attention (or similar) layers.
Currently used to implement experimential variants of:
* Bottleneck Transformers
* Lambda ResNets
* HaloNets
Consider all of the models definitions here as experimental WIP and likely to change.
Hacked together by / copyright Ross Wightman, 2021.
"""
import math
from dataclasses import dataclass, field
from collections import OrderedDict
from typing import Tuple, List, Optional, Union, Any, Callable
from functools import partial
import torch
import torch.nn as nn
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .byobnet import BlocksCfg, ByobCfg, create_byob_stem, create_byob_stages, create_downsample,\
reduce_feat_size, register_block, num_groups, LayerFn, _init_weights
from .helpers import build_model_with_cfg
from .layers import ClassifierHead, ConvBnAct, DropPath, get_act_layer, convert_norm_act, get_attn, get_self_attn,\
make_divisible, to_2tuple
from .registry import register_model
__all__ = ['ByoaNet']
def _cfg(url='', **kwargs):
return {
'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
'crop_pct': 0.875, 'interpolation': 'bilinear',
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
'first_conv': 'stem.conv1.conv', 'classifier': 'head.fc',
'fixed_input_size': False, 'min_input_size': (3, 224, 224),
**kwargs
}
default_cfgs = {
# GPU-Efficient (ResNet) weights
'botnet50t_224': _cfg(url='', fixed_input_size=True),
'botnet50t_c4c5_224': _cfg(url='', fixed_input_size=True),
'halonet_h1': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8), min_input_size=(3, 256, 256)),
'halonet_h1_c4c5': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8), min_input_size=(3, 256, 256)),
'halonet26t': _cfg(url=''),
'halonet50t': _cfg(url=''),
'lambda_resnet26t': _cfg(url='', min_input_size=(3, 128, 128)),
'lambda_resnet50t': _cfg(url='', min_input_size=(3, 128, 128)),
}
@dataclass
class ByoaBlocksCfg(BlocksCfg):
# FIXME allow overriding self_attn layer or args per block/stage,
pass
@dataclass
class ByoaCfg(ByobCfg):
blocks: Tuple[Union[ByoaBlocksCfg, Tuple[ByoaBlocksCfg, ...]], ...] = None
self_attn_layer: Optional[str] = None
self_attn_fixed_size: bool = False
self_attn_kwargs: dict = field(default_factory=lambda: dict())
def interleave_attn(
types : Tuple[str, str], every: Union[int, List[int]], d, first: bool = False, **kwargs
) -> Tuple[ByoaBlocksCfg]:
""" interleave attn blocks
"""
assert len(types) == 2
if isinstance(every, int):
every = list(range(0 if first else every, d, every))
if not every:
every = [d - 1]
set(every)
blocks = []
for i in range(d):
block_type = types[1] if i in every else types[0]
blocks += [ByoaBlocksCfg(type=block_type, d=1, **kwargs)]
return tuple(blocks)
model_cfgs = dict(
botnet50t=ByoaCfg(
blocks=(
ByoaBlocksCfg(type='bottle', d=3, c=256, s=2, gs=0, br=0.25),
ByoaBlocksCfg(type='bottle', d=4, c=512, s=2, gs=0, br=0.25),
ByoaBlocksCfg(type='bottle', d=6, c=1024, s=2, gs=0, br=0.25),
ByoaBlocksCfg(type='self_attn', d=3, c=2048, s=1, gs=0, br=0.25),
),
stem_chs=64,
stem_type='tiered',
stem_pool='',
num_features=0,
self_attn_layer='bottleneck',
self_attn_fixed_size=True,
self_attn_kwargs=dict()
),
botnet50t_c4c5=ByoaCfg(
blocks=(
ByoaBlocksCfg(type='bottle', d=3, c=256, s=1, gs=0, br=0.25),
ByoaBlocksCfg(type='bottle', d=4, c=512, s=2, gs=0, br=0.25),
(
ByoaBlocksCfg(type='self_attn', d=1, c=1024, s=2, gs=0, br=0.25),
ByoaBlocksCfg(type='bottle', d=5, c=1024, s=1, gs=0, br=0.25),
),
(
ByoaBlocksCfg(type='self_attn', d=1, c=2048, s=2, gs=0, br=0.25),
ByoaBlocksCfg(type='bottle', d=2, c=2048, s=1, gs=0, br=0.25),
)
),
stem_chs=64,
stem_type='tiered',
stem_pool='maxpool',
num_features=0,
self_attn_layer='bottleneck',
self_attn_fixed_size=True,
self_attn_kwargs=dict()
),
halonet_h1=ByoaCfg(
blocks=(
ByoaBlocksCfg(type='self_attn', d=3, c=64, s=1, gs=0, br=1.0),
ByoaBlocksCfg(type='self_attn', d=3, c=128, s=2, gs=0, br=1.0),
ByoaBlocksCfg(type='self_attn', d=10, c=256, s=2, gs=0, br=1.0),
ByoaBlocksCfg(type='self_attn', d=3, c=512, s=2, gs=0, br=1.0),
),
stem_chs=64,
stem_type='7x7',
stem_pool='maxpool',
num_features=0,
self_attn_layer='halo',
self_attn_kwargs=dict(block_size=8, halo_size=3),
),
halonet_h1_c4c5=ByoaCfg(
blocks=(
ByoaBlocksCfg(type='bottle', d=3, c=64, s=1, gs=0, br=1.0),
ByoaBlocksCfg(type='bottle', d=3, c=128, s=2, gs=0, br=1.0),
ByoaBlocksCfg(type='self_attn', d=10, c=256, s=2, gs=0, br=1.0),
ByoaBlocksCfg(type='self_attn', d=3, c=512, s=2, gs=0, br=1.0),
),
stem_chs=64,
stem_type='tiered',
stem_pool='maxpool',
num_features=0,
self_attn_layer='halo',
self_attn_kwargs=dict(block_size=8, halo_size=3),
),
halonet26t=ByoaCfg(
blocks=(
ByoaBlocksCfg(type='bottle', d=2, c=256, s=1, gs=0, br=0.25),
ByoaBlocksCfg(type='bottle', d=2, c=512, s=2, gs=0, br=0.25),
ByoaBlocksCfg(type='bottle', d=2, c=1024, s=2, gs=0, br=0.25),
ByoaBlocksCfg(type='self_attn', d=2, c=2048, s=2, gs=0, br=0.25),
),
stem_chs=64,
stem_type='tiered',
stem_pool='maxpool',
num_features=0,
self_attn_layer='halo',
self_attn_kwargs=dict(block_size=7, halo_size=2)
),
halonet50t=ByoaCfg(
blocks=(
ByoaBlocksCfg(type='bottle', d=3, c=256, s=1, gs=0, br=0.25),
ByoaBlocksCfg(type='bottle', d=4, c=512, s=2, gs=0, br=0.25),
ByoaBlocksCfg(type='bottle', d=6, c=1024, s=2, gs=0, br=0.25),
ByoaBlocksCfg(type='self_attn', d=3, c=2048, s=2, gs=0, br=0.25),
),
stem_chs=64,
stem_type='tiered',
stem_pool='maxpool',
num_features=0,
self_attn_layer='halo',
self_attn_kwargs=dict(block_size=7, halo_size=2)
),
lambda_resnet26t=ByoaCfg(
blocks=(
ByoaBlocksCfg(type='bottle', d=2, c=256, s=1, gs=0, br=0.25),
ByoaBlocksCfg(type='bottle', d=2, c=512, s=2, gs=0, br=0.25),
interleave_attn(types=('bottle', 'self_attn'), every=1, d=2, c=1024, s=2, gs=0, br=0.25),
ByoaBlocksCfg(type='self_attn', d=2, c=2048, s=2, gs=0, br=0.25),
),
stem_chs=64,
stem_type='tiered',
stem_pool='maxpool',
num_features=0,
self_attn_layer='lambda',
self_attn_kwargs=dict()
),
lambda_resnet50t=ByoaCfg(
blocks=(
ByoaBlocksCfg(type='bottle', d=3, c=256, s=1, gs=0, br=0.25),
ByoaBlocksCfg(type='bottle', d=4, c=512, s=2, gs=0, br=0.25),
interleave_attn(types=('bottle', 'self_attn'), every=3, d=6, c=1024, s=2, gs=0, br=0.25),
ByoaBlocksCfg(type='self_attn', d=3, c=2048, s=2, gs=0, br=0.25),
),
stem_chs=64,
stem_type='tiered',
stem_pool='maxpool',
num_features=0,
self_attn_layer='lambda',
self_attn_kwargs=dict()
),
)
@dataclass
class ByoaLayerFn(LayerFn):
self_attn: Optional[Callable] = None
class SelfAttnBlock(nn.Module):
""" ResNet-like Bottleneck Block - 1x1 - optional kxk - self attn - 1x1
"""
def __init__(self, in_chs, out_chs, kernel_size=3, stride=1, dilation=(1, 1), bottle_ratio=1., group_size=None,
downsample='avg', extra_conv=False, linear_out=False, post_attn_na=True, feat_size=None,
layers: ByoaLayerFn = None, drop_block=None, drop_path_rate=0.):
super(SelfAttnBlock, self).__init__()
assert layers is not None
mid_chs = make_divisible(out_chs * bottle_ratio)
groups = num_groups(group_size, mid_chs)
if in_chs != out_chs or stride != 1 or dilation[0] != dilation[1]:
self.shortcut = create_downsample(
downsample, in_chs=in_chs, out_chs=out_chs, stride=stride, dilation=dilation[0],
apply_act=False, layers=layers)
else:
self.shortcut = nn.Identity()
self.conv1_1x1 = layers.conv_norm_act(in_chs, mid_chs, 1)
if extra_conv:
self.conv2_kxk = layers.conv_norm_act(
mid_chs, mid_chs, kernel_size, stride=stride, dilation=dilation[0],
groups=groups, drop_block=drop_block)
stride = 1 # striding done via conv if enabled
else:
self.conv2_kxk = nn.Identity()
opt_kwargs = {} if feat_size is None else dict(feat_size=feat_size)
# FIXME need to dilate self attn to have dilated network support, moop moop
self.self_attn = layers.self_attn(mid_chs, stride=stride, **opt_kwargs)
self.post_attn = layers.norm_act(mid_chs) if post_attn_na else nn.Identity()
self.conv3_1x1 = layers.conv_norm_act(mid_chs, out_chs, 1, apply_act=False)
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity()
self.act = nn.Identity() if linear_out else layers.act(inplace=True)
def init_weights(self, zero_init_last_bn=False):
if zero_init_last_bn:
nn.init.zeros_(self.conv3_1x1.bn.weight)
def forward(self, x):
shortcut = self.shortcut(x)
x = self.conv1_1x1(x)
x = self.conv2_kxk(x)
x = self.self_attn(x)
x = self.post_attn(x)
x = self.conv3_1x1(x)
x = self.drop_path(x)
x = self.act(x + shortcut)
return x
register_block('self_attn', SelfAttnBlock)
def _byoa_block_args(block_kwargs, block_cfg: ByoaBlocksCfg, model_cfg: ByoaCfg, feat_size=None):
if block_cfg.type == 'self_attn' and model_cfg.self_attn_fixed_size:
assert feat_size is not None
block_kwargs['feat_size'] = feat_size
return block_kwargs
def get_layer_fns(cfg: ByoaCfg):
act = get_act_layer(cfg.act_layer)
norm_act = convert_norm_act(norm_layer=cfg.norm_layer, act_layer=act)
conv_norm_act = partial(ConvBnAct, norm_layer=cfg.norm_layer, act_layer=act)
attn = partial(get_attn(cfg.attn_layer), **cfg.attn_kwargs) if cfg.attn_layer else None
self_attn = partial(get_self_attn(cfg.self_attn_layer), **cfg.self_attn_kwargs) if cfg.self_attn_layer else None
layer_fn = ByoaLayerFn(
conv_norm_act=conv_norm_act, norm_act=norm_act, act=act, attn=attn, self_attn=self_attn)
return layer_fn
class ByoaNet(nn.Module):
""" 'Bring-your-own-attention' Net
A ResNet inspired backbone that supports interleaving traditional residual blocks with
'Self Attention' bottleneck blocks that replace the bottleneck kxk conv w/ a self-attention
or similar module.
FIXME This class network definition is almost the same as ByobNet, I'd like to merge them but
torchscript limitations prevent sensible inheritance overrides.
"""
def __init__(self, cfg: ByoaCfg, num_classes=1000, in_chans=3, output_stride=32, global_pool='avg',
zero_init_last_bn=True, img_size=None, drop_rate=0., drop_path_rate=0.):
super().__init__()
self.num_classes = num_classes
self.drop_rate = drop_rate
layers = get_layer_fns(cfg)
feat_size = to_2tuple(img_size) if img_size is not None else None
self.feature_info = []
stem_chs = int(round((cfg.stem_chs or cfg.blocks[0].c) * cfg.width_factor))
self.stem, stem_feat = create_byob_stem(in_chans, stem_chs, cfg.stem_type, cfg.stem_pool, layers=layers)
self.feature_info.extend(stem_feat[:-1])
feat_size = reduce_feat_size(feat_size, stride=stem_feat[-1]['reduction'])
self.stages, stage_feat = create_byob_stages(
cfg, drop_path_rate, output_stride, stem_feat[-1],
feat_size=feat_size, layers=layers, extra_args_fn=_byoa_block_args)
self.feature_info.extend(stage_feat[:-1])
prev_chs = stage_feat[-1]['num_chs']
if cfg.num_features:
self.num_features = int(round(cfg.width_factor * cfg.num_features))
self.final_conv = layers.conv_norm_act(prev_chs, self.num_features, 1)
else:
self.num_features = prev_chs
self.final_conv = nn.Identity()
self.feature_info += [
dict(num_chs=self.num_features, reduction=stage_feat[-1]['reduction'], module='final_conv')]
self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate)
for n, m in self.named_modules():
_init_weights(m, n)
for m in self.modules():
# call each block's weight init for block-specific overrides to init above
if hasattr(m, 'init_weights'):
m.init_weights(zero_init_last_bn=zero_init_last_bn)
def get_classifier(self):
return self.head.fc
def reset_classifier(self, num_classes, global_pool='avg'):
self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate)
def forward_features(self, x):
x = self.stem(x)
x = self.stages(x)
x = self.final_conv(x)
return x
def forward(self, x):
x = self.forward_features(x)
x = self.head(x)
return x
def _create_byoanet(variant, cfg_variant=None, pretrained=False, **kwargs):
return build_model_with_cfg(
ByoaNet, variant, pretrained,
default_cfg=default_cfgs[variant],
model_cfg=model_cfgs[variant] if not cfg_variant else model_cfgs[cfg_variant],
feature_cfg=dict(flatten_sequential=True),
**kwargs)
@register_model
def botnet50t_224(pretrained=False, **kwargs):
""" Bottleneck Transformer w/ ResNet50-T backbone. Bottleneck attn in final stage.
"""
kwargs.setdefault('img_size', 224)
return _create_byoanet('botnet50t_224', 'botnet50t', pretrained=pretrained, **kwargs)
@register_model
def botnet50t_c4c5_224(pretrained=False, **kwargs):
""" Bottleneck Transformer w/ ResNet50-T backbone. Bottleneck attn in last two stages.
"""
kwargs.setdefault('img_size', 224)
return _create_byoanet('botnet50t_c4c5_224', 'botnet50t_c4c5', pretrained=pretrained, **kwargs)
@register_model
def halonet_h1(pretrained=False, **kwargs):
""" HaloNet-H1. Halo attention in all stages as per the paper.
This runs very slowly, param count lower than paper --> something is wrong.
"""
return _create_byoanet('halonet_h1', pretrained=pretrained, **kwargs)
@register_model
def halonet_h1_c4c5(pretrained=False, **kwargs):
""" HaloNet-H1 config w/ attention in last two stages.
"""
return _create_byoanet('halonet_h1_c4c5', pretrained=pretrained, **kwargs)
@register_model
def halonet26t(pretrained=False, **kwargs):
""" HaloNet w/ a ResNet26-t backbone, Hallo attention in final stage
"""
return _create_byoanet('halonet26t', pretrained=pretrained, **kwargs)
@register_model
def halonet50t(pretrained=False, **kwargs):
""" HaloNet w/ a ResNet50-t backbone, Hallo attention in final stage
"""
return _create_byoanet('halonet50t', pretrained=pretrained, **kwargs)
@register_model
def lambda_resnet26t(pretrained=False, **kwargs):
""" Lambda-ResNet-26T. Lambda layers in one C4 stage and all C5.
"""
return _create_byoanet('lambda_resnet26t', pretrained=pretrained, **kwargs)
@register_model
def lambda_resnet50t(pretrained=False, **kwargs):
""" Lambda-ResNet-50T. Lambda layers in one C4 stage and all C5.
"""
return _create_byoanet('lambda_resnet50t', pretrained=pretrained, **kwargs)

@ -25,9 +25,9 @@ above nets that include attention.
Hacked together by / copyright Ross Wightman, 2021. Hacked together by / copyright Ross Wightman, 2021.
""" """
import math import math
from dataclasses import dataclass, field from dataclasses import dataclass, field, replace
from collections import OrderedDict from collections import OrderedDict
from typing import Tuple, Dict, Optional, Union, Any, Callable from typing import Tuple, List, Optional, Union, Any, Callable, Sequence
from functools import partial from functools import partial
import torch import torch
@ -35,11 +35,11 @@ import torch.nn as nn
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .helpers import build_model_with_cfg from .helpers import build_model_with_cfg
from .layers import ClassifierHead, ConvBnAct, DropPath, AvgPool2dSame, \ from .layers import ClassifierHead, ConvBnAct, BatchNormAct2d, DropPath, AvgPool2dSame, \
create_conv2d, get_act_layer, get_attn, convert_norm_act, make_divisible create_conv2d, get_act_layer, convert_norm_act, get_attn, make_divisible
from .registry import register_model from .registry import register_model
__all__ = ['ByobNet', 'ByobCfg', 'BlocksCfg'] __all__ = ['ByobNet', 'ByobCfg', 'BlocksCfg', 'create_byob_stem', 'create_block']
def _cfg(url='', **kwargs): def _cfg(url='', **kwargs):
@ -98,20 +98,22 @@ class BlocksCfg:
s: int = 2 # stride of stage (first block) s: int = 2 # stride of stage (first block)
gs: Optional[Union[int, Callable]] = None # group-size of blocks in stage, conv is depthwise if gs == 1 gs: Optional[Union[int, Callable]] = None # group-size of blocks in stage, conv is depthwise if gs == 1
br: float = 1. # bottleneck-ratio of blocks in stage br: float = 1. # bottleneck-ratio of blocks in stage
no_attn: bool = True # disable channel attn (ie SE) when layer is set for model
@dataclass @dataclass
class ByobCfg: class ByobCfg:
blocks: Tuple[BlocksCfg, ...] blocks: Tuple[Union[BlocksCfg, Tuple[BlocksCfg, ...]], ...]
downsample: str = 'conv1x1' downsample: str = 'conv1x1'
stem_type: str = '3x3' stem_type: str = '3x3'
stem_pool: str = ''
stem_chs: int = 32 stem_chs: int = 32
width_factor: float = 1.0 width_factor: float = 1.0
num_features: int = 0 # num out_channels for final conv, no final 1x1 conv if 0 num_features: int = 0 # num out_channels for final conv, no final 1x1 conv if 0
zero_init_last_bn: bool = True zero_init_last_bn: bool = True
act_layer: str = 'relu' act_layer: str = 'relu'
norm_layer: nn.Module = nn.BatchNorm2d norm_layer: str = 'batchnorm'
attn_layer: Optional[str] = None attn_layer: Optional[str] = None
attn_kwargs: dict = field(default_factory=lambda: dict()) attn_kwargs: dict = field(default_factory=lambda: dict())
@ -201,17 +203,29 @@ model_cfgs = dict(
stem_type='rep', stem_type='rep',
stem_chs=64, stem_chs=64,
), ),
)
def _na_args(cfg: dict): resnet52q=ByobCfg(
return dict( blocks=(
norm_layer=cfg.get('norm_layer', nn.BatchNorm2d), BlocksCfg(type='bottle', d=2, c=256, s=1, gs=32, br=0.25),
act_layer=cfg.get('act_layer', nn.ReLU)) BlocksCfg(type='bottle', d=4, c=512, s=2, gs=32, br=0.25),
BlocksCfg(type='bottle', d=6, c=1536, s=2, gs=32, br=0.25),
BlocksCfg(type='bottle', d=4, c=1536, s=2, gs=1, br=1.0),
),
stem_chs=128,
stem_type='quad',
num_features=2048,
act_layer='silu',
),
)
def _ex_tuple(cfg: dict, *names): def expand_blocks_cfg(stage_blocks_cfg: Union[BlocksCfg, Sequence[BlocksCfg]]) -> List[BlocksCfg]:
return tuple([cfg.get(n, None) for n in names]) if not isinstance(stage_blocks_cfg, Sequence):
stage_blocks_cfg = (stage_blocks_cfg,)
block_cfgs = []
for i, cfg in enumerate(stage_blocks_cfg):
block_cfgs += [replace(cfg, d=1) for _ in range(cfg.d)]
return block_cfgs
def num_groups(group_size, channels): def num_groups(group_size, channels):
@ -223,27 +237,36 @@ def num_groups(group_size, channels):
return channels // group_size return channels // group_size
@dataclass
class LayerFn:
conv_norm_act: Callable = ConvBnAct
norm_act: Callable = BatchNormAct2d
act: Callable = nn.ReLU
attn: Optional[Callable] = None
class DownsampleAvg(nn.Module): class DownsampleAvg(nn.Module):
def __init__(self, in_chs, out_chs, stride=1, dilation=1, apply_act=False, norm_layer=None, act_layer=None): def __init__(self, in_chs, out_chs, stride=1, dilation=1, apply_act=False, layers: LayerFn = None):
""" AvgPool Downsampling as in 'D' ResNet variants.""" """ AvgPool Downsampling as in 'D' ResNet variants."""
super(DownsampleAvg, self).__init__() super(DownsampleAvg, self).__init__()
layers = layers or LayerFn()
avg_stride = stride if dilation == 1 else 1 avg_stride = stride if dilation == 1 else 1
if stride > 1 or dilation > 1: if stride > 1 or dilation > 1:
avg_pool_fn = AvgPool2dSame if avg_stride == 1 and dilation > 1 else nn.AvgPool2d avg_pool_fn = AvgPool2dSame if avg_stride == 1 and dilation > 1 else nn.AvgPool2d
self.pool = avg_pool_fn(2, avg_stride, ceil_mode=True, count_include_pad=False) self.pool = avg_pool_fn(2, avg_stride, ceil_mode=True, count_include_pad=False)
else: else:
self.pool = nn.Identity() self.pool = nn.Identity()
self.conv = ConvBnAct(in_chs, out_chs, 1, apply_act=apply_act, norm_layer=norm_layer, act_layer=act_layer) self.conv = layers.conv_norm_act(in_chs, out_chs, 1, apply_act=apply_act)
def forward(self, x): def forward(self, x):
return self.conv(self.pool(x)) return self.conv(self.pool(x))
def create_downsample(type, **kwargs): def create_downsample(downsample_type, layers: LayerFn, **kwargs):
if type == 'avg': if downsample_type == 'avg':
return DownsampleAvg(**kwargs) return DownsampleAvg(**kwargs)
else: else:
return ConvBnAct(kwargs.pop('in_chs'), kwargs.pop('out_chs'), kernel_size=1, **kwargs) return layers.conv_norm_act(kwargs.pop('in_chs'), kwargs.pop('out_chs'), kernel_size=1, **kwargs)
class BasicBlock(nn.Module): class BasicBlock(nn.Module):
@ -252,28 +275,25 @@ class BasicBlock(nn.Module):
def __init__( def __init__(
self, in_chs, out_chs, kernel_size=3, stride=1, dilation=(1, 1), group_size=None, bottle_ratio=1.0, self, in_chs, out_chs, kernel_size=3, stride=1, dilation=(1, 1), group_size=None, bottle_ratio=1.0,
downsample='avg', linear_out=False, layer_cfg=None, drop_block=None, drop_path_rate=0.): downsample='avg', linear_out=False, layers: LayerFn = None, drop_block=None, drop_path_rate=0.):
super(BasicBlock, self).__init__() super(BasicBlock, self).__init__()
layer_cfg = layer_cfg or {} layers = layers or LayerFn()
act_layer, attn_layer = _ex_tuple(layer_cfg, 'act_layer', 'attn_layer')
layer_args = _na_args(layer_cfg)
mid_chs = make_divisible(out_chs * bottle_ratio) mid_chs = make_divisible(out_chs * bottle_ratio)
groups = num_groups(group_size, mid_chs) groups = num_groups(group_size, mid_chs)
if in_chs != out_chs or stride != 1 or dilation[0] != dilation[1]: if in_chs != out_chs or stride != 1 or dilation[0] != dilation[1]:
self.shortcut = create_downsample( self.shortcut = create_downsample(
downsample, in_chs=in_chs, out_chs=out_chs, stride=stride, dilation=dilation[0], downsample, in_chs=in_chs, out_chs=out_chs, stride=stride, dilation=dilation[0],
apply_act=False, **layer_args) apply_act=False, layers=layers)
else: else:
self.shortcut = nn.Identity() self.shortcut = nn.Identity()
self.conv1_kxk = ConvBnAct(in_chs, mid_chs, kernel_size, stride=stride, dilation=dilation[0], **layer_args) self.conv1_kxk = layers.conv_norm_act(in_chs, mid_chs, kernel_size, stride=stride, dilation=dilation[0])
self.conv2_kxk = ConvBnAct( self.conv2_kxk = layers.conv_norm_act(
mid_chs, out_chs, kernel_size, dilation=dilation[1], groups=groups, mid_chs, out_chs, kernel_size, dilation=dilation[1], groups=groups, drop_block=drop_block, apply_act=False)
drop_block=drop_block, apply_act=False, **layer_args) self.attn = nn.Identity() if layers.attn is None else layers.attn(out_chs)
self.attn = nn.Identity() if attn_layer is None else attn_layer(out_chs)
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity() self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity()
self.act = nn.Identity() if linear_out else act_layer(inplace=True) self.act = nn.Identity() if linear_out else layers.act(inplace=True)
def init_weights(self, zero_init_last_bn=False): def init_weights(self, zero_init_last_bn=False):
if zero_init_last_bn: if zero_init_last_bn:
@ -297,29 +317,27 @@ class BottleneckBlock(nn.Module):
""" """
def __init__(self, in_chs, out_chs, kernel_size=3, stride=1, dilation=(1, 1), bottle_ratio=1., group_size=None, def __init__(self, in_chs, out_chs, kernel_size=3, stride=1, dilation=(1, 1), bottle_ratio=1., group_size=None,
downsample='avg', linear_out=False, layer_cfg=None, drop_block=None, drop_path_rate=0.): downsample='avg', linear_out=False, layers : LayerFn = None, drop_block=None, drop_path_rate=0.):
super(BottleneckBlock, self).__init__() super(BottleneckBlock, self).__init__()
layer_cfg = layer_cfg or {} layers = layers or LayerFn()
act_layer, attn_layer = _ex_tuple(layer_cfg, 'act_layer', 'attn_layer')
layer_args = _na_args(layer_cfg)
mid_chs = make_divisible(out_chs * bottle_ratio) mid_chs = make_divisible(out_chs * bottle_ratio)
groups = num_groups(group_size, mid_chs) groups = num_groups(group_size, mid_chs)
if in_chs != out_chs or stride != 1 or dilation[0] != dilation[1]: if in_chs != out_chs or stride != 1 or dilation[0] != dilation[1]:
self.shortcut = create_downsample( self.shortcut = create_downsample(
downsample, in_chs=in_chs, out_chs=out_chs, stride=stride, dilation=dilation[0], downsample, in_chs=in_chs, out_chs=out_chs, stride=stride, dilation=dilation[0],
apply_act=False, **layer_args) apply_act=False, layers=layers)
else: else:
self.shortcut = nn.Identity() self.shortcut = nn.Identity()
self.conv1_1x1 = ConvBnAct(in_chs, mid_chs, 1, **layer_args) self.conv1_1x1 = layers.conv_norm_act(in_chs, mid_chs, 1)
self.conv2_kxk = ConvBnAct( self.conv2_kxk = layers.conv_norm_act(
mid_chs, mid_chs, kernel_size, stride=stride, dilation=dilation[0], mid_chs, mid_chs, kernel_size, stride=stride, dilation=dilation[0],
groups=groups, drop_block=drop_block, **layer_args) groups=groups, drop_block=drop_block)
self.attn = nn.Identity() if attn_layer is None else attn_layer(mid_chs) self.attn = nn.Identity() if layers.attn is None else layers.attn(mid_chs)
self.conv3_1x1 = ConvBnAct(mid_chs, out_chs, 1, apply_act=False, **layer_args) self.conv3_1x1 = layers.conv_norm_act(mid_chs, out_chs, 1, apply_act=False)
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity() self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity()
self.act = nn.Identity() if linear_out else act_layer(inplace=True) self.act = nn.Identity() if linear_out else layers.act(inplace=True)
def init_weights(self, zero_init_last_bn=False): def init_weights(self, zero_init_last_bn=False):
if zero_init_last_bn: if zero_init_last_bn:
@ -350,28 +368,26 @@ class DarkBlock(nn.Module):
""" """
def __init__(self, in_chs, out_chs, kernel_size=3, stride=1, dilation=(1, 1), bottle_ratio=1.0, group_size=None, def __init__(self, in_chs, out_chs, kernel_size=3, stride=1, dilation=(1, 1), bottle_ratio=1.0, group_size=None,
downsample='avg', linear_out=False, layer_cfg=None, drop_block=None, drop_path_rate=0.): downsample='avg', linear_out=False, layers: LayerFn = None, drop_block=None, drop_path_rate=0.):
super(DarkBlock, self).__init__() super(DarkBlock, self).__init__()
layer_cfg = layer_cfg or {} layers = layers or LayerFn()
act_layer, attn_layer = _ex_tuple(layer_cfg, 'act_layer', 'attn_layer')
layer_args = _na_args(layer_cfg)
mid_chs = make_divisible(out_chs * bottle_ratio) mid_chs = make_divisible(out_chs * bottle_ratio)
groups = num_groups(group_size, mid_chs) groups = num_groups(group_size, mid_chs)
if in_chs != out_chs or stride != 1 or dilation[0] != dilation[1]: if in_chs != out_chs or stride != 1 or dilation[0] != dilation[1]:
self.shortcut = create_downsample( self.shortcut = create_downsample(
downsample, in_chs=in_chs, out_chs=out_chs, stride=stride, dilation=dilation[0], downsample, in_chs=in_chs, out_chs=out_chs, stride=stride, dilation=dilation[0],
apply_act=False, **layer_args) apply_act=False, layers=layers)
else: else:
self.shortcut = nn.Identity() self.shortcut = nn.Identity()
self.conv1_1x1 = ConvBnAct(in_chs, mid_chs, 1, **layer_args) self.conv1_1x1 = layers.conv_norm_act(in_chs, mid_chs, 1)
self.conv2_kxk = ConvBnAct( self.conv2_kxk = layers.conv_norm_act(
mid_chs, out_chs, kernel_size, stride=stride, dilation=dilation[0], mid_chs, out_chs, kernel_size, stride=stride, dilation=dilation[0],
groups=groups, drop_block=drop_block, apply_act=False, **layer_args) groups=groups, drop_block=drop_block, apply_act=False)
self.attn = nn.Identity() if attn_layer is None else attn_layer(out_chs) self.attn = nn.Identity() if layers.attn is None else layers.attn(out_chs)
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity() self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity()
self.act = nn.Identity() if linear_out else act_layer(inplace=True) self.act = nn.Identity() if linear_out else layers.act(inplace=True)
def init_weights(self, zero_init_last_bn=False): def init_weights(self, zero_init_last_bn=False):
if zero_init_last_bn: if zero_init_last_bn:
@ -399,28 +415,26 @@ class EdgeBlock(nn.Module):
""" """
def __init__(self, in_chs, out_chs, kernel_size=3, stride=1, dilation=(1, 1), bottle_ratio=1.0, group_size=None, def __init__(self, in_chs, out_chs, kernel_size=3, stride=1, dilation=(1, 1), bottle_ratio=1.0, group_size=None,
downsample='avg', linear_out=False, layer_cfg=None, drop_block=None, drop_path_rate=0.): downsample='avg', linear_out=False, layers: LayerFn = None, drop_block=None, drop_path_rate=0.):
super(EdgeBlock, self).__init__() super(EdgeBlock, self).__init__()
layer_cfg = layer_cfg or {} layers = layers or LayerFn()
act_layer, attn_layer = _ex_tuple(layer_cfg, 'act_layer', 'attn_layer')
layer_args = _na_args(layer_cfg)
mid_chs = make_divisible(out_chs * bottle_ratio) mid_chs = make_divisible(out_chs * bottle_ratio)
groups = num_groups(group_size, mid_chs) groups = num_groups(group_size, mid_chs)
if in_chs != out_chs or stride != 1 or dilation[0] != dilation[1]: if in_chs != out_chs or stride != 1 or dilation[0] != dilation[1]:
self.shortcut = create_downsample( self.shortcut = create_downsample(
downsample, in_chs=in_chs, out_chs=out_chs, stride=stride, dilation=dilation[0], downsample, in_chs=in_chs, out_chs=out_chs, stride=stride, dilation=dilation[0],
apply_act=False, **layer_args) apply_act=False, layers=layers)
else: else:
self.shortcut = nn.Identity() self.shortcut = nn.Identity()
self.conv1_kxk = ConvBnAct( self.conv1_kxk = layers.conv_norm_act(
in_chs, mid_chs, kernel_size, stride=stride, dilation=dilation[0], in_chs, mid_chs, kernel_size, stride=stride, dilation=dilation[0],
groups=groups, drop_block=drop_block, **layer_args) groups=groups, drop_block=drop_block)
self.attn = nn.Identity() if attn_layer is None else attn_layer(out_chs) self.attn = nn.Identity() if layers.attn is None else layers.attn(out_chs)
self.conv2_1x1 = ConvBnAct(mid_chs, out_chs, 1, apply_act=False, **layer_args) self.conv2_1x1 = layers.conv_norm_act(mid_chs, out_chs, 1, apply_act=False)
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity() self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity()
self.act = nn.Identity() if linear_out else act_layer(inplace=True) self.act = nn.Identity() if linear_out else layers.act(inplace=True)
def init_weights(self, zero_init_last_bn=False): def init_weights(self, zero_init_last_bn=False):
if zero_init_last_bn: if zero_init_last_bn:
@ -446,23 +460,20 @@ class RepVggBlock(nn.Module):
""" """
def __init__(self, in_chs, out_chs, kernel_size=3, stride=1, dilation=(1, 1), bottle_ratio=1.0, group_size=None, def __init__(self, in_chs, out_chs, kernel_size=3, stride=1, dilation=(1, 1), bottle_ratio=1.0, group_size=None,
downsample='', layer_cfg=None, drop_block=None, drop_path_rate=0.): downsample='', layers : LayerFn = None, drop_block=None, drop_path_rate=0.):
super(RepVggBlock, self).__init__() super(RepVggBlock, self).__init__()
layer_cfg = layer_cfg or {} layers = layers or LayerFn()
act_layer, norm_layer, attn_layer = _ex_tuple(layer_cfg, 'act_layer', 'norm_layer', 'attn_layer')
norm_layer = convert_norm_act(norm_layer=norm_layer, act_layer=act_layer)
layer_args = _na_args(layer_cfg)
groups = num_groups(group_size, in_chs) groups = num_groups(group_size, in_chs)
use_ident = in_chs == out_chs and stride == 1 and dilation[0] == dilation[1] use_ident = in_chs == out_chs and stride == 1 and dilation[0] == dilation[1]
self.identity = norm_layer(out_chs, apply_act=False) if use_ident else None self.identity = layers.norm_act(out_chs, apply_act=False) if use_ident else None
self.conv_kxk = ConvBnAct( self.conv_kxk = layers.conv_norm_act(
in_chs, out_chs, kernel_size, stride=stride, dilation=dilation[0], in_chs, out_chs, kernel_size, stride=stride, dilation=dilation[0],
groups=groups, drop_block=drop_block, apply_act=False, **layer_args) groups=groups, drop_block=drop_block, apply_act=False)
self.conv_1x1 = ConvBnAct(in_chs, out_chs, 1, stride=stride, groups=groups, apply_act=False, **layer_args) self.conv_1x1 = layers.conv_norm_act(in_chs, out_chs, 1, stride=stride, groups=groups, apply_act=False)
self.attn = nn.Identity() if attn_layer is None else attn_layer(out_chs) self.attn = nn.Identity() if layers.attn is None else layers.attn(out_chs)
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. and use_ident else nn.Identity() self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. and use_ident else nn.Identity()
self.act = act_layer(inplace=True) self.act = layers.act(inplace=True)
def init_weights(self, zero_init_last_bn=False): def init_weights(self, zero_init_last_bn=False):
# NOTE this init overrides that base model init with specific changes for the block type # NOTE this init overrides that base model init with specific changes for the block type
@ -504,33 +515,154 @@ def create_block(block: Union[str, nn.Module], **kwargs):
return _block_registry[block](**kwargs) return _block_registry[block](**kwargs)
def create_stem(in_chs, out_chs, stem_type='', layer_cfg=None): class Stem(nn.Sequential):
layer_cfg = layer_cfg or {}
layer_args = _na_args(layer_cfg) def __init__(self, in_chs, out_chs, kernel_size=3, stride=4, pool='maxpool',
assert stem_type in ('', 'deep', 'deep_tiered', '3x3', '7x7', 'rep') num_rep=3, num_act=None, chs_decay=0.5, layers: LayerFn = None):
if 'deep' in stem_type: super().__init__()
# 3 deep 3x3 conv stack assert stride in (2, 4)
stem = OrderedDict() layers = layers or LayerFn()
stem_chs = (out_chs // 2, out_chs // 2)
if 'tiered' in stem_type: if isinstance(out_chs, (list, tuple)):
stem_chs = (3 * stem_chs[0] // 4, stem_chs[1]) num_rep = len(out_chs)
norm_layer, act_layer = _ex_tuple(layer_args, 'norm_layer', 'act_layer') stem_chs = out_chs
stem['conv1'] = create_conv2d(in_chs, stem_chs[0], kernel_size=3, stride=2) else:
stem['conv2'] = create_conv2d(stem_chs[0], stem_chs[1], kernel_size=3, stride=1) stem_chs = [round(out_chs * chs_decay ** i) for i in range(num_rep)][::-1]
stem['conv3'] = create_conv2d(stem_chs[1], out_chs, kernel_size=3, stride=1)
norm_act_layer = convert_norm_act(norm_layer=norm_layer, act_layer=act_layer) self.stride = stride
stem['na'] = norm_act_layer(out_chs) self.feature_info = [] # track intermediate features
stem = nn.Sequential(stem) prev_feat = ''
stem_strides = [2] + [1] * (num_rep - 1)
if stride == 4 and not pool:
# set last conv in stack to be strided if stride == 4 and no pooling layer
stem_strides[-1] = 2
num_act = num_rep if num_act is None else num_act
# if num_act < num_rep, first convs in stack won't have bn + act
stem_norm_acts = [False] * (num_rep - num_act) + [True] * num_act
prev_chs = in_chs
curr_stride = 1
for i, (ch, s, na) in enumerate(zip(stem_chs, stem_strides, stem_norm_acts)):
layer_fn = layers.conv_norm_act if na else create_conv2d
conv_name = f'conv{i + 1}'
if i > 0 and s > 1:
self.feature_info.append(dict(num_chs=prev_chs, reduction=curr_stride, module=prev_feat))
self.add_module(conv_name, layer_fn(prev_chs, ch, kernel_size=kernel_size, stride=s))
prev_chs = ch
curr_stride *= s
prev_feat = conv_name
if 'max' in pool.lower():
self.feature_info.append(dict(num_chs=prev_chs, reduction=curr_stride, module=prev_feat))
self.add_module('pool', nn.MaxPool2d(3, 2, 1))
curr_stride *= 2
prev_feat = 'pool'
self.feature_info.append(dict(num_chs=prev_chs, reduction=curr_stride, module=prev_feat))
assert curr_stride == stride
def create_byob_stem(in_chs, out_chs, stem_type='', pool_type='', feat_prefix='stem', layers: LayerFn = None):
layers = layers or LayerFn()
assert stem_type in ('', 'quad', 'tiered', 'deep', 'rep', '7x7', '3x3')
if 'quad' in stem_type:
# based on NFNet stem, stack of 4 3x3 convs
num_act = 2 if 'quad2' in stem_type else None
stem = Stem(in_chs, out_chs, num_rep=4, num_act=num_act, pool=pool_type, layers=layers)
elif 'tiered' in stem_type:
# 3x3 stack of 3 convs as in my ResNet-T
stem = Stem(in_chs, (3 * out_chs // 8, out_chs // 2, out_chs), pool=pool_type, layers=layers)
elif 'deep' in stem_type:
# 3x3 stack of 3 convs as in ResNet-D
stem = Stem(in_chs, out_chs, num_rep=3, chs_decay=1.0, pool=pool_type, layers=layers)
elif 'rep' in stem_type:
stem = RepVggBlock(in_chs, out_chs, stride=2, layers=layers)
elif '7x7' in stem_type: elif '7x7' in stem_type:
# 7x7 stem conv as in ResNet # 7x7 stem conv as in ResNet
stem = ConvBnAct(in_chs, out_chs, 7, stride=2, **layer_args) if pool_type:
elif 'rep' in stem_type: stem = Stem(in_chs, out_chs, 7, num_rep=1, pool=pool_type, layers=layers)
stem = RepVggBlock(in_chs, out_chs, stride=2, layer_cfg=layer_cfg) else:
stem = layers.conv_norm_act(in_chs, out_chs, 7, stride=2)
else: else:
# 3x3 stem conv as in RegNet # 3x3 stem conv as in RegNet is the default
stem = ConvBnAct(in_chs, out_chs, 3, stride=2, **layer_args) if pool_type:
stem = Stem(in_chs, out_chs, 3, num_rep=1, pool=pool_type, layers=layers)
else:
stem = layers.conv_norm_act(in_chs, out_chs, 3, stride=2)
return stem if isinstance(stem, Stem):
feature_info = [dict(f, module='.'.join([feat_prefix, f['module']])) for f in stem.feature_info]
else:
feature_info = [dict(num_chs=out_chs, reduction=2, module=feat_prefix)]
return stem, feature_info
def reduce_feat_size(feat_size, stride=2):
return None if feat_size is None else tuple([s // stride for s in feat_size])
def create_byob_stages(
cfg, drop_path_rate, output_stride, stem_feat,
feat_size=None, layers=None, extra_args_fn=None):
layers = layers or LayerFn()
feature_info = []
block_cfgs = [expand_blocks_cfg(s) for s in cfg.blocks]
depths = [sum([bc.d for bc in stage_bcs]) for stage_bcs in block_cfgs]
dpr = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depths)).split(depths)]
dilation = 1
net_stride = stem_feat['reduction']
prev_chs = stem_feat['num_chs']
prev_feat = stem_feat
stages = []
for stage_idx, stage_block_cfgs in enumerate(block_cfgs):
stride = stage_block_cfgs[0].s
if stride != 1 and prev_feat:
feature_info.append(prev_feat)
if net_stride >= output_stride and stride > 1:
dilation *= stride
stride = 1
net_stride *= stride
first_dilation = 1 if dilation in (1, 2) else 2
blocks = []
for block_idx, block_cfg in enumerate(stage_block_cfgs):
out_chs = make_divisible(block_cfg.c * cfg.width_factor)
group_size = block_cfg.gs
if isinstance(group_size, Callable):
group_size = group_size(out_chs, block_idx)
block_kwargs = dict( # Blocks used in this model must accept these arguments
in_chs=prev_chs,
out_chs=out_chs,
stride=stride if block_idx == 0 else 1,
dilation=(first_dilation, dilation),
group_size=group_size,
bottle_ratio=block_cfg.br,
downsample=cfg.downsample,
drop_path_rate=dpr[stage_idx][block_idx],
layers=layers,
)
if extra_args_fn is not None:
extra_args_fn(block_kwargs, block_cfg=block_cfg, model_cfg=cfg, feat_size=feat_size)
blocks += [create_block(block_cfg.type, **block_kwargs)]
first_dilation = dilation
prev_chs = out_chs
if stride > 1 and block_idx == 0:
feat_size = reduce_feat_size(feat_size, stride)
stages += [nn.Sequential(*blocks)]
prev_feat = dict(num_chs=prev_chs, reduction=net_stride, module=f'stages.{stage_idx}')
feature_info.append(prev_feat)
return nn.Sequential(*stages), feature_info
def get_layer_fns(cfg: ByobCfg):
act = get_act_layer(cfg.act_layer)
norm_act = convert_norm_act(norm_layer=cfg.norm_layer, act_layer=act)
conv_norm_act = partial(ConvBnAct, norm_layer=cfg.norm_layer, act_layer=act)
attn = partial(get_attn(cfg.attn_layer), **cfg.attn_kwargs) if cfg.attn_layer else None
layer_fn = LayerFn(conv_norm_act=conv_norm_act, norm_act=norm_act, act=act, attn=attn)
return layer_fn
class ByobNet(nn.Module): class ByobNet(nn.Module):
@ -546,79 +678,30 @@ class ByobNet(nn.Module):
super().__init__() super().__init__()
self.num_classes = num_classes self.num_classes = num_classes
self.drop_rate = drop_rate self.drop_rate = drop_rate
norm_layer = cfg.norm_layer layers = get_layer_fns(cfg)
act_layer = get_act_layer(cfg.act_layer)
attn_layer = partial(get_attn(cfg.attn_layer), **cfg.attn_kwargs) if cfg.attn_layer else None
layer_cfg = dict(norm_layer=norm_layer, act_layer=act_layer, attn_layer=attn_layer)
self.feature_info = []
stem_chs = int(round((cfg.stem_chs or cfg.blocks[0].c) * cfg.width_factor)) stem_chs = int(round((cfg.stem_chs or cfg.blocks[0].c) * cfg.width_factor))
self.stem = create_stem(in_chans, stem_chs, cfg.stem_type, layer_cfg=layer_cfg) self.stem, stem_feat = create_byob_stem(in_chans, stem_chs, cfg.stem_type, cfg.stem_pool, layers=layers)
self.feature_info.extend(stem_feat[:-1])
self.feature_info = [] self.stages, stage_feat = create_byob_stages(cfg, drop_path_rate, output_stride, stem_feat[-1], layers=layers)
depths = [bc.d for bc in cfg.blocks] self.feature_info.extend(stage_feat[:-1])
dpr = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depths)).split(depths)]
prev_name = 'stem'
prev_chs = stem_chs
net_stride = 2
dilation = 1
stages = []
for stage_idx, block_cfg in enumerate(cfg.blocks):
stride = block_cfg.s
if stride != 1:
self.feature_info.append(dict(num_chs=prev_chs, reduction=net_stride, module=prev_name))
if net_stride >= output_stride and stride > 1:
dilation *= stride
stride = 1
net_stride *= stride
first_dilation = 1 if dilation in (1, 2) else 2
blocks = []
for block_idx in range(block_cfg.d):
out_chs = make_divisible(block_cfg.c * cfg.width_factor)
group_size = block_cfg.gs
if isinstance(group_size, Callable):
group_size = group_size(out_chs, block_idx)
block_kwargs = dict( # Blocks used in this model must accept these arguments
in_chs=prev_chs,
out_chs=out_chs,
stride=stride if block_idx == 0 else 1,
dilation=(first_dilation, dilation),
group_size=group_size,
bottle_ratio=block_cfg.br,
downsample=cfg.downsample,
drop_path_rate=dpr[stage_idx][block_idx],
layer_cfg=layer_cfg,
)
blocks += [create_block(block_cfg.type, **block_kwargs)]
first_dilation = dilation
prev_chs = out_chs
stages += [nn.Sequential(*blocks)]
prev_name = f'stages.{stage_idx}'
self.stages = nn.Sequential(*stages)
prev_chs = stage_feat[-1]['num_chs']
if cfg.num_features: if cfg.num_features:
self.num_features = int(round(cfg.width_factor * cfg.num_features)) self.num_features = int(round(cfg.width_factor * cfg.num_features))
self.final_conv = ConvBnAct(prev_chs, self.num_features, 1, **_na_args(layer_cfg)) self.final_conv = layers.conv_norm_act(prev_chs, self.num_features, 1)
else: else:
self.num_features = prev_chs self.num_features = prev_chs
self.final_conv = nn.Identity() self.final_conv = nn.Identity()
self.feature_info += [dict(num_chs=self.num_features, reduction=net_stride, module='final_conv')] self.feature_info += [
dict(num_chs=self.num_features, reduction=stage_feat[-1]['reduction'], module='final_conv')]
self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate) self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate)
for n, m in self.named_modules(): for n, m in self.named_modules():
if isinstance(m, nn.Conv2d): _init_weights(m, n)
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
fan_out //= m.groups
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
if m.bias is not None:
m.bias.data.zero_()
elif isinstance(m, nn.Linear):
nn.init.normal_(m.weight, mean=0.0, std=0.01)
nn.init.zeros_(m.bias)
elif isinstance(m, nn.BatchNorm2d):
nn.init.ones_(m.weight)
nn.init.zeros_(m.bias)
for m in self.modules(): for m in self.modules():
# call each block's weight init for block-specific overrides to init above # call each block's weight init for block-specific overrides to init above
if hasattr(m, 'init_weights'): if hasattr(m, 'init_weights'):
@ -642,6 +725,22 @@ class ByobNet(nn.Module):
return x return x
def _init_weights(m, n=''):
if isinstance(m, nn.Conv2d):
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
fan_out //= m.groups
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
if m.bias is not None:
m.bias.data.zero_()
elif isinstance(m, nn.Linear):
nn.init.normal_(m.weight, mean=0.0, std=0.01)
if m.bias is not None:
nn.init.zeros_(m.bias)
elif isinstance(m, nn.BatchNorm2d):
nn.init.ones_(m.weight)
nn.init.zeros_(m.bias)
def _create_byobnet(variant, pretrained=False, **kwargs): def _create_byobnet(variant, pretrained=False, **kwargs):
return build_model_with_cfg( return build_model_with_cfg(
ByobNet, variant, pretrained, ByobNet, variant, pretrained,

@ -13,6 +13,7 @@ from .create_act import create_act_layer, get_act_layer, get_act_fn
from .create_attn import get_attn, create_attn from .create_attn import get_attn, create_attn
from .create_conv2d import create_conv2d from .create_conv2d import create_conv2d
from .create_norm_act import get_norm_act_layer, create_norm_act, convert_norm_act from .create_norm_act import get_norm_act_layer, create_norm_act, convert_norm_act
from .create_self_attn import get_self_attn, create_self_attn
from .drop import DropBlock2d, DropPath, drop_block_2d, drop_path from .drop import DropBlock2d, DropPath, drop_block_2d, drop_path
from .eca import EcaModule, CecaModule from .eca import EcaModule, CecaModule
from .evo_norm import EvoNormBatch2d, EvoNormSample2d from .evo_norm import EvoNormBatch2d, EvoNormSample2d
@ -20,6 +21,7 @@ from .helpers import to_ntuple, to_2tuple, to_3tuple, to_4tuple, make_divisible
from .inplace_abn import InplaceAbn from .inplace_abn import InplaceAbn
from .linear import Linear from .linear import Linear
from .mixed_conv2d import MixedConv2d from .mixed_conv2d import MixedConv2d
from .norm import GroupNorm
from .norm_act import BatchNormAct2d, GroupNormAct from .norm_act import BatchNormAct2d, GroupNormAct
from .padding import get_padding, get_same_padding, pad_same from .padding import get_padding, get_same_padding, pad_same
from .pool2d_same import AvgPool2dSame, create_pool2d from .pool2d_same import AvgPool2dSame, create_pool2d

@ -0,0 +1,120 @@
""" Bottleneck Self Attention (Bottleneck Transformers)
Paper: `Bottleneck Transformers for Visual Recognition` - https://arxiv.org/abs/2101.11605
@misc{2101.11605,
Author = {Aravind Srinivas and Tsung-Yi Lin and Niki Parmar and Jonathon Shlens and Pieter Abbeel and Ashish Vaswani},
Title = {Bottleneck Transformers for Visual Recognition},
Year = {2021},
}
Based on ref gist at: https://gist.github.com/aravindsrinivas/56359b79f0ce4449bcb04ab4b56a57a2
This impl is a WIP but given that it is based on the ref gist likely not too far off.
Hacked together by / Copyright 2021 Ross Wightman
"""
from typing import List
import torch
import torch.nn as nn
import torch.nn.functional as F
from .helpers import to_2tuple
def rel_logits_1d(q, rel_k, permute_mask: List[int]):
""" Compute relative logits along one dimension
As per: https://gist.github.com/aravindsrinivas/56359b79f0ce4449bcb04ab4b56a57a2
Originally from: `Attention Augmented Convolutional Networks` - https://arxiv.org/abs/1904.09925
Args:
q: (batch, heads, height, width, dim)
rel_k: (2 * width - 1, dim)
permute_mask: permute output dim according to this
"""
B, H, W, dim = q.shape
x = (q @ rel_k.transpose(-1, -2))
x = x.reshape(-1, W, 2 * W -1)
# pad to shift from relative to absolute indexing
x_pad = F.pad(x, [0, 1]).flatten(1)
x_pad = F.pad(x_pad, [0, W - 1])
# reshape and slice out the padded elements
x_pad = x_pad.reshape(-1, W + 1, 2 * W - 1)
x = x_pad[:, :W, W - 1:]
# reshape and tile
x = x.reshape(B, H, 1, W, W).expand(-1, -1, H, -1, -1)
return x.permute(permute_mask)
class PosEmbedRel(nn.Module):
""" Relative Position Embedding
As per: https://gist.github.com/aravindsrinivas/56359b79f0ce4449bcb04ab4b56a57a2
Originally from: `Attention Augmented Convolutional Networks` - https://arxiv.org/abs/1904.09925
"""
def __init__(self, feat_size, dim_head, scale):
super().__init__()
self.height, self.width = to_2tuple(feat_size)
self.dim_head = dim_head
self.scale = scale
self.height_rel = nn.Parameter(torch.randn(self.height * 2 - 1, dim_head) * self.scale)
self.width_rel = nn.Parameter(torch.randn(self.width * 2 - 1, dim_head) * self.scale)
def forward(self, q):
B, num_heads, HW, _ = q.shape
# relative logits in width dimension.
q = q.reshape(B * num_heads, self.height, self.width, -1)
rel_logits_w = rel_logits_1d(q, self.width_rel, permute_mask=(0, 1, 3, 2, 4))
# relative logits in height dimension.
q = q.transpose(1, 2)
rel_logits_h = rel_logits_1d(q, self.height_rel, permute_mask=(0, 3, 1, 4, 2))
rel_logits = rel_logits_h + rel_logits_w
rel_logits = rel_logits.reshape(B, num_heads, HW, HW)
return rel_logits
class BottleneckAttn(nn.Module):
""" Bottleneck Attention
Paper: `Bottleneck Transformers for Visual Recognition` - https://arxiv.org/abs/2101.11605
"""
def __init__(self, dim, dim_out=None, feat_size=None, stride=1, num_heads=4, qkv_bias=False):
super().__init__()
assert feat_size is not None, 'A concrete feature size matching expected input (H, W) is required'
dim_out = dim_out or dim
assert dim_out % num_heads == 0
self.num_heads = num_heads
self.dim_out = dim_out
self.dim_head = dim_out // num_heads
self.scale = self.dim_head ** -0.5
self.qkv = nn.Conv2d(dim, self.dim_out * 3, 1, bias=qkv_bias)
# NOTE I'm only supporting relative pos embedding for now
self.pos_embed = PosEmbedRel(feat_size, dim_head=self.dim_head, scale=self.scale)
self.pool = nn.AvgPool2d(2, 2) if stride == 2 else nn.Identity()
def forward(self, x):
B, C, H, W = x.shape
assert H == self.pos_embed.height and W == self.pos_embed.width
x = self.qkv(x) # B, 3 * num_heads * dim_head, H, W
x = x.reshape(B, -1, self.dim_head, H * W).transpose(-1, -2)
q, k, v = torch.split(x, self.num_heads, dim=1)
attn_logits = (q @ k.transpose(-1, -2)) * self.scale
attn_logits = attn_logits + self.pos_embed(q) # B, num_heads, H * W, H * W
attn_out = attn_logits.softmax(dim = -1)
attn_out = (attn_out @ v).transpose(1, 2).reshape(B, self.dim_out, H, W) # B, dim_out, H, W
attn_out = self.pool(attn_out)
return attn_out

@ -0,0 +1,17 @@
from .bottleneck_attn import BottleneckAttn
from .halo_attn import HaloAttn
from .lambda_layer import LambdaLayer
def get_self_attn(attn_type):
if attn_type == 'bottleneck':
return BottleneckAttn
elif attn_type == 'halo':
return HaloAttn
elif attn_type == 'lambda':
return LambdaLayer
def create_self_attn(attn_type, dim, stride=1, **kwargs):
attn_fn = get_self_attn(attn_type)
return attn_fn(dim, stride=stride, **kwargs)

@ -0,0 +1,157 @@
""" Halo Self Attention
Paper: `Scaling Local Self-Attention for Parameter Efficient Visual Backbones`
- https://arxiv.org/abs/2103.12731
@misc{2103.12731,
Author = {Ashish Vaswani and Prajit Ramachandran and Aravind Srinivas and Niki Parmar and Blake Hechtman and
Jonathon Shlens},
Title = {Scaling Local Self-Attention for Parameter Efficient Visual Backbones},
Year = {2021},
}
Status:
This impl is a WIP, there is no official ref impl and some details in paper weren't clear to me.
Trying to match the 'H1' variant in the paper, my parameter counts are 2M less and the model
is extremely slow. Something isn't right. However, the models do appear to train and experimental
variants with attn in C4 and/or C5 stages are tolerable speed.
Hacked together by / Copyright 2021 Ross Wightman
"""
from typing import Tuple, List
import torch
from torch import nn
import torch.nn.functional as F
def rel_logits_1d(q, rel_k, permute_mask: List[int]):
""" Compute relative logits along one dimension
As per: https://gist.github.com/aravindsrinivas/56359b79f0ce4449bcb04ab4b56a57a2
Originally from: `Attention Augmented Convolutional Networks` - https://arxiv.org/abs/1904.09925
Args:
q: (batch, height, width, dim)
rel_k: (2 * window - 1, dim)
permute_mask: permute output dim according to this
"""
B, H, W, dim = q.shape
rel_size = rel_k.shape[0]
win_size = (rel_size + 1) // 2
x = (q @ rel_k.transpose(-1, -2))
x = x.reshape(-1, W, rel_size)
# pad to shift from relative to absolute indexing
x_pad = F.pad(x, [0, 1]).flatten(1)
x_pad = F.pad(x_pad, [0, rel_size - W])
# reshape and slice out the padded elements
x_pad = x_pad.reshape(-1, W + 1, rel_size)
x = x_pad[:, :W, win_size - 1:]
# reshape and tile
x = x.reshape(B, H, 1, W, win_size).expand(-1, -1, win_size, -1, -1)
return x.permute(permute_mask)
class PosEmbedRel(nn.Module):
""" Relative Position Embedding
As per: https://gist.github.com/aravindsrinivas/56359b79f0ce4449bcb04ab4b56a57a2
Originally from: `Attention Augmented Convolutional Networks` - https://arxiv.org/abs/1904.09925
"""
def __init__(self, block_size, win_size, dim_head, scale):
"""
Args:
block_size (int): block size
win_size (int): neighbourhood window size
dim_head (int): attention head dim
scale (float): scale factor (for init)
"""
super().__init__()
self.block_size = block_size
self.dim_head = dim_head
self.scale = scale
self.height_rel = nn.Parameter(torch.randn(win_size * 2 - 1, dim_head) * self.scale)
self.width_rel = nn.Parameter(torch.randn(win_size * 2 - 1, dim_head) * self.scale)
def forward(self, q):
B, BB, HW, _ = q.shape
# relative logits in width dimension.
q = q.reshape(-1, self.block_size, self.block_size, self.dim_head)
rel_logits_w = rel_logits_1d(q, self.width_rel, permute_mask=(0, 1, 3, 2, 4))
# relative logits in height dimension.
q = q.transpose(1, 2)
rel_logits_h = rel_logits_1d(q, self.height_rel, permute_mask=(0, 3, 1, 4, 2))
rel_logits = rel_logits_h + rel_logits_w
rel_logits = rel_logits.reshape(B, BB, HW, -1)
return rel_logits
class HaloAttn(nn.Module):
""" Halo Attention
Paper: `Scaling Local Self-Attention for Parameter Efficient Visual Backbones`
- https://arxiv.org/abs/2103.12731
"""
def __init__(
self, dim, dim_out=None, stride=1, num_heads=8, dim_head=16, block_size=8, halo_size=3, qkv_bias=False):
super().__init__()
dim_out = dim_out or dim
assert dim_out % num_heads == 0
self.stride = stride
self.num_heads = num_heads
self.dim_head = dim_head
self.dim_qk = num_heads * dim_head
self.dim_v = dim_out
self.block_size = block_size
self.halo_size = halo_size
self.win_size = block_size + halo_size * 2 # neighbourhood window size
self.scale = self.dim_head ** -0.5
# FIXME not clear if this stride behaviour is what the paper intended, not really clear
# Also, the paper mentions using a 3D conv for dealing with the blocking/gather, and leaving
# data in unfolded block form. I haven't wrapped my head around how that'd look.
self.q = nn.Conv2d(dim, self.dim_qk, 1, stride=self.stride, bias=qkv_bias)
self.kv = nn.Conv2d(dim, self.dim_qk + self.dim_v, 1, bias=qkv_bias)
self.pos_embed = PosEmbedRel(
block_size=block_size // self.stride, win_size=self.win_size, dim_head=self.dim_head, scale=self.scale)
def forward(self, x):
B, C, H, W = x.shape
assert H % self.block_size == 0 and W % self.block_size == 0
num_h_blocks = H // self.block_size
num_w_blocks = W // self.block_size
num_blocks = num_h_blocks * num_w_blocks
q = self.q(x)
q = F.unfold(q, kernel_size=self.block_size // self.stride, stride=self.block_size // self.stride)
# B, num_heads * dim_head * block_size ** 2, num_blocks
q = q.reshape(B * self.num_heads, self.dim_head, -1, num_blocks).transpose(1, 3)
# B * num_heads, num_blocks, block_size ** 2, dim_head
kv = self.kv(x)
# FIXME I 'think' this unfold does what I want it to, but I should investigate
k = F.unfold(kv, kernel_size=self.win_size, stride=self.block_size, padding=self.halo_size)
k = k.reshape(
B * self.num_heads, self.dim_head + (self.dim_v // self.num_heads), -1, num_blocks).transpose(1, 3)
k, v = torch.split(k, [self.dim_head, self.dim_v // self.num_heads], dim=-1)
attn_logits = (q @ k.transpose(-1, -2)) * self.scale # FIXME should usual attn scale be applied?
attn_logits = attn_logits + self.pos_embed(q) # B * num_heads, block_size ** 2, win_size ** 2
attn_out = attn_logits.softmax(dim=-1)
attn_out = (attn_out @ v).transpose(1, 3) # B * num_heads, dim_v // num_heads, block_size ** 2, num_blocks
attn_out = F.fold(
attn_out.reshape(B, -1, num_blocks),
(H // self.stride, W // self.stride),
kernel_size=self.block_size // self.stride, stride=self.block_size // self.stride)
# B, dim_out, H // stride, W // stride
return attn_out

@ -0,0 +1,78 @@
""" Lambda Layer
Paper: `LambdaNetworks: Modeling Long-Range Interactions Without Attention`
- https://arxiv.org/abs/2102.08602
@misc{2102.08602,
Author = {Irwan Bello},
Title = {LambdaNetworks: Modeling Long-Range Interactions Without Attention},
Year = {2021},
}
Status:
This impl is a WIP. Code snippets in the paper were used as reference but
good chance some details are missing/wrong.
I've only implemented local lambda conv based pos embeddings.
For a PyTorch impl that includes other embedding options checkout
https://github.com/lucidrains/lambda-networks
Hacked together by / Copyright 2021 Ross Wightman
"""
import torch
from torch import nn
import torch.nn.functional as F
class LambdaLayer(nn.Module):
"""Lambda Layer w/ lambda conv position embedding
Paper: `LambdaNetworks: Modeling Long-Range Interactions Without Attention`
- https://arxiv.org/abs/2102.08602
"""
def __init__(
self,
dim, dim_out=None, stride=1, num_heads=4, dim_head=16, r=5, qkv_bias=False):
super().__init__()
self.dim_out = dim_out or dim
self.dim_k = dim_head # query depth 'k'
self.num_heads = num_heads
assert self.dim_out % num_heads == 0, ' should be divided by num_heads'
self.dim_v = self.dim_out // num_heads # value depth 'v'
self.r = r # relative position neighbourhood (lambda conv kernel size)
self.qkv = nn.Conv2d(
dim,
num_heads * dim_head + dim_head + self.dim_v,
kernel_size=1, bias=qkv_bias)
self.norm_q = nn.BatchNorm2d(num_heads * dim_head)
self.norm_v = nn.BatchNorm2d(self.dim_v)
# NOTE currently only supporting the local lambda convolutions for positional
self.conv_lambda = nn.Conv3d(1, dim_head, (r, r, 1), padding=(r // 2, r // 2, 0))
self.pool = nn.AvgPool2d(2, 2) if stride == 2 else nn.Identity()
def forward(self, x):
B, C, H, W = x.shape
M = H * W
qkv = self.qkv(x)
q, k, v = torch.split(qkv, [
self.num_heads * self.dim_k, self.dim_k, self.dim_v], dim=1)
q = self.norm_q(q).reshape(B, self.num_heads, self.dim_k, M).transpose(-1, -2) # B, num_heads, M, K
v = self.norm_v(v).reshape(B, self.dim_v, M).transpose(-1, -2) # B, M, V
k = F.softmax(k.reshape(B, self.dim_k, M), dim=-1) # B, K, M
content_lam = k @ v # B, K, V
content_out = q @ content_lam.unsqueeze(1) # B, num_heads, M, V
position_lam = self.conv_lambda(v.reshape(B, 1, H, W, self.dim_v)) # B, H, W, V, K
position_lam = position_lam.reshape(B, 1, self.dim_k, H * W, self.dim_v).transpose(2, 3) # B, 1, M, K, V
position_out = (q.unsqueeze(-2) @ position_lam).squeeze(-2) # B, num_heads, M, V
out = (content_out + position_out).transpose(3, 1).reshape(B, C, H, W) # B, C (num_heads * V), H, W
out = self.pool(out)
return out

@ -0,0 +1,14 @@
""" Normalization layers and wrappers
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
class GroupNorm(nn.GroupNorm):
def __init__(self, num_channels, num_groups, eps=1e-5, affine=True):
# NOTE num_channels is swapped to first arg for consistency in swapping norm layers with BN
super().__init__(num_groups, num_channels, eps=eps, affine=affine)
def forward(self, x):
return F.group_norm(x, self.num_groups, self.weight, self.bias, self.eps)

@ -6,13 +6,16 @@ import sys
import re import re
import fnmatch import fnmatch
from collections import defaultdict from collections import defaultdict
from copy import deepcopy
__all__ = ['list_models', 'is_model', 'model_entrypoint', 'list_modules', 'is_model_in_modules'] __all__ = ['list_models', 'is_model', 'model_entrypoint', 'list_modules', 'is_model_in_modules',
'is_model_default_key', 'has_model_default_key', 'get_model_default_value', 'is_model_pretrained']
_module_to_models = defaultdict(set) # dict of sets to check membership of model in module _module_to_models = defaultdict(set) # dict of sets to check membership of model in module
_model_to_module = {} # mapping of model names to module names _model_to_module = {} # mapping of model names to module names
_model_entrypoints = {} # mapping of model names to entrypoint fns _model_entrypoints = {} # mapping of model names to entrypoint fns
_model_has_pretrained = set() # set of model names that have pretrained weight url present _model_has_pretrained = set() # set of model names that have pretrained weight url present
_model_default_cfgs = dict() # central repo for model default_cfgs
def register_model(fn): def register_model(fn):
@ -37,6 +40,7 @@ def register_model(fn):
# this will catch all models that have entrypoint matching cfg key, but miss any aliasing # this will catch all models that have entrypoint matching cfg key, but miss any aliasing
# entrypoints or non-matching combos # entrypoints or non-matching combos
has_pretrained = 'url' in mod.default_cfgs[model_name] and 'http' in mod.default_cfgs[model_name]['url'] has_pretrained = 'url' in mod.default_cfgs[model_name] and 'http' in mod.default_cfgs[model_name]['url']
_model_default_cfgs[model_name] = deepcopy(mod.default_cfgs[model_name])
if has_pretrained: if has_pretrained:
_model_has_pretrained.add(model_name) _model_has_pretrained.add(model_name)
return fn return fn
@ -105,3 +109,31 @@ def is_model_in_modules(model_name, module_names):
assert isinstance(module_names, (tuple, list, set)) assert isinstance(module_names, (tuple, list, set))
return any(model_name in _module_to_models[n] for n in module_names) return any(model_name in _module_to_models[n] for n in module_names)
def has_model_default_key(model_name, cfg_key):
""" Query model default_cfgs for existence of a specific key.
"""
if model_name in _model_default_cfgs and cfg_key in _model_default_cfgs[model_name]:
return True
return False
def is_model_default_key(model_name, cfg_key):
""" Return truthy value for specified model default_cfg key, False if does not exist.
"""
if model_name in _model_default_cfgs and _model_default_cfgs[model_name].get(cfg_key, False):
return True
return False
def get_model_default_value(model_name, cfg_key):
""" Get a specific model default_cfg value by key. None if it doesn't exist.
"""
if model_name in _model_default_cfgs:
return _model_default_cfgs[model_name].get(cfg_key, None)
else:
return None
def is_model_pretrained(model_name):
return model_name in _model_has_pretrained

@ -9,4 +9,5 @@ from .metrics import AverageMeter, accuracy
from .misc import natural_key, add_bool_arg from .misc import natural_key, add_bool_arg
from .model import unwrap_model, get_state_dict from .model import unwrap_model, get_state_dict
from .model_ema import ModelEma, ModelEmaV2 from .model_ema import ModelEma, ModelEmaV2
from .random import random_seed
from .summary import update_summary, get_outdir from .summary import update_summary, get_outdir

@ -0,0 +1,9 @@
import random
import numpy as np
import torch
def random_seed(seed=42, rank=0):
torch.manual_seed(seed + rank)
np.random.seed(seed + rank)
random.seed(seed + rank)

@ -329,7 +329,7 @@ def main():
_logger.warning("Neither APEX or native Torch AMP is available, using float32. " _logger.warning("Neither APEX or native Torch AMP is available, using float32. "
"Install NVIDA apex or upgrade to PyTorch 1.6") "Install NVIDA apex or upgrade to PyTorch 1.6")
torch.manual_seed(args.seed + args.rank) random_seed(args.seed, args.rank)
model = create_model( model = create_model(
args.model, args.model,

Loading…
Cancel
Save