|
|
|
@ -24,6 +24,36 @@ public struct StableDiffusionPipeline: ResourceManaging {
|
|
|
|
|
public enum Error: String, Swift.Error {
|
|
|
|
|
case startingImageProvidedWithoutEncoder
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
public enum Mode {
|
|
|
|
|
case textToImage
|
|
|
|
|
case imageToImage
|
|
|
|
|
// case inPainting
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
public struct SampleInput: Hashable {
|
|
|
|
|
public var prompt: String
|
|
|
|
|
public var negativePrompt: String = ""
|
|
|
|
|
public var startingImage: CGImage? = nil
|
|
|
|
|
//public var maskImage: CGImage? = nil
|
|
|
|
|
public var strength: Float = 1.0
|
|
|
|
|
public var imageCount: Int = 1
|
|
|
|
|
public var stepCount: Int = 50
|
|
|
|
|
public var seed: UInt32 = 0
|
|
|
|
|
public var guidanceScale: Float = 7.5
|
|
|
|
|
public var disableSafety: Bool = false
|
|
|
|
|
public var schedulerType: StableDiffusionScheduler = .pndmScheduler
|
|
|
|
|
|
|
|
|
|
public var mode: Mode {
|
|
|
|
|
guard startingImage != nil else {
|
|
|
|
|
return .textToImage
|
|
|
|
|
}
|
|
|
|
|
guard strength < 1.0 else {
|
|
|
|
|
return .textToImage
|
|
|
|
|
}
|
|
|
|
|
return .imageToImage
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/// Model to generate embeddings for tokenized input text
|
|
|
|
|
var textEncoder: TextEncoder
|
|
|
|
@ -133,10 +163,27 @@ public struct StableDiffusionPipeline: ResourceManaging {
|
|
|
|
|
scheduler schedulerType: StableDiffusionScheduler = .pndmScheduler,
|
|
|
|
|
progressHandler: (Progress) -> Bool = { _ in true }
|
|
|
|
|
) throws -> [CGImage?] {
|
|
|
|
|
try generateImages(input: SampleInput(
|
|
|
|
|
prompt: prompt,
|
|
|
|
|
negativePrompt: negativePrompt,
|
|
|
|
|
startingImage: startingImage,
|
|
|
|
|
strength: strength,
|
|
|
|
|
imageCount: imageCount,
|
|
|
|
|
stepCount: stepCount,
|
|
|
|
|
seed: seed,
|
|
|
|
|
guidanceScale: guidanceScale,
|
|
|
|
|
disableSafety: disableSafety,
|
|
|
|
|
schedulerType: schedulerType), progressHandler: progressHandler)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
public func generateImages(
|
|
|
|
|
input: SampleInput,
|
|
|
|
|
progressHandler: (Progress) -> Bool = { _ in true }
|
|
|
|
|
) throws -> [CGImage?] {
|
|
|
|
|
|
|
|
|
|
// Encode the input prompt and negative prompt
|
|
|
|
|
let promptEmbedding = try textEncoder.encode(prompt)
|
|
|
|
|
let negativePromptEmbedding = try textEncoder.encode(negativePrompt)
|
|
|
|
|
let promptEmbedding = try textEncoder.encode(input.prompt)
|
|
|
|
|
let negativePromptEmbedding = try textEncoder.encode(input.negativePrompt)
|
|
|
|
|
|
|
|
|
|
if reduceMemory {
|
|
|
|
|
textEncoder.unloadResources()
|
|
|
|
@ -152,10 +199,10 @@ public struct StableDiffusionPipeline: ResourceManaging {
|
|
|
|
|
let hiddenStates = toHiddenStates(concatEmbedding)
|
|
|
|
|
|
|
|
|
|
/// Setup schedulers
|
|
|
|
|
let scheduler: [Scheduler] = (0..<imageCount).map { _ in
|
|
|
|
|
switch schedulerType {
|
|
|
|
|
case .pndmScheduler: return PNDMScheduler(stepCount: stepCount)
|
|
|
|
|
case .dpmSolverMultistepScheduler: return DPMSolverMultistepScheduler(stepCount: stepCount)
|
|
|
|
|
let scheduler: [Scheduler] = (0..<input.imageCount).map { _ in
|
|
|
|
|
switch input.schedulerType {
|
|
|
|
|
case .pndmScheduler: return PNDMScheduler(stepCount: input.stepCount)
|
|
|
|
|
case .dpmSolverMultistepScheduler: return DPMSolverMultistepScheduler(stepCount: input.stepCount)
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
let stdev = scheduler[0].initNoiseSigma
|
|
|
|
@ -164,24 +211,27 @@ public struct StableDiffusionPipeline: ResourceManaging {
|
|
|
|
|
var latents: [MLShapedArray<Float32>]
|
|
|
|
|
let timestepStrength: Float?
|
|
|
|
|
|
|
|
|
|
if let startingImage {
|
|
|
|
|
timestepStrength = strength
|
|
|
|
|
if
|
|
|
|
|
let startingImage = input.startingImage,
|
|
|
|
|
input.mode == .imageToImage
|
|
|
|
|
{
|
|
|
|
|
timestepStrength = input.strength
|
|
|
|
|
guard let encoder else {
|
|
|
|
|
throw Error.startingImageProvidedWithoutEncoder
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
let noiseTuples = generateImage2ImageLatentSamples(imageCount, stdev: 1, seed: seed)
|
|
|
|
|
let noiseTuples = generateImage2ImageLatentSamples(input.imageCount, stdev: 1, seed: input.seed)
|
|
|
|
|
latents = try noiseTuples.map({
|
|
|
|
|
try encoder.encode(
|
|
|
|
|
image: startingImage,
|
|
|
|
|
diagonalNoise: $0.diagonal,
|
|
|
|
|
noise: $0.latentNoise,
|
|
|
|
|
alphasCumprodStep: scheduler[0].calculateAlphasCumprod(strength: strength))
|
|
|
|
|
alphasCumprodStep: scheduler[0].calculateAlphasCumprod(strength: input.strength))
|
|
|
|
|
})
|
|
|
|
|
} else {
|
|
|
|
|
timestepStrength = nil
|
|
|
|
|
// Generate random latent samples from specified seed
|
|
|
|
|
latents = generateLatentSamples(imageCount, stdev: stdev, seed: seed)
|
|
|
|
|
latents = generateLatentSamples(input.imageCount, stdev: stdev, seed: input.seed)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// De-noising loop
|
|
|
|
@ -202,11 +252,11 @@ public struct StableDiffusionPipeline: ResourceManaging {
|
|
|
|
|
hiddenStates: hiddenStates
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
noise = performGuidance(noise, guidanceScale)
|
|
|
|
|
noise = performGuidance(noise, input.guidanceScale)
|
|
|
|
|
|
|
|
|
|
// Have the scheduler compute the previous (t-1) latent
|
|
|
|
|
// sample given the predicted noise and current sample
|
|
|
|
|
for i in 0..<imageCount {
|
|
|
|
|
for i in 0..<input.imageCount {
|
|
|
|
|
latents[i] = scheduler[i].step(
|
|
|
|
|
output: noise[i],
|
|
|
|
|
timeStep: t,
|
|
|
|
@ -217,11 +267,11 @@ public struct StableDiffusionPipeline: ResourceManaging {
|
|
|
|
|
// Report progress
|
|
|
|
|
let progress = Progress(
|
|
|
|
|
pipeline: self,
|
|
|
|
|
prompt: prompt,
|
|
|
|
|
prompt: input.prompt,
|
|
|
|
|
step: step,
|
|
|
|
|
stepCount: timeSteps.count,
|
|
|
|
|
currentLatentSamples: latents,
|
|
|
|
|
isSafetyEnabled: canSafetyCheck && !disableSafety
|
|
|
|
|
isSafetyEnabled: canSafetyCheck && !input.disableSafety
|
|
|
|
|
)
|
|
|
|
|
if !progressHandler(progress) {
|
|
|
|
|
// Stop if requested by handler
|
|
|
|
@ -234,7 +284,7 @@ public struct StableDiffusionPipeline: ResourceManaging {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Decode the latent samples to images
|
|
|
|
|
return try decodeToImages(latents, disableSafety: disableSafety)
|
|
|
|
|
return try decodeToImages(latents, disableSafety: input.disableSafety)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func generateLatentSamples(_ count: Int, stdev: Float, seed: UInt32) -> [MLShapedArray<Float32>] {
|
|
|
|
|