|
|
@ -84,7 +84,7 @@ typedef void* thread_ret_t;
|
|
|
|
#define GGML_GELU_FP16
|
|
|
|
#define GGML_GELU_FP16
|
|
|
|
|
|
|
|
|
|
|
|
#define GGML_SOFT_MAX_UNROLL 4
|
|
|
|
#define GGML_SOFT_MAX_UNROLL 4
|
|
|
|
#define GGML_VEC_DOT_UNROLL 4
|
|
|
|
#define GGML_VEC_DOT_UNROLL 2
|
|
|
|
|
|
|
|
|
|
|
|
#ifdef GGML_USE_ACCELERATE
|
|
|
|
#ifdef GGML_USE_ACCELERATE
|
|
|
|
// uncomment to use vDSP for soft max computation
|
|
|
|
// uncomment to use vDSP for soft max computation
|
|
|
@ -923,9 +923,9 @@ inline static void ggml_vec_dot_f16(const int n, float * restrict s, ggml_fp16_t
|
|
|
|
inline static void ggml_vec_dot_f16_unroll(const int n, const int xs, float * restrict s, void * restrict xv, ggml_fp16_t * restrict y) {
|
|
|
|
inline static void ggml_vec_dot_f16_unroll(const int n, const int xs, float * restrict s, void * restrict xv, ggml_fp16_t * restrict y) {
|
|
|
|
ggml_float sumf[GGML_VEC_DOT_UNROLL] = { 0.0 };
|
|
|
|
ggml_float sumf[GGML_VEC_DOT_UNROLL] = { 0.0 };
|
|
|
|
|
|
|
|
|
|
|
|
const ggml_fp16_t * restrict x[GGML_VEC_DOT_UNROLL] = { xv };
|
|
|
|
ggml_fp16_t * restrict x[GGML_VEC_DOT_UNROLL];
|
|
|
|
|
|
|
|
|
|
|
|
for (int i = 1; i < GGML_VEC_DOT_UNROLL; ++i) {
|
|
|
|
for (int i = 0; i < GGML_VEC_DOT_UNROLL; ++i) {
|
|
|
|
x[i] = (ggml_fp16_t *) ((char *) xv + i*xs);
|
|
|
|
x[i] = (ggml_fp16_t *) ((char *) xv + i*xs);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
@ -6158,8 +6158,7 @@ static void ggml_compute_forward_flash_attn_f16(
|
|
|
|
S[i] = -INFINITY;
|
|
|
|
S[i] = -INFINITY;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// looks like unrolling here does not help
|
|
|
|
if (GGML_VEC_DOT_UNROLL > 2 || nek1 % GGML_VEC_DOT_UNROLL != 0) {
|
|
|
|
#if 1
|
|
|
|
|
|
|
|
for (int ic = 0; ic < nek1; ++ic) {
|
|
|
|
for (int ic = 0; ic < nek1; ++ic) {
|
|
|
|
// k indices
|
|
|
|
// k indices
|
|
|
|
const int ik3 = iq3;
|
|
|
|
const int ik3 = iq3;
|
|
|
@ -6174,9 +6173,7 @@ static void ggml_compute_forward_flash_attn_f16(
|
|
|
|
(ggml_fp16_t *) ((char *) k->data + (ik1*nbk1 + ik2*nbk2 + ik3*nbk3)),
|
|
|
|
(ggml_fp16_t *) ((char *) k->data + (ik1*nbk1 + ik2*nbk2 + ik3*nbk3)),
|
|
|
|
(ggml_fp16_t *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)));
|
|
|
|
(ggml_fp16_t *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)));
|
|
|
|
}
|
|
|
|
}
|
|
|
|
#else
|
|
|
|
} else {
|
|
|
|
GGML_ASSERT(nek1 % GGML_VEC_DOT_UNROLL == 0);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for (int ic = 0; ic < nek1; ic += GGML_VEC_DOT_UNROLL) {
|
|
|
|
for (int ic = 0; ic < nek1; ic += GGML_VEC_DOT_UNROLL) {
|
|
|
|
// k indices
|
|
|
|
// k indices
|
|
|
|
const int ik3 = iq3;
|
|
|
|
const int ik3 = iq3;
|
|
|
@ -6191,7 +6188,7 @@ static void ggml_compute_forward_flash_attn_f16(
|
|
|
|
((char *) k->data + (ik1*nbk1 + ik2*nbk2 + ik3*nbk3)),
|
|
|
|
((char *) k->data + (ik1*nbk1 + ik2*nbk2 + ik3*nbk3)),
|
|
|
|
(ggml_fp16_t *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)));
|
|
|
|
(ggml_fp16_t *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)));
|
|
|
|
}
|
|
|
|
}
|
|
|
|
#endif
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// scale
|
|
|
|
// scale
|
|
|
|
ggml_vec_scale_f32(nek1, S, scale);
|
|
|
|
ggml_vec_scale_f32(nek1, S, scale);
|
|
|
@ -6261,8 +6258,19 @@ static void ggml_compute_forward_flash_attn_f16(
|
|
|
|
S16[i] = GGML_FP32_TO_FP16(S[i]);
|
|
|
|
S16[i] = GGML_FP32_TO_FP16(S[i]);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
GGML_ASSERT(nev1 % GGML_VEC_DOT_UNROLL == 0);
|
|
|
|
if (GGML_VEC_DOT_UNROLL == 1 || (nev1 % GGML_VEC_DOT_UNROLL != 0)) {
|
|
|
|
|
|
|
|
for (int ic = 0; ic < nev1; ++ic) {
|
|
|
|
|
|
|
|
// dst indices
|
|
|
|
|
|
|
|
const int i1 = iq1;
|
|
|
|
|
|
|
|
const int i2 = iq2;
|
|
|
|
|
|
|
|
const int i3 = iq3;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ggml_vec_dot_f16(nek1,
|
|
|
|
|
|
|
|
(float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)),
|
|
|
|
|
|
|
|
(ggml_fp16_t *) ((char *) v->data + ( ic*nbv1 + i2*nbv2 + i3*nbv3)),
|
|
|
|
|
|
|
|
S16);
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
} else {
|
|
|
|
for (int ic = 0; ic < nev1; ic += GGML_VEC_DOT_UNROLL) {
|
|
|
|
for (int ic = 0; ic < nev1; ic += GGML_VEC_DOT_UNROLL) {
|
|
|
|
// dst indices
|
|
|
|
// dst indices
|
|
|
|
const int i1 = iq1;
|
|
|
|
const int i1 = iq1;
|
|
|
@ -6275,6 +6283,7 @@ static void ggml_compute_forward_flash_attn_f16(
|
|
|
|
S16);
|
|
|
|
S16);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
static void ggml_compute_forward_flash_attn(
|
|
|
|
static void ggml_compute_forward_flash_attn(
|
|
|
|