ggml : simplify mad q4_0 (ARM)

gq
Georgi Gerganov 1 year ago
parent 6309a60bac
commit e89cb32625
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735

@ -1256,7 +1256,7 @@ inline static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void
//exit(0);
#else
#error "not implemented"
#error "not implemented for QK"
#endif
#else
// scalar
@ -1461,114 +1461,52 @@ inline static void ggml_vec_mad_q4_0(const int n, float * restrict y, void * res
const uint8_t * restrict pp = pb + i*16;
const float32x4_t vd = vdupq_n_f32(d0);
const float32x4_t vy0 = vld1q_f32(y + i*32 + 0);
const float32x4_t vy1 = vld1q_f32(y + i*32 + 4);
const float32x4_t vy2 = vld1q_f32(y + i*32 + 8);
const float32x4_t vy3 = vld1q_f32(y + i*32 + 12);
const float32x4_t vy4 = vld1q_f32(y + i*32 + 16);
const float32x4_t vy5 = vld1q_f32(y + i*32 + 20);
const float32x4_t vy6 = vld1q_f32(y + i*32 + 24);
const float32x4_t vy7 = vld1q_f32(y + i*32 + 28);
const uint8x8_t m4b = vdup_n_u8(0xf);
const int8x8_t s8b = vdup_n_s8(0x8);
const uint8x16_t m4b = vdupq_n_u8(0xf);
const int8x16_t s8b = vdupq_n_s8(0x8);
const float32x4_t vd = vdupq_n_f32(d0);
const uint8x16_t vx = vld1q_u8(pp);
for (int j = 0; j < 2; j++) {
const uint8x8_t vx = vld1_u8(pp + j*8);
const uint8x16_t vxl = vandq_u8 (vx, m4b);
const uint8x16_t vxh = vshrq_n_u8(vx, 4);
const uint8x8_t vxl = vand_u8 (vx, m4b);
const uint8x8_t vxh = vshr_n_u8(vx, 4);
// sub 8
const int8x16_t vxls = vsubq_s8(vxl, s8b);
const int8x16_t vxhs = vsubq_s8(vxh, s8b);
const int8x16_t vxlt = vzip1q_s8(vxls, vxhs);
const int8x16_t vxht = vzip2q_s8(vxls, vxhs);
//printf("vxv: %3d %3d %3d %3d %3d %3d %3d %3d %3d %3d %3d %3d %3d %3d %3d %3d ",
// vgetq_lane_s8(vxlt, 0), vgetq_lane_s8(vxlt, 1), vgetq_lane_s8(vxlt, 2), vgetq_lane_s8(vxlt, 3),
// vgetq_lane_s8(vxlt, 4), vgetq_lane_s8(vxlt, 5), vgetq_lane_s8(vxlt, 6), vgetq_lane_s8(vxlt, 7),
// vgetq_lane_s8(vxlt, 8), vgetq_lane_s8(vxlt, 9), vgetq_lane_s8(vxlt, 10), vgetq_lane_s8(vxlt, 11),
// vgetq_lane_s8(vxlt, 12), vgetq_lane_s8(vxlt, 13), vgetq_lane_s8(vxlt, 14), vgetq_lane_s8(vxlt, 15));
//printf("%3d %3d %3d %3d %3d %3d %3d %3d %3d %3d %3d %3d %3d %3d %3d %3d\n",
// vgetq_lane_s8(vxht, 0), vgetq_lane_s8(vxht, 1), vgetq_lane_s8(vxht, 2), vgetq_lane_s8(vxht, 3),
// vgetq_lane_s8(vxht, 4), vgetq_lane_s8(vxht, 5), vgetq_lane_s8(vxht, 6), vgetq_lane_s8(vxht, 7),
// vgetq_lane_s8(vxht, 8), vgetq_lane_s8(vxht, 9), vgetq_lane_s8(vxht, 10), vgetq_lane_s8(vxht, 11),
// vgetq_lane_s8(vxht, 12), vgetq_lane_s8(vxht, 13), vgetq_lane_s8(vxht, 14), vgetq_lane_s8(vxht, 15));
// convert to 4x int16x8_t
const int16x8_t vxls0 = vmovl_s8(vget_low_s8 (vxlt));
const int16x8_t vxls1 = vmovl_s8(vget_high_s8(vxlt));
const int16x8_t vxhs0 = vmovl_s8(vget_low_s8 (vxht));
const int16x8_t vxhs1 = vmovl_s8(vget_high_s8(vxht));
// convert to 8x float32x4_t
const float32x4_t vx0 = vcvtq_f32_s32(vmovl_s16(vget_low_s16 (vxls0)));
const float32x4_t vx1 = vcvtq_f32_s32(vmovl_s16(vget_high_s16(vxls0)));
const float32x4_t vx2 = vcvtq_f32_s32(vmovl_s16(vget_low_s16 (vxls1)));
const float32x4_t vx3 = vcvtq_f32_s32(vmovl_s16(vget_high_s16(vxls1)));
const float32x4_t vx4 = vcvtq_f32_s32(vmovl_s16(vget_low_s16 (vxhs0)));
const float32x4_t vx5 = vcvtq_f32_s32(vmovl_s16(vget_high_s16(vxhs0)));
const float32x4_t vx6 = vcvtq_f32_s32(vmovl_s16(vget_low_s16 (vxhs1)));
const float32x4_t vx7 = vcvtq_f32_s32(vmovl_s16(vget_high_s16(vxhs1)));
//printf("vxv: %6.2f %6.2f %6.2f %6.2f %6.2f %6.2f %6.2f %6.2f %6.2f %6.2f %6.2f %6.2f %6.2f %6.2f %6.2f %6.2f\n",
// vgetq_lane_f32(vx0m, 0), vgetq_lane_f32(vx0m, 1), vgetq_lane_f32(vx0m, 2), vgetq_lane_f32(vx0m, 3),
// vgetq_lane_f32(vx1m, 0), vgetq_lane_f32(vx1m, 1), vgetq_lane_f32(vx1m, 2), vgetq_lane_f32(vx1m, 3),
// vgetq_lane_f32(vx2m, 0), vgetq_lane_f32(vx2m, 1), vgetq_lane_f32(vx2m, 2), vgetq_lane_f32(vx2m, 3),
// vgetq_lane_f32(vx3m, 0), vgetq_lane_f32(vx3m, 1), vgetq_lane_f32(vx3m, 2), vgetq_lane_f32(vx3m, 3));
const float32x4_t vr0 = vfmaq_f32(vy0, vx0, vd);
const float32x4_t vr1 = vfmaq_f32(vy1, vx1, vd);
const float32x4_t vr2 = vfmaq_f32(vy2, vx2, vd);
const float32x4_t vr3 = vfmaq_f32(vy3, vx3, vd);
const float32x4_t vr4 = vfmaq_f32(vy4, vx4, vd);
const float32x4_t vr5 = vfmaq_f32(vy5, vx5, vd);
const float32x4_t vr6 = vfmaq_f32(vy6, vx6, vd);
const float32x4_t vr7 = vfmaq_f32(vy7, vx7, vd);
vst1q_f32(y + i*32 + 0, vr0);
vst1q_f32(y + i*32 + 4, vr1);
vst1q_f32(y + i*32 + 8, vr2);
vst1q_f32(y + i*32 + 12, vr3);
vst1q_f32(y + i*32 + 16, vr4);
vst1q_f32(y + i*32 + 20, vr5);
vst1q_f32(y + i*32 + 24, vr6);
vst1q_f32(y + i*32 + 28, vr7);
}
//printf("------\n");
// sub 8
const int8x8_t vxls = vsub_s8(vxl, s8b);
const int8x8_t vxhs = vsub_s8(vxh, s8b);
//for (int i = 0; i < nb; i++) {
// const float d = pd[i];
const int8x8_t vxlt = vzip1_s8(vxls, vxhs);
const int8x8_t vxht = vzip2_s8(vxls, vxhs);
// const uint8_t * restrict pp = pb + i*QK/2;
const int8x16_t vxq = vcombine_s8(vxlt, vxht);
// printf("vxs: ");
// for (int l = 0; l < QK; l += 2) {
// const uint8_t vi = pp[l/2];
// convert to 2x int16x8_t
const int16x8_t vxq0 = vmovl_s8(vget_low_s8 (vxq));
const int16x8_t vxq1 = vmovl_s8(vget_high_s8(vxq));
// const int8_t vi0 = vi & 0xf;
// const int8_t vi1 = vi >> 4;
// convert to 4x float32x4_t
const float32x4_t vx0 = vcvtq_f32_s32(vmovl_s16(vget_low_s16 (vxq0)));
const float32x4_t vx1 = vcvtq_f32_s32(vmovl_s16(vget_high_s16(vxq0)));
const float32x4_t vx2 = vcvtq_f32_s32(vmovl_s16(vget_low_s16 (vxq1)));
const float32x4_t vx3 = vcvtq_f32_s32(vmovl_s16(vget_high_s16(vxq1)));
// const float v0 = (vi0 - 8)*d;
// const float v1 = (vi1 - 8)*d;
const float32x4_t vy0 = vld1q_f32(y + i*32 + j*16 + 0);
const float32x4_t vy1 = vld1q_f32(y + i*32 + j*16 + 4);
const float32x4_t vy2 = vld1q_f32(y + i*32 + j*16 + 8);
const float32x4_t vy3 = vld1q_f32(y + i*32 + j*16 + 12);
// y[i*QK + l + 0] += v0*v;
// y[i*QK + l + 1] += v1*v;
const float32x4_t vr0 = vfmaq_f32(vy0, vx0, vd);
const float32x4_t vr1 = vfmaq_f32(vy1, vx1, vd);
const float32x4_t vr2 = vfmaq_f32(vy2, vx2, vd);
const float32x4_t vr3 = vfmaq_f32(vy3, vx3, vd);
// if (l < QK) {
// //printf("%6.2f %6.2f ", v0, v1);
// printf("%3d %3d ", vi0 - 8, vi1 - 8);
// }
// }
// printf("\n");
//}
//exit(0);
vst1q_f32(y + i*32 + j*16 + 0, vr0);
vst1q_f32(y + i*32 + j*16 + 4, vr1);
vst1q_f32(y + i*32 + j*16 + 8, vr2);
vst1q_f32(y + i*32 + j*16 + 12, vr3);
}
}
#endif
#else
// scalar
@ -1593,7 +1531,6 @@ inline static void ggml_vec_mad_q4_0(const int n, float * restrict y, void * res
assert(!isnan(y[i*QK + l + 1]));
assert(!isinf(y[i*QK + l + 0]));
assert(!isinf(y[i*QK + l + 1]));
//printf("mad: v0 %f v1 %f, i = %d, l = %d, d = %f, vi = %d, vi0 = %d, vi1 = %d\n", v0, v1, i, l, d, vi, vi0, vi1);
}
}
#endif

Loading…
Cancel
Save