pull/1330/merge
swenkel 1 year ago committed by GitHub
commit 58d9f09c86
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -0,0 +1,183 @@
import argparse
import os
from timm.models import create_model
import torch
import yaml
def argument_parser() -> dict:
"""Argument Parser
Returns
-------
config : dict
Python dict containing all settings
"""
parser = argparse.ArgumentParser()
parser.add_argument('--single_model',
action='store_true',
help='Flag to converts a single model \
(use full path + filename as \
--folder_path; input dimensions \
(--input_dim) and number of \
classes (--num_classes) are required)')
parser.add_argument('--model_name',
default=None,
type=str,
help='model name (e.g. mobilenetv3_large_100)\
required if single_model')
parser.add_argument('--folder_path',
type=str,
help='Path of folder containing training \
output (incl. yaml files)')
parser.add_argument('--best_model_only',
action='store_true',
help='Flag to export model_best.pth.tar in \
training folder only')
parser.add_argument('--last_model_only',
action='store_true',
help='Flag to export last.pth.tar in \
training folder only')
parser.add_argument('--input_dim',
default=None,
type=int,
help='Input shape of model assuming square\
input (example: --input_dim 224); \
if not provided, input dim is derived\
from folder name')
parser.add_argument('--num_classes',
type=int,
default=None,
help='Number of output classes')
args = parser.parse_args()
config = {}
config['single_model'] = args.single_model
config['best_model_only'] = args.best_model_only
config['last_model_only'] = args.last_model_only
if config['best_model_only'] and config['last_model_only']:
raise Exception('invalid options: best_model_only and\
last_model_only - choose one')
config['folder_path'] = args.folder_path
if not config['single_model']:
if config['folder_path'][-1] != '/':
config['folder_path'] += '/'
if not os.path.exists(config['folder_path']):
raise Exception("Folder containing training output\
does not exist")
if not os.path.exists(config['folder_path']+'args.yaml'):
raise Exception('args.yaml does not exist in folder')
else:
# derive data from folder name
# training folder name format:
# YYYYMMDD-hh:mm:ss-modelName_epochs-inputSize
folder_name = config['folder_path'].split('/')[-2]
timm_train_data = folder_name.split('-')
config['input_dim'] = int(timm_train_data[3])
config['timm_cfg'] = yaml.safe_load(open(
config['folder_path']+'args.yaml','r'))
config['timm_cfg_keys'] = tuple(config['timm_cfg'].keys())
if 'model' in config['timm_cfg_keys']:
config['model_name'] = config['timm_cfg']['model']
else:
config['model_name'] = timm_train_data[2]
files_in_folder = os.listdir(config['folder_path'])
config['models_in_folder'] = \
[file for file in files_in_folder if '.pth.tar' in file]
if len(config['models_in_folder']) == 0:
raise Exception('No models found in folder')
if config['last_model_only']:
if not 'last.pth.tar' in config['models_in_folder']:
raise Exception('model last.pth.tar not found')
else:
config['models_in_folder'] = ['last.pth.tar']
if config['best_model_only']:
if not 'model_best.pth.tar' in config['models_in_folder']:
raise Exception('model model_best.pth.tar not found')
else:
config['models_in_folder'] = ['model_best.pth.tar']
else:
if not os.path.exists(config['folder_path']):
raise Exception("Input model not found")
config['num_classes'] = args.num_classes
if not config['single_model']:
if config['num_classes'] is None:
if 'num_classes' in config['timm_cfg_keys']:
if config['timm_cfg']['num_classes'] is None:
raise Exception('Number of classes required')
else:
config['num_classes'] = config['timm_cfg']['num_classes']
else:
if config['num_classes'] is None:
raise Exception('Number of classes required')
if args.input_dim is None:
if config['single_model']:
raise Exception('Input dimension missing')
else:
config['input_dim'] = args.input_dim
if args.model_name is not None:
config['model_name'] = args.model_name
else:
if config['single_model']:
raise Exception('Model name missing')
return config
def convert_model(model_in : str,
model_out : str,
config : dict):
"""Loads timm model file and converts it to onnx
Parameters
----------
model_in : str
filepath of input model
model_out : str
filepath of converted model
config : dict
Python dict containing all settings
Returns
-------
None
"""
random_input = torch.randn(1,
3,
config['input_dim'],
config['input_dim'],
requires_grad=True)
model = create_model(config['model_name'],
checkpoint_path=model_in,
exportable=True,
num_classes=config['num_classes'])
model.eval()
torch.onnx.export(model,
random_input,
model_out,
export_params=True,
opset_version=13,
do_constant_folding=True,
input_names=['input'],
output_names=['output'])
def main():
config = argument_parser()
print('=' * 72)
print('Exporting models to ONNX')
if not config['single_model']:
for model_file in config['models_in_folder']:
print('Converting '+model_file+' to '+
str(model_file.split('.')[0])+'.onnx')
model_in = config['folder_path']+model_file
model_out = model_in.replace('pth.tar', 'onnx')
convert_model(model_in, model_out, config)
else:
model_in = config['folder_path']
model_out = model_in.replace('pth.tar', 'onnx')
convert_model(model_in, model_out, config)
print('=' * 72)
if __name__ == '__main__':
main()
Loading…
Cancel
Save