Add some models, remove a model, tweak some models

pull/1/head
Ross Wightman 6 years ago
parent 31055466fc
commit e0cfeb7d8e

@ -14,21 +14,17 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
def adaptive_avgmax_pool2d(x, pool_type='avg', padding=0, count_include_pad=False): def adaptive_avgmax_pool2d(x, pool_type='avg', output_size=1):
"""Selectable global pooling function with dynamic input kernel size """Selectable global pooling function with dynamic input kernel size
""" """
if pool_type == 'avgmax': if pool_type == 'avgmax':
x_avg = F.avg_pool2d( x_avg = F.adaptive_avg_pool2d(x, output_size)
x, kernel_size=(x.size(2), x.size(3)), padding=padding, count_include_pad=count_include_pad) x_max = F.adaptive_max_pool2d(x, output_size)
x_max = F.max_pool2d(x, kernel_size=(x.size(2), x.size(3)), padding=padding)
x = 0.5 * (x_avg + x_max) x = 0.5 * (x_avg + x_max)
elif pool_type == 'max': elif pool_type == 'max':
x = F.max_pool2d(x, kernel_size=(x.size(2), x.size(3)), padding=padding) x = F.adaptive_max_pool2d(x, output_size)
else: else:
if pool_type != 'avg': x = F.adaptive_avg_pool2d(x, output_size)
print('Invalid pool type %s specified. Defaulting to average pooling.' % pool_type)
x = F.avg_pool2d(
x, kernel_size=(x.size(2), x.size(3)), padding=padding, count_include_pad=count_include_pad)
return x return x

@ -8,6 +8,7 @@ import torch.nn.functional as F
import torch.utils.model_zoo as model_zoo import torch.utils.model_zoo as model_zoo
from collections import OrderedDict from collections import OrderedDict
from .adaptive_avgmax_pool import * from .adaptive_avgmax_pool import *
import re
__all__ = ['DenseNet', 'densenet121', 'densenet169', 'densenet201', 'densenet161'] __all__ = ['DenseNet', 'densenet121', 'densenet169', 'densenet201', 'densenet161']
@ -20,6 +21,19 @@ model_urls = {
} }
def _filter_pretrained(state_dict):
pattern = re.compile(
r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$')
for key in list(state_dict.keys()):
res = pattern.match(key)
if res:
new_key = res.group(1) + res.group(2)
state_dict[new_key] = state_dict[key]
del state_dict[key]
return state_dict
def densenet121(pretrained=False, **kwargs): def densenet121(pretrained=False, **kwargs):
r"""Densenet-121 model from r"""Densenet-121 model from
`"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>` `"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`
@ -29,7 +43,8 @@ def densenet121(pretrained=False, **kwargs):
""" """
model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 24, 16), **kwargs) model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 24, 16), **kwargs)
if pretrained: if pretrained:
model.load_state_dict(model_zoo.load_url(model_urls['densenet121'])) state_dict = model_zoo.load_url(model_urls['densenet121'])
model.load_state_dict(_filter_pretrained(state_dict))
return model return model
@ -42,7 +57,8 @@ def densenet169(pretrained=False, **kwargs):
""" """
model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 32, 32), **kwargs) model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 32, 32), **kwargs)
if pretrained: if pretrained:
model.load_state_dict(model_zoo.load_url(model_urls['densenet169'])) state_dict = model_zoo.load_url(model_urls['densenet169'])
model.load_state_dict(_filter_pretrained(state_dict))
return model return model
@ -55,7 +71,8 @@ def densenet201(pretrained=False, **kwargs):
""" """
model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 48, 32), **kwargs) model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 48, 32), **kwargs)
if pretrained: if pretrained:
model.load_state_dict(model_zoo.load_url(model_urls['densenet201'])) state_dict = model_zoo.load_url(model_urls['densenet201'])
model.load_state_dict(_filter_pretrained(state_dict))
return model return model
@ -69,20 +86,21 @@ def densenet161(pretrained=False, **kwargs):
print(kwargs) print(kwargs)
model = DenseNet(num_init_features=96, growth_rate=48, block_config=(6, 12, 36, 24), **kwargs) model = DenseNet(num_init_features=96, growth_rate=48, block_config=(6, 12, 36, 24), **kwargs)
if pretrained: if pretrained:
model.load_state_dict(model_zoo.load_url(model_urls['densenet161'])) state_dict = model_zoo.load_url(model_urls['densenet161'])
model.load_state_dict(_filter_pretrained(state_dict))
return model return model
class _DenseLayer(nn.Sequential): class _DenseLayer(nn.Sequential):
def __init__(self, num_input_features, growth_rate, bn_size, drop_rate): def __init__(self, num_input_features, growth_rate, bn_size, drop_rate):
super(_DenseLayer, self).__init__() super(_DenseLayer, self).__init__()
self.add_module('norm.1', nn.BatchNorm2d(num_input_features)), self.add_module('norm1', nn.BatchNorm2d(num_input_features)),
self.add_module('relu.1', nn.ReLU(inplace=True)), self.add_module('relu1', nn.ReLU(inplace=True)),
self.add_module('conv.1', nn.Conv2d(num_input_features, bn_size * self.add_module('conv1', nn.Conv2d(num_input_features, bn_size *
growth_rate, kernel_size=1, stride=1, bias=False)), growth_rate, kernel_size=1, stride=1, bias=False)),
self.add_module('norm.2', nn.BatchNorm2d(bn_size * growth_rate)), self.add_module('norm2', nn.BatchNorm2d(bn_size * growth_rate)),
self.add_module('relu.2', nn.ReLU(inplace=True)), self.add_module('relu2', nn.ReLU(inplace=True)),
self.add_module('conv.2', nn.Conv2d(bn_size * growth_rate, growth_rate, self.add_module('conv2', nn.Conv2d(bn_size * growth_rate, growth_rate,
kernel_size=3, stride=1, padding=1, bias=False)), kernel_size=3, stride=1, padding=1, bias=False)),
self.drop_rate = drop_rate self.drop_rate = drop_rate
@ -172,12 +190,12 @@ class DenseNet(nn.Module):
self.classifier = None self.classifier = None
def forward_features(self, x, pool=True): def forward_features(self, x, pool=True):
features = self.features(x) x = self.features(x)
out = F.relu(features, inplace=True) x = F.relu(x, inplace=True)
if pool: if pool:
out = adaptive_avgmax_pool2d(out, self.global_pool) x = adaptive_avgmax_pool2d(x, self.global_pool)
out = x.view(out.size(0), -1) x = x.view(x.size(0), -1)
return out return x
def forward(self, x): def forward(self, x):
return self.classifier(self.forward_features(x, pool=True)) return self.classifier(self.forward_features(x, pool=True))

