Streamline SelecSLS model without breaking checkpoint compat. Move cfg handling out of model class. Update feature/pooling behaviour to match current.

pull/66/head
Ross Wightman 5 years ago
parent d59a756c16
commit b5315e66b5

@ -53,60 +53,30 @@ default_cfgs = {
} }
def conv_bn(inp, oup, stride): def conv_bn(in_chs, out_chs, k=3, stride=1, padding=None, dilation=1):
if padding is None:
padding = ((stride - 1) + dilation * (k - 1)) // 2
return nn.Sequential( return nn.Sequential(
nn.Conv2d(inp, oup, 3, stride, 1, bias=False), nn.Conv2d(in_chs, out_chs, k, stride, padding=padding, dilation=dilation, bias=False),
nn.BatchNorm2d(oup), nn.BatchNorm2d(out_chs),
nn.ReLU(inplace=True)
)
def conv_1x1_bn(inp, oup):
return nn.Sequential(
nn.Conv2d(inp, oup, 1, 1, 0, bias=False),
nn.BatchNorm2d(oup),
nn.ReLU(inplace=True) nn.ReLU(inplace=True)
) )
class SelecSLSBlock(nn.Module): class SelecSLSBlock(nn.Module):
def __init__(self, inp, skip, k, oup, is_first, stride): def __init__(self, in_chs, skip_chs, mid_chs, out_chs, is_first, stride, dilation=1):
super(SelecSLSBlock, self).__init__() super(SelecSLSBlock, self).__init__()
self.stride = stride self.stride = stride
self.is_first = is_first self.is_first = is_first
assert stride in [1, 2] assert stride in [1, 2]
# Process input with 4 conv blocks with the same number of input and output channels # Process input with 4 conv blocks with the same number of input and output channels
self.conv1 = nn.Sequential( self.conv1 = conv_bn(in_chs, mid_chs, 3, stride, dilation=dilation)
nn.Conv2d(inp, k, 3, stride, 1, groups=1, bias=False, dilation=1), self.conv2 = conv_bn(mid_chs, mid_chs, 1)
nn.BatchNorm2d(k), self.conv3 = conv_bn(mid_chs, mid_chs // 2, 3)
nn.ReLU(inplace=True) self.conv4 = conv_bn(mid_chs // 2, mid_chs, 1)
) self.conv5 = conv_bn(mid_chs, mid_chs // 2, 3)
self.conv2 = nn.Sequential( self.conv6 = conv_bn(2 * mid_chs + (0 if is_first else skip_chs), out_chs, 1)
nn.Conv2d(k, k, 1, 1, 0, groups=1, bias=False, dilation=1),
nn.BatchNorm2d(k),
nn.ReLU(inplace=True)
)
self.conv3 = nn.Sequential(
nn.Conv2d(k, k // 2, 3, 1, 1, groups=1, bias=False, dilation=1),
nn.BatchNorm2d(k // 2),
nn.ReLU(inplace=True)
)
self.conv4 = nn.Sequential(
nn.Conv2d(k // 2, k, 1, 1, 0, groups=1, bias=False, dilation=1),
nn.BatchNorm2d(k),
nn.ReLU(inplace=True)
)
self.conv5 = nn.Sequential(
nn.Conv2d(k, k // 2, 3, 1, 1, groups=1, bias=False, dilation=1),
nn.BatchNorm2d(k // 2),
nn.ReLU(inplace=True)
)
self.conv6 = nn.Sequential(
nn.Conv2d(2 * k + (0 if is_first else skip), oup, 1, 1, 0, groups=1, bias=False, dilation=1),
nn.BatchNorm2d(oup),
nn.ReLU(inplace=True)
)
def forward(self, x): def forward(self, x):
assert isinstance(x, list) assert isinstance(x, list)
@ -127,8 +97,7 @@ class SelecSLS(nn.Module):
Parameters Parameters
---------- ----------
cfg : network config cfg : network config dictionary specifying block type, feature, and head args
String indicating the network config
num_classes : int, default 1000 num_classes : int, default 1000
Number of classification classes. Number of classification classes.
in_chans : int, default 3 in_chans : int, default 3
@ -139,134 +108,16 @@ class SelecSLS(nn.Module):
Global pooling type. One of 'avg', 'max', 'avgmax', 'catavgmax' Global pooling type. One of 'avg', 'max', 'avgmax', 'catavgmax'
""" """
def __init__(self, cfg='selecsls60', num_classes=1000, in_chans=3, def __init__(self, cfg, num_classes=1000, in_chans=3, drop_rate=0.0, global_pool='avg'):
drop_rate=0.0, global_pool='avg'):
self.num_classes = num_classes self.num_classes = num_classes
self.drop_rate = drop_rate self.drop_rate = drop_rate
super(SelecSLS, self).__init__() super(SelecSLS, self).__init__()
self.stem = conv_bn(in_chans, 32, 2) self.stem = conv_bn(in_chans, 32, stride=2)
# Core Network self.features = nn.Sequential(*[cfg['block'](*block_args) for block_args in cfg['features']])
self.features = [] self.head = nn.Sequential(*[conv_bn(*conv_args) for conv_args in cfg['head']])
if cfg == 'selecsls42': self.num_features = cfg['num_features']
self.block = SelecSLSBlock
# Define configuration of the network after the initial neck
self.selecsls_config = [
# inp,skip, k, oup, is_first, stride
[32, 0, 64, 64, True, 2],
[64, 64, 64, 128, False, 1],
[128, 0, 144, 144, True, 2],
[144, 144, 144, 288, False, 1],
[288, 0, 304, 304, True, 2],
[304, 304, 304, 480, False, 1],
]
# Head can be replaced with alternative configurations depending on the problem
self.head = nn.Sequential(
conv_bn(480, 960, 2),
conv_bn(960, 1024, 1),
conv_bn(1024, 1024, 2),
conv_1x1_bn(1024, 1280),
)
self.num_features = 1280
elif cfg == 'selecsls42b':
self.block = SelecSLSBlock
# Define configuration of the network after the initial neck
self.selecsls_config = [
# inp,skip, k, oup, is_first, stride
[32, 0, 64, 64, True, 2],
[64, 64, 64, 128, False, 1],
[128, 0, 144, 144, True, 2],
[144, 144, 144, 288, False, 1],
[288, 0, 304, 304, True, 2],
[304, 304, 304, 480, False, 1],
]
# Head can be replaced with alternative configurations depending on the problem
self.head = nn.Sequential(
conv_bn(480, 960, 2),
conv_bn(960, 1024, 1),
conv_bn(1024, 1280, 2),
conv_1x1_bn(1280, 1024),
)
self.num_features = 1024
elif cfg == 'selecsls60':
self.block = SelecSLSBlock
# Define configuration of the network after the initial neck
self.selecsls_config = [
# inp,skip, k, oup, is_first, stride
[32, 0, 64, 64, True, 2],
[64, 64, 64, 128, False, 1],
[128, 0, 128, 128, True, 2],
[128, 128, 128, 128, False, 1],
[128, 128, 128, 288, False, 1],
[288, 0, 288, 288, True, 2],
[288, 288, 288, 288, False, 1],
[288, 288, 288, 288, False, 1],
[288, 288, 288, 416, False, 1],
]
# Head can be replaced with alternative configurations depending on the problem
self.head = nn.Sequential(
conv_bn(416, 756, 2),
conv_bn(756, 1024, 1),
conv_bn(1024, 1024, 2),
conv_1x1_bn(1024, 1280),
)
self.num_features = 1280
elif cfg == 'selecsls60b':
self.block = SelecSLSBlock
# Define configuration of the network after the initial neck
self.selecsls_config = [
# inp,skip, k, oup, is_first, stride
[32, 0, 64, 64, True, 2],
[64, 64, 64, 128, False, 1],
[128, 0, 128, 128, True, 2],
[128, 128, 128, 128, False, 1],
[128, 128, 128, 288, False, 1],
[288, 0, 288, 288, True, 2],
[288, 288, 288, 288, False, 1],
[288, 288, 288, 288, False, 1],
[288, 288, 288, 416, False, 1],
]
# Head can be replaced with alternative configurations depending on the problem
self.head = nn.Sequential(
conv_bn(416, 756, 2),
conv_bn(756, 1024, 1),
conv_bn(1024, 1280, 2),
conv_1x1_bn(1280, 1024),
)
self.num_features = 1024
elif cfg == 'selecsls84':
self.block = SelecSLSBlock
# Define configuration of the network after the initial neck
self.selecsls_config = [
# inp,skip, k, oup, is_first, stride
[32, 0, 64, 64, True, 2],
[64, 64, 64, 144, False, 1],
[144, 0, 144, 144, True, 2],
[144, 144, 144, 144, False, 1],
[144, 144, 144, 144, False, 1],
[144, 144, 144, 144, False, 1],
[144, 144, 144, 304, False, 1],
[304, 0, 304, 304, True, 2],
[304, 304, 304, 304, False, 1],
[304, 304, 304, 304, False, 1],
[304, 304, 304, 304, False, 1],
[304, 304, 304, 304, False, 1],
[304, 304, 304, 512, False, 1],
]
# Head can be replaced with alternative configurations depending on the problem
self.head = nn.Sequential(
conv_bn(512, 960, 2),
conv_bn(960, 1024, 1),
conv_bn(1024, 1024, 2),
conv_1x1_bn(1024, 1280),
)
self.num_features = 1280
else:
raise ValueError('Invalid net configuration ' + cfg + ' !!!')
for inp, skip, k, oup, is_first, stride in self.selecsls_config:
self.features.append(self.block(inp, skip, k, oup, is_first, stride))
self.features = nn.Sequential(*self.features)
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
self.fc = nn.Linear(self.num_features * self.global_pool.feat_mult(), num_classes) self.fc = nn.Linear(self.num_features * self.global_pool.feat_mult(), num_classes)
@ -289,84 +140,155 @@ class SelecSLS(nn.Module):
else: else:
self.fc = None self.fc = None
def forward_features(self, x, pool=True): def forward_features(self, x):
x = self.stem(x) x = self.stem(x)
x = self.features([x]) x = self.features([x])
x = self.head(x[0]) x = self.head(x[0])
if pool:
x = self.global_pool(x)
x = x.view(x.size(0), -1)
return x return x
def forward(self, x): def forward(self, x):
x = self.forward_features(x) x = self.forward_features(x)
x = self.global_pool(x).flatten(1)
if self.drop_rate > 0.: if self.drop_rate > 0.:
x = F.dropout(x, p=self.drop_rate, training=self.training) x = F.dropout(x, p=self.drop_rate, training=self.training)
x = self.fc(x) x = self.fc(x)
return x return x
def _create_model(variant, pretrained, model_kwargs):
cfg = {}
if variant.startswith('selecsls42'):
cfg['block'] = SelecSLSBlock
# Define configuration of the network after the initial neck
cfg['features'] = [
# in_chs, skip_chs, mid_chs, out_chs, is_first, stride
(32, 0, 64, 64, True, 2),
(64, 64, 64, 128, False, 1),
(128, 0, 144, 144, True, 2),
(144, 144, 144, 288, False, 1),
(288, 0, 304, 304, True, 2),
(304, 304, 304, 480, False, 1),
]
# Head can be replaced with alternative configurations depending on the problem
if variant == 'selecsls42b':
cfg['head'] = [
(480, 960, 3, 2),
(960, 1024, 3, 1),
(1024, 1280, 3, 2),
(1280, 1024, 1, 1),
]
cfg['num_features'] = 1024
else:
cfg['head'] = [
(480, 960, 3, 2),
(960, 1024, 3, 1),
(1024, 1024, 3, 2),
(1024, 1280, 1, 1),
]
cfg['num_features'] = 1280
elif variant.startswith('selecsls60'):
cfg['block'] = SelecSLSBlock
# Define configuration of the network after the initial neck
cfg['features'] = [
# in_chs, skip_chs, mid_chs, out_chs, is_first, stride
(32, 0, 64, 64, True, 2),
(64, 64, 64, 128, False, 1),
(128, 0, 128, 128, True, 2),
(128, 128, 128, 128, False, 1),
(128, 128, 128, 288, False, 1),
(288, 0, 288, 288, True, 2),
(288, 288, 288, 288, False, 1),
(288, 288, 288, 288, False, 1),
(288, 288, 288, 416, False, 1),
]
# Head can be replaced with alternative configurations depending on the problem
if variant == 'selecsls60b':
cfg['head'] = [
(416, 756, 3, 2),
(756, 1024, 3, 1),
(1024, 1280, 3, 2),
(1280, 1024, 1, 1),
]
cfg['num_features'] = 1024
else:
cfg['head'] = [
(416, 756, 3, 2),
(756, 1024, 3, 1),
(1024, 1024, 3, 2),
(1024, 1280, 1, 1),
]
cfg['num_features'] = 1280
elif variant == 'selecsls84':
cfg['block'] = SelecSLSBlock
# Define configuration of the network after the initial neck
cfg['features'] = [
# in_chs, skip_chs, mid_chs, out_chs, is_first, stride
(32, 0, 64, 64, True, 2),
(64, 64, 64, 144, False, 1),
(144, 0, 144, 144, True, 2),
(144, 144, 144, 144, False, 1),
(144, 144, 144, 144, False, 1),
(144, 144, 144, 144, False, 1),
(144, 144, 144, 304, False, 1),
(304, 0, 304, 304, True, 2),
(304, 304, 304, 304, False, 1),
(304, 304, 304, 304, False, 1),
(304, 304, 304, 304, False, 1),
(304, 304, 304, 304, False, 1),
(304, 304, 304, 512, False, 1),
]
# Head can be replaced with alternative configurations depending on the problem
cfg['head'] = [
(512, 960, 3, 2),
(960, 1024, 3, 1),
(1024, 1024, 3, 2),
(1024, 1280, 3, 1),
]
cfg['num_features'] = 1280
else:
raise ValueError('Invalid net configuration ' + variant + ' !!!')
model = SelecSLS(cfg, **model_kwargs)
model.default_cfg = default_cfgs[variant]
if pretrained:
load_pretrained(
model,
num_classes=model_kwargs.get('num_classes', 0),
in_chans=model_kwargs.get('in_chans', 3),
strict=True)
return model
@register_model @register_model
def selecsls42(pretrained=False, num_classes=1000, in_chans=3, **kwargs): def selecsls42(pretrained=False, **kwargs):
"""Constructs a SelecSLS42 model. """Constructs a SelecSLS42 model.
""" """
default_cfg = default_cfgs['selecsls42'] return _create_model('selecsls42', pretrained, kwargs)
model = SelecSLS(
cfg='selecsls42', num_classes=1000, in_chans=3, **kwargs)
model.default_cfg = default_cfg
if pretrained:
load_pretrained(model, default_cfg, num_classes, in_chans)
return model
@register_model @register_model
def selecsls42b(pretrained=False, num_classes=1000, in_chans=3, **kwargs): def selecsls42b(pretrained=False, **kwargs):
"""Constructs a SelecSLS42_B model. """Constructs a SelecSLS42_B model.
""" """
default_cfg = default_cfgs['selecsls42b'] return _create_model('selecsls42b', pretrained, kwargs)
model = SelecSLS(
cfg='selecsls42b', num_classes=1000, in_chans=3, **kwargs)
model.default_cfg = default_cfg
if pretrained:
load_pretrained(model, default_cfg, num_classes, in_chans)
return model
@register_model @register_model
def selecsls60(pretrained=False, num_classes=1000, in_chans=3, **kwargs): def selecsls60(pretrained=False, **kwargs):
"""Constructs a SelecSLS60 model. """Constructs a SelecSLS60 model.
""" """
default_cfg = default_cfgs['selecsls60'] return _create_model('selecsls60', pretrained, kwargs)
model = SelecSLS(
cfg='selecsls60', num_classes=1000, in_chans=3, **kwargs)
model.default_cfg = default_cfg
if pretrained:
load_pretrained(model, default_cfg, num_classes, in_chans)
return model
@register_model @register_model
def selecsls60b(pretrained=False, num_classes=1000, in_chans=3, **kwargs): def selecsls60b(pretrained=False, **kwargs):
"""Constructs a SelecSLS60_B model. """Constructs a SelecSLS60_B model.
""" """
default_cfg = default_cfgs['selecsls60b'] return _create_model('selecsls60b', pretrained, kwargs)
model = SelecSLS(
cfg='selecsls60b', num_classes=1000, in_chans=3, **kwargs)
model.default_cfg = default_cfg
if pretrained:
load_pretrained(model, default_cfg, num_classes, in_chans)
return model
@register_model @register_model
def selecsls84(pretrained=False, num_classes=1000, in_chans=3, **kwargs): def selecsls84(pretrained=False, **kwargs):
"""Constructs a SelecSLS84 model. """Constructs a SelecSLS84 model.
""" """
default_cfg = default_cfgs['selecsls84'] return _create_model('selecsls84', pretrained, kwargs)
model = SelecSLS(
cfg='selecsls84', num_classes=1000, in_chans=3, **kwargs)
model.default_cfg = default_cfg
if pretrained:
load_pretrained(model, default_cfg, num_classes, in_chans)
return model

Loading…
Cancel
Save