Allow user to cancel generation.

This will be more useful when we allow multiple images.
main
Pedro Cuenca 2 years ago
parent d78d430065
commit 1084ea79f4

@ -22,7 +22,15 @@ struct GeneratedImageView: View {
let step = Int(progress.step) + 1
let fraction = Double(step) / Double(progress.stepCount)
let label = "Step \(step) of \(progress.stepCount)"
return AnyView(ProgressView(label, value: fraction, total: 1).padding())
return AnyView(HStack {
ProgressView(label, value: fraction, total: 1).padding()
Button {
generation.cancelGeneration()
} label: {
Image(systemName: "x.circle.fill").foregroundColor(.gray)
}
.buttonStyle(.plain)
})
case .complete(_, let image, _, _):
guard let theImage = image else {
return AnyView(Image(systemName: "exclamationmark.triangle").resizable())
@ -34,6 +42,8 @@ struct GeneratedImageView: View {
)
case .failed(_):
return AnyView(Image(systemName: "exclamationmark.triangle").resizable())
case .userCanceled:
return AnyView(Text("Generation canceled"))
}
}
}

@ -20,7 +20,11 @@ struct StatusView: View {
generation.state = .running(nil)
do {
let result = try await generation.generate()
generation.state = .complete(generation.positivePrompt, result.image, result.lastSeed, result.interval)
if result.userCanceled {
generation.state = .userCanceled
} else {
generation.state = .complete(generation.positivePrompt, result.image, result.lastSeed, result.interval)
}
} catch {
generation.state = .failed(error)
}
@ -92,6 +96,11 @@ struct StatusView: View {
}.frame(maxHeight: 25)
case .failed(let error):
return errorWithDetails("Generation error", error: error)
case .userCanceled:
return HStack {
Text("Generation canceled.")
Spacer()
}
}
}

@ -18,6 +18,7 @@ struct GenerationResult {
var image: CGImage?
var lastSeed: UInt32
var interval: TimeInterval?
var userCanceled: Bool
}
class Pipeline {
@ -30,7 +31,8 @@ class Pipeline {
}
}
lazy private(set) var progressPublisher: CurrentValueSubject<StableDiffusionProgress?, Never> = CurrentValueSubject(progress)
private var canceled = false
init(_ pipeline: StableDiffusionPipeline, maxSeed: UInt32 = UInt32.max) {
self.pipeline = pipeline
@ -47,6 +49,7 @@ class Pipeline {
disableSafety: Bool = false
) throws -> GenerationResult {
let beginDate = Date()
canceled = false
print("Generating...")
let theSeed = seed ?? UInt32.random(in: 0...maxSeed)
let images = try pipeline.generateImages(
@ -60,17 +63,21 @@ class Pipeline {
scheduler: scheduler
) { progress in
handleProgress(progress)
return true
return !canceled
}
let interval = Date().timeIntervalSince(beginDate)
print("Got images: \(images) in \(interval)")
// Unwrap the 1 image we asked for, nil means safety checker triggered
let image = images.compactMap({ $0 }).first
return GenerationResult(image: image, lastSeed: theSeed, interval: interval)
return GenerationResult(image: image, lastSeed: theSeed, interval: interval, userCanceled: canceled)
}
func handleProgress(_ progress: StableDiffusionPipeline.Progress) {
self.progress = progress
}
func setCancelled() {
canceled = true
}
}

@ -17,6 +17,7 @@ enum GenerationState {
case startup
case running(StableDiffusionProgress?)
case complete(String, CGImage?, UInt32, TimeInterval?)
case userCanceled
case failed(Error)
}
@ -63,6 +64,10 @@ class GenerationContext: ObservableObject {
disableSafety: disableSafety
)
}
func cancelGeneration() {
pipeline?.setCancelled()
}
}
class Settings {

Loading…
Cancel
Save