diff --git a/swift/StableDiffusion/pipeline/StableDiffusionPipeline.swift b/swift/StableDiffusion/pipeline/StableDiffusionPipeline.swift index 0cd2253..d918751 100644 --- a/swift/StableDiffusion/pipeline/StableDiffusionPipeline.swift +++ b/swift/StableDiffusion/pipeline/StableDiffusionPipeline.swift @@ -33,9 +33,6 @@ public struct StableDiffusionPipeline: ResourceManaging { /// Optional model for checking safety of generated image var safetyChecker: SafetyChecker? = nil - /// Controls the influence of the text prompt on sampling process (0=random images) - var guidanceScale: Float = 7.5 - /// Reports whether this pipeline can perform safety checks public var canSafetyCheck: Bool { safetyChecker != nil @@ -56,20 +53,19 @@ public struct StableDiffusionPipeline: ResourceManaging { /// - unet: Model for noise prediction on latent samples /// - decoder: Model for decoding latent sample to image /// - safetyChecker: Optional model for checking safety of generated images - /// - guidanceScale: Influence of the text prompt on generation process /// - reduceMemory: Option to enable reduced memory mode /// - Returns: Pipeline ready for image generation - public init(textEncoder: TextEncoder, - unet: Unet, - decoder: Decoder, - safetyChecker: SafetyChecker? = nil, - guidanceScale: Float = 7.5, - reduceMemory: Bool = false) { + public init( + textEncoder: TextEncoder, + unet: Unet, + decoder: Decoder, + safetyChecker: SafetyChecker? = nil, + reduceMemory: Bool = false + ) { self.textEncoder = textEncoder self.unet = unet self.decoder = decoder self.safetyChecker = safetyChecker - self.guidanceScale = guidanceScale self.reduceMemory = reduceMemory } @@ -112,6 +108,7 @@ public struct StableDiffusionPipeline: ResourceManaging { /// - stepCount: Number of inference steps to perform /// - imageCount: Number of samples/images to generate for the input prompt /// - seed: Random seed which + /// - guidanceScale: Controls the influence of the text prompt on sampling process (0=random images) /// - disableSafety: Safety checks are only performed if `self.canSafetyCheck && !disableSafety` /// - progressHandler: Callback to perform after each step, stops on receiving false response /// - Returns: An array of `imageCount` optional images. @@ -122,6 +119,7 @@ public struct StableDiffusionPipeline: ResourceManaging { imageCount: Int = 1, stepCount: Int = 50, seed: UInt32 = 0, + guidanceScale: Float = 7.5, disableSafety: Bool = false, scheduler: StableDiffusionScheduler = .pndmScheduler, progressHandler: (Progress) -> Bool = { _ in true } @@ -173,7 +171,7 @@ public struct StableDiffusionPipeline: ResourceManaging { hiddenStates: hiddenStates ) - noise = performGuidance(noise) + noise = performGuidance(noise, guidanceScale) // Have the scheduler compute the previous (t-1) latent // sample given the predicted noise and current sample @@ -236,11 +234,11 @@ public struct StableDiffusionPipeline: ResourceManaging { return states } - func performGuidance(_ noise: [MLShapedArray]) -> [MLShapedArray] { - noise.map { performGuidance($0) } + func performGuidance(_ noise: [MLShapedArray], _ guidanceScale: Float) -> [MLShapedArray] { + noise.map { performGuidance($0, guidanceScale) } } - func performGuidance(_ noise: MLShapedArray) -> MLShapedArray { + func performGuidance(_ noise: MLShapedArray, _ guidanceScale: Float) -> MLShapedArray { let blankNoiseScalars = noise[0].scalars let textNoiseScalars = noise[1].scalars