// For licensing see accompanying LICENSE.md file. // Copyright (C) 2022 Apple Inc. All Rights Reserved. import Foundation import CoreML import Accelerate /// A decoder model which produces RGB images from latent samples public struct Decoder { /// VAE decoder model var model: MLModel /// Create decoder from Core ML model /// /// - Parameters /// - model: Core ML model for VAE decoder public init(model: MLModel) { self.model = model } /// Prediction queue let queue = DispatchQueue(label: "decoder.predict") /// Batch decode latent samples into images /// /// - Parameters: /// - latents: Batch of latent samples to decode /// - Returns: decoded images public func decode(_ latents: [MLShapedArray]) throws -> [CGImage] { // Form batch inputs for model let inputs: [MLFeatureProvider] = try latents.map { sample in // Reference pipeline scales the latent samples before decoding let sampleScaled = MLShapedArray( scalars: sample.scalars.map { $0 / 0.18215 }, shape: sample.shape) let dict = [inputName: MLMultiArray(sampleScaled)] return try MLDictionaryFeatureProvider(dictionary: dict) } let batch = MLArrayBatchProvider(array: inputs) // Batch predict with model let results = try queue.sync { try model.predictions(fromBatch: batch) } // Transform the outputs to CGImages let images: [CGImage] = (0..(output)) } return images } var inputName: String { model.modelDescription.inputDescriptionsByName.first!.key } typealias PixelBufferPFx1 = vImage.PixelBuffer typealias PixelBufferP8x3 = vImage.PixelBuffer typealias PixelBufferIFx3 = vImage.PixelBuffer typealias PixelBufferI8x3 = vImage.PixelBuffer func toRGBCGImage(_ array: MLShapedArray) -> CGImage { // array is [N,C,H,W], where C==3 let channelCount = array.shape[1] assert(channelCount == 3, "Decoding model output has \(channelCount) channels, expected 3") let height = array.shape[2] let width = array.shape[3] // Normalize each channel into a float between 0 and 1.0 let floatChannels = (0.. [0.0 1.0] cIn.multiply(by: 0.5, preBias: 1.0, postBias: 0.0, destination: cOut) } return cOut } // Convert to interleaved and then to UInt8 let floatImage = PixelBufferIFx3(planarBuffers: floatChannels) let uint8Image = PixelBufferI8x3(width: width, height: height) floatImage.convert(to:uint8Image) // maps [0.0 1.0] -> [0 255] and clips // Convert to uint8x3 to RGB CGImage (no alpha) let bitmapInfo = CGBitmapInfo(rawValue: CGImageAlphaInfo.none.rawValue) let cgImage = uint8Image.makeCGImage(cgImageFormat: .init(bitsPerComponent: 8, bitsPerPixel: 3*8, colorSpace: CGColorSpaceCreateDeviceRGB(), bitmapInfo: bitmapInfo)!)! return cgImage } }