Spell out scheduler names in the enum.

pull/59/head
Pedro Cuenca 3 years ago
parent bd60010baa
commit a44a91095d

@ -8,8 +8,10 @@ import CoreGraphics
/// Schedulers compatible with StableDiffusionPipeline /// Schedulers compatible with StableDiffusionPipeline
public enum StableDiffusionScheduler { public enum StableDiffusionScheduler {
case pndm /// Scheduler that uses a pseudo-linear multi-step (PLMS) method
case dpmpp case pndmScheduler
/// Scheduler that uses a second order DPM-Solver++ algorithm
case dpmSolverMultistepScheduler
} }
/// A pipeline used to generate image samples from text input using stable diffusion /// A pipeline used to generate image samples from text input using stable diffusion
@ -76,7 +78,7 @@ public struct StableDiffusionPipeline {
stepCount: Int = 50, stepCount: Int = 50,
seed: Int = 0, seed: Int = 0,
disableSafety: Bool = false, disableSafety: Bool = false,
scheduler: StableDiffusionScheduler = .pndm, scheduler: StableDiffusionScheduler = .pndmScheduler,
progressHandler: (Progress) -> Bool = { _ in true } progressHandler: (Progress) -> Bool = { _ in true }
) throws -> [CGImage?] { ) throws -> [CGImage?] {
@ -95,8 +97,8 @@ public struct StableDiffusionPipeline {
/// Setup schedulers /// Setup schedulers
let scheduler: [Scheduler] = (0..<imageCount).map { _ in let scheduler: [Scheduler] = (0..<imageCount).map { _ in
switch scheduler { switch scheduler {
case .pndm: return PNDMScheduler(stepCount: stepCount) case .pndmScheduler: return PNDMScheduler(stepCount: stepCount)
case .dpmpp: return DPMSolverMultistepScheduler(stepCount: stepCount) case .dpmSolverMultistepScheduler: return DPMSolverMultistepScheduler(stepCount: stepCount)
} }
} }
let stdev = scheduler[0].initNoiseSigma let stdev = scheduler[0].initNoiseSigma

Loading…
Cancel
Save