|
|
|
@ -26,8 +26,12 @@ const int M = 1280;
|
|
|
|
|
const int N = 1536;
|
|
|
|
|
const int K = 1280;
|
|
|
|
|
|
|
|
|
|
const int QK = 64;
|
|
|
|
|
#define QB 4
|
|
|
|
|
//const int M = 64;
|
|
|
|
|
//const int N = 64;
|
|
|
|
|
//const int K = 64;
|
|
|
|
|
|
|
|
|
|
const int QK = 32;
|
|
|
|
|
#define QB 2
|
|
|
|
|
|
|
|
|
|
//#define GGML_GQ_USE_FP16_SCALE
|
|
|
|
|
|
|
|
|
@ -41,8 +45,8 @@ const int QK = 64;
|
|
|
|
|
#define GGML_GQ_TO_FP32(x) (x)
|
|
|
|
|
#endif
|
|
|
|
|
|
|
|
|
|
#define gq_quant_t uint64_t
|
|
|
|
|
#define gq_t_bits 64
|
|
|
|
|
#define gq_t_bits 32
|
|
|
|
|
#define gq_quant_t uint32_t
|
|
|
|
|
|
|
|
|
|
float frand() {
|
|
|
|
|
return (float) rand() / (float) RAND_MAX;
|
|
|
|
@ -325,7 +329,7 @@ void quantize_2_row(const float * restrict src, void * restrict dst, int k) {
|
|
|
|
|
float32x4_t minv = vdupq_n_f32(min);
|
|
|
|
|
float32x4_t idv = vdupq_n_f32(id);
|
|
|
|
|
|
|
|
|
|
assert(gq_t_bits == 64);
|
|
|
|
|
assert(gq_t_bits % 16 == 0);
|
|
|
|
|
|
|
|
|
|
uint32x4_t p0[QB] = { vdupq_n_u32(0) };
|
|
|
|
|
uint32x4_t p1[QB] = { vdupq_n_u32(0) };
|
|
|
|
@ -376,6 +380,7 @@ void quantize_2_row(const float * restrict src, void * restrict dst, int k) {
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
#if QB == 4
|
|
|
|
|
vst1q_u32((uint32_t *) ppt + 0, p0[0]);
|
|
|
|
|
vst1q_u32((uint32_t *) ppt + 4, p1[0]);
|
|
|
|
|
vst1q_u32((uint32_t *) ppt + 8, p0[1]);
|
|
|
|
@ -389,6 +394,14 @@ void quantize_2_row(const float * restrict src, void * restrict dst, int k) {
|
|
|
|
|
pp[1] = (ppt[8] | ppt[9] | ppt[10] | ppt[11]) | ((uint64_t) (ppt[12] | ppt[13] | ppt[14] | ppt[15])) << 32;
|
|
|
|
|
pp[2] = (ppt[16] | ppt[17] | ppt[18] | ppt[19]) | ((uint64_t) (ppt[20] | ppt[21] | ppt[22] | ppt[23])) << 32;
|
|
|
|
|
pp[3] = (ppt[24] | ppt[25] | ppt[26] | ppt[27]) | ((uint64_t) (ppt[28] | ppt[29] | ppt[30] | ppt[31])) << 32;
|
|
|
|
|
#else
|
|
|
|
|
for (int b = 0; b < QB; ++b) {
|
|
|
|
|
vst1q_u32((uint32_t *) ppt + 0, p0[b]);
|
|
|
|
|
vst1q_u32((uint32_t *) ppt + 4, p1[b]);
|
|
|
|
|
|
|
|
|
|
pp[b] = (ppt[0] | ppt[1] | ppt[2] | ppt[3]) | ((uint64_t) (ppt[4] | ppt[5] | ppt[6] | ppt[7])) << 32;
|
|
|
|
|
}
|
|
|
|
|
#endif
|
|
|
|
|
}
|
|
|
|
|
#else
|
|
|
|
|
// less optimal SIMD
|
|
|
|
@ -458,9 +471,6 @@ void quantize_2(const float * restrict src, char * restrict dst, int n, int k) {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void vec_dot_gq_2(const int n, float * restrict s, const void * restrict x, const void * restrict y) {
|
|
|
|
|
float sumf[(QB + 1)*(QB + 1)];
|
|
|
|
|
memset(sumf, 0, sizeof(sumf));
|
|
|
|
|
|
|
|
|
|
const int nb = quantize_2_blocks_per_row(n);
|
|
|
|
|
const int nq = quantize_2_quants_per_block();
|
|
|
|
|
|
|
|
|
@ -473,10 +483,9 @@ void vec_dot_gq_2(const int n, float * restrict s, const void * restrict x, cons
|
|
|
|
|
const gq_quant_t * restrict pb0 = (const gq_quant_t *) (pd0 + nb);
|
|
|
|
|
const gq_quant_t * restrict pb1 = (const gq_quant_t *) (pd1 + nb);
|
|
|
|
|
|
|
|
|
|
#if 1
|
|
|
|
|
float s0[QB + 1];
|
|
|
|
|
float s1[QB + 1];
|
|
|
|
|
float sumf = 0.0;
|
|
|
|
|
|
|
|
|
|
#if 1
|
|
|
|
|
for (int i = 0; i < nb; i++) {
|
|
|
|
|
const float m0 = GGML_GQ_TO_FP32(pm0[i]);
|
|
|
|
|
const float d0 = GGML_GQ_TO_FP32(pd0[i]);
|
|
|
|
@ -484,6 +493,99 @@ void vec_dot_gq_2(const int n, float * restrict s, const void * restrict x, cons
|
|
|
|
|
const float m1 = GGML_GQ_TO_FP32(pm1[i]);
|
|
|
|
|
const float d1 = GGML_GQ_TO_FP32(pd1[i]);
|
|
|
|
|
|
|
|
|
|
#if QB == 4
|
|
|
|
|
int isum01 = 0;
|
|
|
|
|
int isum10 = 0;
|
|
|
|
|
int isum11 = 0;
|
|
|
|
|
|
|
|
|
|
for (int s = 0; s < nq; ++s) {
|
|
|
|
|
const gq_quant_t * restrict mm0 = pb0 + i*nq*QB + s*QB;
|
|
|
|
|
const gq_quant_t * restrict mm1 = pb1 + i*nq*QB + s*QB;
|
|
|
|
|
|
|
|
|
|
#define bpcnt(x) __builtin_popcountll(x)
|
|
|
|
|
isum01 += (1 << 0)*(bpcnt(mm1[0]));
|
|
|
|
|
isum01 += (1 << 1)*(bpcnt(mm1[1]));
|
|
|
|
|
isum01 += (1 << 2)*(bpcnt(mm1[2]));
|
|
|
|
|
isum01 += (1 << 3)*(bpcnt(mm1[3]));
|
|
|
|
|
|
|
|
|
|
isum10 += (1 << 0)*(bpcnt(mm0[0]));
|
|
|
|
|
isum10 += (1 << 1)*(bpcnt(mm0[1]));
|
|
|
|
|
isum10 += (1 << 2)*(bpcnt(mm0[2]));
|
|
|
|
|
isum10 += (1 << 3)*(bpcnt(mm0[3]));
|
|
|
|
|
|
|
|
|
|
isum11 += (1 << 0)*(bpcnt(mm0[0] & mm1[0]));
|
|
|
|
|
isum11 += (1 << 1)*(bpcnt(mm0[0] & mm1[1]) + bpcnt(mm0[1] & mm1[0]));
|
|
|
|
|
isum11 += (1 << 2)*(bpcnt(mm0[0] & mm1[2]) + bpcnt(mm0[1] & mm1[1]) + bpcnt(mm0[2] & mm1[0]));
|
|
|
|
|
isum11 += (1 << 3)*(bpcnt(mm0[0] & mm1[3]) + bpcnt(mm0[1] & mm1[2]) + bpcnt(mm0[2] & mm1[1]) + bpcnt(mm0[3] & mm1[0]));
|
|
|
|
|
isum11 += (1 << 4)*(bpcnt(mm0[1] & mm1[3]) + bpcnt(mm0[2] & mm1[2]) + bpcnt(mm0[3] & mm1[1]));
|
|
|
|
|
isum11 += (1 << 5)*(bpcnt(mm0[2] & mm1[3]) + bpcnt(mm0[3] & mm1[2]));
|
|
|
|
|
isum11 += (1 << 6)*(bpcnt(mm0[3] & mm1[3]));
|
|
|
|
|
#undef bpcnt
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
sumf += nq*gq_t_bits*(m0*m1) + isum01*(m0*d1) + isum10*(m1*d0) + isum11*(d0*d1);
|
|
|
|
|
#elif QB == 3
|
|
|
|
|
int isum01 = 0;
|
|
|
|
|
int isum10 = 0;
|
|
|
|
|
int isum11 = 0;
|
|
|
|
|
|
|
|
|
|
for (int s = 0; s < nq; ++s) {
|
|
|
|
|
const gq_quant_t * restrict mm0 = pb0 + i*nq*QB + s*QB;
|
|
|
|
|
const gq_quant_t * restrict mm1 = pb1 + i*nq*QB + s*QB;
|
|
|
|
|
|
|
|
|
|
#if gq_t_bits == 32
|
|
|
|
|
#define bpcnt(x) __builtin_popcount(x)
|
|
|
|
|
#else
|
|
|
|
|
#define bpcnt(x) __builtin_popcountll(x)
|
|
|
|
|
#endif
|
|
|
|
|
isum01 += (1 << 0)*(bpcnt(mm1[0]));
|
|
|
|
|
isum01 += (1 << 1)*(bpcnt(mm1[1]));
|
|
|
|
|
isum01 += (1 << 2)*(bpcnt(mm1[2]));
|
|
|
|
|
|
|
|
|
|
isum10 += (1 << 0)*(bpcnt(mm0[0]));
|
|
|
|
|
isum10 += (1 << 1)*(bpcnt(mm0[1]));
|
|
|
|
|
isum10 += (1 << 2)*(bpcnt(mm0[2]));
|
|
|
|
|
|
|
|
|
|
isum11 += (1 << 0)*(bpcnt(mm0[0] & mm1[0]));
|
|
|
|
|
isum11 += (1 << 1)*(bpcnt(mm0[0] & mm1[1]) + bpcnt(mm0[1] & mm1[0]));
|
|
|
|
|
isum11 += (1 << 2)*(bpcnt(mm0[0] & mm1[2]) + bpcnt(mm0[1] & mm1[1]) + bpcnt(mm0[2] & mm1[0]));
|
|
|
|
|
isum11 += (1 << 3)*(bpcnt(mm0[1] & mm1[2]) + bpcnt(mm0[2] & mm1[1]));
|
|
|
|
|
isum11 += (1 << 4)*(bpcnt(mm0[2] & mm1[2]));
|
|
|
|
|
#undef bpcnt
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
sumf += nq*gq_t_bits*(m0*m1) + isum01*(m0*d1) + isum10*(m1*d0) + isum11*(d0*d1);
|
|
|
|
|
#elif QB == 2
|
|
|
|
|
int isum01 = 0;
|
|
|
|
|
int isum10 = 0;
|
|
|
|
|
int isum11 = 0;
|
|
|
|
|
|
|
|
|
|
for (int s = 0; s < nq; ++s) {
|
|
|
|
|
const gq_quant_t * restrict mm0 = pb0 + i*nq*QB + s*QB;
|
|
|
|
|
const gq_quant_t * restrict mm1 = pb1 + i*nq*QB + s*QB;
|
|
|
|
|
|
|
|
|
|
#if gq_t_bits == 32
|
|
|
|
|
#define bpcnt(x) __builtin_popcount(x)
|
|
|
|
|
#else
|
|
|
|
|
#define bpcnt(x) __builtin_popcountll(x)
|
|
|
|
|
#endif
|
|
|
|
|
isum01 += (1 << 0)*(bpcnt(mm1[0]));
|
|
|
|
|
isum01 += (1 << 1)*(bpcnt(mm1[1]));
|
|
|
|
|
|
|
|
|
|
isum10 += (1 << 0)*(bpcnt(mm0[0]));
|
|
|
|
|
isum10 += (1 << 1)*(bpcnt(mm0[1]));
|
|
|
|
|
|
|
|
|
|
isum11 += (1 << 0)*(bpcnt(mm0[0] & mm1[0]));
|
|
|
|
|
isum11 += (1 << 1)*(bpcnt(mm0[0] & mm1[1]) + bpcnt(mm0[1] & mm1[0]));
|
|
|
|
|
isum11 += (1 << 2)*(bpcnt(mm0[1] & mm1[1]));
|
|
|
|
|
#undef bpcnt
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
sumf += nq*gq_t_bits*(m0*m1) + isum01*(m0*d1) + isum10*(m1*d0) + isum11*(d0*d1);
|
|
|
|
|
#else
|
|
|
|
|
float s0[QB + 1];
|
|
|
|
|
float s1[QB + 1];
|
|
|
|
|
|
|
|
|
|
s0[0] = m0;
|
|
|
|
|
s1[0] = m1;
|
|
|
|
|
|
|
|
|
@ -497,36 +599,17 @@ void vec_dot_gq_2(const int n, float * restrict s, const void * restrict x, cons
|
|
|
|
|
const gq_quant_t mm0 = q0 ? pb0[i*nq*QB + s*QB + q0 - 1] : -1ULL;
|
|
|
|
|
for (int q1 = 0; q1 < QB + 1; q1++) {
|
|
|
|
|
const gq_quant_t mm1 = q1 ? pb1[i*nq*QB + s*QB + q1 - 1] : -1ULL;
|
|
|
|
|
sumf[q0*(QB + 1) + q1] += s0[q0]*s1[q1]*__builtin_popcountll(mm0 & mm1);
|
|
|
|
|
sumf += s0[q0]*s1[q1]*__builtin_popcountll(mm0 & mm1);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
#endif
|
|
|
|
|
}
|
|
|
|
|
#else
|
|
|
|
|
// SIMD-ify with the assumptions:
|
|
|
|
|
// - nb is a multiple of 4
|
|
|
|
|
// - gq_scale_t is float
|
|
|
|
|
// - gq_quant_t is uint64_t
|
|
|
|
|
// - QB == 7
|
|
|
|
|
assert(nb % 4 == 0);
|
|
|
|
|
|
|
|
|
|
#ifdef __ARM_NEON
|
|
|
|
|
#else
|
|
|
|
|
// TODO
|
|
|
|
|
#endif
|
|
|
|
|
|
|
|
|
|
#error "not implemented"
|
|
|
|
|
#endif
|
|
|
|
|
|
|
|
|
|
for (int q0 = 0; q0 < QB + 1; q0++) {
|
|
|
|
|
for (int q1 = 1; q1 < QB + 1; q1++) {
|
|
|
|
|
sumf[q0*(QB + 1)] += sumf[q0*(QB + 1) + q1];
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
*s = sumf[0];
|
|
|
|
|
for (int q0 = 1; q0 < QB + 1; q0++) {
|
|
|
|
|
*s += sumf[q0*(QB + 1)];
|
|
|
|
|
}
|
|
|
|
|
*s = sumf;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// use vec_dot_gq_2 to compute the dot product of two rows
|
|
|
|
@ -698,6 +781,7 @@ void quantize_3_row(const float * restrict src, void * restrict dst, int k) {
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
#if QB == 4
|
|
|
|
|
vst1q_u32((uint32_t *) ppt + 0, p0[0]);
|
|
|
|
|
vst1q_u32((uint32_t *) ppt + 4, p1[0]);
|
|
|
|
|
vst1q_u32((uint32_t *) ppt + 8, p0[1]);
|
|
|
|
@ -711,6 +795,14 @@ void quantize_3_row(const float * restrict src, void * restrict dst, int k) {
|
|
|
|
|
pp[1] = (ppt[8] | ppt[9] | ppt[10] | ppt[11]) | ((uint64_t) (ppt[12] | ppt[13] | ppt[14] | ppt[15])) << 32;
|
|
|
|
|
pp[2] = (ppt[16] | ppt[17] | ppt[18] | ppt[19]) | ((uint64_t) (ppt[20] | ppt[21] | ppt[22] | ppt[23])) << 32;
|
|
|
|
|
pp[3] = (ppt[24] | ppt[25] | ppt[26] | ppt[27]) | ((uint64_t) (ppt[28] | ppt[29] | ppt[30] | ppt[31])) << 32;
|
|
|
|
|
#else
|
|
|
|
|
for (int q = 0; q < QB; ++q) {
|
|
|
|
|
vst1q_u32((uint32_t *) ppt + 0, p0[q]);
|
|
|
|
|
vst1q_u32((uint32_t *) ppt + 4, p1[q]);
|
|
|
|
|
|
|
|
|
|
pp[q] = (ppt[0] | ppt[1] | ppt[2] | ppt[3]) | ((uint64_t) (ppt[4] | ppt[5] | ppt[6] | ppt[7])) << 32;
|
|
|
|
|
}
|
|
|
|
|
#endif
|
|
|
|
|
}
|
|
|
|
|
#endif
|
|
|
|
|
memcpy(pb + i*nq*QB + s*QB, pp, sizeof(pp));
|
|
|
|
@ -729,8 +821,7 @@ void quantize_3(const float * restrict src, char * restrict dst, int n, int k) {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void vec_dot_gq_3(const int n, float * restrict s, const void * restrict x, const void * restrict y) {
|
|
|
|
|
float sumf[QB*QB];
|
|
|
|
|
memset(sumf, 0, sizeof(sumf));
|
|
|
|
|
float sumf = 0.0f;
|
|
|
|
|
|
|
|
|
|
const int nb = quantize_3_blocks_per_row(n);
|
|
|
|
|
const int nq = quantize_3_quants_per_block();
|
|
|
|
@ -741,56 +832,160 @@ void vec_dot_gq_3(const int n, float * restrict s, const void * restrict x, cons
|
|
|
|
|
const gq_quant_t * restrict pb0 = (const gq_quant_t *) (pd0 + nb);
|
|
|
|
|
const gq_quant_t * restrict pb1 = (const gq_quant_t *) (pd1 + nb);
|
|
|
|
|
|
|
|
|
|
#if 1
|
|
|
|
|
float s0[QB];
|
|
|
|
|
float s1[QB];
|
|
|
|
|
|
|
|
|
|
#if 0
|
|
|
|
|
for (int i = 0; i < nb; i++) {
|
|
|
|
|
const float d0 = GGML_GQ_TO_FP32(pd0[i]);
|
|
|
|
|
const float d1 = GGML_GQ_TO_FP32(pd1[i]);
|
|
|
|
|
int isum = 0;
|
|
|
|
|
|
|
|
|
|
for (int b = 0; b < QB; b++) {
|
|
|
|
|
s0[b] = d0*(1 << b);
|
|
|
|
|
s1[b] = d1*(1 << b);
|
|
|
|
|
}
|
|
|
|
|
#if QB == 4
|
|
|
|
|
for (int s = 0; s < nq; ++s) {
|
|
|
|
|
const gq_quant_t * restrict m0 = pb0 + i*nq*QB + s*QB;
|
|
|
|
|
const gq_quant_t * restrict m1 = pb1 + i*nq*QB + s*QB;
|
|
|
|
|
|
|
|
|
|
isum += (1 << 0)*(__builtin_popcountll(m0[0] & m1[0]));
|
|
|
|
|
isum += (1 << 1)*(__builtin_popcountll(m0[0] & m1[1]) + __builtin_popcountll(m0[1] & m1[0]));
|
|
|
|
|
isum += (1 << 2)*(__builtin_popcountll(m0[0] & m1[2]) + __builtin_popcountll(m0[1] & m1[1]) + __builtin_popcountll(m0[2] & m1[0]));
|
|
|
|
|
isum += (1 << 3)*(__builtin_popcountll(m0[0] & m1[3]) + __builtin_popcountll(m0[1] & m1[2]) + __builtin_popcountll(m0[2] & m1[1]) + __builtin_popcountll(m0[3] & m1[0]));
|
|
|
|
|
isum += (1 << 4)*(__builtin_popcountll(m0[1] & m1[3]) + __builtin_popcountll(m0[2] & m1[2]) + __builtin_popcountll(m0[3] & m1[1]));
|
|
|
|
|
isum += (1 << 5)*(__builtin_popcountll(m0[2] & m1[3]) + __builtin_popcountll(m0[3] & m1[2]));
|
|
|
|
|
isum += (1 << 6)*(__builtin_popcountll(m0[3] & m1[3]));
|
|
|
|
|
}
|
|
|
|
|
#else
|
|
|
|
|
for (int s = 0; s < nq; ++s) {
|
|
|
|
|
for (int q0 = 0; q0 < QB; q0++) {
|
|
|
|
|
const gq_quant_t mm0 = pb0[i*nq*QB + s*QB + q0];
|
|
|
|
|
for (int q1 = 0; q1 < QB; q1++) {
|
|
|
|
|
const gq_quant_t mm1 = pb1[i*nq*QB + s*QB + q1];
|
|
|
|
|
sumf[q0*QB + q1] += s0[q0]*s1[q1]*__builtin_popcountll(mm0 & mm1);
|
|
|
|
|
isum += (1 << (q0 + q1))*(__builtin_popcountll(mm0 & mm1));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
#endif
|
|
|
|
|
|
|
|
|
|
const float d0 = GGML_GQ_TO_FP32(pd0[i]);
|
|
|
|
|
const float d1 = GGML_GQ_TO_FP32(pd1[i]);
|
|
|
|
|
|
|
|
|
|
sumf += d0*d1*isum;
|
|
|
|
|
}
|
|
|
|
|
#else
|
|
|
|
|
// SIMD-ify with the assumptions:
|
|
|
|
|
// - nb is a multiple of 4
|
|
|
|
|
// - gq_scale_t is float
|
|
|
|
|
// - gq_quant_t is uint64_t
|
|
|
|
|
// - QB == 7
|
|
|
|
|
assert(nb % 4 == 0);
|
|
|
|
|
|
|
|
|
|
#ifdef __ARM_NEON
|
|
|
|
|
#else
|
|
|
|
|
// TODO
|
|
|
|
|
#endif
|
|
|
|
|
|
|
|
|
|
#endif
|
|
|
|
|
// gq_quant_t == uint64_t
|
|
|
|
|
for (int i = 0; i < nb; i += 4) {
|
|
|
|
|
int isum[4] = {0, 0, 0, 0};
|
|
|
|
|
|
|
|
|
|
for (int k = 0; k < 4; ++k) {
|
|
|
|
|
for (int s = 0; s < nq; ++s) {
|
|
|
|
|
const gq_quant_t * restrict m0 = pb0 + (i+k)*nq*QB + s*QB;
|
|
|
|
|
const gq_quant_t * restrict m1 = pb1 + (i+k)*nq*QB + s*QB;
|
|
|
|
|
|
|
|
|
|
#if QB == 4
|
|
|
|
|
#define bpcnt(x) __builtin_popcountll(x)
|
|
|
|
|
//isum[k] += (1ULL << 0)*(bpcnt(m0[0] & m1[0])) +
|
|
|
|
|
// (1ULL << 1)*(bpcnt(m0[0] & m1[1]) + bpcnt(m0[1] & m1[0])) +
|
|
|
|
|
// (1ULL << 2)*(bpcnt(m0[0] & m1[2]) + bpcnt(m0[1] & m1[1]) + bpcnt(m0[2] & m1[0])) +
|
|
|
|
|
// (1ULL << 3)*(bpcnt(m0[0] & m1[3]) + bpcnt(m0[1] & m1[2]) + bpcnt(m0[2] & m1[1]) + bpcnt(m0[3] & m1[0])) +
|
|
|
|
|
// (1ULL << 4)*(bpcnt(m0[1] & m1[3]) + bpcnt(m0[2] & m1[2]) + bpcnt(m0[3] & m1[1])) +
|
|
|
|
|
// (1ULL << 5)*(bpcnt(m0[2] & m1[3]) + bpcnt(m0[3] & m1[2])) +
|
|
|
|
|
// (1ULL << 6)*(bpcnt(m0[3] & m1[3]));
|
|
|
|
|
#undef bpcnt
|
|
|
|
|
|
|
|
|
|
const uint8x8_t m00 = vld1_u8((const uint8_t *) (m0 + 0));
|
|
|
|
|
const uint8x8_t m01 = vld1_u8((const uint8_t *) (m0 + 1));
|
|
|
|
|
const uint8x8_t m02 = vld1_u8((const uint8_t *) (m0 + 2));
|
|
|
|
|
const uint8x8_t m03 = vld1_u8((const uint8_t *) (m0 + 3));
|
|
|
|
|
|
|
|
|
|
const uint8x8_t m10 = vld1_u8((const uint8_t *) (m1 + 0));
|
|
|
|
|
const uint8x8_t m11 = vld1_u8((const uint8_t *) (m1 + 1));
|
|
|
|
|
const uint8x8_t m12 = vld1_u8((const uint8_t *) (m1 + 2));
|
|
|
|
|
const uint8x8_t m13 = vld1_u8((const uint8_t *) (m1 + 3));
|
|
|
|
|
|
|
|
|
|
const uint8x8_t m00m10 = vand_u8(m00, m10);
|
|
|
|
|
|
|
|
|
|
const uint8x8_t m00m11 = vand_u8(m00, m11);
|
|
|
|
|
const uint8x8_t m01m10 = vand_u8(m01, m10);
|
|
|
|
|
|
|
|
|
|
const uint8x8_t m00m12 = vand_u8(m00, m12);
|
|
|
|
|
const uint8x8_t m01m11 = vand_u8(m01, m11);
|
|
|
|
|
const uint8x8_t m02m10 = vand_u8(m02, m10);
|
|
|
|
|
|
|
|
|
|
const uint8x8_t m00m13 = vand_u8(m00, m13);
|
|
|
|
|
const uint8x8_t m01m12 = vand_u8(m01, m12);
|
|
|
|
|
const uint8x8_t m02m11 = vand_u8(m02, m11);
|
|
|
|
|
const uint8x8_t m03m10 = vand_u8(m03, m10);
|
|
|
|
|
|
|
|
|
|
const uint8x8_t m01m13 = vand_u8(m01, m13);
|
|
|
|
|
const uint8x8_t m02m12 = vand_u8(m02, m12);
|
|
|
|
|
const uint8x8_t m03m11 = vand_u8(m03, m11);
|
|
|
|
|
|
|
|
|
|
const uint8x8_t m02m13 = vand_u8(m02, m13);
|
|
|
|
|
const uint8x8_t m03m12 = vand_u8(m03, m12);
|
|
|
|
|
|
|
|
|
|
const uint8x8_t m03m13 = vand_u8(m03, m13);
|
|
|
|
|
|
|
|
|
|
#define bpcnt(x) vaddv_u8(vcnt_u8(x))
|
|
|
|
|
isum[k] += (1ULL << 0)*(bpcnt(m00m10)) +
|
|
|
|
|
(1ULL << 1)*(bpcnt(m00m11) + bpcnt(m01m10)) +
|
|
|
|
|
(1ULL << 2)*(bpcnt(m00m12) + bpcnt(m01m11) + bpcnt(m02m10)) +
|
|
|
|
|
(1ULL << 3)*(bpcnt(m00m13) + bpcnt(m01m12) + bpcnt(m02m11) + bpcnt(m03m10)) +
|
|
|
|
|
(1ULL << 4)*(bpcnt(m01m13) + bpcnt(m02m12) + bpcnt(m03m11)) +
|
|
|
|
|
(1ULL << 5)*(bpcnt(m02m13) + bpcnt(m03m12)) +
|
|
|
|
|
(1ULL << 6)*(bpcnt(m03m13));
|
|
|
|
|
#undef bpcnt
|
|
|
|
|
#else
|
|
|
|
|
for (int q0 = 0; q0 < QB; q0++) {
|
|
|
|
|
for (int q1 = 1; q1 < QB; q1++) {
|
|
|
|
|
sumf[q0*QB] += sumf[q0*QB + q1];
|
|
|
|
|
const gq_quant_t mm0 = m0[q0];
|
|
|
|
|
for (int q1 = 0; q1 < QB; q1++) {
|
|
|
|
|
const gq_quant_t mm1 = m1[q1];
|
|
|
|
|
isum[k] += (1ULL << (q0 + q1))*(__builtin_popcountll(mm0 & mm1));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
#endif
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
*s = sumf[0];
|
|
|
|
|
for (int q0 = 1; q0 < QB; q0++) {
|
|
|
|
|
*s += sumf[q0*QB];
|
|
|
|
|
int32x4_t isumv = vld1q_s32(isum);
|
|
|
|
|
|
|
|
|
|
float32x4_t d0v = vld1q_f32(pd0 + i);
|
|
|
|
|
float32x4_t d1v = vld1q_f32(pd1 + i);
|
|
|
|
|
|
|
|
|
|
float32x4_t sumfv = vmulq_f32(d0v, d1v);
|
|
|
|
|
|
|
|
|
|
sumfv = vmulq_f32(sumfv, vcvtq_f32_s32(isumv));
|
|
|
|
|
sumf += vaddvq_f32(sumfv);
|
|
|
|
|
}
|
|
|
|
|
#else
|
|
|
|
|
#error "not implemented"
|
|
|
|
|
#endif
|
|
|
|
|
|
|
|
|
|
#endif
|
|
|
|
|
*s = sumf;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// for (int s = 0; s < nq; ++s) {
|
|
|
|
|
// const uint8x8_t m00 = vld1_u8((const uint8_t *) (pb0 + i*nq*QB + s*QB + 0));
|
|
|
|
|
// const uint8x8_t m01 = vld1_u8((const uint8_t *) (pb0 + i*nq*QB + s*QB + 1));
|
|
|
|
|
// const uint8x8_t m02 = vld1_u8((const uint8_t *) (pb0 + i*nq*QB + s*QB + 2));
|
|
|
|
|
// const uint8x8_t m03 = vld1_u8((const uint8_t *) (pb0 + i*nq*QB + s*QB + 3));
|
|
|
|
|
//
|
|
|
|
|
// const uint8x8_t m10 = vld1_u8((const uint8_t *) (pb1 + i*nq*QB + s*QB + 0));
|
|
|
|
|
// const uint8x8_t m11 = vld1_u8((const uint8_t *) (pb1 + i*nq*QB + s*QB + 1));
|
|
|
|
|
// const uint8x8_t m12 = vld1_u8((const uint8_t *) (pb1 + i*nq*QB + s*QB + 2));
|
|
|
|
|
// const uint8x8_t m13 = vld1_u8((const uint8_t *) (pb1 + i*nq*QB + s*QB + 3));
|
|
|
|
|
//
|
|
|
|
|
// isum += (1 << 0)*(vaddv_u8(vcnt_u8(vand_u8(m00, m10))));
|
|
|
|
|
// isum += (1 << 1)*(vaddv_u8(vcnt_u8(vand_u8(m00, m11))) + vaddv_u8(vcnt_u8(vand_u8(m01, m10))));
|
|
|
|
|
// isum += (1 << 2)*(vaddv_u8(vcnt_u8(vand_u8(m00, m12))) + vaddv_u8(vcnt_u8(vand_u8(m01, m11))) + vaddv_u8(vcnt_u8(vand_u8(m02, m10))));
|
|
|
|
|
// isum += (1 << 3)*(vaddv_u8(vcnt_u8(vand_u8(m00, m13))) + vaddv_u8(vcnt_u8(vand_u8(m01, m12))) + vaddv_u8(vcnt_u8(vand_u8(m02, m11))) + vaddv_u8(vcnt_u8(vand_u8(m03, m10))));
|
|
|
|
|
// isum += (1 << 4)*(vaddv_u8(vcnt_u8(vand_u8(m01, m13))) + vaddv_u8(vcnt_u8(vand_u8(m02, m12))) + vaddv_u8(vcnt_u8(vand_u8(m03, m11))));
|
|
|
|
|
// isum += (1 << 5)*(vaddv_u8(vcnt_u8(vand_u8(m02, m13))) + vaddv_u8(vcnt_u8(vand_u8(m03, m12))));
|
|
|
|
|
// isum += (1 << 6)*(vaddv_u8(vcnt_u8(vand_u8(m03, m13))));
|
|
|
|
|
// }
|
|
|
|
|
//
|
|
|
|
|
// const float d0 = GGML_GQ_TO_FP32(pd0[i]);
|
|
|
|
|
// const float d1 = GGML_GQ_TO_FP32(pd1[i]);
|
|
|
|
|
//
|
|
|
|
|
// sumf += d0*d1*isum;
|
|
|
|
|
// }
|
|
|
|
|
|
|
|
|
|
// use vec_dot_gq_3 to compute the dot product of two rows
|
|
|
|
|
void mul_mat_gq_3(
|
|
|
|
|
const void * src0,
|
|
|
|
@ -826,14 +1021,21 @@ int main(int argc, const char ** argv) {
|
|
|
|
|
float * src1 = (float *)malloc(sizeof(float)*N*K);
|
|
|
|
|
float * dst = (float *)malloc(sizeof(float)*M*N);
|
|
|
|
|
|
|
|
|
|
// allocate aligned memory
|
|
|
|
|
//float * src0 = (float *)aligned_alloc(32, sizeof(float)*M*K);
|
|
|
|
|
//float * src1 = (float *)aligned_alloc(32, sizeof(float)*N*K);
|
|
|
|
|
//float * dst = (float *)aligned_alloc(32, sizeof(float)*M*N);
|
|
|
|
|
|
|
|
|
|
for (int i = 0; i < M*K; i++) {
|
|
|
|
|
src0[i] = 0.8 - rand() / (float)RAND_MAX;
|
|
|
|
|
/*src0[i] = rand() / (float)RAND_MAX;*/
|
|
|
|
|
src0[i] = i % 3;
|
|
|
|
|
/*src0[i] = i % 3;*/
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
for (int i = 0; i < N*K; i++) {
|
|
|
|
|
src1[i] = 0.8 - rand() / (float)RAND_MAX;
|
|
|
|
|
/*src1[i] = rand() / (float)RAND_MAX;*/
|
|
|
|
|
src1[i] = i % 4;
|
|
|
|
|
/*src1[i] = i % 3;*/
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void * src0_gq = NULL;
|
|
|
|
@ -931,6 +1133,34 @@ int main(int argc, const char ** argv) {
|
|
|
|
|
printf("%s: elapsed us: %d / %f ms\n", __func__, (int)(end_us - start_us), (end_us - start_us) / 1000.0 / nIter);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
#if 0
|
|
|
|
|
// print src0
|
|
|
|
|
printf("src0:\n");
|
|
|
|
|
for (int i = 0; i < M; i++) {
|
|
|
|
|
for (int j = 0; j < K; j++) {
|
|
|
|
|
printf("%4.1f ", src0[i*K+j]);
|
|
|
|
|
}
|
|
|
|
|
printf("\n");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// print src1
|
|
|
|
|
printf("src1:\n");
|
|
|
|
|
for (int i = 0; i < N; i++) {
|
|
|
|
|
for (int j = 0; j < K; j++) {
|
|
|
|
|
printf("%4.1f ", src1[i*K+j]);
|
|
|
|
|
}
|
|
|
|
|
printf("\n");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
printf("dst:\n");
|
|
|
|
|
for (int i = 0; i < M; i++) {
|
|
|
|
|
for (int j = 0; j < N; j++) {
|
|
|
|
|
printf("%4.1f ", dst[i*N+j]);
|
|
|
|
|
}
|
|
|
|
|
printf("\n");
|
|
|
|
|
}
|
|
|
|
|
#endif
|
|
|
|
|
|
|
|
|
|
printf("%f\n", sum);
|
|
|
|
|
|
|
|
|
|
free(src0);
|
|
|
|
|