Adds Negative Prompts (#61)

* Synced to main branch and minimizes line changes

* Adds negative prompt argument to CLI

Co-authored-by: Wanaldino Antimonio <carlosmm31093@gmail.com>
pull/55/head
JustinMeans 2 years ago committed by GitHub
parent 4c00b32010
commit c90b705334
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -108,6 +108,7 @@ public struct StableDiffusionPipeline: ResourceManaging {
///
/// - Parameters:
/// - prompt: Text prompt to guide sampling
/// - negativePrompt: Negative text prompt to guide sampling
/// - stepCount: Number of inference steps to perform
/// - imageCount: Number of samples/images to generate for the input prompt
/// - seed: Random seed which
@ -117,6 +118,7 @@ public struct StableDiffusionPipeline: ResourceManaging {
/// The images will be nil if safety checks were performed and found the result to be un-safe
public func generateImages(
prompt: String,
negativePrompt: String = "",
imageCount: Int = 1,
stepCount: Int = 50,
seed: UInt32 = 0,
@ -125,17 +127,18 @@ public struct StableDiffusionPipeline: ResourceManaging {
progressHandler: (Progress) -> Bool = { _ in true }
) throws -> [CGImage?] {
// Encode the input prompt as well as a blank unconditioned input
// Encode the input prompt and negative prompt
let promptEmbedding = try textEncoder.encode(prompt)
let blankEmbedding = try textEncoder.encode("")
let negativePromptEmbedding = try textEncoder.encode(negativePrompt)
if reduceMemory {
textEncoder.unloadResources()
}
// Convert to Unet hidden state representation
// Concatenate the prompt and negative prompt embeddings
let concatEmbedding = MLShapedArray<Float32>(
concatenating: [blankEmbedding, promptEmbedding],
concatenating: [negativePromptEmbedding, promptEmbedding],
alongAxis: 0
)

@ -19,6 +19,9 @@ struct StableDiffusionSample: ParsableCommand {
@Argument(help: "Input string prompt")
var prompt: String
@Option(help: "Input string negative prompt")
var negativePrompt: String
@Option(
help: ArgumentHelp(
"Path to stable diffusion resources.",
@ -85,6 +88,7 @@ struct StableDiffusionSample: ParsableCommand {
let images = try pipeline.generateImages(
prompt: prompt,
negativePrompt: negativePrompt,
imageCount: imageCount,
stepCount: stepCount,
seed: seed,

Loading…
Cancel
Save