diff --git a/timm/models/byoanet.py b/timm/models/byoanet.py index 6558de35..056813ef 100644 --- a/timm/models/byoanet.py +++ b/timm/models/byoanet.py @@ -57,8 +57,14 @@ default_cfgs = { input_size=(3, 256, 256), pool_size=(8, 8), min_input_size=(3, 256, 256)), 'lambda_resnet26t': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/lambda_resnet26t_256-b040fce6.pth', + url='', + min_input_size=(3, 128, 128), input_size=(3, 256, 256), pool_size=(8, 8)), + 'lambda_resnet50ts': _cfg( + url='', min_input_size=(3, 128, 128), input_size=(3, 256, 256), pool_size=(8, 8)), + 'lambda_resnet26rpt_256': _cfg( + url='', + fixed_input_size=True, input_size=(3, 256, 256), pool_size=(8, 8)), } @@ -198,6 +204,33 @@ model_cfgs = dict( self_attn_layer='lambda', self_attn_kwargs=dict(r=9) ), + lambda_resnet50ts=ByoModelCfg( + blocks=( + ByoBlockCfg(type='bottle', d=3, c=256, s=1, gs=0, br=0.25), + interleave_blocks(types=('bottle', 'self_attn'), every=4, d=4, c=512, s=2, gs=0, br=0.25), + interleave_blocks(types=('bottle', 'self_attn'), d=6, c=1024, s=2, gs=0, br=0.25), + interleave_blocks(types=('bottle', 'self_attn'), d=3, c=2048, s=2, gs=0, br=0.25), + ), + stem_chs=64, + stem_type='tiered', + stem_pool='maxpool', + act_layer='silu', + self_attn_layer='lambda', + self_attn_kwargs=dict(r=9) + ), + lambda_resnet26rpt_256=ByoModelCfg( + blocks=( + ByoBlockCfg(type='bottle', d=2, c=256, s=1, gs=0, br=0.25), + ByoBlockCfg(type='bottle', d=2, c=512, s=2, gs=0, br=0.25), + interleave_blocks(types=('bottle', 'self_attn'), d=2, c=1024, s=2, gs=0, br=0.25), + ByoBlockCfg(type='self_attn', d=2, c=2048, s=2, gs=0, br=0.25), + ), + stem_chs=64, + stem_type='tiered', + stem_pool='maxpool', + self_attn_layer='lambda', + self_attn_kwargs=dict(r=None) + ), ) @@ -275,6 +308,21 @@ def eca_halonext26ts(pretrained=False, **kwargs): @register_model def lambda_resnet26t(pretrained=False, **kwargs): - """ Lambda-ResNet-26T. Lambda layers in last two stages. + """ Lambda-ResNet-26-T. Lambda layers w/ conv pos in last two stages. """ return _create_byoanet('lambda_resnet26t', pretrained=pretrained, **kwargs) + + +@register_model +def lambda_resnet50ts(pretrained=False, **kwargs): + """ Lambda-ResNet-50-TS. SiLU act. Lambda layers w/ conv pos in last two stages. + """ + return _create_byoanet('lambda_resnet50ts', pretrained=pretrained, **kwargs) + + +@register_model +def lambda_resnet26rpt_256(pretrained=False, **kwargs): + """ Lambda-ResNet-26-R-T. Lambda layers w/ rel pos embed in last two stages. + """ + kwargs.setdefault('img_size', 256) + return _create_byoanet('lambda_resnet26rpt_256', pretrained=pretrained, **kwargs) diff --git a/timm/models/byobnet.py b/timm/models/byobnet.py index d00aeb32..515f2073 100644 --- a/timm/models/byobnet.py +++ b/timm/models/byobnet.py @@ -51,6 +51,16 @@ def _cfg(url='', **kwargs): } +def _cfgr(url='', **kwargs): + return { + 'url': url, 'num_classes': 1000, 'input_size': (3, 256, 256), 'pool_size': (8, 8), + 'crop_pct': 0.9, 'interpolation': 'bicubic', + 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, + 'first_conv': 'stem.conv1.conv', 'classifier': 'head.fc', + **kwargs + } + + default_cfgs = { # GPU-Efficient (ResNet) weights 'gernet_s': _cfg( @@ -92,65 +102,50 @@ default_cfgs = { url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnet51q_ra2-d47dcc76.pth', first_conv='stem.conv1', input_size=(3, 256, 256), pool_size=(8, 8), test_input_size=(3, 288, 288), crop_pct=1.0), - 'resnet61q': _cfg( + 'resnet61q': _cfgr( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnet61q_ra2-6afc536c.pth', - first_conv='stem.conv1.conv', input_size=(3, 256, 256), pool_size=(8, 8), - test_input_size=(3, 288, 288), crop_pct=1.0, interpolation='bicubic'), - - 'resnext26ts': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/resnext26ts_256_ra2-8bbd9106.pth', - first_conv='stem.conv1.conv', input_size=(3, 256, 256), pool_size=(8, 8), interpolation='bicubic'), - 'gcresnext26ts': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/gcresnext26ts_256-e414378b.pth', - first_conv='stem.conv1.conv', input_size=(3, 256, 256), pool_size=(8, 8), interpolation='bicubic'), - 'seresnext26ts': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/seresnext26ts_256-6f0d74a3.pth', - first_conv='stem.conv1.conv', input_size=(3, 256, 256), pool_size=(8, 8), interpolation='bicubic'), - 'eca_resnext26ts': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/eca_resnext26ts_256-5a1d030f.pth', - first_conv='stem.conv1.conv', input_size=(3, 256, 256), pool_size=(8, 8), interpolation='bicubic'), - 'bat_resnext26ts': _cfg( + test_input_size=(3, 288, 288), crop_pct=1.0), + + 'resnext26ts': _cfgr( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/resnext26ts_256_ra2-8bbd9106.pth'), + 'gcresnext26ts': _cfgr( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/gcresnext26ts_256-e414378b.pth'), + 'seresnext26ts': _cfgr( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/seresnext26ts_256-6f0d74a3.pth'), + 'eca_resnext26ts': _cfgr( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/eca_resnext26ts_256-5a1d030f.pth'), + 'bat_resnext26ts': _cfgr( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/bat_resnext26ts_256-fa6fd595.pth', - first_conv='stem.conv1.conv', input_size=(3, 256, 256), pool_size=(8, 8), interpolation='bicubic', min_input_size=(3, 256, 256)), - 'resnet32ts': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/resnet32ts_256-aacf5250.pth', - first_conv='stem.conv1.conv', input_size=(3, 256, 256), pool_size=(8, 8), interpolation='bicubic'), - 'resnet33ts': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/resnet33ts_256-e91b09a4.pth', - first_conv='stem.conv1.conv', input_size=(3, 256, 256), pool_size=(8, 8), interpolation='bicubic'), - 'gcresnet33ts': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/gcresnet33ts_256-0e0cd345.pth', - first_conv='stem.conv1.conv', input_size=(3, 256, 256), pool_size=(8, 8), interpolation='bicubic'), - 'seresnet33ts': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/seresnet33ts_256-f8ad44d9.pth', - first_conv='stem.conv1.conv', input_size=(3, 256, 256), pool_size=(8, 8), interpolation='bicubic'), - 'eca_resnet33ts': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/eca_resnet33ts_256-8f98face.pth', - first_conv='stem.conv1.conv', input_size=(3, 256, 256), pool_size=(8, 8), interpolation='bicubic'), - - 'gcresnet50t': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/gcresnet50t_256-96374d1c.pth', - first_conv='stem.conv1.conv', input_size=(3, 256, 256), pool_size=(8, 8), interpolation='bicubic'), - - 'gcresnext50ts': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/gcresnext50ts_256-3e0f515e.pth', - first_conv='stem.conv1.conv', input_size=(3, 256, 256), pool_size=(8, 8), interpolation='bicubic'), - - # experimental models - 'regnetz_b': _cfg( + 'resnet32ts': _cfgr( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/resnet32ts_256-aacf5250.pth'), + 'resnet33ts': _cfgr( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/resnet33ts_256-e91b09a4.pth'), + 'gcresnet33ts': _cfgr( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/gcresnet33ts_256-0e0cd345.pth'), + 'seresnet33ts': _cfgr( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/seresnet33ts_256-f8ad44d9.pth'), + 'eca_resnet33ts': _cfgr( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/eca_resnet33ts_256-8f98face.pth'), + + 'gcresnet50t': _cfgr( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/gcresnet50t_256-96374d1c.pth'), + + 'gcresnext50ts': _cfgr( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/gcresnext50ts_256-3e0f515e.pth'), + + # experimental models, likely to change ot be removed + 'regnetz_b': _cfgr( url='', mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), - input_size=(3, 256, 256), pool_size=(8, 8), interpolation='bicubic'), - 'regnetz_c': _cfg( + input_size=(3, 224, 224), pool_size=(7, 7), first_conv='stem.conv'), + 'regnetz_c': _cfgr( url='', - mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), - input_size=(3, 256, 256), pool_size=(8, 8), interpolation='bicubic'), - 'regnetz_d': _cfg( + imean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), first_conv='stem.conv'), + 'regnetz_d': _cfgr( url='', - mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), - first_conv='stem.conv1.conv', input_size=(3, 256, 256), pool_size=(8, 8), interpolation='bicubic'), + mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)), } @@ -507,46 +502,52 @@ model_cfgs = dict( # experimental models, closer to a RegNetZ than a ResNet. Similar to EfficientNets but w/ groups instead of DW regnetz_b=ByoModelCfg( blocks=( - ByoBlockCfg(type='bottle', d=2, c=192, s=2, gs=24, br=0.25, block_kwargs=dict(linear_out=True)), - ByoBlockCfg(type='bottle', d=6, c=384, s=2, gs=24, br=0.25, block_kwargs=dict(linear_out=True)), - ByoBlockCfg(type='bottle', d=12, c=768, s=2, gs=24, br=0.25, block_kwargs=dict(linear_out=True)), - ByoBlockCfg(type='bottle', d=2, c=1536, s=2, gs=24, br=0.25, block_kwargs=dict(linear_out=True)), + ByoBlockCfg(type='bottle', d=2, c=48, s=2, gs=16, br=3), + ByoBlockCfg(type='bottle', d=6, c=96, s=2, gs=16, br=3), + ByoBlockCfg(type='bottle', d=12, c=192, s=2, gs=16, br=3), + ByoBlockCfg(type='bottle', d=2, c=288, s=2, gs=16, br=3), ), stem_chs=32, stem_pool='', - num_features=1792, + downsample='', + num_features=1536, act_layer='silu', attn_layer='se', attn_kwargs=dict(rd_ratio=0.25), + block_kwargs=dict(bottle_in=True, linear_out=True), ), regnetz_c=ByoModelCfg( blocks=( - ByoBlockCfg(type='bottle', d=2, c=128, s=2, gs=16, br=0.5, block_kwargs=dict(linear_out=True)), - ByoBlockCfg(type='bottle', d=6, c=512, s=2, gs=32, br=0.25, block_kwargs=dict(linear_out=True)), - ByoBlockCfg(type='bottle', d=12, c=768, s=2, gs=32, br=0.25, block_kwargs=dict(linear_out=True)), - ByoBlockCfg(type='bottle', d=2, c=1536, s=2, gs=64, br=0.25, block_kwargs=dict(linear_out=True)), + ByoBlockCfg(type='bottle', d=2, c=48, s=2, gs=16, br=4), + ByoBlockCfg(type='bottle', d=6, c=96, s=2, gs=16, br=4), + ByoBlockCfg(type='bottle', d=12, c=192, s=2, gs=16, br=4), + ByoBlockCfg(type='bottle', d=2, c=288, s=2, gs=16, br=4), ), stem_chs=32, stem_pool='', - num_features=1792, + downsample='', + num_features=1536, act_layer='silu', attn_layer='se', attn_kwargs=dict(rd_ratio=0.25), + block_kwargs=dict(bottle_in=True, linear_out=True), ), regnetz_d=ByoModelCfg( blocks=( - ByoBlockCfg(type='bottle', d=3, c=256, s=1, gs=64, br=0.25, block_kwargs=dict(linear_out=True)), - ByoBlockCfg(type='bottle', d=6, c=512, s=2, gs=64, br=0.25, block_kwargs=dict(linear_out=True)), - ByoBlockCfg(type='bottle', d=12, c=768, s=2, gs=64, br=0.25, block_kwargs=dict(linear_out=True)), - ByoBlockCfg(type='bottle', d=3, c=1536, s=2, gs=64, br=0.25, block_kwargs=dict(linear_out=True)), + ByoBlockCfg(type='bottle', d=3, c=64, s=1, gs=32, br=4), + ByoBlockCfg(type='bottle', d=6, c=128, s=2, gs=32, br=4), + ByoBlockCfg(type='bottle', d=12, c=256, s=2, gs=32, br=4), + ByoBlockCfg(type='bottle', d=3, c=384, s=2, gs=32, br=4), ), - stem_chs=128, - stem_type='quad', + stem_chs=64, + stem_type='tiered', stem_pool='', + downsample='', num_features=1792, act_layer='silu', attn_layer='se', attn_kwargs=dict(rd_ratio=0.25), + block_kwargs=dict(bottle_in=True, linear_out=True), ), ) @@ -802,11 +803,17 @@ class DownsampleAvg(nn.Module): return self.conv(self.pool(x)) -def create_downsample(downsample_type, layers: LayerFn, **kwargs): - if downsample_type == 'avg': - return DownsampleAvg(**kwargs) +def create_shortcut(downsample_type, layers: LayerFn, in_chs, out_chs, stride, dilation, **kwargs): + assert downsample_type in ('avg', 'conv1x1', '') + if in_chs != out_chs or stride != 1 or dilation[0] != dilation[1]: + if not downsample_type: + return None # no shortcut + elif downsample_type == 'avg': + return DownsampleAvg(in_chs, out_chs, stride=stride, dilation=dilation[0], **kwargs) + else: + return layers.conv_norm_act(in_chs, out_chs, kernel_size=1, stride=stride, dilation=dilation[0], **kwargs) else: - return layers.conv_norm_act(kwargs.pop('in_chs'), kwargs.pop('out_chs'), kernel_size=1, **kwargs) + return nn.Identity() # identity shortcut class BasicBlock(nn.Module): @@ -822,12 +829,9 @@ class BasicBlock(nn.Module): mid_chs = make_divisible(out_chs * bottle_ratio) groups = num_groups(group_size, mid_chs) - if in_chs != out_chs or stride != 1 or dilation[0] != dilation[1]: - self.shortcut = create_downsample( - downsample, in_chs=in_chs, out_chs=out_chs, stride=stride, dilation=dilation[0], - apply_act=False, layers=layers) - else: - self.shortcut = nn.Identity() + self.shortcut = create_shortcut( + downsample, in_chs=in_chs, out_chs=out_chs, stride=stride, dilation=dilation, + apply_act=False, layers=layers) self.conv1_kxk = layers.conv_norm_act(in_chs, mid_chs, kernel_size, stride=stride, dilation=dilation[0]) self.attn = nn.Identity() if attn_last or layers.attn is None else layers.attn(mid_chs) @@ -838,23 +842,21 @@ class BasicBlock(nn.Module): self.act = nn.Identity() if linear_out else layers.act(inplace=True) def init_weights(self, zero_init_last: bool = False): - if zero_init_last: + if zero_init_last and self.shortcut is not None: nn.init.zeros_(self.conv2_kxk.bn.weight) for attn in (self.attn, self.attn_last): if hasattr(attn, 'reset_parameters'): attn.reset_parameters() def forward(self, x): - shortcut = self.shortcut(x) - - # residual path + shortcut = x x = self.conv1_kxk(x) x = self.conv2_kxk(x) x = self.attn(x) x = self.drop_path(x) - - x = self.act(x + shortcut) - return x + if self.shortcut is not None: + x = x + self.shortcut(shortcut) + return self.act(x) class BottleneckBlock(nn.Module): @@ -862,24 +864,18 @@ class BottleneckBlock(nn.Module): """ def __init__(self, in_chs, out_chs, kernel_size=3, stride=1, dilation=(1, 1), bottle_ratio=1., group_size=None, - downsample='avg', attn_last=False, linear_out=False, extra_conv=False, layers: LayerFn = None, - drop_block=None, drop_path_rate=0.): + downsample='avg', attn_last=False, linear_out=False, extra_conv=False, bottle_in=False, + layers: LayerFn = None, drop_block=None, drop_path_rate=0.): super(BottleneckBlock, self).__init__() layers = layers or LayerFn() - mid_chs = make_divisible(out_chs * bottle_ratio) + mid_chs = make_divisible((in_chs if bottle_in else out_chs) * bottle_ratio) groups = num_groups(group_size, mid_chs) - if in_chs != out_chs or stride != 1 or dilation[0] != dilation[1]: - self.shortcut = create_downsample( - downsample, in_chs=in_chs, out_chs=out_chs, stride=stride, dilation=dilation[0], - apply_act=False, layers=layers) - else: - self.shortcut = nn.Identity() + self.shortcut = create_shortcut( + downsample, in_chs=in_chs, out_chs=out_chs, stride=stride, dilation=dilation, + apply_act=False, layers=layers) self.conv1_1x1 = layers.conv_norm_act(in_chs, mid_chs, 1) - self.conv2_kxk = layers.conv_norm_act( - mid_chs, mid_chs, kernel_size, stride=stride, dilation=dilation[0], - groups=groups, drop_block=drop_block) self.conv2_kxk = layers.conv_norm_act( mid_chs, mid_chs, kernel_size, stride=stride, dilation=dilation[0], groups=groups, drop_block=drop_block) @@ -895,15 +891,14 @@ class BottleneckBlock(nn.Module): self.act = nn.Identity() if linear_out else layers.act(inplace=True) def init_weights(self, zero_init_last: bool = False): - if zero_init_last: + if zero_init_last and self.shortcut is not None: nn.init.zeros_(self.conv3_1x1.bn.weight) for attn in (self.attn, self.attn_last): if hasattr(attn, 'reset_parameters'): attn.reset_parameters() def forward(self, x): - shortcut = self.shortcut(x) - + shortcut = x x = self.conv1_1x1(x) x = self.conv2_kxk(x) x = self.conv2b_kxk(x) @@ -911,9 +906,9 @@ class BottleneckBlock(nn.Module): x = self.conv3_1x1(x) x = self.attn_last(x) x = self.drop_path(x) - - x = self.act(x + shortcut) - return x + if self.shortcut is not None: + x = x + self.shortcut(shortcut) + return self.act(x) class DarkBlock(nn.Module): @@ -935,12 +930,9 @@ class DarkBlock(nn.Module): mid_chs = make_divisible(out_chs * bottle_ratio) groups = num_groups(group_size, mid_chs) - if in_chs != out_chs or stride != 1 or dilation[0] != dilation[1]: - self.shortcut = create_downsample( - downsample, in_chs=in_chs, out_chs=out_chs, stride=stride, dilation=dilation[0], - apply_act=False, layers=layers) - else: - self.shortcut = nn.Identity() + self.shortcut = create_shortcut( + downsample, in_chs=in_chs, out_chs=out_chs, stride=stride, dilation=dilation, + apply_act=False, layers=layers) self.conv1_1x1 = layers.conv_norm_act(in_chs, mid_chs, 1) self.attn = nn.Identity() if attn_last or layers.attn is None else layers.attn(mid_chs) @@ -952,22 +944,22 @@ class DarkBlock(nn.Module): self.act = nn.Identity() if linear_out else layers.act(inplace=True) def init_weights(self, zero_init_last: bool = False): - if zero_init_last: + if zero_init_last and self.shortcut is not None: nn.init.zeros_(self.conv2_kxk.bn.weight) for attn in (self.attn, self.attn_last): if hasattr(attn, 'reset_parameters'): attn.reset_parameters() def forward(self, x): - shortcut = self.shortcut(x) - + shortcut = x x = self.conv1_1x1(x) x = self.attn(x) x = self.conv2_kxk(x) x = self.attn_last(x) x = self.drop_path(x) - x = self.act(x + shortcut) - return x + if self.shortcut is not None: + x = x + self.shortcut(shortcut) + return self.act(x) class EdgeBlock(nn.Module): @@ -988,12 +980,9 @@ class EdgeBlock(nn.Module): mid_chs = make_divisible(out_chs * bottle_ratio) groups = num_groups(group_size, mid_chs) - if in_chs != out_chs or stride != 1 or dilation[0] != dilation[1]: - self.shortcut = create_downsample( - downsample, in_chs=in_chs, out_chs=out_chs, stride=stride, dilation=dilation[0], - apply_act=False, layers=layers) - else: - self.shortcut = nn.Identity() + self.shortcut = create_shortcut( + downsample, in_chs=in_chs, out_chs=out_chs, stride=stride, dilation=dilation, + apply_act=False, layers=layers) self.conv1_kxk = layers.conv_norm_act( in_chs, mid_chs, kernel_size, stride=stride, dilation=dilation[0], @@ -1005,22 +994,22 @@ class EdgeBlock(nn.Module): self.act = nn.Identity() if linear_out else layers.act(inplace=True) def init_weights(self, zero_init_last: bool = False): - if zero_init_last: + if zero_init_last and self.shortcut is not None: nn.init.zeros_(self.conv2_1x1.bn.weight) for attn in (self.attn, self.attn_last): if hasattr(attn, 'reset_parameters'): attn.reset_parameters() def forward(self, x): - shortcut = self.shortcut(x) - + shortcut = x x = self.conv1_kxk(x) x = self.attn(x) x = self.conv2_1x1(x) x = self.attn_last(x) x = self.drop_path(x) - x = self.act(x + shortcut) - return x + if self.shortcut is not None: + x = x + self.shortcut(shortcut) + return self.act(x) class RepVggBlock(nn.Module): @@ -1065,8 +1054,7 @@ class RepVggBlock(nn.Module): x = self.drop_path(x) # not in the paper / official impl, experimental x = x + identity x = self.attn(x) # no attn in the paper / official impl, experimental - x = self.act(x) - return x + return self.act(x) class SelfAttnBlock(nn.Module): @@ -1074,19 +1062,16 @@ class SelfAttnBlock(nn.Module): """ def __init__(self, in_chs, out_chs, kernel_size=3, stride=1, dilation=(1, 1), bottle_ratio=1., group_size=None, - downsample='avg', extra_conv=False, linear_out=False, post_attn_na=True, feat_size=None, - layers: LayerFn = None, drop_block=None, drop_path_rate=0.): + downsample='avg', extra_conv=False, linear_out=False, bottle_in=False, post_attn_na=True, + feat_size=None, layers: LayerFn = None, drop_block=None, drop_path_rate=0.): super(SelfAttnBlock, self).__init__() assert layers is not None - mid_chs = make_divisible(out_chs * bottle_ratio) + mid_chs = make_divisible((in_chs if bottle_in else out_chs) * bottle_ratio) groups = num_groups(group_size, mid_chs) - if in_chs != out_chs or stride != 1 or dilation[0] != dilation[1]: - self.shortcut = create_downsample( - downsample, in_chs=in_chs, out_chs=out_chs, stride=stride, dilation=dilation[0], - apply_act=False, layers=layers) - else: - self.shortcut = nn.Identity() + self.shortcut = create_shortcut( + downsample, in_chs=in_chs, out_chs=out_chs, stride=stride, dilation=dilation, + apply_act=False, layers=layers) self.conv1_1x1 = layers.conv_norm_act(in_chs, mid_chs, 1) if extra_conv: @@ -1105,7 +1090,7 @@ class SelfAttnBlock(nn.Module): self.act = nn.Identity() if linear_out else layers.act(inplace=True) def init_weights(self, zero_init_last: bool = False): - if zero_init_last: + if zero_init_last and self.shortcut is not None: nn.init.zeros_(self.conv3_1x1.bn.weight) if hasattr(self.self_attn, 'reset_parameters'): self.self_attn.reset_parameters() diff --git a/timm/models/efficientnet_blocks.py b/timm/models/efficientnet_blocks.py index b43f38f5..b1fec449 100644 --- a/timm/models/efficientnet_blocks.py +++ b/timm/models/efficientnet_blocks.py @@ -277,7 +277,6 @@ class EdgeResidual(nn.Module): mid_chs = make_divisible(force_in_chs * exp_ratio) else: mid_chs = make_divisible(in_chs * exp_ratio) - has_se = se_layer is not None and se_ratio > 0. self.has_residual = (in_chs == out_chs and stride == 1) and not noskip self.drop_path_rate = drop_path_rate diff --git a/timm/models/layers/bottleneck_attn.py b/timm/models/layers/bottleneck_attn.py index c0c619cc..bf6af675 100644 --- a/timm/models/layers/bottleneck_attn.py +++ b/timm/models/layers/bottleneck_attn.py @@ -122,7 +122,7 @@ class BottleneckAttn(nn.Module): attn_logits = attn_logits + self.pos_embed(q) # B, num_heads, H * W, H * W attn_out = attn_logits.softmax(dim=-1) - attn_out = (attn_out @ v).transpose(1, 2).reshape(B, self.dim_out, H, W) # B, dim_out, H, W + attn_out = (attn_out @ v).transpose(-1, -2).reshape(B, self.dim_out, H, W) # B, dim_out, H, W attn_out = self.pool(attn_out) return attn_out diff --git a/timm/models/layers/lambda_layer.py b/timm/models/layers/lambda_layer.py index d298c1aa..eeb77e45 100644 --- a/timm/models/layers/lambda_layer.py +++ b/timm/models/layers/lambda_layer.py @@ -24,18 +24,30 @@ import torch from torch import nn import torch.nn.functional as F +from .helpers import to_2tuple from .weight_init import trunc_normal_ +def rel_pos_indices(size): + size = to_2tuple(size) + pos = torch.stack(torch.meshgrid(torch.arange(size[0]), torch.arange(size[1]))).flatten(1) + rel_pos = pos[:, None, :] - pos[:, :, None] + rel_pos[0] += size[0] - 1 + rel_pos[1] += size[1] - 1 + return rel_pos # 2, H * W, H * W + + class LambdaLayer(nn.Module): - """Lambda Layer w/ lambda conv position embedding + """Lambda Layer Paper: `LambdaNetworks: Modeling Long-Range Interactions Without Attention` - https://arxiv.org/abs/2102.08602 + + NOTE: intra-depth parameter 'u' is fixed at 1. It did not appear worth the complexity to add. """ def __init__( self, - dim, dim_out=None, stride=1, num_heads=4, dim_head=16, r=7, qkv_bias=False): + dim, dim_out=None, feat_size=None, stride=1, num_heads=4, dim_head=16, r=7, qkv_bias=False): super().__init__() self.dim = dim self.dim_out = dim_out or dim @@ -43,7 +55,6 @@ class LambdaLayer(nn.Module): self.num_heads = num_heads assert self.dim_out % num_heads == 0, ' should be divided by num_heads' self.dim_v = self.dim_out // num_heads # value depth 'v' - self.r = r # relative position neighbourhood (lambda conv kernel size) self.qkv = nn.Conv2d( dim, @@ -52,8 +63,19 @@ class LambdaLayer(nn.Module): self.norm_q = nn.BatchNorm2d(num_heads * dim_head) self.norm_v = nn.BatchNorm2d(self.dim_v) - # NOTE currently only supporting the local lambda convolutions for positional - self.conv_lambda = nn.Conv3d(1, dim_head, (r, r, 1), padding=(r // 2, r // 2, 0)) + if r is not None: + # local lambda convolution for pos + self.conv_lambda = nn.Conv3d(1, dim_head, (r, r, 1), padding=(r // 2, r // 2, 0)) + self.pos_emb = None + self.rel_pos_indices = None + else: + # relative pos embedding + assert feat_size is not None + feat_size = to_2tuple(feat_size) + rel_size = [2 * s - 1 for s in feat_size] + self.conv_lambda = None + self.pos_emb = nn.Parameter(torch.zeros(rel_size[0], rel_size[1], self.dim_k)) + self.register_buffer('rel_pos_indices', rel_pos_indices(feat_size), persistent=False) self.pool = nn.AvgPool2d(2, 2) if stride == 2 else nn.Identity() @@ -61,12 +83,14 @@ class LambdaLayer(nn.Module): def reset_parameters(self): trunc_normal_(self.qkv.weight, std=self.dim ** -0.5) - trunc_normal_(self.conv_lambda.weight, std=self.dim_k ** -0.5) + if self.conv_lambda is not None: + trunc_normal_(self.conv_lambda.weight, std=self.dim_k ** -0.5) + if self.pos_emb is not None: + trunc_normal_(self.pos_emb, std=.02) def forward(self, x): B, C, H, W = x.shape M = H * W - qkv = self.qkv(x) q, k, v = torch.split(qkv, [ self.num_heads * self.dim_k, self.dim_k, self.dim_v], dim=1) @@ -77,10 +101,15 @@ class LambdaLayer(nn.Module): content_lam = k @ v # B, K, V content_out = q @ content_lam.unsqueeze(1) # B, num_heads, M, V - position_lam = self.conv_lambda(v.reshape(B, 1, H, W, self.dim_v)) # B, H, W, V, K - position_lam = position_lam.reshape(B, 1, self.dim_k, H * W, self.dim_v).transpose(2, 3) # B, 1, M, K, V + if self.pos_emb is None: + position_lam = self.conv_lambda(v.reshape(B, 1, H, W, self.dim_v)) # B, H, W, V, K + position_lam = position_lam.reshape(B, 1, self.dim_k, H * W, self.dim_v).transpose(2, 3) # B, 1, M, K, V + else: + # FIXME relative pos embedding path not fully verified + pos_emb = self.pos_emb[self.rel_pos_indices[0], self.rel_pos_indices[1]].expand(B, -1, -1, -1) + position_lam = (pos_emb.transpose(-1, -2) @ v.unsqueeze(1)).unsqueeze(1) # B, 1, M, K, V position_out = (q.unsqueeze(-2) @ position_lam).squeeze(-2) # B, num_heads, M, V - out = (content_out + position_out).transpose(3, 1).reshape(B, C, H, W) # B, C (num_heads * V), H, W + out = (content_out + position_out).transpose(-1, -2).reshape(B, C, H, W) # B, C (num_heads * V), H, W out = self.pool(out) return out