From 367df55d96faf931d3fe23e6998dd0107defe0b4 Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Wed, 8 Feb 2023 16:00:59 +0100 Subject: [PATCH] Pass seed back, report generation errors. --- Diffusion-macOS/ContentView.swift | 2 +- Diffusion-macOS/GeneratedImageView.swift | 4 +- Diffusion-macOS/StatusView.swift | 70 +++++++++++++----------- Diffusion/Pipeline/Pipeline.swift | 10 +++- Diffusion/State.swift | 9 +-- Diffusion/Views/TextToImage.swift | 2 +- 6 files changed, 57 insertions(+), 40 deletions(-) diff --git a/Diffusion-macOS/ContentView.swift b/Diffusion-macOS/ContentView.swift index 004aacf..3d25826 100644 --- a/Diffusion-macOS/ContentView.swift +++ b/Diffusion-macOS/ContentView.swift @@ -66,7 +66,7 @@ struct ContentView: View { @StateObject var generation = GenerationContext() func toolbar() -> any View { - if case .complete(let prompt, let cgImage, _) = generation.state, let cgImage = cgImage { + if case .complete(let prompt, let cgImage, let seed, _) = generation.state, let cgImage = cgImage { return ShareButtons(image: cgImage, name: prompt) } else { let prompt = DEFAULT_PROMPT diff --git a/Diffusion-macOS/GeneratedImageView.swift b/Diffusion-macOS/GeneratedImageView.swift index bea93d8..7b987ac 100644 --- a/Diffusion-macOS/GeneratedImageView.swift +++ b/Diffusion-macOS/GeneratedImageView.swift @@ -23,7 +23,7 @@ struct GeneratedImageView: View { let fraction = Double(step) / Double(progress.stepCount) let label = "Step \(step) of \(progress.stepCount)" return AnyView(ProgressView(label, value: fraction, total: 1).padding()) - case .complete(_, let image, _): + case .complete(_, let image, _, _): guard let theImage = image else { return AnyView(Image(systemName: "exclamationmark.triangle").resizable()) } @@ -32,6 +32,8 @@ struct GeneratedImageView: View { .resizable() .clipShape(RoundedRectangle(cornerRadius: 20)) ) + case .failed(_): + return AnyView(Image(systemName: "exclamationmark.triangle").resizable()) } } } diff --git a/Diffusion-macOS/StatusView.swift b/Diffusion-macOS/StatusView.swift index 122874f..456279f 100644 --- a/Diffusion-macOS/StatusView.swift +++ b/Diffusion-macOS/StatusView.swift @@ -18,13 +18,42 @@ struct StatusView: View { if case .running = generation.state { return } Task { generation.state = .running(nil) - let interval: TimeInterval? - let image: CGImage? - (image, interval) = await generation.generate() ?? (nil, nil) - generation.state = .complete(generation.positivePrompt, image, interval) + do { + let result = try await generation.generate() + generation.state = .complete(generation.positivePrompt, result.image, result.lastSeed, result.interval) + } catch { + generation.state = .failed(error) + } } } - + + func errorWithDetails(_ message: String, error: Error) -> any View { + HStack { + Text(message) + Spacer() + Button { + showErrorPopover.toggle() + } label: { + Image(systemName: "info.circle") + }.buttonStyle(.plain) + .popover(isPresented: $showErrorPopover) { + VStack { + Text(verbatim: "\(error)") + .lineLimit(nil) + .padding(.all, 5) + Button { + showErrorPopover.toggle() + } label: { + Text("Dismiss").frame(maxWidth: 200) + } + .padding(.bottom) + } + .frame(minWidth: 400, idealWidth: 400, maxWidth: 400) + .fixedSize() + } + } + } + func generationStatusView() -> any View { switch generation.state { case .startup: return EmptyView() @@ -42,7 +71,7 @@ struct StatusView: View { Text("Generating \(Int(round(100*fraction)))%") Spacer() } - case .complete(_, let image, let interval): + case .complete(_, let image, let lastSeed, let interval): guard let _ = image else { return HStack { Text("Safety checker triggered, please try a different prompt or seed") @@ -55,9 +84,11 @@ struct StatusView: View { Text(intervalString) Spacer() }.frame(maxHeight: 25) + case .failed(let error): + return errorWithDetails("Generation error", error: error) } } - + var body: some View { switch pipelineState.wrappedValue { case .downloading(let progress): @@ -80,30 +111,7 @@ struct StatusView: View { AnyView(generationStatusView()) } case .failed(let error): - HStack { - Text("Pipeline loading error") - Spacer() - Button { - showErrorPopover.toggle() - } label: { - Image(systemName: "info.circle") - }.buttonStyle(.plain) - .popover(isPresented: $showErrorPopover) { - VStack { - Text(verbatim: "\(error)") - .lineLimit(nil) - .padding(.all, 5) - Button { - showErrorPopover.toggle() - } label: { - Text("Dismiss").frame(maxWidth: 200) - } - .padding(.bottom) - } - .frame(minWidth: 400, idealWidth: 400, maxWidth: 400) - .fixedSize() - } - } + AnyView(errorWithDetails("Pipeline loading error", error: error)) } } } diff --git a/Diffusion/Pipeline/Pipeline.swift b/Diffusion/Pipeline/Pipeline.swift index 1226ac6..3f04aa6 100644 --- a/Diffusion/Pipeline/Pipeline.swift +++ b/Diffusion/Pipeline/Pipeline.swift @@ -14,6 +14,12 @@ import StableDiffusion typealias StableDiffusionProgress = StableDiffusionPipeline.Progress +struct GenerationResult { + var image: CGImage? + var lastSeed: UInt32 + var interval: TimeInterval? +} + class Pipeline { let pipeline: StableDiffusionPipeline @@ -37,7 +43,7 @@ class Pipeline { seed: UInt32? = nil, guidanceScale: Float = 7.5, disableSafety: Bool = false - ) throws -> (CGImage, TimeInterval) { + ) throws -> GenerationResult { let beginDate = Date() print("Generating...") let theSeed = seed ?? UInt32.random(in: 0.. (CGImage, TimeInterval)? { - guard let pipeline = pipeline else { return nil } + func generate() async throws -> GenerationResult { + guard let pipeline = pipeline else { throw "No pipeline" } let seed = self.seed >= 0 ? UInt32(self.seed) : nil - return try? pipeline.generate( + return try pipeline.generate( prompt: positivePrompt, negativePrompt: negativePrompt, scheduler: scheduler, diff --git a/Diffusion/Views/TextToImage.swift b/Diffusion/Views/TextToImage.swift index e05fdc5..b425839 100644 --- a/Diffusion/Views/TextToImage.swift +++ b/Diffusion/Views/TextToImage.swift @@ -99,7 +99,7 @@ struct TextToImage: View { generation.state = .running(nil) let interval: TimeInterval? let image: CGImage? - (image, interval) = await generation.generate() ?? (nil, nil) + let result = await generation.generate() generation.state = .complete(generation.positivePrompt, image, interval) } }