From f082f5d3e7a7a01a88e7070d996a95ad7ade5ccc Mon Sep 17 00:00:00 2001 From: Wanaldino Antimonio Date: Mon, 19 Dec 2022 20:28:21 +0100 Subject: [PATCH] Added negative promp --- .../StableDiffusion/pipeline/StableDiffusionPipeline.swift | 7 +++++-- swift/StableDiffusionCLI/main.swift | 4 ++++ 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/swift/StableDiffusion/pipeline/StableDiffusionPipeline.swift b/swift/StableDiffusion/pipeline/StableDiffusionPipeline.swift index eddbda4..5b6989e 100644 --- a/swift/StableDiffusion/pipeline/StableDiffusionPipeline.swift +++ b/swift/StableDiffusion/pipeline/StableDiffusionPipeline.swift @@ -117,6 +117,7 @@ public struct StableDiffusionPipeline: ResourceManaging { /// The images will be nil if safety checks were performed and found the result to be un-safe public func generateImages( prompt: String, + negativePrompt: String = "", imageCount: Int = 1, stepCount: Int = 50, seed: Int = 0, @@ -127,7 +128,7 @@ public struct StableDiffusionPipeline: ResourceManaging { // Encode the input prompt as well as a blank unconditioned input let promptEmbedding = try textEncoder.encode(prompt) - let blankEmbedding = try textEncoder.encode("") + let negativePromptEmbedding = try textEncoder.encode(negativePrompt) if reduceMemory { textEncoder.unloadResources() @@ -135,7 +136,7 @@ public struct StableDiffusionPipeline: ResourceManaging { // Convert to Unet hidden state representation let concatEmbedding = MLShapedArray( - concatenating: [blankEmbedding, promptEmbedding], + concatenating: [negativePromptEmbedding, promptEmbedding], alongAxis: 0 ) @@ -186,6 +187,7 @@ public struct StableDiffusionPipeline: ResourceManaging { let progress = Progress( pipeline: self, prompt: prompt, + negativePromp: negativePrompt, step: step, stepCount: stepCount, currentLatentSamples: latents, @@ -292,6 +294,7 @@ extension StableDiffusionPipeline { public struct Progress { public let pipeline: StableDiffusionPipeline public let prompt: String + public let negativePromp: String public let step: Int public let stepCount: Int public let currentLatentSamples: [MLShapedArray] diff --git a/swift/StableDiffusionCLI/main.swift b/swift/StableDiffusionCLI/main.swift index 4da8e62..333853c 100644 --- a/swift/StableDiffusionCLI/main.swift +++ b/swift/StableDiffusionCLI/main.swift @@ -19,6 +19,9 @@ struct StableDiffusionSample: ParsableCommand { @Argument(help: "Input string prompt") var prompt: String + @Option(help: "Input string negative prompt") + var negativePrompt: String + @Option( help: ArgumentHelp( "Path to stable diffusion resources.", @@ -85,6 +88,7 @@ struct StableDiffusionSample: ParsableCommand { let images = try pipeline.generateImages( prompt: prompt, + negativePrompt: negativePrompt, imageCount: imageCount, stepCount: stepCount, seed: seed,