Pass seed back, report generation errors.

pull/28/head
Pedro Cuenca 2 years ago
parent e2cc04fce3
commit 367df55d96

@ -66,7 +66,7 @@ struct ContentView: View {
@StateObject var generation = GenerationContext() @StateObject var generation = GenerationContext()
func toolbar() -> any View { func toolbar() -> any View {
if case .complete(let prompt, let cgImage, _) = generation.state, let cgImage = cgImage { if case .complete(let prompt, let cgImage, let seed, _) = generation.state, let cgImage = cgImage {
return ShareButtons(image: cgImage, name: prompt) return ShareButtons(image: cgImage, name: prompt)
} else { } else {
let prompt = DEFAULT_PROMPT let prompt = DEFAULT_PROMPT

@ -23,7 +23,7 @@ struct GeneratedImageView: View {
let fraction = Double(step) / Double(progress.stepCount) let fraction = Double(step) / Double(progress.stepCount)
let label = "Step \(step) of \(progress.stepCount)" let label = "Step \(step) of \(progress.stepCount)"
return AnyView(ProgressView(label, value: fraction, total: 1).padding()) return AnyView(ProgressView(label, value: fraction, total: 1).padding())
case .complete(_, let image, _): case .complete(_, let image, _, _):
guard let theImage = image else { guard let theImage = image else {
return AnyView(Image(systemName: "exclamationmark.triangle").resizable()) return AnyView(Image(systemName: "exclamationmark.triangle").resizable())
} }
@ -32,6 +32,8 @@ struct GeneratedImageView: View {
.resizable() .resizable()
.clipShape(RoundedRectangle(cornerRadius: 20)) .clipShape(RoundedRectangle(cornerRadius: 20))
) )
case .failed(_):
return AnyView(Image(systemName: "exclamationmark.triangle").resizable())
} }
} }
} }

@ -18,10 +18,39 @@ struct StatusView: View {
if case .running = generation.state { return } if case .running = generation.state { return }
Task { Task {
generation.state = .running(nil) generation.state = .running(nil)
let interval: TimeInterval? do {
let image: CGImage? let result = try await generation.generate()
(image, interval) = await generation.generate() ?? (nil, nil) generation.state = .complete(generation.positivePrompt, result.image, result.lastSeed, result.interval)
generation.state = .complete(generation.positivePrompt, image, interval) } catch {
generation.state = .failed(error)
}
}
}
func errorWithDetails(_ message: String, error: Error) -> any View {
HStack {
Text(message)
Spacer()
Button {
showErrorPopover.toggle()
} label: {
Image(systemName: "info.circle")
}.buttonStyle(.plain)
.popover(isPresented: $showErrorPopover) {
VStack {
Text(verbatim: "\(error)")
.lineLimit(nil)
.padding(.all, 5)
Button {
showErrorPopover.toggle()
} label: {
Text("Dismiss").frame(maxWidth: 200)
}
.padding(.bottom)
}
.frame(minWidth: 400, idealWidth: 400, maxWidth: 400)
.fixedSize()
}
} }
} }
@ -42,7 +71,7 @@ struct StatusView: View {
Text("Generating \(Int(round(100*fraction)))%") Text("Generating \(Int(round(100*fraction)))%")
Spacer() Spacer()
} }
case .complete(_, let image, let interval): case .complete(_, let image, let lastSeed, let interval):
guard let _ = image else { guard let _ = image else {
return HStack { return HStack {
Text("Safety checker triggered, please try a different prompt or seed") Text("Safety checker triggered, please try a different prompt or seed")
@ -55,6 +84,8 @@ struct StatusView: View {
Text(intervalString) Text(intervalString)
Spacer() Spacer()
}.frame(maxHeight: 25) }.frame(maxHeight: 25)
case .failed(let error):
return errorWithDetails("Generation error", error: error)
} }
} }
@ -80,30 +111,7 @@ struct StatusView: View {
AnyView(generationStatusView()) AnyView(generationStatusView())
} }
case .failed(let error): case .failed(let error):
HStack { AnyView(errorWithDetails("Pipeline loading error", error: error))
Text("Pipeline loading error")
Spacer()
Button {
showErrorPopover.toggle()
} label: {
Image(systemName: "info.circle")
}.buttonStyle(.plain)
.popover(isPresented: $showErrorPopover) {
VStack {
Text(verbatim: "\(error)")
.lineLimit(nil)
.padding(.all, 5)
Button {
showErrorPopover.toggle()
} label: {
Text("Dismiss").frame(maxWidth: 200)
}
.padding(.bottom)
}
.frame(minWidth: 400, idealWidth: 400, maxWidth: 400)
.fixedSize()
}
}
} }
} }
} }

