From 1084ea79f4527d921989e34aff3066abecc7240e Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Mon, 20 Feb 2023 20:45:16 +0100 Subject: [PATCH] Allow user to cancel generation. This will be more useful when we allow multiple images. --- Diffusion-macOS/GeneratedImageView.swift | 12 +++++++++++- Diffusion-macOS/StatusView.swift | 11 ++++++++++- Diffusion/Pipeline/Pipeline.swift | 13 ++++++++++--- Diffusion/State.swift | 5 +++++ 4 files changed, 36 insertions(+), 5 deletions(-) diff --git a/Diffusion-macOS/GeneratedImageView.swift b/Diffusion-macOS/GeneratedImageView.swift index 7b987ac..47d2271 100644 --- a/Diffusion-macOS/GeneratedImageView.swift +++ b/Diffusion-macOS/GeneratedImageView.swift @@ -22,7 +22,15 @@ struct GeneratedImageView: View { let step = Int(progress.step) + 1 let fraction = Double(step) / Double(progress.stepCount) let label = "Step \(step) of \(progress.stepCount)" - return AnyView(ProgressView(label, value: fraction, total: 1).padding()) + return AnyView(HStack { + ProgressView(label, value: fraction, total: 1).padding() + Button { + generation.cancelGeneration() + } label: { + Image(systemName: "x.circle.fill").foregroundColor(.gray) + } + .buttonStyle(.plain) + }) case .complete(_, let image, _, _): guard let theImage = image else { return AnyView(Image(systemName: "exclamationmark.triangle").resizable()) @@ -34,6 +42,8 @@ struct GeneratedImageView: View { ) case .failed(_): return AnyView(Image(systemName: "exclamationmark.triangle").resizable()) + case .userCanceled: + return AnyView(Text("Generation canceled")) } } } diff --git a/Diffusion-macOS/StatusView.swift b/Diffusion-macOS/StatusView.swift index b01461d..ac02f42 100644 --- a/Diffusion-macOS/StatusView.swift +++ b/Diffusion-macOS/StatusView.swift @@ -20,7 +20,11 @@ struct StatusView: View { generation.state = .running(nil) do { let result = try await generation.generate() - generation.state = .complete(generation.positivePrompt, result.image, result.lastSeed, result.interval) + if result.userCanceled { + generation.state = .userCanceled + } else { + generation.state = .complete(generation.positivePrompt, result.image, result.lastSeed, result.interval) + } } catch { generation.state = .failed(error) } @@ -92,6 +96,11 @@ struct StatusView: View { }.frame(maxHeight: 25) case .failed(let error): return errorWithDetails("Generation error", error: error) + case .userCanceled: + return HStack { + Text("Generation canceled.") + Spacer() + } } } diff --git a/Diffusion/Pipeline/Pipeline.swift b/Diffusion/Pipeline/Pipeline.swift index e706010..90e17bd 100644 --- a/Diffusion/Pipeline/Pipeline.swift +++ b/Diffusion/Pipeline/Pipeline.swift @@ -18,6 +18,7 @@ struct GenerationResult { var image: CGImage? var lastSeed: UInt32 var interval: TimeInterval? + var userCanceled: Bool } class Pipeline { @@ -30,7 +31,8 @@ class Pipeline { } } lazy private(set) var progressPublisher: CurrentValueSubject = CurrentValueSubject(progress) - + + private var canceled = false init(_ pipeline: StableDiffusionPipeline, maxSeed: UInt32 = UInt32.max) { self.pipeline = pipeline @@ -47,6 +49,7 @@ class Pipeline { disableSafety: Bool = false ) throws -> GenerationResult { let beginDate = Date() + canceled = false print("Generating...") let theSeed = seed ?? UInt32.random(in: 0...maxSeed) let images = try pipeline.generateImages( @@ -60,17 +63,21 @@ class Pipeline { scheduler: scheduler ) { progress in handleProgress(progress) - return true + return !canceled } let interval = Date().timeIntervalSince(beginDate) print("Got images: \(images) in \(interval)") // Unwrap the 1 image we asked for, nil means safety checker triggered let image = images.compactMap({ $0 }).first - return GenerationResult(image: image, lastSeed: theSeed, interval: interval) + return GenerationResult(image: image, lastSeed: theSeed, interval: interval, userCanceled: canceled) } func handleProgress(_ progress: StableDiffusionPipeline.Progress) { self.progress = progress } + + func setCancelled() { + canceled = true + } } diff --git a/Diffusion/State.swift b/Diffusion/State.swift index 967b3a9..12381c5 100644 --- a/Diffusion/State.swift +++ b/Diffusion/State.swift @@ -17,6 +17,7 @@ enum GenerationState { case startup case running(StableDiffusionProgress?) case complete(String, CGImage?, UInt32, TimeInterval?) + case userCanceled case failed(Error) } @@ -63,6 +64,10 @@ class GenerationContext: ObservableObject { disableSafety: disableSafety ) } + + func cancelGeneration() { + pipeline?.setCancelled() + } } class Settings {