gq : method 5 (ARM)

gq
Georgi Gerganov 2 years ago
parent 5a96c91bea
commit 1ca898f94b
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735

@ -46,7 +46,7 @@ if (GGML_SANITIZE_UNDEFINED)
endif() endif()
#set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -ffast-math") #set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -ffast-math")
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -march=native") #set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -march=native")
#set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -mcpu=native") #set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -mcpu=native")
# dependencies # dependencies

@ -1204,7 +1204,7 @@ void vec_dot_gq_4(const int n, float * restrict s, const void * restrict x, cons
float sumf = 0.0; float sumf = 0.0;
#if 1 #if 0
// scalar // scalar
for (int i = 0; i < nb; i++) { for (int i = 0; i < nb; i++) {
const float m0 = GGML_GQ_TO_FP32(pm0[i]); const float m0 = GGML_GQ_TO_FP32(pm0[i]);
@ -1569,12 +1569,21 @@ void vec_dot_gq_4(const int n, float * restrict s, const void * restrict x, cons
const uint16x8_t ph1l = vmull_u8(vget_low_u8 (v0_1h), vget_low_u8 (v1_1h)); const uint16x8_t ph1l = vmull_u8(vget_low_u8 (v0_1h), vget_low_u8 (v1_1h));
const uint16x8_t ph1h = vmull_u8(vget_high_u8(v0_1h), vget_high_u8(v1_1h)); const uint16x8_t ph1h = vmull_u8(vget_high_u8(v0_1h), vget_high_u8(v1_1h));
const uint16x8_t pl0 = vaddq_u16(pl0l, pl0h);
const uint16x8_t pl1 = vaddq_u16(pl1l, pl1h);
const uint16x8_t ph0 = vaddq_u16(ph0l, ph0h);
const uint16x8_t ph1 = vaddq_u16(ph1l, ph1h);
const uint16x8_t pl = vaddq_u16(pl0, pl1);
const uint16x8_t ph = vaddq_u16(ph0, ph1);
sum00 += m0*m1; sum00 += m0*m1;
sum01 += m1*d0*(((int)vaddvq_u8(v0_0l)) + ((int)vaddvq_u8(v0_0h)) + ((int)vaddvq_u8(v0_1l)) + ((int)vaddvq_u8(v0_1h))); sum01 += m1*d0*(vaddvq_u8(v0_0l) + vaddvq_u8(v0_0h) + vaddvq_u8(v0_1l) + vaddvq_u8(v0_1h));
sum10 += m0*d1*(((int)vaddvq_u8(v1_0l)) + ((int)vaddvq_u8(v1_0h)) + ((int)vaddvq_u8(v1_1l)) + ((int)vaddvq_u8(v1_1h))); sum10 += m0*d1*(vaddvq_u8(v1_0l) + vaddvq_u8(v1_0h) + vaddvq_u8(v1_1l) + vaddvq_u8(v1_1h));
sum11 += d0*d1*( //sum11 += d0*d1*(
vaddvq_u16(vaddq_u16(vaddq_u16(pl0l, pl0h), vaddq_u16(pl1l, pl1h))) + // vaddvq_u16(vaddq_u16(vaddq_u16(pl0l, pl0h), vaddq_u16(pl1l, pl1h))) +
vaddvq_u16(vaddq_u16(vaddq_u16(ph0l, ph0h), vaddq_u16(ph1l, ph1h)))); // vaddvq_u16(vaddq_u16(vaddq_u16(ph0l, ph0h), vaddq_u16(ph1l, ph1h))));
sum11 += d0*d1*vaddvq_u16(vaddq_u16(pl, ph));
} }
sumf = 64.0*sum00 + sum01 + sum10 + sum11; sumf = 64.0*sum00 + sum01 + sum10 + sum11;
@ -1748,7 +1757,8 @@ void quantize_5_row(const float * restrict src, void * restrict dst, int k) {
for (int l = 0; l < QK; l++) { for (int l = 0; l < QK; l++) {
const float v = src[i*QK + l]*id; const float v = src[i*QK + l]*id;
const int8_t vi = ((int8_t) (v + frand())) + 8; const int8_t vi = ((int8_t) (round(v))) + 8;
assert(vi >= 0 && vi < 16);
pp[l/2] |= (vi & 0xf) << (4*(l & 1)); pp[l/2] |= (vi & 0xf) << (4*(l & 1));
} }
@ -1859,29 +1869,33 @@ void vec_dot_gq_5(const int n, float * restrict s, const void * restrict x, cons
sumf = _mm256_hadd_ps_gg(sum11); sumf = _mm256_hadd_ps_gg(sum11);
#endif #endif
#elif defined (__ARM_NEON) #elif defined (__ARM_NEON)
float sum00 = 0.0f;
float sum01 = 0.0f;
float sum10 = 0.0f;
float sum11 = 0.0f; float sum11 = 0.0f;
//float32x4_t sum_0 = vdupq_n_f32(0.0f);
//float32x4_t sum_1 = vdupq_n_f32(0.0f);
//float16x8_t sum_0 = vdupq_n_f16(0.0f);
//float16x8_t sum_1 = vdupq_n_f16(0.0f);
for (int i = 0; i < nb; i++) { for (int i = 0; i < nb; i++) {
const float m0 = GGML_GQ_TO_FP32(pm0[i]);
const float d0 = GGML_GQ_TO_FP32(pd0[i]); const float d0 = GGML_GQ_TO_FP32(pd0[i]);
const float m1 = GGML_GQ_TO_FP32(pm1[i]);
const float d1 = GGML_GQ_TO_FP32(pd1[i]); const float d1 = GGML_GQ_TO_FP32(pd1[i]);
//float32x4_t d0d1v = vdupq_n_f32(d0*d1);
//float16x8_t d0d1v = vdupq_n_f16(d0*d1);
const uint8_t * restrict p0 = pb0 + i*QK/2; const uint8_t * restrict p0 = pb0 + i*QK/2;
const uint8_t * restrict p1 = pb1 + i*QK/2; const uint8_t * restrict p1 = pb1 + i*QK/2;
const uint8x16_t m4b = vdupq_n_u8(0xf); const uint8x16_t m4b = vdupq_n_u8(0xf);
const int8x16_t s8b = vdupq_n_s8(0x8);
const uint8x16_t v0_0 = vld1q_u8(p0); const uint8x16_t v0_0 = vld1q_u8(p0);
const uint8x16_t v0_1 = vld1q_u8(p0 + 16); const uint8x16_t v0_1 = vld1q_u8(p0 + 16);
const uint8x16_t v1_0 = vld1q_u8(p1); const uint8x16_t v1_0 = vld1q_u8(p1);
const uint8x16_t v1_1 = vld1q_u8(p1 + 16); const uint8x16_t v1_1 = vld1q_u8(p1 + 16);
// and with 0xf // 4-bit -> 8-bit
const uint8x16_t v0_0l = vandq_u8(v0_0, m4b); const uint8x16_t v0_0l = vandq_u8(v0_0, m4b);
const uint8x16_t v0_1l = vandq_u8(v0_1, m4b); const uint8x16_t v0_1l = vandq_u8(v0_1, m4b);
const uint8x16_t v1_0l = vandq_u8(v1_0, m4b); const uint8x16_t v1_0l = vandq_u8(v1_0, m4b);
@ -1892,26 +1906,76 @@ void vec_dot_gq_5(const int n, float * restrict s, const void * restrict x, cons
const uint8x16_t v1_0h = vshrq_n_u8(v1_0, 4); const uint8x16_t v1_0h = vshrq_n_u8(v1_0, 4);
const uint8x16_t v1_1h = vshrq_n_u8(v1_1, 4); const uint8x16_t v1_1h = vshrq_n_u8(v1_1, 4);
// dot product into uint16x8_t // sub 8
const uint16x8_t pl0l = vmull_u8(vget_low_u8 (v0_0l), vget_low_u8 (v1_0l)); const int8x16_t v0_0ls = vsubq_s8(v0_0l, s8b);
const uint16x8_t pl0h = vmull_u8(vget_high_u8(v0_0l), vget_high_u8(v1_0l)); const int8x16_t v0_1ls = vsubq_s8(v0_1l, s8b);
const uint16x8_t pl1l = vmull_u8(vget_low_u8 (v0_1l), vget_low_u8 (v1_1l)); const int8x16_t v1_0ls = vsubq_s8(v1_0l, s8b);
const uint16x8_t pl1h = vmull_u8(vget_high_u8(v0_1l), vget_high_u8(v1_1l)); const int8x16_t v1_1ls = vsubq_s8(v1_1l, s8b);
const uint16x8_t ph0l = vmull_u8(vget_low_u8 (v0_0h), vget_low_u8 (v1_0h)); const int8x16_t v0_0hs = vsubq_s8(v0_0h, s8b);
const uint16x8_t ph0h = vmull_u8(vget_high_u8(v0_0h), vget_high_u8(v1_0h)); const int8x16_t v0_1hs = vsubq_s8(v0_1h, s8b);
const uint16x8_t ph1l = vmull_u8(vget_low_u8 (v0_1h), vget_low_u8 (v1_1h)); const int8x16_t v1_0hs = vsubq_s8(v1_0h, s8b);
const uint16x8_t ph1h = vmull_u8(vget_high_u8(v0_1h), vget_high_u8(v1_1h)); const int8x16_t v1_1hs = vsubq_s8(v1_1h, s8b);
// dot product into int16x8_t
const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0ls), vget_low_s8 (v1_0ls));
const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0ls), vget_high_s8(v1_0ls));
const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1ls), vget_low_s8 (v1_1ls));
const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1ls), vget_high_s8(v1_1ls));
const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hs), vget_low_s8 (v1_0hs));
const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hs), vget_high_s8(v1_0hs));
const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hs), vget_low_s8 (v1_1hs));
const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hs), vget_high_s8(v1_1hs));
const int16x8_t pl0 = vaddq_s16(pl0l, pl0h);
const int16x8_t pl1 = vaddq_s16(pl1l, pl1h);
const int16x8_t ph0 = vaddq_s16(ph0l, ph0h);
const int16x8_t ph1 = vaddq_s16(ph1l, ph1h);
const int16x8_t pl = vaddq_s16(pl0, pl1);
const int16x8_t ph = vaddq_s16(ph0, ph1);
//const int8x16_t pl0 = vmulq_s8(v0_0ls, v1_0ls);
//const int8x16_t pl1 = vmulq_s8(v0_1ls, v1_1ls);
//const int8x16_t ph0 = vmulq_s8(v0_0hs, v1_0hs);
//const int8x16_t ph1 = vmulq_s8(v0_1hs, v1_1hs);
//const int16x8_t pll = vaddl_s8(vget_low_s8(pl0), vget_low_s8(pl1));
//const int16x8_t plh = vaddl_s8(vget_high_s8(pl0), vget_high_s8(pl1));
//const int16x8_t phl = vaddl_s8(vget_low_s8(ph0), vget_low_s8(ph1));
//const int16x8_t phh = vaddl_s8(vget_high_s8(ph0), vget_high_s8(ph1));
//const int16x8_t pl = vaddq_s16(pll, plh);
//const int16x8_t ph = vaddq_s16(phl, phh);
const int16x8_t p = vaddq_s16(pl, ph);
sum00 += m0*m1; // convert to float
sum01 += m1*d0*(((int)vaddvq_u8(v0_0l)) + ((int)vaddvq_u8(v0_0h)) + ((int)vaddvq_u8(v0_1l)) + ((int)vaddvq_u8(v0_1h))); //const float32x4_t pf0 = vcvtq_f32_s32(vmovl_s16(vget_low_s16 (p)));
sum10 += m0*d1*(((int)vaddvq_u8(v1_0l)) + ((int)vaddvq_u8(v1_0h)) + ((int)vaddvq_u8(v1_1l)) + ((int)vaddvq_u8(v1_1h))); //const float32x4_t pf1 = vcvtq_f32_s32(vmovl_s16(vget_high_s16(p)));
sum11 += d0*d1*(
vaddvq_u16(vaddq_u16(vaddq_u16(pl0l, pl0h), vaddq_u16(pl1l, pl1h))) + // scalar
vaddvq_u16(vaddq_u16(vaddq_u16(ph0l, ph0h), vaddq_u16(ph1l, ph1h)))); sum11 += d0*d1*vaddvq_u16(p);
//sum11 += d0*d1*(vaddvq_u16(pl) + vaddvq_u16(ph));
//sum11 += d0*d1*vaddvq_u16(vaddq_s16(pl, ph));
//sum11 += d0*d1*(vaddvq_s8(pl0) + vaddvq_s8(pl1) + vaddvq_s8(ph0) + vaddvq_s8(ph1));
//sum11 += d0*d1*(vaddvq_s16(pll) + vaddvq_s16(plh) + vaddvq_s16(phl) + vaddvq_s16(phh));
//sum_0 = vfmaq_f16(sum_0, d0d1v, vcvtq_f16_s16(p));
//sum_0 = vfmaq_f16(sum_0, d0d1v, vcvtq_f16_s16(pl));
//sum_1 = vfmaq_f16(sum_1, d0d1v, vcvtq_f16_s16(ph));
// vectorize
//sum_0 = vmlaq_f32(sum_0, d0d1v, pf0);
//sum_1 = vmlaq_f32(sum_1, d0d1v, pf1);
} }
sumf = 64.0*sum00 + sum01 + sum10 + sum11; sumf = sum11;
//sumf = vaddvq_f32(sum_0) + vaddvq_f32(sum_1);
//sumf = sum_0[0] + sum_0[1] + sum_0[2] + sum_0[3] + sum_0[4] + sum_0[5] + sum_0[6] + sum_0[7];
//sum_0 = vaddq_f16(sum_0, sum_1);
//sumf = sum_0[0] + sum_0[1] + sum_0[2] + sum_0[3] + sum_0[4] + sum_0[5] + sum_0[6] + sum_0[7];
#endif #endif
#endif #endif

Loading…
Cancel
Save