Merge 16664ea31b
into 709d5e0d9d
commit
58d9f09c86
@ -0,0 +1,183 @@
|
|||||||
|
import argparse
|
||||||
|
import os
|
||||||
|
|
||||||
|
from timm.models import create_model
|
||||||
|
import torch
|
||||||
|
import yaml
|
||||||
|
|
||||||
|
|
||||||
|
def argument_parser() -> dict:
|
||||||
|
"""Argument Parser
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
config : dict
|
||||||
|
Python dict containing all settings
|
||||||
|
"""
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument('--single_model',
|
||||||
|
action='store_true',
|
||||||
|
help='Flag to converts a single model \
|
||||||
|
(use full path + filename as \
|
||||||
|
--folder_path; input dimensions \
|
||||||
|
(--input_dim) and number of \
|
||||||
|
classes (--num_classes) are required)')
|
||||||
|
parser.add_argument('--model_name',
|
||||||
|
default=None,
|
||||||
|
type=str,
|
||||||
|
help='model name (e.g. mobilenetv3_large_100)\
|
||||||
|
required if single_model')
|
||||||
|
parser.add_argument('--folder_path',
|
||||||
|
type=str,
|
||||||
|
help='Path of folder containing training \
|
||||||
|
output (incl. yaml files)')
|
||||||
|
parser.add_argument('--best_model_only',
|
||||||
|
action='store_true',
|
||||||
|
help='Flag to export model_best.pth.tar in \
|
||||||
|
training folder only')
|
||||||
|
parser.add_argument('--last_model_only',
|
||||||
|
action='store_true',
|
||||||
|
help='Flag to export last.pth.tar in \
|
||||||
|
training folder only')
|
||||||
|
parser.add_argument('--input_dim',
|
||||||
|
default=None,
|
||||||
|
type=int,
|
||||||
|
help='Input shape of model assuming square\
|
||||||
|
input (example: --input_dim 224); \
|
||||||
|
if not provided, input dim is derived\
|
||||||
|
from folder name')
|
||||||
|
parser.add_argument('--num_classes',
|
||||||
|
type=int,
|
||||||
|
default=None,
|
||||||
|
help='Number of output classes')
|
||||||
|
args = parser.parse_args()
|
||||||
|
config = {}
|
||||||
|
config['single_model'] = args.single_model
|
||||||
|
config['best_model_only'] = args.best_model_only
|
||||||
|
config['last_model_only'] = args.last_model_only
|
||||||
|
if config['best_model_only'] and config['last_model_only']:
|
||||||
|
raise Exception('invalid options: best_model_only and\
|
||||||
|
last_model_only - choose one')
|
||||||
|
config['folder_path'] = args.folder_path
|
||||||
|
|
||||||
|
if not config['single_model']:
|
||||||
|
if config['folder_path'][-1] != '/':
|
||||||
|
config['folder_path'] += '/'
|
||||||
|
if not os.path.exists(config['folder_path']):
|
||||||
|
raise Exception("Folder containing training output\
|
||||||
|
does not exist")
|
||||||
|
if not os.path.exists(config['folder_path']+'args.yaml'):
|
||||||
|
raise Exception('args.yaml does not exist in folder')
|
||||||
|
else:
|
||||||
|
# derive data from folder name
|
||||||
|
# training folder name format:
|
||||||
|
# YYYYMMDD-hh:mm:ss-modelName_epochs-inputSize
|
||||||
|
folder_name = config['folder_path'].split('/')[-2]
|
||||||
|
timm_train_data = folder_name.split('-')
|
||||||
|
config['input_dim'] = int(timm_train_data[3])
|
||||||
|
config['timm_cfg'] = yaml.safe_load(open(
|
||||||
|
config['folder_path']+'args.yaml','r'))
|
||||||
|
config['timm_cfg_keys'] = tuple(config['timm_cfg'].keys())
|
||||||
|
if 'model' in config['timm_cfg_keys']:
|
||||||
|
config['model_name'] = config['timm_cfg']['model']
|
||||||
|
else:
|
||||||
|
config['model_name'] = timm_train_data[2]
|
||||||
|
files_in_folder = os.listdir(config['folder_path'])
|
||||||
|
config['models_in_folder'] = \
|
||||||
|
[file for file in files_in_folder if '.pth.tar' in file]
|
||||||
|
if len(config['models_in_folder']) == 0:
|
||||||
|
raise Exception('No models found in folder')
|
||||||
|
if config['last_model_only']:
|
||||||
|
if not 'last.pth.tar' in config['models_in_folder']:
|
||||||
|
raise Exception('model last.pth.tar not found')
|
||||||
|
else:
|
||||||
|
config['models_in_folder'] = ['last.pth.tar']
|
||||||
|
if config['best_model_only']:
|
||||||
|
if not 'model_best.pth.tar' in config['models_in_folder']:
|
||||||
|
raise Exception('model model_best.pth.tar not found')
|
||||||
|
else:
|
||||||
|
config['models_in_folder'] = ['model_best.pth.tar']
|
||||||
|
else:
|
||||||
|
if not os.path.exists(config['folder_path']):
|
||||||
|
raise Exception("Input model not found")
|
||||||
|
config['num_classes'] = args.num_classes
|
||||||
|
if not config['single_model']:
|
||||||
|
if config['num_classes'] is None:
|
||||||
|
if 'num_classes' in config['timm_cfg_keys']:
|
||||||
|
if config['timm_cfg']['num_classes'] is None:
|
||||||
|
raise Exception('Number of classes required')
|
||||||
|
else:
|
||||||
|
config['num_classes'] = config['timm_cfg']['num_classes']
|
||||||
|
else:
|
||||||
|
if config['num_classes'] is None:
|
||||||
|
raise Exception('Number of classes required')
|
||||||
|
if args.input_dim is None:
|
||||||
|
if config['single_model']:
|
||||||
|
raise Exception('Input dimension missing')
|
||||||
|
else:
|
||||||
|
config['input_dim'] = args.input_dim
|
||||||
|
if args.model_name is not None:
|
||||||
|
config['model_name'] = args.model_name
|
||||||
|
else:
|
||||||
|
if config['single_model']:
|
||||||
|
raise Exception('Model name missing')
|
||||||
|
return config
|
||||||
|
|
||||||
|
|
||||||
|
def convert_model(model_in : str,
|
||||||
|
model_out : str,
|
||||||
|
config : dict):
|
||||||
|
"""Loads timm model file and converts it to onnx
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
model_in : str
|
||||||
|
filepath of input model
|
||||||
|
model_out : str
|
||||||
|
filepath of converted model
|
||||||
|
config : dict
|
||||||
|
Python dict containing all settings
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
None
|
||||||
|
"""
|
||||||
|
random_input = torch.randn(1,
|
||||||
|
3,
|
||||||
|
config['input_dim'],
|
||||||
|
config['input_dim'],
|
||||||
|
requires_grad=True)
|
||||||
|
model = create_model(config['model_name'],
|
||||||
|
checkpoint_path=model_in,
|
||||||
|
exportable=True,
|
||||||
|
num_classes=config['num_classes'])
|
||||||
|
model.eval()
|
||||||
|
torch.onnx.export(model,
|
||||||
|
random_input,
|
||||||
|
model_out,
|
||||||
|
export_params=True,
|
||||||
|
opset_version=13,
|
||||||
|
do_constant_folding=True,
|
||||||
|
input_names=['input'],
|
||||||
|
output_names=['output'])
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
config = argument_parser()
|
||||||
|
print('=' * 72)
|
||||||
|
print('Exporting models to ONNX')
|
||||||
|
if not config['single_model']:
|
||||||
|
for model_file in config['models_in_folder']:
|
||||||
|
print('Converting '+model_file+' to '+
|
||||||
|
str(model_file.split('.')[0])+'.onnx')
|
||||||
|
model_in = config['folder_path']+model_file
|
||||||
|
model_out = model_in.replace('pth.tar', 'onnx')
|
||||||
|
convert_model(model_in, model_out, config)
|
||||||
|
else:
|
||||||
|
model_in = config['folder_path']
|
||||||
|
model_out = model_in.replace('pth.tar', 'onnx')
|
||||||
|
convert_model(model_in, model_out, config)
|
||||||
|
print('=' * 72)
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
Loading…
Reference in new issue