You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
64 lines
2.3 KiB
64 lines
2.3 KiB
// For licensing see accompanying LICENSE.md file.
|
|
// Copyright (C) 2022 Apple Inc. All Rights Reserved.
|
|
|
|
import XCTest
|
|
import CoreML
|
|
@testable import StableDiffusion
|
|
|
|
@available(iOS 16.2, macOS 13.1, *)
|
|
final class StableDiffusionTests: XCTestCase {
|
|
|
|
var vocabFileInBundleURL: URL {
|
|
let fileName = "vocab"
|
|
guard let url = Bundle.module.url(forResource: fileName, withExtension: "json") else {
|
|
fatalError("BPE tokenizer vocabulary file is missing from bundle")
|
|
}
|
|
return url
|
|
}
|
|
|
|
var mergesFileInBundleURL: URL {
|
|
let fileName = "merges"
|
|
guard let url = Bundle.module.url(forResource: fileName, withExtension: "txt") else {
|
|
fatalError("BPE tokenizer merges file is missing from bundle")
|
|
}
|
|
return url
|
|
}
|
|
|
|
func testBPETokenizer() throws {
|
|
|
|
let tokenizer = try BPETokenizer(mergesAt: mergesFileInBundleURL, vocabularyAt: vocabFileInBundleURL)
|
|
|
|
func testPrompt(prompt: String, expectedIds: [Int]) {
|
|
|
|
let (tokens, ids) = tokenizer.tokenize(input: prompt)
|
|
|
|
print("Tokens = \(tokens)\n")
|
|
print("Expected tokens = \(expectedIds.map({ tokenizer.token(id: $0) }))")
|
|
print("ids = \(ids)\n")
|
|
print("Expected Ids = \(expectedIds)\n")
|
|
|
|
XCTAssertEqual(ids,expectedIds)
|
|
}
|
|
|
|
testPrompt(prompt: "a photo of an astronaut riding a horse on mars",
|
|
expectedIds: [49406, 320, 1125, 539, 550, 18376, 6765, 320, 4558, 525, 7496, 49407])
|
|
|
|
testPrompt(prompt: "Apple CoreML developer tools on a Macbook Air are fast",
|
|
expectedIds: [49406, 3055, 19622, 5780, 10929, 5771, 525, 320, 20617,
|
|
1922, 631, 1953, 49407])
|
|
}
|
|
|
|
func test_randomNormalValues_matchNumPyRandom() {
|
|
var random = NumPyRandomSource(seed: 12345)
|
|
let samples = random.normalArray(count: 10_000)
|
|
let last5 = samples.suffix(5)
|
|
|
|
// numpy.random.seed(12345); print(numpy.random.randn(10000)[-5:])
|
|
let expected = [-0.86285345, 2.15229409, -0.00670556, -1.21472309, 0.65498866]
|
|
|
|
for (value, expected) in zip(last5, expected) {
|
|
XCTAssertEqual(value, expected, accuracy: .ulpOfOne.squareRoot())
|
|
}
|
|
}
|
|
}
|