@ -6,13 +6,13 @@ import os
from .inception_v4 import inception_v4 from .inception_v4 import inception_v4
from .inception_resnet_v2 import inception_resnet_v2 from .inception_resnet_v2 import inception_resnet_v2
from .wrn50_2 import wrn50_2 from .densenet import densenet161, densenet121, densenet169, densenet201
from .my_densenet import densenet161, densenet121, densenet169, densenet201 from .resnet import resnet18, resnet34, resnet50, resnet101, resnet152
from .my_resnet import resnet18, resnet34, resnet50, resnet101, resnet152
from .fbresnet200 import fbresnet200 from .fbresnet200 import fbresnet200
from .dpn import dpn68, dpn68b, dpn92, dpn98, dpn131, dpn107 from .dpn import dpn68, dpn68b, dpn92, dpn98, dpn131, dpn107
from .senet import se_resnet18, se_resnet34, se_resnet50, se_resnet101, se_resnet152,\ from .senet import seresnet18, seresnet34, seresnet50, seresnet101, seresnet152,\
se_resnext50_32x4d, se_resnext101_32x4d seresnext50_32x4d, seresnext101_32x4d
from .resnext import resnext50, resnext101, resnext152
model_config_dict = { model_config_dict = {
@ -99,15 +99,29 @@ def create_model(
elif model_name == 'inception_resnet_v2': elif model_name == 'inception_resnet_v2':
model = inception_resnet_v2(num_classes=num_classes, pretrained=pretrained, **kwargs) model = inception_resnet_v2(num_classes=num_classes, pretrained=pretrained, **kwargs)
elif model_name == 'inception_v4': elif model_name == 'inception_v4':
model = inception_v4(num_classes=num_classes, pretrained=pretrained, **kwargs) model = inception_v4(num_classes=num_classes, pretrained=pretrained, **kwargs)
elif model_name == 'wrn50':
model = wrn50_2(num_classes=num_classes, pretrained=pretrained, **kwargs)
elif model_name == 'fbresnet200': elif model_name == 'fbresnet200':
model = fbresnet200(num_classes=num_classes, pretrained=pretrained, **kwargs) model = fbresnet200(num_classes=num_classes, pretrained=pretrained, **kwargs)
elif model_name == 'seresnet18': elif model_name == 'seresnet18':
model = se_resnet18(num_classes=num_classes, pretrained=pretrained) model = seresnet18(num_classes=num_classes, pretrained=pretrained, **kwargs)
elif model_name == 'seresnet34': elif model_name == 'seresnet34':
model = se_resnet34(num_classes=num_classes, pretrained=pretrained) model = seresnet34(num_classes=num_classes, pretrained=pretrained, **kwargs)
elif model_name == 'seresnet50':
model = seresnet50(num_classes=num_classes, pretrained=pretrained, **kwargs)
elif model_name == 'seresnet101':
model = seresnet101(num_classes=num_classes, pretrained=pretrained, **kwargs)
elif model_name == 'seresnet152':
model = seresnet152(num_classes=num_classes, pretrained=pretrained, **kwargs)
elif model_name == 'seresnext50_32x4d':
model = seresnext50_32x4d(num_classes=num_classes, pretrained=pretrained, **kwargs)
elif model_name == 'seresnext101_32x4d':
model = seresnext101_32x4d(num_classes=num_classes, pretrained=pretrained, **kwargs)
elif model_name == 'resnext50':
model = resnext50(num_classes=num_classes, pretrained=pretrained, **kwargs)
elif model_name == 'resnext101':
model = resnext101(num_classes=num_classes, pretrained=pretrained, **kwargs)
elif model_name == 'resnext152':
model = resnext152(num_classes=num_classes, pretrained=pretrained, **kwargs)
else: else:
assert False and "Invalid model" assert False and "Invalid model"

@ -9,102 +9,22 @@ import math
import torch.nn as nn import torch.nn as nn
from torch.utils import model_zoo from torch.utils import model_zoo
__all__ = ['SENet', 'senet154', 'se_resnet50', 'se_resnet101', 'se_resnet152', __all__ = ['SENet', 'senet154', 'seresnet50', 'seresnet101', 'seresnet152',
'se_resnext50_32x4d', 'se_resnext101_32x4d'] 'seresnext50_32x4d', 'seresnext101_32x4d']
pretrained_config = { model_urls = {
'senet154': { 'senet154': 'http://data.lip6.fr/cadene/pretrainedmodels/senet154-c7b49a05.pth',
'imagenet': { 'seresnet18': 'http://data.lip6.fr/cadene/pretrainedmodels/se_resnet50-ce0d4300.pth',
'url': 'http://data.lip6.fr/cadene/pretrainedmodels/senet154-c7b49a05.pth', 'seresnet34': 'https://www.dropbox.com/s/q31ccy22aq0fju7/seresnet34-a4004e63.pth?dl=1',
'input_space': 'RGB', 'seresnet50': 'http://data.lip6.fr/cadene/pretrainedmodels/se_resnet50-ce0d4300.pth',
'input_size': [3, 224, 224], 'seresnet101': 'http://data.lip6.fr/cadene/pretrainedmodels/se_resnet101-7e38fcc6.pth',
'input_range': [0, 1], 'seresnet152': 'http://data.lip6.fr/cadene/pretrainedmodels/se_resnet152-d17c99b7.pth',
'mean': [0.485, 0.456, 0.406], 'seresnext50_32x4d': 'http://data.lip6.fr/cadene/pretrainedmodels/se_resnext50_32x4d-a260b3a4.pth',
'std': [0.229, 0.224, 0.225], 'seresnext101_32x4d': 'http://data.lip6.fr/cadene/pretrainedmodels/se_resnext101_32x4d-3b2fe3d8.pth',
'num_classes': 1000
}
},
'se_resnet18': {
'imagenet': {
'url': 'http://data.lip6.fr/cadene/pretrainedmodels/se_resnet50-ce0d4300.pth',
'input_space': 'RGB',
'input_size': [3, 224, 224],
'input_range': [0, 1],
'mean': [0.485, 0.456, 0.406],
'std': [0.229, 0.224, 0.225],
'num_classes': 1000
}
},
'se_resnet34': {
'imagenet': {
'url': 'http://data.lip6.fr/cadene/pretrainedmodels/se_resnet50-ce0d4300.pth',
'input_space': 'RGB',
'input_size': [3, 224, 224],
'input_range': [0, 1],
'mean': [0.485, 0.456, 0.406],
'std': [0.229, 0.224, 0.225],
'num_classes': 1000
}
},
'se_resnet50': {
'imagenet': {
'url': 'http://data.lip6.fr/cadene/pretrainedmodels/se_resnet50-ce0d4300.pth',
'input_space': 'RGB',
'input_size': [3, 224, 224],
'input_range': [0, 1],
'mean': [0.485, 0.456, 0.406],
'std': [0.229, 0.224, 0.225],
'num_classes': 1000
}
},
'se_resnet101': {
'imagenet': {
'url': 'http://data.lip6.fr/cadene/pretrainedmodels/se_resnet101-7e38fcc6.pth',
'input_space': 'RGB',
'input_size': [3, 224, 224],
'input_range': [0, 1],
'mean': [0.485, 0.456, 0.406],
'std': [0.229, 0.224, 0.225],
'num_classes': 1000
}
},
'se_resnet152': {
'imagenet': {
'url': 'http://data.lip6.fr/cadene/pretrainedmodels/se_resnet152-d17c99b7.pth',
'input_space': 'RGB',
'input_size': [3, 224, 224],
'input_range': [0, 1],
'mean': [0.485, 0.456, 0.406],
'std': [0.229, 0.224, 0.225],
'num_classes': 1000
}
},
'se_resnext50_32x4d': {
'imagenet': {
'url': 'http://data.lip6.fr/cadene/pretrainedmodels/se_resnext50_32x4d-a260b3a4.pth',
'input_space': 'RGB',
'input_size': [3, 224, 224],
'input_range': [0, 1],
'mean': [0.485, 0.456, 0.406],
'std': [0.229, 0.224, 0.225],
'num_classes': 1000
}
},
'se_resnext101_32x4d': {
'imagenet': {
'url': 'http://data.lip6.fr/cadene/pretrainedmodels/se_resnext101_32x4d-3b2fe3d8.pth',
'input_space': 'RGB',
'input_size': [3, 224, 224],
'input_range': [0, 1],
'mean': [0.485, 0.456, 0.406],
'std': [0.229, 0.224, 0.225],
'num_classes': 1000
}
},
} }
def _weight_init(m, n='', ll=''): def _weight_init(m):
if isinstance(m, nn.Conv2d): if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
elif isinstance(m, nn.BatchNorm2d): elif isinstance(m, nn.BatchNorm2d):
@ -138,6 +58,7 @@ class Bottleneck(nn.Module):
""" """
Base class for bottlenecks that implements `forward()` method. Base class for bottlenecks that implements `forward()` method.
""" """
def forward(self, x): def forward(self, x):
residual = x residual = x
@ -273,7 +194,7 @@ class SEResNetBlock(nn.Module):
class SENet(nn.Module): class SENet(nn.Module):
def __init__(self, block, layers, groups, reduction, dropout_p=0.2, def __init__(self, block, layers, groups, reduction, dropout_p=0.2,
inch=3, inplanes=128, input_3x3=True, downsample_kernel_size=3, inchans=3, inplanes=128, input_3x3=True, downsample_kernel_size=3,
downsample_padding=1, num_classes=1000): downsample_padding=1, num_classes=1000):
""" """
Parameters Parameters
@ -320,9 +241,10 @@ class SENet(nn.Module):
""" """
super(SENet, self).__init__() super(SENet, self).__init__()
self.inplanes = inplanes self.inplanes = inplanes
self.num_classes = num_classes
if input_3x3: if input_3x3:
layer0_modules = [ layer0_modules = [
('conv1', nn.Conv2d(inch, 64, 3, stride=2, padding=1, bias=False)), ('conv1', nn.Conv2d(inchans, 64, 3, stride=2, padding=1, bias=False)),
('bn1', nn.BatchNorm2d(64)), ('bn1', nn.BatchNorm2d(64)),
('relu1', nn.ReLU(inplace=True)), ('relu1', nn.ReLU(inplace=True)),
('conv2', nn.Conv2d(64, 64, 3, stride=1, padding=1, bias=False)), ('conv2', nn.Conv2d(64, 64, 3, stride=1, padding=1, bias=False)),
@ -335,7 +257,7 @@ class SENet(nn.Module):
else: else:
layer0_modules = [ layer0_modules = [
('conv1', nn.Conv2d( ('conv1', nn.Conv2d(
inch, inplanes, kernel_size=7, stride=2, padding=3, bias=False)), inchans, inplanes, kernel_size=7, stride=2, padding=3, bias=False)),
('bn1', nn.BatchNorm2d(inplanes)), ('bn1', nn.BatchNorm2d(inplanes)),
('relu1', nn.ReLU(inplace=True)), ('relu1', nn.ReLU(inplace=True)),
] ]
@ -384,7 +306,8 @@ class SENet(nn.Module):
) )
self.avg_pool = nn.AdaptiveAvgPool2d(1) self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.dropout = nn.Dropout(dropout_p) if dropout_p is not None else None self.dropout = nn.Dropout(dropout_p) if dropout_p is not None else None
self.last_linear = nn.Linear(512 * block.expansion, num_classes) self.num_features = 512 * block.expansion
self.last_linear = nn.Linear(self.num_features, num_classes)
for m in self.modules(): for m in self.modules():
_weight_init(m) _weight_init(m)
@ -408,19 +331,31 @@ class SENet(nn.Module):
return nn.Sequential(*layers) return nn.Sequential(*layers)
def forward_features(self, x): def get_classifier(self):
return self.last_linear
def reset_classifier(self, num_classes):
self.num_classes = num_classes
del self.last_linear
if num_classes:
self.last_linear = nn.Linear(self.num_features, num_classes)
else:
self.last_linear = None
def forward_features(self, x, pool=True):
x = self.layer0(x) x = self.layer0(x)
x = self.layer1(x) x = self.layer1(x)
x = self.layer2(x) x = self.layer2(x)
x = self.layer3(x) x = self.layer3(x)
x = self.layer4(x) x = self.layer4(x)
if pool:
x = self.avg_pool(x)
x = x.view(x.size(0), -1)
return x return x
def logits(self, x): def logits(self, x):
x = self.avg_pool(x)
if self.dropout is not None: if self.dropout is not None:
x = self.dropout(x) x = self.dropout(x)
x = x.view(x.size(0), -1)
x = self.last_linear(x) x = self.last_linear(x)
return x return x
@ -430,99 +365,89 @@ class SENet(nn.Module):
return x return x
def initialize_pretrained_model(model, num_classes, config): def _load_pretrained(model, url, inchans=3):
assert num_classes == config['num_classes'], \ state_dict = model_zoo.load_url(url)
'num_classes should be {}, but is {}'.format( if inchans == 1:
config['num_classes'], num_classes) conv1_weight = state_dict['conv1.weight']
model.load_state_dict(model_zoo.load_url(config['url'])) state_dict['conv1.weight'] = conv1_weight.sum(dim=1, keepdim=True)
model.input_space = config['input_space'] elif inchans != 3:
model.input_size = config['input_size'] assert False, "Invalid inchans for pretrained weights"
model.input_range = config['input_range'] model.load_state_dict(state_dict)
model.mean = config['mean']
model.std = config['std']
def senet154(num_classes=1000, pretrained='imagenet'): def senet154(num_classes=1000, inchans=3, pretrained='imagenet'):
model = SENet(SEBottleneck, [3, 8, 36, 3], groups=64, reduction=16, model = SENet(SEBottleneck, [3, 8, 36, 3], groups=64, reduction=16,
dropout_p=0.2, num_classes=num_classes) dropout_p=0.2, num_classes=num_classes)
if pretrained: if pretrained:
config = pretrained_config['senet154'][pretrained] _load_pretrained(model, model_urls['senet154'], inchans)
initialize_pretrained_model(model, num_classes, config)
return model return model
def se_resnet18(num_classes=1000, pretrained='imagenet'): def seresnet18(num_classes=1000, inchans=3, pretrained='imagenet'):
model = SENet(SEResNetBlock, [2, 2, 2, 2], groups=1, reduction=16, model = SENet(SEResNetBlock, [2, 2, 2, 2], groups=1, reduction=16,
dropout_p=None, inplanes=64, input_3x3=False, dropout_p=None, inplanes=64, input_3x3=False,
downsample_kernel_size=1, downsample_padding=0, downsample_kernel_size=1, downsample_padding=0,
num_classes=num_classes) num_classes=num_classes)
if pretrained: if pretrained:
config = pretrained_config['se_resnet18'][pretrained] _load_pretrained(model, model_urls['seresnet18'], inchans)
initialize_pretrained_model(model, num_classes, config)
return model return model
def se_resnet34(num_classes=1000, pretrained='imagenet'): def seresnet34(num_classes=1000, inchans=3, pretrained='imagenet'):
model = SENet(SEResNetBlock, [3, 4, 6, 3], groups=1, reduction=16, model = SENet(SEResNetBlock, [3, 4, 6, 3], groups=1, reduction=16,
dropout_p=None, inplanes=64, input_3x3=False, dropout_p=None, inplanes=64, input_3x3=False,
downsample_kernel_size=1, downsample_padding=0, downsample_kernel_size=1, downsample_padding=0,
num_classes=num_classes) num_classes=num_classes)
if pretrained: if pretrained:
config = pretrained_config['se_resnet34'][pretrained] _load_pretrained(model, model_urls['seresnet34'], inchans)
initialize_pretrained_model(model, num_classes, config)
return model return model
def se_resnet50(num_classes=1000, pretrained='imagenet'): def seresnet50(num_classes=1000, inchans=3, pretrained='imagenet'):
model = SENet(SEResNetBottleneck, [3, 4, 6, 3], groups=1, reduction=16, model = SENet(SEResNetBottleneck, [3, 4, 6, 3], groups=1, reduction=16,
dropout_p=None, inplanes=64, input_3x3=False, dropout_p=None, inplanes=64, input_3x3=False,
downsample_kernel_size=1, downsample_padding=0, downsample_kernel_size=1, downsample_padding=0,
num_classes=num_classes) num_classes=num_classes)
if pretrained: if pretrained:
config = pretrained_config['se_resnet50'][pretrained] _load_pretrained(model, model_urls['seresnet50'], inchans)
initialize_pretrained_model(model, num_classes, config)
return model return model
def se_resnet101(num_classes=1000, pretrained='imagenet'): def seresnet101(num_classes=1000, inchans=3, pretrained='imagenet'):
model = SENet(SEResNetBottleneck, [3, 4, 23, 3], groups=1, reduction=16, model = SENet(SEResNetBottleneck, [3, 4, 23, 3], groups=1, reduction=16,
dropout_p=None, inplanes=64, input_3x3=False, dropout_p=None, inplanes=64, input_3x3=False,
downsample_kernel_size=1, downsample_padding=0, downsample_kernel_size=1, downsample_padding=0,
num_classes=num_classes) num_classes=num_classes)
if pretrained: if pretrained:
config = pretrained_config['se_resnet101'][pretrained] _load_pretrained(model, model_urls['seresnet101'], inchans)
initialize_pretrained_model(model, num_classes, config)
return model return model
def se_resnet152(num_classes=1000, pretrained='imagenet'): def seresnet152(num_classes=1000, inchans=3, pretrained='imagenet'):
model = SENet(SEResNetBottleneck, [3, 8, 36, 3], groups=1, reduction=16, model = SENet(SEResNetBottleneck, [3, 8, 36, 3], groups=1, reduction=16,
dropout_p=None, inplanes=64, input_3x3=False, dropout_p=None, inplanes=64, input_3x3=False,
downsample_kernel_size=1, downsample_padding=0, downsample_kernel_size=1, downsample_padding=0,
num_classes=num_classes) num_classes=num_classes)
if pretrained: if pretrained:
config = pretrained_config['se_resnet152'][pretrained] _load_pretrained(model, model_urls['seresnet152'], inchans)
initialize_pretrained_model(model, num_classes, config)
return model return model
def se_resnext50_32x4d(num_classes=1000, pretrained='imagenet'): def seresnext50_32x4d(num_classes=1000, inchans=3, pretrained='imagenet'):
model = SENet(SEResNeXtBottleneck, [3, 4, 6, 3], groups=32, reduction=16, model = SENet(SEResNeXtBottleneck, [3, 4, 6, 3], groups=32, reduction=16,
dropout_p=None, inplanes=64, input_3x3=False, dropout_p=None, inplanes=64, input_3x3=False,
downsample_kernel_size=1, downsample_padding=0, downsample_kernel_size=1, downsample_padding=0,
num_classes=num_classes) num_classes=num_classes)
if pretrained: if pretrained:
config = pretrained_config['se_resnext50_32x4d'][pretrained] _load_pretrained(model, model_urls['seresnext50_32x4d'], inchans)
initialize_pretrained_model(model, num_classes, config)
return model return model
def se_resnext101_32x4d(num_classes=1000, pretrained='imagenet'): def seresnext101_32x4d(num_classes=1000, inchans=3, pretrained='imagenet'):
model = SENet(SEResNeXtBottleneck, [3, 4, 23, 3], groups=32, reduction=16, model = SENet(SEResNeXtBottleneck, [3, 4, 23, 3], groups=32, reduction=16,
dropout_p=None, inplanes=64, input_3x3=False, dropout_p=None, inplanes=64, input_3x3=False,
downsample_kernel_size=1, downsample_padding=0, downsample_kernel_size=1, downsample_padding=0,
num_classes=num_classes) num_classes=num_classes)
if pretrained: if pretrained:
config = pretrained_config['se_resnext101_32x4d'][pretrained] _load_pretrained(model, model_urls['seresnext101_32x4d'], inchans)
initialize_pretrained_model(model, num_classes, config)
return model return model

@ -1,393 +0,0 @@
""" Pytorch Wide-Resnet-50-2
Sourced by running https://github.com/clcarwin/convert_torch_to_pytorch (MIT) on
https://github.com/szagoruyko/wide-residual-networks/blob/master/pretrained/README.md
License of above is, as of yet, unclear.
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.model_zoo as model_zoo
from functools import reduce
from collections import OrderedDict
from .adaptive_avgmax_pool import *
model_urls = {
'wrn50_2': 'https://www.dropbox.com/s/fe7rj3okz9rctn0/wrn50_2-d98ded61.pth?dl=1',
}
class LambdaBase(nn.Sequential):
def __init__(self, fn, *args):
super(LambdaBase, self).__init__(*args)
self.lambda_func = fn
def forward_prepare(self, input):
output = []
for module in self._modules.values():
output.append(module(input))
return output if output else input
class Lambda(LambdaBase):
def forward(self, input):
return self.lambda_func(self.forward_prepare(input))
class LambdaMap(LambdaBase):
def forward(self, input):
return list(map(self.lambda_func, self.forward_prepare(input)))
class LambdaReduce(LambdaBase):
def forward(self, input):
return reduce(self.lambda_func, self.forward_prepare(input))
def wrn_50_2_features(activation_fn=nn.ReLU()):
features = nn.Sequential( # Sequential,
nn.Conv2d(3, 64, (7, 7), (2, 2), (3, 3), 1, 1, bias=False),
nn.BatchNorm2d(64),
activation_fn,
nn.MaxPool2d((3, 3), (2, 2), (1, 1)),
nn.Sequential( # Sequential,
nn.Sequential( # Sequential,
LambdaMap(lambda x: x, # ConcatTable,
nn.Sequential( # Sequential,
nn.Conv2d(64, 128, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
nn.BatchNorm2d(128),
activation_fn,
nn.Conv2d(128, 128, (3, 3), (1, 1), (1, 1), 1, 1, bias=False),
nn.BatchNorm2d(128),
activation_fn,
nn.Conv2d(128, 256, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
nn.BatchNorm2d(256),
),
nn.Sequential( # Sequential,
nn.Conv2d(64, 256, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
nn.BatchNorm2d(256),
),
),
LambdaReduce(lambda x, y: x + y), # CAddTable,
activation_fn,
),
nn.Sequential( # Sequential,
LambdaMap(lambda x: x, # ConcatTable,
nn.Sequential( # Sequential,
nn.Conv2d(256, 128, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
nn.BatchNorm2d(128),
activation_fn,
nn.Conv2d(128, 128, (3, 3), (1, 1), (1, 1), 1, 1, bias=False),
nn.BatchNorm2d(128),
activation_fn,
nn.Conv2d(128, 256, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
nn.BatchNorm2d(256),
),
Lambda(lambda x: x), # Identity,
),
LambdaReduce(lambda x, y: x + y), # CAddTable,
activation_fn,
),
nn.Sequential( # Sequential,
LambdaMap(lambda x: x, # ConcatTable,
nn.Sequential( # Sequential,
nn.Conv2d(256, 128, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
nn.BatchNorm2d(128),
activation_fn,
nn.Conv2d(128, 128, (3, 3), (1, 1), (1, 1), 1, 1, bias=False),
nn.BatchNorm2d(128),
activation_fn,
nn.Conv2d(128, 256, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
nn.BatchNorm2d(256),
),
Lambda(lambda x: x), # Identity,
),
LambdaReduce(lambda x, y: x + y), # CAddTable,
activation_fn,
),
),
nn.Sequential( # Sequential,
nn.Sequential( # Sequential,
LambdaMap(lambda x: x, # ConcatTable,
nn.Sequential( # Sequential,
nn.Conv2d(256, 256, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
nn.BatchNorm2d(256),
activation_fn,
nn.Conv2d(256, 256, (3, 3), (2, 2), (1, 1), 1, 1, bias=False),
nn.BatchNorm2d(256),
activation_fn,
nn.Conv2d(256, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
nn.BatchNorm2d(512),
),
nn.Sequential( # Sequential,
nn.Conv2d(256, 512, (1, 1), (2, 2), (0, 0), 1, 1, bias=False),
nn.BatchNorm2d(512),
),
),
LambdaReduce(lambda x, y: x + y), # CAddTable,
activation_fn,
),
nn.Sequential( # Sequential,
LambdaMap(lambda x: x, # ConcatTable,
nn.Sequential( # Sequential,
nn.Conv2d(512, 256, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
nn.BatchNorm2d(256),
activation_fn,
nn.Conv2d(256, 256, (3, 3), (1, 1), (1, 1), 1, 1, bias=False),
nn.BatchNorm2d(256),
activation_fn,
nn.Conv2d(256, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
nn.BatchNorm2d(512),
),
Lambda(lambda x: x), # Identity,
),
LambdaReduce(lambda x, y: x + y), # CAddTable,
activation_fn,
),
nn.Sequential( # Sequential,
LambdaMap(lambda x: x, # ConcatTable,
nn.Sequential( # Sequential,
nn.Conv2d(512, 256, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
nn.BatchNorm2d(256),
activation_fn,
nn.Conv2d(256, 256, (3, 3), (1, 1), (1, 1), 1, 1, bias=False),
nn.BatchNorm2d(256),
activation_fn,
nn.Conv2d(256, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
nn.BatchNorm2d(512),
),
Lambda(lambda x: x), # Identity,
),
LambdaReduce(lambda x, y: x + y), # CAddTable,
activation_fn,
),
nn.Sequential( # Sequential,
LambdaMap(lambda x: x, # ConcatTable,
nn.Sequential( # Sequential,
nn.Conv2d(512, 256, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
nn.BatchNorm2d(256),
activation_fn,
nn.Conv2d(256, 256, (3, 3), (1, 1), (1, 1), 1, 1, bias=False),
nn.BatchNorm2d(256),
activation_fn,
nn.Conv2d(256, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
nn.BatchNorm2d(512),
),
Lambda(lambda x: x), # Identity,
),
LambdaReduce(lambda x, y: x + y), # CAddTable,
activation_fn,
),
),
nn.Sequential( # Sequential,
nn.Sequential( # Sequential,
LambdaMap(lambda x: x, # ConcatTable,
nn.Sequential( # Sequential,
nn.Conv2d(512, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
nn.BatchNorm2d(512),
activation_fn,
nn.Conv2d(512, 512, (3, 3), (2, 2), (1, 1), 1, 1, bias=False),
nn.BatchNorm2d(512),
activation_fn,
nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
nn.BatchNorm2d(1024),
),
nn.Sequential( # Sequential,
nn.Conv2d(512, 1024, (1, 1), (2, 2), (0, 0), 1, 1, bias=False),
nn.BatchNorm2d(1024),
),
),
LambdaReduce(lambda x, y: x + y), # CAddTable,
activation_fn,
),
nn.Sequential( # Sequential,
LambdaMap(lambda x: x, # ConcatTable,
nn.Sequential( # Sequential,
nn.Conv2d(1024, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
nn.BatchNorm2d(512),
activation_fn,
nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 1, bias=False),
nn.BatchNorm2d(512),
activation_fn,
nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
nn.BatchNorm2d(1024),
),
Lambda(lambda x: x), # Identity,
),
LambdaReduce(lambda x, y: x + y), # CAddTable,
activation_fn,
),
nn.Sequential( # Sequential,
LambdaMap(lambda x: x, # ConcatTable,
nn.Sequential( # Sequential,
nn.Conv2d(1024, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
nn.BatchNorm2d(512),
activation_fn,
nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 1, bias=False),
nn.BatchNorm2d(512),
activation_fn,
nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
nn.BatchNorm2d(1024),
),
Lambda(lambda x: x), # Identity,
),
LambdaReduce(lambda x, y: x + y), # CAddTable,
activation_fn,
),
nn.Sequential( # Sequential,
LambdaMap(lambda x: x, # ConcatTable,
nn.Sequential( # Sequential,
nn.Conv2d(1024, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
nn.BatchNorm2d(512),
activation_fn,
nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 1, bias=False),
nn.BatchNorm2d(512),
activation_fn,
nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
nn.BatchNorm2d(1024),
),
Lambda(lambda x: x), # Identity,
),
LambdaReduce(lambda x, y: x + y), # CAddTable,
activation_fn,
),
nn.Sequential( # Sequential,
LambdaMap(lambda x: x, # ConcatTable,
nn.Sequential( # Sequential,
nn.Conv2d(1024, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
nn.BatchNorm2d(512),
activation_fn,
nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 1, bias=False),
nn.BatchNorm2d(512),
activation_fn,
nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
nn.BatchNorm2d(1024),
),
Lambda(lambda x: x), # Identity,
),
LambdaReduce(lambda x, y: x + y), # CAddTable,
activation_fn,
),
nn.Sequential( # Sequential,
LambdaMap(lambda x: x, # ConcatTable,
nn.Sequential( # Sequential,
nn.Conv2d(1024, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
nn.BatchNorm2d(512),
activation_fn,
nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 1, bias=False),
nn.BatchNorm2d(512),
activation_fn,
nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
nn.BatchNorm2d(1024),
),
Lambda(lambda x: x), # Identity,
),
LambdaReduce(lambda x, y: x + y), # CAddTable,
activation_fn,
),
),
nn.Sequential( # Sequential,
nn.Sequential( # Sequential,
LambdaMap(lambda x: x, # ConcatTable,
nn.Sequential( # Sequential,
nn.Conv2d(1024, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
nn.BatchNorm2d(1024),
activation_fn,
nn.Conv2d(1024, 1024, (3, 3), (2, 2), (1, 1), 1, 1, bias=False),
nn.BatchNorm2d(1024),
activation_fn,
nn.Conv2d(1024, 2048, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
nn.BatchNorm2d(2048),
),
nn.Sequential( # Sequential,
nn.Conv2d(1024, 2048, (1, 1), (2, 2), (0, 0), 1, 1, bias=False),
nn.BatchNorm2d(2048),
),
),
LambdaReduce(lambda x, y: x + y), # CAddTable,
activation_fn,
),
nn.Sequential( # Sequential,
LambdaMap(lambda x: x, # ConcatTable,
nn.Sequential( # Sequential,
nn.Conv2d(2048, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
nn.BatchNorm2d(1024),
activation_fn,
nn.Conv2d(1024, 1024, (3, 3), (1, 1), (1, 1), 1, 1, bias=False),
nn.BatchNorm2d(1024),
activation_fn,
nn.Conv2d(1024, 2048, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
nn.BatchNorm2d(2048),
),
Lambda(lambda x: x), # Identity,
),
LambdaReduce(lambda x, y: x + y), # CAddTable,
activation_fn,
),
nn.Sequential( # Sequential,
LambdaMap(lambda x: x, # ConcatTable,
nn.Sequential( # Sequential,
nn.Conv2d(2048, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
nn.BatchNorm2d(1024),
activation_fn,
nn.Conv2d(1024, 1024, (3, 3), (1, 1), (1, 1), 1, 1, bias=False),
nn.BatchNorm2d(1024),
activation_fn,
nn.Conv2d(1024, 2048, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
nn.BatchNorm2d(2048),
),
Lambda(lambda x: x), # Identity,
),
LambdaReduce(lambda x, y: x + y), # CAddTable,
activation_fn,
),
),
)
return features
class Wrn50_2(nn.Module):
def __init__(self, num_classes=1000, activation_fn=nn.ReLU(), drop_rate=0., global_pool='avg'):
super(Wrn50_2, self).__init__()
self.drop_rate = drop_rate
self.num_classes = num_classes
self.num_features = 2048
self.global_pool = global_pool
self.features = wrn_50_2_features(activation_fn=activation_fn)
self.fc = nn.Linear(2048, num_classes)
def get_classifier(self):
return self.fc
def reset_classifier(self, num_classes, global_pool='avg'):
self.num_classes = num_classes
self.global_pool = global_pool
self.fc = nn.Linear(2048, num_classes)
def forward_features(self, x, pool=True):
x = self.features(x)
if pool:
x = adaptive_avgmax_pool2d(x, self.global_pool)
x = x.view(x.size(0), -1)
return x
def forward(self, x):
x = self.forward_features(x, pool=True)
if self.drop_rate > 0:
x = F.dropout(x, p=self.drop_rate, training=self.training)
x = self.fc(x)
return x
def wrn50_2(pretrained=False, num_classes=1000, **kwargs):
model = Wrn50_2(num_classes=num_classes, **kwargs)
if pretrained:
# Remap pretrained weights to match our class module with features + fc
pretrained_weights = model_zoo.load_url(model_urls['wrn50_2'])
feature_keys = filter(lambda k: '10.1.' not in k, pretrained_weights.keys())
remapped_weights = OrderedDict()
for k in feature_keys:
remapped_weights['features.' + k] = pretrained_weights[k]
remapped_weights['fc.weight'] = pretrained_weights['10.1.weight']
remapped_weights['fc.bias'] = pretrained_weights['10.1.bias']
model.load_state_dict(remapped_weights)
return model
Loading…
Cancel
Save