You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

64 lines
2.3 KiB

// For licensing see accompanying LICENSE.md file.
// Copyright (C) 2022 Apple Inc. All Rights Reserved.
import XCTest
import CoreML
@testable import StableDiffusion
@available(iOS 16.2, macOS 13.1, *)
final class StableDiffusionTests: XCTestCase {
var vocabFileInBundleURL: URL {
let fileName = "vocab"
guard let url = Bundle.module.url(forResource: fileName, withExtension: "json") else {
fatalError("BPE tokenizer vocabulary file is missing from bundle")
}
return url
}
var mergesFileInBundleURL: URL {
let fileName = "merges"
guard let url = Bundle.module.url(forResource: fileName, withExtension: "txt") else {
fatalError("BPE tokenizer merges file is missing from bundle")
}
return url
}
func testBPETokenizer() throws {
let tokenizer = try BPETokenizer(mergesAt: mergesFileInBundleURL, vocabularyAt: vocabFileInBundleURL)
func testPrompt(prompt: String, expectedIds: [Int]) {
let (tokens, ids) = tokenizer.tokenize(input: prompt)
print("Tokens = \(tokens)\n")
print("Expected tokens = \(expectedIds.map({ tokenizer.token(id: $0) }))")
print("ids = \(ids)\n")
print("Expected Ids = \(expectedIds)\n")
XCTAssertEqual(ids,expectedIds)
}
testPrompt(prompt: "a photo of an astronaut riding a horse on mars",
expectedIds: [49406, 320, 1125, 539, 550, 18376, 6765, 320, 4558, 525, 7496, 49407])
testPrompt(prompt: "Apple CoreML developer tools on a Macbook Air are fast",
expectedIds: [49406, 3055, 19622, 5780, 10929, 5771, 525, 320, 20617,
1922, 631, 1953, 49407])
}
func test_randomNormalValues_matchNumPyRandom() {
var random = NumPyRandomSource(seed: 12345)
let samples = random.normalArray(count: 10_000)
let last5 = samples.suffix(5)
// numpy.random.seed(12345); print(numpy.random.randn(10000)[-5:])
let expected = [-0.86285345, 2.15229409, -0.00670556, -1.21472309, 0.65498866]
for (value, expected) in zip(last5, expected) {
XCTAssertEqual(value, expected, accuracy: .ulpOfOne.squareRoot())
}
}
}