ggml : vectorized mad q4_0 (ARM)

gq
Georgi Gerganov 1 year ago
parent 8ce6d1e492
commit ea97a5f469
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735

@ -12,7 +12,7 @@
#include <vector>
#include <regex>
#define QK 64
#define QK 32
size_t ggml_quantize_q4_0(float * src, void * dst, int n, int k) {
const int nb = k / QK;

@ -352,7 +352,7 @@ static const size_t CACHE_LINE_SIZE_F32 = CACHE_LINE_SIZE/sizeof(float);
// quantization
//
#define QK 64
#define QK 32
// method 5
// blocks of QK elements
@ -1094,6 +1094,9 @@ inline static void ggml_vec_dot_f16(const int n, float * restrict s, ggml_fp16_t
inline static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void * restrict x, const void * restrict y) {
const int nb = n / QK;
assert(n % QK == 0);
assert(nb % 2 == 0);
const float * restrict pd0 = (const float *) x;
const float * restrict pd1 = (const float *) y;
@ -1102,6 +1105,120 @@ inline static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void
float sumf = 0.0;
#ifdef __ARM_NEON
#if QK == 32
float sum0 = 0.0f;
float sum1 = 0.0f;
for (int i = 0; i < nb; i += 2) {
const float d0_0 = pd0[i + 0];
const float d1_0 = pd1[i + 0];
const float d0_1 = pd0[i + 1];
const float d1_1 = pd1[i + 1];
//printf("d0_0: %f, d1_0: %f, d0_1: %f, d1_1: %f\n", d0_0, d1_0, d0_1, d1_1);
const uint8_t * restrict p0 = pb0 + i*16;
const uint8_t * restrict p1 = pb1 + i*16;
const uint8x16_t m4b = vdupq_n_u8(0xf);
const int8x16_t s8b = vdupq_n_s8(0x8);
const uint8x16_t v0_0 = vld1q_u8(p0);
const uint8x16_t v1_0 = vld1q_u8(p1);
const uint8x16_t v0_1 = vld1q_u8(p0 + 16);
const uint8x16_t v1_1 = vld1q_u8(p1 + 16);
// 4-bit -> 8-bit
const int8x16_t v0_0l = vandq_u8(v0_0, m4b);
const int8x16_t v1_0l = vandq_u8(v1_0, m4b);
const int8x16_t v0_0h = vshrq_n_u8(v0_0, 4);
const int8x16_t v1_0h = vshrq_n_u8(v1_0, 4);
const int8x16_t v0_1l = vandq_u8(v0_1, m4b);
const int8x16_t v1_1l = vandq_u8(v1_1, m4b);
const int8x16_t v0_1h = vshrq_n_u8(v0_1, 4);
const int8x16_t v1_1h = vshrq_n_u8(v1_1, 4);
// sub 8
const int8x16_t v0_0ls = vsubq_s8(v0_0l, s8b);
const int8x16_t v1_0ls = vsubq_s8(v1_0l, s8b);
const int8x16_t v0_0hs = vsubq_s8(v0_0h, s8b);
const int8x16_t v1_0hs = vsubq_s8(v1_0h, s8b);
const int8x16_t v0_1ls = vsubq_s8(v0_1l, s8b);
const int8x16_t v1_1ls = vsubq_s8(v1_1l, s8b);
const int8x16_t v0_1hs = vsubq_s8(v0_1h, s8b);
const int8x16_t v1_1hs = vsubq_s8(v1_1h, s8b);
// dot product into int16x8_t
const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0ls), vget_low_s8 (v1_0ls));
const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0ls), vget_high_s8(v1_0ls));
const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hs), vget_low_s8 (v1_0hs));
const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hs), vget_high_s8(v1_0hs));
const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1ls), vget_low_s8 (v1_1ls));
const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1ls), vget_high_s8(v1_1ls));
const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hs), vget_low_s8 (v1_1hs));
const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hs), vget_high_s8(v1_1hs));
const int16x8_t pl_0 = vaddq_s16(pl0l, pl0h);
const int16x8_t ph_0 = vaddq_s16(ph0l, ph0h);
const int16x8_t pl_1 = vaddq_s16(pl1l, pl1h);
const int16x8_t ph_1 = vaddq_s16(ph1l, ph1h);
const int16x8_t p_0 = vaddq_s16(pl_0, ph_0);
const int16x8_t p_1 = vaddq_s16(pl_1, ph_1);
//printf("p_0: %d %d %d %d %d %d %d %d\n", vgetq_lane_s16(p_0, 0), vgetq_lane_s16(p_0, 1), vgetq_lane_s16(p_0, 2), vgetq_lane_s16(p_0, 3), vgetq_lane_s16(p_0, 4), vgetq_lane_s16(p_0, 5), vgetq_lane_s16(p_0, 6), vgetq_lane_s16(p_0, 7));
//printf("p_1: %d %d %d %d %d %d %d %d\n", vgetq_lane_s16(p_1, 0), vgetq_lane_s16(p_1, 1), vgetq_lane_s16(p_1, 2), vgetq_lane_s16(p_1, 3), vgetq_lane_s16(p_1, 4), vgetq_lane_s16(p_1, 5), vgetq_lane_s16(p_1, 6), vgetq_lane_s16(p_1, 7));
// scalar
sum0 += d0_0*d1_0*vaddvq_s16(p_0);
sum1 += d0_1*d1_1*vaddvq_s16(p_1);
}
sumf = sum0 + sum1;
//printf("sumf SIMD = %f\n", sumf);
//// scalar
//sumf = 0.0f;
//for (int i = 0; i < nb; i++) {
// const float d0 = pd0[i];
// const float d1 = pd1[i];
// const uint8_t * restrict p0 = pb0 + i*QK/2;
// const uint8_t * restrict p1 = pb1 + i*QK/2;
// for (int j = 0; j < QK/2; j++) {
// const uint8_t v0 = p0[j];
// const uint8_t v1 = p1[j];
// const float f0 = d0*((int8_t) (v0 & 0xf) - 8);
// const float f1 = d0*((int8_t) (v0 >> 4) - 8);
// const float f2 = d1*((int8_t) (v1 & 0xf) - 8);
// const float f3 = d1*((int8_t) (v1 >> 4) - 8);
// sumf += f0*f2 + f1*f3;
// }
//}
//printf("sumf scalar = %f\n", sumf);
//printf("--------\n");
//exit(0);
#else
#error "not implemented"
#endif
#else
// scalar
for (int i = 0; i < nb; i++) {
const float d0 = pd0[i];
@ -1123,6 +1240,7 @@ inline static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void
sumf += f0*f2 + f1*f3;
}
}
#endif
*s = sumf;
}
@ -1296,6 +1414,124 @@ inline static void ggml_vec_mad_q4_0(const int n, float * restrict y, void * res
const float * restrict pd = (const float *) (x);
const uint8_t * restrict pb = (const uint8_t *) (pd + nb);
#if __ARM_NEON
#if QK == 32
for (int i = 0; i < nb; ++i) {
const float d0 = pd[i]*v;
const uint8_t * restrict pp = pb + i*16;
const float32x4_t vd = vdupq_n_f32(d0);
const float32x4_t vy0 = vld1q_f32(y + i*32 + 0);
const float32x4_t vy1 = vld1q_f32(y + i*32 + 4);
const float32x4_t vy2 = vld1q_f32(y + i*32 + 8);
const float32x4_t vy3 = vld1q_f32(y + i*32 + 12);
const float32x4_t vy4 = vld1q_f32(y + i*32 + 16);
const float32x4_t vy5 = vld1q_f32(y + i*32 + 20);
const float32x4_t vy6 = vld1q_f32(y + i*32 + 24);
const float32x4_t vy7 = vld1q_f32(y + i*32 + 28);
const uint8x16_t m4b = vdupq_n_u8(0xf);
const int8x16_t s8b = vdupq_n_s8(0x8);
const uint8x16_t vx = vld1q_u8(pp);
const uint8x16_t vxl = vandq_u8 (vx, m4b);
const uint8x16_t vxh = vshrq_n_u8(vx, 4);
// sub 8
const int8x16_t vxls = vsubq_s8(vxl, s8b);
const int8x16_t vxhs = vsubq_s8(vxh, s8b);
const int8x16_t vxlt = vzip1q_s8(vxls, vxhs);
const int8x16_t vxht = vzip2q_s8(vxls, vxhs);
//printf("vxv: %3d %3d %3d %3d %3d %3d %3d %3d %3d %3d %3d %3d %3d %3d %3d %3d ",
// vgetq_lane_s8(vxlt, 0), vgetq_lane_s8(vxlt, 1), vgetq_lane_s8(vxlt, 2), vgetq_lane_s8(vxlt, 3),
// vgetq_lane_s8(vxlt, 4), vgetq_lane_s8(vxlt, 5), vgetq_lane_s8(vxlt, 6), vgetq_lane_s8(vxlt, 7),
// vgetq_lane_s8(vxlt, 8), vgetq_lane_s8(vxlt, 9), vgetq_lane_s8(vxlt, 10), vgetq_lane_s8(vxlt, 11),
// vgetq_lane_s8(vxlt, 12), vgetq_lane_s8(vxlt, 13), vgetq_lane_s8(vxlt, 14), vgetq_lane_s8(vxlt, 15));
//printf("%3d %3d %3d %3d %3d %3d %3d %3d %3d %3d %3d %3d %3d %3d %3d %3d\n",
// vgetq_lane_s8(vxht, 0), vgetq_lane_s8(vxht, 1), vgetq_lane_s8(vxht, 2), vgetq_lane_s8(vxht, 3),
// vgetq_lane_s8(vxht, 4), vgetq_lane_s8(vxht, 5), vgetq_lane_s8(vxht, 6), vgetq_lane_s8(vxht, 7),
// vgetq_lane_s8(vxht, 8), vgetq_lane_s8(vxht, 9), vgetq_lane_s8(vxht, 10), vgetq_lane_s8(vxht, 11),
// vgetq_lane_s8(vxht, 12), vgetq_lane_s8(vxht, 13), vgetq_lane_s8(vxht, 14), vgetq_lane_s8(vxht, 15));
// convert to 4x int16x8_t
const int16x8_t vxls0 = vmovl_s8(vget_low_s8 (vxlt));
const int16x8_t vxls1 = vmovl_s8(vget_high_s8(vxlt));
const int16x8_t vxhs0 = vmovl_s8(vget_low_s8 (vxht));
const int16x8_t vxhs1 = vmovl_s8(vget_high_s8(vxht));
// convert to 8x float32x4_t
const float32x4_t vx0 = vcvtq_f32_s32(vmovl_s16(vget_low_s16 (vxls0)));
const float32x4_t vx1 = vcvtq_f32_s32(vmovl_s16(vget_high_s16(vxls0)));
const float32x4_t vx2 = vcvtq_f32_s32(vmovl_s16(vget_low_s16 (vxls1)));
const float32x4_t vx3 = vcvtq_f32_s32(vmovl_s16(vget_high_s16(vxls1)));
const float32x4_t vx4 = vcvtq_f32_s32(vmovl_s16(vget_low_s16 (vxhs0)));
const float32x4_t vx5 = vcvtq_f32_s32(vmovl_s16(vget_high_s16(vxhs0)));
const float32x4_t vx6 = vcvtq_f32_s32(vmovl_s16(vget_low_s16 (vxhs1)));
const float32x4_t vx7 = vcvtq_f32_s32(vmovl_s16(vget_high_s16(vxhs1)));
//printf("vxv: %6.2f %6.2f %6.2f %6.2f %6.2f %6.2f %6.2f %6.2f %6.2f %6.2f %6.2f %6.2f %6.2f %6.2f %6.2f %6.2f\n",
// vgetq_lane_f32(vx0m, 0), vgetq_lane_f32(vx0m, 1), vgetq_lane_f32(vx0m, 2), vgetq_lane_f32(vx0m, 3),
// vgetq_lane_f32(vx1m, 0), vgetq_lane_f32(vx1m, 1), vgetq_lane_f32(vx1m, 2), vgetq_lane_f32(vx1m, 3),
// vgetq_lane_f32(vx2m, 0), vgetq_lane_f32(vx2m, 1), vgetq_lane_f32(vx2m, 2), vgetq_lane_f32(vx2m, 3),
// vgetq_lane_f32(vx3m, 0), vgetq_lane_f32(vx3m, 1), vgetq_lane_f32(vx3m, 2), vgetq_lane_f32(vx3m, 3));
const float32x4_t vr0 = vfmaq_f32(vy0, vx0, vd);
const float32x4_t vr1 = vfmaq_f32(vy1, vx1, vd);
const float32x4_t vr2 = vfmaq_f32(vy2, vx2, vd);
const float32x4_t vr3 = vfmaq_f32(vy3, vx3, vd);
const float32x4_t vr4 = vfmaq_f32(vy4, vx4, vd);
const float32x4_t vr5 = vfmaq_f32(vy5, vx5, vd);
const float32x4_t vr6 = vfmaq_f32(vy6, vx6, vd);
const float32x4_t vr7 = vfmaq_f32(vy7, vx7, vd);
vst1q_f32(y + i*32 + 0, vr0);
vst1q_f32(y + i*32 + 4, vr1);
vst1q_f32(y + i*32 + 8, vr2);
vst1q_f32(y + i*32 + 12, vr3);
vst1q_f32(y + i*32 + 16, vr4);
vst1q_f32(y + i*32 + 20, vr5);
vst1q_f32(y + i*32 + 24, vr6);
vst1q_f32(y + i*32 + 28, vr7);
}
//printf("------\n");
//for (int i = 0; i < nb; i++) {
// const float d = pd[i];
// const uint8_t * restrict pp = pb + i*QK/2;
// printf("vxs: ");
// for (int l = 0; l < QK; l += 2) {
// const uint8_t vi = pp[l/2];
// const int8_t vi0 = vi & 0xf;
// const int8_t vi1 = vi >> 4;
// const float v0 = (vi0 - 8)*d;
// const float v1 = (vi1 - 8)*d;
// y[i*QK + l + 0] += v0*v;
// y[i*QK + l + 1] += v1*v;
// if (l < QK) {
// //printf("%6.2f %6.2f ", v0, v1);
// printf("%3d %3d ", vi0 - 8, vi1 - 8);
// }
// }
// printf("\n");
//}
//exit(0);
#endif
#else
// scalar
for (int i = 0; i < nb; i++) {
const float d = pd[i];
@ -1320,6 +1556,7 @@ inline static void ggml_vec_mad_q4_0(const int n, float * restrict y, void * res
//printf("mad: v0 %f v1 %f, i = %d, l = %d, d = %f, vi = %d, vi0 = %d, vi1 = %d\n", v0, v1, i, l, d, vi, vi0, vi1);
}
}
#endif
}
inline static void ggml_vec_mad_q4_1(const int n, float * restrict y, void * restrict x, const float v) {

@ -1956,9 +1956,9 @@ void vec_dot_gq_5(const int n, float * restrict s, const void * restrict x, cons
//const float32x4_t pf1 = vcvtq_f32_s32(vmovl_s16(vget_high_s16(p)));
// scalar
sum11 += d0*d1*vaddvq_u16(p);
//sum11 += d0*d1*(vaddvq_u16(pl) + vaddvq_u16(ph));
//sum11 += d0*d1*vaddvq_u16(vaddq_s16(pl, ph));
sum11 += d0*d1*vaddvq_s16(p);
//sum11 += d0*d1*(vaddvq_s16(pl) + vaddvq_s16(ph));
//sum11 += d0*d1*vaddvq_s16(vaddq_s16(pl, ph));
//sum11 += d0*d1*(vaddvq_s8(pl0) + vaddvq_s8(pl1) + vaddvq_s8(ph0) + vaddvq_s8(ph1));
//sum11 += d0*d1*(vaddvq_s16(pll) + vaddvq_s16(plh) + vaddvq_s16(phl) + vaddvq_s16(phh));
@ -2246,7 +2246,7 @@ void vec_dot_gq_6(const int n, float * restrict s, const void * restrict x, cons
const int16x8_t p = vaddq_s16(pl, ph);
// scalar
sum0 += d0*d1*vaddvq_u16(p);
sum0 += d0*d1*vaddvq_s16(p);
}
sumf = sum0;
@ -2267,8 +2267,8 @@ void vec_dot_gq_6(const int n, float * restrict s, const void * restrict x, cons
const int8x16_t s8b = vdupq_n_s8(0x8);
const uint8x16_t v0_0 = vld1q_u8(p0);
const uint8x16_t v1_0 = vld1q_u8(p1);
const uint8x16_t v0_1 = vld1q_u8(p0 + 16);
const uint8x16_t v1_0 = vld1q_u8(p1);
const uint8x16_t v1_1 = vld1q_u8(p1 + 16);
// 4-bit -> 8-bit
@ -2320,8 +2320,8 @@ void vec_dot_gq_6(const int n, float * restrict s, const void * restrict x, cons
const int16x8_t p_1 = vaddq_s16(pl_1, ph_1);
// scalar
sum0 += d0_0*d1_0*vaddvq_u16(p_0);
sum1 += d0_1*d1_1*vaddvq_u16(p_1);
sum0 += d0_0*d1_0*vaddvq_s16(p_0);
sum1 += d0_1*d1_1*vaddvq_s16(p_1);
}
sumf = sum0 + sum1;

Loading…
Cancel
Save