tests : wip quantized matrix multiplication method 2

pull/15/head
Georgi Gerganov 2 years ago
parent d677c7f61d
commit deb0c486c7
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735

@ -47,6 +47,7 @@ endif()
#set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -ffast-math") #set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -ffast-math")
#set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -march=native") #set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -march=native")
#set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -mcpu=native")
# dependencies # dependencies

@ -42,7 +42,7 @@ uint64_t get_time_us() {
// naive implementation // naive implementation
// //
void mul_mat_vec_f32_0( void mul_mat_vec_f32_naive(
const float * restrict src0, // M x K const float * restrict src0, // M x K
const float * restrict src1, // N x K (transposed) const float * restrict src1, // N x K (transposed)
float * dst, float * dst,
@ -58,7 +58,11 @@ void mul_mat_vec_f32_0(
} }
} }
void quantize(const float * src, void * dst, int n, int k) { //
// method 1
//
void quantize_1(const float * src, void * dst, int n, int k) {
char * p0 = dst; char * p0 = dst;
gq_t pp[QB]; gq_t pp[QB];
@ -128,7 +132,7 @@ void quantize(const float * src, void * dst, int n, int k) {
} }
} }
void mul_mat_vec_gq_0( void mul_mat_vec_gq_1(
const void * src0, const void * src0,
const void * src1, const void * src1,
float * dst, float * dst,
@ -138,6 +142,12 @@ void mul_mat_vec_gq_0(
const char * restrict p0 = src0; const char * restrict p0 = src0;
const char * restrict p1 = src1; const char * restrict p1 = src1;
float s0[QB + 1];
float s1[QB + 1];
gq_t m0[QB + 1];
gq_t m1[QB + 1];
for (int ir0 = 0; ir0 < m; ir0++) { for (int ir0 = 0; ir0 < m; ir0++) {
for (int ir1 = 0; ir1 < n; ir1++) { for (int ir1 = 0; ir1 < n; ir1++) {
float sumf = 0.0; float sumf = 0.0;
@ -159,9 +169,6 @@ void mul_mat_vec_gq_0(
#if 1 #if 1
// >>> General case for any QB // >>> General case for any QB
float s0[QB + 1];
float s1[QB + 1];
s0[0] = min0; s0[0] = min0;
s1[0] = min1; s1[0] = min1;
@ -170,9 +177,147 @@ void mul_mat_vec_gq_0(
s1[b + 1] = d1*(1 << b); s1[b + 1] = d1*(1 << b);
} }
m0[0] = -1LL;
m1[0] = -1LL;
for (int s = 0; s < QK/gq_t_bits; ++s) {
for (int b = 0; b < QB; b++) {
memcpy(&m0[b + 1], pp0, sizeof(gq_t)); pp0 += sizeof(gq_t);
memcpy(&m1[b + 1], pp1, sizeof(gq_t)); pp1 += sizeof(gq_t);
}
for (int q0 = 0; q0 < QB + 1; q0++) {
for (int q1 = 0; q1 < QB + 1; q1++) {
sumf += s0[q0]*s1[q1]*__builtin_popcountll(m0[q0] & m1[q1]);
}
}
}
#else
#endif
}
dst[ir0*n + ir1] = sumf;
}
}
}
//
// method 2
//
void quantize_2(const float * src, void * dst, int n, int k) {
char * p0 = dst;
for (int j = 0; j < n; j++) {
for (int i = 0; i < k/QK; i++) {
float min = FLT_MAX;
float max = -FLT_MAX;
// find min/max
#ifdef __ARM_NEON
{
float32x4_t minv = vdupq_n_f32(FLT_MAX);
float32x4_t maxv = vdupq_n_f32(-FLT_MAX);
for (int l = 0; l < QK; l += 4) {
float32x4_t v = vld1q_f32(src + j*k + i*QK + l);
minv = vminq_f32(minv, v);
maxv = vmaxq_f32(maxv, v);
}
float32x2_t minv32 = vpmin_f32(vget_low_f32(minv), vget_high_f32(minv));
float32x2_t maxv32 = vpmax_f32(vget_low_f32(maxv), vget_high_f32(maxv));
min = MIN(vget_lane_f32(minv32, 0), vget_lane_f32(minv32, 1));
max = MAX(vget_lane_f32(maxv32, 0), vget_lane_f32(maxv32, 1));
//printf("SIMD min/max: %f %f\n", min, max);
}
#else
{
for (int l = 0; l < QK; l++) {
const float v = src[j*k + i*QK + l];
if (v < min) min = v;
if (v > max) max = v;
}
//printf("NORM min/max: %f %f\n", min, max);
}
#endif
const float d = (max - min) / ((1 << QB) - 1);
const float id = d ? 1.0/d : 0.0;
memcpy(p0, &min, sizeof(float)); p0 += sizeof(float);
memcpy(p0, &d, sizeof(float)); p0 += sizeof(float);
//printf("min/max/d/id: %f %f %f %f\n", min, max, d, id);
for (int s = 0; s < QK/gq_t_bits; ++s) {
gq_t pp[QB] = {0};
for (int l = 0; l < gq_t_bits; l++) {
const float v = src[j*k + i*QK + s*gq_t_bits + l];
const uint8_t q = (v - min)*id;
for (int b = 0; b < QB; b++) {
pp[b] |= q & (1 << b) ? (1LL << l) : 0;
}
}
for (int b = 0; b < QB; b++) {
memcpy(p0, &pp[b], sizeof(gq_t)); p0 += sizeof(gq_t);
}
}
}
}
}
void mul_mat_vec_gq_2(
const void * src0,
const void * src1,
float * dst,
int m, int n, int k) {
const int kp = k & ~(gq_t_bits - 1);
const char * restrict p0 = src0;
const char * restrict p1 = src1;
float s0[QB + 1];
float s1[QB + 1];
gq_t m0[QB + 1]; gq_t m0[QB + 1];
gq_t m1[QB + 1]; gq_t m1[QB + 1];
for (int ir0 = 0; ir0 < m; ir0++) {
for (int ir1 = 0; ir1 < n; ir1++) {
float sumf = 0.0;
const char * restrict pp0 = p0 + ir0*((2*sizeof(float) + (QK/gq_t_bits)*QB*sizeof(gq_t))*(k/QK));
const char * restrict pp1 = p1 + ir1*((2*sizeof(float) + (QK/gq_t_bits)*QB*sizeof(gq_t))*(k/QK));
for (int i = 0; i < kp/QK; i++) {
float min0, d0;
memcpy(&min0, pp0, sizeof(float)); pp0 += sizeof(float);
memcpy(&d0, pp0, sizeof(float)); pp0 += sizeof(float);
float min1, d1;
memcpy(&min1, pp1, sizeof(float)); pp1 += sizeof(float);
memcpy(&d1, pp1, sizeof(float)); pp1 += sizeof(float);
//printf("min0/d0 = %f %f | min1/d1 = %f %f\n", min0, d0, min1, d1);
#if 1
// >>> General case for any QB
s0[0] = min0;
s1[0] = min1;
for (int b = 0; b < QB; b++) {
s0[b + 1] = d0*(1 << b);
s1[b + 1] = d1*(1 << b);
}
m0[0] = -1LL; m0[0] = -1LL;
m1[0] = -1LL; m1[0] = -1LL;
@ -198,6 +343,8 @@ void mul_mat_vec_gq_0(
} }
int main(int argc, const char ** argv) { int main(int argc, const char ** argv) {
assert(sizeof(gq_t)*8 == gq_t_bits);
float * src0 = (float *)malloc(sizeof(float)*M*K); float * src0 = (float *)malloc(sizeof(float)*M*K);
float * src1 = (float *)malloc(sizeof(float)*N*K); float * src1 = (float *)malloc(sizeof(float)*N*K);
float * dst = (float *)malloc(sizeof(float)*M*N); float * dst = (float *)malloc(sizeof(float)*M*N);
@ -219,20 +366,27 @@ int main(int argc, const char ** argv) {
printf("compression: %f\n", (float)sizegq/sizef16); printf("compression: %f\n", (float)sizegq/sizef16);
int method = 0;
if (argc > 1) {
method = atoi(argv[1]);
}
// convert fp32 -> gq // convert fp32 -> gq
{ {
const uint64_t t_start = get_time_us(); const uint64_t t_start = get_time_us();
quantize(src0, src0_gq, M, K); if (method == 1) {
quantize(src1, src1_gq, N, K); quantize_1(src0, src0_gq, M, K);
quantize_1(src1, src1_gq, N, K);
}
const uint64_t t_end = get_time_us(); if (method == 2) {
printf("convert time: %f ms\n", (t_end - t_start) / 1000.0); quantize_2(src0, src0_gq, M, K);
quantize_2(src1, src1_gq, N, K);
} }
int method = 0; const uint64_t t_end = get_time_us();
if (argc > 1) { printf("convert time: %f ms / method = %d\n", (t_end - t_start) / 1000.0, method);
method = atoi(argv[1]);
} }
const int nIter = 1; const int nIter = 1;
@ -244,11 +398,15 @@ int main(int argc, const char ** argv) {
double sum = 0.0f; double sum = 0.0f;
for (int i = 0; i < nIter; i++) { for (int i = 0; i < nIter; i++) {
if (method == 0) { if (method == 0) {
mul_mat_vec_f32_0(src0, src1, dst, M, N, K); mul_mat_vec_f32_naive(src0, src1, dst, M, N, K);
} }
if (method == 1) { if (method == 1) {
mul_mat_vec_gq_0(src0_gq, src1_gq, dst, M, N, K); mul_mat_vec_gq_1(src0_gq, src1_gq, dst, M, N, K);
}
if (method == 2) {
mul_mat_vec_gq_1(src0_gq, src1_gq, dst, M, N, K);
} }
} }

Loading…
Cancel
Save