Streamline SelecSLS model without breaking checkpoint compat. Move cfg handling out of model class. Update feature/pooling behaviour to match current.

pull/66/head
Ross Wightman 5 years ago
parent d59a756c16
commit b5315e66b5

@ -53,60 +53,30 @@ default_cfgs = {
}
def conv_bn(inp, oup, stride):
def conv_bn(in_chs, out_chs, k=3, stride=1, padding=None, dilation=1):
if padding is None:
padding = ((stride - 1) + dilation * (k - 1)) // 2
return nn.Sequential(
nn.Conv2d(inp, oup, 3, stride, 1, bias=False),
nn.BatchNorm2d(oup),
nn.ReLU(inplace=True)
)
def conv_1x1_bn(inp, oup):
return nn.Sequential(
nn.Conv2d(inp, oup, 1, 1, 0, bias=False),
nn.BatchNorm2d(oup),
nn.Conv2d(in_chs, out_chs, k, stride, padding=padding, dilation=dilation, bias=False),
nn.BatchNorm2d(out_chs),
nn.ReLU(inplace=True)
)
class SelecSLSBlock(nn.Module):
def __init__(self, inp, skip, k, oup, is_first, stride):
def __init__(self, in_chs, skip_chs, mid_chs, out_chs, is_first, stride, dilation=1):
super(SelecSLSBlock, self).__init__()
self.stride = stride
self.is_first = is_first
assert stride in [1, 2]
# Process input with 4 conv blocks with the same number of input and output channels
self.conv1 = nn.Sequential(
nn.Conv2d(inp, k, 3, stride, 1, groups=1, bias=False, dilation=1),
nn.BatchNorm2d(k),
nn.ReLU(inplace=True)
)
self.conv2 = nn.Sequential(
nn.Conv2d(k, k, 1, 1, 0, groups=1, bias=False, dilation=1),
nn.BatchNorm2d(k),
nn.ReLU(inplace=True)
)
self.conv3 = nn.Sequential(
nn.Conv2d(k, k // 2, 3, 1, 1, groups=1, bias=False, dilation=1),
nn.BatchNorm2d(k // 2),
nn.ReLU(inplace=True)
)
self.conv4 = nn.Sequential(
nn.Conv2d(k // 2, k, 1, 1, 0, groups=1, bias=False, dilation=1),
nn.BatchNorm2d(k),
nn.ReLU(inplace=True)
)
self.conv5 = nn.Sequential(
nn.Conv2d(k, k // 2, 3, 1, 1, groups=1, bias=False, dilation=1),
nn.BatchNorm2d(k // 2),
nn.ReLU(inplace=True)
)
self.conv6 = nn.Sequential(
nn.Conv2d(2 * k + (0 if is_first else skip), oup, 1, 1, 0, groups=1, bias=False, dilation=1),
nn.BatchNorm2d(oup),
nn.ReLU(inplace=True)
)
self.conv1 = conv_bn(in_chs, mid_chs, 3, stride, dilation=dilation)
self.conv2 = conv_bn(mid_chs, mid_chs, 1)
self.conv3 = conv_bn(mid_chs, mid_chs // 2, 3)
self.conv4 = conv_bn(mid_chs // 2, mid_chs, 1)
self.conv5 = conv_bn(mid_chs, mid_chs // 2, 3)
self.conv6 = conv_bn(2 * mid_chs + (0 if is_first else skip_chs), out_chs, 1)
def forward(self, x):
assert isinstance(x, list)
@ -127,8 +97,7 @@ class SelecSLS(nn.Module):
Parameters
----------
cfg : network config
String indicating the network config
cfg : network config dictionary specifying block type, feature, and head args
num_classes : int, default 1000
Number of classification classes.
in_chans : int, default 3
@ -139,134 +108,16 @@ class SelecSLS(nn.Module):
Global pooling type. One of 'avg', 'max', 'avgmax', 'catavgmax'
"""
def __init__(self, cfg='selecsls60', num_classes=1000, in_chans=3,
drop_rate=0.0, global_pool='avg'):
def __init__(self, cfg, num_classes=1000, in_chans=3, drop_rate=0.0, global_pool='avg'):
self.num_classes = num_classes
self.drop_rate = drop_rate
super(SelecSLS, self).__init__()
self.stem = conv_bn(in_chans, 32, 2)
# Core Network
self.features = []
if cfg == 'selecsls42':
self.block = SelecSLSBlock
# Define configuration of the network after the initial neck
self.selecsls_config = [
# inp,skip, k, oup, is_first, stride
[32, 0, 64, 64, True, 2],
[64, 64, 64, 128, False, 1],
[128, 0, 144, 144, True, 2],
[144, 144, 144, 288, False, 1],
[288, 0, 304, 304, True, 2],
[304, 304, 304, 480, False, 1],
]
# Head can be replaced with alternative configurations depending on the problem
self.head = nn.Sequential(
conv_bn(480, 960, 2),
conv_bn(960, 1024, 1),
conv_bn(1024, 1024, 2),
conv_1x1_bn(1024, 1280),
)
self.num_features = 1280
elif cfg == 'selecsls42b':
self.block = SelecSLSBlock
# Define configuration of the network after the initial neck
self.selecsls_config = [
# inp,skip, k, oup, is_first, stride
[32, 0, 64, 64, True, 2],
[64, 64, 64, 128, False, 1],
[128, 0, 144, 144, True, 2],
[144, 144, 144, 288, False, 1],
[288, 0, 304, 304, True, 2],
[304, 304, 304, 480, False, 1],
]
# Head can be replaced with alternative configurations depending on the problem
self.head = nn.Sequential(
conv_bn(480, 960, 2),
conv_bn(960, 1024, 1),
conv_bn(1024, 1280, 2),
conv_1x1_bn(1280, 1024),
)
self.num_features = 1024
elif cfg == 'selecsls60':
self.block = SelecSLSBlock
# Define configuration of the network after the initial neck
self.selecsls_config = [
# inp,skip, k, oup, is_first, stride
[32, 0, 64, 64, True, 2],
[64, 64, 64, 128, False, 1],
[128, 0, 128, 128, True, 2],
[128, 128, 128, 128, False, 1],
[128, 128, 128, 288, False, 1],
[288, 0, 288, 288, True, 2],
[288, 288, 288, 288, False, 1],
[288, 288, 288, 288, False, 1],
[288, 288, 288, 416, False, 1],
]
# Head can be replaced with alternative configurations depending on the problem
self.head = nn.Sequential(
conv_bn(416, 756, 2),
conv_bn(756, 1024, 1),
conv_bn(1024, 1024, 2),
conv_1x1_bn(1024, 1280),
)
self.num_features = 1280
elif cfg == 'selecsls60b':
self.block = SelecSLSBlock
# Define configuration of the network after the initial neck
self.selecsls_config = [
# inp,skip, k, oup, is_first, stride
[32, 0, 64, 64, True, 2],
[64, 64, 64, 128, False, 1],
[128, 0, 128, 128, True, 2],
[128, 128, 128, 128, False, 1],
[128, 128, 128, 288, False, 1],
[288, 0, 288, 288, True, 2],
[288, 288, 288, 288, False, 1],
[288, 288, 288, 288, False, 1],
[288, 288, 288, 416, False, 1],
]
# Head can be replaced with alternative configurations depending on the problem
self.head = nn.Sequential(
conv_bn(416, 756, 2),
conv_bn(756, 1024, 1),
conv_bn(1024, 1280, 2),
conv_1x1_bn(1280, 1024),
)
self.num_features = 1024
elif cfg == 'selecsls84':
self.block = SelecSLSBlock
# Define configuration of the network after the initial neck
self.selecsls_config = [
# inp,skip, k, oup, is_first, stride
[32, 0, 64, 64, True, 2],
[64, 64, 64, 144, False, 1],
[144, 0, 144, 144, True, 2],
[144, 144, 144, 144, False, 1],
[144, 144, 144, 144, False, 1],
[144, 144, 144, 144, False, 1],
[144, 144, 144, 304, False, 1],
[304, 0, 304, 304, True, 2],
[304, 304, 304, 304, False, 1],
[304, 304, 304, 304, False, 1],
[304, 304, 304, 304, False, 1],
[304, 304, 304, 304, False, 1],
[304, 304, 304, 512, False, 1],
]
# Head can be replaced with alternative configurations depending on the problem
self.head = nn.Sequential(
conv_bn(512, 960, 2),
conv_bn(960, 1024, 1),
conv_bn(1024, 1024, 2),
conv_1x1_bn(1024, 1280),
)
self.num_features = 1280
else:
raise ValueError('Invalid net configuration ' + cfg + ' !!!')
self.stem = conv_bn(in_chans, 32, stride=2)
self.features = nn.Sequential(*[cfg['block'](*block_args) for block_args in cfg['features']])
self.head = nn.Sequential(*[conv_bn(*conv_args) for conv_args in cfg['head']])
self.num_features = cfg['num_features']
for inp, skip, k, oup, is_first, stride in self.selecsls_config:
self.features.append(self.block(inp, skip, k, oup, is_first, stride))
self.features = nn.Sequential(*self.features)
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
self.fc = nn.Linear(self.num_features * self.global_pool.feat_mult(), num_classes)
@ -289,84 +140,155 @@ class SelecSLS(nn.Module):
else:
self.fc = None
def forward_features(self, x, pool=True):
def forward_features(self, x):
x = self.stem(x)
x = self.features([x])
x = self.head(x[0])
if pool:
x = self.global_pool(x)
x = x.view(x.size(0), -1)
return x
def forward(self, x):
x = self.forward_features(x)
x = self.global_pool(x).flatten(1)
if self.drop_rate > 0.:
x = F.dropout(x, p=self.drop_rate, training=self.training)
x = self.fc(x)
return x
def _create_model(variant, pretrained, model_kwargs):
cfg = {}
if variant.startswith('selecsls42'):
cfg['block'] = SelecSLSBlock
# Define configuration of the network after the initial neck
cfg['features'] = [
# in_chs, skip_chs, mid_chs, out_chs, is_first, stride
(32, 0, 64, 64, True, 2),
(64, 64, 64, 128, False, 1),
(128, 0, 144, 144, True, 2),
(144, 144, 144, 288, False, 1),
(288, 0, 304, 304, True, 2),
(304, 304, 304, 480, False, 1),
]
# Head can be replaced with alternative configurations depending on the problem
if variant == 'selecsls42b':
cfg['head'] = [
(480, 960, 3, 2),
(960, 1024, 3, 1),
(1024, 1280, 3, 2),
(1280, 1024, 1, 1),
]
cfg['num_features'] = 1024
else:
cfg['head'] = [
(480, 960, 3, 2),
(960, 1024, 3, 1),
(1024, 1024, 3, 2),
(1024, 1280, 1, 1),
]
cfg['num_features'] = 1280
elif variant.startswith('selecsls60'):
cfg['block'] = SelecSLSBlock
# Define configuration of the network after the initial neck
cfg['features'] = [
# in_chs, skip_chs, mid_chs, out_chs, is_first, stride
(32, 0, 64, 64, True, 2),
(64, 64, 64, 128, False, 1),
(128, 0, 128, 128, True, 2),
(128, 128, 128, 128, False, 1),
(128, 128, 128, 288, False, 1),
(288, 0, 288, 288, True, 2),
(288, 288, 288, 288, False, 1),
(288, 288, 288, 288, False, 1),
(288, 288, 288, 416, False, 1),
]
# Head can be replaced with alternative configurations depending on the problem
if variant == 'selecsls60b':
cfg['head'] = [
(416, 756, 3, 2),
(756, 1024, 3, 1),
(1024, 1280, 3, 2),
(1280, 1024, 1, 1),
]
cfg['num_features'] = 1024
else:
cfg['head'] = [
(416, 756, 3, 2),
(756, 1024, 3, 1),
(1024, 1024, 3, 2),
(1024, 1280, 1, 1),
]
cfg['num_features'] = 1280
elif variant == 'selecsls84':
cfg['block'] = SelecSLSBlock
# Define configuration of the network after the initial neck
cfg['features'] = [
# in_chs, skip_chs, mid_chs, out_chs, is_first, stride
(32, 0, 64, 64, True, 2),
(64, 64, 64, 144, False, 1),
(144, 0, 144, 144, True, 2),
(144, 144, 144, 144, False, 1),
(144, 144, 144, 144, False, 1),
(144, 144, 144, 144, False, 1),
(144, 144, 144, 304, False, 1),
(304, 0, 304, 304, True, 2),
(304, 304, 304, 304, False, 1),
(304, 304, 304, 304, False, 1),
(304, 304, 304, 304, False, 1),
(304, 304, 304, 304, False, 1),
(304, 304, 304, 512, False, 1),
]
# Head can be replaced with alternative configurations depending on the problem
cfg['head'] = [
(512, 960, 3, 2),
(960, 1024, 3, 1),
(1024, 1024, 3, 2),
(1024, 1280, 3, 1),
]
cfg['num_features'] = 1280
else:
raise ValueError('Invalid net configuration ' + variant + ' !!!')
model = SelecSLS(cfg, **model_kwargs)
model.default_cfg = default_cfgs[variant]
if pretrained:
load_pretrained(
model,
num_classes=model_kwargs.get('num_classes', 0),
in_chans=model_kwargs.get('in_chans', 3),
strict=True)
return model
@register_model
def selecsls42(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
def selecsls42(pretrained=False, **kwargs):
"""Constructs a SelecSLS42 model.
"""
default_cfg = default_cfgs['selecsls42']
model = SelecSLS(
cfg='selecsls42', num_classes=1000, in_chans=3, **kwargs)
model.default_cfg = default_cfg
if pretrained:
load_pretrained(model, default_cfg, num_classes, in_chans)
return model
return _create_model('selecsls42', pretrained, kwargs)
@register_model
def selecsls42b(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
def selecsls42b(pretrained=False, **kwargs):
"""Constructs a SelecSLS42_B model.
"""
default_cfg = default_cfgs['selecsls42b']
model = SelecSLS(
cfg='selecsls42b', num_classes=1000, in_chans=3, **kwargs)
model.default_cfg = default_cfg
if pretrained:
load_pretrained(model, default_cfg, num_classes, in_chans)
return model
return _create_model('selecsls42b', pretrained, kwargs)
@register_model
def selecsls60(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
def selecsls60(pretrained=False, **kwargs):
"""Constructs a SelecSLS60 model.
"""
default_cfg = default_cfgs['selecsls60']
model = SelecSLS(
cfg='selecsls60', num_classes=1000, in_chans=3, **kwargs)
model.default_cfg = default_cfg
if pretrained:
load_pretrained(model, default_cfg, num_classes, in_chans)
return model
return _create_model('selecsls60', pretrained, kwargs)
@register_model
def selecsls60b(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
def selecsls60b(pretrained=False, **kwargs):
"""Constructs a SelecSLS60_B model.
"""
default_cfg = default_cfgs['selecsls60b']
model = SelecSLS(
cfg='selecsls60b', num_classes=1000, in_chans=3, **kwargs)
model.default_cfg = default_cfg
if pretrained:
load_pretrained(model, default_cfg, num_classes, in_chans)
return model
return _create_model('selecsls60b', pretrained, kwargs)
@register_model
def selecsls84(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
def selecsls84(pretrained=False, **kwargs):
"""Constructs a SelecSLS84 model.
"""
default_cfg = default_cfgs['selecsls84']
model = SelecSLS(
cfg='selecsls84', num_classes=1000, in_chans=3, **kwargs)
model.default_cfg = default_cfg
if pretrained:
load_pretrained(model, default_cfg, num_classes, in_chans)
return model
return _create_model('selecsls84', pretrained, kwargs)

Loading…
Cancel
Save