Synced to main branch and minimizes line changes

pull/61/head
JustinMeans 3 years ago committed by GitHub
parent 4c00b32010
commit e2924a8a8d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -108,6 +108,7 @@ public struct StableDiffusionPipeline: ResourceManaging {
/// ///
/// - Parameters: /// - Parameters:
/// - prompt: Text prompt to guide sampling /// - prompt: Text prompt to guide sampling
/// - negativePrompt: Negative text prompt to guide sampling
/// - stepCount: Number of inference steps to perform /// - stepCount: Number of inference steps to perform
/// - imageCount: Number of samples/images to generate for the input prompt /// - imageCount: Number of samples/images to generate for the input prompt
/// - seed: Random seed which /// - seed: Random seed which
@ -117,6 +118,7 @@ public struct StableDiffusionPipeline: ResourceManaging {
/// The images will be nil if safety checks were performed and found the result to be un-safe /// The images will be nil if safety checks were performed and found the result to be un-safe
public func generateImages( public func generateImages(
prompt: String, prompt: String,
negativePrompt: String = "",
imageCount: Int = 1, imageCount: Int = 1,
stepCount: Int = 50, stepCount: Int = 50,
seed: UInt32 = 0, seed: UInt32 = 0,
@ -125,17 +127,18 @@ public struct StableDiffusionPipeline: ResourceManaging {
progressHandler: (Progress) -> Bool = { _ in true } progressHandler: (Progress) -> Bool = { _ in true }
) throws -> [CGImage?] { ) throws -> [CGImage?] {
// Encode the input prompt as well as a blank unconditioned input // Encode the input prompt and negative prompt
let promptEmbedding = try textEncoder.encode(prompt) let promptEmbedding = try textEncoder.encode(prompt)
let blankEmbedding = try textEncoder.encode("") let negativePromptEmbedding = try textEncoder.encode(negativePrompt)
if reduceMemory { if reduceMemory {
textEncoder.unloadResources() textEncoder.unloadResources()
} }
// Convert to Unet hidden state representation // Convert to Unet hidden state representation
// Concatenate the prompt and negative prompt embeddings
let concatEmbedding = MLShapedArray<Float32>( let concatEmbedding = MLShapedArray<Float32>(
concatenating: [blankEmbedding, promptEmbedding], concatenating: [negativePromptEmbedding, promptEmbedding],
alongAxis: 0 alongAxis: 0
) )

Loading…
Cancel
Save