From fa7bbdc255131b4106c098a75971c4dddc2c2529 Mon Sep 17 00:00:00 2001 From: Timothy Kautz Date: Thu, 9 Feb 2023 10:30:21 -0800 Subject: [PATCH] Image2image - swift (#116) * Image2Image Encoder Encoder * Scheduler and pipeline * fix scheduler * cli * remove CLI comment * disable dpm multistep solver with image2image * clamp initial timestamp * Store timesteps in reverse order for consistency. * Report actual number of steps. * uint32 * PRComments * remove old initializer * pr comments * change name and add error handling also fix 512 hard coded * Add fix for Jpegs --------- Co-authored-by: Pedro Cuenca --- .../pipeline/AlphasCumprodCalculation.swift | 29 ++++ .../pipeline/CGImage+vImage.swift | 120 ++++++++++++++++ swift/StableDiffusion/pipeline/Decoder.swift | 52 +------ swift/StableDiffusion/pipeline/Encoder.swift | 131 ++++++++++++++++++ .../StableDiffusion/pipeline/Scheduler.swift | 32 +++++ .../StableDiffusionPipeline+Resources.swift | 11 ++ ...tableDiffusionPipeline.Configuration.swift | 60 ++++++++ .../pipeline/StableDiffusionPipeline.swift | 111 ++++++++++----- swift/StableDiffusionCLI/main.swift | 68 ++++++--- 9 files changed, 517 insertions(+), 97 deletions(-) create mode 100644 swift/StableDiffusion/pipeline/AlphasCumprodCalculation.swift create mode 100644 swift/StableDiffusion/pipeline/CGImage+vImage.swift create mode 100644 swift/StableDiffusion/pipeline/Encoder.swift create mode 100644 swift/StableDiffusion/pipeline/StableDiffusionPipeline.Configuration.swift diff --git a/swift/StableDiffusion/pipeline/AlphasCumprodCalculation.swift b/swift/StableDiffusion/pipeline/AlphasCumprodCalculation.swift new file mode 100644 index 0000000..7875683 --- /dev/null +++ b/swift/StableDiffusion/pipeline/AlphasCumprodCalculation.swift @@ -0,0 +1,29 @@ +// For licensing see accompanying LICENSE.md file. +// Copyright (C) 2022 Apple Inc. All Rights Reserved. + +import Foundation + +public struct AlphasCumprodCalculation { + public var sqrtAlphasCumprod: Float + public var sqrtOneMinusAlphasCumprod: Float + + public init( + sqrtAlphasCumprod: Float, + sqrtOneMinusAlphasCumprod: Float + ) { + self.sqrtAlphasCumprod = sqrtAlphasCumprod + self.sqrtOneMinusAlphasCumprod = sqrtOneMinusAlphasCumprod + } + + public init( + alphasCumprod: [Float], + timesteps: Int = 1_000, + steps: Int, + strength: Float + ) { + let tEnc = Int(strength * Float(steps)) + let initTimestep = min(max(0, timesteps - timesteps / steps * (steps - tEnc) + 1), timesteps - 1) + self.sqrtAlphasCumprod = alphasCumprod[initTimestep].squareRoot() + self.sqrtOneMinusAlphasCumprod = (1 - alphasCumprod[initTimestep]).squareRoot() + } +} diff --git a/swift/StableDiffusion/pipeline/CGImage+vImage.swift b/swift/StableDiffusion/pipeline/CGImage+vImage.swift new file mode 100644 index 0000000..6910b20 --- /dev/null +++ b/swift/StableDiffusion/pipeline/CGImage+vImage.swift @@ -0,0 +1,120 @@ +// For licensing see accompanying LICENSE.md file. +// Copyright (C) 2022 Apple Inc. All Rights Reserved. + +import Foundation +import Accelerate +import CoreML + +@available(iOS 16.0, macOS 13.0, *) +extension CGImage { + + typealias PixelBufferPFx1 = vImage.PixelBuffer + typealias PixelBufferP8x3 = vImage.PixelBuffer + typealias PixelBufferIFx3 = vImage.PixelBuffer + typealias PixelBufferI8x3 = vImage.PixelBuffer + + public enum ShapedArrayError: String, Swift.Error { + case wrongNumberOfChannels + case incorrectFormatsConvertingToShapedArray + case vImageConverterNotInitialized + } + + public static func fromShapedArray(_ array: MLShapedArray) throws -> CGImage { + + // array is [N,C,H,W], where C==3 + let channelCount = array.shape[1] + guard channelCount == 3 else { + throw ShapedArrayError.wrongNumberOfChannels + } + + let height = array.shape[2] + let width = array.shape[3] + + // Normalize each channel into a float between 0 and 1.0 + let floatChannels = (0.. [0.0 1.0] + cIn.multiply(by: 0.5, preBias: 1.0, postBias: 0.0, destination: cOut) + } + return cOut + } + + // Convert to interleaved and then to UInt8 + let floatImage = PixelBufferIFx3(planarBuffers: floatChannels) + let uint8Image = PixelBufferI8x3(width: width, height: height) + floatImage.convert(to:uint8Image) // maps [0.0 1.0] -> [0 255] and clips + + // Convert to uint8x3 to RGB CGImage (no alpha) + let bitmapInfo = CGBitmapInfo(rawValue: CGImageAlphaInfo.none.rawValue) + let cgImage = uint8Image.makeCGImage(cgImageFormat: + .init(bitsPerComponent: 8, + bitsPerPixel: 3*8, + colorSpace: CGColorSpaceCreateDeviceRGB(), + bitmapInfo: bitmapInfo)!)! + + return cgImage + } + + public var plannerRGBShapedArray: MLShapedArray { + get throws { + guard + var sourceFormat = vImage_CGImageFormat(cgImage: self), + var mediumFormat = vImage_CGImageFormat( + bitsPerComponent: 8 * MemoryLayout.size, + bitsPerPixel: 8 * MemoryLayout.size * 4, + colorSpace: CGColorSpaceCreateDeviceRGB(), + bitmapInfo: CGBitmapInfo(rawValue: CGImageAlphaInfo.first.rawValue)), + let width = vImagePixelCount(exactly: self.width), + let height = vImagePixelCount(exactly: self.height) + else { + throw ShapedArrayError.incorrectFormatsConvertingToShapedArray + } + + var sourceImageBuffer = try vImage_Buffer(cgImage: self) + + var mediumDesination = try vImage_Buffer(width: Int(width), height: Int(height), bitsPerPixel: mediumFormat.bitsPerPixel) + + let converter = vImageConverter_CreateWithCGImageFormat( + &sourceFormat, + &mediumFormat, + nil, + vImage_Flags(kvImagePrintDiagnosticsToConsole), + nil) + + guard let converter = converter?.takeRetainedValue() else { + throw ShapedArrayError.vImageConverterNotInitialized + } + + vImageConvert_AnyToAny(converter, &sourceImageBuffer, &mediumDesination, nil, vImage_Flags(kvImagePrintDiagnosticsToConsole)) + + var destinationA = try vImage_Buffer(width: Int(width), height: Int(height), bitsPerPixel: 8 * UInt32(MemoryLayout.size)) + var destinationR = try vImage_Buffer(width: Int(width), height: Int(height), bitsPerPixel: 8 * UInt32(MemoryLayout.size)) + var destinationG = try vImage_Buffer(width: Int(width), height: Int(height), bitsPerPixel: 8 * UInt32(MemoryLayout.size)) + var destinationB = try vImage_Buffer(width: Int(width), height: Int(height), bitsPerPixel: 8 * UInt32(MemoryLayout.size)) + + var minFloat: [Float] = [-1.0, -1.0, -1.0, -1.0] + var maxFloat: [Float] = [1.0, 1.0, 1.0, 1.0] + + vImageConvert_ARGB8888toPlanarF(&mediumDesination, &destinationA, &destinationR, &destinationG, &destinationB, &maxFloat, &minFloat, .zero) + + let redData = Data(bytes: destinationR.data, count: Int(width) * Int(height) * MemoryLayout.size) + let greenData = Data(bytes: destinationG.data, count: Int(width) * Int(height) * MemoryLayout.size) + let blueData = Data(bytes: destinationB.data, count: Int(width) * Int(height) * MemoryLayout.size) + + let imageData = redData + greenData + blueData + + let shapedArray = MLShapedArray(data: imageData, shape: [1, 3, self.width, self.height]) + + return shapedArray + } + } +} + diff --git a/swift/StableDiffusion/pipeline/Decoder.swift b/swift/StableDiffusion/pipeline/Decoder.swift index 04f04ba..e9b2c70 100644 --- a/swift/StableDiffusion/pipeline/Decoder.swift +++ b/swift/StableDiffusion/pipeline/Decoder.swift @@ -3,7 +3,6 @@ import Foundation import CoreML -import Accelerate /// A decoder model which produces RGB images from latent samples @available(iOS 16.2, macOS 13.1, *) @@ -57,12 +56,11 @@ public struct Decoder: ResourceManaging { } // Transform the outputs to CGImages - let images: [CGImage] = (0..(output)) + return try CGImage.fromShapedArray(MLShapedArray(output)) } return images @@ -74,50 +72,4 @@ public struct Decoder: ResourceManaging { } } - typealias PixelBufferPFx1 = vImage.PixelBuffer - typealias PixelBufferP8x3 = vImage.PixelBuffer - typealias PixelBufferIFx3 = vImage.PixelBuffer - typealias PixelBufferI8x3 = vImage.PixelBuffer - - func toRGBCGImage(_ array: MLShapedArray) -> CGImage { - - // array is [N,C,H,W], where C==3 - let channelCount = array.shape[1] - assert(channelCount == 3, - "Decoding model output has \(channelCount) channels, expected 3") - let height = array.shape[2] - let width = array.shape[3] - - // Normalize each channel into a float between 0 and 1.0 - let floatChannels = (0.. [0.0 1.0] - cIn.multiply(by: 0.5, preBias: 1.0, postBias: 0.0, destination: cOut) - } - return cOut - } - - // Convert to interleaved and then to UInt8 - let floatImage = PixelBufferIFx3(planarBuffers: floatChannels) - let uint8Image = PixelBufferI8x3(width: width, height: height) - floatImage.convert(to:uint8Image) // maps [0.0 1.0] -> [0 255] and clips - - // Convert to uint8x3 to RGB CGImage (no alpha) - let bitmapInfo = CGBitmapInfo(rawValue: CGImageAlphaInfo.none.rawValue) - let cgImage = uint8Image.makeCGImage(cgImageFormat: - .init(bitsPerComponent: 8, - bitsPerPixel: 3*8, - colorSpace: CGColorSpaceCreateDeviceRGB(), - bitmapInfo: bitmapInfo)!)! - - return cgImage - } } diff --git a/swift/StableDiffusion/pipeline/Encoder.swift b/swift/StableDiffusion/pipeline/Encoder.swift new file mode 100644 index 0000000..b35f5a9 --- /dev/null +++ b/swift/StableDiffusion/pipeline/Encoder.swift @@ -0,0 +1,131 @@ +// For licensing see accompanying LICENSE.md file. +// Copyright (C) 2022 Apple Inc. All Rights Reserved. + +import Foundation +import CoreML + +@available(iOS 16.0, macOS 13.1, *) +/// Encoder, currently supports image2image +public struct Encoder: ResourceManaging { + + public enum FeatureName: String { + case sample = "sample" + case diagonalNoise = "diagonal_noise" + case noise = "noise" + case sqrtAlphasCumprod = "sqrt_alphas_cumprod" + case sqrtOneMinusAlphasCumprod = "sqrt_one_minus_alphas_cumprod" + } + + public enum Error: String, Swift.Error { + case latentOutputNotValid + case batchLatentOutputEmpty + case sampleInputShapeNotCorrect + case noiseInputShapeNotCorrect + } + + /// VAE encoder model + post math and adding noise from schedular + var model: ManagedMLModel + + /// Create encoder from Core ML model + /// + /// - Parameters: + /// - url: Location of compiled VAE encoder Core ML model + /// - configuration: configuration to be used when the model is loaded + /// - Returns: An encoder that will lazily load its required resources when needed or requested + public init(modelAt url: URL, configuration: MLModelConfiguration) { + self.model = ManagedMLModel(modelAt: url, configuration: configuration) + } + + /// Ensure the model has been loaded into memory + public func loadResources() throws { + try model.loadResources() + } + + /// Unload the underlying model to free up memory + public func unloadResources() { + model.unloadResources() + } + + /// Prediction queue + let queue = DispatchQueue(label: "encoder.predict") + + /// Batch encode latent samples into images + /// - Parameters: + /// - image: image used for image2image + /// - diagonalNoise: random noise for `DiagonalGaussianDistribution` operation + /// - noise: random noise for initial latent space based on strength argument + /// - alphasCumprodStep: calculations using the scheduler traditionally calculated in the pipeline in pyTorch Diffusers library. + /// - Returns: The encoded latent space as MLShapedArray + public func encode( + image: CGImage, + diagonalNoise: MLShapedArray, + noise: MLShapedArray, + alphasCumprodStep: AlphasCumprodCalculation + ) throws -> MLShapedArray { + let sample = try image.plannerRGBShapedArray + let sqrtAlphasCumprod = MLShapedArray(scalars: [alphasCumprodStep.sqrtAlphasCumprod], shape: [1, 1]) + let sqrtOneMinusAlphasCumprod = MLShapedArray(scalars: [alphasCumprodStep.sqrtOneMinusAlphasCumprod], shape: [1, 1]) + + let dict: [String: Any] = [ + FeatureName.sample.rawValue: MLMultiArray(sample), + FeatureName.diagonalNoise.rawValue: MLMultiArray(diagonalNoise), + FeatureName.noise.rawValue: MLMultiArray(noise), + FeatureName.sqrtAlphasCumprod.rawValue: MLMultiArray(sqrtAlphasCumprod), + FeatureName.sqrtOneMinusAlphasCumprod.rawValue: MLMultiArray(sqrtOneMinusAlphasCumprod), + ] + let featureProvider = try MLDictionaryFeatureProvider(dictionary: dict) + + let batch = MLArrayBatchProvider(array: [featureProvider]) + + // Batch predict with model + + let results = try queue.sync { + try model.perform { model in + if let feature = model.modelDescription.inputDescriptionsByName[FeatureName.sample.rawValue], + let shape = feature.multiArrayConstraint?.shape as? [Int] + { + guard sample.shape == shape else { + // TODO: Consider auto resizing and croping similar to how Vision or CoreML auto-generated Swift code can accomplish with `MLFeatureValue` + throw Error.sampleInputShapeNotCorrect + } + } + + if let feature = model.modelDescription.inputDescriptionsByName[FeatureName.noise.rawValue], + let shape = feature.multiArrayConstraint?.shape as? [Int] + { + guard noise.shape == shape else { + throw Error.noiseInputShapeNotCorrect + } + } + + if let feature = model.modelDescription.inputDescriptionsByName[FeatureName.diagonalNoise.rawValue], + let shape = feature.multiArrayConstraint?.shape as? [Int] + { + guard diagonalNoise.shape == shape else { + throw Error.noiseInputShapeNotCorrect + } + } + + return try model.predictions(fromBatch: batch) + } + } + + let batchLatents: [MLShapedArray] = try (0.. [Int] + /// Schedule of betas which controls the amount of noise added at each timestep var betas: [Float] { get } @@ -71,6 +74,34 @@ public extension Scheduler { } } +// MARK: - Image2Image + +@available(iOS 16.2, macOS 13.1, *) +public extension Scheduler { + + func calculateAlphasCumprod(strength: Float) -> AlphasCumprodCalculation { + AlphasCumprodCalculation( + alphasCumprod: alphasCumProd, + timesteps: trainStepCount, + steps: inferenceStepCount, + strength: strength) + } +} + +// MARK: - Timesteps + +@available(iOS 16.2, macOS 13.1, *) +public extension Scheduler { + func calculateTimesteps(strength: Float?) -> [Int] { + guard let strength else { return timeSteps } + let startStep = max(inferenceStepCount - Int(Float(inferenceStepCount) * strength), 0) + let actualTimesteps = Array(timeSteps[startStep...]) + return actualTimesteps + } +} + +// MARK: - BetaSchedule + /// How to map a beta range to a sequence of betas to step over @available(iOS 16.2, macOS 13.1, *) public enum BetaSchedule { @@ -80,6 +111,7 @@ public enum BetaSchedule { case scaledLinear } +// MARK: - PNDMScheduler /// A scheduler used to compute a de-noised image /// diff --git a/swift/StableDiffusion/pipeline/StableDiffusionPipeline+Resources.swift b/swift/StableDiffusion/pipeline/StableDiffusionPipeline+Resources.swift index 65c6e03..ecb349e 100644 --- a/swift/StableDiffusion/pipeline/StableDiffusionPipeline+Resources.swift +++ b/swift/StableDiffusion/pipeline/StableDiffusionPipeline+Resources.swift @@ -14,6 +14,7 @@ public extension StableDiffusionPipeline { public let unetChunk1URL: URL public let unetChunk2URL: URL public let decoderURL: URL + public let encoderURL: URL public let safetyCheckerURL: URL public let vocabURL: URL public let mergesURL: URL @@ -24,6 +25,7 @@ public extension StableDiffusionPipeline { unetChunk1URL = baseURL.appending(path: "UnetChunk1.mlmodelc") unetChunk2URL = baseURL.appending(path: "UnetChunk2.mlmodelc") decoderURL = baseURL.appending(path: "VAEDecoder.mlmodelc") + encoderURL = baseURL.appending(path: "VAEEncoder.mlmodelc") safetyCheckerURL = baseURL.appending(path: "SafetyChecker.mlmodelc") vocabURL = baseURL.appending(path: "vocab.json") mergesURL = baseURL.appending(path: "merges.txt") @@ -74,11 +76,20 @@ public extension StableDiffusionPipeline { FileManager.default.fileExists(atPath: urls.safetyCheckerURL.path) { safetyChecker = SafetyChecker(modelAt: urls.safetyCheckerURL, configuration: config) } + + // Optional Image Encoder + let encoder: Encoder? + if FileManager.default.fileExists(atPath: urls.encoderURL.path) { + encoder = Encoder(modelAt: urls.encoderURL, configuration: config) + } else { + encoder = nil + } // Construct pipeline self.init(textEncoder: textEncoder, unet: unet, decoder: decoder, + encoder: encoder, safetyChecker: safetyChecker, reduceMemory: reduceMemory) } diff --git a/swift/StableDiffusion/pipeline/StableDiffusionPipeline.Configuration.swift b/swift/StableDiffusion/pipeline/StableDiffusionPipeline.Configuration.swift new file mode 100644 index 0000000..8933557 --- /dev/null +++ b/swift/StableDiffusion/pipeline/StableDiffusionPipeline.Configuration.swift @@ -0,0 +1,60 @@ +// For licensing see accompanying LICENSE.md file. +// Copyright (C) 2022 Apple Inc. All Rights Reserved. + +import Foundation +import CoreGraphics + +@available(iOS 16.2, macOS 13.1, *) +extension StableDiffusionPipeline { + + /// Tyoe of processing that will be performed to generate an image + public enum Mode { + case textToImage + case imageToImage + // case inPainting + } + + /// Image generation configuration + public struct Configuration: Hashable { + + /// Text prompt to guide sampling + public var prompt: String + /// Negative text prompt to guide sampling + public var negativePrompt: String = "" + /// Starting image for image2image or in-painting + public var startingImage: CGImage? = nil + //public var maskImage: CGImage? = nil + public var strength: Float = 1.0 + /// Number of images to generate + public var imageCount: Int = 1 + /// Number of inference steps to perform + public var stepCount: Int = 50 + /// Random seed which to start generation + public var seed: UInt32 = 0 + /// Controls the influence of the text prompt on sampling process (0=random images) + public var guidanceScale: Float = 7.5 + /// Safety checks are only performed if `self.canSafetyCheck && !disableSafety` + public var disableSafety: Bool = false + /// The type of Scheduler to use. + public var schedulerType: StableDiffusionScheduler = .pndmScheduler + + /// Given the configuration, what mode will be used for generation + public var mode: Mode { + guard startingImage != nil else { + return .textToImage + } + guard strength < 1.0 else { + return .textToImage + } + return .imageToImage + } + + public init( + prompt: String + ) { + self.prompt = prompt + } + + } + +} diff --git a/swift/StableDiffusion/pipeline/StableDiffusionPipeline.swift b/swift/StableDiffusion/pipeline/StableDiffusionPipeline.swift index ea65472..f279247 100644 --- a/swift/StableDiffusion/pipeline/StableDiffusionPipeline.swift +++ b/swift/StableDiffusion/pipeline/StableDiffusionPipeline.swift @@ -20,7 +20,11 @@ public enum StableDiffusionScheduler { /// [Hugging Face Diffusers Pipeline](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py) @available(iOS 16.2, macOS 13.1, *) public struct StableDiffusionPipeline: ResourceManaging { - + + public enum Error: String, Swift.Error { + case startingImageProvidedWithoutEncoder + } + /// Model to generate embeddings for tokenized input text var textEncoder: TextEncoder @@ -29,6 +33,9 @@ public struct StableDiffusionPipeline: ResourceManaging { /// Model used to generate final image from latent diffusion process var decoder: Decoder + + /// Model used to latent space for image2image, and soon, in-painting + var encoder: Encoder? /// Optional model for checking safety of generated image var safetyChecker: SafetyChecker? = nil @@ -58,11 +65,13 @@ public struct StableDiffusionPipeline: ResourceManaging { public init(textEncoder: TextEncoder, unet: Unet, decoder: Decoder, + encoder: Encoder?, safetyChecker: SafetyChecker? = nil, reduceMemory: Bool = false) { self.textEncoder = textEncoder self.unet = unet self.decoder = decoder + self.encoder = encoder self.safetyChecker = safetyChecker self.reduceMemory = reduceMemory } @@ -98,34 +107,20 @@ public struct StableDiffusionPipeline: ResourceManaging { try safetyChecker?.prewarmResources() } - /// Text to image generation using stable diffusion - /// + /// Image generation using stable diffusion /// - Parameters: - /// - prompt: Text prompt to guide sampling - /// - negativePrompt: Negative text prompt to guide sampling - /// - stepCount: Number of inference steps to perform - /// - imageCount: Number of samples/images to generate for the input prompt - /// - seed: Random seed which - /// - guidanceScale: Controls the influence of the text prompt on sampling process (0=random images) /// - disableSafety: Safety checks are only performed if `self.canSafetyCheck && !disableSafety` /// - progressHandler: Callback to perform after each step, stops on receiving false response /// - Returns: An array of `imageCount` optional images. /// The images will be nil if safety checks were performed and found the result to be un-safe public func generateImages( - prompt: String, - negativePrompt: String = "", - imageCount: Int = 1, - stepCount: Int = 50, - seed: UInt32 = 0, - guidanceScale: Float = 7.5, - disableSafety: Bool = false, - scheduler: StableDiffusionScheduler = .pndmScheduler, + configuration config: Configuration, progressHandler: (Progress) -> Bool = { _ in true } ) throws -> [CGImage?] { // Encode the input prompt and negative prompt - let promptEmbedding = try textEncoder.encode(prompt) - let negativePromptEmbedding = try textEncoder.encode(negativePrompt) + let promptEmbedding = try textEncoder.encode(config.prompt) + let negativePromptEmbedding = try textEncoder.encode(config.negativePrompt) if reduceMemory { textEncoder.unloadResources() @@ -141,19 +136,44 @@ public struct StableDiffusionPipeline: ResourceManaging { let hiddenStates = toHiddenStates(concatEmbedding) /// Setup schedulers - let scheduler: [Scheduler] = (0..] + let timestepStrength: Float? + + if + let startingImage = config.startingImage, + config.mode == .imageToImage + { + timestepStrength = config.strength + guard let encoder else { + throw Error.startingImageProvidedWithoutEncoder + } + + let noiseTuples = generateImage2ImageLatentSamples(config.imageCount, stdev: 1, seed: config.seed) + latents = try noiseTuples.map({ + try encoder.encode( + image: startingImage, + diagonalNoise: $0.diagonal, + noise: $0.latentNoise, + alphasCumprodStep: scheduler[0].calculateAlphasCumprod(strength: config.strength)) + }) + } else { + timestepStrength = nil + // Generate random latent samples from specified seed + latents = generateLatentSamples(config.imageCount, stdev: stdev, seed: config.seed) + } // De-noising loop - for (step,t) in scheduler[0].timeSteps.enumerated() { + let timeSteps: [Int] = scheduler[0].calculateTimesteps(strength: timestepStrength) + for (step,t) in timeSteps.enumerated() { // Expand the latents for classifier-free guidance // and input to the Unet noise prediction model @@ -169,11 +189,11 @@ public struct StableDiffusionPipeline: ResourceManaging { hiddenStates: hiddenStates ) - noise = performGuidance(noise, guidanceScale) + noise = performGuidance(noise, config.guidanceScale) // Have the scheduler compute the previous (t-1) latent // sample given the predicted noise and current sample - for i in 0.. [MLShapedArray] { @@ -215,6 +235,35 @@ public struct StableDiffusionPipeline: ResourceManaging { } return samples } + + + /// For image2image - + /// - Parameters: + /// - count: batch size + /// - stdev: 1 + /// - seed: seed provided + /// - diagonalAndLatentNoiseIsSame: Diffusions library does not seem to use the same noise for the `DiagonalGaussianDistribution` operation, + /// but I have seen implementations of pipelines where it is the same. + /// - Returns: An array of tuples of noise values with length of batch size. + func generateImage2ImageLatentSamples(_ count: Int, stdev: Float, seed: UInt32, diagonalAndLatentNoiseIsSame: Bool = false) -> [(diagonal: MLShapedArray, latentNoise: MLShapedArray)] { + var sampleShape = unet.latentSampleShape + sampleShape[0] = 1 + + var random = NumPyRandomSource(seed: UInt32(truncatingIfNeeded: seed)) + let samples = (0..( + converting: random.normalShapedArray(sampleShape, mean: 0.0, stdev: Double(stdev))) + return (noise, noise) + } else { + return (MLShapedArray( + converting: random.normalShapedArray(sampleShape, mean: 0.0, stdev: Double(stdev))), + MLShapedArray( + converting: random.normalShapedArray(sampleShape, mean: 0.0, stdev: Double(stdev)))) + } + } + return samples + } func toHiddenStates(_ embedding: MLShapedArray) -> MLShapedArray { // Unoptimized manual transpose [0, 2, None, 1] diff --git a/swift/StableDiffusionCLI/main.swift b/swift/StableDiffusionCLI/main.swift index 7343cf2..c5ae31a 100644 --- a/swift/StableDiffusionCLI/main.swift +++ b/swift/StableDiffusionCLI/main.swift @@ -7,6 +7,7 @@ import CoreML import Foundation import StableDiffusion import UniformTypeIdentifiers +import Cocoa @available(iOS 16.2, macOS 13.1, *) struct StableDiffusionSample: ParsableCommand { @@ -32,6 +33,12 @@ struct StableDiffusionSample: ParsableCommand { ) ) var resourcePath: String = "./" + + @Option(help: "Path to starting image.") + var image: String? = nil + + @Option(help: "Strength for image2image.") + var strength: Float = 0.5 @Option(help: "Number of images to sample / generate") var imageCount: Int = 1 @@ -51,7 +58,7 @@ struct StableDiffusionSample: ParsableCommand { var outputPath: String = "./" @Option(help: "Random seed") - var seed: UInt32 = 93 + var seed: UInt32 = UInt32.random(in: 0...UInt32.max) @Option(help: "Controls the influence of the text prompt on sampling process (0=random images)") var guidanceScale: Float = 7.5 @@ -84,27 +91,52 @@ struct StableDiffusionSample: ParsableCommand { disableSafety: disableSafety, reduceMemory: reduceMemory) try pipeline.loadResources() + + let startingImage: CGImage? + if let image { + let imageURL = URL(filePath: image) + do { + let imageData = try Data(contentsOf: imageURL) + guard + let nsImage = NSImage(data: imageData), + let loadedImage = nsImage.cgImage(forProposedRect: nil, context: nil, hints: nil) + else { + throw RunError.resources("Starting Image not available \(resourcePath)") + } + startingImage = loadedImage + } catch let error { + throw RunError.resources("Starting image not found \(imageURL), error: \(error)") + } + + } else { + startingImage = nil + } log("Sampling ...\n") let sampleTimer = SampleTimer() sampleTimer.start() + var pipelineConfig = StableDiffusionPipeline.Configuration(prompt: prompt) + + pipelineConfig.negativePrompt = negativePrompt + pipelineConfig.startingImage = startingImage + pipelineConfig.strength = strength + pipelineConfig.imageCount = imageCount + pipelineConfig.stepCount = stepCount + pipelineConfig.seed = seed + pipelineConfig.guidanceScale = guidanceScale + pipelineConfig.schedulerType = scheduler.stableDiffusionScheduler + let images = try pipeline.generateImages( - prompt: prompt, - negativePrompt: negativePrompt, - imageCount: imageCount, - stepCount: stepCount, - seed: seed, - guidanceScale: guidanceScale, - scheduler: scheduler.stableDiffusionScheduler - ) { progress in - sampleTimer.stop() - handleProgress(progress,sampleTimer) - if progress.stepCount != progress.step { - sampleTimer.start() - } - return true - } + configuration: pipelineConfig, + progressHandler: { progress in + sampleTimer.stop() + handleProgress(progress,sampleTimer) + if progress.stepCount != progress.step { + sampleTimer.start() + } + return true + }) _ = try saveImages(images, logNames: true) } @@ -168,6 +200,10 @@ struct StableDiffusionSample: ParsableCommand { if imageCount != 1 { name += ".\(sample)" } + + if image != "none" { + name += ".str\(Int(strength * 100))" + } name += ".\(seed)"