Image2image - swift (#116)

* Image2Image Encoder

Encoder

* Scheduler and pipeline

* fix scheduler

* cli

* remove CLI comment

* disable dpm multistep solver with image2image

* clamp initial timestamp

* Store timesteps in reverse order for consistency.

* Report actual number of steps.

* uint32

* PRComments

* remove old initializer

* pr comments

* change name and add error handling

also fix 512 hard coded

* Add fix for Jpegs

---------

Co-authored-by: Pedro Cuenca <pedro@huggingface.co>
pull/125/head
Timothy Kautz 1 year ago committed by GitHub
parent f0fc326477
commit fa7bbdc255
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -0,0 +1,29 @@
// For licensing see accompanying LICENSE.md file.
// Copyright (C) 2022 Apple Inc. All Rights Reserved.
import Foundation
public struct AlphasCumprodCalculation {
public var sqrtAlphasCumprod: Float
public var sqrtOneMinusAlphasCumprod: Float
public init(
sqrtAlphasCumprod: Float,
sqrtOneMinusAlphasCumprod: Float
) {
self.sqrtAlphasCumprod = sqrtAlphasCumprod
self.sqrtOneMinusAlphasCumprod = sqrtOneMinusAlphasCumprod
}
public init(
alphasCumprod: [Float],
timesteps: Int = 1_000,
steps: Int,
strength: Float
) {
let tEnc = Int(strength * Float(steps))
let initTimestep = min(max(0, timesteps - timesteps / steps * (steps - tEnc) + 1), timesteps - 1)
self.sqrtAlphasCumprod = alphasCumprod[initTimestep].squareRoot()
self.sqrtOneMinusAlphasCumprod = (1 - alphasCumprod[initTimestep]).squareRoot()
}
}

@ -0,0 +1,120 @@
// For licensing see accompanying LICENSE.md file.
// Copyright (C) 2022 Apple Inc. All Rights Reserved.
import Foundation
import Accelerate
import CoreML
@available(iOS 16.0, macOS 13.0, *)
extension CGImage {
typealias PixelBufferPFx1 = vImage.PixelBuffer<vImage.PlanarF>
typealias PixelBufferP8x3 = vImage.PixelBuffer<vImage.Planar8x3>
typealias PixelBufferIFx3 = vImage.PixelBuffer<vImage.InterleavedFx3>
typealias PixelBufferI8x3 = vImage.PixelBuffer<vImage.Interleaved8x3>
public enum ShapedArrayError: String, Swift.Error {
case wrongNumberOfChannels
case incorrectFormatsConvertingToShapedArray
case vImageConverterNotInitialized
}
public static func fromShapedArray(_ array: MLShapedArray<Float32>) throws -> CGImage {
// array is [N,C,H,W], where C==3
let channelCount = array.shape[1]
guard channelCount == 3 else {
throw ShapedArrayError.wrongNumberOfChannels
}
let height = array.shape[2]
let width = array.shape[3]
// Normalize each channel into a float between 0 and 1.0
let floatChannels = (0..<channelCount).map { i in
// Normalized channel output
let cOut = PixelBufferPFx1(width: width, height:height)
// Reference this channel in the array and normalize
array[0][i].withUnsafeShapedBufferPointer { ptr, _, strides in
let cIn = PixelBufferPFx1(data: .init(mutating: ptr.baseAddress!),
width: width, height: height,
byteCountPerRow: strides[0]*4)
// Map [-1.0 1.0] -> [0.0 1.0]
cIn.multiply(by: 0.5, preBias: 1.0, postBias: 0.0, destination: cOut)
}
return cOut
}
// Convert to interleaved and then to UInt8
let floatImage = PixelBufferIFx3(planarBuffers: floatChannels)
let uint8Image = PixelBufferI8x3(width: width, height: height)
floatImage.convert(to:uint8Image) // maps [0.0 1.0] -> [0 255] and clips
// Convert to uint8x3 to RGB CGImage (no alpha)
let bitmapInfo = CGBitmapInfo(rawValue: CGImageAlphaInfo.none.rawValue)
let cgImage = uint8Image.makeCGImage(cgImageFormat:
.init(bitsPerComponent: 8,
bitsPerPixel: 3*8,
colorSpace: CGColorSpaceCreateDeviceRGB(),
bitmapInfo: bitmapInfo)!)!
return cgImage
}
public var plannerRGBShapedArray: MLShapedArray<Float32> {
get throws {
guard
var sourceFormat = vImage_CGImageFormat(cgImage: self),
var mediumFormat = vImage_CGImageFormat(
bitsPerComponent: 8 * MemoryLayout<UInt8>.size,
bitsPerPixel: 8 * MemoryLayout<UInt8>.size * 4,
colorSpace: CGColorSpaceCreateDeviceRGB(),
bitmapInfo: CGBitmapInfo(rawValue: CGImageAlphaInfo.first.rawValue)),
let width = vImagePixelCount(exactly: self.width),
let height = vImagePixelCount(exactly: self.height)
else {
throw ShapedArrayError.incorrectFormatsConvertingToShapedArray
}
var sourceImageBuffer = try vImage_Buffer(cgImage: self)
var mediumDesination = try vImage_Buffer(width: Int(width), height: Int(height), bitsPerPixel: mediumFormat.bitsPerPixel)
let converter = vImageConverter_CreateWithCGImageFormat(
&sourceFormat,
&mediumFormat,
nil,
vImage_Flags(kvImagePrintDiagnosticsToConsole),
nil)
guard let converter = converter?.takeRetainedValue() else {
throw ShapedArrayError.vImageConverterNotInitialized
}
vImageConvert_AnyToAny(converter, &sourceImageBuffer, &mediumDesination, nil, vImage_Flags(kvImagePrintDiagnosticsToConsole))
var destinationA = try vImage_Buffer(width: Int(width), height: Int(height), bitsPerPixel: 8 * UInt32(MemoryLayout<Float>.size))
var destinationR = try vImage_Buffer(width: Int(width), height: Int(height), bitsPerPixel: 8 * UInt32(MemoryLayout<Float>.size))
var destinationG = try vImage_Buffer(width: Int(width), height: Int(height), bitsPerPixel: 8 * UInt32(MemoryLayout<Float>.size))
var destinationB = try vImage_Buffer(width: Int(width), height: Int(height), bitsPerPixel: 8 * UInt32(MemoryLayout<Float>.size))
var minFloat: [Float] = [-1.0, -1.0, -1.0, -1.0]
var maxFloat: [Float] = [1.0, 1.0, 1.0, 1.0]
vImageConvert_ARGB8888toPlanarF(&mediumDesination, &destinationA, &destinationR, &destinationG, &destinationB, &maxFloat, &minFloat, .zero)
let redData = Data(bytes: destinationR.data, count: Int(width) * Int(height) * MemoryLayout<Float>.size)
let greenData = Data(bytes: destinationG.data, count: Int(width) * Int(height) * MemoryLayout<Float>.size)
let blueData = Data(bytes: destinationB.data, count: Int(width) * Int(height) * MemoryLayout<Float>.size)
let imageData = redData + greenData + blueData
let shapedArray = MLShapedArray<Float32>(data: imageData, shape: [1, 3, self.width, self.height])
return shapedArray
}
}
}

@ -3,7 +3,6 @@
import Foundation
import CoreML
import Accelerate
/// A decoder model which produces RGB images from latent samples
@available(iOS 16.2, macOS 13.1, *)
@ -57,12 +56,11 @@ public struct Decoder: ResourceManaging {
}
// Transform the outputs to CGImages
let images: [CGImage] = (0..<results.count).map { i in
let images: [CGImage] = try (0..<results.count).map { i in
let result = results.features(at: i)
let outputName = result.featureNames.first!
let output = result.featureValue(for: outputName)!.multiArrayValue!
return toRGBCGImage(MLShapedArray<Float32>(output))
return try CGImage.fromShapedArray(MLShapedArray<Float32>(output))
}
return images
@ -74,50 +72,4 @@ public struct Decoder: ResourceManaging {
}
}
typealias PixelBufferPFx1 = vImage.PixelBuffer<vImage.PlanarF>
typealias PixelBufferP8x3 = vImage.PixelBuffer<vImage.Planar8x3>
typealias PixelBufferIFx3 = vImage.PixelBuffer<vImage.InterleavedFx3>
typealias PixelBufferI8x3 = vImage.PixelBuffer<vImage.Interleaved8x3>
func toRGBCGImage(_ array: MLShapedArray<Float32>) -> CGImage {
// array is [N,C,H,W], where C==3
let channelCount = array.shape[1]
assert(channelCount == 3,
"Decoding model output has \(channelCount) channels, expected 3")
let height = array.shape[2]
let width = array.shape[3]
// Normalize each channel into a float between 0 and 1.0
let floatChannels = (0..<channelCount).map { i in
// Normalized channel output
let cOut = PixelBufferPFx1(width: width, height:height)
// Reference this channel in the array and normalize
array[0][i].withUnsafeShapedBufferPointer { ptr, _, strides in
let cIn = PixelBufferPFx1(data: .init(mutating: ptr.baseAddress!),
width: width, height: height,
byteCountPerRow: strides[0]*4)
// Map [-1.0 1.0] -> [0.0 1.0]
cIn.multiply(by: 0.5, preBias: 1.0, postBias: 0.0, destination: cOut)
}
return cOut
}
// Convert to interleaved and then to UInt8
let floatImage = PixelBufferIFx3(planarBuffers: floatChannels)
let uint8Image = PixelBufferI8x3(width: width, height: height)
floatImage.convert(to:uint8Image) // maps [0.0 1.0] -> [0 255] and clips
// Convert to uint8x3 to RGB CGImage (no alpha)
let bitmapInfo = CGBitmapInfo(rawValue: CGImageAlphaInfo.none.rawValue)
let cgImage = uint8Image.makeCGImage(cgImageFormat:
.init(bitsPerComponent: 8,
bitsPerPixel: 3*8,
colorSpace: CGColorSpaceCreateDeviceRGB(),
bitmapInfo: bitmapInfo)!)!
return cgImage
}
}

