wip : unsuccessful attempts speeding mul_mat using blocking

Performance is slightly worse compared to the no-blocking
approach.  Not sure what I am doing wrong.
experiments/blocking
Georgi Gerganov 2 years ago
parent 67ac34fcfa
commit 3afb833f84
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735

@ -1307,7 +1307,7 @@ bool whisper_encode(
ggml_build_forward_expand(&gf, inpO); ggml_build_forward_expand(&gf, inpO);
ggml_graph_compute (ctxL, &gf); ggml_graph_compute (ctxL, &gf);
//ggml_graph_print(&gf); ggml_graph_print(&gf);
} }
// TODO: this is a hack to have per-layer computation graphs - need to come up with something better // TODO: this is a hack to have per-layer computation graphs - need to come up with something better

@ -195,6 +195,7 @@ int64_t ggml_cycles_per_ms(void) {
return CLOCKS_PER_SEC/1000; return CLOCKS_PER_SEC/1000;
} }
#define GGML_PERF
#ifdef GGML_PERF #ifdef GGML_PERF
#define ggml_perf_time_ms() ggml_time_ms() #define ggml_perf_time_ms() ggml_time_ms()
#define ggml_perf_time_us() ggml_time_us() #define ggml_perf_time_us() ggml_time_us()
@ -495,6 +496,154 @@ inline static void ggml_vec_dot_f16(const int n, float * restrict s, ggml_fp16_t
*s = sumf; *s = sumf;
} }
// blocking
inline static void ggml_vec_dot_f16_blck32(const int n0, const int n1, float * restrict s, ggml_fp16_t * restrict x, ggml_fp16_t * restrict y) {
float16x8_t x0 = vld1q_f16(x + 0 );
float16x8_t x1 = vld1q_f16(x + 8 );
float16x8_t x2 = vld1q_f16(x + 16);
float16x8_t x3 = vld1q_f16(x + 24);
float16x8_t y0, y1, y2, y3;
for (int i = 0; i < n1; ++i) {
y0 = vld1q_f16(y + 32*i + 0 );
y1 = vld1q_f16(y + 32*i + 8 );
y2 = vld1q_f16(y + 32*i + 16);
y3 = vld1q_f16(y + 32*i + 24);
y0 = vmulq_f16(x0, y0);
y1 = vmulq_f16(x1, y1);
y2 = vmulq_f16(x2, y2);
y3 = vmulq_f16(x3, y3);
y0 = vaddq_f16(y0, y1);
y2 = vaddq_f16(y2, y3);
y0 = vaddq_f16(y0, y2);
s[i*n0] += y0[0] + y0[1] + y0[2] + y0[3] + y0[4] + y0[5] + y0[6] + y0[7];
}
}
inline static void ggml_vec_dot_f16_blck64(const int n0, const int n1, float * restrict s, ggml_fp16_t * restrict x, ggml_fp16_t * restrict y) {
float16x8_t x0 = vld1q_f16(x + 0 );
float16x8_t x1 = vld1q_f16(x + 8 );
float16x8_t x2 = vld1q_f16(x + 16);
float16x8_t x3 = vld1q_f16(x + 24);
float16x8_t x4 = vld1q_f16(x + 32);
float16x8_t x5 = vld1q_f16(x + 40);
float16x8_t x6 = vld1q_f16(x + 48);
float16x8_t x7 = vld1q_f16(x + 56);
float16x8_t y0, y1, y2, y3, y4, y5, y6, y7;
for (int i = 0; i < n1; ++i) {
y0 = vld1q_f16(y + 0 );
y1 = vld1q_f16(y + 8 );
y2 = vld1q_f16(y + 16);
y3 = vld1q_f16(y + 24);
y4 = vld1q_f16(y + 32);
y5 = vld1q_f16(y + 40);
y6 = vld1q_f16(y + 48);
y7 = vld1q_f16(y + 56);
y0 = vmulq_f16(x0, y0);
y1 = vmulq_f16(x1, y1);
y2 = vmulq_f16(x2, y2);
y3 = vmulq_f16(x3, y3);
y4 = vmulq_f16(x4, y4);
y5 = vmulq_f16(x5, y5);
y6 = vmulq_f16(x6, y6);
y7 = vmulq_f16(x7, y7);
y0 = vaddq_f16(y0, y1);
y2 = vaddq_f16(y2, y3);
y4 = vaddq_f16(y4, y5);
y6 = vaddq_f16(y6, y7);
y0 = vaddq_f16(y0, y2);
y4 = vaddq_f16(y4, y6);
y0 = vaddq_f16(y0, y4);
s[i*n0] += y0[0] + y0[1] + y0[2] + y0[3] + y0[4] + y0[5] + y0[6] + y0[7];
y += 64;
}
}
inline static void ggml_vec_dot_f16_blck128(const int n0, const int n1, float * restrict s, ggml_fp16_t * restrict x, ggml_fp16_t * restrict y) {
float16x8_t x0 = vld1q_f16(x + 0 );
float16x8_t x1 = vld1q_f16(x + 8 );
float16x8_t x2 = vld1q_f16(x + 16 );
float16x8_t x3 = vld1q_f16(x + 24 );
float16x8_t x4 = vld1q_f16(x + 32 );
float16x8_t x5 = vld1q_f16(x + 40 );
float16x8_t x6 = vld1q_f16(x + 48 );
float16x8_t x7 = vld1q_f16(x + 56 );
float16x8_t x8 = vld1q_f16(x + 64 );
float16x8_t x9 = vld1q_f16(x + 72 );
float16x8_t x10 = vld1q_f16(x + 80 );
float16x8_t x11 = vld1q_f16(x + 88 );
float16x8_t x12 = vld1q_f16(x + 96 );
float16x8_t x13 = vld1q_f16(x + 104);
float16x8_t x14 = vld1q_f16(x + 112);
float16x8_t x15 = vld1q_f16(x + 120);
float16x8_t y0, y1, y2, y3, y4, y5, y6, y7, y8, y9, y10, y11, y12, y13, y14, y15;
for (int i = 0; i < n1; ++i) {
y0 = vld1q_f16(y + 0 );
y1 = vld1q_f16(y + 8 );
y2 = vld1q_f16(y + 16 );
y3 = vld1q_f16(y + 24 );
y4 = vld1q_f16(y + 32 );
y5 = vld1q_f16(y + 40 );
y6 = vld1q_f16(y + 48 );
y7 = vld1q_f16(y + 56 );
y8 = vld1q_f16(y + 64 );
y9 = vld1q_f16(y + 72 );
y10 = vld1q_f16(y + 80 );
y11 = vld1q_f16(y + 88 );
y12 = vld1q_f16(y + 96 );
y13 = vld1q_f16(y + 104);
y14 = vld1q_f16(y + 112);
y15 = vld1q_f16(y + 120);
y0 = vmulq_f16(x0, y0);
y1 = vmulq_f16(x1, y1);
y2 = vmulq_f16(x2, y2);
y3 = vmulq_f16(x3, y3);
y4 = vmulq_f16(x4, y4);
y5 = vmulq_f16(x5, y5);
y6 = vmulq_f16(x6, y6);
y7 = vmulq_f16(x7, y7);
y8 = vmulq_f16(x8, y8);
y9 = vmulq_f16(x9, y9);
y10 = vmulq_f16(x10, y10);
y11 = vmulq_f16(x11, y11);
y12 = vmulq_f16(x12, y12);
y13 = vmulq_f16(x13, y13);
y14 = vmulq_f16(x14, y14);
y15 = vmulq_f16(x15, y15);
y0 = vaddq_f16(y0, y1);
y2 = vaddq_f16(y2, y3);
y4 = vaddq_f16(y4, y5);
y6 = vaddq_f16(y6, y7);
y8 = vaddq_f16(y8, y9);
y10 = vaddq_f16(y10, y11);
y12 = vaddq_f16(y12, y13);
y14 = vaddq_f16(y14, y15);
y0 = vaddq_f16(y0, y2);
y4 = vaddq_f16(y4, y6);
y8 = vaddq_f16(y8, y10);
y12 = vaddq_f16(y12, y14);
y0 = vaddq_f16(y0, y4);
y8 = vaddq_f16(y8, y12);
y0 = vaddq_f16(y0, y8);
s[i*n0] += (y0[0] + y0[1] + y0[2] + y0[3]) + (y0[4] + y0[5] + y0[6] + y0[7]);
y += 128;
}
}
inline static void ggml_vec_mad_f32(const int n, float * restrict y, const float * restrict x, const float v) { inline static void ggml_vec_mad_f32(const int n, float * restrict y, const float * restrict x, const float v) {
#ifdef __ARM_NEON #ifdef __ARM_NEON
// NEON 128-bit // NEON 128-bit
@ -4083,6 +4232,9 @@ void ggml_compute_forward_mul_mat_f16_f32(
assert(ne2 == ne02); assert(ne2 == ne02);
assert(ne3 == ne03); assert(ne3 == ne03);
// blocking
const int bs = 128;
// nb01 >= nb00 - src0 is not transposed // nb01 >= nb00 - src0 is not transposed
// compute by src0 rows // compute by src0 rows
// //
@ -4094,6 +4246,7 @@ void ggml_compute_forward_mul_mat_f16_f32(
ggml_fp16_t * const wdata = params->wdata; ggml_fp16_t * const wdata = params->wdata;
int id = 0; int id = 0;
if (ne00 < bs) {
for (int i13 = 0; i13 < ne13; ++i13) { for (int i13 = 0; i13 < ne13; ++i13) {
for (int i12 = 0; i12 < ne12; ++i12) { for (int i12 = 0; i12 < ne12; ++i12) {
for (int i11 = 0; i11 < ne11; ++i11) { for (int i11 = 0; i11 < ne11; ++i11) {
@ -4103,6 +4256,22 @@ void ggml_compute_forward_mul_mat_f16_f32(
} }
} }
} }
} else {
// blocking
for (int i13 = 0; i13 < ne13; ++i13) {
for (int i12 = 0; i12 < ne12; ++i12) {
for (int i10 = 0; i10 < ne10; i10 += bs) {
for (int i11 = 0; i11 < ne11; ++i11) {
for (int i = 0; i < bs; ++i) {
wdata[id++] = ggml_fp32_to_fp16(*(float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11 + (i10+i)*nb10));
}
}
}
}
}
memset(dst->data, 0, ne*sizeof(float));
}
GGML_ASSERT(id*sizeof(ggml_fp16_t) <= params->wsize); GGML_ASSERT(id*sizeof(ggml_fp16_t) <= params->wsize);
@ -4181,11 +4350,27 @@ void ggml_compute_forward_mul_mat_f16_f32(
float * dst_col = (float *) ((char *) dst->data + (i0*nb0 + 0*nb1 + i2*nb2 + i3*nb3)); float * dst_col = (float *) ((char *) dst->data + (i0*nb0 + 0*nb1 + i2*nb2 + i3*nb3));
if (ne00 < bs) {
for (int ic = 0; ic < ne11; ++ic) { for (int ic = 0; ic < ne11; ++ic) {
assert(ne00 % 32 == 0); assert(ne00 % 32 == 0);
ggml_vec_dot_f16(ne00, &dst_col[ic*ne0], src0_row, src1_col + ic*ne00); ggml_vec_dot_f16(ne00, &dst_col[ic*ne0], src0_row, src1_col + ic*ne00);
} }
} else {
// blocking
//for (int k = 0; k < ne00/bs; ++k) {
// for (int ic = 0; ic < ne11; ++ic) {
// float d = 0.0f;
// ggml_vec_dot_f16(bs, &d, src0_row + k*bs, src1_col + k*ne11*bs + ic*bs);
// dst_col[ic*ne0] += d;
// }
//}
for (int k = 0; k < ne00/bs; ++k) {
//ggml_vec_dot_f16_blck32(ne0, ne11, dst_col, src0_row + k*bs, src1_col + k*ne11*bs);
//ggml_vec_dot_f16_blck64(ne0, ne11, dst_col, src0_row + k*bs, src1_col + k*ne11*bs);
ggml_vec_dot_f16_blck128(ne0, ne11, dst_col, src0_row + k*bs, src1_col + k*ne11*bs);
}
}
} }
} else { } else {
// parallelize by src1 columns using ggml_vec_mad_f32 // parallelize by src1 columns using ggml_vec_mad_f32
@ -5750,6 +5935,9 @@ void ggml_compute_forward_flash_ff_f16(
const int ir0 = dr*ith; const int ir0 = dr*ith;
const int ir1 = MIN(ir0 + dr, nr); const int ir1 = MIN(ir0 + dr, nr);
// blocking
const int bs = 128;
for (int ir = ir0; ir < ir1; ++ir) { for (int ir = ir0; ir < ir1; ++ir) {
// a indices // a indices
const int ia3 = ir/(nea2*nea1); const int ia3 = ir/(nea2*nea1);
@ -5758,6 +5946,9 @@ void ggml_compute_forward_flash_ff_f16(
float * S = (float *) params->wdata + ith*(2*M + CACHE_LINE_SIZE_F32); float * S = (float *) params->wdata + ith*(2*M + CACHE_LINE_SIZE_F32);
memset(S, 0, neb01*sizeof(float));
if (bs == 1) {
for (int ic = 0; ic < neb01; ++ic) { for (int ic = 0; ic < neb01; ++ic) {
// b0 indices // b0 indices
const int ib03 = ia3; const int ib03 = ia3;
@ -5772,6 +5963,25 @@ void ggml_compute_forward_flash_ff_f16(
(ggml_fp16_t *) ((char *) b0->data + (ib01*nbb01 + ib02*nbb02 + ib03*nbb03)), (ggml_fp16_t *) ((char *) b0->data + (ib01*nbb01 + ib02*nbb02 + ib03*nbb03)),
(ggml_fp16_t *) ((char *) a->data + ( ia1*nba1 + ia2*nba2 + ia3*nba3))); (ggml_fp16_t *) ((char *) a->data + ( ia1*nba1 + ia2*nba2 + ia3*nba3)));
} }
} else {
for (int k = 0; k < nea0/bs; ++k) {
for (int ic = 0; ic < neb01; ++ic) {
// b0 indices
const int ib03 = ia3;
const int ib02 = ia2;
const int ib01 = ic;
// S indices
const int i1 = ib01;
float d = 0.0f;
ggml_vec_dot_f16(bs, &d,
(ggml_fp16_t *) ((char *) b0->data + (ib01*nbb01 + ib02*nbb02 + ib03*nbb03 + k*bs*nbb00)),
(ggml_fp16_t *) ((char *) a->data + ( ia1*nba1 + ia2*nba2 + ia3*nba3 + k*bs*nba0)));
S[i1] += d;
}
}
}
ggml_vec_add_f32(neb01, S, S, (float *) b1->data); ggml_vec_add_f32(neb01, S, S, (float *) b1->data);
//ggml_vec_gelu_f32(neb01, S, S); //ggml_vec_gelu_f32(neb01, S, S);
@ -5791,7 +6001,6 @@ void ggml_compute_forward_flash_ff_f16(
const int i3 = ia3; const int i3 = ia3;
for (int ic = 0; ic < nec01; ++ic) { for (int ic = 0; ic < nec01; ++ic) {
ggml_vec_dot_f16(neb01, ggml_vec_dot_f16(neb01,
(float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)), (float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)),
(ggml_fp16_t *) ((char *) c0->data + ( ic*nbc01 + i2*nbc02 + i3*nbc03)), (ggml_fp16_t *) ((char *) c0->data + ( ic*nbc01 + i2*nbc02 + i3*nbc03)),

Loading…
Cancel
Save