You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
pytorch-image-models/models/model_factory.py

156 lines
7.6 KiB

import torch
import os
from collections import OrderedDict
from .inception_v4 import inception_v4
from .inception_resnet_v2 import inception_resnet_v2
from .densenet import densenet161, densenet121, densenet169, densenet201
from .resnet import resnet18, resnet34, resnet50, resnet101, resnet152, \
resnext50_32x4d, resnext101_32x4d, resnext101_64x4d, resnext152_32x4d
from .dpn import dpn68, dpn68b, dpn92, dpn98, dpn131, dpn107
from .senet import seresnet18, seresnet34, seresnet50, seresnet101, seresnet152, \
seresnext26_32x4d, seresnext50_32x4d, seresnext101_32x4d
#from .resnext import resnext50, resnext101, resnext152
from .xception import xception
from .pnasnet import pnasnet5large
model_config_dict = {
'resnet18': {
'model_name': 'resnet18', 'num_classes': 1000, 'input_size': 224, 'normalizer': 'tv'},
'resnet34': {
'model_name': 'resnet34', 'num_classes': 1000, 'input_size': 224, 'normalizer': 'tv'},
'resnet50': {
'model_name': 'resnet50', 'num_classes': 1000, 'input_size': 224, 'normalizer': 'tv'},
'resnet101': {
'model_name': 'resnet101', 'num_classes': 1000, 'input_size': 224, 'normalizer': 'tv'},
'resnet152': {
'model_name': 'resnet152', 'num_classes': 1000, 'input_size': 224, 'normalizer': 'tv'},
'densenet121': {
'model_name': 'densenet121', 'num_classes': 1000, 'input_size': 224, 'normalizer': 'tv'},
'densenet169': {
'model_name': 'densenet169', 'num_classes': 1000, 'input_size': 224, 'normalizer': 'tv'},
'densenet201': {
'model_name': 'densenet201', 'num_classes': 1000, 'input_size': 224, 'normalizer': 'tv'},
'densenet161': {
'model_name': 'densenet161', 'num_classes': 1000, 'input_size': 224, 'normalizer': 'tv'},
'dpn107': {
'model_name': 'dpn107', 'num_classes': 1000, 'input_size': 299, 'normalizer': 'dpn'},
'dpn92_extra': {
'model_name': 'dpn92', 'num_classes': 1000, 'input_size': 299, 'normalizer': 'dpn'},
'dpn92': {
'model_name': 'dpn92', 'num_classes': 1000, 'input_size': 299, 'normalizer': 'dpn'},
'dpn68': {
'model_name': 'dpn68', 'num_classes': 1000, 'input_size': 299, 'normalizer': 'dpn'},
'dpn68b': {
'model_name': 'dpn68b', 'num_classes': 1000, 'input_size': 299, 'normalizer': 'dpn'},
'dpn68b_extra': {
'model_name': 'dpn68b', 'num_classes': 1000, 'input_size': 299, 'normalizer': 'dpn'},
'inception_resnet_v2': {
'model_name': 'inception_resnet_v2', 'num_classes': 1000, 'input_size': 299, 'normalizer': 'le'},
'xception': {
'model_name': 'xception', 'num_classes': 1000, 'input_size': 299, 'normalizer': 'le'},
'pnasnet5large': {
'model_name': 'pnasnet5large', 'num_classes': 1000, 'input_size': 331, 'normalizer': 'le'}
}
def create_model(
model_name='resnet50',
pretrained=None,
num_classes=1000,
checkpoint_path='',
**kwargs):
if model_name == 'dpn68':
model = dpn68(num_classes=num_classes, pretrained=pretrained)
elif model_name == 'dpn68b':
model = dpn68b(num_classes=num_classes, pretrained=pretrained)
elif model_name == 'dpn92':
model = dpn92(num_classes=num_classes, pretrained=pretrained)
elif model_name == 'dpn98':
model = dpn98(num_classes=num_classes, pretrained=pretrained)
elif model_name == 'dpn131':
model = dpn131(num_classes=num_classes, pretrained=pretrained)
elif model_name == 'dpn107':
model = dpn107(num_classes=num_classes, pretrained=pretrained)
elif model_name == 'resnet18':
model = resnet18(num_classes=num_classes, pretrained=pretrained, **kwargs)
elif model_name == 'resnet34':
model = resnet34(num_classes=num_classes, pretrained=pretrained, **kwargs)
elif model_name == 'resnet50':
model = resnet50(num_classes=num_classes, pretrained=pretrained, **kwargs)
elif model_name == 'resnet101':
model = resnet101(num_classes=num_classes, pretrained=pretrained, **kwargs)
elif model_name == 'resnet152':
model = resnet152(num_classes=num_classes, pretrained=pretrained, **kwargs)
elif model_name == 'densenet121':
model = densenet121(num_classes=num_classes, pretrained=pretrained, **kwargs)
elif model_name == 'densenet161':
model = densenet161(num_classes=num_classes, pretrained=pretrained, **kwargs)
elif model_name == 'densenet169':
model = densenet169(num_classes=num_classes, pretrained=pretrained, **kwargs)
elif model_name == 'densenet201':
model = densenet201(num_classes=num_classes, pretrained=pretrained, **kwargs)
elif model_name == 'inception_resnet_v2':
model = inception_resnet_v2(num_classes=num_classes, pretrained=pretrained, **kwargs)
elif model_name == 'inception_v4':
model = inception_v4(num_classes=num_classes, pretrained=pretrained, **kwargs)
elif model_name == 'seresnet18':
model = seresnet18(num_classes=num_classes, pretrained=pretrained, **kwargs)
elif model_name == 'seresnet34':
model = seresnet34(num_classes=num_classes, pretrained=pretrained, **kwargs)
elif model_name == 'seresnet50':
model = seresnet50(num_classes=num_classes, pretrained=pretrained, **kwargs)
elif model_name == 'seresnet101':
model = seresnet101(num_classes=num_classes, pretrained=pretrained, **kwargs)
elif model_name == 'seresnet152':
model = seresnet152(num_classes=num_classes, pretrained=pretrained, **kwargs)
elif model_name == 'seresnext26_32x4d':
model = seresnext26_32x4d(num_classes=num_classes, pretrained=pretrained, **kwargs)
elif model_name == 'seresnext50_32x4d':
model = seresnext50_32x4d(num_classes=num_classes, pretrained=pretrained, **kwargs)
elif model_name == 'seresnext101_32x4d':
model = seresnext101_32x4d(num_classes=num_classes, pretrained=pretrained, **kwargs)
elif model_name == 'resnext50_32x4d':
model = resnext50_32x4d(num_classes=num_classes, pretrained=pretrained, **kwargs)
elif model_name == 'resnext101_32x4d':
model = resnext101_32x4d(num_classes=num_classes, pretrained=pretrained, **kwargs)
elif model_name == 'resnext101_64x4d':
model = resnext101_32x4d(num_classes=num_classes, pretrained=pretrained, **kwargs)
elif model_name == 'resnext152_32x4d':
model = resnext152_32x4d(num_classes=num_classes, pretrained=pretrained, **kwargs)
elif model_name == 'xception':
model = xception(num_classes=num_classes, pretrained=pretrained)
elif model_name == 'pnasnet5large':
model = pnasnet5large(num_classes=num_classes, pretrained=pretrained)
else:
assert False and "Invalid model"
if checkpoint_path and not pretrained:
print(checkpoint_path)
load_checkpoint(model, checkpoint_path)
return model
def load_checkpoint(model, checkpoint_path):
if checkpoint_path and os.path.isfile(checkpoint_path):
print("=> Loading checkpoint '{}'".format(checkpoint_path))
checkpoint = torch.load(checkpoint_path)
if isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
new_state_dict = OrderedDict()
for k, v in checkpoint['state_dict'].items():
if k.startswith('module'):
name = k[7:] # remove `module.`
else:
name = k
new_state_dict[name] = v
model.load_state_dict(new_state_dict)
else:
model.load_state_dict(checkpoint)
print("=> Loaded checkpoint '{}'".format(checkpoint_path))
return True
else:
print("=> Error: No checkpoint found at '{}'".format(checkpoint_path))
return False