diff --git a/swift/StableDiffusion/pipeline/StableDiffusionPipeline+SampleInput.swift b/swift/StableDiffusion/pipeline/StableDiffusionPipeline+SampleInput.swift new file mode 100644 index 0000000..0efe0ec --- /dev/null +++ b/swift/StableDiffusion/pipeline/StableDiffusionPipeline+SampleInput.swift @@ -0,0 +1,82 @@ +// +// File.swift +// +// +// Created by Timothy Kautz on 1/30/23. +// + +import Foundation +import CoreGraphics + +@available(iOS 16.2, macOS 13.1, *) +extension StableDiffusionPipeline { + + /// Tyoe of processing that will be performed to generate an image + public enum Mode { + case textToImage + case imageToImage + // case inPainting + } + + /// Image generation configuration + public struct SampleInput: Hashable { + + /// Text prompt to guide sampling + public var prompt: String + /// Negative text prompt to guide sampling + public var negativePrompt: String = "" + /// Starting image for image2image or in-painting + public var startingImage: CGImage? = nil + //public var maskImage: CGImage? = nil + public var strength: Float = 1.0 + /// Number of images to generate + public var imageCount: Int = 1 + /// Number of inference steps to perform + public var stepCount: Int = 50 + /// Random seed which to start generation + public var seed: UInt32 = 0 + /// Controls the influence of the text prompt on sampling process (0=random images) + public var guidanceScale: Float = 7.5 + /// Safety checks are only performed if `self.canSafetyCheck && !disableSafety` + public var disableSafety: Bool = false + /// The type of Scheduler to use. + public var schedulerType: StableDiffusionScheduler = .pndmScheduler + + /// Given the configuration, what mode will be used for generation + public var mode: Mode { + guard startingImage != nil else { + return .textToImage + } + guard strength < 1.0 else { + return .textToImage + } + return .imageToImage + } + + public init( + prompt: String, + negativePrompt: String = "", + startingImage: CGImage? = nil, + strength: Float = 1.0, + imageCount: Int = 1, + stepCount: Int = 50, + seed: UInt32 = 0, + guidanceScale: Float = 7.5, + disableSafety: Bool = false, + schedulerType: StableDiffusionScheduler = .pndmScheduler + ) { + self.prompt = prompt + self.negativePrompt = negativePrompt + self.startingImage = startingImage + self.strength = strength + self.imageCount = imageCount + self.stepCount = stepCount + self.seed = seed + self.guidanceScale = guidanceScale + self.disableSafety = disableSafety + self.schedulerType = schedulerType + } + + } + +} diff --git a/swift/StableDiffusion/pipeline/StableDiffusionPipeline.swift b/swift/StableDiffusion/pipeline/StableDiffusionPipeline.swift index 0157820..8a9c350 100644 --- a/swift/StableDiffusion/pipeline/StableDiffusionPipeline.swift +++ b/swift/StableDiffusion/pipeline/StableDiffusionPipeline.swift @@ -25,36 +25,6 @@ public struct StableDiffusionPipeline: ResourceManaging { case startingImageProvidedWithoutEncoder } - public enum Mode { - case textToImage - case imageToImage - // case inPainting - } - - public struct SampleInput: Hashable { - public var prompt: String - public var negativePrompt: String = "" - public var startingImage: CGImage? = nil - //public var maskImage: CGImage? = nil - public var strength: Float = 1.0 - public var imageCount: Int = 1 - public var stepCount: Int = 50 - public var seed: UInt32 = 0 - public var guidanceScale: Float = 7.5 - public var disableSafety: Bool = false - public var schedulerType: StableDiffusionScheduler = .pndmScheduler - - public var mode: Mode { - guard startingImage != nil else { - return .textToImage - } - guard strength < 1.0 else { - return .textToImage - } - return .imageToImage - } - } - /// Model to generate embeddings for tokenized input text var textEncoder: TextEncoder @@ -137,45 +107,12 @@ public struct StableDiffusionPipeline: ResourceManaging { try safetyChecker?.prewarmResources() } - /// Text to image generation using stable diffusion - /// + /// Image generation using stable diffusion /// - Parameters: - /// - prompt: Text prompt to guide sampling - /// - negativePrompt: Negative text prompt to guide sampling - /// - stepCount: Number of inference steps to perform - /// - imageCount: Number of samples/images to generate for the input prompt - /// - seed: Random seed which - /// - guidanceScale: Controls the influence of the text prompt on sampling process (0=random images) /// - disableSafety: Safety checks are only performed if `self.canSafetyCheck && !disableSafety` /// - progressHandler: Callback to perform after each step, stops on receiving false response /// - Returns: An array of `imageCount` optional images. /// The images will be nil if safety checks were performed and found the result to be un-safe - public func generateImages( - prompt: String, - negativePrompt: String = "", - startingImage: CGImage? = nil, - strength: Float = 1.0, - imageCount: Int = 1, - stepCount: Int = 50, - seed: UInt32 = 0, - guidanceScale: Float = 7.5, - disableSafety: Bool = false, - scheduler schedulerType: StableDiffusionScheduler = .pndmScheduler, - progressHandler: (Progress) -> Bool = { _ in true } - ) throws -> [CGImage?] { - try generateImages(input: SampleInput( - prompt: prompt, - negativePrompt: negativePrompt, - startingImage: startingImage, - strength: strength, - imageCount: imageCount, - stepCount: stepCount, - seed: seed, - guidanceScale: guidanceScale, - disableSafety: disableSafety, - schedulerType: schedulerType), progressHandler: progressHandler) - } - public func generateImages( input: SampleInput, progressHandler: (Progress) -> Bool = { _ in true } diff --git a/swift/StableDiffusionCLI/main.swift b/swift/StableDiffusionCLI/main.swift index c0ec63d..338891c 100644 --- a/swift/StableDiffusionCLI/main.swift +++ b/swift/StableDiffusionCLI/main.swift @@ -119,15 +119,16 @@ struct StableDiffusionSample: ParsableCommand { sampleTimer.start() let images = try pipeline.generateImages( - prompt: prompt, - negativePrompt: negativePrompt, - startingImage: startingImage, - strength: strength, - imageCount: imageCount, - stepCount: stepCount, - seed: seed, - guidanceScale: guidanceScale, - scheduler: scheduler.stableDiffusionScheduler + input: StableDiffusionPipeline.SampleInput( + prompt: prompt, + negativePrompt: negativePrompt, + startingImage: startingImage, + strength: strength, + imageCount: imageCount, + stepCount: stepCount, + seed: seed, + guidanceScale: guidanceScale, + schedulerType: scheduler.stableDiffusionScheduler) ) { progress in sampleTimer.stop() handleProgress(progress,sampleTimer)