Synced to main branch and minimizes line changes

pull/61/head
JustinMeans 1 year ago committed by GitHub
parent 4c00b32010
commit e2924a8a8d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -108,6 +108,7 @@ public struct StableDiffusionPipeline: ResourceManaging {
///
/// - Parameters:
/// - prompt: Text prompt to guide sampling
/// - negativePrompt: Negative text prompt to guide sampling
/// - stepCount: Number of inference steps to perform
/// - imageCount: Number of samples/images to generate for the input prompt
/// - seed: Random seed which
@ -117,6 +118,7 @@ public struct StableDiffusionPipeline: ResourceManaging {
/// The images will be nil if safety checks were performed and found the result to be un-safe
public func generateImages(
prompt: String,
negativePrompt: String = "",
imageCount: Int = 1,
stepCount: Int = 50,
seed: UInt32 = 0,
@ -125,17 +127,18 @@ public struct StableDiffusionPipeline: ResourceManaging {
progressHandler: (Progress) -> Bool = { _ in true }
) throws -> [CGImage?] {
// Encode the input prompt as well as a blank unconditioned input
// Encode the input prompt and negative prompt
let promptEmbedding = try textEncoder.encode(prompt)
let blankEmbedding = try textEncoder.encode("")
let negativePromptEmbedding = try textEncoder.encode(negativePrompt)
if reduceMemory {
textEncoder.unloadResources()
}
// Convert to Unet hidden state representation
// Concatenate the prompt and negative prompt embeddings
let concatEmbedding = MLShapedArray<Float32>(
concatenating: [blankEmbedding, promptEmbedding],
concatenating: [negativePromptEmbedding, promptEmbedding],
alongAxis: 0
)

Loading…
Cancel
Save