|
|
|
@ -108,6 +108,7 @@ public struct StableDiffusionPipeline: ResourceManaging {
|
|
|
|
|
///
|
|
|
|
|
/// - Parameters:
|
|
|
|
|
/// - prompt: Text prompt to guide sampling
|
|
|
|
|
/// - negativePrompt: Negative text prompt to guide sampling
|
|
|
|
|
/// - stepCount: Number of inference steps to perform
|
|
|
|
|
/// - imageCount: Number of samples/images to generate for the input prompt
|
|
|
|
|
/// - seed: Random seed which
|
|
|
|
@ -117,6 +118,7 @@ public struct StableDiffusionPipeline: ResourceManaging {
|
|
|
|
|
/// The images will be nil if safety checks were performed and found the result to be un-safe
|
|
|
|
|
public func generateImages(
|
|
|
|
|
prompt: String,
|
|
|
|
|
negativePrompt: String = "",
|
|
|
|
|
imageCount: Int = 1,
|
|
|
|
|
stepCount: Int = 50,
|
|
|
|
|
seed: UInt32 = 0,
|
|
|
|
@ -125,17 +127,18 @@ public struct StableDiffusionPipeline: ResourceManaging {
|
|
|
|
|
progressHandler: (Progress) -> Bool = { _ in true }
|
|
|
|
|
) throws -> [CGImage?] {
|
|
|
|
|
|
|
|
|
|
// Encode the input prompt as well as a blank unconditioned input
|
|
|
|
|
// Encode the input prompt and negative prompt
|
|
|
|
|
let promptEmbedding = try textEncoder.encode(prompt)
|
|
|
|
|
let blankEmbedding = try textEncoder.encode("")
|
|
|
|
|
let negativePromptEmbedding = try textEncoder.encode(negativePrompt)
|
|
|
|
|
|
|
|
|
|
if reduceMemory {
|
|
|
|
|
textEncoder.unloadResources()
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Convert to Unet hidden state representation
|
|
|
|
|
// Concatenate the prompt and negative prompt embeddings
|
|
|
|
|
let concatEmbedding = MLShapedArray<Float32>(
|
|
|
|
|
concatenating: [blankEmbedding, promptEmbedding],
|
|
|
|
|
concatenating: [negativePromptEmbedding, promptEmbedding],
|
|
|
|
|
alongAxis: 0
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|