Adds Negative Prompts (#61)

* Synced to main branch and minimizes line changes

* Adds negative prompt argument to CLI

Co-authored-by: Wanaldino Antimonio <carlosmm31093@gmail.com>
pull/55/head
JustinMeans 2 years ago committed by GitHub
parent 4c00b32010
commit c90b705334
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -108,6 +108,7 @@ public struct StableDiffusionPipeline: ResourceManaging {
/// ///
/// - Parameters: /// - Parameters:
/// - prompt: Text prompt to guide sampling /// - prompt: Text prompt to guide sampling
/// - negativePrompt: Negative text prompt to guide sampling
/// - stepCount: Number of inference steps to perform /// - stepCount: Number of inference steps to perform
/// - imageCount: Number of samples/images to generate for the input prompt /// - imageCount: Number of samples/images to generate for the input prompt
/// - seed: Random seed which /// - seed: Random seed which
@ -117,6 +118,7 @@ public struct StableDiffusionPipeline: ResourceManaging {
/// The images will be nil if safety checks were performed and found the result to be un-safe /// The images will be nil if safety checks were performed and found the result to be un-safe
public func generateImages( public func generateImages(
prompt: String, prompt: String,
negativePrompt: String = "",
imageCount: Int = 1, imageCount: Int = 1,
stepCount: Int = 50, stepCount: Int = 50,
seed: UInt32 = 0, seed: UInt32 = 0,
@ -125,17 +127,18 @@ public struct StableDiffusionPipeline: ResourceManaging {
progressHandler: (Progress) -> Bool = { _ in true } progressHandler: (Progress) -> Bool = { _ in true }
) throws -> [CGImage?] { ) throws -> [CGImage?] {
// Encode the input prompt as well as a blank unconditioned input // Encode the input prompt and negative prompt
let promptEmbedding = try textEncoder.encode(prompt) let promptEmbedding = try textEncoder.encode(prompt)
let blankEmbedding = try textEncoder.encode("") let negativePromptEmbedding = try textEncoder.encode(negativePrompt)
if reduceMemory { if reduceMemory {
textEncoder.unloadResources() textEncoder.unloadResources()
} }
// Convert to Unet hidden state representation // Convert to Unet hidden state representation
// Concatenate the prompt and negative prompt embeddings
let concatEmbedding = MLShapedArray<Float32>( let concatEmbedding = MLShapedArray<Float32>(
concatenating: [blankEmbedding, promptEmbedding], concatenating: [negativePromptEmbedding, promptEmbedding],
alongAxis: 0 alongAxis: 0
) )

@ -19,6 +19,9 @@ struct StableDiffusionSample: ParsableCommand {
@Argument(help: "Input string prompt") @Argument(help: "Input string prompt")
var prompt: String var prompt: String
@Option(help: "Input string negative prompt")
var negativePrompt: String
@Option( @Option(
help: ArgumentHelp( help: ArgumentHelp(
"Path to stable diffusion resources.", "Path to stable diffusion resources.",
@ -85,6 +88,7 @@ struct StableDiffusionSample: ParsableCommand {
let images = try pipeline.generateImages( let images = try pipeline.generateImages(
prompt: prompt, prompt: prompt,
negativePrompt: negativePrompt,
imageCount: imageCount, imageCount: imageCount,
stepCount: stepCount, stepCount: stepCount,
seed: seed, seed: seed,

Loading…
Cancel
Save