From 8e3c634b271af2f3b8cc2b3a3f8d3b4e098a28be Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Wed, 5 Oct 2022 23:15:10 +0300 Subject: [PATCH] whisper : various improvements --- examples/whisper/main.cpp | 7 +- examples/whisper/whisper.cpp | 27 ++++---- include/ggml/ggml.h | 2 +- src/ggml.c | 130 +++++++++++++++++++++++++++++++++-- 4 files changed, 142 insertions(+), 24 deletions(-) diff --git a/examples/whisper/main.cpp b/examples/whisper/main.cpp index 562559a..6d1c55d 100644 --- a/examples/whisper/main.cpp +++ b/examples/whisper/main.cpp @@ -149,11 +149,11 @@ int main(int argc, char ** argv) { // convert to mono, float pcmf32.resize(n); if (wav.channels == 1) { - for (size_t i = 0; i < n; i++) { + for (int i = 0; i < n; i++) { pcmf32[i] = float(pcm16[i])/32768.0f; } } else { - for (size_t i = 0; i < n; i++) { + for (int i = 0; i < n; i++) { pcmf32[i] = float(pcm16[2*i] + pcm16[2*i + 1])/65536.0f; } } @@ -185,6 +185,9 @@ int main(int argc, char ** argv) { wparams.print_progress = false; wparams.print_timestamps = !params.no_timestamps; wparams.print_special_tokens = params.print_special_tokens; + wparams.translate = params.translate; + wparams.language = params.language.c_str(); + wparams.n_threads = params.n_threads; if (whisper_full(ctx, wparams, pcmf32.data(), pcmf32.size()) != 0) { fprintf(stderr, "%s: failed to process audio\n", argv[0]); diff --git a/examples/whisper/whisper.cpp b/examples/whisper/whisper.cpp index 4f105ee..46a4caa 100644 --- a/examples/whisper/whisper.cpp +++ b/examples/whisper/whisper.cpp @@ -1031,8 +1031,6 @@ bool whisper_encode( const auto & mel_inp = wctx.mel; const auto & hparams = model.hparams; - const int n_vocab = hparams.n_vocab; - const int n_ctx = hparams.n_audio_ctx; const int n_state = hparams.n_audio_state; const int n_head = hparams.n_audio_head; @@ -1293,7 +1291,8 @@ bool whisper_encode( struct ggml_tensor * inpO = ggml_add(ctxL, cur, inpFF); { - struct ggml_cgraph gf = { .n_threads = n_threads }; + struct ggml_cgraph gf = {}; + gf.n_threads = n_threads; ggml_build_forward_expand(&gf, inpO); ggml_graph_compute (ctxL, &gf); @@ -1329,7 +1328,8 @@ bool whisper_encode( // run the computation { - struct ggml_cgraph gf = { .n_threads = n_threads }; + struct ggml_cgraph gf = {}; + gf.n_threads = n_threads; ggml_build_forward_expand(&gf, cur); ggml_graph_compute (ctx0, &gf); @@ -1353,7 +1353,8 @@ bool whisper_encode( // pre-compute cross-attention memory { - struct ggml_cgraph gf = { .n_threads = n_threads }; + struct ggml_cgraph gf = {}; + gf.n_threads = n_threads; // TODO: hack to disconnect the encoded features from the previous graph cur->op = GGML_OP_NONE; @@ -1463,7 +1464,8 @@ bool whisper_decode( }; struct ggml_context * ctxL = ggml_init(paramsL); - struct ggml_cgraph gf = { .n_threads = n_threads }; + struct ggml_cgraph gf = {}; + gf.n_threads = n_threads; // norm { @@ -1746,7 +1748,8 @@ bool whisper_decode( // run the computation { - struct ggml_cgraph gf = { .n_threads = n_threads }; + struct ggml_cgraph gf = {}; + gf.n_threads = n_threads; ggml_build_forward_expand(&gf, cur); ggml_graph_compute (ctx0, &gf); @@ -2336,7 +2339,7 @@ int whisper_full( } } - if (seek >= whisper_n_len(ctx)) { + if (seek + 100 >= whisper_n_len(ctx)) { break; } @@ -2365,7 +2368,6 @@ int whisper_full( bool done = false; int seek_delta = 100*WHISPER_CHUNK_SIZE; - whisper_token last_id = 0; // print the prompt //printf("\n\n"); @@ -2395,8 +2397,6 @@ int whisper_full( // feel free to experiment! // { - const int n_vocab = whisper_n_vocab(ctx); - whisper_token id = 0; whisper_token tid = whisper_token_beg(ctx); @@ -2410,7 +2410,6 @@ int whisper_full( seek_delta = 2*(id - whisper_token_beg(ctx)); result_len = i + 1; } - last_id = id; // add it to the context prompt.push_back(id); @@ -2444,7 +2443,7 @@ int whisper_full( std::string text = ""; - for (int i = 0; i < result_cur.size(); i++) { + for (int i = 0; i < (int) result_cur.size(); i++) { if (params.print_special_tokens == false && result_cur[i].id >= whisper_token_eot(ctx)) { } else { text += whisper_token_to_str(ctx, result_cur[i].id); @@ -2464,7 +2463,7 @@ int whisper_full( result_all.push_back({ t0, t1, text }); } text = ""; - while (result_cur[i].id > whisper_token_beg(ctx) && i < result_cur.size()) { + while (result_cur[i].id > whisper_token_beg(ctx) && i < (int) result_cur.size()) { i++; } i--; diff --git a/include/ggml/ggml.h b/include/ggml/ggml.h index 465a9b6..5b7b258 100644 --- a/include/ggml/ggml.h +++ b/include/ggml/ggml.h @@ -108,7 +108,7 @@ struct ggml_tensor { int64_t perf_time_us; void * data; - char pad[8]; + char padding[8]; }; // computation graph diff --git a/src/ggml.c b/src/ggml.c index 9b18d81..5894489 100644 --- a/src/ggml.c +++ b/src/ggml.c @@ -1,5 +1,6 @@ #include "ggml.h" +#include #include #include #include @@ -12,7 +13,12 @@ #include #define GGML_DEBUG 0 -#define GGML_MEM_ALIGN 16 + +#if UINTPTR_MAX == 0xFFFFFFFF + #define GGML_MEM_ALIGN 4 +#else + #define GGML_MEM_ALIGN 16 +#endif #define MAX(a, b) ((a) > (b) ? (a) : (b)) #define MIN(a, b) ((a) < (b) ? (a) : (b)) @@ -305,6 +311,7 @@ inline static void ggml_vec_dot_f16(const int n, float * restrict s, ggml_fp16_t #ifdef __ARM_NEON const int n32 = (n & ~31); +#if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) float16x8_t sum0 = vdupq_n_f16(0); float16x8_t sum1 = vdupq_n_f16(0); float16x8_t sum2 = vdupq_n_f16(0); @@ -344,6 +351,61 @@ inline static void ggml_vec_dot_f16(const int n, float * restrict s, ggml_fp16_t float32x2_t sumf32 = vadd_f32(vget_low_f32(sum0f32), vget_high_f32(sum0f32)); sumf = vget_lane_f32(sumf32, 0) + vget_lane_f32(sumf32, 1); +#else + float32x4_t sum0 = vdupq_n_f32(0); + float32x4_t sum1 = vdupq_n_f32(0); + float32x4_t sum2 = vdupq_n_f32(0); + float32x4_t sum3 = vdupq_n_f32(0); + float32x4_t sum4 = vdupq_n_f32(0); + float32x4_t sum5 = vdupq_n_f32(0); + float32x4_t sum6 = vdupq_n_f32(0); + float32x4_t sum7 = vdupq_n_f32(0); + + float32x4_t x0, x1, x2, x3, x4, x5, x6, x7; + float32x4_t y0, y1, y2, y3, y4, y5, y6, y7; + + for (int i = 0; i < n32; i += 32) { + x0 = vcvt_f32_f16(vld1_f16(x + i + 0 )); + x1 = vcvt_f32_f16(vld1_f16(x + i + 4 )); + x2 = vcvt_f32_f16(vld1_f16(x + i + 8 )); + x3 = vcvt_f32_f16(vld1_f16(x + i + 12)); + x4 = vcvt_f32_f16(vld1_f16(x + i + 16)); + x5 = vcvt_f32_f16(vld1_f16(x + i + 20)); + x6 = vcvt_f32_f16(vld1_f16(x + i + 24)); + x7 = vcvt_f32_f16(vld1_f16(x + i + 28)); + + y0 = vcvt_f32_f16(vld1_f16(y + i + 0 )); + y1 = vcvt_f32_f16(vld1_f16(y + i + 4 )); + y2 = vcvt_f32_f16(vld1_f16(y + i + 8 )); + y3 = vcvt_f32_f16(vld1_f16(y + i + 12)); + y4 = vcvt_f32_f16(vld1_f16(y + i + 16)); + y5 = vcvt_f32_f16(vld1_f16(y + i + 20)); + y6 = vcvt_f32_f16(vld1_f16(y + i + 24)); + y7 = vcvt_f32_f16(vld1_f16(y + i + 28)); + + sum0 = vfmaq_f32(sum0, x0, y0); + sum1 = vfmaq_f32(sum1, x1, y1); + sum2 = vfmaq_f32(sum2, x2, y2); + sum3 = vfmaq_f32(sum3, x3, y3); + sum4 = vfmaq_f32(sum4, x4, y4); + sum5 = vfmaq_f32(sum5, x5, y5); + sum6 = vfmaq_f32(sum6, x6, y6); + sum7 = vfmaq_f32(sum7, x7, y7); + } + + // reduce sum0..sum7 to sum0 + sum0 = vaddq_f32(sum0, sum1); + sum2 = vaddq_f32(sum2, sum3); + sum4 = vaddq_f32(sum4, sum5); + sum6 = vaddq_f32(sum6, sum7); + sum0 = vaddq_f32(sum0, sum2); + sum4 = vaddq_f32(sum4, sum6); + sum0 = vaddq_f32(sum0, sum4); + + // reduce sum0 to sumf + float32x2_t sumf32 = vadd_f32(vget_low_f32(sum0), vget_high_f32(sum0)); + sumf = vget_lane_f32(sumf32, 0) + vget_lane_f32(sumf32, 1); +#endif // leftovers for (int i = n32; i < n; ++i) { @@ -486,6 +548,7 @@ inline static void ggml_vec_mad_f16(const int n, ggml_fp16_t * restrict y, ggml_ // NEON 128-bit const int n32 = (n & ~31); +#if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) const float16x8_t v8 = vdupq_n_f16(v); float16x8_t x0, x1, x2, x3; @@ -512,6 +575,51 @@ inline static void ggml_vec_mad_f16(const int n, ggml_fp16_t * restrict y, ggml_ vst1q_f16(y + i + 16, y2); vst1q_f16(y + i + 24, y3); } +#else + const float32x4_t v40 = vdupq_n_f32(v); + const float32x4_t v41 = vdupq_n_f32(v); + + float32x4_t x0, x1, x2, x3, x4, x5, x6, x7; + float32x4_t y0, y1, y2, y3, y4, y5, y6, y7; + + for (int i = 0; i < n32; i += 32) { + y0 = vcvt_f32_f16(vld1_f16(y + i + 0 )); + y1 = vcvt_f32_f16(vld1_f16(y + i + 4 )); + y2 = vcvt_f32_f16(vld1_f16(y + i + 8 )); + y3 = vcvt_f32_f16(vld1_f16(y + i + 12)); + y4 = vcvt_f32_f16(vld1_f16(y + i + 16)); + y5 = vcvt_f32_f16(vld1_f16(y + i + 20)); + y6 = vcvt_f32_f16(vld1_f16(y + i + 24)); + y7 = vcvt_f32_f16(vld1_f16(y + i + 28)); + + x0 = vcvt_f32_f16(vld1_f16(x + i + 0 )); + x1 = vcvt_f32_f16(vld1_f16(x + i + 4 )); + x2 = vcvt_f32_f16(vld1_f16(x + i + 8 )); + x3 = vcvt_f32_f16(vld1_f16(x + i + 12)); + x4 = vcvt_f32_f16(vld1_f16(x + i + 16)); + x5 = vcvt_f32_f16(vld1_f16(x + i + 20)); + x6 = vcvt_f32_f16(vld1_f16(x + i + 24)); + x7 = vcvt_f32_f16(vld1_f16(x + i + 28)); + + y0 = vfmaq_f32(y0, x0, v40); + y1 = vfmaq_f32(y1, x1, v40); + y2 = vfmaq_f32(y2, x2, v40); + y3 = vfmaq_f32(y3, x3, v40); + y4 = vfmaq_f32(y4, x4, v41); + y5 = vfmaq_f32(y5, x5, v41); + y6 = vfmaq_f32(y6, x6, v41); + y7 = vfmaq_f32(y7, x7, v41); + + vst1_f16(y + i + 0 , vcvt_f16_f32(y0)); + vst1_f16(y + i + 4 , vcvt_f16_f32(y1)); + vst1_f16(y + i + 8 , vcvt_f16_f32(y2)); + vst1_f16(y + i + 12, vcvt_f16_f32(y3)); + vst1_f16(y + i + 16, vcvt_f16_f32(y4)); + vst1_f16(y + i + 20, vcvt_f16_f32(y5)); + vst1_f16(y + i + 24, vcvt_f16_f32(y6)); + vst1_f16(y + i + 28, vcvt_f16_f32(y7)); + } +#endif // leftovers for (int i = n32; i < n; ++i) { @@ -911,16 +1019,18 @@ struct ggml_context * ggml_init(struct ggml_init_params params) { if (is_first_call) { const uint64_t t_start = ggml_time_us(); UNUSED(t_start); + ggml_fp16_t ii; for (int i = 0; i < (1 << 16); ++i) { - uint16_t ii = (uint16_t) i; - const float f = ggml_fp16_to_fp32(*(ggml_fp16_t *)(&ii)); + uint16_t ui = i; + memcpy(&ii, &ui, sizeof(ii)); + const float f = ggml_fp16_to_fp32(ii); table_gelu_f16[i] = ggml_fp32_to_fp16(ggml_gelu_f32(f)); table_exp_f16[i] = ggml_fp32_to_fp16(exp(f)); } const uint64_t t_end = ggml_time_us(); UNUSED(t_end); - GGML_PRINT_DEBUG("%s: GELU table initialized in %f ms\n", __func__, (t_end - t_start)/1000.0f); + GGML_PRINT_DEBUG("%s: GELU and EXP tables initialized in %f ms\n", __func__, (t_end - t_start)/1000.0f); is_first_call = false; } @@ -4427,13 +4537,15 @@ void ggml_compute_forward_soft_max_f32( ggml_float sum = 0.0; + uint16_t ss; for (int i = 0; i < nc; i++) { if (p[i] == -INFINITY) { p[i] = 0.0; } else { //const float val = (p[i] == -INFINITY) ? 0.0 : exp(p[i] - max); ggml_fp16_t s = ggml_fp32_to_fp16(p[i] - max); - const float val = ggml_fp16_to_fp32(table_exp_f16[*(uint16_t *) &s]); + memcpy(&ss, &s, sizeof(ss)); + const float val = ggml_fp16_to_fp32(table_exp_f16[ss]); sum += val; p[i] = val; } @@ -5234,13 +5346,15 @@ void ggml_compute_forward_flash_attn_f32( ggml_float sum = 0.0; + uint16_t ss; for (int i = 0; i < M; i++) { if (S[i] == -INFINITY) { S[i] = 0.0; } else { //const float val = (S[i] == -INFINITY) ? 0.0 : exp(S[i] - max); ggml_fp16_t s = ggml_fp32_to_fp16(S[i] - max); - const float val = ggml_fp16_to_fp32(table_exp_f16[*(uint16_t *) &s]); + memcpy(&ss, &s, sizeof(ss)); + const float val = ggml_fp16_to_fp32(table_exp_f16[ss]); sum += val; S[i] = val; } @@ -5413,13 +5527,15 @@ void ggml_compute_forward_flash_attn_f16( ggml_float sum = 0.0; + uint16_t ss; for (int i = 0; i < M; i++) { if (S[i] == -INFINITY) { S[i] = 0.0; } else { //const float val = (S[i] == -INFINITY) ? 0.0 : exp(S[i] - max); ggml_fp16_t s = ggml_fp32_to_fp16(S[i] - max); - const float val = ggml_fp16_to_fp32(table_exp_f16[*(uint16_t *) &s]); + memcpy(&ss, &s, sizeof(ss)); + const float val = ggml_fp16_to_fp32(table_exp_f16[ss]); sum += val; S[i] = val; }