|
|
|
@ -5,7 +5,6 @@ import torch
|
|
|
|
|
import torch.nn as nn
|
|
|
|
|
import torch.utils.model_zoo as model_zoo
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
pretrained_settings = {
|
|
|
|
|
'pnasnet5large': {
|
|
|
|
|
'imagenet': {
|
|
|
|
@ -292,6 +291,8 @@ class PNASNet5Large(nn.Module):
|
|
|
|
|
def __init__(self, num_classes=1001):
|
|
|
|
|
super(PNASNet5Large, self).__init__()
|
|
|
|
|
self.num_classes = num_classes
|
|
|
|
|
self.num_features = 4320
|
|
|
|
|
|
|
|
|
|
self.conv_0 = nn.Sequential(OrderedDict([
|
|
|
|
|
('conv', nn.Conv2d(3, 96, kernel_size=3, stride=2, bias=False)),
|
|
|
|
|
('bn', nn.BatchNorm2d(96, eps=0.001))
|
|
|
|
@ -335,9 +336,20 @@ class PNASNet5Large(nn.Module):
|
|
|
|
|
self.relu = nn.ReLU()
|
|
|
|
|
self.avg_pool = nn.AvgPool2d(11, stride=1, padding=0)
|
|
|
|
|
self.dropout = nn.Dropout(0.5)
|
|
|
|
|
self.last_linear = nn.Linear(4320, num_classes)
|
|
|
|
|
self.last_linear = nn.Linear(self.num_features, num_classes)
|
|
|
|
|
|
|
|
|
|
def get_classifier(self):
|
|
|
|
|
return self.last_linear
|
|
|
|
|
|
|
|
|
|
def reset_classifier(self, num_classes):
|
|
|
|
|
self.num_classes = num_classes
|
|
|
|
|
del self.last_linear
|
|
|
|
|
if num_classes:
|
|
|
|
|
self.last_linear = nn.Linear(self.num_features, num_classes)
|
|
|
|
|
else:
|
|
|
|
|
self.last_linear = None
|
|
|
|
|
|
|
|
|
|
def features(self, x):
|
|
|
|
|
def forward_features(self, x, pool=True):
|
|
|
|
|
x_conv_0 = self.conv_0(x)
|
|
|
|
|
x_stem_0 = self.cell_stem_0(x_conv_0)
|
|
|
|
|
x_stem_1 = self.cell_stem_1(x_conv_0, x_stem_0)
|
|
|
|
@ -353,19 +365,16 @@ class PNASNet5Large(nn.Module):
|
|
|
|
|
x_cell_9 = self.cell_9(x_cell_7, x_cell_8)
|
|
|
|
|
x_cell_10 = self.cell_10(x_cell_8, x_cell_9)
|
|
|
|
|
x_cell_11 = self.cell_11(x_cell_9, x_cell_10)
|
|
|
|
|
return x_cell_11
|
|
|
|
|
|
|
|
|
|
def logits(self, features):
|
|
|
|
|
x = self.relu(features)
|
|
|
|
|
x = self.relu(x_cell_11)
|
|
|
|
|
if pool:
|
|
|
|
|
x = self.avg_pool(x)
|
|
|
|
|
x = x.view(x.size(0), -1)
|
|
|
|
|
x = self.dropout(x)
|
|
|
|
|
x = self.last_linear(x)
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
def forward(self, input):
|
|
|
|
|
x = self.features(input)
|
|
|
|
|
x = self.logits(x)
|
|
|
|
|
x = self.forward_features(input)
|
|
|
|
|
x = self.dropout(x)
|
|
|
|
|
x = self.last_linear(x)
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -375,7 +384,7 @@ def pnasnet5large(num_classes=1001, pretrained='imagenet'):
|
|
|
|
|
<https://arxiv.org/abs/1712.00559>`_ paper.
|
|
|
|
|
"""
|
|
|
|
|
if pretrained:
|
|
|
|
|
settings = pretrained_settings['pnasnet5large'][pretrained]
|
|
|
|
|
settings = pretrained_settings['pnasnet5large']['imagenet']
|
|
|
|
|
assert num_classes == settings[
|
|
|
|
|
'num_classes'], 'num_classes should be {}, but is {}'.format(
|
|
|
|
|
settings['num_classes'], num_classes)
|
|
|
|
@ -384,18 +393,12 @@ def pnasnet5large(num_classes=1001, pretrained='imagenet'):
|
|
|
|
|
model = PNASNet5Large(num_classes=1001)
|
|
|
|
|
model.load_state_dict(model_zoo.load_url(settings['url']))
|
|
|
|
|
|
|
|
|
|
if pretrained == 'imagenet':
|
|
|
|
|
#if pretrained == 'imagenet':
|
|
|
|
|
new_last_linear = nn.Linear(model.last_linear.in_features, 1000)
|
|
|
|
|
new_last_linear.weight.data = model.last_linear.weight.data[1:]
|
|
|
|
|
new_last_linear.bias.data = model.last_linear.bias.data[1:]
|
|
|
|
|
model.last_linear = new_last_linear
|
|
|
|
|
|
|
|
|
|
model.input_space = settings['input_space']
|
|
|
|
|
model.input_size = settings['input_size']
|
|
|
|
|
model.input_range = settings['input_range']
|
|
|
|
|
|
|
|
|
|
model.mean = settings['mean']
|
|
|
|
|
model.std = settings['std']
|
|
|
|
|
else:
|
|
|
|
|
model = PNASNet5Large(num_classes=num_classes)
|
|
|
|
|
return model
|
|
|
|
|