From e14e0119cb4dabe8da94ef02c4fdc0302fd9e1b5 Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Wed, 18 Jan 2023 12:20:16 +0100 Subject: [PATCH] Make Generate button work. (Hardcoded prompt) --- Diffusion-macOS/ContentView.swift | 13 +++-- Diffusion-macOS/GeneratedImageView.swift | 47 +++++++++++++++++ Diffusion-macOS/PromptView.swift | 2 +- Diffusion-macOS/StatusView.swift | 65 +++++++++++++++++++++--- Diffusion.xcodeproj/project.pbxproj | 4 ++ Diffusion/State.swift | 25 ++++++++- 6 files changed, 142 insertions(+), 14 deletions(-) create mode 100644 Diffusion-macOS/GeneratedImageView.swift diff --git a/Diffusion-macOS/ContentView.swift b/Diffusion-macOS/ContentView.swift index 01f35a7..f84bb7a 100644 --- a/Diffusion-macOS/ContentView.swift +++ b/Diffusion-macOS/ContentView.swift @@ -8,22 +8,25 @@ import SwiftUI struct ContentView: View { + @StateObject var context = GenerationContext() + var body: some View { NavigationSplitView { PromptView() .navigationSplitViewColumnWidth(min: 250, ideal: 300) } detail: { - Image("placeholder") - .resizable() + GeneratedImageView() .aspectRatio(contentMode: .fit) .frame(width: 512, height: 512) .cornerRadius(15) .toolbar { Button(action: {}) { - Label("share", systemImage: "square.and.arrow.up") - } - } + Label("share", systemImage: "square.and.arrow.up") + } + } + } + .environmentObject(context) } } diff --git a/Diffusion-macOS/GeneratedImageView.swift b/Diffusion-macOS/GeneratedImageView.swift new file mode 100644 index 0000000..cf22dcb --- /dev/null +++ b/Diffusion-macOS/GeneratedImageView.swift @@ -0,0 +1,47 @@ +// +// GeneratedImageView.swift +// Diffusion +// +// Created by Pedro Cuenca on 18/1/23. +// + +import SwiftUI + +struct GeneratedImageView: View { + @EnvironmentObject var generation: GenerationContext + + var body: some View { + switch generation.state { + case .startup: return AnyView(Image("placeholder").resizable()) + case .running(let progress): + guard let progress = progress, progress.stepCount > 0 else { + // The first time it takes a little bit before generation starts + return AnyView(ProgressView()) + } + let step = Int(progress.step) + 1 + let fraction = Double(step) / Double(progress.stepCount) + let label = "Step \(step) of \(progress.stepCount)" + return AnyView(ProgressView(label, value: fraction, total: 1).padding()) + case .complete(let lastPrompt, let image, let interval): + guard let theImage = image else { + return AnyView(Image(systemName: "exclamationmark.triangle").resizable()) + } + + let imageView = Image(theImage, scale: 1, label: Text("generated")) + return AnyView( + VStack { + imageView.resizable().clipShape(RoundedRectangle(cornerRadius: 20)) +// HStack { +// let intervalString = String(format: "Time: %.1fs", interval ?? 0) +// Rectangle().fill(.clear).overlay(Text(intervalString).frame(maxWidth: .infinity, alignment: .leading).padding(.leading)) +// Rectangle().fill(.clear).overlay( +// HStack { +// Spacer() +// ShareButtons(image: theImage, name: lastPrompt).padding(.trailing) +// } +// ) +// }.frame(maxHeight: 25) + }) + } + } +} diff --git a/Diffusion-macOS/PromptView.swift b/Diffusion-macOS/PromptView.swift index 3b1a904..ab2d7ae 100644 --- a/Diffusion-macOS/PromptView.swift +++ b/Diffusion-macOS/PromptView.swift @@ -18,7 +18,7 @@ enum PipelineState { } struct PromptView: View { - @StateObject var context = GenerationContext() + @EnvironmentObject var context: GenerationContext static let models = ModelInfo.MODELS static let modelNames = models.map { $0.modelVersion } diff --git a/Diffusion-macOS/StatusView.swift b/Diffusion-macOS/StatusView.swift index f753d0e..cc4c90e 100644 --- a/Diffusion-macOS/StatusView.swift +++ b/Diffusion-macOS/StatusView.swift @@ -8,8 +8,54 @@ import SwiftUI struct StatusView: View { + @EnvironmentObject var generation: GenerationContext var pipelineState: Binding + func submit() { + if case .running = generation.state { return } + Task { + generation.state = .running(nil) + let interval: TimeInterval? + let image: CGImage? + let prompt = "Portrait of cat in a tuxedo, oil on canvas" + (image, interval) = await generation.generate(prompt: prompt) ?? (nil, nil) + generation.state = .complete(prompt, image, interval) + } + } + + func generationStatusView() -> any View { + switch generation.state { + case .startup: return EmptyView() + case .running(let progress): + guard let progress = progress, progress.stepCount > 0 else { + // The first time it takes a little bit before generation starts + return HStack { + Text("Preparing model…") + Spacer() + } + } + let step = Int(progress.step) + 1 + let fraction = Double(step) / Double(progress.stepCount) + return HStack { + Text("Generating \(Int(round(100*fraction)))%") + Spacer() + } + case .complete(_, let image, let interval): + guard let _ = image else { + return HStack { + Text("Safety checker triggered, please try a different prompt or seed") + Spacer() + } + } + + return HStack { + let intervalString = String(format: "Time: %.1fs", interval ?? 0) + Text(intervalString) + Spacer() + }.frame(maxHeight: 25) + } + } + var body: some View { switch pipelineState.wrappedValue { case .downloading(let progress): @@ -19,14 +65,19 @@ struct StatusView: View { case .loading: ProgressView("Loading…", value: 105, total: 110).padding() case .ready: - Button { - // Generate image here - } label: { - Text("Generate") - .frame(maxWidth: .infinity) - .frame(height: 50) + VStack { + Button { + submit() + } label: { + Text("Generate") + .frame(maxWidth: .infinity) + .frame(height: 50) + } + .buttonStyle(.borderedProminent) + + // Generation state + AnyView(generationStatusView()) } - .buttonStyle(.borderedProminent) case .failed: Text("Pipeline loading error") } diff --git a/Diffusion.xcodeproj/project.pbxproj b/Diffusion.xcodeproj/project.pbxproj index 68242c2..0e14092 100644 --- a/Diffusion.xcodeproj/project.pbxproj +++ b/Diffusion.xcodeproj/project.pbxproj @@ -23,6 +23,7 @@ EBDD7DB62973206600C1C4B2 /* Downloader.swift in Sources */ = {isa = PBXBuildFile; fileRef = EBB5BA5929426E06003A2A5B /* Downloader.swift */; }; EBDD7DB82976AAFE00C1C4B2 /* State.swift in Sources */ = {isa = PBXBuildFile; fileRef = EBDD7DB72976AAFE00C1C4B2 /* State.swift */; }; EBDD7DB92976AAFE00C1C4B2 /* State.swift in Sources */ = {isa = PBXBuildFile; fileRef = EBDD7DB72976AAFE00C1C4B2 /* State.swift */; }; + EBDD7DBD2977FFB300C1C4B2 /* GeneratedImageView.swift in Sources */ = {isa = PBXBuildFile; fileRef = EBDD7DBB2977FFB300C1C4B2 /* GeneratedImageView.swift */; }; EBE3FF4C295E1EFE00E921AA /* ModelInfo.swift in Sources */ = {isa = PBXBuildFile; fileRef = EBE3FF4B295E1EFE00E921AA /* ModelInfo.swift */; }; EBE755C9293E37DD00806B32 /* DiffusionApp.swift in Sources */ = {isa = PBXBuildFile; fileRef = EBE755C8293E37DD00806B32 /* DiffusionApp.swift */; }; EBE755CB293E37DD00806B32 /* TextToImage.swift in Sources */ = {isa = PBXBuildFile; fileRef = EBE755CA293E37DD00806B32 /* TextToImage.swift */; }; @@ -67,6 +68,7 @@ EBDD7DB22973200200C1C4B2 /* Utils.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = Utils.swift; sourceTree = ""; }; EBDD7DB72976AAFE00C1C4B2 /* State.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = State.swift; sourceTree = ""; }; EBDD7DBA2976F03600C1C4B2 /* debug.xcconfig */ = {isa = PBXFileReference; lastKnownFileType = text.xcconfig; name = debug.xcconfig; path = config/debug.xcconfig; sourceTree = ""; }; + EBDD7DBB2977FFB300C1C4B2 /* GeneratedImageView.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = GeneratedImageView.swift; sourceTree = ""; }; EBE3FF4A295DFE2400E921AA /* common.xcconfig */ = {isa = PBXFileReference; lastKnownFileType = text.xcconfig; name = common.xcconfig; path = config/common.xcconfig; sourceTree = ""; }; EBE3FF4B295E1EFE00E921AA /* ModelInfo.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ModelInfo.swift; sourceTree = ""; }; EBE4438729488DCA00CDA605 /* README.md */ = {isa = PBXFileReference; lastKnownFileType = net.daringfireball.markdown; path = README.md; sourceTree = ""; }; @@ -241,6 +243,7 @@ children = ( F15520232971093300DC009B /* Diffusion_macOSApp.swift */, F15520252971093300DC009B /* ContentView.swift */, + EBDD7DBB2977FFB300C1C4B2 /* GeneratedImageView.swift */, F1552030297109C300DC009B /* PromptView.swift */, F155203329710B3600DC009B /* StatusView.swift */, F15520272971093400DC009B /* Assets.xcassets */, @@ -479,6 +482,7 @@ F155203429710B3600DC009B /* StatusView.swift in Sources */, F15520242971093300DC009B /* Diffusion_macOSApp.swift in Sources */, EBDD7DB52973201800C1C4B2 /* ModelInfo.swift in Sources */, + EBDD7DBD2977FFB300C1C4B2 /* GeneratedImageView.swift in Sources */, ); runOnlyForDeploymentPostprocessing = 0; }; diff --git a/Diffusion/State.swift b/Diffusion/State.swift index 56c149a..d90e5a9 100644 --- a/Diffusion/State.swift +++ b/Diffusion/State.swift @@ -5,7 +5,9 @@ // Created by Pedro Cuenca on 17/1/23. // +import Combine import SwiftUI +import StableDiffusion let DEFAULT_MODEL = ModelInfo.v2Base @@ -16,8 +18,29 @@ enum GenerationState { } class GenerationContext: ObservableObject { - @Published var pipeline: Pipeline? = nil + let scheduler = StableDiffusionScheduler.dpmSolverMultistepScheduler + + @Published var pipeline: Pipeline? = nil { + didSet { + if let pipeline = pipeline { + progressSubscriber = pipeline + .progressPublisher + .receive(on: DispatchQueue.main) + .sink { progress in + guard let progress = progress else { return } + self.state = .running(progress) + } + } + } + } @Published var state: GenerationState = .startup + + private var progressSubscriber: Cancellable? + + func generate(prompt: String, steps: Int = 25, seed: UInt32? = nil) async -> (CGImage, TimeInterval)? { + guard let pipeline = pipeline else { return nil } + return try? pipeline.generate(prompt: prompt, scheduler: scheduler, numInferenceSteps: steps, seed: seed) + } } class Settings {