You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
78 lines
2.5 KiB
78 lines
2.5 KiB
// For licensing see accompanying LICENSE.md file.
|
|
// Copyright (C) 2022 Apple Inc. All Rights Reserved.
|
|
|
|
import CoreML
|
|
|
|
/// A class to manage and gate access to a Core ML model
|
|
///
|
|
/// It will automatically load a model into memory when needed or requested
|
|
/// It allows one to request to unload the model from memory
|
|
@available(iOS 16.2, macOS 13.1, *)
|
|
public final class ManagedMLModel: ResourceManaging {
|
|
|
|
/// The location of the model
|
|
var modelURL: URL
|
|
|
|
/// The configuration to be used when the model is loaded
|
|
var configuration: MLModelConfiguration
|
|
|
|
/// The loaded model (when loaded)
|
|
var loadedModel: MLModel?
|
|
|
|
/// Queue to protect access to loaded model
|
|
var queue: DispatchQueue
|
|
|
|
/// Create a managed model given its location and desired loaded configuration
|
|
///
|
|
/// - Parameters:
|
|
/// - url: The location of the model
|
|
/// - configuration: The configuration to be used when the model is loaded/used
|
|
/// - Returns: A managed model that has not been loaded
|
|
public init(modelAt url: URL, configuration: MLModelConfiguration) {
|
|
self.modelURL = url
|
|
self.configuration = configuration
|
|
self.loadedModel = nil
|
|
self.queue = DispatchQueue(label: "managed.\(url.lastPathComponent)")
|
|
}
|
|
|
|
/// Instantiation and load model into memory
|
|
public func loadResources() throws {
|
|
try queue.sync {
|
|
try loadModel()
|
|
}
|
|
}
|
|
|
|
/// Unload the model if it was loaded
|
|
public func unloadResources() {
|
|
queue.sync {
|
|
loadedModel = nil
|
|
}
|
|
}
|
|
|
|
/// Perform an operation with the managed model via a supplied closure.
|
|
/// The model will be loaded and supplied to the closure and should only be
|
|
/// used within the closure to ensure all resource management is synchronized
|
|
///
|
|
/// - Parameters:
|
|
/// - body: Closure which performs and action on a loaded model
|
|
/// - Returns: The result of the closure
|
|
/// - Throws: An error if the model cannot be loaded or if the closure throws
|
|
public func perform<R>(_ body: (MLModel) throws -> R) throws -> R {
|
|
return try queue.sync {
|
|
try autoreleasepool {
|
|
try loadModel()
|
|
return try body(loadedModel!)
|
|
}
|
|
}
|
|
}
|
|
|
|
private func loadModel() throws {
|
|
if loadedModel == nil {
|
|
loadedModel = try MLModel(contentsOf: modelURL,
|
|
configuration: configuration)
|
|
}
|
|
}
|
|
|
|
|
|
}
|