Add guidance scale.

pull/28/head
Pedro Cuenca 2 years ago
parent d3c145f219
commit 1839453f3d

@ -54,6 +54,7 @@ struct ControlsView: View {
@State private var model = Settings.shared.currentModel.modelVersion
@State private var disclosedModel = true
@State private var disclosedPrompt = true
@State private var disclosedGuidance = false
@State private var disclosedSteps = false
@State private var disclosedSeed = false
@ -134,6 +135,17 @@ struct ControlsView: View {
Divider()
DisclosureGroup(isExpanded: $disclosedGuidance) {
CompactSlider(value: $generation.guidanceScale, in: 0...20, step: 0.5) {
Text("Guidance Scale")
Spacer()
Text(generation.guidanceScale.formatted("%.1f"))
}.padding(.leading, 10)
} label: {
Label("Guidance Scale", systemImage: "scalemass").foregroundColor(.secondary)
}
Divider()
DisclosureGroup(isExpanded: $disclosedSteps) {
CompactSlider(value: $generation.steps, in: 0...150, step: 5) {
Text("Steps")

@ -29,7 +29,15 @@ class Pipeline {
self.pipeline = pipeline
}
func generate(prompt: String, negativePrompt: String = "", scheduler: StableDiffusionScheduler, numInferenceSteps stepCount: Int = 50, seed: UInt32? = nil) throws -> (CGImage, TimeInterval) {
func generate(
prompt: String,
negativePrompt: String = "",
scheduler: StableDiffusionScheduler,
numInferenceSteps stepCount: Int = 50,
seed: UInt32? = nil,
guidanceScale: Float = 7.5,
disableSafety: Bool = false
) throws -> (CGImage, TimeInterval) {
let beginDate = Date()
print("Generating...")
let theSeed = seed ?? UInt32.random(in: 0..<UInt32.max)
@ -39,6 +47,8 @@ class Pipeline {
imageCount: 1,
stepCount: stepCount,
seed: theSeed,
guidanceScale: guidanceScale,
disableSafety: disableSafety,
scheduler: scheduler
) { progress in
handleProgress(progress)

@ -44,13 +44,23 @@ class GenerationContext: ObservableObject {
@Published var steps = 25.0
@Published var numImages = 1.0
@Published var seed = -1.0
@Published var guidanceScale = 7.5
@Published var disableSafety = false
private var progressSubscriber: Cancellable?
func generate() async -> (CGImage, TimeInterval)? {
guard let pipeline = pipeline else { return nil }
let seed = self.seed >= 0 ? UInt32(self.seed) : nil
return try? pipeline.generate(prompt: positivePrompt, negativePrompt: negativePrompt, scheduler: scheduler, numInferenceSteps: Int(steps), seed: seed)
return try? pipeline.generate(
prompt: positivePrompt,
negativePrompt: negativePrompt,
scheduler: scheduler,
numInferenceSteps: Int(steps),
seed: seed,
guidanceScale: Float(guidanceScale),
disableSafety: disableSafety
)
}
}

@ -9,3 +9,9 @@
import Foundation
extension String: Error {}
extension Double {
func formatted(_ format: String) -> String {
return String(format: "\(format)", self)
}
}

Loading…
Cancel
Save