From 05e7d26ba41ac727b9c2df395c2853a68e6f4aa6 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Mon, 27 Feb 2023 18:28:27 +0200 Subject: [PATCH] ggml : add WASM SIMD for Q4_0 --- src/ggml.c | 142 ++++++++++++++++++++++++++++++++++++++++++++--------- 1 file changed, 118 insertions(+), 24 deletions(-) diff --git a/src/ggml.c b/src/ggml.c index fdbbde2..2c60942 100644 --- a/src/ggml.c +++ b/src/ggml.c @@ -407,6 +407,45 @@ void quantize_row_q4_0(const float * restrict x, void * restrict y, int k) { #else #error "not implemented for QK" #endif +#elif defined(__wasm_simd128__) +#if QK == 32 + for (int i = 0; i < nb; i++) { + float amax = 0.0f; // absolute max + + v128_t srcv [8]; + v128_t asrcv[8]; + v128_t amaxv[8]; + + for (int l = 0; l < 8; l++) srcv[l] = wasm_v128_load(x + i*32 + 4*l); + for (int l = 0; l < 8; l++) asrcv[l] = wasm_f32x4_abs(srcv[l]); + + for (int l = 0; l < 4; l++) amaxv[2*l] = wasm_f32x4_max(asrcv[2*l], asrcv[2*l+1]); + for (int l = 0; l < 2; l++) amaxv[4*l] = wasm_f32x4_max(amaxv[4*l], amaxv[4*l+2]); + for (int l = 0; l < 1; l++) amaxv[8*l] = wasm_f32x4_max(amaxv[8*l], amaxv[8*l+4]); + + amax = MAX( + MAX(wasm_f32x4_extract_lane(amaxv[0], 0), wasm_f32x4_extract_lane(amaxv[0], 1)), + MAX(wasm_f32x4_extract_lane(amaxv[0], 2), wasm_f32x4_extract_lane(amaxv[0], 3))); + + const float d = amax / ((1 << 3) - 1); + const float id = d ? 1.0/d : 0.0; + + pd[i] = d; + + for (int l = 0; l < 8; l++) { + const v128_t v = wasm_f32x4_mul(srcv[l], wasm_f32x4_splat(id)); + const v128_t vf = wasm_f32x4_add(v, wasm_f32x4_splat(8.5f)); + const v128_t vi = wasm_i32x4_trunc_sat_f32x4(vf); + + pp[2*l + 0] = wasm_i32x4_extract_lane(vi, 0) | (wasm_i32x4_extract_lane(vi, 1) << 4); + pp[2*l + 1] = wasm_i32x4_extract_lane(vi, 2) | (wasm_i32x4_extract_lane(vi, 3) << 4); + } + + memcpy(pb + i*16, pp, sizeof(pp)); + } +#else +#error "not implemented for QK" +#endif #else // scalar for (int i = 0; i < nb; i++) { @@ -1216,9 +1255,6 @@ inline static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void const int16x8_t p_0 = vaddq_s16(pl_0, ph_0); const int16x8_t p_1 = vaddq_s16(pl_1, ph_1); - //printf("p_0: %d %d %d %d %d %d %d %d\n", vgetq_lane_s16(p_0, 0), vgetq_lane_s16(p_0, 1), vgetq_lane_s16(p_0, 2), vgetq_lane_s16(p_0, 3), vgetq_lane_s16(p_0, 4), vgetq_lane_s16(p_0, 5), vgetq_lane_s16(p_0, 6), vgetq_lane_s16(p_0, 7)); - //printf("p_1: %d %d %d %d %d %d %d %d\n", vgetq_lane_s16(p_1, 0), vgetq_lane_s16(p_1, 1), vgetq_lane_s16(p_1, 2), vgetq_lane_s16(p_1, 3), vgetq_lane_s16(p_1, 4), vgetq_lane_s16(p_1, 5), vgetq_lane_s16(p_1, 6), vgetq_lane_s16(p_1, 7)); - // scalar #if defined(__ARM_FEATURE_QRDMX) sum0 += d0_0*d1_0*vaddvq_s16(p_0); @@ -1230,35 +1266,93 @@ inline static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void } sumf = sum0 + sum1; +#else +#error "not implemented for QK" +#endif +#elif defined(__wasm_simd128__) +#if QK == 32 + // wasm simd + float sum0 = 0.0f; + float sum1 = 0.0f; - //printf("sumf SIMD = %f\n", sumf); + for (int i = 0; i < nb; i += 2) { + const float d0_0 = pd0[i + 0]; + const float d0_1 = pd0[i + 1]; + const float d1_0 = pd1[i + 0]; + const float d1_1 = pd1[i + 1]; - //// scalar - //sumf = 0.0f; - //for (int i = 0; i < nb; i++) { - // const float d0 = pd0[i]; - // const float d1 = pd1[i]; + const uint8_t * restrict p0 = pb0 + i*16; + const uint8_t * restrict p1 = pb1 + i*16; - // const uint8_t * restrict p0 = pb0 + i*QK/2; - // const uint8_t * restrict p1 = pb1 + i*QK/2; + const v128_t m4b = wasm_u8x16_splat(0xf); + const v128_t s8b = wasm_i8x16_splat(0x8); - // for (int j = 0; j < QK/2; j++) { - // const uint8_t v0 = p0[j]; - // const uint8_t v1 = p1[j]; + const v128_t v0_0 = wasm_v128_load(p0); + const v128_t v0_1 = wasm_v128_load(p0 + 16); + const v128_t v1_0 = wasm_v128_load(p1); + const v128_t v1_1 = wasm_v128_load(p1 + 16); - // const float f0 = d0*((int8_t) (v0 & 0xf) - 8); - // const float f1 = d0*((int8_t) (v0 >> 4) - 8); + // 4-bit -> 8-bit + const v128_t v0_0l = wasm_v128_and(v0_0, m4b); + const v128_t v1_0l = wasm_v128_and(v1_0, m4b); - // const float f2 = d1*((int8_t) (v1 & 0xf) - 8); - // const float f3 = d1*((int8_t) (v1 >> 4) - 8); + const v128_t v0_0h = wasm_u8x16_shr(v0_0, 4); + const v128_t v1_0h = wasm_u8x16_shr(v1_0, 4); - // sumf += f0*f2 + f1*f3; - // } - //} - //printf("sumf scalar = %f\n", sumf); - //printf("--------\n"); + const v128_t v0_1l = wasm_v128_and(v0_1, m4b); + const v128_t v1_1l = wasm_v128_and(v1_1, m4b); + + const v128_t v0_1h = wasm_u8x16_shr(v0_1, 4); + const v128_t v1_1h = wasm_u8x16_shr(v1_1, 4); + + // sub 8 + const v128_t v0_0ls = wasm_i8x16_sub(v0_0l, s8b); + const v128_t v1_0ls = wasm_i8x16_sub(v1_0l, s8b); - //exit(0); + const v128_t v0_0hs = wasm_i8x16_sub(v0_0h, s8b); + const v128_t v1_0hs = wasm_i8x16_sub(v1_0h, s8b); + + const v128_t v0_1ls = wasm_i8x16_sub(v0_1l, s8b); + const v128_t v1_1ls = wasm_i8x16_sub(v1_1l, s8b); + + const v128_t v0_1hs = wasm_i8x16_sub(v0_1h, s8b); + const v128_t v1_1hs = wasm_i8x16_sub(v1_1h, s8b); + + // dot product into int16x8_t + const v128_t pl0l = wasm_i16x8_mul(wasm_i16x8_extend_low_i8x16(v0_0ls), wasm_i16x8_extend_low_i8x16(v1_0ls)); + const v128_t pl0h = wasm_i16x8_mul(wasm_i16x8_extend_high_i8x16(v0_0ls), wasm_i16x8_extend_high_i8x16(v1_0ls)); + + const v128_t ph0l = wasm_i16x8_mul(wasm_i16x8_extend_low_i8x16(v0_0hs), wasm_i16x8_extend_low_i8x16(v1_0hs)); + const v128_t ph0h = wasm_i16x8_mul(wasm_i16x8_extend_high_i8x16(v0_0hs), wasm_i16x8_extend_high_i8x16(v1_0hs)); + + const v128_t pl1l = wasm_i16x8_mul(wasm_i16x8_extend_low_i8x16(v0_1ls), wasm_i16x8_extend_low_i8x16(v1_1ls)); + const v128_t pl1h = wasm_i16x8_mul(wasm_i16x8_extend_high_i8x16(v0_1ls), wasm_i16x8_extend_high_i8x16(v1_1ls)); + + const v128_t ph1l = wasm_i16x8_mul(wasm_i16x8_extend_low_i8x16(v0_1hs), wasm_i16x8_extend_low_i8x16(v1_1hs)); + const v128_t ph1h = wasm_i16x8_mul(wasm_i16x8_extend_high_i8x16(v0_1hs), wasm_i16x8_extend_high_i8x16(v1_1hs)); + + const v128_t pl_0 = wasm_i16x8_add(pl0l, pl0h); + const v128_t ph_0 = wasm_i16x8_add(ph0l, ph0h); + + const v128_t pl_1 = wasm_i16x8_add(pl1l, pl1h); + const v128_t ph_1 = wasm_i16x8_add(ph1l, ph1h); + + const v128_t p_0 = wasm_i16x8_add(pl_0, ph_0); + const v128_t p_1 = wasm_i16x8_add(pl_1, ph_1); + + sum0 += d0_0*d1_0*( + wasm_i16x8_extract_lane(p_0, 0) + wasm_i16x8_extract_lane(p_0, 1) + + wasm_i16x8_extract_lane(p_0, 2) + wasm_i16x8_extract_lane(p_0, 3) + + wasm_i16x8_extract_lane(p_0, 4) + wasm_i16x8_extract_lane(p_0, 5) + + wasm_i16x8_extract_lane(p_0, 6) + wasm_i16x8_extract_lane(p_0, 7)); + sum1 += d0_1*d1_1*( + wasm_i16x8_extract_lane(p_1, 0) + wasm_i16x8_extract_lane(p_1, 1) + + wasm_i16x8_extract_lane(p_1, 2) + wasm_i16x8_extract_lane(p_1, 3) + + wasm_i16x8_extract_lane(p_1, 4) + wasm_i16x8_extract_lane(p_1, 5) + + wasm_i16x8_extract_lane(p_1, 6) + wasm_i16x8_extract_lane(p_1, 7)); + } + + sumf = sum0 + sum1; #else #error "not implemented for QK" #endif