You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
171 lines
5.7 KiB
171 lines
5.7 KiB
// For licensing see accompanying LICENSE.md file.
|
|
// Copyright (C) 2022 Apple Inc. All Rights Reserved.
|
|
|
|
import Foundation
|
|
import CoreML
|
|
|
|
/// U-Net noise prediction model for stable diffusion
|
|
@available(iOS 16.2, macOS 13.1, *)
|
|
public struct Unet: ResourceManaging {
|
|
|
|
/// Model used to predict noise residuals given an input, diffusion time step, and conditional embedding
|
|
///
|
|
/// It can be in the form of a single model or multiple stages
|
|
var models: [ManagedMLModel]
|
|
|
|
/// Creates a U-Net noise prediction model
|
|
///
|
|
/// - Parameters:
|
|
/// - url: Location of single U-Net compiled Core ML model
|
|
/// - configuration: Configuration to be used when the model is loaded
|
|
/// - Returns: U-net model that will lazily load its required resources when needed or requested
|
|
public init(modelAt url: URL,
|
|
configuration: MLModelConfiguration) {
|
|
self.models = [ManagedMLModel(modelAt: url, configuration: configuration)]
|
|
}
|
|
|
|
/// Creates a U-Net noise prediction model
|
|
///
|
|
/// - Parameters:
|
|
/// - urls: Location of chunked U-Net via urls to each compiled chunk
|
|
/// - configuration: Configuration to be used when the model is loaded
|
|
/// - Returns: U-net model that will lazily load its required resources when needed or requested
|
|
public init(chunksAt urls: [URL],
|
|
configuration: MLModelConfiguration) {
|
|
self.models = urls.map { ManagedMLModel(modelAt: $0, configuration: configuration) }
|
|
}
|
|
|
|
/// Load resources.
|
|
public func loadResources() throws {
|
|
for model in models {
|
|
try model.loadResources()
|
|
}
|
|
}
|
|
|
|
/// Unload the underlying model to free up memory
|
|
public func unloadResources() {
|
|
for model in models {
|
|
model.unloadResources()
|
|
}
|
|
}
|
|
|
|
/// Pre-warm resources
|
|
public func prewarmResources() throws {
|
|
// Override default to pre-warm each model
|
|
for model in models {
|
|
try model.loadResources()
|
|
model.unloadResources()
|
|
}
|
|
}
|
|
|
|
var latentSampleDescription: MLFeatureDescription {
|
|
try! models.first!.perform { model in
|
|
model.modelDescription.inputDescriptionsByName["sample"]!
|
|
}
|
|
}
|
|
|
|
/// The expected shape of the models latent sample input
|
|
public var latentSampleShape: [Int] {
|
|
latentSampleDescription.multiArrayConstraint!.shape.map { $0.intValue }
|
|
}
|
|
|
|
/// Batch prediction noise from latent samples
|
|
///
|
|
/// - Parameters:
|
|
/// - latents: Batch of latent samples in an array
|
|
/// - timeStep: Current diffusion timestep
|
|
/// - hiddenStates: Hidden state to condition on
|
|
/// - Returns: Array of predicted noise residuals
|
|
func predictNoise(
|
|
latents: [MLShapedArray<Float32>],
|
|
timeStep: Int,
|
|
hiddenStates: MLShapedArray<Float32>
|
|
) throws -> [MLShapedArray<Float32>] {
|
|
|
|
// Match time step batch dimension to the model / latent samples
|
|
let t = MLShapedArray<Float32>(scalars:[Float(timeStep), Float(timeStep)],shape:[2])
|
|
|
|
// Form batch input to model
|
|
let inputs = try latents.map {
|
|
let dict: [String: Any] = [
|
|
"sample" : MLMultiArray($0),
|
|
"timestep" : MLMultiArray(t),
|
|
"encoder_hidden_states": MLMultiArray(hiddenStates)
|
|
]
|
|
return try MLDictionaryFeatureProvider(dictionary: dict)
|
|
}
|
|
let batch = MLArrayBatchProvider(array: inputs)
|
|
|
|
// Make predictions
|
|
let results = try predictions(from: batch)
|
|
|
|
// Pull out the results in Float32 format
|
|
let noise = (0..<results.count).map { i in
|
|
|
|
let result = results.features(at: i)
|
|
let outputName = result.featureNames.first!
|
|
|
|
let outputNoise = result.featureValue(for: outputName)!.multiArrayValue!
|
|
|
|
// To conform to this func return type make sure we return float32
|
|
// Use the fact that the concatenating constructor for MLMultiArray
|
|
// can do type conversion:
|
|
let fp32Noise = MLMultiArray(
|
|
concatenating: [outputNoise],
|
|
axis: 0,
|
|
dataType: .float32
|
|
)
|
|
return MLShapedArray<Float32>(fp32Noise)
|
|
}
|
|
|
|
return noise
|
|
}
|
|
|
|
func predictions(from batch: MLBatchProvider) throws -> MLBatchProvider {
|
|
|
|
var results = try models.first!.perform { model in
|
|
try model.predictions(fromBatch: batch)
|
|
}
|
|
|
|
if models.count == 1 {
|
|
return results
|
|
}
|
|
|
|
// Manual pipeline batch prediction
|
|
let inputs = batch.arrayOfFeatureValueDictionaries
|
|
for stage in models.dropFirst() {
|
|
|
|
// Combine the original inputs with the outputs of the last stage
|
|
let next = try results.arrayOfFeatureValueDictionaries
|
|
.enumerated().map { (index, dict) in
|
|
let nextDict = dict.merging(inputs[index]) { (out, _) in out }
|
|
return try MLDictionaryFeatureProvider(dictionary: nextDict)
|
|
}
|
|
let nextBatch = MLArrayBatchProvider(array: next)
|
|
|
|
// Predict
|
|
results = try stage.perform { model in
|
|
try model.predictions(fromBatch: nextBatch)
|
|
}
|
|
}
|
|
|
|
return results
|
|
}
|
|
}
|
|
|
|
extension MLFeatureProvider {
|
|
var featureValueDictionary: [String : MLFeatureValue] {
|
|
self.featureNames.reduce(into: [String : MLFeatureValue]()) { result, name in
|
|
result[name] = self.featureValue(for: name)
|
|
}
|
|
}
|
|
}
|
|
|
|
extension MLBatchProvider {
|
|
var arrayOfFeatureValueDictionaries: [[String : MLFeatureValue]] {
|
|
(0..<self.count).map {
|
|
self.features(at: $0).featureValueDictionary
|
|
}
|
|
}
|
|
}
|