pr comments

pull/116/head
Timothy Kautz 1 year ago
parent d6647a4551
commit 088bdc1dc4

@ -1,9 +1,5 @@
//
// File.swift
//
//
// Created by Timothy Kautz on 1/30/23.
//
// For licensing see accompanying LICENSE.md file.
// Copyright (C) 2022 Apple Inc. All Rights Reserved.
import Foundation
import CoreGraphics
@ -19,7 +15,7 @@ extension StableDiffusionPipeline {
}
/// Image generation configuration
public struct SampleInput: Hashable {
public struct Configuration: Hashable {
/// Text prompt to guide sampling
public var prompt: String
@ -54,27 +50,9 @@ extension StableDiffusionPipeline {
}
public init(
prompt: String,
negativePrompt: String = "",
startingImage: CGImage? = nil,
strength: Float = 1.0,
imageCount: Int = 1,
stepCount: Int = 50,
seed: UInt32 = 0,
guidanceScale: Float = 7.5,
disableSafety: Bool = false,
schedulerType: StableDiffusionScheduler = .pndmScheduler
prompt: String
) {
self.prompt = prompt
self.negativePrompt = negativePrompt
self.startingImage = startingImage
self.strength = strength
self.imageCount = imageCount
self.stepCount = stepCount
self.seed = seed
self.guidanceScale = guidanceScale
self.disableSafety = disableSafety
self.schedulerType = schedulerType
}
}

@ -114,13 +114,13 @@ public struct StableDiffusionPipeline: ResourceManaging {
/// - Returns: An array of `imageCount` optional images.
/// The images will be nil if safety checks were performed and found the result to be un-safe
public func generateImages(
input: SampleInput,
configuration config: Configuration,
progressHandler: (Progress) -> Bool = { _ in true }
) throws -> [CGImage?] {
// Encode the input prompt and negative prompt
let promptEmbedding = try textEncoder.encode(input.prompt)
let negativePromptEmbedding = try textEncoder.encode(input.negativePrompt)
let promptEmbedding = try textEncoder.encode(config.prompt)
let negativePromptEmbedding = try textEncoder.encode(config.negativePrompt)
if reduceMemory {
textEncoder.unloadResources()
@ -136,10 +136,10 @@ public struct StableDiffusionPipeline: ResourceManaging {
let hiddenStates = toHiddenStates(concatEmbedding)
/// Setup schedulers
let scheduler: [Scheduler] = (0..<input.imageCount).map { _ in
switch input.schedulerType {
case .pndmScheduler: return PNDMScheduler(stepCount: input.stepCount)
case .dpmSolverMultistepScheduler: return DPMSolverMultistepScheduler(stepCount: input.stepCount)
let scheduler: [Scheduler] = (0..<config.imageCount).map { _ in
switch config.schedulerType {
case .pndmScheduler: return PNDMScheduler(stepCount: config.stepCount)
case .dpmSolverMultistepScheduler: return DPMSolverMultistepScheduler(stepCount: config.stepCount)
}
}
let stdev = scheduler[0].initNoiseSigma
@ -149,26 +149,26 @@ public struct StableDiffusionPipeline: ResourceManaging {
let timestepStrength: Float?
if
let startingImage = input.startingImage,
input.mode == .imageToImage
let startingImage = config.startingImage,
config.mode == .imageToImage
{
timestepStrength = input.strength
timestepStrength = config.strength
guard let encoder else {
throw Error.startingImageProvidedWithoutEncoder
}
let noiseTuples = generateImage2ImageLatentSamples(input.imageCount, stdev: 1, seed: input.seed)
let noiseTuples = generateImage2ImageLatentSamples(config.imageCount, stdev: 1, seed: config.seed)
latents = try noiseTuples.map({
try encoder.encode(
image: startingImage,
diagonalNoise: $0.diagonal,
noise: $0.latentNoise,
alphasCumprodStep: scheduler[0].calculateAlphasCumprod(strength: input.strength))
alphasCumprodStep: scheduler[0].calculateAlphasCumprod(strength: config.strength))
})
} else {
timestepStrength = nil
// Generate random latent samples from specified seed
latents = generateLatentSamples(input.imageCount, stdev: stdev, seed: input.seed)
latents = generateLatentSamples(config.imageCount, stdev: stdev, seed: config.seed)
}
// De-noising loop
@ -189,11 +189,11 @@ public struct StableDiffusionPipeline: ResourceManaging {
hiddenStates: hiddenStates
)
noise = performGuidance(noise, input.guidanceScale)
noise = performGuidance(noise, config.guidanceScale)
// Have the scheduler compute the previous (t-1) latent
// sample given the predicted noise and current sample
for i in 0..<input.imageCount {
for i in 0..<config.imageCount {
latents[i] = scheduler[i].step(
output: noise[i],
timeStep: t,
@ -204,11 +204,11 @@ public struct StableDiffusionPipeline: ResourceManaging {
// Report progress
let progress = Progress(
pipeline: self,
prompt: input.prompt,
prompt: config.prompt,
step: step,
stepCount: timeSteps.count,
currentLatentSamples: latents,
isSafetyEnabled: canSafetyCheck && !input.disableSafety
isSafetyEnabled: canSafetyCheck && !config.disableSafety
)
if !progressHandler(progress) {
// Stop if requested by handler
@ -221,7 +221,7 @@ public struct StableDiffusionPipeline: ResourceManaging {
}
// Decode the latent samples to images
return try decodeToImages(latents, disableSafety: input.disableSafety)
return try decodeToImages(latents, disableSafety: config.disableSafety)
}
func generateLatentSamples(_ count: Int, stdev: Float, seed: UInt32) -> [MLShapedArray<Float32>] {

@ -118,25 +118,27 @@ struct StableDiffusionSample: ParsableCommand {
let sampleTimer = SampleTimer()
sampleTimer.start()
var pipelineConfig = StableDiffusionPipeline.Configuration(prompt: prompt)
pipelineConfig.negativePrompt = negativePrompt
pipelineConfig.startingImage = startingImage
pipelineConfig.strength = strength
pipelineConfig.imageCount = imageCount
pipelineConfig.stepCount = stepCount
pipelineConfig.seed = seed
pipelineConfig.guidanceScale = guidanceScale
pipelineConfig.schedulerType = scheduler.stableDiffusionScheduler
let images = try pipeline.generateImages(
input: StableDiffusionPipeline.SampleInput(
prompt: prompt,
negativePrompt: negativePrompt,
startingImage: startingImage,
strength: strength,
imageCount: imageCount,
stepCount: stepCount,
seed: seed,
guidanceScale: guidanceScale,
schedulerType: scheduler.stableDiffusionScheduler)
) { progress in
sampleTimer.stop()
handleProgress(progress,sampleTimer)
if progress.stepCount != progress.step {
sampleTimer.start()
}
return true
}
configuration: pipelineConfig,
progressHandler: { progress in
sampleTimer.stop()
handleProgress(progress,sampleTimer)
if progress.stepCount != progress.step {
sampleTimer.start()
}
return true
})
_ = try saveImages(images, logNames: true)
}

Loading…
Cancel
Save