Run PyCharm autoformat on selecsls and change mix cap variables and model names to all lower

pull/66/head
Ross Wightman 5 years ago
parent fb3a0f4bb8
commit d59a756c16

@ -20,7 +20,6 @@ from .helpers import load_pretrained
from .adaptive_avgmax_pool import SelectAdaptivePool2d from .adaptive_avgmax_pool import SelectAdaptivePool2d
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
__all__ = ['SelecSLS'] # model_registry will add each entrypoint fn to this __all__ = ['SelecSLS'] # model_registry will add each entrypoint fn to this
@ -39,13 +38,13 @@ default_cfgs = {
'selecsls42': _cfg( 'selecsls42': _cfg(
url='', url='',
interpolation='bicubic'), interpolation='bicubic'),
'selecsls42_B': _cfg( 'selecsls42b': _cfg(
url='http://gvv.mpi-inf.mpg.de/projects/XNect/assets/models/SelecSLS42_B.pth', url='http://gvv.mpi-inf.mpg.de/projects/XNect/assets/models/SelecSLS42_B.pth',
interpolation='bicubic'), interpolation='bicubic'),
'selecsls60': _cfg( 'selecsls60': _cfg(
url='', url='',
interpolation='bicubic'), interpolation='bicubic'),
'selecsls60_B': _cfg( 'selecsls60b': _cfg(
url='http://gvv.mpi-inf.mpg.de/projects/XNect/assets/models/SelecSLS60_B.pth', url='http://gvv.mpi-inf.mpg.de/projects/XNect/assets/models/SelecSLS60_B.pth',
interpolation='bicubic'), interpolation='bicubic'),
'selecsls84': _cfg( 'selecsls84': _cfg(
@ -69,57 +68,59 @@ def conv_1x1_bn(inp, oup):
nn.ReLU(inplace=True) nn.ReLU(inplace=True)
) )
class SelecSLSBlock(nn.Module): class SelecSLSBlock(nn.Module):
def __init__(self, inp, skip, k, oup, isFirst, stride): def __init__(self, inp, skip, k, oup, is_first, stride):
super(SelecSLSBlock, self).__init__() super(SelecSLSBlock, self).__init__()
self.stride = stride self.stride = stride
self.isFirst = isFirst self.is_first = is_first
assert stride in [1, 2] assert stride in [1, 2]
#Process input with 4 conv blocks with the same number of input and output channels # Process input with 4 conv blocks with the same number of input and output channels
self.conv1 = nn.Sequential( self.conv1 = nn.Sequential(
nn.Conv2d(inp, k, 3, stride, 1,groups= 1, bias=False, dilation=1), nn.Conv2d(inp, k, 3, stride, 1, groups=1, bias=False, dilation=1),
nn.BatchNorm2d(k), nn.BatchNorm2d(k),
nn.ReLU(inplace=True) nn.ReLU(inplace=True)
) )
self.conv2 = nn.Sequential( self.conv2 = nn.Sequential(
nn.Conv2d(k, k, 1, 1, 0,groups= 1, bias=False, dilation=1), nn.Conv2d(k, k, 1, 1, 0, groups=1, bias=False, dilation=1),
nn.BatchNorm2d(k), nn.BatchNorm2d(k),
nn.ReLU(inplace=True) nn.ReLU(inplace=True)
) )
self.conv3 = nn.Sequential( self.conv3 = nn.Sequential(
nn.Conv2d(k, k//2, 3, 1, 1,groups= 1, bias=False, dilation=1), nn.Conv2d(k, k // 2, 3, 1, 1, groups=1, bias=False, dilation=1),
nn.BatchNorm2d(k//2), nn.BatchNorm2d(k // 2),
nn.ReLU(inplace=True) nn.ReLU(inplace=True)
) )
self.conv4 = nn.Sequential( self.conv4 = nn.Sequential(
nn.Conv2d(k//2, k, 1, 1, 0,groups= 1, bias=False, dilation=1), nn.Conv2d(k // 2, k, 1, 1, 0, groups=1, bias=False, dilation=1),
nn.BatchNorm2d(k), nn.BatchNorm2d(k),
nn.ReLU(inplace=True) nn.ReLU(inplace=True)
) )
self.conv5 = nn.Sequential( self.conv5 = nn.Sequential(
nn.Conv2d(k, k//2, 3, 1, 1,groups= 1, bias=False, dilation=1), nn.Conv2d(k, k // 2, 3, 1, 1, groups=1, bias=False, dilation=1),
nn.BatchNorm2d(k//2), nn.BatchNorm2d(k // 2),
nn.ReLU(inplace=True) nn.ReLU(inplace=True)
) )
self.conv6 = nn.Sequential( self.conv6 = nn.Sequential(
nn.Conv2d(2*k + (0 if isFirst else skip), oup, 1, 1, 0,groups= 1, bias=False, dilation=1), nn.Conv2d(2 * k + (0 if is_first else skip), oup, 1, 1, 0, groups=1, bias=False, dilation=1),
nn.BatchNorm2d(oup), nn.BatchNorm2d(oup),
nn.ReLU(inplace=True) nn.ReLU(inplace=True)
) )
def forward(self, x): def forward(self, x):
assert isinstance(x,list) assert isinstance(x, list)
assert len(x) in [1,2] assert len(x) in [1, 2]
d1 = self.conv1(x[0]) d1 = self.conv1(x[0])
d2 = self.conv3(self.conv2(d1)) d2 = self.conv3(self.conv2(d1))
d3 = self.conv5(self.conv4(d2)) d3 = self.conv5(self.conv4(d2))
if self.isFirst: if self.is_first:
out = self.conv6(torch.cat([d1, d2, d3], 1)) out = self.conv6(torch.cat([d1, d2, d3], 1))
return [out, out] return [out, out]
else: else:
return [self.conv6(torch.cat([d1, d2, d3, x[1]], 1)) , x[1]] return [self.conv6(torch.cat([d1, d2, d3, x[1]], 1)), x[1]]
class SelecSLS(nn.Module): class SelecSLS(nn.Module):
"""SelecSLS42 / SelecSLS60 / SelecSLS84 """SelecSLS42 / SelecSLS60 / SelecSLS84
@ -137,6 +138,7 @@ class SelecSLS(nn.Module):
global_pool : str, default 'avg' global_pool : str, default 'avg'
Global pooling type. One of 'avg', 'max', 'avgmax', 'catavgmax' Global pooling type. One of 'avg', 'max', 'avgmax', 'catavgmax'
""" """
def __init__(self, cfg='selecsls60', num_classes=1000, in_chans=3, def __init__(self, cfg='selecsls60', num_classes=1000, in_chans=3,
drop_rate=0.0, global_pool='avg'): drop_rate=0.0, global_pool='avg'):
self.num_classes = num_classes self.num_classes = num_classes
@ -144,126 +146,126 @@ class SelecSLS(nn.Module):
super(SelecSLS, self).__init__() super(SelecSLS, self).__init__()
self.stem = conv_bn(in_chans, 32, 2) self.stem = conv_bn(in_chans, 32, 2)
#Core Network # Core Network
self.features = [] self.features = []
if cfg=='selecsls42': if cfg == 'selecsls42':
self.block = SelecSLSBlock self.block = SelecSLSBlock
#Define configuration of the network after the initial neck # Define configuration of the network after the initial neck
self.selecSLS_config = [ self.selecsls_config = [
#inp,skip, k, oup, isFirst, stride # inp,skip, k, oup, is_first, stride
[ 32, 0, 64, 64, True, 2], [32, 0, 64, 64, True, 2],
[ 64, 64, 64, 128, False, 1], [64, 64, 64, 128, False, 1],
[128, 0, 144, 144, True, 2], [128, 0, 144, 144, True, 2],
[144, 144, 144, 288, False, 1], [144, 144, 144, 288, False, 1],
[288, 0, 304, 304, True, 2], [288, 0, 304, 304, True, 2],
[304, 304, 304, 480, False, 1], [304, 304, 304, 480, False, 1],
] ]
#Head can be replaced with alternative configurations depending on the problem # Head can be replaced with alternative configurations depending on the problem
self.head = nn.Sequential( self.head = nn.Sequential(
conv_bn(480, 960, 2), conv_bn(480, 960, 2),
conv_bn(960, 1024, 1), conv_bn(960, 1024, 1),
conv_bn(1024, 1024, 2), conv_bn(1024, 1024, 2),
conv_1x1_bn(1024, 1280), conv_1x1_bn(1024, 1280),
) )
self.num_features = 1280 self.num_features = 1280
elif cfg=='selecsls42_B': elif cfg == 'selecsls42b':
self.block = SelecSLSBlock self.block = SelecSLSBlock
#Define configuration of the network after the initial neck # Define configuration of the network after the initial neck
self.selecSLS_config = [ self.selecsls_config = [
#inp,skip, k, oup, isFirst, stride # inp,skip, k, oup, is_first, stride
[ 32, 0, 64, 64, True, 2], [32, 0, 64, 64, True, 2],
[ 64, 64, 64, 128, False, 1], [64, 64, 64, 128, False, 1],
[128, 0, 144, 144, True, 2], [128, 0, 144, 144, True, 2],
[144, 144, 144, 288, False, 1], [144, 144, 144, 288, False, 1],
[288, 0, 304, 304, True, 2], [288, 0, 304, 304, True, 2],
[304, 304, 304, 480, False, 1], [304, 304, 304, 480, False, 1],
] ]
#Head can be replaced with alternative configurations depending on the problem # Head can be replaced with alternative configurations depending on the problem
self.head = nn.Sequential( self.head = nn.Sequential(
conv_bn(480, 960, 2), conv_bn(480, 960, 2),
conv_bn(960, 1024, 1), conv_bn(960, 1024, 1),
conv_bn(1024, 1280, 2), conv_bn(1024, 1280, 2),
conv_1x1_bn(1280, 1024), conv_1x1_bn(1280, 1024),
) )
self.num_features = 1024 self.num_features = 1024
elif cfg=='selecsls60': elif cfg == 'selecsls60':
self.block = SelecSLSBlock self.block = SelecSLSBlock
#Define configuration of the network after the initial neck # Define configuration of the network after the initial neck
self.selecSLS_config = [ self.selecsls_config = [
#inp,skip, k, oup, isFirst, stride # inp,skip, k, oup, is_first, stride
[ 32, 0, 64, 64, True, 2], [32, 0, 64, 64, True, 2],
[ 64, 64, 64, 128, False, 1], [64, 64, 64, 128, False, 1],
[128, 0, 128, 128, True, 2], [128, 0, 128, 128, True, 2],
[128, 128, 128, 128, False, 1], [128, 128, 128, 128, False, 1],
[128, 128, 128, 288, False, 1], [128, 128, 128, 288, False, 1],
[288, 0, 288, 288, True, 2], [288, 0, 288, 288, True, 2],
[288, 288, 288, 288, False, 1], [288, 288, 288, 288, False, 1],
[288, 288, 288, 288, False, 1], [288, 288, 288, 288, False, 1],
[288, 288, 288, 416, False, 1], [288, 288, 288, 416, False, 1],
] ]
#Head can be replaced with alternative configurations depending on the problem # Head can be replaced with alternative configurations depending on the problem
self.head = nn.Sequential( self.head = nn.Sequential(
conv_bn(416, 756, 2), conv_bn(416, 756, 2),
conv_bn(756, 1024, 1), conv_bn(756, 1024, 1),
conv_bn(1024, 1024, 2), conv_bn(1024, 1024, 2),
conv_1x1_bn(1024, 1280), conv_1x1_bn(1024, 1280),
) )
self.num_features = 1280 self.num_features = 1280
elif cfg=='selecsls60_B': elif cfg == 'selecsls60b':
self.block = SelecSLSBlock self.block = SelecSLSBlock
#Define configuration of the network after the initial neck # Define configuration of the network after the initial neck
self.selecSLS_config = [ self.selecsls_config = [
#inp,skip, k, oup, isFirst, stride # inp,skip, k, oup, is_first, stride
[ 32, 0, 64, 64, True, 2], [32, 0, 64, 64, True, 2],
[ 64, 64, 64, 128, False, 1], [64, 64, 64, 128, False, 1],
[128, 0, 128, 128, True, 2], [128, 0, 128, 128, True, 2],
[128, 128, 128, 128, False, 1], [128, 128, 128, 128, False, 1],
[128, 128, 128, 288, False, 1], [128, 128, 128, 288, False, 1],
[288, 0, 288, 288, True, 2], [288, 0, 288, 288, True, 2],
[288, 288, 288, 288, False, 1], [288, 288, 288, 288, False, 1],
[288, 288, 288, 288, False, 1], [288, 288, 288, 288, False, 1],
[288, 288, 288, 416, False, 1], [288, 288, 288, 416, False, 1],
] ]
#Head can be replaced with alternative configurations depending on the problem # Head can be replaced with alternative configurations depending on the problem
self.head = nn.Sequential( self.head = nn.Sequential(
conv_bn(416, 756, 2), conv_bn(416, 756, 2),
conv_bn(756, 1024, 1), conv_bn(756, 1024, 1),
conv_bn(1024, 1280, 2), conv_bn(1024, 1280, 2),
conv_1x1_bn(1280, 1024), conv_1x1_bn(1280, 1024),
) )
self.num_features = 1024 self.num_features = 1024
elif cfg=='selecsls84': elif cfg == 'selecsls84':
self.block = SelecSLSBlock self.block = SelecSLSBlock
#Define configuration of the network after the initial neck # Define configuration of the network after the initial neck
self.selecSLS_config = [ self.selecsls_config = [
#inp,skip, k, oup, isFirst, stride # inp,skip, k, oup, is_first, stride
[ 32, 0, 64, 64, True, 2], [32, 0, 64, 64, True, 2],
[ 64, 64, 64, 144, False, 1], [64, 64, 64, 144, False, 1],
[144, 0, 144, 144, True, 2], [144, 0, 144, 144, True, 2],
[144, 144, 144, 144, False, 1], [144, 144, 144, 144, False, 1],
[144, 144, 144, 144, False, 1], [144, 144, 144, 144, False, 1],
[144, 144, 144, 144, False, 1], [144, 144, 144, 144, False, 1],
[144, 144, 144, 304, False, 1], [144, 144, 144, 304, False, 1],
[304, 0, 304, 304, True, 2], [304, 0, 304, 304, True, 2],
[304, 304, 304, 304, False, 1], [304, 304, 304, 304, False, 1],
[304, 304, 304, 304, False, 1], [304, 304, 304, 304, False, 1],
[304, 304, 304, 304, False, 1], [304, 304, 304, 304, False, 1],
[304, 304, 304, 304, False, 1], [304, 304, 304, 304, False, 1],
[304, 304, 304, 512, False, 1], [304, 304, 304, 512, False, 1],
] ]
#Head can be replaced with alternative configurations depending on the problem # Head can be replaced with alternative configurations depending on the problem
self.head = nn.Sequential( self.head = nn.Sequential(
conv_bn(512, 960, 2), conv_bn(512, 960, 2),
conv_bn(960, 1024, 1), conv_bn(960, 1024, 1),
conv_bn(1024, 1024, 2), conv_bn(1024, 1024, 2),
conv_1x1_bn(1024, 1280), conv_1x1_bn(1024, 1280),
) )
self.num_features = 1280 self.num_features = 1280
else: else:
raise ValueError('Invalid net configuration '+cfg+' !!!') raise ValueError('Invalid net configuration ' + cfg + ' !!!')
for inp, skip, k, oup, isFirst, stride in self.selecSLS_config: for inp, skip, k, oup, is_first, stride in self.selecsls_config:
self.features.append(self.block(inp, skip, k, oup, isFirst, stride)) self.features.append(self.block(inp, skip, k, oup, is_first, stride))
self.features = nn.Sequential(*self.features) self.features = nn.Sequential(*self.features)
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
self.fc = nn.Linear(self.num_features * self.global_pool.feat_mult(), num_classes) self.fc = nn.Linear(self.num_features * self.global_pool.feat_mult(), num_classes)
@ -317,25 +319,27 @@ def selecsls42(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
load_pretrained(model, default_cfg, num_classes, in_chans) load_pretrained(model, default_cfg, num_classes, in_chans)
return model return model
@register_model @register_model
def selecsls42_B(pretrained=False, num_classes=1000, in_chans=3, **kwargs): def selecsls42b(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
"""Constructs a SelecSLS42_B model. """Constructs a SelecSLS42_B model.
""" """
default_cfg = default_cfgs['selecsls42_B'] default_cfg = default_cfgs['selecsls42b']
model = SelecSLS( model = SelecSLS(
cfg='selecsls42_B', num_classes=1000, in_chans=3,**kwargs) cfg='selecsls42b', num_classes=1000, in_chans=3, **kwargs)
model.default_cfg = default_cfg model.default_cfg = default_cfg
if pretrained: if pretrained:
load_pretrained(model, default_cfg, num_classes, in_chans) load_pretrained(model, default_cfg, num_classes, in_chans)
return model return model
@register_model @register_model
def selecsls60(pretrained=False, num_classes=1000, in_chans=3, **kwargs): def selecsls60(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
"""Constructs a SelecSLS60 model. """Constructs a SelecSLS60 model.
""" """
default_cfg = default_cfgs['selecsls60'] default_cfg = default_cfgs['selecsls60']
model = SelecSLS( model = SelecSLS(
cfg='selecsls60', num_classes=1000, in_chans=3,**kwargs) cfg='selecsls60', num_classes=1000, in_chans=3, **kwargs)
model.default_cfg = default_cfg model.default_cfg = default_cfg
if pretrained: if pretrained:
load_pretrained(model, default_cfg, num_classes, in_chans) load_pretrained(model, default_cfg, num_classes, in_chans)
@ -343,17 +347,18 @@ def selecsls60(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
@register_model @register_model
def selecsls60_B(pretrained=False, num_classes=1000, in_chans=3, **kwargs): def selecsls60b(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
"""Constructs a SelecSLS60_B model. """Constructs a SelecSLS60_B model.
""" """
default_cfg = default_cfgs['selecsls60_B'] default_cfg = default_cfgs['selecsls60b']
model = SelecSLS( model = SelecSLS(
cfg='selecsls60_B', num_classes=1000, in_chans=3,**kwargs) cfg='selecsls60b', num_classes=1000, in_chans=3, **kwargs)
model.default_cfg = default_cfg model.default_cfg = default_cfg
if pretrained: if pretrained:
load_pretrained(model, default_cfg, num_classes, in_chans) load_pretrained(model, default_cfg, num_classes, in_chans)
return model return model
@register_model @register_model
def selecsls84(pretrained=False, num_classes=1000, in_chans=3, **kwargs): def selecsls84(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
"""Constructs a SelecSLS84 model. """Constructs a SelecSLS84 model.
@ -365,4 +370,3 @@ def selecsls84(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
if pretrained: if pretrained:
load_pretrained(model, default_cfg, num_classes, in_chans) load_pretrained(model, default_cfg, num_classes, in_chans)
return model return model

Loading…
Cancel
Save