Make Generate button work.

(Hardcoded prompt)
pull/12/head
Pedro Cuenca 1 year ago
parent 79e98f4457
commit e14e0119cb

@ -8,22 +8,25 @@
import SwiftUI
struct ContentView: View {
@StateObject var context = GenerationContext()
var body: some View {
NavigationSplitView {
PromptView()
.navigationSplitViewColumnWidth(min: 250, ideal: 300)
} detail: {
Image("placeholder")
.resizable()
GeneratedImageView()
.aspectRatio(contentMode: .fit)
.frame(width: 512, height: 512)
.cornerRadius(15)
.toolbar {
Button(action: {}) {
Label("share", systemImage: "square.and.arrow.up")
}
}
Label("share", systemImage: "square.and.arrow.up")
}
}
}
.environmentObject(context)
}
}

@ -0,0 +1,47 @@
//
// GeneratedImageView.swift
// Diffusion
//
// Created by Pedro Cuenca on 18/1/23.
//
import SwiftUI
struct GeneratedImageView: View {
@EnvironmentObject var generation: GenerationContext
var body: some View {
switch generation.state {
case .startup: return AnyView(Image("placeholder").resizable())
case .running(let progress):
guard let progress = progress, progress.stepCount > 0 else {
// The first time it takes a little bit before generation starts
return AnyView(ProgressView())
}
let step = Int(progress.step) + 1
let fraction = Double(step) / Double(progress.stepCount)
let label = "Step \(step) of \(progress.stepCount)"
return AnyView(ProgressView(label, value: fraction, total: 1).padding())
case .complete(let lastPrompt, let image, let interval):
guard let theImage = image else {
return AnyView(Image(systemName: "exclamationmark.triangle").resizable())
}
let imageView = Image(theImage, scale: 1, label: Text("generated"))
return AnyView(
VStack {
imageView.resizable().clipShape(RoundedRectangle(cornerRadius: 20))
// HStack {
// let intervalString = String(format: "Time: %.1fs", interval ?? 0)
// Rectangle().fill(.clear).overlay(Text(intervalString).frame(maxWidth: .infinity, alignment: .leading).padding(.leading))
// Rectangle().fill(.clear).overlay(
// HStack {
// Spacer()
// ShareButtons(image: theImage, name: lastPrompt).padding(.trailing)
// }
// )
// }.frame(maxHeight: 25)
})
}
}
}

