pr comments

pull/116/head
Timothy Kautz 2 years ago
parent d6647a4551
commit 088bdc1dc4

@ -1,9 +1,5 @@
// // For licensing see accompanying LICENSE.md file.
// File.swift // Copyright (C) 2022 Apple Inc. All Rights Reserved.
//
//
// Created by Timothy Kautz on 1/30/23.
//
import Foundation import Foundation
import CoreGraphics import CoreGraphics
@ -19,7 +15,7 @@ extension StableDiffusionPipeline {
} }
/// Image generation configuration /// Image generation configuration
public struct SampleInput: Hashable { public struct Configuration: Hashable {
/// Text prompt to guide sampling /// Text prompt to guide sampling
public var prompt: String public var prompt: String
@ -54,27 +50,9 @@ extension StableDiffusionPipeline {
} }
public init( public init(
prompt: String, prompt: String
negativePrompt: String = "",
startingImage: CGImage? = nil,
strength: Float = 1.0,
imageCount: Int = 1,
stepCount: Int = 50,
seed: UInt32 = 0,
guidanceScale: Float = 7.5,
disableSafety: Bool = false,
schedulerType: StableDiffusionScheduler = .pndmScheduler
) { ) {
self.prompt = prompt self.prompt = prompt
self.negativePrompt = negativePrompt
self.startingImage = startingImage
self.strength = strength
self.imageCount = imageCount
self.stepCount = stepCount
self.seed = seed
self.guidanceScale = guidanceScale
self.disableSafety = disableSafety
self.schedulerType = schedulerType
} }
} }

