|
|
|
@ -13,7 +13,7 @@ def _cfg(url='', **kwargs):
|
|
|
|
|
return {
|
|
|
|
|
'url': url,
|
|
|
|
|
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
|
|
|
|
|
'crop_pct': 0.875, 'interpolation': 'bilinear',
|
|
|
|
|
'crop_pct': 0.875, 'interpolation': 'bicubic',
|
|
|
|
|
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
|
|
|
|
|
'first_conv': 'conv1', 'classifier': 'fc',
|
|
|
|
|
**kwargs
|
|
|
|
@ -21,11 +21,13 @@ def _cfg(url='', **kwargs):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
default_cfgs = {
|
|
|
|
|
'skresnet18': _cfg(url=''),
|
|
|
|
|
'skresnet26d': _cfg(),
|
|
|
|
|
'skresnet18': _cfg(
|
|
|
|
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/skresnet18_ra-4eec2804.pth'),
|
|
|
|
|
'skresnet34': _cfg(url=''),
|
|
|
|
|
'skresnet50': _cfg(),
|
|
|
|
|
'skresnet50d': _cfg(),
|
|
|
|
|
'skresnext50_32x4d': _cfg(),
|
|
|
|
|
'skresnext50_32x4d': _cfg(
|
|
|
|
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/skresnext50_ra-f40e40bf.pth'),
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -134,24 +136,10 @@ class SelectiveKernelBottleneck(nn.Module):
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
def skresnet18(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
|
|
|
|
"""Constructs a ResNet-18 model.
|
|
|
|
|
"""
|
|
|
|
|
default_cfg = default_cfgs['skresnet18']
|
|
|
|
|
sk_kwargs = dict(
|
|
|
|
|
min_attn_channels=16,
|
|
|
|
|
)
|
|
|
|
|
model = ResNet(
|
|
|
|
|
SelectiveKernelBasic, [2, 2, 2, 2], num_classes=num_classes, in_chans=in_chans,
|
|
|
|
|
block_args=dict(sk_kwargs=sk_kwargs), **kwargs)
|
|
|
|
|
model.default_cfg = default_cfg
|
|
|
|
|
if pretrained:
|
|
|
|
|
load_pretrained(model, default_cfg, num_classes, in_chans)
|
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
"""Constructs a Selective Kernel ResNet-18 model.
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
def sksresnet18(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
|
|
|
|
"""Constructs a ResNet-18 model.
|
|
|
|
|
Different from configs in Select Kernel paper or "Compounding the Performance Improvements..." this
|
|
|
|
|
variation splits the input channels to the selective convolutions to keep param count down.
|
|
|
|
|
"""
|
|
|
|
|
default_cfg = default_cfgs['skresnet18']
|
|
|
|
|
sk_kwargs = dict(
|
|
|
|
@ -169,17 +157,21 @@ def sksresnet18(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
def skresnet26d(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
|
|
|
|
"""Constructs a ResNet-26 model.
|
|
|
|
|
def skresnet34(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
|
|
|
|
"""Constructs a Selective Kernel ResNet-34 model.
|
|
|
|
|
|
|
|
|
|
Different from configs in Select Kernel paper or "Compounding the Performance Improvements..." this
|
|
|
|
|
variation splits the input channels to the selective convolutions to keep param count down.
|
|
|
|
|
"""
|
|
|
|
|
default_cfg = default_cfgs['skresnet26d']
|
|
|
|
|
default_cfg = default_cfgs['skresnet34']
|
|
|
|
|
sk_kwargs = dict(
|
|
|
|
|
keep_3x3=False,
|
|
|
|
|
min_attn_channels=16,
|
|
|
|
|
attn_reduction=8,
|
|
|
|
|
split_input=True
|
|
|
|
|
)
|
|
|
|
|
model = ResNet(
|
|
|
|
|
SelectiveKernelBottleneck, [2, 2, 2, 2], stem_width=32, stem_type='deep', avg_down=True,
|
|
|
|
|
num_classes=num_classes, in_chans=in_chans, block_args=dict(sk_kwargs=sk_kwargs), zero_init_last_bn=False
|
|
|
|
|
**kwargs)
|
|
|
|
|
SelectiveKernelBasic, [3, 4, 6, 3], num_classes=num_classes, in_chans=in_chans,
|
|
|
|
|
block_args=dict(sk_kwargs=sk_kwargs), zero_init_last_bn=False, **kwargs)
|
|
|
|
|
model.default_cfg = default_cfg
|
|
|
|
|
if pretrained:
|
|
|
|
|
load_pretrained(model, default_cfg, num_classes, in_chans)
|
|
|
|
@ -189,11 +181,12 @@ def skresnet26d(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
|
|
|
|
@register_model
|
|
|
|
|
def skresnet50(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
|
|
|
|
"""Constructs a Select Kernel ResNet-50 model.
|
|
|
|
|
Based on config in "Compounding the Performance Improvements of Assembled Techniques in a
|
|
|
|
|
Convolutional Neural Network"
|
|
|
|
|
|
|
|
|
|
Different from configs in Select Kernel paper or "Compounding the Performance Improvements..." this
|
|
|
|
|
variation splits the input channels to the selective convolutions to keep param count down.
|
|
|
|
|
"""
|
|
|
|
|
sk_kwargs = dict(
|
|
|
|
|
attn_reduction=2,
|
|
|
|
|
split_input=True,
|
|
|
|
|
)
|
|
|
|
|
default_cfg = default_cfgs['skresnet50']
|
|
|
|
|
model = ResNet(
|
|
|
|
@ -208,11 +201,12 @@ def skresnet50(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
|
|
|
|
@register_model
|
|
|
|
|
def skresnet50d(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
|
|
|
|
"""Constructs a Select Kernel ResNet-50-D model.
|
|
|
|
|
Based on config in "Compounding the Performance Improvements of Assembled Techniques in a
|
|
|
|
|
Convolutional Neural Network"
|
|
|
|
|
|
|
|
|
|
Different from configs in Select Kernel paper or "Compounding the Performance Improvements..." this
|
|
|
|
|
variation splits the input channels to the selective convolutions to keep param count down.
|
|
|
|
|
"""
|
|
|
|
|
sk_kwargs = dict(
|
|
|
|
|
attn_reduction=2,
|
|
|
|
|
split_input=True,
|
|
|
|
|
)
|
|
|
|
|
default_cfg = default_cfgs['skresnet50d']
|
|
|
|
|
model = ResNet(
|
|
|
|
|