// For licensing see accompanying LICENSE.md file. // Copyright (C) 2022 Apple Inc. All Rights Reserved. import Foundation import CoreML import Accelerate import CoreGraphics /// Schedulers compatible with StableDiffusionPipeline public enum StableDiffusionScheduler { /// Scheduler that uses a pseudo-linear multi-step (PLMS) method case pndmScheduler /// Scheduler that uses a second order DPM-Solver++ algorithm case dpmSolverMultistepScheduler } /// A pipeline used to generate image samples from text input using stable diffusion /// /// This implementation matches: /// [Hugging Face Diffusers Pipeline](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py) @available(iOS 16.2, macOS 13.1, *) public struct StableDiffusionPipeline: ResourceManaging { /// Model to generate embeddings for tokenized input text var textEncoder: TextEncoder /// Model used to predict noise residuals given an input, diffusion time step, and conditional embedding var unet: Unet /// Model used to generate final image from latent diffusion process var decoder: Decoder /// Optional model for checking safety of generated image var safetyChecker: SafetyChecker? = nil /// Reports whether this pipeline can perform safety checks public var canSafetyCheck: Bool { safetyChecker != nil } /// Option to reduce memory during image generation /// /// If true, the pipeline will lazily load TextEncoder, Unet, Decoder, and SafetyChecker /// when needed and aggressively unload their resources after /// /// This will increase latency in favor of reducing memory var reduceMemory: Bool = false /// Creates a pipeline using the specified models and tokenizer /// /// - Parameters: /// - textEncoder: Model for encoding tokenized text /// - unet: Model for noise prediction on latent samples /// - decoder: Model for decoding latent sample to image /// - safetyChecker: Optional model for checking safety of generated images /// - reduceMemory: Option to enable reduced memory mode /// - Returns: Pipeline ready for image generation public init(textEncoder: TextEncoder, unet: Unet, decoder: Decoder, safetyChecker: SafetyChecker? = nil, reduceMemory: Bool = false) { self.textEncoder = textEncoder self.unet = unet self.decoder = decoder self.safetyChecker = safetyChecker self.reduceMemory = reduceMemory } /// Load required resources for this pipeline /// /// If reducedMemory is true this will instead call prewarmResources instead /// and let the pipeline lazily load resources as needed public func loadResources() throws { if reduceMemory { try prewarmResources() } else { try textEncoder.loadResources() try unet.loadResources() try decoder.loadResources() try safetyChecker?.loadResources() } } /// Unload the underlying resources to free up memory public func unloadResources() { textEncoder.unloadResources() unet.unloadResources() decoder.unloadResources() safetyChecker?.unloadResources() } // Prewarm resources one at a time public func prewarmResources() throws { try textEncoder.prewarmResources() try unet.prewarmResources() try decoder.prewarmResources() try safetyChecker?.prewarmResources() } /// Text to image generation using stable diffusion /// /// - Parameters: /// - prompt: Text prompt to guide sampling /// - negativePrompt: Negative text prompt to guide sampling /// - stepCount: Number of inference steps to perform /// - imageCount: Number of samples/images to generate for the input prompt /// - seed: Random seed which /// - guidanceScale: Controls the influence of the text prompt on sampling process (0=random images) /// - disableSafety: Safety checks are only performed if `self.canSafetyCheck && !disableSafety` /// - progressHandler: Callback to perform after each step, stops on receiving false response /// - Returns: An array of `imageCount` optional images. /// The images will be nil if safety checks were performed and found the result to be un-safe public func generateImages( prompt: String, negativePrompt: String = "", imageCount: Int = 1, stepCount: Int = 50, seed: UInt32 = 0, guidanceScale: Float = 7.5, disableSafety: Bool = false, scheduler: StableDiffusionScheduler = .pndmScheduler, progressHandler: (Progress) -> Bool = { _ in true } ) throws -> [CGImage?] { // Encode the input prompt and negative prompt let promptEmbedding = try textEncoder.encode(prompt) let negativePromptEmbedding = try textEncoder.encode(negativePrompt) if reduceMemory { textEncoder.unloadResources() } // Convert to Unet hidden state representation // Concatenate the prompt and negative prompt embeddings let concatEmbedding = MLShapedArray( concatenating: [negativePromptEmbedding, promptEmbedding], alongAxis: 0 ) let hiddenStates = toHiddenStates(concatEmbedding) /// Setup schedulers let scheduler: [Scheduler] = (0..(concatenating: [$0, $0], alongAxis: 0) } // Predict noise residuals from latent samples // and current time step conditioned on hidden states var noise = try unet.predictNoise( latents: latentUnetInput, timeStep: t, hiddenStates: hiddenStates ) noise = performGuidance(noise, guidanceScale) // Have the scheduler compute the previous (t-1) latent // sample given the predicted noise and current sample for i in 0.. [MLShapedArray] { var sampleShape = unet.latentSampleShape sampleShape[0] = 1 var random = NumPyRandomSource(seed: seed) let samples = (0..( converting: random.normalShapedArray(sampleShape, mean: 0.0, stdev: Double(stdev))) } return samples } func toHiddenStates(_ embedding: MLShapedArray) -> MLShapedArray { // Unoptimized manual transpose [0, 2, None, 1] // e.g. From [2, 77, 768] to [2, 768, 1, 77] let fromShape = embedding.shape let stateShape = [fromShape[0],fromShape[2], 1, fromShape[1]] var states = MLShapedArray(repeating: 0.0, shape: stateShape) for i0 in 0..], _ guidanceScale: Float) -> [MLShapedArray] { noise.map { performGuidance($0, guidanceScale) } } func performGuidance(_ noise: MLShapedArray, _ guidanceScale: Float) -> MLShapedArray { let blankNoiseScalars = noise[0].scalars let textNoiseScalars = noise[1].scalars var resultScalars = blankNoiseScalars for i in 0..(scalars: resultScalars, shape: shape) } func decodeToImages(_ latents: [MLShapedArray], disableSafety: Bool) throws -> [CGImage?] { let images = try decoder.decode(latents) if reduceMemory { decoder.unloadResources() } // If safety is disabled return what was decoded if disableSafety { return images } // If there is no safety checker return what was decoded guard let safetyChecker = safetyChecker else { return images } // Otherwise change images which are not safe to nil let safeImages = try images.map { image in try safetyChecker.isSafe(image) ? image : nil } if reduceMemory { safetyChecker.unloadResources() } return safeImages } } @available(iOS 16.2, macOS 13.1, *) extension StableDiffusionPipeline { /// Sampling progress details public struct Progress { public let pipeline: StableDiffusionPipeline public let prompt: String public let step: Int public let stepCount: Int public let currentLatentSamples: [MLShapedArray] public let isSafetyEnabled: Bool public var currentImages: [CGImage?] { try! pipeline.decodeToImages( currentLatentSamples, disableSafety: !isSafetyEnabled) } } }