You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

103 lines
3.2 KiB

#
# For licensing see accompanying LICENSE.md file.
# Copyright (C) 2022 Apple Inc. All Rights Reserved.
#
import coremltools as ct
import logging
logging.basicConfig()
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
import numpy as np
import os
import time
class CoreMLModel:
""" Wrapper for running CoreML models using coremltools
"""
def __init__(self, model_path, compute_unit):
assert os.path.exists(model_path) and model_path.endswith(".mlpackage")
logger.info(f"Loading {model_path}")
start = time.time()
self.model = ct.models.MLModel(
model_path, compute_units=ct.ComputeUnit[compute_unit])
load_time = time.time() - start
logger.info(f"Done. Took {load_time:.1f} seconds.")
if load_time > LOAD_TIME_INFO_MSG_TRIGGER:
logger.info(
"Loading a CoreML model through coremltools triggers compilation every time. "
"The Swift package we provide uses precompiled Core ML models (.mlmodelc) to avoid compile-on-load."
)
DTYPE_MAP = {
65552: np.float16,
65568: np.float32,
131104: np.int32,
}
self.expected_inputs = {
input_tensor.name: {
"shape": tuple(input_tensor.type.multiArrayType.shape),
"dtype": DTYPE_MAP[input_tensor.type.multiArrayType.dataType],
}
for input_tensor in self.model._spec.description.input
}
def _verify_inputs(self, **kwargs):
for k, v in kwargs.items():
if k in self.expected_inputs:
if not isinstance(v, np.ndarray):
raise TypeError(
f"Expected numpy.ndarray, got {v} for input: {k}")
expected_dtype = self.expected_inputs[k]["dtype"]
if not v.dtype == expected_dtype:
raise TypeError(
f"Expected dtype {expected_dtype}, got {v.dtype} for input: {k}"
)
expected_shape = self.expected_inputs[k]["shape"]
if not v.shape == expected_shape:
raise TypeError(
f"Expected shape {expected_shape}, got {v.shape} for input: {k}"
)
else:
raise ValueError("Received unexpected input kwarg: {k}")
def __call__(self, **kwargs):
self._verify_inputs(**kwargs)
return self.model.predict(kwargs)
LOAD_TIME_INFO_MSG_TRIGGER = 10 # seconds
def _load_mlpackage(submodule_name, mlpackages_dir, model_version,
compute_unit):
""" Load Core ML (mlpackage) models from disk (As exported by torch2coreml.py)
"""
logger.info(f"Loading {submodule_name} mlpackage")
fname = f"Stable_Diffusion_version_{model_version}_{submodule_name}.mlpackage".replace(
"/", "_")
mlpackage_path = os.path.join(mlpackages_dir, fname)
if not os.path.exists(mlpackage_path):
raise FileNotFoundError(
f"{submodule_name} CoreML model doesn't exist at {mlpackage_path}")
return CoreMLModel(mlpackage_path, compute_unit)
def get_available_compute_units():
return tuple(cu for cu in ct.ComputeUnit._member_names_)