diff --git a/swift/StableDiffusion/pipeline/StableDiffusionPipeline.swift b/swift/StableDiffusion/pipeline/StableDiffusionPipeline.swift index 0cd2253..ea65472 100644 --- a/swift/StableDiffusion/pipeline/StableDiffusionPipeline.swift +++ b/swift/StableDiffusion/pipeline/StableDiffusionPipeline.swift @@ -33,9 +33,6 @@ public struct StableDiffusionPipeline: ResourceManaging { /// Optional model for checking safety of generated image var safetyChecker: SafetyChecker? = nil - /// Controls the influence of the text prompt on sampling process (0=random images) - var guidanceScale: Float = 7.5 - /// Reports whether this pipeline can perform safety checks public var canSafetyCheck: Bool { safetyChecker != nil @@ -56,20 +53,17 @@ public struct StableDiffusionPipeline: ResourceManaging { /// - unet: Model for noise prediction on latent samples /// - decoder: Model for decoding latent sample to image /// - safetyChecker: Optional model for checking safety of generated images - /// - guidanceScale: Influence of the text prompt on generation process /// - reduceMemory: Option to enable reduced memory mode /// - Returns: Pipeline ready for image generation public init(textEncoder: TextEncoder, unet: Unet, decoder: Decoder, safetyChecker: SafetyChecker? = nil, - guidanceScale: Float = 7.5, reduceMemory: Bool = false) { self.textEncoder = textEncoder self.unet = unet self.decoder = decoder self.safetyChecker = safetyChecker - self.guidanceScale = guidanceScale self.reduceMemory = reduceMemory } @@ -112,6 +106,7 @@ public struct StableDiffusionPipeline: ResourceManaging { /// - stepCount: Number of inference steps to perform /// - imageCount: Number of samples/images to generate for the input prompt /// - seed: Random seed which + /// - guidanceScale: Controls the influence of the text prompt on sampling process (0=random images) /// - disableSafety: Safety checks are only performed if `self.canSafetyCheck && !disableSafety` /// - progressHandler: Callback to perform after each step, stops on receiving false response /// - Returns: An array of `imageCount` optional images. @@ -122,6 +117,7 @@ public struct StableDiffusionPipeline: ResourceManaging { imageCount: Int = 1, stepCount: Int = 50, seed: UInt32 = 0, + guidanceScale: Float = 7.5, disableSafety: Bool = false, scheduler: StableDiffusionScheduler = .pndmScheduler, progressHandler: (Progress) -> Bool = { _ in true } @@ -173,7 +169,7 @@ public struct StableDiffusionPipeline: ResourceManaging { hiddenStates: hiddenStates ) - noise = performGuidance(noise) + noise = performGuidance(noise, guidanceScale) // Have the scheduler compute the previous (t-1) latent // sample given the predicted noise and current sample @@ -236,11 +232,11 @@ public struct StableDiffusionPipeline: ResourceManaging { return states } - func performGuidance(_ noise: [MLShapedArray]) -> [MLShapedArray] { - noise.map { performGuidance($0) } + func performGuidance(_ noise: [MLShapedArray], _ guidanceScale: Float) -> [MLShapedArray] { + noise.map { performGuidance($0, guidanceScale) } } - func performGuidance(_ noise: MLShapedArray) -> MLShapedArray { + func performGuidance(_ noise: MLShapedArray, _ guidanceScale: Float) -> MLShapedArray { let blankNoiseScalars = noise[0].scalars let textNoiseScalars = noise[1].scalars diff --git a/swift/StableDiffusionCLI/main.swift b/swift/StableDiffusionCLI/main.swift index 5cbe627..7471316 100644 --- a/swift/StableDiffusionCLI/main.swift +++ b/swift/StableDiffusionCLI/main.swift @@ -53,6 +53,9 @@ struct StableDiffusionSample: ParsableCommand { @Option(help: "Random seed") var seed: UInt32 = 93 + @Option(help: "Controls the influence of the text prompt on sampling process (0=random images)") + var guidanceScale: Float = 7.5 + @Option(help: "Compute units to load model with {all,cpuOnly,cpuAndGPU,cpuAndNeuralEngine}") var computeUnits: ComputeUnits = .all @@ -92,6 +95,7 @@ struct StableDiffusionSample: ParsableCommand { imageCount: imageCount, stepCount: stepCount, seed: seed, + guidanceScale: guidanceScale, scheduler: scheduler.stableDiffusionScheduler ) { progress in sampleTimer.stop()