You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
110 lines
4.0 KiB
110 lines
4.0 KiB
2 years ago
|
// For licensing see accompanying LICENSE.md file.
|
||
|
// Copyright (C) 2022 Apple Inc. All Rights Reserved.
|
||
|
|
||
|
import Foundation
|
||
|
import CoreML
|
||
|
import Accelerate
|
||
|
|
||
|
/// A decoder model which produces RGB images from latent samples
|
||
|
public struct Decoder {
|
||
|
|
||
|
/// VAE decoder model
|
||
|
var model: MLModel
|
||
|
|
||
|
/// Create decoder from Core ML model
|
||
|
///
|
||
|
/// - Parameters
|
||
|
/// - model: Core ML model for VAE decoder
|
||
|
public init(model: MLModel) {
|
||
|
self.model = model
|
||
|
}
|
||
|
|
||
|
/// Prediction queue
|
||
|
let queue = DispatchQueue(label: "decoder.predict")
|
||
|
|
||
|
/// Batch decode latent samples into images
|
||
|
///
|
||
|
/// - Parameters:
|
||
|
/// - latents: Batch of latent samples to decode
|
||
|
/// - Returns: decoded images
|
||
|
public func decode(_ latents: [MLShapedArray<Float32>]) throws -> [CGImage] {
|
||
|
|
||
|
// Form batch inputs for model
|
||
|
let inputs: [MLFeatureProvider] = try latents.map { sample in
|
||
|
// Reference pipeline scales the latent samples before decoding
|
||
|
let sampleScaled = MLShapedArray<Float32>(
|
||
|
scalars: sample.scalars.map { $0 / 0.18215 },
|
||
|
shape: sample.shape)
|
||
|
|
||
|
let dict = [inputName: MLMultiArray(sampleScaled)]
|
||
|
return try MLDictionaryFeatureProvider(dictionary: dict)
|
||
|
}
|
||
|
let batch = MLArrayBatchProvider(array: inputs)
|
||
|
|
||
|
// Batch predict with model
|
||
|
let results = try queue.sync { try model.predictions(fromBatch: batch) }
|
||
|
|
||
|
// Transform the outputs to CGImages
|
||
|
let images: [CGImage] = (0..<results.count).map { i in
|
||
|
let result = results.features(at: i)
|
||
|
let outputName = result.featureNames.first!
|
||
|
let output = result.featureValue(for: outputName)!.multiArrayValue!
|
||
|
|
||
|
return toRGBCGImage(MLShapedArray<Float32>(output))
|
||
|
}
|
||
|
|
||
|
return images
|
||
|
}
|
||
|
|
||
|
var inputName: String {
|
||
|
model.modelDescription.inputDescriptionsByName.first!.key
|
||
|
}
|
||
|
|
||
|
typealias PixelBufferPFx1 = vImage.PixelBuffer<vImage.PlanarF>
|
||
|
typealias PixelBufferP8x3 = vImage.PixelBuffer<vImage.Planar8x3>
|
||
|
typealias PixelBufferIFx3 = vImage.PixelBuffer<vImage.InterleavedFx3>
|
||
|
typealias PixelBufferI8x3 = vImage.PixelBuffer<vImage.Interleaved8x3>
|
||
|
|
||
|
func toRGBCGImage(_ array: MLShapedArray<Float32>) -> CGImage {
|
||
|
|
||
|
// array is [N,C,H,W], where C==3
|
||
|
let channelCount = array.shape[1]
|
||
|
assert(channelCount == 3,
|
||
|
"Decoding model output has \(channelCount) channels, expected 3")
|
||
|
let height = array.shape[2]
|
||
|
let width = array.shape[3]
|
||
|
|
||
|
// Normalize each channel into a float between 0 and 1.0
|
||
|
let floatChannels = (0..<channelCount).map { i in
|
||
|
|
||
|
// Normalized channel output
|
||
|
let cOut = PixelBufferPFx1(width: width, height:height)
|
||
|
|
||
|
// Reference this channel in the array and normalize
|
||
|
array[0][i].withUnsafeShapedBufferPointer { ptr, _, strides in
|
||
|
let cIn = PixelBufferPFx1(data: .init(mutating: ptr.baseAddress!),
|
||
|
width: width, height: height,
|
||
|
byteCountPerRow: strides[0]*4)
|
||
|
// Map [-1.0 1.0] -> [0.0 1.0]
|
||
|
cIn.multiply(by: 0.5, preBias: 1.0, postBias: 0.0, destination: cOut)
|
||
|
}
|
||
|
return cOut
|
||
|
}
|
||
|
|
||
|
// Convert to interleaved and then to UInt8
|
||
|
let floatImage = PixelBufferIFx3(planarBuffers: floatChannels)
|
||
|
let uint8Image = PixelBufferI8x3(width: width, height: height)
|
||
|
floatImage.convert(to:uint8Image) // maps [0.0 1.0] -> [0 255] and clips
|
||
|
|
||
|
// Convert to uint8x3 to RGB CGImage (no alpha)
|
||
|
let bitmapInfo = CGBitmapInfo(rawValue: CGImageAlphaInfo.none.rawValue)
|
||
|
let cgImage = uint8Image.makeCGImage(cgImageFormat:
|
||
|
.init(bitsPerComponent: 8,
|
||
|
bitsPerPixel: 3*8,
|
||
|
colorSpace: CGColorSpaceCreateDeviceRGB(),
|
||
|
bitmapInfo: bitmapInfo)!)!
|
||
|
|
||
|
return cgImage
|
||
|
}
|
||
|
}
|