// For licensing see accompanying LICENSE.md file. // Copyright (C) 2022 Apple Inc. and The HuggingFace Team. All Rights Reserved. import Accelerate import CoreML /// A scheduler used to compute a de-noised image /// /// This implementation matches: /// [Hugging Face Diffusers DPMSolverMultistepScheduler](https://github.com/huggingface/diffusers/blob/main/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py) /// /// It uses the DPM-Solver++ algorithm: [code](https://github.com/LuChengTHU/dpm-solver) [paper](https://arxiv.org/abs/2211.01095). /// Limitations: /// - Only implemented for DPM-Solver++ algorithm (not DPM-Solver). /// - Second order only. /// - Assumes the model predicts epsilon. /// - No dynamic thresholding. /// - `midpoint` solver algorithm. @available(iOS 16.2, macOS 13.1, *) public final class DPMSolverMultistepScheduler: Scheduler { public let trainStepCount: Int public let inferenceStepCount: Int public let betas: [Float] public let alphas: [Float] public let alphasCumProd: [Float] public let timeSteps: [Int] public let alpha_t: [Float] public let sigma_t: [Float] public let lambda_t: [Float] public let solverOrder = 2 private(set) var lowerOrderStepped = 0 /// Whether to use lower-order solvers in the final steps. Only valid for less than 15 inference steps. /// We empirically find this trick can stabilize the sampling of DPM-Solver, especially with 10 or fewer steps. public let useLowerOrderFinal = true // Stores solverOrder (2) items private(set) var modelOutputs: [MLShapedArray] = [] /// Create a scheduler that uses a second order DPM-Solver++ algorithm. /// /// - Parameters: /// - stepCount: Number of inference steps to schedule /// - trainStepCount: Number of training diffusion steps /// - betaSchedule: Method to schedule betas from betaStart to betaEnd /// - betaStart: The starting value of beta for inference /// - betaEnd: The end value for beta for inference /// - Returns: A scheduler ready for its first step public init( stepCount: Int = 50, trainStepCount: Int = 1000, betaSchedule: BetaSchedule = .scaledLinear, betaStart: Float = 0.00085, betaEnd: Float = 0.012 ) { self.trainStepCount = trainStepCount self.inferenceStepCount = stepCount switch betaSchedule { case .linear: self.betas = linspace(betaStart, betaEnd, trainStepCount) case .scaledLinear: self.betas = linspace(pow(betaStart, 0.5), pow(betaEnd, 0.5), trainStepCount).map({ $0 * $0 }) } self.alphas = betas.map({ 1.0 - $0 }) var alphasCumProd = self.alphas for i in 1.., timestep: Int, sample: MLShapedArray) -> MLShapedArray { assert(modelOutput.scalars.count == sample.scalars.count) let (alpha_t, sigma_t) = (self.alpha_t[timestep], self.sigma_t[timestep]) // This could be optimized with a Metal kernel if we find we need to let x0_scalars = zip(modelOutput.scalars, sample.scalars).map { m, s in (s - m * sigma_t) / alpha_t } return MLShapedArray(scalars: x0_scalars, shape: modelOutput.shape) } /// One step for the first-order DPM-Solver (equivalent to DDIM). /// See https://arxiv.org/abs/2206.00927 for the detailed derivation. /// var names and code structure mostly follow https://github.com/huggingface/diffusers/blob/main/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py func firstOrderUpdate( modelOutput: MLShapedArray, timestep: Int, prevTimestep: Int, sample: MLShapedArray ) -> MLShapedArray { let (p_lambda_t, lambda_s) = (Double(lambda_t[prevTimestep]), Double(lambda_t[timestep])) let p_alpha_t = Double(alpha_t[prevTimestep]) let (p_sigma_t, sigma_s) = (Double(sigma_t[prevTimestep]), Double(sigma_t[timestep])) let h = p_lambda_t - lambda_s // x_t = (sigma_t / sigma_s) * sample - (alpha_t * (torch.exp(-h) - 1.0)) * model_output let x_t = weightedSum( [p_sigma_t / sigma_s, -p_alpha_t * (exp(-h) - 1)], [sample, modelOutput] ) return x_t } /// One step for the second-order multistep DPM-Solver++ algorithm, using the midpoint method. /// var names and code structure mostly follow https://github.com/huggingface/diffusers/blob/main/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py func secondOrderUpdate( modelOutputs: [MLShapedArray], timesteps: [Int], prevTimestep t: Int, sample: MLShapedArray ) -> MLShapedArray { let (s0, s1) = (timesteps[back: 1], timesteps[back: 2]) let (m0, m1) = (modelOutputs[back: 1], modelOutputs[back: 2]) let (p_lambda_t, lambda_s0, lambda_s1) = (Double(lambda_t[t]), Double(lambda_t[s0]), Double(lambda_t[s1])) let p_alpha_t = Double(alpha_t[t]) let (p_sigma_t, sigma_s0) = (Double(sigma_t[t]), Double(sigma_t[s0])) let (h, h_0) = (p_lambda_t - lambda_s0, lambda_s0 - lambda_s1) let r0 = h_0 / h let D0 = m0 // D1 = (1.0 / r0) * (m0 - m1) let D1 = weightedSum( [1/r0, -1/r0], [m0, m1] ) // See https://arxiv.org/abs/2211.01095 for detailed derivations // x_t = ( // (sigma_t / sigma_s0) * sample // - (alpha_t * (torch.exp(-h) - 1.0)) * D0 // - 0.5 * (alpha_t * (torch.exp(-h) - 1.0)) * D1 // ) let x_t = weightedSum( [p_sigma_t/sigma_s0, -p_alpha_t * (exp(-h) - 1), -0.5 * p_alpha_t * (exp(-h) - 1)], [sample, D0, D1] ) return x_t } public func step(output: MLShapedArray, timeStep t: Int, sample: MLShapedArray) -> MLShapedArray { let stepIndex = timeSteps.firstIndex(of: t) ?? timeSteps.count - 1 let prevTimestep = stepIndex == timeSteps.count - 1 ? 0 : timeSteps[stepIndex + 1] let lowerOrderFinal = useLowerOrderFinal && stepIndex == timeSteps.count - 1 && timeSteps.count < 15 let lowerOrderSecond = useLowerOrderFinal && stepIndex == timeSteps.count - 2 && timeSteps.count < 15 let lowerOrder = lowerOrderStepped < 1 || lowerOrderFinal || lowerOrderSecond let modelOutput = convertModelOutput(modelOutput: output, timestep: t, sample: sample) if modelOutputs.count == solverOrder { modelOutputs.removeFirst() } modelOutputs.append(modelOutput) let prevSample: MLShapedArray if lowerOrder { prevSample = firstOrderUpdate(modelOutput: modelOutput, timestep: t, prevTimestep: prevTimestep, sample: sample) } else { prevSample = secondOrderUpdate( modelOutputs: modelOutputs, timesteps: [timeSteps[stepIndex - 1], t], prevTimestep: prevTimestep, sample: sample ) } if lowerOrderStepped < solverOrder { lowerOrderStepped += 1 } return prevSample } }