@ -18,7 +18,7 @@ enum PipelineState {
}
struct PromptView: View {
@StateObject var context = GenerationContext()
@EnvironmentObject var context: GenerationContext
static let models = ModelInfo.MODELS
static let modelNames = models.map { $0.modelVersion }

@ -8,8 +8,54 @@
import SwiftUI
struct StatusView: View {
@EnvironmentObject var generation: GenerationContext
var pipelineState: Binding<PipelineState>
func submit() {
if case .running = generation.state { return }
Task {
generation.state = .running(nil)
let interval: TimeInterval?
let image: CGImage?
let prompt = "Portrait of cat in a tuxedo, oil on canvas"
(image, interval) = await generation.generate(prompt: prompt) ?? (nil, nil)
generation.state = .complete(prompt, image, interval)
}
}
func generationStatusView() -> any View {
switch generation.state {
case .startup: return EmptyView()
case .running(let progress):
guard let progress = progress, progress.stepCount > 0 else {
// The first time it takes a little bit before generation starts
return HStack {
Text("Preparing model…")
Spacer()
}
}
let step = Int(progress.step) + 1
let fraction = Double(step) / Double(progress.stepCount)
return HStack {
Text("Generating \(Int(round(100*fraction)))%")
Spacer()
}
case .complete(_, let image, let interval):
guard let _ = image else {
return HStack {
Text("Safety checker triggered, please try a different prompt or seed")
Spacer()
}
}
return HStack {
let intervalString = String(format: "Time: %.1fs", interval ?? 0)
Text(intervalString)
Spacer()
}.frame(maxHeight: 25)
}
}
var body: some View {
switch pipelineState.wrappedValue {
case .downloading(let progress):
@ -19,14 +65,19 @@ struct StatusView: View {
case .loading:
ProgressView("Loading…", value: 105, total: 110).padding()
case .ready:
Button {
// Generate image here
} label: {
Text("Generate")
.frame(maxWidth: .infinity)
.frame(height: 50)
VStack {
Button {
submit()
} label: {
Text("Generate")
.frame(maxWidth: .infinity)
.frame(height: 50)
}
.buttonStyle(.borderedProminent)
// Generation state
AnyView(generationStatusView())
}
.buttonStyle(.borderedProminent)
case .failed:
Text("Pipeline loading error")
}

@ -23,6 +23,7 @@
EBDD7DB62973206600C1C4B2 /* Downloader.swift in Sources */ = {isa = PBXBuildFile; fileRef = EBB5BA5929426E06003A2A5B /* Downloader.swift */; };
EBDD7DB82976AAFE00C1C4B2 /* State.swift in Sources */ = {isa = PBXBuildFile; fileRef = EBDD7DB72976AAFE00C1C4B2 /* State.swift */; };
EBDD7DB92976AAFE00C1C4B2 /* State.swift in Sources */ = {isa = PBXBuildFile; fileRef = EBDD7DB72976AAFE00C1C4B2 /* State.swift */; };
EBDD7DBD2977FFB300C1C4B2 /* GeneratedImageView.swift in Sources */ = {isa = PBXBuildFile; fileRef = EBDD7DBB2977FFB300C1C4B2 /* GeneratedImageView.swift */; };
EBE3FF4C295E1EFE00E921AA /* ModelInfo.swift in Sources */ = {isa = PBXBuildFile; fileRef = EBE3FF4B295E1EFE00E921AA /* ModelInfo.swift */; };
EBE755C9293E37DD00806B32 /* DiffusionApp.swift in Sources */ = {isa = PBXBuildFile; fileRef = EBE755C8293E37DD00806B32 /* DiffusionApp.swift */; };
EBE755CB293E37DD00806B32 /* TextToImage.swift in Sources */ = {isa = PBXBuildFile; fileRef = EBE755CA293E37DD00806B32 /* TextToImage.swift */; };
@ -67,6 +68,7 @@
EBDD7DB22973200200C1C4B2 /* Utils.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = Utils.swift; sourceTree = "<group>"; };
EBDD7DB72976AAFE00C1C4B2 /* State.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = State.swift; sourceTree = "<group>"; };
EBDD7DBA2976F03600C1C4B2 /* debug.xcconfig */ = {isa = PBXFileReference; lastKnownFileType = text.xcconfig; name = debug.xcconfig; path = config/debug.xcconfig; sourceTree = "<group>"; };
EBDD7DBB2977FFB300C1C4B2 /* GeneratedImageView.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = GeneratedImageView.swift; sourceTree = "<group>"; };
EBE3FF4A295DFE2400E921AA /* common.xcconfig */ = {isa = PBXFileReference; lastKnownFileType = text.xcconfig; name = common.xcconfig; path = config/common.xcconfig; sourceTree = "<group>"; };
EBE3FF4B295E1EFE00E921AA /* ModelInfo.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ModelInfo.swift; sourceTree = "<group>"; };
EBE4438729488DCA00CDA605 /* README.md */ = {isa = PBXFileReference; lastKnownFileType = net.daringfireball.markdown; path = README.md; sourceTree = "<group>"; };
@ -241,6 +243,7 @@
children = (
F15520232971093300DC009B /* Diffusion_macOSApp.swift */,
F15520252971093300DC009B /* ContentView.swift */,
EBDD7DBB2977FFB300C1C4B2 /* GeneratedImageView.swift */,
F1552030297109C300DC009B /* PromptView.swift */,
F155203329710B3600DC009B /* StatusView.swift */,
F15520272971093400DC009B /* Assets.xcassets */,
@ -479,6 +482,7 @@
F155203429710B3600DC009B /* StatusView.swift in Sources */,
F15520242971093300DC009B /* Diffusion_macOSApp.swift in Sources */,
EBDD7DB52973201800C1C4B2 /* ModelInfo.swift in Sources */,
EBDD7DBD2977FFB300C1C4B2 /* GeneratedImageView.swift in Sources */,
);
runOnlyForDeploymentPostprocessing = 0;
};

@ -5,7 +5,9 @@
// Created by Pedro Cuenca on 17/1/23.
//
import Combine
import SwiftUI
import StableDiffusion
let DEFAULT_MODEL = ModelInfo.v2Base
@ -16,8 +18,29 @@ enum GenerationState {
}
class GenerationContext: ObservableObject {
@Published var pipeline: Pipeline? = nil
let scheduler = StableDiffusionScheduler.dpmSolverMultistepScheduler
@Published var pipeline: Pipeline? = nil {
didSet {
if let pipeline = pipeline {
progressSubscriber = pipeline
.progressPublisher
.receive(on: DispatchQueue.main)
.sink { progress in
guard let progress = progress else { return }
self.state = .running(progress)
}
}
}
}
@Published var state: GenerationState = .startup
private var progressSubscriber: Cancellable?
func generate(prompt: String, steps: Int = 25, seed: UInt32? = nil) async -> (CGImage, TimeInterval)? {
guard let pipeline = pipeline else { return nil }
return try? pipeline.generate(prompt: prompt, scheduler: scheduler, numInferenceSteps: steps, seed: seed)
}
}
class Settings {

Loading…
Cancel
Save