@ -114,13 +114,13 @@ public struct StableDiffusionPipeline: ResourceManaging {
/// - Returns: An array of `imageCount` optional images. /// - Returns: An array of `imageCount` optional images.
/// The images will be nil if safety checks were performed and found the result to be un-safe /// The images will be nil if safety checks were performed and found the result to be un-safe
public func generateImages( public func generateImages(
input: SampleInput, configuration config: Configuration,
progressHandler: (Progress) -> Bool = { _ in true } progressHandler: (Progress) -> Bool = { _ in true }
) throws -> [CGImage?] { ) throws -> [CGImage?] {
// Encode the input prompt and negative prompt // Encode the input prompt and negative prompt
let promptEmbedding = try textEncoder.encode(input.prompt) let promptEmbedding = try textEncoder.encode(config.prompt)
let negativePromptEmbedding = try textEncoder.encode(input.negativePrompt) let negativePromptEmbedding = try textEncoder.encode(config.negativePrompt)
if reduceMemory { if reduceMemory {
textEncoder.unloadResources() textEncoder.unloadResources()
@ -136,10 +136,10 @@ public struct StableDiffusionPipeline: ResourceManaging {
let hiddenStates = toHiddenStates(concatEmbedding) let hiddenStates = toHiddenStates(concatEmbedding)
/// Setup schedulers /// Setup schedulers
let scheduler: [Scheduler] = (0..<input.imageCount).map { _ in let scheduler: [Scheduler] = (0..<config.imageCount).map { _ in
switch input.schedulerType { switch config.schedulerType {
case .pndmScheduler: return PNDMScheduler(stepCount: input.stepCount) case .pndmScheduler: return PNDMScheduler(stepCount: config.stepCount)
case .dpmSolverMultistepScheduler: return DPMSolverMultistepScheduler(stepCount: input.stepCount) case .dpmSolverMultistepScheduler: return DPMSolverMultistepScheduler(stepCount: config.stepCount)
} }
} }
let stdev = scheduler[0].initNoiseSigma let stdev = scheduler[0].initNoiseSigma
@ -149,26 +149,26 @@ public struct StableDiffusionPipeline: ResourceManaging {
let timestepStrength: Float? let timestepStrength: Float?
if if
let startingImage = input.startingImage, let startingImage = config.startingImage,
input.mode == .imageToImage config.mode == .imageToImage
{ {
timestepStrength = input.strength timestepStrength = config.strength
guard let encoder else { guard let encoder else {
throw Error.startingImageProvidedWithoutEncoder throw Error.startingImageProvidedWithoutEncoder
} }
let noiseTuples = generateImage2ImageLatentSamples(input.imageCount, stdev: 1, seed: input.seed) let noiseTuples = generateImage2ImageLatentSamples(config.imageCount, stdev: 1, seed: config.seed)
latents = try noiseTuples.map({ latents = try noiseTuples.map({
try encoder.encode( try encoder.encode(
image: startingImage, image: startingImage,
diagonalNoise: $0.diagonal, diagonalNoise: $0.diagonal,
noise: $0.latentNoise, noise: $0.latentNoise,
alphasCumprodStep: scheduler[0].calculateAlphasCumprod(strength: input.strength)) alphasCumprodStep: scheduler[0].calculateAlphasCumprod(strength: config.strength))
}) })
} else { } else {
timestepStrength = nil timestepStrength = nil
// Generate random latent samples from specified seed // Generate random latent samples from specified seed
latents = generateLatentSamples(input.imageCount, stdev: stdev, seed: input.seed) latents = generateLatentSamples(config.imageCount, stdev: stdev, seed: config.seed)
} }
// De-noising loop // De-noising loop
@ -189,11 +189,11 @@ public struct StableDiffusionPipeline: ResourceManaging {
hiddenStates: hiddenStates hiddenStates: hiddenStates
) )
noise = performGuidance(noise, input.guidanceScale) noise = performGuidance(noise, config.guidanceScale)
// Have the scheduler compute the previous (t-1) latent // Have the scheduler compute the previous (t-1) latent
// sample given the predicted noise and current sample // sample given the predicted noise and current sample
for i in 0..<input.imageCount { for i in 0..<config.imageCount {
latents[i] = scheduler[i].step( latents[i] = scheduler[i].step(
output: noise[i], output: noise[i],
timeStep: t, timeStep: t,
@ -204,11 +204,11 @@ public struct StableDiffusionPipeline: ResourceManaging {
// Report progress // Report progress
let progress = Progress( let progress = Progress(
pipeline: self, pipeline: self,
prompt: input.prompt, prompt: config.prompt,
step: step, step: step,
stepCount: timeSteps.count, stepCount: timeSteps.count,
currentLatentSamples: latents, currentLatentSamples: latents,
isSafetyEnabled: canSafetyCheck && !input.disableSafety isSafetyEnabled: canSafetyCheck && !config.disableSafety
) )
if !progressHandler(progress) { if !progressHandler(progress) {
// Stop if requested by handler // Stop if requested by handler
@ -221,7 +221,7 @@ public struct StableDiffusionPipeline: ResourceManaging {
} }
// Decode the latent samples to images // Decode the latent samples to images
return try decodeToImages(latents, disableSafety: input.disableSafety) return try decodeToImages(latents, disableSafety: config.disableSafety)
} }
func generateLatentSamples(_ count: Int, stdev: Float, seed: UInt32) -> [MLShapedArray<Float32>] { func generateLatentSamples(_ count: Int, stdev: Float, seed: UInt32) -> [MLShapedArray<Float32>] {

@ -118,25 +118,27 @@ struct StableDiffusionSample: ParsableCommand {
let sampleTimer = SampleTimer() let sampleTimer = SampleTimer()
sampleTimer.start() sampleTimer.start()
var pipelineConfig = StableDiffusionPipeline.Configuration(prompt: prompt)
pipelineConfig.negativePrompt = negativePrompt
pipelineConfig.startingImage = startingImage
pipelineConfig.strength = strength
pipelineConfig.imageCount = imageCount
pipelineConfig.stepCount = stepCount
pipelineConfig.seed = seed
pipelineConfig.guidanceScale = guidanceScale
pipelineConfig.schedulerType = scheduler.stableDiffusionScheduler
let images = try pipeline.generateImages( let images = try pipeline.generateImages(
input: StableDiffusionPipeline.SampleInput( configuration: pipelineConfig,
prompt: prompt, progressHandler: { progress in
negativePrompt: negativePrompt, sampleTimer.stop()
startingImage: startingImage, handleProgress(progress,sampleTimer)
strength: strength, if progress.stepCount != progress.step {
imageCount: imageCount, sampleTimer.start()
stepCount: stepCount, }
seed: seed, return true
guidanceScale: guidanceScale, })
schedulerType: scheduler.stableDiffusionScheduler)
) { progress in
sampleTimer.stop()
handleProgress(progress,sampleTimer)
if progress.stepCount != progress.step {
sampleTimer.start()
}
return true
}
_ = try saveImages(images, logNames: true) _ = try saveImages(images, logNames: true)
} }

Loading…
Cancel
Save