Allow last random seed to be reused

pull/28/head
Pedro Cuenca 2 years ago
parent 367df55d96
commit 9c54eb40f6

@ -66,7 +66,8 @@ struct ContentView: View {
@StateObject var generation = GenerationContext() @StateObject var generation = GenerationContext()
func toolbar() -> any View { func toolbar() -> any View {
if case .complete(let prompt, let cgImage, let seed, _) = generation.state, let cgImage = cgImage { if case .complete(let prompt, let cgImage, _, _) = generation.state, let cgImage = cgImage {
// TODO: share seed too
return ShareButtons(image: cgImage, name: prompt) return ShareButtons(image: cgImage, name: prompt)
} else { } else {
let prompt = DEFAULT_PROMPT let prompt = DEFAULT_PROMPT

@ -70,6 +70,9 @@ struct ControlsView: View {
@State private var showGuidanceHelp = false @State private var showGuidanceHelp = false
@State private var showStepsHelp = false @State private var showStepsHelp = false
@State private var showSeedHelp = false @State private var showSeedHelp = false
// Reasonable range for the slider
let maxSeed: UInt32 = 1000
func updateSafetyCheckerState() { func updateSafetyCheckerState() {
mustShowSafetyCheckerDisclaimer = generation.disableSafety && !Settings.shared.safetyCheckerDisclaimerShown mustShowSafetyCheckerDisclaimer = generation.disableSafety && !Settings.shared.safetyCheckerDisclaimerShown
@ -82,7 +85,7 @@ struct ControlsView: View {
pipelineLoader?.cancel() pipelineLoader?.cancel()
pipelineState = .downloading(0) pipelineState = .downloading(0)
Task.init { Task.init {
let loader = PipelineLoader(model: model) let loader = PipelineLoader(model: model, maxSeed: maxSeed)
self.pipelineLoader = loader self.pipelineLoader = loader
stateSubscriber = loader.statePublisher.sink { state in stateSubscriber = loader.statePublisher.sink { state in
DispatchQueue.main.async { DispatchQueue.main.async {
@ -245,7 +248,7 @@ struct ControlsView: View {
DisclosureGroup(isExpanded: $disclosedSeed) { DisclosureGroup(isExpanded: $disclosedSeed) {
let sliderLabel = generation.seed < 0 ? "Random Seed" : "Seed" let sliderLabel = generation.seed < 0 ? "Random Seed" : "Seed"
CompactSlider(value: $generation.seed, in: -1...1000, step: 1) { CompactSlider(value: $generation.seed, in: -1...Double(maxSeed), step: 1) {
Text(sliderLabel) Text(sliderLabel)
Spacer() Spacer()
Text("\(Int(generation.seed))") Text("\(Int(generation.seed))")

@ -83,6 +83,12 @@ struct StatusView: View {
let intervalString = String(format: "Time: %.1fs", interval ?? 0) let intervalString = String(format: "Time: %.1fs", interval ?? 0)
Text(intervalString) Text(intervalString)
Spacer() Spacer()
if generation.seed != Double(lastSeed) {
Text("Seed: \(lastSeed)")
Button("Set") {
generation.seed = Double(lastSeed)
}
}
}.frame(maxHeight: 25) }.frame(maxHeight: 25)
case .failed(let error): case .failed(let error):
return errorWithDetails("Generation error", error: error) return errorWithDetails("Generation error", error: error)

@ -22,6 +22,7 @@ struct GenerationResult {
class Pipeline { class Pipeline {
let pipeline: StableDiffusionPipeline let pipeline: StableDiffusionPipeline
let maxSeed: UInt32
var progress: StableDiffusionProgress? = nil { var progress: StableDiffusionProgress? = nil {
didSet { didSet {
@ -31,8 +32,9 @@ class Pipeline {
lazy private(set) var progressPublisher: CurrentValueSubject<StableDiffusionProgress?, Never> = CurrentValueSubject(progress) lazy private(set) var progressPublisher: CurrentValueSubject<StableDiffusionProgress?, Never> = CurrentValueSubject(progress)
init(_ pipeline: StableDiffusionPipeline) { init(_ pipeline: StableDiffusionPipeline, maxSeed: UInt32 = UInt32.max) {
self.pipeline = pipeline self.pipeline = pipeline
self.maxSeed = maxSeed
} }
func generate( func generate(
@ -46,7 +48,7 @@ class Pipeline {
) throws -> GenerationResult { ) throws -> GenerationResult {
let beginDate = Date() let beginDate = Date()
print("Generating...") print("Generating...")
let theSeed = seed ?? UInt32.random(in: 0..<UInt32.max) let theSeed = seed ?? UInt32.random(in: 0...maxSeed)
let images = try pipeline.generateImages( let images = try pipeline.generateImages(
prompt: prompt, prompt: prompt,
negativePrompt: negativePrompt, negativePrompt: negativePrompt,

@ -18,10 +18,13 @@ class PipelineLoader {
static let models = Path.applicationSupport / "hf-diffusion-models" static let models = Path.applicationSupport / "hf-diffusion-models"
let model: ModelInfo let model: ModelInfo
let maxSeed: UInt32
private var downloadSubscriber: Cancellable? private var downloadSubscriber: Cancellable?
init(model: ModelInfo) { init(model: ModelInfo, maxSeed: UInt32 = UInt32.max) {
self.model = model self.model = model
self.maxSeed = maxSeed
state = .undetermined state = .undetermined
setInitialState() setInitialState()
} }
@ -100,7 +103,7 @@ extension PipelineLoader {
try await download() try await download()
try await unzip() try await unzip()
let pipeline = try await load(url: compiledPath.url) let pipeline = try await load(url: compiledPath.url)
return Pipeline(pipeline) return Pipeline(pipeline, maxSeed: maxSeed)
} catch { } catch {
state = .failed(error) state = .failed(error)
throw error throw error

Loading…
Cancel
Save