ggml : 4-bit quantization works (only scalar for now)

gq
Georgi Gerganov 2 years ago
parent b48b09c37f
commit cc94fdafe7
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735

@ -51,190 +51,6 @@ def convert_to_ftype(data, ftype):
if ftype == 1: if ftype == 1:
return data.astype(np.float16) return data.astype(np.float16)
# qint4_0
# C code:
# {
# for (int l = 0; l < QK; l++) {
# const float v = src[i*QK + l];
# amax = MAX(amax, fabsf(v));
# }
#
# const float d = amax / ((1 << (QB - 1)) - 1);
# const float id = d ? 1.0/d : 0.0;
#
# pd[i] = GGML_FP32_TO_GQ(d);
#
# for (int l = 0; l < QK; l++) {
# const float v = src[i*QK + l]*id;
# const int8_t vi = ((int8_t) (round(v))) + 8;
# assert(vi >= 0 && vi < 16);
# pp[l/2] |= (vi & 0xf) << (4*(l & 1));
# }
#
# memcpy(pb + i*QK/2, pp, sizeof(pp));
# }
if ftype == 2:
assert data.dtype == np.float32
assert data.shape[-1] % 64 == 0
# create 2 new arrays:
# - pd: float32 (lowest dimension is data.shape[-1] // 64)
# - pb: int8
pd = np.zeros(data.shape[:-1] + (data.shape[-1] // 64,), dtype=np.float32)
pb = np.zeros(data.shape[:-1] + (data.shape[-1], ), dtype=np.int8)
# the quantized data goes here
dst = np.zeros((data.size // 64) * (4 + 32), dtype=np.uint8)
print("data:", data.shape, data.size)
print("pd: ", pd.shape, pd.size)
print("pb: ", pb.shape, pb.size)
print("dst: ", dst.shape, dst.size)
for i in range(0, data.shape[-1], 64):
max_abs = np.max(np.abs(data[..., i:i+64]))
max_q = (1 << 3) - 1
d = max_abs / max_q
id = 1.0 / d if d != 0 else 0.0
pd[..., i//64] = d
for j in range(64):
v = data[..., i+j] * id
vi = np.round(v).astype(np.int8) + 8
assert np.all(vi >= 0) and np.all(vi < 16)
#ve = vi[...,(j & 1) == 0].reshape(-1, 1)
#print("ve:", ve.shape, ve)
#print("vo:", vo.shape, vo)
#print("pb:", pb[..., (i+j)//2].shape, pb[..., (i+j)//2])
pb[..., i+j] = vi
# convert to 1D array
pd = pd.reshape(-1, 1)
pb = pb.reshape(-1, 1)
# populate the destination array
n = data.size
nr = data.shape[-1]
nn = nr//64
for i in range(0, n, nr):
for j in range(0, nr, 64):
d = pd[(i//nr)*nn + j//64][0]
b = pb[i+j:i+j+64].reshape(-1)
db = struct.unpack("4B", struct.pack("f", d))
dst[(i//nr)*nn*36 + (j//64)*4 + 0] = db[0]
dst[(i//nr)*nn*36 + (j//64)*4 + 1] = db[1]
dst[(i//nr)*nn*36 + (j//64)*4 + 2] = db[2]
dst[(i//nr)*nn*36 + (j//64)*4 + 3] = db[3]
for k in range(32):
dst[(i//nr)*nn*36 + nn*4 + (j//64)*32 + k] = b[2*k] | (b[2*k+1] << 4)
return dst
# qint4_1
# C code:
# {
# for (int l = 0; l < QK; l++) {
# const float v = src[i*QK + l];
# if (v < min) min = v;
# if (v > max) max = v;
# }
# const float d = (max - min) / ((1 << QB) - 1);
# const float id = d ? 1.0/d : 0.0;
# pm[i] = GGML_FP32_TO_GQ(min);
# pd[i] = GGML_FP32_TO_GQ(d);
# for (int l = 0; l < QK; l++) {
# const float v = (src[i*QK + l] - min) * id;
# const uint8_t vi = (uint8_t) (v + frand());
# pp[l/2] |= (vi & 0xf) << (4*(l & 1));
# }
# memcpy(pb + i*QK/2, pp, sizeof(pp));
# }
if ftype == 3:
assert data.dtype == np.float32
assert data.shape[-1] % 64 == 0
# create 2 new arrays:
# - pd: float32 (lowest dimension is data.shape[-1] // 64)
# - pb: int8
pm = np.zeros(data.shape[:-1] + (data.shape[-1] // 64,), dtype=np.float32)
pd = np.zeros(data.shape[:-1] + (data.shape[-1] // 64,), dtype=np.float32)
pb = np.zeros(data.shape[:-1] + (data.shape[-1], ), dtype=np.int8)
# the quantized data goes here
dst = np.zeros((data.size // 64) * (4 + 4 + 32), dtype=np.uint8)
print("data:", data.shape, data.size)
print("pm: ", pm.shape, pm.size)
print("pd: ", pd.shape, pd.size)
print("pb: ", pb.shape, pb.size)
print("dst: ", dst.shape, dst.size)
for i in range(0, data.shape[-1], 64):
mmin = np.min(data[..., i:i+64])
mmax = np.max(data[..., i:i+64])
max_q = (1 << 4) - 1
d = (mmax - mmin) / max_q
id = 1.0 / d if d != 0 else 0.0
pm[..., i//64] = mmin
pd[..., i//64] = d
for j in range(64):
v = (data[..., i+j] - mmin) * id
vi = np.round(v).astype(np.uint8)
assert np.all(vi >= 0) and np.all(vi < 16)
pb[..., i+j] = vi
# convert to 1D array
pm = pm.reshape(-1, 1)
pd = pd.reshape(-1, 1)
pb = pb.reshape(-1, 1)
# populate the destination array
n = data.size
nr = data.shape[-1]
nn = nr//64
for i in range(0, n, nr):
for j in range(0, nr, 64):
m = pm[(i//nr)*nn + j//64][0]
idx = (i//nr)*nn*40 + (j//64)*4
mb = struct.unpack("4B", struct.pack("f", m))
dst[idx + 0] = mb[0]
dst[idx + 1] = mb[1]
dst[idx + 2] = mb[2]
dst[idx + 3] = mb[3]
for j in range(0, nr, 64):
d = pd[(i//nr)*nn + j//64][0]
idx = (i//nr)*nn*40 + 4*nn + (j//64)*4
db = struct.unpack("4B", struct.pack("f", d))
dst[idx + 0] = db[0]
dst[idx + 1] = db[1]
dst[idx + 2] = db[2]
dst[idx + 3] = db[3]
for j in range(0, nr, 64):
b = pb[i+j:i+j+64].reshape(-1)
idx = (i//nr)*nn*40 + nn*8 + (j//64)*32
for k in range(32):
dst[idx + k] = b[2*k] | (b[2*k+1] << 4)
return dst
assert False, "Invalid ftype: " + str(ftype) assert False, "Invalid ftype: " + str(ftype)
if len(sys.argv) < 2: if len(sys.argv) < 2:
@ -258,12 +74,12 @@ with open(dir_model + "/hparams.json", "r") as f:
# ftype == 3 -> qint4_1 # ftype == 3 -> qint4_1
# #
# map from ftype to string # map from ftype to string
ftype_str = ["f32", "f16", "q4_0", "q4_1"] ftype_str = ["f32", "f16"]
ftype = 1 ftype = 1
if len(sys.argv) > 2: if len(sys.argv) > 2:
ftype = int(sys.argv[2]) ftype = int(sys.argv[2])
if ftype < 0 or ftype > 3: if ftype < 0 or ftype > 1:
print("Invalid ftype: " + str(ftype)) print("Invalid ftype: " + str(ftype))
sys.exit(1) sys.exit(1)
fname_out = sys.argv[1] + "/ggml-model-" + ftype_str[ftype] + ".bin" fname_out = sys.argv[1] + "/ggml-model-" + ftype_str[ftype] + ".bin"
@ -312,8 +128,6 @@ for name, shape in list_vars:
# "model/h.*/mlp/c_fc/w" # "model/h.*/mlp/c_fc/w"
# "model/h.*/mlp/c_proj/w" # "model/h.*/mlp/c_proj/w"
if name == "model/wte" or name[-2:] == "/w": if name == "model/wte" or name[-2:] == "/w":
#if name[-6:] == "attn/w":
#if name == "model/wte":
print(" Converting to " + ftype_str[ftype]) print(" Converting to " + ftype_str[ftype])
data = convert_to_ftype(data, ftype) data = convert_to_ftype(data, ftype)
ftype_cur = ftype ftype_cur = ftype

@ -12,8 +12,9 @@
#include <vector> #include <vector>
#include <regex> #include <regex>
#define QK 64
size_t ggml_quantize_q4_0(float * src, void * dst, int n, int k) { size_t ggml_quantize_q4_0(float * src, void * dst, int n, int k) {
const int QK = 64;
const int nb = k / QK; const int nb = k / QK;
const size_t row_size = nb*(sizeof(float) + sizeof(uint8_t)*QK/2); const size_t row_size = nb*(sizeof(float) + sizeof(uint8_t)*QK/2);
@ -63,7 +64,6 @@ size_t ggml_quantize_q4_0(float * src, void * dst, int n, int k) {
} }
size_t ggml_quantize_q4_1(float * src, void * dst, int n, int k) { size_t ggml_quantize_q4_1(float * src, void * dst, int n, int k) {
const int QK = 64;
const int nb = k / QK; const int nb = k / QK;
const size_t row_size = nb*(2*sizeof(float) + sizeof(uint8_t)*QK/2); const size_t row_size = nb*(2*sizeof(float) + sizeof(uint8_t)*QK/2);

@ -352,106 +352,27 @@ static const size_t CACHE_LINE_SIZE_F32 = CACHE_LINE_SIZE/sizeof(float);
// quantization // quantization
// //
#define QK 64
// method 5 // method 5
// blocks of 64 elements // blocks of QK elements
// represented with a single float (delta) and 32 8-bit ints (i.e 64 4-bit signed integer factors) // represented with a single float (delta) and QK/2 8-bit ints (i.e QK 4-bit signed integer factors)
void quantize_row_q4_0(const float * restrict x, void * restrict y, int k) { void quantize_row_q4_0(const float * restrict x, void * restrict y, int k) {
assert(k % 64 == 0); assert(k % QK == 0);
const int nb = k / 64; const int nb = k / QK;
float * restrict pd = (float *) (y); float * restrict pd = (float *) (y);
uint8_t * restrict pb = (uint8_t *) (pd + nb); uint8_t * restrict pb = (uint8_t *) (pd + nb);
uint8_t pp[32]; uint8_t pp[QK/2];
for (int i = 0; i < nb; i++) { for (int i = 0; i < nb; i++) {
float amax = 0.0f; // absolute max float amax = 0.0f; // absolute max
#if defined(__AVX2__)
{ {
__m256 srcv [8]; for (int l = 0; l < QK; l++) {
__m256 asrcv[8]; const float v = x[i*QK + l];
__m256 amaxv[8];
// TODO: unroll these loops (maybe not? this is more compact)
for (int l = 0; l < 8; l++) srcv[l] = _mm256_loadu_ps(x + i*64 + 8*l);
for (int l = 0; l < 8; l++) asrcv[l] = _mm256_and_ps(srcv[l], (__m256) _mm256_set1_epi32(0x7fffffff));
for (int l = 0; l < 4; l++) amaxv[2*l] = _mm256_max_ps(asrcv[2*l], asrcv[2*l+1]);
for (int l = 0; l < 2; l++) amaxv[4*l] = _mm256_max_ps(amaxv[4*l], amaxv[4*l+2]);
for (int l = 0; l < 1; l++) amaxv[8*l] = _mm256_max_ps(amaxv[8*l], amaxv[8*l+4]);
const __m256 amaxv0_0 = _mm256_permute2f128_ps(amaxv[0], amaxv[0], 3);
const __m256 amaxv0_1 = _mm256_max_ps(amaxv[0], amaxv0_0);
const __m256 amaxv0_2 = _mm256_permute_ps(amaxv0_1, 0x4e);
const __m256 amaxv0_3 = _mm256_max_ps(amaxv0_1, amaxv0_2);
const __m256 amaxv0_4 = _mm256_permute_ps(amaxv0_3, 0xb1);
const __m256 amaxv0_5 = _mm256_max_ps(amaxv0_3, amaxv0_4);
amax = _mm256_cvtss_f32(amaxv0_5);
//printf("amax = %f\n", amax);
const float d = amax / ((1 << 3) - 1);
const float id = d ? 1.0f/d : 0.0f;
pd[i] = d;
const __m256 idv = _mm256_set1_ps(id);
for (int l = 0; l < 8; l++) {
__m256 v = _mm256_mul_ps(srcv[l], idv);
#if 0
v[0] += frand(); v[1] += frand(); v[2] += frand(); v[3] += frand();
v[4] += frand(); v[5] += frand(); v[6] += frand(); v[7] += frand();
#endif
// convert to int8
__m256i vi = _mm256_cvtps_epi32(v);
vi = _mm256_add_epi32(vi, _mm256_set1_epi32(8));
int32_t vi_0 = _mm256_extract_epi32(vi, 0);
int32_t vi_1 = _mm256_extract_epi32(vi, 1);
int32_t vi_2 = _mm256_extract_epi32(vi, 2);
int32_t vi_3 = _mm256_extract_epi32(vi, 3);
int32_t vi_4 = _mm256_extract_epi32(vi, 4);
int32_t vi_5 = _mm256_extract_epi32(vi, 5);
int32_t vi_6 = _mm256_extract_epi32(vi, 6);
int32_t vi_7 = _mm256_extract_epi32(vi, 7);
// convert to 4-bit, 2 consecutive packed into 1 byte
pp[4*l + 0] = vi_0 | (vi_1 << 4);
pp[4*l + 1] = vi_2 | (vi_3 << 4);
pp[4*l + 2] = vi_4 | (vi_5 << 4);
pp[4*l + 3] = vi_6 | (vi_7 << 4);
//printf("vi: %7d %7d %7d %7d %7d %7d %7d %7d\n", vi_0, vi_1, vi_2, vi_3, vi_4, vi_5, vi_6, vi_7);
//printf("v : %7.3f %7.3f %7.3f %7.3f %7.3f %7.3f %7.3f %7.3f\n", v[0], v[1], v[2], v[3], v[4], v[5], v[6], v[7]);
assert(vi_0 >= 0 && vi_0 < 16);
assert(vi_1 >= 0 && vi_1 < 16);
assert(vi_2 >= 0 && vi_2 < 16);
assert(vi_3 >= 0 && vi_3 < 16);
assert(vi_4 >= 0 && vi_4 < 16);
assert(vi_5 >= 0 && vi_5 < 16);
assert(vi_6 >= 0 && vi_6 < 16);
assert(vi_7 >= 0 && vi_7 < 16);
}
memcpy(pb + i*32, pp, sizeof(pp));
}
#elif defined(__ARM_NEON) && 0
{
// TODO
#pragma warning "implement me !!"
}
#else
{
for (int l = 0; l < 64; l++) {
const float v = x[i*64 + l];
amax = MAX(amax, fabsf(v)); amax = MAX(amax, fabsf(v));
} }
@ -460,9 +381,9 @@ void quantize_row_q4_0(const float * restrict x, void * restrict y, int k) {
pd[i] = d; pd[i] = d;
for (int l = 0; l < 64; l += 2) { for (int l = 0; l < QK; l += 2) {
const float v0 = x[i*64 + l + 0]*id; const float v0 = x[i*QK + l + 0]*id;
const float v1 = x[i*64 + l + 1]*id; const float v1 = x[i*QK + l + 1]*id;
const uint8_t vi0 = ((int8_t) (round(v0))) + 8; const uint8_t vi0 = ((int8_t) (round(v0))) + 8;
const uint8_t vi1 = ((int8_t) (round(v1))) + 8; const uint8_t vi1 = ((int8_t) (round(v1))) + 8;
@ -473,33 +394,32 @@ void quantize_row_q4_0(const float * restrict x, void * restrict y, int k) {
pp[l/2] = vi0 | (vi1 << 4); pp[l/2] = vi0 | (vi1 << 4);
} }
memcpy(pb + i*32, pp, sizeof(pp)); memcpy(pb + i*QK/2, pp, sizeof(pp));
} }
#endif
} }
} }
// method 4 // method 4
// blocks of 64 elements // blocks of QK elements
// represented with 2 floats (min + delta) and 32 8-bit ints (i.e 64 4-bit unsigned integer factors) // represented with 2 floats (min + delta) and QK/2 8-bit ints (i.e QK 4-bit unsigned integer factors)
void quantize_row_q4_1(const float * restrict x, void * restrict y, int k) { void quantize_row_q4_1(const float * restrict x, void * restrict y, int k) {
assert(k % 64 == 0); assert(k % QK == 0);
const int nb = k / 64; const int nb = k / QK;
float * restrict pm = (float *) (y); float * restrict pm = (float *) (y);
float * restrict pd = (float *) (pm + nb); float * restrict pd = (float *) (pm + nb);
uint8_t * restrict pb = (uint8_t *) (pd + nb); uint8_t * restrict pb = (uint8_t *) (pd + nb);
uint8_t pp[32]; uint8_t pp[QK/2];
for (int i = 0; i < nb; i++) { for (int i = 0; i < nb; i++) {
float min = FLT_MAX; float min = FLT_MAX;
float max = -FLT_MAX; float max = -FLT_MAX;
{ {
for (int l = 0; l < 64; l++) { for (int l = 0; l < QK; l++) {
const float v = x[i*64 + l]; const float v = x[i*QK + l];
if (v < min) min = v; if (v < min) min = v;
if (v > max) max = v; if (v > max) max = v;
} }
@ -510,9 +430,9 @@ void quantize_row_q4_1(const float * restrict x, void * restrict y, int k) {
pm[i] = min; pm[i] = min;
pd[i] = d; pd[i] = d;
for (int l = 0; l < 64; l += 2) { for (int l = 0; l < QK; l += 2) {
const float v0 = (x[i*64 + l + 0] - min)*id; const float v0 = (x[i*QK + l + 0] - min)*id;
const float v1 = (x[i*64 + l + 1] - min)*id; const float v1 = (x[i*QK + l + 1] - min)*id;
const uint8_t vi0 = round(v0); const uint8_t vi0 = round(v0);
const uint8_t vi1 = round(v1); const uint8_t vi1 = round(v1);
@ -523,15 +443,15 @@ void quantize_row_q4_1(const float * restrict x, void * restrict y, int k) {
pp[l/2] = vi0 | (vi1 << 4); pp[l/2] = vi0 | (vi1 << 4);
} }
memcpy(pb + i*32, pp, sizeof(pp)); memcpy(pb + i*QK/2, pp, sizeof(pp));
} }
} }
} }
void dequantize_row_q4_0(const void * restrict x, float * restrict y, int k) { void dequantize_row_q4_0(const void * restrict x, float * restrict y, int k) {
assert(k % 64 == 0); assert(k % QK == 0);
const int nb = k / 64; const int nb = k / QK;
const float * restrict pd = (const float *) (x); const float * restrict pd = (const float *) (x);
const uint8_t * restrict pb = (const uint8_t *) (pd + nb); const uint8_t * restrict pb = (const uint8_t *) (pd + nb);
@ -539,9 +459,9 @@ void dequantize_row_q4_0(const void * restrict x, float * restrict y, int k) {
for (int i = 0; i < nb; i++) { for (int i = 0; i < nb; i++) {
const float d = pd[i]; const float d = pd[i];
const uint8_t * restrict pp = pb + i*32; const uint8_t * restrict pp = pb + i*QK/2;
for (int l = 0; l < 64; l += 2) { for (int l = 0; l < QK; l += 2) {
const uint8_t vi = pp[l/2]; const uint8_t vi = pp[l/2];
const int8_t vi0 = vi & 0xf; const int8_t vi0 = vi & 0xf;
@ -550,20 +470,20 @@ void dequantize_row_q4_0(const void * restrict x, float * restrict y, int k) {
const float v0 = (vi0 - 8)*d; const float v0 = (vi0 - 8)*d;
const float v1 = (vi1 - 8)*d; const float v1 = (vi1 - 8)*d;
y[i*64 + l + 0] = v0; y[i*QK + l + 0] = v0;
y[i*64 + l + 1] = v1; y[i*QK + l + 1] = v1;
assert(!isnan(y[i*64 + l + 0])); assert(!isnan(y[i*QK + l + 0]));
assert(!isnan(y[i*64 + l + 1])); assert(!isnan(y[i*QK + l + 1]));
//printf("v0 %f v1 %f, i = %d, l = %d, d = %f, vi = %d, vi0 = %d, vi1 = %d\n", v0, v1, i, l, d, vi, vi0, vi1); //printf("v0 %f v1 %f, i = %d, l = %d, d = %f, vi = %d, vi0 = %d, vi1 = %d\n", v0, v1, i, l, d, vi, vi0, vi1);
} }
} }
} }
void dequantize_row_q4_1(const void * restrict x, float * restrict y, int k) { void dequantize_row_q4_1(const void * restrict x, float * restrict y, int k) {
assert(k % 64 == 0); assert(k % QK == 0);
const int nb = k / 64; const int nb = k / QK;
const float * restrict pm = (const float *) (x); const float * restrict pm = (const float *) (x);
const float * restrict pd = (const float *) (pm + nb); const float * restrict pd = (const float *) (pm + nb);
@ -573,9 +493,9 @@ void dequantize_row_q4_1(const void * restrict x, float * restrict y, int k) {
const float m = pm[i]; const float m = pm[i];
const float d = pd[i]; const float d = pd[i];
const uint8_t * restrict pp = pb + i*32; const uint8_t * restrict pp = pb + i*QK/2;
for (int l = 0; l < 64; l += 2) { for (int l = 0; l < QK; l += 2) {
const uint8_t vi = pp[l/2]; const uint8_t vi = pp[l/2];
const int8_t vi0 = vi & 0xf; const int8_t vi0 = vi & 0xf;
@ -584,11 +504,11 @@ void dequantize_row_q4_1(const void * restrict x, float * restrict y, int k) {
const float v0 = vi0*d + m; const float v0 = vi0*d + m;
const float v1 = vi1*d + m; const float v1 = vi1*d + m;
y[i*64 + l + 0] = v0; y[i*QK + l + 0] = v0;
y[i*64 + l + 1] = v1; y[i*QK + l + 1] = v1;
assert(!isnan(y[i*64 + l + 0])); assert(!isnan(y[i*QK + l + 0]));
assert(!isnan(y[i*64 + l + 1])); assert(!isnan(y[i*QK + l + 1]));
//printf("v0 %f v1 %f, i = %d, l = %d, d = %f, vi = %d, vi0 = %d, vi1 = %d\n", v0, v1, i, l, d, vi, vi0, vi1); //printf("v0 %f v1 %f, i = %d, l = %d, d = %f, vi = %d, vi0 = %d, vi1 = %d\n", v0, v1, i, l, d, vi, vi0, vi1);
} }
} }
@ -1172,7 +1092,7 @@ inline static void ggml_vec_dot_f16(const int n, float * restrict s, ggml_fp16_t
} }
inline static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void * restrict x, const void * restrict y) { inline static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void * restrict x, const void * restrict y) {
const int nb = n / 64; const int nb = n / QK;
const float * restrict pd0 = (const float *) x; const float * restrict pd0 = (const float *) x;
const float * restrict pd1 = (const float *) y; const float * restrict pd1 = (const float *) y;
@ -1182,16 +1102,15 @@ inline static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void
float sumf = 0.0; float sumf = 0.0;
#if 0
// scalar // scalar
for (int i = 0; i < nb; i++) { for (int i = 0; i < nb; i++) {
const float d0 = pd0[i]; const float d0 = pd0[i];
const float d1 = pd1[i]; const float d1 = pd1[i];
const uint8_t * restrict p0 = pb0 + i*32; const uint8_t * restrict p0 = pb0 + i*QK/2;
const uint8_t * restrict p1 = pb1 + i*32; const uint8_t * restrict p1 = pb1 + i*QK/2;
for (int j = 0; j < 32; j++) { for (int j = 0; j < QK/2; j++) {
const uint8_t v0 = p0[j]; const uint8_t v0 = p0[j];
const uint8_t v1 = p1[j]; const uint8_t v1 = p1[j];
@ -1204,134 +1123,12 @@ inline static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void
sumf += f0*f2 + f1*f3; sumf += f0*f2 + f1*f3;
} }
} }
#else
#if defined(__AVX2__)
__m256 sum11 = _mm256_setzero_ps();
for (int i = 0; i < nb; i++) {
const float d0 = pd0[i];
const float d1 = pd1[i];
const uint8_t * restrict p0 = pb0 + i*32;
const uint8_t * restrict p1 = pb1 + i*32;
const __m256 d0v = _mm256_set1_ps(d0);
const __m256 d1v = _mm256_set1_ps(d1);
const __m256 d0d1v = _mm256_mul_ps(d0v, d1v);
const __m256i m4b = _mm256_set1_epi8(0xf);
// 64 x 4
const __m256i v0 = _mm256_loadu_si256((__m256i *) p0);
const __m256i v1 = _mm256_loadu_si256((__m256i *) p1);
// 32 x 8
__m256i v0l = _mm256_and_si256(v0, m4b);
__m256i v1l = _mm256_and_si256(v1, m4b);
__m256i v0h = _mm256_and_si256(_mm256_srli_epi16(v0, 4), m4b);
__m256i v1h = _mm256_and_si256(_mm256_srli_epi16(v1, 4), m4b);
// sub 8
v0l = _mm256_sub_epi8(v0l, _mm256_set1_epi8(8));
v0h = _mm256_sub_epi8(v0h, _mm256_set1_epi8(8));
v1l = _mm256_sub_epi8(v1l, _mm256_set1_epi8(8));
v1h = _mm256_sub_epi8(v1h, _mm256_set1_epi8(8));
// abs
const __m256i v0la = _mm256_sign_epi8(v0l, v0l);
const __m256i v0ha = _mm256_sign_epi8(v0h, v0h);
// sign
const __m256i v1ls = _mm256_sign_epi8(v1l, v0l);
const __m256i v1hs = _mm256_sign_epi8(v1h, v0h);
const __m256i pl = _mm256_maddubs_epi16(v0la, v1ls);
const __m256i ph = _mm256_maddubs_epi16(v0ha, v1hs);
const __m256i p16 = _mm256_add_epi16(ph, pl);
const __m256i p = _mm256_madd_epi16(_mm256_set1_epi16(1), p16);
sum11 = _mm256_fmadd_ps(d0d1v, _mm256_cvtepi32_ps(p), sum11);
}
sumf = _mm256_hadd_ps_gg(sum11);
#elif defined (__ARM_NEON)
float sum11 = 0.0f;
for (int i = 0; i < nb; i++) {
const float d0 = pd0[i];
const float d1 = pd1[i];
const uint8_t * restrict p0 = pb0 + i*32;
const uint8_t * restrict p1 = pb1 + i*32;
const uint8x16_t m4b = vdupq_n_u8(0xf);
const int8x16_t s8b = vdupq_n_s8(0x8);
const uint8x16_t v0_0 = vld1q_u8(p0);
const uint8x16_t v0_1 = vld1q_u8(p0 + 16);
const uint8x16_t v1_0 = vld1q_u8(p1);
const uint8x16_t v1_1 = vld1q_u8(p1 + 16);
// 4-bit -> 8-bit
const uint8x16_t v0_0l = vandq_u8(v0_0, m4b);
const uint8x16_t v0_1l = vandq_u8(v0_1, m4b);
const uint8x16_t v1_0l = vandq_u8(v1_0, m4b);
const uint8x16_t v1_1l = vandq_u8(v1_1, m4b);
const uint8x16_t v0_0h = vshrq_n_u8(v0_0, 4);
const uint8x16_t v0_1h = vshrq_n_u8(v0_1, 4);
const uint8x16_t v1_0h = vshrq_n_u8(v1_0, 4);
const uint8x16_t v1_1h = vshrq_n_u8(v1_1, 4);
// sub 8
const int8x16_t v0_0ls = vsubq_s8(v0_0l, s8b);
const int8x16_t v0_1ls = vsubq_s8(v0_1l, s8b);
const int8x16_t v1_0ls = vsubq_s8(v1_0l, s8b);
const int8x16_t v1_1ls = vsubq_s8(v1_1l, s8b);
const int8x16_t v0_0hs = vsubq_s8(v0_0h, s8b);
const int8x16_t v0_1hs = vsubq_s8(v0_1h, s8b);
const int8x16_t v1_0hs = vsubq_s8(v1_0h, s8b);
const int8x16_t v1_1hs = vsubq_s8(v1_1h, s8b);
// dot product into int16x8_t
const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0ls), vget_low_s8 (v1_0ls));
const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0ls), vget_high_s8(v1_0ls));
const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1ls), vget_low_s8 (v1_1ls));
const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1ls), vget_high_s8(v1_1ls));
const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hs), vget_low_s8 (v1_0hs));
const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hs), vget_high_s8(v1_0hs));
const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hs), vget_low_s8 (v1_1hs));
const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hs), vget_high_s8(v1_1hs));
const int16x8_t pl0 = vaddq_s16(pl0l, pl0h);
const int16x8_t pl1 = vaddq_s16(pl1l, pl1h);
const int16x8_t ph0 = vaddq_s16(ph0l, ph0h);
const int16x8_t ph1 = vaddq_s16(ph1l, ph1h);
const int16x8_t pl = vaddq_s16(pl0, pl1);
const int16x8_t ph = vaddq_s16(ph0, ph1);
const int16x8_t p = vaddq_s16(pl, ph);
// scalar
sum11 += d0*d1*vaddvq_u16(p);
}
sumf = sum11;
#endif
#endif
*s = sumf; *s = sumf;
} }
inline static void ggml_vec_dot_q4_1(const int n, float * restrict s, const void * restrict x, const void * restrict y) { inline static void ggml_vec_dot_q4_1(const int n, float * restrict s, const void * restrict x, const void * restrict y) {
const int nb = n / 64; const int nb = n / QK;
const float * restrict pm0 = (const float *) x; const float * restrict pm0 = (const float *) x;
const float * restrict pm1 = (const float *) y; const float * restrict pm1 = (const float *) y;
@ -1353,10 +1150,10 @@ inline static void ggml_vec_dot_q4_1(const int n, float * restrict s, const void
const float d0 = pd0[i]; const float d0 = pd0[i];
const float d1 = pd1[i]; const float d1 = pd1[i];
const uint8_t * restrict p0 = pb0 + i*32; const uint8_t * restrict p0 = pb0 + i*QK/2;
const uint8_t * restrict p1 = pb1 + i*32; const uint8_t * restrict p1 = pb1 + i*QK/2;
for (int j = 0; j < 32; j++) { for (int j = 0; j < QK/2; j++) {
const uint8_t v0 = p0[j]; const uint8_t v0 = p0[j];
const uint8_t v1 = p1[j]; const uint8_t v1 = p1[j];
@ -1492,9 +1289,9 @@ inline static void ggml_vec_mad_f16(const int n, ggml_fp16_t * restrict y, ggml_
} }
inline static void ggml_vec_mad_q4_0(const int n, float * restrict y, void * restrict x, const float v) { inline static void ggml_vec_mad_q4_0(const int n, float * restrict y, void * restrict x, const float v) {
assert(n % 64 == 0); assert(n % QK == 0);
const int nb = n / 64; const int nb = n / QK;
const float * restrict pd = (const float *) (x); const float * restrict pd = (const float *) (x);
const uint8_t * restrict pb = (const uint8_t *) (pd + nb); const uint8_t * restrict pb = (const uint8_t *) (pd + nb);
@ -1502,9 +1299,9 @@ inline static void ggml_vec_mad_q4_0(const int n, float * restrict y, void * res
for (int i = 0; i < nb; i++) { for (int i = 0; i < nb; i++) {
const float d = pd[i]; const float d = pd[i];
const uint8_t * restrict pp = pb + i*32; const uint8_t * restrict pp = pb + i*QK/2;
for (int l = 0; l < 64; l += 2) { for (int l = 0; l < QK; l += 2) {
const uint8_t vi = pp[l/2]; const uint8_t vi = pp[l/2];
const int8_t vi0 = vi & 0xf; const int8_t vi0 = vi & 0xf;
@ -1513,22 +1310,22 @@ inline static void ggml_vec_mad_q4_0(const int n, float * restrict y, void * res
const float v0 = (vi0 - 8)*d; const float v0 = (vi0 - 8)*d;
const float v1 = (vi1 - 8)*d; const float v1 = (vi1 - 8)*d;
y[i*64 + l + 0] += v0*v; y[i*QK + l + 0] += v0*v;
y[i*64 + l + 1] += v1*v; y[i*QK + l + 1] += v1*v;
assert(!isnan(y[i*64 + l + 0])); assert(!isnan(y[i*QK + l + 0]));
assert(!isnan(y[i*64 + l + 1])); assert(!isnan(y[i*QK + l + 1]));
assert(!isinf(y[i*64 + l + 0])); assert(!isinf(y[i*QK + l + 0]));
assert(!isinf(y[i*64 + l + 1])); assert(!isinf(y[i*QK + l + 1]));
//printf("mad: v0 %f v1 %f, i = %d, l = %d, d = %f, vi = %d, vi0 = %d, vi1 = %d\n", v0, v1, i, l, d, vi, vi0, vi1); //printf("mad: v0 %f v1 %f, i = %d, l = %d, d = %f, vi = %d, vi0 = %d, vi1 = %d\n", v0, v1, i, l, d, vi, vi0, vi1);
} }
} }
} }
inline static void ggml_vec_mad_q4_1(const int n, float * restrict y, void * restrict x, const float v) { inline static void ggml_vec_mad_q4_1(const int n, float * restrict y, void * restrict x, const float v) {
assert(n % 64 == 0); assert(n % QK == 0);
const int nb = n / 64; const int nb = n / QK;
const float * restrict pm = (const float *) (x); const float * restrict pm = (const float *) (x);
const float * restrict pd = (const float *) (pm + nb); const float * restrict pd = (const float *) (pm + nb);
@ -1538,9 +1335,9 @@ inline static void ggml_vec_mad_q4_1(const int n, float * restrict y, void * res
const float m = pm[i]; const float m = pm[i];
const float d = pd[i]; const float d = pd[i];
const uint8_t * restrict pp = pb + i*32; const uint8_t * restrict pp = pb + i*QK/2;
for (int l = 0; l < 64; l += 2) { for (int l = 0; l < QK; l += 2) {
const uint8_t vi = pp[l/2]; const uint8_t vi = pp[l/2];
const uint8_t vi0 = vi & 0xf; const uint8_t vi0 = vi & 0xf;
@ -1549,13 +1346,13 @@ inline static void ggml_vec_mad_q4_1(const int n, float * restrict y, void * res
const float v0 = d*vi0 + m; const float v0 = d*vi0 + m;
const float v1 = d*vi1 + m; const float v1 = d*vi1 + m;
y[i*64 + l + 0] += v0*v; y[i*QK + l + 0] += v0*v;
y[i*64 + l + 1] += v1*v; y[i*QK + l + 1] += v1*v;
assert(!isnan(y[i*64 + l + 0])); assert(!isnan(y[i*QK + l + 0]));
assert(!isnan(y[i*64 + l + 1])); assert(!isnan(y[i*QK + l + 1]));
assert(!isinf(y[i*64 + l + 0])); assert(!isinf(y[i*QK + l + 0]));
assert(!isinf(y[i*64 + l + 1])); assert(!isinf(y[i*QK + l + 1]));
//printf("mad: v0 %f v1 %f, i = %d, l = %d, d = %f, vi = %d, vi0 = %d, vi1 = %d\n", v0, v1, i, l, d, vi, vi0, vi1); //printf("mad: v0 %f v1 %f, i = %d, l = %d, d = %f, vi = %d, vi0 = %d, vi1 = %d\n", v0, v1, i, l, d, vi, vi0, vi1);
} }
} }
@ -1685,8 +1482,8 @@ inline static void ggml_vec_norm_inv_f32(const int n, float * s, const float * x
// //
static const int GGML_BLCK_SIZE[GGML_TYPE_COUNT] = { static const int GGML_BLCK_SIZE[GGML_TYPE_COUNT] = {
64, QK,
64, QK,
1, 1,
1, 1,
1, 1,
@ -1697,8 +1494,8 @@ static const int GGML_BLCK_SIZE[GGML_TYPE_COUNT] = {
static_assert(GGML_TYPE_COUNT == 7, "GGML_TYPE_COUNT != 5"); static_assert(GGML_TYPE_COUNT == 7, "GGML_TYPE_COUNT != 5");
static const size_t GGML_TYPE_SIZE[GGML_TYPE_COUNT] = { static const size_t GGML_TYPE_SIZE[GGML_TYPE_COUNT] = {
sizeof(float ) + 32, sizeof(float ) + QK/2,
sizeof(float )*2 + 32, sizeof(float )*2 + QK/2,
sizeof(int8_t ), sizeof(int8_t ),
sizeof(int16_t), sizeof(int16_t),
sizeof(int32_t), sizeof(int32_t),

Loading…
Cancel
Save