Merge 16664ea31b
into 709d5e0d9d
commit
58d9f09c86
@ -0,0 +1,183 @@
|
||||
import argparse
|
||||
import os
|
||||
|
||||
from timm.models import create_model
|
||||
import torch
|
||||
import yaml
|
||||
|
||||
|
||||
def argument_parser() -> dict:
|
||||
"""Argument Parser
|
||||
|
||||
Returns
|
||||
-------
|
||||
config : dict
|
||||
Python dict containing all settings
|
||||
"""
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--single_model',
|
||||
action='store_true',
|
||||
help='Flag to converts a single model \
|
||||
(use full path + filename as \
|
||||
--folder_path; input dimensions \
|
||||
(--input_dim) and number of \
|
||||
classes (--num_classes) are required)')
|
||||
parser.add_argument('--model_name',
|
||||
default=None,
|
||||
type=str,
|
||||
help='model name (e.g. mobilenetv3_large_100)\
|
||||
required if single_model')
|
||||
parser.add_argument('--folder_path',
|
||||
type=str,
|
||||
help='Path of folder containing training \
|
||||
output (incl. yaml files)')
|
||||
parser.add_argument('--best_model_only',
|
||||
action='store_true',
|
||||
help='Flag to export model_best.pth.tar in \
|
||||
training folder only')
|
||||
parser.add_argument('--last_model_only',
|
||||
action='store_true',
|
||||
help='Flag to export last.pth.tar in \
|
||||
training folder only')
|
||||
parser.add_argument('--input_dim',
|
||||
default=None,
|
||||
type=int,
|
||||
help='Input shape of model assuming square\
|
||||
input (example: --input_dim 224); \
|
||||
if not provided, input dim is derived\
|
||||
from folder name')
|
||||
parser.add_argument('--num_classes',
|
||||
type=int,
|
||||
default=None,
|
||||
help='Number of output classes')
|
||||
args = parser.parse_args()
|
||||
config = {}
|
||||
config['single_model'] = args.single_model
|
||||
config['best_model_only'] = args.best_model_only
|
||||
config['last_model_only'] = args.last_model_only
|
||||
if config['best_model_only'] and config['last_model_only']:
|
||||
raise Exception('invalid options: best_model_only and\
|
||||
last_model_only - choose one')
|
||||
config['folder_path'] = args.folder_path
|
||||
|
||||
if not config['single_model']:
|
||||
if config['folder_path'][-1] != '/':
|
||||
config['folder_path'] += '/'
|
||||
if not os.path.exists(config['folder_path']):
|
||||
raise Exception("Folder containing training output\
|
||||
does not exist")
|
||||
if not os.path.exists(config['folder_path']+'args.yaml'):
|
||||
raise Exception('args.yaml does not exist in folder')
|
||||
else:
|
||||
# derive data from folder name
|
||||
# training folder name format:
|
||||
# YYYYMMDD-hh:mm:ss-modelName_epochs-inputSize
|
||||
folder_name = config['folder_path'].split('/')[-2]
|
||||
timm_train_data = folder_name.split('-')
|
||||
config['input_dim'] = int(timm_train_data[3])
|
||||
config['timm_cfg'] = yaml.safe_load(open(
|
||||
config['folder_path']+'args.yaml','r'))
|
||||
config['timm_cfg_keys'] = tuple(config['timm_cfg'].keys())
|
||||
if 'model' in config['timm_cfg_keys']:
|
||||
config['model_name'] = config['timm_cfg']['model']
|
||||
else:
|
||||
config['model_name'] = timm_train_data[2]
|
||||
files_in_folder = os.listdir(config['folder_path'])
|
||||
config['models_in_folder'] = \
|
||||
[file for file in files_in_folder if '.pth.tar' in file]
|
||||
if len(config['models_in_folder']) == 0:
|
||||
raise Exception('No models found in folder')
|
||||
if config['last_model_only']:
|
||||
if not 'last.pth.tar' in config['models_in_folder']:
|
||||
raise Exception('model last.pth.tar not found')
|
||||
else:
|
||||
config['models_in_folder'] = ['last.pth.tar']
|
||||
if config['best_model_only']:
|
||||
if not 'model_best.pth.tar' in config['models_in_folder']:
|
||||
raise Exception('model model_best.pth.tar not found')
|
||||
else:
|
||||
config['models_in_folder'] = ['model_best.pth.tar']
|
||||
else:
|
||||
if not os.path.exists(config['folder_path']):
|
||||
raise Exception("Input model not found")
|
||||
config['num_classes'] = args.num_classes
|
||||
if not config['single_model']:
|
||||
if config['num_classes'] is None:
|
||||
if 'num_classes' in config['timm_cfg_keys']:
|
||||
if config['timm_cfg']['num_classes'] is None:
|
||||
raise Exception('Number of classes required')
|
||||
else:
|
||||
config['num_classes'] = config['timm_cfg']['num_classes']
|
||||
else:
|
||||
if config['num_classes'] is None:
|
||||
raise Exception('Number of classes required')
|
||||
if args.input_dim is None:
|
||||
if config['single_model']:
|
||||
raise Exception('Input dimension missing')
|
||||
else:
|
||||
config['input_dim'] = args.input_dim
|
||||
if args.model_name is not None:
|
||||
config['model_name'] = args.model_name
|
||||
else:
|
||||
if config['single_model']:
|
||||
raise Exception('Model name missing')
|
||||
return config
|
||||
|
||||
|
||||
def convert_model(model_in : str,
|
||||
model_out : str,
|
||||
config : dict):
|
||||
"""Loads timm model file and converts it to onnx
|
||||
|
||||
Parameters
|
||||
----------
|
||||
model_in : str
|
||||
filepath of input model
|
||||
model_out : str
|
||||
filepath of converted model
|
||||
config : dict
|
||||
Python dict containing all settings
|
||||
|
||||
Returns
|
||||
-------
|
||||
None
|
||||
"""
|
||||
random_input = torch.randn(1,
|
||||
3,
|
||||
config['input_dim'],
|
||||
config['input_dim'],
|
||||
requires_grad=True)
|
||||
model = create_model(config['model_name'],
|
||||
checkpoint_path=model_in,
|
||||
exportable=True,
|
||||
num_classes=config['num_classes'])
|
||||
model.eval()
|
||||
torch.onnx.export(model,
|
||||
random_input,
|
||||
model_out,
|
||||
export_params=True,
|
||||
opset_version=13,
|
||||
do_constant_folding=True,
|
||||
input_names=['input'],
|
||||
output_names=['output'])
|
||||
|
||||
|
||||
def main():
|
||||
config = argument_parser()
|
||||
print('=' * 72)
|
||||
print('Exporting models to ONNX')
|
||||
if not config['single_model']:
|
||||
for model_file in config['models_in_folder']:
|
||||
print('Converting '+model_file+' to '+
|
||||
str(model_file.split('.')[0])+'.onnx')
|
||||
model_in = config['folder_path']+model_file
|
||||
model_out = model_in.replace('pth.tar', 'onnx')
|
||||
convert_model(model_in, model_out, config)
|
||||
else:
|
||||
model_in = config['folder_path']
|
||||
model_out = model_in.replace('pth.tar', 'onnx')
|
||||
convert_model(model_in, model_out, config)
|
||||
print('=' * 72)
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
Loading…
Reference in new issue