From 1ca898f94b6b9f847977cc227a830c106d544582 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Fri, 24 Feb 2023 21:15:13 +0200 Subject: [PATCH] gq : method 5 (ARM) --- CMakeLists.txt | 2 +- tests/test-mul-mat2.c | 126 +++++++++++++++++++++++++++++++----------- 2 files changed, 96 insertions(+), 32 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index c7a3e00..54d18b0 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -46,7 +46,7 @@ if (GGML_SANITIZE_UNDEFINED) endif() #set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -ffast-math") -set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -march=native") +#set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -march=native") #set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -mcpu=native") # dependencies diff --git a/tests/test-mul-mat2.c b/tests/test-mul-mat2.c index a39819f..66bc13b 100644 --- a/tests/test-mul-mat2.c +++ b/tests/test-mul-mat2.c @@ -1204,7 +1204,7 @@ void vec_dot_gq_4(const int n, float * restrict s, const void * restrict x, cons float sumf = 0.0; -#if 1 +#if 0 // scalar for (int i = 0; i < nb; i++) { const float m0 = GGML_GQ_TO_FP32(pm0[i]); @@ -1569,12 +1569,21 @@ void vec_dot_gq_4(const int n, float * restrict s, const void * restrict x, cons const uint16x8_t ph1l = vmull_u8(vget_low_u8 (v0_1h), vget_low_u8 (v1_1h)); const uint16x8_t ph1h = vmull_u8(vget_high_u8(v0_1h), vget_high_u8(v1_1h)); + const uint16x8_t pl0 = vaddq_u16(pl0l, pl0h); + const uint16x8_t pl1 = vaddq_u16(pl1l, pl1h); + const uint16x8_t ph0 = vaddq_u16(ph0l, ph0h); + const uint16x8_t ph1 = vaddq_u16(ph1l, ph1h); + + const uint16x8_t pl = vaddq_u16(pl0, pl1); + const uint16x8_t ph = vaddq_u16(ph0, ph1); + sum00 += m0*m1; - sum01 += m1*d0*(((int)vaddvq_u8(v0_0l)) + ((int)vaddvq_u8(v0_0h)) + ((int)vaddvq_u8(v0_1l)) + ((int)vaddvq_u8(v0_1h))); - sum10 += m0*d1*(((int)vaddvq_u8(v1_0l)) + ((int)vaddvq_u8(v1_0h)) + ((int)vaddvq_u8(v1_1l)) + ((int)vaddvq_u8(v1_1h))); - sum11 += d0*d1*( - vaddvq_u16(vaddq_u16(vaddq_u16(pl0l, pl0h), vaddq_u16(pl1l, pl1h))) + - vaddvq_u16(vaddq_u16(vaddq_u16(ph0l, ph0h), vaddq_u16(ph1l, ph1h)))); + sum01 += m1*d0*(vaddvq_u8(v0_0l) + vaddvq_u8(v0_0h) + vaddvq_u8(v0_1l) + vaddvq_u8(v0_1h)); + sum10 += m0*d1*(vaddvq_u8(v1_0l) + vaddvq_u8(v1_0h) + vaddvq_u8(v1_1l) + vaddvq_u8(v1_1h)); + //sum11 += d0*d1*( + // vaddvq_u16(vaddq_u16(vaddq_u16(pl0l, pl0h), vaddq_u16(pl1l, pl1h))) + + // vaddvq_u16(vaddq_u16(vaddq_u16(ph0l, ph0h), vaddq_u16(ph1l, ph1h)))); + sum11 += d0*d1*vaddvq_u16(vaddq_u16(pl, ph)); } sumf = 64.0*sum00 + sum01 + sum10 + sum11; @@ -1748,7 +1757,8 @@ void quantize_5_row(const float * restrict src, void * restrict dst, int k) { for (int l = 0; l < QK; l++) { const float v = src[i*QK + l]*id; - const int8_t vi = ((int8_t) (v + frand())) + 8; + const int8_t vi = ((int8_t) (round(v))) + 8; + assert(vi >= 0 && vi < 16); pp[l/2] |= (vi & 0xf) << (4*(l & 1)); } @@ -1859,29 +1869,33 @@ void vec_dot_gq_5(const int n, float * restrict s, const void * restrict x, cons sumf = _mm256_hadd_ps_gg(sum11); #endif #elif defined (__ARM_NEON) - float sum00 = 0.0f; - float sum01 = 0.0f; - float sum10 = 0.0f; float sum11 = 0.0f; + //float32x4_t sum_0 = vdupq_n_f32(0.0f); + //float32x4_t sum_1 = vdupq_n_f32(0.0f); + + //float16x8_t sum_0 = vdupq_n_f16(0.0f); + //float16x8_t sum_1 = vdupq_n_f16(0.0f); + for (int i = 0; i < nb; i++) { - const float m0 = GGML_GQ_TO_FP32(pm0[i]); const float d0 = GGML_GQ_TO_FP32(pd0[i]); - - const float m1 = GGML_GQ_TO_FP32(pm1[i]); const float d1 = GGML_GQ_TO_FP32(pd1[i]); + //float32x4_t d0d1v = vdupq_n_f32(d0*d1); + //float16x8_t d0d1v = vdupq_n_f16(d0*d1); + const uint8_t * restrict p0 = pb0 + i*QK/2; const uint8_t * restrict p1 = pb1 + i*QK/2; const uint8x16_t m4b = vdupq_n_u8(0xf); + const int8x16_t s8b = vdupq_n_s8(0x8); const uint8x16_t v0_0 = vld1q_u8(p0); const uint8x16_t v0_1 = vld1q_u8(p0 + 16); const uint8x16_t v1_0 = vld1q_u8(p1); const uint8x16_t v1_1 = vld1q_u8(p1 + 16); - // and with 0xf + // 4-bit -> 8-bit const uint8x16_t v0_0l = vandq_u8(v0_0, m4b); const uint8x16_t v0_1l = vandq_u8(v0_1, m4b); const uint8x16_t v1_0l = vandq_u8(v1_0, m4b); @@ -1892,26 +1906,76 @@ void vec_dot_gq_5(const int n, float * restrict s, const void * restrict x, cons const uint8x16_t v1_0h = vshrq_n_u8(v1_0, 4); const uint8x16_t v1_1h = vshrq_n_u8(v1_1, 4); - // dot product into uint16x8_t - const uint16x8_t pl0l = vmull_u8(vget_low_u8 (v0_0l), vget_low_u8 (v1_0l)); - const uint16x8_t pl0h = vmull_u8(vget_high_u8(v0_0l), vget_high_u8(v1_0l)); - const uint16x8_t pl1l = vmull_u8(vget_low_u8 (v0_1l), vget_low_u8 (v1_1l)); - const uint16x8_t pl1h = vmull_u8(vget_high_u8(v0_1l), vget_high_u8(v1_1l)); - - const uint16x8_t ph0l = vmull_u8(vget_low_u8 (v0_0h), vget_low_u8 (v1_0h)); - const uint16x8_t ph0h = vmull_u8(vget_high_u8(v0_0h), vget_high_u8(v1_0h)); - const uint16x8_t ph1l = vmull_u8(vget_low_u8 (v0_1h), vget_low_u8 (v1_1h)); - const uint16x8_t ph1h = vmull_u8(vget_high_u8(v0_1h), vget_high_u8(v1_1h)); + // sub 8 + const int8x16_t v0_0ls = vsubq_s8(v0_0l, s8b); + const int8x16_t v0_1ls = vsubq_s8(v0_1l, s8b); + const int8x16_t v1_0ls = vsubq_s8(v1_0l, s8b); + const int8x16_t v1_1ls = vsubq_s8(v1_1l, s8b); + + const int8x16_t v0_0hs = vsubq_s8(v0_0h, s8b); + const int8x16_t v0_1hs = vsubq_s8(v0_1h, s8b); + const int8x16_t v1_0hs = vsubq_s8(v1_0h, s8b); + const int8x16_t v1_1hs = vsubq_s8(v1_1h, s8b); + + // dot product into int16x8_t + const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0ls), vget_low_s8 (v1_0ls)); + const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0ls), vget_high_s8(v1_0ls)); + const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1ls), vget_low_s8 (v1_1ls)); + const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1ls), vget_high_s8(v1_1ls)); + + const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hs), vget_low_s8 (v1_0hs)); + const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hs), vget_high_s8(v1_0hs)); + const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hs), vget_low_s8 (v1_1hs)); + const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hs), vget_high_s8(v1_1hs)); + + const int16x8_t pl0 = vaddq_s16(pl0l, pl0h); + const int16x8_t pl1 = vaddq_s16(pl1l, pl1h); + const int16x8_t ph0 = vaddq_s16(ph0l, ph0h); + const int16x8_t ph1 = vaddq_s16(ph1l, ph1h); + + const int16x8_t pl = vaddq_s16(pl0, pl1); + const int16x8_t ph = vaddq_s16(ph0, ph1); + + //const int8x16_t pl0 = vmulq_s8(v0_0ls, v1_0ls); + //const int8x16_t pl1 = vmulq_s8(v0_1ls, v1_1ls); + //const int8x16_t ph0 = vmulq_s8(v0_0hs, v1_0hs); + //const int8x16_t ph1 = vmulq_s8(v0_1hs, v1_1hs); + + //const int16x8_t pll = vaddl_s8(vget_low_s8(pl0), vget_low_s8(pl1)); + //const int16x8_t plh = vaddl_s8(vget_high_s8(pl0), vget_high_s8(pl1)); + //const int16x8_t phl = vaddl_s8(vget_low_s8(ph0), vget_low_s8(ph1)); + //const int16x8_t phh = vaddl_s8(vget_high_s8(ph0), vget_high_s8(ph1)); + + //const int16x8_t pl = vaddq_s16(pll, plh); + //const int16x8_t ph = vaddq_s16(phl, phh); + + const int16x8_t p = vaddq_s16(pl, ph); - sum00 += m0*m1; - sum01 += m1*d0*(((int)vaddvq_u8(v0_0l)) + ((int)vaddvq_u8(v0_0h)) + ((int)vaddvq_u8(v0_1l)) + ((int)vaddvq_u8(v0_1h))); - sum10 += m0*d1*(((int)vaddvq_u8(v1_0l)) + ((int)vaddvq_u8(v1_0h)) + ((int)vaddvq_u8(v1_1l)) + ((int)vaddvq_u8(v1_1h))); - sum11 += d0*d1*( - vaddvq_u16(vaddq_u16(vaddq_u16(pl0l, pl0h), vaddq_u16(pl1l, pl1h))) + - vaddvq_u16(vaddq_u16(vaddq_u16(ph0l, ph0h), vaddq_u16(ph1l, ph1h)))); + // convert to float + //const float32x4_t pf0 = vcvtq_f32_s32(vmovl_s16(vget_low_s16 (p))); + //const float32x4_t pf1 = vcvtq_f32_s32(vmovl_s16(vget_high_s16(p))); + + // scalar + sum11 += d0*d1*vaddvq_u16(p); + //sum11 += d0*d1*(vaddvq_u16(pl) + vaddvq_u16(ph)); + //sum11 += d0*d1*vaddvq_u16(vaddq_s16(pl, ph)); + //sum11 += d0*d1*(vaddvq_s8(pl0) + vaddvq_s8(pl1) + vaddvq_s8(ph0) + vaddvq_s8(ph1)); + //sum11 += d0*d1*(vaddvq_s16(pll) + vaddvq_s16(plh) + vaddvq_s16(phl) + vaddvq_s16(phh)); + + //sum_0 = vfmaq_f16(sum_0, d0d1v, vcvtq_f16_s16(p)); + //sum_0 = vfmaq_f16(sum_0, d0d1v, vcvtq_f16_s16(pl)); + //sum_1 = vfmaq_f16(sum_1, d0d1v, vcvtq_f16_s16(ph)); + + // vectorize + //sum_0 = vmlaq_f32(sum_0, d0d1v, pf0); + //sum_1 = vmlaq_f32(sum_1, d0d1v, pf1); } - sumf = 64.0*sum00 + sum01 + sum10 + sum11; + sumf = sum11; + //sumf = vaddvq_f32(sum_0) + vaddvq_f32(sum_1); + //sumf = sum_0[0] + sum_0[1] + sum_0[2] + sum_0[3] + sum_0[4] + sum_0[5] + sum_0[6] + sum_0[7]; + //sum_0 = vaddq_f16(sum_0, sum_1); + //sumf = sum_0[0] + sum_0[1] + sum_0[2] + sum_0[3] + sum_0[4] + sum_0[5] + sum_0[6] + sum_0[7]; #endif #endif