Merge pull request #66 from rwightman/selecsls_updates

SelecSLS updates
pull/82/head
Ross Wightman 5 years ago committed by GitHub
commit e728d70831
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -316,6 +316,13 @@ model_list = [
_entry('hrnet_w44', 'HRNet-W44-C', '1908.07919'),
_entry('hrnet_w48', 'HRNet-W48-C', '1908.07919'),
_entry('hrnet_w64', 'HRNet-W64-C', '1908.07919'),
## SelecSLS official impl weights
_entry('selecsls42b', 'SelecSLS-42_B', '1907.00837',
model_desc='Originally from https://github.com/mehtadushy/SelecSLS-Pytorch'),
_entry('selecsls60b', 'SelecSLS-60_B', '1907.00837',
model_desc='Originally from https://github.com/mehtadushy/SelecSLS-Pytorch'),
]
for m in model_list:

@ -20,7 +20,6 @@ from .helpers import load_pretrained
from .adaptive_avgmax_pool import SelectAdaptivePool2d
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
__all__ = ['SelecSLS'] # model_registry will add each entrypoint fn to this
@ -39,14 +38,14 @@ default_cfgs = {
'selecsls42': _cfg(
url='',
interpolation='bicubic'),
'selecsls42_B': _cfg(
url='http://gvv.mpi-inf.mpg.de/projects/XNect/assets/models/SelecSLS42_B.pth',
'selecsls42b': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-selecsls/selecsls42b-8af30141.pth',
interpolation='bicubic'),
'selecsls60': _cfg(
url='',
interpolation='bicubic'),
'selecsls60_B': _cfg(
url='http://gvv.mpi-inf.mpg.de/projects/XNect/assets/models/SelecSLS60_B.pth',
'selecsls60b': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-selecsls/selecsls60b-94e619b5.pth',
interpolation='bicubic'),
'selecsls84': _cfg(
url='',
@ -54,59 +53,30 @@ default_cfgs = {
}
def conv_bn(inp, oup, stride):
def conv_bn(in_chs, out_chs, k=3, stride=1, padding=None, dilation=1):
if padding is None:
padding = ((stride - 1) + dilation * (k - 1)) // 2
return nn.Sequential(
nn.Conv2d(inp, oup, 3, stride, 1, bias=False),
nn.BatchNorm2d(oup),
nn.Conv2d(in_chs, out_chs, k, stride, padding=padding, dilation=dilation, bias=False),
nn.BatchNorm2d(out_chs),
nn.ReLU(inplace=True)
)
def conv_1x1_bn(inp, oup):
return nn.Sequential(
nn.Conv2d(inp, oup, 1, 1, 0, bias=False),
nn.BatchNorm2d(oup),
nn.ReLU(inplace=True)
)
class SelecSLSBlock(nn.Module):
def __init__(self, inp, skip, k, oup, isFirst, stride):
def __init__(self, in_chs, skip_chs, mid_chs, out_chs, is_first, stride, dilation=1):
super(SelecSLSBlock, self).__init__()
self.stride = stride
self.isFirst = isFirst
self.is_first = is_first
assert stride in [1, 2]
# Process input with 4 conv blocks with the same number of input and output channels
self.conv1 = nn.Sequential(
nn.Conv2d(inp, k, 3, stride, 1,groups= 1, bias=False, dilation=1),
nn.BatchNorm2d(k),
nn.ReLU(inplace=True)
)
self.conv2 = nn.Sequential(
nn.Conv2d(k, k, 1, 1, 0,groups= 1, bias=False, dilation=1),
nn.BatchNorm2d(k),
nn.ReLU(inplace=True)
)
self.conv3 = nn.Sequential(
nn.Conv2d(k, k//2, 3, 1, 1,groups= 1, bias=False, dilation=1),
nn.BatchNorm2d(k//2),
nn.ReLU(inplace=True)
)
self.conv4 = nn.Sequential(
nn.Conv2d(k//2, k, 1, 1, 0,groups= 1, bias=False, dilation=1),
nn.BatchNorm2d(k),
nn.ReLU(inplace=True)
)
self.conv5 = nn.Sequential(
nn.Conv2d(k, k//2, 3, 1, 1,groups= 1, bias=False, dilation=1),
nn.BatchNorm2d(k//2),
nn.ReLU(inplace=True)
)
self.conv6 = nn.Sequential(
nn.Conv2d(2*k + (0 if isFirst else skip), oup, 1, 1, 0,groups= 1, bias=False, dilation=1),
nn.BatchNorm2d(oup),
nn.ReLU(inplace=True)
)
self.conv1 = conv_bn(in_chs, mid_chs, 3, stride, dilation=dilation)
self.conv2 = conv_bn(mid_chs, mid_chs, 1)
self.conv3 = conv_bn(mid_chs, mid_chs // 2, 3)
self.conv4 = conv_bn(mid_chs // 2, mid_chs, 1)
self.conv5 = conv_bn(mid_chs, mid_chs // 2, 3)
self.conv6 = conv_bn(2 * mid_chs + (0 if is_first else skip_chs), out_chs, 1)
def forward(self, x):
assert isinstance(x, list)
@ -115,19 +85,19 @@ class SelecSLSBlock(nn.Module):
d1 = self.conv1(x[0])
d2 = self.conv3(self.conv2(d1))
d3 = self.conv5(self.conv4(d2))
if self.isFirst:
if self.is_first:
out = self.conv6(torch.cat([d1, d2, d3], 1))
return [out, out]
else:
return [self.conv6(torch.cat([d1, d2, d3, x[1]], 1)), x[1]]
class SelecSLS(nn.Module):
"""SelecSLS42 / SelecSLS60 / SelecSLS84
Parameters
----------
cfg : network config
String indicating the network config
cfg : network config dictionary specifying block type, feature, and head args
num_classes : int, default 1000
Number of classification classes.
in_chans : int, default 3
@ -137,134 +107,17 @@ class SelecSLS(nn.Module):
global_pool : str, default 'avg'
Global pooling type. One of 'avg', 'max', 'avgmax', 'catavgmax'
"""
def __init__(self, cfg='selecsls60', num_classes=1000, in_chans=3,
drop_rate=0.0, global_pool='avg'):
def __init__(self, cfg, num_classes=1000, in_chans=3, drop_rate=0.0, global_pool='avg'):
self.num_classes = num_classes
self.drop_rate = drop_rate
super(SelecSLS, self).__init__()
self.stem = conv_bn(in_chans, 32, 2)
#Core Network
self.features = []
if cfg=='selecsls42':
self.block = SelecSLSBlock
#Define configuration of the network after the initial neck
self.selecSLS_config = [
#inp,skip, k, oup, isFirst, stride
[ 32, 0, 64, 64, True, 2],
[ 64, 64, 64, 128, False, 1],
[128, 0, 144, 144, True, 2],
[144, 144, 144, 288, False, 1],
[288, 0, 304, 304, True, 2],
[304, 304, 304, 480, False, 1],
]
#Head can be replaced with alternative configurations depending on the problem
self.head = nn.Sequential(
conv_bn(480, 960, 2),
conv_bn(960, 1024, 1),
conv_bn(1024, 1024, 2),
conv_1x1_bn(1024, 1280),
)
self.num_features = 1280
elif cfg=='selecsls42_B':
self.block = SelecSLSBlock
#Define configuration of the network after the initial neck
self.selecSLS_config = [
#inp,skip, k, oup, isFirst, stride
[ 32, 0, 64, 64, True, 2],
[ 64, 64, 64, 128, False, 1],
[128, 0, 144, 144, True, 2],
[144, 144, 144, 288, False, 1],
[288, 0, 304, 304, True, 2],
[304, 304, 304, 480, False, 1],
]
#Head can be replaced with alternative configurations depending on the problem
self.head = nn.Sequential(
conv_bn(480, 960, 2),
conv_bn(960, 1024, 1),
conv_bn(1024, 1280, 2),
conv_1x1_bn(1280, 1024),
)
self.num_features = 1024
elif cfg=='selecsls60':
self.block = SelecSLSBlock
#Define configuration of the network after the initial neck
self.selecSLS_config = [
#inp,skip, k, oup, isFirst, stride
[ 32, 0, 64, 64, True, 2],
[ 64, 64, 64, 128, False, 1],
[128, 0, 128, 128, True, 2],
[128, 128, 128, 128, False, 1],
[128, 128, 128, 288, False, 1],
[288, 0, 288, 288, True, 2],
[288, 288, 288, 288, False, 1],
[288, 288, 288, 288, False, 1],
[288, 288, 288, 416, False, 1],
]
#Head can be replaced with alternative configurations depending on the problem
self.head = nn.Sequential(
conv_bn(416, 756, 2),
conv_bn(756, 1024, 1),
conv_bn(1024, 1024, 2),
conv_1x1_bn(1024, 1280),
)
self.num_features = 1280
elif cfg=='selecsls60_B':
self.block = SelecSLSBlock
#Define configuration of the network after the initial neck
self.selecSLS_config = [
#inp,skip, k, oup, isFirst, stride
[ 32, 0, 64, 64, True, 2],
[ 64, 64, 64, 128, False, 1],
[128, 0, 128, 128, True, 2],
[128, 128, 128, 128, False, 1],
[128, 128, 128, 288, False, 1],
[288, 0, 288, 288, True, 2],
[288, 288, 288, 288, False, 1],
[288, 288, 288, 288, False, 1],
[288, 288, 288, 416, False, 1],
]
#Head can be replaced with alternative configurations depending on the problem
self.head = nn.Sequential(
conv_bn(416, 756, 2),
conv_bn(756, 1024, 1),
conv_bn(1024, 1280, 2),
conv_1x1_bn(1280, 1024),
)
self.num_features = 1024
elif cfg=='selecsls84':
self.block = SelecSLSBlock
#Define configuration of the network after the initial neck
self.selecSLS_config = [
#inp,skip, k, oup, isFirst, stride
[ 32, 0, 64, 64, True, 2],
[ 64, 64, 64, 144, False, 1],
[144, 0, 144, 144, True, 2],
[144, 144, 144, 144, False, 1],
[144, 144, 144, 144, False, 1],
[144, 144, 144, 144, False, 1],
[144, 144, 144, 304, False, 1],
[304, 0, 304, 304, True, 2],
[304, 304, 304, 304, False, 1],
[304, 304, 304, 304, False, 1],
[304, 304, 304, 304, False, 1],
[304, 304, 304, 304, False, 1],
[304, 304, 304, 512, False, 1],
]
#Head can be replaced with alternative configurations depending on the problem
self.head = nn.Sequential(
conv_bn(512, 960, 2),
conv_bn(960, 1024, 1),
conv_bn(1024, 1024, 2),
conv_1x1_bn(1024, 1280),
)
self.num_features = 1280
else:
raise ValueError('Invalid net configuration '+cfg+' !!!')
self.stem = conv_bn(in_chans, 32, stride=2)
self.features = nn.Sequential(*[cfg['block'](*block_args) for block_args in cfg['features']])
self.head = nn.Sequential(*[conv_bn(*conv_args) for conv_args in cfg['head']])
self.num_features = cfg['num_features']
for inp, skip, k, oup, isFirst, stride in self.selecSLS_config:
self.features.append(self.block(inp, skip, k, oup, isFirst, stride))
self.features = nn.Sequential(*self.features)
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
self.fc = nn.Linear(self.num_features * self.global_pool.feat_mult(), num_classes)
@ -287,82 +140,155 @@ class SelecSLS(nn.Module):
else:
self.fc = None
def forward_features(self, x, pool=True):
def forward_features(self, x):
x = self.stem(x)
x = self.features([x])
x = self.head(x[0])
if pool:
x = self.global_pool(x)
x = x.view(x.size(0), -1)
return x
def forward(self, x):
x = self.forward_features(x)
x = self.global_pool(x).flatten(1)
if self.drop_rate > 0.:
x = F.dropout(x, p=self.drop_rate, training=self.training)
x = self.fc(x)
return x
def _create_model(variant, pretrained, model_kwargs):
cfg = {}
if variant.startswith('selecsls42'):
cfg['block'] = SelecSLSBlock
# Define configuration of the network after the initial neck
cfg['features'] = [
# in_chs, skip_chs, mid_chs, out_chs, is_first, stride
(32, 0, 64, 64, True, 2),
(64, 64, 64, 128, False, 1),
(128, 0, 144, 144, True, 2),
(144, 144, 144, 288, False, 1),
(288, 0, 304, 304, True, 2),
(304, 304, 304, 480, False, 1),
]
# Head can be replaced with alternative configurations depending on the problem
if variant == 'selecsls42b':
cfg['head'] = [
(480, 960, 3, 2),
(960, 1024, 3, 1),
(1024, 1280, 3, 2),
(1280, 1024, 1, 1),
]
cfg['num_features'] = 1024
else:
cfg['head'] = [
(480, 960, 3, 2),
(960, 1024, 3, 1),
(1024, 1024, 3, 2),
(1024, 1280, 1, 1),
]
cfg['num_features'] = 1280
elif variant.startswith('selecsls60'):
cfg['block'] = SelecSLSBlock
# Define configuration of the network after the initial neck
cfg['features'] = [
# in_chs, skip_chs, mid_chs, out_chs, is_first, stride
(32, 0, 64, 64, True, 2),
(64, 64, 64, 128, False, 1),
(128, 0, 128, 128, True, 2),
(128, 128, 128, 128, False, 1),
(128, 128, 128, 288, False, 1),
(288, 0, 288, 288, True, 2),
(288, 288, 288, 288, False, 1),
(288, 288, 288, 288, False, 1),
(288, 288, 288, 416, False, 1),
]
# Head can be replaced with alternative configurations depending on the problem
if variant == 'selecsls60b':
cfg['head'] = [
(416, 756, 3, 2),
(756, 1024, 3, 1),
(1024, 1280, 3, 2),
(1280, 1024, 1, 1),
]
cfg['num_features'] = 1024
else:
cfg['head'] = [
(416, 756, 3, 2),
(756, 1024, 3, 1),
(1024, 1024, 3, 2),
(1024, 1280, 1, 1),
]
cfg['num_features'] = 1280
elif variant == 'selecsls84':
cfg['block'] = SelecSLSBlock
# Define configuration of the network after the initial neck
cfg['features'] = [
# in_chs, skip_chs, mid_chs, out_chs, is_first, stride
(32, 0, 64, 64, True, 2),
(64, 64, 64, 144, False, 1),
(144, 0, 144, 144, True, 2),
(144, 144, 144, 144, False, 1),
(144, 144, 144, 144, False, 1),
(144, 144, 144, 144, False, 1),
(144, 144, 144, 304, False, 1),
(304, 0, 304, 304, True, 2),
(304, 304, 304, 304, False, 1),
(304, 304, 304, 304, False, 1),
(304, 304, 304, 304, False, 1),
(304, 304, 304, 304, False, 1),
(304, 304, 304, 512, False, 1),
]
# Head can be replaced with alternative configurations depending on the problem
cfg['head'] = [
(512, 960, 3, 2),
(960, 1024, 3, 1),
(1024, 1024, 3, 2),
(1024, 1280, 3, 1),
]
cfg['num_features'] = 1280
else:
raise ValueError('Invalid net configuration ' + variant + ' !!!')
model = SelecSLS(cfg, **model_kwargs)
model.default_cfg = default_cfgs[variant]
if pretrained:
load_pretrained(
model,
num_classes=model_kwargs.get('num_classes', 0),
in_chans=model_kwargs.get('in_chans', 3),
strict=True)
return model
@register_model
def selecsls42(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
def selecsls42(pretrained=False, **kwargs):
"""Constructs a SelecSLS42 model.
"""
default_cfg = default_cfgs['selecsls42']
model = SelecSLS(
cfg='selecsls42', num_classes=1000, in_chans=3, **kwargs)
model.default_cfg = default_cfg
if pretrained:
load_pretrained(model, default_cfg, num_classes, in_chans)
return model
return _create_model('selecsls42', pretrained, kwargs)
@register_model
def selecsls42_B(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
def selecsls42b(pretrained=False, **kwargs):
"""Constructs a SelecSLS42_B model.
"""
default_cfg = default_cfgs['selecsls42_B']
model = SelecSLS(
cfg='selecsls42_B', num_classes=1000, in_chans=3,**kwargs)
model.default_cfg = default_cfg
if pretrained:
load_pretrained(model, default_cfg, num_classes, in_chans)
return model
return _create_model('selecsls42b', pretrained, kwargs)
@register_model
def selecsls60(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
def selecsls60(pretrained=False, **kwargs):
"""Constructs a SelecSLS60 model.
"""
default_cfg = default_cfgs['selecsls60']
model = SelecSLS(
cfg='selecsls60', num_classes=1000, in_chans=3,**kwargs)
model.default_cfg = default_cfg
if pretrained:
load_pretrained(model, default_cfg, num_classes, in_chans)
return model
return _create_model('selecsls60', pretrained, kwargs)
@register_model
def selecsls60_B(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
def selecsls60b(pretrained=False, **kwargs):
"""Constructs a SelecSLS60_B model.
"""
default_cfg = default_cfgs['selecsls60_B']
model = SelecSLS(
cfg='selecsls60_B', num_classes=1000, in_chans=3,**kwargs)
model.default_cfg = default_cfg
if pretrained:
load_pretrained(model, default_cfg, num_classes, in_chans)
return model
return _create_model('selecsls60b', pretrained, kwargs)
@register_model
def selecsls84(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
def selecsls84(pretrained=False, **kwargs):
"""Constructs a SelecSLS84 model.
"""
default_cfg = default_cfgs['selecsls84']
model = SelecSLS(
cfg='selecsls84', num_classes=1000, in_chans=3, **kwargs)
model.default_cfg = default_cfg
if pretrained:
load_pretrained(model, default_cfg, num_classes, in_chans)
return model
return _create_model('selecsls84', pretrained, kwargs)

Loading…
Cancel
Save