You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
pytorch-image-models/models/wrn50_2.py

393 lines
18 KiB

""" Pytorch Wide-Resnet-50-2
Sourced by running https://github.com/clcarwin/convert_torch_to_pytorch (MIT) on
https://github.com/szagoruyko/wide-residual-networks/blob/master/pretrained/README.md
License of above is, as of yet, unclear.
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.model_zoo as model_zoo
from functools import reduce
from collections import OrderedDict
from .adaptive_avgmax_pool import *
model_urls = {
'wrn50_2': 'https://www.dropbox.com/s/fe7rj3okz9rctn0/wrn50_2-d98ded61.pth?dl=1',
}
class LambdaBase(nn.Sequential):
def __init__(self, fn, *args):
super(LambdaBase, self).__init__(*args)
self.lambda_func = fn
def forward_prepare(self, input):
output = []
for module in self._modules.values():
output.append(module(input))
return output if output else input
class Lambda(LambdaBase):
def forward(self, input):
return self.lambda_func(self.forward_prepare(input))
class LambdaMap(LambdaBase):
def forward(self, input):
return list(map(self.lambda_func, self.forward_prepare(input)))
class LambdaReduce(LambdaBase):
def forward(self, input):
return reduce(self.lambda_func, self.forward_prepare(input))
def wrn_50_2_features(activation_fn=nn.ReLU()):
features = nn.Sequential( # Sequential,
nn.Conv2d(3, 64, (7, 7), (2, 2), (3, 3), 1, 1, bias=False),
nn.BatchNorm2d(64),
activation_fn,
nn.MaxPool2d((3, 3), (2, 2), (1, 1)),
nn.Sequential( # Sequential,
nn.Sequential( # Sequential,
LambdaMap(lambda x: x, # ConcatTable,
nn.Sequential( # Sequential,
nn.Conv2d(64, 128, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
nn.BatchNorm2d(128),
activation_fn,
nn.Conv2d(128, 128, (3, 3), (1, 1), (1, 1), 1, 1, bias=False),
nn.BatchNorm2d(128),
activation_fn,
nn.Conv2d(128, 256, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
nn.BatchNorm2d(256),
),
nn.Sequential( # Sequential,
nn.Conv2d(64, 256, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
nn.BatchNorm2d(256),
),
),
LambdaReduce(lambda x, y: x + y), # CAddTable,
activation_fn,
),
nn.Sequential( # Sequential,
LambdaMap(lambda x: x, # ConcatTable,
nn.Sequential( # Sequential,
nn.Conv2d(256, 128, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
nn.BatchNorm2d(128),
activation_fn,
nn.Conv2d(128, 128, (3, 3), (1, 1), (1, 1), 1, 1, bias=False),
nn.BatchNorm2d(128),
activation_fn,
nn.Conv2d(128, 256, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
nn.BatchNorm2d(256),
),
Lambda(lambda x: x), # Identity,
),
LambdaReduce(lambda x, y: x + y), # CAddTable,
activation_fn,
),
nn.Sequential( # Sequential,
LambdaMap(lambda x: x, # ConcatTable,
nn.Sequential( # Sequential,
nn.Conv2d(256, 128, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
nn.BatchNorm2d(128),
activation_fn,
nn.Conv2d(128, 128, (3, 3), (1, 1), (1, 1), 1, 1, bias=False),
nn.BatchNorm2d(128),
activation_fn,
nn.Conv2d(128, 256, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
nn.BatchNorm2d(256),
),
Lambda(lambda x: x), # Identity,
),
LambdaReduce(lambda x, y: x + y), # CAddTable,
activation_fn,
),
),
nn.Sequential( # Sequential,
nn.Sequential( # Sequential,
LambdaMap(lambda x: x, # ConcatTable,
nn.Sequential( # Sequential,
nn.Conv2d(256, 256, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
nn.BatchNorm2d(256),
activation_fn,
nn.Conv2d(256, 256, (3, 3), (2, 2), (1, 1), 1, 1, bias=False),
nn.BatchNorm2d(256),
activation_fn,
nn.Conv2d(256, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
nn.BatchNorm2d(512),
),
nn.Sequential( # Sequential,
nn.Conv2d(256, 512, (1, 1), (2, 2), (0, 0), 1, 1, bias=False),
nn.BatchNorm2d(512),
),
),
LambdaReduce(lambda x, y: x + y), # CAddTable,
activation_fn,
),
nn.Sequential( # Sequential,
LambdaMap(lambda x: x, # ConcatTable,
nn.Sequential( # Sequential,
nn.Conv2d(512, 256, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
nn.BatchNorm2d(256),
activation_fn,
nn.Conv2d(256, 256, (3, 3), (1, 1), (1, 1), 1, 1, bias=False),
nn.BatchNorm2d(256),
activation_fn,
nn.Conv2d(256, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
nn.BatchNorm2d(512),
),
Lambda(lambda x: x), # Identity,
),
LambdaReduce(lambda x, y: x + y), # CAddTable,
activation_fn,
),
nn.Sequential( # Sequential,
LambdaMap(lambda x: x, # ConcatTable,
nn.Sequential( # Sequential,
nn.Conv2d(512, 256, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
nn.BatchNorm2d(256),
activation_fn,
nn.Conv2d(256, 256, (3, 3), (1, 1), (1, 1), 1, 1, bias=False),
nn.BatchNorm2d(256),
activation_fn,
nn.Conv2d(256, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
nn.BatchNorm2d(512),
),
Lambda(lambda x: x), # Identity,
),
LambdaReduce(lambda x, y: x + y), # CAddTable,
activation_fn,
),
nn.Sequential( # Sequential,
LambdaMap(lambda x: x, # ConcatTable,
nn.Sequential( # Sequential,
nn.Conv2d(512, 256, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
nn.BatchNorm2d(256),
activation_fn,
nn.Conv2d(256, 256, (3, 3), (1, 1), (1, 1), 1, 1, bias=False),
nn.BatchNorm2d(256),
activation_fn,
nn.Conv2d(256, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
nn.BatchNorm2d(512),
),
Lambda(lambda x: x), # Identity,
),
LambdaReduce(lambda x, y: x + y), # CAddTable,
activation_fn,
),
),
nn.Sequential( # Sequential,
nn.Sequential( # Sequential,
LambdaMap(lambda x: x, # ConcatTable,
nn.Sequential( # Sequential,
nn.Conv2d(512, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
nn.BatchNorm2d(512),
activation_fn,
nn.Conv2d(512, 512, (3, 3), (2, 2), (1, 1), 1, 1, bias=False),
nn.BatchNorm2d(512),
activation_fn,
nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
nn.BatchNorm2d(1024),
),
nn.Sequential( # Sequential,
nn.Conv2d(512, 1024, (1, 1), (2, 2), (0, 0), 1, 1, bias=False),
nn.BatchNorm2d(1024),
),
),
LambdaReduce(lambda x, y: x + y), # CAddTable,
activation_fn,
),
nn.Sequential( # Sequential,
LambdaMap(lambda x: x, # ConcatTable,
nn.Sequential( # Sequential,
nn.Conv2d(1024, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
nn.BatchNorm2d(512),
activation_fn,
nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 1, bias=False),
nn.BatchNorm2d(512),
activation_fn,
nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
nn.BatchNorm2d(1024),
),
Lambda(lambda x: x), # Identity,
),
LambdaReduce(lambda x, y: x + y), # CAddTable,
activation_fn,
),
nn.Sequential( # Sequential,
LambdaMap(lambda x: x, # ConcatTable,
nn.Sequential( # Sequential,
nn.Conv2d(1024, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
nn.BatchNorm2d(512),
activation_fn,
nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 1, bias=False),
nn.BatchNorm2d(512),
activation_fn,
nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
nn.BatchNorm2d(1024),
),
Lambda(lambda x: x), # Identity,
),
LambdaReduce(lambda x, y: x + y), # CAddTable,
activation_fn,
),
nn.Sequential( # Sequential,
LambdaMap(lambda x: x, # ConcatTable,
nn.Sequential( # Sequential,
nn.Conv2d(1024, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
nn.BatchNorm2d(512),
activation_fn,
nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 1, bias=False),
nn.BatchNorm2d(512),
activation_fn,
nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
nn.BatchNorm2d(1024),
),
Lambda(lambda x: x), # Identity,
),
LambdaReduce(lambda x, y: x + y), # CAddTable,
activation_fn,
),
nn.Sequential( # Sequential,
LambdaMap(lambda x: x, # ConcatTable,
nn.Sequential( # Sequential,
nn.Conv2d(1024, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
nn.BatchNorm2d(512),
activation_fn,
nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 1, bias=False),
nn.BatchNorm2d(512),
activation_fn,
nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
nn.BatchNorm2d(1024),
),
Lambda(lambda x: x), # Identity,
),
LambdaReduce(lambda x, y: x + y), # CAddTable,
activation_fn,
),
nn.Sequential( # Sequential,
LambdaMap(lambda x: x, # ConcatTable,
nn.Sequential( # Sequential,
nn.Conv2d(1024, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
nn.BatchNorm2d(512),
activation_fn,
nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 1, bias=False),
nn.BatchNorm2d(512),
activation_fn,
nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
nn.BatchNorm2d(1024),
),
Lambda(lambda x: x), # Identity,
),
LambdaReduce(lambda x, y: x + y), # CAddTable,
activation_fn,
),
),
nn.Sequential( # Sequential,
nn.Sequential( # Sequential,
LambdaMap(lambda x: x, # ConcatTable,
nn.Sequential( # Sequential,
nn.Conv2d(1024, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
nn.BatchNorm2d(1024),
activation_fn,
nn.Conv2d(1024, 1024, (3, 3), (2, 2), (1, 1), 1, 1, bias=False),
nn.BatchNorm2d(1024),
activation_fn,
nn.Conv2d(1024, 2048, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
nn.BatchNorm2d(2048),
),
nn.Sequential( # Sequential,
nn.Conv2d(1024, 2048, (1, 1), (2, 2), (0, 0), 1, 1, bias=False),
nn.BatchNorm2d(2048),
),
),
LambdaReduce(lambda x, y: x + y), # CAddTable,
activation_fn,
),
nn.Sequential( # Sequential,
LambdaMap(lambda x: x, # ConcatTable,
nn.Sequential( # Sequential,
nn.Conv2d(2048, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
nn.BatchNorm2d(1024),
activation_fn,
nn.Conv2d(1024, 1024, (3, 3), (1, 1), (1, 1), 1, 1, bias=False),
nn.BatchNorm2d(1024),
activation_fn,
nn.Conv2d(1024, 2048, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
nn.BatchNorm2d(2048),
),
Lambda(lambda x: x), # Identity,
),
LambdaReduce(lambda x, y: x + y), # CAddTable,
activation_fn,
),
nn.Sequential( # Sequential,
LambdaMap(lambda x: x, # ConcatTable,
nn.Sequential( # Sequential,
nn.Conv2d(2048, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
nn.BatchNorm2d(1024),
activation_fn,
nn.Conv2d(1024, 1024, (3, 3), (1, 1), (1, 1), 1, 1, bias=False),
nn.BatchNorm2d(1024),
activation_fn,
nn.Conv2d(1024, 2048, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
nn.BatchNorm2d(2048),
),
Lambda(lambda x: x), # Identity,
),
LambdaReduce(lambda x, y: x + y), # CAddTable,
activation_fn,
),
),
)
return features
class Wrn50_2(nn.Module):
def __init__(self, num_classes=1000, activation_fn=nn.ReLU(), drop_rate=0., global_pool='avg'):
super(Wrn50_2, self).__init__()
self.drop_rate = drop_rate
self.num_classes = num_classes
self.num_features = 2048
self.global_pool = global_pool
self.features = wrn_50_2_features(activation_fn=activation_fn)
self.fc = nn.Linear(2048, num_classes)
def get_classifier(self):
return self.fc
def reset_classifier(self, num_classes, global_pool='avg'):
self.num_classes = num_classes
self.global_pool = global_pool
self.fc = nn.Linear(2048, num_classes)
def forward_features(self, x, pool=True):
x = self.features(x)
if pool:
x = adaptive_avgmax_pool2d(x, self.global_pool)
x = x.view(x.size(0), -1)
return x
def forward(self, x):
x = self.forward_features(x, pool=True)
if self.drop_rate > 0:
x = F.dropout(x, p=self.drop_rate, training=self.training)
x = self.fc(x)
return x
def wrn50_2(pretrained=False, num_classes=1000, **kwargs):
model = Wrn50_2(num_classes=num_classes, **kwargs)
if pretrained:
# Remap pretrained weights to match our class module with features + fc
pretrained_weights = model_zoo.load_url(model_urls['wrn50_2'])
feature_keys = filter(lambda k: '10.1.' not in k, pretrained_weights.keys())
remapped_weights = OrderedDict()
for k in feature_keys:
remapped_weights['features.' + k] = pretrained_weights[k]
remapped_weights['fc.weight'] = pretrained_weights['10.1.weight']
remapped_weights['fc.bias'] = pretrained_weights['10.1.bias']
model.load_state_dict(remapped_weights)
return model