Xception model working

pull/1/head
Ross Wightman 6 years ago
parent 1e23727f2f
commit 183d8e4aef

@ -10,7 +10,7 @@ from .dpn import dpn68, dpn68b, dpn92, dpn98, dpn131, dpn107
from .senet import seresnet18, seresnet34, seresnet50, seresnet101, seresnet152, \
seresnext26_32x4d, seresnext50_32x4d, seresnext101_32x4d
from .resnext import resnext50, resnext101, resnext152
from .xception import xception
model_config_dict = {
'resnet18': {
@ -45,6 +45,8 @@ model_config_dict = {
'model_name': 'dpn68b', 'num_classes': 1000, 'input_size': 299, 'normalizer': 'dpn'},
'inception_resnet_v2': {
'model_name': 'inception_resnet_v2', 'num_classes': 1000, 'input_size': 299, 'normalizer': 'le'},
'xception': {
'model_name': 'xception', 'num_classes': 1000, 'input_size': 299, 'normalizer': 'le'},
}
@ -121,6 +123,8 @@ def create_model(
model = resnext101(num_classes=num_classes, pretrained=pretrained, **kwargs)
elif model_name == 'resnext152':
model = resnext152(num_classes=num_classes, pretrained=pretrained, **kwargs)
elif model_name == 'xception':
model = xception(num_classes=num_classes, pretrained=pretrained)
else:
assert False and "Invalid model"

@ -162,14 +162,12 @@ class Xception(nn.Module):
self.fc = nn.Linear(2048, num_classes)
# #------- init weights --------
# for m in self.modules():
# if isinstance(m, nn.Conv2d):
# n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
# m.weight.data.normal_(0, math.sqrt(2. / n))
# elif isinstance(m, nn.BatchNorm2d):
# m.weight.data.fill_(1)
# m.bias.data.zero_()
# #-----------------------------
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
def forward_features(self, input):
x = self.conv1(input)
@ -215,10 +213,10 @@ class Xception(nn.Module):
return x
def xception(num_classes=1000, pretrained='imagenet'):
def xception(num_classes=1000, pretrained=False):
model = Xception(num_classes=num_classes)
if pretrained:
config = pretrained_config['xception'][pretrained]
config = pretrained_config['xception']['imagenet']
assert num_classes == config['num_classes'], \
"num_classes should be {}, but is {}".format(config['num_classes'], num_classes)

Loading…
Cancel
Save