@ -0,0 +1,131 @@
// For licensing see accompanying LICENSE.md file.
// Copyright (C) 2022 Apple Inc. All Rights Reserved.
import Foundation
import CoreML
@available(iOS 16.0, macOS 13.1, *)
/// Encoder, currently supports image2image
public struct Encoder: ResourceManaging {
public enum FeatureName: String {
case sample = "sample"
case diagonalNoise = "diagonal_noise"
case noise = "noise"
case sqrtAlphasCumprod = "sqrt_alphas_cumprod"
case sqrtOneMinusAlphasCumprod = "sqrt_one_minus_alphas_cumprod"
}
public enum Error: String, Swift.Error {
case latentOutputNotValid
case batchLatentOutputEmpty
case sampleInputShapeNotCorrect
case noiseInputShapeNotCorrect
}
/// VAE encoder model + post math and adding noise from schedular
var model: ManagedMLModel
/// Create encoder from Core ML model
///
/// - Parameters:
/// - url: Location of compiled VAE encoder Core ML model
/// - configuration: configuration to be used when the model is loaded
/// - Returns: An encoder that will lazily load its required resources when needed or requested
public init(modelAt url: URL, configuration: MLModelConfiguration) {
self.model = ManagedMLModel(modelAt: url, configuration: configuration)
}
/// Ensure the model has been loaded into memory
public func loadResources() throws {
try model.loadResources()
}
/// Unload the underlying model to free up memory
public func unloadResources() {
model.unloadResources()
}
/// Prediction queue
let queue = DispatchQueue(label: "encoder.predict")
/// Batch encode latent samples into images
/// - Parameters:
/// - image: image used for image2image
/// - diagonalNoise: random noise for `DiagonalGaussianDistribution` operation
/// - noise: random noise for initial latent space based on strength argument
/// - alphasCumprodStep: calculations using the scheduler traditionally calculated in the pipeline in pyTorch Diffusers library.
/// - Returns: The encoded latent space as MLShapedArray
public func encode(
image: CGImage,
diagonalNoise: MLShapedArray<Float32>,
noise: MLShapedArray<Float32>,
alphasCumprodStep: AlphasCumprodCalculation
) throws -> MLShapedArray<Float32> {
let sample = try image.plannerRGBShapedArray
let sqrtAlphasCumprod = MLShapedArray(scalars: [alphasCumprodStep.sqrtAlphasCumprod], shape: [1, 1])
let sqrtOneMinusAlphasCumprod = MLShapedArray(scalars: [alphasCumprodStep.sqrtOneMinusAlphasCumprod], shape: [1, 1])
let dict: [String: Any] = [
FeatureName.sample.rawValue: MLMultiArray(sample),
FeatureName.diagonalNoise.rawValue: MLMultiArray(diagonalNoise),
FeatureName.noise.rawValue: MLMultiArray(noise),
FeatureName.sqrtAlphasCumprod.rawValue: MLMultiArray(sqrtAlphasCumprod),
FeatureName.sqrtOneMinusAlphasCumprod.rawValue: MLMultiArray(sqrtOneMinusAlphasCumprod),
]
let featureProvider = try MLDictionaryFeatureProvider(dictionary: dict)
let batch = MLArrayBatchProvider(array: [featureProvider])
// Batch predict with model
let results = try queue.sync {
try model.perform { model in
if let feature = model.modelDescription.inputDescriptionsByName[FeatureName.sample.rawValue],
let shape = feature.multiArrayConstraint?.shape as? [Int]
{
guard sample.shape == shape else {
// TODO: Consider auto resizing and croping similar to how Vision or CoreML auto-generated Swift code can accomplish with `MLFeatureValue`
throw Error.sampleInputShapeNotCorrect
}
}
if let feature = model.modelDescription.inputDescriptionsByName[FeatureName.noise.rawValue],
let shape = feature.multiArrayConstraint?.shape as? [Int]
{
guard noise.shape == shape else {
throw Error.noiseInputShapeNotCorrect
}
}
if let feature = model.modelDescription.inputDescriptionsByName[FeatureName.diagonalNoise.rawValue],
let shape = feature.multiArrayConstraint?.shape as? [Int]
{
guard diagonalNoise.shape == shape else {
throw Error.noiseInputShapeNotCorrect
}
}
return try model.predictions(fromBatch: batch)
}
}
let batchLatents: [MLShapedArray<Float32>] = try (0..<results.count).compactMap { i in
let result = results.features(at: i)
guard
let outputName = result.featureNames.first,
let output = result.featureValue(for: outputName)?.multiArrayValue
else {
throw Error.latentOutputNotValid
}
return MLShapedArray(output)
}
guard let latents = batchLatents.first else {
throw Error.batchLatentOutputEmpty
}
return latents
}
}

