Move guidanceScale as generation parameter

pull/40/head
Wanaldino Antimonio 3 years ago
parent 583cc04ece
commit dacda3e156

@ -24,9 +24,6 @@ public struct StableDiffusionPipeline {
/// Optional model for checking safety of generated image
var safetyChecker: SafetyChecker? = nil
/// Controls the influence of the text prompt on sampling process (0=random images)
var guidanceScale: Float = 7.5
/// Reports whether this pipeline can perform safety checks
public var canSafetyCheck: Bool {
safetyChecker != nil
@ -41,16 +38,16 @@ public struct StableDiffusionPipeline {
/// - safetyChecker: Optional model for checking safety of generated images
/// - guidanceScale: Influence of the text prompt on generation process
/// - Returns: Pipeline ready for image generation
public init(textEncoder: TextEncoder,
unet: Unet,
decoder: Decoder,
safetyChecker: SafetyChecker? = nil,
guidanceScale: Float = 7.5) {
public init(
textEncoder: TextEncoder,
unet: Unet,
decoder: Decoder,
safetyChecker: SafetyChecker? = nil
) {
self.textEncoder = textEncoder
self.unet = unet
self.decoder = decoder
self.safetyChecker = safetyChecker
self.guidanceScale = guidanceScale
}
/// Text to image generation using stable diffusion
@ -60,6 +57,7 @@ public struct StableDiffusionPipeline {
/// - stepCount: Number of inference steps to perform
/// - imageCount: Number of samples/images to generate for the input prompt
/// - seed: Random seed which
/// - guidanceScale: Controls the influence of the text prompt on sampling process (0=random images)
/// - disableSafety: Safety checks are only performed if `self.canSafetyCheck && !disableSafety`
/// - progressHandler: Callback to perform after each step, stops on receiving false response
/// - Returns: An array of `imageCount` optional images.
@ -69,6 +67,7 @@ public struct StableDiffusionPipeline {
imageCount: Int = 1,
stepCount: Int = 50,
seed: Int = 0,
guidanceScale: Float = 7.5,
disableSafety: Bool = false,
progressHandler: (Progress) -> Bool = { _ in true }
) throws -> [CGImage?] {
@ -109,7 +108,7 @@ public struct StableDiffusionPipeline {
hiddenStates: hiddenStates
)
noise = performGuidance(noise)
noise = performGuidance(noise, guidanceScale)
// Have the scheduler compute the previous (t-1) latent
// sample given the predicted noise and current sample
@ -168,11 +167,11 @@ public struct StableDiffusionPipeline {
return states
}
func performGuidance(_ noise: [MLShapedArray<Float32>]) -> [MLShapedArray<Float32>] {
noise.map { performGuidance($0) }
func performGuidance(_ noise: [MLShapedArray<Float32>], _ guidanceScale: Float) -> [MLShapedArray<Float32>] {
noise.map { performGuidance($0, guidanceScale) }
}
func performGuidance(_ noise: MLShapedArray<Float32>) -> MLShapedArray<Float32> {
func performGuidance(_ noise: MLShapedArray<Float32>, _ guidanceScale: Float) -> MLShapedArray<Float32> {
let blankNoiseScalars = noise[0].scalars
let textNoiseScalars = noise[1].scalars

Loading…
Cancel
Save