Compare commits

...

5 Commits
4bit ... master

Author SHA1 Message Date
katsu560 4c2f924553
cmake : update CMakeLists.txt to add correct flags (#26)
1 year ago
Georgi Gerganov ba3e8a3d7f
readme : update Roadmap
1 year ago
Georgi Gerganov 2546cb7780
readme : add Roadmap section
1 year ago
Georgi Gerganov 8f8a5aca99
sync : latest whisper.cpp
1 year ago
Georgi Gerganov efa2cc36a2
tests : fix cblas_sgemm call
1 year ago

@ -2,6 +2,9 @@
Tensor library for machine learning
***Note that this project is under development and not ready for production use. \
Some of the development is currently happening in the [whisper.cpp](https://github.com/ggerganov/whisper.cpp) repo***
## Features
- Written in C
@ -13,9 +16,15 @@ Tensor library for machine learning
- No third-party dependencies
- Zero memory allocations during runtime
***Note that this project is under development and not ready for production use.
Most of the development is currently happening in the [whisper.cpp](https://github.com/ggerganov/whisper.cpp) repo,
so if you are interested in this project, make sure to follow what is happening there.***
## Roadmap
- [X] Example of GPT-2 inference [examples/gpt-2](https://github.com/ggerganov/ggml/tree/master/examples/gpt-2)
- [X] Example of GPT-J inference [examples/gpt-j](https://github.com/ggerganov/ggml/tree/master/examples/gpt-j)
- [X] Example of Whisper inference [examples/whisper](https://github.com/ggerganov/ggml/tree/master/examples/whisper)
- [ ] Support 4-bit integer quantization https://github.com/ggerganov/ggml/pull/27
- [ ] Example of FLAN-T5 inference https://github.com/ggerganov/ggml/pull/12
- [ ] Example of LLaMA inference
- [ ] Example of RWKV inference
## Whisper inference (example)

@ -91,12 +91,12 @@ void whisper_print_usage(int argc, char ** argv, const whisper_params & params);
bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
for (int i = 1; i < argc; i++) {
std::string arg = argv[i];
if (arg == "-"){
params.fname_inp.push_back(arg);
continue;
}
if (arg[0] != '-') {
params.fname_inp.push_back(arg);
continue;

@ -592,16 +592,16 @@ struct whisper_context {
mutable std::mt19937 rng; // used for sampling at t > 0.0
int lang_id;
int lang_id = 0; // english by default
// [EXPERIMENTAL] token-level timestamps data
int64_t t_beg;
int64_t t_last;
int64_t t_beg = 0;
int64_t t_last = 0;
whisper_token tid_last;
std::vector<float> energy; // PCM signal energy
// [EXPERIMENTAL] speed-up techniques
int32_t exp_n_audio_ctx; // 0 - use default
int32_t exp_n_audio_ctx = 0; // 0 - use default
void use_buf(struct ggml_context * ctx, int i) {
#if defined(WHISPER_USE_SCRATCH)
@ -805,7 +805,7 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
MEM_REQ_SCRATCH3.at (model.type) +
scale*MEM_REQ_MODEL.at (model.type) +
scale*MEM_REQ_KV_CROSS.at(model.type) +
scale*std::max(MEM_REQ_ENCODE.at(model.type), MEM_REQ_DECODE.at(model.type));
scale*std::max(MEM_REQ_ENCODE.at(model.type), MEM_REQ_DECODE.at(model.type));
// this is the memory required by one decoder
const size_t mem_required_decoder =
@ -2936,7 +2936,7 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
/*.language =*/ "en",
/*.suppress_blank =*/ true,
/*.suppress_non_speech_tokens =*/true,
/*.suppress_non_speech_tokens =*/ false,
/*.temperature =*/ 0.0f,
/*.max_initial_ts =*/ 1.0f,
@ -2962,6 +2962,9 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
/*.encoder_begin_callback =*/ nullptr,
/*.encoder_begin_callback_user_data =*/ nullptr,
/*.logits_filter_callback =*/ nullptr,
/*.logits_filter_callback_user_data =*/ nullptr,
};
switch (strategy) {
@ -3078,8 +3081,7 @@ static int whisper_wrap_segment(struct whisper_context & ctx, int max_len, bool
return res;
}
static const std::vector<std::string> non_speech_tokens
{
static const std::vector<std::string> non_speech_tokens = {
"\"", "#", "(", ")", "*", "+", "/", ":", ";", "<", "=", ">", "@", "[", "\\", "]", "^",
"_", "`", "{", "|", "}", "~", "", "", "", "", "<<", ">>", "<<<", ">>>", "--",
"---", "-(", "-[", "('", "(\"", "((", "))", "(((", ")))", "[[", "]]", "{{", "}}", "♪♪",
@ -3090,7 +3092,7 @@ static const std::vector<std::string> non_speech_tokens
// - applies logit filters
// - computes logprobs and probs
static void whisper_process_logits(
const struct whisper_context & ctx,
struct whisper_context & ctx,
const struct whisper_full_params params,
struct whisper_decoder & decoder,
float temperature) {
@ -3146,29 +3148,27 @@ static void whisper_process_logits(
logits[vocab.token_translate] = -INFINITY;
logits[vocab.token_transcribe] = -INFINITY;
if (params.logits_filter_callback) {
params.logits_filter_callback(&ctx, tokens_cur.data(), tokens_cur.size(), logits.data(), params.logits_filter_callback_user_data);
}
// suppress non-speech tokens
// ref: https://github.com/openai/whisper/blob/7858aa9c08d98f75575035ecd6481f462d66ca27/whisper/tokenizer.py#L224-L253
if (params.suppress_non_speech_tokens)
{
for (const std::string &token : non_speech_tokens)
{
std::string suppress_tokens[] = {token, " " + token};
for (const std::string &suppress_token : suppress_tokens)
{
if (vocab.token_to_id.find(suppress_token) != vocab.token_to_id.end())
{
if (params.suppress_non_speech_tokens) {
for (const std::string & token : non_speech_tokens) {
const std::string suppress_tokens[] = {token, " " + token};
for (const std::string & suppress_token : suppress_tokens) {
if (vocab.token_to_id.find(suppress_token) != vocab.token_to_id.end()) {
logits[vocab.token_to_id.at(suppress_token)] = -INFINITY;
}
}
}
// allow hyphens "-" and single quotes "'" between words, but not at the beginning of a word
if (vocab.token_to_id.find(" -") != vocab.token_to_id.end())
{
if (vocab.token_to_id.find(" -") != vocab.token_to_id.end()) {
logits[vocab.token_to_id.at(" -")] = -INFINITY;
}
if (vocab.token_to_id.find(" '") != vocab.token_to_id.end())
{
if (vocab.token_to_id.find(" '") != vocab.token_to_id.end()) {
logits[vocab.token_to_id.at(" '")] = -INFINITY;
}
}
@ -3854,7 +3854,7 @@ int whisper_full(
return a.sequence.sum_logprobs_all > b.sequence.sum_logprobs_all;
});
unsigned int cur_c = 0;
uint32_t cur_c = 0;
for (int j = 0; j < n_decoders_cur; ++j) {
auto & decoder = ctx->decoders[j];
@ -4339,7 +4339,7 @@ int whisper_full_n_segments(struct whisper_context * ctx) {
}
int whisper_full_lang_id(struct whisper_context * ctx) {
return ctx->lang_id;
return ctx->lang_id;
}
int64_t whisper_full_get_segment_t0(struct whisper_context * ctx, int i_segment) {

@ -243,6 +243,16 @@ extern "C" {
// If it returns false, the computation is aborted
typedef bool (*whisper_encoder_begin_callback)(struct whisper_context * ctx, void * user_data);
// Logits filter callback
// Can be used to modify the logits before sampling
// If not NULL, called after applying temperature to logits
typedef void (*whisper_logits_filter_callback)(
struct whisper_context * ctx,
const whisper_token_data * tokens,
int n_tokens,
float * logits,
void * user_data);
// Parameters for the whisper_full() function
// If you chnage the order or add new parameters, make sure to update the default values in whisper.cpp:
// whisper_full_default_params()
@ -315,6 +325,10 @@ extern "C" {
// called each time before the encoder starts
whisper_encoder_begin_callback encoder_begin_callback;
void * encoder_begin_callback_user_data;
// called by each decoder to filter obtained logits
whisper_logits_filter_callback logits_filter_callback;
void * logits_filter_callback_user_data;
};
WHISPER_API struct whisper_full_params whisper_full_default_params(enum whisper_sampling_strategy strategy);

@ -25,12 +25,94 @@ endif()
message(STATUS "CMAKE_SYSTEM_PROCESSOR: ${CMAKE_SYSTEM_PROCESSOR}")
if (NOT UNAME_S)
execute_process(COMMAND uname -s OUTPUT_VARIABLE UNAME_S)
endif()
if (NOT UNAME_P)
execute_process(COMMAND uname -p OUTPUT_VARIABLE UNAME_P)
endif()
if (NOT UNAME_M)
execute_process(COMMAND uname -m OUTPUT_VARIABLE UNAME_M)
endif()
message(STATUS "UNAME_S: ${UNAME_S} UNAME_P: ${UNAME_P} UNAME_M: ${UNAME_M}")
# Mac OS + Arm can report x86_64
# ref: https://github.com/ggerganov/whisper.cpp/issues/66#issuecomment-1282546789
if (UNAME_S MATCHES "Darwin")
if (NOT UNAME_P MATCHES "arm")
execute_process(COMMAND sysctl -n hw.optional.arm64 OUTPUT_VARIABLE SYSCTL_M)
if (SYSCTL_M MATCHES "1")
#set(UNAME_P "arm")
#set(UNAME_M "arm64")
message(WARNING "Your arch is announced as x86_64, but it seems to actually be ARM64. Not fixing that can lead to bad performance. For more info see: https://github.com/ggerganov/whisper.cpp/issues/66\#issuecomment-#1282546789")
endif()
endif()
endif()
if (${CMAKE_SYSTEM_PROCESSOR} MATCHES "arm" OR ${CMAKE_SYSTEM_PROCESSOR} MATCHES "aarch64")
message(STATUS "ARM detected")
#set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -mcpu=apple-m1")
else()
message(STATUS "x86 detected")
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -mavx -mavx2 -mfma -mf16c")
#set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -mavx -mavx2 -mfma -mf16c")
if (UNAME_S MATCHES "Darwin")
execute_process(COMMAND sysctl machdep.cpu.features OUTPUT_VARIABLE AVX1_M)
if (AVX1_M MATCHES "AVX1.0")
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -mavx")
endif()
execute_process(COMMAND sysctl machdep.cpu.leaf7_features OUTPUT_VARIABLE AVX2_M)
if (AVX2_M MATCHES "AVX2")
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -mavx2")
endif()
if (AVX1_M MATCHES "FMA")
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -mfma")
endif()
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -mf16c")
elseif (UNAME_S MATCHES "Linux")
message(STATUS "Linux detected")
execute_process(COMMAND grep "avx " /proc/cpuinfo OUTPUT_VARIABLE AVX1_M)
if (AVX1_M MATCHES "avx")
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -mavx")
endif()
execute_process(COMMAND grep "avx2 " /proc/cpuinfo OUTPUT_VARIABLE AVX2_M)
if (AVX2_M MATCHES "avx2")
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -mavx2")
endif()
execute_process(COMMAND grep "fma " /proc/cpuinfo OUTPUT_VARIABLE FMA_M)
if (FMA_M MATCHES "fma")
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -mfma")
endif()
execute_process(COMMAND grep "f16c " /proc/cpuinfo OUTPUT_VARIABLE F16C_M)
if (F16C_M MATCHES "f16c")
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -mf16c")
endif()
execute_process(COMMAND grep "sse3 " /proc/cpuinfo OUTPUT_VARIABLE SSE3_M)
if (SSE3_M MATCHES "sse3")
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -msse3")
endif()
message(STATUS "CMAKE_C_FLAGS: ${CMAKE_C_FLAGS}")
elseif (UNAME_S MATCHES "Haiku")
message(STATUS "Haiku detected")
execute_process(COMMAND sysinfo -cpu | grep "AVX " OUTPUT_VARIABLE AVX1_M)
if (AVX1_M MATCHES "avx")
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -mavx")
endif()
execute_process(COMMAND sysinfo -cpu | grep "AVX2 " OUTPUT_VARIABLE AVX2_M)
if (AVX2_M MATCHES "avx2")
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -mavx2")
endif()
execute_process(COMMAND sysinfo -cpu | grep "FMA " OUTPUT_VARIABLE FMA_M)
if (FMA_M MATCHES "fma")
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -mfma")
endif()
execute_process(COMMAND sysinfo -cpu | grep "F16C " OUTPUT_VARIABLE F16C_M)
if (F16C_M MATCHES "f16c")
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -mf16c")
endif()
message(STATUS "CMAKE_C_FLAGS: ${CMAKE_C_FLAGS}")
else()
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -mfma -mf16c -mavx -mavx2")
endif()
endif()

@ -13,7 +13,7 @@
#include <Accelerate/Accelerate.h>
const int M = 1280;
const int N = 1500;
const int N = 1536;
const int K = 1280;
uint64_t get_time_us() {
@ -284,7 +284,7 @@ int main(int argc, const char ** argv) {
if (method == 4) {
// Use BLAS sgemm from Accelerate framework
cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasNoTrans, M, K, N, 1.0f, src0, N, src1, N, 0.0f, dst, N);
cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, M, N, K, 1.0f, src0, K, src1, K, 0.0f, dst, N);
}
}

Loading…
Cancel
Save