@ -108,6 +108,7 @@ public struct StableDiffusionPipeline: ResourceManaging {
// /
// /
// / - P a r a m e t e r s :
// / - P a r a m e t e r s :
// / - p r o m p t : T e x t p r o m p t t o g u i d e s a m p l i n g
// / - p r o m p t : T e x t p r o m p t t o g u i d e s a m p l i n g
// / - n e g a t i v e P r o m p t : N e g a t i v e t e x t p r o m p t t o g u i d e s a m p l i n g
// / - s t e p C o u n t : N u m b e r o f i n f e r e n c e s t e p s t o p e r f o r m
// / - s t e p C o u n t : N u m b e r o f i n f e r e n c e s t e p s t o p e r f o r m
// / - i m a g e C o u n t : N u m b e r o f s a m p l e s / i m a g e s t o g e n e r a t e f o r t h e i n p u t p r o m p t
// / - i m a g e C o u n t : N u m b e r o f s a m p l e s / i m a g e s t o g e n e r a t e f o r t h e i n p u t p r o m p t
// / - s e e d : R a n d o m s e e d w h i c h
// / - s e e d : R a n d o m s e e d w h i c h
@ -117,6 +118,7 @@ public struct StableDiffusionPipeline: ResourceManaging {
// / T h e i m a g e s w i l l b e n i l i f s a f e t y c h e c k s w e r e p e r f o r m e d a n d f o u n d t h e r e s u l t t o b e u n - s a f e
// / T h e i m a g e s w i l l b e n i l i f s a f e t y c h e c k s w e r e p e r f o r m e d a n d f o u n d t h e r e s u l t t o b e u n - s a f e
public func generateImages (
public func generateImages (
prompt : String ,
prompt : String ,
negativePrompt : String = " " ,
imageCount : Int = 1 ,
imageCount : Int = 1 ,
stepCount : Int = 50 ,
stepCount : Int = 50 ,
seed : UInt32 = 0 ,
seed : UInt32 = 0 ,
@ -125,17 +127,18 @@ public struct StableDiffusionPipeline: ResourceManaging {
progressHandler : ( Progress ) -> Bool = { _ in true }
progressHandler : ( Progress ) -> Bool = { _ in true }
) throws -> [ CGImage ? ] {
) throws -> [ CGImage ? ] {
// E n c o d e t h e i n p u t p r o m p t a s w e l l a s a b l a n k u n c o n d i t i o n e d i n p u t
// E n c o d e t h e i n p u t p r o m p t a nd n e g a t i v e p r o m p t
let promptEmbedding = try textEncoder . encode ( prompt )
let promptEmbedding = try textEncoder . encode ( prompt )
let blankEmbedding = try textEncoder . encode ( " " )
let negativePromptEmbedding = try textEncoder . encode ( negativePrompt )
if reduceMemory {
if reduceMemory {
textEncoder . unloadResources ( )
textEncoder . unloadResources ( )
}
}
// C o n v e r t t o U n e t h i d d e n s t a t e r e p r e s e n t a t i o n
// C o n v e r t t o U n e t h i d d e n s t a t e r e p r e s e n t a t i o n
// C o n c a t e n a t e t h e p r o m p t a n d n e g a t i v e p r o m p t e m b e d d i n g s
let concatEmbedding = MLShapedArray < Float32 > (
let concatEmbedding = MLShapedArray < Float32 > (
concatenating : [ blank Embedding, promptEmbedding ] ,
concatenating : [ negativePrompt Embedding, promptEmbedding ] ,
alongAxis : 0
alongAxis : 0
)
)