diff --git a/models/model_factory.py b/models/model_factory.py index a40c7638..68d6bd6d 100644 --- a/models/model_factory.py +++ b/models/model_factory.py @@ -10,7 +10,7 @@ from .dpn import dpn68, dpn68b, dpn92, dpn98, dpn131, dpn107 from .senet import seresnet18, seresnet34, seresnet50, seresnet101, seresnet152, \ seresnext26_32x4d, seresnext50_32x4d, seresnext101_32x4d from .resnext import resnext50, resnext101, resnext152 - +from .xception import xception model_config_dict = { 'resnet18': { @@ -45,6 +45,8 @@ model_config_dict = { 'model_name': 'dpn68b', 'num_classes': 1000, 'input_size': 299, 'normalizer': 'dpn'}, 'inception_resnet_v2': { 'model_name': 'inception_resnet_v2', 'num_classes': 1000, 'input_size': 299, 'normalizer': 'le'}, + 'xception': { + 'model_name': 'xception', 'num_classes': 1000, 'input_size': 299, 'normalizer': 'le'}, } @@ -121,6 +123,8 @@ def create_model( model = resnext101(num_classes=num_classes, pretrained=pretrained, **kwargs) elif model_name == 'resnext152': model = resnext152(num_classes=num_classes, pretrained=pretrained, **kwargs) + elif model_name == 'xception': + model = xception(num_classes=num_classes, pretrained=pretrained) else: assert False and "Invalid model" diff --git a/models/xception.py b/models/xception.py index 8aca27d8..c4ae09fa 100644 --- a/models/xception.py +++ b/models/xception.py @@ -162,14 +162,12 @@ class Xception(nn.Module): self.fc = nn.Linear(2048, num_classes) # #------- init weights -------- - # for m in self.modules(): - # if isinstance(m, nn.Conv2d): - # n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels - # m.weight.data.normal_(0, math.sqrt(2. / n)) - # elif isinstance(m, nn.BatchNorm2d): - # m.weight.data.fill_(1) - # m.bias.data.zero_() - # #----------------------------- + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, nn.BatchNorm2d): + m.weight.data.fill_(1) + m.bias.data.zero_() def forward_features(self, input): x = self.conv1(input) @@ -215,10 +213,10 @@ class Xception(nn.Module): return x -def xception(num_classes=1000, pretrained='imagenet'): +def xception(num_classes=1000, pretrained=False): model = Xception(num_classes=num_classes) if pretrained: - config = pretrained_config['xception'][pretrained] + config = pretrained_config['xception']['imagenet'] assert num_classes == config['num_classes'], \ "num_classes should be {}, but is {}".format(config['num_classes'], num_classes)