diff --git a/CMakeLists.txt b/CMakeLists.txt index d88c5b1..c7a3e00 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -31,7 +31,7 @@ option(GGML_NO_ACCELERATE "ggml: disable Accelerate framework" OFF) # sanitizers if (GGML_SANITIZE_THREAD) - set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -fsanitize=thread") + set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -fsanitize=thread") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fsanitize=thread") endif() @@ -46,7 +46,7 @@ if (GGML_SANITIZE_UNDEFINED) endif() #set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -ffast-math") -#set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -march=native") +set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -march=native") #set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -mcpu=native") # dependencies diff --git a/tests/test-mul-mat2.c b/tests/test-mul-mat2.c index a3443c0..a39819f 100644 --- a/tests/test-mul-mat2.c +++ b/tests/test-mul-mat2.c @@ -35,7 +35,7 @@ const int K = 1280; #define QK 64 #define QB 4 -/*#define GGML_GQ_USE_FP16_SCALE*/ +//#define GGML_GQ_USE_FP16_SCALE #if defined(GGML_GQ_USE_FP16_SCALE) #define gq_scale_t ggml_fp16_t @@ -60,6 +60,47 @@ uint64_t get_time_us() { return tv.tv_sec * 1000000 + tv.tv_usec; } +#if defined(__AVX2__) +// horizontally reduce 8 32-bit integers +static inline uint32_t _mm256_hadd_epi32_gg(__m256i v) { + __m128i v0 = _mm256_extractf128_si256(v, 0); + __m128i v1 = _mm256_extractf128_si256(v, 1); + + v0 = _mm_add_epi32(v0, v1); + + v1 = _mm_shuffle_epi32(v0, 0x0e); + v0 = _mm_add_epi32(v0, v1); + + v1 = _mm_shuffle_epi32(v0, 0x01); + v0 = _mm_add_epi32(v0, v1); + + return _mm_cvtsi128_si32(v0); +} + +//static inline float _mm256_hadd_epi32_gg(__m256i v) { +// const __m256 v0 = _mm256_cvtepi32_ps(v); +// const __m128 t0 = _mm_add_ps(_mm256_castps256_ps128(v0), _mm256_extractf128_ps(v0, 1)); +// const __m128 t1 = _mm_hadd_ps(t0, t0); +// +// return _mm_cvtss_f32(_mm_hadd_ps(t1, t1)); +//} + +// horizontally reduce 32 8-bit integers +static inline int32_t _mm256_hadd_epi8_gg(__m256i v0) { + __m256i v1 = _mm256_maddubs_epi16(v0, _mm256_set1_epi8(1)); + __m256i v2 = _mm256_madd_epi16 (v1, _mm256_set1_epi16(1)); + + return _mm256_hadd_epi32_gg(v2); +} + +static inline float _mm256_hadd_ps_gg(__m256 v) { + const __m128 t0 = _mm_add_ps(_mm256_castps256_ps128(v), _mm256_extractf128_ps(v, 1)); + const __m128 t1 = _mm_hadd_ps(t0, t0); + + return _mm_cvtss_f32(_mm_hadd_ps(t1, t1)); +} +#endif + // // naive implementation // @@ -240,6 +281,7 @@ void mul_mat_gq_1( // // method 2 +// n-bit quantization (2nd attempt) // static inline int quantize_2_blocks_per_row(int k) { @@ -1021,27 +1063,28 @@ void quantize_4_row(const float * restrict src, void * restrict dst, int k) { #if defined(__AVX2__) { - assert(QK < 64); + assert(QK == 64); + const int QK8 = QK/8; - __m256 srcv[QK/8]; - __m256 minv[QK/8]; // ??????? - __m256 maxv[QK/8]; // ??????? + __m256 srcv[QK8]; + __m256 minv[QK8]; + __m256 maxv[QK8]; - for (int l = 0; l < QK/8; l++) { + for (int l = 0; l < QK8; l++) { srcv[l] = _mm256_loadu_ps(src + i*QK + 8*l); } - for (int l = 0; l < QK/8/2; l++) { + for (int l = 0; l < QK8/2; l++) { minv[2*l] = _mm256_min_ps(srcv[2*l], srcv[2*l+1]); maxv[2*l] = _mm256_max_ps(srcv[2*l], srcv[2*l+1]); } - for (int l = 0; l < QK/8/4; l++) { + for (int l = 0; l < QK8/4; l++) { minv[4*l] = _mm256_min_ps(minv[4*l], minv[4*l+2]); maxv[4*l] = _mm256_max_ps(maxv[4*l], maxv[4*l+2]); } - for (int l = 0; l < QK/8/8; l++) { + for (int l = 0; l < QK8/8; l++) { minv[8*l] = _mm256_min_ps(minv[8*l], minv[8*l+4]); maxv[8*l] = _mm256_max_ps(maxv[8*l], maxv[8*l+4]); } @@ -1147,40 +1190,6 @@ void quantize_4(const float * restrict src, char * restrict dst, int n, int k) { } } -#if defined(__AVX2__) -// horizontally reduce 8 32-bit integers -static inline int32_t _mm256_hadd_epi32_gg(__m256i v) { - __m128i v0 = _mm256_extractf128_si256(v, 0); - __m128i v1 = _mm256_extractf128_si256(v, 1); - - v0 = _mm_add_epi32(v0, v1); - - v1 = _mm_shuffle_epi32(v0, 0x0e); - v0 = _mm_add_epi32(v0, v1); - - v1 = _mm_shuffle_epi32(v0, 0x01); - v0 = _mm_add_epi32(v0, v1); - - return _mm_cvtsi128_si32(v0); -} - -//static inline float _mm256_hadd_epi32_gg(__m256i v) { -// const __m256 v0 = _mm256_cvtepi32_ps(v); -// const __m128 t0 = _mm_add_ps(_mm256_castps256_ps128(v0), _mm256_extractf128_ps(v0, 1)); -// const __m128 t1 = _mm_hadd_ps(t0, t0); -// -// return _mm_cvtss_f32(_mm_hadd_ps(t1, t1)); -//} - -// horizontally reduce 32 8-bit integers -static inline int32_t _mm256_hadd_epi8_gg(__m256i v0) { - __m256i v1 = _mm256_maddubs_epi16(v0, _mm256_set1_epi8(1)); - __m256i v2 = _mm256_madd_epi16 (v1, _mm256_set1_epi16(1)); - - return _mm256_hadd_epi32_gg(v2); -} -#endif - void vec_dot_gq_4(const int n, float * restrict s, const void * restrict x, const void * restrict y) { const int nb = quantize_4_blocks_per_row(n); @@ -1195,31 +1204,32 @@ void vec_dot_gq_4(const int n, float * restrict s, const void * restrict x, cons float sumf = 0.0; +#if 1 // scalar - //for (int i = 0; i < nb; i++) { - // const float m0 = GGML_GQ_TO_FP32(pm0[i]); - // const float d0 = GGML_GQ_TO_FP32(pd0[i]); - - // const float m1 = GGML_GQ_TO_FP32(pm1[i]); - // const float d1 = GGML_GQ_TO_FP32(pd1[i]); + for (int i = 0; i < nb; i++) { + const float m0 = GGML_GQ_TO_FP32(pm0[i]); + const float d0 = GGML_GQ_TO_FP32(pd0[i]); - // const uint8_t * restrict p0 = pb0 + i*QK/2; - // const uint8_t * restrict p1 = pb1 + i*QK/2; + const float m1 = GGML_GQ_TO_FP32(pm1[i]); + const float d1 = GGML_GQ_TO_FP32(pd1[i]); - // for (int j = 0; j < QK/2; j++) { - // const uint8_t v0 = p0[j]; - // const uint8_t v1 = p1[j]; + const uint8_t * restrict p0 = pb0 + i*QK/2; + const uint8_t * restrict p1 = pb1 + i*QK/2; - // const float f0 = d0*(v0 & 0xf) + m0; - // const float f1 = d0*(v0 >> 4) + m0; + for (int j = 0; j < QK/2; j++) { + const uint8_t v0 = p0[j]; + const uint8_t v1 = p1[j]; - // const float f2 = d1*(v1 & 0xf) + m1; - // const float f3 = d1*(v1 >> 4) + m1; + const float f0 = d0*(v0 & 0xf) + m0; + const float f1 = d0*(v0 >> 4) + m0; - // sumf += f0*f2 + f1*f3; - // } - //} + const float f2 = d1*(v1 & 0xf) + m1; + const float f3 = d1*(v1 >> 4) + m1; + sumf += f0*f2 + f1*f3; + } + } +#else #if defined(__AVX2__) #if QK == 64 && 0 __m256 sumv0 = _mm256_setzero_ps(); @@ -1394,12 +1404,14 @@ void vec_dot_gq_4(const int n, float * restrict s, const void * restrict x, cons const __m128 t1 = _mm_hadd_ps(t0, t0); sumf += _mm_cvtss_f32(_mm_hadd_ps(t1, t1)); -#elif QK == 64 +#elif QK == 64 && 0 float sum00 = 0.0f; float sum01 = 0.0f; float sum10 = 0.0f; float sum11 = 0.0f; + const __m256i m4b = _mm256_set1_epi8(0xf); + for (int i = 0; i < nb; i++) { const float m0 = GGML_GQ_TO_FP32(pm0[i]); const float d0 = GGML_GQ_TO_FP32(pd0[i]); @@ -1410,13 +1422,13 @@ void vec_dot_gq_4(const int n, float * restrict s, const void * restrict x, cons const uint8_t * restrict p0 = pb0 + i*QK/2; const uint8_t * restrict p1 = pb1 + i*QK/2; - const __m256i m4b = _mm256_set1_epi8(0xf); - - const __m256i v0 = _mm256_loadu_si256((__m256i *) p0); // 64 x 4 + // 64 x 4 + const __m256i v0 = _mm256_loadu_si256((__m256i *) p0); const __m256i v1 = _mm256_loadu_si256((__m256i *) p1); - const __m256i v0l = _mm256_and_si256(v0, m4b); // 32 x 8 - const __m256i v1l = _mm256_and_si256(v1, m4b); // + // 32 x 8 + const __m256i v0l = _mm256_and_si256(v0, m4b); + const __m256i v1l = _mm256_and_si256(v1, m4b); const __m256i v0h = _mm256_and_si256(_mm256_srli_epi16(v0, 4), m4b); const __m256i v1h = _mm256_and_si256(_mm256_srli_epi16(v1, 4), m4b); @@ -1434,6 +1446,83 @@ void vec_dot_gq_4(const int n, float * restrict s, const void * restrict x, cons } sumf = 64.0*sum00 + sum01 + sum10 + sum11; +#elif QK == 64 && 1 // this is the best when using min + d + float sum00 = 0.0f; + + __m256 sum01 = _mm256_setzero_ps(); + __m256 sum10 = _mm256_setzero_ps(); + __m256 sum11 = _mm256_setzero_ps(); + + for (int i = 0; i < nb; i++) { + const float m0 = GGML_GQ_TO_FP32(pm0[i]); + const float d0 = GGML_GQ_TO_FP32(pd0[i]); + + const float m1 = GGML_GQ_TO_FP32(pm1[i]); + const float d1 = GGML_GQ_TO_FP32(pd1[i]); + + const uint8_t * restrict p0 = pb0 + i*QK/2; + const uint8_t * restrict p1 = pb1 + i*QK/2; + + const __m256 m0v = _mm256_set1_ps(m0); + const __m256 d0v = _mm256_set1_ps(d0); + + const __m256 m1v = _mm256_set1_ps(m1); + const __m256 d1v = _mm256_set1_ps(d1); + + const __m256 m1d0v = _mm256_mul_ps(m1v, d0v); + const __m256 m0d1v = _mm256_mul_ps(m0v, d1v); + const __m256 d0d1v = _mm256_mul_ps(d0v, d1v); + + const __m256i m4b = _mm256_set1_epi8(0xf); + + // 64 x 4 + const __m256i v0 = _mm256_loadu_si256((__m256i *) p0); + const __m256i v1 = _mm256_loadu_si256((__m256i *) p1); + + // 32 x 8 + const __m256i v0l = _mm256_and_si256(v0, m4b); + const __m256i v1l = _mm256_and_si256(v1, m4b); + + const __m256i v0h = _mm256_and_si256(_mm256_srli_epi16(v0, 4), m4b); + const __m256i v1h = _mm256_and_si256(_mm256_srli_epi16(v1, 4), m4b); + + const __m256i v0a = _mm256_add_epi8(v0l, v0h); + const __m256i v1a = _mm256_add_epi8(v1l, v1h); + + const __m128i v0al = _mm256_extracti128_si256(v0a, 0); + const __m128i v0ah = _mm256_extracti128_si256(v0a, 1); + + const __m128i v1al = _mm256_extracti128_si256(v1a, 0); + const __m128i v1ah = _mm256_extracti128_si256(v1a, 1); + + const __m128i v0as = _mm_add_epi8(v0al, v0ah); + const __m128i v1as = _mm_add_epi8(v1al, v1ah); + + const __m256i v0as_0 = _mm256_cvtepu8_epi32(v0as); + const __m256i v0as_1 = _mm256_cvtepu8_epi32(_mm_srli_si128(v0as, 8)); + + const __m256i v1as_0 = _mm256_cvtepu8_epi32(v1as); + const __m256i v1as_1 = _mm256_cvtepu8_epi32(_mm_srli_si128(v1as, 8)); + + const __m256i v0ass = _mm256_add_epi32(v0as_0, v0as_1); + const __m256i v1ass = _mm256_add_epi32(v1as_0, v1as_1); + + const __m256 v0f = _mm256_cvtepi32_ps(v0ass); + const __m256 v1f = _mm256_cvtepi32_ps(v1ass); + + const __m256i pl = _mm256_maddubs_epi16(v0l, v1l); + const __m256i ph = _mm256_maddubs_epi16(v0h, v1h); + + const __m256i p16 = _mm256_add_epi16(ph, pl); + const __m256i p = _mm256_madd_epi16(_mm256_set1_epi16(1), p16); + + sum00 += m0*m1; + sum01 = _mm256_fmadd_ps(m1d0v, v0f, sum01); + sum10 = _mm256_fmadd_ps(m0d1v, v1f, sum10); + sum11 = _mm256_fmadd_ps(d0d1v, _mm256_cvtepi32_ps(p), sum11); + } + + sumf = 64.0*sum00 + _mm256_hadd_ps_gg(sum01) + _mm256_hadd_ps_gg(sum10) + _mm256_hadd_ps_gg(sum11); #endif #elif defined (__ARM_NEON) float sum00 = 0.0f; @@ -1489,6 +1578,7 @@ void vec_dot_gq_4(const int n, float * restrict s, const void * restrict x, cons } sumf = 64.0*sum00 + sum01 + sum10 + sum11; +#endif #endif *s = sumf; @@ -1502,13 +1592,6 @@ void mul_mat_gq_4( int m, int n, int k) { assert(k % QK == 0); - // needed to initialize f16 tables - { - struct ggml_init_params params = { 0, NULL }; - struct ggml_context * ctx = ggml_init(params); - ggml_free(ctx); - } - const int nb = quantize_4_blocks_per_row(k); for (int ir0 = 0; ir0 < m; ir0++) { @@ -1523,9 +1606,350 @@ void mul_mat_gq_4( } } +// +// method 5 +// 4-bit quantization (without min, only delta) +// + +static inline int quantize_5_blocks_per_row(int k) { + return k/QK; +} + +static inline int quantize_5_row_size(int k) { + const int nb = quantize_5_blocks_per_row(k); + + return nb*(sizeof(gq_scale_t) + QK/2); +} + +void quantize_5_row(const float * restrict src, void * restrict dst, int k) { + assert(k % QK == 0); + assert(QB == 4); + + const int nb = quantize_5_blocks_per_row(k); + + gq_scale_t * restrict pd = (gq_scale_t *) (dst); + uint8_t * restrict pb = (uint8_t *) (pd + nb); + + uint8_t pp[QK/2]; + + for (int i = 0; i < nb; i++) { + memset(pp, 0, sizeof(pp)); + + float amax = 0.0f; // absolute max + +#if defined(__AVX2__) + { + assert(QK == 64); + const int QK8 = QK/8; + + __m256 srcv [QK8]; + __m256 asrcv[QK8]; + __m256 amaxv[QK8]; + + for (int l = 0; l < QK8; l++) { + srcv[l] = _mm256_loadu_ps(src + i*QK + 8*l); + } + + for (int l = 0; l < QK8; l++) { + asrcv[l] = _mm256_and_ps(srcv[l], (__m256) _mm256_set1_epi32(0x7fffffff)); + } + + + for (int l = 0; l < QK8/2; l++) { + amaxv[2*l] = _mm256_max_ps(asrcv[2*l], asrcv[2*l+1]); + } + + for (int l = 0; l < QK8/4; l++) { + amaxv[4*l] = _mm256_max_ps(amaxv[4*l], amaxv[4*l+2]); + } + + for (int l = 0; l < QK8/8; l++) { + amaxv[8*l] = _mm256_max_ps(amaxv[8*l], amaxv[8*l+4]); + } + + //amax = MAX(amaxv[0][0], MAX(amaxv[0][1], MAX(amaxv[0][2], MAX(amaxv[0][3], MAX(amaxv[0][4], MAX(amaxv[0][5], MAX(amaxv[0][6], amaxv[0][7]))))))); + + const __m256 amaxv0_0 = _mm256_permute2f128_ps(amaxv[0], amaxv[0], 3); + const __m256 amaxv0_1 = _mm256_max_ps(amaxv[0], amaxv0_0); + const __m256 amaxv0_2 = _mm256_permute_ps(amaxv0_1, 0x4e); + const __m256 amaxv0_3 = _mm256_max_ps(amaxv0_1, amaxv0_2); + const __m256 amaxv0_4 = _mm256_permute_ps(amaxv0_3, 0xb1); + const __m256 amaxv0_5 = _mm256_max_ps(amaxv0_3, amaxv0_4); + + amax = _mm256_cvtss_f32(amaxv0_5); + + //printf("amax = %f\n", amax); + + const float d = amax / ((1 << (QB - 1)) - 1); + const float id = d ? 1.0/d : 0.0; + + pd[i] = GGML_FP32_TO_GQ(d); + + const __m256 idv = _mm256_set1_ps(id); + + for (int l = 0; l < QK/8; l++) { + __m256 v = _mm256_mul_ps(srcv[l], idv); +#if 0 + v[0] += frand(); v[1] += frand(); v[2] += frand(); v[3] += frand(); + v[4] += frand(); v[5] += frand(); v[6] += frand(); v[7] += frand(); +#endif + + // convert to int8 + __m256i vi = _mm256_cvtps_epi32(v); + vi = _mm256_add_epi32(vi, _mm256_set1_epi32(8)); + + int32_t vi_0 = _mm256_extract_epi32(vi, 0); + int32_t vi_1 = _mm256_extract_epi32(vi, 1); + int32_t vi_2 = _mm256_extract_epi32(vi, 2); + int32_t vi_3 = _mm256_extract_epi32(vi, 3); + + int32_t vi_4 = _mm256_extract_epi32(vi, 4); + int32_t vi_5 = _mm256_extract_epi32(vi, 5); + int32_t vi_6 = _mm256_extract_epi32(vi, 6); + int32_t vi_7 = _mm256_extract_epi32(vi, 7); + + // convert to 4-bit, 2 consecutive packed into 1 byte + pp[4*l + 0] = vi_0 | (vi_1 << 4); + pp[4*l + 1] = vi_2 | (vi_3 << 4); + pp[4*l + 2] = vi_4 | (vi_5 << 4); + pp[4*l + 3] = vi_6 | (vi_7 << 4); + + //printf("vi: %7d %7d %7d %7d %7d %7d %7d %7d\n", vi_0, vi_1, vi_2, vi_3, vi_4, vi_5, vi_6, vi_7); + ////printf("v : %7.3f %7.3f %7.3f %7.3f %7.3f %7.3f %7.3f %7.3f\n", v[0], v[1], v[2], v[3], v[4], v[5], v[6], v[7]); + + assert(vi_0 >= 0 && vi_0 < 16); + assert(vi_1 >= 0 && vi_1 < 16); + assert(vi_2 >= 0 && vi_2 < 16); + assert(vi_3 >= 0 && vi_3 < 16); + + assert(vi_4 >= 0 && vi_4 < 16); + assert(vi_5 >= 0 && vi_5 < 16); + assert(vi_6 >= 0 && vi_6 < 16); + assert(vi_7 >= 0 && vi_7 < 16); + } + + memcpy(pb + i*QK/2, pp, sizeof(pp)); + } +#elif defined(__ARM_NEON) && 0 + { + // TODO + } +#else + { + for (int l = 0; l < QK; l++) { + const float v = src[i*QK + l]; + amax = MAX(amax, fabsf(v)); + } + + const float d = amax / ((1 << (QB - 1)) - 1); + const float id = d ? 1.0/d : 0.0; + + pd[i] = GGML_FP32_TO_GQ(d); + + for (int l = 0; l < QK; l++) { + const float v = src[i*QK + l]*id; + const int8_t vi = ((int8_t) (v + frand())) + 8; + pp[l/2] |= (vi & 0xf) << (4*(l & 1)); + } + + memcpy(pb + i*QK/2, pp, sizeof(pp)); + } +#endif + //printf("min %f max %f\n", min, max); + } +} + +// reimplementation of quantize_5 using quantize_5_row +void quantize_5(const float * restrict src, char * restrict dst, int n, int k) { + assert(k % QK == 0); + + for (int j = 0; j < n; j++) { + quantize_5_row(src + j*k, dst, k); + dst = (char *) dst + quantize_5_row_size(k); + } +} + +void vec_dot_gq_5(const int n, float * restrict s, const void * restrict x, const void * restrict y) { + const int nb = quantize_5_blocks_per_row(n); + + const gq_scale_t * restrict pd0 = (const gq_scale_t *) x; + const gq_scale_t * restrict pd1 = (const gq_scale_t *) y; + + const uint8_t * restrict pb0 = (const uint8_t *) (pd0 + nb); + const uint8_t * restrict pb1 = (const uint8_t *) (pd1 + nb); + + float sumf = 0.0; + +#if 0 + // scalar + for (int i = 0; i < nb; i++) { + const float d0 = GGML_GQ_TO_FP32(pd0[i]); + const float d1 = GGML_GQ_TO_FP32(pd1[i]); + + const uint8_t * restrict p0 = pb0 + i*QK/2; + const uint8_t * restrict p1 = pb1 + i*QK/2; + + for (int j = 0; j < QK/2; j++) { + const uint8_t v0 = p0[j]; + const uint8_t v1 = p1[j]; + + const float f0 = d0*((int8_t) (v0 & 0xf) - 8); + const float f1 = d0*((int8_t) (v0 >> 4) - 8); + + const float f2 = d1*((int8_t) (v1 & 0xf) - 8); + const float f3 = d1*((int8_t) (v1 >> 4) - 8); + + sumf += f0*f2 + f1*f3; + } + } +#else +#if defined(__AVX2__) +#if QK == 64 && 1 + __m256 sum11 = _mm256_setzero_ps(); + + for (int i = 0; i < nb; i++) { + const float d0 = GGML_GQ_TO_FP32(pd0[i]); + const float d1 = GGML_GQ_TO_FP32(pd1[i]); + + const uint8_t * restrict p0 = pb0 + i*QK/2; + const uint8_t * restrict p1 = pb1 + i*QK/2; + + const __m256 d0v = _mm256_set1_ps(d0); + const __m256 d1v = _mm256_set1_ps(d1); + + const __m256 d0d1v = _mm256_mul_ps(d0v, d1v); + + const __m256i m4b = _mm256_set1_epi8(0xf); + + // 64 x 4 + const __m256i v0 = _mm256_loadu_si256((__m256i *) p0); + const __m256i v1 = _mm256_loadu_si256((__m256i *) p1); + + // 32 x 8 + __m256i v0l = _mm256_and_si256(v0, m4b); + __m256i v1l = _mm256_and_si256(v1, m4b); + + __m256i v0h = _mm256_and_si256(_mm256_srli_epi16(v0, 4), m4b); + __m256i v1h = _mm256_and_si256(_mm256_srli_epi16(v1, 4), m4b); + + // sub 8 + v0l = _mm256_sub_epi8(v0l, _mm256_set1_epi8(8)); + v0h = _mm256_sub_epi8(v0h, _mm256_set1_epi8(8)); + + v1l = _mm256_sub_epi8(v1l, _mm256_set1_epi8(8)); + v1h = _mm256_sub_epi8(v1h, _mm256_set1_epi8(8)); + + // abs + const __m256i v0la = _mm256_sign_epi8(v0l, v0l); + const __m256i v0ha = _mm256_sign_epi8(v0h, v0h); + + // sign + const __m256i v1ls = _mm256_sign_epi8(v1l, v0l); + const __m256i v1hs = _mm256_sign_epi8(v1h, v0h); + + const __m256i pl = _mm256_maddubs_epi16(v0la, v1ls); + const __m256i ph = _mm256_maddubs_epi16(v0ha, v1hs); + + const __m256i p16 = _mm256_add_epi16(ph, pl); + const __m256i p = _mm256_madd_epi16(_mm256_set1_epi16(1), p16); + + sum11 = _mm256_fmadd_ps(d0d1v, _mm256_cvtepi32_ps(p), sum11); + } + + sumf = _mm256_hadd_ps_gg(sum11); +#endif +#elif defined (__ARM_NEON) + float sum00 = 0.0f; + float sum01 = 0.0f; + float sum10 = 0.0f; + float sum11 = 0.0f; + + for (int i = 0; i < nb; i++) { + const float m0 = GGML_GQ_TO_FP32(pm0[i]); + const float d0 = GGML_GQ_TO_FP32(pd0[i]); + + const float m1 = GGML_GQ_TO_FP32(pm1[i]); + const float d1 = GGML_GQ_TO_FP32(pd1[i]); + + const uint8_t * restrict p0 = pb0 + i*QK/2; + const uint8_t * restrict p1 = pb1 + i*QK/2; + + const uint8x16_t m4b = vdupq_n_u8(0xf); + + const uint8x16_t v0_0 = vld1q_u8(p0); + const uint8x16_t v0_1 = vld1q_u8(p0 + 16); + const uint8x16_t v1_0 = vld1q_u8(p1); + const uint8x16_t v1_1 = vld1q_u8(p1 + 16); + + // and with 0xf + const uint8x16_t v0_0l = vandq_u8(v0_0, m4b); + const uint8x16_t v0_1l = vandq_u8(v0_1, m4b); + const uint8x16_t v1_0l = vandq_u8(v1_0, m4b); + const uint8x16_t v1_1l = vandq_u8(v1_1, m4b); + + const uint8x16_t v0_0h = vshrq_n_u8(v0_0, 4); + const uint8x16_t v0_1h = vshrq_n_u8(v0_1, 4); + const uint8x16_t v1_0h = vshrq_n_u8(v1_0, 4); + const uint8x16_t v1_1h = vshrq_n_u8(v1_1, 4); + + // dot product into uint16x8_t + const uint16x8_t pl0l = vmull_u8(vget_low_u8 (v0_0l), vget_low_u8 (v1_0l)); + const uint16x8_t pl0h = vmull_u8(vget_high_u8(v0_0l), vget_high_u8(v1_0l)); + const uint16x8_t pl1l = vmull_u8(vget_low_u8 (v0_1l), vget_low_u8 (v1_1l)); + const uint16x8_t pl1h = vmull_u8(vget_high_u8(v0_1l), vget_high_u8(v1_1l)); + + const uint16x8_t ph0l = vmull_u8(vget_low_u8 (v0_0h), vget_low_u8 (v1_0h)); + const uint16x8_t ph0h = vmull_u8(vget_high_u8(v0_0h), vget_high_u8(v1_0h)); + const uint16x8_t ph1l = vmull_u8(vget_low_u8 (v0_1h), vget_low_u8 (v1_1h)); + const uint16x8_t ph1h = vmull_u8(vget_high_u8(v0_1h), vget_high_u8(v1_1h)); + + sum00 += m0*m1; + sum01 += m1*d0*(((int)vaddvq_u8(v0_0l)) + ((int)vaddvq_u8(v0_0h)) + ((int)vaddvq_u8(v0_1l)) + ((int)vaddvq_u8(v0_1h))); + sum10 += m0*d1*(((int)vaddvq_u8(v1_0l)) + ((int)vaddvq_u8(v1_0h)) + ((int)vaddvq_u8(v1_1l)) + ((int)vaddvq_u8(v1_1h))); + sum11 += d0*d1*( + vaddvq_u16(vaddq_u16(vaddq_u16(pl0l, pl0h), vaddq_u16(pl1l, pl1h))) + + vaddvq_u16(vaddq_u16(vaddq_u16(ph0l, ph0h), vaddq_u16(ph1l, ph1h)))); + } + + sumf = 64.0*sum00 + sum01 + sum10 + sum11; +#endif +#endif + + *s = sumf; +} + +// use vec_dot_gq_5 to compute the dot product of two rows +void mul_mat_gq_5( + const void * src0, + const void * src1, // transposed + float * dst, + int m, int n, int k) { + assert(k % QK == 0); + + const int nb = quantize_5_blocks_per_row(k); + + for (int ir0 = 0; ir0 < m; ir0++) { + for (int ir1 = 0; ir1 < n; ir1++) { + vec_dot_gq_5(k, dst + ir1, src0, src1); + src1 = (const char *) src1 + quantize_5_row_size(k); + } + src0 = (const char *) src0 + quantize_5_row_size(k); + src1 = (const char *) src1 - n*quantize_5_row_size(k); + + dst = (float *) dst + n; + } +} + int main(int argc, const char ** argv) { assert(sizeof(gq_quant_t)*8 == gq_t_bits); + // needed to initialize f16 tables + { + struct ggml_init_params params = { 0, NULL }; + struct ggml_context * ctx = ggml_init(params); + ggml_free(ctx); + } + int method = 0; if (argc > 1) { method = atoi(argv[1]); @@ -1585,6 +2009,13 @@ int main(int argc, const char ** argv) { sizegq = quantize_4_row_size(K)*M + quantize_4_row_size(K)*N; } + + if (method == 5) { + src0_gq = calloc(1, quantize_5_row_size(K)*M); + src1_gq = calloc(1, quantize_5_row_size(K)*N); + + sizegq = quantize_5_row_size(K)*M + quantize_5_row_size(K)*N; + } } const size_t sizef16 = sizeof(ggml_fp16_t)*M*K + sizeof(ggml_fp16_t)*N*K; @@ -1615,6 +2046,11 @@ int main(int argc, const char ** argv) { quantize_4(src1, src1_gq, N, K); } + if (method == 5) { + quantize_5(src0, src0_gq, M, K); + quantize_5(src1, src1_gq, N, K); + } + const uint64_t t_end = get_time_us(); printf("convert time: %f ms / method = %d\n", (t_end - t_start) / 1000.0, method); } @@ -1650,6 +2086,10 @@ int main(int argc, const char ** argv) { if (method == 4) { mul_mat_gq_4(src0_gq, src1_gq, dst, M, N, K); } + + if (method == 5) { + mul_mat_gq_5(src0_gq, src1_gq, dst, M, N, K); + } } for (int i = 0; i < N; i++) {