@ -14,6 +14,9 @@ public protocol Scheduler {
/// Training diffusion time steps index by inference time step
var timeSteps: [Int] { get }
/// Training diffusion time steps index by inference time step
func calculateTimesteps(strength: Float?) -> [Int]
/// Schedule of betas which controls the amount of noise added at each timestep
var betas: [Float] { get }
@ -71,6 +74,34 @@ public extension Scheduler {
}
}
// MARK: - Image2Image
@available(iOS 16.2, macOS 13.1, *)
public extension Scheduler {
func calculateAlphasCumprod(strength: Float) -> AlphasCumprodCalculation {
AlphasCumprodCalculation(
alphasCumprod: alphasCumProd,
timesteps: trainStepCount,
steps: inferenceStepCount,
strength: strength)
}
}
// MARK: - Timesteps
@available(iOS 16.2, macOS 13.1, *)
public extension Scheduler {
func calculateTimesteps(strength: Float?) -> [Int] {
guard let strength else { return timeSteps }
let startStep = max(inferenceStepCount - Int(Float(inferenceStepCount) * strength), 0)
let actualTimesteps = Array(timeSteps[startStep...])
return actualTimesteps
}
}
// MARK: - BetaSchedule
/// How to map a beta range to a sequence of betas to step over
@available(iOS 16.2, macOS 13.1, *)
public enum BetaSchedule {
@ -80,6 +111,7 @@ public enum BetaSchedule {
case scaledLinear
}
// MARK: - PNDMScheduler
/// A scheduler used to compute a de-noised image
///

@ -14,6 +14,7 @@ public extension StableDiffusionPipeline {
public let unetChunk1URL: URL
public let unetChunk2URL: URL
public let decoderURL: URL
public let encoderURL: URL
public let safetyCheckerURL: URL
public let vocabURL: URL
public let mergesURL: URL
@ -24,6 +25,7 @@ public extension StableDiffusionPipeline {
unetChunk1URL = baseURL.appending(path: "UnetChunk1.mlmodelc")
unetChunk2URL = baseURL.appending(path: "UnetChunk2.mlmodelc")
decoderURL = baseURL.appending(path: "VAEDecoder.mlmodelc")
encoderURL = baseURL.appending(path: "VAEEncoder.mlmodelc")
safetyCheckerURL = baseURL.appending(path: "SafetyChecker.mlmodelc")
vocabURL = baseURL.appending(path: "vocab.json")
mergesURL = baseURL.appending(path: "merges.txt")
@ -74,11 +76,20 @@ public extension StableDiffusionPipeline {
FileManager.default.fileExists(atPath: urls.safetyCheckerURL.path) {
safetyChecker = SafetyChecker(modelAt: urls.safetyCheckerURL, configuration: config)
}
// Optional Image Encoder
let encoder: Encoder?
if FileManager.default.fileExists(atPath: urls.encoderURL.path) {
encoder = Encoder(modelAt: urls.encoderURL, configuration: config)
} else {
encoder = nil
}
// Construct pipeline
self.init(textEncoder: textEncoder,
unet: unet,
decoder: decoder,
encoder: encoder,
safetyChecker: safetyChecker,
reduceMemory: reduceMemory)
}

@ -0,0 +1,60 @@
// For licensing see accompanying LICENSE.md file.
// Copyright (C) 2022 Apple Inc. All Rights Reserved.
import Foundation
import CoreGraphics
@available(iOS 16.2, macOS 13.1, *)
extension StableDiffusionPipeline {
/// Tyoe of processing that will be performed to generate an image
public enum Mode {
case textToImage
case imageToImage
// case inPainting
}
/// Image generation configuration
public struct Configuration: Hashable {
/// Text prompt to guide sampling
public var prompt: String
/// Negative text prompt to guide sampling
public var negativePrompt: String = ""
/// Starting image for image2image or in-painting
public var startingImage: CGImage? = nil
//public var maskImage: CGImage? = nil
public var strength: Float = 1.0
/// Number of images to generate
public var imageCount: Int = 1
/// Number of inference steps to perform
public var stepCount: Int = 50
/// Random seed which to start generation
public var seed: UInt32 = 0
/// Controls the influence of the text prompt on sampling process (0=random images)
public var guidanceScale: Float = 7.5
/// Safety checks are only performed if `self.canSafetyCheck && !disableSafety`
public var disableSafety: Bool = false
/// The type of Scheduler to use.
public var schedulerType: StableDiffusionScheduler = .pndmScheduler
/// Given the configuration, what mode will be used for generation
public var mode: Mode {
guard startingImage != nil else {
return .textToImage
}
guard strength < 1.0 else {
return .textToImage
}
return .imageToImage
}
public init(
prompt: String
) {
self.prompt = prompt
}
}
}

@ -20,7 +20,11 @@ public enum StableDiffusionScheduler {
/// [Hugging Face Diffusers Pipeline](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py)
@available(iOS 16.2, macOS 13.1, *)
public struct StableDiffusionPipeline: ResourceManaging {
public enum Error: String, Swift.Error {
case startingImageProvidedWithoutEncoder
}
/// Model to generate embeddings for tokenized input text
var textEncoder: TextEncoder
@ -29,6 +33,9 @@ public struct StableDiffusionPipeline: ResourceManaging {
/// Model used to generate final image from latent diffusion process
var decoder: Decoder
/// Model used to latent space for image2image, and soon, in-painting
var encoder: Encoder?
/// Optional model for checking safety of generated image
var safetyChecker: SafetyChecker? = nil
@ -58,11 +65,13 @@ public struct StableDiffusionPipeline: ResourceManaging {
public init(textEncoder: TextEncoder,
unet: Unet,
decoder: Decoder,
encoder: Encoder?,
safetyChecker: SafetyChecker? = nil,
reduceMemory: Bool = false) {
self.textEncoder = textEncoder
self.unet = unet
self.decoder = decoder
self.encoder = encoder
self.safetyChecker = safetyChecker
self.reduceMemory = reduceMemory
}
@ -98,34 +107,20 @@ public struct StableDiffusionPipeline: ResourceManaging {
try safetyChecker?.prewarmResources()
}
/// Text to image generation using stable diffusion
///
/// Image generation using stable diffusion
/// - Parameters:
/// - prompt: Text prompt to guide sampling
/// - negativePrompt: Negative text prompt to guide sampling
/// - stepCount: Number of inference steps to perform
/// - imageCount: Number of samples/images to generate for the input prompt
/// - seed: Random seed which
/// - guidanceScale: Controls the influence of the text prompt on sampling process (0=random images)
/// - disableSafety: Safety checks are only performed if `self.canSafetyCheck && !disableSafety`
/// - progressHandler: Callback to perform after each step, stops on receiving false response
/// - Returns: An array of `imageCount` optional images.
/// The images will be nil if safety checks were performed and found the result to be un-safe
public func generateImages(
prompt: String,
negativePrompt: String = "",
imageCount: Int = 1,
stepCount: Int = 50,
seed: UInt32 = 0,
guidanceScale: Float = 7.5,
disableSafety: Bool = false,
scheduler: StableDiffusionScheduler = .pndmScheduler,
configuration config: Configuration,
progressHandler: (Progress) -> Bool = { _ in true }
) throws -> [CGImage?] {
// Encode the input prompt and negative prompt
let promptEmbedding = try textEncoder.encode(prompt)
let negativePromptEmbedding = try textEncoder.encode(negativePrompt)
let promptEmbedding = try textEncoder.encode(config.prompt)
let negativePromptEmbedding = try textEncoder.encode(config.negativePrompt)
if reduceMemory {
textEncoder.unloadResources()
@ -141,19 +136,44 @@ public struct StableDiffusionPipeline: ResourceManaging {
let hiddenStates = toHiddenStates(concatEmbedding)
/// Setup schedulers
let scheduler: [Scheduler] = (0..<imageCount).map { _ in
switch scheduler {
case .pndmScheduler: return PNDMScheduler(stepCount: stepCount)
case .dpmSolverMultistepScheduler: return DPMSolverMultistepScheduler(stepCount: stepCount)
let scheduler: [Scheduler] = (0..<config.imageCount).map { _ in
switch config.schedulerType {
case .pndmScheduler: return PNDMScheduler(stepCount: config.stepCount)
case .dpmSolverMultistepScheduler: return DPMSolverMultistepScheduler(stepCount: config.stepCount)
}
}
let stdev = scheduler[0].initNoiseSigma
// Generate random latent samples from specified seed
var latents = generateLatentSamples(imageCount, stdev: stdev, seed: seed)
var latents: [MLShapedArray<Float32>]
let timestepStrength: Float?
if
let startingImage = config.startingImage,
config.mode == .imageToImage
{
timestepStrength = config.strength
guard let encoder else {
throw Error.startingImageProvidedWithoutEncoder
}
let noiseTuples = generateImage2ImageLatentSamples(config.imageCount, stdev: 1, seed: config.seed)
latents = try noiseTuples.map({
try encoder.encode(
image: startingImage,
diagonalNoise: $0.diagonal,
noise: $0.latentNoise,
alphasCumprodStep: scheduler[0].calculateAlphasCumprod(strength: config.strength))
})
} else {
timestepStrength = nil
// Generate random latent samples from specified seed
latents = generateLatentSamples(config.imageCount, stdev: stdev, seed: config.seed)
}
// De-noising loop
for (step,t) in scheduler[0].timeSteps.enumerated() {
let timeSteps: [Int] = scheduler[0].calculateTimesteps(strength: timestepStrength)
for (step,t) in timeSteps.enumerated() {
// Expand the latents for classifier-free guidance
// and input to the Unet noise prediction model
@ -169,11 +189,11 @@ public struct StableDiffusionPipeline: ResourceManaging {
hiddenStates: hiddenStates
)
noise = performGuidance(noise, guidanceScale)
noise = performGuidance(noise, config.guidanceScale)
// Have the scheduler compute the previous (t-1) latent
// sample given the predicted noise and current sample
for i in 0..<imageCount {
for i in 0..<config.imageCount {
latents[i] = scheduler[i].step(
output: noise[i],
timeStep: t,
@ -184,11 +204,11 @@ public struct StableDiffusionPipeline: ResourceManaging {
// Report progress
let progress = Progress(
pipeline: self,
prompt: prompt,
prompt: config.prompt,
step: step,
stepCount: stepCount,
stepCount: timeSteps.count,
currentLatentSamples: latents,
isSafetyEnabled: canSafetyCheck && !disableSafety
isSafetyEnabled: canSafetyCheck && !config.disableSafety
)
if !progressHandler(progress) {
// Stop if requested by handler
@ -201,7 +221,7 @@ public struct StableDiffusionPipeline: ResourceManaging {
}
// Decode the latent samples to images
return try decodeToImages(latents, disableSafety: disableSafety)
return try decodeToImages(latents, disableSafety: config.disableSafety)
}
func generateLatentSamples(_ count: Int, stdev: Float, seed: UInt32) -> [MLShapedArray<Float32>] {
@ -215,6 +235,35 @@ public struct StableDiffusionPipeline: ResourceManaging {
}
return samples
}
/// For image2image -
/// - Parameters:
/// - count: batch size
/// - stdev: 1
/// - seed: seed provided
/// - diagonalAndLatentNoiseIsSame: Diffusions library does not seem to use the same noise for the `DiagonalGaussianDistribution` operation,
/// but I have seen implementations of pipelines where it is the same.
/// - Returns: An array of tuples of noise values with length of batch size.
func generateImage2ImageLatentSamples(_ count: Int, stdev: Float, seed: UInt32, diagonalAndLatentNoiseIsSame: Bool = false) -> [(diagonal: MLShapedArray<Float32>, latentNoise: MLShapedArray<Float32>)] {
var sampleShape = unet.latentSampleShape
sampleShape[0] = 1
var random = NumPyRandomSource(seed: UInt32(truncatingIfNeeded: seed))
let samples = (0..<count).map { _ in
if diagonalAndLatentNoiseIsSame {
let noise = MLShapedArray<Float32>(
converting: random.normalShapedArray(sampleShape, mean: 0.0, stdev: Double(stdev)))
return (noise, noise)
} else {
return (MLShapedArray<Float32>(
converting: random.normalShapedArray(sampleShape, mean: 0.0, stdev: Double(stdev))),
MLShapedArray<Float32>(
converting: random.normalShapedArray(sampleShape, mean: 0.0, stdev: Double(stdev))))
}
}
return samples
}
func toHiddenStates(_ embedding: MLShapedArray<Float32>) -> MLShapedArray<Float32> {
// Unoptimized manual transpose [0, 2, None, 1]

@ -7,6 +7,7 @@ import CoreML
import Foundation
import StableDiffusion
import UniformTypeIdentifiers
import Cocoa
@available(iOS 16.2, macOS 13.1, *)
struct StableDiffusionSample: ParsableCommand {
@ -32,6 +33,12 @@ struct StableDiffusionSample: ParsableCommand {
)
)
var resourcePath: String = "./"
@Option(help: "Path to starting image.")
var image: String? = nil
@Option(help: "Strength for image2image.")
var strength: Float = 0.5
@Option(help: "Number of images to sample / generate")
var imageCount: Int = 1
@ -51,7 +58,7 @@ struct StableDiffusionSample: ParsableCommand {
var outputPath: String = "./"
@Option(help: "Random seed")
var seed: UInt32 = 93
var seed: UInt32 = UInt32.random(in: 0...UInt32.max)
@Option(help: "Controls the influence of the text prompt on sampling process (0=random images)")
var guidanceScale: Float = 7.5
@ -84,27 +91,52 @@ struct StableDiffusionSample: ParsableCommand {
disableSafety: disableSafety,
reduceMemory: reduceMemory)
try pipeline.loadResources()
let startingImage: CGImage?
if let image {
let imageURL = URL(filePath: image)
do {
let imageData = try Data(contentsOf: imageURL)
guard
let nsImage = NSImage(data: imageData),
let loadedImage = nsImage.cgImage(forProposedRect: nil, context: nil, hints: nil)
else {
throw RunError.resources("Starting Image not available \(resourcePath)")
}
startingImage = loadedImage
} catch let error {
throw RunError.resources("Starting image not found \(imageURL), error: \(error)")
}
} else {
startingImage = nil
}
log("Sampling ...\n")
let sampleTimer = SampleTimer()
sampleTimer.start()
var pipelineConfig = StableDiffusionPipeline.Configuration(prompt: prompt)
pipelineConfig.negativePrompt = negativePrompt
pipelineConfig.startingImage = startingImage
pipelineConfig.strength = strength
pipelineConfig.imageCount = imageCount
pipelineConfig.stepCount = stepCount
pipelineConfig.seed = seed
pipelineConfig.guidanceScale = guidanceScale
pipelineConfig.schedulerType = scheduler.stableDiffusionScheduler
let images = try pipeline.generateImages(
prompt: prompt,
negativePrompt: negativePrompt,
imageCount: imageCount,
stepCount: stepCount,
seed: seed,
guidanceScale: guidanceScale,
scheduler: scheduler.stableDiffusionScheduler
) { progress in
sampleTimer.stop()
handleProgress(progress,sampleTimer)
if progress.stepCount != progress.step {
sampleTimer.start()
}
return true
}
configuration: pipelineConfig,
progressHandler: { progress in
sampleTimer.stop()
handleProgress(progress,sampleTimer)
if progress.stepCount != progress.step {
sampleTimer.start()
}
return true
})
_ = try saveImages(images, logNames: true)
}
@ -168,6 +200,10 @@ struct StableDiffusionSample: ParsableCommand {
if imageCount != 1 {
name += ".\(sample)"
}
if image != "none" {
name += ".str\(Int(strength * 100))"
}
name += ".\(seed)"

Loading…
Cancel
Save