diff --git a/Diffusion.xcodeproj/project.pbxproj b/Diffusion.xcodeproj/project.pbxproj index 4ff8217..b9c5bac 100644 --- a/Diffusion.xcodeproj/project.pbxproj +++ b/Diffusion.xcodeproj/project.pbxproj @@ -7,20 +7,21 @@ objects = { /* Begin PBXBuildFile section */ + DD6512D0294B0FF50035EC0F /* PreviewView.swift in Sources */ = {isa = PBXBuildFile; fileRef = DD6512CF294B0FF50035EC0F /* PreviewView.swift */; }; EBB5BA5329425BEE003A2A5B /* PipelineLoader.swift in Sources */ = {isa = PBXBuildFile; fileRef = EBB5BA5229425BEE003A2A5B /* PipelineLoader.swift */; }; EBB5BA5829425E17003A2A5B /* Path in Frameworks */ = {isa = PBXBuildFile; productRef = EBB5BA5729425E17003A2A5B /* Path */; }; EBB5BA5A29426E06003A2A5B /* Downloader.swift in Sources */ = {isa = PBXBuildFile; fileRef = EBB5BA5929426E06003A2A5B /* Downloader.swift */; }; EBB5BA5D294504DE003A2A5B /* ZIPFoundation in Frameworks */ = {isa = PBXBuildFile; productRef = EBB5BA5C294504DE003A2A5B /* ZIPFoundation */; }; EBB5BA602946365A003A2A5B /* StableDiffusion in Frameworks */ = {isa = PBXBuildFile; productRef = EBB5BA5F2946365A003A2A5B /* StableDiffusion */; }; EBE755C9293E37DD00806B32 /* DiffusionApp.swift in Sources */ = {isa = PBXBuildFile; fileRef = EBE755C8293E37DD00806B32 /* DiffusionApp.swift */; }; - EBE755CB293E37DD00806B32 /* TextToImage.swift in Sources */ = {isa = PBXBuildFile; fileRef = EBE755CA293E37DD00806B32 /* TextToImage.swift */; }; + EBE755CB293E37DD00806B32 /* TextToImageView.swift in Sources */ = {isa = PBXBuildFile; fileRef = EBE755CA293E37DD00806B32 /* TextToImageView.swift */; }; EBE755CD293E37DD00806B32 /* Assets.xcassets in Resources */ = {isa = PBXBuildFile; fileRef = EBE755CC293E37DD00806B32 /* Assets.xcassets */; }; EBE755D1293E37DD00806B32 /* Preview Assets.xcassets in Resources */ = {isa = PBXBuildFile; fileRef = EBE755D0293E37DD00806B32 /* Preview Assets.xcassets */; }; EBE755DB293E37DE00806B32 /* DiffusionTests.swift in Sources */ = {isa = PBXBuildFile; fileRef = EBE755DA293E37DE00806B32 /* DiffusionTests.swift */; }; EBE755E5293E37DE00806B32 /* DiffusionUITests.swift in Sources */ = {isa = PBXBuildFile; fileRef = EBE755E4293E37DE00806B32 /* DiffusionUITests.swift */; }; EBE755E7293E37DE00806B32 /* DiffusionUITestsLaunchTests.swift in Sources */ = {isa = PBXBuildFile; fileRef = EBE755E6293E37DE00806B32 /* DiffusionUITestsLaunchTests.swift */; }; EBE75602293E91E200806B32 /* Pipeline.swift in Sources */ = {isa = PBXBuildFile; fileRef = EBE75601293E91E200806B32 /* Pipeline.swift */; }; - EBE756092941178600806B32 /* Loading.swift in Sources */ = {isa = PBXBuildFile; fileRef = EBE756082941178600806B32 /* Loading.swift */; }; + EBE756092941178600806B32 /* LoadingView.swift in Sources */ = {isa = PBXBuildFile; fileRef = EBE756082941178600806B32 /* LoadingView.swift */; }; /* End PBXBuildFile section */ /* Begin PBXContainerItemProxy section */ @@ -45,13 +46,14 @@ DD651295294ADFAB0035EC0F /* Debug.xcconfig */ = {isa = PBXFileReference; lastKnownFileType = text.xcconfig; path = Debug.xcconfig; sourceTree = ""; }; DD651296294AE18F0035EC0F /* Release.xcconfig */ = {isa = PBXFileReference; lastKnownFileType = text.xcconfig; path = Release.xcconfig; sourceTree = ""; }; DD651297294AE1A70035EC0F /* Sign-Release.xcconfig */ = {isa = PBXFileReference; lastKnownFileType = text.xcconfig; path = "Sign-Release.xcconfig"; sourceTree = ""; }; + DD6512CF294B0FF50035EC0F /* PreviewView.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = PreviewView.swift; sourceTree = ""; }; EBB5BA5229425BEE003A2A5B /* PipelineLoader.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = PipelineLoader.swift; sourceTree = ""; }; EBB5BA5929426E06003A2A5B /* Downloader.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = Downloader.swift; sourceTree = ""; }; EBE4438729488DCA00CDA605 /* README.md */ = {isa = PBXFileReference; lastKnownFileType = net.daringfireball.markdown; path = README.md; sourceTree = ""; }; EBE443892948953600CDA605 /* LICENSE */ = {isa = PBXFileReference; lastKnownFileType = text; path = LICENSE; sourceTree = ""; }; EBE755C5293E37DD00806B32 /* Diffusion.app */ = {isa = PBXFileReference; explicitFileType = wrapper.application; includeInIndex = 0; path = Diffusion.app; sourceTree = BUILT_PRODUCTS_DIR; }; EBE755C8293E37DD00806B32 /* DiffusionApp.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = DiffusionApp.swift; sourceTree = ""; }; - EBE755CA293E37DD00806B32 /* TextToImage.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = TextToImage.swift; sourceTree = ""; }; + EBE755CA293E37DD00806B32 /* TextToImageView.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = TextToImageView.swift; sourceTree = ""; }; EBE755CC293E37DD00806B32 /* Assets.xcassets */ = {isa = PBXFileReference; lastKnownFileType = folder.assetcatalog; path = Assets.xcassets; sourceTree = ""; }; EBE755CE293E37DD00806B32 /* Diffusion.entitlements */ = {isa = PBXFileReference; lastKnownFileType = text.plist.entitlements; path = Diffusion.entitlements; sourceTree = ""; }; EBE755D0293E37DD00806B32 /* Preview Assets.xcassets */ = {isa = PBXFileReference; lastKnownFileType = folder.assetcatalog; path = "Preview Assets.xcassets"; sourceTree = ""; }; @@ -61,7 +63,7 @@ EBE755E4293E37DE00806B32 /* DiffusionUITests.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = DiffusionUITests.swift; sourceTree = ""; }; EBE755E6293E37DE00806B32 /* DiffusionUITestsLaunchTests.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = DiffusionUITestsLaunchTests.swift; sourceTree = ""; }; EBE75601293E91E200806B32 /* Pipeline.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = Pipeline.swift; sourceTree = ""; }; - EBE756082941178600806B32 /* Loading.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = Loading.swift; sourceTree = ""; }; + EBE756082941178600806B32 /* LoadingView.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = LoadingView.swift; sourceTree = ""; }; /* End PBXFileReference section */ /* Begin PBXFrameworksBuildPhase section */ @@ -193,8 +195,9 @@ EBE7560A29411A5E00806B32 /* Views */ = { isa = PBXGroup; children = ( - EBE756082941178600806B32 /* Loading.swift */, - EBE755CA293E37DD00806B32 /* TextToImage.swift */, + EBE756082941178600806B32 /* LoadingView.swift */, + DD6512CF294B0FF50035EC0F /* PreviewView.swift */, + EBE755CA293E37DD00806B32 /* TextToImageView.swift */, ); path = Views; sourceTree = ""; @@ -340,9 +343,10 @@ buildActionMask = 2147483647; files = ( EBE75602293E91E200806B32 /* Pipeline.swift in Sources */, - EBE755CB293E37DD00806B32 /* TextToImage.swift in Sources */, + EBE755CB293E37DD00806B32 /* TextToImageView.swift in Sources */, + DD6512D0294B0FF50035EC0F /* PreviewView.swift in Sources */, EBB5BA5A29426E06003A2A5B /* Downloader.swift in Sources */, - EBE756092941178600806B32 /* Loading.swift in Sources */, + EBE756092941178600806B32 /* LoadingView.swift in Sources */, EBB5BA5329425BEE003A2A5B /* PipelineLoader.swift in Sources */, EBE755C9293E37DD00806B32 /* DiffusionApp.swift in Sources */, ); diff --git a/Diffusion/Pipeline/Pipeline.swift b/Diffusion/Pipeline/Pipeline.swift index 7d1a7c7..3633f70 100644 --- a/Diffusion/Pipeline/Pipeline.swift +++ b/Diffusion/Pipeline/Pipeline.swift @@ -29,15 +29,16 @@ class Pipeline { self.pipeline = pipeline } - func generate(prompt: String, scheduler: StableDiffusionScheduler, numInferenceSteps stepCount: Int = 50, seed: UInt32? = nil) throws -> CGImage { + func generate(prompt: String, scheduler: StableDiffusionScheduler, numInferenceSteps stepCount: Int = 50, imageCount: Int = 1, safetyOn: Bool = false, seed: UInt32? = nil) throws -> CGImage { let beginDate = Date() print("Generating...") let theSeed = seed ?? UInt32.random(in: 0.. + var state: Binding + + var body: some View { + switch state.wrappedValue { + case .startup: return AnyView(Image("placeholder").resizable()) + case .running(let progress): + guard let progress = progress, progress.stepCount > 0 else { + // The first time it takes a little bit before generation starts + return AnyView(ProgressView()) + } + let step = Int(progress.step) + 1 + let fraction = Double(step) / Double(progress.stepCount) + let label = "Step \(step) of \(progress.stepCount)" + return AnyView(ProgressView(label, value: fraction, total: 1).padding()) + case .idle(let lastPrompt): + guard let theImage = image.wrappedValue else { + return AnyView(Image(systemName: "exclamationmark.triangle").resizable()) + } + + let imageView = Image(theImage, scale: 1, label: Text("generated")) + return AnyView( + VStack { + imageView.resizable().clipShape(RoundedRectangle(cornerRadius: 20)) + ShareLink(item: imageView, preview: SharePreview(lastPrompt, image: imageView)) + }) + } + } +} + + +struct PreviewView_Previews: PreviewProvider { + static var previews: some View { + PreviewView(image: .constant(nil), state: .constant(.startup)) + } +} diff --git a/Diffusion/Views/TextToImage.swift b/Diffusion/Views/TextToImage.swift deleted file mode 100644 index 307227b..0000000 --- a/Diffusion/Views/TextToImage.swift +++ /dev/null @@ -1,104 +0,0 @@ -// -// TextToImage.swift -// Diffusion -// -// Created by Pedro Cuenca on December 2022. -// See LICENSE at https://github.com/huggingface/swift-coreml-diffusers/LICENSE -// - -import SwiftUI -import Combine -import StableDiffusion - -// TODO: bind to UI controls -let scheduler = StableDiffusionScheduler.dpmpp -let steps = 25 -let seed: UInt32? = nil - -func generate(pipeline: Pipeline?, prompt: String) async -> CGImage? { - guard let pipeline = pipeline else { return nil } - return try? pipeline.generate(prompt: prompt, scheduler: scheduler, numInferenceSteps: steps, seed: seed) -} - -enum GenerationState { - case startup - case running(StableDiffusionProgress?) - case idle(String) -} - -struct ImageWithPlaceholder: View { - var image: Binding - var state: Binding - - var body: some View { - switch state.wrappedValue { - case .startup: return AnyView(Image("placeholder").resizable()) - case .running(let progress): - guard let progress = progress, progress.stepCount > 0 else { - // The first time it takes a little bit before generation starts - return AnyView(ProgressView()) - } - let step = Int(progress.step) + 1 - let fraction = Double(step) / Double(progress.stepCount) - let label = "Step \(step) of \(progress.stepCount)" - return AnyView(ProgressView(label, value: fraction, total: 1).padding()) - case .idle(let lastPrompt): - guard let theImage = image.wrappedValue else { - return AnyView(Image(systemName: "exclamationmark.triangle").resizable()) - } - - let imageView = Image(theImage, scale: 1, label: Text("generated")) - return AnyView( - VStack { - imageView.resizable().clipShape(RoundedRectangle(cornerRadius: 20)) - ShareLink(item: imageView, preview: SharePreview(lastPrompt, image: imageView)) - }) - } - } -} - -struct TextToImage: View { - @EnvironmentObject var context: DiffusionGlobals - - @State private var prompt = "Labrador in the style of Vermeer" - @State private var image: CGImage? = nil - @State private var state: GenerationState = .startup - - @State private var progressSubscriber: Cancellable? - - func submit() { - if case .running = state { return } - Task { - state = .running(nil) - image = await generate(pipeline: context.pipeline, prompt: prompt) - state = .idle(prompt) - } - } - - var body: some View { - VStack { - HStack { - TextField("Prompt", text: $prompt) - .textFieldStyle(.roundedBorder) - .onSubmit { - submit() - } - Button("Generate") { - submit() - } - .padding() - .buttonStyle(.borderedProminent) - } - ImageWithPlaceholder(image: $image, state: $state) - .scaledToFit() - Spacer() - } - .padding() - .onAppear { - progressSubscriber = context.pipeline!.progressPublisher.sink { progress in - guard let progress = progress else { return } - state = .running(progress) - } - } - } -} diff --git a/Diffusion/Views/TextToImageView.swift b/Diffusion/Views/TextToImageView.swift new file mode 100644 index 0000000..c6d6163 --- /dev/null +++ b/Diffusion/Views/TextToImageView.swift @@ -0,0 +1,111 @@ +// +// TextToImageView.swift +// Diffusion +// +// Created by Pedro Cuenca on December 2022. +// See LICENSE at https://github.com/huggingface/swift-coreml-diffusers/LICENSE +// + +import SwiftUI +import Combine +import StableDiffusion + +enum GenerationState { + case startup + case running(StableDiffusionProgress?) + case idle(String) +} + +struct TextToImageView: View { + @EnvironmentObject var context: DiffusionGlobals + + @State private var image: CGImage? = nil + @State private var state: GenerationState = .startup + @State private var prompt = "Labrador in the style of Vermeer" + @State private var scheduler = StableDiffusionScheduler.dpmpp + @State private var width = 512.0 + @State private var height = 512.0 + @State private var steps = 25.0 + @State private var numImages = 1.0 + @State private var seed: UInt32? = nil + @State private var safetyOn: Bool = true + + @State private var progressSubscriber: Cancellable? + + func submit() { + if case .running = state { return } + Task { + state = .running(nil) + image = await generate(pipeline: context.pipeline, prompt: prompt) + state = .idle(prompt) + } + } + + var body: some View { + VStack(alignment: .leading) { + HStack { + TextField("Prompt", text: $prompt) + .textFieldStyle(.roundedBorder) + .onSubmit { + submit() + } + Button("Generate") { + submit() + } + .padding() + .buttonStyle(.borderedProminent) + } + Spacer() + HStack(alignment: .top) { + VStack(alignment: .leading) { + Group { + Text("Image Width") + Slider(value: $width, in: 64...2048, step: 8, label: {}, + minimumValueLabel: {Text("64")}, + maximumValueLabel: {Text("2048")}) + Text("Image Height") + Slider(value: $height, in: 64...2048, step: 8, label: {}, + minimumValueLabel: {Text("64")}, + maximumValueLabel: {Text("2048")}) + } + Text("Number of Inference Steps") + Slider(value: $steps, in: 1...300, step: 1, label: {}, + minimumValueLabel: {Text("1")}, + maximumValueLabel: {Text("300")}) + Text("Number of Images") + Slider(value: $numImages, in: 1...8, step: 1, label: {}, + minimumValueLabel: {Text("1")}, + maximumValueLabel: {Text("8")}) + Text("Safety") + Toggle("", isOn: $safetyOn) + Text("Seed Check On?") + TextField("", value: $seed, format: .number) + } + Spacer() + VStack { + PreviewView(image: $image, state: $state) + .scaledToFit() + } + } + Spacer() + } + .padding() + .onAppear { + progressSubscriber = context.pipeline?.progressPublisher.sink { progress in + guard let progress = progress else { return } + state = .running(progress) + } + } + } + + func generate(pipeline: Pipeline?, prompt: String) async -> CGImage? { + guard let pipeline = pipeline else { return nil } + return try? pipeline.generate(prompt: prompt, scheduler: scheduler, numInferenceSteps: Int(steps), safetyOn: safetyOn, seed: seed) + } +} + +struct TextToImageView_Previews: PreviewProvider { + static var previews: some View { + TextToImageView().environmentObject(DiffusionGlobals()) + } +}