|
|
|
@ -352,7 +352,7 @@ static const size_t CACHE_LINE_SIZE_F32 = CACHE_LINE_SIZE/sizeof(float);
|
|
|
|
|
// quantization
|
|
|
|
|
//
|
|
|
|
|
|
|
|
|
|
#define QK 64
|
|
|
|
|
#define QK 32
|
|
|
|
|
|
|
|
|
|
// method 5
|
|
|
|
|
// blocks of QK elements
|
|
|
|
@ -1094,6 +1094,9 @@ inline static void ggml_vec_dot_f16(const int n, float * restrict s, ggml_fp16_t
|
|
|
|
|
inline static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void * restrict x, const void * restrict y) {
|
|
|
|
|
const int nb = n / QK;
|
|
|
|
|
|
|
|
|
|
assert(n % QK == 0);
|
|
|
|
|
assert(nb % 2 == 0);
|
|
|
|
|
|
|
|
|
|
const float * restrict pd0 = (const float *) x;
|
|
|
|
|
const float * restrict pd1 = (const float *) y;
|
|
|
|
|
|
|
|
|
@ -1102,6 +1105,120 @@ inline static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void
|
|
|
|
|
|
|
|
|
|
float sumf = 0.0;
|
|
|
|
|
|
|
|
|
|
#ifdef __ARM_NEON
|
|
|
|
|
#if QK == 32
|
|
|
|
|
float sum0 = 0.0f;
|
|
|
|
|
float sum1 = 0.0f;
|
|
|
|
|
|
|
|
|
|
for (int i = 0; i < nb; i += 2) {
|
|
|
|
|
const float d0_0 = pd0[i + 0];
|
|
|
|
|
const float d1_0 = pd1[i + 0];
|
|
|
|
|
const float d0_1 = pd0[i + 1];
|
|
|
|
|
const float d1_1 = pd1[i + 1];
|
|
|
|
|
|
|
|
|
|
//printf("d0_0: %f, d1_0: %f, d0_1: %f, d1_1: %f\n", d0_0, d1_0, d0_1, d1_1);
|
|
|
|
|
|
|
|
|
|
const uint8_t * restrict p0 = pb0 + i*16;
|
|
|
|
|
const uint8_t * restrict p1 = pb1 + i*16;
|
|
|
|
|
|
|
|
|
|
const uint8x16_t m4b = vdupq_n_u8(0xf);
|
|
|
|
|
const int8x16_t s8b = vdupq_n_s8(0x8);
|
|
|
|
|
|
|
|
|
|
const uint8x16_t v0_0 = vld1q_u8(p0);
|
|
|
|
|
const uint8x16_t v1_0 = vld1q_u8(p1);
|
|
|
|
|
const uint8x16_t v0_1 = vld1q_u8(p0 + 16);
|
|
|
|
|
const uint8x16_t v1_1 = vld1q_u8(p1 + 16);
|
|
|
|
|
|
|
|
|
|
// 4-bit -> 8-bit
|
|
|
|
|
const int8x16_t v0_0l = vandq_u8(v0_0, m4b);
|
|
|
|
|
const int8x16_t v1_0l = vandq_u8(v1_0, m4b);
|
|
|
|
|
|
|
|
|
|
const int8x16_t v0_0h = vshrq_n_u8(v0_0, 4);
|
|
|
|
|
const int8x16_t v1_0h = vshrq_n_u8(v1_0, 4);
|
|
|
|
|
|
|
|
|
|
const int8x16_t v0_1l = vandq_u8(v0_1, m4b);
|
|
|
|
|
const int8x16_t v1_1l = vandq_u8(v1_1, m4b);
|
|
|
|
|
|
|
|
|
|
const int8x16_t v0_1h = vshrq_n_u8(v0_1, 4);
|
|
|
|
|
const int8x16_t v1_1h = vshrq_n_u8(v1_1, 4);
|
|
|
|
|
|
|
|
|
|
// sub 8
|
|
|
|
|
const int8x16_t v0_0ls = vsubq_s8(v0_0l, s8b);
|
|
|
|
|
const int8x16_t v1_0ls = vsubq_s8(v1_0l, s8b);
|
|
|
|
|
|
|
|
|
|
const int8x16_t v0_0hs = vsubq_s8(v0_0h, s8b);
|
|
|
|
|
const int8x16_t v1_0hs = vsubq_s8(v1_0h, s8b);
|
|
|
|
|
|
|
|
|
|
const int8x16_t v0_1ls = vsubq_s8(v0_1l, s8b);
|
|
|
|
|
const int8x16_t v1_1ls = vsubq_s8(v1_1l, s8b);
|
|
|
|
|
|
|
|
|
|
const int8x16_t v0_1hs = vsubq_s8(v0_1h, s8b);
|
|
|
|
|
const int8x16_t v1_1hs = vsubq_s8(v1_1h, s8b);
|
|
|
|
|
|
|
|
|
|
// dot product into int16x8_t
|
|
|
|
|
const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0ls), vget_low_s8 (v1_0ls));
|
|
|
|
|
const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0ls), vget_high_s8(v1_0ls));
|
|
|
|
|
|
|
|
|
|
const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hs), vget_low_s8 (v1_0hs));
|
|
|
|
|
const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hs), vget_high_s8(v1_0hs));
|
|
|
|
|
|
|
|
|
|
const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1ls), vget_low_s8 (v1_1ls));
|
|
|
|
|
const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1ls), vget_high_s8(v1_1ls));
|
|
|
|
|
|
|
|
|
|
const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hs), vget_low_s8 (v1_1hs));
|
|
|
|
|
const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hs), vget_high_s8(v1_1hs));
|
|
|
|
|
|
|
|
|
|
const int16x8_t pl_0 = vaddq_s16(pl0l, pl0h);
|
|
|
|
|
const int16x8_t ph_0 = vaddq_s16(ph0l, ph0h);
|
|
|
|
|
|
|
|
|
|
const int16x8_t pl_1 = vaddq_s16(pl1l, pl1h);
|
|
|
|
|
const int16x8_t ph_1 = vaddq_s16(ph1l, ph1h);
|
|
|
|
|
|
|
|
|
|
const int16x8_t p_0 = vaddq_s16(pl_0, ph_0);
|
|
|
|
|
const int16x8_t p_1 = vaddq_s16(pl_1, ph_1);
|
|
|
|
|
|
|
|
|
|
//printf("p_0: %d %d %d %d %d %d %d %d\n", vgetq_lane_s16(p_0, 0), vgetq_lane_s16(p_0, 1), vgetq_lane_s16(p_0, 2), vgetq_lane_s16(p_0, 3), vgetq_lane_s16(p_0, 4), vgetq_lane_s16(p_0, 5), vgetq_lane_s16(p_0, 6), vgetq_lane_s16(p_0, 7));
|
|
|
|
|
//printf("p_1: %d %d %d %d %d %d %d %d\n", vgetq_lane_s16(p_1, 0), vgetq_lane_s16(p_1, 1), vgetq_lane_s16(p_1, 2), vgetq_lane_s16(p_1, 3), vgetq_lane_s16(p_1, 4), vgetq_lane_s16(p_1, 5), vgetq_lane_s16(p_1, 6), vgetq_lane_s16(p_1, 7));
|
|
|
|
|
|
|
|
|
|
// scalar
|
|
|
|
|
sum0 += d0_0*d1_0*vaddvq_s16(p_0);
|
|
|
|
|
sum1 += d0_1*d1_1*vaddvq_s16(p_1);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
sumf = sum0 + sum1;
|
|
|
|
|
|
|
|
|
|
//printf("sumf SIMD = %f\n", sumf);
|
|
|
|
|
|
|
|
|
|
//// scalar
|
|
|
|
|
//sumf = 0.0f;
|
|
|
|
|
//for (int i = 0; i < nb; i++) {
|
|
|
|
|
// const float d0 = pd0[i];
|
|
|
|
|
// const float d1 = pd1[i];
|
|
|
|
|
|
|
|
|
|
// const uint8_t * restrict p0 = pb0 + i*QK/2;
|
|
|
|
|
// const uint8_t * restrict p1 = pb1 + i*QK/2;
|
|
|
|
|
|
|
|
|
|
// for (int j = 0; j < QK/2; j++) {
|
|
|
|
|
// const uint8_t v0 = p0[j];
|
|
|
|
|
// const uint8_t v1 = p1[j];
|
|
|
|
|
|
|
|
|
|
// const float f0 = d0*((int8_t) (v0 & 0xf) - 8);
|
|
|
|
|
// const float f1 = d0*((int8_t) (v0 >> 4) - 8);
|
|
|
|
|
|
|
|
|
|
// const float f2 = d1*((int8_t) (v1 & 0xf) - 8);
|
|
|
|
|
// const float f3 = d1*((int8_t) (v1 >> 4) - 8);
|
|
|
|
|
|
|
|
|
|
// sumf += f0*f2 + f1*f3;
|
|
|
|
|
// }
|
|
|
|
|
//}
|
|
|
|
|
//printf("sumf scalar = %f\n", sumf);
|
|
|
|
|
//printf("--------\n");
|
|
|
|
|
|
|
|
|
|
//exit(0);
|
|
|
|
|
#else
|
|
|
|
|
#error "not implemented"
|
|
|
|
|
#endif
|
|
|
|
|
#else
|
|
|
|
|
// scalar
|
|
|
|
|
for (int i = 0; i < nb; i++) {
|
|
|
|
|
const float d0 = pd0[i];
|
|
|
|
@ -1123,6 +1240,7 @@ inline static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void
|
|
|
|
|
sumf += f0*f2 + f1*f3;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
#endif
|
|
|
|
|
|
|
|
|
|
*s = sumf;
|
|
|
|
|
}
|
|
|
|
@ -1296,6 +1414,124 @@ inline static void ggml_vec_mad_q4_0(const int n, float * restrict y, void * res
|
|
|
|
|
const float * restrict pd = (const float *) (x);
|
|
|
|
|
const uint8_t * restrict pb = (const uint8_t *) (pd + nb);
|
|
|
|
|
|
|
|
|
|
#if __ARM_NEON
|
|
|
|
|
#if QK == 32
|
|
|
|
|
for (int i = 0; i < nb; ++i) {
|
|
|
|
|
const float d0 = pd[i]*v;
|
|
|
|
|
|
|
|
|
|
const uint8_t * restrict pp = pb + i*16;
|
|
|
|
|
|
|
|
|
|
const float32x4_t vd = vdupq_n_f32(d0);
|
|
|
|
|
|
|
|
|
|
const float32x4_t vy0 = vld1q_f32(y + i*32 + 0);
|
|
|
|
|
const float32x4_t vy1 = vld1q_f32(y + i*32 + 4);
|
|
|
|
|
const float32x4_t vy2 = vld1q_f32(y + i*32 + 8);
|
|
|
|
|
const float32x4_t vy3 = vld1q_f32(y + i*32 + 12);
|
|
|
|
|
const float32x4_t vy4 = vld1q_f32(y + i*32 + 16);
|
|
|
|
|
const float32x4_t vy5 = vld1q_f32(y + i*32 + 20);
|
|
|
|
|
const float32x4_t vy6 = vld1q_f32(y + i*32 + 24);
|
|
|
|
|
const float32x4_t vy7 = vld1q_f32(y + i*32 + 28);
|
|
|
|
|
|
|
|
|
|
const uint8x16_t m4b = vdupq_n_u8(0xf);
|
|
|
|
|
const int8x16_t s8b = vdupq_n_s8(0x8);
|
|
|
|
|
|
|
|
|
|
const uint8x16_t vx = vld1q_u8(pp);
|
|
|
|
|
|
|
|
|
|
const uint8x16_t vxl = vandq_u8 (vx, m4b);
|
|
|
|
|
const uint8x16_t vxh = vshrq_n_u8(vx, 4);
|
|
|
|
|
|
|
|
|
|
// sub 8
|
|
|
|
|
const int8x16_t vxls = vsubq_s8(vxl, s8b);
|
|
|
|
|
const int8x16_t vxhs = vsubq_s8(vxh, s8b);
|
|
|
|
|
|
|
|
|
|
const int8x16_t vxlt = vzip1q_s8(vxls, vxhs);
|
|
|
|
|
const int8x16_t vxht = vzip2q_s8(vxls, vxhs);
|
|
|
|
|
|
|
|
|
|
//printf("vxv: %3d %3d %3d %3d %3d %3d %3d %3d %3d %3d %3d %3d %3d %3d %3d %3d ",
|
|
|
|
|
// vgetq_lane_s8(vxlt, 0), vgetq_lane_s8(vxlt, 1), vgetq_lane_s8(vxlt, 2), vgetq_lane_s8(vxlt, 3),
|
|
|
|
|
// vgetq_lane_s8(vxlt, 4), vgetq_lane_s8(vxlt, 5), vgetq_lane_s8(vxlt, 6), vgetq_lane_s8(vxlt, 7),
|
|
|
|
|
// vgetq_lane_s8(vxlt, 8), vgetq_lane_s8(vxlt, 9), vgetq_lane_s8(vxlt, 10), vgetq_lane_s8(vxlt, 11),
|
|
|
|
|
// vgetq_lane_s8(vxlt, 12), vgetq_lane_s8(vxlt, 13), vgetq_lane_s8(vxlt, 14), vgetq_lane_s8(vxlt, 15));
|
|
|
|
|
//printf("%3d %3d %3d %3d %3d %3d %3d %3d %3d %3d %3d %3d %3d %3d %3d %3d\n",
|
|
|
|
|
// vgetq_lane_s8(vxht, 0), vgetq_lane_s8(vxht, 1), vgetq_lane_s8(vxht, 2), vgetq_lane_s8(vxht, 3),
|
|
|
|
|
// vgetq_lane_s8(vxht, 4), vgetq_lane_s8(vxht, 5), vgetq_lane_s8(vxht, 6), vgetq_lane_s8(vxht, 7),
|
|
|
|
|
// vgetq_lane_s8(vxht, 8), vgetq_lane_s8(vxht, 9), vgetq_lane_s8(vxht, 10), vgetq_lane_s8(vxht, 11),
|
|
|
|
|
// vgetq_lane_s8(vxht, 12), vgetq_lane_s8(vxht, 13), vgetq_lane_s8(vxht, 14), vgetq_lane_s8(vxht, 15));
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// convert to 4x int16x8_t
|
|
|
|
|
const int16x8_t vxls0 = vmovl_s8(vget_low_s8 (vxlt));
|
|
|
|
|
const int16x8_t vxls1 = vmovl_s8(vget_high_s8(vxlt));
|
|
|
|
|
const int16x8_t vxhs0 = vmovl_s8(vget_low_s8 (vxht));
|
|
|
|
|
const int16x8_t vxhs1 = vmovl_s8(vget_high_s8(vxht));
|
|
|
|
|
|
|
|
|
|
// convert to 8x float32x4_t
|
|
|
|
|
const float32x4_t vx0 = vcvtq_f32_s32(vmovl_s16(vget_low_s16 (vxls0)));
|
|
|
|
|
const float32x4_t vx1 = vcvtq_f32_s32(vmovl_s16(vget_high_s16(vxls0)));
|
|
|
|
|
const float32x4_t vx2 = vcvtq_f32_s32(vmovl_s16(vget_low_s16 (vxls1)));
|
|
|
|
|
const float32x4_t vx3 = vcvtq_f32_s32(vmovl_s16(vget_high_s16(vxls1)));
|
|
|
|
|
const float32x4_t vx4 = vcvtq_f32_s32(vmovl_s16(vget_low_s16 (vxhs0)));
|
|
|
|
|
const float32x4_t vx5 = vcvtq_f32_s32(vmovl_s16(vget_high_s16(vxhs0)));
|
|
|
|
|
const float32x4_t vx6 = vcvtq_f32_s32(vmovl_s16(vget_low_s16 (vxhs1)));
|
|
|
|
|
const float32x4_t vx7 = vcvtq_f32_s32(vmovl_s16(vget_high_s16(vxhs1)));
|
|
|
|
|
|
|
|
|
|
//printf("vxv: %6.2f %6.2f %6.2f %6.2f %6.2f %6.2f %6.2f %6.2f %6.2f %6.2f %6.2f %6.2f %6.2f %6.2f %6.2f %6.2f\n",
|
|
|
|
|
// vgetq_lane_f32(vx0m, 0), vgetq_lane_f32(vx0m, 1), vgetq_lane_f32(vx0m, 2), vgetq_lane_f32(vx0m, 3),
|
|
|
|
|
// vgetq_lane_f32(vx1m, 0), vgetq_lane_f32(vx1m, 1), vgetq_lane_f32(vx1m, 2), vgetq_lane_f32(vx1m, 3),
|
|
|
|
|
// vgetq_lane_f32(vx2m, 0), vgetq_lane_f32(vx2m, 1), vgetq_lane_f32(vx2m, 2), vgetq_lane_f32(vx2m, 3),
|
|
|
|
|
// vgetq_lane_f32(vx3m, 0), vgetq_lane_f32(vx3m, 1), vgetq_lane_f32(vx3m, 2), vgetq_lane_f32(vx3m, 3));
|
|
|
|
|
|
|
|
|
|
const float32x4_t vr0 = vfmaq_f32(vy0, vx0, vd);
|
|
|
|
|
const float32x4_t vr1 = vfmaq_f32(vy1, vx1, vd);
|
|
|
|
|
const float32x4_t vr2 = vfmaq_f32(vy2, vx2, vd);
|
|
|
|
|
const float32x4_t vr3 = vfmaq_f32(vy3, vx3, vd);
|
|
|
|
|
const float32x4_t vr4 = vfmaq_f32(vy4, vx4, vd);
|
|
|
|
|
const float32x4_t vr5 = vfmaq_f32(vy5, vx5, vd);
|
|
|
|
|
const float32x4_t vr6 = vfmaq_f32(vy6, vx6, vd);
|
|
|
|
|
const float32x4_t vr7 = vfmaq_f32(vy7, vx7, vd);
|
|
|
|
|
|
|
|
|
|
vst1q_f32(y + i*32 + 0, vr0);
|
|
|
|
|
vst1q_f32(y + i*32 + 4, vr1);
|
|
|
|
|
vst1q_f32(y + i*32 + 8, vr2);
|
|
|
|
|
vst1q_f32(y + i*32 + 12, vr3);
|
|
|
|
|
vst1q_f32(y + i*32 + 16, vr4);
|
|
|
|
|
vst1q_f32(y + i*32 + 20, vr5);
|
|
|
|
|
vst1q_f32(y + i*32 + 24, vr6);
|
|
|
|
|
vst1q_f32(y + i*32 + 28, vr7);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
//printf("------\n");
|
|
|
|
|
|
|
|
|
|
//for (int i = 0; i < nb; i++) {
|
|
|
|
|
// const float d = pd[i];
|
|
|
|
|
|
|
|
|
|
// const uint8_t * restrict pp = pb + i*QK/2;
|
|
|
|
|
|
|
|
|
|
// printf("vxs: ");
|
|
|
|
|
// for (int l = 0; l < QK; l += 2) {
|
|
|
|
|
// const uint8_t vi = pp[l/2];
|
|
|
|
|
|
|
|
|
|
// const int8_t vi0 = vi & 0xf;
|
|
|
|
|
// const int8_t vi1 = vi >> 4;
|
|
|
|
|
|
|
|
|
|
// const float v0 = (vi0 - 8)*d;
|
|
|
|
|
// const float v1 = (vi1 - 8)*d;
|
|
|
|
|
|
|
|
|
|
// y[i*QK + l + 0] += v0*v;
|
|
|
|
|
// y[i*QK + l + 1] += v1*v;
|
|
|
|
|
|
|
|
|
|
// if (l < QK) {
|
|
|
|
|
// //printf("%6.2f %6.2f ", v0, v1);
|
|
|
|
|
// printf("%3d %3d ", vi0 - 8, vi1 - 8);
|
|
|
|
|
// }
|
|
|
|
|
// }
|
|
|
|
|
// printf("\n");
|
|
|
|
|
//}
|
|
|
|
|
|
|
|
|
|
//exit(0);
|
|
|
|
|
#endif
|
|
|
|
|
#else
|
|
|
|
|
// scalar
|
|
|
|
|
for (int i = 0; i < nb; i++) {
|
|
|
|
|
const float d = pd[i];
|
|
|
|
|
|
|
|
|
@ -1320,6 +1556,7 @@ inline static void ggml_vec_mad_q4_0(const int n, float * restrict y, void * res
|
|
|
|
|
//printf("mad: v0 %f v1 %f, i = %d, l = %d, d = %f, vi = %d, vi0 = %d, vi1 = %d\n", v0, v1, i, l, d, vi, vi0, vi1);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
#endif
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
inline static void ggml_vec_mad_q4_1(const int n, float * restrict y, void * restrict x, const float v) {
|
|
|
|
|