From 74ffa14e1d22a34ac5e3f5751b90187f3459a7f1 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sat, 7 Jan 2023 17:32:23 +0200 Subject: [PATCH] ggml : unroll ggml_vec_dot_f16 in ggml_compute_forward_flash_attn_f16 --- ggml.c | 91 ++++++++++++++++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 85 insertions(+), 6 deletions(-) diff --git a/ggml.c b/ggml.c index 7db762f..e627164 100644 --- a/ggml.c +++ b/ggml.c @@ -79,9 +79,11 @@ typedef void* thread_ret_t; #define static_assert(cond, msg) _Static_assert(cond, msg) #endif +/*#define GGML_PERF*/ #define GGML_DEBUG 0 #define GGML_GELU_FP16 #define GGML_SOFT_MAX_UNROLL 4 +#define GGML_VEC_DOT_UNROLL 4 #if UINTPTR_MAX == 0xFFFFFFFF #define GGML_MEM_ALIGN 4 @@ -909,6 +911,61 @@ inline static void ggml_vec_dot_f16(const int n, float * restrict s, ggml_fp16_t *s = sumf; } +// compute GGML_VEC_DOT_UNROLL dot products at once +// xs - x row stride in bytes +inline static void ggml_vec_dot_f16_unroll(const int n, const int xs, float * restrict s, void * restrict xv, ggml_fp16_t * restrict y) { + ggml_float sumf[GGML_VEC_DOT_UNROLL] = { 0.0 }; + + const ggml_fp16_t * restrict x[GGML_VEC_DOT_UNROLL] = { xv }; + + for (int i = 1; i < GGML_VEC_DOT_UNROLL; ++i) { + x[i] = (ggml_fp16_t *) ((char *) xv + i*xs); + } + +#if defined(GGML_SIMD) + const int np = (n & ~(GGML_F16_STEP - 1)); + + GGML_F16_VEC sum[GGML_VEC_DOT_UNROLL][GGML_F16_ARR] = { { GGML_F16_VEC_ZERO } }; + + GGML_F16_VEC ax[GGML_F16_ARR]; + GGML_F16_VEC ay[GGML_F16_ARR]; + + for (int i = 0; i < np; i += GGML_F16_STEP) { + for (int j = 0; j < GGML_F16_ARR; j++) { + ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR, j); + + for (int k = 0; k < GGML_VEC_DOT_UNROLL; ++k) { + ax[j] = GGML_F16_VEC_LOAD(x[k] + i + j*GGML_F16_EPR, j); + + sum[k][j] = GGML_F16_VEC_FMA(sum[k][j], ax[j], ay[j]); + } + } + } + + // reduce sum0..sum3 to sum0 + for (int k = 0; k < GGML_VEC_DOT_UNROLL; ++k) { + GGML_F16_VEC_REDUCE(sumf[k], sum[k]); + } + + // leftovers + for (int i = np; i < n; ++i) { + for (int j = 0; j < GGML_VEC_DOT_UNROLL; ++j) { + sumf[j] += GGML_FP16_TO_FP32(x[j][i])*GGML_FP16_TO_FP32(y[i]); + } + } +#else + for (int i = 0; i < n; ++i) { + for (int j = 0; j < GGML_VEC_DOT_UNROLL; ++j) { + sumf[j] += GGML_FP16_TO_FP32(x[j][i])*GGML_FP16_TO_FP32(y[i]); + } + } +#endif + + for (int i = 0; i < GGML_VEC_DOT_UNROLL; ++i) { + s[i] = sumf[i]; + } +} + inline static void ggml_vec_mad_f32(const int n, float * restrict y, const float * restrict x, const float v) { #if defined(GGML_SIMD) const int np = (n & ~(GGML_F32_STEP - 1)); @@ -4720,9 +4777,9 @@ static void ggml_compute_forward_mul_mat_f16_f32( float * dst_col = (float *) ((char *) dst->data + (i0*nb0 + 0*nb1 + i2*nb2 + i3*nb3)); - for (int ic = 0; ic < ne11; ++ic) { - assert(ne00 % 32 == 0); + assert(ne00 % 32 == 0); + for (int ic = 0; ic < ne11; ++ic) { ggml_vec_dot_f16(ne00, &dst_col[ic*ne0], src0_row, src1_col + ic*ne00); } } @@ -6092,6 +6149,8 @@ static void ggml_compute_forward_flash_attn_f16( S[i] = -INFINITY; } + // looks like unrolling here does not help +#if 1 for (int ic = 0; ic < nek1; ++ic) { // k indices const int ik3 = iq3; @@ -6106,6 +6165,24 @@ static void ggml_compute_forward_flash_attn_f16( (ggml_fp16_t *) ((char *) k->data + (ik1*nbk1 + ik2*nbk2 + ik3*nbk3)), (ggml_fp16_t *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3))); } +#else + GGML_ASSERT(nek1 % GGML_VEC_DOT_UNROLL == 0); + + for (int ic = 0; ic < nek1; ic += GGML_VEC_DOT_UNROLL) { + // k indices + const int ik3 = iq3; + const int ik2 = iq2; + const int ik1 = ic; + + // S indices + const int i1 = ik1; + + ggml_vec_dot_f16_unroll(neq0, nbk1, + S + i1, + ((char *) k->data + (ik1*nbk1 + ik2*nbk2 + ik3*nbk3)), + (ggml_fp16_t *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3))); + } +#endif // scale ggml_vec_scale_f32(nek1, S, scale); @@ -6173,15 +6250,17 @@ static void ggml_compute_forward_flash_attn_f16( S16[i] = GGML_FP32_TO_FP16(S[i]); } - for (int ic = 0; ic < nev1; ++ic) { + GGML_ASSERT(nev1 % GGML_VEC_DOT_UNROLL == 0); + + for (int ic = 0; ic < nev1; ic += GGML_VEC_DOT_UNROLL) { // dst indices const int i1 = iq1; const int i2 = iq2; const int i3 = iq3; - ggml_vec_dot_f16(nek1, - (float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)), - (ggml_fp16_t *) ((char *) v->data + ( ic*nbv1 + i2*nbv2 + i3*nbv3)), + ggml_vec_dot_f16_unroll(nek1, nbv1, + (float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)), + ((char *) v->data + ( ic*nbv1 + i2*nbv2 + i3*nbv3)), S16); } }