You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

121 lines
5.5 KiB

// For licensing see accompanying LICENSE.md file.
// Copyright (C) 2022 Apple Inc. All Rights Reserved.
import Foundation
import Accelerate
import CoreML
@available(iOS 16.0, macOS 13.0, *)
extension CGImage {
typealias PixelBufferPFx1 = vImage.PixelBuffer<vImage.PlanarF>
typealias PixelBufferP8x3 = vImage.PixelBuffer<vImage.Planar8x3>
typealias PixelBufferIFx3 = vImage.PixelBuffer<vImage.InterleavedFx3>
typealias PixelBufferI8x3 = vImage.PixelBuffer<vImage.Interleaved8x3>
public enum ShapedArrayError: String, Swift.Error {
case wrongNumberOfChannels
case incorrectFormatsConvertingToShapedArray
case vImageConverterNotInitialized
}
public static func fromShapedArray(_ array: MLShapedArray<Float32>) throws -> CGImage {
// array is [N,C,H,W], where C==3
let channelCount = array.shape[1]
guard channelCount == 3 else {
throw ShapedArrayError.wrongNumberOfChannels
}
let height = array.shape[2]
let width = array.shape[3]
// Normalize each channel into a float between 0 and 1.0
let floatChannels = (0..<channelCount).map { i in
// Normalized channel output
let cOut = PixelBufferPFx1(width: width, height:height)
// Reference this channel in the array and normalize
array[0][i].withUnsafeShapedBufferPointer { ptr, _, strides in
let cIn = PixelBufferPFx1(data: .init(mutating: ptr.baseAddress!),
width: width, height: height,
byteCountPerRow: strides[0]*4)
// Map [-1.0 1.0] -> [0.0 1.0]
cIn.multiply(by: 0.5, preBias: 1.0, postBias: 0.0, destination: cOut)
}
return cOut
}
// Convert to interleaved and then to UInt8
let floatImage = PixelBufferIFx3(planarBuffers: floatChannels)
let uint8Image = PixelBufferI8x3(width: width, height: height)
floatImage.convert(to:uint8Image) // maps [0.0 1.0] -> [0 255] and clips
// Convert to uint8x3 to RGB CGImage (no alpha)
let bitmapInfo = CGBitmapInfo(rawValue: CGImageAlphaInfo.none.rawValue)
let cgImage = uint8Image.makeCGImage(cgImageFormat:
.init(bitsPerComponent: 8,
bitsPerPixel: 3*8,
colorSpace: CGColorSpaceCreateDeviceRGB(),
bitmapInfo: bitmapInfo)!)!
return cgImage
}
public var plannerRGBShapedArray: MLShapedArray<Float32> {
get throws {
guard
var sourceFormat = vImage_CGImageFormat(cgImage: self),
var mediumFormat = vImage_CGImageFormat(
bitsPerComponent: 8 * MemoryLayout<UInt8>.size,
bitsPerPixel: 8 * MemoryLayout<UInt8>.size * 4,
colorSpace: CGColorSpaceCreateDeviceRGB(),
bitmapInfo: CGBitmapInfo(rawValue: CGImageAlphaInfo.first.rawValue)),
let width = vImagePixelCount(exactly: self.width),
let height = vImagePixelCount(exactly: self.height)
else {
throw ShapedArrayError.incorrectFormatsConvertingToShapedArray
}
var sourceImageBuffer = try vImage_Buffer(cgImage: self)
var mediumDesination = try vImage_Buffer(width: Int(width), height: Int(height), bitsPerPixel: mediumFormat.bitsPerPixel)
let converter = vImageConverter_CreateWithCGImageFormat(
&sourceFormat,
&mediumFormat,
nil,
vImage_Flags(kvImagePrintDiagnosticsToConsole),
nil)
guard let converter = converter?.takeRetainedValue() else {
throw ShapedArrayError.vImageConverterNotInitialized
}
vImageConvert_AnyToAny(converter, &sourceImageBuffer, &mediumDesination, nil, vImage_Flags(kvImagePrintDiagnosticsToConsole))
var destinationA = try vImage_Buffer(width: Int(width), height: Int(height), bitsPerPixel: 8 * UInt32(MemoryLayout<Float>.size))
var destinationR = try vImage_Buffer(width: Int(width), height: Int(height), bitsPerPixel: 8 * UInt32(MemoryLayout<Float>.size))
var destinationG = try vImage_Buffer(width: Int(width), height: Int(height), bitsPerPixel: 8 * UInt32(MemoryLayout<Float>.size))
var destinationB = try vImage_Buffer(width: Int(width), height: Int(height), bitsPerPixel: 8 * UInt32(MemoryLayout<Float>.size))
var minFloat: [Float] = [-1.0, -1.0, -1.0, -1.0]
var maxFloat: [Float] = [1.0, 1.0, 1.0, 1.0]
vImageConvert_ARGB8888toPlanarF(&mediumDesination, &destinationA, &destinationR, &destinationG, &destinationB, &maxFloat, &minFloat, .zero)
let redData = Data(bytes: destinationR.data, count: Int(width) * Int(height) * MemoryLayout<Float>.size)
let greenData = Data(bytes: destinationG.data, count: Int(width) * Int(height) * MemoryLayout<Float>.size)
let blueData = Data(bytes: destinationB.data, count: Int(width) * Int(height) * MemoryLayout<Float>.size)
let imageData = redData + greenData + blueData
let shapedArray = MLShapedArray<Float32>(data: imageData, shape: [1, 3, 512, 512])
return shapedArray
}
}
}