diff --git a/swift/StableDiffusion/pipeline/StableDiffusionPipeline+SampleInput.swift b/swift/StableDiffusion/pipeline/StableDiffusionPipeline+SampleInput.swift index 0efe0ec..8933557 100644 --- a/swift/StableDiffusion/pipeline/StableDiffusionPipeline+SampleInput.swift +++ b/swift/StableDiffusion/pipeline/StableDiffusionPipeline+SampleInput.swift @@ -1,9 +1,5 @@ -// -// File.swift -// -// -// Created by Timothy Kautz on 1/30/23. -// +// For licensing see accompanying LICENSE.md file. +// Copyright (C) 2022 Apple Inc. All Rights Reserved. import Foundation import CoreGraphics @@ -19,7 +15,7 @@ extension StableDiffusionPipeline { } /// Image generation configuration - public struct SampleInput: Hashable { + public struct Configuration: Hashable { /// Text prompt to guide sampling public var prompt: String @@ -54,27 +50,9 @@ extension StableDiffusionPipeline { } public init( - prompt: String, - negativePrompt: String = "", - startingImage: CGImage? = nil, - strength: Float = 1.0, - imageCount: Int = 1, - stepCount: Int = 50, - seed: UInt32 = 0, - guidanceScale: Float = 7.5, - disableSafety: Bool = false, - schedulerType: StableDiffusionScheduler = .pndmScheduler + prompt: String ) { self.prompt = prompt - self.negativePrompt = negativePrompt - self.startingImage = startingImage - self.strength = strength - self.imageCount = imageCount - self.stepCount = stepCount - self.seed = seed - self.guidanceScale = guidanceScale - self.disableSafety = disableSafety - self.schedulerType = schedulerType } } diff --git a/swift/StableDiffusion/pipeline/StableDiffusionPipeline.swift b/swift/StableDiffusion/pipeline/StableDiffusionPipeline.swift index 8a9c350..f279247 100644 --- a/swift/StableDiffusion/pipeline/StableDiffusionPipeline.swift +++ b/swift/StableDiffusion/pipeline/StableDiffusionPipeline.swift @@ -114,13 +114,13 @@ public struct StableDiffusionPipeline: ResourceManaging { /// - Returns: An array of `imageCount` optional images. /// The images will be nil if safety checks were performed and found the result to be un-safe public func generateImages( - input: SampleInput, + configuration config: Configuration, progressHandler: (Progress) -> Bool = { _ in true } ) throws -> [CGImage?] { // Encode the input prompt and negative prompt - let promptEmbedding = try textEncoder.encode(input.prompt) - let negativePromptEmbedding = try textEncoder.encode(input.negativePrompt) + let promptEmbedding = try textEncoder.encode(config.prompt) + let negativePromptEmbedding = try textEncoder.encode(config.negativePrompt) if reduceMemory { textEncoder.unloadResources() @@ -136,10 +136,10 @@ public struct StableDiffusionPipeline: ResourceManaging { let hiddenStates = toHiddenStates(concatEmbedding) /// Setup schedulers - let scheduler: [Scheduler] = (0.. [MLShapedArray] { diff --git a/swift/StableDiffusionCLI/main.swift b/swift/StableDiffusionCLI/main.swift index 338891c..6b46046 100644 --- a/swift/StableDiffusionCLI/main.swift +++ b/swift/StableDiffusionCLI/main.swift @@ -118,25 +118,27 @@ struct StableDiffusionSample: ParsableCommand { let sampleTimer = SampleTimer() sampleTimer.start() + var pipelineConfig = StableDiffusionPipeline.Configuration(prompt: prompt) + + pipelineConfig.negativePrompt = negativePrompt + pipelineConfig.startingImage = startingImage + pipelineConfig.strength = strength + pipelineConfig.imageCount = imageCount + pipelineConfig.stepCount = stepCount + pipelineConfig.seed = seed + pipelineConfig.guidanceScale = guidanceScale + pipelineConfig.schedulerType = scheduler.stableDiffusionScheduler + let images = try pipeline.generateImages( - input: StableDiffusionPipeline.SampleInput( - prompt: prompt, - negativePrompt: negativePrompt, - startingImage: startingImage, - strength: strength, - imageCount: imageCount, - stepCount: stepCount, - seed: seed, - guidanceScale: guidanceScale, - schedulerType: scheduler.stableDiffusionScheduler) - ) { progress in - sampleTimer.stop() - handleProgress(progress,sampleTimer) - if progress.stepCount != progress.step { - sampleTimer.start() - } - return true - } + configuration: pipelineConfig, + progressHandler: { progress in + sampleTimer.stop() + handleProgress(progress,sampleTimer) + if progress.stepCount != progress.step { + sampleTimer.start() + } + return true + }) _ = try saveImages(images, logNames: true) }