You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
97 lines
3.2 KiB
97 lines
3.2 KiB
// For licensing see accompanying LICENSE.md file.
|
|
// Copyright (C) 2022 Apple Inc. All Rights Reserved.
|
|
|
|
import Foundation
|
|
import CoreML
|
|
|
|
/// A model for encoding text
|
|
@available(iOS 16.2, macOS 13.1, *)
|
|
public struct TextEncoder: ResourceManaging {
|
|
|
|
/// Text tokenizer
|
|
var tokenizer: BPETokenizer
|
|
|
|
/// Embedding model
|
|
var model: ManagedMLModel
|
|
|
|
/// Creates text encoder which embeds a tokenized string
|
|
///
|
|
/// - Parameters:
|
|
/// - tokenizer: Tokenizer for input text
|
|
/// - url: Location of compiled text encoding Core ML model
|
|
/// - configuration: configuration to be used when the model is loaded
|
|
/// - Returns: A text encoder that will lazily load its required resources when needed or requested
|
|
public init(tokenizer: BPETokenizer,
|
|
modelAt url: URL,
|
|
configuration: MLModelConfiguration) {
|
|
self.tokenizer = tokenizer
|
|
self.model = ManagedMLModel(modelAt: url, configuration: configuration)
|
|
}
|
|
|
|
/// Ensure the model has been loaded into memory
|
|
public func loadResources() throws {
|
|
try model.loadResources()
|
|
}
|
|
|
|
/// Unload the underlying model to free up memory
|
|
public func unloadResources() {
|
|
model.unloadResources()
|
|
}
|
|
|
|
/// Encode input text/string
|
|
///
|
|
/// - Parameters:
|
|
/// - text: Input text to be tokenized and then embedded
|
|
/// - Returns: Embedding representing the input text
|
|
public func encode(_ text: String) throws -> MLShapedArray<Float32> {
|
|
|
|
// Get models expected input length
|
|
let inputLength = inputShape.last!
|
|
|
|
// Tokenize, padding to the expected length
|
|
var (tokens, ids) = tokenizer.tokenize(input: text, minCount: inputLength)
|
|
|
|
// Truncate if necessary
|
|
if ids.count > inputLength {
|
|
tokens = tokens.dropLast(tokens.count - inputLength)
|
|
ids = ids.dropLast(ids.count - inputLength)
|
|
let truncated = tokenizer.decode(tokens: tokens)
|
|
print("Needed to truncate input '\(text)' to '\(truncated)'")
|
|
}
|
|
|
|
// Use the model to generate the embedding
|
|
return try encode(ids: ids)
|
|
}
|
|
|
|
/// Prediction queue
|
|
let queue = DispatchQueue(label: "textencoder.predict")
|
|
|
|
func encode(ids: [Int]) throws -> MLShapedArray<Float32> {
|
|
let inputName = inputDescription.name
|
|
let inputShape = inputShape
|
|
|
|
let floatIds = ids.map { Float32($0) }
|
|
let inputArray = MLShapedArray<Float32>(scalars: floatIds, shape: inputShape)
|
|
let inputFeatures = try! MLDictionaryFeatureProvider(
|
|
dictionary: [inputName: MLMultiArray(inputArray)])
|
|
|
|
let result = try model.perform { model in
|
|
try model.prediction(from: inputFeatures)
|
|
}
|
|
|
|
let embeddingFeature = result.featureValue(for: "last_hidden_state")
|
|
return MLShapedArray<Float32>(converting: embeddingFeature!.multiArrayValue!)
|
|
}
|
|
|
|
var inputDescription: MLFeatureDescription {
|
|
try! model.perform { model in
|
|
model.modelDescription.inputDescriptionsByName.first!.value
|
|
}
|
|
}
|
|
|
|
var inputShape: [Int] {
|
|
inputDescription.multiArrayConstraint!.shape.map { $0.intValue }
|
|
}
|
|
|
|
}
|