You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

81 lines
3.0 KiB

// For licensing see accompanying LICENSE.md file.
// Copyright (C) 2022 Apple Inc. All Rights Reserved.
import Foundation
import CoreML
@available(iOS 16.0, macOS 13.0, *)
/// Encoder, currently supports image2image
public struct Encoder {
public enum Error: String, Swift.Error {
case latentOutputNotValid
case batchLatentOutputEmpty
}
/// VAE encoder model + post math and adding noise from schedular
var model: MLModel
/// Create decoder from Core ML model
///
/// - Parameters
/// - model: Core ML model for VAE decoder
public init(model: MLModel) {
self.model = model
}
/// Prediction queue
let queue = DispatchQueue(label: "encoder.predict")
/// Batch encode latent samples into images
/// - Parameters:
/// - image: image used for image2image
/// - diagonalNoise: random noise for `DiagonalGaussianDistribution` operation
/// - noise: random noise for initial latent space based on strength argument
/// - alphasCumprodStep: calculations using the scheduler traditionally calculated in the pipeline in pyTorch Diffusers library.
/// - Returns: The encoded latent space as MLShapedArray
public func encode(
image: CGImage,
diagonalNoise: MLShapedArray<Float32>,
noise: MLShapedArray<Float32>,
alphasCumprodStep: AlphasCumprodCalculation
) throws -> MLShapedArray<Float32> {
let sample = try image.plannerRGBShapedArray
let sqrtAlphasCumprod = MLShapedArray(scalars: [alphasCumprodStep.sqrtAlphasCumprod], shape: [1, 1])
let sqrtOneMinusAlphasCumprod = MLShapedArray(scalars: [alphasCumprodStep.sqrtOneMinusAlphasCumprod], shape: [1, 1])
let dict: [String: Any] = [
"sample": MLMultiArray(sample),
"diagonalNoise": MLMultiArray(diagonalNoise),
"noise": MLMultiArray(noise),
"sqrtAlphasCumprod": MLMultiArray(sqrtAlphasCumprod),
"sqrtOneMinusAlphasCumprod": MLMultiArray(sqrtOneMinusAlphasCumprod),
]
let featureProvider = try MLDictionaryFeatureProvider(dictionary: dict)
let batch = MLArrayBatchProvider(array: [featureProvider])
// Batch predict with model
let results = try queue.sync { try model.predictions(fromBatch: batch) }
let batchLatents: [MLShapedArray<Float32>] = try (0..<results.count).compactMap { i in
let result = results.features(at: i)
guard
let outputName = result.featureNames.first,
let output = result.featureValue(for: outputName)?.multiArrayValue
else {
throw Error.latentOutputNotValid
}
print("output.shape: \(output.shape)")
return MLShapedArray(output)
}
guard let latents = batchLatents.first else {
throw Error.batchLatentOutputEmpty
}
return latents
}
}