You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

364 lines
13 KiB

"""PyTorch SelecSLS on ImageNet
Based on ResNet implementation in this repository
SelecSLS (core) Network Architecture as proposed in
XNect: Real-time Multi-person 3D Human Pose Estimation with a Single RGB Camera, Mehta et al.
Implementation by Dushyant Mehta (@mehtadushy)
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from .registry import register_model
from .helpers import load_pretrained
from .adaptive_avgmax_pool import SelectAdaptivePool2d
__all__ = ['SelecSLS'] # model_registry will add each entrypoint fn to this
def _cfg(url='', **kwargs):
return {
'url': url,
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (3, 3),
'crop_pct': 0.875, 'interpolation': 'bilinear',
'first_conv': 'stem', 'classifier': 'fc',
default_cfgs = {
'selecsls42': _cfg(
'selecsls60': _cfg(
'selecsls60NH': _cfg(
'selecsls84': _cfg(
def conv_bn(inp, oup, stride):
return nn.Sequential(
nn.Conv2d(inp, oup, 3, stride, 1, bias=False),
def conv_1x1_bn(inp, oup):
return nn.Sequential(
nn.Conv2d(inp, oup, 1, 1, 0, bias=False),
class SelecSLSBlock(nn.Module):
def __init__(self, inp, skip, k, oup, isFirst, stride):
super(SelecSLSBlock, self).__init__()
self.stride = stride
self.isFirst = isFirst
assert stride in [1, 2]
#Process input with 4 conv blocks with the same number of input and output channels
self.conv1 = nn.Sequential(
nn.Conv2d(inp, k, 3, stride, 1,groups= 1, bias=False, dilation=1),
self.conv2 = nn.Sequential(
nn.Conv2d(k, k, 1, 1, 0,groups= 1, bias=False, dilation=1),
self.conv3 = nn.Sequential(
nn.Conv2d(k, k//2, 3, 1, 1,groups= 1, bias=False, dilation=1),
self.conv4 = nn.Sequential(
nn.Conv2d(k//2, k, 1, 1, 0,groups= 1, bias=False, dilation=1),
self.conv5 = nn.Sequential(
nn.Conv2d(k, k//2, 3, 1, 1,groups= 1, bias=False, dilation=1),
self.conv6 = nn.Sequential(
nn.Conv2d(2*k + (0 if isFirst else skip), oup, 1, 1, 0,groups= 1, bias=False, dilation=1),
def forward(self, x):
assert isinstance(x,list)
assert len(x) in [1,2]
d1 = self.conv1(x[0])
d2 = self.conv3(self.conv2(d1))
d3 = self.conv5(self.conv4(d2))
if self.isFirst:
out = self.conv6([d1, d2, d3], 1))
return [out, out]
return [self.conv6([d1, d2, d3, x[1]], 1)) , x[1]]
class SelecSLS(nn.Module):
"""SelecSLS42 / SelecSLS60 / SelecSLS84
cfg : network config
String indicating the network config
num_classes : int, default 1000
Number of classification classes.
in_chans : int, default 3
Number of input (color) channels.
drop_rate : float, default 0.
Dropout probability before classifier, for training
global_pool : str, default 'avg'
Global pooling type. One of 'avg', 'max', 'avgmax', 'catavgmax'
def __init__(self, cfg='selecsls60', num_classes=1000, in_chans=3,
drop_rate=0.0, global_pool='avg'):
self.num_classes = num_classes
self.drop_rate = drop_rate
super(SelecSLS, self).__init__()
self.stem = conv_bn(in_chans, 32, 2)
#Core Network
self.features = []
if cfg=='selecsls42':
self.block = SelecSLSBlock
#Define configuration of the network after the initial neck
self.selecSLS_config = [
#inp,skip, k, oup, isFirst, stride
[ 32, 0, 64, 64, True, 2],
[ 64, 64, 64, 128, False, 1],
[128, 0, 144, 144, True, 2],
[144, 144, 144, 288, False, 1],
[288, 0, 304, 304, True, 2],
[304, 304, 304, 480, False, 1],
#Head can be replaced with alternative configurations depending on the problem
self.head = nn.Sequential(
conv_bn(480, 960, 2),
conv_bn(960, 1024, 1),
conv_bn(1024, 1024, 2),
conv_1x1_bn(1024, 1280),
self.num_features = 1280
elif cfg=='selecsls42NH':
self.block = SelecSLSBlock
#Define configuration of the network after the initial neck
self.selecSLS_config = [
#inp,skip, k, oup, isFirst, stride
[ 32, 0, 64, 64, True, 2],
[ 64, 64, 64, 128, False, 1],
[128, 0, 144, 144, True, 2],
[144, 144, 144, 288, False, 1],
[288, 0, 304, 304, True, 2],
[304, 304, 304, 480, False, 1],
#Head can be replaced with alternative configurations depending on the problem
self.head = nn.Sequential(
conv_bn(480, 960, 2),
conv_bn(960, 1024, 1),
conv_bn(1024, 1280, 2),
conv_1x1_bn(1280, 1024),
self.num_features = 1024
elif cfg=='selecsls60':
self.block = SelecSLSBlock
#Define configuration of the network after the initial neck
self.selecSLS_config = [
#inp,skip, k, oup, isFirst, stride
[ 32, 0, 64, 64, True, 2],
[ 64, 64, 64, 128, False, 1],
[128, 0, 128, 128, True, 2],
[128, 128, 128, 128, False, 1],
[128, 128, 128, 288, False, 1],
[288, 0, 288, 288, True, 2],
[288, 288, 288, 288, False, 1],
[288, 288, 288, 288, False, 1],
[288, 288, 288, 416, False, 1],
#Head can be replaced with alternative configurations depending on the problem
self.head = nn.Sequential(
conv_bn(416, 756, 2),
conv_bn(756, 1024, 1),
conv_bn(1024, 1024, 2),
conv_1x1_bn(1024, 1280),
self.num_features = 1280
elif cfg=='selecsls60NH':
self.block = SelecSLSBlock
#Define configuration of the network after the initial neck
self.selecSLS_config = [
#inp,skip, k, oup, isFirst, stride
[ 32, 0, 64, 64, True, 2],
[ 64, 64, 64, 128, False, 1],
[128, 0, 128, 128, True, 2],
[128, 128, 128, 128, False, 1],
[128, 128, 128, 288, False, 1],
[288, 0, 288, 288, True, 2],
[288, 288, 288, 288, False, 1],
[288, 288, 288, 288, False, 1],
[288, 288, 288, 416, False, 1],
#Head can be replaced with alternative configurations depending on the problem
self.head = nn.Sequential(
conv_bn(416, 756, 2),
conv_bn(756, 1024, 1),
conv_bn(1024, 1280, 2),
conv_1x1_bn(1280, 1024),
self.num_features = 1024
elif cfg=='selecsls84':
self.block = SelecSLSBlock
#Define configuration of the network after the initial neck
self.selecSLS_config = [
#inp,skip, k, oup, isFirst, stride
[ 32, 0, 64, 64, True, 2],
[ 64, 64, 64, 144, False, 1],
[144, 0, 144, 144, True, 2],
[144, 144, 144, 144, False, 1],
[144, 144, 144, 144, False, 1],
[144, 144, 144, 144, False, 1],
[144, 144, 144, 304, False, 1],
[304, 0, 304, 304, True, 2],
[304, 304, 304, 304, False, 1],
[304, 304, 304, 304, False, 1],
[304, 304, 304, 304, False, 1],
[304, 304, 304, 304, False, 1],
[304, 304, 304, 512, False, 1],
#Head can be replaced with alternative configurations depending on the problem
self.head = nn.Sequential(
conv_bn(512, 960, 2),
conv_bn(960, 1024, 1),
conv_bn(1024, 1024, 2),
conv_1x1_bn(1024, 1280),
self.num_features = 1280
raise ValueError('Invalid net configuration '+cfg+' !!!')
for inp, skip, k, oup, isFirst, stride in self.selecSLS_config:
self.features.append(self.block(inp, skip, k, oup, isFirst, stride))
self.features = nn.Sequential(*self.features)
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
self.fc = nn.Linear(self.num_features * self.global_pool.feat_mult(), num_classes)
for n, m in self.named_modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1.)
nn.init.constant_(m.bias, 0.)
def get_classifier(self):
return self.fc
def reset_classifier(self, num_classes, global_pool='avg'):
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
self.num_classes = num_classes
del self.fc
if num_classes:
self.fc = nn.Linear(self.num_features * self.global_pool.feat_mult(), num_classes)
self.fc = None
def forward_features(self, x, pool=True):
x = self.stem(x)
x = self.features([x])
x = self.head(x[0])
if pool:
x = self.global_pool(x)
x = x.view(x.size(0), -1)
return x
def forward(self, x):
x = self.forward_features(x)
if self.drop_rate > 0.:
x = F.dropout(x, p=self.drop_rate,
x = self.fc(x)
return x
def selecsls42(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
"""Constructs a SelecSLS42 model.
default_cfg = default_cfgs['selecsls42']
model = SelecSLS(
cfg='selecsls42', num_classes=1000, in_chans=3, **kwargs)
model.default_cfg = default_cfg
if pretrained:
load_pretrained(model, default_cfg, num_classes, in_chans)
return model
def selecsls42NH(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
"""Constructs a SelecSLS42NH model.
default_cfg = default_cfgs['selecsls42NH']
model = SelecSLS(
cfg='selecsls42NH', num_classes=1000, in_chans=3,**kwargs)
model.default_cfg = default_cfg
if pretrained:
load_pretrained(model, default_cfg, num_classes, in_chans)
return model
def selecsls60(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
"""Constructs a SelecSLS60 model.
default_cfg = default_cfgs['selecsls60']
model = SelecSLS(
cfg='selecsls60', num_classes=1000, in_chans=3,**kwargs)
model.default_cfg = default_cfg
if pretrained:
load_pretrained(model, default_cfg, num_classes, in_chans)
return model
def selecsls60NH(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
"""Constructs a SelecSLS60NH model.
default_cfg = default_cfgs['selecsls60NH']
model = SelecSLS(
cfg='selecsls60NH', num_classes=1000, in_chans=3,**kwargs)
model.default_cfg = default_cfg
if pretrained:
load_pretrained(model, default_cfg, num_classes, in_chans)
return model
def selecsls84(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
"""Constructs a SelecSLS84 model.
default_cfg = default_cfgs['selecsls84']
model = SelecSLS(
cfg='selecsls84', num_classes=1000, in_chans=3, **kwargs)
model.default_cfg = default_cfg
if pretrained:
load_pretrained(model, default_cfg, num_classes, in_chans)
return model