|
|
|
@ -10,7 +10,7 @@ from .dpn import dpn68, dpn68b, dpn92, dpn98, dpn131, dpn107
|
|
|
|
|
from .senet import seresnet18, seresnet34, seresnet50, seresnet101, seresnet152, \
|
|
|
|
|
seresnext26_32x4d, seresnext50_32x4d, seresnext101_32x4d
|
|
|
|
|
from .resnext import resnext50, resnext101, resnext152
|
|
|
|
|
|
|
|
|
|
from .xception import xception
|
|
|
|
|
|
|
|
|
|
model_config_dict = {
|
|
|
|
|
'resnet18': {
|
|
|
|
@ -45,6 +45,8 @@ model_config_dict = {
|
|
|
|
|
'model_name': 'dpn68b', 'num_classes': 1000, 'input_size': 299, 'normalizer': 'dpn'},
|
|
|
|
|
'inception_resnet_v2': {
|
|
|
|
|
'model_name': 'inception_resnet_v2', 'num_classes': 1000, 'input_size': 299, 'normalizer': 'le'},
|
|
|
|
|
'xception': {
|
|
|
|
|
'model_name': 'xception', 'num_classes': 1000, 'input_size': 299, 'normalizer': 'le'},
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -121,6 +123,8 @@ def create_model(
|
|
|
|
|
model = resnext101(num_classes=num_classes, pretrained=pretrained, **kwargs)
|
|
|
|
|
elif model_name == 'resnext152':
|
|
|
|
|
model = resnext152(num_classes=num_classes, pretrained=pretrained, **kwargs)
|
|
|
|
|
elif model_name == 'xception':
|
|
|
|
|
model = xception(num_classes=num_classes, pretrained=pretrained)
|
|
|
|
|
else:
|
|
|
|
|
assert False and "Invalid model"
|
|
|
|
|
|
|
|
|
|