From d59a756c167c2d996b7df086df5d2905aacc6ef9 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Mon, 30 Dec 2019 14:30:46 -0800 Subject: [PATCH 1/4] Run PyCharm autoformat on selecsls and change mix cap variables and model names to all lower --- timm/models/selecsls.py | 282 ++++++++++++++++++++-------------------- 1 file changed, 143 insertions(+), 139 deletions(-) diff --git a/timm/models/selecsls.py b/timm/models/selecsls.py index f37834a7..b7f2a9f0 100644 --- a/timm/models/selecsls.py +++ b/timm/models/selecsls.py @@ -20,7 +20,6 @@ from .helpers import load_pretrained from .adaptive_avgmax_pool import SelectAdaptivePool2d from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD - __all__ = ['SelecSLS'] # model_registry will add each entrypoint fn to this @@ -39,13 +38,13 @@ default_cfgs = { 'selecsls42': _cfg( url='', interpolation='bicubic'), - 'selecsls42_B': _cfg( + 'selecsls42b': _cfg( url='http://gvv.mpi-inf.mpg.de/projects/XNect/assets/models/SelecSLS42_B.pth', interpolation='bicubic'), 'selecsls60': _cfg( url='', interpolation='bicubic'), - 'selecsls60_B': _cfg( + 'selecsls60b': _cfg( url='http://gvv.mpi-inf.mpg.de/projects/XNect/assets/models/SelecSLS60_B.pth', interpolation='bicubic'), 'selecsls84': _cfg( @@ -69,57 +68,59 @@ def conv_1x1_bn(inp, oup): nn.ReLU(inplace=True) ) + class SelecSLSBlock(nn.Module): - def __init__(self, inp, skip, k, oup, isFirst, stride): + def __init__(self, inp, skip, k, oup, is_first, stride): super(SelecSLSBlock, self).__init__() self.stride = stride - self.isFirst = isFirst + self.is_first = is_first assert stride in [1, 2] - #Process input with 4 conv blocks with the same number of input and output channels + # Process input with 4 conv blocks with the same number of input and output channels self.conv1 = nn.Sequential( - nn.Conv2d(inp, k, 3, stride, 1,groups= 1, bias=False, dilation=1), - nn.BatchNorm2d(k), - nn.ReLU(inplace=True) - ) + nn.Conv2d(inp, k, 3, stride, 1, groups=1, bias=False, dilation=1), + nn.BatchNorm2d(k), + nn.ReLU(inplace=True) + ) self.conv2 = nn.Sequential( - nn.Conv2d(k, k, 1, 1, 0,groups= 1, bias=False, dilation=1), - nn.BatchNorm2d(k), - nn.ReLU(inplace=True) - ) + nn.Conv2d(k, k, 1, 1, 0, groups=1, bias=False, dilation=1), + nn.BatchNorm2d(k), + nn.ReLU(inplace=True) + ) self.conv3 = nn.Sequential( - nn.Conv2d(k, k//2, 3, 1, 1,groups= 1, bias=False, dilation=1), - nn.BatchNorm2d(k//2), - nn.ReLU(inplace=True) - ) + nn.Conv2d(k, k // 2, 3, 1, 1, groups=1, bias=False, dilation=1), + nn.BatchNorm2d(k // 2), + nn.ReLU(inplace=True) + ) self.conv4 = nn.Sequential( - nn.Conv2d(k//2, k, 1, 1, 0,groups= 1, bias=False, dilation=1), - nn.BatchNorm2d(k), - nn.ReLU(inplace=True) - ) + nn.Conv2d(k // 2, k, 1, 1, 0, groups=1, bias=False, dilation=1), + nn.BatchNorm2d(k), + nn.ReLU(inplace=True) + ) self.conv5 = nn.Sequential( - nn.Conv2d(k, k//2, 3, 1, 1,groups= 1, bias=False, dilation=1), - nn.BatchNorm2d(k//2), - nn.ReLU(inplace=True) - ) + nn.Conv2d(k, k // 2, 3, 1, 1, groups=1, bias=False, dilation=1), + nn.BatchNorm2d(k // 2), + nn.ReLU(inplace=True) + ) self.conv6 = nn.Sequential( - nn.Conv2d(2*k + (0 if isFirst else skip), oup, 1, 1, 0,groups= 1, bias=False, dilation=1), - nn.BatchNorm2d(oup), - nn.ReLU(inplace=True) - ) + nn.Conv2d(2 * k + (0 if is_first else skip), oup, 1, 1, 0, groups=1, bias=False, dilation=1), + nn.BatchNorm2d(oup), + nn.ReLU(inplace=True) + ) def forward(self, x): - assert isinstance(x,list) - assert len(x) in [1,2] + assert isinstance(x, list) + assert len(x) in [1, 2] d1 = self.conv1(x[0]) d2 = self.conv3(self.conv2(d1)) d3 = self.conv5(self.conv4(d2)) - if self.isFirst: + if self.is_first: out = self.conv6(torch.cat([d1, d2, d3], 1)) return [out, out] else: - return [self.conv6(torch.cat([d1, d2, d3, x[1]], 1)) , x[1]] + return [self.conv6(torch.cat([d1, d2, d3, x[1]], 1)), x[1]] + class SelecSLS(nn.Module): """SelecSLS42 / SelecSLS60 / SelecSLS84 @@ -137,6 +138,7 @@ class SelecSLS(nn.Module): global_pool : str, default 'avg' Global pooling type. One of 'avg', 'max', 'avgmax', 'catavgmax' """ + def __init__(self, cfg='selecsls60', num_classes=1000, in_chans=3, drop_rate=0.0, global_pool='avg'): self.num_classes = num_classes @@ -144,126 +146,126 @@ class SelecSLS(nn.Module): super(SelecSLS, self).__init__() self.stem = conv_bn(in_chans, 32, 2) - #Core Network + # Core Network self.features = [] - if cfg=='selecsls42': + if cfg == 'selecsls42': self.block = SelecSLSBlock - #Define configuration of the network after the initial neck - self.selecSLS_config = [ - #inp,skip, k, oup, isFirst, stride - [ 32, 0, 64, 64, True, 2], - [ 64, 64, 64, 128, False, 1], - [128, 0, 144, 144, True, 2], - [144, 144, 144, 288, False, 1], - [288, 0, 304, 304, True, 2], - [304, 304, 304, 480, False, 1], + # Define configuration of the network after the initial neck + self.selecsls_config = [ + # inp,skip, k, oup, is_first, stride + [32, 0, 64, 64, True, 2], + [64, 64, 64, 128, False, 1], + [128, 0, 144, 144, True, 2], + [144, 144, 144, 288, False, 1], + [288, 0, 304, 304, True, 2], + [304, 304, 304, 480, False, 1], ] - #Head can be replaced with alternative configurations depending on the problem + # Head can be replaced with alternative configurations depending on the problem self.head = nn.Sequential( - conv_bn(480, 960, 2), - conv_bn(960, 1024, 1), - conv_bn(1024, 1024, 2), - conv_1x1_bn(1024, 1280), - ) + conv_bn(480, 960, 2), + conv_bn(960, 1024, 1), + conv_bn(1024, 1024, 2), + conv_1x1_bn(1024, 1280), + ) self.num_features = 1280 - elif cfg=='selecsls42_B': + elif cfg == 'selecsls42b': self.block = SelecSLSBlock - #Define configuration of the network after the initial neck - self.selecSLS_config = [ - #inp,skip, k, oup, isFirst, stride - [ 32, 0, 64, 64, True, 2], - [ 64, 64, 64, 128, False, 1], - [128, 0, 144, 144, True, 2], - [144, 144, 144, 288, False, 1], - [288, 0, 304, 304, True, 2], - [304, 304, 304, 480, False, 1], + # Define configuration of the network after the initial neck + self.selecsls_config = [ + # inp,skip, k, oup, is_first, stride + [32, 0, 64, 64, True, 2], + [64, 64, 64, 128, False, 1], + [128, 0, 144, 144, True, 2], + [144, 144, 144, 288, False, 1], + [288, 0, 304, 304, True, 2], + [304, 304, 304, 480, False, 1], ] - #Head can be replaced with alternative configurations depending on the problem + # Head can be replaced with alternative configurations depending on the problem self.head = nn.Sequential( - conv_bn(480, 960, 2), - conv_bn(960, 1024, 1), - conv_bn(1024, 1280, 2), - conv_1x1_bn(1280, 1024), - ) + conv_bn(480, 960, 2), + conv_bn(960, 1024, 1), + conv_bn(1024, 1280, 2), + conv_1x1_bn(1280, 1024), + ) self.num_features = 1024 - elif cfg=='selecsls60': + elif cfg == 'selecsls60': self.block = SelecSLSBlock - #Define configuration of the network after the initial neck - self.selecSLS_config = [ - #inp,skip, k, oup, isFirst, stride - [ 32, 0, 64, 64, True, 2], - [ 64, 64, 64, 128, False, 1], - [128, 0, 128, 128, True, 2], - [128, 128, 128, 128, False, 1], - [128, 128, 128, 288, False, 1], - [288, 0, 288, 288, True, 2], - [288, 288, 288, 288, False, 1], - [288, 288, 288, 288, False, 1], - [288, 288, 288, 416, False, 1], + # Define configuration of the network after the initial neck + self.selecsls_config = [ + # inp,skip, k, oup, is_first, stride + [32, 0, 64, 64, True, 2], + [64, 64, 64, 128, False, 1], + [128, 0, 128, 128, True, 2], + [128, 128, 128, 128, False, 1], + [128, 128, 128, 288, False, 1], + [288, 0, 288, 288, True, 2], + [288, 288, 288, 288, False, 1], + [288, 288, 288, 288, False, 1], + [288, 288, 288, 416, False, 1], ] - #Head can be replaced with alternative configurations depending on the problem + # Head can be replaced with alternative configurations depending on the problem self.head = nn.Sequential( - conv_bn(416, 756, 2), - conv_bn(756, 1024, 1), - conv_bn(1024, 1024, 2), - conv_1x1_bn(1024, 1280), - ) + conv_bn(416, 756, 2), + conv_bn(756, 1024, 1), + conv_bn(1024, 1024, 2), + conv_1x1_bn(1024, 1280), + ) self.num_features = 1280 - elif cfg=='selecsls60_B': + elif cfg == 'selecsls60b': self.block = SelecSLSBlock - #Define configuration of the network after the initial neck - self.selecSLS_config = [ - #inp,skip, k, oup, isFirst, stride - [ 32, 0, 64, 64, True, 2], - [ 64, 64, 64, 128, False, 1], - [128, 0, 128, 128, True, 2], - [128, 128, 128, 128, False, 1], - [128, 128, 128, 288, False, 1], - [288, 0, 288, 288, True, 2], - [288, 288, 288, 288, False, 1], - [288, 288, 288, 288, False, 1], - [288, 288, 288, 416, False, 1], + # Define configuration of the network after the initial neck + self.selecsls_config = [ + # inp,skip, k, oup, is_first, stride + [32, 0, 64, 64, True, 2], + [64, 64, 64, 128, False, 1], + [128, 0, 128, 128, True, 2], + [128, 128, 128, 128, False, 1], + [128, 128, 128, 288, False, 1], + [288, 0, 288, 288, True, 2], + [288, 288, 288, 288, False, 1], + [288, 288, 288, 288, False, 1], + [288, 288, 288, 416, False, 1], ] - #Head can be replaced with alternative configurations depending on the problem + # Head can be replaced with alternative configurations depending on the problem self.head = nn.Sequential( - conv_bn(416, 756, 2), - conv_bn(756, 1024, 1), - conv_bn(1024, 1280, 2), - conv_1x1_bn(1280, 1024), - ) + conv_bn(416, 756, 2), + conv_bn(756, 1024, 1), + conv_bn(1024, 1280, 2), + conv_1x1_bn(1280, 1024), + ) self.num_features = 1024 - elif cfg=='selecsls84': + elif cfg == 'selecsls84': self.block = SelecSLSBlock - #Define configuration of the network after the initial neck - self.selecSLS_config = [ - #inp,skip, k, oup, isFirst, stride - [ 32, 0, 64, 64, True, 2], - [ 64, 64, 64, 144, False, 1], - [144, 0, 144, 144, True, 2], - [144, 144, 144, 144, False, 1], - [144, 144, 144, 144, False, 1], - [144, 144, 144, 144, False, 1], - [144, 144, 144, 304, False, 1], - [304, 0, 304, 304, True, 2], - [304, 304, 304, 304, False, 1], - [304, 304, 304, 304, False, 1], - [304, 304, 304, 304, False, 1], - [304, 304, 304, 304, False, 1], - [304, 304, 304, 512, False, 1], + # Define configuration of the network after the initial neck + self.selecsls_config = [ + # inp,skip, k, oup, is_first, stride + [32, 0, 64, 64, True, 2], + [64, 64, 64, 144, False, 1], + [144, 0, 144, 144, True, 2], + [144, 144, 144, 144, False, 1], + [144, 144, 144, 144, False, 1], + [144, 144, 144, 144, False, 1], + [144, 144, 144, 304, False, 1], + [304, 0, 304, 304, True, 2], + [304, 304, 304, 304, False, 1], + [304, 304, 304, 304, False, 1], + [304, 304, 304, 304, False, 1], + [304, 304, 304, 304, False, 1], + [304, 304, 304, 512, False, 1], ] - #Head can be replaced with alternative configurations depending on the problem + # Head can be replaced with alternative configurations depending on the problem self.head = nn.Sequential( - conv_bn(512, 960, 2), - conv_bn(960, 1024, 1), - conv_bn(1024, 1024, 2), - conv_1x1_bn(1024, 1280), - ) + conv_bn(512, 960, 2), + conv_bn(960, 1024, 1), + conv_bn(1024, 1024, 2), + conv_1x1_bn(1024, 1280), + ) self.num_features = 1280 else: - raise ValueError('Invalid net configuration '+cfg+' !!!') + raise ValueError('Invalid net configuration ' + cfg + ' !!!') - for inp, skip, k, oup, isFirst, stride in self.selecSLS_config: - self.features.append(self.block(inp, skip, k, oup, isFirst, stride)) + for inp, skip, k, oup, is_first, stride in self.selecsls_config: + self.features.append(self.block(inp, skip, k, oup, is_first, stride)) self.features = nn.Sequential(*self.features) self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) self.fc = nn.Linear(self.num_features * self.global_pool.feat_mult(), num_classes) @@ -317,25 +319,27 @@ def selecsls42(pretrained=False, num_classes=1000, in_chans=3, **kwargs): load_pretrained(model, default_cfg, num_classes, in_chans) return model + @register_model -def selecsls42_B(pretrained=False, num_classes=1000, in_chans=3, **kwargs): +def selecsls42b(pretrained=False, num_classes=1000, in_chans=3, **kwargs): """Constructs a SelecSLS42_B model. """ - default_cfg = default_cfgs['selecsls42_B'] + default_cfg = default_cfgs['selecsls42b'] model = SelecSLS( - cfg='selecsls42_B', num_classes=1000, in_chans=3,**kwargs) + cfg='selecsls42b', num_classes=1000, in_chans=3, **kwargs) model.default_cfg = default_cfg if pretrained: load_pretrained(model, default_cfg, num_classes, in_chans) return model + @register_model def selecsls60(pretrained=False, num_classes=1000, in_chans=3, **kwargs): """Constructs a SelecSLS60 model. """ default_cfg = default_cfgs['selecsls60'] model = SelecSLS( - cfg='selecsls60', num_classes=1000, in_chans=3,**kwargs) + cfg='selecsls60', num_classes=1000, in_chans=3, **kwargs) model.default_cfg = default_cfg if pretrained: load_pretrained(model, default_cfg, num_classes, in_chans) @@ -343,17 +347,18 @@ def selecsls60(pretrained=False, num_classes=1000, in_chans=3, **kwargs): @register_model -def selecsls60_B(pretrained=False, num_classes=1000, in_chans=3, **kwargs): +def selecsls60b(pretrained=False, num_classes=1000, in_chans=3, **kwargs): """Constructs a SelecSLS60_B model. """ - default_cfg = default_cfgs['selecsls60_B'] + default_cfg = default_cfgs['selecsls60b'] model = SelecSLS( - cfg='selecsls60_B', num_classes=1000, in_chans=3,**kwargs) + cfg='selecsls60b', num_classes=1000, in_chans=3, **kwargs) model.default_cfg = default_cfg if pretrained: load_pretrained(model, default_cfg, num_classes, in_chans) return model + @register_model def selecsls84(pretrained=False, num_classes=1000, in_chans=3, **kwargs): """Constructs a SelecSLS84 model. @@ -365,4 +370,3 @@ def selecsls84(pretrained=False, num_classes=1000, in_chans=3, **kwargs): if pretrained: load_pretrained(model, default_cfg, num_classes, in_chans) return model - From b5315e66b523d32fc92825675b5370ca1ebcba6f Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Mon, 30 Dec 2019 15:44:47 -0800 Subject: [PATCH 2/4] Streamline SelecSLS model without breaking checkpoint compat. Move cfg handling out of model class. Update feature/pooling behaviour to match current. --- timm/models/selecsls.py | 346 ++++++++++++++++------------------------ 1 file changed, 134 insertions(+), 212 deletions(-) diff --git a/timm/models/selecsls.py b/timm/models/selecsls.py index b7f2a9f0..ef9b85dd 100644 --- a/timm/models/selecsls.py +++ b/timm/models/selecsls.py @@ -53,60 +53,30 @@ default_cfgs = { } -def conv_bn(inp, oup, stride): +def conv_bn(in_chs, out_chs, k=3, stride=1, padding=None, dilation=1): + if padding is None: + padding = ((stride - 1) + dilation * (k - 1)) // 2 return nn.Sequential( - nn.Conv2d(inp, oup, 3, stride, 1, bias=False), - nn.BatchNorm2d(oup), - nn.ReLU(inplace=True) - ) - - -def conv_1x1_bn(inp, oup): - return nn.Sequential( - nn.Conv2d(inp, oup, 1, 1, 0, bias=False), - nn.BatchNorm2d(oup), + nn.Conv2d(in_chs, out_chs, k, stride, padding=padding, dilation=dilation, bias=False), + nn.BatchNorm2d(out_chs), nn.ReLU(inplace=True) ) class SelecSLSBlock(nn.Module): - def __init__(self, inp, skip, k, oup, is_first, stride): + def __init__(self, in_chs, skip_chs, mid_chs, out_chs, is_first, stride, dilation=1): super(SelecSLSBlock, self).__init__() self.stride = stride self.is_first = is_first assert stride in [1, 2] # Process input with 4 conv blocks with the same number of input and output channels - self.conv1 = nn.Sequential( - nn.Conv2d(inp, k, 3, stride, 1, groups=1, bias=False, dilation=1), - nn.BatchNorm2d(k), - nn.ReLU(inplace=True) - ) - self.conv2 = nn.Sequential( - nn.Conv2d(k, k, 1, 1, 0, groups=1, bias=False, dilation=1), - nn.BatchNorm2d(k), - nn.ReLU(inplace=True) - ) - self.conv3 = nn.Sequential( - nn.Conv2d(k, k // 2, 3, 1, 1, groups=1, bias=False, dilation=1), - nn.BatchNorm2d(k // 2), - nn.ReLU(inplace=True) - ) - self.conv4 = nn.Sequential( - nn.Conv2d(k // 2, k, 1, 1, 0, groups=1, bias=False, dilation=1), - nn.BatchNorm2d(k), - nn.ReLU(inplace=True) - ) - self.conv5 = nn.Sequential( - nn.Conv2d(k, k // 2, 3, 1, 1, groups=1, bias=False, dilation=1), - nn.BatchNorm2d(k // 2), - nn.ReLU(inplace=True) - ) - self.conv6 = nn.Sequential( - nn.Conv2d(2 * k + (0 if is_first else skip), oup, 1, 1, 0, groups=1, bias=False, dilation=1), - nn.BatchNorm2d(oup), - nn.ReLU(inplace=True) - ) + self.conv1 = conv_bn(in_chs, mid_chs, 3, stride, dilation=dilation) + self.conv2 = conv_bn(mid_chs, mid_chs, 1) + self.conv3 = conv_bn(mid_chs, mid_chs // 2, 3) + self.conv4 = conv_bn(mid_chs // 2, mid_chs, 1) + self.conv5 = conv_bn(mid_chs, mid_chs // 2, 3) + self.conv6 = conv_bn(2 * mid_chs + (0 if is_first else skip_chs), out_chs, 1) def forward(self, x): assert isinstance(x, list) @@ -127,8 +97,7 @@ class SelecSLS(nn.Module): Parameters ---------- - cfg : network config - String indicating the network config + cfg : network config dictionary specifying block type, feature, and head args num_classes : int, default 1000 Number of classification classes. in_chans : int, default 3 @@ -139,134 +108,16 @@ class SelecSLS(nn.Module): Global pooling type. One of 'avg', 'max', 'avgmax', 'catavgmax' """ - def __init__(self, cfg='selecsls60', num_classes=1000, in_chans=3, - drop_rate=0.0, global_pool='avg'): + def __init__(self, cfg, num_classes=1000, in_chans=3, drop_rate=0.0, global_pool='avg'): self.num_classes = num_classes self.drop_rate = drop_rate super(SelecSLS, self).__init__() - self.stem = conv_bn(in_chans, 32, 2) - # Core Network - self.features = [] - if cfg == 'selecsls42': - self.block = SelecSLSBlock - # Define configuration of the network after the initial neck - self.selecsls_config = [ - # inp,skip, k, oup, is_first, stride - [32, 0, 64, 64, True, 2], - [64, 64, 64, 128, False, 1], - [128, 0, 144, 144, True, 2], - [144, 144, 144, 288, False, 1], - [288, 0, 304, 304, True, 2], - [304, 304, 304, 480, False, 1], - ] - # Head can be replaced with alternative configurations depending on the problem - self.head = nn.Sequential( - conv_bn(480, 960, 2), - conv_bn(960, 1024, 1), - conv_bn(1024, 1024, 2), - conv_1x1_bn(1024, 1280), - ) - self.num_features = 1280 - elif cfg == 'selecsls42b': - self.block = SelecSLSBlock - # Define configuration of the network after the initial neck - self.selecsls_config = [ - # inp,skip, k, oup, is_first, stride - [32, 0, 64, 64, True, 2], - [64, 64, 64, 128, False, 1], - [128, 0, 144, 144, True, 2], - [144, 144, 144, 288, False, 1], - [288, 0, 304, 304, True, 2], - [304, 304, 304, 480, False, 1], - ] - # Head can be replaced with alternative configurations depending on the problem - self.head = nn.Sequential( - conv_bn(480, 960, 2), - conv_bn(960, 1024, 1), - conv_bn(1024, 1280, 2), - conv_1x1_bn(1280, 1024), - ) - self.num_features = 1024 - elif cfg == 'selecsls60': - self.block = SelecSLSBlock - # Define configuration of the network after the initial neck - self.selecsls_config = [ - # inp,skip, k, oup, is_first, stride - [32, 0, 64, 64, True, 2], - [64, 64, 64, 128, False, 1], - [128, 0, 128, 128, True, 2], - [128, 128, 128, 128, False, 1], - [128, 128, 128, 288, False, 1], - [288, 0, 288, 288, True, 2], - [288, 288, 288, 288, False, 1], - [288, 288, 288, 288, False, 1], - [288, 288, 288, 416, False, 1], - ] - # Head can be replaced with alternative configurations depending on the problem - self.head = nn.Sequential( - conv_bn(416, 756, 2), - conv_bn(756, 1024, 1), - conv_bn(1024, 1024, 2), - conv_1x1_bn(1024, 1280), - ) - self.num_features = 1280 - elif cfg == 'selecsls60b': - self.block = SelecSLSBlock - # Define configuration of the network after the initial neck - self.selecsls_config = [ - # inp,skip, k, oup, is_first, stride - [32, 0, 64, 64, True, 2], - [64, 64, 64, 128, False, 1], - [128, 0, 128, 128, True, 2], - [128, 128, 128, 128, False, 1], - [128, 128, 128, 288, False, 1], - [288, 0, 288, 288, True, 2], - [288, 288, 288, 288, False, 1], - [288, 288, 288, 288, False, 1], - [288, 288, 288, 416, False, 1], - ] - # Head can be replaced with alternative configurations depending on the problem - self.head = nn.Sequential( - conv_bn(416, 756, 2), - conv_bn(756, 1024, 1), - conv_bn(1024, 1280, 2), - conv_1x1_bn(1280, 1024), - ) - self.num_features = 1024 - elif cfg == 'selecsls84': - self.block = SelecSLSBlock - # Define configuration of the network after the initial neck - self.selecsls_config = [ - # inp,skip, k, oup, is_first, stride - [32, 0, 64, 64, True, 2], - [64, 64, 64, 144, False, 1], - [144, 0, 144, 144, True, 2], - [144, 144, 144, 144, False, 1], - [144, 144, 144, 144, False, 1], - [144, 144, 144, 144, False, 1], - [144, 144, 144, 304, False, 1], - [304, 0, 304, 304, True, 2], - [304, 304, 304, 304, False, 1], - [304, 304, 304, 304, False, 1], - [304, 304, 304, 304, False, 1], - [304, 304, 304, 304, False, 1], - [304, 304, 304, 512, False, 1], - ] - # Head can be replaced with alternative configurations depending on the problem - self.head = nn.Sequential( - conv_bn(512, 960, 2), - conv_bn(960, 1024, 1), - conv_bn(1024, 1024, 2), - conv_1x1_bn(1024, 1280), - ) - self.num_features = 1280 - else: - raise ValueError('Invalid net configuration ' + cfg + ' !!!') + self.stem = conv_bn(in_chans, 32, stride=2) + self.features = nn.Sequential(*[cfg['block'](*block_args) for block_args in cfg['features']]) + self.head = nn.Sequential(*[conv_bn(*conv_args) for conv_args in cfg['head']]) + self.num_features = cfg['num_features'] - for inp, skip, k, oup, is_first, stride in self.selecsls_config: - self.features.append(self.block(inp, skip, k, oup, is_first, stride)) - self.features = nn.Sequential(*self.features) self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) self.fc = nn.Linear(self.num_features * self.global_pool.feat_mult(), num_classes) @@ -289,84 +140,155 @@ class SelecSLS(nn.Module): else: self.fc = None - def forward_features(self, x, pool=True): + def forward_features(self, x): x = self.stem(x) x = self.features([x]) x = self.head(x[0]) - - if pool: - x = self.global_pool(x) - x = x.view(x.size(0), -1) return x def forward(self, x): x = self.forward_features(x) + x = self.global_pool(x).flatten(1) if self.drop_rate > 0.: x = F.dropout(x, p=self.drop_rate, training=self.training) x = self.fc(x) return x +def _create_model(variant, pretrained, model_kwargs): + cfg = {} + if variant.startswith('selecsls42'): + cfg['block'] = SelecSLSBlock + # Define configuration of the network after the initial neck + cfg['features'] = [ + # in_chs, skip_chs, mid_chs, out_chs, is_first, stride + (32, 0, 64, 64, True, 2), + (64, 64, 64, 128, False, 1), + (128, 0, 144, 144, True, 2), + (144, 144, 144, 288, False, 1), + (288, 0, 304, 304, True, 2), + (304, 304, 304, 480, False, 1), + ] + # Head can be replaced with alternative configurations depending on the problem + if variant == 'selecsls42b': + cfg['head'] = [ + (480, 960, 3, 2), + (960, 1024, 3, 1), + (1024, 1280, 3, 2), + (1280, 1024, 1, 1), + ] + cfg['num_features'] = 1024 + else: + cfg['head'] = [ + (480, 960, 3, 2), + (960, 1024, 3, 1), + (1024, 1024, 3, 2), + (1024, 1280, 1, 1), + ] + cfg['num_features'] = 1280 + elif variant.startswith('selecsls60'): + cfg['block'] = SelecSLSBlock + # Define configuration of the network after the initial neck + cfg['features'] = [ + # in_chs, skip_chs, mid_chs, out_chs, is_first, stride + (32, 0, 64, 64, True, 2), + (64, 64, 64, 128, False, 1), + (128, 0, 128, 128, True, 2), + (128, 128, 128, 128, False, 1), + (128, 128, 128, 288, False, 1), + (288, 0, 288, 288, True, 2), + (288, 288, 288, 288, False, 1), + (288, 288, 288, 288, False, 1), + (288, 288, 288, 416, False, 1), + ] + # Head can be replaced with alternative configurations depending on the problem + if variant == 'selecsls60b': + cfg['head'] = [ + (416, 756, 3, 2), + (756, 1024, 3, 1), + (1024, 1280, 3, 2), + (1280, 1024, 1, 1), + ] + cfg['num_features'] = 1024 + else: + cfg['head'] = [ + (416, 756, 3, 2), + (756, 1024, 3, 1), + (1024, 1024, 3, 2), + (1024, 1280, 1, 1), + ] + cfg['num_features'] = 1280 + elif variant == 'selecsls84': + cfg['block'] = SelecSLSBlock + # Define configuration of the network after the initial neck + cfg['features'] = [ + # in_chs, skip_chs, mid_chs, out_chs, is_first, stride + (32, 0, 64, 64, True, 2), + (64, 64, 64, 144, False, 1), + (144, 0, 144, 144, True, 2), + (144, 144, 144, 144, False, 1), + (144, 144, 144, 144, False, 1), + (144, 144, 144, 144, False, 1), + (144, 144, 144, 304, False, 1), + (304, 0, 304, 304, True, 2), + (304, 304, 304, 304, False, 1), + (304, 304, 304, 304, False, 1), + (304, 304, 304, 304, False, 1), + (304, 304, 304, 304, False, 1), + (304, 304, 304, 512, False, 1), + ] + # Head can be replaced with alternative configurations depending on the problem + cfg['head'] = [ + (512, 960, 3, 2), + (960, 1024, 3, 1), + (1024, 1024, 3, 2), + (1024, 1280, 3, 1), + ] + cfg['num_features'] = 1280 + else: + raise ValueError('Invalid net configuration ' + variant + ' !!!') + + model = SelecSLS(cfg, **model_kwargs) + model.default_cfg = default_cfgs[variant] + if pretrained: + load_pretrained( + model, + num_classes=model_kwargs.get('num_classes', 0), + in_chans=model_kwargs.get('in_chans', 3), + strict=True) + return model + + @register_model -def selecsls42(pretrained=False, num_classes=1000, in_chans=3, **kwargs): +def selecsls42(pretrained=False, **kwargs): """Constructs a SelecSLS42 model. """ - default_cfg = default_cfgs['selecsls42'] - model = SelecSLS( - cfg='selecsls42', num_classes=1000, in_chans=3, **kwargs) - model.default_cfg = default_cfg - if pretrained: - load_pretrained(model, default_cfg, num_classes, in_chans) - return model + return _create_model('selecsls42', pretrained, kwargs) @register_model -def selecsls42b(pretrained=False, num_classes=1000, in_chans=3, **kwargs): +def selecsls42b(pretrained=False, **kwargs): """Constructs a SelecSLS42_B model. """ - default_cfg = default_cfgs['selecsls42b'] - model = SelecSLS( - cfg='selecsls42b', num_classes=1000, in_chans=3, **kwargs) - model.default_cfg = default_cfg - if pretrained: - load_pretrained(model, default_cfg, num_classes, in_chans) - return model + return _create_model('selecsls42b', pretrained, kwargs) @register_model -def selecsls60(pretrained=False, num_classes=1000, in_chans=3, **kwargs): +def selecsls60(pretrained=False, **kwargs): """Constructs a SelecSLS60 model. """ - default_cfg = default_cfgs['selecsls60'] - model = SelecSLS( - cfg='selecsls60', num_classes=1000, in_chans=3, **kwargs) - model.default_cfg = default_cfg - if pretrained: - load_pretrained(model, default_cfg, num_classes, in_chans) - return model + return _create_model('selecsls60', pretrained, kwargs) @register_model -def selecsls60b(pretrained=False, num_classes=1000, in_chans=3, **kwargs): +def selecsls60b(pretrained=False, **kwargs): """Constructs a SelecSLS60_B model. """ - default_cfg = default_cfgs['selecsls60b'] - model = SelecSLS( - cfg='selecsls60b', num_classes=1000, in_chans=3, **kwargs) - model.default_cfg = default_cfg - if pretrained: - load_pretrained(model, default_cfg, num_classes, in_chans) - return model + return _create_model('selecsls60b', pretrained, kwargs) @register_model -def selecsls84(pretrained=False, num_classes=1000, in_chans=3, **kwargs): +def selecsls84(pretrained=False, **kwargs): """Constructs a SelecSLS84 model. """ - default_cfg = default_cfgs['selecsls84'] - model = SelecSLS( - cfg='selecsls84', num_classes=1000, in_chans=3, **kwargs) - model.default_cfg = default_cfg - if pretrained: - load_pretrained(model, default_cfg, num_classes, in_chans) - return model + return _create_model('selecsls84', pretrained, kwargs) From 0062c15fb0d537b48bddc4861e0ec2d4ef77d093 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Mon, 30 Dec 2019 15:59:19 -0800 Subject: [PATCH 3/4] Update checkpoint url with modelzoo compatible ones. --- timm/models/selecsls.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/timm/models/selecsls.py b/timm/models/selecsls.py index ef9b85dd..b2a38e36 100644 --- a/timm/models/selecsls.py +++ b/timm/models/selecsls.py @@ -39,13 +39,13 @@ default_cfgs = { url='', interpolation='bicubic'), 'selecsls42b': _cfg( - url='http://gvv.mpi-inf.mpg.de/projects/XNect/assets/models/SelecSLS42_B.pth', + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-selecsls/selecsls42b-8af30141.pth', interpolation='bicubic'), 'selecsls60': _cfg( url='', interpolation='bicubic'), 'selecsls60b': _cfg( - url='http://gvv.mpi-inf.mpg.de/projects/XNect/assets/models/SelecSLS60_B.pth', + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-selecsls/selecsls60b-94e619b5.pth', interpolation='bicubic'), 'selecsls84': _cfg( url='', From 84ca3d1f4df88f9ba5e8254924b2844154ea9b9c Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Mon, 30 Dec 2019 16:04:39 -0800 Subject: [PATCH 4/4] Add SelecSLS to sotabench list --- sotabench.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/sotabench.py b/sotabench.py index 533ec3e9..a9399f63 100644 --- a/sotabench.py +++ b/sotabench.py @@ -316,6 +316,13 @@ model_list = [ _entry('hrnet_w44', 'HRNet-W44-C', '1908.07919'), _entry('hrnet_w48', 'HRNet-W48-C', '1908.07919'), _entry('hrnet_w64', 'HRNet-W64-C', '1908.07919'), + + + ## SelecSLS official impl weights + _entry('selecsls42b', 'SelecSLS-42_B', '1907.00837', + model_desc='Originally from https://github.com/mehtadushy/SelecSLS-Pytorch'), + _entry('selecsls60b', 'SelecSLS-60_B', '1907.00837', + model_desc='Originally from https://github.com/mehtadushy/SelecSLS-Pytorch'), ] for m in model_list: