diff --git a/CMakeLists.txt b/CMakeLists.txt index 4d074ef..d88c5b1 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -46,7 +46,7 @@ if (GGML_SANITIZE_UNDEFINED) endif() #set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -ffast-math") -set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -march=native") +#set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -march=native") #set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -mcpu=native") # dependencies diff --git a/tests/test-mul-mat2.c b/tests/test-mul-mat2.c index 6d5956f..a3443c0 100644 --- a/tests/test-mul-mat2.c +++ b/tests/test-mul-mat2.c @@ -1106,7 +1106,7 @@ void quantize_4_row(const float * restrict src, void * restrict dst, int k) { memcpy(pb + i*QK/2, pp, sizeof(pp)); } -#elif defined(__ARM_NEON) +#elif defined(__ARM_NEON) && 0 { // TODO } @@ -1147,6 +1147,7 @@ void quantize_4(const float * restrict src, char * restrict dst, int n, int k) { } } +#if defined(__AVX2__) // horizontally reduce 8 32-bit integers static inline int32_t _mm256_hadd_epi32_gg(__m256i v) { __m128i v0 = _mm256_extractf128_si256(v, 0); @@ -1178,6 +1179,7 @@ static inline int32_t _mm256_hadd_epi8_gg(__m256i v0) { return _mm256_hadd_epi32_gg(v2); } +#endif void vec_dot_gq_4(const int n, float * restrict s, const void * restrict x, const void * restrict y) { const int nb = quantize_4_blocks_per_row(n); @@ -1432,9 +1434,61 @@ void vec_dot_gq_4(const int n, float * restrict s, const void * restrict x, cons } sumf = 64.0*sum00 + sum01 + sum10 + sum11; -#else -#error "QK must be 64" #endif +#elif defined (__ARM_NEON) + float sum00 = 0.0f; + float sum01 = 0.0f; + float sum10 = 0.0f; + float sum11 = 0.0f; + + for (int i = 0; i < nb; i++) { + const float m0 = GGML_GQ_TO_FP32(pm0[i]); + const float d0 = GGML_GQ_TO_FP32(pd0[i]); + + const float m1 = GGML_GQ_TO_FP32(pm1[i]); + const float d1 = GGML_GQ_TO_FP32(pd1[i]); + + const uint8_t * restrict p0 = pb0 + i*QK/2; + const uint8_t * restrict p1 = pb1 + i*QK/2; + + const uint8x16_t m4b = vdupq_n_u8(0xf); + + const uint8x16_t v0_0 = vld1q_u8(p0); + const uint8x16_t v0_1 = vld1q_u8(p0 + 16); + const uint8x16_t v1_0 = vld1q_u8(p1); + const uint8x16_t v1_1 = vld1q_u8(p1 + 16); + + // and with 0xf + const uint8x16_t v0_0l = vandq_u8(v0_0, m4b); + const uint8x16_t v0_1l = vandq_u8(v0_1, m4b); + const uint8x16_t v1_0l = vandq_u8(v1_0, m4b); + const uint8x16_t v1_1l = vandq_u8(v1_1, m4b); + + const uint8x16_t v0_0h = vshrq_n_u8(v0_0, 4); + const uint8x16_t v0_1h = vshrq_n_u8(v0_1, 4); + const uint8x16_t v1_0h = vshrq_n_u8(v1_0, 4); + const uint8x16_t v1_1h = vshrq_n_u8(v1_1, 4); + + // dot product into uint16x8_t + const uint16x8_t pl0l = vmull_u8(vget_low_u8 (v0_0l), vget_low_u8 (v1_0l)); + const uint16x8_t pl0h = vmull_u8(vget_high_u8(v0_0l), vget_high_u8(v1_0l)); + const uint16x8_t pl1l = vmull_u8(vget_low_u8 (v0_1l), vget_low_u8 (v1_1l)); + const uint16x8_t pl1h = vmull_u8(vget_high_u8(v0_1l), vget_high_u8(v1_1l)); + + const uint16x8_t ph0l = vmull_u8(vget_low_u8 (v0_0h), vget_low_u8 (v1_0h)); + const uint16x8_t ph0h = vmull_u8(vget_high_u8(v0_0h), vget_high_u8(v1_0h)); + const uint16x8_t ph1l = vmull_u8(vget_low_u8 (v0_1h), vget_low_u8 (v1_1h)); + const uint16x8_t ph1h = vmull_u8(vget_high_u8(v0_1h), vget_high_u8(v1_1h)); + + sum00 += m0*m1; + sum01 += m1*d0*(((int)vaddvq_u8(v0_0l)) + ((int)vaddvq_u8(v0_0h)) + ((int)vaddvq_u8(v0_1l)) + ((int)vaddvq_u8(v0_1h))); + sum10 += m0*d1*(((int)vaddvq_u8(v1_0l)) + ((int)vaddvq_u8(v1_0h)) + ((int)vaddvq_u8(v1_1l)) + ((int)vaddvq_u8(v1_1h))); + sum11 += d0*d1*( + vaddvq_u16(vaddq_u16(vaddq_u16(pl0l, pl0h), vaddq_u16(pl1l, pl1h))) + + vaddvq_u16(vaddq_u16(vaddq_u16(ph0l, ph0h), vaddq_u16(ph1l, ph1h)))); + } + + sumf = 64.0*sum00 + sum01 + sum10 + sum11; #endif *s = sumf; @@ -1450,7 +1504,7 @@ void mul_mat_gq_4( // needed to initialize f16 tables { - struct ggml_init_params params; + struct ggml_init_params params = { 0, NULL }; struct ggml_context * ctx = ggml_init(params); ggml_free(ctx); }