gq : attempt at n-bit quantization

gq
Georgi Gerganov 2 years ago
parent 4c2f924553
commit 0a7debb7bf
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735

@ -27,7 +27,7 @@ const int N = 1536;
const int K = 1280;
const int QK = 64;
#define QB 7
#define QB 4
//#define GGML_GQ_USE_FP16_SCALE
@ -44,6 +44,10 @@ const int QK = 64;
#define gq_quant_t uint64_t
#define gq_t_bits 64
float frand() {
return (float) rand() / (float) RAND_MAX;
}
uint64_t get_time_us() {
struct timeval tv;
gettimeofday(&tv, NULL);
@ -244,15 +248,41 @@ void quantize_2_row(const float * restrict src, void * restrict dst, int k) {
gq_quant_t pp[QB];
static const int32_t sh[32] = {
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31,
};
for (int i = 0; i < nb; i++) {
float min = FLT_MAX;
float max = -FLT_MAX;
#ifdef __ARM_NEON
{
float32x4_t minv = vdupq_n_f32(FLT_MAX);
float32x4_t maxv = vdupq_n_f32(-FLT_MAX);
for (int l = 0; l < QK; l += 4) {
float32x4_t v = vld1q_f32(src + i*QK + l);
minv = vminq_f32(minv, v);
maxv = vmaxq_f32(maxv, v);
}
float32x2_t minv32 = vpmin_f32(vget_low_f32(minv), vget_high_f32(minv));
float32x2_t maxv32 = vpmax_f32(vget_low_f32(maxv), vget_high_f32(maxv));
min = MIN(vget_lane_f32(minv32, 0), vget_lane_f32(minv32, 1));
max = MAX(vget_lane_f32(maxv32, 0), vget_lane_f32(maxv32, 1));
}
#else
{
for (int l = 0; l < QK; l++) {
const float v = src[i*QK + l];
if (v < min) min = v;
if (v > max) max = v;
}
}
#endif
const float d = (max - min) / ((1 << QB) - 1);
const float id = d ? 1.0/d : 0.0;
@ -263,19 +293,142 @@ void quantize_2_row(const float * restrict src, void * restrict dst, int k) {
for (int s = 0; s < nq; ++s) {
memset(pp, 0, sizeof(pp));
#if 0
for (int l = 0; l < gq_t_bits; l++) {
const float v = src[i*QK + s*gq_t_bits + l];
const uint8_t q = (v - min)*id;
const uint8_t q = (v - min)*id + frand();
for (int b = 0; b < QB; b++) {
pp[b] |= q & (1 << b) ? (1ULL << l) : 0;
}
}
#elif defined(__ARM_NEON)
#if 1
{
uint32_t ppt[2*4*QB];
float32x4_t minv = vdupq_n_f32(min);
float32x4_t idv = vdupq_n_f32(id);
assert(gq_t_bits == 64);
uint32x4_t p0[QB] = { vdupq_n_u32(0) };
uint32x4_t p1[QB] = { vdupq_n_u32(0) };
for (int l = 0; l < gq_t_bits; l += 16) {
float32x4_t v0 = vld1q_f32(src + i*QK + s*gq_t_bits + l + 0);
float32x4_t v1 = vld1q_f32(src + i*QK + s*gq_t_bits + l + 4);
float32x4_t v2 = vld1q_f32(src + i*QK + s*gq_t_bits + l + 8);
float32x4_t v3 = vld1q_f32(src + i*QK + s*gq_t_bits + l + 12);
v0 = vsubq_f32(v0, minv);
v1 = vsubq_f32(v1, minv);
v2 = vsubq_f32(v2, minv);
v3 = vsubq_f32(v3, minv);
v0 = vmulq_f32(v0, idv);
v1 = vmulq_f32(v1, idv);
v2 = vmulq_f32(v2, idv);
v3 = vmulq_f32(v3, idv);
#if 0
v0[0] += frand(); v0[1] += frand(); v0[2] += frand(); v0[3] += frand();
v1[0] += frand(); v1[1] += frand(); v1[2] += frand(); v1[3] += frand();
v2[0] += frand(); v2[1] += frand(); v2[2] += frand(); v2[3] += frand();
v3[0] += frand(); v3[1] += frand(); v3[2] += frand(); v3[3] += frand();
#endif
uint32x4_t q0 = vcvtq_u32_f32(v0);
uint32x4_t q1 = vcvtq_u32_f32(v1);
uint32x4_t q2 = vcvtq_u32_f32(v2);
uint32x4_t q3 = vcvtq_u32_f32(v3);
for (int b = 0; b < QB; ++b) {
uint32x4_t m = vdupq_n_u32(1 << b);
uint32x4_t r = vdupq_n_u32(-b);
if (l < 32) {
p0[b] = vorrq_u32(p0[b], vshlq_u32(vshlq_u32(vandq_u32(q0, m), r), vld1q_s32(sh + l + 0)));
p0[b] = vorrq_u32(p0[b], vshlq_u32(vshlq_u32(vandq_u32(q1, m), r), vld1q_s32(sh + l + 4)));
p0[b] = vorrq_u32(p0[b], vshlq_u32(vshlq_u32(vandq_u32(q2, m), r), vld1q_s32(sh + l + 8)));
p0[b] = vorrq_u32(p0[b], vshlq_u32(vshlq_u32(vandq_u32(q3, m), r), vld1q_s32(sh + l + 12)));
} else {
p1[b] = vorrq_u32(p1[b], vshlq_u32(vshlq_u32(vandq_u32(q0, m), r), vld1q_s32(sh + l - 32)));
p1[b] = vorrq_u32(p1[b], vshlq_u32(vshlq_u32(vandq_u32(q1, m), r), vld1q_s32(sh + l - 28)));
p1[b] = vorrq_u32(p1[b], vshlq_u32(vshlq_u32(vandq_u32(q2, m), r), vld1q_s32(sh + l - 24)));
p1[b] = vorrq_u32(p1[b], vshlq_u32(vshlq_u32(vandq_u32(q3, m), r), vld1q_s32(sh + l - 20)));
}
}
}
vst1q_u32((uint32_t *) ppt + 0, p0[0]);
vst1q_u32((uint32_t *) ppt + 4, p1[0]);
vst1q_u32((uint32_t *) ppt + 8, p0[1]);
vst1q_u32((uint32_t *) ppt + 12, p1[1]);
vst1q_u32((uint32_t *) ppt + 16, p0[2]);
vst1q_u32((uint32_t *) ppt + 20, p1[2]);
vst1q_u32((uint32_t *) ppt + 24, p0[3]);
vst1q_u32((uint32_t *) ppt + 28, p1[3]);
pp[0] = (ppt[0] | ppt[1] | ppt[2] | ppt[3] ) | ((uint64_t) (ppt[4] | ppt[5] | ppt[6] | ppt[7]) ) << 32;
pp[1] = (ppt[8] | ppt[9] | ppt[10] | ppt[11]) | ((uint64_t) (ppt[12] | ppt[13] | ppt[14] | ppt[15])) << 32;
pp[2] = (ppt[16] | ppt[17] | ppt[18] | ppt[19]) | ((uint64_t) (ppt[20] | ppt[21] | ppt[22] | ppt[23])) << 32;
pp[3] = (ppt[24] | ppt[25] | ppt[26] | ppt[27]) | ((uint64_t) (ppt[28] | ppt[29] | ppt[30] | ppt[31])) << 32;
}
#else
// less optimal SIMD
{
float32x4_t minv = vdupq_n_f32(min);
float32x4_t idv = vdupq_n_f32(id);
assert(gq_t_bits == 64);
uint8_t qq[gq_t_bits];
for (int l = 0; l < gq_t_bits; l += 16) {
float32x4_t v0 = vld1q_f32(src + i*QK + s*gq_t_bits + l + 0);
float32x4_t v1 = vld1q_f32(src + i*QK + s*gq_t_bits + l + 4);
float32x4_t v2 = vld1q_f32(src + i*QK + s*gq_t_bits + l + 8);
float32x4_t v3 = vld1q_f32(src + i*QK + s*gq_t_bits + l + 12);
v0 = vsubq_f32(v0, minv);
v1 = vsubq_f32(v1, minv);
v2 = vsubq_f32(v2, minv);
v3 = vsubq_f32(v3, minv);
v0 = vmulq_f32(v0, idv);
v1 = vmulq_f32(v1, idv);
v2 = vmulq_f32(v2, idv);
v3 = vmulq_f32(v3, idv);
#if 0
v0[0] += frand(); v0[1] += frand(); v0[2] += frand(); v0[3] += frand();
v1[0] += frand(); v1[1] += frand(); v1[2] += frand(); v1[3] += frand();
v2[0] += frand(); v2[1] += frand(); v2[2] += frand(); v2[3] += frand();
v3[0] += frand(); v3[1] += frand(); v3[2] += frand(); v3[3] += frand();
#endif
uint32x4_t q0 = vcvtq_u32_f32(v0);
uint32x4_t q1 = vcvtq_u32_f32(v1);
uint32x4_t q2 = vcvtq_u32_f32(v2);
uint32x4_t q3 = vcvtq_u32_f32(v3);
// store in qq as uint8_t
vst1_u8(qq + l + 0, vmovn_u16(vcombine_u16(vmovn_u32(q0), vmovn_u32(q1))));
vst1_u8(qq + l + 8, vmovn_u16(vcombine_u16(vmovn_u32(q2), vmovn_u32(q3))));
}
for (int l = 0; l < gq_t_bits; l++) {
for (int b = 0; b < QB; b++) {
pb[i*nq*QB + s*QB + b] = pp[b];
const uint64_t ql = qq[l];
/*pp[b] |= qq[l] & (1 << b) ? (1ULL << l) : 0;*/
pp[b] |= ((ql & (1 << b)) >> b) << l;
}
}
}
#endif
#endif
memcpy(pb + i*nq*QB + s*QB, pp, sizeof(pp));
}
}
}
@ -430,6 +583,10 @@ int main(int argc, const char ** argv) {
printf("convert time: %f ms / method = %d\n", (t_end - t_start) / 1000.0, method);
}
for (int i = 0; i < 16; ++i) {
printf("%f %f\n", src0[i], src1[i]);
}
const int nIter = 1;
const clock_t start = clock();

Loading…
Cancel
Save