change name and add error handling

also fix 512 hard coded
pull/116/head
Timothy Kautz 1 year ago
parent 088bdc1dc4
commit eafd72efd9

@ -111,7 +111,7 @@ extension CGImage {
let imageData = redData + greenData + blueData
let shapedArray = MLShapedArray<Float32>(data: imageData, shape: [1, 3, 512, 512])
let shapedArray = MLShapedArray<Float32>(data: imageData, shape: [1, 3, self.width, self.height])
return shapedArray
}

@ -8,9 +8,19 @@ import CoreML
/// Encoder, currently supports image2image
public struct Encoder: ResourceManaging {
public enum FeatureName: String {
case sample = "sample"
case diagonalNoise = "diagonal_noise"
case noise = "noise"
case sqrtAlphasCumprod = "sqrt_alphas_cumprod"
case sqrtOneMinusAlphasCumprod = "sqrt_one_minus_alphas_cumprod"
}
public enum Error: String, Swift.Error {
case latentOutputNotValid
case batchLatentOutputEmpty
case sampleInputShapeNotCorrect
case noiseInputShapeNotCorrect
}
/// VAE encoder model + post math and adding noise from schedular
@ -57,11 +67,11 @@ public struct Encoder: ResourceManaging {
let sqrtOneMinusAlphasCumprod = MLShapedArray(scalars: [alphasCumprodStep.sqrtOneMinusAlphasCumprod], shape: [1, 1])
let dict: [String: Any] = [
"sample": MLMultiArray(sample),
"diagonal_noise": MLMultiArray(diagonalNoise),
"noise": MLMultiArray(noise),
"sqrt_alphas_cumprod": MLMultiArray(sqrtAlphasCumprod),
"sqrt_one_minus_alphas_cumprod": MLMultiArray(sqrtOneMinusAlphasCumprod),
FeatureName.sample.rawValue: MLMultiArray(sample),
FeatureName.diagonalNoise.rawValue: MLMultiArray(diagonalNoise),
FeatureName.noise.rawValue: MLMultiArray(noise),
FeatureName.sqrtAlphasCumprod.rawValue: MLMultiArray(sqrtAlphasCumprod),
FeatureName.sqrtOneMinusAlphasCumprod.rawValue: MLMultiArray(sqrtOneMinusAlphasCumprod),
]
let featureProvider = try MLDictionaryFeatureProvider(dictionary: dict)
@ -71,7 +81,32 @@ public struct Encoder: ResourceManaging {
let results = try queue.sync {
try model.perform { model in
try model.predictions(fromBatch: batch)
if let feature = model.modelDescription.inputDescriptionsByName[FeatureName.sample.rawValue],
let shape = feature.multiArrayConstraint?.shape as? [Int]
{
guard sample.shape == shape else {
// TODO: Consider auto resizing and croping similar to how Vision or CoreML auto-generated Swift code can accomplish with `MLFeatureValue`
throw Error.sampleInputShapeNotCorrect
}
}
if let feature = model.modelDescription.inputDescriptionsByName[FeatureName.noise.rawValue],
let shape = feature.multiArrayConstraint?.shape as? [Int]
{
guard noise.shape == shape else {
throw Error.noiseInputShapeNotCorrect
}
}
if let feature = model.modelDescription.inputDescriptionsByName[FeatureName.diagonalNoise.rawValue],
let shape = feature.multiArrayConstraint?.shape as? [Int]
{
guard diagonalNoise.shape == shape else {
throw Error.noiseInputShapeNotCorrect
}
}
return try model.predictions(fromBatch: batch)
}
}

Loading…
Cancel
Save