Image2image - swift (#116)
* Image2Image Encoder Encoder * Scheduler and pipeline * fix scheduler * cli * remove CLI comment * disable dpm multistep solver with image2image * clamp initial timestamp * Store timesteps in reverse order for consistency. * Report actual number of steps. * uint32 * PRComments * remove old initializer * pr comments * change name and add error handling also fix 512 hard coded * Add fix for Jpegs --------- Co-authored-by: Pedro Cuenca <pedro@huggingface.co>pull/125/head
parent
f0fc326477
commit
fa7bbdc255
@ -0,0 +1,29 @@
|
|||||||
|
// For licensing see accompanying LICENSE.md file.
|
||||||
|
// Copyright (C) 2022 Apple Inc. All Rights Reserved.
|
||||||
|
|
||||||
|
import Foundation
|
||||||
|
|
||||||
|
public struct AlphasCumprodCalculation {
|
||||||
|
public var sqrtAlphasCumprod: Float
|
||||||
|
public var sqrtOneMinusAlphasCumprod: Float
|
||||||
|
|
||||||
|
public init(
|
||||||
|
sqrtAlphasCumprod: Float,
|
||||||
|
sqrtOneMinusAlphasCumprod: Float
|
||||||
|
) {
|
||||||
|
self.sqrtAlphasCumprod = sqrtAlphasCumprod
|
||||||
|
self.sqrtOneMinusAlphasCumprod = sqrtOneMinusAlphasCumprod
|
||||||
|
}
|
||||||
|
|
||||||
|
public init(
|
||||||
|
alphasCumprod: [Float],
|
||||||
|
timesteps: Int = 1_000,
|
||||||
|
steps: Int,
|
||||||
|
strength: Float
|
||||||
|
) {
|
||||||
|
let tEnc = Int(strength * Float(steps))
|
||||||
|
let initTimestep = min(max(0, timesteps - timesteps / steps * (steps - tEnc) + 1), timesteps - 1)
|
||||||
|
self.sqrtAlphasCumprod = alphasCumprod[initTimestep].squareRoot()
|
||||||
|
self.sqrtOneMinusAlphasCumprod = (1 - alphasCumprod[initTimestep]).squareRoot()
|
||||||
|
}
|
||||||
|
}
|
@ -0,0 +1,120 @@
|
|||||||
|
// For licensing see accompanying LICENSE.md file.
|
||||||
|
// Copyright (C) 2022 Apple Inc. All Rights Reserved.
|
||||||
|
|
||||||
|
import Foundation
|
||||||
|
import Accelerate
|
||||||
|
import CoreML
|
||||||
|
|
||||||
|
@available(iOS 16.0, macOS 13.0, *)
|
||||||
|
extension CGImage {
|
||||||
|
|
||||||
|
typealias PixelBufferPFx1 = vImage.PixelBuffer<vImage.PlanarF>
|
||||||
|
typealias PixelBufferP8x3 = vImage.PixelBuffer<vImage.Planar8x3>
|
||||||
|
typealias PixelBufferIFx3 = vImage.PixelBuffer<vImage.InterleavedFx3>
|
||||||
|
typealias PixelBufferI8x3 = vImage.PixelBuffer<vImage.Interleaved8x3>
|
||||||
|
|
||||||
|
public enum ShapedArrayError: String, Swift.Error {
|
||||||
|
case wrongNumberOfChannels
|
||||||
|
case incorrectFormatsConvertingToShapedArray
|
||||||
|
case vImageConverterNotInitialized
|
||||||
|
}
|
||||||
|
|
||||||
|
public static func fromShapedArray(_ array: MLShapedArray<Float32>) throws -> CGImage {
|
||||||
|
|
||||||
|
// array is [N,C,H,W], where C==3
|
||||||
|
let channelCount = array.shape[1]
|
||||||
|
guard channelCount == 3 else {
|
||||||
|
throw ShapedArrayError.wrongNumberOfChannels
|
||||||
|
}
|
||||||
|
|
||||||
|
let height = array.shape[2]
|
||||||
|
let width = array.shape[3]
|
||||||
|
|
||||||
|
// Normalize each channel into a float between 0 and 1.0
|
||||||
|
let floatChannels = (0..<channelCount).map { i in
|
||||||
|
|
||||||
|
// Normalized channel output
|
||||||
|
let cOut = PixelBufferPFx1(width: width, height:height)
|
||||||
|
|
||||||
|
// Reference this channel in the array and normalize
|
||||||
|
array[0][i].withUnsafeShapedBufferPointer { ptr, _, strides in
|
||||||
|
let cIn = PixelBufferPFx1(data: .init(mutating: ptr.baseAddress!),
|
||||||
|
width: width, height: height,
|
||||||
|
byteCountPerRow: strides[0]*4)
|
||||||
|
// Map [-1.0 1.0] -> [0.0 1.0]
|
||||||
|
cIn.multiply(by: 0.5, preBias: 1.0, postBias: 0.0, destination: cOut)
|
||||||
|
}
|
||||||
|
return cOut
|
||||||
|
}
|
||||||
|
|
||||||
|
// Convert to interleaved and then to UInt8
|
||||||
|
let floatImage = PixelBufferIFx3(planarBuffers: floatChannels)
|
||||||
|
let uint8Image = PixelBufferI8x3(width: width, height: height)
|
||||||
|
floatImage.convert(to:uint8Image) // maps [0.0 1.0] -> [0 255] and clips
|
||||||
|
|
||||||
|
// Convert to uint8x3 to RGB CGImage (no alpha)
|
||||||
|
let bitmapInfo = CGBitmapInfo(rawValue: CGImageAlphaInfo.none.rawValue)
|
||||||
|
let cgImage = uint8Image.makeCGImage(cgImageFormat:
|
||||||
|
.init(bitsPerComponent: 8,
|
||||||
|
bitsPerPixel: 3*8,
|
||||||
|
colorSpace: CGColorSpaceCreateDeviceRGB(),
|
||||||
|
bitmapInfo: bitmapInfo)!)!
|
||||||
|
|
||||||
|
return cgImage
|
||||||
|
}
|
||||||
|
|
||||||
|
public var plannerRGBShapedArray: MLShapedArray<Float32> {
|
||||||
|
get throws {
|
||||||
|
guard
|
||||||
|
var sourceFormat = vImage_CGImageFormat(cgImage: self),
|
||||||
|
var mediumFormat = vImage_CGImageFormat(
|
||||||
|
bitsPerComponent: 8 * MemoryLayout<UInt8>.size,
|
||||||
|
bitsPerPixel: 8 * MemoryLayout<UInt8>.size * 4,
|
||||||
|
colorSpace: CGColorSpaceCreateDeviceRGB(),
|
||||||
|
bitmapInfo: CGBitmapInfo(rawValue: CGImageAlphaInfo.first.rawValue)),
|
||||||
|
let width = vImagePixelCount(exactly: self.width),
|
||||||
|
let height = vImagePixelCount(exactly: self.height)
|
||||||
|
else {
|
||||||
|
throw ShapedArrayError.incorrectFormatsConvertingToShapedArray
|
||||||
|
}
|
||||||
|
|
||||||
|
var sourceImageBuffer = try vImage_Buffer(cgImage: self)
|
||||||
|
|
||||||
|
var mediumDesination = try vImage_Buffer(width: Int(width), height: Int(height), bitsPerPixel: mediumFormat.bitsPerPixel)
|
||||||
|
|
||||||
|
let converter = vImageConverter_CreateWithCGImageFormat(
|
||||||
|
&sourceFormat,
|
||||||
|
&mediumFormat,
|
||||||
|
nil,
|
||||||
|
vImage_Flags(kvImagePrintDiagnosticsToConsole),
|
||||||
|
nil)
|
||||||
|
|
||||||
|
guard let converter = converter?.takeRetainedValue() else {
|
||||||
|
throw ShapedArrayError.vImageConverterNotInitialized
|
||||||
|
}
|
||||||
|
|
||||||
|
vImageConvert_AnyToAny(converter, &sourceImageBuffer, &mediumDesination, nil, vImage_Flags(kvImagePrintDiagnosticsToConsole))
|
||||||
|
|
||||||
|
var destinationA = try vImage_Buffer(width: Int(width), height: Int(height), bitsPerPixel: 8 * UInt32(MemoryLayout<Float>.size))
|
||||||
|
var destinationR = try vImage_Buffer(width: Int(width), height: Int(height), bitsPerPixel: 8 * UInt32(MemoryLayout<Float>.size))
|
||||||
|
var destinationG = try vImage_Buffer(width: Int(width), height: Int(height), bitsPerPixel: 8 * UInt32(MemoryLayout<Float>.size))
|
||||||
|
var destinationB = try vImage_Buffer(width: Int(width), height: Int(height), bitsPerPixel: 8 * UInt32(MemoryLayout<Float>.size))
|
||||||
|
|
||||||
|
var minFloat: [Float] = [-1.0, -1.0, -1.0, -1.0]
|
||||||
|
var maxFloat: [Float] = [1.0, 1.0, 1.0, 1.0]
|
||||||
|
|
||||||
|
vImageConvert_ARGB8888toPlanarF(&mediumDesination, &destinationA, &destinationR, &destinationG, &destinationB, &maxFloat, &minFloat, .zero)
|
||||||
|
|
||||||
|
let redData = Data(bytes: destinationR.data, count: Int(width) * Int(height) * MemoryLayout<Float>.size)
|
||||||
|
let greenData = Data(bytes: destinationG.data, count: Int(width) * Int(height) * MemoryLayout<Float>.size)
|
||||||
|
let blueData = Data(bytes: destinationB.data, count: Int(width) * Int(height) * MemoryLayout<Float>.size)
|
||||||
|
|
||||||
|
let imageData = redData + greenData + blueData
|
||||||
|
|
||||||
|
let shapedArray = MLShapedArray<Float32>(data: imageData, shape: [1, 3, self.width, self.height])
|
||||||
|
|
||||||
|
return shapedArray
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,131 @@
|
|||||||
|
// For licensing see accompanying LICENSE.md file.
|
||||||
|
// Copyright (C) 2022 Apple Inc. All Rights Reserved.
|
||||||
|
|
||||||
|
import Foundation
|
||||||
|
import CoreML
|
||||||
|
|
||||||
|
@available(iOS 16.0, macOS 13.1, *)
|
||||||
|
/// Encoder, currently supports image2image
|
||||||
|
public struct Encoder: ResourceManaging {
|
||||||
|
|
||||||
|
public enum FeatureName: String {
|
||||||
|
case sample = "sample"
|
||||||
|
case diagonalNoise = "diagonal_noise"
|
||||||
|
case noise = "noise"
|
||||||
|
case sqrtAlphasCumprod = "sqrt_alphas_cumprod"
|
||||||
|
case sqrtOneMinusAlphasCumprod = "sqrt_one_minus_alphas_cumprod"
|
||||||
|
}
|
||||||
|
|
||||||
|
public enum Error: String, Swift.Error {
|
||||||
|
case latentOutputNotValid
|
||||||
|
case batchLatentOutputEmpty
|
||||||
|
case sampleInputShapeNotCorrect
|
||||||
|
case noiseInputShapeNotCorrect
|
||||||
|
}
|
||||||
|
|
||||||
|
/// VAE encoder model + post math and adding noise from schedular
|
||||||
|
var model: ManagedMLModel
|
||||||
|
|
||||||
|
/// Create encoder from Core ML model
|
||||||
|
///
|
||||||
|
/// - Parameters:
|
||||||
|
/// - url: Location of compiled VAE encoder Core ML model
|
||||||
|
/// - configuration: configuration to be used when the model is loaded
|
||||||
|
/// - Returns: An encoder that will lazily load its required resources when needed or requested
|
||||||
|
public init(modelAt url: URL, configuration: MLModelConfiguration) {
|
||||||
|
self.model = ManagedMLModel(modelAt: url, configuration: configuration)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Ensure the model has been loaded into memory
|
||||||
|
public func loadResources() throws {
|
||||||
|
try model.loadResources()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Unload the underlying model to free up memory
|
||||||
|
public func unloadResources() {
|
||||||
|
model.unloadResources()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Prediction queue
|
||||||
|
let queue = DispatchQueue(label: "encoder.predict")
|
||||||
|
|
||||||
|
/// Batch encode latent samples into images
|
||||||
|
/// - Parameters:
|
||||||
|
/// - image: image used for image2image
|
||||||
|
/// - diagonalNoise: random noise for `DiagonalGaussianDistribution` operation
|
||||||
|
/// - noise: random noise for initial latent space based on strength argument
|
||||||
|
/// - alphasCumprodStep: calculations using the scheduler traditionally calculated in the pipeline in pyTorch Diffusers library.
|
||||||
|
/// - Returns: The encoded latent space as MLShapedArray
|
||||||
|
public func encode(
|
||||||
|
image: CGImage,
|
||||||
|
diagonalNoise: MLShapedArray<Float32>,
|
||||||
|
noise: MLShapedArray<Float32>,
|
||||||
|
alphasCumprodStep: AlphasCumprodCalculation
|
||||||
|
) throws -> MLShapedArray<Float32> {
|
||||||
|
let sample = try image.plannerRGBShapedArray
|
||||||
|
let sqrtAlphasCumprod = MLShapedArray(scalars: [alphasCumprodStep.sqrtAlphasCumprod], shape: [1, 1])
|
||||||
|
let sqrtOneMinusAlphasCumprod = MLShapedArray(scalars: [alphasCumprodStep.sqrtOneMinusAlphasCumprod], shape: [1, 1])
|
||||||
|
|
||||||
|
let dict: [String: Any] = [
|
||||||
|
FeatureName.sample.rawValue: MLMultiArray(sample),
|
||||||
|
FeatureName.diagonalNoise.rawValue: MLMultiArray(diagonalNoise),
|
||||||
|
FeatureName.noise.rawValue: MLMultiArray(noise),
|
||||||
|
FeatureName.sqrtAlphasCumprod.rawValue: MLMultiArray(sqrtAlphasCumprod),
|
||||||
|
FeatureName.sqrtOneMinusAlphasCumprod.rawValue: MLMultiArray(sqrtOneMinusAlphasCumprod),
|
||||||
|
]
|
||||||
|
let featureProvider = try MLDictionaryFeatureProvider(dictionary: dict)
|
||||||
|
|
||||||
|
let batch = MLArrayBatchProvider(array: [featureProvider])
|
||||||
|
|
||||||
|
// Batch predict with model
|
||||||
|
|
||||||
|
let results = try queue.sync {
|
||||||
|
try model.perform { model in
|
||||||
|
if let feature = model.modelDescription.inputDescriptionsByName[FeatureName.sample.rawValue],
|
||||||
|
let shape = feature.multiArrayConstraint?.shape as? [Int]
|
||||||
|
{
|
||||||
|
guard sample.shape == shape else {
|
||||||
|
// TODO: Consider auto resizing and croping similar to how Vision or CoreML auto-generated Swift code can accomplish with `MLFeatureValue`
|
||||||
|
throw Error.sampleInputShapeNotCorrect
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if let feature = model.modelDescription.inputDescriptionsByName[FeatureName.noise.rawValue],
|
||||||
|
let shape = feature.multiArrayConstraint?.shape as? [Int]
|
||||||
|
{
|
||||||
|
guard noise.shape == shape else {
|
||||||
|
throw Error.noiseInputShapeNotCorrect
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if let feature = model.modelDescription.inputDescriptionsByName[FeatureName.diagonalNoise.rawValue],
|
||||||
|
let shape = feature.multiArrayConstraint?.shape as? [Int]
|
||||||
|
{
|
||||||
|
guard diagonalNoise.shape == shape else {
|
||||||
|
throw Error.noiseInputShapeNotCorrect
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return try model.predictions(fromBatch: batch)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let batchLatents: [MLShapedArray<Float32>] = try (0..<results.count).compactMap { i in
|
||||||
|
let result = results.features(at: i)
|
||||||
|
guard
|
||||||
|
let outputName = result.featureNames.first,
|
||||||
|
let output = result.featureValue(for: outputName)?.multiArrayValue
|
||||||
|
else {
|
||||||
|
throw Error.latentOutputNotValid
|
||||||
|
}
|
||||||
|
return MLShapedArray(output)
|
||||||
|
}
|
||||||
|
|
||||||
|
guard let latents = batchLatents.first else {
|
||||||
|
throw Error.batchLatentOutputEmpty
|
||||||
|
}
|
||||||
|
|
||||||
|
return latents
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
@ -0,0 +1,60 @@
|
|||||||
|
// For licensing see accompanying LICENSE.md file.
|
||||||
|
// Copyright (C) 2022 Apple Inc. All Rights Reserved.
|
||||||
|
|
||||||
|
import Foundation
|
||||||
|
import CoreGraphics
|
||||||
|
|
||||||
|
@available(iOS 16.2, macOS 13.1, *)
|
||||||
|
extension StableDiffusionPipeline {
|
||||||
|
|
||||||
|
/// Tyoe of processing that will be performed to generate an image
|
||||||
|
public enum Mode {
|
||||||
|
case textToImage
|
||||||
|
case imageToImage
|
||||||
|
// case inPainting
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Image generation configuration
|
||||||
|
public struct Configuration: Hashable {
|
||||||
|
|
||||||
|
/// Text prompt to guide sampling
|
||||||
|
public var prompt: String
|
||||||
|
/// Negative text prompt to guide sampling
|
||||||
|
public var negativePrompt: String = ""
|
||||||
|
/// Starting image for image2image or in-painting
|
||||||
|
public var startingImage: CGImage? = nil
|
||||||
|
//public var maskImage: CGImage? = nil
|
||||||
|
public var strength: Float = 1.0
|
||||||
|
/// Number of images to generate
|
||||||
|
public var imageCount: Int = 1
|
||||||
|
/// Number of inference steps to perform
|
||||||
|
public var stepCount: Int = 50
|
||||||
|
/// Random seed which to start generation
|
||||||
|
public var seed: UInt32 = 0
|
||||||
|
/// Controls the influence of the text prompt on sampling process (0=random images)
|
||||||
|
public var guidanceScale: Float = 7.5
|
||||||
|
/// Safety checks are only performed if `self.canSafetyCheck && !disableSafety`
|
||||||
|
public var disableSafety: Bool = false
|
||||||
|
/// The type of Scheduler to use.
|
||||||
|
public var schedulerType: StableDiffusionScheduler = .pndmScheduler
|
||||||
|
|
||||||
|
/// Given the configuration, what mode will be used for generation
|
||||||
|
public var mode: Mode {
|
||||||
|
guard startingImage != nil else {
|
||||||
|
return .textToImage
|
||||||
|
}
|
||||||
|
guard strength < 1.0 else {
|
||||||
|
return .textToImage
|
||||||
|
}
|
||||||
|
return .imageToImage
|
||||||
|
}
|
||||||
|
|
||||||
|
public init(
|
||||||
|
prompt: String
|
||||||
|
) {
|
||||||
|
self.prompt = prompt
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
Loading…
Reference in new issue