Move guidanceScale as generation parameter (#46)

* Move guidanceScale as generation parameter

* Added guidanceScale in CLI

* Reverted identation change
pull/88/head
Wanaldino Antimonio 1 year ago committed by GitHub
parent 877ccb9666
commit e07c4d00c3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -33,9 +33,6 @@ public struct StableDiffusionPipeline: ResourceManaging {
/// Optional model for checking safety of generated image /// Optional model for checking safety of generated image
var safetyChecker: SafetyChecker? = nil var safetyChecker: SafetyChecker? = nil
/// Controls the influence of the text prompt on sampling process (0=random images)
var guidanceScale: Float = 7.5
/// Reports whether this pipeline can perform safety checks /// Reports whether this pipeline can perform safety checks
public var canSafetyCheck: Bool { public var canSafetyCheck: Bool {
safetyChecker != nil safetyChecker != nil
@ -56,20 +53,17 @@ public struct StableDiffusionPipeline: ResourceManaging {
/// - unet: Model for noise prediction on latent samples /// - unet: Model for noise prediction on latent samples
/// - decoder: Model for decoding latent sample to image /// - decoder: Model for decoding latent sample to image
/// - safetyChecker: Optional model for checking safety of generated images /// - safetyChecker: Optional model for checking safety of generated images
/// - guidanceScale: Influence of the text prompt on generation process
/// - reduceMemory: Option to enable reduced memory mode /// - reduceMemory: Option to enable reduced memory mode
/// - Returns: Pipeline ready for image generation /// - Returns: Pipeline ready for image generation
public init(textEncoder: TextEncoder, public init(textEncoder: TextEncoder,
unet: Unet, unet: Unet,
decoder: Decoder, decoder: Decoder,
safetyChecker: SafetyChecker? = nil, safetyChecker: SafetyChecker? = nil,
guidanceScale: Float = 7.5,
reduceMemory: Bool = false) { reduceMemory: Bool = false) {
self.textEncoder = textEncoder self.textEncoder = textEncoder
self.unet = unet self.unet = unet
self.decoder = decoder self.decoder = decoder
self.safetyChecker = safetyChecker self.safetyChecker = safetyChecker
self.guidanceScale = guidanceScale
self.reduceMemory = reduceMemory self.reduceMemory = reduceMemory
} }
@ -112,6 +106,7 @@ public struct StableDiffusionPipeline: ResourceManaging {
/// - stepCount: Number of inference steps to perform /// - stepCount: Number of inference steps to perform
/// - imageCount: Number of samples/images to generate for the input prompt /// - imageCount: Number of samples/images to generate for the input prompt
/// - seed: Random seed which /// - seed: Random seed which
/// - guidanceScale: Controls the influence of the text prompt on sampling process (0=random images)
/// - disableSafety: Safety checks are only performed if `self.canSafetyCheck && !disableSafety` /// - disableSafety: Safety checks are only performed if `self.canSafetyCheck && !disableSafety`
/// - progressHandler: Callback to perform after each step, stops on receiving false response /// - progressHandler: Callback to perform after each step, stops on receiving false response
/// - Returns: An array of `imageCount` optional images. /// - Returns: An array of `imageCount` optional images.
@ -122,6 +117,7 @@ public struct StableDiffusionPipeline: ResourceManaging {
imageCount: Int = 1, imageCount: Int = 1,
stepCount: Int = 50, stepCount: Int = 50,
seed: UInt32 = 0, seed: UInt32 = 0,
guidanceScale: Float = 7.5,
disableSafety: Bool = false, disableSafety: Bool = false,
scheduler: StableDiffusionScheduler = .pndmScheduler, scheduler: StableDiffusionScheduler = .pndmScheduler,
progressHandler: (Progress) -> Bool = { _ in true } progressHandler: (Progress) -> Bool = { _ in true }
@ -173,7 +169,7 @@ public struct StableDiffusionPipeline: ResourceManaging {
hiddenStates: hiddenStates hiddenStates: hiddenStates
) )
noise = performGuidance(noise) noise = performGuidance(noise, guidanceScale)
// Have the scheduler compute the previous (t-1) latent // Have the scheduler compute the previous (t-1) latent
// sample given the predicted noise and current sample // sample given the predicted noise and current sample
@ -236,11 +232,11 @@ public struct StableDiffusionPipeline: ResourceManaging {
return states return states
} }
func performGuidance(_ noise: [MLShapedArray<Float32>]) -> [MLShapedArray<Float32>] { func performGuidance(_ noise: [MLShapedArray<Float32>], _ guidanceScale: Float) -> [MLShapedArray<Float32>] {
noise.map { performGuidance($0) } noise.map { performGuidance($0, guidanceScale) }
} }
func performGuidance(_ noise: MLShapedArray<Float32>) -> MLShapedArray<Float32> { func performGuidance(_ noise: MLShapedArray<Float32>, _ guidanceScale: Float) -> MLShapedArray<Float32> {
let blankNoiseScalars = noise[0].scalars let blankNoiseScalars = noise[0].scalars
let textNoiseScalars = noise[1].scalars let textNoiseScalars = noise[1].scalars

@ -53,6 +53,9 @@ struct StableDiffusionSample: ParsableCommand {
@Option(help: "Random seed") @Option(help: "Random seed")
var seed: UInt32 = 93 var seed: UInt32 = 93
@Option(help: "Controls the influence of the text prompt on sampling process (0=random images)")
var guidanceScale: Float = 7.5
@Option(help: "Compute units to load model with {all,cpuOnly,cpuAndGPU,cpuAndNeuralEngine}") @Option(help: "Compute units to load model with {all,cpuOnly,cpuAndGPU,cpuAndNeuralEngine}")
var computeUnits: ComputeUnits = .all var computeUnits: ComputeUnits = .all
@ -92,6 +95,7 @@ struct StableDiffusionSample: ParsableCommand {
imageCount: imageCount, imageCount: imageCount,
stepCount: stepCount, stepCount: stepCount,
seed: seed, seed: seed,
guidanceScale: guidanceScale,
scheduler: scheduler.stableDiffusionScheduler scheduler: scheduler.stableDiffusionScheduler
) { progress in ) { progress in
sampleTimer.stop() sampleTimer.stop()

Loading…
Cancel
Save