diff --git a/swift/StableDiffusionCLI/main.swift b/swift/StableDiffusionCLI/main.swift index d6d9855..316ab12 100644 --- a/swift/StableDiffusionCLI/main.swift +++ b/swift/StableDiffusionCLI/main.swift @@ -53,6 +53,9 @@ struct StableDiffusionSample: ParsableCommand { @Option(help: "Compute units to load model with {all,cpuOnly,cpuAndGPU,cpuAndNeuralEngine}") var computeUnits: ComputeUnits = .all + @Option(help: "Scheduler to use, one of {pndm, dpmpp}") + var scheduler: SchedulerOption = .pndm + @Flag(help: "Disable safety checking") var disableSafety: Bool = false @@ -84,7 +87,8 @@ struct StableDiffusionSample: ParsableCommand { prompt: prompt, imageCount: imageCount, stepCount: stepCount, - seed: seed + seed: seed, + scheduler: scheduler.stableDiffusionScheduler ) { progress in sampleTimer.stop() handleProgress(progress,sampleTimer) @@ -190,6 +194,17 @@ enum ComputeUnits: String, ExpressibleByArgument, CaseIterable { } } +@available(iOS 16.2, macOS 13.1, *) +enum SchedulerOption: String, ExpressibleByArgument { + case pndm, dpmpp + var stableDiffusionScheduler: StableDiffusionScheduler { + switch self { + case .pndm: return .pndmScheduler + case .dpmpp: return .dpmSolverMultistepScheduler + } + } +} + if #available(iOS 16.2, macOS 13.1, *) { StableDiffusionSample.main() } else {