ggml : sync with latest whisper.cpp

pull/15/head
Georgi Gerganov 1 year ago
parent 73a7916d30
commit 1af4cf0102
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735

@ -79,8 +79,11 @@ typedef void* thread_ret_t;
#define static_assert(cond, msg) _Static_assert(cond, msg)
#endif
/*#define GGML_PERF*/
#define GGML_DEBUG 0
#define GGML_GELU_FP16
#define GGML_SOFT_MAX_UNROLL 4
#define GGML_VEC_DOT_UNROLL 4
#if UINTPTR_MAX == 0xFFFFFFFF
#define GGML_MEM_ALIGN 4
@ -908,6 +911,61 @@ inline static void ggml_vec_dot_f16(const int n, float * restrict s, ggml_fp16_t
*s = sumf;
}
// compute GGML_VEC_DOT_UNROLL dot products at once
// xs - x row stride in bytes
inline static void ggml_vec_dot_f16_unroll(const int n, const int xs, float * restrict s, void * restrict xv, ggml_fp16_t * restrict y) {
ggml_float sumf[GGML_VEC_DOT_UNROLL] = { 0.0 };
const ggml_fp16_t * restrict x[GGML_VEC_DOT_UNROLL] = { xv };
for (int i = 1; i < GGML_VEC_DOT_UNROLL; ++i) {
x[i] = (ggml_fp16_t *) ((char *) xv + i*xs);
}
#if defined(GGML_SIMD)
const int np = (n & ~(GGML_F16_STEP - 1));
GGML_F16_VEC sum[GGML_VEC_DOT_UNROLL][GGML_F16_ARR] = { { GGML_F16_VEC_ZERO } };
GGML_F16_VEC ax[GGML_F16_ARR];
GGML_F16_VEC ay[GGML_F16_ARR];
for (int i = 0; i < np; i += GGML_F16_STEP) {
for (int j = 0; j < GGML_F16_ARR; j++) {
ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR, j);
for (int k = 0; k < GGML_VEC_DOT_UNROLL; ++k) {
ax[j] = GGML_F16_VEC_LOAD(x[k] + i + j*GGML_F16_EPR, j);
sum[k][j] = GGML_F16_VEC_FMA(sum[k][j], ax[j], ay[j]);
}
}
}
// reduce sum0..sum3 to sum0
for (int k = 0; k < GGML_VEC_DOT_UNROLL; ++k) {
GGML_F16_VEC_REDUCE(sumf[k], sum[k]);
}
// leftovers
for (int i = np; i < n; ++i) {
for (int j = 0; j < GGML_VEC_DOT_UNROLL; ++j) {
sumf[j] += GGML_FP16_TO_FP32(x[j][i])*GGML_FP16_TO_FP32(y[i]);
}
}
#else
for (int i = 0; i < n; ++i) {
for (int j = 0; j < GGML_VEC_DOT_UNROLL; ++j) {
sumf[j] += GGML_FP16_TO_FP32(x[j][i])*GGML_FP16_TO_FP32(y[i]);
}
}
#endif
for (int i = 0; i < GGML_VEC_DOT_UNROLL; ++i) {
s[i] = sumf[i];
}
}
inline static void ggml_vec_mad_f32(const int n, float * restrict y, const float * restrict x, const float v) {
#if defined(GGML_SIMD)
const int np = (n & ~(GGML_F32_STEP - 1));
@ -1039,7 +1097,30 @@ inline static void ggml_vec_gelu_f32(const int n, float * y, const float * x) {
}
#endif
inline static void ggml_vec_sum_f32 (const int n, float * s, const float * x) { ggml_float sum = 0.0; for (int i = 0; i < n; ++i) sum += x[i]; *s += sum; }
inline static void ggml_vec_sum_f32(const int n, float * s, const float * x) {
#ifndef GGML_USE_ACCELERATE
ggml_float sum = 0.0;
for (int i = 0; i < n; ++i) {
sum += x[i];
*s += sum;
}
#else
vDSP_sve(x, 1, s, n);
#endif
}
inline static void ggml_vec_max_f32(const int n, float * s, const float * x) {
#ifndef GGML_USE_ACCELERATE
ggml_float max = -INFINITY;
for (int i = 0; i < n; ++i) {
max = MAX(max, x[i]);
}
*s = max;
#else
vDSP_maxv(x, 1, s, n);
#endif
}
inline static void ggml_vec_norm_inv_f32(const int n, float * s, const float * x) { ggml_vec_norm_f32(n, s, x); *s = 1./(*s); }
//
@ -1293,25 +1374,25 @@ size_t ggml_element_size(const struct ggml_tensor * tensor) {
return GGML_TYPE_SIZE[tensor->type];
}
bool ggml_is_scalar(const struct ggml_tensor * tensor) {
static inline bool ggml_is_scalar(const struct ggml_tensor * tensor) {
static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
return tensor->ne[0] == 1 && tensor->ne[1] == 1 && tensor->ne[2] == 1 && tensor->ne[3] == 1;
}
bool ggml_is_vector(const struct ggml_tensor * tensor) {
static inline bool ggml_is_vector(const struct ggml_tensor * tensor) {
static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
return tensor->ne[1] == 1 && tensor->ne[2] == 1 && tensor->ne[3] == 1;
}
bool ggml_is_matrix(const struct ggml_tensor * tensor) {
static inline bool ggml_is_matrix(const struct ggml_tensor * tensor) {
static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
return tensor->ne[2] == 1 && tensor->ne[3] == 1;
}
bool ggml_can_mul_mat(const struct ggml_tensor * t0, const struct ggml_tensor * t1) {
static inline bool ggml_can_mul_mat(const struct ggml_tensor * t0, const struct ggml_tensor * t1) {
static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
return
@ -1320,7 +1401,7 @@ bool ggml_can_mul_mat(const struct ggml_tensor * t0, const struct ggml_tensor *
(t0->ne[3] == t1->ne[3]);
}
bool ggml_is_contiguous(const struct ggml_tensor * tensor) {
static inline bool ggml_is_contiguous(const struct ggml_tensor * tensor) {
static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
return
@ -1330,7 +1411,7 @@ bool ggml_is_contiguous(const struct ggml_tensor * tensor) {
tensor->nb[3] == tensor->nb[2]*tensor->ne[2];
}
bool ggml_is_padded_1d(const struct ggml_tensor * tensor) {
static inline bool ggml_is_padded_1d(const struct ggml_tensor * tensor) {
static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
return
@ -1339,7 +1420,7 @@ bool ggml_is_padded_1d(const struct ggml_tensor * tensor) {
tensor->nb[3] == tensor->nb[2]*tensor->ne[2];
}
bool ggml_are_same_shape(const struct ggml_tensor * t0, const struct ggml_tensor * t1) {
static inline bool ggml_are_same_shape(const struct ggml_tensor * t0, const struct ggml_tensor * t1) {
static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
return
@ -1350,7 +1431,7 @@ bool ggml_are_same_shape(const struct ggml_tensor * t0, const struct ggml_tensor
}
// check if t1 can be represented as a repeatition of t0
bool ggml_can_repeat(const struct ggml_tensor * t0, const struct ggml_tensor * t1) {
static inline bool ggml_can_repeat(const struct ggml_tensor * t0, const struct ggml_tensor * t1) {
static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
return
@ -1360,14 +1441,20 @@ bool ggml_can_repeat(const struct ggml_tensor * t0, const struct ggml_tensor * t
(t1->ne[3]%t0->ne[3] == 0);
}
int ggml_up32(int n) {
static inline int ggml_up32(int n) {
return (n + 31) & ~31;
}
int ggml_up64(int n) {
static inline int ggml_up64(int n) {
return (n + 63) & ~63;
}
static inline int ggml_up(int n, int m) {
// assert m is a power of 2
GGML_ASSERT((m & (m - 1)) == 0);
return (n + m - 1) & ~(m - 1);
}
// assert that pointer is aligned to GGML_MEM_ALIGN
#define ggml_assert_aligned(ptr) \
assert(((uintptr_t) (ptr))%GGML_MEM_ALIGN == 0)
@ -4658,7 +4745,7 @@ static void ggml_compute_forward_mul_mat_f16_f32(
// TODO: do not support transposed src1
assert(nb10/2 == sizeof(ggml_fp16_t));
// parallelize by src0 rows using ggml_vec_dot_f32
// parallelize by src0 rows using ggml_vec_dot_f16
// total rows in src0
const int nr = ne01*ne02*ne03;
@ -4686,13 +4773,13 @@ static void ggml_compute_forward_mul_mat_f16_f32(
const int i3 = i03;
ggml_fp16_t * src0_row = (ggml_fp16_t *) ((char *) src0->data + (i01*nb01 + i02*nb02 + i03*nb03));
ggml_fp16_t * src1_col = wdata + (i13*ne12*ne11 + i12*ne11 + 0)*ne00;
ggml_fp16_t * src1_col = wdata + ( 0 + i12*ne11 + i13*ne12*ne11)*ne00;
float * dst_col = (float *) ((char *) dst->data + (i0*nb0 + 0*nb1 + i2*nb2 + i3*nb3));
for (int ic = 0; ic < ne11; ++ic) {
assert(ne00 % 32 == 0);
assert(ne00 % 32 == 0);
for (int ic = 0; ic < ne11; ++ic) {
ggml_vec_dot_f16(ne00, &dst_col[ic*ne0], src0_row, src1_col + ic*ne00);
}
}
@ -5071,21 +5158,19 @@ static void ggml_compute_forward_soft_max_f32(
#endif
float max = -INFINITY;
for (int i = 0; i < nc; i++) {
max = MAX(max, p[i]);
}
ggml_vec_max_f32(nc, &max, p);
ggml_float sum = 0.0;
uint16_t ss;
uint16_t scvt;
for (int i = 0; i < nc; i++) {
if (p[i] == -INFINITY) {
p[i] = 0.0;
p[i] = 0.0f;
} else {
//const float val = (p[i] == -INFINITY) ? 0.0 : exp(p[i] - max);
ggml_fp16_t s = GGML_FP32_TO_FP16(p[i] - max);
memcpy(&ss, &s, sizeof(ss));
const float val = GGML_FP16_TO_FP32(table_exp_f16[ss]);
memcpy(&scvt, &s, sizeof(scvt));
const float val = GGML_FP16_TO_FP32(table_exp_f16[scvt]);
sum += val;
p[i] = val;
}
@ -5797,6 +5882,8 @@ static void ggml_compute_forward_flash_attn_f32(
const int P = nek1 - N;
const int M = P + N;
const int Mup = ggml_up(M, GGML_SOFT_MAX_UNROLL);
GGML_ASSERT(ne0 == D);
GGML_ASSERT(ne1 == N);
GGML_ASSERT(P >= 0);
@ -5849,7 +5936,11 @@ static void ggml_compute_forward_flash_attn_f32(
const int iq2 = (ir - iq3*neq2*neq1)/neq1;
const int iq1 = (ir - iq3*neq2*neq1 - iq2*neq1);
float * S = (float *) params->wdata + ith*(M + CACHE_LINE_SIZE_F32);
float * S = (float *) params->wdata + ith*(Mup + CACHE_LINE_SIZE_F32);
for (int i = M; i < Mup; ++i) {
S[i] = -INFINITY;
}
for (int ic = 0; ic < nek1; ++ic) {
// k indices
@ -5880,30 +5971,50 @@ static void ggml_compute_forward_flash_attn_f32(
// softmax
{
float max = -INFINITY;
for (int i = 0; i < M; i++) {
max = MAX(max, S[i]);
}
ggml_vec_max_f32(M, &max, S);
ggml_float sum = 0.0;
float sum = 0.0f;
{
#ifndef GGML_USE_ACCELERATE
uint16_t scvt[GGML_SOFT_MAX_UNROLL];
ggml_float sump[GGML_SOFT_MAX_UNROLL] = { 0.0 };
uint16_t ss;
for (int i = 0; i < M; i++) {
if (S[i] == -INFINITY) {
S[i] = 0.0;
} else {
//const float val = (S[i] == -INFINITY) ? 0.0 : exp(S[i] - max);
ggml_fp16_t s = GGML_FP32_TO_FP16(S[i] - max);
memcpy(&ss, &s, sizeof(ss));
const float val = GGML_FP16_TO_FP32(table_exp_f16[ss]);
sum += val;
S[i] = val;
for (int i = 0; i < Mup; i += GGML_SOFT_MAX_UNROLL) {
float * SS = S + i;
for (int j = 0; j < GGML_SOFT_MAX_UNROLL; ++j) {
if (SS[j] == -INFINITY) {
SS[j] = 0.0f;
} else {
ggml_fp16_t s = GGML_FP32_TO_FP16(SS[j] - max);
memcpy(&scvt[j], &s, sizeof(uint16_t));
const float val = GGML_FP16_TO_FP32(table_exp_f16[scvt[j]]);
sump[j] += val;
SS[j] = val;
}
}
}
for (int i = 0; i < GGML_SOFT_MAX_UNROLL; i++) {
sum += sump[i];
}
#else
vvexpf(S, S, &Mup);
ggml_vec_sum_f32(Mup, &sum, S);
#endif
}
assert(sum > 0.0f);
sum = 1.0/sum;
ggml_vec_scale_f32(M, S, sum);
#ifndef NDEBUG
for (int i = 0; i < M; ++i) {
assert(!isnan(S[i]));
assert(!isinf(S[i]));
}
#endif
}
for (int ic = 0; ic < nev1; ++ic) {
@ -5978,6 +6089,8 @@ static void ggml_compute_forward_flash_attn_f16(
const int P = nek1 - N;
const int M = P + N;
const int Mup = ggml_up(M, GGML_SOFT_MAX_UNROLL);
GGML_ASSERT(ne0 == D);
GGML_ASSERT(ne1 == N);
GGML_ASSERT(P >= 0);
@ -6030,8 +6143,14 @@ static void ggml_compute_forward_flash_attn_f16(
const int iq2 = (ir - iq3*neq2*neq1)/neq1;
const int iq1 = (ir - iq3*neq2*neq1 - iq2*neq1);
float * S = (float *) params->wdata + ith*(2*M + CACHE_LINE_SIZE_F32);
float * S = (float *) params->wdata + ith*(2*Mup + CACHE_LINE_SIZE_F32);
for (int i = M; i < Mup; ++i) {
S[i] = -INFINITY;
}
// looks like unrolling here does not help
#if 1
for (int ic = 0; ic < nek1; ++ic) {
// k indices
const int ik3 = iq3;
@ -6046,6 +6165,24 @@ static void ggml_compute_forward_flash_attn_f16(
(ggml_fp16_t *) ((char *) k->data + (ik1*nbk1 + ik2*nbk2 + ik3*nbk3)),
(ggml_fp16_t *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)));
}
#else
GGML_ASSERT(nek1 % GGML_VEC_DOT_UNROLL == 0);
for (int ic = 0; ic < nek1; ic += GGML_VEC_DOT_UNROLL) {
// k indices
const int ik3 = iq3;
const int ik2 = iq2;
const int ik1 = ic;
// S indices
const int i1 = ik1;
ggml_vec_dot_f16_unroll(neq0, nbk1,
S + i1,
((char *) k->data + (ik1*nbk1 + ik2*nbk2 + ik3*nbk3)),
(ggml_fp16_t *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)));
}
#endif
// scale
ggml_vec_scale_f32(nek1, S, scale);
@ -6061,30 +6198,50 @@ static void ggml_compute_forward_flash_attn_f16(
// softmax
{
float max = -INFINITY;
for (int i = 0; i < M; i++) {
max = MAX(max, S[i]);
}
ggml_vec_max_f32(M, &max, S);
float sum = 0.0f;
{
#ifndef GGML_USE_ACCELERATE
uint16_t scvt[GGML_SOFT_MAX_UNROLL];
ggml_float sump[GGML_SOFT_MAX_UNROLL] = { 0.0 };
ggml_float sum = 0.0;
for (int i = 0; i < Mup; i += GGML_SOFT_MAX_UNROLL) {
float * SS = S + i;
uint16_t ss;
for (int i = 0; i < M; i++) {
if (S[i] == -INFINITY) {
S[i] = 0.0;
} else {
//const float val = (S[i] == -INFINITY) ? 0.0 : exp(S[i] - max);
ggml_fp16_t s = GGML_FP32_TO_FP16(S[i] - max);
memcpy(&ss, &s, sizeof(ss));
const float val = GGML_FP16_TO_FP32(table_exp_f16[ss]);
sum += val;
S[i] = val;
for (int j = 0; j < GGML_SOFT_MAX_UNROLL; ++j) {
if (SS[j] == -INFINITY) {
SS[j] = 0.0f;
} else {
ggml_fp16_t s = GGML_FP32_TO_FP16(SS[j] - max);
memcpy(&scvt[j], &s, sizeof(uint16_t));
const float val = GGML_FP16_TO_FP32(table_exp_f16[scvt[j]]);
sump[j] += val;
SS[j] = val;
}
}
}
for (int i = 0; i < GGML_SOFT_MAX_UNROLL; i++) {
sum += sump[i];
}
#else
vvexpf(S, S, &Mup);
ggml_vec_sum_f32(Mup, &sum, S);
#endif
}
assert(sum > 0.0f);
sum = 1.0/sum;
ggml_vec_scale_f32(M, S, sum);
#ifndef NDEBUG
for (int i = 0; i < M; ++i) {
assert(!isnan(S[i]));
assert(!isinf(S[i]));
}
#endif
}
ggml_fp16_t * S16 = (ggml_fp16_t *) ((float *) params->wdata + ith*(2*M + CACHE_LINE_SIZE_F32) + M);
@ -6093,15 +6250,17 @@ static void ggml_compute_forward_flash_attn_f16(
S16[i] = GGML_FP32_TO_FP16(S[i]);
}
for (int ic = 0; ic < nev1; ++ic) {
GGML_ASSERT(nev1 % GGML_VEC_DOT_UNROLL == 0);
for (int ic = 0; ic < nev1; ic += GGML_VEC_DOT_UNROLL) {
// dst indices
const int i1 = iq1;
const int i2 = iq2;
const int i3 = iq3;
ggml_vec_dot_f16(nek1,
(float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)),
(ggml_fp16_t *) ((char *) v->data + ( ic*nbv1 + i2*nbv2 + i3*nbv3)),
ggml_vec_dot_f16_unroll(nek1, nbv1,
(float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)),
((char *) v->data + ( ic*nbv1 + i2*nbv2 + i3*nbv3)),
S16);
}
}
@ -6983,7 +7142,9 @@ static thread_ret_t ggml_graph_compute_thread(void * data) {
}
if (state->node) {
ggml_compute_forward(&state->params, state->node);
if (state->params.ith < state->params.nth) {
ggml_compute_forward(&state->params, state->node);
}
state->node = NULL;
} else {
break;
@ -7077,9 +7238,15 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
} break;
case GGML_OP_MUL_MAT:
{
// TODO: use different scheduling for different matrix sizes
node->n_tasks = n_threads;
// TODO: use different scheduling for different matrix sizes
//const int nr0 = ggml_nrows(node->src0);
//const int nr1 = ggml_nrows(node->src1);
//node->n_tasks = MIN(n_threads, MAX(1, nr0/128));
//printf("nr0 = %8d, nr1 = %8d, nr0*nr1 = %8d, n_tasks = %d\n", nr0, nr1, nr0*nr1, node->n_tasks);
size_t cur = 0;
// TODO: better way to determine if the matrix is transposed
@ -7090,6 +7257,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
node->src1->type == GGML_TYPE_F32) {
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
if (ggml_compute_forward_mul_mat_use_blas(node->src0, node->src1, node)) {
node->n_tasks = 1;
cur = sizeof(float)*(node->src0->ne[0]*node->src0->ne[1]);
} else {
cur = sizeof(ggml_fp16_t)*ggml_nelements(node->src1);
@ -7165,14 +7333,16 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
size_t cur = 0;
const int ne11 = ggml_up(node->src1->ne[1], GGML_SOFT_MAX_UNROLL);
if (node->src1->type == GGML_TYPE_F32) {
cur = sizeof(float)*node->src1->ne[1]*node->n_tasks; // TODO: this can become (n_tasks-1)
cur += sizeof(float)*node->src1->ne[1]*node->n_tasks; // this is overestimated by x2
cur = sizeof(float)*ne11*node->n_tasks; // TODO: this can become (n_tasks-1)
cur += sizeof(float)*ne11*node->n_tasks; // this is overestimated by x2
}
if (node->src1->type == GGML_TYPE_F16) {
cur = sizeof(float)*node->src1->ne[1]*node->n_tasks; // TODO: this can become (n_tasks-1)
cur += sizeof(float)*node->src1->ne[1]*node->n_tasks; // this is overestimated by x2
cur = sizeof(float)*ne11*node->n_tasks; // TODO: this can become (n_tasks-1)
cur += sizeof(float)*ne11*node->n_tasks; // this is overestimated by x2
}
work_size = MAX(work_size, cur);
@ -7261,7 +7431,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
workers[j].params = (struct ggml_compute_params) {
.type = GGML_TASK_COMPUTE,
.ith = j + 1,
.nth = n_threads,
.nth = node->n_tasks,
.wsize = cgraph->work ? ggml_nbytes(cgraph->work) : 0,
.wdata = cgraph->work ? cgraph->work->data : NULL,
};
@ -7316,7 +7486,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
workers[j].params = (struct ggml_compute_params) {
.type = GGML_TASK_FINALIZE,
.ith = j + 1,
.nth = n_threads,
.nth = node->n_tasks,
.wsize = cgraph->work ? ggml_nbytes(cgraph->work) : 0,
.wdata = cgraph->work ? cgraph->work->data : NULL,
};

Loading…
Cancel
Save