Allow user to cancel generation.

This will be more useful when we allow multiple images.
main
Pedro Cuenca 2 years ago
parent d78d430065
commit 1084ea79f4

@ -22,7 +22,15 @@ struct GeneratedImageView: View {
let step = Int(progress.step) + 1 let step = Int(progress.step) + 1
let fraction = Double(step) / Double(progress.stepCount) let fraction = Double(step) / Double(progress.stepCount)
let label = "Step \(step) of \(progress.stepCount)" let label = "Step \(step) of \(progress.stepCount)"
return AnyView(ProgressView(label, value: fraction, total: 1).padding()) return AnyView(HStack {
ProgressView(label, value: fraction, total: 1).padding()
Button {
generation.cancelGeneration()
} label: {
Image(systemName: "x.circle.fill").foregroundColor(.gray)
}
.buttonStyle(.plain)
})
case .complete(_, let image, _, _): case .complete(_, let image, _, _):
guard let theImage = image else { guard let theImage = image else {
return AnyView(Image(systemName: "exclamationmark.triangle").resizable()) return AnyView(Image(systemName: "exclamationmark.triangle").resizable())
@ -34,6 +42,8 @@ struct GeneratedImageView: View {
) )
case .failed(_): case .failed(_):
return AnyView(Image(systemName: "exclamationmark.triangle").resizable()) return AnyView(Image(systemName: "exclamationmark.triangle").resizable())
case .userCanceled:
return AnyView(Text("Generation canceled"))
} }
} }
} }

@ -20,7 +20,11 @@ struct StatusView: View {
generation.state = .running(nil) generation.state = .running(nil)
do { do {
let result = try await generation.generate() let result = try await generation.generate()
generation.state = .complete(generation.positivePrompt, result.image, result.lastSeed, result.interval) if result.userCanceled {
generation.state = .userCanceled
} else {
generation.state = .complete(generation.positivePrompt, result.image, result.lastSeed, result.interval)
}
} catch { } catch {
generation.state = .failed(error) generation.state = .failed(error)
} }
@ -92,6 +96,11 @@ struct StatusView: View {
}.frame(maxHeight: 25) }.frame(maxHeight: 25)
case .failed(let error): case .failed(let error):
return errorWithDetails("Generation error", error: error) return errorWithDetails("Generation error", error: error)
case .userCanceled:
return HStack {
Text("Generation canceled.")
Spacer()
}
} }
} }

@ -18,6 +18,7 @@ struct GenerationResult {
var image: CGImage? var image: CGImage?
var lastSeed: UInt32 var lastSeed: UInt32
var interval: TimeInterval? var interval: TimeInterval?
var userCanceled: Bool
} }
class Pipeline { class Pipeline {
@ -30,7 +31,8 @@ class Pipeline {
} }
} }
lazy private(set) var progressPublisher: CurrentValueSubject<StableDiffusionProgress?, Never> = CurrentValueSubject(progress) lazy private(set) var progressPublisher: CurrentValueSubject<StableDiffusionProgress?, Never> = CurrentValueSubject(progress)
private var canceled = false
init(_ pipeline: StableDiffusionPipeline, maxSeed: UInt32 = UInt32.max) { init(_ pipeline: StableDiffusionPipeline, maxSeed: UInt32 = UInt32.max) {
self.pipeline = pipeline self.pipeline = pipeline
@ -47,6 +49,7 @@ class Pipeline {
disableSafety: Bool = false disableSafety: Bool = false
) throws -> GenerationResult { ) throws -> GenerationResult {
let beginDate = Date() let beginDate = Date()
canceled = false
print("Generating...") print("Generating...")
let theSeed = seed ?? UInt32.random(in: 0...maxSeed) let theSeed = seed ?? UInt32.random(in: 0...maxSeed)
let images = try pipeline.generateImages( let images = try pipeline.generateImages(
@ -60,17 +63,21 @@ class Pipeline {
scheduler: scheduler scheduler: scheduler
) { progress in ) { progress in
handleProgress(progress) handleProgress(progress)
return true return !canceled
} }
let interval = Date().timeIntervalSince(beginDate) let interval = Date().timeIntervalSince(beginDate)
print("Got images: \(images) in \(interval)") print("Got images: \(images) in \(interval)")
// Unwrap the 1 image we asked for, nil means safety checker triggered // Unwrap the 1 image we asked for, nil means safety checker triggered
let image = images.compactMap({ $0 }).first let image = images.compactMap({ $0 }).first
return GenerationResult(image: image, lastSeed: theSeed, interval: interval) return GenerationResult(image: image, lastSeed: theSeed, interval: interval, userCanceled: canceled)
} }
func handleProgress(_ progress: StableDiffusionPipeline.Progress) { func handleProgress(_ progress: StableDiffusionPipeline.Progress) {
self.progress = progress self.progress = progress
} }
func setCancelled() {
canceled = true
}
} }

@ -17,6 +17,7 @@ enum GenerationState {
case startup case startup
case running(StableDiffusionProgress?) case running(StableDiffusionProgress?)
case complete(String, CGImage?, UInt32, TimeInterval?) case complete(String, CGImage?, UInt32, TimeInterval?)
case userCanceled
case failed(Error) case failed(Error)
} }
@ -63,6 +64,10 @@ class GenerationContext: ObservableObject {
disableSafety: disableSafety disableSafety: disableSafety
) )
} }
func cancelGeneration() {
pipeline?.setCancelled()
}
} }
class Settings { class Settings {

Loading…
Cancel
Save