Allow last random seed to be reused

pull/28/head
Pedro Cuenca 2 years ago
parent 367df55d96
commit 9c54eb40f6

@ -66,7 +66,8 @@ struct ContentView: View {
@StateObject var generation = GenerationContext()
func toolbar() -> any View {
if case .complete(let prompt, let cgImage, let seed, _) = generation.state, let cgImage = cgImage {
if case .complete(let prompt, let cgImage, _, _) = generation.state, let cgImage = cgImage {
// TODO: share seed too
return ShareButtons(image: cgImage, name: prompt)
} else {
let prompt = DEFAULT_PROMPT

@ -71,6 +71,9 @@ struct ControlsView: View {
@State private var showStepsHelp = false
@State private var showSeedHelp = false
// Reasonable range for the slider
let maxSeed: UInt32 = 1000
func updateSafetyCheckerState() {
mustShowSafetyCheckerDisclaimer = generation.disableSafety && !Settings.shared.safetyCheckerDisclaimerShown
}
@ -82,7 +85,7 @@ struct ControlsView: View {
pipelineLoader?.cancel()
pipelineState = .downloading(0)
Task.init {
let loader = PipelineLoader(model: model)
let loader = PipelineLoader(model: model, maxSeed: maxSeed)
self.pipelineLoader = loader
stateSubscriber = loader.statePublisher.sink { state in
DispatchQueue.main.async {
@ -245,7 +248,7 @@ struct ControlsView: View {
DisclosureGroup(isExpanded: $disclosedSeed) {
let sliderLabel = generation.seed < 0 ? "Random Seed" : "Seed"
CompactSlider(value: $generation.seed, in: -1...1000, step: 1) {
CompactSlider(value: $generation.seed, in: -1...Double(maxSeed), step: 1) {
Text(sliderLabel)
Spacer()
Text("\(Int(generation.seed))")

@ -83,6 +83,12 @@ struct StatusView: View {
let intervalString = String(format: "Time: %.1fs", interval ?? 0)
Text(intervalString)
Spacer()
if generation.seed != Double(lastSeed) {
Text("Seed: \(lastSeed)")
Button("Set") {
generation.seed = Double(lastSeed)
}
}
}.frame(maxHeight: 25)
case .failed(let error):
return errorWithDetails("Generation error", error: error)

@ -22,6 +22,7 @@ struct GenerationResult {
class Pipeline {
let pipeline: StableDiffusionPipeline
let maxSeed: UInt32
var progress: StableDiffusionProgress? = nil {
didSet {
@ -31,8 +32,9 @@ class Pipeline {
lazy private(set) var progressPublisher: CurrentValueSubject<StableDiffusionProgress?, Never> = CurrentValueSubject(progress)
init(_ pipeline: StableDiffusionPipeline) {
init(_ pipeline: StableDiffusionPipeline, maxSeed: UInt32 = UInt32.max) {
self.pipeline = pipeline
self.maxSeed = maxSeed
}
func generate(
@ -46,7 +48,7 @@ class Pipeline {
) throws -> GenerationResult {
let beginDate = Date()
print("Generating...")
let theSeed = seed ?? UInt32.random(in: 0..<UInt32.max)
let theSeed = seed ?? UInt32.random(in: 0...maxSeed)
let images = try pipeline.generateImages(
prompt: prompt,
negativePrompt: negativePrompt,

@ -18,10 +18,13 @@ class PipelineLoader {
static let models = Path.applicationSupport / "hf-diffusion-models"
let model: ModelInfo
let maxSeed: UInt32
private var downloadSubscriber: Cancellable?
init(model: ModelInfo) {
init(model: ModelInfo, maxSeed: UInt32 = UInt32.max) {
self.model = model
self.maxSeed = maxSeed
state = .undetermined
setInitialState()
}
@ -100,7 +103,7 @@ extension PipelineLoader {
try await download()
try await unzip()
let pipeline = try await load(url: compiledPath.url)
return Pipeline(pipeline)
return Pipeline(pipeline, maxSeed: maxSeed)
} catch {
state = .failed(error)
throw error

Loading…
Cancel
Save