ggml : fixes for rpi4

gq
Georgi Gerganov 2 years ago
parent 2fcbd28143
commit b7621b4fda
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735

@ -1169,17 +1169,17 @@ inline static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void
const uint8x16_t v1_1 = vld1q_u8(p1 + 16);
// 4-bit -> 8-bit
const int8x16_t v0_0l = vandq_u8(v0_0, m4b);
const int8x16_t v1_0l = vandq_u8(v1_0, m4b);
const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8(v0_0, m4b));
const int8x16_t v1_0l = vreinterpretq_s8_u8(vandq_u8(v1_0, m4b));
const int8x16_t v0_0h = vshrq_n_u8(v0_0, 4);
const int8x16_t v1_0h = vshrq_n_u8(v1_0, 4);
const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4));
const int8x16_t v1_0h = vreinterpretq_s8_u8(vshrq_n_u8(v1_0, 4));
const int8x16_t v0_1l = vandq_u8(v0_1, m4b);
const int8x16_t v1_1l = vandq_u8(v1_1, m4b);
const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8(v0_1, m4b));
const int8x16_t v1_1l = vreinterpretq_s8_u8(vandq_u8(v1_1, m4b));
const int8x16_t v0_1h = vshrq_n_u8(v0_1, 4);
const int8x16_t v1_1h = vshrq_n_u8(v1_1, 4);
const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4));
const int8x16_t v1_1h = vreinterpretq_s8_u8(vshrq_n_u8(v1_1, 4));
// sub 8
const int8x16_t v0_0ls = vsubq_s8(v0_0l, s8b);
@ -1220,8 +1220,13 @@ inline static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void
//printf("p_1: %d %d %d %d %d %d %d %d\n", vgetq_lane_s16(p_1, 0), vgetq_lane_s16(p_1, 1), vgetq_lane_s16(p_1, 2), vgetq_lane_s16(p_1, 3), vgetq_lane_s16(p_1, 4), vgetq_lane_s16(p_1, 5), vgetq_lane_s16(p_1, 6), vgetq_lane_s16(p_1, 7));
// scalar
#if defined(__ARM_FEATURE_QRDMX)
sum0 += d0_0*d1_0*vaddvq_s16(p_0);
sum1 += d0_1*d1_1*vaddvq_s16(p_1);
#else
sum0 += d0_0*d1_0*(vgetq_lane_s16(p_0, 0) + vgetq_lane_s16(p_0, 1) + vgetq_lane_s16(p_0, 2) + vgetq_lane_s16(p_0, 3) + vgetq_lane_s16(p_0, 4) + vgetq_lane_s16(p_0, 5) + vgetq_lane_s16(p_0, 6) + vgetq_lane_s16(p_0, 7));
sum1 += d0_1*d1_1*(vgetq_lane_s16(p_1, 0) + vgetq_lane_s16(p_1, 1) + vgetq_lane_s16(p_1, 2) + vgetq_lane_s16(p_1, 3) + vgetq_lane_s16(p_1, 4) + vgetq_lane_s16(p_1, 5) + vgetq_lane_s16(p_1, 6) + vgetq_lane_s16(p_1, 7));
#endif
}
sumf = sum0 + sum1;
@ -1468,13 +1473,15 @@ inline static void ggml_vec_mad_q4_0(const int n, float * restrict y, void * res
for (int j = 0; j < 2; j++) {
const uint8x8_t vx = vld1_u8(pp + j*8);
const uint8x8_t vxl = vand_u8 (vx, m4b);
const uint8x8_t vxh = vshr_n_u8(vx, 4);
const int8x8_t vxl = vreinterpret_s8_u8(vand_u8(vx, m4b));
const int8x8_t vxh = vreinterpret_s8_u8(vshr_n_u8(vx, 4));
// sub 8
const int8x8_t vxls = vsub_s8(vxl, s8b);
const int8x8_t vxhs = vsub_s8(vxh, s8b);
//const int8x8_t vxlt = vzip_s8(vxls, vxhs)[0];
//const int8x8_t vxht = vzip_s8(vxls, vxhs)[1];
const int8x8_t vxlt = vzip1_s8(vxls, vxhs);
const int8x8_t vxht = vzip2_s8(vxls, vxhs);

Loading…
Cancel
Save