You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
112 lines
3.3 KiB
112 lines
3.3 KiB
//
|
|
// TextToImageView.swift
|
|
// Diffusion
|
|
//
|
|
// Created by Pedro Cuenca on December 2022.
|
|
// See LICENSE at https://github.com/huggingface/swift-coreml-diffusers/LICENSE
|
|
//
|
|
|
|
import SwiftUI
|
|
import Combine
|
|
import StableDiffusion
|
|
|
|
enum GenerationState {
|
|
case startup
|
|
case running(StableDiffusionProgress?)
|
|
case idle(String)
|
|
}
|
|
|
|
struct TextToImageView: View {
|
|
@EnvironmentObject var context: DiffusionGlobals
|
|
|
|
@State private var image: CGImage? = nil
|
|
@State private var state: GenerationState = .startup
|
|
@State private var prompt = "Labrador in the style of Vermeer"
|
|
@State private var scheduler = StableDiffusionScheduler.dpmpp
|
|
@State private var width = 512.0
|
|
@State private var height = 512.0
|
|
@State private var steps = 25.0
|
|
@State private var numImages = 1.0
|
|
@State private var seed: UInt32? = nil
|
|
@State private var safetyOn: Bool = true
|
|
|
|
@State private var progressSubscriber: Cancellable?
|
|
|
|
func submit() {
|
|
if case .running = state { return }
|
|
Task {
|
|
state = .running(nil)
|
|
image = await generate(pipeline: context.pipeline, prompt: prompt)
|
|
state = .idle(prompt)
|
|
}
|
|
}
|
|
|
|
var body: some View {
|
|
VStack(alignment: .leading) {
|
|
HStack {
|
|
TextField("Prompt", text: $prompt)
|
|
.textFieldStyle(.roundedBorder)
|
|
.onSubmit {
|
|
submit()
|
|
}
|
|
Button("Generate") {
|
|
submit()
|
|
}
|
|
.padding()
|
|
.buttonStyle(.borderedProminent)
|
|
}
|
|
Spacer()
|
|
HStack(alignment: .top) {
|
|
VStack(alignment: .leading) {
|
|
Group {
|
|
Text("Image Width")
|
|
Slider(value: $width, in: 64...2048, step: 8, label: {},
|
|
minimumValueLabel: {Text("64")},
|
|
maximumValueLabel: {Text("2048")})
|
|
Text("Image Height")
|
|
Slider(value: $height, in: 64...2048, step: 8, label: {},
|
|
minimumValueLabel: {Text("64")},
|
|
maximumValueLabel: {Text("2048")})
|
|
}
|
|
Text("Number of Inference Steps")
|
|
Slider(value: $steps, in: 1...300, step: 1, label: {},
|
|
minimumValueLabel: {Text("1")},
|
|
maximumValueLabel: {Text("300")})
|
|
Text("Number of Images")
|
|
Slider(value: $numImages, in: 1...8, step: 1, label: {},
|
|
minimumValueLabel: {Text("1")},
|
|
maximumValueLabel: {Text("8")})
|
|
Text("Safety")
|
|
Toggle("", isOn: $safetyOn)
|
|
Text("Seed Check On?")
|
|
TextField("", value: $seed, format: .number)
|
|
}
|
|
Spacer()
|
|
VStack {
|
|
PreviewView(image: $image, state: $state)
|
|
.scaledToFit()
|
|
}
|
|
}
|
|
Spacer()
|
|
}
|
|
.padding()
|
|
.onAppear {
|
|
progressSubscriber = context.pipeline?.progressPublisher.sink { progress in
|
|
guard let progress = progress else { return }
|
|
state = .running(progress)
|
|
}
|
|
}
|
|
}
|
|
|
|
func generate(pipeline: Pipeline?, prompt: String) async -> CGImage? {
|
|
guard let pipeline = pipeline else { return nil }
|
|
return try? pipeline.generate(prompt: prompt, scheduler: scheduler, numInferenceSteps: Int(steps), safetyOn: safetyOn, seed: seed)
|
|
}
|
|
}
|
|
|
|
struct TextToImageView_Previews: PreviewProvider {
|
|
static var previews: some View {
|
|
TextToImageView().environmentObject(DiffusionGlobals())
|
|
}
|
|
}
|