You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

184 lines
6.1 KiB

// For licensing see accompanying LICENSE.md file.
// Copyright (C) 2022 Apple Inc. All Rights Reserved.
import Foundation
/// A tokenizer based on byte pair encoding.
@available(iOS 16.2, macOS 13.1, *)
public struct BPETokenizer {
/// A dictionary that maps pairs of tokens to the rank/order of the merge.
let merges: [TokenPair : Int]
/// A dictionary from of tokens to identifiers.
let vocabulary: [String: Int]
/// The start token.
let startToken: String = "<|startoftext|>"
/// The end token.
let endToken: String = "<|endoftext|>"
/// The token used for padding
let padToken: String = "<|endoftext|>"
/// The unknown token.
let unknownToken: String = "<|endoftext|>"
var unknownTokenID: Int {
vocabulary[unknownToken, default: 0]
}
/// Creates a tokenizer.
///
/// - Parameters:
/// - merges: A dictionary that maps pairs of tokens to the rank/order of the merge.
/// - vocabulary: A dictionary from of tokens to identifiers.
public init(merges: [TokenPair: Int], vocabulary: [String: Int]) {
self.merges = merges
self.vocabulary = vocabulary
}
/// Creates a tokenizer by loading merges and vocabulary from URLs.
///
/// - Parameters:
/// - mergesURL: The URL of a text file containing merges.
/// - vocabularyURL: The URL of a JSON file containing the vocabulary.
public init(mergesAt mergesURL: URL, vocabularyAt vocabularyURL: URL) throws {
self.merges = try Self.readMerges(url: mergesURL)
self.vocabulary = try! Self.readVocabulary(url: vocabularyURL)
}
/// Tokenizes an input string.
///
/// - Parameters:
/// - input: A string.
/// - minCount: The minimum number of tokens to return.
/// - Returns: An array of tokens and an array of token identifiers.
public func tokenize(input: String, minCount: Int? = nil) -> (tokens: [String], tokenIDs: [Int]) {
var tokens: [String] = []
tokens.append(startToken)
tokens.append(contentsOf: encode(input: input))
tokens.append(endToken)
// Pad if there was a min length specified
if let minLen = minCount, minLen > tokens.count {
tokens.append(contentsOf: repeatElement(padToken, count: minLen - tokens.count))
}
let ids = tokens.map({ vocabulary[$0, default: unknownTokenID] })
return (tokens: tokens, tokenIDs: ids)
}
/// Returns the token identifier for a token.
public func tokenID(for token: String) -> Int? {
vocabulary[token]
}
/// Returns the token for a token identifier.
public func token(id: Int) -> String? {
vocabulary.first(where: { $0.value == id })?.key
}
/// Decodes a sequence of tokens into a fully formed string
public func decode(tokens: [String]) -> String {
String(tokens.joined())
.replacingOccurrences(of: "</w>", with: " ")
.replacingOccurrences(of: startToken, with: "")
.replacingOccurrences(of: endToken, with: "")
}
/// Encode an input string to a sequence of tokens
func encode(input: String) -> [String] {
let normalized = input.trimmingCharacters(in: .whitespacesAndNewlines).lowercased()
let words = normalized.split(separator: " ")
return words.flatMap({ encode(word: $0) })
}
/// Encode a single word into a sequence of tokens
func encode(word: Substring) -> [String] {
var tokens = word.map { String($0) }
if let last = tokens.indices.last {
tokens[last] = tokens[last] + "</w>"
}
while true {
let pairs = pairs(for: tokens)
let canMerge = pairs.filter { merges[$0] != nil }
if canMerge.isEmpty {
break
}
// If multiple merges are found, use the one with the lowest rank
let shouldMerge = canMerge.min { merges[$0]! < merges[$1]! }!
tokens = update(tokens, merging: shouldMerge)
}
return tokens
}
/// Get the set of adjacent pairs / bigrams from a sequence of tokens
func pairs(for tokens: [String]) -> Set<TokenPair> {
guard tokens.count > 1 else {
return Set()
}
var pairs = Set<TokenPair>(minimumCapacity: tokens.count - 1)
var prev = tokens.first!
for current in tokens.dropFirst() {
pairs.insert(TokenPair(prev, current))
prev = current
}
return pairs
}
/// Update the sequence of tokens by greedily merging instance of a specific bigram
func update(_ tokens: [String], merging bigram: TokenPair) -> [String] {
guard tokens.count > 1 else {
return []
}
var newTokens = [String]()
newTokens.reserveCapacity(tokens.count - 1)
var index = 0
while index < tokens.count {
let remainingTokens = tokens[index...]
if let startMatchIndex = remainingTokens.firstIndex(of: bigram.first) {
// Found a possible match, append everything before it
newTokens.append(contentsOf: tokens[index..<startMatchIndex])
if index < tokens.count - 1 && tokens[startMatchIndex + 1] == bigram.second {
// Full match, merge
newTokens.append(bigram.first + bigram.second)
index = startMatchIndex + 2
} else {
// Only matched the first, no merge
newTokens.append(bigram.first)
index = startMatchIndex + 1
}
} else {
// Didn't find any more matches, append the rest unmerged
newTokens.append(contentsOf: remainingTokens)
break
}
}
return newTokens
}
}
@available(iOS 16.2, macOS 13.1, *)
extension BPETokenizer {
/// A hashable tuple of strings
public struct TokenPair: Hashable {
let first: String
let second: String
init(_ first: String, _ second: String) {
self.first = first
self.second = second
}
}
}