|
|
|
@ -27,7 +27,7 @@ const int N = 1536;
|
|
|
|
|
const int K = 1280;
|
|
|
|
|
|
|
|
|
|
const int QK = 64;
|
|
|
|
|
#define QB 7
|
|
|
|
|
#define QB 4
|
|
|
|
|
|
|
|
|
|
//#define GGML_GQ_USE_FP16_SCALE
|
|
|
|
|
|
|
|
|
@ -44,6 +44,10 @@ const int QK = 64;
|
|
|
|
|
#define gq_quant_t uint64_t
|
|
|
|
|
#define gq_t_bits 64
|
|
|
|
|
|
|
|
|
|
float frand() {
|
|
|
|
|
return (float) rand() / (float) RAND_MAX;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
uint64_t get_time_us() {
|
|
|
|
|
struct timeval tv;
|
|
|
|
|
gettimeofday(&tv, NULL);
|
|
|
|
@ -244,15 +248,41 @@ void quantize_2_row(const float * restrict src, void * restrict dst, int k) {
|
|
|
|
|
|
|
|
|
|
gq_quant_t pp[QB];
|
|
|
|
|
|
|
|
|
|
static const int32_t sh[32] = {
|
|
|
|
|
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
|
|
|
|
|
16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31,
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
for (int i = 0; i < nb; i++) {
|
|
|
|
|
float min = FLT_MAX;
|
|
|
|
|
float max = -FLT_MAX;
|
|
|
|
|
|
|
|
|
|
#ifdef __ARM_NEON
|
|
|
|
|
{
|
|
|
|
|
float32x4_t minv = vdupq_n_f32(FLT_MAX);
|
|
|
|
|
float32x4_t maxv = vdupq_n_f32(-FLT_MAX);
|
|
|
|
|
|
|
|
|
|
for (int l = 0; l < QK; l += 4) {
|
|
|
|
|
float32x4_t v = vld1q_f32(src + i*QK + l);
|
|
|
|
|
minv = vminq_f32(minv, v);
|
|
|
|
|
maxv = vmaxq_f32(maxv, v);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
float32x2_t minv32 = vpmin_f32(vget_low_f32(minv), vget_high_f32(minv));
|
|
|
|
|
float32x2_t maxv32 = vpmax_f32(vget_low_f32(maxv), vget_high_f32(maxv));
|
|
|
|
|
|
|
|
|
|
min = MIN(vget_lane_f32(minv32, 0), vget_lane_f32(minv32, 1));
|
|
|
|
|
max = MAX(vget_lane_f32(maxv32, 0), vget_lane_f32(maxv32, 1));
|
|
|
|
|
}
|
|
|
|
|
#else
|
|
|
|
|
{
|
|
|
|
|
for (int l = 0; l < QK; l++) {
|
|
|
|
|
const float v = src[i*QK + l];
|
|
|
|
|
if (v < min) min = v;
|
|
|
|
|
if (v > max) max = v;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
#endif
|
|
|
|
|
|
|
|
|
|
const float d = (max - min) / ((1 << QB) - 1);
|
|
|
|
|
const float id = d ? 1.0/d : 0.0;
|
|
|
|
@ -263,19 +293,142 @@ void quantize_2_row(const float * restrict src, void * restrict dst, int k) {
|
|
|
|
|
for (int s = 0; s < nq; ++s) {
|
|
|
|
|
memset(pp, 0, sizeof(pp));
|
|
|
|
|
|
|
|
|
|
#if 0
|
|
|
|
|
for (int l = 0; l < gq_t_bits; l++) {
|
|
|
|
|
const float v = src[i*QK + s*gq_t_bits + l];
|
|
|
|
|
const uint8_t q = (v - min)*id;
|
|
|
|
|
const uint8_t q = (v - min)*id + frand();
|
|
|
|
|
|
|
|
|
|
for (int b = 0; b < QB; b++) {
|
|
|
|
|
pp[b] |= q & (1 << b) ? (1ULL << l) : 0;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
#elif defined(__ARM_NEON)
|
|
|
|
|
#if 1
|
|
|
|
|
{
|
|
|
|
|
uint32_t ppt[2*4*QB];
|
|
|
|
|
|
|
|
|
|
float32x4_t minv = vdupq_n_f32(min);
|
|
|
|
|
float32x4_t idv = vdupq_n_f32(id);
|
|
|
|
|
|
|
|
|
|
assert(gq_t_bits == 64);
|
|
|
|
|
|
|
|
|
|
uint32x4_t p0[QB] = { vdupq_n_u32(0) };
|
|
|
|
|
uint32x4_t p1[QB] = { vdupq_n_u32(0) };
|
|
|
|
|
|
|
|
|
|
for (int l = 0; l < gq_t_bits; l += 16) {
|
|
|
|
|
float32x4_t v0 = vld1q_f32(src + i*QK + s*gq_t_bits + l + 0);
|
|
|
|
|
float32x4_t v1 = vld1q_f32(src + i*QK + s*gq_t_bits + l + 4);
|
|
|
|
|
float32x4_t v2 = vld1q_f32(src + i*QK + s*gq_t_bits + l + 8);
|
|
|
|
|
float32x4_t v3 = vld1q_f32(src + i*QK + s*gq_t_bits + l + 12);
|
|
|
|
|
|
|
|
|
|
v0 = vsubq_f32(v0, minv);
|
|
|
|
|
v1 = vsubq_f32(v1, minv);
|
|
|
|
|
v2 = vsubq_f32(v2, minv);
|
|
|
|
|
v3 = vsubq_f32(v3, minv);
|
|
|
|
|
|
|
|
|
|
v0 = vmulq_f32(v0, idv);
|
|
|
|
|
v1 = vmulq_f32(v1, idv);
|
|
|
|
|
v2 = vmulq_f32(v2, idv);
|
|
|
|
|
v3 = vmulq_f32(v3, idv);
|
|
|
|
|
|
|
|
|
|
#if 0
|
|
|
|
|
v0[0] += frand(); v0[1] += frand(); v0[2] += frand(); v0[3] += frand();
|
|
|
|
|
v1[0] += frand(); v1[1] += frand(); v1[2] += frand(); v1[3] += frand();
|
|
|
|
|
v2[0] += frand(); v2[1] += frand(); v2[2] += frand(); v2[3] += frand();
|
|
|
|
|
v3[0] += frand(); v3[1] += frand(); v3[2] += frand(); v3[3] += frand();
|
|
|
|
|
#endif
|
|
|
|
|
|
|
|
|
|
uint32x4_t q0 = vcvtq_u32_f32(v0);
|
|
|
|
|
uint32x4_t q1 = vcvtq_u32_f32(v1);
|
|
|
|
|
uint32x4_t q2 = vcvtq_u32_f32(v2);
|
|
|
|
|
uint32x4_t q3 = vcvtq_u32_f32(v3);
|
|
|
|
|
|
|
|
|
|
for (int b = 0; b < QB; ++b) {
|
|
|
|
|
uint32x4_t m = vdupq_n_u32(1 << b);
|
|
|
|
|
uint32x4_t r = vdupq_n_u32(-b);
|
|
|
|
|
|
|
|
|
|
if (l < 32) {
|
|
|
|
|
p0[b] = vorrq_u32(p0[b], vshlq_u32(vshlq_u32(vandq_u32(q0, m), r), vld1q_s32(sh + l + 0)));
|
|
|
|
|
p0[b] = vorrq_u32(p0[b], vshlq_u32(vshlq_u32(vandq_u32(q1, m), r), vld1q_s32(sh + l + 4)));
|
|
|
|
|
p0[b] = vorrq_u32(p0[b], vshlq_u32(vshlq_u32(vandq_u32(q2, m), r), vld1q_s32(sh + l + 8)));
|
|
|
|
|
p0[b] = vorrq_u32(p0[b], vshlq_u32(vshlq_u32(vandq_u32(q3, m), r), vld1q_s32(sh + l + 12)));
|
|
|
|
|
} else {
|
|
|
|
|
p1[b] = vorrq_u32(p1[b], vshlq_u32(vshlq_u32(vandq_u32(q0, m), r), vld1q_s32(sh + l - 32)));
|
|
|
|
|
p1[b] = vorrq_u32(p1[b], vshlq_u32(vshlq_u32(vandq_u32(q1, m), r), vld1q_s32(sh + l - 28)));
|
|
|
|
|
p1[b] = vorrq_u32(p1[b], vshlq_u32(vshlq_u32(vandq_u32(q2, m), r), vld1q_s32(sh + l - 24)));
|
|
|
|
|
p1[b] = vorrq_u32(p1[b], vshlq_u32(vshlq_u32(vandq_u32(q3, m), r), vld1q_s32(sh + l - 20)));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
vst1q_u32((uint32_t *) ppt + 0, p0[0]);
|
|
|
|
|
vst1q_u32((uint32_t *) ppt + 4, p1[0]);
|
|
|
|
|
vst1q_u32((uint32_t *) ppt + 8, p0[1]);
|
|
|
|
|
vst1q_u32((uint32_t *) ppt + 12, p1[1]);
|
|
|
|
|
vst1q_u32((uint32_t *) ppt + 16, p0[2]);
|
|
|
|
|
vst1q_u32((uint32_t *) ppt + 20, p1[2]);
|
|
|
|
|
vst1q_u32((uint32_t *) ppt + 24, p0[3]);
|
|
|
|
|
vst1q_u32((uint32_t *) ppt + 28, p1[3]);
|
|
|
|
|
|
|
|
|
|
pp[0] = (ppt[0] | ppt[1] | ppt[2] | ppt[3] ) | ((uint64_t) (ppt[4] | ppt[5] | ppt[6] | ppt[7]) ) << 32;
|
|
|
|
|
pp[1] = (ppt[8] | ppt[9] | ppt[10] | ppt[11]) | ((uint64_t) (ppt[12] | ppt[13] | ppt[14] | ppt[15])) << 32;
|
|
|
|
|
pp[2] = (ppt[16] | ppt[17] | ppt[18] | ppt[19]) | ((uint64_t) (ppt[20] | ppt[21] | ppt[22] | ppt[23])) << 32;
|
|
|
|
|
pp[3] = (ppt[24] | ppt[25] | ppt[26] | ppt[27]) | ((uint64_t) (ppt[28] | ppt[29] | ppt[30] | ppt[31])) << 32;
|
|
|
|
|
}
|
|
|
|
|
#else
|
|
|
|
|
// less optimal SIMD
|
|
|
|
|
{
|
|
|
|
|
float32x4_t minv = vdupq_n_f32(min);
|
|
|
|
|
float32x4_t idv = vdupq_n_f32(id);
|
|
|
|
|
|
|
|
|
|
assert(gq_t_bits == 64);
|
|
|
|
|
uint8_t qq[gq_t_bits];
|
|
|
|
|
|
|
|
|
|
for (int l = 0; l < gq_t_bits; l += 16) {
|
|
|
|
|
float32x4_t v0 = vld1q_f32(src + i*QK + s*gq_t_bits + l + 0);
|
|
|
|
|
float32x4_t v1 = vld1q_f32(src + i*QK + s*gq_t_bits + l + 4);
|
|
|
|
|
float32x4_t v2 = vld1q_f32(src + i*QK + s*gq_t_bits + l + 8);
|
|
|
|
|
float32x4_t v3 = vld1q_f32(src + i*QK + s*gq_t_bits + l + 12);
|
|
|
|
|
|
|
|
|
|
v0 = vsubq_f32(v0, minv);
|
|
|
|
|
v1 = vsubq_f32(v1, minv);
|
|
|
|
|
v2 = vsubq_f32(v2, minv);
|
|
|
|
|
v3 = vsubq_f32(v3, minv);
|
|
|
|
|
|
|
|
|
|
v0 = vmulq_f32(v0, idv);
|
|
|
|
|
v1 = vmulq_f32(v1, idv);
|
|
|
|
|
v2 = vmulq_f32(v2, idv);
|
|
|
|
|
v3 = vmulq_f32(v3, idv);
|
|
|
|
|
|
|
|
|
|
#if 0
|
|
|
|
|
v0[0] += frand(); v0[1] += frand(); v0[2] += frand(); v0[3] += frand();
|
|
|
|
|
v1[0] += frand(); v1[1] += frand(); v1[2] += frand(); v1[3] += frand();
|
|
|
|
|
v2[0] += frand(); v2[1] += frand(); v2[2] += frand(); v2[3] += frand();
|
|
|
|
|
v3[0] += frand(); v3[1] += frand(); v3[2] += frand(); v3[3] += frand();
|
|
|
|
|
#endif
|
|
|
|
|
|
|
|
|
|
uint32x4_t q0 = vcvtq_u32_f32(v0);
|
|
|
|
|
uint32x4_t q1 = vcvtq_u32_f32(v1);
|
|
|
|
|
uint32x4_t q2 = vcvtq_u32_f32(v2);
|
|
|
|
|
uint32x4_t q3 = vcvtq_u32_f32(v3);
|
|
|
|
|
|
|
|
|
|
// store in qq as uint8_t
|
|
|
|
|
vst1_u8(qq + l + 0, vmovn_u16(vcombine_u16(vmovn_u32(q0), vmovn_u32(q1))));
|
|
|
|
|
vst1_u8(qq + l + 8, vmovn_u16(vcombine_u16(vmovn_u32(q2), vmovn_u32(q3))));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
for (int l = 0; l < gq_t_bits; l++) {
|
|
|
|
|
for (int b = 0; b < QB; b++) {
|
|
|
|
|
pb[i*nq*QB + s*QB + b] = pp[b];
|
|
|
|
|
const uint64_t ql = qq[l];
|
|
|
|
|
/*pp[b] |= qq[l] & (1 << b) ? (1ULL << l) : 0;*/
|
|
|
|
|
pp[b] |= ((ql & (1 << b)) >> b) << l;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
#endif
|
|
|
|
|
#endif
|
|
|
|
|
memcpy(pb + i*nq*QB + s*QB, pp, sizeof(pp));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -430,6 +583,10 @@ int main(int argc, const char ** argv) {
|
|
|
|
|
printf("convert time: %f ms / method = %d\n", (t_end - t_start) / 1000.0, method);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
for (int i = 0; i < 16; ++i) {
|
|
|
|
|
printf("%f %f\n", src0[i], src1[i]);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
const int nIter = 1;
|
|
|
|
|
|
|
|
|
|
const clock_t start = clock();
|
|
|
|
|