|
|
@ -33,9 +33,6 @@ public struct StableDiffusionPipeline: ResourceManaging {
|
|
|
|
/// Optional model for checking safety of generated image
|
|
|
|
/// Optional model for checking safety of generated image
|
|
|
|
var safetyChecker: SafetyChecker? = nil
|
|
|
|
var safetyChecker: SafetyChecker? = nil
|
|
|
|
|
|
|
|
|
|
|
|
/// Controls the influence of the text prompt on sampling process (0=random images)
|
|
|
|
|
|
|
|
var guidanceScale: Float = 7.5
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
/// Reports whether this pipeline can perform safety checks
|
|
|
|
/// Reports whether this pipeline can perform safety checks
|
|
|
|
public var canSafetyCheck: Bool {
|
|
|
|
public var canSafetyCheck: Bool {
|
|
|
|
safetyChecker != nil
|
|
|
|
safetyChecker != nil
|
|
|
@ -56,20 +53,17 @@ public struct StableDiffusionPipeline: ResourceManaging {
|
|
|
|
/// - unet: Model for noise prediction on latent samples
|
|
|
|
/// - unet: Model for noise prediction on latent samples
|
|
|
|
/// - decoder: Model for decoding latent sample to image
|
|
|
|
/// - decoder: Model for decoding latent sample to image
|
|
|
|
/// - safetyChecker: Optional model for checking safety of generated images
|
|
|
|
/// - safetyChecker: Optional model for checking safety of generated images
|
|
|
|
/// - guidanceScale: Influence of the text prompt on generation process
|
|
|
|
|
|
|
|
/// - reduceMemory: Option to enable reduced memory mode
|
|
|
|
/// - reduceMemory: Option to enable reduced memory mode
|
|
|
|
/// - Returns: Pipeline ready for image generation
|
|
|
|
/// - Returns: Pipeline ready for image generation
|
|
|
|
public init(textEncoder: TextEncoder,
|
|
|
|
public init(textEncoder: TextEncoder,
|
|
|
|
unet: Unet,
|
|
|
|
unet: Unet,
|
|
|
|
decoder: Decoder,
|
|
|
|
decoder: Decoder,
|
|
|
|
safetyChecker: SafetyChecker? = nil,
|
|
|
|
safetyChecker: SafetyChecker? = nil,
|
|
|
|
guidanceScale: Float = 7.5,
|
|
|
|
|
|
|
|
reduceMemory: Bool = false) {
|
|
|
|
reduceMemory: Bool = false) {
|
|
|
|
self.textEncoder = textEncoder
|
|
|
|
self.textEncoder = textEncoder
|
|
|
|
self.unet = unet
|
|
|
|
self.unet = unet
|
|
|
|
self.decoder = decoder
|
|
|
|
self.decoder = decoder
|
|
|
|
self.safetyChecker = safetyChecker
|
|
|
|
self.safetyChecker = safetyChecker
|
|
|
|
self.guidanceScale = guidanceScale
|
|
|
|
|
|
|
|
self.reduceMemory = reduceMemory
|
|
|
|
self.reduceMemory = reduceMemory
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
@ -112,6 +106,7 @@ public struct StableDiffusionPipeline: ResourceManaging {
|
|
|
|
/// - stepCount: Number of inference steps to perform
|
|
|
|
/// - stepCount: Number of inference steps to perform
|
|
|
|
/// - imageCount: Number of samples/images to generate for the input prompt
|
|
|
|
/// - imageCount: Number of samples/images to generate for the input prompt
|
|
|
|
/// - seed: Random seed which
|
|
|
|
/// - seed: Random seed which
|
|
|
|
|
|
|
|
/// - guidanceScale: Controls the influence of the text prompt on sampling process (0=random images)
|
|
|
|
/// - disableSafety: Safety checks are only performed if `self.canSafetyCheck && !disableSafety`
|
|
|
|
/// - disableSafety: Safety checks are only performed if `self.canSafetyCheck && !disableSafety`
|
|
|
|
/// - progressHandler: Callback to perform after each step, stops on receiving false response
|
|
|
|
/// - progressHandler: Callback to perform after each step, stops on receiving false response
|
|
|
|
/// - Returns: An array of `imageCount` optional images.
|
|
|
|
/// - Returns: An array of `imageCount` optional images.
|
|
|
@ -122,6 +117,7 @@ public struct StableDiffusionPipeline: ResourceManaging {
|
|
|
|
imageCount: Int = 1,
|
|
|
|
imageCount: Int = 1,
|
|
|
|
stepCount: Int = 50,
|
|
|
|
stepCount: Int = 50,
|
|
|
|
seed: UInt32 = 0,
|
|
|
|
seed: UInt32 = 0,
|
|
|
|
|
|
|
|
guidanceScale: Float = 7.5,
|
|
|
|
disableSafety: Bool = false,
|
|
|
|
disableSafety: Bool = false,
|
|
|
|
scheduler: StableDiffusionScheduler = .pndmScheduler,
|
|
|
|
scheduler: StableDiffusionScheduler = .pndmScheduler,
|
|
|
|
progressHandler: (Progress) -> Bool = { _ in true }
|
|
|
|
progressHandler: (Progress) -> Bool = { _ in true }
|
|
|
@ -173,7 +169,7 @@ public struct StableDiffusionPipeline: ResourceManaging {
|
|
|
|
hiddenStates: hiddenStates
|
|
|
|
hiddenStates: hiddenStates
|
|
|
|
)
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
noise = performGuidance(noise)
|
|
|
|
noise = performGuidance(noise, guidanceScale)
|
|
|
|
|
|
|
|
|
|
|
|
// Have the scheduler compute the previous (t-1) latent
|
|
|
|
// Have the scheduler compute the previous (t-1) latent
|
|
|
|
// sample given the predicted noise and current sample
|
|
|
|
// sample given the predicted noise and current sample
|
|
|
@ -236,11 +232,11 @@ public struct StableDiffusionPipeline: ResourceManaging {
|
|
|
|
return states
|
|
|
|
return states
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
func performGuidance(_ noise: [MLShapedArray<Float32>]) -> [MLShapedArray<Float32>] {
|
|
|
|
func performGuidance(_ noise: [MLShapedArray<Float32>], _ guidanceScale: Float) -> [MLShapedArray<Float32>] {
|
|
|
|
noise.map { performGuidance($0) }
|
|
|
|
noise.map { performGuidance($0, guidanceScale) }
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
func performGuidance(_ noise: MLShapedArray<Float32>) -> MLShapedArray<Float32> {
|
|
|
|
func performGuidance(_ noise: MLShapedArray<Float32>, _ guidanceScale: Float) -> MLShapedArray<Float32> {
|
|
|
|
|
|
|
|
|
|
|
|
let blankNoiseScalars = noise[0].scalars
|
|
|
|
let blankNoiseScalars = noise[0].scalars
|
|
|
|
let textNoiseScalars = noise[1].scalars
|
|
|
|
let textNoiseScalars = noise[1].scalars
|
|
|
|