|
|
|
@ -8,9 +8,19 @@ import CoreML
|
|
|
|
|
/// Encoder, currently supports image2image
|
|
|
|
|
public struct Encoder: ResourceManaging {
|
|
|
|
|
|
|
|
|
|
public enum FeatureName: String {
|
|
|
|
|
case sample = "sample"
|
|
|
|
|
case diagonalNoise = "diagonal_noise"
|
|
|
|
|
case noise = "noise"
|
|
|
|
|
case sqrtAlphasCumprod = "sqrt_alphas_cumprod"
|
|
|
|
|
case sqrtOneMinusAlphasCumprod = "sqrt_one_minus_alphas_cumprod"
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
public enum Error: String, Swift.Error {
|
|
|
|
|
case latentOutputNotValid
|
|
|
|
|
case batchLatentOutputEmpty
|
|
|
|
|
case sampleInputShapeNotCorrect
|
|
|
|
|
case noiseInputShapeNotCorrect
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/// VAE encoder model + post math and adding noise from schedular
|
|
|
|
@ -57,11 +67,11 @@ public struct Encoder: ResourceManaging {
|
|
|
|
|
let sqrtOneMinusAlphasCumprod = MLShapedArray(scalars: [alphasCumprodStep.sqrtOneMinusAlphasCumprod], shape: [1, 1])
|
|
|
|
|
|
|
|
|
|
let dict: [String: Any] = [
|
|
|
|
|
"sample": MLMultiArray(sample),
|
|
|
|
|
"diagonal_noise": MLMultiArray(diagonalNoise),
|
|
|
|
|
"noise": MLMultiArray(noise),
|
|
|
|
|
"sqrt_alphas_cumprod": MLMultiArray(sqrtAlphasCumprod),
|
|
|
|
|
"sqrt_one_minus_alphas_cumprod": MLMultiArray(sqrtOneMinusAlphasCumprod),
|
|
|
|
|
FeatureName.sample.rawValue: MLMultiArray(sample),
|
|
|
|
|
FeatureName.diagonalNoise.rawValue: MLMultiArray(diagonalNoise),
|
|
|
|
|
FeatureName.noise.rawValue: MLMultiArray(noise),
|
|
|
|
|
FeatureName.sqrtAlphasCumprod.rawValue: MLMultiArray(sqrtAlphasCumprod),
|
|
|
|
|
FeatureName.sqrtOneMinusAlphasCumprod.rawValue: MLMultiArray(sqrtOneMinusAlphasCumprod),
|
|
|
|
|
]
|
|
|
|
|
let featureProvider = try MLDictionaryFeatureProvider(dictionary: dict)
|
|
|
|
|
|
|
|
|
@ -71,7 +81,32 @@ public struct Encoder: ResourceManaging {
|
|
|
|
|
|
|
|
|
|
let results = try queue.sync {
|
|
|
|
|
try model.perform { model in
|
|
|
|
|
try model.predictions(fromBatch: batch)
|
|
|
|
|
if let feature = model.modelDescription.inputDescriptionsByName[FeatureName.sample.rawValue],
|
|
|
|
|
let shape = feature.multiArrayConstraint?.shape as? [Int]
|
|
|
|
|
{
|
|
|
|
|
guard sample.shape == shape else {
|
|
|
|
|
// TODO: Consider auto resizing and croping similar to how Vision or CoreML auto-generated Swift code can accomplish with `MLFeatureValue`
|
|
|
|
|
throw Error.sampleInputShapeNotCorrect
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if let feature = model.modelDescription.inputDescriptionsByName[FeatureName.noise.rawValue],
|
|
|
|
|
let shape = feature.multiArrayConstraint?.shape as? [Int]
|
|
|
|
|
{
|
|
|
|
|
guard noise.shape == shape else {
|
|
|
|
|
throw Error.noiseInputShapeNotCorrect
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if let feature = model.modelDescription.inputDescriptionsByName[FeatureName.diagonalNoise.rawValue],
|
|
|
|
|
let shape = feature.multiArrayConstraint?.shape as? [Int]
|
|
|
|
|
{
|
|
|
|
|
guard diagonalNoise.shape == shape else {
|
|
|
|
|
throw Error.noiseInputShapeNotCorrect
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return try model.predictions(fromBatch: batch)
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|