gq : method 4 (ARM)

gq
Georgi Gerganov 1 year ago
parent 054d97e0e1
commit cde7c22ab1
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735

@ -46,7 +46,7 @@ if (GGML_SANITIZE_UNDEFINED)
endif()
#set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -ffast-math")
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -march=native")
#set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -march=native")
#set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -mcpu=native")
# dependencies

@ -1106,7 +1106,7 @@ void quantize_4_row(const float * restrict src, void * restrict dst, int k) {
memcpy(pb + i*QK/2, pp, sizeof(pp));
}
#elif defined(__ARM_NEON)
#elif defined(__ARM_NEON) && 0
{
// TODO
}
@ -1147,6 +1147,7 @@ void quantize_4(const float * restrict src, char * restrict dst, int n, int k) {
}
}
#if defined(__AVX2__)
// horizontally reduce 8 32-bit integers
static inline int32_t _mm256_hadd_epi32_gg(__m256i v) {
__m128i v0 = _mm256_extractf128_si256(v, 0);
@ -1178,6 +1179,7 @@ static inline int32_t _mm256_hadd_epi8_gg(__m256i v0) {
return _mm256_hadd_epi32_gg(v2);
}
#endif
void vec_dot_gq_4(const int n, float * restrict s, const void * restrict x, const void * restrict y) {
const int nb = quantize_4_blocks_per_row(n);
@ -1432,9 +1434,61 @@ void vec_dot_gq_4(const int n, float * restrict s, const void * restrict x, cons
}
sumf = 64.0*sum00 + sum01 + sum10 + sum11;
#else
#error "QK must be 64"
#endif
#elif defined (__ARM_NEON)
float sum00 = 0.0f;
float sum01 = 0.0f;
float sum10 = 0.0f;
float sum11 = 0.0f;
for (int i = 0; i < nb; i++) {
const float m0 = GGML_GQ_TO_FP32(pm0[i]);
const float d0 = GGML_GQ_TO_FP32(pd0[i]);
const float m1 = GGML_GQ_TO_FP32(pm1[i]);
const float d1 = GGML_GQ_TO_FP32(pd1[i]);
const uint8_t * restrict p0 = pb0 + i*QK/2;
const uint8_t * restrict p1 = pb1 + i*QK/2;
const uint8x16_t m4b = vdupq_n_u8(0xf);
const uint8x16_t v0_0 = vld1q_u8(p0);
const uint8x16_t v0_1 = vld1q_u8(p0 + 16);
const uint8x16_t v1_0 = vld1q_u8(p1);
const uint8x16_t v1_1 = vld1q_u8(p1 + 16);
// and with 0xf
const uint8x16_t v0_0l = vandq_u8(v0_0, m4b);
const uint8x16_t v0_1l = vandq_u8(v0_1, m4b);
const uint8x16_t v1_0l = vandq_u8(v1_0, m4b);
const uint8x16_t v1_1l = vandq_u8(v1_1, m4b);
const uint8x16_t v0_0h = vshrq_n_u8(v0_0, 4);
const uint8x16_t v0_1h = vshrq_n_u8(v0_1, 4);
const uint8x16_t v1_0h = vshrq_n_u8(v1_0, 4);
const uint8x16_t v1_1h = vshrq_n_u8(v1_1, 4);
// dot product into uint16x8_t
const uint16x8_t pl0l = vmull_u8(vget_low_u8 (v0_0l), vget_low_u8 (v1_0l));
const uint16x8_t pl0h = vmull_u8(vget_high_u8(v0_0l), vget_high_u8(v1_0l));
const uint16x8_t pl1l = vmull_u8(vget_low_u8 (v0_1l), vget_low_u8 (v1_1l));
const uint16x8_t pl1h = vmull_u8(vget_high_u8(v0_1l), vget_high_u8(v1_1l));
const uint16x8_t ph0l = vmull_u8(vget_low_u8 (v0_0h), vget_low_u8 (v1_0h));
const uint16x8_t ph0h = vmull_u8(vget_high_u8(v0_0h), vget_high_u8(v1_0h));
const uint16x8_t ph1l = vmull_u8(vget_low_u8 (v0_1h), vget_low_u8 (v1_1h));
const uint16x8_t ph1h = vmull_u8(vget_high_u8(v0_1h), vget_high_u8(v1_1h));
sum00 += m0*m1;
sum01 += m1*d0*(((int)vaddvq_u8(v0_0l)) + ((int)vaddvq_u8(v0_0h)) + ((int)vaddvq_u8(v0_1l)) + ((int)vaddvq_u8(v0_1h)));
sum10 += m0*d1*(((int)vaddvq_u8(v1_0l)) + ((int)vaddvq_u8(v1_0h)) + ((int)vaddvq_u8(v1_1l)) + ((int)vaddvq_u8(v1_1h)));
sum11 += d0*d1*(
vaddvq_u16(vaddq_u16(vaddq_u16(pl0l, pl0h), vaddq_u16(pl1l, pl1h))) +
vaddvq_u16(vaddq_u16(vaddq_u16(ph0l, ph0h), vaddq_u16(ph1l, ph1h))));
}
sumf = 64.0*sum00 + sum01 + sum10 + sum11;
#endif
*s = sumf;
@ -1450,7 +1504,7 @@ void mul_mat_gq_4(
// needed to initialize f16 tables
{
struct ggml_init_params params;
struct ggml_init_params params = { 0, NULL };
struct ggml_context * ctx = ggml_init(params);
ggml_free(ctx);
}

Loading…
Cancel
Save