diff --git a/src/ggml.c b/src/ggml.c index 2b59648..18f143a 100644 --- a/src/ggml.c +++ b/src/ggml.c @@ -1256,7 +1256,7 @@ inline static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void //exit(0); #else -#error "not implemented" +#error "not implemented for QK" #endif #else // scalar @@ -1461,114 +1461,52 @@ inline static void ggml_vec_mad_q4_0(const int n, float * restrict y, void * res const uint8_t * restrict pp = pb + i*16; - const float32x4_t vd = vdupq_n_f32(d0); - - const float32x4_t vy0 = vld1q_f32(y + i*32 + 0); - const float32x4_t vy1 = vld1q_f32(y + i*32 + 4); - const float32x4_t vy2 = vld1q_f32(y + i*32 + 8); - const float32x4_t vy3 = vld1q_f32(y + i*32 + 12); - const float32x4_t vy4 = vld1q_f32(y + i*32 + 16); - const float32x4_t vy5 = vld1q_f32(y + i*32 + 20); - const float32x4_t vy6 = vld1q_f32(y + i*32 + 24); - const float32x4_t vy7 = vld1q_f32(y + i*32 + 28); + const uint8x8_t m4b = vdup_n_u8(0xf); + const int8x8_t s8b = vdup_n_s8(0x8); - const uint8x16_t m4b = vdupq_n_u8(0xf); - const int8x16_t s8b = vdupq_n_s8(0x8); + const float32x4_t vd = vdupq_n_f32(d0); - const uint8x16_t vx = vld1q_u8(pp); + for (int j = 0; j < 2; j++) { + const uint8x8_t vx = vld1_u8(pp + j*8); - const uint8x16_t vxl = vandq_u8 (vx, m4b); - const uint8x16_t vxh = vshrq_n_u8(vx, 4); + const uint8x8_t vxl = vand_u8 (vx, m4b); + const uint8x8_t vxh = vshr_n_u8(vx, 4); - // sub 8 - const int8x16_t vxls = vsubq_s8(vxl, s8b); - const int8x16_t vxhs = vsubq_s8(vxh, s8b); - - const int8x16_t vxlt = vzip1q_s8(vxls, vxhs); - const int8x16_t vxht = vzip2q_s8(vxls, vxhs); - - //printf("vxv: %3d %3d %3d %3d %3d %3d %3d %3d %3d %3d %3d %3d %3d %3d %3d %3d ", - // vgetq_lane_s8(vxlt, 0), vgetq_lane_s8(vxlt, 1), vgetq_lane_s8(vxlt, 2), vgetq_lane_s8(vxlt, 3), - // vgetq_lane_s8(vxlt, 4), vgetq_lane_s8(vxlt, 5), vgetq_lane_s8(vxlt, 6), vgetq_lane_s8(vxlt, 7), - // vgetq_lane_s8(vxlt, 8), vgetq_lane_s8(vxlt, 9), vgetq_lane_s8(vxlt, 10), vgetq_lane_s8(vxlt, 11), - // vgetq_lane_s8(vxlt, 12), vgetq_lane_s8(vxlt, 13), vgetq_lane_s8(vxlt, 14), vgetq_lane_s8(vxlt, 15)); - //printf("%3d %3d %3d %3d %3d %3d %3d %3d %3d %3d %3d %3d %3d %3d %3d %3d\n", - // vgetq_lane_s8(vxht, 0), vgetq_lane_s8(vxht, 1), vgetq_lane_s8(vxht, 2), vgetq_lane_s8(vxht, 3), - // vgetq_lane_s8(vxht, 4), vgetq_lane_s8(vxht, 5), vgetq_lane_s8(vxht, 6), vgetq_lane_s8(vxht, 7), - // vgetq_lane_s8(vxht, 8), vgetq_lane_s8(vxht, 9), vgetq_lane_s8(vxht, 10), vgetq_lane_s8(vxht, 11), - // vgetq_lane_s8(vxht, 12), vgetq_lane_s8(vxht, 13), vgetq_lane_s8(vxht, 14), vgetq_lane_s8(vxht, 15)); - - - // convert to 4x int16x8_t - const int16x8_t vxls0 = vmovl_s8(vget_low_s8 (vxlt)); - const int16x8_t vxls1 = vmovl_s8(vget_high_s8(vxlt)); - const int16x8_t vxhs0 = vmovl_s8(vget_low_s8 (vxht)); - const int16x8_t vxhs1 = vmovl_s8(vget_high_s8(vxht)); - - // convert to 8x float32x4_t - const float32x4_t vx0 = vcvtq_f32_s32(vmovl_s16(vget_low_s16 (vxls0))); - const float32x4_t vx1 = vcvtq_f32_s32(vmovl_s16(vget_high_s16(vxls0))); - const float32x4_t vx2 = vcvtq_f32_s32(vmovl_s16(vget_low_s16 (vxls1))); - const float32x4_t vx3 = vcvtq_f32_s32(vmovl_s16(vget_high_s16(vxls1))); - const float32x4_t vx4 = vcvtq_f32_s32(vmovl_s16(vget_low_s16 (vxhs0))); - const float32x4_t vx5 = vcvtq_f32_s32(vmovl_s16(vget_high_s16(vxhs0))); - const float32x4_t vx6 = vcvtq_f32_s32(vmovl_s16(vget_low_s16 (vxhs1))); - const float32x4_t vx7 = vcvtq_f32_s32(vmovl_s16(vget_high_s16(vxhs1))); - - //printf("vxv: %6.2f %6.2f %6.2f %6.2f %6.2f %6.2f %6.2f %6.2f %6.2f %6.2f %6.2f %6.2f %6.2f %6.2f %6.2f %6.2f\n", - // vgetq_lane_f32(vx0m, 0), vgetq_lane_f32(vx0m, 1), vgetq_lane_f32(vx0m, 2), vgetq_lane_f32(vx0m, 3), - // vgetq_lane_f32(vx1m, 0), vgetq_lane_f32(vx1m, 1), vgetq_lane_f32(vx1m, 2), vgetq_lane_f32(vx1m, 3), - // vgetq_lane_f32(vx2m, 0), vgetq_lane_f32(vx2m, 1), vgetq_lane_f32(vx2m, 2), vgetq_lane_f32(vx2m, 3), - // vgetq_lane_f32(vx3m, 0), vgetq_lane_f32(vx3m, 1), vgetq_lane_f32(vx3m, 2), vgetq_lane_f32(vx3m, 3)); - - const float32x4_t vr0 = vfmaq_f32(vy0, vx0, vd); - const float32x4_t vr1 = vfmaq_f32(vy1, vx1, vd); - const float32x4_t vr2 = vfmaq_f32(vy2, vx2, vd); - const float32x4_t vr3 = vfmaq_f32(vy3, vx3, vd); - const float32x4_t vr4 = vfmaq_f32(vy4, vx4, vd); - const float32x4_t vr5 = vfmaq_f32(vy5, vx5, vd); - const float32x4_t vr6 = vfmaq_f32(vy6, vx6, vd); - const float32x4_t vr7 = vfmaq_f32(vy7, vx7, vd); - - vst1q_f32(y + i*32 + 0, vr0); - vst1q_f32(y + i*32 + 4, vr1); - vst1q_f32(y + i*32 + 8, vr2); - vst1q_f32(y + i*32 + 12, vr3); - vst1q_f32(y + i*32 + 16, vr4); - vst1q_f32(y + i*32 + 20, vr5); - vst1q_f32(y + i*32 + 24, vr6); - vst1q_f32(y + i*32 + 28, vr7); - } - - //printf("------\n"); + // sub 8 + const int8x8_t vxls = vsub_s8(vxl, s8b); + const int8x8_t vxhs = vsub_s8(vxh, s8b); - //for (int i = 0; i < nb; i++) { - // const float d = pd[i]; + const int8x8_t vxlt = vzip1_s8(vxls, vxhs); + const int8x8_t vxht = vzip2_s8(vxls, vxhs); - // const uint8_t * restrict pp = pb + i*QK/2; + const int8x16_t vxq = vcombine_s8(vxlt, vxht); - // printf("vxs: "); - // for (int l = 0; l < QK; l += 2) { - // const uint8_t vi = pp[l/2]; + // convert to 2x int16x8_t + const int16x8_t vxq0 = vmovl_s8(vget_low_s8 (vxq)); + const int16x8_t vxq1 = vmovl_s8(vget_high_s8(vxq)); - // const int8_t vi0 = vi & 0xf; - // const int8_t vi1 = vi >> 4; + // convert to 4x float32x4_t + const float32x4_t vx0 = vcvtq_f32_s32(vmovl_s16(vget_low_s16 (vxq0))); + const float32x4_t vx1 = vcvtq_f32_s32(vmovl_s16(vget_high_s16(vxq0))); + const float32x4_t vx2 = vcvtq_f32_s32(vmovl_s16(vget_low_s16 (vxq1))); + const float32x4_t vx3 = vcvtq_f32_s32(vmovl_s16(vget_high_s16(vxq1))); - // const float v0 = (vi0 - 8)*d; - // const float v1 = (vi1 - 8)*d; + const float32x4_t vy0 = vld1q_f32(y + i*32 + j*16 + 0); + const float32x4_t vy1 = vld1q_f32(y + i*32 + j*16 + 4); + const float32x4_t vy2 = vld1q_f32(y + i*32 + j*16 + 8); + const float32x4_t vy3 = vld1q_f32(y + i*32 + j*16 + 12); - // y[i*QK + l + 0] += v0*v; - // y[i*QK + l + 1] += v1*v; + const float32x4_t vr0 = vfmaq_f32(vy0, vx0, vd); + const float32x4_t vr1 = vfmaq_f32(vy1, vx1, vd); + const float32x4_t vr2 = vfmaq_f32(vy2, vx2, vd); + const float32x4_t vr3 = vfmaq_f32(vy3, vx3, vd); - // if (l < QK) { - // //printf("%6.2f %6.2f ", v0, v1); - // printf("%3d %3d ", vi0 - 8, vi1 - 8); - // } - // } - // printf("\n"); - //} - - //exit(0); + vst1q_f32(y + i*32 + j*16 + 0, vr0); + vst1q_f32(y + i*32 + j*16 + 4, vr1); + vst1q_f32(y + i*32 + j*16 + 8, vr2); + vst1q_f32(y + i*32 + j*16 + 12, vr3); + } + } #endif #else // scalar @@ -1593,7 +1531,6 @@ inline static void ggml_vec_mad_q4_0(const int n, float * restrict y, void * res assert(!isnan(y[i*QK + l + 1])); assert(!isinf(y[i*QK + l + 0])); assert(!isinf(y[i*QK + l + 1])); - //printf("mad: v0 %f v1 %f, i = %d, l = %d, d = %f, vi = %d, vi0 = %d, vi1 = %d\n", v0, v1, i, l, d, vi, vi0, vi1); } } #endif