From ea97a5f469472f98fdea1eab4ef56b5cb8931707 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sat, 25 Feb 2023 19:08:02 +0200 Subject: [PATCH] ggml : vectorized mad q4_0 (ARM) --- examples/gpt-2/quantize.cpp | 2 +- src/ggml.c | 239 +++++++++++++++++++++++++++++++++++- tests/test-mul-mat2.c | 14 +-- 3 files changed, 246 insertions(+), 9 deletions(-) diff --git a/examples/gpt-2/quantize.cpp b/examples/gpt-2/quantize.cpp index 128cd5d..08895e7 100644 --- a/examples/gpt-2/quantize.cpp +++ b/examples/gpt-2/quantize.cpp @@ -12,7 +12,7 @@ #include #include -#define QK 64 +#define QK 32 size_t ggml_quantize_q4_0(float * src, void * dst, int n, int k) { const int nb = k / QK; diff --git a/src/ggml.c b/src/ggml.c index 7378af2..1a138c4 100644 --- a/src/ggml.c +++ b/src/ggml.c @@ -352,7 +352,7 @@ static const size_t CACHE_LINE_SIZE_F32 = CACHE_LINE_SIZE/sizeof(float); // quantization // -#define QK 64 +#define QK 32 // method 5 // blocks of QK elements @@ -1094,6 +1094,9 @@ inline static void ggml_vec_dot_f16(const int n, float * restrict s, ggml_fp16_t inline static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void * restrict x, const void * restrict y) { const int nb = n / QK; + assert(n % QK == 0); + assert(nb % 2 == 0); + const float * restrict pd0 = (const float *) x; const float * restrict pd1 = (const float *) y; @@ -1102,6 +1105,120 @@ inline static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void float sumf = 0.0; +#ifdef __ARM_NEON +#if QK == 32 + float sum0 = 0.0f; + float sum1 = 0.0f; + + for (int i = 0; i < nb; i += 2) { + const float d0_0 = pd0[i + 0]; + const float d1_0 = pd1[i + 0]; + const float d0_1 = pd0[i + 1]; + const float d1_1 = pd1[i + 1]; + + //printf("d0_0: %f, d1_0: %f, d0_1: %f, d1_1: %f\n", d0_0, d1_0, d0_1, d1_1); + + const uint8_t * restrict p0 = pb0 + i*16; + const uint8_t * restrict p1 = pb1 + i*16; + + const uint8x16_t m4b = vdupq_n_u8(0xf); + const int8x16_t s8b = vdupq_n_s8(0x8); + + const uint8x16_t v0_0 = vld1q_u8(p0); + const uint8x16_t v1_0 = vld1q_u8(p1); + const uint8x16_t v0_1 = vld1q_u8(p0 + 16); + const uint8x16_t v1_1 = vld1q_u8(p1 + 16); + + // 4-bit -> 8-bit + const int8x16_t v0_0l = vandq_u8(v0_0, m4b); + const int8x16_t v1_0l = vandq_u8(v1_0, m4b); + + const int8x16_t v0_0h = vshrq_n_u8(v0_0, 4); + const int8x16_t v1_0h = vshrq_n_u8(v1_0, 4); + + const int8x16_t v0_1l = vandq_u8(v0_1, m4b); + const int8x16_t v1_1l = vandq_u8(v1_1, m4b); + + const int8x16_t v0_1h = vshrq_n_u8(v0_1, 4); + const int8x16_t v1_1h = vshrq_n_u8(v1_1, 4); + + // sub 8 + const int8x16_t v0_0ls = vsubq_s8(v0_0l, s8b); + const int8x16_t v1_0ls = vsubq_s8(v1_0l, s8b); + + const int8x16_t v0_0hs = vsubq_s8(v0_0h, s8b); + const int8x16_t v1_0hs = vsubq_s8(v1_0h, s8b); + + const int8x16_t v0_1ls = vsubq_s8(v0_1l, s8b); + const int8x16_t v1_1ls = vsubq_s8(v1_1l, s8b); + + const int8x16_t v0_1hs = vsubq_s8(v0_1h, s8b); + const int8x16_t v1_1hs = vsubq_s8(v1_1h, s8b); + + // dot product into int16x8_t + const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0ls), vget_low_s8 (v1_0ls)); + const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0ls), vget_high_s8(v1_0ls)); + + const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hs), vget_low_s8 (v1_0hs)); + const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hs), vget_high_s8(v1_0hs)); + + const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1ls), vget_low_s8 (v1_1ls)); + const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1ls), vget_high_s8(v1_1ls)); + + const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hs), vget_low_s8 (v1_1hs)); + const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hs), vget_high_s8(v1_1hs)); + + const int16x8_t pl_0 = vaddq_s16(pl0l, pl0h); + const int16x8_t ph_0 = vaddq_s16(ph0l, ph0h); + + const int16x8_t pl_1 = vaddq_s16(pl1l, pl1h); + const int16x8_t ph_1 = vaddq_s16(ph1l, ph1h); + + const int16x8_t p_0 = vaddq_s16(pl_0, ph_0); + const int16x8_t p_1 = vaddq_s16(pl_1, ph_1); + + //printf("p_0: %d %d %d %d %d %d %d %d\n", vgetq_lane_s16(p_0, 0), vgetq_lane_s16(p_0, 1), vgetq_lane_s16(p_0, 2), vgetq_lane_s16(p_0, 3), vgetq_lane_s16(p_0, 4), vgetq_lane_s16(p_0, 5), vgetq_lane_s16(p_0, 6), vgetq_lane_s16(p_0, 7)); + //printf("p_1: %d %d %d %d %d %d %d %d\n", vgetq_lane_s16(p_1, 0), vgetq_lane_s16(p_1, 1), vgetq_lane_s16(p_1, 2), vgetq_lane_s16(p_1, 3), vgetq_lane_s16(p_1, 4), vgetq_lane_s16(p_1, 5), vgetq_lane_s16(p_1, 6), vgetq_lane_s16(p_1, 7)); + + // scalar + sum0 += d0_0*d1_0*vaddvq_s16(p_0); + sum1 += d0_1*d1_1*vaddvq_s16(p_1); + } + + sumf = sum0 + sum1; + + //printf("sumf SIMD = %f\n", sumf); + + //// scalar + //sumf = 0.0f; + //for (int i = 0; i < nb; i++) { + // const float d0 = pd0[i]; + // const float d1 = pd1[i]; + + // const uint8_t * restrict p0 = pb0 + i*QK/2; + // const uint8_t * restrict p1 = pb1 + i*QK/2; + + // for (int j = 0; j < QK/2; j++) { + // const uint8_t v0 = p0[j]; + // const uint8_t v1 = p1[j]; + + // const float f0 = d0*((int8_t) (v0 & 0xf) - 8); + // const float f1 = d0*((int8_t) (v0 >> 4) - 8); + + // const float f2 = d1*((int8_t) (v1 & 0xf) - 8); + // const float f3 = d1*((int8_t) (v1 >> 4) - 8); + + // sumf += f0*f2 + f1*f3; + // } + //} + //printf("sumf scalar = %f\n", sumf); + //printf("--------\n"); + + //exit(0); +#else +#error "not implemented" +#endif +#else // scalar for (int i = 0; i < nb; i++) { const float d0 = pd0[i]; @@ -1123,6 +1240,7 @@ inline static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void sumf += f0*f2 + f1*f3; } } +#endif *s = sumf; } @@ -1296,6 +1414,124 @@ inline static void ggml_vec_mad_q4_0(const int n, float * restrict y, void * res const float * restrict pd = (const float *) (x); const uint8_t * restrict pb = (const uint8_t *) (pd + nb); +#if __ARM_NEON +#if QK == 32 + for (int i = 0; i < nb; ++i) { + const float d0 = pd[i]*v; + + const uint8_t * restrict pp = pb + i*16; + + const float32x4_t vd = vdupq_n_f32(d0); + + const float32x4_t vy0 = vld1q_f32(y + i*32 + 0); + const float32x4_t vy1 = vld1q_f32(y + i*32 + 4); + const float32x4_t vy2 = vld1q_f32(y + i*32 + 8); + const float32x4_t vy3 = vld1q_f32(y + i*32 + 12); + const float32x4_t vy4 = vld1q_f32(y + i*32 + 16); + const float32x4_t vy5 = vld1q_f32(y + i*32 + 20); + const float32x4_t vy6 = vld1q_f32(y + i*32 + 24); + const float32x4_t vy7 = vld1q_f32(y + i*32 + 28); + + const uint8x16_t m4b = vdupq_n_u8(0xf); + const int8x16_t s8b = vdupq_n_s8(0x8); + + const uint8x16_t vx = vld1q_u8(pp); + + const uint8x16_t vxl = vandq_u8 (vx, m4b); + const uint8x16_t vxh = vshrq_n_u8(vx, 4); + + // sub 8 + const int8x16_t vxls = vsubq_s8(vxl, s8b); + const int8x16_t vxhs = vsubq_s8(vxh, s8b); + + const int8x16_t vxlt = vzip1q_s8(vxls, vxhs); + const int8x16_t vxht = vzip2q_s8(vxls, vxhs); + + //printf("vxv: %3d %3d %3d %3d %3d %3d %3d %3d %3d %3d %3d %3d %3d %3d %3d %3d ", + // vgetq_lane_s8(vxlt, 0), vgetq_lane_s8(vxlt, 1), vgetq_lane_s8(vxlt, 2), vgetq_lane_s8(vxlt, 3), + // vgetq_lane_s8(vxlt, 4), vgetq_lane_s8(vxlt, 5), vgetq_lane_s8(vxlt, 6), vgetq_lane_s8(vxlt, 7), + // vgetq_lane_s8(vxlt, 8), vgetq_lane_s8(vxlt, 9), vgetq_lane_s8(vxlt, 10), vgetq_lane_s8(vxlt, 11), + // vgetq_lane_s8(vxlt, 12), vgetq_lane_s8(vxlt, 13), vgetq_lane_s8(vxlt, 14), vgetq_lane_s8(vxlt, 15)); + //printf("%3d %3d %3d %3d %3d %3d %3d %3d %3d %3d %3d %3d %3d %3d %3d %3d\n", + // vgetq_lane_s8(vxht, 0), vgetq_lane_s8(vxht, 1), vgetq_lane_s8(vxht, 2), vgetq_lane_s8(vxht, 3), + // vgetq_lane_s8(vxht, 4), vgetq_lane_s8(vxht, 5), vgetq_lane_s8(vxht, 6), vgetq_lane_s8(vxht, 7), + // vgetq_lane_s8(vxht, 8), vgetq_lane_s8(vxht, 9), vgetq_lane_s8(vxht, 10), vgetq_lane_s8(vxht, 11), + // vgetq_lane_s8(vxht, 12), vgetq_lane_s8(vxht, 13), vgetq_lane_s8(vxht, 14), vgetq_lane_s8(vxht, 15)); + + + // convert to 4x int16x8_t + const int16x8_t vxls0 = vmovl_s8(vget_low_s8 (vxlt)); + const int16x8_t vxls1 = vmovl_s8(vget_high_s8(vxlt)); + const int16x8_t vxhs0 = vmovl_s8(vget_low_s8 (vxht)); + const int16x8_t vxhs1 = vmovl_s8(vget_high_s8(vxht)); + + // convert to 8x float32x4_t + const float32x4_t vx0 = vcvtq_f32_s32(vmovl_s16(vget_low_s16 (vxls0))); + const float32x4_t vx1 = vcvtq_f32_s32(vmovl_s16(vget_high_s16(vxls0))); + const float32x4_t vx2 = vcvtq_f32_s32(vmovl_s16(vget_low_s16 (vxls1))); + const float32x4_t vx3 = vcvtq_f32_s32(vmovl_s16(vget_high_s16(vxls1))); + const float32x4_t vx4 = vcvtq_f32_s32(vmovl_s16(vget_low_s16 (vxhs0))); + const float32x4_t vx5 = vcvtq_f32_s32(vmovl_s16(vget_high_s16(vxhs0))); + const float32x4_t vx6 = vcvtq_f32_s32(vmovl_s16(vget_low_s16 (vxhs1))); + const float32x4_t vx7 = vcvtq_f32_s32(vmovl_s16(vget_high_s16(vxhs1))); + + //printf("vxv: %6.2f %6.2f %6.2f %6.2f %6.2f %6.2f %6.2f %6.2f %6.2f %6.2f %6.2f %6.2f %6.2f %6.2f %6.2f %6.2f\n", + // vgetq_lane_f32(vx0m, 0), vgetq_lane_f32(vx0m, 1), vgetq_lane_f32(vx0m, 2), vgetq_lane_f32(vx0m, 3), + // vgetq_lane_f32(vx1m, 0), vgetq_lane_f32(vx1m, 1), vgetq_lane_f32(vx1m, 2), vgetq_lane_f32(vx1m, 3), + // vgetq_lane_f32(vx2m, 0), vgetq_lane_f32(vx2m, 1), vgetq_lane_f32(vx2m, 2), vgetq_lane_f32(vx2m, 3), + // vgetq_lane_f32(vx3m, 0), vgetq_lane_f32(vx3m, 1), vgetq_lane_f32(vx3m, 2), vgetq_lane_f32(vx3m, 3)); + + const float32x4_t vr0 = vfmaq_f32(vy0, vx0, vd); + const float32x4_t vr1 = vfmaq_f32(vy1, vx1, vd); + const float32x4_t vr2 = vfmaq_f32(vy2, vx2, vd); + const float32x4_t vr3 = vfmaq_f32(vy3, vx3, vd); + const float32x4_t vr4 = vfmaq_f32(vy4, vx4, vd); + const float32x4_t vr5 = vfmaq_f32(vy5, vx5, vd); + const float32x4_t vr6 = vfmaq_f32(vy6, vx6, vd); + const float32x4_t vr7 = vfmaq_f32(vy7, vx7, vd); + + vst1q_f32(y + i*32 + 0, vr0); + vst1q_f32(y + i*32 + 4, vr1); + vst1q_f32(y + i*32 + 8, vr2); + vst1q_f32(y + i*32 + 12, vr3); + vst1q_f32(y + i*32 + 16, vr4); + vst1q_f32(y + i*32 + 20, vr5); + vst1q_f32(y + i*32 + 24, vr6); + vst1q_f32(y + i*32 + 28, vr7); + } + + //printf("------\n"); + + //for (int i = 0; i < nb; i++) { + // const float d = pd[i]; + + // const uint8_t * restrict pp = pb + i*QK/2; + + // printf("vxs: "); + // for (int l = 0; l < QK; l += 2) { + // const uint8_t vi = pp[l/2]; + + // const int8_t vi0 = vi & 0xf; + // const int8_t vi1 = vi >> 4; + + // const float v0 = (vi0 - 8)*d; + // const float v1 = (vi1 - 8)*d; + + // y[i*QK + l + 0] += v0*v; + // y[i*QK + l + 1] += v1*v; + + // if (l < QK) { + // //printf("%6.2f %6.2f ", v0, v1); + // printf("%3d %3d ", vi0 - 8, vi1 - 8); + // } + // } + // printf("\n"); + //} + + //exit(0); +#endif +#else + // scalar for (int i = 0; i < nb; i++) { const float d = pd[i]; @@ -1320,6 +1556,7 @@ inline static void ggml_vec_mad_q4_0(const int n, float * restrict y, void * res //printf("mad: v0 %f v1 %f, i = %d, l = %d, d = %f, vi = %d, vi0 = %d, vi1 = %d\n", v0, v1, i, l, d, vi, vi0, vi1); } } +#endif } inline static void ggml_vec_mad_q4_1(const int n, float * restrict y, void * restrict x, const float v) { diff --git a/tests/test-mul-mat2.c b/tests/test-mul-mat2.c index c9bee75..f5ebfd5 100644 --- a/tests/test-mul-mat2.c +++ b/tests/test-mul-mat2.c @@ -1956,9 +1956,9 @@ void vec_dot_gq_5(const int n, float * restrict s, const void * restrict x, cons //const float32x4_t pf1 = vcvtq_f32_s32(vmovl_s16(vget_high_s16(p))); // scalar - sum11 += d0*d1*vaddvq_u16(p); - //sum11 += d0*d1*(vaddvq_u16(pl) + vaddvq_u16(ph)); - //sum11 += d0*d1*vaddvq_u16(vaddq_s16(pl, ph)); + sum11 += d0*d1*vaddvq_s16(p); + //sum11 += d0*d1*(vaddvq_s16(pl) + vaddvq_s16(ph)); + //sum11 += d0*d1*vaddvq_s16(vaddq_s16(pl, ph)); //sum11 += d0*d1*(vaddvq_s8(pl0) + vaddvq_s8(pl1) + vaddvq_s8(ph0) + vaddvq_s8(ph1)); //sum11 += d0*d1*(vaddvq_s16(pll) + vaddvq_s16(plh) + vaddvq_s16(phl) + vaddvq_s16(phh)); @@ -2246,7 +2246,7 @@ void vec_dot_gq_6(const int n, float * restrict s, const void * restrict x, cons const int16x8_t p = vaddq_s16(pl, ph); // scalar - sum0 += d0*d1*vaddvq_u16(p); + sum0 += d0*d1*vaddvq_s16(p); } sumf = sum0; @@ -2267,8 +2267,8 @@ void vec_dot_gq_6(const int n, float * restrict s, const void * restrict x, cons const int8x16_t s8b = vdupq_n_s8(0x8); const uint8x16_t v0_0 = vld1q_u8(p0); - const uint8x16_t v1_0 = vld1q_u8(p1); const uint8x16_t v0_1 = vld1q_u8(p0 + 16); + const uint8x16_t v1_0 = vld1q_u8(p1); const uint8x16_t v1_1 = vld1q_u8(p1 + 16); // 4-bit -> 8-bit @@ -2320,8 +2320,8 @@ void vec_dot_gq_6(const int n, float * restrict s, const void * restrict x, cons const int16x8_t p_1 = vaddq_s16(pl_1, ph_1); // scalar - sum0 += d0_0*d1_0*vaddvq_u16(p_0); - sum1 += d0_1*d1_1*vaddvq_u16(p_1); + sum0 += d0_0*d1_0*vaddvq_s16(p_0); + sum1 += d0_1*d1_1*vaddvq_s16(p_1); } sumf = sum0 + sum1;