parent
31055466fc
commit
e0cfeb7d8e
@ -1,393 +0,0 @@
|
||||
""" Pytorch Wide-Resnet-50-2
|
||||
Sourced by running https://github.com/clcarwin/convert_torch_to_pytorch (MIT) on
|
||||
https://github.com/szagoruyko/wide-residual-networks/blob/master/pretrained/README.md
|
||||
License of above is, as of yet, unclear.
|
||||
"""
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch.utils.model_zoo as model_zoo
|
||||
from functools import reduce
|
||||
from collections import OrderedDict
|
||||
from .adaptive_avgmax_pool import *
|
||||
|
||||
model_urls = {
|
||||
'wrn50_2': 'https://www.dropbox.com/s/fe7rj3okz9rctn0/wrn50_2-d98ded61.pth?dl=1',
|
||||
}
|
||||
|
||||
|
||||
class LambdaBase(nn.Sequential):
|
||||
def __init__(self, fn, *args):
|
||||
super(LambdaBase, self).__init__(*args)
|
||||
self.lambda_func = fn
|
||||
|
||||
def forward_prepare(self, input):
|
||||
output = []
|
||||
for module in self._modules.values():
|
||||
output.append(module(input))
|
||||
return output if output else input
|
||||
|
||||
|
||||
class Lambda(LambdaBase):
|
||||
def forward(self, input):
|
||||
return self.lambda_func(self.forward_prepare(input))
|
||||
|
||||
|
||||
class LambdaMap(LambdaBase):
|
||||
def forward(self, input):
|
||||
return list(map(self.lambda_func, self.forward_prepare(input)))
|
||||
|
||||
|
||||
class LambdaReduce(LambdaBase):
|
||||
def forward(self, input):
|
||||
return reduce(self.lambda_func, self.forward_prepare(input))
|
||||
|
||||
|
||||
def wrn_50_2_features(activation_fn=nn.ReLU()):
|
||||
features = nn.Sequential( # Sequential,
|
||||
nn.Conv2d(3, 64, (7, 7), (2, 2), (3, 3), 1, 1, bias=False),
|
||||
nn.BatchNorm2d(64),
|
||||
activation_fn,
|
||||
nn.MaxPool2d((3, 3), (2, 2), (1, 1)),
|
||||
nn.Sequential( # Sequential,
|
||||
nn.Sequential( # Sequential,
|
||||
LambdaMap(lambda x: x, # ConcatTable,
|
||||
nn.Sequential( # Sequential,
|
||||
nn.Conv2d(64, 128, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
|
||||
nn.BatchNorm2d(128),
|
||||
activation_fn,
|
||||
nn.Conv2d(128, 128, (3, 3), (1, 1), (1, 1), 1, 1, bias=False),
|
||||
nn.BatchNorm2d(128),
|
||||
activation_fn,
|
||||
nn.Conv2d(128, 256, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
|
||||
nn.BatchNorm2d(256),
|
||||
),
|
||||
nn.Sequential( # Sequential,
|
||||
nn.Conv2d(64, 256, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
|
||||
nn.BatchNorm2d(256),
|
||||
),
|
||||
),
|
||||
LambdaReduce(lambda x, y: x + y), # CAddTable,
|
||||
activation_fn,
|
||||
),
|
||||
nn.Sequential( # Sequential,
|
||||
LambdaMap(lambda x: x, # ConcatTable,
|
||||
nn.Sequential( # Sequential,
|
||||
nn.Conv2d(256, 128, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
|
||||
nn.BatchNorm2d(128),
|
||||
activation_fn,
|
||||
nn.Conv2d(128, 128, (3, 3), (1, 1), (1, 1), 1, 1, bias=False),
|
||||
nn.BatchNorm2d(128),
|
||||
activation_fn,
|
||||
nn.Conv2d(128, 256, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
|
||||
nn.BatchNorm2d(256),
|
||||
),
|
||||
Lambda(lambda x: x), # Identity,
|
||||
),
|
||||
LambdaReduce(lambda x, y: x + y), # CAddTable,
|
||||
activation_fn,
|
||||
),
|
||||
nn.Sequential( # Sequential,
|
||||
LambdaMap(lambda x: x, # ConcatTable,
|
||||
nn.Sequential( # Sequential,
|
||||
nn.Conv2d(256, 128, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
|
||||
nn.BatchNorm2d(128),
|
||||
activation_fn,
|
||||
nn.Conv2d(128, 128, (3, 3), (1, 1), (1, 1), 1, 1, bias=False),
|
||||
nn.BatchNorm2d(128),
|
||||
activation_fn,
|
||||
nn.Conv2d(128, 256, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
|
||||
nn.BatchNorm2d(256),
|
||||
),
|
||||
Lambda(lambda x: x), # Identity,
|
||||
),
|
||||
LambdaReduce(lambda x, y: x + y), # CAddTable,
|
||||
activation_fn,
|
||||
),
|
||||
),
|
||||
nn.Sequential( # Sequential,
|
||||
nn.Sequential( # Sequential,
|
||||
LambdaMap(lambda x: x, # ConcatTable,
|
||||
nn.Sequential( # Sequential,
|
||||
nn.Conv2d(256, 256, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
|
||||
nn.BatchNorm2d(256),
|
||||
activation_fn,
|
||||
nn.Conv2d(256, 256, (3, 3), (2, 2), (1, 1), 1, 1, bias=False),
|
||||
nn.BatchNorm2d(256),
|
||||
activation_fn,
|
||||
nn.Conv2d(256, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
|
||||
nn.BatchNorm2d(512),
|
||||
),
|
||||
nn.Sequential( # Sequential,
|
||||
nn.Conv2d(256, 512, (1, 1), (2, 2), (0, 0), 1, 1, bias=False),
|
||||
nn.BatchNorm2d(512),
|
||||
),
|
||||
),
|
||||
LambdaReduce(lambda x, y: x + y), # CAddTable,
|
||||
activation_fn,
|
||||
),
|
||||
nn.Sequential( # Sequential,
|
||||
LambdaMap(lambda x: x, # ConcatTable,
|
||||
nn.Sequential( # Sequential,
|
||||
nn.Conv2d(512, 256, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
|
||||
nn.BatchNorm2d(256),
|
||||
activation_fn,
|
||||
nn.Conv2d(256, 256, (3, 3), (1, 1), (1, 1), 1, 1, bias=False),
|
||||
nn.BatchNorm2d(256),
|
||||
activation_fn,
|
||||
nn.Conv2d(256, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
|
||||
nn.BatchNorm2d(512),
|
||||
),
|
||||
Lambda(lambda x: x), # Identity,
|
||||
),
|
||||
LambdaReduce(lambda x, y: x + y), # CAddTable,
|
||||
activation_fn,
|
||||
),
|
||||
nn.Sequential( # Sequential,
|
||||
LambdaMap(lambda x: x, # ConcatTable,
|
||||
nn.Sequential( # Sequential,
|
||||
nn.Conv2d(512, 256, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
|
||||
nn.BatchNorm2d(256),
|
||||
activation_fn,
|
||||
nn.Conv2d(256, 256, (3, 3), (1, 1), (1, 1), 1, 1, bias=False),
|
||||
nn.BatchNorm2d(256),
|
||||
activation_fn,
|
||||
nn.Conv2d(256, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
|
||||
nn.BatchNorm2d(512),
|
||||
),
|
||||
Lambda(lambda x: x), # Identity,
|
||||
),
|
||||
LambdaReduce(lambda x, y: x + y), # CAddTable,
|
||||
activation_fn,
|
||||
),
|
||||
nn.Sequential( # Sequential,
|
||||
LambdaMap(lambda x: x, # ConcatTable,
|
||||
nn.Sequential( # Sequential,
|
||||
nn.Conv2d(512, 256, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
|
||||
nn.BatchNorm2d(256),
|
||||
activation_fn,
|
||||
nn.Conv2d(256, 256, (3, 3), (1, 1), (1, 1), 1, 1, bias=False),
|
||||
nn.BatchNorm2d(256),
|
||||
activation_fn,
|
||||
nn.Conv2d(256, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
|
||||
nn.BatchNorm2d(512),
|
||||
),
|
||||
Lambda(lambda x: x), # Identity,
|
||||
),
|
||||
LambdaReduce(lambda x, y: x + y), # CAddTable,
|
||||
activation_fn,
|
||||
),
|
||||
),
|
||||
nn.Sequential( # Sequential,
|
||||
nn.Sequential( # Sequential,
|
||||
LambdaMap(lambda x: x, # ConcatTable,
|
||||
nn.Sequential( # Sequential,
|
||||
nn.Conv2d(512, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
|
||||
nn.BatchNorm2d(512),
|
||||
activation_fn,
|
||||
nn.Conv2d(512, 512, (3, 3), (2, 2), (1, 1), 1, 1, bias=False),
|
||||
nn.BatchNorm2d(512),
|
||||
activation_fn,
|
||||
nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
|
||||
nn.BatchNorm2d(1024),
|
||||
),
|
||||
nn.Sequential( # Sequential,
|
||||
nn.Conv2d(512, 1024, (1, 1), (2, 2), (0, 0), 1, 1, bias=False),
|
||||
nn.BatchNorm2d(1024),
|
||||
),
|
||||
),
|
||||
LambdaReduce(lambda x, y: x + y), # CAddTable,
|
||||
activation_fn,
|
||||
),
|
||||
nn.Sequential( # Sequential,
|
||||
LambdaMap(lambda x: x, # ConcatTable,
|
||||
nn.Sequential( # Sequential,
|
||||
nn.Conv2d(1024, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
|
||||
nn.BatchNorm2d(512),
|
||||
activation_fn,
|
||||
nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 1, bias=False),
|
||||
nn.BatchNorm2d(512),
|
||||
activation_fn,
|
||||
nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
|
||||
nn.BatchNorm2d(1024),
|
||||
),
|
||||
Lambda(lambda x: x), # Identity,
|
||||
),
|
||||
LambdaReduce(lambda x, y: x + y), # CAddTable,
|
||||
activation_fn,
|
||||
),
|
||||
nn.Sequential( # Sequential,
|
||||
LambdaMap(lambda x: x, # ConcatTable,
|
||||
nn.Sequential( # Sequential,
|
||||
nn.Conv2d(1024, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
|
||||
nn.BatchNorm2d(512),
|
||||
activation_fn,
|
||||
nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 1, bias=False),
|
||||
nn.BatchNorm2d(512),
|
||||
activation_fn,
|
||||
nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
|
||||
nn.BatchNorm2d(1024),
|
||||
),
|
||||
Lambda(lambda x: x), # Identity,
|
||||
),
|
||||
LambdaReduce(lambda x, y: x + y), # CAddTable,
|
||||
activation_fn,
|
||||
),
|
||||
nn.Sequential( # Sequential,
|
||||
LambdaMap(lambda x: x, # ConcatTable,
|
||||
nn.Sequential( # Sequential,
|
||||
nn.Conv2d(1024, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
|
||||
nn.BatchNorm2d(512),
|
||||
activation_fn,
|
||||
nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 1, bias=False),
|
||||
nn.BatchNorm2d(512),
|
||||
activation_fn,
|
||||
nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
|
||||
nn.BatchNorm2d(1024),
|
||||
),
|
||||
Lambda(lambda x: x), # Identity,
|
||||
),
|
||||
LambdaReduce(lambda x, y: x + y), # CAddTable,
|
||||
activation_fn,
|
||||
),
|
||||
nn.Sequential( # Sequential,
|
||||
LambdaMap(lambda x: x, # ConcatTable,
|
||||
nn.Sequential( # Sequential,
|
||||
nn.Conv2d(1024, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
|
||||
nn.BatchNorm2d(512),
|
||||
activation_fn,
|
||||
nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 1, bias=False),
|
||||
nn.BatchNorm2d(512),
|
||||
activation_fn,
|
||||
nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
|
||||
nn.BatchNorm2d(1024),
|
||||
),
|
||||
Lambda(lambda x: x), # Identity,
|
||||
),
|
||||
LambdaReduce(lambda x, y: x + y), # CAddTable,
|
||||
activation_fn,
|
||||
),
|
||||
nn.Sequential( # Sequential,
|
||||
LambdaMap(lambda x: x, # ConcatTable,
|
||||
nn.Sequential( # Sequential,
|
||||
nn.Conv2d(1024, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
|
||||
nn.BatchNorm2d(512),
|
||||
activation_fn,
|
||||
nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 1, bias=False),
|
||||
nn.BatchNorm2d(512),
|
||||
activation_fn,
|
||||
nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
|
||||
nn.BatchNorm2d(1024),
|
||||
),
|
||||
Lambda(lambda x: x), # Identity,
|
||||
),
|
||||
LambdaReduce(lambda x, y: x + y), # CAddTable,
|
||||
activation_fn,
|
||||
),
|
||||
),
|
||||
nn.Sequential( # Sequential,
|
||||
nn.Sequential( # Sequential,
|
||||
LambdaMap(lambda x: x, # ConcatTable,
|
||||
nn.Sequential( # Sequential,
|
||||
nn.Conv2d(1024, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
|
||||
nn.BatchNorm2d(1024),
|
||||
activation_fn,
|
||||
nn.Conv2d(1024, 1024, (3, 3), (2, 2), (1, 1), 1, 1, bias=False),
|
||||
nn.BatchNorm2d(1024),
|
||||
activation_fn,
|
||||
nn.Conv2d(1024, 2048, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
|
||||
nn.BatchNorm2d(2048),
|
||||
),
|
||||
nn.Sequential( # Sequential,
|
||||
nn.Conv2d(1024, 2048, (1, 1), (2, 2), (0, 0), 1, 1, bias=False),
|
||||
nn.BatchNorm2d(2048),
|
||||
),
|
||||
),
|
||||
LambdaReduce(lambda x, y: x + y), # CAddTable,
|
||||
activation_fn,
|
||||
),
|
||||
nn.Sequential( # Sequential,
|
||||
LambdaMap(lambda x: x, # ConcatTable,
|
||||
nn.Sequential( # Sequential,
|
||||
nn.Conv2d(2048, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
|
||||
nn.BatchNorm2d(1024),
|
||||
activation_fn,
|
||||
nn.Conv2d(1024, 1024, (3, 3), (1, 1), (1, 1), 1, 1, bias=False),
|
||||
nn.BatchNorm2d(1024),
|
||||
activation_fn,
|
||||
nn.Conv2d(1024, 2048, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
|
||||
nn.BatchNorm2d(2048),
|
||||
),
|
||||
Lambda(lambda x: x), # Identity,
|
||||
),
|
||||
LambdaReduce(lambda x, y: x + y), # CAddTable,
|
||||
activation_fn,
|
||||
),
|
||||
nn.Sequential( # Sequential,
|
||||
LambdaMap(lambda x: x, # ConcatTable,
|
||||
nn.Sequential( # Sequential,
|
||||
nn.Conv2d(2048, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
|
||||
nn.BatchNorm2d(1024),
|
||||
activation_fn,
|
||||
nn.Conv2d(1024, 1024, (3, 3), (1, 1), (1, 1), 1, 1, bias=False),
|
||||
nn.BatchNorm2d(1024),
|
||||
activation_fn,
|
||||
nn.Conv2d(1024, 2048, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
|
||||
nn.BatchNorm2d(2048),
|
||||
),
|
||||
Lambda(lambda x: x), # Identity,
|
||||
),
|
||||
LambdaReduce(lambda x, y: x + y), # CAddTable,
|
||||
activation_fn,
|
||||
),
|
||||
),
|
||||
)
|
||||
return features
|
||||
|
||||
|
||||
class Wrn50_2(nn.Module):
|
||||
def __init__(self, num_classes=1000, activation_fn=nn.ReLU(), drop_rate=0., global_pool='avg'):
|
||||
super(Wrn50_2, self).__init__()
|
||||
self.drop_rate = drop_rate
|
||||
self.num_classes = num_classes
|
||||
self.num_features = 2048
|
||||
self.global_pool = global_pool
|
||||
self.features = wrn_50_2_features(activation_fn=activation_fn)
|
||||
self.fc = nn.Linear(2048, num_classes)
|
||||
|
||||
def get_classifier(self):
|
||||
return self.fc
|
||||
|
||||
def reset_classifier(self, num_classes, global_pool='avg'):
|
||||
self.num_classes = num_classes
|
||||
self.global_pool = global_pool
|
||||
self.fc = nn.Linear(2048, num_classes)
|
||||
|
||||
def forward_features(self, x, pool=True):
|
||||
x = self.features(x)
|
||||
if pool:
|
||||
x = adaptive_avgmax_pool2d(x, self.global_pool)
|
||||
x = x.view(x.size(0), -1)
|
||||
return x
|
||||
|
||||
def forward(self, x):
|
||||
x = self.forward_features(x, pool=True)
|
||||
if self.drop_rate > 0:
|
||||
x = F.dropout(x, p=self.drop_rate, training=self.training)
|
||||
x = self.fc(x)
|
||||
return x
|
||||
|
||||
|
||||
def wrn50_2(pretrained=False, num_classes=1000, **kwargs):
|
||||
model = Wrn50_2(num_classes=num_classes, **kwargs)
|
||||
if pretrained:
|
||||
# Remap pretrained weights to match our class module with features + fc
|
||||
pretrained_weights = model_zoo.load_url(model_urls['wrn50_2'])
|
||||
feature_keys = filter(lambda k: '10.1.' not in k, pretrained_weights.keys())
|
||||
remapped_weights = OrderedDict()
|
||||
for k in feature_keys:
|
||||
remapped_weights['features.' + k] = pretrained_weights[k]
|
||||
remapped_weights['fc.weight'] = pretrained_weights['10.1.weight']
|
||||
remapped_weights['fc.bias'] = pretrained_weights['10.1.bias']
|
||||
model.load_state_dict(remapped_weights)
|
||||
return model
|
Loading…
Reference in new issue