diff --git a/swift/StableDiffusion/pipeline/DPMSolverMultistepScheduler.swift b/swift/StableDiffusion/pipeline/DPMSolverMultistepScheduler.swift new file mode 100644 index 0000000..704472a --- /dev/null +++ b/swift/StableDiffusion/pipeline/DPMSolverMultistepScheduler.swift @@ -0,0 +1,182 @@ +// For licensing see accompanying LICENSE.md file. +// Copyright (C) 2022 Apple Inc. and The HuggingFace Team. All Rights Reserved. + +import Accelerate +import CoreML + +/// A scheduler used to compute a de-noised image +/// +/// This implementation matches: +/// [Hugging Face Diffusers DPMSolverMultistepScheduler](https://github.com/huggingface/diffusers/blob/main/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py) +/// +/// It uses the DPM-Solver++ algorithm: [code](https://github.com/LuChengTHU/dpm-solver) [paper](https://arxiv.org/abs/2211.01095). +/// Limitations: +/// - Only implemented for DPM-Solver++ algorithm (not DPM-Solver). +/// - Second order only. +/// - Assumes the model predicts epsilon. +/// - No dynamic thresholding. +/// - `midpoint` solver algorithm. +@available(iOS 16.2, macOS 13.1, *) +public final class DPMSolverMultistepScheduler: Scheduler { + public let trainStepCount: Int + public let inferenceStepCount: Int + public let betas: [Float] + public let alphas: [Float] + public let alphasCumProd: [Float] + public let timeSteps: [Int] + + public let alpha_t: [Float] + public let sigma_t: [Float] + public let lambda_t: [Float] + + public let solverOrder = 2 + private(set) var lowerOrderStepped = 0 + + /// Whether to use lower-order solvers in the final steps. Only valid for less than 15 inference steps. + /// We empirically find this trick can stabilize the sampling of DPM-Solver, especially with 10 or fewer steps. + public let useLowerOrderFinal = true + + // Stores solverOrder (2) items + private(set) var modelOutputs: [MLShapedArray] = [] + + /// Create a scheduler that uses a second order DPM-Solver++ algorithm. + /// + /// - Parameters: + /// - stepCount: Number of inference steps to schedule + /// - trainStepCount: Number of training diffusion steps + /// - betaSchedule: Method to schedule betas from betaStart to betaEnd + /// - betaStart: The starting value of beta for inference + /// - betaEnd: The end value for beta for inference + /// - Returns: A scheduler ready for its first step + public init( + stepCount: Int = 50, + trainStepCount: Int = 1000, + betaSchedule: BetaSchedule = .scaledLinear, + betaStart: Float = 0.00085, + betaEnd: Float = 0.012 + ) { + self.trainStepCount = trainStepCount + self.inferenceStepCount = stepCount + + switch betaSchedule { + case .linear: + self.betas = linspace(betaStart, betaEnd, trainStepCount) + case .scaledLinear: + self.betas = linspace(pow(betaStart, 0.5), pow(betaEnd, 0.5), trainStepCount).map({ $0 * $0 }) + } + + self.alphas = betas.map({ 1.0 - $0 }) + var alphasCumProd = self.alphas + for i in 1.., timestep: Int, sample: MLShapedArray) -> MLShapedArray { + assert(modelOutput.scalars.count == sample.scalars.count) + let (alpha_t, sigma_t) = (self.alpha_t[timestep], self.sigma_t[timestep]) + + // This could be optimized with a Metal kernel if we find we need to + let x0_scalars = zip(modelOutput.scalars, sample.scalars).map { m, s in + (s - m * sigma_t) / alpha_t + } + return MLShapedArray(scalars: x0_scalars, shape: modelOutput.shape) + } + + /// One step for the first-order DPM-Solver (equivalent to DDIM). + /// See https://arxiv.org/abs/2206.00927 for the detailed derivation. + /// var names and code structure mostly follow https://github.com/huggingface/diffusers/blob/main/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py + func firstOrderUpdate( + modelOutput: MLShapedArray, + timestep: Int, + prevTimestep: Int, + sample: MLShapedArray + ) -> MLShapedArray { + let (p_lambda_t, lambda_s) = (Double(lambda_t[prevTimestep]), Double(lambda_t[timestep])) + let p_alpha_t = Double(alpha_t[prevTimestep]) + let (p_sigma_t, sigma_s) = (Double(sigma_t[prevTimestep]), Double(sigma_t[timestep])) + let h = p_lambda_t - lambda_s + // x_t = (sigma_t / sigma_s) * sample - (alpha_t * (torch.exp(-h) - 1.0)) * model_output + let x_t = weightedSum( + [p_sigma_t / sigma_s, -p_alpha_t * (exp(-h) - 1)], + [sample, modelOutput] + ) + return x_t + } + + /// One step for the second-order multistep DPM-Solver++ algorithm, using the midpoint method. + /// var names and code structure mostly follow https://github.com/huggingface/diffusers/blob/main/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py + func secondOrderUpdate( + modelOutputs: [MLShapedArray], + timesteps: [Int], + prevTimestep t: Int, + sample: MLShapedArray + ) -> MLShapedArray { + let (s0, s1) = (timesteps[back: 1], timesteps[back: 2]) + let (m0, m1) = (modelOutputs[back: 1], modelOutputs[back: 2]) + let (p_lambda_t, lambda_s0, lambda_s1) = (Double(lambda_t[t]), Double(lambda_t[s0]), Double(lambda_t[s1])) + let p_alpha_t = Double(alpha_t[t]) + let (p_sigma_t, sigma_s0) = (Double(sigma_t[t]), Double(sigma_t[s0])) + let (h, h_0) = (p_lambda_t - lambda_s0, lambda_s0 - lambda_s1) + let r0 = h_0 / h + let D0 = m0 + + // D1 = (1.0 / r0) * (m0 - m1) + let D1 = weightedSum( + [1/r0, -1/r0], + [m0, m1] + ) + + // See https://arxiv.org/abs/2211.01095 for detailed derivations + // x_t = ( + // (sigma_t / sigma_s0) * sample + // - (alpha_t * (torch.exp(-h) - 1.0)) * D0 + // - 0.5 * (alpha_t * (torch.exp(-h) - 1.0)) * D1 + // ) + let x_t = weightedSum( + [p_sigma_t/sigma_s0, -p_alpha_t * (exp(-h) - 1), -0.5 * p_alpha_t * (exp(-h) - 1)], + [sample, D0, D1] + ) + return x_t + } + + public func step(output: MLShapedArray, timeStep t: Int, sample: MLShapedArray) -> MLShapedArray { + let stepIndex = timeSteps.firstIndex(of: t) ?? timeSteps.count - 1 + let prevTimestep = stepIndex == timeSteps.count - 1 ? 0 : timeSteps[stepIndex + 1] + + let lowerOrderFinal = useLowerOrderFinal && stepIndex == timeSteps.count - 1 && timeSteps.count < 15 + let lowerOrderSecond = useLowerOrderFinal && stepIndex == timeSteps.count - 2 && timeSteps.count < 15 + let lowerOrder = lowerOrderStepped < 1 || lowerOrderFinal || lowerOrderSecond + + let modelOutput = convertModelOutput(modelOutput: output, timestep: t, sample: sample) + if modelOutputs.count == solverOrder { modelOutputs.removeFirst() } + modelOutputs.append(modelOutput) + + let prevSample: MLShapedArray + if lowerOrder { + prevSample = firstOrderUpdate(modelOutput: modelOutput, timestep: t, prevTimestep: prevTimestep, sample: sample) + } else { + prevSample = secondOrderUpdate( + modelOutputs: modelOutputs, + timesteps: [timeSteps[stepIndex - 1], t], + prevTimestep: prevTimestep, + sample: sample + ) + } + if lowerOrderStepped < solverOrder { + lowerOrderStepped += 1 + } + + return prevSample + } +} diff --git a/swift/StableDiffusion/pipeline/Scheduler.swift b/swift/StableDiffusion/pipeline/Scheduler.swift index 8ddabf0..0bd9284 100644 --- a/swift/StableDiffusion/pipeline/Scheduler.swift +++ b/swift/StableDiffusion/pipeline/Scheduler.swift @@ -3,34 +3,98 @@ import CoreML -/// A scheduler used to compute a de-noised image -/// -/// This implementation matches: -/// [Hugging Face Diffusers PNDMScheduler](https://github.com/huggingface/diffusers/blob/main/src/diffusers/schedulers/scheduling_pndm.py) -/// -/// It uses the pseudo linear multi-step (PLMS) method only, skipping pseudo Runge-Kutta (PRK) steps @available(iOS 16.2, macOS 13.1, *) -public final class Scheduler { +public protocol Scheduler { /// Number of diffusion steps performed during training - public let trainStepCount: Int + var trainStepCount: Int { get } /// Number of inference steps to be performed - public let inferenceStepCount: Int + var inferenceStepCount: Int { get } /// Training diffusion time steps index by inference time step - public let timeSteps: [Int] + var timeSteps: [Int] { get } /// Schedule of betas which controls the amount of noise added at each timestep - public let betas: [Float] + var betas: [Float] { get } /// 1 - betas - let alphas: [Float] + var alphas: [Float] { get } /// Cached cumulative product of alphas - let alphasCumProd: [Float] + var alphasCumProd: [Float] { get } /// Standard deviation of the initial noise distribution - public let initNoiseSigma: Float + var initNoiseSigma: Float { get } + + /// Compute a de-noised image sample and step scheduler state + /// + /// - Parameters: + /// - output: The predicted residual noise output of learned diffusion model + /// - timeStep: The current time step in the diffusion chain + /// - sample: The current input sample to the diffusion model + /// - Returns: Predicted de-noised sample at the previous time step + /// - Postcondition: The scheduler state is updated. + /// The state holds the current sample and history of model output noise residuals + func step( + output: MLShapedArray, + timeStep t: Int, + sample s: MLShapedArray + ) -> MLShapedArray +} + +@available(iOS 16.2, macOS 13.1, *) +public extension Scheduler { + var initNoiseSigma: Float { 1 } +} + +@available(iOS 16.2, macOS 13.1, *) +public extension Scheduler { + /// Compute weighted sum of shaped arrays of equal shapes + /// + /// - Parameters: + /// - weights: The weights each array is multiplied by + /// - values: The arrays to be weighted and summed + /// - Returns: sum_i weights[i]*values[i] + func weightedSum(_ weights: [Double], _ values: [MLShapedArray]) -> MLShapedArray { + assert(weights.count > 1 && values.count == weights.count) + assert(values.allSatisfy({ $0.scalarCount == values.first!.scalarCount })) + var w = Float(weights.first!) + var scalars = values.first!.scalars.map({ $0 * w }) + for next in 1 ..< values.count { + w = Float(weights[next]) + let nextScalars = values[next].scalars + for i in 0 ..< scalars.count { + scalars[i] += w * nextScalars[i] + } + } + return MLShapedArray(scalars: scalars, shape: values.first!.shape) + } +} + +/// How to map a beta range to a sequence of betas to step over +@available(iOS 16.2, macOS 13.1, *) +public enum BetaSchedule { + /// Linear stepping between start and end + case linear + /// Steps using linspace(sqrt(start),sqrt(end))^2 + case scaledLinear +} + + +/// A scheduler used to compute a de-noised image +/// +/// This implementation matches: +/// [Hugging Face Diffusers PNDMScheduler](https://github.com/huggingface/diffusers/blob/main/src/diffusers/schedulers/scheduling_pndm.py) +/// +/// This scheduler uses the pseudo linear multi-step (PLMS) method only, skipping pseudo Runge-Kutta (PRK) steps +@available(iOS 16.2, macOS 13.1, *) +public final class PNDMScheduler: Scheduler { + public let trainStepCount: Int + public let inferenceStepCount: Int + public let betas: [Float] + public let alphas: [Float] + public let alphasCumProd: [Float] + public let timeSteps: [Int] // Internal state var counter: Int @@ -62,15 +126,12 @@ public final class Scheduler { case .scaledLinear: self.betas = linspace(pow(betaStart, 0.5), pow(betaEnd, 0.5), trainStepCount).map({ $0 * $0 }) } - self.alphas = betas.map({ 1.0 - $0 }) - self.initNoiseSigma = 1.0 var alphasCumProd = self.alphas for i in 1..]) -> MLShapedArray { - assert(weights.count > 1 && values.count == weights.count) - assert(values.allSatisfy({$0.scalarCount == values.first!.scalarCount})) - var w = Float(weights.first!) - var scalars = values.first!.scalars.map({ $0 * w }) - for next in 1 ..< values.count { - w = Float(weights[next]) - let nextScalars = values[next].scalars - for i in 0 ..< scalars.count { - scalars[i] += w * nextScalars[i] - } - } - return MLShapedArray(scalars: scalars, shape: values.first!.shape) - } - /// Compute sample (denoised image) at previous step given a current time step /// /// - Parameters: @@ -225,17 +265,6 @@ public final class Scheduler { } } -@available(iOS 16.2, macOS 13.1, *) -extension Scheduler { - /// How to map a beta range to a sequence of betas to step over - public enum BetaSchedule { - /// Linear stepping between start and end - case linear - /// Steps using linspace(sqrt(start),sqrt(end))^2 - case scaledLinear - } -} - /// Evenly spaced floats between specified interval /// /// - Parameters: diff --git a/swift/StableDiffusion/pipeline/StableDiffusionPipeline.swift b/swift/StableDiffusion/pipeline/StableDiffusionPipeline.swift index 2d55572..eddbda4 100644 --- a/swift/StableDiffusion/pipeline/StableDiffusionPipeline.swift +++ b/swift/StableDiffusion/pipeline/StableDiffusionPipeline.swift @@ -6,6 +6,14 @@ import CoreML import Accelerate import CoreGraphics +/// Schedulers compatible with StableDiffusionPipeline +public enum StableDiffusionScheduler { + /// Scheduler that uses a pseudo-linear multi-step (PLMS) method + case pndmScheduler + /// Scheduler that uses a second order DPM-Solver++ algorithm + case dpmSolverMultistepScheduler +} + /// A pipeline used to generate image samples from text input using stable diffusion /// /// This implementation matches: @@ -113,6 +121,7 @@ public struct StableDiffusionPipeline: ResourceManaging { stepCount: Int = 50, seed: Int = 0, disableSafety: Bool = false, + scheduler: StableDiffusionScheduler = .pndmScheduler, progressHandler: (Progress) -> Bool = { _ in true } ) throws -> [CGImage?] { @@ -133,7 +142,12 @@ public struct StableDiffusionPipeline: ResourceManaging { let hiddenStates = toHiddenStates(concatEmbedding) /// Setup schedulers - let scheduler = (0..