parent
bd9f710a45
commit
446ccf3ab1
@ -0,0 +1,272 @@
|
||||
// quantized matrix multiplication
|
||||
|
||||
#include <float.h>
|
||||
#include <stdint.h>
|
||||
#include <stdio.h>
|
||||
#include <assert.h>
|
||||
#include <stdlib.h>
|
||||
#include <string.h>
|
||||
#include <time.h>
|
||||
#include <math.h>
|
||||
|
||||
#include <sys/time.h>
|
||||
|
||||
#ifdef __ARM_NEON
|
||||
#include "arm_neon.h"
|
||||
#endif
|
||||
|
||||
#ifndef MIN
|
||||
#define MAX(a, b) ((a) > (b) ? (a) : (b))
|
||||
#define MIN(a, b) ((a) < (b) ? (a) : (b))
|
||||
#endif
|
||||
|
||||
const int M = 1280;
|
||||
const int N = 1536;
|
||||
const int K = 1280;
|
||||
|
||||
const int QK = 64;
|
||||
const int QB = 7;
|
||||
|
||||
#define gq_t uint64_t
|
||||
#define gq_t_bits 64
|
||||
|
||||
uint64_t get_time_us() {
|
||||
struct timeval tv;
|
||||
gettimeofday(&tv, NULL);
|
||||
return tv.tv_sec * 1000000 + tv.tv_usec;
|
||||
}
|
||||
|
||||
//
|
||||
// naive implementation
|
||||
//
|
||||
|
||||
void mul_mat_vec_f32_0(
|
||||
const float * restrict src0, // M x K
|
||||
const float * restrict src1, // N x K (transposed)
|
||||
float * dst,
|
||||
int m, int n, int k) {
|
||||
for (int i = 0; i < m; i++) {
|
||||
for (int j = 0; j < n; j++) {
|
||||
float sum = 0;
|
||||
for (int l = 0; l < k; l++) {
|
||||
sum += src0[i*k + l] * src1[j*k + l];
|
||||
}
|
||||
dst[i*n + j] = sum;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void quantize(const float * src, void * dst, int n, int k) {
|
||||
char * p0 = dst;
|
||||
|
||||
for (int j = 0; j < n; j++) {
|
||||
for (int i = 0; i < k/QK; i++) {
|
||||
float min = FLT_MAX;
|
||||
float max = -FLT_MAX;
|
||||
|
||||
// find min/max
|
||||
#ifdef __ARM_NEON
|
||||
{
|
||||
float32x4_t minv = vdupq_n_f32(FLT_MAX);
|
||||
float32x4_t maxv = vdupq_n_f32(-FLT_MAX);
|
||||
|
||||
for (int l = 0; l < QK; l += 4) {
|
||||
float32x4_t v = vld1q_f32(src + j*k + i*QK + l);
|
||||
minv = vminq_f32(minv, v);
|
||||
maxv = vmaxq_f32(maxv, v);
|
||||
}
|
||||
|
||||
float32x2_t minv32 = vpmin_f32(vget_low_f32(minv), vget_high_f32(minv));
|
||||
float32x2_t maxv32 = vpmax_f32(vget_low_f32(maxv), vget_high_f32(maxv));
|
||||
|
||||
min = MIN(vget_lane_f32(minv32, 0), vget_lane_f32(minv32, 1));
|
||||
max = MAX(vget_lane_f32(maxv32, 0), vget_lane_f32(maxv32, 1));
|
||||
|
||||
//printf("SIMD min/max: %f %f\n", min, max);
|
||||
}
|
||||
#else
|
||||
{
|
||||
for (int l = 0; l < QK; l++) {
|
||||
const float v = src[j*k + i*QK + l];
|
||||
if (v < min) min = v;
|
||||
if (v > max) max = v;
|
||||
}
|
||||
|
||||
//printf("NORM min/max: %f %f\n", min, max);
|
||||
}
|
||||
#endif
|
||||
|
||||
const float d = (max - min) / ((1 << QB) - 1);
|
||||
const float id = d ? 1.0/d : 0.0;
|
||||
|
||||
memcpy(p0, &min, sizeof(float)); p0 += sizeof(float);
|
||||
memcpy(p0, &d, sizeof(float)); p0 += sizeof(float);
|
||||
|
||||
//printf("min/max/d/id: %f %f %f %f\n", min, max, d, id);
|
||||
|
||||
for (int s = 0; s < QK/gq_t_bits; ++s) {
|
||||
gq_t pp[QB] = {0};
|
||||
|
||||
for (int l = 0; l < gq_t_bits; l++) {
|
||||
const float v = src[j*k + i*QK + s*gq_t_bits + l];
|
||||
const uint8_t q = (v - min)*id;
|
||||
|
||||
for (int b = 0; b < QB; b++) {
|
||||
pp[b] |= q & (1 << b) ? (1LL << l) : 0;
|
||||
}
|
||||
}
|
||||
|
||||
for (int b = 0; b < QB; b++) {
|
||||
memcpy(p0, &pp[b], sizeof(gq_t)); p0 += sizeof(gq_t);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void mul_mat_vec_gq_0(
|
||||
const void * src0,
|
||||
const void * src1,
|
||||
float * dst,
|
||||
int m, int n, int k) {
|
||||
const int kp = k & ~(gq_t_bits - 1);
|
||||
|
||||
const char * restrict p0 = src0;
|
||||
const char * restrict p1 = src1;
|
||||
|
||||
for (int ir0 = 0; ir0 < m; ir0++) {
|
||||
for (int ir1 = 0; ir1 < n; ir1++) {
|
||||
float sumf = 0.0;
|
||||
|
||||
const char * restrict pp0 = p0 + ir0*((2*sizeof(float) + (QK/gq_t_bits)*QB*sizeof(gq_t))*(k/QK));
|
||||
const char * restrict pp1 = p1 + ir1*((2*sizeof(float) + (QK/gq_t_bits)*QB*sizeof(gq_t))*(k/QK));
|
||||
|
||||
for (int i = 0; i < kp/QK; i++) {
|
||||
float min0, d0;
|
||||
memcpy(&min0, pp0, sizeof(float)); pp0 += sizeof(float);
|
||||
memcpy(&d0, pp0, sizeof(float)); pp0 += sizeof(float);
|
||||
|
||||
float min1, d1;
|
||||
memcpy(&min1, pp1, sizeof(float)); pp1 += sizeof(float);
|
||||
memcpy(&d1, pp1, sizeof(float)); pp1 += sizeof(float);
|
||||
|
||||
//printf("min0/d0 = %f %f | min1/d1 = %f %f\n", min0, d0, min1, d1);
|
||||
|
||||
#if 1
|
||||
// >>> General case for any QB
|
||||
|
||||
float s0[QB + 1];
|
||||
float s1[QB + 1];
|
||||
|
||||
s0[0] = min0;
|
||||
s1[0] = min1;
|
||||
|
||||
for (int b = 0; b < QB; b++) {
|
||||
s0[b + 1] = d0*(1 << b);
|
||||
s1[b + 1] = d1*(1 << b);
|
||||
}
|
||||
|
||||
gq_t m0[QB + 1];
|
||||
gq_t m1[QB + 1];
|
||||
|
||||
m0[0] = -1LL;
|
||||
m1[0] = -1LL;
|
||||
|
||||
for (int s = 0; s < QK/gq_t_bits; ++s) {
|
||||
for (int b = 0; b < QB; b++) {
|
||||
memcpy(&m0[b + 1], pp0, sizeof(gq_t)); pp0 += sizeof(gq_t);
|
||||
memcpy(&m1[b + 1], pp1, sizeof(gq_t)); pp1 += sizeof(gq_t);
|
||||
}
|
||||
|
||||
for (int q0 = 0; q0 < QB + 1; q0++) {
|
||||
for (int q1 = 0; q1 < QB + 1; q1++) {
|
||||
sumf += s0[q0]*s1[q1]*__builtin_popcountll(m0[q0] & m1[q1]);
|
||||
}
|
||||
}
|
||||
}
|
||||
#else
|
||||
#endif
|
||||
}
|
||||
|
||||
dst[ir0*n + ir1] = sumf;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
int main(int argc, const char ** argv) {
|
||||
float * src0 = (float *)malloc(sizeof(float)*M*K);
|
||||
float * src1 = (float *)malloc(sizeof(float)*N*K);
|
||||
float * dst = (float *)malloc(sizeof(float)*M*N);
|
||||
|
||||
for (int i = 0; i < M*K; i++) {
|
||||
src0[i] = rand() / (float)RAND_MAX;
|
||||
}
|
||||
|
||||
for (int i = 0; i < N*K; i++) {
|
||||
src1[i] = rand() / (float)RAND_MAX;
|
||||
}
|
||||
|
||||
void * src0_gq = calloc(1, (2*sizeof(float) + (QK/gq_t_bits)*QB*sizeof(gq_t))*(K/QK)*M);
|
||||
void * src1_gq = calloc(1, (2*sizeof(float) + (QK/gq_t_bits)*QB*sizeof(gq_t))*(K/QK)*N);
|
||||
|
||||
const size_t sizef16 = sizeof(__fp16)*M*K + sizeof(__fp16)*N*K;
|
||||
const size_t sizegq = (2*sizeof(float) + (QK/gq_t_bits)*QB*sizeof(gq_t))*(K/QK)*M +
|
||||
(2*sizeof(float) + (QK/gq_t_bits)*QB*sizeof(gq_t))*(K/QK)*N;
|
||||
|
||||
printf("compression: %f\n", (float)sizegq/sizef16);
|
||||
|
||||
// convert fp32 -> gq
|
||||
{
|
||||
const uint64_t t_start = get_time_us();
|
||||
|
||||
quantize(src0, src0_gq, M, K);
|
||||
quantize(src1, src1_gq, N, K);
|
||||
|
||||
const uint64_t t_end = get_time_us();
|
||||
printf("convert time: %f ms\n", (t_end - t_start) / 1000.0);
|
||||
}
|
||||
|
||||
int method = 0;
|
||||
if (argc > 1) {
|
||||
method = atoi(argv[1]);
|
||||
}
|
||||
|
||||
const int nIter = 1;
|
||||
|
||||
const clock_t start = clock();
|
||||
const uint64_t start_us = get_time_us();
|
||||
|
||||
double iM = 1.0/M;
|
||||
double sum = 0.0f;
|
||||
for (int i = 0; i < nIter; i++) {
|
||||
if (method == 0) {
|
||||
mul_mat_vec_f32_0(src0, src1, dst, M, N, K);
|
||||
}
|
||||
|
||||
if (method == 1) {
|
||||
mul_mat_vec_gq_0(src0_gq, src1_gq, dst, M, N, K);
|
||||
}
|
||||
}
|
||||
|
||||
for (int i = 0; i < N; i++) {
|
||||
sum += dst[i]*iM;
|
||||
}
|
||||
|
||||
{
|
||||
const clock_t end = clock();
|
||||
const uint64_t end_us = get_time_us();
|
||||
printf("%s: elapsed ticks: %ld\n", __func__, end - start);
|
||||
printf("%s: elapsed us: %llu / %f ms\n", __func__, end_us - start_us, (end_us - start_us) / 1000.0 / nIter);
|
||||
}
|
||||
|
||||
printf("%f\n", sum);
|
||||
|
||||
free(src0);
|
||||
free(src1);
|
||||
free(dst);
|
||||
|
||||
free(src0_gq);
|
||||
free(src1_gq);
|
||||
|
||||
return 0;
|
||||
}
|
Loading…
Reference in new issue