Changed seed type into UInt32 (#47)

pull/61/head
Wanaldino Antimonio 2 years ago committed by GitHub
parent 37ff2a430f
commit 4c00b32010
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -119,7 +119,7 @@ public struct StableDiffusionPipeline: ResourceManaging {
prompt: String, prompt: String,
imageCount: Int = 1, imageCount: Int = 1,
stepCount: Int = 50, stepCount: Int = 50,
seed: Int = 0, seed: UInt32 = 0,
disableSafety: Bool = false, disableSafety: Bool = false,
scheduler: StableDiffusionScheduler = .pndmScheduler, scheduler: StableDiffusionScheduler = .pndmScheduler,
progressHandler: (Progress) -> Bool = { _ in true } progressHandler: (Progress) -> Bool = { _ in true }
@ -205,11 +205,11 @@ public struct StableDiffusionPipeline: ResourceManaging {
return try decodeToImages(latents, disableSafety: disableSafety) return try decodeToImages(latents, disableSafety: disableSafety)
} }
func generateLatentSamples(_ count: Int, stdev: Float, seed: Int) -> [MLShapedArray<Float32>] { func generateLatentSamples(_ count: Int, stdev: Float, seed: UInt32) -> [MLShapedArray<Float32>] {
var sampleShape = unet.latentSampleShape var sampleShape = unet.latentSampleShape
sampleShape[0] = 1 sampleShape[0] = 1
var random = NumPyRandomSource(seed: UInt32(truncatingIfNeeded: seed)) var random = NumPyRandomSource(seed: seed)
let samples = (0..<count).map { _ in let samples = (0..<count).map { _ in
MLShapedArray<Float32>( MLShapedArray<Float32>(
converting: random.normalShapedArray(sampleShape, mean: 0.0, stdev: Double(stdev))) converting: random.normalShapedArray(sampleShape, mean: 0.0, stdev: Double(stdev)))

@ -48,7 +48,7 @@ struct StableDiffusionSample: ParsableCommand {
var outputPath: String = "./" var outputPath: String = "./"
@Option(help: "Random seed") @Option(help: "Random seed")
var seed: Int = 93 var seed: UInt32 = 93
@Option(help: "Compute units to load model with {all,cpuOnly,cpuAndGPU,cpuAndNeuralEngine}") @Option(help: "Compute units to load model with {all,cpuOnly,cpuAndGPU,cpuAndNeuralEngine}")
var computeUnits: ComputeUnits = .all var computeUnits: ComputeUnits = .all

Loading…
Cancel
Save