From 09602864d60f4e2908417c6c6c48061ac21bcc50 Mon Sep 17 00:00:00 2001 From: Liu Liu Date: Mon, 13 Feb 2023 17:18:16 -0500 Subject: [PATCH] Update with configuration available in SwiftDiffusionCLI --- .../{Random.swift => NumPyRandomSource.swift} | 2 +- .../pipeline/RandomSource.swift | 6 ++++ ...tableDiffusionPipeline.Configuration.swift | 2 ++ .../pipeline/StableDiffusionPipeline.swift | 30 ++++++++++++++----- .../pipeline/TorchRandomSource.swift | 2 +- swift/StableDiffusionCLI/main.swift | 14 +++++++++ 6 files changed, 47 insertions(+), 9 deletions(-) rename swift/StableDiffusion/pipeline/{Random.swift => NumPyRandomSource.swift} (98%) create mode 100644 swift/StableDiffusion/pipeline/RandomSource.swift diff --git a/swift/StableDiffusion/pipeline/Random.swift b/swift/StableDiffusion/pipeline/NumPyRandomSource.swift similarity index 98% rename from swift/StableDiffusion/pipeline/Random.swift rename to swift/StableDiffusion/pipeline/NumPyRandomSource.swift index a1e8d35..62c9c5b 100644 --- a/swift/StableDiffusion/pipeline/Random.swift +++ b/swift/StableDiffusion/pipeline/NumPyRandomSource.swift @@ -10,7 +10,7 @@ import CoreML /// [NumPy's older randomkit.c](https://github.com/numpy/numpy/blob/v1.0/numpy/random/mtrand/randomkit.c) /// @available(iOS 16.2, macOS 13.1, *) -struct NumPyRandomSource: RandomNumberGenerator { +struct NumPyRandomSource: RandomNumberGenerator, RandomSource { struct State { var key = [UInt32](repeating: 0, count: 624) diff --git a/swift/StableDiffusion/pipeline/RandomSource.swift b/swift/StableDiffusion/pipeline/RandomSource.swift new file mode 100644 index 0000000..8ff2e3d --- /dev/null +++ b/swift/StableDiffusion/pipeline/RandomSource.swift @@ -0,0 +1,6 @@ +import CoreML + +@available(iOS 16.2, macOS 13.1, *) +public protocol RandomSource { + mutating func normalShapedArray(_ shape: [Int], mean: Double, stdev: Double) -> MLShapedArray +} diff --git a/swift/StableDiffusion/pipeline/StableDiffusionPipeline.Configuration.swift b/swift/StableDiffusion/pipeline/StableDiffusionPipeline.Configuration.swift index 8933557..f58fcf1 100644 --- a/swift/StableDiffusion/pipeline/StableDiffusionPipeline.Configuration.swift +++ b/swift/StableDiffusion/pipeline/StableDiffusionPipeline.Configuration.swift @@ -37,6 +37,8 @@ extension StableDiffusionPipeline { public var disableSafety: Bool = false /// The type of Scheduler to use. public var schedulerType: StableDiffusionScheduler = .pndmScheduler + /// The type of RNG to use + public var rngType: StableDiffusionRNG = .numpyRNG /// Given the configuration, what mode will be used for generation public var mode: Mode { diff --git a/swift/StableDiffusion/pipeline/StableDiffusionPipeline.swift b/swift/StableDiffusion/pipeline/StableDiffusionPipeline.swift index f279247..2bc58c7 100644 --- a/swift/StableDiffusion/pipeline/StableDiffusionPipeline.swift +++ b/swift/StableDiffusion/pipeline/StableDiffusionPipeline.swift @@ -14,6 +14,14 @@ public enum StableDiffusionScheduler { case dpmSolverMultistepScheduler } +/// RNG compatible with StableDiffusionPipeline +public enum StableDiffusionRNG { + /// RNG that matches numpy implementation + case numpyRNG + /// RNG that matches PyTorch CPU implementation. + case torchRNG +} + /// A pipeline used to generate image samples from text input using stable diffusion /// /// This implementation matches: @@ -157,7 +165,7 @@ public struct StableDiffusionPipeline: ResourceManaging { throw Error.startingImageProvidedWithoutEncoder } - let noiseTuples = generateImage2ImageLatentSamples(config.imageCount, stdev: 1, seed: config.seed) + let noiseTuples = generateImage2ImageLatentSamples(config.imageCount, rng: config.rngType, stdev: 1, seed: config.seed) latents = try noiseTuples.map({ try encoder.encode( image: startingImage, @@ -168,7 +176,7 @@ public struct StableDiffusionPipeline: ResourceManaging { } else { timestepStrength = nil // Generate random latent samples from specified seed - latents = generateLatentSamples(config.imageCount, stdev: stdev, seed: config.seed) + latents = generateLatentSamples(config.imageCount, rng: config.rngType, stdev: stdev, seed: config.seed) } // De-noising loop @@ -224,11 +232,19 @@ public struct StableDiffusionPipeline: ResourceManaging { return try decodeToImages(latents, disableSafety: config.disableSafety) } - func generateLatentSamples(_ count: Int, stdev: Float, seed: UInt32) -> [MLShapedArray] { + private func randomSource(from rng: StableDiffusionRNG, seed: UInt32) -> RandomSource { + switch rng { + case .numpyRNG: + return NumPyRandomSource(seed: seed) + case .torchRNG: + return TorchRandomSource(seed: seed) + } + } + + func generateLatentSamples(_ count: Int, rng: StableDiffusionRNG, stdev: Float, seed: UInt32) -> [MLShapedArray] { var sampleShape = unet.latentSampleShape sampleShape[0] = 1 - - var random = NumPyRandomSource(seed: seed) + var random = randomSource(from: rng, seed: seed) let samples = (0..( converting: random.normalShapedArray(sampleShape, mean: 0.0, stdev: Double(stdev))) @@ -245,11 +261,11 @@ public struct StableDiffusionPipeline: ResourceManaging { /// - diagonalAndLatentNoiseIsSame: Diffusions library does not seem to use the same noise for the `DiagonalGaussianDistribution` operation, /// but I have seen implementations of pipelines where it is the same. /// - Returns: An array of tuples of noise values with length of batch size. - func generateImage2ImageLatentSamples(_ count: Int, stdev: Float, seed: UInt32, diagonalAndLatentNoiseIsSame: Bool = false) -> [(diagonal: MLShapedArray, latentNoise: MLShapedArray)] { + func generateImage2ImageLatentSamples(_ count: Int, rng: StableDiffusionRNG, stdev: Float, seed: UInt32, diagonalAndLatentNoiseIsSame: Bool = false) -> [(diagonal: MLShapedArray, latentNoise: MLShapedArray)] { var sampleShape = unet.latentSampleShape sampleShape[0] = 1 - var random = NumPyRandomSource(seed: UInt32(truncatingIfNeeded: seed)) + var random = randomSource(from: rng, seed: seed) let samples = (0..( diff --git a/swift/StableDiffusion/pipeline/TorchRandomSource.swift b/swift/StableDiffusion/pipeline/TorchRandomSource.swift index 547da64..8799d18 100644 --- a/swift/StableDiffusion/pipeline/TorchRandomSource.swift +++ b/swift/StableDiffusion/pipeline/TorchRandomSource.swift @@ -13,7 +13,7 @@ import CoreML /// https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/core/TransformationHelper.h /// @available(iOS 16.2, macOS 13.1, *) -struct TorchRandomSource: RandomNumberGenerator { +struct TorchRandomSource: RandomNumberGenerator, RandomSource { struct State { var key = [UInt32](repeating: 0, count: 624) diff --git a/swift/StableDiffusionCLI/main.swift b/swift/StableDiffusionCLI/main.swift index c5ae31a..5bacb1b 100644 --- a/swift/StableDiffusionCLI/main.swift +++ b/swift/StableDiffusionCLI/main.swift @@ -69,6 +69,9 @@ struct StableDiffusionSample: ParsableCommand { @Option(help: "Scheduler to use, one of {pndm, dpmpp}") var scheduler: SchedulerOption = .pndm + @Option(help: "Random number generator to use, one of {numpy, torch}") + var rng: RNGOption = .numpy + @Flag(help: "Disable safety checking") var disableSafety: Bool = false @@ -250,6 +253,17 @@ enum SchedulerOption: String, ExpressibleByArgument { } } +@available(iOS 16.2, macOS 13.1, *) +enum RNGOption: String, ExpressibleByArgument { + case numpy, torch + var stableDiffusionRNG: StableDiffusionRNG { + switch self { + case .numpy: return .numpyRNG + case .torch: return .torchRNG + } + } +} + if #available(iOS 16.2, macOS 13.1, *) { StableDiffusionSample.main() } else {