diff --git a/examples/command/command.cpp b/examples/command/command.cpp index 74a14f9..4558a67 100644 --- a/examples/command/command.cpp +++ b/examples/command/command.cpp @@ -781,7 +781,7 @@ int always_prompt_transcription(struct whisper_context * ctx, audio_async & audi std::string prompt; std::string command; - for (int i = 0; i < words.size(); ++i) { + for (int i = 0; i < (int) words.size(); ++i) { if (i < k_prompt_length) { prompt += words[i] + " "; } else { diff --git a/ggml.c b/ggml.c index eefdcdd..ccbd6c7 100644 --- a/ggml.c +++ b/ggml.c @@ -84,7 +84,7 @@ typedef void* thread_ret_t; #define GGML_GELU_FP16 #define GGML_SOFT_MAX_UNROLL 4 -#define GGML_VEC_DOT_UNROLL 4 +#define GGML_VEC_DOT_UNROLL 2 #ifdef GGML_USE_ACCELERATE // uncomment to use vDSP for soft max computation @@ -923,9 +923,9 @@ inline static void ggml_vec_dot_f16(const int n, float * restrict s, ggml_fp16_t inline static void ggml_vec_dot_f16_unroll(const int n, const int xs, float * restrict s, void * restrict xv, ggml_fp16_t * restrict y) { ggml_float sumf[GGML_VEC_DOT_UNROLL] = { 0.0 }; - const ggml_fp16_t * restrict x[GGML_VEC_DOT_UNROLL] = { xv }; + ggml_fp16_t * restrict x[GGML_VEC_DOT_UNROLL]; - for (int i = 1; i < GGML_VEC_DOT_UNROLL; ++i) { + for (int i = 0; i < GGML_VEC_DOT_UNROLL; ++i) { x[i] = (ggml_fp16_t *) ((char *) xv + i*xs); } @@ -6158,40 +6158,37 @@ static void ggml_compute_forward_flash_attn_f16( S[i] = -INFINITY; } - // looks like unrolling here does not help -#if 1 - for (int ic = 0; ic < nek1; ++ic) { - // k indices - const int ik3 = iq3; - const int ik2 = iq2; - const int ik1 = ic; - - // S indices - const int i1 = ik1; - - ggml_vec_dot_f16(neq0, - S + i1, - (ggml_fp16_t *) ((char *) k->data + (ik1*nbk1 + ik2*nbk2 + ik3*nbk3)), - (ggml_fp16_t *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3))); - } -#else - GGML_ASSERT(nek1 % GGML_VEC_DOT_UNROLL == 0); - - for (int ic = 0; ic < nek1; ic += GGML_VEC_DOT_UNROLL) { - // k indices - const int ik3 = iq3; - const int ik2 = iq2; - const int ik1 = ic; + if (GGML_VEC_DOT_UNROLL > 2 || nek1 % GGML_VEC_DOT_UNROLL != 0) { + for (int ic = 0; ic < nek1; ++ic) { + // k indices + const int ik3 = iq3; + const int ik2 = iq2; + const int ik1 = ic; - // S indices - const int i1 = ik1; + // S indices + const int i1 = ik1; - ggml_vec_dot_f16_unroll(neq0, nbk1, - S + i1, - ((char *) k->data + (ik1*nbk1 + ik2*nbk2 + ik3*nbk3)), - (ggml_fp16_t *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3))); + ggml_vec_dot_f16(neq0, + S + i1, + (ggml_fp16_t *) ((char *) k->data + (ik1*nbk1 + ik2*nbk2 + ik3*nbk3)), + (ggml_fp16_t *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3))); + } + } else { + for (int ic = 0; ic < nek1; ic += GGML_VEC_DOT_UNROLL) { + // k indices + const int ik3 = iq3; + const int ik2 = iq2; + const int ik1 = ic; + + // S indices + const int i1 = ik1; + + ggml_vec_dot_f16_unroll(neq0, nbk1, + S + i1, + ((char *) k->data + (ik1*nbk1 + ik2*nbk2 + ik3*nbk3)), + (ggml_fp16_t *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3))); + } } -#endif // scale ggml_vec_scale_f32(nek1, S, scale); @@ -6261,18 +6258,30 @@ static void ggml_compute_forward_flash_attn_f16( S16[i] = GGML_FP32_TO_FP16(S[i]); } - GGML_ASSERT(nev1 % GGML_VEC_DOT_UNROLL == 0); + if (GGML_VEC_DOT_UNROLL == 1 || (nev1 % GGML_VEC_DOT_UNROLL != 0)) { + for (int ic = 0; ic < nev1; ++ic) { + // dst indices + const int i1 = iq1; + const int i2 = iq2; + const int i3 = iq3; - for (int ic = 0; ic < nev1; ic += GGML_VEC_DOT_UNROLL) { - // dst indices - const int i1 = iq1; - const int i2 = iq2; - const int i3 = iq3; + ggml_vec_dot_f16(nek1, + (float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)), + (ggml_fp16_t *) ((char *) v->data + ( ic*nbv1 + i2*nbv2 + i3*nbv3)), + S16); + } + } else { + for (int ic = 0; ic < nev1; ic += GGML_VEC_DOT_UNROLL) { + // dst indices + const int i1 = iq1; + const int i2 = iq2; + const int i3 = iq3; - ggml_vec_dot_f16_unroll(nek1, nbv1, - (float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)), - ((char *) v->data + ( ic*nbv1 + i2*nbv2 + i3*nbv3)), - S16); + ggml_vec_dot_f16_unroll(nek1, nbv1, + (float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)), + ((char *) v->data + ( ic*nbv1 + i2*nbv2 + i3*nbv3)), + S16); + } } } }