You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

75 lines
2.1 KiB

2 years ago
//
// Pipeline.swift
// Diffusion
//
// Created by Pedro Cuenca on December 2022.
// See LICENSE at https://github.com/huggingface/swift-coreml-diffusers/LICENSE
//
import Foundation
import CoreML
import Combine
import StableDiffusion
typealias StableDiffusionProgress = StableDiffusionPipeline.Progress
struct GenerationResult {
var image: CGImage?
var lastSeed: UInt32
var interval: TimeInterval?
}
2 years ago
class Pipeline {
let pipeline: StableDiffusionPipeline
var progress: StableDiffusionProgress? = nil {
didSet {
progressPublisher.value = progress
}
}
lazy private(set) var progressPublisher: CurrentValueSubject<StableDiffusionProgress?, Never> = CurrentValueSubject(progress)
init(_ pipeline: StableDiffusionPipeline) {
self.pipeline = pipeline
}
func generate(
prompt: String,
negativePrompt: String = "",
scheduler: StableDiffusionScheduler,
numInferenceSteps stepCount: Int = 50,
seed: UInt32? = nil,
guidanceScale: Float = 7.5,
disableSafety: Bool = false
) throws -> GenerationResult {
2 years ago
let beginDate = Date()
print("Generating...")
let theSeed = seed ?? UInt32.random(in: 0..<UInt32.max)
let images = try pipeline.generateImages(
prompt: prompt,
negativePrompt: negativePrompt,
2 years ago
imageCount: 1,
stepCount: stepCount,
seed: theSeed,
guidanceScale: guidanceScale,
disableSafety: disableSafety,
2 years ago
scheduler: scheduler
) { progress in
handleProgress(progress)
return true
}
let interval = Date().timeIntervalSince(beginDate)
print("Got images: \(images) in \(interval)")
2 years ago
// unwrap the 1 image we asked for
guard let image = images.compactMap({ $0 }).first else { throw "Generation failed" }
return GenerationResult(image: image, lastSeed: theSeed, interval: interval)
2 years ago
}
func handleProgress(_ progress: StableDiffusionPipeline.Progress) {
self.progress = progress
}
}