From e2924a8a8d4bb10f0a86c1e2c2a7703e30f74f4e Mon Sep 17 00:00:00 2001 From: JustinMeans <46542161+JustinMeans@users.noreply.github.com> Date: Tue, 20 Dec 2022 17:54:51 -0700 Subject: [PATCH] Synced to main branch and minimizes line changes --- .../pipeline/StableDiffusionPipeline.swift | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/swift/StableDiffusion/pipeline/StableDiffusionPipeline.swift b/swift/StableDiffusion/pipeline/StableDiffusionPipeline.swift index 8464f8a..0cd2253 100644 --- a/swift/StableDiffusion/pipeline/StableDiffusionPipeline.swift +++ b/swift/StableDiffusion/pipeline/StableDiffusionPipeline.swift @@ -108,6 +108,7 @@ public struct StableDiffusionPipeline: ResourceManaging { /// /// - Parameters: /// - prompt: Text prompt to guide sampling + /// - negativePrompt: Negative text prompt to guide sampling /// - stepCount: Number of inference steps to perform /// - imageCount: Number of samples/images to generate for the input prompt /// - seed: Random seed which @@ -117,6 +118,7 @@ public struct StableDiffusionPipeline: ResourceManaging { /// The images will be nil if safety checks were performed and found the result to be un-safe public func generateImages( prompt: String, + negativePrompt: String = "", imageCount: Int = 1, stepCount: Int = 50, seed: UInt32 = 0, @@ -125,17 +127,18 @@ public struct StableDiffusionPipeline: ResourceManaging { progressHandler: (Progress) -> Bool = { _ in true } ) throws -> [CGImage?] { - // Encode the input prompt as well as a blank unconditioned input + // Encode the input prompt and negative prompt let promptEmbedding = try textEncoder.encode(prompt) - let blankEmbedding = try textEncoder.encode("") + let negativePromptEmbedding = try textEncoder.encode(negativePrompt) if reduceMemory { textEncoder.unloadResources() } // Convert to Unet hidden state representation + // Concatenate the prompt and negative prompt embeddings let concatEmbedding = MLShapedArray( - concatenating: [blankEmbedding, promptEmbedding], + concatenating: [negativePromptEmbedding, promptEmbedding], alongAxis: 0 )