Add some models, remove a model, tweak some models

pull/1/head
Ross Wightman 6 years ago
parent 31055466fc
commit e0cfeb7d8e

@ -14,21 +14,17 @@ import torch.nn as nn
import torch.nn.functional as F
def adaptive_avgmax_pool2d(x, pool_type='avg', padding=0, count_include_pad=False):
def adaptive_avgmax_pool2d(x, pool_type='avg', output_size=1):
"""Selectable global pooling function with dynamic input kernel size
"""
if pool_type == 'avgmax':
x_avg = F.avg_pool2d(
x, kernel_size=(x.size(2), x.size(3)), padding=padding, count_include_pad=count_include_pad)
x_max = F.max_pool2d(x, kernel_size=(x.size(2), x.size(3)), padding=padding)
x_avg = F.adaptive_avg_pool2d(x, output_size)
x_max = F.adaptive_max_pool2d(x, output_size)
x = 0.5 * (x_avg + x_max)
elif pool_type == 'max':
x = F.max_pool2d(x, kernel_size=(x.size(2), x.size(3)), padding=padding)
x = F.adaptive_max_pool2d(x, output_size)
else:
if pool_type != 'avg':
print('Invalid pool type %s specified. Defaulting to average pooling.' % pool_type)
x = F.avg_pool2d(
x, kernel_size=(x.size(2), x.size(3)), padding=padding, count_include_pad=count_include_pad)
x = F.adaptive_avg_pool2d(x, output_size)
return x

