|
|
|
@ -117,6 +117,7 @@ public struct StableDiffusionPipeline: ResourceManaging {
|
|
|
|
|
/// The images will be nil if safety checks were performed and found the result to be un-safe
|
|
|
|
|
public func generateImages(
|
|
|
|
|
prompt: String,
|
|
|
|
|
negativePrompt: String = "",
|
|
|
|
|
imageCount: Int = 1,
|
|
|
|
|
stepCount: Int = 50,
|
|
|
|
|
seed: Int = 0,
|
|
|
|
@ -127,7 +128,7 @@ public struct StableDiffusionPipeline: ResourceManaging {
|
|
|
|
|
|
|
|
|
|
// Encode the input prompt as well as a blank unconditioned input
|
|
|
|
|
let promptEmbedding = try textEncoder.encode(prompt)
|
|
|
|
|
let blankEmbedding = try textEncoder.encode("")
|
|
|
|
|
let negativePromptEmbedding = try textEncoder.encode(negativePrompt)
|
|
|
|
|
|
|
|
|
|
if reduceMemory {
|
|
|
|
|
textEncoder.unloadResources()
|
|
|
|
@ -135,7 +136,7 @@ public struct StableDiffusionPipeline: ResourceManaging {
|
|
|
|
|
|
|
|
|
|
// Convert to Unet hidden state representation
|
|
|
|
|
let concatEmbedding = MLShapedArray<Float32>(
|
|
|
|
|
concatenating: [blankEmbedding, promptEmbedding],
|
|
|
|
|
concatenating: [negativePromptEmbedding, promptEmbedding],
|
|
|
|
|
alongAxis: 0
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
@ -186,6 +187,7 @@ public struct StableDiffusionPipeline: ResourceManaging {
|
|
|
|
|
let progress = Progress(
|
|
|
|
|
pipeline: self,
|
|
|
|
|
prompt: prompt,
|
|
|
|
|
negativePromp: negativePrompt,
|
|
|
|
|
step: step,
|
|
|
|
|
stepCount: stepCount,
|
|
|
|
|
currentLatentSamples: latents,
|
|
|
|
@ -292,6 +294,7 @@ extension StableDiffusionPipeline {
|
|
|
|
|
public struct Progress {
|
|
|
|
|
public let pipeline: StableDiffusionPipeline
|
|
|
|
|
public let prompt: String
|
|
|
|
|
public let negativePromp: String
|
|
|
|
|
public let step: Int
|
|
|
|
|
public let stepCount: Int
|
|
|
|
|
public let currentLatentSamples: [MLShapedArray<Float32>]
|
|
|
|
|