diff --git a/swift/StableDiffusion/pipeline/Random.swift b/swift/StableDiffusion/pipeline/NumPyRandomSource.swift similarity index 98% rename from swift/StableDiffusion/pipeline/Random.swift rename to swift/StableDiffusion/pipeline/NumPyRandomSource.swift index a1e8d35..62c9c5b 100644 --- a/swift/StableDiffusion/pipeline/Random.swift +++ b/swift/StableDiffusion/pipeline/NumPyRandomSource.swift @@ -10,7 +10,7 @@ import CoreML /// [NumPy's older randomkit.c](https://github.com/numpy/numpy/blob/v1.0/numpy/random/mtrand/randomkit.c) /// @available(iOS 16.2, macOS 13.1, *) -struct NumPyRandomSource: RandomNumberGenerator { +struct NumPyRandomSource: RandomNumberGenerator, RandomSource { struct State { var key = [UInt32](repeating: 0, count: 624) diff --git a/swift/StableDiffusion/pipeline/RandomSource.swift b/swift/StableDiffusion/pipeline/RandomSource.swift new file mode 100644 index 0000000..8ff2e3d --- /dev/null +++ b/swift/StableDiffusion/pipeline/RandomSource.swift @@ -0,0 +1,6 @@ +import CoreML + +@available(iOS 16.2, macOS 13.1, *) +public protocol RandomSource { + mutating func normalShapedArray(_ shape: [Int], mean: Double, stdev: Double) -> MLShapedArray +} diff --git a/swift/StableDiffusion/pipeline/StableDiffusionPipeline.Configuration.swift b/swift/StableDiffusion/pipeline/StableDiffusionPipeline.Configuration.swift index 8933557..f58fcf1 100644 --- a/swift/StableDiffusion/pipeline/StableDiffusionPipeline.Configuration.swift +++ b/swift/StableDiffusion/pipeline/StableDiffusionPipeline.Configuration.swift @@ -37,6 +37,8 @@ extension StableDiffusionPipeline { public var disableSafety: Bool = false /// The type of Scheduler to use. public var schedulerType: StableDiffusionScheduler = .pndmScheduler + /// The type of RNG to use + public var rngType: StableDiffusionRNG = .numpyRNG /// Given the configuration, what mode will be used for generation public var mode: Mode { diff --git a/swift/StableDiffusion/pipeline/StableDiffusionPipeline.swift b/swift/StableDiffusion/pipeline/StableDiffusionPipeline.swift index f279247..2bc58c7 100644 --- a/swift/StableDiffusion/pipeline/StableDiffusionPipeline.swift +++ b/swift/StableDiffusion/pipeline/StableDiffusionPipeline.swift @@ -14,6 +14,14 @@ public enum StableDiffusionScheduler { case dpmSolverMultistepScheduler } +/// RNG compatible with StableDiffusionPipeline +public enum StableDiffusionRNG { + /// RNG that matches numpy implementation + case numpyRNG + /// RNG that matches PyTorch CPU implementation. + case torchRNG +} + /// A pipeline used to generate image samples from text input using stable diffusion /// /// This implementation matches: @@ -157,7 +165,7 @@ public struct StableDiffusionPipeline: ResourceManaging { throw Error.startingImageProvidedWithoutEncoder } - let noiseTuples = generateImage2ImageLatentSamples(config.imageCount, stdev: 1, seed: config.seed) + let noiseTuples = generateImage2ImageLatentSamples(config.imageCount, rng: config.rngType, stdev: 1, seed: config.seed) latents = try noiseTuples.map({ try encoder.encode( image: startingImage, @@ -168,7 +176,7 @@ public struct StableDiffusionPipeline: ResourceManaging { } else { timestepStrength = nil // Generate random latent samples from specified seed - latents = generateLatentSamples(config.imageCount, stdev: stdev, seed: config.seed) + latents = generateLatentSamples(config.imageCount, rng: config.rngType, stdev: stdev, seed: config.seed) } // De-noising loop @@ -224,11 +232,19 @@ public struct StableDiffusionPipeline: ResourceManaging { return try decodeToImages(latents, disableSafety: config.disableSafety) } - func generateLatentSamples(_ count: Int, stdev: Float, seed: UInt32) -> [MLShapedArray] { + private func randomSource(from rng: StableDiffusionRNG, seed: UInt32) -> RandomSource { + switch rng { + case .numpyRNG: + return NumPyRandomSource(seed: seed) + case .torchRNG: + return TorchRandomSource(seed: seed) + } + } + + func generateLatentSamples(_ count: Int, rng: StableDiffusionRNG, stdev: Float, seed: UInt32) -> [MLShapedArray] { var sampleShape = unet.latentSampleShape sampleShape[0] = 1 - - var random = NumPyRandomSource(seed: seed) + var random = randomSource(from: rng, seed: seed) let samples = (0..( converting: random.normalShapedArray(sampleShape, mean: 0.0, stdev: Double(stdev))) @@ -245,11 +261,11 @@ public struct StableDiffusionPipeline: ResourceManaging { /// - diagonalAndLatentNoiseIsSame: Diffusions library does not seem to use the same noise for the `DiagonalGaussianDistribution` operation, /// but I have seen implementations of pipelines where it is the same. /// - Returns: An array of tuples of noise values with length of batch size. - func generateImage2ImageLatentSamples(_ count: Int, stdev: Float, seed: UInt32, diagonalAndLatentNoiseIsSame: Bool = false) -> [(diagonal: MLShapedArray, latentNoise: MLShapedArray)] { + func generateImage2ImageLatentSamples(_ count: Int, rng: StableDiffusionRNG, stdev: Float, seed: UInt32, diagonalAndLatentNoiseIsSame: Bool = false) -> [(diagonal: MLShapedArray, latentNoise: MLShapedArray)] { var sampleShape = unet.latentSampleShape sampleShape[0] = 1 - var random = NumPyRandomSource(seed: UInt32(truncatingIfNeeded: seed)) + var random = randomSource(from: rng, seed: seed) let samples = (0..( diff --git a/swift/StableDiffusion/pipeline/TorchRandomSource.swift b/swift/StableDiffusion/pipeline/TorchRandomSource.swift new file mode 100644 index 0000000..8799d18 --- /dev/null +++ b/swift/StableDiffusion/pipeline/TorchRandomSource.swift @@ -0,0 +1,152 @@ +// For licensing see accompanying LICENSE.md file. +// Copyright (C) 2022 Apple Inc. All Rights Reserved. + +import Foundation +import CoreML + +/// A random source consistent with PyTorch +/// +/// This implementation matches: +/// https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/core/DistributionsHelper.h +/// https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cpu/DistributionTemplates.h +/// https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cpu/DistributionKernels.cpp +/// https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/core/TransformationHelper.h +/// +@available(iOS 16.2, macOS 13.1, *) +struct TorchRandomSource: RandomNumberGenerator, RandomSource { + + struct State { + var key = [UInt32](repeating: 0, count: 624) + var pos: Int = 0 + var nextGauss: Double? = nil + } + + var state: State + + /// Initialize with a random seed + /// + /// - Parameters + /// - seed: Seed for underlying Mersenne Twister 19937 generator + /// - Returns random source + init(seed: UInt32) { + state = .init() + var s = seed & 0xffff_ffff + for i in 0..> 30)) + UInt64(i) + 1) & 0xffff_ffff) + } + state.pos = state.key.count + state.nextGauss = nil + } + + /// Generate next UInt32 using fast 32bit Mersenne Twister + mutating func nextUInt32() -> UInt32 { + let n = 624 + let m = 397 + let matrixA: UInt64 = 0x9908_b0df + let upperMask: UInt32 = 0x8000_0000 + let lowerMask: UInt32 = 0x7fff_ffff + + var y: UInt32 + if state.pos == state.key.count { + for i in 0..<(n - m) { + y = (state.key[i] & upperMask) | (state.key[i + 1] & lowerMask) + state.key[i] = state.key[i + m] ^ (y >> 1) ^ UInt32((UInt64(~(y & 1)) + 1) & matrixA) + } + for i in (n - m)..<(n - 1) { + y = (state.key[i] & upperMask) | (state.key[i + 1] & lowerMask) + state.key[i] = state.key[i + (m - n)] ^ (y >> 1) ^ UInt32((UInt64(~(y & 1)) + 1) & matrixA) + } + y = (state.key[n - 1] & upperMask) | (state.key[0] & lowerMask) + state.key[n - 1] = state.key[m - 1] ^ (y >> 1) ^ UInt32((UInt64(~(y & 1)) + 1) & matrixA) + state.pos = 0 + } + y = state.key[state.pos] + state.pos += 1 + + y ^= (y >> 11) + y ^= (y << 7) & 0x9d2c_5680 + y ^= (y << 15) & 0xefc6_0000 + y ^= (y >> 18) + + return y + } + + mutating func next() -> UInt64 { + let high = nextUInt32() + let low = nextUInt32() + return (UInt64(high) << 32) | UInt64(low) + } + + /// Generate next random double value + mutating func nextDouble() -> Double { + let a = next() + return Double(a & 9_007_199_254_740_991) * (1.0 / 9007199254740992.0) + } + + /// Generate next random float value + mutating func nextFloat() -> Float { + let a = nextUInt32() + return Float(a & 16_777_215) * (1.0 / 16777216.0) + } + + /// Generate next random value from a standard normal + mutating func nextGauss() -> Double { + if let nextGauss = state.nextGauss { + state.nextGauss = nil + return nextGauss + } + // Box-Muller transform + let u1: Double = nextDouble() + let u2: Double = 1 - nextDouble() + let radius = sqrt(-2.0 * log(u2)) + let theta = 2.0 * .pi * u1 + state.nextGauss = radius * sin(theta) + return radius * cos(theta) + } + + /// Generates an array of random values from a normal distribution with given mean and standard deviation. + /// This simulates torch.randn([1, 4, 64, 64], dtype=torch.float), note that for dtype=torch.double, it + /// will be slightly different. + mutating func normalArray(count: Int, mean: Double = 0.0, stdev: Double = 1.0) -> [Double] { + // If it is smaller than 16 elements, Torch generates from Box-Muller transform directly. + // Note that even if this is used to generate Float, it will use Double underneath. + guard count >= 16 else { + return (0.. MLShapedArray { + let count = shape.reduce(1, *) + return .init(scalars: normalArray(count: count, mean: mean, stdev: stdev), shape: shape) + } +} diff --git a/swift/StableDiffusionCLI/main.swift b/swift/StableDiffusionCLI/main.swift index c5ae31a..b9f7192 100644 --- a/swift/StableDiffusionCLI/main.swift +++ b/swift/StableDiffusionCLI/main.swift @@ -69,6 +69,9 @@ struct StableDiffusionSample: ParsableCommand { @Option(help: "Scheduler to use, one of {pndm, dpmpp}") var scheduler: SchedulerOption = .pndm + @Option(help: "Random number generator to use, one of {numpy, torch}") + var rng: RNGOption = .numpy + @Flag(help: "Disable safety checking") var disableSafety: Bool = false @@ -126,6 +129,7 @@ struct StableDiffusionSample: ParsableCommand { pipelineConfig.seed = seed pipelineConfig.guidanceScale = guidanceScale pipelineConfig.schedulerType = scheduler.stableDiffusionScheduler + pipelineConfig.rngType = rng.stableDiffusionRNG let images = try pipeline.generateImages( configuration: pipelineConfig, @@ -250,6 +254,17 @@ enum SchedulerOption: String, ExpressibleByArgument { } } +@available(iOS 16.2, macOS 13.1, *) +enum RNGOption: String, ExpressibleByArgument { + case numpy, torch + var stableDiffusionRNG: StableDiffusionRNG { + switch self { + case .numpy: return .numpyRNG + case .torch: return .torchRNG + } + } +} + if #available(iOS 16.2, macOS 13.1, *) { StableDiffusionSample.main() } else {