@ -8,6 +8,7 @@ import torch.nn.functional as F
import torch.utils.model_zoo as model_zoo
from collections import OrderedDict
from .adaptive_avgmax_pool import *
import re
__all__ = ['DenseNet', 'densenet121', 'densenet169', 'densenet201', 'densenet161']
@ -20,6 +21,19 @@ model_urls = {
}
def _filter_pretrained(state_dict):
pattern = re.compile(
r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$')
for key in list(state_dict.keys()):
res = pattern.match(key)
if res:
new_key = res.group(1) + res.group(2)
state_dict[new_key] = state_dict[key]
del state_dict[key]
return state_dict
def densenet121(pretrained=False, **kwargs):
r"""Densenet-121 model from
`"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`
@ -29,7 +43,8 @@ def densenet121(pretrained=False, **kwargs):
"""
model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 24, 16), **kwargs)
if pretrained:
model.load_state_dict(model_zoo.load_url(model_urls['densenet121']))
state_dict = model_zoo.load_url(model_urls['densenet121'])
model.load_state_dict(_filter_pretrained(state_dict))
return model
@ -42,7 +57,8 @@ def densenet169(pretrained=False, **kwargs):
"""
model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 32, 32), **kwargs)
if pretrained:
model.load_state_dict(model_zoo.load_url(model_urls['densenet169']))
state_dict = model_zoo.load_url(model_urls['densenet169'])
model.load_state_dict(_filter_pretrained(state_dict))
return model
@ -55,7 +71,8 @@ def densenet201(pretrained=False, **kwargs):
"""
model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 48, 32), **kwargs)
if pretrained:
model.load_state_dict(model_zoo.load_url(model_urls['densenet201']))
state_dict = model_zoo.load_url(model_urls['densenet201'])
model.load_state_dict(_filter_pretrained(state_dict))
return model
@ -69,20 +86,21 @@ def densenet161(pretrained=False, **kwargs):
print(kwargs)
model = DenseNet(num_init_features=96, growth_rate=48, block_config=(6, 12, 36, 24), **kwargs)
if pretrained:
model.load_state_dict(model_zoo.load_url(model_urls['densenet161']))
state_dict = model_zoo.load_url(model_urls['densenet161'])
model.load_state_dict(_filter_pretrained(state_dict))
return model
class _DenseLayer(nn.Sequential):
def __init__(self, num_input_features, growth_rate, bn_size, drop_rate):
super(_DenseLayer, self).__init__()
self.add_module('norm.1', nn.BatchNorm2d(num_input_features)),
self.add_module('relu.1', nn.ReLU(inplace=True)),
self.add_module('conv.1', nn.Conv2d(num_input_features, bn_size *
self.add_module('norm1', nn.BatchNorm2d(num_input_features)),
self.add_module('relu1', nn.ReLU(inplace=True)),
self.add_module('conv1', nn.Conv2d(num_input_features, bn_size *
growth_rate, kernel_size=1, stride=1, bias=False)),
self.add_module('norm.2', nn.BatchNorm2d(bn_size * growth_rate)),
self.add_module('relu.2', nn.ReLU(inplace=True)),
self.add_module('conv.2', nn.Conv2d(bn_size * growth_rate, growth_rate,
self.add_module('norm2', nn.BatchNorm2d(bn_size * growth_rate)),
self.add_module('relu2', nn.ReLU(inplace=True)),
self.add_module('conv2', nn.Conv2d(bn_size * growth_rate, growth_rate,
kernel_size=3, stride=1, padding=1, bias=False)),
self.drop_rate = drop_rate
@ -172,12 +190,12 @@ class DenseNet(nn.Module):
self.classifier = None
def forward_features(self, x, pool=True):
features = self.features(x)
out = F.relu(features, inplace=True)
x = self.features(x)
x = F.relu(x, inplace=True)
if pool:
out = adaptive_avgmax_pool2d(out, self.global_pool)
out = x.view(out.size(0), -1)
return out
x = adaptive_avgmax_pool2d(x, self.global_pool)
x = x.view(x.size(0), -1)
return x
def forward(self, x):
return self.classifier(self.forward_features(x, pool=True))

@ -6,13 +6,13 @@ import os
from .inception_v4 import inception_v4
from .inception_resnet_v2 import inception_resnet_v2
from .wrn50_2 import wrn50_2
from .my_densenet import densenet161, densenet121, densenet169, densenet201
from .my_resnet import resnet18, resnet34, resnet50, resnet101, resnet152
from .densenet import densenet161, densenet121, densenet169, densenet201
from .resnet import resnet18, resnet34, resnet50, resnet101, resnet152
from .fbresnet200 import fbresnet200
from .dpn import dpn68, dpn68b, dpn92, dpn98, dpn131, dpn107
from .senet import se_resnet18, se_resnet34, se_resnet50, se_resnet101, se_resnet152,\
se_resnext50_32x4d, se_resnext101_32x4d
from .senet import seresnet18, seresnet34, seresnet50, seresnet101, seresnet152,\
seresnext50_32x4d, seresnext101_32x4d
from .resnext import resnext50, resnext101, resnext152
model_config_dict = {
@ -100,14 +100,28 @@ def create_model(
model = inception_resnet_v2(num_classes=num_classes, pretrained=pretrained, **kwargs)
elif model_name == 'inception_v4':
model = inception_v4(num_classes=num_classes, pretrained=pretrained, **kwargs)
elif model_name == 'wrn50':
model = wrn50_2(num_classes=num_classes, pretrained=pretrained, **kwargs)
elif model_name == 'fbresnet200':
model = fbresnet200(num_classes=num_classes, pretrained=pretrained, **kwargs)
elif model_name == 'seresnet18':
model = se_resnet18(num_classes=num_classes, pretrained=pretrained)
model = seresnet18(num_classes=num_classes, pretrained=pretrained, **kwargs)
elif model_name == 'seresnet34':
model = se_resnet34(num_classes=num_classes, pretrained=pretrained)
model = seresnet34(num_classes=num_classes, pretrained=pretrained, **kwargs)
elif model_name == 'seresnet50':
model = seresnet50(num_classes=num_classes, pretrained=pretrained, **kwargs)
elif model_name == 'seresnet101':
model = seresnet101(num_classes=num_classes, pretrained=pretrained, **kwargs)
elif model_name == 'seresnet152':
model = seresnet152(num_classes=num_classes, pretrained=pretrained, **kwargs)
elif model_name == 'seresnext50_32x4d':
model = seresnext50_32x4d(num_classes=num_classes, pretrained=pretrained, **kwargs)
elif model_name == 'seresnext101_32x4d':
model = seresnext101_32x4d(num_classes=num_classes, pretrained=pretrained, **kwargs)
elif model_name == 'resnext50':
model = resnext50(num_classes=num_classes, pretrained=pretrained, **kwargs)
elif model_name == 'resnext101':
model = resnext101(num_classes=num_classes, pretrained=pretrained, **kwargs)
elif model_name == 'resnext152':
model = resnext152(num_classes=num_classes, pretrained=pretrained, **kwargs)
else:
assert False and "Invalid model"

@ -9,102 +9,22 @@ import math
import torch.nn as nn
from torch.utils import model_zoo
__all__ = ['SENet', 'senet154', 'se_resnet50', 'se_resnet101', 'se_resnet152',
'se_resnext50_32x4d', 'se_resnext101_32x4d']
pretrained_config = {
'senet154': {
'imagenet': {
'url': 'http://data.lip6.fr/cadene/pretrainedmodels/senet154-c7b49a05.pth',
'input_space': 'RGB',
'input_size': [3, 224, 224],
'input_range': [0, 1],
'mean': [0.485, 0.456, 0.406],
'std': [0.229, 0.224, 0.225],
'num_classes': 1000
}
},
'se_resnet18': {
'imagenet': {
'url': 'http://data.lip6.fr/cadene/pretrainedmodels/se_resnet50-ce0d4300.pth',
'input_space': 'RGB',
'input_size': [3, 224, 224],
'input_range': [0, 1],
'mean': [0.485, 0.456, 0.406],
'std': [0.229, 0.224, 0.225],
'num_classes': 1000
}
},
'se_resnet34': {
'imagenet': {
'url': 'http://data.lip6.fr/cadene/pretrainedmodels/se_resnet50-ce0d4300.pth',
'input_space': 'RGB',
'input_size': [3, 224, 224],
'input_range': [0, 1],
'mean': [0.485, 0.456, 0.406],
'std': [0.229, 0.224, 0.225],
'num_classes': 1000
}
},
'se_resnet50': {
'imagenet': {
'url': 'http://data.lip6.fr/cadene/pretrainedmodels/se_resnet50-ce0d4300.pth',
'input_space': 'RGB',
'input_size': [3, 224, 224],
'input_range': [0, 1],
'mean': [0.485, 0.456, 0.406],
'std': [0.229, 0.224, 0.225],
'num_classes': 1000
}
},
'se_resnet101': {
'imagenet': {
'url': 'http://data.lip6.fr/cadene/pretrainedmodels/se_resnet101-7e38fcc6.pth',
'input_space': 'RGB',
'input_size': [3, 224, 224],
'input_range': [0, 1],
'mean': [0.485, 0.456, 0.406],
'std': [0.229, 0.224, 0.225],
'num_classes': 1000
}
},
'se_resnet152': {
'imagenet': {
'url': 'http://data.lip6.fr/cadene/pretrainedmodels/se_resnet152-d17c99b7.pth',
'input_space': 'RGB',
'input_size': [3, 224, 224],
'input_range': [0, 1],
'mean': [0.485, 0.456, 0.406],
'std': [0.229, 0.224, 0.225],
'num_classes': 1000
}
},
'se_resnext50_32x4d': {
'imagenet': {
'url': 'http://data.lip6.fr/cadene/pretrainedmodels/se_resnext50_32x4d-a260b3a4.pth',
'input_space': 'RGB',
'input_size': [3, 224, 224],
'input_range': [0, 1],
'mean': [0.485, 0.456, 0.406],
'std': [0.229, 0.224, 0.225],
'num_classes': 1000
}
},
'se_resnext101_32x4d': {
'imagenet': {
'url': 'http://data.lip6.fr/cadene/pretrainedmodels/se_resnext101_32x4d-3b2fe3d8.pth',
'input_space': 'RGB',
'input_size': [3, 224, 224],
'input_range': [0, 1],
'mean': [0.485, 0.456, 0.406],
'std': [0.229, 0.224, 0.225],
'num_classes': 1000
}
},
__all__ = ['SENet', 'senet154', 'seresnet50', 'seresnet101', 'seresnet152',
'seresnext50_32x4d', 'seresnext101_32x4d']
model_urls = {
'senet154': 'http://data.lip6.fr/cadene/pretrainedmodels/senet154-c7b49a05.pth',
'seresnet18': 'http://data.lip6.fr/cadene/pretrainedmodels/se_resnet50-ce0d4300.pth',
'seresnet34': 'https://www.dropbox.com/s/q31ccy22aq0fju7/seresnet34-a4004e63.pth?dl=1',
'seresnet50': 'http://data.lip6.fr/cadene/pretrainedmodels/se_resnet50-ce0d4300.pth',
'seresnet101': 'http://data.lip6.fr/cadene/pretrainedmodels/se_resnet101-7e38fcc6.pth',
'seresnet152': 'http://data.lip6.fr/cadene/pretrainedmodels/se_resnet152-d17c99b7.pth',
'seresnext50_32x4d': 'http://data.lip6.fr/cadene/pretrainedmodels/se_resnext50_32x4d-a260b3a4.pth',
'seresnext101_32x4d': 'http://data.lip6.fr/cadene/pretrainedmodels/se_resnext101_32x4d-3b2fe3d8.pth',
}
def _weight_init(m, n='', ll=''):
def _weight_init(m):
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
elif isinstance(m, nn.BatchNorm2d):
@ -138,6 +58,7 @@ class Bottleneck(nn.Module):
"""
Base class for bottlenecks that implements `forward()` method.
"""
def forward(self, x):
residual = x
@ -273,7 +194,7 @@ class SEResNetBlock(nn.Module):
class SENet(nn.Module):
def __init__(self, block, layers, groups, reduction, dropout_p=0.2,
inch=3, inplanes=128, input_3x3=True, downsample_kernel_size=3,
inchans=3, inplanes=128, input_3x3=True, downsample_kernel_size=3,
downsample_padding=1, num_classes=1000):
"""
Parameters
@ -320,9 +241,10 @@ class SENet(nn.Module):
"""
super(SENet, self).__init__()
self.inplanes = inplanes
self.num_classes = num_classes
if input_3x3:
layer0_modules = [
('conv1', nn.Conv2d(inch, 64, 3, stride=2, padding=1, bias=False)),
('conv1', nn.Conv2d(inchans, 64, 3, stride=2, padding=1, bias=False)),
('bn1', nn.BatchNorm2d(64)),
('relu1', nn.ReLU(inplace=True)),
('conv2', nn.Conv2d(64, 64, 3, stride=1, padding=1, bias=False)),
@ -335,7 +257,7 @@ class SENet(nn.Module):
else:
layer0_modules = [
('conv1', nn.Conv2d(
inch, inplanes, kernel_size=7, stride=2, padding=3, bias=False)),
inchans, inplanes, kernel_size=7, stride=2, padding=3, bias=False)),
('bn1', nn.BatchNorm2d(inplanes)),
('relu1', nn.ReLU(inplace=True)),
]
@ -384,7 +306,8 @@ class SENet(nn.Module):
)
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.dropout = nn.Dropout(dropout_p) if dropout_p is not None else None
self.last_linear = nn.Linear(512 * block.expansion, num_classes)
self.num_features = 512 * block.expansion
self.last_linear = nn.Linear(self.num_features, num_classes)
for m in self.modules():
_weight_init(m)
@ -408,19 +331,31 @@ class SENet(nn.Module):
return nn.Sequential(*layers)
def forward_features(self, x):
def get_classifier(self):
return self.last_linear
def reset_classifier(self, num_classes):
self.num_classes = num_classes
del self.last_linear
if num_classes:
self.last_linear = nn.Linear(self.num_features, num_classes)
else:
self.last_linear = None
def forward_features(self, x, pool=True):
x = self.layer0(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
if pool:
x = self.avg_pool(x)
x = x.view(x.size(0), -1)
return x
def logits(self, x):
x = self.avg_pool(x)
if self.dropout is not None:
x = self.dropout(x)
x = x.view(x.size(0), -1)
x = self.last_linear(x)
return x
@ -430,99 +365,89 @@ class SENet(nn.Module):
return x
def initialize_pretrained_model(model, num_classes, config):
assert num_classes == config['num_classes'], \
'num_classes should be {}, but is {}'.format(
config['num_classes'], num_classes)
model.load_state_dict(model_zoo.load_url(config['url']))
model.input_space = config['input_space']
model.input_size = config['input_size']
model.input_range = config['input_range']
model.mean = config['mean']
model.std = config['std']
def _load_pretrained(model, url, inchans=3):
state_dict = model_zoo.load_url(url)
if inchans == 1:
conv1_weight = state_dict['conv1.weight']
state_dict['conv1.weight'] = conv1_weight.sum(dim=1, keepdim=True)
elif inchans != 3:
assert False, "Invalid inchans for pretrained weights"
model.load_state_dict(state_dict)
def senet154(num_classes=1000, pretrained='imagenet'):
def senet154(num_classes=1000, inchans=3, pretrained='imagenet'):
model = SENet(SEBottleneck, [3, 8, 36, 3], groups=64, reduction=16,
dropout_p=0.2, num_classes=num_classes)
if pretrained:
config = pretrained_config['senet154'][pretrained]
initialize_pretrained_model(model, num_classes, config)
_load_pretrained(model, model_urls['senet154'], inchans)
return model
def se_resnet18(num_classes=1000, pretrained='imagenet'):
def seresnet18(num_classes=1000, inchans=3, pretrained='imagenet'):
model = SENet(SEResNetBlock, [2, 2, 2, 2], groups=1, reduction=16,
dropout_p=None, inplanes=64, input_3x3=False,
downsample_kernel_size=1, downsample_padding=0,
num_classes=num_classes)
if pretrained:
config = pretrained_config['se_resnet18'][pretrained]
initialize_pretrained_model(model, num_classes, config)
_load_pretrained(model, model_urls['seresnet18'], inchans)
return model
def se_resnet34(num_classes=1000, pretrained='imagenet'):
def seresnet34(num_classes=1000, inchans=3, pretrained='imagenet'):
model = SENet(SEResNetBlock, [3, 4, 6, 3], groups=1, reduction=16,
dropout_p=None, inplanes=64, input_3x3=False,
downsample_kernel_size=1, downsample_padding=0,
num_classes=num_classes)
if pretrained:
config = pretrained_config['se_resnet34'][pretrained]
initialize_pretrained_model(model, num_classes, config)
_load_pretrained(model, model_urls['seresnet34'], inchans)
return model
def se_resnet50(num_classes=1000, pretrained='imagenet'):
def seresnet50(num_classes=1000, inchans=3, pretrained='imagenet'):
model = SENet(SEResNetBottleneck, [3, 4, 6, 3], groups=1, reduction=16,
dropout_p=None, inplanes=64, input_3x3=False,
downsample_kernel_size=1, downsample_padding=0,
num_classes=num_classes)
if pretrained:
config = pretrained_config['se_resnet50'][pretrained]
initialize_pretrained_model(model, num_classes, config)
_load_pretrained(model, model_urls['seresnet50'], inchans)
return model
def se_resnet101(num_classes=1000, pretrained='imagenet'):
def seresnet101(num_classes=1000, inchans=3, pretrained='imagenet'):
model = SENet(SEResNetBottleneck, [3, 4, 23, 3], groups=1, reduction=16,
dropout_p=None, inplanes=64, input_3x3=False,
downsample_kernel_size=1, downsample_padding=0,
num_classes=num_classes)
if pretrained:
config = pretrained_config['se_resnet101'][pretrained]
initialize_pretrained_model(model, num_classes, config)
_load_pretrained(model, model_urls['seresnet101'], inchans)
return model
def se_resnet152(num_classes=1000, pretrained='imagenet'):
def seresnet152(num_classes=1000, inchans=3, pretrained='imagenet'):
model = SENet(SEResNetBottleneck, [3, 8, 36, 3], groups=1, reduction=16,
dropout_p=None, inplanes=64, input_3x3=False,
downsample_kernel_size=1, downsample_padding=0,
num_classes=num_classes)
if pretrained:
config = pretrained_config['se_resnet152'][pretrained]
initialize_pretrained_model(model, num_classes, config)
_load_pretrained(model, model_urls['seresnet152'], inchans)
return model
def se_resnext50_32x4d(num_classes=1000, pretrained='imagenet'):
def seresnext50_32x4d(num_classes=1000, inchans=3, pretrained='imagenet'):
model = SENet(SEResNeXtBottleneck, [3, 4, 6, 3], groups=32, reduction=16,
dropout_p=None, inplanes=64, input_3x3=False,
downsample_kernel_size=1, downsample_padding=0,
num_classes=num_classes)
if pretrained:
config = pretrained_config['se_resnext50_32x4d'][pretrained]
initialize_pretrained_model(model, num_classes, config)
_load_pretrained(model, model_urls['seresnext50_32x4d'], inchans)
return model
def se_resnext101_32x4d(num_classes=1000, pretrained='imagenet'):
def seresnext101_32x4d(num_classes=1000, inchans=3, pretrained='imagenet'):
model = SENet(SEResNeXtBottleneck, [3, 4, 23, 3], groups=32, reduction=16,
dropout_p=None, inplanes=64, input_3x3=False,
downsample_kernel_size=1, downsample_padding=0,
num_classes=num_classes)
if pretrained:
config = pretrained_config['se_resnext101_32x4d'][pretrained]
initialize_pretrained_model(model, num_classes, config)
_load_pretrained(model, model_urls['seresnext101_32x4d'], inchans)
return model

@ -1,393 +0,0 @@
""" Pytorch Wide-Resnet-50-2
Sourced by running https://github.com/clcarwin/convert_torch_to_pytorch (MIT) on
https://github.com/szagoruyko/wide-residual-networks/blob/master/pretrained/README.md
License of above is, as of yet, unclear.
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.model_zoo as model_zoo
from functools import reduce
from collections import OrderedDict
from .adaptive_avgmax_pool import *
model_urls = {
'wrn50_2': 'https://www.dropbox.com/s/fe7rj3okz9rctn0/wrn50_2-d98ded61.pth?dl=1',
}
class LambdaBase(nn.Sequential):
def __init__(self, fn, *args):
super(LambdaBase, self).__init__(*args)
self.lambda_func = fn
def forward_prepare(self, input):
output = []
for module in self._modules.values():
output.append(module(input))
return output if output else input
class Lambda(LambdaBase):
def forward(self, input):
return self.lambda_func(self.forward_prepare(input))
class LambdaMap(LambdaBase):
def forward(self, input):
return list(map(self.lambda_func, self.forward_prepare(input)))
class LambdaReduce(LambdaBase):
def forward(self, input):
return reduce(self.lambda_func, self.forward_prepare(input))
def wrn_50_2_features(activation_fn=nn.ReLU()):
features = nn.Sequential( # Sequential,
nn.Conv2d(3, 64, (7, 7), (2, 2), (3, 3), 1, 1, bias=False),
nn.BatchNorm2d(64),
activation_fn,
nn.MaxPool2d((3, 3), (2, 2), (1, 1)),
nn.Sequential( # Sequential,
nn.Sequential( # Sequential,
LambdaMap(lambda x: x, # ConcatTable,
nn.Sequential( # Sequential,
nn.Conv2d(64, 128, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
nn.BatchNorm2d(128),
activation_fn,
nn.Conv2d(128, 128, (3, 3), (1, 1), (1, 1), 1, 1, bias=False),
nn.BatchNorm2d(128),
activation_fn,
nn.Conv2d(128, 256, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
nn.BatchNorm2d(256),
),
nn.Sequential( # Sequential,
nn.Conv2d(64, 256, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
nn.BatchNorm2d(256),
),
),
LambdaReduce(lambda x, y: x + y), # CAddTable,
activation_fn,
),
nn.Sequential( # Sequential,
LambdaMap(lambda x: x, # ConcatTable,
nn.Sequential( # Sequential,
nn.Conv2d(256, 128, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
nn.BatchNorm2d(128),
activation_fn,
nn.Conv2d(128, 128, (3, 3), (1, 1), (1, 1), 1, 1, bias=False),
nn.BatchNorm2d(128),
activation_fn,
nn.Conv2d(128, 256, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
nn.BatchNorm2d(256),
),
Lambda(lambda x: x), # Identity,
),
LambdaReduce(lambda x, y: x + y), # CAddTable,
activation_fn,
),
nn.Sequential( # Sequential,
LambdaMap(lambda x: x, # ConcatTable,
nn.Sequential( # Sequential,
nn.Conv2d(256, 128, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
nn.BatchNorm2d(128),
activation_fn,
nn.Conv2d(128, 128, (3, 3), (1, 1), (1, 1), 1, 1, bias=False),
nn.BatchNorm2d(128),
activation_fn,
nn.Conv2d(128, 256, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
nn.BatchNorm2d(256),
),
Lambda(lambda x: x), # Identity,
),
LambdaReduce(lambda x, y: x + y), # CAddTable,
activation_fn,
),
),
nn.Sequential( # Sequential,
nn.Sequential( # Sequential,
LambdaMap(lambda x: x, # ConcatTable,
nn.Sequential( # Sequential,
nn.Conv2d(256, 256, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
nn.BatchNorm2d(256),
activation_fn,
nn.Conv2d(256, 256, (3, 3), (2, 2), (1, 1), 1, 1, bias=False),
nn.BatchNorm2d(256),
activation_fn,
nn.Conv2d(256, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
nn.BatchNorm2d(512),
),
nn.Sequential( # Sequential,
nn.Conv2d(256, 512, (1, 1), (2, 2), (0, 0), 1, 1, bias=False),
nn.BatchNorm2d(512),
),
),
LambdaReduce(lambda x, y: x + y), # CAddTable,
activation_fn,
),
nn.Sequential( # Sequential,
LambdaMap(lambda x: x, # ConcatTable,
nn.Sequential( # Sequential,
nn.Conv2d(512, 256, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
nn.BatchNorm2d(256),
activation_fn,
nn.Conv2d(256, 256, (3, 3), (1, 1), (1, 1), 1, 1, bias=False),
nn.BatchNorm2d(256),
activation_fn,
nn.Conv2d(256, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
nn.BatchNorm2d(512),
),
Lambda(lambda x: x), # Identity,
),
LambdaReduce(lambda x, y: x + y), # CAddTable,
activation_fn,
),
nn.Sequential( # Sequential,
LambdaMap(lambda x: x, # ConcatTable,
nn.Sequential( # Sequential,
nn.Conv2d(512, 256, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
nn.BatchNorm2d(256),
activation_fn,
nn.Conv2d(256, 256, (3, 3), (1, 1), (1, 1), 1, 1, bias=False),
nn.BatchNorm2d(256),
activation_fn,
nn.Conv2d(256, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
nn.BatchNorm2d(512),
),
Lambda(lambda x: x), # Identity,
),
LambdaReduce(lambda x, y: x + y), # CAddTable,
activation_fn,
),
nn.Sequential( # Sequential,
LambdaMap(lambda x: x, # ConcatTable,
nn.Sequential( # Sequential,
nn.Conv2d(512, 256, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
nn.BatchNorm2d(256),
activation_fn,
nn.Conv2d(256, 256, (3, 3), (1, 1), (1, 1), 1, 1, bias=False),
nn.BatchNorm2d(256),
activation_fn,
nn.Conv2d(256, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
nn.BatchNorm2d(512),
),
Lambda(lambda x: x), # Identity,
),
LambdaReduce(lambda x, y: x + y), # CAddTable,
activation_fn,
),
),
nn.Sequential( # Sequential,
nn.Sequential( # Sequential,
LambdaMap(lambda x: x, # ConcatTable,
nn.Sequential( # Sequential,
nn.Conv2d(512, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
nn.BatchNorm2d(512),
activation_fn,
nn.Conv2d(512, 512, (3, 3), (2, 2), (1, 1), 1, 1, bias=False),
nn.BatchNorm2d(512),
activation_fn,
nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
nn.BatchNorm2d(1024),
),
nn.Sequential( # Sequential,
nn.Conv2d(512, 1024, (1, 1), (2, 2), (0, 0), 1, 1, bias=False),
nn.BatchNorm2d(1024),
),
),
LambdaReduce(lambda x, y: x + y), # CAddTable,
activation_fn,
),
nn.Sequential( # Sequential,
LambdaMap(lambda x: x, # ConcatTable,
nn.Sequential( # Sequential,
nn.Conv2d(1024, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
nn.BatchNorm2d(512),
activation_fn,
nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 1, bias=False),
nn.BatchNorm2d(512),
activation_fn,
nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
nn.BatchNorm2d(1024),
),
Lambda(lambda x: x), # Identity,
),
LambdaReduce(lambda x, y: x + y), # CAddTable,
activation_fn,
),
nn.Sequential( # Sequential,
LambdaMap(lambda x: x, # ConcatTable,
nn.Sequential( # Sequential,
nn.Conv2d(1024, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
nn.BatchNorm2d(512),
activation_fn,
nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 1, bias=False),
nn.BatchNorm2d(512),
activation_fn,
nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
nn.BatchNorm2d(1024),
),
Lambda(lambda x: x), # Identity,
),
LambdaReduce(lambda x, y: x + y), # CAddTable,
activation_fn,
),
nn.Sequential( # Sequential,
LambdaMap(lambda x: x, # ConcatTable,
nn.Sequential( # Sequential,
nn.Conv2d(1024, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
nn.BatchNorm2d(512),
activation_fn,
nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 1, bias=False),
nn.BatchNorm2d(512),
activation_fn,
nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
nn.BatchNorm2d(1024),
),
Lambda(lambda x: x), # Identity,
),
LambdaReduce(lambda x, y: x + y), # CAddTable,
activation_fn,
),
nn.Sequential( # Sequential,
LambdaMap(lambda x: x, # ConcatTable,
nn.Sequential( # Sequential,
nn.Conv2d(1024, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
nn.BatchNorm2d(512),
activation_fn,
nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 1, bias=False),
nn.BatchNorm2d(512),
activation_fn,
nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
nn.BatchNorm2d(1024),
),
Lambda(lambda x: x), # Identity,
),
LambdaReduce(lambda x, y: x + y), # CAddTable,
activation_fn,
),
nn.Sequential( # Sequential,
LambdaMap(lambda x: x, # ConcatTable,
nn.Sequential( # Sequential,
nn.Conv2d(1024, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
nn.BatchNorm2d(512),
activation_fn,
nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 1, bias=False),
nn.BatchNorm2d(512),
activation_fn,
nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
nn.BatchNorm2d(1024),
),
Lambda(lambda x: x), # Identity,
),
LambdaReduce(lambda x, y: x + y), # CAddTable,
activation_fn,
),
),
nn.Sequential( # Sequential,
nn.Sequential( # Sequential,
LambdaMap(lambda x: x, # ConcatTable,
nn.Sequential( # Sequential,
nn.Conv2d(1024, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
nn.BatchNorm2d(1024),
activation_fn,
nn.Conv2d(1024, 1024, (3, 3), (2, 2), (1, 1), 1, 1, bias=False),
nn.BatchNorm2d(1024),
activation_fn,
nn.Conv2d(1024, 2048, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
nn.BatchNorm2d(2048),
),
nn.Sequential( # Sequential,
nn.Conv2d(1024, 2048, (1, 1), (2, 2), (0, 0), 1, 1, bias=False),
nn.BatchNorm2d(2048),
),
),
LambdaReduce(lambda x, y: x + y), # CAddTable,
activation_fn,
),
nn.Sequential( # Sequential,
LambdaMap(lambda x: x, # ConcatTable,
nn.Sequential( # Sequential,
nn.Conv2d(2048, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
nn.BatchNorm2d(1024),
activation_fn,
nn.Conv2d(1024, 1024, (3, 3), (1, 1), (1, 1), 1, 1, bias=False),
nn.BatchNorm2d(1024),
activation_fn,
nn.Conv2d(1024, 2048, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
nn.BatchNorm2d(2048),
),
Lambda(lambda x: x), # Identity,
),
LambdaReduce(lambda x, y: x + y), # CAddTable,
activation_fn,
),
nn.Sequential( # Sequential,
LambdaMap(lambda x: x, # ConcatTable,
nn.Sequential( # Sequential,
nn.Conv2d(2048, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
nn.BatchNorm2d(1024),
activation_fn,
nn.Conv2d(1024, 1024, (3, 3), (1, 1), (1, 1), 1, 1, bias=False),
nn.BatchNorm2d(1024),
activation_fn,
nn.Conv2d(1024, 2048, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
nn.BatchNorm2d(2048),
),
Lambda(lambda x: x), # Identity,
),
LambdaReduce(lambda x, y: x + y), # CAddTable,
activation_fn,
),
),
)
return features
class Wrn50_2(nn.Module):
def __init__(self, num_classes=1000, activation_fn=nn.ReLU(), drop_rate=0., global_pool='avg'):
super(Wrn50_2, self).__init__()
self.drop_rate = drop_rate
self.num_classes = num_classes
self.num_features = 2048
self.global_pool = global_pool
self.features = wrn_50_2_features(activation_fn=activation_fn)
self.fc = nn.Linear(2048, num_classes)
def get_classifier(self):
return self.fc
def reset_classifier(self, num_classes, global_pool='avg'):
self.num_classes = num_classes
self.global_pool = global_pool
self.fc = nn.Linear(2048, num_classes)
def forward_features(self, x, pool=True):
x = self.features(x)
if pool:
x = adaptive_avgmax_pool2d(x, self.global_pool)
x = x.view(x.size(0), -1)
return x
def forward(self, x):
x = self.forward_features(x, pool=True)
if self.drop_rate > 0:
x = F.dropout(x, p=self.drop_rate, training=self.training)
x = self.fc(x)
return x
def wrn50_2(pretrained=False, num_classes=1000, **kwargs):
model = Wrn50_2(num_classes=num_classes, **kwargs)
if pretrained:
# Remap pretrained weights to match our class module with features + fc
pretrained_weights = model_zoo.load_url(model_urls['wrn50_2'])
feature_keys = filter(lambda k: '10.1.' not in k, pretrained_weights.keys())
remapped_weights = OrderedDict()
for k in feature_keys:
remapped_weights['features.' + k] = pretrained_weights[k]
remapped_weights['fc.weight'] = pretrained_weights['10.1.weight']
remapped_weights['fc.bias'] = pretrained_weights['10.1.bias']
model.load_state_dict(remapped_weights)
return model
Loading…
Cancel
Save