Compare commits

...

1 Commits

Author SHA1 Message Date
Georgi Gerganov 3afb833f84
wip : unsuccessful attempts speeding mul_mat using blocking
2 years ago

@ -1307,7 +1307,7 @@ bool whisper_encode(
ggml_build_forward_expand(&gf, inpO);
ggml_graph_compute (ctxL, &gf);
//ggml_graph_print(&gf);
ggml_graph_print(&gf);
}
// TODO: this is a hack to have per-layer computation graphs - need to come up with something better

@ -195,6 +195,7 @@ int64_t ggml_cycles_per_ms(void) {
return CLOCKS_PER_SEC/1000;
}
#define GGML_PERF
#ifdef GGML_PERF
#define ggml_perf_time_ms() ggml_time_ms()
#define ggml_perf_time_us() ggml_time_us()
@ -495,6 +496,154 @@ inline static void ggml_vec_dot_f16(const int n, float * restrict s, ggml_fp16_t
*s = sumf;
}
// blocking
inline static void ggml_vec_dot_f16_blck32(const int n0, const int n1, float * restrict s, ggml_fp16_t * restrict x, ggml_fp16_t * restrict y) {
float16x8_t x0 = vld1q_f16(x + 0 );
float16x8_t x1 = vld1q_f16(x + 8 );
float16x8_t x2 = vld1q_f16(x + 16);
float16x8_t x3 = vld1q_f16(x + 24);
float16x8_t y0, y1, y2, y3;
for (int i = 0; i < n1; ++i) {
y0 = vld1q_f16(y + 32*i + 0 );
y1 = vld1q_f16(y + 32*i + 8 );
y2 = vld1q_f16(y + 32*i + 16);
y3 = vld1q_f16(y + 32*i + 24);
y0 = vmulq_f16(x0, y0);
y1 = vmulq_f16(x1, y1);
y2 = vmulq_f16(x2, y2);
y3 = vmulq_f16(x3, y3);
y0 = vaddq_f16(y0, y1);
y2 = vaddq_f16(y2, y3);
y0 = vaddq_f16(y0, y2);
s[i*n0] += y0[0] + y0[1] + y0[2] + y0[3] + y0[4] + y0[5] + y0[6] + y0[7];
}
}
inline static void ggml_vec_dot_f16_blck64(const int n0, const int n1, float * restrict s, ggml_fp16_t * restrict x, ggml_fp16_t * restrict y) {
float16x8_t x0 = vld1q_f16(x + 0 );
float16x8_t x1 = vld1q_f16(x + 8 );
float16x8_t x2 = vld1q_f16(x + 16);
float16x8_t x3 = vld1q_f16(x + 24);
float16x8_t x4 = vld1q_f16(x + 32);
float16x8_t x5 = vld1q_f16(x + 40);
float16x8_t x6 = vld1q_f16(x + 48);
float16x8_t x7 = vld1q_f16(x + 56);
float16x8_t y0, y1, y2, y3, y4, y5, y6, y7;
for (int i = 0; i < n1; ++i) {
y0 = vld1q_f16(y + 0 );
y1 = vld1q_f16(y + 8 );
y2 = vld1q_f16(y + 16);
y3 = vld1q_f16(y + 24);
y4 = vld1q_f16(y + 32);
y5 = vld1q_f16(y + 40);
y6 = vld1q_f16(y + 48);
y7 = vld1q_f16(y + 56);
y0 = vmulq_f16(x0, y0);
y1 = vmulq_f16(x1, y1);
y2 = vmulq_f16(x2, y2);
y3 = vmulq_f16(x3, y3);
y4 = vmulq_f16(x4, y4);
y5 = vmulq_f16(x5, y5);
y6 = vmulq_f16(x6, y6);
y7 = vmulq_f16(x7, y7);
y0 = vaddq_f16(y0, y1);
y2 = vaddq_f16(y2, y3);
y4 = vaddq_f16(y4, y5);
y6 = vaddq_f16(y6, y7);
y0 = vaddq_f16(y0, y2);
y4 = vaddq_f16(y4, y6);
y0 = vaddq_f16(y0, y4);
s[i*n0] += y0[0] + y0[1] + y0[2] + y0[3] + y0[4] + y0[5] + y0[6] + y0[7];
y += 64;
}
}
inline static void ggml_vec_dot_f16_blck128(const int n0, const int n1, float * restrict s, ggml_fp16_t * restrict x, ggml_fp16_t * restrict y) {
float16x8_t x0 = vld1q_f16(x + 0 );
float16x8_t x1 = vld1q_f16(x + 8 );
float16x8_t x2 = vld1q_f16(x + 16 );
float16x8_t x3 = vld1q_f16(x + 24 );
float16x8_t x4 = vld1q_f16(x + 32 );
float16x8_t x5 = vld1q_f16(x + 40 );
float16x8_t x6 = vld1q_f16(x + 48 );
float16x8_t x7 = vld1q_f16(x + 56 );
float16x8_t x8 = vld1q_f16(x + 64 );
float16x8_t x9 = vld1q_f16(x + 72 );
float16x8_t x10 = vld1q_f16(x + 80 );
float16x8_t x11 = vld1q_f16(x + 88 );
float16x8_t x12 = vld1q_f16(x + 96 );
float16x8_t x13 = vld1q_f16(x + 104);
float16x8_t x14 = vld1q_f16(x + 112);
float16x8_t x15 = vld1q_f16(x + 120);
float16x8_t y0, y1, y2, y3, y4, y5, y6, y7, y8, y9, y10, y11, y12, y13, y14, y15;
for (int i = 0; i < n1; ++i) {
y0 = vld1q_f16(y + 0 );
y1 = vld1q_f16(y + 8 );
y2 = vld1q_f16(y + 16 );
y3 = vld1q_f16(y + 24 );
y4 = vld1q_f16(y + 32 );
y5 = vld1q_f16(y + 40 );
y6 = vld1q_f16(y + 48 );
y7 = vld1q_f16(y + 56 );
y8 = vld1q_f16(y + 64 );
y9 = vld1q_f16(y + 72 );
y10 = vld1q_f16(y + 80 );
y11 = vld1q_f16(y + 88 );
y12 = vld1q_f16(y + 96 );
y13 = vld1q_f16(y + 104);
y14 = vld1q_f16(y + 112);
y15 = vld1q_f16(y + 120);
y0 = vmulq_f16(x0, y0);
y1 = vmulq_f16(x1, y1);
y2 = vmulq_f16(x2, y2);
y3 = vmulq_f16(x3, y3);
y4 = vmulq_f16(x4, y4);
y5 = vmulq_f16(x5, y5);
y6 = vmulq_f16(x6, y6);
y7 = vmulq_f16(x7, y7);
y8 = vmulq_f16(x8, y8);
y9 = vmulq_f16(x9, y9);
y10 = vmulq_f16(x10, y10);
y11 = vmulq_f16(x11, y11);
y12 = vmulq_f16(x12, y12);
y13 = vmulq_f16(x13, y13);
y14 = vmulq_f16(x14, y14);
y15 = vmulq_f16(x15, y15);
y0 = vaddq_f16(y0, y1);
y2 = vaddq_f16(y2, y3);
y4 = vaddq_f16(y4, y5);
y6 = vaddq_f16(y6, y7);
y8 = vaddq_f16(y8, y9);
y10 = vaddq_f16(y10, y11);
y12 = vaddq_f16(y12, y13);
y14 = vaddq_f16(y14, y15);
y0 = vaddq_f16(y0, y2);
y4 = vaddq_f16(y4, y6);
y8 = vaddq_f16(y8, y10);
y12 = vaddq_f16(y12, y14);
y0 = vaddq_f16(y0, y4);
y8 = vaddq_f16(y8, y12);
y0 = vaddq_f16(y0, y8);
s[i*n0] += (y0[0] + y0[1] + y0[2] + y0[3]) + (y0[4] + y0[5] + y0[6] + y0[7]);
y += 128;
}
}
inline static void ggml_vec_mad_f32(const int n, float * restrict y, const float * restrict x, const float v) {
#ifdef __ARM_NEON
// NEON 128-bit
@ -4083,6 +4232,9 @@ void ggml_compute_forward_mul_mat_f16_f32(
assert(ne2 == ne02);
assert(ne3 == ne03);
// blocking
const int bs = 128;
// nb01 >= nb00 - src0 is not transposed
// compute by src0 rows
//
@ -4094,14 +4246,31 @@ void ggml_compute_forward_mul_mat_f16_f32(
ggml_fp16_t * const wdata = params->wdata;
int id = 0;
for (int i13 = 0; i13 < ne13; ++i13) {
for (int i12 = 0; i12 < ne12; ++i12) {
for (int i11 = 0; i11 < ne11; ++i11) {
for (int i10 = 0; i10 < ne10; ++i10) {
wdata[id++] = ggml_fp32_to_fp16(*(float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11 + i10*nb10));
if (ne00 < bs) {
for (int i13 = 0; i13 < ne13; ++i13) {
for (int i12 = 0; i12 < ne12; ++i12) {
for (int i11 = 0; i11 < ne11; ++i11) {
for (int i10 = 0; i10 < ne10; ++i10) {
wdata[id++] = ggml_fp32_to_fp16(*(float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11 + i10*nb10));
}
}
}
}
} else {
// blocking
for (int i13 = 0; i13 < ne13; ++i13) {
for (int i12 = 0; i12 < ne12; ++i12) {
for (int i10 = 0; i10 < ne10; i10 += bs) {
for (int i11 = 0; i11 < ne11; ++i11) {
for (int i = 0; i < bs; ++i) {
wdata[id++] = ggml_fp32_to_fp16(*(float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11 + (i10+i)*nb10));
}
}
}
}
}
memset(dst->data, 0, ne*sizeof(float));
}
GGML_ASSERT(id*sizeof(ggml_fp16_t) <= params->wsize);
@ -4181,10 +4350,26 @@ void ggml_compute_forward_mul_mat_f16_f32(
float * dst_col = (float *) ((char *) dst->data + (i0*nb0 + 0*nb1 + i2*nb2 + i3*nb3));
for (int ic = 0; ic < ne11; ++ic) {
assert(ne00 % 32 == 0);
if (ne00 < bs) {
for (int ic = 0; ic < ne11; ++ic) {
assert(ne00 % 32 == 0);
ggml_vec_dot_f16(ne00, &dst_col[ic*ne0], src0_row, src1_col + ic*ne00);
ggml_vec_dot_f16(ne00, &dst_col[ic*ne0], src0_row, src1_col + ic*ne00);
}
} else {
// blocking
//for (int k = 0; k < ne00/bs; ++k) {
// for (int ic = 0; ic < ne11; ++ic) {
// float d = 0.0f;
// ggml_vec_dot_f16(bs, &d, src0_row + k*bs, src1_col + k*ne11*bs + ic*bs);
// dst_col[ic*ne0] += d;
// }
//}
for (int k = 0; k < ne00/bs; ++k) {
//ggml_vec_dot_f16_blck32(ne0, ne11, dst_col, src0_row + k*bs, src1_col + k*ne11*bs);
//ggml_vec_dot_f16_blck64(ne0, ne11, dst_col, src0_row + k*bs, src1_col + k*ne11*bs);
ggml_vec_dot_f16_blck128(ne0, ne11, dst_col, src0_row + k*bs, src1_col + k*ne11*bs);
}
}
}
} else {
@ -5750,6 +5935,9 @@ void ggml_compute_forward_flash_ff_f16(
const int ir0 = dr*ith;
const int ir1 = MIN(ir0 + dr, nr);
// blocking
const int bs = 128;
for (int ir = ir0; ir < ir1; ++ir) {
// a indices
const int ia3 = ir/(nea2*nea1);
@ -5758,19 +5946,41 @@ void ggml_compute_forward_flash_ff_f16(
float * S = (float *) params->wdata + ith*(2*M + CACHE_LINE_SIZE_F32);
for (int ic = 0; ic < neb01; ++ic) {
// b0 indices
const int ib03 = ia3;
const int ib02 = ia2;
const int ib01 = ic;
memset(S, 0, neb01*sizeof(float));
// S indices
const int i1 = ib01;
if (bs == 1) {
for (int ic = 0; ic < neb01; ++ic) {
// b0 indices
const int ib03 = ia3;
const int ib02 = ia2;
const int ib01 = ic;
ggml_vec_dot_f16(nea0,
S + i1,
(ggml_fp16_t *) ((char *) b0->data + (ib01*nbb01 + ib02*nbb02 + ib03*nbb03)),
(ggml_fp16_t *) ((char *) a->data + ( ia1*nba1 + ia2*nba2 + ia3*nba3)));
// S indices
const int i1 = ib01;
ggml_vec_dot_f16(nea0,
S + i1,
(ggml_fp16_t *) ((char *) b0->data + (ib01*nbb01 + ib02*nbb02 + ib03*nbb03)),
(ggml_fp16_t *) ((char *) a->data + ( ia1*nba1 + ia2*nba2 + ia3*nba3)));
}
} else {
for (int k = 0; k < nea0/bs; ++k) {
for (int ic = 0; ic < neb01; ++ic) {
// b0 indices
const int ib03 = ia3;
const int ib02 = ia2;
const int ib01 = ic;
// S indices
const int i1 = ib01;
float d = 0.0f;
ggml_vec_dot_f16(bs, &d,
(ggml_fp16_t *) ((char *) b0->data + (ib01*nbb01 + ib02*nbb02 + ib03*nbb03 + k*bs*nbb00)),
(ggml_fp16_t *) ((char *) a->data + ( ia1*nba1 + ia2*nba2 + ia3*nba3 + k*bs*nba0)));
S[i1] += d;
}
}
}
ggml_vec_add_f32(neb01, S, S, (float *) b1->data);
@ -5791,7 +6001,6 @@ void ggml_compute_forward_flash_ff_f16(
const int i3 = ia3;
for (int ic = 0; ic < nec01; ++ic) {
ggml_vec_dot_f16(neb01,
(float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)),
(ggml_fp16_t *) ((char *) c0->data + ( ic*nbc01 + i2*nbc02 + i3*nbc03)),

Loading…
Cancel
Save