|
|
|
@ -1256,7 +1256,7 @@ inline static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void
|
|
|
|
|
|
|
|
|
|
//exit(0);
|
|
|
|
|
#else
|
|
|
|
|
#error "not implemented"
|
|
|
|
|
#error "not implemented for QK"
|
|
|
|
|
#endif
|
|
|
|
|
#else
|
|
|
|
|
// scalar
|
|
|
|
@ -1461,114 +1461,52 @@ inline static void ggml_vec_mad_q4_0(const int n, float * restrict y, void * res
|
|
|
|
|
|
|
|
|
|
const uint8_t * restrict pp = pb + i*16;
|
|
|
|
|
|
|
|
|
|
const float32x4_t vd = vdupq_n_f32(d0);
|
|
|
|
|
|
|
|
|
|
const float32x4_t vy0 = vld1q_f32(y + i*32 + 0);
|
|
|
|
|
const float32x4_t vy1 = vld1q_f32(y + i*32 + 4);
|
|
|
|
|
const float32x4_t vy2 = vld1q_f32(y + i*32 + 8);
|
|
|
|
|
const float32x4_t vy3 = vld1q_f32(y + i*32 + 12);
|
|
|
|
|
const float32x4_t vy4 = vld1q_f32(y + i*32 + 16);
|
|
|
|
|
const float32x4_t vy5 = vld1q_f32(y + i*32 + 20);
|
|
|
|
|
const float32x4_t vy6 = vld1q_f32(y + i*32 + 24);
|
|
|
|
|
const float32x4_t vy7 = vld1q_f32(y + i*32 + 28);
|
|
|
|
|
const uint8x8_t m4b = vdup_n_u8(0xf);
|
|
|
|
|
const int8x8_t s8b = vdup_n_s8(0x8);
|
|
|
|
|
|
|
|
|
|
const uint8x16_t m4b = vdupq_n_u8(0xf);
|
|
|
|
|
const int8x16_t s8b = vdupq_n_s8(0x8);
|
|
|
|
|
const float32x4_t vd = vdupq_n_f32(d0);
|
|
|
|
|
|
|
|
|
|
const uint8x16_t vx = vld1q_u8(pp);
|
|
|
|
|
for (int j = 0; j < 2; j++) {
|
|
|
|
|
const uint8x8_t vx = vld1_u8(pp + j*8);
|
|
|
|
|
|
|
|
|
|
const uint8x16_t vxl = vandq_u8 (vx, m4b);
|
|
|
|
|
const uint8x16_t vxh = vshrq_n_u8(vx, 4);
|
|
|
|
|
const uint8x8_t vxl = vand_u8 (vx, m4b);
|
|
|
|
|
const uint8x8_t vxh = vshr_n_u8(vx, 4);
|
|
|
|
|
|
|
|
|
|
// sub 8
|
|
|
|
|
const int8x16_t vxls = vsubq_s8(vxl, s8b);
|
|
|
|
|
const int8x16_t vxhs = vsubq_s8(vxh, s8b);
|
|
|
|
|
|
|
|
|
|
const int8x16_t vxlt = vzip1q_s8(vxls, vxhs);
|
|
|
|
|
const int8x16_t vxht = vzip2q_s8(vxls, vxhs);
|
|
|
|
|
|
|
|
|
|
//printf("vxv: %3d %3d %3d %3d %3d %3d %3d %3d %3d %3d %3d %3d %3d %3d %3d %3d ",
|
|
|
|
|
// vgetq_lane_s8(vxlt, 0), vgetq_lane_s8(vxlt, 1), vgetq_lane_s8(vxlt, 2), vgetq_lane_s8(vxlt, 3),
|
|
|
|
|
// vgetq_lane_s8(vxlt, 4), vgetq_lane_s8(vxlt, 5), vgetq_lane_s8(vxlt, 6), vgetq_lane_s8(vxlt, 7),
|
|
|
|
|
// vgetq_lane_s8(vxlt, 8), vgetq_lane_s8(vxlt, 9), vgetq_lane_s8(vxlt, 10), vgetq_lane_s8(vxlt, 11),
|
|
|
|
|
// vgetq_lane_s8(vxlt, 12), vgetq_lane_s8(vxlt, 13), vgetq_lane_s8(vxlt, 14), vgetq_lane_s8(vxlt, 15));
|
|
|
|
|
//printf("%3d %3d %3d %3d %3d %3d %3d %3d %3d %3d %3d %3d %3d %3d %3d %3d\n",
|
|
|
|
|
// vgetq_lane_s8(vxht, 0), vgetq_lane_s8(vxht, 1), vgetq_lane_s8(vxht, 2), vgetq_lane_s8(vxht, 3),
|
|
|
|
|
// vgetq_lane_s8(vxht, 4), vgetq_lane_s8(vxht, 5), vgetq_lane_s8(vxht, 6), vgetq_lane_s8(vxht, 7),
|
|
|
|
|
// vgetq_lane_s8(vxht, 8), vgetq_lane_s8(vxht, 9), vgetq_lane_s8(vxht, 10), vgetq_lane_s8(vxht, 11),
|
|
|
|
|
// vgetq_lane_s8(vxht, 12), vgetq_lane_s8(vxht, 13), vgetq_lane_s8(vxht, 14), vgetq_lane_s8(vxht, 15));
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// convert to 4x int16x8_t
|
|
|
|
|
const int16x8_t vxls0 = vmovl_s8(vget_low_s8 (vxlt));
|
|
|
|
|
const int16x8_t vxls1 = vmovl_s8(vget_high_s8(vxlt));
|
|
|
|
|
const int16x8_t vxhs0 = vmovl_s8(vget_low_s8 (vxht));
|
|
|
|
|
const int16x8_t vxhs1 = vmovl_s8(vget_high_s8(vxht));
|
|
|
|
|
|
|
|
|
|
// convert to 8x float32x4_t
|
|
|
|
|
const float32x4_t vx0 = vcvtq_f32_s32(vmovl_s16(vget_low_s16 (vxls0)));
|
|
|
|
|
const float32x4_t vx1 = vcvtq_f32_s32(vmovl_s16(vget_high_s16(vxls0)));
|
|
|
|
|
const float32x4_t vx2 = vcvtq_f32_s32(vmovl_s16(vget_low_s16 (vxls1)));
|
|
|
|
|
const float32x4_t vx3 = vcvtq_f32_s32(vmovl_s16(vget_high_s16(vxls1)));
|
|
|
|
|
const float32x4_t vx4 = vcvtq_f32_s32(vmovl_s16(vget_low_s16 (vxhs0)));
|
|
|
|
|
const float32x4_t vx5 = vcvtq_f32_s32(vmovl_s16(vget_high_s16(vxhs0)));
|
|
|
|
|
const float32x4_t vx6 = vcvtq_f32_s32(vmovl_s16(vget_low_s16 (vxhs1)));
|
|
|
|
|
const float32x4_t vx7 = vcvtq_f32_s32(vmovl_s16(vget_high_s16(vxhs1)));
|
|
|
|
|
|
|
|
|
|
//printf("vxv: %6.2f %6.2f %6.2f %6.2f %6.2f %6.2f %6.2f %6.2f %6.2f %6.2f %6.2f %6.2f %6.2f %6.2f %6.2f %6.2f\n",
|
|
|
|
|
// vgetq_lane_f32(vx0m, 0), vgetq_lane_f32(vx0m, 1), vgetq_lane_f32(vx0m, 2), vgetq_lane_f32(vx0m, 3),
|
|
|
|
|
// vgetq_lane_f32(vx1m, 0), vgetq_lane_f32(vx1m, 1), vgetq_lane_f32(vx1m, 2), vgetq_lane_f32(vx1m, 3),
|
|
|
|
|
// vgetq_lane_f32(vx2m, 0), vgetq_lane_f32(vx2m, 1), vgetq_lane_f32(vx2m, 2), vgetq_lane_f32(vx2m, 3),
|
|
|
|
|
// vgetq_lane_f32(vx3m, 0), vgetq_lane_f32(vx3m, 1), vgetq_lane_f32(vx3m, 2), vgetq_lane_f32(vx3m, 3));
|
|
|
|
|
|
|
|
|
|
const float32x4_t vr0 = vfmaq_f32(vy0, vx0, vd);
|
|
|
|
|
const float32x4_t vr1 = vfmaq_f32(vy1, vx1, vd);
|
|
|
|
|
const float32x4_t vr2 = vfmaq_f32(vy2, vx2, vd);
|
|
|
|
|
const float32x4_t vr3 = vfmaq_f32(vy3, vx3, vd);
|
|
|
|
|
const float32x4_t vr4 = vfmaq_f32(vy4, vx4, vd);
|
|
|
|
|
const float32x4_t vr5 = vfmaq_f32(vy5, vx5, vd);
|
|
|
|
|
const float32x4_t vr6 = vfmaq_f32(vy6, vx6, vd);
|
|
|
|
|
const float32x4_t vr7 = vfmaq_f32(vy7, vx7, vd);
|
|
|
|
|
|
|
|
|
|
vst1q_f32(y + i*32 + 0, vr0);
|
|
|
|
|
vst1q_f32(y + i*32 + 4, vr1);
|
|
|
|
|
vst1q_f32(y + i*32 + 8, vr2);
|
|
|
|
|
vst1q_f32(y + i*32 + 12, vr3);
|
|
|
|
|
vst1q_f32(y + i*32 + 16, vr4);
|
|
|
|
|
vst1q_f32(y + i*32 + 20, vr5);
|
|
|
|
|
vst1q_f32(y + i*32 + 24, vr6);
|
|
|
|
|
vst1q_f32(y + i*32 + 28, vr7);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
//printf("------\n");
|
|
|
|
|
// sub 8
|
|
|
|
|
const int8x8_t vxls = vsub_s8(vxl, s8b);
|
|
|
|
|
const int8x8_t vxhs = vsub_s8(vxh, s8b);
|
|
|
|
|
|
|
|
|
|
//for (int i = 0; i < nb; i++) {
|
|
|
|
|
// const float d = pd[i];
|
|
|
|
|
const int8x8_t vxlt = vzip1_s8(vxls, vxhs);
|
|
|
|
|
const int8x8_t vxht = vzip2_s8(vxls, vxhs);
|
|
|
|
|
|
|
|
|
|
// const uint8_t * restrict pp = pb + i*QK/2;
|
|
|
|
|
const int8x16_t vxq = vcombine_s8(vxlt, vxht);
|
|
|
|
|
|
|
|
|
|
// printf("vxs: ");
|
|
|
|
|
// for (int l = 0; l < QK; l += 2) {
|
|
|
|
|
// const uint8_t vi = pp[l/2];
|
|
|
|
|
// convert to 2x int16x8_t
|
|
|
|
|
const int16x8_t vxq0 = vmovl_s8(vget_low_s8 (vxq));
|
|
|
|
|
const int16x8_t vxq1 = vmovl_s8(vget_high_s8(vxq));
|
|
|
|
|
|
|
|
|
|
// const int8_t vi0 = vi & 0xf;
|
|
|
|
|
// const int8_t vi1 = vi >> 4;
|
|
|
|
|
// convert to 4x float32x4_t
|
|
|
|
|
const float32x4_t vx0 = vcvtq_f32_s32(vmovl_s16(vget_low_s16 (vxq0)));
|
|
|
|
|
const float32x4_t vx1 = vcvtq_f32_s32(vmovl_s16(vget_high_s16(vxq0)));
|
|
|
|
|
const float32x4_t vx2 = vcvtq_f32_s32(vmovl_s16(vget_low_s16 (vxq1)));
|
|
|
|
|
const float32x4_t vx3 = vcvtq_f32_s32(vmovl_s16(vget_high_s16(vxq1)));
|
|
|
|
|
|
|
|
|
|
// const float v0 = (vi0 - 8)*d;
|
|
|
|
|
// const float v1 = (vi1 - 8)*d;
|
|
|
|
|
const float32x4_t vy0 = vld1q_f32(y + i*32 + j*16 + 0);
|
|
|
|
|
const float32x4_t vy1 = vld1q_f32(y + i*32 + j*16 + 4);
|
|
|
|
|
const float32x4_t vy2 = vld1q_f32(y + i*32 + j*16 + 8);
|
|
|
|
|
const float32x4_t vy3 = vld1q_f32(y + i*32 + j*16 + 12);
|
|
|
|
|
|
|
|
|
|
// y[i*QK + l + 0] += v0*v;
|
|
|
|
|
// y[i*QK + l + 1] += v1*v;
|
|
|
|
|
const float32x4_t vr0 = vfmaq_f32(vy0, vx0, vd);
|
|
|
|
|
const float32x4_t vr1 = vfmaq_f32(vy1, vx1, vd);
|
|
|
|
|
const float32x4_t vr2 = vfmaq_f32(vy2, vx2, vd);
|
|
|
|
|
const float32x4_t vr3 = vfmaq_f32(vy3, vx3, vd);
|
|
|
|
|
|
|
|
|
|
// if (l < QK) {
|
|
|
|
|
// //printf("%6.2f %6.2f ", v0, v1);
|
|
|
|
|
// printf("%3d %3d ", vi0 - 8, vi1 - 8);
|
|
|
|
|
// }
|
|
|
|
|
// }
|
|
|
|
|
// printf("\n");
|
|
|
|
|
//}
|
|
|
|
|
|
|
|
|
|
//exit(0);
|
|
|
|
|
vst1q_f32(y + i*32 + j*16 + 0, vr0);
|
|
|
|
|
vst1q_f32(y + i*32 + j*16 + 4, vr1);
|
|
|
|
|
vst1q_f32(y + i*32 + j*16 + 8, vr2);
|
|
|
|
|
vst1q_f32(y + i*32 + j*16 + 12, vr3);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
#endif
|
|
|
|
|
#else
|
|
|
|
|
// scalar
|
|
|
|
@ -1593,7 +1531,6 @@ inline static void ggml_vec_mad_q4_0(const int n, float * restrict y, void * res
|
|
|
|
|
assert(!isnan(y[i*QK + l + 1]));
|
|
|
|
|
assert(!isinf(y[i*QK + l + 0]));
|
|
|
|
|
assert(!isinf(y[i*QK + l + 1]));
|
|
|
|
|
//printf("mad: v0 %f v1 %f, i = %d, l = %d, d = %f, vi = %d, vi0 = %d, vi1 = %d\n", v0, v1, i, l, d, vi, vi0, vi1);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
#endif
|
|
|
|
|