You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
103 lines
3.2 KiB
103 lines
3.2 KiB
2 years ago
|
#
|
||
|
# For licensing see accompanying LICENSE.md file.
|
||
|
# Copyright (C) 2022 Apple Inc. All Rights Reserved.
|
||
|
#
|
||
|
|
||
|
import coremltools as ct
|
||
|
|
||
|
import logging
|
||
|
|
||
|
logging.basicConfig()
|
||
|
logger = logging.getLogger(__name__)
|
||
|
logger.setLevel(logging.INFO)
|
||
|
|
||
|
import numpy as np
|
||
|
|
||
|
import os
|
||
|
import time
|
||
|
|
||
|
|
||
|
class CoreMLModel:
|
||
|
""" Wrapper for running CoreML models using coremltools
|
||
|
"""
|
||
|
|
||
|
def __init__(self, model_path, compute_unit):
|
||
|
assert os.path.exists(model_path) and model_path.endswith(".mlpackage")
|
||
|
|
||
|
logger.info(f"Loading {model_path}")
|
||
|
|
||
|
start = time.time()
|
||
|
self.model = ct.models.MLModel(
|
||
|
model_path, compute_units=ct.ComputeUnit[compute_unit])
|
||
|
load_time = time.time() - start
|
||
|
logger.info(f"Done. Took {load_time:.1f} seconds.")
|
||
|
|
||
|
if load_time > LOAD_TIME_INFO_MSG_TRIGGER:
|
||
|
logger.info(
|
||
|
"Loading a CoreML model through coremltools triggers compilation every time. "
|
||
|
"The Swift package we provide uses precompiled Core ML models (.mlmodelc) to avoid compile-on-load."
|
||
|
)
|
||
|
|
||
|
|
||
|
DTYPE_MAP = {
|
||
|
65552: np.float16,
|
||
|
65568: np.float32,
|
||
|
131104: np.int32,
|
||
|
}
|
||
|
|
||
|
self.expected_inputs = {
|
||
|
input_tensor.name: {
|
||
|
"shape": tuple(input_tensor.type.multiArrayType.shape),
|
||
|
"dtype": DTYPE_MAP[input_tensor.type.multiArrayType.dataType],
|
||
|
}
|
||
|
for input_tensor in self.model._spec.description.input
|
||
|
}
|
||
|
|
||
|
def _verify_inputs(self, **kwargs):
|
||
|
for k, v in kwargs.items():
|
||
|
if k in self.expected_inputs:
|
||
|
if not isinstance(v, np.ndarray):
|
||
|
raise TypeError(
|
||
|
f"Expected numpy.ndarray, got {v} for input: {k}")
|
||
|
|
||
|
expected_dtype = self.expected_inputs[k]["dtype"]
|
||
|
if not v.dtype == expected_dtype:
|
||
|
raise TypeError(
|
||
|
f"Expected dtype {expected_dtype}, got {v.dtype} for input: {k}"
|
||
|
)
|
||
|
|
||
|
expected_shape = self.expected_inputs[k]["shape"]
|
||
|
if not v.shape == expected_shape:
|
||
|
raise TypeError(
|
||
|
f"Expected shape {expected_shape}, got {v.shape} for input: {k}"
|
||
|
)
|
||
|
else:
|
||
|
raise ValueError("Received unexpected input kwarg: {k}")
|
||
|
|
||
|
def __call__(self, **kwargs):
|
||
|
self._verify_inputs(**kwargs)
|
||
|
return self.model.predict(kwargs)
|
||
|
|
||
|
|
||
|
LOAD_TIME_INFO_MSG_TRIGGER = 10 # seconds
|
||
|
|
||
|
|
||
|
def _load_mlpackage(submodule_name, mlpackages_dir, model_version,
|
||
|
compute_unit):
|
||
|
""" Load Core ML (mlpackage) models from disk (As exported by torch2coreml.py)
|
||
|
"""
|
||
|
logger.info(f"Loading {submodule_name} mlpackage")
|
||
|
|
||
|
fname = f"Stable_Diffusion_version_{model_version}_{submodule_name}.mlpackage".replace(
|
||
|
"/", "_")
|
||
|
mlpackage_path = os.path.join(mlpackages_dir, fname)
|
||
|
|
||
|
if not os.path.exists(mlpackage_path):
|
||
|
raise FileNotFoundError(
|
||
|
f"{submodule_name} CoreML model doesn't exist at {mlpackage_path}")
|
||
|
|
||
|
return CoreMLModel(mlpackage_path, compute_unit)
|
||
|
|
||
|
def get_available_compute_units():
|
||
|
return tuple(cu for cu in ct.ComputeUnit._member_names_)
|