From dacda3e156b79a3f0f375b274c411c3e9b0cdfc7 Mon Sep 17 00:00:00 2001 From: Wanaldino Antimonio Date: Mon, 5 Dec 2022 16:56:30 +0100 Subject: [PATCH] Move guidanceScale as generation parameter --- .../pipeline/StableDiffusionPipeline.swift | 25 +++++++++---------- 1 file changed, 12 insertions(+), 13 deletions(-) diff --git a/swift/StableDiffusion/pipeline/StableDiffusionPipeline.swift b/swift/StableDiffusion/pipeline/StableDiffusionPipeline.swift index bc888cc..33dee2f 100644 --- a/swift/StableDiffusion/pipeline/StableDiffusionPipeline.swift +++ b/swift/StableDiffusion/pipeline/StableDiffusionPipeline.swift @@ -24,9 +24,6 @@ public struct StableDiffusionPipeline { /// Optional model for checking safety of generated image var safetyChecker: SafetyChecker? = nil - /// Controls the influence of the text prompt on sampling process (0=random images) - var guidanceScale: Float = 7.5 - /// Reports whether this pipeline can perform safety checks public var canSafetyCheck: Bool { safetyChecker != nil @@ -41,16 +38,16 @@ public struct StableDiffusionPipeline { /// - safetyChecker: Optional model for checking safety of generated images /// - guidanceScale: Influence of the text prompt on generation process /// - Returns: Pipeline ready for image generation - public init(textEncoder: TextEncoder, - unet: Unet, - decoder: Decoder, - safetyChecker: SafetyChecker? = nil, - guidanceScale: Float = 7.5) { + public init( + textEncoder: TextEncoder, + unet: Unet, + decoder: Decoder, + safetyChecker: SafetyChecker? = nil + ) { self.textEncoder = textEncoder self.unet = unet self.decoder = decoder self.safetyChecker = safetyChecker - self.guidanceScale = guidanceScale } /// Text to image generation using stable diffusion @@ -60,6 +57,7 @@ public struct StableDiffusionPipeline { /// - stepCount: Number of inference steps to perform /// - imageCount: Number of samples/images to generate for the input prompt /// - seed: Random seed which + /// - guidanceScale: Controls the influence of the text prompt on sampling process (0=random images) /// - disableSafety: Safety checks are only performed if `self.canSafetyCheck && !disableSafety` /// - progressHandler: Callback to perform after each step, stops on receiving false response /// - Returns: An array of `imageCount` optional images. @@ -69,6 +67,7 @@ public struct StableDiffusionPipeline { imageCount: Int = 1, stepCount: Int = 50, seed: Int = 0, + guidanceScale: Float = 7.5, disableSafety: Bool = false, progressHandler: (Progress) -> Bool = { _ in true } ) throws -> [CGImage?] { @@ -109,7 +108,7 @@ public struct StableDiffusionPipeline { hiddenStates: hiddenStates ) - noise = performGuidance(noise) + noise = performGuidance(noise, guidanceScale) // Have the scheduler compute the previous (t-1) latent // sample given the predicted noise and current sample @@ -168,11 +167,11 @@ public struct StableDiffusionPipeline { return states } - func performGuidance(_ noise: [MLShapedArray]) -> [MLShapedArray] { - noise.map { performGuidance($0) } + func performGuidance(_ noise: [MLShapedArray], _ guidanceScale: Float) -> [MLShapedArray] { + noise.map { performGuidance($0, guidanceScale) } } - func performGuidance(_ noise: MLShapedArray) -> MLShapedArray { + func performGuidance(_ noise: MLShapedArray, _ guidanceScale: Float) -> MLShapedArray { let blankNoiseScalars = noise[0].scalars let textNoiseScalars = noise[1].scalars