// For licensing see accompanying LICENSE.md file. // Copyright (C) 2022 Apple Inc. All Rights Reserved. import XCTest import CoreML @testable import StableDiffusion final class StableDiffusionTests: XCTestCase { var vocabFileInBundleURL: URL { let fileName = "vocab" guard let url = Bundle.module.url(forResource: fileName, withExtension: "json") else { fatalError("BPE tokenizer vocabulary file is missing from bundle") } return url } var mergesFileInBundleURL: URL { let fileName = "merges" guard let url = Bundle.module.url(forResource: fileName, withExtension: "txt") else { fatalError("BPE tokenizer merges file is missing from bundle") } return url } func testBPETokenizer() throws { let tokenizer = try BPETokenizer(mergesAt: mergesFileInBundleURL, vocabularyAt: vocabFileInBundleURL) func testPrompt(prompt: String, expectedIds: [Int]) { let (tokens, ids) = tokenizer.tokenize(input: prompt) print("Tokens = \(tokens)\n") print("Expected tokens = \(expectedIds.map({ tokenizer.token(id: $0) }))") print("ids = \(ids)\n") print("Expected Ids = \(expectedIds)\n") XCTAssertEqual(ids,expectedIds) } testPrompt(prompt: "a photo of an astronaut riding a horse on mars", expectedIds: [49406, 320, 1125, 539, 550, 18376, 6765, 320, 4558, 525, 7496, 49407]) testPrompt(prompt: "Apple CoreML developer tools on a Macbook Air are fast", expectedIds: [49406, 3055, 19622, 5780, 10929, 5771, 525, 320, 20617, 1922, 631, 1953, 49407]) } func test_randomNormalValues_matchNumPyRandom() { var random = NumPyRandomSource(seed: 12345) let samples = random.normalArray(count: 10_000) let last5 = samples.suffix(5) // numpy.random.seed(12345); print(numpy.random.randn(10000)[-5:]) let expected = [-0.86285345, 2.15229409, -0.00670556, -1.21472309, 0.65498866] for (value, expected) in zip(last5, expected) { XCTAssertEqual(value, expected, accuracy: .ulpOfOne.squareRoot()) } } }