|
|
|
@ -1106,7 +1106,7 @@ void quantize_4_row(const float * restrict src, void * restrict dst, int k) {
|
|
|
|
|
|
|
|
|
|
memcpy(pb + i*QK/2, pp, sizeof(pp));
|
|
|
|
|
}
|
|
|
|
|
#elif defined(__ARM_NEON)
|
|
|
|
|
#elif defined(__ARM_NEON) && 0
|
|
|
|
|
{
|
|
|
|
|
// TODO
|
|
|
|
|
}
|
|
|
|
@ -1147,6 +1147,7 @@ void quantize_4(const float * restrict src, char * restrict dst, int n, int k) {
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
#if defined(__AVX2__)
|
|
|
|
|
// horizontally reduce 8 32-bit integers
|
|
|
|
|
static inline int32_t _mm256_hadd_epi32_gg(__m256i v) {
|
|
|
|
|
__m128i v0 = _mm256_extractf128_si256(v, 0);
|
|
|
|
@ -1178,6 +1179,7 @@ static inline int32_t _mm256_hadd_epi8_gg(__m256i v0) {
|
|
|
|
|
|
|
|
|
|
return _mm256_hadd_epi32_gg(v2);
|
|
|
|
|
}
|
|
|
|
|
#endif
|
|
|
|
|
|
|
|
|
|
void vec_dot_gq_4(const int n, float * restrict s, const void * restrict x, const void * restrict y) {
|
|
|
|
|
const int nb = quantize_4_blocks_per_row(n);
|
|
|
|
@ -1432,9 +1434,61 @@ void vec_dot_gq_4(const int n, float * restrict s, const void * restrict x, cons
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
sumf = 64.0*sum00 + sum01 + sum10 + sum11;
|
|
|
|
|
#else
|
|
|
|
|
#error "QK must be 64"
|
|
|
|
|
#endif
|
|
|
|
|
#elif defined (__ARM_NEON)
|
|
|
|
|
float sum00 = 0.0f;
|
|
|
|
|
float sum01 = 0.0f;
|
|
|
|
|
float sum10 = 0.0f;
|
|
|
|
|
float sum11 = 0.0f;
|
|
|
|
|
|
|
|
|
|
for (int i = 0; i < nb; i++) {
|
|
|
|
|
const float m0 = GGML_GQ_TO_FP32(pm0[i]);
|
|
|
|
|
const float d0 = GGML_GQ_TO_FP32(pd0[i]);
|
|
|
|
|
|
|
|
|
|
const float m1 = GGML_GQ_TO_FP32(pm1[i]);
|
|
|
|
|
const float d1 = GGML_GQ_TO_FP32(pd1[i]);
|
|
|
|
|
|
|
|
|
|
const uint8_t * restrict p0 = pb0 + i*QK/2;
|
|
|
|
|
const uint8_t * restrict p1 = pb1 + i*QK/2;
|
|
|
|
|
|
|
|
|
|
const uint8x16_t m4b = vdupq_n_u8(0xf);
|
|
|
|
|
|
|
|
|
|
const uint8x16_t v0_0 = vld1q_u8(p0);
|
|
|
|
|
const uint8x16_t v0_1 = vld1q_u8(p0 + 16);
|
|
|
|
|
const uint8x16_t v1_0 = vld1q_u8(p1);
|
|
|
|
|
const uint8x16_t v1_1 = vld1q_u8(p1 + 16);
|
|
|
|
|
|
|
|
|
|
// and with 0xf
|
|
|
|
|
const uint8x16_t v0_0l = vandq_u8(v0_0, m4b);
|
|
|
|
|
const uint8x16_t v0_1l = vandq_u8(v0_1, m4b);
|
|
|
|
|
const uint8x16_t v1_0l = vandq_u8(v1_0, m4b);
|
|
|
|
|
const uint8x16_t v1_1l = vandq_u8(v1_1, m4b);
|
|
|
|
|
|
|
|
|
|
const uint8x16_t v0_0h = vshrq_n_u8(v0_0, 4);
|
|
|
|
|
const uint8x16_t v0_1h = vshrq_n_u8(v0_1, 4);
|
|
|
|
|
const uint8x16_t v1_0h = vshrq_n_u8(v1_0, 4);
|
|
|
|
|
const uint8x16_t v1_1h = vshrq_n_u8(v1_1, 4);
|
|
|
|
|
|
|
|
|
|
// dot product into uint16x8_t
|
|
|
|
|
const uint16x8_t pl0l = vmull_u8(vget_low_u8 (v0_0l), vget_low_u8 (v1_0l));
|
|
|
|
|
const uint16x8_t pl0h = vmull_u8(vget_high_u8(v0_0l), vget_high_u8(v1_0l));
|
|
|
|
|
const uint16x8_t pl1l = vmull_u8(vget_low_u8 (v0_1l), vget_low_u8 (v1_1l));
|
|
|
|
|
const uint16x8_t pl1h = vmull_u8(vget_high_u8(v0_1l), vget_high_u8(v1_1l));
|
|
|
|
|
|
|
|
|
|
const uint16x8_t ph0l = vmull_u8(vget_low_u8 (v0_0h), vget_low_u8 (v1_0h));
|
|
|
|
|
const uint16x8_t ph0h = vmull_u8(vget_high_u8(v0_0h), vget_high_u8(v1_0h));
|
|
|
|
|
const uint16x8_t ph1l = vmull_u8(vget_low_u8 (v0_1h), vget_low_u8 (v1_1h));
|
|
|
|
|
const uint16x8_t ph1h = vmull_u8(vget_high_u8(v0_1h), vget_high_u8(v1_1h));
|
|
|
|
|
|
|
|
|
|
sum00 += m0*m1;
|
|
|
|
|
sum01 += m1*d0*(((int)vaddvq_u8(v0_0l)) + ((int)vaddvq_u8(v0_0h)) + ((int)vaddvq_u8(v0_1l)) + ((int)vaddvq_u8(v0_1h)));
|
|
|
|
|
sum10 += m0*d1*(((int)vaddvq_u8(v1_0l)) + ((int)vaddvq_u8(v1_0h)) + ((int)vaddvq_u8(v1_1l)) + ((int)vaddvq_u8(v1_1h)));
|
|
|
|
|
sum11 += d0*d1*(
|
|
|
|
|
vaddvq_u16(vaddq_u16(vaddq_u16(pl0l, pl0h), vaddq_u16(pl1l, pl1h))) +
|
|
|
|
|
vaddvq_u16(vaddq_u16(vaddq_u16(ph0l, ph0h), vaddq_u16(ph1l, ph1h))));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
sumf = 64.0*sum00 + sum01 + sum10 + sum11;
|
|
|
|
|
#endif
|
|
|
|
|
|
|
|
|
|
*s = sumf;
|
|
|
|
@ -1450,7 +1504,7 @@ void mul_mat_gq_4(
|
|
|
|
|
|
|
|
|
|
// needed to initialize f16 tables
|
|
|
|
|
{
|
|
|
|
|
struct ggml_init_params params;
|
|
|
|
|
struct ggml_init_params params = { 0, NULL };
|
|
|
|
|
struct ggml_context * ctx = ggml_init(params);
|
|
|
|
|
ggml_free(ctx);
|
|
|
|
|
}
|
|
|
|
|