|
|
|
@ -114,13 +114,13 @@ public struct StableDiffusionPipeline: ResourceManaging {
|
|
|
|
|
/// - Returns: An array of `imageCount` optional images.
|
|
|
|
|
/// The images will be nil if safety checks were performed and found the result to be un-safe
|
|
|
|
|
public func generateImages(
|
|
|
|
|
input: SampleInput,
|
|
|
|
|
configuration config: Configuration,
|
|
|
|
|
progressHandler: (Progress) -> Bool = { _ in true }
|
|
|
|
|
) throws -> [CGImage?] {
|
|
|
|
|
|
|
|
|
|
// Encode the input prompt and negative prompt
|
|
|
|
|
let promptEmbedding = try textEncoder.encode(input.prompt)
|
|
|
|
|
let negativePromptEmbedding = try textEncoder.encode(input.negativePrompt)
|
|
|
|
|
let promptEmbedding = try textEncoder.encode(config.prompt)
|
|
|
|
|
let negativePromptEmbedding = try textEncoder.encode(config.negativePrompt)
|
|
|
|
|
|
|
|
|
|
if reduceMemory {
|
|
|
|
|
textEncoder.unloadResources()
|
|
|
|
@ -136,10 +136,10 @@ public struct StableDiffusionPipeline: ResourceManaging {
|
|
|
|
|
let hiddenStates = toHiddenStates(concatEmbedding)
|
|
|
|
|
|
|
|
|
|
/// Setup schedulers
|
|
|
|
|
let scheduler: [Scheduler] = (0..<input.imageCount).map { _ in
|
|
|
|
|
switch input.schedulerType {
|
|
|
|
|
case .pndmScheduler: return PNDMScheduler(stepCount: input.stepCount)
|
|
|
|
|
case .dpmSolverMultistepScheduler: return DPMSolverMultistepScheduler(stepCount: input.stepCount)
|
|
|
|
|
let scheduler: [Scheduler] = (0..<config.imageCount).map { _ in
|
|
|
|
|
switch config.schedulerType {
|
|
|
|
|
case .pndmScheduler: return PNDMScheduler(stepCount: config.stepCount)
|
|
|
|
|
case .dpmSolverMultistepScheduler: return DPMSolverMultistepScheduler(stepCount: config.stepCount)
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
let stdev = scheduler[0].initNoiseSigma
|
|
|
|
@ -149,26 +149,26 @@ public struct StableDiffusionPipeline: ResourceManaging {
|
|
|
|
|
let timestepStrength: Float?
|
|
|
|
|
|
|
|
|
|
if
|
|
|
|
|
let startingImage = input.startingImage,
|
|
|
|
|
input.mode == .imageToImage
|
|
|
|
|
let startingImage = config.startingImage,
|
|
|
|
|
config.mode == .imageToImage
|
|
|
|
|
{
|
|
|
|
|
timestepStrength = input.strength
|
|
|
|
|
timestepStrength = config.strength
|
|
|
|
|
guard let encoder else {
|
|
|
|
|
throw Error.startingImageProvidedWithoutEncoder
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
let noiseTuples = generateImage2ImageLatentSamples(input.imageCount, stdev: 1, seed: input.seed)
|
|
|
|
|
let noiseTuples = generateImage2ImageLatentSamples(config.imageCount, stdev: 1, seed: config.seed)
|
|
|
|
|
latents = try noiseTuples.map({
|
|
|
|
|
try encoder.encode(
|
|
|
|
|
image: startingImage,
|
|
|
|
|
diagonalNoise: $0.diagonal,
|
|
|
|
|
noise: $0.latentNoise,
|
|
|
|
|
alphasCumprodStep: scheduler[0].calculateAlphasCumprod(strength: input.strength))
|
|
|
|
|
alphasCumprodStep: scheduler[0].calculateAlphasCumprod(strength: config.strength))
|
|
|
|
|
})
|
|
|
|
|
} else {
|
|
|
|
|
timestepStrength = nil
|
|
|
|
|
// Generate random latent samples from specified seed
|
|
|
|
|
latents = generateLatentSamples(input.imageCount, stdev: stdev, seed: input.seed)
|
|
|
|
|
latents = generateLatentSamples(config.imageCount, stdev: stdev, seed: config.seed)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// De-noising loop
|
|
|
|
@ -189,11 +189,11 @@ public struct StableDiffusionPipeline: ResourceManaging {
|
|
|
|
|
hiddenStates: hiddenStates
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
noise = performGuidance(noise, input.guidanceScale)
|
|
|
|
|
noise = performGuidance(noise, config.guidanceScale)
|
|
|
|
|
|
|
|
|
|
// Have the scheduler compute the previous (t-1) latent
|
|
|
|
|
// sample given the predicted noise and current sample
|
|
|
|
|
for i in 0..<input.imageCount {
|
|
|
|
|
for i in 0..<config.imageCount {
|
|
|
|
|
latents[i] = scheduler[i].step(
|
|
|
|
|
output: noise[i],
|
|
|
|
|
timeStep: t,
|
|
|
|
@ -204,11 +204,11 @@ public struct StableDiffusionPipeline: ResourceManaging {
|
|
|
|
|
// Report progress
|
|
|
|
|
let progress = Progress(
|
|
|
|
|
pipeline: self,
|
|
|
|
|
prompt: input.prompt,
|
|
|
|
|
prompt: config.prompt,
|
|
|
|
|
step: step,
|
|
|
|
|
stepCount: timeSteps.count,
|
|
|
|
|
currentLatentSamples: latents,
|
|
|
|
|
isSafetyEnabled: canSafetyCheck && !input.disableSafety
|
|
|
|
|
isSafetyEnabled: canSafetyCheck && !config.disableSafety
|
|
|
|
|
)
|
|
|
|
|
if !progressHandler(progress) {
|
|
|
|
|
// Stop if requested by handler
|
|
|
|
@ -221,7 +221,7 @@ public struct StableDiffusionPipeline: ResourceManaging {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Decode the latent samples to images
|
|
|
|
|
return try decodeToImages(latents, disableSafety: input.disableSafety)
|
|
|
|
|
return try decodeToImages(latents, disableSafety: config.disableSafety)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func generateLatentSamples(_ count: Int, stdev: Float, seed: UInt32) -> [MLShapedArray<Float32>] {
|
|
|
|
|