|
|
@ -18,6 +18,7 @@ struct GenerationResult {
|
|
|
|
var image: CGImage?
|
|
|
|
var image: CGImage?
|
|
|
|
var lastSeed: UInt32
|
|
|
|
var lastSeed: UInt32
|
|
|
|
var interval: TimeInterval?
|
|
|
|
var interval: TimeInterval?
|
|
|
|
|
|
|
|
var userCanceled: Bool
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
class Pipeline {
|
|
|
|
class Pipeline {
|
|
|
@ -31,6 +32,7 @@ class Pipeline {
|
|
|
|
}
|
|
|
|
}
|
|
|
|
lazy private(set) var progressPublisher: CurrentValueSubject<StableDiffusionProgress?, Never> = CurrentValueSubject(progress)
|
|
|
|
lazy private(set) var progressPublisher: CurrentValueSubject<StableDiffusionProgress?, Never> = CurrentValueSubject(progress)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
private var canceled = false
|
|
|
|
|
|
|
|
|
|
|
|
init(_ pipeline: StableDiffusionPipeline, maxSeed: UInt32 = UInt32.max) {
|
|
|
|
init(_ pipeline: StableDiffusionPipeline, maxSeed: UInt32 = UInt32.max) {
|
|
|
|
self.pipeline = pipeline
|
|
|
|
self.pipeline = pipeline
|
|
|
@ -47,6 +49,7 @@ class Pipeline {
|
|
|
|
disableSafety: Bool = false
|
|
|
|
disableSafety: Bool = false
|
|
|
|
) throws -> GenerationResult {
|
|
|
|
) throws -> GenerationResult {
|
|
|
|
let beginDate = Date()
|
|
|
|
let beginDate = Date()
|
|
|
|
|
|
|
|
canceled = false
|
|
|
|
print("Generating...")
|
|
|
|
print("Generating...")
|
|
|
|
let theSeed = seed ?? UInt32.random(in: 0...maxSeed)
|
|
|
|
let theSeed = seed ?? UInt32.random(in: 0...maxSeed)
|
|
|
|
let images = try pipeline.generateImages(
|
|
|
|
let images = try pipeline.generateImages(
|
|
|
@ -60,17 +63,21 @@ class Pipeline {
|
|
|
|
scheduler: scheduler
|
|
|
|
scheduler: scheduler
|
|
|
|
) { progress in
|
|
|
|
) { progress in
|
|
|
|
handleProgress(progress)
|
|
|
|
handleProgress(progress)
|
|
|
|
return true
|
|
|
|
return !canceled
|
|
|
|
}
|
|
|
|
}
|
|
|
|
let interval = Date().timeIntervalSince(beginDate)
|
|
|
|
let interval = Date().timeIntervalSince(beginDate)
|
|
|
|
print("Got images: \(images) in \(interval)")
|
|
|
|
print("Got images: \(images) in \(interval)")
|
|
|
|
|
|
|
|
|
|
|
|
// Unwrap the 1 image we asked for, nil means safety checker triggered
|
|
|
|
// Unwrap the 1 image we asked for, nil means safety checker triggered
|
|
|
|
let image = images.compactMap({ $0 }).first
|
|
|
|
let image = images.compactMap({ $0 }).first
|
|
|
|
return GenerationResult(image: image, lastSeed: theSeed, interval: interval)
|
|
|
|
return GenerationResult(image: image, lastSeed: theSeed, interval: interval, userCanceled: canceled)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
func handleProgress(_ progress: StableDiffusionPipeline.Progress) {
|
|
|
|
func handleProgress(_ progress: StableDiffusionPipeline.Progress) {
|
|
|
|
self.progress = progress
|
|
|
|
self.progress = progress
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
func setCancelled() {
|
|
|
|
|
|
|
|
canceled = true
|
|
|
|
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|