Refactor & UI Changes

* Refactor code to split it up a bit more into specific classes & rename consistently
* Add extra supported parameters to generate in Pipeline
* Update UI in TextoImageView to handle all values to be supported (but some are not usable currently)
pull/2/head
Fahim Farook 1 year ago
parent d7d0039969
commit c79312fa8e

@ -7,20 +7,21 @@
objects = {
/* Begin PBXBuildFile section */
DD6512D0294B0FF50035EC0F /* PreviewView.swift in Sources */ = {isa = PBXBuildFile; fileRef = DD6512CF294B0FF50035EC0F /* PreviewView.swift */; };
EBB5BA5329425BEE003A2A5B /* PipelineLoader.swift in Sources */ = {isa = PBXBuildFile; fileRef = EBB5BA5229425BEE003A2A5B /* PipelineLoader.swift */; };
EBB5BA5829425E17003A2A5B /* Path in Frameworks */ = {isa = PBXBuildFile; productRef = EBB5BA5729425E17003A2A5B /* Path */; };
EBB5BA5A29426E06003A2A5B /* Downloader.swift in Sources */ = {isa = PBXBuildFile; fileRef = EBB5BA5929426E06003A2A5B /* Downloader.swift */; };
EBB5BA5D294504DE003A2A5B /* ZIPFoundation in Frameworks */ = {isa = PBXBuildFile; productRef = EBB5BA5C294504DE003A2A5B /* ZIPFoundation */; };
EBB5BA602946365A003A2A5B /* StableDiffusion in Frameworks */ = {isa = PBXBuildFile; productRef = EBB5BA5F2946365A003A2A5B /* StableDiffusion */; };
EBE755C9293E37DD00806B32 /* DiffusionApp.swift in Sources */ = {isa = PBXBuildFile; fileRef = EBE755C8293E37DD00806B32 /* DiffusionApp.swift */; };
EBE755CB293E37DD00806B32 /* TextToImage.swift in Sources */ = {isa = PBXBuildFile; fileRef = EBE755CA293E37DD00806B32 /* TextToImage.swift */; };
EBE755CB293E37DD00806B32 /* TextToImageView.swift in Sources */ = {isa = PBXBuildFile; fileRef = EBE755CA293E37DD00806B32 /* TextToImageView.swift */; };
EBE755CD293E37DD00806B32 /* Assets.xcassets in Resources */ = {isa = PBXBuildFile; fileRef = EBE755CC293E37DD00806B32 /* Assets.xcassets */; };
EBE755D1293E37DD00806B32 /* Preview Assets.xcassets in Resources */ = {isa = PBXBuildFile; fileRef = EBE755D0293E37DD00806B32 /* Preview Assets.xcassets */; };
EBE755DB293E37DE00806B32 /* DiffusionTests.swift in Sources */ = {isa = PBXBuildFile; fileRef = EBE755DA293E37DE00806B32 /* DiffusionTests.swift */; };
EBE755E5293E37DE00806B32 /* DiffusionUITests.swift in Sources */ = {isa = PBXBuildFile; fileRef = EBE755E4293E37DE00806B32 /* DiffusionUITests.swift */; };
EBE755E7293E37DE00806B32 /* DiffusionUITestsLaunchTests.swift in Sources */ = {isa = PBXBuildFile; fileRef = EBE755E6293E37DE00806B32 /* DiffusionUITestsLaunchTests.swift */; };
EBE75602293E91E200806B32 /* Pipeline.swift in Sources */ = {isa = PBXBuildFile; fileRef = EBE75601293E91E200806B32 /* Pipeline.swift */; };
EBE756092941178600806B32 /* Loading.swift in Sources */ = {isa = PBXBuildFile; fileRef = EBE756082941178600806B32 /* Loading.swift */; };
EBE756092941178600806B32 /* LoadingView.swift in Sources */ = {isa = PBXBuildFile; fileRef = EBE756082941178600806B32 /* LoadingView.swift */; };
/* End PBXBuildFile section */
/* Begin PBXContainerItemProxy section */
@ -45,13 +46,14 @@
DD651295294ADFAB0035EC0F /* Debug.xcconfig */ = {isa = PBXFileReference; lastKnownFileType = text.xcconfig; path = Debug.xcconfig; sourceTree = "<group>"; };
DD651296294AE18F0035EC0F /* Release.xcconfig */ = {isa = PBXFileReference; lastKnownFileType = text.xcconfig; path = Release.xcconfig; sourceTree = "<group>"; };
DD651297294AE1A70035EC0F /* Sign-Release.xcconfig */ = {isa = PBXFileReference; lastKnownFileType = text.xcconfig; path = "Sign-Release.xcconfig"; sourceTree = "<group>"; };
DD6512CF294B0FF50035EC0F /* PreviewView.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = PreviewView.swift; sourceTree = "<group>"; };
EBB5BA5229425BEE003A2A5B /* PipelineLoader.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = PipelineLoader.swift; sourceTree = "<group>"; };
EBB5BA5929426E06003A2A5B /* Downloader.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = Downloader.swift; sourceTree = "<group>"; };
EBE4438729488DCA00CDA605 /* README.md */ = {isa = PBXFileReference; lastKnownFileType = net.daringfireball.markdown; path = README.md; sourceTree = "<group>"; };
EBE443892948953600CDA605 /* LICENSE */ = {isa = PBXFileReference; lastKnownFileType = text; path = LICENSE; sourceTree = "<group>"; };
EBE755C5293E37DD00806B32 /* Diffusion.app */ = {isa = PBXFileReference; explicitFileType = wrapper.application; includeInIndex = 0; path = Diffusion.app; sourceTree = BUILT_PRODUCTS_DIR; };
EBE755C8293E37DD00806B32 /* DiffusionApp.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = DiffusionApp.swift; sourceTree = "<group>"; };
EBE755CA293E37DD00806B32 /* TextToImage.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = TextToImage.swift; sourceTree = "<group>"; };
EBE755CA293E37DD00806B32 /* TextToImageView.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = TextToImageView.swift; sourceTree = "<group>"; };
EBE755CC293E37DD00806B32 /* Assets.xcassets */ = {isa = PBXFileReference; lastKnownFileType = folder.assetcatalog; path = Assets.xcassets; sourceTree = "<group>"; };
EBE755CE293E37DD00806B32 /* Diffusion.entitlements */ = {isa = PBXFileReference; lastKnownFileType = text.plist.entitlements; path = Diffusion.entitlements; sourceTree = "<group>"; };
EBE755D0293E37DD00806B32 /* Preview Assets.xcassets */ = {isa = PBXFileReference; lastKnownFileType = folder.assetcatalog; path = "Preview Assets.xcassets"; sourceTree = "<group>"; };
@ -61,7 +63,7 @@
EBE755E4293E37DE00806B32 /* DiffusionUITests.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = DiffusionUITests.swift; sourceTree = "<group>"; };
EBE755E6293E37DE00806B32 /* DiffusionUITestsLaunchTests.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = DiffusionUITestsLaunchTests.swift; sourceTree = "<group>"; };
EBE75601293E91E200806B32 /* Pipeline.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = Pipeline.swift; sourceTree = "<group>"; };
EBE756082941178600806B32 /* Loading.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = Loading.swift; sourceTree = "<group>"; };
EBE756082941178600806B32 /* LoadingView.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = LoadingView.swift; sourceTree = "<group>"; };
/* End PBXFileReference section */
/* Begin PBXFrameworksBuildPhase section */
@ -193,8 +195,9 @@
EBE7560A29411A5E00806B32 /* Views */ = {
isa = PBXGroup;
children = (
EBE756082941178600806B32 /* Loading.swift */,
EBE755CA293E37DD00806B32 /* TextToImage.swift */,
EBE756082941178600806B32 /* LoadingView.swift */,
DD6512CF294B0FF50035EC0F /* PreviewView.swift */,
EBE755CA293E37DD00806B32 /* TextToImageView.swift */,
);
path = Views;
sourceTree = "<group>";
@ -340,9 +343,10 @@
buildActionMask = 2147483647;
files = (
EBE75602293E91E200806B32 /* Pipeline.swift in Sources */,
EBE755CB293E37DD00806B32 /* TextToImage.swift in Sources */,
EBE755CB293E37DD00806B32 /* TextToImageView.swift in Sources */,
DD6512D0294B0FF50035EC0F /* PreviewView.swift in Sources */,
EBB5BA5A29426E06003A2A5B /* Downloader.swift in Sources */,
EBE756092941178600806B32 /* Loading.swift in Sources */,
EBE756092941178600806B32 /* LoadingView.swift in Sources */,
EBB5BA5329425BEE003A2A5B /* PipelineLoader.swift in Sources */,
EBE755C9293E37DD00806B32 /* DiffusionApp.swift in Sources */,
);

@ -29,15 +29,16 @@ class Pipeline {
self.pipeline = pipeline
}
func generate(prompt: String, scheduler: StableDiffusionScheduler, numInferenceSteps stepCount: Int = 50, seed: UInt32? = nil) throws -> CGImage {
func generate(prompt: String, scheduler: StableDiffusionScheduler, numInferenceSteps stepCount: Int = 50, imageCount: Int = 1, safetyOn: Bool = false, seed: UInt32? = nil) throws -> CGImage {
let beginDate = Date()
print("Generating...")
let theSeed = seed ?? UInt32.random(in: 0..<UInt32.max)
let images = try pipeline.generateImages(
prompt: prompt,
imageCount: 1,
imageCount: imageCount,
stepCount: stepCount,
seed: theSeed,
disableSafety: !safetyOn,
scheduler: scheduler
) { progress in
handleProgress(progress)

@ -33,7 +33,7 @@ struct LoadingView: View {
var body: some View {
VStack {
switch currentView {
case .textToImage: TextToImage().transition(.opacity)
case .textToImage: TextToImageView().transition(.opacity)
case .error(let message): ErrorPopover(errorMessage: message).transition(.move(edge: .top))
case .loading:
// TODO: Don't present progress view if the pipeline is cached

@ -0,0 +1,46 @@
//
// PreviewView.swift
// Diffusion
//
// Created by Fahim Farook on 15/12/2022.
//
import SwiftUI
struct PreviewView: View {
var image: Binding<CGImage?>
var state: Binding<GenerationState>
var body: some View {
switch state.wrappedValue {
case .startup: return AnyView(Image("placeholder").resizable())
case .running(let progress):
guard let progress = progress, progress.stepCount > 0 else {
// The first time it takes a little bit before generation starts
return AnyView(ProgressView())
}
let step = Int(progress.step) + 1
let fraction = Double(step) / Double(progress.stepCount)
let label = "Step \(step) of \(progress.stepCount)"
return AnyView(ProgressView(label, value: fraction, total: 1).padding())
case .idle(let lastPrompt):
guard let theImage = image.wrappedValue else {
return AnyView(Image(systemName: "exclamationmark.triangle").resizable())
}
let imageView = Image(theImage, scale: 1, label: Text("generated"))
return AnyView(
VStack {
imageView.resizable().clipShape(RoundedRectangle(cornerRadius: 20))
ShareLink(item: imageView, preview: SharePreview(lastPrompt, image: imageView))
})
}
}
}
struct PreviewView_Previews: PreviewProvider {
static var previews: some View {
PreviewView(image: .constant(nil), state: .constant(.startup))
}
}

@ -1,104 +0,0 @@
//
// TextToImage.swift
// Diffusion
//
// Created by Pedro Cuenca on December 2022.
// See LICENSE at https://github.com/huggingface/swift-coreml-diffusers/LICENSE
//
import SwiftUI
import Combine
import StableDiffusion
// TODO: bind to UI controls
let scheduler = StableDiffusionScheduler.dpmpp
let steps = 25
let seed: UInt32? = nil
func generate(pipeline: Pipeline?, prompt: String) async -> CGImage? {
guard let pipeline = pipeline else { return nil }
return try? pipeline.generate(prompt: prompt, scheduler: scheduler, numInferenceSteps: steps, seed: seed)
}
enum GenerationState {
case startup
case running(StableDiffusionProgress?)
case idle(String)
}
struct ImageWithPlaceholder: View {
var image: Binding<CGImage?>
var state: Binding<GenerationState>
var body: some View {
switch state.wrappedValue {
case .startup: return AnyView(Image("placeholder").resizable())
case .running(let progress):
guard let progress = progress, progress.stepCount > 0 else {
// The first time it takes a little bit before generation starts
return AnyView(ProgressView())
}
let step = Int(progress.step) + 1
let fraction = Double(step) / Double(progress.stepCount)
let label = "Step \(step) of \(progress.stepCount)"
return AnyView(ProgressView(label, value: fraction, total: 1).padding())
case .idle(let lastPrompt):
guard let theImage = image.wrappedValue else {
return AnyView(Image(systemName: "exclamationmark.triangle").resizable())
}
let imageView = Image(theImage, scale: 1, label: Text("generated"))
return AnyView(
VStack {
imageView.resizable().clipShape(RoundedRectangle(cornerRadius: 20))
ShareLink(item: imageView, preview: SharePreview(lastPrompt, image: imageView))
})
}
}
}
struct TextToImage: View {
@EnvironmentObject var context: DiffusionGlobals
@State private var prompt = "Labrador in the style of Vermeer"
@State private var image: CGImage? = nil
@State private var state: GenerationState = .startup
@State private var progressSubscriber: Cancellable?
func submit() {
if case .running = state { return }
Task {
state = .running(nil)
image = await generate(pipeline: context.pipeline, prompt: prompt)
state = .idle(prompt)
}
}
var body: some View {
VStack {
HStack {
TextField("Prompt", text: $prompt)
.textFieldStyle(.roundedBorder)
.onSubmit {
submit()
}
Button("Generate") {
submit()
}
.padding()
.buttonStyle(.borderedProminent)
}
ImageWithPlaceholder(image: $image, state: $state)
.scaledToFit()
Spacer()
}
.padding()
.onAppear {
progressSubscriber = context.pipeline!.progressPublisher.sink { progress in
guard let progress = progress else { return }
state = .running(progress)
}
}
}
}

@ -0,0 +1,111 @@
//
// TextToImageView.swift
// Diffusion
//
// Created by Pedro Cuenca on December 2022.
// See LICENSE at https://github.com/huggingface/swift-coreml-diffusers/LICENSE
//
import SwiftUI
import Combine
import StableDiffusion
enum GenerationState {
case startup
case running(StableDiffusionProgress?)
case idle(String)
}
struct TextToImageView: View {
@EnvironmentObject var context: DiffusionGlobals
@State private var image: CGImage? = nil
@State private var state: GenerationState = .startup
@State private var prompt = "Labrador in the style of Vermeer"
@State private var scheduler = StableDiffusionScheduler.dpmpp
@State private var width = 512.0
@State private var height = 512.0
@State private var steps = 25.0
@State private var numImages = 1.0
@State private var seed: UInt32? = nil
@State private var safetyOn: Bool = true
@State private var progressSubscriber: Cancellable?
func submit() {
if case .running = state { return }
Task {
state = .running(nil)
image = await generate(pipeline: context.pipeline, prompt: prompt)
state = .idle(prompt)
}
}
var body: some View {
VStack(alignment: .leading) {
HStack {
TextField("Prompt", text: $prompt)
.textFieldStyle(.roundedBorder)
.onSubmit {
submit()
}
Button("Generate") {
submit()
}
.padding()
.buttonStyle(.borderedProminent)
}
Spacer()
HStack(alignment: .top) {
VStack(alignment: .leading) {
Group {
Text("Image Width")
Slider(value: $width, in: 64...2048, step: 8, label: {},
minimumValueLabel: {Text("64")},
maximumValueLabel: {Text("2048")})
Text("Image Height")
Slider(value: $height, in: 64...2048, step: 8, label: {},
minimumValueLabel: {Text("64")},
maximumValueLabel: {Text("2048")})
}
Text("Number of Inference Steps")
Slider(value: $steps, in: 1...300, step: 1, label: {},
minimumValueLabel: {Text("1")},
maximumValueLabel: {Text("300")})
Text("Number of Images")
Slider(value: $numImages, in: 1...8, step: 1, label: {},
minimumValueLabel: {Text("1")},
maximumValueLabel: {Text("8")})
Text("Safety")
Toggle("", isOn: $safetyOn)
Text("Seed Check On?")
TextField("", value: $seed, format: .number)
}
Spacer()
VStack {
PreviewView(image: $image, state: $state)
.scaledToFit()
}
}
Spacer()
}
.padding()
.onAppear {
progressSubscriber = context.pipeline?.progressPublisher.sink { progress in
guard let progress = progress else { return }
state = .running(progress)
}
}
}
func generate(pipeline: Pipeline?, prompt: String) async -> CGImage? {
guard let pipeline = pipeline else { return nil }
return try? pipeline.generate(prompt: prompt, scheduler: scheduler, numInferenceSteps: Int(steps), safetyOn: safetyOn, seed: seed)
}
}
struct TextToImageView_Previews: PreviewProvider {
static var previews: some View {
TextToImageView().environmentObject(DiffusionGlobals())
}
}
Loading…
Cancel
Save