|
|
@ -14,6 +14,7 @@
|
|
|
|
#include <string>
|
|
|
|
#include <string>
|
|
|
|
#include <thread>
|
|
|
|
#include <thread>
|
|
|
|
#include <vector>
|
|
|
|
#include <vector>
|
|
|
|
|
|
|
|
#include <regex>
|
|
|
|
|
|
|
|
|
|
|
|
#define USE_FLASH_ATTN
|
|
|
|
#define USE_FLASH_ATTN
|
|
|
|
//#define USE_FLASH_FF
|
|
|
|
//#define USE_FLASH_FF
|
|
|
@ -2161,6 +2162,71 @@ static bool log_mel_spectrogram(
|
|
|
|
return true;
|
|
|
|
return true;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// split text into tokens
|
|
|
|
|
|
|
|
//
|
|
|
|
|
|
|
|
// ref: https://github.com/openai/gpt-2/blob/a74da5d99abaaba920de8131d64da2862a8f213b/src/encoder.py#L53
|
|
|
|
|
|
|
|
//
|
|
|
|
|
|
|
|
// Regex (Python):
|
|
|
|
|
|
|
|
// r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""
|
|
|
|
|
|
|
|
//
|
|
|
|
|
|
|
|
// Regex (C++):
|
|
|
|
|
|
|
|
// R"('s|'t|'re|'ve|'m|'ll|'d| ?[[:alpha:]]+| ?[[:digit:]]+| ?[^\s[:alpha:][:digit:]]+|\s+(?!\S)|\s+)"
|
|
|
|
|
|
|
|
//
|
|
|
|
|
|
|
|
static std::vector<whisper_vocab::id> tokenize(const whisper_vocab & vocab, const std::string & text) {
|
|
|
|
|
|
|
|
std::vector<std::string> words;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// first split the text into words
|
|
|
|
|
|
|
|
{
|
|
|
|
|
|
|
|
std::string str = text;
|
|
|
|
|
|
|
|
std::string pat = R"('s|'t|'re|'ve|'m|'ll|'d| ?[[:alpha:]]+| ?[[:digit:]]+| ?[^\s[:alpha:][:digit:]]+|\s+(?!\S)|\s+)";
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
std::regex re(pat);
|
|
|
|
|
|
|
|
std::smatch m;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
while (std::regex_search(str, m, re)) {
|
|
|
|
|
|
|
|
for (auto x : m) {
|
|
|
|
|
|
|
|
words.push_back(x);
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
str = m.suffix();
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// find the longest tokens that form the words:
|
|
|
|
|
|
|
|
std::vector<whisper_vocab::id> tokens;
|
|
|
|
|
|
|
|
for (const auto & word : words) {
|
|
|
|
|
|
|
|
if (word.size() == 0) continue;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
int i = 0;
|
|
|
|
|
|
|
|
int n = word.size();
|
|
|
|
|
|
|
|
while (i < n) {
|
|
|
|
|
|
|
|
int j = n;
|
|
|
|
|
|
|
|
while (j > i) {
|
|
|
|
|
|
|
|
auto it = vocab.token_to_id.find(word.substr(i, j-i));
|
|
|
|
|
|
|
|
if (it != vocab.token_to_id.end()) {
|
|
|
|
|
|
|
|
tokens.push_back(it->second);
|
|
|
|
|
|
|
|
i = j;
|
|
|
|
|
|
|
|
break;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
--j;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
if (i == n) {
|
|
|
|
|
|
|
|
break;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
if (j == i) {
|
|
|
|
|
|
|
|
auto sub = word.substr(i, 1);
|
|
|
|
|
|
|
|
if (vocab.token_to_id.find(sub) != vocab.token_to_id.end()) {
|
|
|
|
|
|
|
|
tokens.push_back(vocab.token_to_id.at(sub));
|
|
|
|
|
|
|
|
} else {
|
|
|
|
|
|
|
|
fprintf(stderr, "%s: unknown token '%s'\n", __func__, sub.data());
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
++i;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return tokens;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
//
|
|
|
|
//
|
|
|
|
// interface implementation
|
|
|
|
// interface implementation
|
|
|
|
//
|
|
|
|
//
|
|
|
@ -2291,6 +2357,21 @@ struct whisper_token_data whisper_sample_timestamp(struct whisper_context * ctx,
|
|
|
|
return res;
|
|
|
|
return res;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
int whisper_tokenize(struct whisper_context * ctx, const char * text, whisper_token * tokens, int n_max_tokens) {
|
|
|
|
|
|
|
|
const auto res = tokenize(ctx->vocab, text);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if (res.size() > n_max_tokens) {
|
|
|
|
|
|
|
|
fprintf(stderr, "%s: too many resulting tokens: %d (max %d)\n", __func__, (int) res.size(), n_max_tokens);
|
|
|
|
|
|
|
|
return -1;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for (int i = 0; i < res.size(); i++) {
|
|
|
|
|
|
|
|
tokens[i] = res[i];
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return res.size();
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
int whisper_lang_id(const char * lang) {
|
|
|
|
int whisper_lang_id(const char * lang) {
|
|
|
|
if (!g_lang.count(lang)) {
|
|
|
|
if (!g_lang.count(lang)) {
|
|
|
|
fprintf(stderr, "%s: unknown language '%s'\n", __func__, lang);
|
|
|
|
fprintf(stderr, "%s: unknown language '%s'\n", __func__, lang);
|
|
|
|