|
|
|
@ -51,6 +51,16 @@ def _cfg(url='', **kwargs):
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _cfgr(url='', **kwargs):
|
|
|
|
|
return {
|
|
|
|
|
'url': url, 'num_classes': 1000, 'input_size': (3, 256, 256), 'pool_size': (8, 8),
|
|
|
|
|
'crop_pct': 0.9, 'interpolation': 'bicubic',
|
|
|
|
|
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
|
|
|
|
|
'first_conv': 'stem.conv1.conv', 'classifier': 'head.fc',
|
|
|
|
|
**kwargs
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
default_cfgs = {
|
|
|
|
|
# GPU-Efficient (ResNet) weights
|
|
|
|
|
'gernet_s': _cfg(
|
|
|
|
@ -92,65 +102,50 @@ default_cfgs = {
|
|
|
|
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnet51q_ra2-d47dcc76.pth',
|
|
|
|
|
first_conv='stem.conv1', input_size=(3, 256, 256), pool_size=(8, 8),
|
|
|
|
|
test_input_size=(3, 288, 288), crop_pct=1.0),
|
|
|
|
|
'resnet61q': _cfg(
|
|
|
|
|
'resnet61q': _cfgr(
|
|
|
|
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnet61q_ra2-6afc536c.pth',
|
|
|
|
|
first_conv='stem.conv1.conv', input_size=(3, 256, 256), pool_size=(8, 8),
|
|
|
|
|
test_input_size=(3, 288, 288), crop_pct=1.0, interpolation='bicubic'),
|
|
|
|
|
|
|
|
|
|
'resnext26ts': _cfg(
|
|
|
|
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/resnext26ts_256_ra2-8bbd9106.pth',
|
|
|
|
|
first_conv='stem.conv1.conv', input_size=(3, 256, 256), pool_size=(8, 8), interpolation='bicubic'),
|
|
|
|
|
'gcresnext26ts': _cfg(
|
|
|
|
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/gcresnext26ts_256-e414378b.pth',
|
|
|
|
|
first_conv='stem.conv1.conv', input_size=(3, 256, 256), pool_size=(8, 8), interpolation='bicubic'),
|
|
|
|
|
'seresnext26ts': _cfg(
|
|
|
|
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/seresnext26ts_256-6f0d74a3.pth',
|
|
|
|
|
first_conv='stem.conv1.conv', input_size=(3, 256, 256), pool_size=(8, 8), interpolation='bicubic'),
|
|
|
|
|
'eca_resnext26ts': _cfg(
|
|
|
|
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/eca_resnext26ts_256-5a1d030f.pth',
|
|
|
|
|
first_conv='stem.conv1.conv', input_size=(3, 256, 256), pool_size=(8, 8), interpolation='bicubic'),
|
|
|
|
|
'bat_resnext26ts': _cfg(
|
|
|
|
|
test_input_size=(3, 288, 288), crop_pct=1.0),
|
|
|
|
|
|
|
|
|
|
'resnext26ts': _cfgr(
|
|
|
|
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/resnext26ts_256_ra2-8bbd9106.pth'),
|
|
|
|
|
'gcresnext26ts': _cfgr(
|
|
|
|
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/gcresnext26ts_256-e414378b.pth'),
|
|
|
|
|
'seresnext26ts': _cfgr(
|
|
|
|
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/seresnext26ts_256-6f0d74a3.pth'),
|
|
|
|
|
'eca_resnext26ts': _cfgr(
|
|
|
|
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/eca_resnext26ts_256-5a1d030f.pth'),
|
|
|
|
|
'bat_resnext26ts': _cfgr(
|
|
|
|
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/bat_resnext26ts_256-fa6fd595.pth',
|
|
|
|
|
first_conv='stem.conv1.conv', input_size=(3, 256, 256), pool_size=(8, 8), interpolation='bicubic',
|
|
|
|
|
min_input_size=(3, 256, 256)),
|
|
|
|
|
|
|
|
|
|
'resnet32ts': _cfg(
|
|
|
|
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/resnet32ts_256-aacf5250.pth',
|
|
|
|
|
first_conv='stem.conv1.conv', input_size=(3, 256, 256), pool_size=(8, 8), interpolation='bicubic'),
|
|
|
|
|
'resnet33ts': _cfg(
|
|
|
|
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/resnet33ts_256-e91b09a4.pth',
|
|
|
|
|
first_conv='stem.conv1.conv', input_size=(3, 256, 256), pool_size=(8, 8), interpolation='bicubic'),
|
|
|
|
|
'gcresnet33ts': _cfg(
|
|
|
|
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/gcresnet33ts_256-0e0cd345.pth',
|
|
|
|
|
first_conv='stem.conv1.conv', input_size=(3, 256, 256), pool_size=(8, 8), interpolation='bicubic'),
|
|
|
|
|
'seresnet33ts': _cfg(
|
|
|
|
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/seresnet33ts_256-f8ad44d9.pth',
|
|
|
|
|
first_conv='stem.conv1.conv', input_size=(3, 256, 256), pool_size=(8, 8), interpolation='bicubic'),
|
|
|
|
|
'eca_resnet33ts': _cfg(
|
|
|
|
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/eca_resnet33ts_256-8f98face.pth',
|
|
|
|
|
first_conv='stem.conv1.conv', input_size=(3, 256, 256), pool_size=(8, 8), interpolation='bicubic'),
|
|
|
|
|
|
|
|
|
|
'gcresnet50t': _cfg(
|
|
|
|
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/gcresnet50t_256-96374d1c.pth',
|
|
|
|
|
first_conv='stem.conv1.conv', input_size=(3, 256, 256), pool_size=(8, 8), interpolation='bicubic'),
|
|
|
|
|
|
|
|
|
|
'gcresnext50ts': _cfg(
|
|
|
|
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/gcresnext50ts_256-3e0f515e.pth',
|
|
|
|
|
first_conv='stem.conv1.conv', input_size=(3, 256, 256), pool_size=(8, 8), interpolation='bicubic'),
|
|
|
|
|
|
|
|
|
|
# experimental models
|
|
|
|
|
'regnetz_b': _cfg(
|
|
|
|
|
'resnet32ts': _cfgr(
|
|
|
|
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/resnet32ts_256-aacf5250.pth'),
|
|
|
|
|
'resnet33ts': _cfgr(
|
|
|
|
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/resnet33ts_256-e91b09a4.pth'),
|
|
|
|
|
'gcresnet33ts': _cfgr(
|
|
|
|
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/gcresnet33ts_256-0e0cd345.pth'),
|
|
|
|
|
'seresnet33ts': _cfgr(
|
|
|
|
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/seresnet33ts_256-f8ad44d9.pth'),
|
|
|
|
|
'eca_resnet33ts': _cfgr(
|
|
|
|
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/eca_resnet33ts_256-8f98face.pth'),
|
|
|
|
|
|
|
|
|
|
'gcresnet50t': _cfgr(
|
|
|
|
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/gcresnet50t_256-96374d1c.pth'),
|
|
|
|
|
|
|
|
|
|
'gcresnext50ts': _cfgr(
|
|
|
|
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/gcresnext50ts_256-3e0f515e.pth'),
|
|
|
|
|
|
|
|
|
|
# experimental models, likely to change ot be removed
|
|
|
|
|
'regnetz_b': _cfgr(
|
|
|
|
|
url='',
|
|
|
|
|
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
|
|
|
|
|
input_size=(3, 256, 256), pool_size=(8, 8), interpolation='bicubic'),
|
|
|
|
|
'regnetz_c': _cfg(
|
|
|
|
|
input_size=(3, 224, 224), pool_size=(7, 7), first_conv='stem.conv'),
|
|
|
|
|
'regnetz_c': _cfgr(
|
|
|
|
|
url='',
|
|
|
|
|
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
|
|
|
|
|
input_size=(3, 256, 256), pool_size=(8, 8), interpolation='bicubic'),
|
|
|
|
|
'regnetz_d': _cfg(
|
|
|
|
|
imean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), first_conv='stem.conv'),
|
|
|
|
|
'regnetz_d': _cfgr(
|
|
|
|
|
url='',
|
|
|
|
|
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
|
|
|
|
|
first_conv='stem.conv1.conv', input_size=(3, 256, 256), pool_size=(8, 8), interpolation='bicubic'),
|
|
|
|
|
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -507,46 +502,52 @@ model_cfgs = dict(
|
|
|
|
|
# experimental models, closer to a RegNetZ than a ResNet. Similar to EfficientNets but w/ groups instead of DW
|
|
|
|
|
regnetz_b=ByoModelCfg(
|
|
|
|
|
blocks=(
|
|
|
|
|
ByoBlockCfg(type='bottle', d=2, c=192, s=2, gs=24, br=0.25, block_kwargs=dict(linear_out=True)),
|
|
|
|
|
ByoBlockCfg(type='bottle', d=6, c=384, s=2, gs=24, br=0.25, block_kwargs=dict(linear_out=True)),
|
|
|
|
|
ByoBlockCfg(type='bottle', d=12, c=768, s=2, gs=24, br=0.25, block_kwargs=dict(linear_out=True)),
|
|
|
|
|
ByoBlockCfg(type='bottle', d=2, c=1536, s=2, gs=24, br=0.25, block_kwargs=dict(linear_out=True)),
|
|
|
|
|
ByoBlockCfg(type='bottle', d=2, c=48, s=2, gs=16, br=3),
|
|
|
|
|
ByoBlockCfg(type='bottle', d=6, c=96, s=2, gs=16, br=3),
|
|
|
|
|
ByoBlockCfg(type='bottle', d=12, c=192, s=2, gs=16, br=3),
|
|
|
|
|
ByoBlockCfg(type='bottle', d=2, c=288, s=2, gs=16, br=3),
|
|
|
|
|
),
|
|
|
|
|
stem_chs=32,
|
|
|
|
|
stem_pool='',
|
|
|
|
|
num_features=1792,
|
|
|
|
|
downsample='',
|
|
|
|
|
num_features=1536,
|
|
|
|
|
act_layer='silu',
|
|
|
|
|
attn_layer='se',
|
|
|
|
|
attn_kwargs=dict(rd_ratio=0.25),
|
|
|
|
|
block_kwargs=dict(bottle_in=True, linear_out=True),
|
|
|
|
|
),
|
|
|
|
|
regnetz_c=ByoModelCfg(
|
|
|
|
|
blocks=(
|
|
|
|
|
ByoBlockCfg(type='bottle', d=2, c=128, s=2, gs=16, br=0.5, block_kwargs=dict(linear_out=True)),
|
|
|
|
|
ByoBlockCfg(type='bottle', d=6, c=512, s=2, gs=32, br=0.25, block_kwargs=dict(linear_out=True)),
|
|
|
|
|
ByoBlockCfg(type='bottle', d=12, c=768, s=2, gs=32, br=0.25, block_kwargs=dict(linear_out=True)),
|
|
|
|
|
ByoBlockCfg(type='bottle', d=2, c=1536, s=2, gs=64, br=0.25, block_kwargs=dict(linear_out=True)),
|
|
|
|
|
ByoBlockCfg(type='bottle', d=2, c=48, s=2, gs=16, br=4),
|
|
|
|
|
ByoBlockCfg(type='bottle', d=6, c=96, s=2, gs=16, br=4),
|
|
|
|
|
ByoBlockCfg(type='bottle', d=12, c=192, s=2, gs=16, br=4),
|
|
|
|
|
ByoBlockCfg(type='bottle', d=2, c=288, s=2, gs=16, br=4),
|
|
|
|
|
),
|
|
|
|
|
stem_chs=32,
|
|
|
|
|
stem_pool='',
|
|
|
|
|
num_features=1792,
|
|
|
|
|
downsample='',
|
|
|
|
|
num_features=1536,
|
|
|
|
|
act_layer='silu',
|
|
|
|
|
attn_layer='se',
|
|
|
|
|
attn_kwargs=dict(rd_ratio=0.25),
|
|
|
|
|
block_kwargs=dict(bottle_in=True, linear_out=True),
|
|
|
|
|
),
|
|
|
|
|
regnetz_d=ByoModelCfg(
|
|
|
|
|
blocks=(
|
|
|
|
|
ByoBlockCfg(type='bottle', d=3, c=256, s=1, gs=64, br=0.25, block_kwargs=dict(linear_out=True)),
|
|
|
|
|
ByoBlockCfg(type='bottle', d=6, c=512, s=2, gs=64, br=0.25, block_kwargs=dict(linear_out=True)),
|
|
|
|
|
ByoBlockCfg(type='bottle', d=12, c=768, s=2, gs=64, br=0.25, block_kwargs=dict(linear_out=True)),
|
|
|
|
|
ByoBlockCfg(type='bottle', d=3, c=1536, s=2, gs=64, br=0.25, block_kwargs=dict(linear_out=True)),
|
|
|
|
|
ByoBlockCfg(type='bottle', d=3, c=64, s=1, gs=32, br=4),
|
|
|
|
|
ByoBlockCfg(type='bottle', d=6, c=128, s=2, gs=32, br=4),
|
|
|
|
|
ByoBlockCfg(type='bottle', d=12, c=256, s=2, gs=32, br=4),
|
|
|
|
|
ByoBlockCfg(type='bottle', d=3, c=384, s=2, gs=32, br=4),
|
|
|
|
|
),
|
|
|
|
|
stem_chs=128,
|
|
|
|
|
stem_type='quad',
|
|
|
|
|
stem_chs=64,
|
|
|
|
|
stem_type='tiered',
|
|
|
|
|
stem_pool='',
|
|
|
|
|
downsample='',
|
|
|
|
|
num_features=1792,
|
|
|
|
|
act_layer='silu',
|
|
|
|
|
attn_layer='se',
|
|
|
|
|
attn_kwargs=dict(rd_ratio=0.25),
|
|
|
|
|
block_kwargs=dict(bottle_in=True, linear_out=True),
|
|
|
|
|
),
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
@ -802,11 +803,17 @@ class DownsampleAvg(nn.Module):
|
|
|
|
|
return self.conv(self.pool(x))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def create_downsample(downsample_type, layers: LayerFn, **kwargs):
|
|
|
|
|
if downsample_type == 'avg':
|
|
|
|
|
return DownsampleAvg(**kwargs)
|
|
|
|
|
def create_shortcut(downsample_type, layers: LayerFn, in_chs, out_chs, stride, dilation, **kwargs):
|
|
|
|
|
assert downsample_type in ('avg', 'conv1x1', '')
|
|
|
|
|
if in_chs != out_chs or stride != 1 or dilation[0] != dilation[1]:
|
|
|
|
|
if not downsample_type:
|
|
|
|
|
return None # no shortcut
|
|
|
|
|
elif downsample_type == 'avg':
|
|
|
|
|
return DownsampleAvg(in_chs, out_chs, stride=stride, dilation=dilation[0], **kwargs)
|
|
|
|
|
else:
|
|
|
|
|
return layers.conv_norm_act(in_chs, out_chs, kernel_size=1, stride=stride, dilation=dilation[0], **kwargs)
|
|
|
|
|
else:
|
|
|
|
|
return layers.conv_norm_act(kwargs.pop('in_chs'), kwargs.pop('out_chs'), kernel_size=1, **kwargs)
|
|
|
|
|
return nn.Identity() # identity shortcut
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class BasicBlock(nn.Module):
|
|
|
|
@ -822,12 +829,9 @@ class BasicBlock(nn.Module):
|
|
|
|
|
mid_chs = make_divisible(out_chs * bottle_ratio)
|
|
|
|
|
groups = num_groups(group_size, mid_chs)
|
|
|
|
|
|
|
|
|
|
if in_chs != out_chs or stride != 1 or dilation[0] != dilation[1]:
|
|
|
|
|
self.shortcut = create_downsample(
|
|
|
|
|
downsample, in_chs=in_chs, out_chs=out_chs, stride=stride, dilation=dilation[0],
|
|
|
|
|
apply_act=False, layers=layers)
|
|
|
|
|
else:
|
|
|
|
|
self.shortcut = nn.Identity()
|
|
|
|
|
self.shortcut = create_shortcut(
|
|
|
|
|
downsample, in_chs=in_chs, out_chs=out_chs, stride=stride, dilation=dilation,
|
|
|
|
|
apply_act=False, layers=layers)
|
|
|
|
|
|
|
|
|
|
self.conv1_kxk = layers.conv_norm_act(in_chs, mid_chs, kernel_size, stride=stride, dilation=dilation[0])
|
|
|
|
|
self.attn = nn.Identity() if attn_last or layers.attn is None else layers.attn(mid_chs)
|
|
|
|
@ -838,23 +842,21 @@ class BasicBlock(nn.Module):
|
|
|
|
|
self.act = nn.Identity() if linear_out else layers.act(inplace=True)
|
|
|
|
|
|
|
|
|
|
def init_weights(self, zero_init_last: bool = False):
|
|
|
|
|
if zero_init_last:
|
|
|
|
|
if zero_init_last and self.shortcut is not None:
|
|
|
|
|
nn.init.zeros_(self.conv2_kxk.bn.weight)
|
|
|
|
|
for attn in (self.attn, self.attn_last):
|
|
|
|
|
if hasattr(attn, 'reset_parameters'):
|
|
|
|
|
attn.reset_parameters()
|
|
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
|
shortcut = self.shortcut(x)
|
|
|
|
|
|
|
|
|
|
# residual path
|
|
|
|
|
shortcut = x
|
|
|
|
|
x = self.conv1_kxk(x)
|
|
|
|
|
x = self.conv2_kxk(x)
|
|
|
|
|
x = self.attn(x)
|
|
|
|
|
x = self.drop_path(x)
|
|
|
|
|
|
|
|
|
|
x = self.act(x + shortcut)
|
|
|
|
|
return x
|
|
|
|
|
if self.shortcut is not None:
|
|
|
|
|
x = x + self.shortcut(shortcut)
|
|
|
|
|
return self.act(x)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class BottleneckBlock(nn.Module):
|
|
|
|
@ -862,24 +864,18 @@ class BottleneckBlock(nn.Module):
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def __init__(self, in_chs, out_chs, kernel_size=3, stride=1, dilation=(1, 1), bottle_ratio=1., group_size=None,
|
|
|
|
|
downsample='avg', attn_last=False, linear_out=False, extra_conv=False, layers: LayerFn = None,
|
|
|
|
|
drop_block=None, drop_path_rate=0.):
|
|
|
|
|
downsample='avg', attn_last=False, linear_out=False, extra_conv=False, bottle_in=False,
|
|
|
|
|
layers: LayerFn = None, drop_block=None, drop_path_rate=0.):
|
|
|
|
|
super(BottleneckBlock, self).__init__()
|
|
|
|
|
layers = layers or LayerFn()
|
|
|
|
|
mid_chs = make_divisible(out_chs * bottle_ratio)
|
|
|
|
|
mid_chs = make_divisible((in_chs if bottle_in else out_chs) * bottle_ratio)
|
|
|
|
|
groups = num_groups(group_size, mid_chs)
|
|
|
|
|
|
|
|
|
|
if in_chs != out_chs or stride != 1 or dilation[0] != dilation[1]:
|
|
|
|
|
self.shortcut = create_downsample(
|
|
|
|
|
downsample, in_chs=in_chs, out_chs=out_chs, stride=stride, dilation=dilation[0],
|
|
|
|
|
apply_act=False, layers=layers)
|
|
|
|
|
else:
|
|
|
|
|
self.shortcut = nn.Identity()
|
|
|
|
|
self.shortcut = create_shortcut(
|
|
|
|
|
downsample, in_chs=in_chs, out_chs=out_chs, stride=stride, dilation=dilation,
|
|
|
|
|
apply_act=False, layers=layers)
|
|
|
|
|
|
|
|
|
|
self.conv1_1x1 = layers.conv_norm_act(in_chs, mid_chs, 1)
|
|
|
|
|
self.conv2_kxk = layers.conv_norm_act(
|
|
|
|
|
mid_chs, mid_chs, kernel_size, stride=stride, dilation=dilation[0],
|
|
|
|
|
groups=groups, drop_block=drop_block)
|
|
|
|
|
self.conv2_kxk = layers.conv_norm_act(
|
|
|
|
|
mid_chs, mid_chs, kernel_size, stride=stride, dilation=dilation[0],
|
|
|
|
|
groups=groups, drop_block=drop_block)
|
|
|
|
@ -895,15 +891,14 @@ class BottleneckBlock(nn.Module):
|
|
|
|
|
self.act = nn.Identity() if linear_out else layers.act(inplace=True)
|
|
|
|
|
|
|
|
|
|
def init_weights(self, zero_init_last: bool = False):
|
|
|
|
|
if zero_init_last:
|
|
|
|
|
if zero_init_last and self.shortcut is not None:
|
|
|
|
|
nn.init.zeros_(self.conv3_1x1.bn.weight)
|
|
|
|
|
for attn in (self.attn, self.attn_last):
|
|
|
|
|
if hasattr(attn, 'reset_parameters'):
|
|
|
|
|
attn.reset_parameters()
|
|
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
|
shortcut = self.shortcut(x)
|
|
|
|
|
|
|
|
|
|
shortcut = x
|
|
|
|
|
x = self.conv1_1x1(x)
|
|
|
|
|
x = self.conv2_kxk(x)
|
|
|
|
|
x = self.conv2b_kxk(x)
|
|
|
|
@ -911,9 +906,9 @@ class BottleneckBlock(nn.Module):
|
|
|
|
|
x = self.conv3_1x1(x)
|
|
|
|
|
x = self.attn_last(x)
|
|
|
|
|
x = self.drop_path(x)
|
|
|
|
|
|
|
|
|
|
x = self.act(x + shortcut)
|
|
|
|
|
return x
|
|
|
|
|
if self.shortcut is not None:
|
|
|
|
|
x = x + self.shortcut(shortcut)
|
|
|
|
|
return self.act(x)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class DarkBlock(nn.Module):
|
|
|
|
@ -935,12 +930,9 @@ class DarkBlock(nn.Module):
|
|
|
|
|
mid_chs = make_divisible(out_chs * bottle_ratio)
|
|
|
|
|
groups = num_groups(group_size, mid_chs)
|
|
|
|
|
|
|
|
|
|
if in_chs != out_chs or stride != 1 or dilation[0] != dilation[1]:
|
|
|
|
|
self.shortcut = create_downsample(
|
|
|
|
|
downsample, in_chs=in_chs, out_chs=out_chs, stride=stride, dilation=dilation[0],
|
|
|
|
|
apply_act=False, layers=layers)
|
|
|
|
|
else:
|
|
|
|
|
self.shortcut = nn.Identity()
|
|
|
|
|
self.shortcut = create_shortcut(
|
|
|
|
|
downsample, in_chs=in_chs, out_chs=out_chs, stride=stride, dilation=dilation,
|
|
|
|
|
apply_act=False, layers=layers)
|
|
|
|
|
|
|
|
|
|
self.conv1_1x1 = layers.conv_norm_act(in_chs, mid_chs, 1)
|
|
|
|
|
self.attn = nn.Identity() if attn_last or layers.attn is None else layers.attn(mid_chs)
|
|
|
|
@ -952,22 +944,22 @@ class DarkBlock(nn.Module):
|
|
|
|
|
self.act = nn.Identity() if linear_out else layers.act(inplace=True)
|
|
|
|
|
|
|
|
|
|
def init_weights(self, zero_init_last: bool = False):
|
|
|
|
|
if zero_init_last:
|
|
|
|
|
if zero_init_last and self.shortcut is not None:
|
|
|
|
|
nn.init.zeros_(self.conv2_kxk.bn.weight)
|
|
|
|
|
for attn in (self.attn, self.attn_last):
|
|
|
|
|
if hasattr(attn, 'reset_parameters'):
|
|
|
|
|
attn.reset_parameters()
|
|
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
|
shortcut = self.shortcut(x)
|
|
|
|
|
|
|
|
|
|
shortcut = x
|
|
|
|
|
x = self.conv1_1x1(x)
|
|
|
|
|
x = self.attn(x)
|
|
|
|
|
x = self.conv2_kxk(x)
|
|
|
|
|
x = self.attn_last(x)
|
|
|
|
|
x = self.drop_path(x)
|
|
|
|
|
x = self.act(x + shortcut)
|
|
|
|
|
return x
|
|
|
|
|
if self.shortcut is not None:
|
|
|
|
|
x = x + self.shortcut(shortcut)
|
|
|
|
|
return self.act(x)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class EdgeBlock(nn.Module):
|
|
|
|
@ -988,12 +980,9 @@ class EdgeBlock(nn.Module):
|
|
|
|
|
mid_chs = make_divisible(out_chs * bottle_ratio)
|
|
|
|
|
groups = num_groups(group_size, mid_chs)
|
|
|
|
|
|
|
|
|
|
if in_chs != out_chs or stride != 1 or dilation[0] != dilation[1]:
|
|
|
|
|
self.shortcut = create_downsample(
|
|
|
|
|
downsample, in_chs=in_chs, out_chs=out_chs, stride=stride, dilation=dilation[0],
|
|
|
|
|
apply_act=False, layers=layers)
|
|
|
|
|
else:
|
|
|
|
|
self.shortcut = nn.Identity()
|
|
|
|
|
self.shortcut = create_shortcut(
|
|
|
|
|
downsample, in_chs=in_chs, out_chs=out_chs, stride=stride, dilation=dilation,
|
|
|
|
|
apply_act=False, layers=layers)
|
|
|
|
|
|
|
|
|
|
self.conv1_kxk = layers.conv_norm_act(
|
|
|
|
|
in_chs, mid_chs, kernel_size, stride=stride, dilation=dilation[0],
|
|
|
|
@ -1005,22 +994,22 @@ class EdgeBlock(nn.Module):
|
|
|
|
|
self.act = nn.Identity() if linear_out else layers.act(inplace=True)
|
|
|
|
|
|
|
|
|
|
def init_weights(self, zero_init_last: bool = False):
|
|
|
|
|
if zero_init_last:
|
|
|
|
|
if zero_init_last and self.shortcut is not None:
|
|
|
|
|
nn.init.zeros_(self.conv2_1x1.bn.weight)
|
|
|
|
|
for attn in (self.attn, self.attn_last):
|
|
|
|
|
if hasattr(attn, 'reset_parameters'):
|
|
|
|
|
attn.reset_parameters()
|
|
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
|
shortcut = self.shortcut(x)
|
|
|
|
|
|
|
|
|
|
shortcut = x
|
|
|
|
|
x = self.conv1_kxk(x)
|
|
|
|
|
x = self.attn(x)
|
|
|
|
|
x = self.conv2_1x1(x)
|
|
|
|
|
x = self.attn_last(x)
|
|
|
|
|
x = self.drop_path(x)
|
|
|
|
|
x = self.act(x + shortcut)
|
|
|
|
|
return x
|
|
|
|
|
if self.shortcut is not None:
|
|
|
|
|
x = x + self.shortcut(shortcut)
|
|
|
|
|
return self.act(x)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class RepVggBlock(nn.Module):
|
|
|
|
@ -1065,8 +1054,7 @@ class RepVggBlock(nn.Module):
|
|
|
|
|
x = self.drop_path(x) # not in the paper / official impl, experimental
|
|
|
|
|
x = x + identity
|
|
|
|
|
x = self.attn(x) # no attn in the paper / official impl, experimental
|
|
|
|
|
x = self.act(x)
|
|
|
|
|
return x
|
|
|
|
|
return self.act(x)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class SelfAttnBlock(nn.Module):
|
|
|
|
@ -1074,19 +1062,16 @@ class SelfAttnBlock(nn.Module):
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def __init__(self, in_chs, out_chs, kernel_size=3, stride=1, dilation=(1, 1), bottle_ratio=1., group_size=None,
|
|
|
|
|
downsample='avg', extra_conv=False, linear_out=False, post_attn_na=True, feat_size=None,
|
|
|
|
|
layers: LayerFn = None, drop_block=None, drop_path_rate=0.):
|
|
|
|
|
downsample='avg', extra_conv=False, linear_out=False, bottle_in=False, post_attn_na=True,
|
|
|
|
|
feat_size=None, layers: LayerFn = None, drop_block=None, drop_path_rate=0.):
|
|
|
|
|
super(SelfAttnBlock, self).__init__()
|
|
|
|
|
assert layers is not None
|
|
|
|
|
mid_chs = make_divisible(out_chs * bottle_ratio)
|
|
|
|
|
mid_chs = make_divisible((in_chs if bottle_in else out_chs) * bottle_ratio)
|
|
|
|
|
groups = num_groups(group_size, mid_chs)
|
|
|
|
|
|
|
|
|
|
if in_chs != out_chs or stride != 1 or dilation[0] != dilation[1]:
|
|
|
|
|
self.shortcut = create_downsample(
|
|
|
|
|
downsample, in_chs=in_chs, out_chs=out_chs, stride=stride, dilation=dilation[0],
|
|
|
|
|
apply_act=False, layers=layers)
|
|
|
|
|
else:
|
|
|
|
|
self.shortcut = nn.Identity()
|
|
|
|
|
self.shortcut = create_shortcut(
|
|
|
|
|
downsample, in_chs=in_chs, out_chs=out_chs, stride=stride, dilation=dilation,
|
|
|
|
|
apply_act=False, layers=layers)
|
|
|
|
|
|
|
|
|
|
self.conv1_1x1 = layers.conv_norm_act(in_chs, mid_chs, 1)
|
|
|
|
|
if extra_conv:
|
|
|
|
@ -1105,7 +1090,7 @@ class SelfAttnBlock(nn.Module):
|
|
|
|
|
self.act = nn.Identity() if linear_out else layers.act(inplace=True)
|
|
|
|
|
|
|
|
|
|
def init_weights(self, zero_init_last: bool = False):
|
|
|
|
|
if zero_init_last:
|
|
|
|
|
if zero_init_last and self.shortcut is not None:
|
|
|
|
|
nn.init.zeros_(self.conv3_1x1.bn.weight)
|
|
|
|
|
if hasattr(self.self_attn, 'reset_parameters'):
|
|
|
|
|
self.self_attn.reset_parameters()
|
|
|
|
|