You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

84 lines
2.4 KiB

//
// Pipeline.swift
// Diffusion
//
// Created by Pedro Cuenca on December 2022.
// See LICENSE at https://github.com/huggingface/swift-coreml-diffusers/LICENSE
//
import Foundation
import CoreML
import Combine
import StableDiffusion
typealias StableDiffusionProgress = StableDiffusionPipeline.Progress
struct GenerationResult {
var image: CGImage?
var lastSeed: UInt32
var interval: TimeInterval?
var userCanceled: Bool
}
class Pipeline {
let pipeline: StableDiffusionPipeline
let maxSeed: UInt32
var progress: StableDiffusionProgress? = nil {
didSet {
progressPublisher.value = progress
}
}
lazy private(set) var progressPublisher: CurrentValueSubject<StableDiffusionProgress?, Never> = CurrentValueSubject(progress)
private var canceled = false
init(_ pipeline: StableDiffusionPipeline, maxSeed: UInt32 = UInt32.max) {
self.pipeline = pipeline
self.maxSeed = maxSeed
}
func generate(
prompt: String,
negativePrompt: String = "",
scheduler: StableDiffusionScheduler,
numInferenceSteps stepCount: Int = 50,
seed: UInt32? = nil,
guidanceScale: Float = 7.5,
disableSafety: Bool = false
) throws -> GenerationResult {
let beginDate = Date()
canceled = false
print("Generating...")
let theSeed = seed ?? UInt32.random(in: 0...maxSeed)
let images = try pipeline.generateImages(
prompt: prompt,
negativePrompt: negativePrompt,
imageCount: 1,
stepCount: stepCount,
seed: theSeed,
guidanceScale: guidanceScale,
disableSafety: disableSafety,
scheduler: scheduler
) { progress in
handleProgress(progress)
return !canceled
}
let interval = Date().timeIntervalSince(beginDate)
print("Got images: \(images) in \(interval)")
// Unwrap the 1 image we asked for, nil means safety checker triggered
let image = images.compactMap({ $0 }).first
return GenerationResult(image: image, lastSeed: theSeed, interval: interval, userCanceled: canceled)
}
func handleProgress(_ progress: StableDiffusionPipeline.Progress) {
self.progress = progress
}
func setCancelled() {
canceled = true
}
}