diff --git a/examples/whisper/whisper.cpp b/examples/whisper/whisper.cpp index a5f79d2..dc4f3ea 100644 --- a/examples/whisper/whisper.cpp +++ b/examples/whisper/whisper.cpp @@ -1307,7 +1307,7 @@ bool whisper_encode( ggml_build_forward_expand(&gf, inpO); ggml_graph_compute (ctxL, &gf); - //ggml_graph_print(&gf); + ggml_graph_print(&gf); } // TODO: this is a hack to have per-layer computation graphs - need to come up with something better diff --git a/src/ggml.c b/src/ggml.c index 6608300..be9c630 100644 --- a/src/ggml.c +++ b/src/ggml.c @@ -195,6 +195,7 @@ int64_t ggml_cycles_per_ms(void) { return CLOCKS_PER_SEC/1000; } +#define GGML_PERF #ifdef GGML_PERF #define ggml_perf_time_ms() ggml_time_ms() #define ggml_perf_time_us() ggml_time_us() @@ -495,6 +496,154 @@ inline static void ggml_vec_dot_f16(const int n, float * restrict s, ggml_fp16_t *s = sumf; } +// blocking +inline static void ggml_vec_dot_f16_blck32(const int n0, const int n1, float * restrict s, ggml_fp16_t * restrict x, ggml_fp16_t * restrict y) { + float16x8_t x0 = vld1q_f16(x + 0 ); + float16x8_t x1 = vld1q_f16(x + 8 ); + float16x8_t x2 = vld1q_f16(x + 16); + float16x8_t x3 = vld1q_f16(x + 24); + + float16x8_t y0, y1, y2, y3; + + for (int i = 0; i < n1; ++i) { + y0 = vld1q_f16(y + 32*i + 0 ); + y1 = vld1q_f16(y + 32*i + 8 ); + y2 = vld1q_f16(y + 32*i + 16); + y3 = vld1q_f16(y + 32*i + 24); + + y0 = vmulq_f16(x0, y0); + y1 = vmulq_f16(x1, y1); + y2 = vmulq_f16(x2, y2); + y3 = vmulq_f16(x3, y3); + + y0 = vaddq_f16(y0, y1); + y2 = vaddq_f16(y2, y3); + y0 = vaddq_f16(y0, y2); + + s[i*n0] += y0[0] + y0[1] + y0[2] + y0[3] + y0[4] + y0[5] + y0[6] + y0[7]; + } +} + +inline static void ggml_vec_dot_f16_blck64(const int n0, const int n1, float * restrict s, ggml_fp16_t * restrict x, ggml_fp16_t * restrict y) { + float16x8_t x0 = vld1q_f16(x + 0 ); + float16x8_t x1 = vld1q_f16(x + 8 ); + float16x8_t x2 = vld1q_f16(x + 16); + float16x8_t x3 = vld1q_f16(x + 24); + float16x8_t x4 = vld1q_f16(x + 32); + float16x8_t x5 = vld1q_f16(x + 40); + float16x8_t x6 = vld1q_f16(x + 48); + float16x8_t x7 = vld1q_f16(x + 56); + + float16x8_t y0, y1, y2, y3, y4, y5, y6, y7; + + for (int i = 0; i < n1; ++i) { + y0 = vld1q_f16(y + 0 ); + y1 = vld1q_f16(y + 8 ); + y2 = vld1q_f16(y + 16); + y3 = vld1q_f16(y + 24); + y4 = vld1q_f16(y + 32); + y5 = vld1q_f16(y + 40); + y6 = vld1q_f16(y + 48); + y7 = vld1q_f16(y + 56); + + y0 = vmulq_f16(x0, y0); + y1 = vmulq_f16(x1, y1); + y2 = vmulq_f16(x2, y2); + y3 = vmulq_f16(x3, y3); + y4 = vmulq_f16(x4, y4); + y5 = vmulq_f16(x5, y5); + y6 = vmulq_f16(x6, y6); + y7 = vmulq_f16(x7, y7); + + y0 = vaddq_f16(y0, y1); + y2 = vaddq_f16(y2, y3); + y4 = vaddq_f16(y4, y5); + y6 = vaddq_f16(y6, y7); + y0 = vaddq_f16(y0, y2); + y4 = vaddq_f16(y4, y6); + y0 = vaddq_f16(y0, y4); + + s[i*n0] += y0[0] + y0[1] + y0[2] + y0[3] + y0[4] + y0[5] + y0[6] + y0[7]; + y += 64; + } +} + +inline static void ggml_vec_dot_f16_blck128(const int n0, const int n1, float * restrict s, ggml_fp16_t * restrict x, ggml_fp16_t * restrict y) { + float16x8_t x0 = vld1q_f16(x + 0 ); + float16x8_t x1 = vld1q_f16(x + 8 ); + float16x8_t x2 = vld1q_f16(x + 16 ); + float16x8_t x3 = vld1q_f16(x + 24 ); + float16x8_t x4 = vld1q_f16(x + 32 ); + float16x8_t x5 = vld1q_f16(x + 40 ); + float16x8_t x6 = vld1q_f16(x + 48 ); + float16x8_t x7 = vld1q_f16(x + 56 ); + float16x8_t x8 = vld1q_f16(x + 64 ); + float16x8_t x9 = vld1q_f16(x + 72 ); + float16x8_t x10 = vld1q_f16(x + 80 ); + float16x8_t x11 = vld1q_f16(x + 88 ); + float16x8_t x12 = vld1q_f16(x + 96 ); + float16x8_t x13 = vld1q_f16(x + 104); + float16x8_t x14 = vld1q_f16(x + 112); + float16x8_t x15 = vld1q_f16(x + 120); + + float16x8_t y0, y1, y2, y3, y4, y5, y6, y7, y8, y9, y10, y11, y12, y13, y14, y15; + + for (int i = 0; i < n1; ++i) { + y0 = vld1q_f16(y + 0 ); + y1 = vld1q_f16(y + 8 ); + y2 = vld1q_f16(y + 16 ); + y3 = vld1q_f16(y + 24 ); + y4 = vld1q_f16(y + 32 ); + y5 = vld1q_f16(y + 40 ); + y6 = vld1q_f16(y + 48 ); + y7 = vld1q_f16(y + 56 ); + y8 = vld1q_f16(y + 64 ); + y9 = vld1q_f16(y + 72 ); + y10 = vld1q_f16(y + 80 ); + y11 = vld1q_f16(y + 88 ); + y12 = vld1q_f16(y + 96 ); + y13 = vld1q_f16(y + 104); + y14 = vld1q_f16(y + 112); + y15 = vld1q_f16(y + 120); + + y0 = vmulq_f16(x0, y0); + y1 = vmulq_f16(x1, y1); + y2 = vmulq_f16(x2, y2); + y3 = vmulq_f16(x3, y3); + y4 = vmulq_f16(x4, y4); + y5 = vmulq_f16(x5, y5); + y6 = vmulq_f16(x6, y6); + y7 = vmulq_f16(x7, y7); + y8 = vmulq_f16(x8, y8); + y9 = vmulq_f16(x9, y9); + y10 = vmulq_f16(x10, y10); + y11 = vmulq_f16(x11, y11); + y12 = vmulq_f16(x12, y12); + y13 = vmulq_f16(x13, y13); + y14 = vmulq_f16(x14, y14); + y15 = vmulq_f16(x15, y15); + + y0 = vaddq_f16(y0, y1); + y2 = vaddq_f16(y2, y3); + y4 = vaddq_f16(y4, y5); + y6 = vaddq_f16(y6, y7); + y8 = vaddq_f16(y8, y9); + y10 = vaddq_f16(y10, y11); + y12 = vaddq_f16(y12, y13); + y14 = vaddq_f16(y14, y15); + y0 = vaddq_f16(y0, y2); + y4 = vaddq_f16(y4, y6); + y8 = vaddq_f16(y8, y10); + y12 = vaddq_f16(y12, y14); + y0 = vaddq_f16(y0, y4); + y8 = vaddq_f16(y8, y12); + y0 = vaddq_f16(y0, y8); + + s[i*n0] += (y0[0] + y0[1] + y0[2] + y0[3]) + (y0[4] + y0[5] + y0[6] + y0[7]); + y += 128; + } +} + inline static void ggml_vec_mad_f32(const int n, float * restrict y, const float * restrict x, const float v) { #ifdef __ARM_NEON // NEON 128-bit @@ -4083,6 +4232,9 @@ void ggml_compute_forward_mul_mat_f16_f32( assert(ne2 == ne02); assert(ne3 == ne03); + // blocking + const int bs = 128; + // nb01 >= nb00 - src0 is not transposed // compute by src0 rows // @@ -4094,14 +4246,31 @@ void ggml_compute_forward_mul_mat_f16_f32( ggml_fp16_t * const wdata = params->wdata; int id = 0; - for (int i13 = 0; i13 < ne13; ++i13) { - for (int i12 = 0; i12 < ne12; ++i12) { - for (int i11 = 0; i11 < ne11; ++i11) { - for (int i10 = 0; i10 < ne10; ++i10) { - wdata[id++] = ggml_fp32_to_fp16(*(float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11 + i10*nb10)); + if (ne00 < bs) { + for (int i13 = 0; i13 < ne13; ++i13) { + for (int i12 = 0; i12 < ne12; ++i12) { + for (int i11 = 0; i11 < ne11; ++i11) { + for (int i10 = 0; i10 < ne10; ++i10) { + wdata[id++] = ggml_fp32_to_fp16(*(float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11 + i10*nb10)); + } } } } + } else { + // blocking + for (int i13 = 0; i13 < ne13; ++i13) { + for (int i12 = 0; i12 < ne12; ++i12) { + for (int i10 = 0; i10 < ne10; i10 += bs) { + for (int i11 = 0; i11 < ne11; ++i11) { + for (int i = 0; i < bs; ++i) { + wdata[id++] = ggml_fp32_to_fp16(*(float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11 + (i10+i)*nb10)); + } + } + } + } + } + + memset(dst->data, 0, ne*sizeof(float)); } GGML_ASSERT(id*sizeof(ggml_fp16_t) <= params->wsize); @@ -4181,10 +4350,26 @@ void ggml_compute_forward_mul_mat_f16_f32( float * dst_col = (float *) ((char *) dst->data + (i0*nb0 + 0*nb1 + i2*nb2 + i3*nb3)); - for (int ic = 0; ic < ne11; ++ic) { - assert(ne00 % 32 == 0); + if (ne00 < bs) { + for (int ic = 0; ic < ne11; ++ic) { + assert(ne00 % 32 == 0); - ggml_vec_dot_f16(ne00, &dst_col[ic*ne0], src0_row, src1_col + ic*ne00); + ggml_vec_dot_f16(ne00, &dst_col[ic*ne0], src0_row, src1_col + ic*ne00); + } + } else { + // blocking + //for (int k = 0; k < ne00/bs; ++k) { + // for (int ic = 0; ic < ne11; ++ic) { + // float d = 0.0f; + // ggml_vec_dot_f16(bs, &d, src0_row + k*bs, src1_col + k*ne11*bs + ic*bs); + // dst_col[ic*ne0] += d; + // } + //} + for (int k = 0; k < ne00/bs; ++k) { + //ggml_vec_dot_f16_blck32(ne0, ne11, dst_col, src0_row + k*bs, src1_col + k*ne11*bs); + //ggml_vec_dot_f16_blck64(ne0, ne11, dst_col, src0_row + k*bs, src1_col + k*ne11*bs); + ggml_vec_dot_f16_blck128(ne0, ne11, dst_col, src0_row + k*bs, src1_col + k*ne11*bs); + } } } } else { @@ -5750,6 +5935,9 @@ void ggml_compute_forward_flash_ff_f16( const int ir0 = dr*ith; const int ir1 = MIN(ir0 + dr, nr); + // blocking + const int bs = 128; + for (int ir = ir0; ir < ir1; ++ir) { // a indices const int ia3 = ir/(nea2*nea1); @@ -5758,19 +5946,41 @@ void ggml_compute_forward_flash_ff_f16( float * S = (float *) params->wdata + ith*(2*M + CACHE_LINE_SIZE_F32); - for (int ic = 0; ic < neb01; ++ic) { - // b0 indices - const int ib03 = ia3; - const int ib02 = ia2; - const int ib01 = ic; + memset(S, 0, neb01*sizeof(float)); - // S indices - const int i1 = ib01; + if (bs == 1) { + for (int ic = 0; ic < neb01; ++ic) { + // b0 indices + const int ib03 = ia3; + const int ib02 = ia2; + const int ib01 = ic; - ggml_vec_dot_f16(nea0, - S + i1, - (ggml_fp16_t *) ((char *) b0->data + (ib01*nbb01 + ib02*nbb02 + ib03*nbb03)), - (ggml_fp16_t *) ((char *) a->data + ( ia1*nba1 + ia2*nba2 + ia3*nba3))); + // S indices + const int i1 = ib01; + + ggml_vec_dot_f16(nea0, + S + i1, + (ggml_fp16_t *) ((char *) b0->data + (ib01*nbb01 + ib02*nbb02 + ib03*nbb03)), + (ggml_fp16_t *) ((char *) a->data + ( ia1*nba1 + ia2*nba2 + ia3*nba3))); + } + } else { + for (int k = 0; k < nea0/bs; ++k) { + for (int ic = 0; ic < neb01; ++ic) { + // b0 indices + const int ib03 = ia3; + const int ib02 = ia2; + const int ib01 = ic; + + // S indices + const int i1 = ib01; + + float d = 0.0f; + ggml_vec_dot_f16(bs, &d, + (ggml_fp16_t *) ((char *) b0->data + (ib01*nbb01 + ib02*nbb02 + ib03*nbb03 + k*bs*nbb00)), + (ggml_fp16_t *) ((char *) a->data + ( ia1*nba1 + ia2*nba2 + ia3*nba3 + k*bs*nba0))); + S[i1] += d; + } + } } ggml_vec_add_f32(neb01, S, S, (float *) b1->data); @@ -5791,7 +6001,6 @@ void ggml_compute_forward_flash_ff_f16( const int i3 = ia3; for (int ic = 0; ic < nec01; ++ic) { - ggml_vec_dot_f16(neb01, (float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)), (ggml_fp16_t *) ((char *) c0->data + ( ic*nbc01 + i2*nbc02 + i3*nbc03)),