Added negative promp

pull/76/head
Wanaldino Antimonio 3 years ago
parent 66dde8da13
commit f082f5d3e7

@ -117,6 +117,7 @@ public struct StableDiffusionPipeline: ResourceManaging {
/// The images will be nil if safety checks were performed and found the result to be un-safe /// The images will be nil if safety checks were performed and found the result to be un-safe
public func generateImages( public func generateImages(
prompt: String, prompt: String,
negativePrompt: String = "",
imageCount: Int = 1, imageCount: Int = 1,
stepCount: Int = 50, stepCount: Int = 50,
seed: Int = 0, seed: Int = 0,
@ -127,7 +128,7 @@ public struct StableDiffusionPipeline: ResourceManaging {
// Encode the input prompt as well as a blank unconditioned input // Encode the input prompt as well as a blank unconditioned input
let promptEmbedding = try textEncoder.encode(prompt) let promptEmbedding = try textEncoder.encode(prompt)
let blankEmbedding = try textEncoder.encode("") let negativePromptEmbedding = try textEncoder.encode(negativePrompt)
if reduceMemory { if reduceMemory {
textEncoder.unloadResources() textEncoder.unloadResources()
@ -135,7 +136,7 @@ public struct StableDiffusionPipeline: ResourceManaging {
// Convert to Unet hidden state representation // Convert to Unet hidden state representation
let concatEmbedding = MLShapedArray<Float32>( let concatEmbedding = MLShapedArray<Float32>(
concatenating: [blankEmbedding, promptEmbedding], concatenating: [negativePromptEmbedding, promptEmbedding],
alongAxis: 0 alongAxis: 0
) )
@ -186,6 +187,7 @@ public struct StableDiffusionPipeline: ResourceManaging {
let progress = Progress( let progress = Progress(
pipeline: self, pipeline: self,
prompt: prompt, prompt: prompt,
negativePromp: negativePrompt,
step: step, step: step,
stepCount: stepCount, stepCount: stepCount,
currentLatentSamples: latents, currentLatentSamples: latents,
@ -292,6 +294,7 @@ extension StableDiffusionPipeline {
public struct Progress { public struct Progress {
public let pipeline: StableDiffusionPipeline public let pipeline: StableDiffusionPipeline
public let prompt: String public let prompt: String
public let negativePromp: String
public let step: Int public let step: Int
public let stepCount: Int public let stepCount: Int
public let currentLatentSamples: [MLShapedArray<Float32>] public let currentLatentSamples: [MLShapedArray<Float32>]

@ -19,6 +19,9 @@ struct StableDiffusionSample: ParsableCommand {
@Argument(help: "Input string prompt") @Argument(help: "Input string prompt")
var prompt: String var prompt: String
@Option(help: "Input string negative prompt")
var negativePrompt: String
@Option( @Option(
help: ArgumentHelp( help: ArgumentHelp(
"Path to stable diffusion resources.", "Path to stable diffusion resources.",
@ -85,6 +88,7 @@ struct StableDiffusionSample: ParsableCommand {
let images = try pipeline.generateImages( let images = try pipeline.generateImages(
prompt: prompt, prompt: prompt,
negativePrompt: negativePrompt,
imageCount: imageCount, imageCount: imageCount,
stepCount: stepCount, stepCount: stepCount,
seed: seed, seed: seed,

Loading…
Cancel
Save