pull/116/head
Timothy Kautz 2 years ago
parent c199a4225d
commit c17c80ff8b

@ -4,9 +4,9 @@
import Foundation import Foundation
import CoreML import CoreML
@available(iOS 16.0, macOS 13.0, *) @available(iOS 16.0, macOS 13.1, *)
/// Encoder, currently supports image2image /// Encoder, currently supports image2image
public struct Encoder { public struct Encoder: ResourceManaging {
public enum Error: String, Swift.Error { public enum Error: String, Swift.Error {
case latentOutputNotValid case latentOutputNotValid
@ -14,14 +14,26 @@ public struct Encoder {
} }
/// VAE encoder model + post math and adding noise from schedular /// VAE encoder model + post math and adding noise from schedular
var model: MLModel var model: ManagedMLModel
/// Create decoder from Core ML model /// Create encoder from Core ML model
/// ///
/// - Parameters /// - Parameters:
/// - model: Core ML model for VAE decoder /// - url: Location of compiled VAE encoder Core ML model
public init(model: MLModel) { /// - configuration: configuration to be used when the model is loaded
self.model = model /// - Returns: An encoder that will lazily load its required resources when needed or requested
public init(modelAt url: URL, configuration: MLModelConfiguration) {
self.model = ManagedMLModel(modelAt: url, configuration: configuration)
}
/// Ensure the model has been loaded into memory
public func loadResources() throws {
try model.loadResources()
}
/// Unload the underlying model to free up memory
public func unloadResources() {
model.unloadResources()
} }
/// Prediction queue /// Prediction queue
@ -46,17 +58,22 @@ public struct Encoder {
let dict: [String: Any] = [ let dict: [String: Any] = [
"sample": MLMultiArray(sample), "sample": MLMultiArray(sample),
"diagonalNoise": MLMultiArray(diagonalNoise), "diagonal_noise": MLMultiArray(diagonalNoise),
"noise": MLMultiArray(noise), "noise": MLMultiArray(noise),
"sqrtAlphasCumprod": MLMultiArray(sqrtAlphasCumprod), "sqrt_alphas_cumprod": MLMultiArray(sqrtAlphasCumprod),
"sqrtOneMinusAlphasCumprod": MLMultiArray(sqrtOneMinusAlphasCumprod), "sqrt_one_minus_alphas_cumprod": MLMultiArray(sqrtOneMinusAlphasCumprod),
] ]
let featureProvider = try MLDictionaryFeatureProvider(dictionary: dict) let featureProvider = try MLDictionaryFeatureProvider(dictionary: dict)
let batch = MLArrayBatchProvider(array: [featureProvider]) let batch = MLArrayBatchProvider(array: [featureProvider])
// Batch predict with model // Batch predict with model
let results = try queue.sync { try model.predictions(fromBatch: batch) }
let results = try queue.sync {
try model.perform { model in
try model.predictions(fromBatch: batch)
}
}
let batchLatents: [MLShapedArray<Float32>] = try (0..<results.count).compactMap { i in let batchLatents: [MLShapedArray<Float32>] = try (0..<results.count).compactMap { i in
let result = results.features(at: i) let result = results.features(at: i)

@ -79,10 +79,8 @@ public extension StableDiffusionPipeline {
// Optional Image Encoder // Optional Image Encoder
let encoder: Encoder? let encoder: Encoder?
if if FileManager.default.fileExists(atPath: urls.encoderURL.path) {
let encoderModel = try? MLModel(contentsOf: urls.encoderURL, configuration: config) encoder = Encoder(modelAt: urls.encoderURL, configuration: config)
{
encoder = Encoder(model: encoderModel)
} else { } else {
encoder = nil encoder = nil
} }

@ -25,6 +25,36 @@ public struct StableDiffusionPipeline: ResourceManaging {
case startingImageProvidedWithoutEncoder case startingImageProvidedWithoutEncoder
} }
public enum Mode {
case textToImage
case imageToImage
// case inPainting
}
public struct SampleInput: Hashable {
public var prompt: String
public var negativePrompt: String = ""
public var startingImage: CGImage? = nil
//public var maskImage: CGImage? = nil
public var strength: Float = 1.0
public var imageCount: Int = 1
public var stepCount: Int = 50
public var seed: UInt32 = 0
public var guidanceScale: Float = 7.5
public var disableSafety: Bool = false
public var schedulerType: StableDiffusionScheduler = .pndmScheduler
public var mode: Mode {
guard startingImage != nil else {
return .textToImage
}
guard strength < 1.0 else {
return .textToImage
}
return .imageToImage
}
}
/// Model to generate embeddings for tokenized input text /// Model to generate embeddings for tokenized input text
var textEncoder: TextEncoder var textEncoder: TextEncoder
@ -133,10 +163,27 @@ public struct StableDiffusionPipeline: ResourceManaging {
scheduler schedulerType: StableDiffusionScheduler = .pndmScheduler, scheduler schedulerType: StableDiffusionScheduler = .pndmScheduler,
progressHandler: (Progress) -> Bool = { _ in true } progressHandler: (Progress) -> Bool = { _ in true }
) throws -> [CGImage?] { ) throws -> [CGImage?] {
try generateImages(input: SampleInput(
prompt: prompt,
negativePrompt: negativePrompt,
startingImage: startingImage,
strength: strength,
imageCount: imageCount,
stepCount: stepCount,
seed: seed,
guidanceScale: guidanceScale,
disableSafety: disableSafety,
schedulerType: schedulerType), progressHandler: progressHandler)
}
public func generateImages(
input: SampleInput,
progressHandler: (Progress) -> Bool = { _ in true }
) throws -> [CGImage?] {
// Encode the input prompt and negative prompt // Encode the input prompt and negative prompt
let promptEmbedding = try textEncoder.encode(prompt) let promptEmbedding = try textEncoder.encode(input.prompt)
let negativePromptEmbedding = try textEncoder.encode(negativePrompt) let negativePromptEmbedding = try textEncoder.encode(input.negativePrompt)
if reduceMemory { if reduceMemory {
textEncoder.unloadResources() textEncoder.unloadResources()
@ -152,10 +199,10 @@ public struct StableDiffusionPipeline: ResourceManaging {
let hiddenStates = toHiddenStates(concatEmbedding) let hiddenStates = toHiddenStates(concatEmbedding)
/// Setup schedulers /// Setup schedulers
let scheduler: [Scheduler] = (0..<imageCount).map { _ in let scheduler: [Scheduler] = (0..<input.imageCount).map { _ in
switch schedulerType { switch input.schedulerType {
case .pndmScheduler: return PNDMScheduler(stepCount: stepCount) case .pndmScheduler: return PNDMScheduler(stepCount: input.stepCount)
case .dpmSolverMultistepScheduler: return DPMSolverMultistepScheduler(stepCount: stepCount) case .dpmSolverMultistepScheduler: return DPMSolverMultistepScheduler(stepCount: input.stepCount)
} }
} }
let stdev = scheduler[0].initNoiseSigma let stdev = scheduler[0].initNoiseSigma
@ -164,24 +211,27 @@ public struct StableDiffusionPipeline: ResourceManaging {
var latents: [MLShapedArray<Float32>] var latents: [MLShapedArray<Float32>]
let timestepStrength: Float? let timestepStrength: Float?
if let startingImage { if
timestepStrength = strength let startingImage = input.startingImage,
input.mode == .imageToImage
{
timestepStrength = input.strength
guard let encoder else { guard let encoder else {
throw Error.startingImageProvidedWithoutEncoder throw Error.startingImageProvidedWithoutEncoder
} }
let noiseTuples = generateImage2ImageLatentSamples(imageCount, stdev: 1, seed: seed) let noiseTuples = generateImage2ImageLatentSamples(input.imageCount, stdev: 1, seed: input.seed)
latents = try noiseTuples.map({ latents = try noiseTuples.map({
try encoder.encode( try encoder.encode(
image: startingImage, image: startingImage,
diagonalNoise: $0.diagonal, diagonalNoise: $0.diagonal,
noise: $0.latentNoise, noise: $0.latentNoise,
alphasCumprodStep: scheduler[0].calculateAlphasCumprod(strength: strength)) alphasCumprodStep: scheduler[0].calculateAlphasCumprod(strength: input.strength))
}) })
} else { } else {
timestepStrength = nil timestepStrength = nil
// Generate random latent samples from specified seed // Generate random latent samples from specified seed
latents = generateLatentSamples(imageCount, stdev: stdev, seed: seed) latents = generateLatentSamples(input.imageCount, stdev: stdev, seed: input.seed)
} }
// De-noising loop // De-noising loop
@ -202,11 +252,11 @@ public struct StableDiffusionPipeline: ResourceManaging {
hiddenStates: hiddenStates hiddenStates: hiddenStates
) )
noise = performGuidance(noise, guidanceScale) noise = performGuidance(noise, input.guidanceScale)
// Have the scheduler compute the previous (t-1) latent // Have the scheduler compute the previous (t-1) latent
// sample given the predicted noise and current sample // sample given the predicted noise and current sample
for i in 0..<imageCount { for i in 0..<input.imageCount {
latents[i] = scheduler[i].step( latents[i] = scheduler[i].step(
output: noise[i], output: noise[i],
timeStep: t, timeStep: t,
@ -217,11 +267,11 @@ public struct StableDiffusionPipeline: ResourceManaging {
// Report progress // Report progress
let progress = Progress( let progress = Progress(
pipeline: self, pipeline: self,
prompt: prompt, prompt: input.prompt,
step: step, step: step,
stepCount: timeSteps.count, stepCount: timeSteps.count,
currentLatentSamples: latents, currentLatentSamples: latents,
isSafetyEnabled: canSafetyCheck && !disableSafety isSafetyEnabled: canSafetyCheck && !input.disableSafety
) )
if !progressHandler(progress) { if !progressHandler(progress) {
// Stop if requested by handler // Stop if requested by handler
@ -234,7 +284,7 @@ public struct StableDiffusionPipeline: ResourceManaging {
} }
// Decode the latent samples to images // Decode the latent samples to images
return try decodeToImages(latents, disableSafety: disableSafety) return try decodeToImages(latents, disableSafety: input.disableSafety)
} }
func generateLatentSamples(_ count: Int, stdev: Float, seed: UInt32) -> [MLShapedArray<Float32>] { func generateLatentSamples(_ count: Int, stdev: Float, seed: UInt32) -> [MLShapedArray<Float32>] {

@ -34,7 +34,7 @@ struct StableDiffusionSample: ParsableCommand {
var resourcePath: String = "./" var resourcePath: String = "./"
@Option(help: "Path to starting image.") @Option(help: "Path to starting image.")
var image: String = "none" var image: String? = nil
@Option(help: "Strength for image2image.") @Option(help: "Strength for image2image.")
var strength: Float = 0.5 var strength: Float = 0.5
@ -92,7 +92,7 @@ struct StableDiffusionSample: ParsableCommand {
try pipeline.loadResources() try pipeline.loadResources()
let startingImage: CGImage? let startingImage: CGImage?
if image != "none" { if let image {
let imageURL = URL(filePath: image) let imageURL = URL(filePath: image)
do { do {
let imageData = try Data(contentsOf: imageURL) let imageData = try Data(contentsOf: imageURL)

Loading…
Cancel
Save