You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
84 lines
2.4 KiB
84 lines
2.4 KiB
//
|
|
// Pipeline.swift
|
|
// Diffusion
|
|
//
|
|
// Created by Pedro Cuenca on December 2022.
|
|
// See LICENSE at https://github.com/huggingface/swift-coreml-diffusers/LICENSE
|
|
//
|
|
|
|
import Foundation
|
|
import CoreML
|
|
import Combine
|
|
|
|
import StableDiffusion
|
|
|
|
typealias StableDiffusionProgress = StableDiffusionPipeline.Progress
|
|
|
|
struct GenerationResult {
|
|
var image: CGImage?
|
|
var lastSeed: UInt32
|
|
var interval: TimeInterval?
|
|
var userCanceled: Bool
|
|
}
|
|
|
|
class Pipeline {
|
|
let pipeline: StableDiffusionPipeline
|
|
let maxSeed: UInt32
|
|
|
|
var progress: StableDiffusionProgress? = nil {
|
|
didSet {
|
|
progressPublisher.value = progress
|
|
}
|
|
}
|
|
lazy private(set) var progressPublisher: CurrentValueSubject<StableDiffusionProgress?, Never> = CurrentValueSubject(progress)
|
|
|
|
private var canceled = false
|
|
|
|
init(_ pipeline: StableDiffusionPipeline, maxSeed: UInt32 = UInt32.max) {
|
|
self.pipeline = pipeline
|
|
self.maxSeed = maxSeed
|
|
}
|
|
|
|
func generate(
|
|
prompt: String,
|
|
negativePrompt: String = "",
|
|
scheduler: StableDiffusionScheduler,
|
|
numInferenceSteps stepCount: Int = 50,
|
|
seed: UInt32? = nil,
|
|
guidanceScale: Float = 7.5,
|
|
disableSafety: Bool = false
|
|
) throws -> GenerationResult {
|
|
let beginDate = Date()
|
|
canceled = false
|
|
print("Generating...")
|
|
let theSeed = seed ?? UInt32.random(in: 0...maxSeed)
|
|
let images = try pipeline.generateImages(
|
|
prompt: prompt,
|
|
negativePrompt: negativePrompt,
|
|
imageCount: 1,
|
|
stepCount: stepCount,
|
|
seed: theSeed,
|
|
guidanceScale: guidanceScale,
|
|
disableSafety: disableSafety,
|
|
scheduler: scheduler
|
|
) { progress in
|
|
handleProgress(progress)
|
|
return !canceled
|
|
}
|
|
let interval = Date().timeIntervalSince(beginDate)
|
|
print("Got images: \(images) in \(interval)")
|
|
|
|
// Unwrap the 1 image we asked for, nil means safety checker triggered
|
|
let image = images.compactMap({ $0 }).first
|
|
return GenerationResult(image: image, lastSeed: theSeed, interval: interval, userCanceled: canceled)
|
|
}
|
|
|
|
func handleProgress(_ progress: StableDiffusionPipeline.Progress) {
|
|
self.progress = progress
|
|
}
|
|
|
|
func setCancelled() {
|
|
canceled = true
|
|
}
|
|
}
|