diff --git a/data/transforms.py b/data/transforms.py index f1222e2b..90419ae0 100644 --- a/data/transforms.py +++ b/data/transforms.py @@ -20,7 +20,7 @@ def get_model_meanstd(model_name): model_name = model_name.lower() if 'dpn' in model_name: return IMAGENET_DPN_MEAN, IMAGENET_DPN_STD - elif 'ception' in model_name: + elif 'ception' in model_name or 'nasnet' in model_name: return IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD else: return IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD @@ -30,7 +30,7 @@ def get_model_mean(model_name): model_name = model_name.lower() if 'dpn' in model_name: return IMAGENET_DPN_STD - elif 'ception' in model_name: + elif 'ception' in model_name or 'nasnet' in model_name: return IMAGENET_INCEPTION_MEAN else: return IMAGENET_DEFAULT_MEAN @@ -40,7 +40,7 @@ def get_model_std(model_name): model_name = model_name.lower() if 'dpn' in model_name: return IMAGENET_DEFAULT_STD - elif 'ception' in model_name: + elif 'ception' in model_name or 'nasnet' in model_name: return IMAGENET_INCEPTION_STD else: return IMAGENET_DEFAULT_STD diff --git a/models/model_factory.py b/models/model_factory.py index 68d6bd6d..c2a47fed 100644 --- a/models/model_factory.py +++ b/models/model_factory.py @@ -11,6 +11,7 @@ from .senet import seresnet18, seresnet34, seresnet50, seresnet101, seresnet152, seresnext26_32x4d, seresnext50_32x4d, seresnext101_32x4d from .resnext import resnext50, resnext101, resnext152 from .xception import xception +from .pnasnet import pnasnet5large model_config_dict = { 'resnet18': { @@ -47,6 +48,8 @@ model_config_dict = { 'model_name': 'inception_resnet_v2', 'num_classes': 1000, 'input_size': 299, 'normalizer': 'le'}, 'xception': { 'model_name': 'xception', 'num_classes': 1000, 'input_size': 299, 'normalizer': 'le'}, + 'pnasnet5large': { + 'model_name': 'pnasnet5large', 'num_classes': 1000, 'input_size': 331, 'normalizer': 'le'} } @@ -125,6 +128,8 @@ def create_model( model = resnext152(num_classes=num_classes, pretrained=pretrained, **kwargs) elif model_name == 'xception': model = xception(num_classes=num_classes, pretrained=pretrained) + elif model_name == 'pnasnet5large': + model = pnasnet5large(num_classes=num_classes, pretrained=pretrained) else: assert False and "Invalid model" diff --git a/models/pnasnet.py b/models/pnasnet.py index c169c695..6aebb772 100644 --- a/models/pnasnet.py +++ b/models/pnasnet.py @@ -5,7 +5,6 @@ import torch import torch.nn as nn import torch.utils.model_zoo as model_zoo - pretrained_settings = { 'pnasnet5large': { 'imagenet': { @@ -292,6 +291,8 @@ class PNASNet5Large(nn.Module): def __init__(self, num_classes=1001): super(PNASNet5Large, self).__init__() self.num_classes = num_classes + self.num_features = 4320 + self.conv_0 = nn.Sequential(OrderedDict([ ('conv', nn.Conv2d(3, 96, kernel_size=3, stride=2, bias=False)), ('bn', nn.BatchNorm2d(96, eps=0.001)) @@ -335,9 +336,20 @@ class PNASNet5Large(nn.Module): self.relu = nn.ReLU() self.avg_pool = nn.AvgPool2d(11, stride=1, padding=0) self.dropout = nn.Dropout(0.5) - self.last_linear = nn.Linear(4320, num_classes) + self.last_linear = nn.Linear(self.num_features, num_classes) + + def get_classifier(self): + return self.last_linear + + def reset_classifier(self, num_classes): + self.num_classes = num_classes + del self.last_linear + if num_classes: + self.last_linear = nn.Linear(self.num_features, num_classes) + else: + self.last_linear = None - def features(self, x): + def forward_features(self, x, pool=True): x_conv_0 = self.conv_0(x) x_stem_0 = self.cell_stem_0(x_conv_0) x_stem_1 = self.cell_stem_1(x_conv_0, x_stem_0) @@ -353,19 +365,16 @@ class PNASNet5Large(nn.Module): x_cell_9 = self.cell_9(x_cell_7, x_cell_8) x_cell_10 = self.cell_10(x_cell_8, x_cell_9) x_cell_11 = self.cell_11(x_cell_9, x_cell_10) - return x_cell_11 - - def logits(self, features): - x = self.relu(features) - x = self.avg_pool(x) - x = x.view(x.size(0), -1) - x = self.dropout(x) - x = self.last_linear(x) + x = self.relu(x_cell_11) + if pool: + x = self.avg_pool(x) + x = x.view(x.size(0), -1) return x def forward(self, input): - x = self.features(input) - x = self.logits(x) + x = self.forward_features(input) + x = self.dropout(x) + x = self.last_linear(x) return x @@ -375,7 +384,7 @@ def pnasnet5large(num_classes=1001, pretrained='imagenet'): `_ paper. """ if pretrained: - settings = pretrained_settings['pnasnet5large'][pretrained] + settings = pretrained_settings['pnasnet5large']['imagenet'] assert num_classes == settings[ 'num_classes'], 'num_classes should be {}, but is {}'.format( settings['num_classes'], num_classes) @@ -384,18 +393,12 @@ def pnasnet5large(num_classes=1001, pretrained='imagenet'): model = PNASNet5Large(num_classes=1001) model.load_state_dict(model_zoo.load_url(settings['url'])) - if pretrained == 'imagenet': - new_last_linear = nn.Linear(model.last_linear.in_features, 1000) - new_last_linear.weight.data = model.last_linear.weight.data[1:] - new_last_linear.bias.data = model.last_linear.bias.data[1:] - model.last_linear = new_last_linear - - model.input_space = settings['input_space'] - model.input_size = settings['input_size'] - model.input_range = settings['input_range'] + #if pretrained == 'imagenet': + new_last_linear = nn.Linear(model.last_linear.in_features, 1000) + new_last_linear.weight.data = model.last_linear.weight.data[1:] + new_last_linear.bias.data = model.last_linear.bias.data[1:] + model.last_linear = new_last_linear - model.mean = settings['mean'] - model.std = settings['std'] else: model = PNASNet5Large(num_classes=num_classes) return model diff --git a/models/resnext.py b/models/resnext.py index aafcd93b..57cb79f8 100644 --- a/models/resnext.py +++ b/models/resnext.py @@ -142,7 +142,6 @@ def resnext50(cardinality=32, base_width=4, pretrained=False, **kwargs): Args: cardinality (int): Cardinality of the aggregated transform base_width (int): Base width of the grouped convolution - shortcut ('A'|'B'|'C'): 'B' use 1x1 conv to downsample, 'C' use 1x1 conv on every residual connection """ model = ResNeXt( ResNeXtBottleneckC, [3, 4, 6, 3], cardinality=cardinality, base_width=base_width, **kwargs) @@ -155,7 +154,6 @@ def resnext101(cardinality=32, base_width=4, pretrained=False, **kwargs): Args: cardinality (int): Cardinality of the aggregated transform base_width (int): Base width of the grouped convolution - shortcut ('A'|'B'|'C'): 'B' use 1x1 conv to downsample, 'C' use 1x1 conv on every residual connection """ model = ResNeXt( ResNeXtBottleneckC, [3, 4, 23, 3], cardinality=cardinality, base_width=base_width, **kwargs) @@ -168,7 +166,6 @@ def resnext152(cardinality=32, base_width=4, pretrained=False, **kwargs): Args: cardinality (int): Cardinality of the aggregated transform base_width (int): Base width of the grouped convolution - shortcut ('A'|'B'|'C'): 'B' use 1x1 conv to downsample, 'C' use 1x1 conv on every residual connection """ model = ResNeXt( ResNeXtBottleneckC, [3, 8, 36, 3], cardinality=cardinality, base_width=base_width, **kwargs) diff --git a/models/xception.py b/models/xception.py index c4ae09fa..97b3947d 100644 --- a/models/xception.py +++ b/models/xception.py @@ -127,6 +127,7 @@ class Xception(nn.Module): """ super(Xception, self).__init__() self.num_classes = num_classes + self.num_features = 2048 self.conv1 = nn.Conv2d(3, 32, 3, 2, 0, bias=False) self.bn1 = nn.BatchNorm2d(32) @@ -156,10 +157,10 @@ class Xception(nn.Module): self.bn3 = nn.BatchNorm2d(1536) # do relu here - self.conv4 = SeparableConv2d(1536, 2048, 3, 1, 1) - self.bn4 = nn.BatchNorm2d(2048) + self.conv4 = SeparableConv2d(1536, self.num_features, 3, 1, 1) + self.bn4 = nn.BatchNorm2d(self.num_features) - self.fc = nn.Linear(2048, num_classes) + self.fc = nn.Linear(self.num_features, num_classes) # #------- init weights -------- for m in self.modules(): @@ -169,7 +170,18 @@ class Xception(nn.Module): m.weight.data.fill_(1) m.bias.data.zero_() - def forward_features(self, input): + def get_classifier(self): + return self.fc + + def reset_classifier(self, num_classes): + self.num_classes = num_classes + del self.fc + if num_classes: + self.fc = nn.Linear(self.num_features, num_classes) + else: + self.fc = None + + def forward_features(self, input, pool=True): x = self.conv1(input) x = self.bn1(x) x = self.relu(x) @@ -197,19 +209,16 @@ class Xception(nn.Module): x = self.conv4(x) x = self.bn4(x) - return x - - def logits(self, features): - x = self.relu(features) + x = self.relu(x) - x = F.adaptive_avg_pool2d(x, (1, 1)) - x = x.view(x.size(0), -1) - x = self.last_linear(x) + if pool: + x = F.adaptive_avg_pool2d(x, (1, 1)) + x = x.view(x.size(0), -1) return x def forward(self, input): x = self.forward_features(input) - x = self.logits(x) + x = self.fc(x) return x @@ -223,13 +232,4 @@ def xception(num_classes=1000, pretrained=False): model = Xception(num_classes=num_classes) model.load_state_dict(model_zoo.load_url(config['url'])) - model.input_space = config['input_space'] - model.input_size = config['input_size'] - model.input_range = config['input_range'] - model.mean = config['mean'] - model.std = config['std'] - - # TODO: ugly - model.last_linear = model.fc - del model.fc return model