disable dpm multistep solver with image2image

pull/116/head
Timothy Kautz 3 years ago
parent 2dca60608c
commit e7eb953d7b

@ -23,16 +23,12 @@ public final class DPMSolverMultistepScheduler: Scheduler {
public let betas: [Float] public let betas: [Float]
public let alphas: [Float] public let alphas: [Float]
public let alphasCumProd: [Float] public let alphasCumProd: [Float]
private let timeSteps: [Int] public let timeSteps: [Int]
public let alpha_t: [Float] public let alpha_t: [Float]
public let sigma_t: [Float] public let sigma_t: [Float]
public let lambda_t: [Float] public let lambda_t: [Float]
public var allTimeSteps: [Int] {
timeSteps
}
public let solverOrder = 2 public let solverOrder = 2
private(set) var lowerOrderStepped = 0 private(set) var lowerOrderStepped = 0
@ -183,4 +179,9 @@ public final class DPMSolverMultistepScheduler: Scheduler {
return prevSample return prevSample
} }
/// This scheduler does not support image2image strength value.
public func calculateTimesteps(strength: Float?) -> [Int] {
timeSteps
}
} }

@ -12,7 +12,7 @@ public protocol Scheduler {
var inferenceStepCount: Int { get } var inferenceStepCount: Int { get }
/// Training diffusion time steps index by inference time step /// Training diffusion time steps index by inference time step
var allTimeSteps: [Int] { get } var timeSteps: [Int] { get }
/// Training diffusion time steps index by inference time step /// Training diffusion time steps index by inference time step
func calculateTimesteps(strength: Float?) -> [Int] func calculateTimesteps(strength: Float?) -> [Int]
@ -94,9 +94,9 @@ public extension Scheduler {
public extension Scheduler { public extension Scheduler {
func calculateTimesteps(strength: Float?) -> [Int] { func calculateTimesteps(strength: Float?) -> [Int] {
guard let strength else { return allTimeSteps.reversed() } guard let strength else { return timeSteps.reversed() }
let startStep = Int(Float(inferenceStepCount) * strength) let startStep = Int(Float(inferenceStepCount) * strength)
let acutalTimesteps = Array(allTimeSteps[0..<startStep].reversed()) let acutalTimesteps = Array(timeSteps[0..<startStep].reversed())
return acutalTimesteps return acutalTimesteps
} }
} }
@ -127,11 +127,7 @@ public final class PNDMScheduler: Scheduler {
public let betas: [Float] public let betas: [Float]
public let alphas: [Float] public let alphas: [Float]
public let alphasCumProd: [Float] public let alphasCumProd: [Float]
private let timeSteps: [Int] public let timeSteps: [Int]
public var allTimeSteps: [Int] {
timeSteps
}
// Internal state // Internal state
var counter: Int var counter: Int

@ -23,6 +23,7 @@ public struct StableDiffusionPipeline: ResourceManaging {
public enum Error: String, Swift.Error { public enum Error: String, Swift.Error {
case startingImageProvidedWithoutEncoder case startingImageProvidedWithoutEncoder
case schedulerNotSupportedWithImage2Image
} }
/// Model to generate embeddings for tokenized input text /// Model to generate embeddings for tokenized input text
@ -130,7 +131,7 @@ public struct StableDiffusionPipeline: ResourceManaging {
seed: UInt32 = 0, seed: UInt32 = 0,
guidanceScale: Float = 7.5, guidanceScale: Float = 7.5,
disableSafety: Bool = false, disableSafety: Bool = false,
scheduler: StableDiffusionScheduler = .pndmScheduler, scheduler schedulerType: StableDiffusionScheduler = .pndmScheduler,
progressHandler: (Progress) -> Bool = { _ in true } progressHandler: (Progress) -> Bool = { _ in true }
) throws -> [CGImage?] { ) throws -> [CGImage?] {
@ -153,7 +154,7 @@ public struct StableDiffusionPipeline: ResourceManaging {
/// Setup schedulers /// Setup schedulers
let scheduler: [Scheduler] = (0..<imageCount).map { _ in let scheduler: [Scheduler] = (0..<imageCount).map { _ in
switch scheduler { switch schedulerType {
case .pndmScheduler: return PNDMScheduler(stepCount: stepCount) case .pndmScheduler: return PNDMScheduler(stepCount: stepCount)
case .dpmSolverMultistepScheduler: return DPMSolverMultistepScheduler(stepCount: stepCount) case .dpmSolverMultistepScheduler: return DPMSolverMultistepScheduler(stepCount: stepCount)
} }
@ -169,6 +170,13 @@ public struct StableDiffusionPipeline: ResourceManaging {
guard let encoder else { guard let encoder else {
throw Error.startingImageProvidedWithoutEncoder throw Error.startingImageProvidedWithoutEncoder
} }
switch schedulerType {
case .pndmScheduler:
break
case .dpmSolverMultistepScheduler:
throw Error.schedulerNotSupportedWithImage2Image
}
let noiseTuples = generateImage2ImageLatentSamples(imageCount, stdev: 1, seed: seed) let noiseTuples = generateImage2ImageLatentSamples(imageCount, stdev: 1, seed: seed)
latents = try noiseTuples.map({ latents = try noiseTuples.map({
try encoder.encode( try encoder.encode(
@ -184,7 +192,7 @@ public struct StableDiffusionPipeline: ResourceManaging {
} }
// De-noising loop // De-noising loop
let timeSteps = scheduler[0].calculateTimesteps(strength: timestepStrength) let timeSteps: [Int] = scheduler[0].calculateTimesteps(strength: timestepStrength)
for (step,t) in timeSteps.enumerated() { for (step,t) in timeSteps.enumerated() {
// Expand the latents for classifier-free guidance // Expand the latents for classifier-free guidance

Loading…
Cancel
Save