|
|
|
@ -23,6 +23,7 @@ public struct StableDiffusionPipeline: ResourceManaging {
|
|
|
|
|
|
|
|
|
|
public enum Error: String, Swift.Error {
|
|
|
|
|
case startingImageProvidedWithoutEncoder
|
|
|
|
|
case schedulerNotSupportedWithImage2Image
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/// Model to generate embeddings for tokenized input text
|
|
|
|
@ -130,7 +131,7 @@ public struct StableDiffusionPipeline: ResourceManaging {
|
|
|
|
|
seed: UInt32 = 0,
|
|
|
|
|
guidanceScale: Float = 7.5,
|
|
|
|
|
disableSafety: Bool = false,
|
|
|
|
|
scheduler: StableDiffusionScheduler = .pndmScheduler,
|
|
|
|
|
scheduler schedulerType: StableDiffusionScheduler = .pndmScheduler,
|
|
|
|
|
progressHandler: (Progress) -> Bool = { _ in true }
|
|
|
|
|
) throws -> [CGImage?] {
|
|
|
|
|
|
|
|
|
@ -153,7 +154,7 @@ public struct StableDiffusionPipeline: ResourceManaging {
|
|
|
|
|
|
|
|
|
|
/// Setup schedulers
|
|
|
|
|
let scheduler: [Scheduler] = (0..<imageCount).map { _ in
|
|
|
|
|
switch scheduler {
|
|
|
|
|
switch schedulerType {
|
|
|
|
|
case .pndmScheduler: return PNDMScheduler(stepCount: stepCount)
|
|
|
|
|
case .dpmSolverMultistepScheduler: return DPMSolverMultistepScheduler(stepCount: stepCount)
|
|
|
|
|
}
|
|
|
|
@ -169,6 +170,13 @@ public struct StableDiffusionPipeline: ResourceManaging {
|
|
|
|
|
guard let encoder else {
|
|
|
|
|
throw Error.startingImageProvidedWithoutEncoder
|
|
|
|
|
}
|
|
|
|
|
switch schedulerType {
|
|
|
|
|
case .pndmScheduler:
|
|
|
|
|
break
|
|
|
|
|
case .dpmSolverMultistepScheduler:
|
|
|
|
|
throw Error.schedulerNotSupportedWithImage2Image
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
let noiseTuples = generateImage2ImageLatentSamples(imageCount, stdev: 1, seed: seed)
|
|
|
|
|
latents = try noiseTuples.map({
|
|
|
|
|
try encoder.encode(
|
|
|
|
@ -184,7 +192,7 @@ public struct StableDiffusionPipeline: ResourceManaging {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// De-noising loop
|
|
|
|
|
let timeSteps = scheduler[0].calculateTimesteps(strength: timestepStrength)
|
|
|
|
|
let timeSteps: [Int] = scheduler[0].calculateTimesteps(strength: timestepStrength)
|
|
|
|
|
for (step,t) in timeSteps.enumerated() {
|
|
|
|
|
|
|
|
|
|
// Expand the latents for classifier-free guidance
|
|
|
|
|