From e0cfeb7d8e9ae8f226770ed68d092fd17f8a2c9a Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Fri, 1 Mar 2019 13:08:35 -0800 Subject: [PATCH] Add some models, remove a model, tweak some models --- models/adaptive_avgmax_pool.py | 14 +- models/{my_densenet.py => densenet.py} | 48 ++- models/model_factory.py | 36 ++- models/{my_resnet.py => resnet.py} | 0 models/senet.py | 197 ++++--------- models/wrn50_2.py | 393 ------------------------- 6 files changed, 124 insertions(+), 564 deletions(-) rename models/{my_densenet.py => densenet.py} (81%) rename models/{my_resnet.py => resnet.py} (100%) delete mode 100644 models/wrn50_2.py diff --git a/models/adaptive_avgmax_pool.py b/models/adaptive_avgmax_pool.py index 611b05ac..01fcb4ae 100644 --- a/models/adaptive_avgmax_pool.py +++ b/models/adaptive_avgmax_pool.py @@ -14,21 +14,17 @@ import torch.nn as nn import torch.nn.functional as F -def adaptive_avgmax_pool2d(x, pool_type='avg', padding=0, count_include_pad=False): +def adaptive_avgmax_pool2d(x, pool_type='avg', output_size=1): """Selectable global pooling function with dynamic input kernel size """ if pool_type == 'avgmax': - x_avg = F.avg_pool2d( - x, kernel_size=(x.size(2), x.size(3)), padding=padding, count_include_pad=count_include_pad) - x_max = F.max_pool2d(x, kernel_size=(x.size(2), x.size(3)), padding=padding) + x_avg = F.adaptive_avg_pool2d(x, output_size) + x_max = F.adaptive_max_pool2d(x, output_size) x = 0.5 * (x_avg + x_max) elif pool_type == 'max': - x = F.max_pool2d(x, kernel_size=(x.size(2), x.size(3)), padding=padding) + x = F.adaptive_max_pool2d(x, output_size) else: - if pool_type != 'avg': - print('Invalid pool type %s specified. Defaulting to average pooling.' % pool_type) - x = F.avg_pool2d( - x, kernel_size=(x.size(2), x.size(3)), padding=padding, count_include_pad=count_include_pad) + x = F.adaptive_avg_pool2d(x, output_size) return x diff --git a/models/my_densenet.py b/models/densenet.py similarity index 81% rename from models/my_densenet.py rename to models/densenet.py index 1d29f574..9a63533f 100644 --- a/models/my_densenet.py +++ b/models/densenet.py @@ -8,6 +8,7 @@ import torch.nn.functional as F import torch.utils.model_zoo as model_zoo from collections import OrderedDict from .adaptive_avgmax_pool import * +import re __all__ = ['DenseNet', 'densenet121', 'densenet169', 'densenet201', 'densenet161'] @@ -20,6 +21,19 @@ model_urls = { } +def _filter_pretrained(state_dict): + pattern = re.compile( + r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$') + + for key in list(state_dict.keys()): + res = pattern.match(key) + if res: + new_key = res.group(1) + res.group(2) + state_dict[new_key] = state_dict[key] + del state_dict[key] + return state_dict + + def densenet121(pretrained=False, **kwargs): r"""Densenet-121 model from `"Densely Connected Convolutional Networks" ` @@ -29,7 +43,8 @@ def densenet121(pretrained=False, **kwargs): """ model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 24, 16), **kwargs) if pretrained: - model.load_state_dict(model_zoo.load_url(model_urls['densenet121'])) + state_dict = model_zoo.load_url(model_urls['densenet121']) + model.load_state_dict(_filter_pretrained(state_dict)) return model @@ -42,7 +57,8 @@ def densenet169(pretrained=False, **kwargs): """ model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 32, 32), **kwargs) if pretrained: - model.load_state_dict(model_zoo.load_url(model_urls['densenet169'])) + state_dict = model_zoo.load_url(model_urls['densenet169']) + model.load_state_dict(_filter_pretrained(state_dict)) return model @@ -55,7 +71,8 @@ def densenet201(pretrained=False, **kwargs): """ model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 48, 32), **kwargs) if pretrained: - model.load_state_dict(model_zoo.load_url(model_urls['densenet201'])) + state_dict = model_zoo.load_url(model_urls['densenet201']) + model.load_state_dict(_filter_pretrained(state_dict)) return model @@ -69,20 +86,21 @@ def densenet161(pretrained=False, **kwargs): print(kwargs) model = DenseNet(num_init_features=96, growth_rate=48, block_config=(6, 12, 36, 24), **kwargs) if pretrained: - model.load_state_dict(model_zoo.load_url(model_urls['densenet161'])) + state_dict = model_zoo.load_url(model_urls['densenet161']) + model.load_state_dict(_filter_pretrained(state_dict)) return model class _DenseLayer(nn.Sequential): def __init__(self, num_input_features, growth_rate, bn_size, drop_rate): super(_DenseLayer, self).__init__() - self.add_module('norm.1', nn.BatchNorm2d(num_input_features)), - self.add_module('relu.1', nn.ReLU(inplace=True)), - self.add_module('conv.1', nn.Conv2d(num_input_features, bn_size * + self.add_module('norm1', nn.BatchNorm2d(num_input_features)), + self.add_module('relu1', nn.ReLU(inplace=True)), + self.add_module('conv1', nn.Conv2d(num_input_features, bn_size * growth_rate, kernel_size=1, stride=1, bias=False)), - self.add_module('norm.2', nn.BatchNorm2d(bn_size * growth_rate)), - self.add_module('relu.2', nn.ReLU(inplace=True)), - self.add_module('conv.2', nn.Conv2d(bn_size * growth_rate, growth_rate, + self.add_module('norm2', nn.BatchNorm2d(bn_size * growth_rate)), + self.add_module('relu2', nn.ReLU(inplace=True)), + self.add_module('conv2', nn.Conv2d(bn_size * growth_rate, growth_rate, kernel_size=3, stride=1, padding=1, bias=False)), self.drop_rate = drop_rate @@ -172,12 +190,12 @@ class DenseNet(nn.Module): self.classifier = None def forward_features(self, x, pool=True): - features = self.features(x) - out = F.relu(features, inplace=True) + x = self.features(x) + x = F.relu(x, inplace=True) if pool: - out = adaptive_avgmax_pool2d(out, self.global_pool) - out = x.view(out.size(0), -1) - return out + x = adaptive_avgmax_pool2d(x, self.global_pool) + x = x.view(x.size(0), -1) + return x def forward(self, x): return self.classifier(self.forward_features(x, pool=True)) diff --git a/models/model_factory.py b/models/model_factory.py index e43da7bf..806b6ee2 100644 --- a/models/model_factory.py +++ b/models/model_factory.py @@ -6,13 +6,13 @@ import os from .inception_v4 import inception_v4 from .inception_resnet_v2 import inception_resnet_v2 -from .wrn50_2 import wrn50_2 -from .my_densenet import densenet161, densenet121, densenet169, densenet201 -from .my_resnet import resnet18, resnet34, resnet50, resnet101, resnet152 +from .densenet import densenet161, densenet121, densenet169, densenet201 +from .resnet import resnet18, resnet34, resnet50, resnet101, resnet152 from .fbresnet200 import fbresnet200 from .dpn import dpn68, dpn68b, dpn92, dpn98, dpn131, dpn107 -from .senet import se_resnet18, se_resnet34, se_resnet50, se_resnet101, se_resnet152,\ - se_resnext50_32x4d, se_resnext101_32x4d +from .senet import seresnet18, seresnet34, seresnet50, seresnet101, seresnet152,\ + seresnext50_32x4d, seresnext101_32x4d +from .resnext import resnext50, resnext101, resnext152 model_config_dict = { @@ -99,15 +99,29 @@ def create_model( elif model_name == 'inception_resnet_v2': model = inception_resnet_v2(num_classes=num_classes, pretrained=pretrained, **kwargs) elif model_name == 'inception_v4': - model = inception_v4(num_classes=num_classes, pretrained=pretrained, **kwargs) - elif model_name == 'wrn50': - model = wrn50_2(num_classes=num_classes, pretrained=pretrained, **kwargs) + model = inception_v4(num_classes=num_classes, pretrained=pretrained, **kwargs) elif model_name == 'fbresnet200': - model = fbresnet200(num_classes=num_classes, pretrained=pretrained, **kwargs) + model = fbresnet200(num_classes=num_classes, pretrained=pretrained, **kwargs) elif model_name == 'seresnet18': - model = se_resnet18(num_classes=num_classes, pretrained=pretrained) + model = seresnet18(num_classes=num_classes, pretrained=pretrained, **kwargs) elif model_name == 'seresnet34': - model = se_resnet34(num_classes=num_classes, pretrained=pretrained) + model = seresnet34(num_classes=num_classes, pretrained=pretrained, **kwargs) + elif model_name == 'seresnet50': + model = seresnet50(num_classes=num_classes, pretrained=pretrained, **kwargs) + elif model_name == 'seresnet101': + model = seresnet101(num_classes=num_classes, pretrained=pretrained, **kwargs) + elif model_name == 'seresnet152': + model = seresnet152(num_classes=num_classes, pretrained=pretrained, **kwargs) + elif model_name == 'seresnext50_32x4d': + model = seresnext50_32x4d(num_classes=num_classes, pretrained=pretrained, **kwargs) + elif model_name == 'seresnext101_32x4d': + model = seresnext101_32x4d(num_classes=num_classes, pretrained=pretrained, **kwargs) + elif model_name == 'resnext50': + model = resnext50(num_classes=num_classes, pretrained=pretrained, **kwargs) + elif model_name == 'resnext101': + model = resnext101(num_classes=num_classes, pretrained=pretrained, **kwargs) + elif model_name == 'resnext152': + model = resnext152(num_classes=num_classes, pretrained=pretrained, **kwargs) else: assert False and "Invalid model" diff --git a/models/my_resnet.py b/models/resnet.py similarity index 100% rename from models/my_resnet.py rename to models/resnet.py diff --git a/models/senet.py b/models/senet.py index d16ccf62..cf169f24 100644 --- a/models/senet.py +++ b/models/senet.py @@ -9,102 +9,22 @@ import math import torch.nn as nn from torch.utils import model_zoo -__all__ = ['SENet', 'senet154', 'se_resnet50', 'se_resnet101', 'se_resnet152', - 'se_resnext50_32x4d', 'se_resnext101_32x4d'] - -pretrained_config = { - 'senet154': { - 'imagenet': { - 'url': 'http://data.lip6.fr/cadene/pretrainedmodels/senet154-c7b49a05.pth', - 'input_space': 'RGB', - 'input_size': [3, 224, 224], - 'input_range': [0, 1], - 'mean': [0.485, 0.456, 0.406], - 'std': [0.229, 0.224, 0.225], - 'num_classes': 1000 - } - }, - 'se_resnet18': { - 'imagenet': { - 'url': 'http://data.lip6.fr/cadene/pretrainedmodels/se_resnet50-ce0d4300.pth', - 'input_space': 'RGB', - 'input_size': [3, 224, 224], - 'input_range': [0, 1], - 'mean': [0.485, 0.456, 0.406], - 'std': [0.229, 0.224, 0.225], - 'num_classes': 1000 - } - }, - 'se_resnet34': { - 'imagenet': { - 'url': 'http://data.lip6.fr/cadene/pretrainedmodels/se_resnet50-ce0d4300.pth', - 'input_space': 'RGB', - 'input_size': [3, 224, 224], - 'input_range': [0, 1], - 'mean': [0.485, 0.456, 0.406], - 'std': [0.229, 0.224, 0.225], - 'num_classes': 1000 - } - }, - 'se_resnet50': { - 'imagenet': { - 'url': 'http://data.lip6.fr/cadene/pretrainedmodels/se_resnet50-ce0d4300.pth', - 'input_space': 'RGB', - 'input_size': [3, 224, 224], - 'input_range': [0, 1], - 'mean': [0.485, 0.456, 0.406], - 'std': [0.229, 0.224, 0.225], - 'num_classes': 1000 - } - }, - 'se_resnet101': { - 'imagenet': { - 'url': 'http://data.lip6.fr/cadene/pretrainedmodels/se_resnet101-7e38fcc6.pth', - 'input_space': 'RGB', - 'input_size': [3, 224, 224], - 'input_range': [0, 1], - 'mean': [0.485, 0.456, 0.406], - 'std': [0.229, 0.224, 0.225], - 'num_classes': 1000 - } - }, - 'se_resnet152': { - 'imagenet': { - 'url': 'http://data.lip6.fr/cadene/pretrainedmodels/se_resnet152-d17c99b7.pth', - 'input_space': 'RGB', - 'input_size': [3, 224, 224], - 'input_range': [0, 1], - 'mean': [0.485, 0.456, 0.406], - 'std': [0.229, 0.224, 0.225], - 'num_classes': 1000 - } - }, - 'se_resnext50_32x4d': { - 'imagenet': { - 'url': 'http://data.lip6.fr/cadene/pretrainedmodels/se_resnext50_32x4d-a260b3a4.pth', - 'input_space': 'RGB', - 'input_size': [3, 224, 224], - 'input_range': [0, 1], - 'mean': [0.485, 0.456, 0.406], - 'std': [0.229, 0.224, 0.225], - 'num_classes': 1000 - } - }, - 'se_resnext101_32x4d': { - 'imagenet': { - 'url': 'http://data.lip6.fr/cadene/pretrainedmodels/se_resnext101_32x4d-3b2fe3d8.pth', - 'input_space': 'RGB', - 'input_size': [3, 224, 224], - 'input_range': [0, 1], - 'mean': [0.485, 0.456, 0.406], - 'std': [0.229, 0.224, 0.225], - 'num_classes': 1000 - } - }, +__all__ = ['SENet', 'senet154', 'seresnet50', 'seresnet101', 'seresnet152', + 'seresnext50_32x4d', 'seresnext101_32x4d'] + +model_urls = { + 'senet154': 'http://data.lip6.fr/cadene/pretrainedmodels/senet154-c7b49a05.pth', + 'seresnet18': 'http://data.lip6.fr/cadene/pretrainedmodels/se_resnet50-ce0d4300.pth', + 'seresnet34': 'https://www.dropbox.com/s/q31ccy22aq0fju7/seresnet34-a4004e63.pth?dl=1', + 'seresnet50': 'http://data.lip6.fr/cadene/pretrainedmodels/se_resnet50-ce0d4300.pth', + 'seresnet101': 'http://data.lip6.fr/cadene/pretrainedmodels/se_resnet101-7e38fcc6.pth', + 'seresnet152': 'http://data.lip6.fr/cadene/pretrainedmodels/se_resnet152-d17c99b7.pth', + 'seresnext50_32x4d': 'http://data.lip6.fr/cadene/pretrainedmodels/se_resnext50_32x4d-a260b3a4.pth', + 'seresnext101_32x4d': 'http://data.lip6.fr/cadene/pretrainedmodels/se_resnext101_32x4d-3b2fe3d8.pth', } -def _weight_init(m, n='', ll=''): +def _weight_init(m): if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') elif isinstance(m, nn.BatchNorm2d): @@ -138,6 +58,7 @@ class Bottleneck(nn.Module): """ Base class for bottlenecks that implements `forward()` method. """ + def forward(self, x): residual = x @@ -236,7 +157,7 @@ class SEResNeXtBottleneck(Bottleneck): class SEResNetBlock(nn.Module): expansion = 1 - + def __init__(self, inplanes, planes, groups, reduction, stride=1, downsample=None): super(SEResNetBlock, self).__init__() self.conv1 = nn.Conv2d( @@ -273,7 +194,7 @@ class SEResNetBlock(nn.Module): class SENet(nn.Module): def __init__(self, block, layers, groups, reduction, dropout_p=0.2, - inch=3, inplanes=128, input_3x3=True, downsample_kernel_size=3, + inchans=3, inplanes=128, input_3x3=True, downsample_kernel_size=3, downsample_padding=1, num_classes=1000): """ Parameters @@ -320,9 +241,10 @@ class SENet(nn.Module): """ super(SENet, self).__init__() self.inplanes = inplanes + self.num_classes = num_classes if input_3x3: layer0_modules = [ - ('conv1', nn.Conv2d(inch, 64, 3, stride=2, padding=1, bias=False)), + ('conv1', nn.Conv2d(inchans, 64, 3, stride=2, padding=1, bias=False)), ('bn1', nn.BatchNorm2d(64)), ('relu1', nn.ReLU(inplace=True)), ('conv2', nn.Conv2d(64, 64, 3, stride=1, padding=1, bias=False)), @@ -335,7 +257,7 @@ class SENet(nn.Module): else: layer0_modules = [ ('conv1', nn.Conv2d( - inch, inplanes, kernel_size=7, stride=2, padding=3, bias=False)), + inchans, inplanes, kernel_size=7, stride=2, padding=3, bias=False)), ('bn1', nn.BatchNorm2d(inplanes)), ('relu1', nn.ReLU(inplace=True)), ] @@ -384,7 +306,8 @@ class SENet(nn.Module): ) self.avg_pool = nn.AdaptiveAvgPool2d(1) self.dropout = nn.Dropout(dropout_p) if dropout_p is not None else None - self.last_linear = nn.Linear(512 * block.expansion, num_classes) + self.num_features = 512 * block.expansion + self.last_linear = nn.Linear(self.num_features, num_classes) for m in self.modules(): _weight_init(m) @@ -408,19 +331,31 @@ class SENet(nn.Module): return nn.Sequential(*layers) - def forward_features(self, x): + def get_classifier(self): + return self.last_linear + + def reset_classifier(self, num_classes): + self.num_classes = num_classes + del self.last_linear + if num_classes: + self.last_linear = nn.Linear(self.num_features, num_classes) + else: + self.last_linear = None + + def forward_features(self, x, pool=True): x = self.layer0(x) x = self.layer1(x) x = self.layer2(x) x = self.layer3(x) x = self.layer4(x) + if pool: + x = self.avg_pool(x) + x = x.view(x.size(0), -1) return x def logits(self, x): - x = self.avg_pool(x) if self.dropout is not None: x = self.dropout(x) - x = x.view(x.size(0), -1) x = self.last_linear(x) return x @@ -430,99 +365,89 @@ class SENet(nn.Module): return x -def initialize_pretrained_model(model, num_classes, config): - assert num_classes == config['num_classes'], \ - 'num_classes should be {}, but is {}'.format( - config['num_classes'], num_classes) - model.load_state_dict(model_zoo.load_url(config['url'])) - model.input_space = config['input_space'] - model.input_size = config['input_size'] - model.input_range = config['input_range'] - model.mean = config['mean'] - model.std = config['std'] - +def _load_pretrained(model, url, inchans=3): + state_dict = model_zoo.load_url(url) + if inchans == 1: + conv1_weight = state_dict['conv1.weight'] + state_dict['conv1.weight'] = conv1_weight.sum(dim=1, keepdim=True) + elif inchans != 3: + assert False, "Invalid inchans for pretrained weights" + model.load_state_dict(state_dict) + -def senet154(num_classes=1000, pretrained='imagenet'): +def senet154(num_classes=1000, inchans=3, pretrained='imagenet'): model = SENet(SEBottleneck, [3, 8, 36, 3], groups=64, reduction=16, dropout_p=0.2, num_classes=num_classes) if pretrained: - config = pretrained_config['senet154'][pretrained] - initialize_pretrained_model(model, num_classes, config) + _load_pretrained(model, model_urls['senet154'], inchans) return model -def se_resnet18(num_classes=1000, pretrained='imagenet'): +def seresnet18(num_classes=1000, inchans=3, pretrained='imagenet'): model = SENet(SEResNetBlock, [2, 2, 2, 2], groups=1, reduction=16, dropout_p=None, inplanes=64, input_3x3=False, downsample_kernel_size=1, downsample_padding=0, num_classes=num_classes) if pretrained: - config = pretrained_config['se_resnet18'][pretrained] - initialize_pretrained_model(model, num_classes, config) + _load_pretrained(model, model_urls['seresnet18'], inchans) return model -def se_resnet34(num_classes=1000, pretrained='imagenet'): +def seresnet34(num_classes=1000, inchans=3, pretrained='imagenet'): model = SENet(SEResNetBlock, [3, 4, 6, 3], groups=1, reduction=16, dropout_p=None, inplanes=64, input_3x3=False, downsample_kernel_size=1, downsample_padding=0, num_classes=num_classes) if pretrained: - config = pretrained_config['se_resnet34'][pretrained] - initialize_pretrained_model(model, num_classes, config) + _load_pretrained(model, model_urls['seresnet34'], inchans) return model -def se_resnet50(num_classes=1000, pretrained='imagenet'): +def seresnet50(num_classes=1000, inchans=3, pretrained='imagenet'): model = SENet(SEResNetBottleneck, [3, 4, 6, 3], groups=1, reduction=16, dropout_p=None, inplanes=64, input_3x3=False, downsample_kernel_size=1, downsample_padding=0, num_classes=num_classes) if pretrained: - config = pretrained_config['se_resnet50'][pretrained] - initialize_pretrained_model(model, num_classes, config) + _load_pretrained(model, model_urls['seresnet50'], inchans) return model -def se_resnet101(num_classes=1000, pretrained='imagenet'): +def seresnet101(num_classes=1000, inchans=3, pretrained='imagenet'): model = SENet(SEResNetBottleneck, [3, 4, 23, 3], groups=1, reduction=16, dropout_p=None, inplanes=64, input_3x3=False, downsample_kernel_size=1, downsample_padding=0, num_classes=num_classes) if pretrained: - config = pretrained_config['se_resnet101'][pretrained] - initialize_pretrained_model(model, num_classes, config) + _load_pretrained(model, model_urls['seresnet101'], inchans) return model -def se_resnet152(num_classes=1000, pretrained='imagenet'): +def seresnet152(num_classes=1000, inchans=3, pretrained='imagenet'): model = SENet(SEResNetBottleneck, [3, 8, 36, 3], groups=1, reduction=16, dropout_p=None, inplanes=64, input_3x3=False, downsample_kernel_size=1, downsample_padding=0, num_classes=num_classes) if pretrained: - config = pretrained_config['se_resnet152'][pretrained] - initialize_pretrained_model(model, num_classes, config) + _load_pretrained(model, model_urls['seresnet152'], inchans) return model -def se_resnext50_32x4d(num_classes=1000, pretrained='imagenet'): +def seresnext50_32x4d(num_classes=1000, inchans=3, pretrained='imagenet'): model = SENet(SEResNeXtBottleneck, [3, 4, 6, 3], groups=32, reduction=16, dropout_p=None, inplanes=64, input_3x3=False, downsample_kernel_size=1, downsample_padding=0, num_classes=num_classes) if pretrained: - config = pretrained_config['se_resnext50_32x4d'][pretrained] - initialize_pretrained_model(model, num_classes, config) + _load_pretrained(model, model_urls['seresnext50_32x4d'], inchans) return model -def se_resnext101_32x4d(num_classes=1000, pretrained='imagenet'): +def seresnext101_32x4d(num_classes=1000, inchans=3, pretrained='imagenet'): model = SENet(SEResNeXtBottleneck, [3, 4, 23, 3], groups=32, reduction=16, dropout_p=None, inplanes=64, input_3x3=False, downsample_kernel_size=1, downsample_padding=0, num_classes=num_classes) if pretrained: - config = pretrained_config['se_resnext101_32x4d'][pretrained] - initialize_pretrained_model(model, num_classes, config) + _load_pretrained(model, model_urls['seresnext101_32x4d'], inchans) return model diff --git a/models/wrn50_2.py b/models/wrn50_2.py deleted file mode 100644 index 63274fd2..00000000 --- a/models/wrn50_2.py +++ /dev/null @@ -1,393 +0,0 @@ -""" Pytorch Wide-Resnet-50-2 -Sourced by running https://github.com/clcarwin/convert_torch_to_pytorch (MIT) on -https://github.com/szagoruyko/wide-residual-networks/blob/master/pretrained/README.md -License of above is, as of yet, unclear. -""" -import torch -import torch.nn as nn -import torch.nn.functional as F -import torch.utils.model_zoo as model_zoo -from functools import reduce -from collections import OrderedDict -from .adaptive_avgmax_pool import * - -model_urls = { - 'wrn50_2': 'https://www.dropbox.com/s/fe7rj3okz9rctn0/wrn50_2-d98ded61.pth?dl=1', -} - - -class LambdaBase(nn.Sequential): - def __init__(self, fn, *args): - super(LambdaBase, self).__init__(*args) - self.lambda_func = fn - - def forward_prepare(self, input): - output = [] - for module in self._modules.values(): - output.append(module(input)) - return output if output else input - - -class Lambda(LambdaBase): - def forward(self, input): - return self.lambda_func(self.forward_prepare(input)) - - -class LambdaMap(LambdaBase): - def forward(self, input): - return list(map(self.lambda_func, self.forward_prepare(input))) - - -class LambdaReduce(LambdaBase): - def forward(self, input): - return reduce(self.lambda_func, self.forward_prepare(input)) - - -def wrn_50_2_features(activation_fn=nn.ReLU()): - features = nn.Sequential( # Sequential, - nn.Conv2d(3, 64, (7, 7), (2, 2), (3, 3), 1, 1, bias=False), - nn.BatchNorm2d(64), - activation_fn, - nn.MaxPool2d((3, 3), (2, 2), (1, 1)), - nn.Sequential( # Sequential, - nn.Sequential( # Sequential, - LambdaMap(lambda x: x, # ConcatTable, - nn.Sequential( # Sequential, - nn.Conv2d(64, 128, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), - nn.BatchNorm2d(128), - activation_fn, - nn.Conv2d(128, 128, (3, 3), (1, 1), (1, 1), 1, 1, bias=False), - nn.BatchNorm2d(128), - activation_fn, - nn.Conv2d(128, 256, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), - nn.BatchNorm2d(256), - ), - nn.Sequential( # Sequential, - nn.Conv2d(64, 256, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), - nn.BatchNorm2d(256), - ), - ), - LambdaReduce(lambda x, y: x + y), # CAddTable, - activation_fn, - ), - nn.Sequential( # Sequential, - LambdaMap(lambda x: x, # ConcatTable, - nn.Sequential( # Sequential, - nn.Conv2d(256, 128, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), - nn.BatchNorm2d(128), - activation_fn, - nn.Conv2d(128, 128, (3, 3), (1, 1), (1, 1), 1, 1, bias=False), - nn.BatchNorm2d(128), - activation_fn, - nn.Conv2d(128, 256, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), - nn.BatchNorm2d(256), - ), - Lambda(lambda x: x), # Identity, - ), - LambdaReduce(lambda x, y: x + y), # CAddTable, - activation_fn, - ), - nn.Sequential( # Sequential, - LambdaMap(lambda x: x, # ConcatTable, - nn.Sequential( # Sequential, - nn.Conv2d(256, 128, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), - nn.BatchNorm2d(128), - activation_fn, - nn.Conv2d(128, 128, (3, 3), (1, 1), (1, 1), 1, 1, bias=False), - nn.BatchNorm2d(128), - activation_fn, - nn.Conv2d(128, 256, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), - nn.BatchNorm2d(256), - ), - Lambda(lambda x: x), # Identity, - ), - LambdaReduce(lambda x, y: x + y), # CAddTable, - activation_fn, - ), - ), - nn.Sequential( # Sequential, - nn.Sequential( # Sequential, - LambdaMap(lambda x: x, # ConcatTable, - nn.Sequential( # Sequential, - nn.Conv2d(256, 256, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), - nn.BatchNorm2d(256), - activation_fn, - nn.Conv2d(256, 256, (3, 3), (2, 2), (1, 1), 1, 1, bias=False), - nn.BatchNorm2d(256), - activation_fn, - nn.Conv2d(256, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), - nn.BatchNorm2d(512), - ), - nn.Sequential( # Sequential, - nn.Conv2d(256, 512, (1, 1), (2, 2), (0, 0), 1, 1, bias=False), - nn.BatchNorm2d(512), - ), - ), - LambdaReduce(lambda x, y: x + y), # CAddTable, - activation_fn, - ), - nn.Sequential( # Sequential, - LambdaMap(lambda x: x, # ConcatTable, - nn.Sequential( # Sequential, - nn.Conv2d(512, 256, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), - nn.BatchNorm2d(256), - activation_fn, - nn.Conv2d(256, 256, (3, 3), (1, 1), (1, 1), 1, 1, bias=False), - nn.BatchNorm2d(256), - activation_fn, - nn.Conv2d(256, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), - nn.BatchNorm2d(512), - ), - Lambda(lambda x: x), # Identity, - ), - LambdaReduce(lambda x, y: x + y), # CAddTable, - activation_fn, - ), - nn.Sequential( # Sequential, - LambdaMap(lambda x: x, # ConcatTable, - nn.Sequential( # Sequential, - nn.Conv2d(512, 256, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), - nn.BatchNorm2d(256), - activation_fn, - nn.Conv2d(256, 256, (3, 3), (1, 1), (1, 1), 1, 1, bias=False), - nn.BatchNorm2d(256), - activation_fn, - nn.Conv2d(256, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), - nn.BatchNorm2d(512), - ), - Lambda(lambda x: x), # Identity, - ), - LambdaReduce(lambda x, y: x + y), # CAddTable, - activation_fn, - ), - nn.Sequential( # Sequential, - LambdaMap(lambda x: x, # ConcatTable, - nn.Sequential( # Sequential, - nn.Conv2d(512, 256, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), - nn.BatchNorm2d(256), - activation_fn, - nn.Conv2d(256, 256, (3, 3), (1, 1), (1, 1), 1, 1, bias=False), - nn.BatchNorm2d(256), - activation_fn, - nn.Conv2d(256, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), - nn.BatchNorm2d(512), - ), - Lambda(lambda x: x), # Identity, - ), - LambdaReduce(lambda x, y: x + y), # CAddTable, - activation_fn, - ), - ), - nn.Sequential( # Sequential, - nn.Sequential( # Sequential, - LambdaMap(lambda x: x, # ConcatTable, - nn.Sequential( # Sequential, - nn.Conv2d(512, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), - nn.BatchNorm2d(512), - activation_fn, - nn.Conv2d(512, 512, (3, 3), (2, 2), (1, 1), 1, 1, bias=False), - nn.BatchNorm2d(512), - activation_fn, - nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), - nn.BatchNorm2d(1024), - ), - nn.Sequential( # Sequential, - nn.Conv2d(512, 1024, (1, 1), (2, 2), (0, 0), 1, 1, bias=False), - nn.BatchNorm2d(1024), - ), - ), - LambdaReduce(lambda x, y: x + y), # CAddTable, - activation_fn, - ), - nn.Sequential( # Sequential, - LambdaMap(lambda x: x, # ConcatTable, - nn.Sequential( # Sequential, - nn.Conv2d(1024, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), - nn.BatchNorm2d(512), - activation_fn, - nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 1, bias=False), - nn.BatchNorm2d(512), - activation_fn, - nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), - nn.BatchNorm2d(1024), - ), - Lambda(lambda x: x), # Identity, - ), - LambdaReduce(lambda x, y: x + y), # CAddTable, - activation_fn, - ), - nn.Sequential( # Sequential, - LambdaMap(lambda x: x, # ConcatTable, - nn.Sequential( # Sequential, - nn.Conv2d(1024, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), - nn.BatchNorm2d(512), - activation_fn, - nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 1, bias=False), - nn.BatchNorm2d(512), - activation_fn, - nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), - nn.BatchNorm2d(1024), - ), - Lambda(lambda x: x), # Identity, - ), - LambdaReduce(lambda x, y: x + y), # CAddTable, - activation_fn, - ), - nn.Sequential( # Sequential, - LambdaMap(lambda x: x, # ConcatTable, - nn.Sequential( # Sequential, - nn.Conv2d(1024, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), - nn.BatchNorm2d(512), - activation_fn, - nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 1, bias=False), - nn.BatchNorm2d(512), - activation_fn, - nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), - nn.BatchNorm2d(1024), - ), - Lambda(lambda x: x), # Identity, - ), - LambdaReduce(lambda x, y: x + y), # CAddTable, - activation_fn, - ), - nn.Sequential( # Sequential, - LambdaMap(lambda x: x, # ConcatTable, - nn.Sequential( # Sequential, - nn.Conv2d(1024, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), - nn.BatchNorm2d(512), - activation_fn, - nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 1, bias=False), - nn.BatchNorm2d(512), - activation_fn, - nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), - nn.BatchNorm2d(1024), - ), - Lambda(lambda x: x), # Identity, - ), - LambdaReduce(lambda x, y: x + y), # CAddTable, - activation_fn, - ), - nn.Sequential( # Sequential, - LambdaMap(lambda x: x, # ConcatTable, - nn.Sequential( # Sequential, - nn.Conv2d(1024, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), - nn.BatchNorm2d(512), - activation_fn, - nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 1, bias=False), - nn.BatchNorm2d(512), - activation_fn, - nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), - nn.BatchNorm2d(1024), - ), - Lambda(lambda x: x), # Identity, - ), - LambdaReduce(lambda x, y: x + y), # CAddTable, - activation_fn, - ), - ), - nn.Sequential( # Sequential, - nn.Sequential( # Sequential, - LambdaMap(lambda x: x, # ConcatTable, - nn.Sequential( # Sequential, - nn.Conv2d(1024, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), - nn.BatchNorm2d(1024), - activation_fn, - nn.Conv2d(1024, 1024, (3, 3), (2, 2), (1, 1), 1, 1, bias=False), - nn.BatchNorm2d(1024), - activation_fn, - nn.Conv2d(1024, 2048, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), - nn.BatchNorm2d(2048), - ), - nn.Sequential( # Sequential, - nn.Conv2d(1024, 2048, (1, 1), (2, 2), (0, 0), 1, 1, bias=False), - nn.BatchNorm2d(2048), - ), - ), - LambdaReduce(lambda x, y: x + y), # CAddTable, - activation_fn, - ), - nn.Sequential( # Sequential, - LambdaMap(lambda x: x, # ConcatTable, - nn.Sequential( # Sequential, - nn.Conv2d(2048, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), - nn.BatchNorm2d(1024), - activation_fn, - nn.Conv2d(1024, 1024, (3, 3), (1, 1), (1, 1), 1, 1, bias=False), - nn.BatchNorm2d(1024), - activation_fn, - nn.Conv2d(1024, 2048, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), - nn.BatchNorm2d(2048), - ), - Lambda(lambda x: x), # Identity, - ), - LambdaReduce(lambda x, y: x + y), # CAddTable, - activation_fn, - ), - nn.Sequential( # Sequential, - LambdaMap(lambda x: x, # ConcatTable, - nn.Sequential( # Sequential, - nn.Conv2d(2048, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), - nn.BatchNorm2d(1024), - activation_fn, - nn.Conv2d(1024, 1024, (3, 3), (1, 1), (1, 1), 1, 1, bias=False), - nn.BatchNorm2d(1024), - activation_fn, - nn.Conv2d(1024, 2048, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), - nn.BatchNorm2d(2048), - ), - Lambda(lambda x: x), # Identity, - ), - LambdaReduce(lambda x, y: x + y), # CAddTable, - activation_fn, - ), - ), - ) - return features - - -class Wrn50_2(nn.Module): - def __init__(self, num_classes=1000, activation_fn=nn.ReLU(), drop_rate=0., global_pool='avg'): - super(Wrn50_2, self).__init__() - self.drop_rate = drop_rate - self.num_classes = num_classes - self.num_features = 2048 - self.global_pool = global_pool - self.features = wrn_50_2_features(activation_fn=activation_fn) - self.fc = nn.Linear(2048, num_classes) - - def get_classifier(self): - return self.fc - - def reset_classifier(self, num_classes, global_pool='avg'): - self.num_classes = num_classes - self.global_pool = global_pool - self.fc = nn.Linear(2048, num_classes) - - def forward_features(self, x, pool=True): - x = self.features(x) - if pool: - x = adaptive_avgmax_pool2d(x, self.global_pool) - x = x.view(x.size(0), -1) - return x - - def forward(self, x): - x = self.forward_features(x, pool=True) - if self.drop_rate > 0: - x = F.dropout(x, p=self.drop_rate, training=self.training) - x = self.fc(x) - return x - - -def wrn50_2(pretrained=False, num_classes=1000, **kwargs): - model = Wrn50_2(num_classes=num_classes, **kwargs) - if pretrained: - # Remap pretrained weights to match our class module with features + fc - pretrained_weights = model_zoo.load_url(model_urls['wrn50_2']) - feature_keys = filter(lambda k: '10.1.' not in k, pretrained_weights.keys()) - remapped_weights = OrderedDict() - for k in feature_keys: - remapped_weights['features.' + k] = pretrained_weights[k] - remapped_weights['fc.weight'] = pretrained_weights['10.1.weight'] - remapped_weights['fc.bias'] = pretrained_weights['10.1.bias'] - model.load_state_dict(remapped_weights) - return model \ No newline at end of file