whisper : various improvements

experiments/blocking
Georgi Gerganov 2 years ago
parent 8ca553add4
commit 8e3c634b27
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735

@ -149,11 +149,11 @@ int main(int argc, char ** argv) {
// convert to mono, float // convert to mono, float
pcmf32.resize(n); pcmf32.resize(n);
if (wav.channels == 1) { if (wav.channels == 1) {
for (size_t i = 0; i < n; i++) { for (int i = 0; i < n; i++) {
pcmf32[i] = float(pcm16[i])/32768.0f; pcmf32[i] = float(pcm16[i])/32768.0f;
} }
} else { } else {
for (size_t i = 0; i < n; i++) { for (int i = 0; i < n; i++) {
pcmf32[i] = float(pcm16[2*i] + pcm16[2*i + 1])/65536.0f; pcmf32[i] = float(pcm16[2*i] + pcm16[2*i + 1])/65536.0f;
} }
} }
@ -185,6 +185,9 @@ int main(int argc, char ** argv) {
wparams.print_progress = false; wparams.print_progress = false;
wparams.print_timestamps = !params.no_timestamps; wparams.print_timestamps = !params.no_timestamps;
wparams.print_special_tokens = params.print_special_tokens; wparams.print_special_tokens = params.print_special_tokens;
wparams.translate = params.translate;
wparams.language = params.language.c_str();
wparams.n_threads = params.n_threads;
if (whisper_full(ctx, wparams, pcmf32.data(), pcmf32.size()) != 0) { if (whisper_full(ctx, wparams, pcmf32.data(), pcmf32.size()) != 0) {
fprintf(stderr, "%s: failed to process audio\n", argv[0]); fprintf(stderr, "%s: failed to process audio\n", argv[0]);

@ -1031,8 +1031,6 @@ bool whisper_encode(
const auto & mel_inp = wctx.mel; const auto & mel_inp = wctx.mel;
const auto & hparams = model.hparams; const auto & hparams = model.hparams;
const int n_vocab = hparams.n_vocab;
const int n_ctx = hparams.n_audio_ctx; const int n_ctx = hparams.n_audio_ctx;
const int n_state = hparams.n_audio_state; const int n_state = hparams.n_audio_state;
const int n_head = hparams.n_audio_head; const int n_head = hparams.n_audio_head;
@ -1293,7 +1291,8 @@ bool whisper_encode(
struct ggml_tensor * inpO = ggml_add(ctxL, cur, inpFF); struct ggml_tensor * inpO = ggml_add(ctxL, cur, inpFF);
{ {
struct ggml_cgraph gf = { .n_threads = n_threads }; struct ggml_cgraph gf = {};
gf.n_threads = n_threads;
ggml_build_forward_expand(&gf, inpO); ggml_build_forward_expand(&gf, inpO);
ggml_graph_compute (ctxL, &gf); ggml_graph_compute (ctxL, &gf);
@ -1329,7 +1328,8 @@ bool whisper_encode(
// run the computation // run the computation
{ {
struct ggml_cgraph gf = { .n_threads = n_threads }; struct ggml_cgraph gf = {};
gf.n_threads = n_threads;
ggml_build_forward_expand(&gf, cur); ggml_build_forward_expand(&gf, cur);
ggml_graph_compute (ctx0, &gf); ggml_graph_compute (ctx0, &gf);
@ -1353,7 +1353,8 @@ bool whisper_encode(
// pre-compute cross-attention memory // pre-compute cross-attention memory
{ {
struct ggml_cgraph gf = { .n_threads = n_threads }; struct ggml_cgraph gf = {};
gf.n_threads = n_threads;
// TODO: hack to disconnect the encoded features from the previous graph // TODO: hack to disconnect the encoded features from the previous graph
cur->op = GGML_OP_NONE; cur->op = GGML_OP_NONE;
@ -1463,7 +1464,8 @@ bool whisper_decode(
}; };
struct ggml_context * ctxL = ggml_init(paramsL); struct ggml_context * ctxL = ggml_init(paramsL);
struct ggml_cgraph gf = { .n_threads = n_threads }; struct ggml_cgraph gf = {};
gf.n_threads = n_threads;
// norm // norm
{ {
@ -1746,7 +1748,8 @@ bool whisper_decode(
// run the computation // run the computation
{ {
struct ggml_cgraph gf = { .n_threads = n_threads }; struct ggml_cgraph gf = {};
gf.n_threads = n_threads;
ggml_build_forward_expand(&gf, cur); ggml_build_forward_expand(&gf, cur);
ggml_graph_compute (ctx0, &gf); ggml_graph_compute (ctx0, &gf);
@ -2336,7 +2339,7 @@ int whisper_full(
} }
} }
if (seek >= whisper_n_len(ctx)) { if (seek + 100 >= whisper_n_len(ctx)) {
break; break;
} }
@ -2365,7 +2368,6 @@ int whisper_full(
bool done = false; bool done = false;
int seek_delta = 100*WHISPER_CHUNK_SIZE; int seek_delta = 100*WHISPER_CHUNK_SIZE;
whisper_token last_id = 0;
// print the prompt // print the prompt
//printf("\n\n"); //printf("\n\n");
@ -2395,8 +2397,6 @@ int whisper_full(
// feel free to experiment! // feel free to experiment!
// //
{ {
const int n_vocab = whisper_n_vocab(ctx);
whisper_token id = 0; whisper_token id = 0;
whisper_token tid = whisper_token_beg(ctx); whisper_token tid = whisper_token_beg(ctx);
@ -2410,7 +2410,6 @@ int whisper_full(
seek_delta = 2*(id - whisper_token_beg(ctx)); seek_delta = 2*(id - whisper_token_beg(ctx));
result_len = i + 1; result_len = i + 1;
} }
last_id = id;
// add it to the context // add it to the context
prompt.push_back(id); prompt.push_back(id);
@ -2444,7 +2443,7 @@ int whisper_full(
std::string text = ""; std::string text = "";
for (int i = 0; i < result_cur.size(); i++) { for (int i = 0; i < (int) result_cur.size(); i++) {
if (params.print_special_tokens == false && result_cur[i].id >= whisper_token_eot(ctx)) { if (params.print_special_tokens == false && result_cur[i].id >= whisper_token_eot(ctx)) {
} else { } else {
text += whisper_token_to_str(ctx, result_cur[i].id); text += whisper_token_to_str(ctx, result_cur[i].id);
@ -2464,7 +2463,7 @@ int whisper_full(
result_all.push_back({ t0, t1, text }); result_all.push_back({ t0, t1, text });
} }
text = ""; text = "";
while (result_cur[i].id > whisper_token_beg(ctx) && i < result_cur.size()) { while (result_cur[i].id > whisper_token_beg(ctx) && i < (int) result_cur.size()) {
i++; i++;
} }
i--; i--;

@ -108,7 +108,7 @@ struct ggml_tensor {
int64_t perf_time_us; int64_t perf_time_us;
void * data; void * data;
char pad[8]; char padding[8];
}; };
// computation graph // computation graph

@ -1,5 +1,6 @@
#include "ggml.h" #include "ggml.h"
#include <alloca.h>
#include <assert.h> #include <assert.h>
#include <time.h> #include <time.h>
#include <math.h> #include <math.h>
@ -12,7 +13,12 @@
#include <pthread.h> #include <pthread.h>
#define GGML_DEBUG 0 #define GGML_DEBUG 0
#define GGML_MEM_ALIGN 16
#if UINTPTR_MAX == 0xFFFFFFFF
#define GGML_MEM_ALIGN 4
#else
#define GGML_MEM_ALIGN 16
#endif
#define MAX(a, b) ((a) > (b) ? (a) : (b)) #define MAX(a, b) ((a) > (b) ? (a) : (b))
#define MIN(a, b) ((a) < (b) ? (a) : (b)) #define MIN(a, b) ((a) < (b) ? (a) : (b))
@ -305,6 +311,7 @@ inline static void ggml_vec_dot_f16(const int n, float * restrict s, ggml_fp16_t
#ifdef __ARM_NEON #ifdef __ARM_NEON
const int n32 = (n & ~31); const int n32 = (n & ~31);
#if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC)
float16x8_t sum0 = vdupq_n_f16(0); float16x8_t sum0 = vdupq_n_f16(0);
float16x8_t sum1 = vdupq_n_f16(0); float16x8_t sum1 = vdupq_n_f16(0);
float16x8_t sum2 = vdupq_n_f16(0); float16x8_t sum2 = vdupq_n_f16(0);
@ -344,6 +351,61 @@ inline static void ggml_vec_dot_f16(const int n, float * restrict s, ggml_fp16_t
float32x2_t sumf32 = vadd_f32(vget_low_f32(sum0f32), vget_high_f32(sum0f32)); float32x2_t sumf32 = vadd_f32(vget_low_f32(sum0f32), vget_high_f32(sum0f32));
sumf = vget_lane_f32(sumf32, 0) + vget_lane_f32(sumf32, 1); sumf = vget_lane_f32(sumf32, 0) + vget_lane_f32(sumf32, 1);
#else
float32x4_t sum0 = vdupq_n_f32(0);
float32x4_t sum1 = vdupq_n_f32(0);
float32x4_t sum2 = vdupq_n_f32(0);
float32x4_t sum3 = vdupq_n_f32(0);
float32x4_t sum4 = vdupq_n_f32(0);
float32x4_t sum5 = vdupq_n_f32(0);
float32x4_t sum6 = vdupq_n_f32(0);
float32x4_t sum7 = vdupq_n_f32(0);
float32x4_t x0, x1, x2, x3, x4, x5, x6, x7;
float32x4_t y0, y1, y2, y3, y4, y5, y6, y7;
for (int i = 0; i < n32; i += 32) {
x0 = vcvt_f32_f16(vld1_f16(x + i + 0 ));
x1 = vcvt_f32_f16(vld1_f16(x + i + 4 ));
x2 = vcvt_f32_f16(vld1_f16(x + i + 8 ));
x3 = vcvt_f32_f16(vld1_f16(x + i + 12));
x4 = vcvt_f32_f16(vld1_f16(x + i + 16));
x5 = vcvt_f32_f16(vld1_f16(x + i + 20));
x6 = vcvt_f32_f16(vld1_f16(x + i + 24));
x7 = vcvt_f32_f16(vld1_f16(x + i + 28));
y0 = vcvt_f32_f16(vld1_f16(y + i + 0 ));
y1 = vcvt_f32_f16(vld1_f16(y + i + 4 ));
y2 = vcvt_f32_f16(vld1_f16(y + i + 8 ));
y3 = vcvt_f32_f16(vld1_f16(y + i + 12));
y4 = vcvt_f32_f16(vld1_f16(y + i + 16));
y5 = vcvt_f32_f16(vld1_f16(y + i + 20));
y6 = vcvt_f32_f16(vld1_f16(y + i + 24));
y7 = vcvt_f32_f16(vld1_f16(y + i + 28));
sum0 = vfmaq_f32(sum0, x0, y0);
sum1 = vfmaq_f32(sum1, x1, y1);
sum2 = vfmaq_f32(sum2, x2, y2);
sum3 = vfmaq_f32(sum3, x3, y3);
sum4 = vfmaq_f32(sum4, x4, y4);
sum5 = vfmaq_f32(sum5, x5, y5);
sum6 = vfmaq_f32(sum6, x6, y6);
sum7 = vfmaq_f32(sum7, x7, y7);
}
// reduce sum0..sum7 to sum0
sum0 = vaddq_f32(sum0, sum1);
sum2 = vaddq_f32(sum2, sum3);
sum4 = vaddq_f32(sum4, sum5);
sum6 = vaddq_f32(sum6, sum7);
sum0 = vaddq_f32(sum0, sum2);
sum4 = vaddq_f32(sum4, sum6);
sum0 = vaddq_f32(sum0, sum4);
// reduce sum0 to sumf
float32x2_t sumf32 = vadd_f32(vget_low_f32(sum0), vget_high_f32(sum0));
sumf = vget_lane_f32(sumf32, 0) + vget_lane_f32(sumf32, 1);
#endif
// leftovers // leftovers
for (int i = n32; i < n; ++i) { for (int i = n32; i < n; ++i) {
@ -486,6 +548,7 @@ inline static void ggml_vec_mad_f16(const int n, ggml_fp16_t * restrict y, ggml_
// NEON 128-bit // NEON 128-bit
const int n32 = (n & ~31); const int n32 = (n & ~31);
#if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC)
const float16x8_t v8 = vdupq_n_f16(v); const float16x8_t v8 = vdupq_n_f16(v);
float16x8_t x0, x1, x2, x3; float16x8_t x0, x1, x2, x3;
@ -512,6 +575,51 @@ inline static void ggml_vec_mad_f16(const int n, ggml_fp16_t * restrict y, ggml_
vst1q_f16(y + i + 16, y2); vst1q_f16(y + i + 16, y2);
vst1q_f16(y + i + 24, y3); vst1q_f16(y + i + 24, y3);
} }
#else
const float32x4_t v40 = vdupq_n_f32(v);
const float32x4_t v41 = vdupq_n_f32(v);
float32x4_t x0, x1, x2, x3, x4, x5, x6, x7;
float32x4_t y0, y1, y2, y3, y4, y5, y6, y7;
for (int i = 0; i < n32; i += 32) {
y0 = vcvt_f32_f16(vld1_f16(y + i + 0 ));
y1 = vcvt_f32_f16(vld1_f16(y + i + 4 ));
y2 = vcvt_f32_f16(vld1_f16(y + i + 8 ));
y3 = vcvt_f32_f16(vld1_f16(y + i + 12));
y4 = vcvt_f32_f16(vld1_f16(y + i + 16));
y5 = vcvt_f32_f16(vld1_f16(y + i + 20));
y6 = vcvt_f32_f16(vld1_f16(y + i + 24));
y7 = vcvt_f32_f16(vld1_f16(y + i + 28));
x0 = vcvt_f32_f16(vld1_f16(x + i + 0 ));
x1 = vcvt_f32_f16(vld1_f16(x + i + 4 ));
x2 = vcvt_f32_f16(vld1_f16(x + i + 8 ));
x3 = vcvt_f32_f16(vld1_f16(x + i + 12));
x4 = vcvt_f32_f16(vld1_f16(x + i + 16));
x5 = vcvt_f32_f16(vld1_f16(x + i + 20));
x6 = vcvt_f32_f16(vld1_f16(x + i + 24));
x7 = vcvt_f32_f16(vld1_f16(x + i + 28));
y0 = vfmaq_f32(y0, x0, v40);
y1 = vfmaq_f32(y1, x1, v40);
y2 = vfmaq_f32(y2, x2, v40);
y3 = vfmaq_f32(y3, x3, v40);
y4 = vfmaq_f32(y4, x4, v41);
y5 = vfmaq_f32(y5, x5, v41);
y6 = vfmaq_f32(y6, x6, v41);
y7 = vfmaq_f32(y7, x7, v41);
vst1_f16(y + i + 0 , vcvt_f16_f32(y0));
vst1_f16(y + i + 4 , vcvt_f16_f32(y1));
vst1_f16(y + i + 8 , vcvt_f16_f32(y2));
vst1_f16(y + i + 12, vcvt_f16_f32(y3));
vst1_f16(y + i + 16, vcvt_f16_f32(y4));
vst1_f16(y + i + 20, vcvt_f16_f32(y5));
vst1_f16(y + i + 24, vcvt_f16_f32(y6));
vst1_f16(y + i + 28, vcvt_f16_f32(y7));
}
#endif
// leftovers // leftovers
for (int i = n32; i < n; ++i) { for (int i = n32; i < n; ++i) {
@ -911,16 +1019,18 @@ struct ggml_context * ggml_init(struct ggml_init_params params) {
if (is_first_call) { if (is_first_call) {
const uint64_t t_start = ggml_time_us(); UNUSED(t_start); const uint64_t t_start = ggml_time_us(); UNUSED(t_start);
ggml_fp16_t ii;
for (int i = 0; i < (1 << 16); ++i) { for (int i = 0; i < (1 << 16); ++i) {
uint16_t ii = (uint16_t) i; uint16_t ui = i;
const float f = ggml_fp16_to_fp32(*(ggml_fp16_t *)(&ii)); memcpy(&ii, &ui, sizeof(ii));
const float f = ggml_fp16_to_fp32(ii);
table_gelu_f16[i] = ggml_fp32_to_fp16(ggml_gelu_f32(f)); table_gelu_f16[i] = ggml_fp32_to_fp16(ggml_gelu_f32(f));
table_exp_f16[i] = ggml_fp32_to_fp16(exp(f)); table_exp_f16[i] = ggml_fp32_to_fp16(exp(f));
} }
const uint64_t t_end = ggml_time_us(); UNUSED(t_end); const uint64_t t_end = ggml_time_us(); UNUSED(t_end);
GGML_PRINT_DEBUG("%s: GELU table initialized in %f ms\n", __func__, (t_end - t_start)/1000.0f); GGML_PRINT_DEBUG("%s: GELU and EXP tables initialized in %f ms\n", __func__, (t_end - t_start)/1000.0f);
is_first_call = false; is_first_call = false;
} }
@ -4427,13 +4537,15 @@ void ggml_compute_forward_soft_max_f32(
ggml_float sum = 0.0; ggml_float sum = 0.0;
uint16_t ss;
for (int i = 0; i < nc; i++) { for (int i = 0; i < nc; i++) {
if (p[i] == -INFINITY) { if (p[i] == -INFINITY) {
p[i] = 0.0; p[i] = 0.0;
} else { } else {
//const float val = (p[i] == -INFINITY) ? 0.0 : exp(p[i] - max); //const float val = (p[i] == -INFINITY) ? 0.0 : exp(p[i] - max);
ggml_fp16_t s = ggml_fp32_to_fp16(p[i] - max); ggml_fp16_t s = ggml_fp32_to_fp16(p[i] - max);
const float val = ggml_fp16_to_fp32(table_exp_f16[*(uint16_t *) &s]); memcpy(&ss, &s, sizeof(ss));
const float val = ggml_fp16_to_fp32(table_exp_f16[ss]);
sum += val; sum += val;
p[i] = val; p[i] = val;
} }
@ -5234,13 +5346,15 @@ void ggml_compute_forward_flash_attn_f32(
ggml_float sum = 0.0; ggml_float sum = 0.0;
uint16_t ss;
for (int i = 0; i < M; i++) { for (int i = 0; i < M; i++) {
if (S[i] == -INFINITY) { if (S[i] == -INFINITY) {
S[i] = 0.0; S[i] = 0.0;
} else { } else {
//const float val = (S[i] == -INFINITY) ? 0.0 : exp(S[i] - max); //const float val = (S[i] == -INFINITY) ? 0.0 : exp(S[i] - max);
ggml_fp16_t s = ggml_fp32_to_fp16(S[i] - max); ggml_fp16_t s = ggml_fp32_to_fp16(S[i] - max);
const float val = ggml_fp16_to_fp32(table_exp_f16[*(uint16_t *) &s]); memcpy(&ss, &s, sizeof(ss));
const float val = ggml_fp16_to_fp32(table_exp_f16[ss]);
sum += val; sum += val;
S[i] = val; S[i] = val;
} }
@ -5413,13 +5527,15 @@ void ggml_compute_forward_flash_attn_f16(
ggml_float sum = 0.0; ggml_float sum = 0.0;
uint16_t ss;
for (int i = 0; i < M; i++) { for (int i = 0; i < M; i++) {
if (S[i] == -INFINITY) { if (S[i] == -INFINITY) {
S[i] = 0.0; S[i] = 0.0;
} else { } else {
//const float val = (S[i] == -INFINITY) ? 0.0 : exp(S[i] - max); //const float val = (S[i] == -INFINITY) ? 0.0 : exp(S[i] - max);
ggml_fp16_t s = ggml_fp32_to_fp16(S[i] - max); ggml_fp16_t s = ggml_fp32_to_fp16(S[i] - max);
const float val = ggml_fp16_to_fp32(table_exp_f16[*(uint16_t *) &s]); memcpy(&ss, &s, sizeof(ss));
const float val = ggml_fp16_to_fp32(table_exp_f16[ss]);
sum += val; sum += val;
S[i] = val; S[i] = val;
} }

Loading…
Cancel
Save