From cc94fdafe7bc9e570fd122c2021c318141b4ff06 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sat, 25 Feb 2023 17:51:27 +0200 Subject: [PATCH] ggml : 4-bit quantization works (only scalar for now) --- examples/gpt-2/convert-ckpt-to-ggml.py | 190 +------------- examples/gpt-2/quantize.cpp | 4 +- src/ggml.c | 347 +++++-------------------- 3 files changed, 76 insertions(+), 465 deletions(-) diff --git a/examples/gpt-2/convert-ckpt-to-ggml.py b/examples/gpt-2/convert-ckpt-to-ggml.py index b2f06e2..0d018ca 100644 --- a/examples/gpt-2/convert-ckpt-to-ggml.py +++ b/examples/gpt-2/convert-ckpt-to-ggml.py @@ -51,190 +51,6 @@ def convert_to_ftype(data, ftype): if ftype == 1: return data.astype(np.float16) - # qint4_0 - # C code: - # { - # for (int l = 0; l < QK; l++) { - # const float v = src[i*QK + l]; - # amax = MAX(amax, fabsf(v)); - # } - # - # const float d = amax / ((1 << (QB - 1)) - 1); - # const float id = d ? 1.0/d : 0.0; - # - # pd[i] = GGML_FP32_TO_GQ(d); - # - # for (int l = 0; l < QK; l++) { - # const float v = src[i*QK + l]*id; - # const int8_t vi = ((int8_t) (round(v))) + 8; - # assert(vi >= 0 && vi < 16); - # pp[l/2] |= (vi & 0xf) << (4*(l & 1)); - # } - # - # memcpy(pb + i*QK/2, pp, sizeof(pp)); - # } - if ftype == 2: - assert data.dtype == np.float32 - assert data.shape[-1] % 64 == 0 - - # create 2 new arrays: - # - pd: float32 (lowest dimension is data.shape[-1] // 64) - # - pb: int8 - pd = np.zeros(data.shape[:-1] + (data.shape[-1] // 64,), dtype=np.float32) - pb = np.zeros(data.shape[:-1] + (data.shape[-1], ), dtype=np.int8) - - # the quantized data goes here - dst = np.zeros((data.size // 64) * (4 + 32), dtype=np.uint8) - - print("data:", data.shape, data.size) - print("pd: ", pd.shape, pd.size) - print("pb: ", pb.shape, pb.size) - print("dst: ", dst.shape, dst.size) - - for i in range(0, data.shape[-1], 64): - max_abs = np.max(np.abs(data[..., i:i+64])) - max_q = (1 << 3) - 1 - d = max_abs / max_q - id = 1.0 / d if d != 0 else 0.0 - pd[..., i//64] = d - - for j in range(64): - v = data[..., i+j] * id - vi = np.round(v).astype(np.int8) + 8 - assert np.all(vi >= 0) and np.all(vi < 16) - - #ve = vi[...,(j & 1) == 0].reshape(-1, 1) - - #print("ve:", ve.shape, ve) - #print("vo:", vo.shape, vo) - #print("pb:", pb[..., (i+j)//2].shape, pb[..., (i+j)//2]) - - pb[..., i+j] = vi - - # convert to 1D array - pd = pd.reshape(-1, 1) - pb = pb.reshape(-1, 1) - - # populate the destination array - n = data.size - nr = data.shape[-1] - nn = nr//64 - for i in range(0, n, nr): - for j in range(0, nr, 64): - d = pd[(i//nr)*nn + j//64][0] - b = pb[i+j:i+j+64].reshape(-1) - - db = struct.unpack("4B", struct.pack("f", d)) - dst[(i//nr)*nn*36 + (j//64)*4 + 0] = db[0] - dst[(i//nr)*nn*36 + (j//64)*4 + 1] = db[1] - dst[(i//nr)*nn*36 + (j//64)*4 + 2] = db[2] - dst[(i//nr)*nn*36 + (j//64)*4 + 3] = db[3] - for k in range(32): - dst[(i//nr)*nn*36 + nn*4 + (j//64)*32 + k] = b[2*k] | (b[2*k+1] << 4) - - return dst - - # qint4_1 - # C code: - # { - # for (int l = 0; l < QK; l++) { - # const float v = src[i*QK + l]; - # if (v < min) min = v; - # if (v > max) max = v; - # } - - # const float d = (max - min) / ((1 << QB) - 1); - # const float id = d ? 1.0/d : 0.0; - - # pm[i] = GGML_FP32_TO_GQ(min); - # pd[i] = GGML_FP32_TO_GQ(d); - - # for (int l = 0; l < QK; l++) { - # const float v = (src[i*QK + l] - min) * id; - # const uint8_t vi = (uint8_t) (v + frand()); - # pp[l/2] |= (vi & 0xf) << (4*(l & 1)); - # } - - # memcpy(pb + i*QK/2, pp, sizeof(pp)); - # } - if ftype == 3: - assert data.dtype == np.float32 - assert data.shape[-1] % 64 == 0 - - # create 2 new arrays: - # - pd: float32 (lowest dimension is data.shape[-1] // 64) - # - pb: int8 - pm = np.zeros(data.shape[:-1] + (data.shape[-1] // 64,), dtype=np.float32) - pd = np.zeros(data.shape[:-1] + (data.shape[-1] // 64,), dtype=np.float32) - pb = np.zeros(data.shape[:-1] + (data.shape[-1], ), dtype=np.int8) - - # the quantized data goes here - dst = np.zeros((data.size // 64) * (4 + 4 + 32), dtype=np.uint8) - - print("data:", data.shape, data.size) - print("pm: ", pm.shape, pm.size) - print("pd: ", pd.shape, pd.size) - print("pb: ", pb.shape, pb.size) - print("dst: ", dst.shape, dst.size) - - for i in range(0, data.shape[-1], 64): - mmin = np.min(data[..., i:i+64]) - mmax = np.max(data[..., i:i+64]) - max_q = (1 << 4) - 1 - d = (mmax - mmin) / max_q - id = 1.0 / d if d != 0 else 0.0 - - pm[..., i//64] = mmin - pd[..., i//64] = d - - for j in range(64): - v = (data[..., i+j] - mmin) * id - vi = np.round(v).astype(np.uint8) - assert np.all(vi >= 0) and np.all(vi < 16) - - pb[..., i+j] = vi - - # convert to 1D array - pm = pm.reshape(-1, 1) - pd = pd.reshape(-1, 1) - pb = pb.reshape(-1, 1) - - # populate the destination array - n = data.size - nr = data.shape[-1] - nn = nr//64 - for i in range(0, n, nr): - for j in range(0, nr, 64): - m = pm[(i//nr)*nn + j//64][0] - - idx = (i//nr)*nn*40 + (j//64)*4 - - mb = struct.unpack("4B", struct.pack("f", m)) - dst[idx + 0] = mb[0] - dst[idx + 1] = mb[1] - dst[idx + 2] = mb[2] - dst[idx + 3] = mb[3] - - for j in range(0, nr, 64): - d = pd[(i//nr)*nn + j//64][0] - - idx = (i//nr)*nn*40 + 4*nn + (j//64)*4 - - db = struct.unpack("4B", struct.pack("f", d)) - dst[idx + 0] = db[0] - dst[idx + 1] = db[1] - dst[idx + 2] = db[2] - dst[idx + 3] = db[3] - - for j in range(0, nr, 64): - b = pb[i+j:i+j+64].reshape(-1) - - idx = (i//nr)*nn*40 + nn*8 + (j//64)*32 - for k in range(32): - dst[idx + k] = b[2*k] | (b[2*k+1] << 4) - - return dst - assert False, "Invalid ftype: " + str(ftype) if len(sys.argv) < 2: @@ -258,12 +74,12 @@ with open(dir_model + "/hparams.json", "r") as f: # ftype == 3 -> qint4_1 # # map from ftype to string -ftype_str = ["f32", "f16", "q4_0", "q4_1"] +ftype_str = ["f32", "f16"] ftype = 1 if len(sys.argv) > 2: ftype = int(sys.argv[2]) - if ftype < 0 or ftype > 3: + if ftype < 0 or ftype > 1: print("Invalid ftype: " + str(ftype)) sys.exit(1) fname_out = sys.argv[1] + "/ggml-model-" + ftype_str[ftype] + ".bin" @@ -312,8 +128,6 @@ for name, shape in list_vars: # "model/h.*/mlp/c_fc/w" # "model/h.*/mlp/c_proj/w" if name == "model/wte" or name[-2:] == "/w": - #if name[-6:] == "attn/w": - #if name == "model/wte": print(" Converting to " + ftype_str[ftype]) data = convert_to_ftype(data, ftype) ftype_cur = ftype diff --git a/examples/gpt-2/quantize.cpp b/examples/gpt-2/quantize.cpp index 7c9d44d..128cd5d 100644 --- a/examples/gpt-2/quantize.cpp +++ b/examples/gpt-2/quantize.cpp @@ -12,8 +12,9 @@ #include #include +#define QK 64 + size_t ggml_quantize_q4_0(float * src, void * dst, int n, int k) { - const int QK = 64; const int nb = k / QK; const size_t row_size = nb*(sizeof(float) + sizeof(uint8_t)*QK/2); @@ -63,7 +64,6 @@ size_t ggml_quantize_q4_0(float * src, void * dst, int n, int k) { } size_t ggml_quantize_q4_1(float * src, void * dst, int n, int k) { - const int QK = 64; const int nb = k / QK; const size_t row_size = nb*(2*sizeof(float) + sizeof(uint8_t)*QK/2); diff --git a/src/ggml.c b/src/ggml.c index d6cce7b..7378af2 100644 --- a/src/ggml.c +++ b/src/ggml.c @@ -352,106 +352,27 @@ static const size_t CACHE_LINE_SIZE_F32 = CACHE_LINE_SIZE/sizeof(float); // quantization // +#define QK 64 + // method 5 -// blocks of 64 elements -// represented with a single float (delta) and 32 8-bit ints (i.e 64 4-bit signed integer factors) +// blocks of QK elements +// represented with a single float (delta) and QK/2 8-bit ints (i.e QK 4-bit signed integer factors) void quantize_row_q4_0(const float * restrict x, void * restrict y, int k) { - assert(k % 64 == 0); + assert(k % QK == 0); - const int nb = k / 64; + const int nb = k / QK; float * restrict pd = (float *) (y); uint8_t * restrict pb = (uint8_t *) (pd + nb); - uint8_t pp[32]; + uint8_t pp[QK/2]; for (int i = 0; i < nb; i++) { float amax = 0.0f; // absolute max -#if defined(__AVX2__) { - __m256 srcv [8]; - __m256 asrcv[8]; - __m256 amaxv[8]; - - // TODO: unroll these loops (maybe not? this is more compact) - for (int l = 0; l < 8; l++) srcv[l] = _mm256_loadu_ps(x + i*64 + 8*l); - for (int l = 0; l < 8; l++) asrcv[l] = _mm256_and_ps(srcv[l], (__m256) _mm256_set1_epi32(0x7fffffff)); - for (int l = 0; l < 4; l++) amaxv[2*l] = _mm256_max_ps(asrcv[2*l], asrcv[2*l+1]); - for (int l = 0; l < 2; l++) amaxv[4*l] = _mm256_max_ps(amaxv[4*l], amaxv[4*l+2]); - for (int l = 0; l < 1; l++) amaxv[8*l] = _mm256_max_ps(amaxv[8*l], amaxv[8*l+4]); - - const __m256 amaxv0_0 = _mm256_permute2f128_ps(amaxv[0], amaxv[0], 3); - const __m256 amaxv0_1 = _mm256_max_ps(amaxv[0], amaxv0_0); - const __m256 amaxv0_2 = _mm256_permute_ps(amaxv0_1, 0x4e); - const __m256 amaxv0_3 = _mm256_max_ps(amaxv0_1, amaxv0_2); - const __m256 amaxv0_4 = _mm256_permute_ps(amaxv0_3, 0xb1); - const __m256 amaxv0_5 = _mm256_max_ps(amaxv0_3, amaxv0_4); - - amax = _mm256_cvtss_f32(amaxv0_5); - - //printf("amax = %f\n", amax); - - const float d = amax / ((1 << 3) - 1); - const float id = d ? 1.0f/d : 0.0f; - - pd[i] = d; - - const __m256 idv = _mm256_set1_ps(id); - - for (int l = 0; l < 8; l++) { - __m256 v = _mm256_mul_ps(srcv[l], idv); - -#if 0 - v[0] += frand(); v[1] += frand(); v[2] += frand(); v[3] += frand(); - v[4] += frand(); v[5] += frand(); v[6] += frand(); v[7] += frand(); -#endif - - // convert to int8 - __m256i vi = _mm256_cvtps_epi32(v); - vi = _mm256_add_epi32(vi, _mm256_set1_epi32(8)); - - int32_t vi_0 = _mm256_extract_epi32(vi, 0); - int32_t vi_1 = _mm256_extract_epi32(vi, 1); - int32_t vi_2 = _mm256_extract_epi32(vi, 2); - int32_t vi_3 = _mm256_extract_epi32(vi, 3); - - int32_t vi_4 = _mm256_extract_epi32(vi, 4); - int32_t vi_5 = _mm256_extract_epi32(vi, 5); - int32_t vi_6 = _mm256_extract_epi32(vi, 6); - int32_t vi_7 = _mm256_extract_epi32(vi, 7); - - // convert to 4-bit, 2 consecutive packed into 1 byte - pp[4*l + 0] = vi_0 | (vi_1 << 4); - pp[4*l + 1] = vi_2 | (vi_3 << 4); - pp[4*l + 2] = vi_4 | (vi_5 << 4); - pp[4*l + 3] = vi_6 | (vi_7 << 4); - - //printf("vi: %7d %7d %7d %7d %7d %7d %7d %7d\n", vi_0, vi_1, vi_2, vi_3, vi_4, vi_5, vi_6, vi_7); - //printf("v : %7.3f %7.3f %7.3f %7.3f %7.3f %7.3f %7.3f %7.3f\n", v[0], v[1], v[2], v[3], v[4], v[5], v[6], v[7]); - - assert(vi_0 >= 0 && vi_0 < 16); - assert(vi_1 >= 0 && vi_1 < 16); - assert(vi_2 >= 0 && vi_2 < 16); - assert(vi_3 >= 0 && vi_3 < 16); - - assert(vi_4 >= 0 && vi_4 < 16); - assert(vi_5 >= 0 && vi_5 < 16); - assert(vi_6 >= 0 && vi_6 < 16); - assert(vi_7 >= 0 && vi_7 < 16); - } - - memcpy(pb + i*32, pp, sizeof(pp)); - } -#elif defined(__ARM_NEON) && 0 - { - // TODO -#pragma warning "implement me !!" - } -#else - { - for (int l = 0; l < 64; l++) { - const float v = x[i*64 + l]; + for (int l = 0; l < QK; l++) { + const float v = x[i*QK + l]; amax = MAX(amax, fabsf(v)); } @@ -460,9 +381,9 @@ void quantize_row_q4_0(const float * restrict x, void * restrict y, int k) { pd[i] = d; - for (int l = 0; l < 64; l += 2) { - const float v0 = x[i*64 + l + 0]*id; - const float v1 = x[i*64 + l + 1]*id; + for (int l = 0; l < QK; l += 2) { + const float v0 = x[i*QK + l + 0]*id; + const float v1 = x[i*QK + l + 1]*id; const uint8_t vi0 = ((int8_t) (round(v0))) + 8; const uint8_t vi1 = ((int8_t) (round(v1))) + 8; @@ -473,33 +394,32 @@ void quantize_row_q4_0(const float * restrict x, void * restrict y, int k) { pp[l/2] = vi0 | (vi1 << 4); } - memcpy(pb + i*32, pp, sizeof(pp)); + memcpy(pb + i*QK/2, pp, sizeof(pp)); } -#endif } } // method 4 -// blocks of 64 elements -// represented with 2 floats (min + delta) and 32 8-bit ints (i.e 64 4-bit unsigned integer factors) +// blocks of QK elements +// represented with 2 floats (min + delta) and QK/2 8-bit ints (i.e QK 4-bit unsigned integer factors) void quantize_row_q4_1(const float * restrict x, void * restrict y, int k) { - assert(k % 64 == 0); + assert(k % QK == 0); - const int nb = k / 64; + const int nb = k / QK; float * restrict pm = (float *) (y); float * restrict pd = (float *) (pm + nb); uint8_t * restrict pb = (uint8_t *) (pd + nb); - uint8_t pp[32]; + uint8_t pp[QK/2]; for (int i = 0; i < nb; i++) { float min = FLT_MAX; float max = -FLT_MAX; { - for (int l = 0; l < 64; l++) { - const float v = x[i*64 + l]; + for (int l = 0; l < QK; l++) { + const float v = x[i*QK + l]; if (v < min) min = v; if (v > max) max = v; } @@ -510,9 +430,9 @@ void quantize_row_q4_1(const float * restrict x, void * restrict y, int k) { pm[i] = min; pd[i] = d; - for (int l = 0; l < 64; l += 2) { - const float v0 = (x[i*64 + l + 0] - min)*id; - const float v1 = (x[i*64 + l + 1] - min)*id; + for (int l = 0; l < QK; l += 2) { + const float v0 = (x[i*QK + l + 0] - min)*id; + const float v1 = (x[i*QK + l + 1] - min)*id; const uint8_t vi0 = round(v0); const uint8_t vi1 = round(v1); @@ -523,15 +443,15 @@ void quantize_row_q4_1(const float * restrict x, void * restrict y, int k) { pp[l/2] = vi0 | (vi1 << 4); } - memcpy(pb + i*32, pp, sizeof(pp)); + memcpy(pb + i*QK/2, pp, sizeof(pp)); } } } void dequantize_row_q4_0(const void * restrict x, float * restrict y, int k) { - assert(k % 64 == 0); + assert(k % QK == 0); - const int nb = k / 64; + const int nb = k / QK; const float * restrict pd = (const float *) (x); const uint8_t * restrict pb = (const uint8_t *) (pd + nb); @@ -539,9 +459,9 @@ void dequantize_row_q4_0(const void * restrict x, float * restrict y, int k) { for (int i = 0; i < nb; i++) { const float d = pd[i]; - const uint8_t * restrict pp = pb + i*32; + const uint8_t * restrict pp = pb + i*QK/2; - for (int l = 0; l < 64; l += 2) { + for (int l = 0; l < QK; l += 2) { const uint8_t vi = pp[l/2]; const int8_t vi0 = vi & 0xf; @@ -550,20 +470,20 @@ void dequantize_row_q4_0(const void * restrict x, float * restrict y, int k) { const float v0 = (vi0 - 8)*d; const float v1 = (vi1 - 8)*d; - y[i*64 + l + 0] = v0; - y[i*64 + l + 1] = v1; + y[i*QK + l + 0] = v0; + y[i*QK + l + 1] = v1; - assert(!isnan(y[i*64 + l + 0])); - assert(!isnan(y[i*64 + l + 1])); + assert(!isnan(y[i*QK + l + 0])); + assert(!isnan(y[i*QK + l + 1])); //printf("v0 %f v1 %f, i = %d, l = %d, d = %f, vi = %d, vi0 = %d, vi1 = %d\n", v0, v1, i, l, d, vi, vi0, vi1); } } } void dequantize_row_q4_1(const void * restrict x, float * restrict y, int k) { - assert(k % 64 == 0); + assert(k % QK == 0); - const int nb = k / 64; + const int nb = k / QK; const float * restrict pm = (const float *) (x); const float * restrict pd = (const float *) (pm + nb); @@ -573,9 +493,9 @@ void dequantize_row_q4_1(const void * restrict x, float * restrict y, int k) { const float m = pm[i]; const float d = pd[i]; - const uint8_t * restrict pp = pb + i*32; + const uint8_t * restrict pp = pb + i*QK/2; - for (int l = 0; l < 64; l += 2) { + for (int l = 0; l < QK; l += 2) { const uint8_t vi = pp[l/2]; const int8_t vi0 = vi & 0xf; @@ -584,11 +504,11 @@ void dequantize_row_q4_1(const void * restrict x, float * restrict y, int k) { const float v0 = vi0*d + m; const float v1 = vi1*d + m; - y[i*64 + l + 0] = v0; - y[i*64 + l + 1] = v1; + y[i*QK + l + 0] = v0; + y[i*QK + l + 1] = v1; - assert(!isnan(y[i*64 + l + 0])); - assert(!isnan(y[i*64 + l + 1])); + assert(!isnan(y[i*QK + l + 0])); + assert(!isnan(y[i*QK + l + 1])); //printf("v0 %f v1 %f, i = %d, l = %d, d = %f, vi = %d, vi0 = %d, vi1 = %d\n", v0, v1, i, l, d, vi, vi0, vi1); } } @@ -1172,7 +1092,7 @@ inline static void ggml_vec_dot_f16(const int n, float * restrict s, ggml_fp16_t } inline static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void * restrict x, const void * restrict y) { - const int nb = n / 64; + const int nb = n / QK; const float * restrict pd0 = (const float *) x; const float * restrict pd1 = (const float *) y; @@ -1182,16 +1102,15 @@ inline static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void float sumf = 0.0; -#if 0 // scalar for (int i = 0; i < nb; i++) { const float d0 = pd0[i]; const float d1 = pd1[i]; - const uint8_t * restrict p0 = pb0 + i*32; - const uint8_t * restrict p1 = pb1 + i*32; + const uint8_t * restrict p0 = pb0 + i*QK/2; + const uint8_t * restrict p1 = pb1 + i*QK/2; - for (int j = 0; j < 32; j++) { + for (int j = 0; j < QK/2; j++) { const uint8_t v0 = p0[j]; const uint8_t v1 = p1[j]; @@ -1204,134 +1123,12 @@ inline static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void sumf += f0*f2 + f1*f3; } } -#else -#if defined(__AVX2__) - __m256 sum11 = _mm256_setzero_ps(); - - for (int i = 0; i < nb; i++) { - const float d0 = pd0[i]; - const float d1 = pd1[i]; - - const uint8_t * restrict p0 = pb0 + i*32; - const uint8_t * restrict p1 = pb1 + i*32; - - const __m256 d0v = _mm256_set1_ps(d0); - const __m256 d1v = _mm256_set1_ps(d1); - - const __m256 d0d1v = _mm256_mul_ps(d0v, d1v); - - const __m256i m4b = _mm256_set1_epi8(0xf); - - // 64 x 4 - const __m256i v0 = _mm256_loadu_si256((__m256i *) p0); - const __m256i v1 = _mm256_loadu_si256((__m256i *) p1); - - // 32 x 8 - __m256i v0l = _mm256_and_si256(v0, m4b); - __m256i v1l = _mm256_and_si256(v1, m4b); - - __m256i v0h = _mm256_and_si256(_mm256_srli_epi16(v0, 4), m4b); - __m256i v1h = _mm256_and_si256(_mm256_srli_epi16(v1, 4), m4b); - - // sub 8 - v0l = _mm256_sub_epi8(v0l, _mm256_set1_epi8(8)); - v0h = _mm256_sub_epi8(v0h, _mm256_set1_epi8(8)); - - v1l = _mm256_sub_epi8(v1l, _mm256_set1_epi8(8)); - v1h = _mm256_sub_epi8(v1h, _mm256_set1_epi8(8)); - - // abs - const __m256i v0la = _mm256_sign_epi8(v0l, v0l); - const __m256i v0ha = _mm256_sign_epi8(v0h, v0h); - - // sign - const __m256i v1ls = _mm256_sign_epi8(v1l, v0l); - const __m256i v1hs = _mm256_sign_epi8(v1h, v0h); - - const __m256i pl = _mm256_maddubs_epi16(v0la, v1ls); - const __m256i ph = _mm256_maddubs_epi16(v0ha, v1hs); - - const __m256i p16 = _mm256_add_epi16(ph, pl); - const __m256i p = _mm256_madd_epi16(_mm256_set1_epi16(1), p16); - - sum11 = _mm256_fmadd_ps(d0d1v, _mm256_cvtepi32_ps(p), sum11); - } - - sumf = _mm256_hadd_ps_gg(sum11); -#elif defined (__ARM_NEON) - float sum11 = 0.0f; - - for (int i = 0; i < nb; i++) { - const float d0 = pd0[i]; - const float d1 = pd1[i]; - - const uint8_t * restrict p0 = pb0 + i*32; - const uint8_t * restrict p1 = pb1 + i*32; - - const uint8x16_t m4b = vdupq_n_u8(0xf); - const int8x16_t s8b = vdupq_n_s8(0x8); - - const uint8x16_t v0_0 = vld1q_u8(p0); - const uint8x16_t v0_1 = vld1q_u8(p0 + 16); - const uint8x16_t v1_0 = vld1q_u8(p1); - const uint8x16_t v1_1 = vld1q_u8(p1 + 16); - - // 4-bit -> 8-bit - const uint8x16_t v0_0l = vandq_u8(v0_0, m4b); - const uint8x16_t v0_1l = vandq_u8(v0_1, m4b); - const uint8x16_t v1_0l = vandq_u8(v1_0, m4b); - const uint8x16_t v1_1l = vandq_u8(v1_1, m4b); - - const uint8x16_t v0_0h = vshrq_n_u8(v0_0, 4); - const uint8x16_t v0_1h = vshrq_n_u8(v0_1, 4); - const uint8x16_t v1_0h = vshrq_n_u8(v1_0, 4); - const uint8x16_t v1_1h = vshrq_n_u8(v1_1, 4); - - // sub 8 - const int8x16_t v0_0ls = vsubq_s8(v0_0l, s8b); - const int8x16_t v0_1ls = vsubq_s8(v0_1l, s8b); - const int8x16_t v1_0ls = vsubq_s8(v1_0l, s8b); - const int8x16_t v1_1ls = vsubq_s8(v1_1l, s8b); - - const int8x16_t v0_0hs = vsubq_s8(v0_0h, s8b); - const int8x16_t v0_1hs = vsubq_s8(v0_1h, s8b); - const int8x16_t v1_0hs = vsubq_s8(v1_0h, s8b); - const int8x16_t v1_1hs = vsubq_s8(v1_1h, s8b); - - // dot product into int16x8_t - const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0ls), vget_low_s8 (v1_0ls)); - const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0ls), vget_high_s8(v1_0ls)); - const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1ls), vget_low_s8 (v1_1ls)); - const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1ls), vget_high_s8(v1_1ls)); - - const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hs), vget_low_s8 (v1_0hs)); - const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hs), vget_high_s8(v1_0hs)); - const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hs), vget_low_s8 (v1_1hs)); - const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hs), vget_high_s8(v1_1hs)); - - const int16x8_t pl0 = vaddq_s16(pl0l, pl0h); - const int16x8_t pl1 = vaddq_s16(pl1l, pl1h); - const int16x8_t ph0 = vaddq_s16(ph0l, ph0h); - const int16x8_t ph1 = vaddq_s16(ph1l, ph1h); - - const int16x8_t pl = vaddq_s16(pl0, pl1); - const int16x8_t ph = vaddq_s16(ph0, ph1); - - const int16x8_t p = vaddq_s16(pl, ph); - - // scalar - sum11 += d0*d1*vaddvq_u16(p); - } - - sumf = sum11; -#endif -#endif *s = sumf; } inline static void ggml_vec_dot_q4_1(const int n, float * restrict s, const void * restrict x, const void * restrict y) { - const int nb = n / 64; + const int nb = n / QK; const float * restrict pm0 = (const float *) x; const float * restrict pm1 = (const float *) y; @@ -1353,10 +1150,10 @@ inline static void ggml_vec_dot_q4_1(const int n, float * restrict s, const void const float d0 = pd0[i]; const float d1 = pd1[i]; - const uint8_t * restrict p0 = pb0 + i*32; - const uint8_t * restrict p1 = pb1 + i*32; + const uint8_t * restrict p0 = pb0 + i*QK/2; + const uint8_t * restrict p1 = pb1 + i*QK/2; - for (int j = 0; j < 32; j++) { + for (int j = 0; j < QK/2; j++) { const uint8_t v0 = p0[j]; const uint8_t v1 = p1[j]; @@ -1492,9 +1289,9 @@ inline static void ggml_vec_mad_f16(const int n, ggml_fp16_t * restrict y, ggml_ } inline static void ggml_vec_mad_q4_0(const int n, float * restrict y, void * restrict x, const float v) { - assert(n % 64 == 0); + assert(n % QK == 0); - const int nb = n / 64; + const int nb = n / QK; const float * restrict pd = (const float *) (x); const uint8_t * restrict pb = (const uint8_t *) (pd + nb); @@ -1502,9 +1299,9 @@ inline static void ggml_vec_mad_q4_0(const int n, float * restrict y, void * res for (int i = 0; i < nb; i++) { const float d = pd[i]; - const uint8_t * restrict pp = pb + i*32; + const uint8_t * restrict pp = pb + i*QK/2; - for (int l = 0; l < 64; l += 2) { + for (int l = 0; l < QK; l += 2) { const uint8_t vi = pp[l/2]; const int8_t vi0 = vi & 0xf; @@ -1513,22 +1310,22 @@ inline static void ggml_vec_mad_q4_0(const int n, float * restrict y, void * res const float v0 = (vi0 - 8)*d; const float v1 = (vi1 - 8)*d; - y[i*64 + l + 0] += v0*v; - y[i*64 + l + 1] += v1*v; + y[i*QK + l + 0] += v0*v; + y[i*QK + l + 1] += v1*v; - assert(!isnan(y[i*64 + l + 0])); - assert(!isnan(y[i*64 + l + 1])); - assert(!isinf(y[i*64 + l + 0])); - assert(!isinf(y[i*64 + l + 1])); + assert(!isnan(y[i*QK + l + 0])); + assert(!isnan(y[i*QK + l + 1])); + assert(!isinf(y[i*QK + l + 0])); + assert(!isinf(y[i*QK + l + 1])); //printf("mad: v0 %f v1 %f, i = %d, l = %d, d = %f, vi = %d, vi0 = %d, vi1 = %d\n", v0, v1, i, l, d, vi, vi0, vi1); } } } inline static void ggml_vec_mad_q4_1(const int n, float * restrict y, void * restrict x, const float v) { - assert(n % 64 == 0); + assert(n % QK == 0); - const int nb = n / 64; + const int nb = n / QK; const float * restrict pm = (const float *) (x); const float * restrict pd = (const float *) (pm + nb); @@ -1538,9 +1335,9 @@ inline static void ggml_vec_mad_q4_1(const int n, float * restrict y, void * res const float m = pm[i]; const float d = pd[i]; - const uint8_t * restrict pp = pb + i*32; + const uint8_t * restrict pp = pb + i*QK/2; - for (int l = 0; l < 64; l += 2) { + for (int l = 0; l < QK; l += 2) { const uint8_t vi = pp[l/2]; const uint8_t vi0 = vi & 0xf; @@ -1549,13 +1346,13 @@ inline static void ggml_vec_mad_q4_1(const int n, float * restrict y, void * res const float v0 = d*vi0 + m; const float v1 = d*vi1 + m; - y[i*64 + l + 0] += v0*v; - y[i*64 + l + 1] += v1*v; + y[i*QK + l + 0] += v0*v; + y[i*QK + l + 1] += v1*v; - assert(!isnan(y[i*64 + l + 0])); - assert(!isnan(y[i*64 + l + 1])); - assert(!isinf(y[i*64 + l + 0])); - assert(!isinf(y[i*64 + l + 1])); + assert(!isnan(y[i*QK + l + 0])); + assert(!isnan(y[i*QK + l + 1])); + assert(!isinf(y[i*QK + l + 0])); + assert(!isinf(y[i*QK + l + 1])); //printf("mad: v0 %f v1 %f, i = %d, l = %d, d = %f, vi = %d, vi0 = %d, vi1 = %d\n", v0, v1, i, l, d, vi, vi0, vi1); } } @@ -1685,8 +1482,8 @@ inline static void ggml_vec_norm_inv_f32(const int n, float * s, const float * x // static const int GGML_BLCK_SIZE[GGML_TYPE_COUNT] = { - 64, - 64, + QK, + QK, 1, 1, 1, @@ -1697,8 +1494,8 @@ static const int GGML_BLCK_SIZE[GGML_TYPE_COUNT] = { static_assert(GGML_TYPE_COUNT == 7, "GGML_TYPE_COUNT != 5"); static const size_t GGML_TYPE_SIZE[GGML_TYPE_COUNT] = { - sizeof(float ) + 32, - sizeof(float )*2 + 32, + sizeof(float ) + QK/2, + sizeof(float )*2 + QK/2, sizeof(int8_t ), sizeof(int16_t), sizeof(int32_t),