You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

78 lines
2.5 KiB

// For licensing see accompanying LICENSE.md file.
// Copyright (C) 2022 Apple Inc. All Rights Reserved.
import CoreML
/// A class to manage and gate access to a Core ML model
///
/// It will automatically load a model into memory when needed or requested
/// It allows one to request to unload the model from memory
@available(iOS 16.2, macOS 13.1, *)
public final class ManagedMLModel: ResourceManaging {
/// The location of the model
var modelURL: URL
/// The configuration to be used when the model is loaded
var configuration: MLModelConfiguration
/// The loaded model (when loaded)
var loadedModel: MLModel?
/// Queue to protect access to loaded model
var queue: DispatchQueue
/// Create a managed model given its location and desired loaded configuration
///
/// - Parameters:
/// - url: The location of the model
/// - configuration: The configuration to be used when the model is loaded/used
/// - Returns: A managed model that has not been loaded
public init(modelAt url: URL, configuration: MLModelConfiguration) {
self.modelURL = url
self.configuration = configuration
self.loadedModel = nil
self.queue = DispatchQueue(label: "managed.\(url.lastPathComponent)")
}
/// Instantiation and load model into memory
public func loadResources() throws {
try queue.sync {
try loadModel()
}
}
/// Unload the model if it was loaded
public func unloadResources() {
queue.sync {
loadedModel = nil
}
}
/// Perform an operation with the managed model via a supplied closure.
/// The model will be loaded and supplied to the closure and should only be
/// used within the closure to ensure all resource management is synchronized
///
/// - Parameters:
/// - body: Closure which performs and action on a loaded model
/// - Returns: The result of the closure
/// - Throws: An error if the model cannot be loaded or if the closure throws
public func perform<R>(_ body: (MLModel) throws -> R) throws -> R {
return try queue.sync {
try autoreleasepool {
try loadModel()
return try body(loadedModel!)
}
}
}
private func loadModel() throws {
if loadedModel == nil {
loadedModel = try MLModel(contentsOf: modelURL,
configuration: configuration)
}
}
}