Implement DPM-Solver++ scheduler (#59)

pull/66/head
Pedro Cuenca 2 years ago committed by GitHub
parent d6a54fc2b0
commit 3758abc652
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -0,0 +1,182 @@
// For licensing see accompanying LICENSE.md file.
// Copyright (C) 2022 Apple Inc. and The HuggingFace Team. All Rights Reserved.
import Accelerate
import CoreML
/// A scheduler used to compute a de-noised image
///
/// This implementation matches:
/// [Hugging Face Diffusers DPMSolverMultistepScheduler](https://github.com/huggingface/diffusers/blob/main/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py)
///
/// It uses the DPM-Solver++ algorithm: [code](https://github.com/LuChengTHU/dpm-solver) [paper](https://arxiv.org/abs/2211.01095).
/// Limitations:
/// - Only implemented for DPM-Solver++ algorithm (not DPM-Solver).
/// - Second order only.
/// - Assumes the model predicts epsilon.
/// - No dynamic thresholding.
/// - `midpoint` solver algorithm.
@available(iOS 16.2, macOS 13.1, *)
public final class DPMSolverMultistepScheduler: Scheduler {
public let trainStepCount: Int
public let inferenceStepCount: Int
public let betas: [Float]
public let alphas: [Float]
public let alphasCumProd: [Float]
public let timeSteps: [Int]
public let alpha_t: [Float]
public let sigma_t: [Float]
public let lambda_t: [Float]
public let solverOrder = 2
private(set) var lowerOrderStepped = 0
/// Whether to use lower-order solvers in the final steps. Only valid for less than 15 inference steps.
/// We empirically find this trick can stabilize the sampling of DPM-Solver, especially with 10 or fewer steps.
public let useLowerOrderFinal = true
// Stores solverOrder (2) items
private(set) var modelOutputs: [MLShapedArray<Float32>] = []
/// Create a scheduler that uses a second order DPM-Solver++ algorithm.
///
/// - Parameters:
/// - stepCount: Number of inference steps to schedule
/// - trainStepCount: Number of training diffusion steps
/// - betaSchedule: Method to schedule betas from betaStart to betaEnd
/// - betaStart: The starting value of beta for inference
/// - betaEnd: The end value for beta for inference
/// - Returns: A scheduler ready for its first step
public init(
stepCount: Int = 50,
trainStepCount: Int = 1000,
betaSchedule: BetaSchedule = .scaledLinear,
betaStart: Float = 0.00085,
betaEnd: Float = 0.012
) {
self.trainStepCount = trainStepCount
self.inferenceStepCount = stepCount
switch betaSchedule {
case .linear:
self.betas = linspace(betaStart, betaEnd, trainStepCount)
case .scaledLinear:
self.betas = linspace(pow(betaStart, 0.5), pow(betaEnd, 0.5), trainStepCount).map({ $0 * $0 })
}
self.alphas = betas.map({ 1.0 - $0 })
var alphasCumProd = self.alphas
for i in 1..<alphasCumProd.count {
alphasCumProd[i] *= alphasCumProd[i - 1]
}
self.alphasCumProd = alphasCumProd
// Currently we only support VP-type noise shedule
self.alpha_t = vForce.sqrt(self.alphasCumProd)
self.sigma_t = vForce.sqrt(vDSP.subtract([Float](repeating: 1, count: self.alphasCumProd.count), self.alphasCumProd))
self.lambda_t = zip(self.alpha_t, self.sigma_t).map { α, σ in log(α) - log(σ) }
self.timeSteps = linspace(0, Float(self.trainStepCount-1), stepCount).reversed().map { Int(round($0)) }
}
/// Convert the model output to the corresponding type the algorithm needs.
/// This implementation is for second-order DPM-Solver++ assuming epsilon prediction.
func convertModelOutput(modelOutput: MLShapedArray<Float32>, timestep: Int, sample: MLShapedArray<Float32>) -> MLShapedArray<Float32> {
assert(modelOutput.scalars.count == sample.scalars.count)
let (alpha_t, sigma_t) = (self.alpha_t[timestep], self.sigma_t[timestep])
// This could be optimized with a Metal kernel if we find we need to
let x0_scalars = zip(modelOutput.scalars, sample.scalars).map { m, s in
(s - m * sigma_t) / alpha_t
}
return MLShapedArray(scalars: x0_scalars, shape: modelOutput.shape)
}
/// One step for the first-order DPM-Solver (equivalent to DDIM).
/// See https://arxiv.org/abs/2206.00927 for the detailed derivation.
/// var names and code structure mostly follow https://github.com/huggingface/diffusers/blob/main/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py
func firstOrderUpdate(
modelOutput: MLShapedArray<Float32>,
timestep: Int,
prevTimestep: Int,
sample: MLShapedArray<Float32>
) -> MLShapedArray<Float32> {
let (p_lambda_t, lambda_s) = (Double(lambda_t[prevTimestep]), Double(lambda_t[timestep]))
let p_alpha_t = Double(alpha_t[prevTimestep])
let (p_sigma_t, sigma_s) = (Double(sigma_t[prevTimestep]), Double(sigma_t[timestep]))
let h = p_lambda_t - lambda_s
// x_t = (sigma_t / sigma_s) * sample - (alpha_t * (torch.exp(-h) - 1.0)) * model_output
let x_t = weightedSum(
[p_sigma_t / sigma_s, -p_alpha_t * (exp(-h) - 1)],
[sample, modelOutput]
)
return x_t
}
/// One step for the second-order multistep DPM-Solver++ algorithm, using the midpoint method.
/// var names and code structure mostly follow https://github.com/huggingface/diffusers/blob/main/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py
func secondOrderUpdate(
modelOutputs: [MLShapedArray<Float32>],
timesteps: [Int],
prevTimestep t: Int,
sample: MLShapedArray<Float32>
) -> MLShapedArray<Float32> {
let (s0, s1) = (timesteps[back: 1], timesteps[back: 2])
let (m0, m1) = (modelOutputs[back: 1], modelOutputs[back: 2])
let (p_lambda_t, lambda_s0, lambda_s1) = (Double(lambda_t[t]), Double(lambda_t[s0]), Double(lambda_t[s1]))
let p_alpha_t = Double(alpha_t[t])
let (p_sigma_t, sigma_s0) = (Double(sigma_t[t]), Double(sigma_t[s0]))
let (h, h_0) = (p_lambda_t - lambda_s0, lambda_s0 - lambda_s1)
let r0 = h_0 / h
let D0 = m0
// D1 = (1.0 / r0) * (m0 - m1)
let D1 = weightedSum(
[1/r0, -1/r0],
[m0, m1]
)
// See https://arxiv.org/abs/2211.01095 for detailed derivations
// x_t = (
// (sigma_t / sigma_s0) * sample
// - (alpha_t * (torch.exp(-h) - 1.0)) * D0
// - 0.5 * (alpha_t * (torch.exp(-h) - 1.0)) * D1
// )
let x_t = weightedSum(
[p_sigma_t/sigma_s0, -p_alpha_t * (exp(-h) - 1), -0.5 * p_alpha_t * (exp(-h) - 1)],
[sample, D0, D1]
)
return x_t
}
public func step(output: MLShapedArray<Float32>, timeStep t: Int, sample: MLShapedArray<Float32>) -> MLShapedArray<Float32> {
let stepIndex = timeSteps.firstIndex(of: t) ?? timeSteps.count - 1
let prevTimestep = stepIndex == timeSteps.count - 1 ? 0 : timeSteps[stepIndex + 1]
let lowerOrderFinal = useLowerOrderFinal && stepIndex == timeSteps.count - 1 && timeSteps.count < 15
let lowerOrderSecond = useLowerOrderFinal && stepIndex == timeSteps.count - 2 && timeSteps.count < 15
let lowerOrder = lowerOrderStepped < 1 || lowerOrderFinal || lowerOrderSecond
let modelOutput = convertModelOutput(modelOutput: output, timestep: t, sample: sample)
if modelOutputs.count == solverOrder { modelOutputs.removeFirst() }
modelOutputs.append(modelOutput)
let prevSample: MLShapedArray<Float32>
if lowerOrder {
prevSample = firstOrderUpdate(modelOutput: modelOutput, timestep: t, prevTimestep: prevTimestep, sample: sample)
} else {
prevSample = secondOrderUpdate(
modelOutputs: modelOutputs,
timesteps: [timeSteps[stepIndex - 1], t],
prevTimestep: prevTimestep,
sample: sample
)
}
if lowerOrderStepped < solverOrder {
lowerOrderStepped += 1
}
return prevSample
}
}

@ -3,34 +3,98 @@
import CoreML
/// A scheduler used to compute a de-noised image
///
/// This implementation matches:
/// [Hugging Face Diffusers PNDMScheduler](https://github.com/huggingface/diffusers/blob/main/src/diffusers/schedulers/scheduling_pndm.py)
///
/// It uses the pseudo linear multi-step (PLMS) method only, skipping pseudo Runge-Kutta (PRK) steps
@available(iOS 16.2, macOS 13.1, *)
public final class Scheduler {
public protocol Scheduler {
/// Number of diffusion steps performed during training
public let trainStepCount: Int
var trainStepCount: Int { get }
/// Number of inference steps to be performed
public let inferenceStepCount: Int
var inferenceStepCount: Int { get }
/// Training diffusion time steps index by inference time step
public let timeSteps: [Int]
var timeSteps: [Int] { get }
/// Schedule of betas which controls the amount of noise added at each timestep
public let betas: [Float]
var betas: [Float] { get }
/// 1 - betas
let alphas: [Float]
var alphas: [Float] { get }
/// Cached cumulative product of alphas
let alphasCumProd: [Float]
var alphasCumProd: [Float] { get }
/// Standard deviation of the initial noise distribution
public let initNoiseSigma: Float
var initNoiseSigma: Float { get }
/// Compute a de-noised image sample and step scheduler state
///
/// - Parameters:
/// - output: The predicted residual noise output of learned diffusion model
/// - timeStep: The current time step in the diffusion chain
/// - sample: The current input sample to the diffusion model
/// - Returns: Predicted de-noised sample at the previous time step
/// - Postcondition: The scheduler state is updated.
/// The state holds the current sample and history of model output noise residuals
func step(
output: MLShapedArray<Float32>,
timeStep t: Int,
sample s: MLShapedArray<Float32>
) -> MLShapedArray<Float32>
}
@available(iOS 16.2, macOS 13.1, *)
public extension Scheduler {
var initNoiseSigma: Float { 1 }
}
@available(iOS 16.2, macOS 13.1, *)
public extension Scheduler {
/// Compute weighted sum of shaped arrays of equal shapes
///
/// - Parameters:
/// - weights: The weights each array is multiplied by
/// - values: The arrays to be weighted and summed
/// - Returns: sum_i weights[i]*values[i]
func weightedSum(_ weights: [Double], _ values: [MLShapedArray<Float32>]) -> MLShapedArray<Float32> {
assert(weights.count > 1 && values.count == weights.count)
assert(values.allSatisfy({ $0.scalarCount == values.first!.scalarCount }))
var w = Float(weights.first!)
var scalars = values.first!.scalars.map({ $0 * w })
for next in 1 ..< values.count {
w = Float(weights[next])
let nextScalars = values[next].scalars
for i in 0 ..< scalars.count {
scalars[i] += w * nextScalars[i]
}
}
return MLShapedArray(scalars: scalars, shape: values.first!.shape)
}
}
/// How to map a beta range to a sequence of betas to step over
@available(iOS 16.2, macOS 13.1, *)
public enum BetaSchedule {
/// Linear stepping between start and end
case linear
/// Steps using linspace(sqrt(start),sqrt(end))^2
case scaledLinear
}
/// A scheduler used to compute a de-noised image
///
/// This implementation matches:
/// [Hugging Face Diffusers PNDMScheduler](https://github.com/huggingface/diffusers/blob/main/src/diffusers/schedulers/scheduling_pndm.py)
///
/// This scheduler uses the pseudo linear multi-step (PLMS) method only, skipping pseudo Runge-Kutta (PRK) steps
@available(iOS 16.2, macOS 13.1, *)
public final class PNDMScheduler: Scheduler {
public let trainStepCount: Int
public let inferenceStepCount: Int
public let betas: [Float]
public let alphas: [Float]
public let alphasCumProd: [Float]
public let timeSteps: [Int]
// Internal state
var counter: Int
@ -62,15 +126,12 @@ public final class Scheduler {
case .scaledLinear:
self.betas = linspace(pow(betaStart, 0.5), pow(betaEnd, 0.5), trainStepCount).map({ $0 * $0 })
}
self.alphas = betas.map({ 1.0 - $0 })
self.initNoiseSigma = 1.0
var alphasCumProd = self.alphas
for i in 1..<alphasCumProd.count {
alphasCumProd[i] *= alphasCumProd[i - 1]
}
self.alphasCumProd = alphasCumProd
let stepsOffset = 1 // For stable diffusion
let stepRatio = Float(trainStepCount / stepCount )
let forwardSteps = (0..<stepCount).map {
@ -152,27 +213,6 @@ public final class Scheduler {
return prevSample
}
/// Compute weighted sum of shaped arrays of equal shapes
///
/// - Parameters:
/// - weights: The weights each array is multiplied by
/// - values: The arrays to be weighted and summed
/// - Returns: sum_i weights[i]*values[i]
func weightedSum(_ weights: [Double], _ values: [MLShapedArray<Float32>]) -> MLShapedArray<Float32> {
assert(weights.count > 1 && values.count == weights.count)
assert(values.allSatisfy({$0.scalarCount == values.first!.scalarCount}))
var w = Float(weights.first!)
var scalars = values.first!.scalars.map({ $0 * w })
for next in 1 ..< values.count {
w = Float(weights[next])
let nextScalars = values[next].scalars
for i in 0 ..< scalars.count {
scalars[i] += w * nextScalars[i]
}
}
return MLShapedArray(scalars: scalars, shape: values.first!.shape)
}
/// Compute sample (denoised image) at previous step given a current time step
///
/// - Parameters:
@ -225,17 +265,6 @@ public final class Scheduler {
}
}
@available(iOS 16.2, macOS 13.1, *)
extension Scheduler {
/// How to map a beta range to a sequence of betas to step over
public enum BetaSchedule {
/// Linear stepping between start and end
case linear
/// Steps using linspace(sqrt(start),sqrt(end))^2
case scaledLinear
}
}
/// Evenly spaced floats between specified interval
///
/// - Parameters:

@ -6,6 +6,14 @@ import CoreML
import Accelerate
import CoreGraphics
/// Schedulers compatible with StableDiffusionPipeline
public enum StableDiffusionScheduler {
/// Scheduler that uses a pseudo-linear multi-step (PLMS) method
case pndmScheduler
/// Scheduler that uses a second order DPM-Solver++ algorithm
case dpmSolverMultistepScheduler
}
/// A pipeline used to generate image samples from text input using stable diffusion
///
/// This implementation matches:
@ -113,6 +121,7 @@ public struct StableDiffusionPipeline: ResourceManaging {
stepCount: Int = 50,
seed: Int = 0,
disableSafety: Bool = false,
scheduler: StableDiffusionScheduler = .pndmScheduler,
progressHandler: (Progress) -> Bool = { _ in true }
) throws -> [CGImage?] {
@ -133,7 +142,12 @@ public struct StableDiffusionPipeline: ResourceManaging {
let hiddenStates = toHiddenStates(concatEmbedding)
/// Setup schedulers
let scheduler = (0..<imageCount).map { _ in Scheduler(stepCount: stepCount) }
let scheduler: [Scheduler] = (0..<imageCount).map { _ in
switch scheduler {
case .pndmScheduler: return PNDMScheduler(stepCount: stepCount)
case .dpmSolverMultistepScheduler: return DPMSolverMultistepScheduler(stepCount: stepCount)
}
}
let stdev = scheduler[0].initNoiseSigma
// Generate random latent samples from specified seed

@ -53,6 +53,9 @@ struct StableDiffusionSample: ParsableCommand {
@Option(help: "Compute units to load model with {all,cpuOnly,cpuAndGPU,cpuAndNeuralEngine}")
var computeUnits: ComputeUnits = .all
@Option(help: "Scheduler to use, one of {pndm, dpmpp}")
var scheduler: SchedulerOption = .pndm
@Flag(help: "Disable safety checking")
var disableSafety: Bool = false
@ -84,7 +87,8 @@ struct StableDiffusionSample: ParsableCommand {
prompt: prompt,
imageCount: imageCount,
stepCount: stepCount,
seed: seed
seed: seed,
scheduler: scheduler.stableDiffusionScheduler
) { progress in
sampleTimer.stop()
handleProgress(progress,sampleTimer)
@ -190,6 +194,17 @@ enum ComputeUnits: String, ExpressibleByArgument, CaseIterable {
}
}
@available(iOS 16.2, macOS 13.1, *)
enum SchedulerOption: String, ExpressibleByArgument {
case pndm, dpmpp
var stableDiffusionScheduler: StableDiffusionScheduler {
switch self {
case .pndm: return .pndmScheduler
case .dpmpp: return .dpmSolverMultistepScheduler
}
}
}
if #available(iOS 16.2, macOS 13.1, *) {
StableDiffusionSample.main()
} else {

Loading…
Cancel
Save