@ -14,6 +14,12 @@ import StableDiffusion
typealias StableDiffusionProgress = StableDiffusionPipeline.Progress typealias StableDiffusionProgress = StableDiffusionPipeline.Progress
struct GenerationResult {
var image: CGImage?
var lastSeed: UInt32
var interval: TimeInterval?
}
class Pipeline { class Pipeline {
let pipeline: StableDiffusionPipeline let pipeline: StableDiffusionPipeline
@ -37,7 +43,7 @@ class Pipeline {
seed: UInt32? = nil, seed: UInt32? = nil,
guidanceScale: Float = 7.5, guidanceScale: Float = 7.5,
disableSafety: Bool = false disableSafety: Bool = false
) throws -> (CGImage, TimeInterval) { ) throws -> GenerationResult {
let beginDate = Date() let beginDate = Date()
print("Generating...") print("Generating...")
let theSeed = seed ?? UInt32.random(in: 0..<UInt32.max) let theSeed = seed ?? UInt32.random(in: 0..<UInt32.max)
@ -59,7 +65,7 @@ class Pipeline {
// unwrap the 1 image we asked for // unwrap the 1 image we asked for
guard let image = images.compactMap({ $0 }).first else { throw "Generation failed" } guard let image = images.compactMap({ $0 }).first else { throw "Generation failed" }
return (image, interval) return GenerationResult(image: image, lastSeed: theSeed, interval: interval)
} }
func handleProgress(_ progress: StableDiffusionPipeline.Progress) { func handleProgress(_ progress: StableDiffusionPipeline.Progress) {

@ -16,7 +16,8 @@ let DEFAULT_PROMPT = "Labrador in the style of Vermeer"
enum GenerationState { enum GenerationState {
case startup case startup
case running(StableDiffusionProgress?) case running(StableDiffusionProgress?)
case complete(String, CGImage?, TimeInterval?) case complete(String, CGImage?, UInt32, TimeInterval?)
case failed(Error)
} }
class GenerationContext: ObservableObject { class GenerationContext: ObservableObject {
@ -49,10 +50,10 @@ class GenerationContext: ObservableObject {
private var progressSubscriber: Cancellable? private var progressSubscriber: Cancellable?
func generate() async -> (CGImage, TimeInterval)? { func generate() async throws -> GenerationResult {
guard let pipeline = pipeline else { return nil } guard let pipeline = pipeline else { throw "No pipeline" }
let seed = self.seed >= 0 ? UInt32(self.seed) : nil let seed = self.seed >= 0 ? UInt32(self.seed) : nil
return try? pipeline.generate( return try pipeline.generate(
prompt: positivePrompt, prompt: positivePrompt,
negativePrompt: negativePrompt, negativePrompt: negativePrompt,
scheduler: scheduler, scheduler: scheduler,

@ -99,7 +99,7 @@ struct TextToImage: View {
generation.state = .running(nil) generation.state = .running(nil)
let interval: TimeInterval? let interval: TimeInterval?
let image: CGImage? let image: CGImage?
(image, interval) = await generation.generate() ?? (nil, nil) let result = await generation.generate()
generation.state = .complete(generation.positivePrompt, image, interval) generation.state = .complete(generation.positivePrompt, image, interval)
} }
} }

Loading…
Cancel
Save