ggml : q4_0 quantization support

gq
Georgi Gerganov 2 years ago
parent 751aa84f1a
commit a37776ddc0
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735

@ -117,19 +117,20 @@ def convert_to_ftype(data, ftype):
# populate the destination array # populate the destination array
n = data.size n = data.size
for i in range(0, n, 64): nr = data.shape[-1]
d = pd[i//64][0] nn = nr//64
b = pb[i:i+64].reshape(-1) for i in range(0, n, nr):
#print("d:", d) for j in range(0, nr, 64):
#print("b:", b) d = pd[(i//nr)*nn + j//64][0]
b = pb[i+j:i+j+64].reshape(-1)
db = struct.unpack("4B", struct.pack("f", d)) db = struct.unpack("4B", struct.pack("f", d))
dst[(i//64)*36 + 0] = db[0] dst[(i//nr)*nn*36 + (j//64)*4 + 0] = db[0]
dst[(i//64)*36 + 1] = db[1] dst[(i//nr)*nn*36 + (j//64)*4 + 1] = db[1]
dst[(i//64)*36 + 2] = db[2] dst[(i//nr)*nn*36 + (j//64)*4 + 2] = db[2]
dst[(i//64)*36 + 3] = db[3] dst[(i//nr)*nn*36 + (j//64)*4 + 3] = db[3]
for j in range(32): for k in range(32):
dst[(i//64)*36 + 4 + j] = b[j] | (b[j+1] << 4) dst[(i//nr)*nn*36 + nn*4 + (j//64)*32 + k] = b[k] | (b[k+1] << 4)
return dst return dst

@ -375,7 +375,7 @@ void quantize_row_q4_0(const float * restrict x, void * restrict y, int k) {
__m256 amaxv[8]; __m256 amaxv[8];
// TODO: unroll these loops (maybe not? this is more compact) // TODO: unroll these loops (maybe not? this is more compact)
for (int l = 0; l < 8; l++) srcv[l] = _mm256_loadu_ps(x + i*QK + 8*l); for (int l = 0; l < 8; l++) srcv[l] = _mm256_loadu_ps(x + i*64 + 8*l);
for (int l = 0; l < 8; l++) asrcv[l] = _mm256_and_ps(srcv[l], (__m256) _mm256_set1_epi32(0x7fffffff)); for (int l = 0; l < 8; l++) asrcv[l] = _mm256_and_ps(srcv[l], (__m256) _mm256_set1_epi32(0x7fffffff));
for (int l = 0; l < 4; l++) amaxv[2*l] = _mm256_max_ps(asrcv[2*l], asrcv[2*l+1]); for (int l = 0; l < 4; l++) amaxv[2*l] = _mm256_max_ps(asrcv[2*l], asrcv[2*l+1]);
for (int l = 0; l < 2; l++) amaxv[4*l] = _mm256_max_ps(amaxv[4*l], amaxv[4*l+2]); for (int l = 0; l < 2; l++) amaxv[4*l] = _mm256_max_ps(amaxv[4*l], amaxv[4*l+2]);
@ -470,7 +470,7 @@ void quantize_row_q4_0(const float * restrict x, void * restrict y, int k) {
assert(vi0 >= 0 && vi0 < 16); assert(vi0 >= 0 && vi0 < 16);
assert(vi1 >= 0 && vi1 < 16); assert(vi1 >= 0 && vi1 < 16);
pp[l] = vi0 | (vi1 << 4); pp[l/2] = vi0 | (vi1 << 4);
} }
memcpy(pb + i*32, pp, sizeof(pp)); memcpy(pb + i*32, pp, sizeof(pp));
@ -504,6 +504,10 @@ void dequantize_row_q4_0(const void * restrict x, float * restrict y, int k) {
y[i*64 + l + 0] = v0; y[i*64 + l + 0] = v0;
y[i*64 + l + 1] = v1; y[i*64 + l + 1] = v1;
assert(!isnan(y[i*64 + l + 0]));
assert(!isnan(y[i*64 + l + 1]));
//printf("v0 %f v1 %f, i = %d, l = %d, d = %f, vi = %d, vi0 = %d, vi1 = %d\n", v0, v1, i, l, d, vi, vi0, vi1);
} }
} }
} }
@ -1085,6 +1089,165 @@ inline static void ggml_vec_dot_f16(const int n, float * restrict s, ggml_fp16_t
*s = sumf; *s = sumf;
} }
inline static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void * restrict x, const void * restrict y) {
const int nb = n / 64;
const float * restrict pd0 = (const float *) x;
const float * restrict pd1 = (const float *) y;
const uint8_t * restrict pb0 = (const uint8_t *) (pd0 + nb);
const uint8_t * restrict pb1 = (const uint8_t *) (pd1 + nb);
float sumf = 0.0;
#if 0
// scalar
for (int i = 0; i < nb; i++) {
const float d0 = pd0[i];
const float d1 = pd1[i];
const uint8_t * restrict p0 = pb0 + i*32;
const uint8_t * restrict p1 = pb1 + i*32;
for (int j = 0; j < 32; j++) {
const uint8_t v0 = p0[j];
const uint8_t v1 = p1[j];
const float f0 = d0*((int8_t) (v0 & 0xf) - 8);
const float f1 = d0*((int8_t) (v0 >> 4) - 8);
const float f2 = d1*((int8_t) (v1 & 0xf) - 8);
const float f3 = d1*((int8_t) (v1 >> 4) - 8);
sumf += f0*f2 + f1*f3;
}
}
#else
#if defined(__AVX2__)
__m256 sum11 = _mm256_setzero_ps();
for (int i = 0; i < nb; i++) {
const float d0 = pd0[i];
const float d1 = pd1[i];
const uint8_t * restrict p0 = pb0 + i*32;
const uint8_t * restrict p1 = pb1 + i*32;
const __m256 d0v = _mm256_set1_ps(d0);
const __m256 d1v = _mm256_set1_ps(d1);
const __m256 d0d1v = _mm256_mul_ps(d0v, d1v);
const __m256i m4b = _mm256_set1_epi8(0xf);
// 64 x 4
const __m256i v0 = _mm256_loadu_si256((__m256i *) p0);
const __m256i v1 = _mm256_loadu_si256((__m256i *) p1);
// 32 x 8
__m256i v0l = _mm256_and_si256(v0, m4b);
__m256i v1l = _mm256_and_si256(v1, m4b);
__m256i v0h = _mm256_and_si256(_mm256_srli_epi16(v0, 4), m4b);
__m256i v1h = _mm256_and_si256(_mm256_srli_epi16(v1, 4), m4b);
// sub 8
v0l = _mm256_sub_epi8(v0l, _mm256_set1_epi8(8));
v0h = _mm256_sub_epi8(v0h, _mm256_set1_epi8(8));
v1l = _mm256_sub_epi8(v1l, _mm256_set1_epi8(8));
v1h = _mm256_sub_epi8(v1h, _mm256_set1_epi8(8));
// abs
const __m256i v0la = _mm256_sign_epi8(v0l, v0l);
const __m256i v0ha = _mm256_sign_epi8(v0h, v0h);
// sign
const __m256i v1ls = _mm256_sign_epi8(v1l, v0l);
const __m256i v1hs = _mm256_sign_epi8(v1h, v0h);
const __m256i pl = _mm256_maddubs_epi16(v0la, v1ls);
const __m256i ph = _mm256_maddubs_epi16(v0ha, v1hs);
const __m256i p16 = _mm256_add_epi16(ph, pl);
const __m256i p = _mm256_madd_epi16(_mm256_set1_epi16(1), p16);
sum11 = _mm256_fmadd_ps(d0d1v, _mm256_cvtepi32_ps(p), sum11);
}
sumf = _mm256_hadd_ps_gg(sum11);
#elif defined (__ARM_NEON)
float sum11 = 0.0f;
for (int i = 0; i < nb; i++) {
const float d0 = pd0[i];
const float d1 = pd1[i];
const uint8_t * restrict p0 = pb0 + i*32;
const uint8_t * restrict p1 = pb1 + i*32;
const uint8x16_t m4b = vdupq_n_u8(0xf);
const int8x16_t s8b = vdupq_n_s8(0x8);
const uint8x16_t v0_0 = vld1q_u8(p0);
const uint8x16_t v0_1 = vld1q_u8(p0 + 16);
const uint8x16_t v1_0 = vld1q_u8(p1);
const uint8x16_t v1_1 = vld1q_u8(p1 + 16);
// 4-bit -> 8-bit
const uint8x16_t v0_0l = vandq_u8(v0_0, m4b);
const uint8x16_t v0_1l = vandq_u8(v0_1, m4b);
const uint8x16_t v1_0l = vandq_u8(v1_0, m4b);
const uint8x16_t v1_1l = vandq_u8(v1_1, m4b);
const uint8x16_t v0_0h = vshrq_n_u8(v0_0, 4);
const uint8x16_t v0_1h = vshrq_n_u8(v0_1, 4);
const uint8x16_t v1_0h = vshrq_n_u8(v1_0, 4);
const uint8x16_t v1_1h = vshrq_n_u8(v1_1, 4);
// sub 8
const int8x16_t v0_0ls = vsubq_s8(v0_0l, s8b);
const int8x16_t v0_1ls = vsubq_s8(v0_1l, s8b);
const int8x16_t v1_0ls = vsubq_s8(v1_0l, s8b);
const int8x16_t v1_1ls = vsubq_s8(v1_1l, s8b);
const int8x16_t v0_0hs = vsubq_s8(v0_0h, s8b);
const int8x16_t v0_1hs = vsubq_s8(v0_1h, s8b);
const int8x16_t v1_0hs = vsubq_s8(v1_0h, s8b);
const int8x16_t v1_1hs = vsubq_s8(v1_1h, s8b);
// dot product into int16x8_t
const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0ls), vget_low_s8 (v1_0ls));
const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0ls), vget_high_s8(v1_0ls));
const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1ls), vget_low_s8 (v1_1ls));
const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1ls), vget_high_s8(v1_1ls));
const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hs), vget_low_s8 (v1_0hs));
const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hs), vget_high_s8(v1_0hs));
const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hs), vget_low_s8 (v1_1hs));
const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hs), vget_high_s8(v1_1hs));
const int16x8_t pl0 = vaddq_s16(pl0l, pl0h);
const int16x8_t pl1 = vaddq_s16(pl1l, pl1h);
const int16x8_t ph0 = vaddq_s16(ph0l, ph0h);
const int16x8_t ph1 = vaddq_s16(ph1l, ph1h);
const int16x8_t pl = vaddq_s16(pl0, pl1);
const int16x8_t ph = vaddq_s16(ph0, ph1);
const int16x8_t p = vaddq_s16(pl, ph);
// scalar
sum11 += d0*d1*vaddvq_u16(p);
}
sumf = sum11;
#endif
#endif
*s = sumf;
}
// compute GGML_VEC_DOT_UNROLL dot products at once // compute GGML_VEC_DOT_UNROLL dot products at once
// xs - x row stride in bytes // xs - x row stride in bytes
inline static void ggml_vec_dot_f16_unroll(const int n, const int xs, float * restrict s, void * restrict xv, ggml_fp16_t * restrict y) { inline static void ggml_vec_dot_f16_unroll(const int n, const int xs, float * restrict s, void * restrict xv, ggml_fp16_t * restrict y) {
@ -1202,6 +1365,40 @@ inline static void ggml_vec_mad_f16(const int n, ggml_fp16_t * restrict y, ggml_
#endif #endif
} }
inline static void ggml_vec_mad_q4_0(const int n, float * restrict y, void * restrict x, const float v) {
assert(n % 64 == 0);
const int nb = n / 64;
const float * restrict pd = (const float *) (x);
const uint8_t * restrict pb = (const uint8_t *) (pd + nb);
for (int i = 0; i < nb; i++) {
const float d = pd[i];
const uint8_t * restrict pp = pb + i*32;
for (int l = 0; l < 64; l += 2) {
const uint8_t vi = pp[l/2];
const int8_t vi0 = vi & 0xf;
const int8_t vi1 = vi >> 4;
const float v0 = (vi0 - 8)*d;
const float v1 = (vi1 - 8)*d;
y[i*64 + l + 0] += v0*v;
y[i*64 + l + 1] += v1*v;
assert(!isnan(y[i*64 + l + 0]));
assert(!isnan(y[i*64 + l + 1]));
assert(!isinf(y[i*64 + l + 0]));
assert(!isinf(y[i*64 + l + 1]));
//printf("mad: v0 %f v1 %f, i = %d, l = %d, d = %f, vi = %d, vi0 = %d, vi1 = %d\n", v0, v1, i, l, d, vi, vi0, vi1);
}
}
}
//inline static void ggml_vec_scale_f32(const int n, float * y, const float v) { for (int i = 0; i < n; ++i) y[i] *= v; } //inline static void ggml_vec_scale_f32(const int n, float * y, const float v) { for (int i = 0; i < n; ++i) y[i] *= v; }
inline static void ggml_vec_scale_f32(const int n, float * y, const float v) { inline static void ggml_vec_scale_f32(const int n, float * y, const float v) {
#if defined(GGML_SIMD) #if defined(GGML_SIMD)
@ -1808,8 +2005,8 @@ struct ggml_tensor * ggml_new_tensor_impl(
size_t size_needed = 0; size_t size_needed = 0;
if (data == NULL) { if (data == NULL) {
size_needed += GGML_TYPE_SIZE[type]; size_needed += GGML_TYPE_SIZE[type]*(ne[0]/GGML_BLCK_SIZE[type]);
for (int i = 0; i < n_dims; i++) { for (int i = 1; i < n_dims; i++) {
size_needed *= ne[i]; size_needed *= ne[i];
} }
// align to GGML_MEM_ALIGN // align to GGML_MEM_ALIGN
@ -1902,7 +2099,8 @@ struct ggml_tensor * ggml_new_tensor_impl(
} }
result->nb[0] = GGML_TYPE_SIZE[type]; result->nb[0] = GGML_TYPE_SIZE[type];
for (int i = 1; i < GGML_MAX_DIMS; i++) { result->nb[1] = result->nb[0]*(result->ne[0]/GGML_BLCK_SIZE[type]);
for (int i = 2; i < GGML_MAX_DIMS; i++) {
result->nb[i] = result->nb[i - 1]*result->ne[i - 1]; result->nb[i] = result->nb[i - 1]*result->ne[i - 1];
} }
@ -4618,7 +4816,7 @@ static bool ggml_compute_forward_mul_mat_use_blas(
if (ggml_is_contiguous(src0) && ggml_is_contiguous(src1) && ( if (ggml_is_contiguous(src0) && ggml_is_contiguous(src1) && (
(ne0 >= 32 && ne1 >= 32 && ne10 >= 32) (ne0 >= 32 && ne1 >= 32 && ne10 >= 32)
)) { )) {
//printf("BLAS: %d %d %d\n", ne0, ne1, ne10); printf("BLAS: %d %d %d\n", ne0, ne1, ne10);
return true; return true;
} }
@ -5177,6 +5375,306 @@ static void ggml_compute_forward_mul_mat_f16_f32(
//} //}
} }
static void ggml_compute_forward_mul_mat_q4_0_f32(
const struct ggml_compute_params * params,
const struct ggml_tensor * src0,
const struct ggml_tensor * src1,
struct ggml_tensor * dst) {
int64_t t0 = ggml_perf_time_us();
UNUSED(t0);
const int ne00 = src0->ne[0];
const int ne01 = src0->ne[1];
const int ne02 = src0->ne[2];
const int ne03 = src0->ne[3];
const int ne10 = src1->ne[0];
const int ne11 = src1->ne[1];
const int ne12 = src1->ne[2];
const int ne13 = src1->ne[3];
const int ne0 = dst->ne[0];
const int ne1 = dst->ne[1];
const int ne2 = dst->ne[2];
const int ne3 = dst->ne[3];
const int ne = ne0*ne1*ne2*ne3;
const int nb00 = src0->nb[0];
const int nb01 = src0->nb[1];
const int nb02 = src0->nb[2];
const int nb03 = src0->nb[3];
const int nb10 = src1->nb[0];
const int nb11 = src1->nb[1];
const int nb12 = src1->nb[2];
const int nb13 = src1->nb[3];
const int nb0 = dst->nb[0];
const int nb1 = dst->nb[1];
const int nb2 = dst->nb[2];
const int nb3 = dst->nb[3];
const int ith = params->ith;
const int nth = params->nth;
GGML_ASSERT(ne02 == ne12);
GGML_ASSERT(ne03 == ne13);
GGML_ASSERT(ne2 == ne12);
GGML_ASSERT(ne3 == ne13);
// TODO: we don't support permuted src0
GGML_ASSERT(nb00 == (int) GGML_TYPE_SIZE[GGML_TYPE_Q4_0] || nb01 == (int) GGML_TYPE_SIZE[GGML_TYPE_Q4_0]);
// dst cannot be transposed or permuted
GGML_ASSERT(nb0 == sizeof(float));
GGML_ASSERT(nb0 <= nb1);
GGML_ASSERT(nb1 <= nb2);
GGML_ASSERT(nb2 <= nb3);
GGML_ASSERT(ne0 == ne01);
GGML_ASSERT(ne1 == ne11);
GGML_ASSERT(ne2 == ne02);
GGML_ASSERT(ne3 == ne03);
// nb01 >= nb00 - src0 is not transposed
// compute by src0 rows
//
// nb00 < nb01 - src0 is transposed
// compute by src0 columns
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
if (ggml_compute_forward_mul_mat_use_blas(src0, src1, dst)) {
GGML_ASSERT(nb10 == sizeof(float));
if (params->ith != 0) {
return;
}
if (params->type == GGML_TASK_INIT) {
return;
}
if (params->type == GGML_TASK_FINALIZE) {
return;
}
float * const wdata = params->wdata;
for (int i03 = 0; i03 < ne03; i03++) {
for (int i02 = 0; i02 < ne02; i02++) {
{
int id = 0;
for (int i01 = 0; i01 < ne01; ++i01) {
//for (int i00 = 0; i00 < ne00; ++i00) {
// wdata[id++] = GGML_FP16_TO_FP32(*(ggml_fp16_t *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00));
//}
dequantize_row_q4_0((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01, wdata + id, ne00);
id += ne00;
}
}
const float * x = wdata;
const float * y = (float *) ((char *) src1->data + i02*nb12 + i03*nb13);
// float * z = wdata + ne00*ne01;
// z = x * yT
//{
// cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans,
// ne01, ne11, ne00,
// 1.0f, x, ne00,
// y, ne00,
// 0.0f, z, ne11);
//}
float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
// transpose z
//for (int j = 0; j < ne11; ++j) {
// for (int i = 0; i < ne01; ++i) {
// d[j*ne01 + i] = z[i*ne11 + j];
// }
//}
{
#if 1
// zT = y * xT
cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans,
ne11, ne01, ne10,
1.0f, y, ne00,
x, ne00,
0.0f, d, ne01);
#else
// zT = (xT * y)T
cblas_sgemm(CblasColMajor, CblasTrans, CblasNoTrans,
ne01, ne11, ne10,
1.0f, x, ne00,
y, ne00,
0.0f, d, ne01);
#endif
}
}
}
//printf("CBLAS = %f ms, %d x %d x %d x %d\n", (ggml_perf_time_us() - t0)/1000.0, ne0, ne1, ne2, ne3);
return;
}
#endif
if (params->type == GGML_TASK_INIT) {
//printf("HHHHHHHHH ith = %d, nth = %d\n", ith, nth);
if (nb01 >= nb00) {
char * wdata = params->wdata;
for (int i13 = 0; i13 < ne13; ++i13) {
for (int i12 = 0; i12 < ne12; ++i12) {
for (int i11 = 0; i11 < ne11; ++i11) {
//for (int i10 = 0; i10 < ne10; ++i10) {
// wdata[id++] = GGML_FP32_TO_FP16(*(float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11 + i10*nb10));
//}
quantize_row_q4_0((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11), (void *) wdata, ne10);
wdata += (ne10*GGML_TYPE_SIZE[GGML_TYPE_Q4_0])/GGML_BLCK_SIZE[GGML_TYPE_Q4_0];
}
}
}
return;
}
// TODO: fix this memset (wsize is overestimated)
memset(params->wdata, 0, params->wsize);
return;
}
if (params->type == GGML_TASK_FINALIZE) {
if (nb01 >= nb00) {
return;
}
float * const wdata = params->wdata;
// cols per thread
const int dc = (ne + nth - 1)/nth;
// col range for this thread
const int ic0 = dc*ith;
const int ic1 = MIN(ic0 + dc, ne);
ggml_vec_cpy_f32(ic1 - ic0, (float *) dst->data + ic0, wdata + ic0);
for (int k = 1; k < nth; k++) {
ggml_vec_acc_f32(ic1 - ic0, (float *) dst->data + ic0, wdata + (ne + CACHE_LINE_SIZE_F32)*k + ic0);
}
return;
}
if (nb01 >= nb00) {
// TODO: do not support transposed src1
// parallelize by src0 rows using ggml_vec_dot_q4_0
// total rows in src0
const int nr = ne01*ne02*ne03;
// rows per thread
const int dr = (nr + nth - 1)/nth;
// row range for this thread
const int ir0 = dr*ith;
const int ir1 = MIN(ir0 + dr, nr);
void * wdata = params->wdata;
for (int ir = ir0; ir < ir1; ++ir) {
// src0 indices
const int i03 = ir/(ne02*ne01);
const int i02 = (ir - i03*ne02*ne01)/ne01;
const int i01 = (ir - i03*ne02*ne01 - i02*ne01);
const int i13 = i03;
const int i12 = i02;
const int i0 = i01;
const int i2 = i02;
const int i3 = i03;
void * src0_row = (void *) ((char *) src0->data + (i01*nb01 + i02*nb02 + i03*nb03));
char * src1_col = ((char *) wdata + ( (0 + i12*ne11 + i13*ne12*ne11)*ne00*GGML_TYPE_SIZE[GGML_TYPE_Q4_0])/GGML_BLCK_SIZE[GGML_TYPE_Q4_0]);
float * dst_col = (float *) ((char *) dst->data + (i0*nb0 + 0*nb1 + i2*nb2 + i3*nb3));
assert(ne00 % 32 == 0);
for (int ic = 0; ic < ne11; ++ic) {
ggml_vec_dot_q4_0(ne00, &dst_col[ic*ne0], src0_row, ((void *) (src1_col + (ic*ne00*GGML_TYPE_SIZE[GGML_TYPE_Q4_0])/GGML_BLCK_SIZE[GGML_TYPE_Q4_0])));
}
}
} else {
//printf("AAAAA ith = %d, nth = %d\n", ith, nth);
// parallelize by src1 columns using ggml_vec_mad_q4_0
// each thread has its own work data
// during FINALIZE we accumulate all work data into dst
// total columns in src1
const int nc = ne10;
// columns per thread
const int dc = (nc + nth - 1)/nth;
// column range for this thread
const int ic0 = dc*ith;
const int ic1 = MIN(ic0 + dc, nc);
// work data for thread
const int wo = (ne + CACHE_LINE_SIZE_F32)*ith;
float * const wdata = params->wdata;
for (int i13 = 0; i13 < ne13; ++i13) {
for (int i12 = 0; i12 < ne12; ++i12) {
for (int i11 = 0; i11 < ne11; ++i11) {
// dst indices
const int i1 = i11;
const int i2 = i12;
const int i3 = i13;
float * dst_row = wdata + wo + i3*ne2*ne1*ne0 + i2*ne1*ne0 + i1*ne0;
for (int ic = ic0; ic < ic1; ++ic) {
// src1 indices
const int i10 = ic;
// src0 indices
const int i03 = i13;
const int i02 = i12;
const int i00 = ic;
assert(sizeof(float)*(wo + i3*ne2*ne1*ne0 + i2*ne1*ne0 + i1*ne0 + ne01) <= params->wsize);
void * src0_col = (void *) ((char *) src0->data + (i00*nb00 + i02*nb02 + i03*nb03));
float src1_val = *(float *) ((char *) src1->data + (i10*nb10 + i11*nb11 + i12*nb12 + i13*nb13));
ggml_vec_mad_q4_0(ne01, dst_row, src0_col, src1_val);
}
}
}
}
}
//int64_t t1 = ggml_time_us();
//static int64_t acc = 0;
//acc += t1 - t0;
//if (t1 - t0 > 10) {
// printf("\n");
// printf("ne00 = %5d, ne01 = %5d, ne02 = %5d, ne03 = %5d\n", ne00, ne01, ne02, ne03);
// printf("nb00 = %5d, nb01 = %5d, nb02 = %5d, nb03 = %5d\n", nb00, nb01, nb02, nb03);
// printf("ne10 = %5d, ne11 = %5d, ne12 = %5d, ne13 = %5d\n", ne10, ne11, ne12, ne13);
// printf("XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX task %d/%d: %d us, acc = %d\n", ith, nth, (int) (t1 - t0), (int) acc);
//}
}
static void ggml_compute_forward_mul_mat( static void ggml_compute_forward_mul_mat(
const struct ggml_compute_params * params, const struct ggml_compute_params * params,
const struct ggml_tensor * src0, const struct ggml_tensor * src0,
@ -5185,7 +5683,7 @@ static void ggml_compute_forward_mul_mat(
switch (src0->type) { switch (src0->type) {
case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_0:
{ {
GGML_ASSERT(false); // TODO: implement ggml_compute_forward_mul_mat_q4_0_f32(params, src0, src1, dst);
} break; } break;
case GGML_TYPE_F16: case GGML_TYPE_F16:
{ {
@ -5319,7 +5817,7 @@ static void ggml_compute_forward_transpose(
// ggml_compute_forward_get_rows // ggml_compute_forward_get_rows
static void ggml_compute_forward_get_rows_f16( static void ggml_compute_forward_get_rows_q4_0(
const struct ggml_compute_params * params, const struct ggml_compute_params * params,
const struct ggml_tensor * src0, const struct ggml_tensor * src0,
const struct ggml_tensor * src1, const struct ggml_tensor * src1,
@ -5335,19 +5833,18 @@ static void ggml_compute_forward_get_rows_f16(
assert( dst->ne[0] == nc); assert( dst->ne[0] == nc);
assert( dst->ne[1] == nr); assert( dst->ne[1] == nr);
assert(src0->nb[0] == sizeof(ggml_fp16_t)); assert(src0->nb[0] == GGML_TYPE_SIZE[GGML_TYPE_Q4_0]);
for (int i = 0; i < nr; ++i) { for (int i = 0; i < nr; ++i) {
const int r = ((int32_t *) src1->data)[i]; const int r = ((int32_t *) src1->data)[i];
for (int j = 0; j < nc; ++j) { dequantize_row_q4_0(
ggml_fp16_t v = ((ggml_fp16_t *) ((char *) src0->data + r*src0->nb[1]))[j]; (const void *) ((char *) src0->data + r*src0->nb[1]),
((float *) ((char *) dst->data + i*dst->nb[1]))[j] = GGML_FP16_TO_FP32(v); (float *) ((char *) dst->data + i*dst->nb[1]), nc);
}
} }
} }
static void ggml_compute_forward_get_rows_q4_0( static void ggml_compute_forward_get_rows_f16(
const struct ggml_compute_params * params, const struct ggml_compute_params * params,
const struct ggml_tensor * src0, const struct ggml_tensor * src0,
const struct ggml_tensor * src1, const struct ggml_tensor * src1,
@ -5363,14 +5860,15 @@ static void ggml_compute_forward_get_rows_q4_0(
assert( dst->ne[0] == nc); assert( dst->ne[0] == nc);
assert( dst->ne[1] == nr); assert( dst->ne[1] == nr);
assert(src0->nb[0] == GGML_TYPE_SIZE[GGML_TYPE_Q4_0]); assert(src0->nb[0] == sizeof(ggml_fp16_t));
for (int i = 0; i < nr; ++i) { for (int i = 0; i < nr; ++i) {
const int r = ((int32_t *) src1->data)[i]; const int r = ((int32_t *) src1->data)[i];
dequantize_row_q4_0( for (int j = 0; j < nc; ++j) {
(const void *) ((char *) src0->data + r*src0->nb[1]), ggml_fp16_t v = ((ggml_fp16_t *) ((char *) src0->data + r*src0->nb[1]))[j];
(float *) ((char *) dst->data + i*dst->nb[1]), nc); ((float *) ((char *) dst->data + i*dst->nb[1]))[j] = GGML_FP16_TO_FP32(v);
}
} }
} }
@ -5525,6 +6023,7 @@ static void ggml_compute_forward_soft_max_f32(
#ifndef NDEBUG #ifndef NDEBUG
for (int i = 0; i < nc; ++i) { for (int i = 0; i < nc; ++i) {
//printf("p[%d] = %f\n", i, p[i]);
assert(!isnan(p[i])); assert(!isnan(p[i]));
} }
#endif #endif
@ -7651,6 +8150,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
// TODO: better way to determine if the matrix is transposed // TODO: better way to determine if the matrix is transposed
if (node->src0->nb[1] < node->src0->nb[0]) { if (node->src0->nb[1] < node->src0->nb[0]) {
cur = ggml_nbytes(node)*node->n_tasks; // TODO: this can become (n_tasks-1) cur = ggml_nbytes(node)*node->n_tasks; // TODO: this can become (n_tasks-1)
// TODO: overestimated by factor of x2 for FP16
} else { } else {
if (node->src0->type == GGML_TYPE_F16 && if (node->src0->type == GGML_TYPE_F16 &&
node->src1->type == GGML_TYPE_F32) { node->src1->type == GGML_TYPE_F32) {
@ -7663,7 +8163,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
//printf("src1: ne0 = %d, ne1 = %d, ne = %d\n", node->src1->ne[0], node->src1->ne[1], node->src1->ne[0]*node->src1->ne[1]); //printf("src1: ne0 = %d, ne1 = %d, ne = %d\n", node->src1->ne[0], node->src1->ne[1], node->src1->ne[0]*node->src1->ne[1]);
//printf("cur = %zu\n", cur); //printf("cur = %zu\n", cur);
} else { } else {
cur = GGML_TYPE_SIZE[GGML_TYPE_Q4_0]*ggml_nelements(node->src1); cur = GGML_TYPE_SIZE[GGML_TYPE_F16]*ggml_nelements(node->src1);
} }
#else #else
cur = GGML_TYPE_SIZE[GGML_TYPE_F16]*ggml_nelements(node->src1); cur = GGML_TYPE_SIZE[GGML_TYPE_F16]*ggml_nelements(node->src1);
@ -7673,7 +8173,16 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
cur = 0; cur = 0;
} else if (node->src0->type == GGML_TYPE_Q4_0 && } else if (node->src0->type == GGML_TYPE_Q4_0 &&
node->src1->type == GGML_TYPE_F32) { node->src1->type == GGML_TYPE_F32) {
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
if (ggml_compute_forward_mul_mat_use_blas(node->src0, node->src1, node)) {
node->n_tasks = 1;
cur = GGML_TYPE_SIZE[GGML_TYPE_F32]*(node->src0->ne[0]*node->src0->ne[1]);
} else {
cur = (GGML_TYPE_SIZE[GGML_TYPE_Q4_0]*ggml_nelements(node->src1))/GGML_BLCK_SIZE[GGML_TYPE_Q4_0]; cur = (GGML_TYPE_SIZE[GGML_TYPE_Q4_0]*ggml_nelements(node->src1))/GGML_BLCK_SIZE[GGML_TYPE_Q4_0];
}
#else
cur = (GGML_TYPE_SIZE[GGML_TYPE_Q4_0]*ggml_nelements(node->src1))/GGML_BLCK_SIZE[GGML_TYPE_Q4_0];
#endif
} else { } else {
GGML_ASSERT(false); GGML_ASSERT(false);
} }

Loading…
Cancel
Save