gq : method 4 (AVX2)

gq
Georgi Gerganov 1 year ago
parent 37dcfad83b
commit 054d97e0e1
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735

@ -46,7 +46,7 @@ if (GGML_SANITIZE_UNDEFINED)
endif()
#set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -ffast-math")
#set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -march=native")
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -march=native")
#set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -mcpu=native")
# dependencies

@ -13,8 +13,10 @@
#include <sys/time.h>
#ifdef __ARM_NEON
#if defined(__ARM_NEON)
#include "arm_neon.h"
#elif defined(__AVX__) || defined(__AVX2__)
#include "immintrin.h"
#endif
#ifndef MIN
@ -30,10 +32,10 @@ const int K = 1280;
//const int N = 64;
//const int K = 64;
const int QK = 32;
#define QB 2
#define QK 64
#define QB 4
//#define GGML_GQ_USE_FP16_SCALE
/*#define GGML_GQ_USE_FP16_SCALE*/
#if defined(GGML_GQ_USE_FP16_SCALE)
#define gq_scale_t ggml_fp16_t
@ -45,8 +47,8 @@ const int QK = 32;
#define GGML_GQ_TO_FP32(x) (x)
#endif
#define gq_t_bits 32
#define gq_quant_t uint32_t
#define gq_t_bits 64
#define gq_quant_t uint64_t
float frand() {
return (float) rand() / (float) RAND_MAX;
@ -312,7 +314,7 @@ void quantize_2_row(const float * restrict src, void * restrict dst, int k) {
for (int s = 0; s < nq; ++s) {
memset(pp, 0, sizeof(pp));
#if 0
#if 1
for (int l = 0; l < gq_t_bits; l++) {
const float v = src[i*QK + s*gq_t_bits + l];
const uint8_t q = (v - min)*id + frand();
@ -637,6 +639,7 @@ void mul_mat_gq_2(
//
// method 3
// (does not work)
//
static inline int quantize_3_blocks_per_row(int k) {
@ -832,7 +835,7 @@ void vec_dot_gq_3(const int n, float * restrict s, const void * restrict x, cons
const gq_quant_t * restrict pb0 = (const gq_quant_t *) (pd0 + nb);
const gq_quant_t * restrict pb1 = (const gq_quant_t *) (pd1 + nb);
#if 0
#if 1
for (int i = 0; i < nb; i++) {
int isum = 0;
@ -960,32 +963,6 @@ void vec_dot_gq_3(const int n, float * restrict s, const void * restrict x, cons
*s = sumf;
}
// for (int s = 0; s < nq; ++s) {
// const uint8x8_t m00 = vld1_u8((const uint8_t *) (pb0 + i*nq*QB + s*QB + 0));
// const uint8x8_t m01 = vld1_u8((const uint8_t *) (pb0 + i*nq*QB + s*QB + 1));
// const uint8x8_t m02 = vld1_u8((const uint8_t *) (pb0 + i*nq*QB + s*QB + 2));
// const uint8x8_t m03 = vld1_u8((const uint8_t *) (pb0 + i*nq*QB + s*QB + 3));
//
// const uint8x8_t m10 = vld1_u8((const uint8_t *) (pb1 + i*nq*QB + s*QB + 0));
// const uint8x8_t m11 = vld1_u8((const uint8_t *) (pb1 + i*nq*QB + s*QB + 1));
// const uint8x8_t m12 = vld1_u8((const uint8_t *) (pb1 + i*nq*QB + s*QB + 2));
// const uint8x8_t m13 = vld1_u8((const uint8_t *) (pb1 + i*nq*QB + s*QB + 3));
//
// isum += (1 << 0)*(vaddv_u8(vcnt_u8(vand_u8(m00, m10))));
// isum += (1 << 1)*(vaddv_u8(vcnt_u8(vand_u8(m00, m11))) + vaddv_u8(vcnt_u8(vand_u8(m01, m10))));
// isum += (1 << 2)*(vaddv_u8(vcnt_u8(vand_u8(m00, m12))) + vaddv_u8(vcnt_u8(vand_u8(m01, m11))) + vaddv_u8(vcnt_u8(vand_u8(m02, m10))));
// isum += (1 << 3)*(vaddv_u8(vcnt_u8(vand_u8(m00, m13))) + vaddv_u8(vcnt_u8(vand_u8(m01, m12))) + vaddv_u8(vcnt_u8(vand_u8(m02, m11))) + vaddv_u8(vcnt_u8(vand_u8(m03, m10))));
// isum += (1 << 4)*(vaddv_u8(vcnt_u8(vand_u8(m01, m13))) + vaddv_u8(vcnt_u8(vand_u8(m02, m12))) + vaddv_u8(vcnt_u8(vand_u8(m03, m11))));
// isum += (1 << 5)*(vaddv_u8(vcnt_u8(vand_u8(m02, m13))) + vaddv_u8(vcnt_u8(vand_u8(m03, m12))));
// isum += (1 << 6)*(vaddv_u8(vcnt_u8(vand_u8(m03, m13))));
// }
//
// const float d0 = GGML_GQ_TO_FP32(pd0[i]);
// const float d1 = GGML_GQ_TO_FP32(pd1[i]);
//
// sumf += d0*d1*isum;
// }
// use vec_dot_gq_3 to compute the dot product of two rows
void mul_mat_gq_3(
const void * src0,
@ -1009,6 +986,489 @@ void mul_mat_gq_3(
}
}
//
// method 4
// 4-bit quantization
//
static inline int quantize_4_blocks_per_row(int k) {
return k/QK;
}
static inline int quantize_4_row_size(int k) {
const int nb = quantize_4_blocks_per_row(k);
return nb*(2*sizeof(gq_scale_t) + QK/2);
}
void quantize_4_row(const float * restrict src, void * restrict dst, int k) {
assert(k % QK == 0);
assert(QB == 4);
const int nb = quantize_4_blocks_per_row(k);
gq_scale_t * restrict pm = (gq_scale_t *) (dst);
gq_scale_t * restrict pd = (gq_scale_t *) (pm + nb);
uint8_t * restrict pb = (uint8_t *) (pd + nb);
uint8_t pp[QK/2];
for (int i = 0; i < nb; i++) {
memset(pp, 0, sizeof(pp));
float min = FLT_MAX;
float max = -FLT_MAX;
#if defined(__AVX2__)
{
assert(QK < 64);
__m256 srcv[QK/8];
__m256 minv[QK/8]; // ???????
__m256 maxv[QK/8]; // ???????
for (int l = 0; l < QK/8; l++) {
srcv[l] = _mm256_loadu_ps(src + i*QK + 8*l);
}
for (int l = 0; l < QK/8/2; l++) {
minv[2*l] = _mm256_min_ps(srcv[2*l], srcv[2*l+1]);
maxv[2*l] = _mm256_max_ps(srcv[2*l], srcv[2*l+1]);
}
for (int l = 0; l < QK/8/4; l++) {
minv[4*l] = _mm256_min_ps(minv[4*l], minv[4*l+2]);
maxv[4*l] = _mm256_max_ps(maxv[4*l], maxv[4*l+2]);
}
for (int l = 0; l < QK/8/8; l++) {
minv[8*l] = _mm256_min_ps(minv[8*l], minv[8*l+4]);
maxv[8*l] = _mm256_max_ps(maxv[8*l], maxv[8*l+4]);
}
//min = MIN(minv[0][0], MIN(minv[0][1], MIN(minv[0][2], MIN(minv[0][3], MIN(minv[0][4], MIN(minv[0][5], MIN(minv[0][6], minv[0][7])))))));
//max = MAX(maxv[0][0], MAX(maxv[0][1], MAX(maxv[0][2], MAX(maxv[0][3], MAX(maxv[0][4], MAX(maxv[0][5], MAX(maxv[0][6], maxv[0][7])))))));
const __m256 minv0_0 = _mm256_permute2f128_ps(minv[0], minv[0], 3);
const __m256 minv0_1 = _mm256_min_ps(minv[0], minv0_0);
const __m256 minv0_2 = _mm256_permute_ps(minv0_1, 0x4e);
const __m256 minv0_3 = _mm256_min_ps(minv0_1, minv0_2);
const __m256 minv0_4 = _mm256_permute_ps(minv0_3, 0xb1);
const __m256 minv0_5 = _mm256_min_ps(minv0_3, minv0_4);
const __m256 maxv0_0 = _mm256_permute2f128_ps(maxv[0], maxv[0], 3);
const __m256 maxv0_1 = _mm256_max_ps(maxv[0], maxv0_0);
const __m256 maxv0_2 = _mm256_permute_ps(maxv0_1, 0x4e);
const __m256 maxv0_3 = _mm256_max_ps(maxv0_1, maxv0_2);
const __m256 maxv0_4 = _mm256_permute_ps(maxv0_3, 0xb1);
const __m256 maxv0_5 = _mm256_max_ps(maxv0_3, maxv0_4);
min = _mm256_cvtss_f32(minv0_5);
max = _mm256_cvtss_f32(maxv0_5);
const float d = (max - min) / ((1 << QB) - 2);
const float id = d ? 1.0/d : 0.0;
pm[i] = GGML_FP32_TO_GQ(min);
pd[i] = GGML_FP32_TO_GQ(d);
const __m256 idv = _mm256_set1_ps(id);
for (int l = 0; l < QK/8; l++) {
__m256 v = _mm256_mul_ps(_mm256_sub_ps(srcv[l], _mm256_set1_ps(min)), idv);
#if 0
v[0] += frand(); v[1] += frand(); v[2] += frand(); v[3] += frand();
v[4] += frand(); v[5] += frand(); v[6] += frand(); v[7] += frand();
#endif
// convert to uint8
__m256i vi = _mm256_cvtps_epi32(v);
uint32_t vi_0 = _mm256_extract_epi32(vi, 0);
uint32_t vi_1 = _mm256_extract_epi32(vi, 1);
uint32_t vi_2 = _mm256_extract_epi32(vi, 2);
uint32_t vi_3 = _mm256_extract_epi32(vi, 3);
uint32_t vi_4 = _mm256_extract_epi32(vi, 4);
uint32_t vi_5 = _mm256_extract_epi32(vi, 5);
uint32_t vi_6 = _mm256_extract_epi32(vi, 6);
uint32_t vi_7 = _mm256_extract_epi32(vi, 7);
// convert to 4-bit, 2 consecutive packed into 1 byte
pp[4*l + 0] = vi_0 | (vi_1 << 4);
pp[4*l + 1] = vi_2 | (vi_3 << 4);
pp[4*l + 2] = vi_4 | (vi_5 << 4);
pp[4*l + 3] = vi_6 | (vi_7 << 4);
//printf("vi: %7d %7d %7d %7d %7d %7d %7d %7d\n", vi_0, vi_1, vi_2, vi_3, vi_4, vi_5, vi_6, vi_7);
//printf("v : %7.3f %7.3f %7.3f %7.3f %7.3f %7.3f %7.3f %7.3f\n", v[0], v[1], v[2], v[3], v[4], v[5], v[6], v[7]);
}
memcpy(pb + i*QK/2, pp, sizeof(pp));
}
#elif defined(__ARM_NEON)
{
// TODO
}
#else
{
for (int l = 0; l < QK; l++) {
const float v = src[i*QK + l];
if (v < min) min = v;
if (v > max) max = v;
}
const float d = (max - min) / ((1 << QB) - 1);
const float id = d ? 1.0/d : 0.0;
pm[i] = GGML_FP32_TO_GQ(min);
pd[i] = GGML_FP32_TO_GQ(d);
for (int l = 0; l < QK; l++) {
const float v = (src[i*QK + l] - min) * id;
const uint8_t vi = (uint8_t) (v + frand());
pp[l/2] |= (vi & 0xf) << (4*(l & 1));
}
memcpy(pb + i*QK/2, pp, sizeof(pp));
}
#endif
//printf("min %f max %f\n", min, max);
}
}
// reimplementation of quantize_4 using quantize_4_row
void quantize_4(const float * restrict src, char * restrict dst, int n, int k) {
assert(k % QK == 0);
for (int j = 0; j < n; j++) {
quantize_4_row(src + j*k, dst, k);
dst = (char *) dst + quantize_4_row_size(k);
}
}
// horizontally reduce 8 32-bit integers
static inline int32_t _mm256_hadd_epi32_gg(__m256i v) {
__m128i v0 = _mm256_extractf128_si256(v, 0);
__m128i v1 = _mm256_extractf128_si256(v, 1);
v0 = _mm_add_epi32(v0, v1);
v1 = _mm_shuffle_epi32(v0, 0x0e);
v0 = _mm_add_epi32(v0, v1);
v1 = _mm_shuffle_epi32(v0, 0x01);
v0 = _mm_add_epi32(v0, v1);
return _mm_cvtsi128_si32(v0);
}
//static inline float _mm256_hadd_epi32_gg(__m256i v) {
// const __m256 v0 = _mm256_cvtepi32_ps(v);
// const __m128 t0 = _mm_add_ps(_mm256_castps256_ps128(v0), _mm256_extractf128_ps(v0, 1));
// const __m128 t1 = _mm_hadd_ps(t0, t0);
//
// return _mm_cvtss_f32(_mm_hadd_ps(t1, t1));
//}
// horizontally reduce 32 8-bit integers
static inline int32_t _mm256_hadd_epi8_gg(__m256i v0) {
__m256i v1 = _mm256_maddubs_epi16(v0, _mm256_set1_epi8(1));
__m256i v2 = _mm256_madd_epi16 (v1, _mm256_set1_epi16(1));
return _mm256_hadd_epi32_gg(v2);
}
void vec_dot_gq_4(const int n, float * restrict s, const void * restrict x, const void * restrict y) {
const int nb = quantize_4_blocks_per_row(n);
const gq_scale_t * restrict pm0 = (const gq_scale_t *) x;
const gq_scale_t * restrict pm1 = (const gq_scale_t *) y;
const gq_scale_t * restrict pd0 = pm0 + nb;
const gq_scale_t * restrict pd1 = pm1 + nb;
const uint8_t * restrict pb0 = (const uint8_t *) (pd0 + nb);
const uint8_t * restrict pb1 = (const uint8_t *) (pd1 + nb);
float sumf = 0.0;
// scalar
//for (int i = 0; i < nb; i++) {
// const float m0 = GGML_GQ_TO_FP32(pm0[i]);
// const float d0 = GGML_GQ_TO_FP32(pd0[i]);
// const float m1 = GGML_GQ_TO_FP32(pm1[i]);
// const float d1 = GGML_GQ_TO_FP32(pd1[i]);
// const uint8_t * restrict p0 = pb0 + i*QK/2;
// const uint8_t * restrict p1 = pb1 + i*QK/2;
// for (int j = 0; j < QK/2; j++) {
// const uint8_t v0 = p0[j];
// const uint8_t v1 = p1[j];
// const float f0 = d0*(v0 & 0xf) + m0;
// const float f1 = d0*(v0 >> 4) + m0;
// const float f2 = d1*(v1 & 0xf) + m1;
// const float f3 = d1*(v1 >> 4) + m1;
// sumf += f0*f2 + f1*f3;
// }
//}
#if defined(__AVX2__)
#if QK == 64 && 0
__m256 sumv0 = _mm256_setzero_ps();
__m256 sumv1 = _mm256_setzero_ps();
for (int i = 0; i < nb; i++) {
const float m0 = GGML_GQ_TO_FP32(pm0[i]);
const float d0 = GGML_GQ_TO_FP32(pd0[i]);
const float m1 = GGML_GQ_TO_FP32(pm1[i]);
const float d1 = GGML_GQ_TO_FP32(pd1[i]);
const uint8_t * restrict p0 = pb0 + i*QK/2;
const uint8_t * restrict p1 = pb1 + i*QK/2;
const __m256 m0v = _mm256_set1_ps(m0);
const __m256 d0v = _mm256_set1_ps(d0);
const __m256 m1v = _mm256_set1_ps(m1);
const __m256 d1v = _mm256_set1_ps(d1);
const __m256i m4b = _mm256_set1_epi8(0xf);
__m256i v0 = _mm256_loadu_si256((__m256i *) p0);
//_mm_prefetch((const char *) (p0 + 32), _MM_HINT_T0);
//_mm_prefetch((const char *) (p1 + 32), _MM_HINT_T0);
//_mm_prefetch((const char *) (pm0 + i + 1), _MM_HINT_T0);
//_mm_prefetch((const char *) (pm1 + i + 1), _MM_HINT_T0);
//_mm_prefetch((const char *) (pd0 + i + 1), _MM_HINT_T0);
//_mm_prefetch((const char *) (pd1 + i + 1), _MM_HINT_T0);
__m256i v00 = _mm256_and_si256(v0, _mm256_set1_epi32(0x000000FF));
__m256i v01 = _mm256_srli_epi32(_mm256_and_si256(v0, _mm256_set1_epi32(0x0000FFFF)), 8);
__m256i v02 = _mm256_srli_epi32(_mm256_and_si256(v0, _mm256_set1_epi32(0x00FFFFFF)), 16);
__m256i v03 = _mm256_srli_epi32(v0, 24);
//////////////////////
//{
// uint32_t vi_0 = _mm256_extract_epi32(v00, 0);
// uint32_t vi_1 = _mm256_extract_epi32(v00, 1);
// uint32_t vi_2 = _mm256_extract_epi32(v00, 2);
// uint32_t vi_3 = _mm256_extract_epi32(v00, 3);
// uint32_t vi_4 = _mm256_extract_epi32(v00, 4);
// uint32_t vi_5 = _mm256_extract_epi32(v00, 5);
// uint32_t vi_6 = _mm256_extract_epi32(v00, 6);
// uint32_t vi_7 = _mm256_extract_epi32(v00, 7);
// printf("v0: %7d %7d %7d %7d %7d %7d %7d %7d\n", vi_0, vi_1, vi_2, vi_3, vi_4, vi_5, vi_6, vi_7);
// printf("p0: %7d %7d %7d %7d %7d %7d %7d %7d\n", p0[0], p0[4], p0[8], p0[12], p0[16], p0[20], p0[24], p0[28]);
// printf("p1: %7d %7d %7d %7d %7d %7d %7d %7d\n", p0[1], p0[5], p0[9], p0[13], p0[17], p0[21], p0[25], p0[29]);
// printf("p2: %7d %7d %7d %7d %7d %7d %7d %7d\n", p0[2], p0[6], p0[10], p0[14], p0[18], p0[22], p0[26], p0[30]);
// printf("p3: %7d %7d %7d %7d %7d %7d %7d %7d\n", p0[3], p0[7], p0[11], p0[15], p0[19], p0[23], p0[27], p0[31]);
//}
// compute 32 x 4-bit values (low and high)
__m256i v00l = _mm256_and_si256(v00, m4b);
__m256i v01l = _mm256_and_si256(v01, m4b);
__m256i v02l = _mm256_and_si256(v02, m4b);
__m256i v03l = _mm256_and_si256(v03, m4b);
__m256i v00h = _mm256_srli_epi32(v00, 4);
__m256i v01h = _mm256_srli_epi32(v01, 4);
__m256i v02h = _mm256_srli_epi32(v02, 4);
__m256i v03h = _mm256_srli_epi32(v03, 4);
//{
// uint32_t vi_0 = _mm256_extract_epi32(v00l, 0);
// uint32_t vi_1 = _mm256_extract_epi32(v00l, 1);
// uint32_t vi_2 = _mm256_extract_epi32(v00l, 2);
// uint32_t vi_3 = _mm256_extract_epi32(v00l, 3);
// uint32_t vi_4 = _mm256_extract_epi32(v00l, 4);
// uint32_t vi_5 = _mm256_extract_epi32(v00l, 5);
// uint32_t vi_6 = _mm256_extract_epi32(v00l, 6);
// uint32_t vi_7 = _mm256_extract_epi32(v00l, 7);
// printf("v0l: %7d %7d %7d %7d %7d %7d %7d %7d\n", vi_0, vi_1, vi_2, vi_3, vi_4, vi_5, vi_6, vi_7);
// vi_0 = _mm256_extract_epi32(v00h, 0);
// vi_1 = _mm256_extract_epi32(v00h, 1);
// vi_2 = _mm256_extract_epi32(v00h, 2);
// vi_3 = _mm256_extract_epi32(v00h, 3);
// vi_4 = _mm256_extract_epi32(v00h, 4);
// vi_5 = _mm256_extract_epi32(v00h, 5);
// vi_6 = _mm256_extract_epi32(v00h, 6);
// vi_7 = _mm256_extract_epi32(v00h, 7);
// printf("v0h: %7d %7d %7d %7d %7d %7d %7d %7d\n", vi_0, vi_1, vi_2, vi_3, vi_4, vi_5, vi_6, vi_7);
//}
// convert to float
__m256 vf00l = _mm256_cvtepi32_ps(v00l);
__m256 vf01l = _mm256_cvtepi32_ps(v01l);
__m256 vf02l = _mm256_cvtepi32_ps(v02l);
__m256 vf03l = _mm256_cvtepi32_ps(v03l);
__m256 vf00h = _mm256_cvtepi32_ps(v00h);
__m256 vf01h = _mm256_cvtepi32_ps(v01h);
__m256 vf02h = _mm256_cvtepi32_ps(v02h);
__m256 vf03h = _mm256_cvtepi32_ps(v03h);
//{
// printf("vf00l: %7.3f %7.3f %7.3f %7.3f %7.3f %7.3f %7.3f %7.3f\n", vf00l[0], vf00l[1], vf00l[2], vf00l[3], vf00l[4], vf00l[5], vf00l[6], vf00l[7]);
// printf("vf01l: %7.3f %7.3f %7.3f %7.3f %7.3f %7.3f %7.3f %7.3f\n", vf01l[0], vf01l[1], vf01l[2], vf01l[3], vf01l[4], vf01l[5], vf01l[6], vf01l[7]);
// printf("vf02l: %7.3f %7.3f %7.3f %7.3f %7.3f %7.3f %7.3f %7.3f\n", vf02l[0], vf02l[1], vf02l[2], vf02l[3], vf02l[4], vf02l[5], vf02l[6], vf02l[7]);
// printf("vf03l: %7.3f %7.3f %7.3f %7.3f %7.3f %7.3f %7.3f %7.3f\n", vf03l[0], vf03l[1], vf03l[2], vf03l[3], vf03l[4], vf03l[5], vf03l[6], vf03l[7]);
//}
// multiply by scale and add offset
vf00l = _mm256_fmadd_ps(vf00l, d0v, m0v);
vf01l = _mm256_fmadd_ps(vf01l, d0v, m0v);
vf02l = _mm256_fmadd_ps(vf02l, d0v, m0v);
vf03l = _mm256_fmadd_ps(vf03l, d0v, m0v);
vf00h = _mm256_fmadd_ps(vf00h, d0v, m0v);
vf01h = _mm256_fmadd_ps(vf01h, d0v, m0v);
vf02h = _mm256_fmadd_ps(vf02h, d0v, m0v);
vf03h = _mm256_fmadd_ps(vf03h, d0v, m0v);
__m256i v1 = _mm256_loadu_si256((__m256i *) p1);
__m256i v10 = _mm256_and_si256(v1, _mm256_set1_epi32(0x000000FF));
__m256i v11 = _mm256_srli_epi32(_mm256_and_si256(v1, _mm256_set1_epi32(0x0000FFFF)), 8);
__m256i v12 = _mm256_srli_epi32(_mm256_and_si256(v1, _mm256_set1_epi32(0x00FFFFFF)), 16);
__m256i v13 = _mm256_srli_epi32(v1, 24);
__m256i v10l = _mm256_and_si256(v10, m4b);
__m256i v11l = _mm256_and_si256(v11, m4b);
__m256i v12l = _mm256_and_si256(v12, m4b);
__m256i v13l = _mm256_and_si256(v13, m4b);
__m256i v10h = _mm256_srli_epi32(v10, 4);
__m256i v11h = _mm256_srli_epi32(v11, 4);
__m256i v12h = _mm256_srli_epi32(v12, 4);
__m256i v13h = _mm256_srli_epi32(v13, 4);
__m256 vf10l = _mm256_cvtepi32_ps(v10l);
__m256 vf11l = _mm256_cvtepi32_ps(v11l);
__m256 vf12l = _mm256_cvtepi32_ps(v12l);
__m256 vf13l = _mm256_cvtepi32_ps(v13l);
__m256 vf10h = _mm256_cvtepi32_ps(v10h);
__m256 vf11h = _mm256_cvtepi32_ps(v11h);
__m256 vf12h = _mm256_cvtepi32_ps(v12h);
__m256 vf13h = _mm256_cvtepi32_ps(v13h);
vf10l = _mm256_fmadd_ps(vf10l, d1v, m1v);
vf11l = _mm256_fmadd_ps(vf11l, d1v, m1v);
vf12l = _mm256_fmadd_ps(vf12l, d1v, m1v);
vf13l = _mm256_fmadd_ps(vf13l, d1v, m1v);
vf10h = _mm256_fmadd_ps(vf10h, d1v, m1v);
vf11h = _mm256_fmadd_ps(vf11h, d1v, m1v);
vf12h = _mm256_fmadd_ps(vf12h, d1v, m1v);
vf13h = _mm256_fmadd_ps(vf13h, d1v, m1v);
// compute dot product
sumv0 = _mm256_fmadd_ps(vf00l, vf10l, sumv0);
sumv0 = _mm256_fmadd_ps(vf01l, vf11l, sumv0);
sumv0 = _mm256_fmadd_ps(vf02l, vf12l, sumv0);
sumv0 = _mm256_fmadd_ps(vf03l, vf13l, sumv0);
sumv1 = _mm256_fmadd_ps(vf00h, vf10h, sumv1);
sumv1 = _mm256_fmadd_ps(vf01h, vf11h, sumv1);
sumv1 = _mm256_fmadd_ps(vf02h, vf12h, sumv1);
sumv1 = _mm256_fmadd_ps(vf03h, vf13h, sumv1);
}
// accumulate (horizontal sum)
const __m256 vdot = _mm256_add_ps(sumv0, sumv1);
const __m128 t0 = _mm_add_ps(_mm256_castps256_ps128(vdot), _mm256_extractf128_ps(vdot, 1));
const __m128 t1 = _mm_hadd_ps(t0, t0);
sumf += _mm_cvtss_f32(_mm_hadd_ps(t1, t1));
#elif QK == 64
float sum00 = 0.0f;
float sum01 = 0.0f;
float sum10 = 0.0f;
float sum11 = 0.0f;
for (int i = 0; i < nb; i++) {
const float m0 = GGML_GQ_TO_FP32(pm0[i]);
const float d0 = GGML_GQ_TO_FP32(pd0[i]);
const float m1 = GGML_GQ_TO_FP32(pm1[i]);
const float d1 = GGML_GQ_TO_FP32(pd1[i]);
const uint8_t * restrict p0 = pb0 + i*QK/2;
const uint8_t * restrict p1 = pb1 + i*QK/2;
const __m256i m4b = _mm256_set1_epi8(0xf);
const __m256i v0 = _mm256_loadu_si256((__m256i *) p0); // 64 x 4
const __m256i v1 = _mm256_loadu_si256((__m256i *) p1);
const __m256i v0l = _mm256_and_si256(v0, m4b); // 32 x 8
const __m256i v1l = _mm256_and_si256(v1, m4b); //
const __m256i v0h = _mm256_and_si256(_mm256_srli_epi16(v0, 4), m4b);
const __m256i v1h = _mm256_and_si256(_mm256_srli_epi16(v1, 4), m4b);
const __m256i pl = _mm256_maddubs_epi16(v0l, v1l);
const __m256i ph = _mm256_maddubs_epi16(v0h, v1h);
const __m256i p16 = _mm256_add_epi16(ph, pl);
const __m256i p = _mm256_madd_epi16(_mm256_set1_epi16(1), p16);
sum00 += m0*m1;
sum01 += m1*d0*(_mm256_hadd_epi8_gg(_mm256_add_epi8(v0l, v0h)));
sum10 += m0*d1*(_mm256_hadd_epi8_gg(_mm256_add_epi8(v1l, v1h)));
sum11 += d0*d1*(_mm256_hadd_epi32_gg(p));
}
sumf = 64.0*sum00 + sum01 + sum10 + sum11;
#else
#error "QK must be 64"
#endif
#endif
*s = sumf;
}
// use vec_dot_gq_4 to compute the dot product of two rows
void mul_mat_gq_4(
const void * src0,
const void * src1, // transposed
float * dst,
int m, int n, int k) {
assert(k % QK == 0);
// needed to initialize f16 tables
{
struct ggml_init_params params;
struct ggml_context * ctx = ggml_init(params);
ggml_free(ctx);
}
const int nb = quantize_4_blocks_per_row(k);
for (int ir0 = 0; ir0 < m; ir0++) {
for (int ir1 = 0; ir1 < n; ir1++) {
vec_dot_gq_4(k, dst + ir1, src0, src1);
src1 = (const char *) src1 + quantize_4_row_size(k);
}
src0 = (const char *) src0 + quantize_4_row_size(k);
src1 = (const char *) src1 - n*quantize_4_row_size(k);
dst = (float *) dst + n;
}
}
int main(int argc, const char ** argv) {
assert(sizeof(gq_quant_t)*8 == gq_t_bits);
@ -1029,7 +1489,7 @@ int main(int argc, const char ** argv) {
for (int i = 0; i < M*K; i++) {
src0[i] = 0.8 - rand() / (float)RAND_MAX;
/*src0[i] = rand() / (float)RAND_MAX;*/
/*src0[i] = i % 3;*/
/*src0[i] = i % 2;*/
}
for (int i = 0; i < N*K; i++) {
@ -1064,6 +1524,13 @@ int main(int argc, const char ** argv) {
sizegq = quantize_3_row_size(K)*M + quantize_3_row_size(K)*N;
}
if (method == 4) {
src0_gq = calloc(1, quantize_4_row_size(K)*M);
src1_gq = calloc(1, quantize_4_row_size(K)*N);
sizegq = quantize_4_row_size(K)*M + quantize_4_row_size(K)*N;
}
}
const size_t sizef16 = sizeof(ggml_fp16_t)*M*K + sizeof(ggml_fp16_t)*N*K;
@ -1089,6 +1556,11 @@ int main(int argc, const char ** argv) {
quantize_3(src1, src1_gq, N, K);
}
if (method == 4) {
quantize_4(src0, src0_gq, M, K);
quantize_4(src1, src1_gq, N, K);
}
const uint64_t t_end = get_time_us();
printf("convert time: %f ms / method = %d\n", (t_end - t_start) / 1000.0, method);
}
@ -1120,6 +1592,10 @@ int main(int argc, const char ** argv) {
if (method == 3) {
mul_mat_gq_3(src0_gq, src1_gq, dst, M, N, K);
}
if (method == 4) {
mul_mat_gq_4(src0_gq, src1_gq, dst, M, N, K);
}
}
for (int i = 0; i < N; i++) {

Loading…
Cancel
Save