|
|
|
@ -1204,7 +1204,7 @@ void vec_dot_gq_4(const int n, float * restrict s, const void * restrict x, cons
|
|
|
|
|
|
|
|
|
|
float sumf = 0.0;
|
|
|
|
|
|
|
|
|
|
#if 1
|
|
|
|
|
#if 0
|
|
|
|
|
// scalar
|
|
|
|
|
for (int i = 0; i < nb; i++) {
|
|
|
|
|
const float m0 = GGML_GQ_TO_FP32(pm0[i]);
|
|
|
|
@ -1569,12 +1569,21 @@ void vec_dot_gq_4(const int n, float * restrict s, const void * restrict x, cons
|
|
|
|
|
const uint16x8_t ph1l = vmull_u8(vget_low_u8 (v0_1h), vget_low_u8 (v1_1h));
|
|
|
|
|
const uint16x8_t ph1h = vmull_u8(vget_high_u8(v0_1h), vget_high_u8(v1_1h));
|
|
|
|
|
|
|
|
|
|
const uint16x8_t pl0 = vaddq_u16(pl0l, pl0h);
|
|
|
|
|
const uint16x8_t pl1 = vaddq_u16(pl1l, pl1h);
|
|
|
|
|
const uint16x8_t ph0 = vaddq_u16(ph0l, ph0h);
|
|
|
|
|
const uint16x8_t ph1 = vaddq_u16(ph1l, ph1h);
|
|
|
|
|
|
|
|
|
|
const uint16x8_t pl = vaddq_u16(pl0, pl1);
|
|
|
|
|
const uint16x8_t ph = vaddq_u16(ph0, ph1);
|
|
|
|
|
|
|
|
|
|
sum00 += m0*m1;
|
|
|
|
|
sum01 += m1*d0*(((int)vaddvq_u8(v0_0l)) + ((int)vaddvq_u8(v0_0h)) + ((int)vaddvq_u8(v0_1l)) + ((int)vaddvq_u8(v0_1h)));
|
|
|
|
|
sum10 += m0*d1*(((int)vaddvq_u8(v1_0l)) + ((int)vaddvq_u8(v1_0h)) + ((int)vaddvq_u8(v1_1l)) + ((int)vaddvq_u8(v1_1h)));
|
|
|
|
|
sum11 += d0*d1*(
|
|
|
|
|
vaddvq_u16(vaddq_u16(vaddq_u16(pl0l, pl0h), vaddq_u16(pl1l, pl1h))) +
|
|
|
|
|
vaddvq_u16(vaddq_u16(vaddq_u16(ph0l, ph0h), vaddq_u16(ph1l, ph1h))));
|
|
|
|
|
sum01 += m1*d0*(vaddvq_u8(v0_0l) + vaddvq_u8(v0_0h) + vaddvq_u8(v0_1l) + vaddvq_u8(v0_1h));
|
|
|
|
|
sum10 += m0*d1*(vaddvq_u8(v1_0l) + vaddvq_u8(v1_0h) + vaddvq_u8(v1_1l) + vaddvq_u8(v1_1h));
|
|
|
|
|
//sum11 += d0*d1*(
|
|
|
|
|
// vaddvq_u16(vaddq_u16(vaddq_u16(pl0l, pl0h), vaddq_u16(pl1l, pl1h))) +
|
|
|
|
|
// vaddvq_u16(vaddq_u16(vaddq_u16(ph0l, ph0h), vaddq_u16(ph1l, ph1h))));
|
|
|
|
|
sum11 += d0*d1*vaddvq_u16(vaddq_u16(pl, ph));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
sumf = 64.0*sum00 + sum01 + sum10 + sum11;
|
|
|
|
@ -1748,7 +1757,8 @@ void quantize_5_row(const float * restrict src, void * restrict dst, int k) {
|
|
|
|
|
|
|
|
|
|
for (int l = 0; l < QK; l++) {
|
|
|
|
|
const float v = src[i*QK + l]*id;
|
|
|
|
|
const int8_t vi = ((int8_t) (v + frand())) + 8;
|
|
|
|
|
const int8_t vi = ((int8_t) (round(v))) + 8;
|
|
|
|
|
assert(vi >= 0 && vi < 16);
|
|
|
|
|
pp[l/2] |= (vi & 0xf) << (4*(l & 1));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -1859,29 +1869,33 @@ void vec_dot_gq_5(const int n, float * restrict s, const void * restrict x, cons
|
|
|
|
|
sumf = _mm256_hadd_ps_gg(sum11);
|
|
|
|
|
#endif
|
|
|
|
|
#elif defined (__ARM_NEON)
|
|
|
|
|
float sum00 = 0.0f;
|
|
|
|
|
float sum01 = 0.0f;
|
|
|
|
|
float sum10 = 0.0f;
|
|
|
|
|
float sum11 = 0.0f;
|
|
|
|
|
|
|
|
|
|
//float32x4_t sum_0 = vdupq_n_f32(0.0f);
|
|
|
|
|
//float32x4_t sum_1 = vdupq_n_f32(0.0f);
|
|
|
|
|
|
|
|
|
|
//float16x8_t sum_0 = vdupq_n_f16(0.0f);
|
|
|
|
|
//float16x8_t sum_1 = vdupq_n_f16(0.0f);
|
|
|
|
|
|
|
|
|
|
for (int i = 0; i < nb; i++) {
|
|
|
|
|
const float m0 = GGML_GQ_TO_FP32(pm0[i]);
|
|
|
|
|
const float d0 = GGML_GQ_TO_FP32(pd0[i]);
|
|
|
|
|
|
|
|
|
|
const float m1 = GGML_GQ_TO_FP32(pm1[i]);
|
|
|
|
|
const float d1 = GGML_GQ_TO_FP32(pd1[i]);
|
|
|
|
|
|
|
|
|
|
//float32x4_t d0d1v = vdupq_n_f32(d0*d1);
|
|
|
|
|
//float16x8_t d0d1v = vdupq_n_f16(d0*d1);
|
|
|
|
|
|
|
|
|
|
const uint8_t * restrict p0 = pb0 + i*QK/2;
|
|
|
|
|
const uint8_t * restrict p1 = pb1 + i*QK/2;
|
|
|
|
|
|
|
|
|
|
const uint8x16_t m4b = vdupq_n_u8(0xf);
|
|
|
|
|
const int8x16_t s8b = vdupq_n_s8(0x8);
|
|
|
|
|
|
|
|
|
|
const uint8x16_t v0_0 = vld1q_u8(p0);
|
|
|
|
|
const uint8x16_t v0_1 = vld1q_u8(p0 + 16);
|
|
|
|
|
const uint8x16_t v1_0 = vld1q_u8(p1);
|
|
|
|
|
const uint8x16_t v1_1 = vld1q_u8(p1 + 16);
|
|
|
|
|
|
|
|
|
|
// and with 0xf
|
|
|
|
|
// 4-bit -> 8-bit
|
|
|
|
|
const uint8x16_t v0_0l = vandq_u8(v0_0, m4b);
|
|
|
|
|
const uint8x16_t v0_1l = vandq_u8(v0_1, m4b);
|
|
|
|
|
const uint8x16_t v1_0l = vandq_u8(v1_0, m4b);
|
|
|
|
@ -1892,26 +1906,76 @@ void vec_dot_gq_5(const int n, float * restrict s, const void * restrict x, cons
|
|
|
|
|
const uint8x16_t v1_0h = vshrq_n_u8(v1_0, 4);
|
|
|
|
|
const uint8x16_t v1_1h = vshrq_n_u8(v1_1, 4);
|
|
|
|
|
|
|
|
|
|
// dot product into uint16x8_t
|
|
|
|
|
const uint16x8_t pl0l = vmull_u8(vget_low_u8 (v0_0l), vget_low_u8 (v1_0l));
|
|
|
|
|
const uint16x8_t pl0h = vmull_u8(vget_high_u8(v0_0l), vget_high_u8(v1_0l));
|
|
|
|
|
const uint16x8_t pl1l = vmull_u8(vget_low_u8 (v0_1l), vget_low_u8 (v1_1l));
|
|
|
|
|
const uint16x8_t pl1h = vmull_u8(vget_high_u8(v0_1l), vget_high_u8(v1_1l));
|
|
|
|
|
|
|
|
|
|
const uint16x8_t ph0l = vmull_u8(vget_low_u8 (v0_0h), vget_low_u8 (v1_0h));
|
|
|
|
|
const uint16x8_t ph0h = vmull_u8(vget_high_u8(v0_0h), vget_high_u8(v1_0h));
|
|
|
|
|
const uint16x8_t ph1l = vmull_u8(vget_low_u8 (v0_1h), vget_low_u8 (v1_1h));
|
|
|
|
|
const uint16x8_t ph1h = vmull_u8(vget_high_u8(v0_1h), vget_high_u8(v1_1h));
|
|
|
|
|
// sub 8
|
|
|
|
|
const int8x16_t v0_0ls = vsubq_s8(v0_0l, s8b);
|
|
|
|
|
const int8x16_t v0_1ls = vsubq_s8(v0_1l, s8b);
|
|
|
|
|
const int8x16_t v1_0ls = vsubq_s8(v1_0l, s8b);
|
|
|
|
|
const int8x16_t v1_1ls = vsubq_s8(v1_1l, s8b);
|
|
|
|
|
|
|
|
|
|
const int8x16_t v0_0hs = vsubq_s8(v0_0h, s8b);
|
|
|
|
|
const int8x16_t v0_1hs = vsubq_s8(v0_1h, s8b);
|
|
|
|
|
const int8x16_t v1_0hs = vsubq_s8(v1_0h, s8b);
|
|
|
|
|
const int8x16_t v1_1hs = vsubq_s8(v1_1h, s8b);
|
|
|
|
|
|
|
|
|
|
// dot product into int16x8_t
|
|
|
|
|
const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0ls), vget_low_s8 (v1_0ls));
|
|
|
|
|
const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0ls), vget_high_s8(v1_0ls));
|
|
|
|
|
const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1ls), vget_low_s8 (v1_1ls));
|
|
|
|
|
const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1ls), vget_high_s8(v1_1ls));
|
|
|
|
|
|
|
|
|
|
const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hs), vget_low_s8 (v1_0hs));
|
|
|
|
|
const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hs), vget_high_s8(v1_0hs));
|
|
|
|
|
const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hs), vget_low_s8 (v1_1hs));
|
|
|
|
|
const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hs), vget_high_s8(v1_1hs));
|
|
|
|
|
|
|
|
|
|
const int16x8_t pl0 = vaddq_s16(pl0l, pl0h);
|
|
|
|
|
const int16x8_t pl1 = vaddq_s16(pl1l, pl1h);
|
|
|
|
|
const int16x8_t ph0 = vaddq_s16(ph0l, ph0h);
|
|
|
|
|
const int16x8_t ph1 = vaddq_s16(ph1l, ph1h);
|
|
|
|
|
|
|
|
|
|
const int16x8_t pl = vaddq_s16(pl0, pl1);
|
|
|
|
|
const int16x8_t ph = vaddq_s16(ph0, ph1);
|
|
|
|
|
|
|
|
|
|
//const int8x16_t pl0 = vmulq_s8(v0_0ls, v1_0ls);
|
|
|
|
|
//const int8x16_t pl1 = vmulq_s8(v0_1ls, v1_1ls);
|
|
|
|
|
//const int8x16_t ph0 = vmulq_s8(v0_0hs, v1_0hs);
|
|
|
|
|
//const int8x16_t ph1 = vmulq_s8(v0_1hs, v1_1hs);
|
|
|
|
|
|
|
|
|
|
//const int16x8_t pll = vaddl_s8(vget_low_s8(pl0), vget_low_s8(pl1));
|
|
|
|
|
//const int16x8_t plh = vaddl_s8(vget_high_s8(pl0), vget_high_s8(pl1));
|
|
|
|
|
//const int16x8_t phl = vaddl_s8(vget_low_s8(ph0), vget_low_s8(ph1));
|
|
|
|
|
//const int16x8_t phh = vaddl_s8(vget_high_s8(ph0), vget_high_s8(ph1));
|
|
|
|
|
|
|
|
|
|
//const int16x8_t pl = vaddq_s16(pll, plh);
|
|
|
|
|
//const int16x8_t ph = vaddq_s16(phl, phh);
|
|
|
|
|
|
|
|
|
|
const int16x8_t p = vaddq_s16(pl, ph);
|
|
|
|
|
|
|
|
|
|
sum00 += m0*m1;
|
|
|
|
|
sum01 += m1*d0*(((int)vaddvq_u8(v0_0l)) + ((int)vaddvq_u8(v0_0h)) + ((int)vaddvq_u8(v0_1l)) + ((int)vaddvq_u8(v0_1h)));
|
|
|
|
|
sum10 += m0*d1*(((int)vaddvq_u8(v1_0l)) + ((int)vaddvq_u8(v1_0h)) + ((int)vaddvq_u8(v1_1l)) + ((int)vaddvq_u8(v1_1h)));
|
|
|
|
|
sum11 += d0*d1*(
|
|
|
|
|
vaddvq_u16(vaddq_u16(vaddq_u16(pl0l, pl0h), vaddq_u16(pl1l, pl1h))) +
|
|
|
|
|
vaddvq_u16(vaddq_u16(vaddq_u16(ph0l, ph0h), vaddq_u16(ph1l, ph1h))));
|
|
|
|
|
// convert to float
|
|
|
|
|
//const float32x4_t pf0 = vcvtq_f32_s32(vmovl_s16(vget_low_s16 (p)));
|
|
|
|
|
//const float32x4_t pf1 = vcvtq_f32_s32(vmovl_s16(vget_high_s16(p)));
|
|
|
|
|
|
|
|
|
|
// scalar
|
|
|
|
|
sum11 += d0*d1*vaddvq_u16(p);
|
|
|
|
|
//sum11 += d0*d1*(vaddvq_u16(pl) + vaddvq_u16(ph));
|
|
|
|
|
//sum11 += d0*d1*vaddvq_u16(vaddq_s16(pl, ph));
|
|
|
|
|
//sum11 += d0*d1*(vaddvq_s8(pl0) + vaddvq_s8(pl1) + vaddvq_s8(ph0) + vaddvq_s8(ph1));
|
|
|
|
|
//sum11 += d0*d1*(vaddvq_s16(pll) + vaddvq_s16(plh) + vaddvq_s16(phl) + vaddvq_s16(phh));
|
|
|
|
|
|
|
|
|
|
//sum_0 = vfmaq_f16(sum_0, d0d1v, vcvtq_f16_s16(p));
|
|
|
|
|
//sum_0 = vfmaq_f16(sum_0, d0d1v, vcvtq_f16_s16(pl));
|
|
|
|
|
//sum_1 = vfmaq_f16(sum_1, d0d1v, vcvtq_f16_s16(ph));
|
|
|
|
|
|
|
|
|
|
// vectorize
|
|
|
|
|
//sum_0 = vmlaq_f32(sum_0, d0d1v, pf0);
|
|
|
|
|
//sum_1 = vmlaq_f32(sum_1, d0d1v, pf1);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
sumf = 64.0*sum00 + sum01 + sum10 + sum11;
|
|
|
|
|
sumf = sum11;
|
|
|
|
|
//sumf = vaddvq_f32(sum_0) + vaddvq_f32(sum_1);
|
|
|
|
|
//sumf = sum_0[0] + sum_0[1] + sum_0[2] + sum_0[3] + sum_0[4] + sum_0[5] + sum_0[6] + sum_0[7];
|
|
|
|
|
//sum_0 = vaddq_f16(sum_0, sum_1);
|
|
|
|
|
//sumf = sum_0[0] + sum_0[1] + sum_0[2] + sum_0[3] + sum_0[4] + sum_0[5] + sum_0[6] + sum_0[7];
|
|
|
|
|
#endif
|
|
|
|
|
#endif
|
|
|
|
|
|
|
|
|
|