You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

76 lines
2.5 KiB

// For licensing see accompanying LICENSE.md file.
// Copyright (C) 2022 Apple Inc. All Rights Reserved.
import Foundation
import CoreML
/// A decoder model which produces RGB images from latent samples
@available(iOS 16.2, macOS 13.1, *)
public struct Decoder: ResourceManaging {
/// VAE decoder model
var model: ManagedMLModel
/// Create decoder from Core ML model
///
/// - Parameters:
/// - url: Location of compiled VAE decoder Core ML model
/// - configuration: configuration to be used when the model is loaded
/// - Returns: A decoder that will lazily load its required resources when needed or requested
public init(modelAt url: URL, configuration: MLModelConfiguration) {
self.model = ManagedMLModel(modelAt: url, configuration: configuration)
}
/// Ensure the model has been loaded into memory
public func loadResources() throws {
try model.loadResources()
}
/// Unload the underlying model to free up memory
public func unloadResources() {
model.unloadResources()
}
/// Batch decode latent samples into images
///
/// - Parameters:
/// - latents: Batch of latent samples to decode
/// - Returns: decoded images
public func decode(_ latents: [MLShapedArray<Float32>]) throws -> [CGImage] {
// Form batch inputs for model
let inputs: [MLFeatureProvider] = try latents.map { sample in
// Reference pipeline scales the latent samples before decoding
let sampleScaled = MLShapedArray<Float32>(
scalars: sample.scalars.map { $0 / 0.18215 },
shape: sample.shape)
let dict = [inputName: MLMultiArray(sampleScaled)]
return try MLDictionaryFeatureProvider(dictionary: dict)
}
let batch = MLArrayBatchProvider(array: inputs)
// Batch predict with model
let results = try model.perform { model in
try model.predictions(fromBatch: batch)
}
// Transform the outputs to CGImages
let images: [CGImage] = try (0..<results.count).map { i in
let result = results.features(at: i)
let outputName = result.featureNames.first!
let output = result.featureValue(for: outputName)!.multiArrayValue!
return try CGImage.fromShapedArray(MLShapedArray<Float32>(output))
}
return images
}
var inputName: String {
try! model.perform { model in
model.modelDescription.inputDescriptionsByName.first!.key
}
}
}