diff --git a/ggml.c b/ggml.c index d67612c..fdbbde2 100644 --- a/ggml.c +++ b/ggml.c @@ -13,6 +13,7 @@ #include #include #include +#include // if C99 - static_assert is noop // ref: https://stackoverflow.com/a/53923785/4039976 @@ -348,6 +349,210 @@ int64_t ggml_cycles_per_ms(void) { static const size_t CACHE_LINE_SIZE_F32 = CACHE_LINE_SIZE/sizeof(float); +// +// quantization +// + +#define QK 32 + +// method 5 +// blocks of QK elements +// represented with a single float (delta) and QK/2 8-bit ints (i.e QK 4-bit signed integer factors) +void quantize_row_q4_0(const float * restrict x, void * restrict y, int k) { + assert(k % QK == 0); + + const int nb = k / QK; + + float * restrict pd = (float *) (y); + uint8_t * restrict pb = (uint8_t *) (pd + nb); + + uint8_t pp[QK/2]; + +#if __ARM_NEON +#if QK == 32 + for (int i = 0; i < nb; i++) { + float amax = 0.0f; // absolute max + + float32x4_t srcv [8]; + float32x4_t asrcv[8]; + float32x4_t amaxv[8]; + + for (int l = 0; l < 8; l++) srcv[l] = vld1q_f32(x + i*32 + 4*l); + for (int l = 0; l < 8; l++) asrcv[l] = vabsq_f32(srcv[l]); + + for (int l = 0; l < 4; l++) amaxv[2*l] = vmaxq_f32(asrcv[2*l], asrcv[2*l+1]); + for (int l = 0; l < 2; l++) amaxv[4*l] = vmaxq_f32(amaxv[4*l], amaxv[4*l+2]); + for (int l = 0; l < 1; l++) amaxv[8*l] = vmaxq_f32(amaxv[8*l], amaxv[8*l+4]); + + amax = MAX( + MAX(vgetq_lane_f32(amaxv[0], 0), vgetq_lane_f32(amaxv[0], 1)), + MAX(vgetq_lane_f32(amaxv[0], 2), vgetq_lane_f32(amaxv[0], 3))); + + const float d = amax / ((1 << 3) - 1); + const float id = d ? 1.0/d : 0.0; + + pd[i] = d; + + for (int l = 0; l < 8; l++) { + const float32x4_t v = vmulq_n_f32(srcv[l], id); + const float32x4_t vf = vaddq_f32(v, vdupq_n_f32(8.5f)); + const int32x4_t vi = vcvtq_s32_f32(vf); + + pp[2*l + 0] = vgetq_lane_s32(vi, 0) | (vgetq_lane_s32(vi, 1) << 4); + pp[2*l + 1] = vgetq_lane_s32(vi, 2) | (vgetq_lane_s32(vi, 3) << 4); + } + + memcpy(pb + i*16, pp, sizeof(pp)); + } +#else +#error "not implemented for QK" +#endif +#else + // scalar + for (int i = 0; i < nb; i++) { + float amax = 0.0f; // absolute max + + for (int l = 0; l < QK; l++) { + const float v = x[i*QK + l]; + amax = MAX(amax, fabsf(v)); + } + + const float d = amax / ((1 << 3) - 1); + const float id = d ? 1.0f/d : 0.0f; + + pd[i] = d; + + for (int l = 0; l < QK; l += 2) { + const float v0 = x[i*QK + l + 0]*id; + const float v1 = x[i*QK + l + 1]*id; + + const uint8_t vi0 = ((int8_t) (round(v0))) + 8; + const uint8_t vi1 = ((int8_t) (round(v1))) + 8; + + assert(vi0 >= 0 && vi0 < 16); + assert(vi1 >= 0 && vi1 < 16); + + pp[l/2] = vi0 | (vi1 << 4); + } + + memcpy(pb + i*QK/2, pp, sizeof(pp)); + } +#endif +} + +// method 4 +// blocks of QK elements +// represented with 2 floats (min + delta) and QK/2 8-bit ints (i.e QK 4-bit unsigned integer factors) +void quantize_row_q4_1(const float * restrict x, void * restrict y, int k) { + assert(k % QK == 0); + + const int nb = k / QK; + + float * restrict pm = (float *) (y); + float * restrict pd = (float *) (pm + nb); + uint8_t * restrict pb = (uint8_t *) (pd + nb); + + uint8_t pp[QK/2]; + + for (int i = 0; i < nb; i++) { + float min = FLT_MAX; + float max = -FLT_MAX; + + for (int l = 0; l < QK; l++) { + const float v = x[i*QK + l]; + if (v < min) min = v; + if (v > max) max = v; + } + + const float d = (max - min) / ((1 << 4) - 1); + const float id = d ? 1.0f/d : 0.0f; + + pm[i] = min; + pd[i] = d; + + for (int l = 0; l < QK; l += 2) { + const float v0 = (x[i*QK + l + 0] - min)*id; + const float v1 = (x[i*QK + l + 1] - min)*id; + + const uint8_t vi0 = round(v0); + const uint8_t vi1 = round(v1); + + assert(vi0 >= 0 && vi0 < 16); + assert(vi1 >= 0 && vi1 < 16); + + pp[l/2] = vi0 | (vi1 << 4); + } + + memcpy(pb + i*QK/2, pp, sizeof(pp)); + } +} + +// TODO: vectorize +void dequantize_row_q4_0(const void * restrict x, float * restrict y, int k) { + assert(k % QK == 0); + + const int nb = k / QK; + + const float * restrict pd = (const float *) (x); + const uint8_t * restrict pb = (const uint8_t *) (pd + nb); + + // scalar + for (int i = 0; i < nb; i++) { + const float d = pd[i]; + + const uint8_t * restrict pp = pb + i*QK/2; + + for (int l = 0; l < QK; l += 2) { + const uint8_t vi = pp[l/2]; + + const int8_t vi0 = vi & 0xf; + const int8_t vi1 = vi >> 4; + + const float v0 = (vi0 - 8)*d; + const float v1 = (vi1 - 8)*d; + + y[i*QK + l + 0] = v0; + y[i*QK + l + 1] = v1; + + assert(!isnan(y[i*QK + l + 0])); + assert(!isnan(y[i*QK + l + 1])); + } + } +} + +void dequantize_row_q4_1(const void * restrict x, float * restrict y, int k) { + assert(k % QK == 0); + + const int nb = k / QK; + + const float * restrict pm = (const float *) (x); + const float * restrict pd = (const float *) (pm + nb); + const uint8_t * restrict pb = (const uint8_t *) (pd + nb); + + for (int i = 0; i < nb; i++) { + const float m = pm[i]; + const float d = pd[i]; + + const uint8_t * restrict pp = pb + i*QK/2; + + for (int l = 0; l < QK; l += 2) { + const uint8_t vi = pp[l/2]; + + const int8_t vi0 = vi & 0xf; + const int8_t vi1 = vi >> 4; + + const float v0 = vi0*d + m; + const float v1 = vi1*d + m; + + y[i*QK + l + 0] = v0; + y[i*QK + l + 1] = v1; + + assert(!isnan(y[i*QK + l + 0])); + assert(!isnan(y[i*QK + l + 1])); + } + } +} + // // simd mappings // @@ -925,6 +1130,209 @@ inline static void ggml_vec_dot_f16(const int n, float * restrict s, ggml_fp16_t *s = sumf; } +inline static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void * restrict x, const void * restrict y) { + const int nb = n / QK; + + assert(n % QK == 0); + assert(nb % 2 == 0); + + const float * restrict pd0 = (const float *) x; + const float * restrict pd1 = (const float *) y; + + const uint8_t * restrict pb0 = (const uint8_t *) (pd0 + nb); + const uint8_t * restrict pb1 = (const uint8_t *) (pd1 + nb); + + float sumf = 0.0; + +#ifdef __ARM_NEON +#if QK == 32 + float sum0 = 0.0f; + float sum1 = 0.0f; + + for (int i = 0; i < nb; i += 2) { + const float d0_0 = pd0[i + 0]; + const float d1_0 = pd1[i + 0]; + const float d0_1 = pd0[i + 1]; + const float d1_1 = pd1[i + 1]; + + //printf("d0_0: %f, d1_0: %f, d0_1: %f, d1_1: %f\n", d0_0, d1_0, d0_1, d1_1); + + const uint8_t * restrict p0 = pb0 + i*16; + const uint8_t * restrict p1 = pb1 + i*16; + + const uint8x16_t m4b = vdupq_n_u8(0xf); + const int8x16_t s8b = vdupq_n_s8(0x8); + + const uint8x16_t v0_0 = vld1q_u8(p0); + const uint8x16_t v1_0 = vld1q_u8(p1); + const uint8x16_t v0_1 = vld1q_u8(p0 + 16); + const uint8x16_t v1_1 = vld1q_u8(p1 + 16); + + // 4-bit -> 8-bit + const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8(v0_0, m4b)); + const int8x16_t v1_0l = vreinterpretq_s8_u8(vandq_u8(v1_0, m4b)); + + const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4)); + const int8x16_t v1_0h = vreinterpretq_s8_u8(vshrq_n_u8(v1_0, 4)); + + const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8(v0_1, m4b)); + const int8x16_t v1_1l = vreinterpretq_s8_u8(vandq_u8(v1_1, m4b)); + + const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4)); + const int8x16_t v1_1h = vreinterpretq_s8_u8(vshrq_n_u8(v1_1, 4)); + + // sub 8 + const int8x16_t v0_0ls = vsubq_s8(v0_0l, s8b); + const int8x16_t v1_0ls = vsubq_s8(v1_0l, s8b); + + const int8x16_t v0_0hs = vsubq_s8(v0_0h, s8b); + const int8x16_t v1_0hs = vsubq_s8(v1_0h, s8b); + + const int8x16_t v0_1ls = vsubq_s8(v0_1l, s8b); + const int8x16_t v1_1ls = vsubq_s8(v1_1l, s8b); + + const int8x16_t v0_1hs = vsubq_s8(v0_1h, s8b); + const int8x16_t v1_1hs = vsubq_s8(v1_1h, s8b); + + // dot product into int16x8_t + const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0ls), vget_low_s8 (v1_0ls)); + const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0ls), vget_high_s8(v1_0ls)); + + const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hs), vget_low_s8 (v1_0hs)); + const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hs), vget_high_s8(v1_0hs)); + + const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1ls), vget_low_s8 (v1_1ls)); + const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1ls), vget_high_s8(v1_1ls)); + + const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hs), vget_low_s8 (v1_1hs)); + const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hs), vget_high_s8(v1_1hs)); + + const int16x8_t pl_0 = vaddq_s16(pl0l, pl0h); + const int16x8_t ph_0 = vaddq_s16(ph0l, ph0h); + + const int16x8_t pl_1 = vaddq_s16(pl1l, pl1h); + const int16x8_t ph_1 = vaddq_s16(ph1l, ph1h); + + const int16x8_t p_0 = vaddq_s16(pl_0, ph_0); + const int16x8_t p_1 = vaddq_s16(pl_1, ph_1); + + //printf("p_0: %d %d %d %d %d %d %d %d\n", vgetq_lane_s16(p_0, 0), vgetq_lane_s16(p_0, 1), vgetq_lane_s16(p_0, 2), vgetq_lane_s16(p_0, 3), vgetq_lane_s16(p_0, 4), vgetq_lane_s16(p_0, 5), vgetq_lane_s16(p_0, 6), vgetq_lane_s16(p_0, 7)); + //printf("p_1: %d %d %d %d %d %d %d %d\n", vgetq_lane_s16(p_1, 0), vgetq_lane_s16(p_1, 1), vgetq_lane_s16(p_1, 2), vgetq_lane_s16(p_1, 3), vgetq_lane_s16(p_1, 4), vgetq_lane_s16(p_1, 5), vgetq_lane_s16(p_1, 6), vgetq_lane_s16(p_1, 7)); + + // scalar +#if defined(__ARM_FEATURE_QRDMX) + sum0 += d0_0*d1_0*vaddvq_s16(p_0); + sum1 += d0_1*d1_1*vaddvq_s16(p_1); +#else + sum0 += d0_0*d1_0*(vgetq_lane_s16(p_0, 0) + vgetq_lane_s16(p_0, 1) + vgetq_lane_s16(p_0, 2) + vgetq_lane_s16(p_0, 3) + vgetq_lane_s16(p_0, 4) + vgetq_lane_s16(p_0, 5) + vgetq_lane_s16(p_0, 6) + vgetq_lane_s16(p_0, 7)); + sum1 += d0_1*d1_1*(vgetq_lane_s16(p_1, 0) + vgetq_lane_s16(p_1, 1) + vgetq_lane_s16(p_1, 2) + vgetq_lane_s16(p_1, 3) + vgetq_lane_s16(p_1, 4) + vgetq_lane_s16(p_1, 5) + vgetq_lane_s16(p_1, 6) + vgetq_lane_s16(p_1, 7)); +#endif + } + + sumf = sum0 + sum1; + + //printf("sumf SIMD = %f\n", sumf); + + //// scalar + //sumf = 0.0f; + //for (int i = 0; i < nb; i++) { + // const float d0 = pd0[i]; + // const float d1 = pd1[i]; + + // const uint8_t * restrict p0 = pb0 + i*QK/2; + // const uint8_t * restrict p1 = pb1 + i*QK/2; + + // for (int j = 0; j < QK/2; j++) { + // const uint8_t v0 = p0[j]; + // const uint8_t v1 = p1[j]; + + // const float f0 = d0*((int8_t) (v0 & 0xf) - 8); + // const float f1 = d0*((int8_t) (v0 >> 4) - 8); + + // const float f2 = d1*((int8_t) (v1 & 0xf) - 8); + // const float f3 = d1*((int8_t) (v1 >> 4) - 8); + + // sumf += f0*f2 + f1*f3; + // } + //} + //printf("sumf scalar = %f\n", sumf); + //printf("--------\n"); + + //exit(0); +#else +#error "not implemented for QK" +#endif +#else + // scalar + for (int i = 0; i < nb; i++) { + const float d0 = pd0[i]; + const float d1 = pd1[i]; + + const uint8_t * restrict p0 = pb0 + i*QK/2; + const uint8_t * restrict p1 = pb1 + i*QK/2; + + for (int j = 0; j < QK/2; j++) { + const uint8_t v0 = p0[j]; + const uint8_t v1 = p1[j]; + + const float f0 = d0*((int8_t) (v0 & 0xf) - 8); + const float f1 = d0*((int8_t) (v0 >> 4) - 8); + + const float f2 = d1*((int8_t) (v1 & 0xf) - 8); + const float f3 = d1*((int8_t) (v1 >> 4) - 8); + + sumf += f0*f2 + f1*f3; + } + } +#endif + + *s = sumf; +} + +inline static void ggml_vec_dot_q4_1(const int n, float * restrict s, const void * restrict x, const void * restrict y) { + const int nb = n / QK; + + const float * restrict pm0 = (const float *) x; + const float * restrict pm1 = (const float *) y; + + const float * restrict pd0 = (const float *) (pm0 + nb); + const float * restrict pd1 = (const float *) (pm1 + nb); + + const uint8_t * restrict pb0 = (const uint8_t *) (pd0 + nb); + const uint8_t * restrict pb1 = (const uint8_t *) (pd1 + nb); + + float sumf = 0.0; + +#if 1 + // scalar + for (int i = 0; i < nb; i++) { + const float m0 = pm0[i]; + const float m1 = pm1[i]; + + const float d0 = pd0[i]; + const float d1 = pd1[i]; + + const uint8_t * restrict p0 = pb0 + i*QK/2; + const uint8_t * restrict p1 = pb1 + i*QK/2; + + for (int j = 0; j < QK/2; j++) { + const uint8_t v0 = p0[j]; + const uint8_t v1 = p1[j]; + + const float f0 = d0*(v0 & 0xf) + m0; + const float f1 = d0*(v0 >> 4) + m0; + + const float f2 = d1*(v1 & 0xf) + m1; + const float f3 = d1*(v1 >> 4) + m1; + + sumf += f0*f2 + f1*f3; + } + } +#endif + + *s = sumf; +} + // compute GGML_VEC_DOT_UNROLL dot products at once // xs - x row stride in bytes inline static void ggml_vec_dot_f16_unroll(const int n, const int xs, float * restrict s, void * restrict xv, ggml_fp16_t * restrict y) { @@ -1042,6 +1450,134 @@ inline static void ggml_vec_mad_f16(const int n, ggml_fp16_t * restrict y, ggml_ #endif } +inline static void ggml_vec_mad_q4_0(const int n, float * restrict y, void * restrict x, const float v) { + assert(n % QK == 0); + + const int nb = n / QK; + + const float * restrict pd = (const float *) (x); + const uint8_t * restrict pb = (const uint8_t *) (pd + nb); + +#if __ARM_NEON +#if QK == 32 + for (int i = 0; i < nb; ++i) { + const float d0 = pd[i]*v; + + const uint8_t * restrict pp = pb + i*16; + + const uint8x8_t m4b = vdup_n_u8(0xf); + const int8x8_t s8b = vdup_n_s8(0x8); + + const float32x4_t vd = vdupq_n_f32(d0); + + for (int j = 0; j < 2; j++) { + const uint8x8_t vx = vld1_u8(pp + j*8); + + const int8x8_t vxl = vreinterpret_s8_u8(vand_u8(vx, m4b)); + const int8x8_t vxh = vreinterpret_s8_u8(vshr_n_u8(vx, 4)); + + // sub 8 + const int8x8_t vxls = vsub_s8(vxl, s8b); + const int8x8_t vxhs = vsub_s8(vxh, s8b); + + //const int8x8_t vxlt = vzip_s8(vxls, vxhs)[0]; + //const int8x8_t vxht = vzip_s8(vxls, vxhs)[1]; + const int8x8_t vxlt = vzip1_s8(vxls, vxhs); + const int8x8_t vxht = vzip2_s8(vxls, vxhs); + + const int8x16_t vxq = vcombine_s8(vxlt, vxht); + + // convert to 2x int16x8_t + const int16x8_t vxq0 = vmovl_s8(vget_low_s8 (vxq)); + const int16x8_t vxq1 = vmovl_s8(vget_high_s8(vxq)); + + // convert to 4x float32x4_t + const float32x4_t vx0 = vcvtq_f32_s32(vmovl_s16(vget_low_s16 (vxq0))); + const float32x4_t vx1 = vcvtq_f32_s32(vmovl_s16(vget_high_s16(vxq0))); + const float32x4_t vx2 = vcvtq_f32_s32(vmovl_s16(vget_low_s16 (vxq1))); + const float32x4_t vx3 = vcvtq_f32_s32(vmovl_s16(vget_high_s16(vxq1))); + + const float32x4_t vy0 = vld1q_f32(y + i*32 + j*16 + 0); + const float32x4_t vy1 = vld1q_f32(y + i*32 + j*16 + 4); + const float32x4_t vy2 = vld1q_f32(y + i*32 + j*16 + 8); + const float32x4_t vy3 = vld1q_f32(y + i*32 + j*16 + 12); + + const float32x4_t vr0 = vfmaq_f32(vy0, vx0, vd); + const float32x4_t vr1 = vfmaq_f32(vy1, vx1, vd); + const float32x4_t vr2 = vfmaq_f32(vy2, vx2, vd); + const float32x4_t vr3 = vfmaq_f32(vy3, vx3, vd); + + vst1q_f32(y + i*32 + j*16 + 0, vr0); + vst1q_f32(y + i*32 + j*16 + 4, vr1); + vst1q_f32(y + i*32 + j*16 + 8, vr2); + vst1q_f32(y + i*32 + j*16 + 12, vr3); + } + } +#endif +#else + // scalar + for (int i = 0; i < nb; i++) { + const float d = pd[i]; + + const uint8_t * restrict pp = pb + i*QK/2; + + for (int l = 0; l < QK; l += 2) { + const uint8_t vi = pp[l/2]; + + const int8_t vi0 = vi & 0xf; + const int8_t vi1 = vi >> 4; + + const float v0 = (vi0 - 8)*d; + const float v1 = (vi1 - 8)*d; + + y[i*QK + l + 0] += v0*v; + y[i*QK + l + 1] += v1*v; + + assert(!isnan(y[i*QK + l + 0])); + assert(!isnan(y[i*QK + l + 1])); + assert(!isinf(y[i*QK + l + 0])); + assert(!isinf(y[i*QK + l + 1])); + } + } +#endif +} + +inline static void ggml_vec_mad_q4_1(const int n, float * restrict y, void * restrict x, const float v) { + assert(n % QK == 0); + + const int nb = n / QK; + + const float * restrict pm = (const float *) (x); + const float * restrict pd = (const float *) (pm + nb); + const uint8_t * restrict pb = (const uint8_t *) (pd + nb); + + for (int i = 0; i < nb; i++) { + const float m = pm[i]; + const float d = pd[i]; + + const uint8_t * restrict pp = pb + i*QK/2; + + for (int l = 0; l < QK; l += 2) { + const uint8_t vi = pp[l/2]; + + const uint8_t vi0 = vi & 0xf; + const uint8_t vi1 = vi >> 4; + + const float v0 = d*vi0 + m; + const float v1 = d*vi1 + m; + + y[i*QK + l + 0] += v0*v; + y[i*QK + l + 1] += v1*v; + + assert(!isnan(y[i*QK + l + 0])); + assert(!isnan(y[i*QK + l + 1])); + assert(!isinf(y[i*QK + l + 0])); + assert(!isinf(y[i*QK + l + 1])); + //printf("mad: v0 %f v1 %f, i = %d, l = %d, d = %f, vi = %d, vi0 = %d, vi1 = %d\n", v0, v1, i, l, d, vi, vi0, vi1); + } + } +} + //inline static void ggml_vec_scale_f32(const int n, float * y, const float v) { for (int i = 0; i < n; ++i) y[i] *= v; } inline static void ggml_vec_scale_f32(const int n, float * y, const float v) { #if defined(GGML_SIMD) @@ -1165,7 +1701,21 @@ inline static void ggml_vec_norm_inv_f32(const int n, float * s, const float * x // data types // +static const int GGML_BLCK_SIZE[GGML_TYPE_COUNT] = { + QK, + QK, + 1, + 1, + 1, + 1, + 1, +}; + +static_assert(GGML_TYPE_COUNT == 7, "GGML_TYPE_COUNT != 5"); + static const size_t GGML_TYPE_SIZE[GGML_TYPE_COUNT] = { + sizeof(float ) + QK/2, + sizeof(float )*2 + QK/2, sizeof(int8_t ), sizeof(int16_t), sizeof(int32_t), @@ -1173,6 +1723,9 @@ static const size_t GGML_TYPE_SIZE[GGML_TYPE_COUNT] = { sizeof(float ), }; +// don't forget to update the array above when adding new types +static_assert(GGML_TYPE_COUNT == 7, "GGML_TYPE_COUNT != 5"); + static const char * GGML_OP_LABEL[GGML_OP_COUNT] = { "NONE", @@ -1213,6 +1766,8 @@ static const char * GGML_OP_LABEL[GGML_OP_COUNT] = { "FLASH_FF", }; +static_assert(GGML_OP_COUNT == 33, "GGML_OP_COUNT != 33"); + static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "none", @@ -1253,6 +1808,8 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "flash_ff(x)", }; +static_assert(GGML_OP_COUNT == 33, "GGML_OP_COUNT != 33"); + // // ggml object // @@ -1380,13 +1937,21 @@ int ggml_nrows(const struct ggml_tensor * tensor) { size_t ggml_nbytes(const struct ggml_tensor * tensor) { static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function"); - return ggml_nelements(tensor)*GGML_TYPE_SIZE[tensor->type]; + return (ggml_nelements(tensor)*GGML_TYPE_SIZE[tensor->type])/GGML_BLCK_SIZE[tensor->type]; +} + +int ggml_blck_size(enum ggml_type type) { + return GGML_BLCK_SIZE[type]; } size_t ggml_type_size(enum ggml_type type) { return GGML_TYPE_SIZE[type]; } +float ggml_type_sizef(enum ggml_type type) { + return ((float)(GGML_TYPE_SIZE[type]))/GGML_BLCK_SIZE[type]; +} + size_t ggml_element_size(const struct ggml_tensor * tensor) { return GGML_TYPE_SIZE[tensor->type]; } @@ -1423,7 +1988,7 @@ static inline bool ggml_is_contiguous(const struct ggml_tensor * tensor) { return tensor->nb[0] == GGML_TYPE_SIZE[tensor->type] && - tensor->nb[1] == tensor->nb[0]*tensor->ne[0] && + tensor->nb[1] == (tensor->nb[0]*tensor->ne[0])/GGML_BLCK_SIZE[tensor->type] && tensor->nb[2] == tensor->nb[1]*tensor->ne[1] && tensor->nb[3] == tensor->nb[2]*tensor->ne[2]; } @@ -1623,8 +2188,8 @@ struct ggml_tensor * ggml_new_tensor_impl( size_t size_needed = 0; if (data == NULL) { - size_needed += GGML_TYPE_SIZE[type]; - for (int i = 0; i < n_dims; i++) { + size_needed += GGML_TYPE_SIZE[type]*(ne[0]/GGML_BLCK_SIZE[type]); + for (int i = 1; i < n_dims; i++) { size_needed *= ne[i]; } // align to GGML_MEM_ALIGN @@ -1717,7 +2282,8 @@ struct ggml_tensor * ggml_new_tensor_impl( } result->nb[0] = GGML_TYPE_SIZE[type]; - for (int i = 1; i < GGML_MAX_DIMS; i++) { + result->nb[1] = result->nb[0]*(result->ne[0]/GGML_BLCK_SIZE[type]); + for (int i = 2; i < GGML_MAX_DIMS; i++) { result->nb[i] = result->nb[i - 1]*result->ne[i - 1]; } @@ -1814,6 +2380,14 @@ struct ggml_tensor * ggml_set_i32 (struct ggml_tensor * tensor, int32_t value) { char * const data = tensor->data; switch (tensor->type) { + case GGML_TYPE_Q4_0: + { + GGML_ASSERT(false); + } break; + case GGML_TYPE_Q4_1: + { + GGML_ASSERT(false); + } break; case GGML_TYPE_I8: { assert(tensor->nb[0] == sizeof(int8_t)); @@ -1851,7 +2425,7 @@ struct ggml_tensor * ggml_set_i32 (struct ggml_tensor * tensor, int32_t value) { } break; case GGML_TYPE_COUNT: { - assert(false); + GGML_ASSERT(false); } break; } @@ -1866,6 +2440,14 @@ struct ggml_tensor * ggml_set_f32(struct ggml_tensor * tensor, float value) { char * const data = tensor->data; switch (tensor->type) { + case GGML_TYPE_Q4_0: + { + GGML_ASSERT(false); + } break; + case GGML_TYPE_Q4_1: + { + GGML_ASSERT(false); + } break; case GGML_TYPE_I8: { assert(tensor->nb[0] == sizeof(int8_t)); @@ -1903,7 +2485,7 @@ struct ggml_tensor * ggml_set_f32(struct ggml_tensor * tensor, float value) { } break; case GGML_TYPE_COUNT: { - assert(false); + GGML_ASSERT(false); } break; } @@ -1912,6 +2494,14 @@ struct ggml_tensor * ggml_set_f32(struct ggml_tensor * tensor, float value) { int32_t ggml_get_i32_1d(const struct ggml_tensor * tensor, int i) { switch (tensor->type) { + case GGML_TYPE_Q4_0: + { + GGML_ASSERT(false); + } break; + case GGML_TYPE_Q4_1: + { + GGML_ASSERT(false); + } break; case GGML_TYPE_I8: { GGML_ASSERT(tensor->nb[0] == sizeof(int8_t)); @@ -1948,6 +2538,14 @@ int32_t ggml_get_i32_1d(const struct ggml_tensor * tensor, int i) { void ggml_set_i32_1d(const struct ggml_tensor * tensor, int i, int32_t value) { switch (tensor->type) { + case GGML_TYPE_Q4_0: + { + GGML_ASSERT(false); + } break; + case GGML_TYPE_Q4_1: + { + GGML_ASSERT(false); + } break; case GGML_TYPE_I8: { GGML_ASSERT(tensor->nb[0] == sizeof(int8_t)); @@ -1982,6 +2580,14 @@ void ggml_set_i32_1d(const struct ggml_tensor * tensor, int i, int32_t value) { float ggml_get_f32_1d(const struct ggml_tensor * tensor, int i) { switch (tensor->type) { + case GGML_TYPE_Q4_0: + { + GGML_ASSERT(false); + } break; + case GGML_TYPE_Q4_1: + { + GGML_ASSERT(false); + } break; case GGML_TYPE_I8: { GGML_ASSERT(tensor->nb[0] == sizeof(int8_t)); @@ -2018,6 +2624,14 @@ float ggml_get_f32_1d(const struct ggml_tensor * tensor, int i) { void ggml_set_f32_1d(const struct ggml_tensor * tensor, int i, float value) { switch (tensor->type) { + case GGML_TYPE_Q4_0: + { + GGML_ASSERT(false); + } break; + case GGML_TYPE_Q4_1: + { + GGML_ASSERT(false); + } break; case GGML_TYPE_I8: { GGML_ASSERT(tensor->nb[0] == sizeof(int8_t)); @@ -2108,7 +2722,7 @@ struct ggml_tensor * ggml_add_impl( struct ggml_tensor * a, struct ggml_tensor * b, bool inplace) { - assert(ggml_are_same_shape(a, b)); + GGML_ASSERT(ggml_are_same_shape(a, b)); bool is_node = false; @@ -2147,7 +2761,7 @@ struct ggml_tensor * ggml_sub_impl( struct ggml_tensor * a, struct ggml_tensor * b, bool inplace) { - assert(ggml_are_same_shape(a, b)); + GGML_ASSERT(ggml_are_same_shape(a, b)); bool is_node = false; @@ -2186,7 +2800,7 @@ struct ggml_tensor * ggml_mul_impl( struct ggml_tensor * a, struct ggml_tensor * b, bool inplace) { - assert(ggml_are_same_shape(a, b)); + GGML_ASSERT(ggml_are_same_shape(a, b)); bool is_node = false; @@ -2195,7 +2809,7 @@ struct ggml_tensor * ggml_mul_impl( } if (inplace) { - assert(is_node == false); + GGML_ASSERT(is_node == false); } struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); @@ -2229,7 +2843,7 @@ struct ggml_tensor * ggml_div_impl( struct ggml_tensor * a, struct ggml_tensor * b, bool inplace) { - assert(ggml_are_same_shape(a, b)); + GGML_ASSERT(ggml_are_same_shape(a, b)); bool is_node = false; @@ -2238,7 +2852,7 @@ struct ggml_tensor * ggml_div_impl( } if (inplace) { - assert(is_node == false); + GGML_ASSERT(is_node == false); } struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); @@ -2362,7 +2976,7 @@ struct ggml_tensor * ggml_mean( bool is_node = false; if (a->grad) { - assert(false); // TODO: implement + GGML_ASSERT(false); // TODO: implement is_node = true; } @@ -2383,7 +2997,7 @@ struct ggml_tensor * ggml_repeat( struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b) { - assert(ggml_can_repeat(a, b)); + GGML_ASSERT(ggml_can_repeat(a, b)); bool is_node = false; @@ -2619,7 +3233,7 @@ struct ggml_tensor * ggml_norm_impl( bool is_node = false; if (!inplace && (a->grad)) { - assert(false); // TODO: implement backward + GGML_ASSERT(false); // TODO: implement backward is_node = true; } @@ -2651,7 +3265,7 @@ struct ggml_tensor * ggml_mul_mat( struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b) { - assert(ggml_can_mul_mat(a, b)); + GGML_ASSERT(ggml_can_mul_mat(a, b)); bool is_node = false; @@ -2677,13 +3291,13 @@ struct ggml_tensor * ggml_scale_impl( struct ggml_tensor * a, struct ggml_tensor * b, bool inplace) { - assert(ggml_is_scalar(b)); - assert(ggml_is_padded_1d(a)); + GGML_ASSERT(ggml_is_scalar(b)); + GGML_ASSERT(ggml_is_padded_1d(a)); bool is_node = false; if (!inplace && (a->grad || b->grad)) { - assert(false); // TODO: implement backward + GGML_ASSERT(false); // TODO: implement backward is_node = true; } @@ -2720,12 +3334,12 @@ struct ggml_tensor * ggml_cpy_impl( struct ggml_tensor * a, struct ggml_tensor * b, bool inplace) { - assert(ggml_nelements(a) == ggml_nelements(b)); + GGML_ASSERT(ggml_nelements(a) == ggml_nelements(b)); bool is_node = false; if (!inplace && (a->grad || b->grad)) { - assert(false); // TODO: implement backward + GGML_ASSERT(false); // TODO: implement backward is_node = true; } @@ -2760,14 +3374,14 @@ struct ggml_tensor * ggml_reshape( struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b) { - assert(ggml_is_contiguous(a)); - assert(ggml_is_contiguous(b)); - assert(ggml_nelements(a) == ggml_nelements(b)); + GGML_ASSERT(ggml_is_contiguous(a)); + GGML_ASSERT(ggml_is_contiguous(b)); + GGML_ASSERT(ggml_nelements(a) == ggml_nelements(b)); bool is_node = false; if (a->grad || b->grad) { - assert(false); // TODO: implement backward + GGML_ASSERT(false); // TODO: implement backward is_node = true; } @@ -2786,13 +3400,13 @@ struct ggml_tensor * ggml_reshape_2d( struct ggml_tensor * a, int ne0, int ne1) { - assert(ggml_is_contiguous(a)); - assert(ggml_nelements(a) == ne0*ne1); + GGML_ASSERT(ggml_is_contiguous(a)); + GGML_ASSERT(ggml_nelements(a) == ne0*ne1); bool is_node = false; if (a->grad) { - assert(false); // TODO: implement backward + GGML_ASSERT(false); // TODO: implement backward is_node = true; } @@ -2813,13 +3427,13 @@ struct ggml_tensor * ggml_reshape_3d( int ne0, int ne1, int ne2) { - assert(ggml_is_contiguous(a)); - assert(ggml_nelements(a) == ne0*ne1*ne2); + GGML_ASSERT(ggml_is_contiguous(a)); + GGML_ASSERT(ggml_nelements(a) == ne0*ne1*ne2); bool is_node = false; if (a->grad) { - assert(false); // TODO: implement backward + GGML_ASSERT(false); // TODO: implement backward is_node = true; } @@ -2842,7 +3456,7 @@ struct ggml_tensor * ggml_view_1d( int ne0, size_t offset) { if (a->grad) { - assert(false); // gradient propagation is not supported + GGML_ASSERT(false); // gradient propagation is not supported } struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, 1, &ne0, (char *) a->data + offset); @@ -2865,7 +3479,7 @@ struct ggml_tensor * ggml_view_2d( size_t nb1, size_t offset) { if (a->grad) { - assert(false); // gradient propagation is not supported + GGML_ASSERT(false); // gradient propagation is not supported } const int ne[GGML_MAX_DIMS] = { ne0, ne1, 1, 1 }; @@ -2893,22 +3507,22 @@ struct ggml_tensor * ggml_permute( int axis1, int axis2, int axis3) { - assert(axis0 >= 0 && axis0 < GGML_MAX_DIMS); - assert(axis1 >= 0 && axis1 < GGML_MAX_DIMS); - assert(axis2 >= 0 && axis2 < GGML_MAX_DIMS); - assert(axis3 >= 0 && axis3 < GGML_MAX_DIMS); - - assert(axis0 != axis1); - assert(axis0 != axis2); - assert(axis0 != axis3); - assert(axis1 != axis2); - assert(axis1 != axis3); - assert(axis2 != axis3); + GGML_ASSERT(axis0 >= 0 && axis0 < GGML_MAX_DIMS); + GGML_ASSERT(axis1 >= 0 && axis1 < GGML_MAX_DIMS); + GGML_ASSERT(axis2 >= 0 && axis2 < GGML_MAX_DIMS); + GGML_ASSERT(axis3 >= 0 && axis3 < GGML_MAX_DIMS); + + GGML_ASSERT(axis0 != axis1); + GGML_ASSERT(axis0 != axis2); + GGML_ASSERT(axis0 != axis3); + GGML_ASSERT(axis1 != axis2); + GGML_ASSERT(axis1 != axis3); + GGML_ASSERT(axis2 != axis3); bool is_node = false; if (a->grad) { - assert(false); // TODO: implement backward + GGML_ASSERT(false); // TODO: implement backward is_node = true; } @@ -2953,7 +3567,7 @@ struct ggml_tensor * ggml_transpose( bool is_node = false; if (a->grad) { - assert(false); // TODO: implement backward + GGML_ASSERT(false); // TODO: implement backward is_node = true; } @@ -2979,12 +3593,12 @@ struct ggml_tensor * ggml_get_rows( struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b) { - assert(ggml_is_matrix(a) && ggml_is_vector(b) && b->type == GGML_TYPE_I32); + GGML_ASSERT(ggml_is_matrix(a) && ggml_is_vector(b) && b->type == GGML_TYPE_I32); bool is_node = false; if (a->grad || b->grad) { - assert(false); // TODO: implement backward + GGML_ASSERT(false); // TODO: implement backward is_node = true; } @@ -3009,7 +3623,7 @@ struct ggml_tensor * ggml_diag_mask_inf( bool is_node = false; if (a->grad) { - assert(false); // TODO: implement backward + GGML_ASSERT(false); // TODO: implement backward is_node = true; } @@ -3034,7 +3648,7 @@ struct ggml_tensor * ggml_soft_max( bool is_node = false; if (a->grad) { - assert(false); // TODO: implement backward + GGML_ASSERT(false); // TODO: implement backward is_node = true; } @@ -3058,11 +3672,11 @@ struct ggml_tensor * ggml_rope( int n_past, int n_dims, int mode) { - assert(n_past >= 0); + GGML_ASSERT(n_past >= 0); bool is_node = false; if (a->grad) { - assert(false); // TODO: implement backward + GGML_ASSERT(false); // TODO: implement backward is_node = true; } @@ -3089,13 +3703,13 @@ struct ggml_tensor * ggml_conv_1d_1s( struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b) { - assert(ggml_is_matrix(b)); - assert(a->ne[1] == b->ne[1]); - assert(a->ne[3] == 1); + GGML_ASSERT(ggml_is_matrix(b)); + GGML_ASSERT(a->ne[1] == b->ne[1]); + GGML_ASSERT(a->ne[3] == 1); bool is_node = false; if (a->grad || b->grad) { - assert(false); // TODO: implement backward + GGML_ASSERT(false); // TODO: implement backward is_node = true; } @@ -3116,13 +3730,13 @@ struct ggml_tensor * ggml_conv_1d_2s( struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b) { - assert(ggml_is_matrix(b)); - assert(a->ne[1] == b->ne[1]); - assert(a->ne[3] == 1); + GGML_ASSERT(ggml_is_matrix(b)); + GGML_ASSERT(a->ne[1] == b->ne[1]); + GGML_ASSERT(a->ne[3] == 1); bool is_node = false; if (a->grad || b->grad) { - assert(false); // TODO: implement backward + GGML_ASSERT(false); // TODO: implement backward is_node = true; } @@ -3145,7 +3759,7 @@ struct ggml_tensor * ggml_flash_attn( struct ggml_tensor * k, struct ggml_tensor * v, bool masked) { - assert(ggml_can_mul_mat(k, q)); + GGML_ASSERT(ggml_can_mul_mat(k, q)); // TODO: check if vT can be multiplied by (k*qT) bool is_node = false; @@ -3177,7 +3791,7 @@ struct ggml_tensor * ggml_flash_ff( struct ggml_tensor * b1, struct ggml_tensor * c0, struct ggml_tensor * c1) { - assert(ggml_can_mul_mat(b0, a)); + GGML_ASSERT(ggml_can_mul_mat(b0, a)); // TODO: more checks bool is_node = false; @@ -3208,7 +3822,7 @@ void ggml_set_param( struct ggml_tensor * tensor) { tensor->is_param = true; - assert(tensor->grad == NULL); + GGML_ASSERT(tensor->grad == NULL); tensor->grad = ggml_dup_tensor(ctx, tensor); } @@ -3218,9 +3832,9 @@ static void ggml_compute_forward_dup_f16( const struct ggml_compute_params * params, const struct ggml_tensor * src0, struct ggml_tensor * dst) { - assert(params->ith == 0); - assert(ggml_is_contiguous(dst)); - assert(ggml_nelements(dst) == ggml_nelements(src0)); + GGML_ASSERT(params->ith == 0); + GGML_ASSERT(ggml_is_contiguous(dst)); + GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0)); if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { return; @@ -3435,6 +4049,8 @@ static void ggml_compute_forward_dup( { ggml_compute_forward_dup_f32(params, src0, dst); } break; + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: case GGML_TYPE_I8: case GGML_TYPE_I16: case GGML_TYPE_I32: @@ -3510,13 +4126,15 @@ static void ggml_compute_forward_add( { ggml_compute_forward_add_f32(params, src0, src1, dst); } break; + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: case GGML_TYPE_I8: case GGML_TYPE_I16: case GGML_TYPE_I32: case GGML_TYPE_F16: case GGML_TYPE_COUNT: { - assert(false); + GGML_ASSERT(false); } break; } } @@ -3560,13 +4178,15 @@ static void ggml_compute_forward_sub( { ggml_compute_forward_sub_f32(params, src0, src1, dst); } break; + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: case GGML_TYPE_I8: case GGML_TYPE_I16: case GGML_TYPE_I32: case GGML_TYPE_F16: case GGML_TYPE_COUNT: { - assert(false); + GGML_ASSERT(false); } break; } } @@ -3610,13 +4230,15 @@ static void ggml_compute_forward_mul( { ggml_compute_forward_mul_f32(params, src0, src1, dst); } break; + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: case GGML_TYPE_I8: case GGML_TYPE_I16: case GGML_TYPE_I32: case GGML_TYPE_F16: case GGML_TYPE_COUNT: { - assert(false); + GGML_ASSERT(false); } break; } } @@ -3660,13 +4282,15 @@ static void ggml_compute_forward_div( { ggml_compute_forward_div_f32(params, src0, src1, dst); } break; + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: case GGML_TYPE_I8: case GGML_TYPE_I16: case GGML_TYPE_I32: case GGML_TYPE_F16: case GGML_TYPE_COUNT: { - assert(false); + GGML_ASSERT(false); } break; } } @@ -3706,13 +4330,15 @@ static void ggml_compute_forward_sqr( { ggml_compute_forward_sqr_f32(params, src0, dst); } break; + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: case GGML_TYPE_I8: case GGML_TYPE_I16: case GGML_TYPE_I32: case GGML_TYPE_F16: case GGML_TYPE_COUNT: { - assert(false); + GGML_ASSERT(false); } break; } } @@ -3752,13 +4378,15 @@ static void ggml_compute_forward_sqrt( { ggml_compute_forward_sqrt_f32(params, src0, dst); } break; + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: case GGML_TYPE_I8: case GGML_TYPE_I16: case GGML_TYPE_I32: case GGML_TYPE_F16: case GGML_TYPE_COUNT: { - assert(false); + GGML_ASSERT(false); } break; } } @@ -3808,13 +4436,15 @@ static void ggml_compute_forward_sum( { ggml_compute_forward_sum_f32(params, src0, dst); } break; + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: case GGML_TYPE_I8: case GGML_TYPE_I16: case GGML_TYPE_I32: case GGML_TYPE_F16: case GGML_TYPE_COUNT: { - assert(false); + GGML_ASSERT(false); } break; } } @@ -3883,13 +4513,15 @@ static void ggml_compute_forward_mean( { ggml_compute_forward_mean_f32(params, src0, dst); } break; + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: case GGML_TYPE_I8: case GGML_TYPE_I16: case GGML_TYPE_I32: case GGML_TYPE_F16: case GGML_TYPE_COUNT: { - assert(false); + GGML_ASSERT(false); } break; } } @@ -3945,13 +4577,15 @@ static void ggml_compute_forward_repeat( { ggml_compute_forward_repeat_f32(params, src0, dst); } break; + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: case GGML_TYPE_I8: case GGML_TYPE_I16: case GGML_TYPE_I32: case GGML_TYPE_F16: case GGML_TYPE_COUNT: { - assert(false); + GGML_ASSERT(false); } break; } } @@ -3991,13 +4625,15 @@ static void ggml_compute_forward_abs( { ggml_compute_forward_abs_f32(params, src0, dst); } break; + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: case GGML_TYPE_I8: case GGML_TYPE_I16: case GGML_TYPE_I32: case GGML_TYPE_F16: case GGML_TYPE_COUNT: { - assert(false); + GGML_ASSERT(false); } break; } } @@ -4037,13 +4673,15 @@ static void ggml_compute_forward_sgn( { ggml_compute_forward_sgn_f32(params, src0, dst); } break; + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: case GGML_TYPE_I8: case GGML_TYPE_I16: case GGML_TYPE_I32: case GGML_TYPE_F16: case GGML_TYPE_COUNT: { - assert(false); + GGML_ASSERT(false); } break; } } @@ -4083,13 +4721,15 @@ static void ggml_compute_forward_neg( { ggml_compute_forward_neg_f32(params, src0, dst); } break; + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: case GGML_TYPE_I8: case GGML_TYPE_I16: case GGML_TYPE_I32: case GGML_TYPE_F16: case GGML_TYPE_COUNT: { - assert(false); + GGML_ASSERT(false); } break; } } @@ -4129,13 +4769,15 @@ static void ggml_compute_forward_step( { ggml_compute_forward_step_f32(params, src0, dst); } break; + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: case GGML_TYPE_I8: case GGML_TYPE_I16: case GGML_TYPE_I32: case GGML_TYPE_F16: case GGML_TYPE_COUNT: { - assert(false); + GGML_ASSERT(false); } break; } } @@ -4175,13 +4817,15 @@ static void ggml_compute_forward_relu( { ggml_compute_forward_relu_f32(params, src0, dst); } break; + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: case GGML_TYPE_I8: case GGML_TYPE_I16: case GGML_TYPE_I32: case GGML_TYPE_F16: case GGML_TYPE_COUNT: { - assert(false); + GGML_ASSERT(false); } break; } } @@ -4238,15 +4882,19 @@ static void ggml_compute_forward_gelu( { ggml_compute_forward_gelu_f32(params, src0, dst); } break; + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: case GGML_TYPE_I8: case GGML_TYPE_I16: case GGML_TYPE_I32: case GGML_TYPE_F16: case GGML_TYPE_COUNT: { - assert(false); + GGML_ASSERT(false); } break; } + + //printf("XXXXXXXX gelu\n"); } // ggml_compute_forward_norm @@ -4320,13 +4968,15 @@ static void ggml_compute_forward_norm( { ggml_compute_forward_norm_f32(params, src0, dst); } break; + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: case GGML_TYPE_I8: case GGML_TYPE_I16: case GGML_TYPE_I32: case GGML_TYPE_F16: case GGML_TYPE_COUNT: { - assert(false); + GGML_ASSERT(false); } break; } } @@ -4348,9 +4998,8 @@ static bool ggml_compute_forward_mul_mat_use_blas( const int ne1 = dst->ne[1]; // TODO: find the optimal values for these - if (ggml_is_contiguous(src0) && ggml_is_contiguous(src1) && ( - (ne0 >= 32 && ne1 >= 32 && ne10 >= 32) - )) { + if (ggml_is_contiguous(src0) && + ggml_is_contiguous(src1) && ((ne0 >= 32 && ne1 >= 32 && ne10 >= 32))) { //printf("BLAS: %d %d %d\n", ne0, ne1, ne10); return true; } @@ -4449,25 +5098,631 @@ static void ggml_compute_forward_mul_mat_f32( float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3); - // zT = y * xT + // zT = y * xT + { + cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, + ne11, ne01, ne10, + 1.0f, y, ne10, + x, ne10, + 0.0f, d, ne01); + } + } + } + + //printf("CBLAS F32 = %f ms, %d x %d x %d x %d\n", (ggml_perf_time_us() - t0)/1000.0, ne0, ne1, ne2, ne3); + + return; + } +#endif + + if (params->type == GGML_TASK_INIT) { + if (nb01 >= nb00) { + return; + } + + // TODO: fix this memset (wsize is overestimated) + memset(params->wdata, 0, params->wsize); + return; + } + + if (params->type == GGML_TASK_FINALIZE) { + if (nb01 >= nb00) { + return; + } + + // TODO: fix this memset (wsize is overestimated) + //assert(params->wsize == (ggml_nbytes(dst) + CACHE_LINE_SIZE)*nth); + + float * const wdata = params->wdata; + + // cols per thread + const int dc = (ne + nth - 1)/nth; + + // col range for this thread + const int ic0 = dc*ith; + const int ic1 = MIN(ic0 + dc, ne); + + ggml_vec_cpy_f32(ic1 - ic0, (float *) dst->data + ic0, wdata + ic0); + + for (int k = 1; k < nth; k++) { + ggml_vec_acc_f32(ic1 - ic0, (float *) dst->data + ic0, wdata + (ne + CACHE_LINE_SIZE_F32)*k + ic0); + } + + return; + } + + if (nb01 >= nb00) { + // TODO: do not support transposed src1 + assert(nb10 == sizeof(float)); + + // parallelize by src0 rows using ggml_vec_dot_f32 + + // total rows in src0 + const int nr = ne01*ne02*ne03; + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + for (int ir = ir0; ir < ir1; ++ir) { + // src0 indices + const int i03 = ir/(ne02*ne01); + const int i02 = (ir - i03*ne02*ne01)/ne01; + const int i01 = (ir - i03*ne02*ne01 - i02*ne01); + + for (int ic = 0; ic < ne11; ++ic) { + // src1 indices + const int i13 = i03; + const int i12 = i02; + const int i11 = ic; + + // dst indices + const int i0 = i01; + const int i1 = i11; + const int i2 = i02; + const int i3 = i03; + + ggml_vec_dot_f32(ne00, + (float *) ((char *) dst->data + (i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3)), + (float *) ((char *) src0->data + (i01*nb01 + i02*nb02 + i03*nb03)), + (float *) ((char *) src1->data + (i11*nb11 + i12*nb12 + i13*nb13))); + } + } + } else { + // parallelize by src1 columns using ggml_vec_mad_f32 + // each thread has its own work data + // during FINALIZE we accumulate all work data into dst + + // total columns in src1 + const int nc = ne10; + + // columns per thread + const int dc = (nc + nth - 1)/nth; + + // column range for this thread + const int ic0 = dc*ith; + const int ic1 = MIN(ic0 + dc, nc); + + // work data for thread + const int wo = (ne + CACHE_LINE_SIZE_F32)*ith; + float * const wdata = params->wdata; + + for (int i13 = 0; i13 < ne13; ++i13) { + for (int i12 = 0; i12 < ne12; ++i12) { + for (int i11 = 0; i11 < ne11; ++i11) { + for (int ic = ic0; ic < ic1; ++ic) { + // src1 indices + const int i10 = ic; + + // src0 indices + const int i03 = i13; + const int i02 = i12; + const int i00 = ic; + + // dst indices + const int i1 = i11; + const int i2 = i12; + const int i3 = i13; + + assert(sizeof(float)*(wo + i3*ne2*ne1*ne0 + i2*ne1*ne0 + i1*ne0 + ne01) <= params->wsize); + + ggml_vec_mad_f32(ne01, + (float *) (wdata + wo + i3*ne2*ne1*ne0 + i2*ne1*ne0 + i1*ne0), + (float *) ((char *) src0->data + (i00*nb00 + i02*nb02 + i03*nb03)), + *(float *) ((char *) src1->data + (i10*nb10 + i11*nb11 + i12*nb12 + i13*nb13))); + } + } + } + } + } + + //int64_t t1 = ggml_perf_time_us(); + //static int64_t acc = 0; + //acc += t1 - t0; + //if (t1 - t0 > 10) { + // printf("\n"); + // printf("ne00 = %5d, ne01 = %5d, ne02 = %5d, ne03 = %5d\n", ne00, ne01, ne02, ne03); + // printf("nb00 = %5d, nb01 = %5d, nb02 = %5d, nb03 = %5d\n", nb00, nb01, nb02, nb03); + // printf("ne10 = %5d, ne11 = %5d, ne12 = %5d, ne13 = %5d\n", ne10, ne11, ne12, ne13); + // printf("nb10 = %5d, nb11 = %5d, nb12 = %5d, nb13 = %5d\n", nb10, nb11, nb12, nb13); + + // printf("XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX task %d/%d: %d us, acc = %d\n", ith, nth, (int) (t1 - t0), (int) acc); + //} +} + +static void ggml_compute_forward_mul_mat_f16_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + struct ggml_tensor * dst) { + int64_t t0 = ggml_perf_time_us(); + UNUSED(t0); + + const int ne00 = src0->ne[0]; + const int ne01 = src0->ne[1]; + const int ne02 = src0->ne[2]; + const int ne03 = src0->ne[3]; + + const int ne10 = src1->ne[0]; + const int ne11 = src1->ne[1]; + const int ne12 = src1->ne[2]; + const int ne13 = src1->ne[3]; + + const int ne0 = dst->ne[0]; + const int ne1 = dst->ne[1]; + const int ne2 = dst->ne[2]; + const int ne3 = dst->ne[3]; + const int ne = ne0*ne1*ne2*ne3; + + const int nb00 = src0->nb[0]; + const int nb01 = src0->nb[1]; + const int nb02 = src0->nb[2]; + const int nb03 = src0->nb[3]; + + const int nb10 = src1->nb[0]; + const int nb11 = src1->nb[1]; + const int nb12 = src1->nb[2]; + const int nb13 = src1->nb[3]; + + const int nb0 = dst->nb[0]; + const int nb1 = dst->nb[1]; + const int nb2 = dst->nb[2]; + const int nb3 = dst->nb[3]; + + const int ith = params->ith; + const int nth = params->nth; + + GGML_ASSERT(ne02 == ne12); + GGML_ASSERT(ne03 == ne13); + GGML_ASSERT(ne2 == ne12); + GGML_ASSERT(ne3 == ne13); + + // TODO: we don't support permuted src0 + GGML_ASSERT(nb00 == sizeof(ggml_fp16_t) || nb01 == sizeof(ggml_fp16_t)); + + // dst cannot be transposed or permuted + GGML_ASSERT(nb0 == sizeof(float)); + GGML_ASSERT(nb0 <= nb1); + GGML_ASSERT(nb1 <= nb2); + GGML_ASSERT(nb2 <= nb3); + + GGML_ASSERT(ne0 == ne01); + GGML_ASSERT(ne1 == ne11); + GGML_ASSERT(ne2 == ne02); + GGML_ASSERT(ne3 == ne03); + + // nb01 >= nb00 - src0 is not transposed + // compute by src0 rows + // + // nb00 < nb01 - src0 is transposed + // compute by src0 columns + +#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) + if (ggml_compute_forward_mul_mat_use_blas(src0, src1, dst)) { + GGML_ASSERT(nb10 == sizeof(float)); + + if (params->ith != 0) { + return; + } + + if (params->type == GGML_TASK_INIT) { + return; + } + + if (params->type == GGML_TASK_FINALIZE) { + return; + } + + float * const wdata = params->wdata; + + for (int i03 = 0; i03 < ne03; i03++) { + for (int i02 = 0; i02 < ne02; i02++) { + { + int id = 0; + for (int i01 = 0; i01 < ne01; ++i01) { + for (int i00 = 0; i00 < ne00; ++i00) { + wdata[id++] = GGML_FP16_TO_FP32(*(ggml_fp16_t *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00)); + } + } + } + + const float * x = wdata; + const float * y = (float *) ((char *) src1->data + i02*nb12 + i03*nb13); + + // float * z = wdata + ne00*ne01; + + // z = x * yT + //{ + // cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, + // ne01, ne11, ne00, + // 1.0f, x, ne00, + // y, ne00, + // 0.0f, z, ne11); + //} + + float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3); + + // transpose z + //for (int j = 0; j < ne11; ++j) { + // for (int i = 0; i < ne01; ++i) { + // d[j*ne01 + i] = z[i*ne11 + j]; + // } + //} + + { +#if 1 + // zT = y * xT + cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, + ne11, ne01, ne10, + 1.0f, y, ne00, + x, ne00, + 0.0f, d, ne01); +#else + // zT = (xT * y)T + cblas_sgemm(CblasColMajor, CblasTrans, CblasNoTrans, + ne01, ne11, ne10, + 1.0f, x, ne00, + y, ne00, + 0.0f, d, ne01); +#endif + } + } + } + + /*printf("CBLAS F16 = %f ms, %d x %d x %d x %d\n", (ggml_perf_time_us() - t0)/1000.0, ne0, ne1, ne2, ne3);*/ + + return; + } +#endif + + if (params->type == GGML_TASK_INIT) { + if (nb01 >= nb00) { + ggml_fp16_t * const wdata = params->wdata; + + int id = 0; + for (int i13 = 0; i13 < ne13; ++i13) { + for (int i12 = 0; i12 < ne12; ++i12) { + for (int i11 = 0; i11 < ne11; ++i11) { + for (int i10 = 0; i10 < ne10; ++i10) { + wdata[id++] = GGML_FP32_TO_FP16(*(float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11 + i10*nb10)); + } + } + } + } + + GGML_ASSERT(id*sizeof(ggml_fp16_t) <= params->wsize); + + return; + } + + // TODO: fix this memset (wsize is overestimated) + memset(params->wdata, 0, params->wsize); + return; + } + + if (params->type == GGML_TASK_FINALIZE) { + if (nb01 >= nb00) { + return; + } + + // TODO: fix this memset (wsize is overestimated) + //assert(params->wsize == (ggml_nbytes(dst) + CACHE_LINE_SIZE)*nth); + + ggml_fp16_t * const wdata = params->wdata; + + // cols per thread + const int dc = (ne + nth - 1)/nth; + + // col range for this thread + const int ic0 = dc*ith; + const int ic1 = MIN(ic0 + dc, ne); + + for (int i = ic0; i < ic1; ++i) { + ((float *) dst->data)[i] = GGML_FP16_TO_FP32(wdata[i]); + } + + for (int k = 1; k < nth; k++) { + for (int i = ic0; i < ic1; ++i) { + ((float *) dst->data)[i] += GGML_FP16_TO_FP32(wdata[(ne + CACHE_LINE_SIZE_F32)*k + i]); + } + } + + return; + } + + if (nb01 >= nb00) { + // fp16 -> half the size, so divide by 2 + // TODO: do not support transposed src1 + assert(nb10/2 == sizeof(ggml_fp16_t)); + + // parallelize by src0 rows using ggml_vec_dot_f16 + + // total rows in src0 + const int nr = ne01*ne02*ne03; + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + ggml_fp16_t * wdata = params->wdata; + + for (int ir = ir0; ir < ir1; ++ir) { + // src0 indices + const int i03 = ir/(ne02*ne01); + const int i02 = (ir - i03*ne02*ne01)/ne01; + const int i01 = (ir - i03*ne02*ne01 - i02*ne01); + + const int i13 = i03; + const int i12 = i02; + + const int i0 = i01; + const int i2 = i02; + const int i3 = i03; + + ggml_fp16_t * src0_row = (ggml_fp16_t *) ((char *) src0->data + (i01*nb01 + i02*nb02 + i03*nb03)); + ggml_fp16_t * src1_col = wdata + ( 0 + i12*ne11 + i13*ne12*ne11)*ne00; + + float * dst_col = (float *) ((char *) dst->data + (i0*nb0 + 0*nb1 + i2*nb2 + i3*nb3)); + + assert(ne00 % 32 == 0); + + for (int ic = 0; ic < ne11; ++ic) { + ggml_vec_dot_f16(ne00, &dst_col[ic*ne0], src0_row, src1_col + ic*ne00); + } + } + } else { + // parallelize by src1 columns using ggml_vec_mad_f16 + // each thread has its own work data + // during FINALIZE we accumulate all work data into dst + + // total columns in src1 + const int nc = ne10; + + // columns per thread + const int dc = (nc + nth - 1)/nth; + + // column range for this thread + const int ic0 = dc*ith; + const int ic1 = MIN(ic0 + dc, nc); + + // work data for thread + const int wo = (ne + CACHE_LINE_SIZE_F32)*ith; + ggml_fp16_t * const wdata = params->wdata; + + for (int i13 = 0; i13 < ne13; ++i13) { + for (int i12 = 0; i12 < ne12; ++i12) { + for (int i11 = 0; i11 < ne11; ++i11) { + // dst indices + const int i1 = i11; + const int i2 = i12; + const int i3 = i13; + + ggml_fp16_t * dst_row = wdata + wo + i3*ne2*ne1*ne0 + i2*ne1*ne0 + i1*ne0; + + for (int ic = ic0; ic < ic1; ++ic) { + // src1 indices + const int i10 = ic; + + // src0 indices + const int i03 = i13; + const int i02 = i12; + const int i00 = ic; + + assert(sizeof(ggml_fp16_t)*(wo + i3*ne2*ne1*ne0 + i2*ne1*ne0 + i1*ne0 + ne01) <= params->wsize); + + ggml_fp16_t * src0_col = (ggml_fp16_t *) ((char *) src0->data + (i00*nb00 + i02*nb02 + i03*nb03)); + float src1_val = * (float *) ((char *) src1->data + (i10*nb10 + i11*nb11 + i12*nb12 + i13*nb13)); + + ggml_vec_mad_f16(ne01, dst_row, src0_col, src1_val); + } + } + } + } + } + + //int64_t t1 = ggml_time_us(); + //static int64_t acc = 0; + //acc += t1 - t0; + //if (t1 - t0 > 10) { + // printf("\n"); + // printf("ne00 = %5d, ne01 = %5d, ne02 = %5d, ne03 = %5d\n", ne00, ne01, ne02, ne03); + // printf("nb00 = %5d, nb01 = %5d, nb02 = %5d, nb03 = %5d\n", nb00, nb01, nb02, nb03); + // printf("ne10 = %5d, ne11 = %5d, ne12 = %5d, ne13 = %5d\n", ne10, ne11, ne12, ne13); + + // printf("XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX task %d/%d: %d us, acc = %d\n", ith, nth, (int) (t1 - t0), (int) acc); + //} +} + +static void ggml_compute_forward_mul_mat_q4_0_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + struct ggml_tensor * dst) { + int64_t t0 = ggml_perf_time_us(); + UNUSED(t0); + + const int ne00 = src0->ne[0]; + const int ne01 = src0->ne[1]; + const int ne02 = src0->ne[2]; + const int ne03 = src0->ne[3]; + + const int ne10 = src1->ne[0]; + const int ne11 = src1->ne[1]; + const int ne12 = src1->ne[2]; + const int ne13 = src1->ne[3]; + + const int ne0 = dst->ne[0]; + const int ne1 = dst->ne[1]; + const int ne2 = dst->ne[2]; + const int ne3 = dst->ne[3]; + const int ne = ne0*ne1*ne2*ne3; + + const int nb00 = src0->nb[0]; + const int nb01 = src0->nb[1]; + const int nb02 = src0->nb[2]; + const int nb03 = src0->nb[3]; + + const int nb10 = src1->nb[0]; + const int nb11 = src1->nb[1]; + const int nb12 = src1->nb[2]; + const int nb13 = src1->nb[3]; + + const int nb0 = dst->nb[0]; + const int nb1 = dst->nb[1]; + const int nb2 = dst->nb[2]; + const int nb3 = dst->nb[3]; + + const int ith = params->ith; + const int nth = params->nth; + + GGML_ASSERT(ne02 == ne12); + GGML_ASSERT(ne03 == ne13); + GGML_ASSERT(ne2 == ne12); + GGML_ASSERT(ne3 == ne13); + + // TODO: we don't support permuted src0 + GGML_ASSERT(nb00 == (int) GGML_TYPE_SIZE[GGML_TYPE_Q4_0] || nb01 == (int) GGML_TYPE_SIZE[GGML_TYPE_Q4_0]); + + // dst cannot be transposed or permuted + GGML_ASSERT(nb0 == sizeof(float)); + GGML_ASSERT(nb0 <= nb1); + GGML_ASSERT(nb1 <= nb2); + GGML_ASSERT(nb2 <= nb3); + + GGML_ASSERT(ne0 == ne01); + GGML_ASSERT(ne1 == ne11); + GGML_ASSERT(ne2 == ne02); + GGML_ASSERT(ne3 == ne03); + + // nb01 >= nb00 - src0 is not transposed + // compute by src0 rows + // + // nb00 < nb01 - src0 is transposed + // compute by src0 columns + +#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) + if (ggml_compute_forward_mul_mat_use_blas(src0, src1, dst)) { + GGML_ASSERT(nb10 == sizeof(float)); + + if (params->ith != 0) { + return; + } + + if (params->type == GGML_TASK_INIT) { + return; + } + + if (params->type == GGML_TASK_FINALIZE) { + return; + } + + float * const wdata = params->wdata; + + for (int i03 = 0; i03 < ne03; i03++) { + for (int i02 = 0; i02 < ne02; i02++) { + { + int id = 0; + for (int i01 = 0; i01 < ne01; ++i01) { + //for (int i00 = 0; i00 < ne00; ++i00) { + // wdata[id++] = GGML_FP16_TO_FP32(*(ggml_fp16_t *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00)); + //} + dequantize_row_q4_0((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01, wdata + id, ne00); + id += ne00; + } + } + + const float * x = wdata; + const float * y = (float *) ((char *) src1->data + i02*nb12 + i03*nb13); + + // float * z = wdata + ne00*ne01; + + // z = x * yT + //{ + // cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, + // ne01, ne11, ne00, + // 1.0f, x, ne00, + // y, ne00, + // 0.0f, z, ne11); + //} + + float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3); + + // transpose z + //for (int j = 0; j < ne11; ++j) { + // for (int i = 0; i < ne01; ++i) { + // d[j*ne01 + i] = z[i*ne11 + j]; + // } + //} + { +#if 1 + // zT = y * xT cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, ne11, ne01, ne10, - 1.0f, y, ne10, - x, ne10, + 1.0f, y, ne00, + x, ne00, + 0.0f, d, ne01); +#else + // zT = (xT * y)T + cblas_sgemm(CblasColMajor, CblasTrans, CblasNoTrans, + ne01, ne11, ne10, + 1.0f, x, ne00, + y, ne00, 0.0f, d, ne01); +#endif } } } - //printf("CBLAS F32 = %f ms, %d x %d x %d x %d\n", (ggml_perf_time_us() - t0)/1000.0, ne0, ne1, ne2, ne3); + /*printf("CBLAS Q4_0 = %f ms, %d x %d x %d x %d\n", (ggml_perf_time_us() - t0)/1000.0, ne0, ne1, ne2, ne3);*/ return; } #endif if (params->type == GGML_TASK_INIT) { + //printf("HHHHHHHHH ith = %d, nth = %d\n", ith, nth); if (nb01 >= nb00) { + char * wdata = params->wdata; + + for (int i13 = 0; i13 < ne13; ++i13) { + for (int i12 = 0; i12 < ne12; ++i12) { + for (int i11 = 0; i11 < ne11; ++i11) { + //for (int i10 = 0; i10 < ne10; ++i10) { + // wdata[id++] = GGML_FP32_TO_FP16(*(float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11 + i10*nb10)); + //} + quantize_row_q4_0((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11), (void *) wdata, ne10); + wdata += (ne10*GGML_TYPE_SIZE[GGML_TYPE_Q4_0])/GGML_BLCK_SIZE[GGML_TYPE_Q4_0]; + } + } + } + return; } @@ -4481,9 +5736,6 @@ static void ggml_compute_forward_mul_mat_f32( return; } - // TODO: fix this memset (wsize is overestimated) - //assert(params->wsize == (ggml_nbytes(dst) + CACHE_LINE_SIZE)*nth); - float * const wdata = params->wdata; // cols per thread @@ -4504,9 +5756,8 @@ static void ggml_compute_forward_mul_mat_f32( if (nb01 >= nb00) { // TODO: do not support transposed src1 - assert(nb10 == sizeof(float)); - // parallelize by src0 rows using ggml_vec_dot_f32 + // parallelize by src0 rows using ggml_vec_dot_q4_0 // total rows in src0 const int nr = ne01*ne02*ne03; @@ -4518,32 +5769,35 @@ static void ggml_compute_forward_mul_mat_f32( const int ir0 = dr*ith; const int ir1 = MIN(ir0 + dr, nr); + void * wdata = params->wdata; + for (int ir = ir0; ir < ir1; ++ir) { // src0 indices const int i03 = ir/(ne02*ne01); const int i02 = (ir - i03*ne02*ne01)/ne01; const int i01 = (ir - i03*ne02*ne01 - i02*ne01); - for (int ic = 0; ic < ne11; ++ic) { - // src1 indices - const int i13 = i03; - const int i12 = i02; - const int i11 = ic; + const int i13 = i03; + const int i12 = i02; - // dst indices - const int i0 = i01; - const int i1 = i11; - const int i2 = i02; - const int i3 = i03; + const int i0 = i01; + const int i2 = i02; + const int i3 = i03; - ggml_vec_dot_f32(ne00, - (float *) ((char *) dst->data + (i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3)), - (float *) ((char *) src0->data + (i01*nb01 + i02*nb02 + i03*nb03)), - (float *) ((char *) src1->data + (i11*nb11 + i12*nb12 + i13*nb13))); + void * src0_row = (void *) ((char *) src0->data + (i01*nb01 + i02*nb02 + i03*nb03)); + char * src1_col = ((char *) wdata + ( (0 + i12*ne11 + i13*ne12*ne11)*ne00*GGML_TYPE_SIZE[GGML_TYPE_Q4_0])/GGML_BLCK_SIZE[GGML_TYPE_Q4_0]); + + float * dst_col = (float *) ((char *) dst->data + (i0*nb0 + 0*nb1 + i2*nb2 + i3*nb3)); + + assert(ne00 % 32 == 0); + + for (int ic = 0; ic < ne11; ++ic) { + ggml_vec_dot_q4_0(ne00, &dst_col[ic*ne0], src0_row, ((void *) (src1_col + (ic*ne00*GGML_TYPE_SIZE[GGML_TYPE_Q4_0])/GGML_BLCK_SIZE[GGML_TYPE_Q4_0]))); } } } else { - // parallelize by src1 columns using ggml_vec_mad_f32 + //printf("AAAAA ith = %d, nth = %d\n", ith, nth); + // parallelize by src1 columns using ggml_vec_mad_q4_0 // each thread has its own work data // during FINALIZE we accumulate all work data into dst @@ -4564,6 +5818,13 @@ static void ggml_compute_forward_mul_mat_f32( for (int i13 = 0; i13 < ne13; ++i13) { for (int i12 = 0; i12 < ne12; ++i12) { for (int i11 = 0; i11 < ne11; ++i11) { + // dst indices + const int i1 = i11; + const int i2 = i12; + const int i3 = i13; + + float * dst_row = wdata + wo + i3*ne2*ne1*ne0 + i2*ne1*ne0 + i1*ne0; + for (int ic = ic0; ic < ic1; ++ic) { // src1 indices const int i10 = ic; @@ -4573,24 +5834,19 @@ static void ggml_compute_forward_mul_mat_f32( const int i02 = i12; const int i00 = ic; - // dst indices - const int i1 = i11; - const int i2 = i12; - const int i3 = i13; - assert(sizeof(float)*(wo + i3*ne2*ne1*ne0 + i2*ne1*ne0 + i1*ne0 + ne01) <= params->wsize); - ggml_vec_mad_f32(ne01, - (float *) (wdata + wo + i3*ne2*ne1*ne0 + i2*ne1*ne0 + i1*ne0), - (float *) ((char *) src0->data + (i00*nb00 + i02*nb02 + i03*nb03)), - *(float *) ((char *) src1->data + (i10*nb10 + i11*nb11 + i12*nb12 + i13*nb13))); + void * src0_col = (void *) ((char *) src0->data + (i00*nb00 + i02*nb02 + i03*nb03)); + float src1_val = *(float *) ((char *) src1->data + (i10*nb10 + i11*nb11 + i12*nb12 + i13*nb13)); + + ggml_vec_mad_q4_0(ne01, dst_row, src0_col, src1_val); } } } } } - //int64_t t1 = ggml_perf_time_us(); + //int64_t t1 = ggml_time_us(); //static int64_t acc = 0; //acc += t1 - t0; //if (t1 - t0 > 10) { @@ -4598,13 +5854,12 @@ static void ggml_compute_forward_mul_mat_f32( // printf("ne00 = %5d, ne01 = %5d, ne02 = %5d, ne03 = %5d\n", ne00, ne01, ne02, ne03); // printf("nb00 = %5d, nb01 = %5d, nb02 = %5d, nb03 = %5d\n", nb00, nb01, nb02, nb03); // printf("ne10 = %5d, ne11 = %5d, ne12 = %5d, ne13 = %5d\n", ne10, ne11, ne12, ne13); - // printf("nb10 = %5d, nb11 = %5d, nb12 = %5d, nb13 = %5d\n", nb10, nb11, nb12, nb13); // printf("XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX task %d/%d: %d us, acc = %d\n", ith, nth, (int) (t1 - t0), (int) acc); //} } -static void ggml_compute_forward_mul_mat_f16_f32( +static void ggml_compute_forward_mul_mat_q4_1_f32( const struct ggml_compute_params * params, const struct ggml_tensor * src0, const struct ggml_tensor * src1, @@ -4652,7 +5907,7 @@ static void ggml_compute_forward_mul_mat_f16_f32( GGML_ASSERT(ne3 == ne13); // TODO: we don't support permuted src0 - GGML_ASSERT(nb00 == sizeof(ggml_fp16_t) || nb01 == sizeof(ggml_fp16_t)); + GGML_ASSERT(nb00 == (int) GGML_TYPE_SIZE[GGML_TYPE_Q4_1] || nb01 == (int) GGML_TYPE_SIZE[GGML_TYPE_Q4_1]); // dst cannot be transposed or permuted GGML_ASSERT(nb0 == sizeof(float)); @@ -4694,9 +5949,11 @@ static void ggml_compute_forward_mul_mat_f16_f32( { int id = 0; for (int i01 = 0; i01 < ne01; ++i01) { - for (int i00 = 0; i00 < ne00; ++i00) { - wdata[id++] = GGML_FP16_TO_FP32(*(ggml_fp16_t *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00)); - } + //for (int i00 = 0; i00 < ne00; ++i00) { + // wdata[id++] = GGML_FP16_TO_FP32(*(ggml_fp16_t *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00)); + //} + dequantize_row_q4_1((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01, wdata + id, ne00); + id += ne00; } } @@ -4750,22 +6007,22 @@ static void ggml_compute_forward_mul_mat_f16_f32( #endif if (params->type == GGML_TASK_INIT) { + //printf("HHHHHHHHH ith = %d, nth = %d\n", ith, nth); if (nb01 >= nb00) { - ggml_fp16_t * const wdata = params->wdata; + char * wdata = params->wdata; - int id = 0; for (int i13 = 0; i13 < ne13; ++i13) { for (int i12 = 0; i12 < ne12; ++i12) { for (int i11 = 0; i11 < ne11; ++i11) { - for (int i10 = 0; i10 < ne10; ++i10) { - wdata[id++] = GGML_FP32_TO_FP16(*(float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11 + i10*nb10)); - } + //for (int i10 = 0; i10 < ne10; ++i10) { + // wdata[id++] = GGML_FP32_TO_FP16(*(float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11 + i10*nb10)); + //} + quantize_row_q4_1((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11), (void *) wdata, ne10); + wdata += (ne10*GGML_TYPE_SIZE[GGML_TYPE_Q4_1])/GGML_BLCK_SIZE[GGML_TYPE_Q4_1]; } } } - GGML_ASSERT(id*sizeof(ggml_fp16_t) <= params->wsize); - return; } @@ -4779,10 +6036,7 @@ static void ggml_compute_forward_mul_mat_f16_f32( return; } - // TODO: fix this memset (wsize is overestimated) - //assert(params->wsize == (ggml_nbytes(dst) + CACHE_LINE_SIZE)*nth); - - ggml_fp16_t * const wdata = params->wdata; + float * const wdata = params->wdata; // cols per thread const int dc = (ne + nth - 1)/nth; @@ -4791,25 +6045,19 @@ static void ggml_compute_forward_mul_mat_f16_f32( const int ic0 = dc*ith; const int ic1 = MIN(ic0 + dc, ne); - for (int i = ic0; i < ic1; ++i) { - ((float *) dst->data)[i] = GGML_FP16_TO_FP32(wdata[i]); - } + ggml_vec_cpy_f32(ic1 - ic0, (float *) dst->data + ic0, wdata + ic0); for (int k = 1; k < nth; k++) { - for (int i = ic0; i < ic1; ++i) { - ((float *) dst->data)[i] += GGML_FP16_TO_FP32(wdata[(ne + CACHE_LINE_SIZE_F32)*k + i]); - } + ggml_vec_acc_f32(ic1 - ic0, (float *) dst->data + ic0, wdata + (ne + CACHE_LINE_SIZE_F32)*k + ic0); } return; } if (nb01 >= nb00) { - // fp16 -> half the size, so divide by 2 // TODO: do not support transposed src1 - assert(nb10/2 == sizeof(ggml_fp16_t)); - // parallelize by src0 rows using ggml_vec_dot_f16 + // parallelize by src0 rows using ggml_vec_dot_q4_1 // total rows in src0 const int nr = ne01*ne02*ne03; @@ -4821,7 +6069,7 @@ static void ggml_compute_forward_mul_mat_f16_f32( const int ir0 = dr*ith; const int ir1 = MIN(ir0 + dr, nr); - ggml_fp16_t * wdata = params->wdata; + void * wdata = params->wdata; for (int ir = ir0; ir < ir1; ++ir) { // src0 indices @@ -4836,19 +6084,20 @@ static void ggml_compute_forward_mul_mat_f16_f32( const int i2 = i02; const int i3 = i03; - ggml_fp16_t * src0_row = (ggml_fp16_t *) ((char *) src0->data + (i01*nb01 + i02*nb02 + i03*nb03)); - ggml_fp16_t * src1_col = wdata + ( 0 + i12*ne11 + i13*ne12*ne11)*ne00; + void * src0_row = (void *) ((char *) src0->data + (i01*nb01 + i02*nb02 + i03*nb03)); + char * src1_col = ((char *) wdata + ( (0 + i12*ne11 + i13*ne12*ne11)*ne00*GGML_TYPE_SIZE[GGML_TYPE_Q4_1])/GGML_BLCK_SIZE[GGML_TYPE_Q4_1]); float * dst_col = (float *) ((char *) dst->data + (i0*nb0 + 0*nb1 + i2*nb2 + i3*nb3)); assert(ne00 % 32 == 0); for (int ic = 0; ic < ne11; ++ic) { - ggml_vec_dot_f16(ne00, &dst_col[ic*ne0], src0_row, src1_col + ic*ne00); + ggml_vec_dot_q4_1(ne00, &dst_col[ic*ne0], src0_row, ((void *) (src1_col + (ic*ne00*GGML_TYPE_SIZE[GGML_TYPE_Q4_1])/GGML_BLCK_SIZE[GGML_TYPE_Q4_1]))); } } } else { - // parallelize by src1 columns using ggml_vec_mad_f16 + //printf("AAAAA ith = %d, nth = %d\n", ith, nth); + // parallelize by src1 columns using ggml_vec_mad_q4_1 // each thread has its own work data // during FINALIZE we accumulate all work data into dst @@ -4864,7 +6113,7 @@ static void ggml_compute_forward_mul_mat_f16_f32( // work data for thread const int wo = (ne + CACHE_LINE_SIZE_F32)*ith; - ggml_fp16_t * const wdata = params->wdata; + float * const wdata = params->wdata; for (int i13 = 0; i13 < ne13; ++i13) { for (int i12 = 0; i12 < ne12; ++i12) { @@ -4874,7 +6123,7 @@ static void ggml_compute_forward_mul_mat_f16_f32( const int i2 = i12; const int i3 = i13; - ggml_fp16_t * dst_row = wdata + wo + i3*ne2*ne1*ne0 + i2*ne1*ne0 + i1*ne0; + float * dst_row = wdata + wo + i3*ne2*ne1*ne0 + i2*ne1*ne0 + i1*ne0; for (int ic = ic0; ic < ic1; ++ic) { // src1 indices @@ -4885,12 +6134,12 @@ static void ggml_compute_forward_mul_mat_f16_f32( const int i02 = i12; const int i00 = ic; - assert(sizeof(ggml_fp16_t)*(wo + i3*ne2*ne1*ne0 + i2*ne1*ne0 + i1*ne0 + ne01) <= params->wsize); + assert(sizeof(float)*(wo + i3*ne2*ne1*ne0 + i2*ne1*ne0 + i1*ne0 + ne01) <= params->wsize); - ggml_fp16_t * src0_col = (ggml_fp16_t *) ((char *) src0->data + (i00*nb00 + i02*nb02 + i03*nb03)); - float src1_val = * (float *) ((char *) src1->data + (i10*nb10 + i11*nb11 + i12*nb12 + i13*nb13)); + void * src0_col = (void *) ((char *) src0->data + (i00*nb00 + i02*nb02 + i03*nb03)); + float src1_val = *(float *) ((char *) src1->data + (i10*nb10 + i11*nb11 + i12*nb12 + i13*nb13)); - ggml_vec_mad_f16(ne01, dst_row, src0_col, src1_val); + ggml_vec_mad_q4_1(ne01, dst_row, src0_col, src1_val); } } } @@ -4916,6 +6165,14 @@ static void ggml_compute_forward_mul_mat( const struct ggml_tensor * src1, struct ggml_tensor * dst) { switch (src0->type) { + case GGML_TYPE_Q4_0: + { + ggml_compute_forward_mul_mat_q4_0_f32(params, src0, src1, dst); + } break; + case GGML_TYPE_Q4_1: + { + ggml_compute_forward_mul_mat_q4_1_f32(params, src0, src1, dst); + } break; case GGML_TYPE_F16: { ggml_compute_forward_mul_mat_f16_f32(params, src0, src1, dst); @@ -4929,9 +6186,37 @@ static void ggml_compute_forward_mul_mat( case GGML_TYPE_I32: case GGML_TYPE_COUNT: { - assert(false); + GGML_ASSERT(false); } break; } + +#if 0 + if (src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_Q4_1) { + static int first = 8; + printf("src0: ne0 = %5d, ne1 = %5d, ne2 = %5d\n", src0->ne[0], src0->ne[1], src0->ne[2]); + printf("src1: ne0 = %5d, ne1 = %5d, ne2 = %5d\n", src1->ne[0], src1->ne[1], src1->ne[2]); + printf("dst: ne0 = %5d, ne1 = %5d, ne2 = %5d\n", dst->ne[0], dst->ne[1], dst->ne[2]); + if (first) { + --first; + } else { + for (int k = 0; k < dst->ne[1]; ++k) { + for (int j = 0; j < dst->ne[0]/16; ++j) { + for (int i = 0; i < 16; ++i) { + printf("%8.4f ", ((float *) dst->data)[k*dst->ne[0] + j*16 + i]); + } + printf("\n"); + } + printf("\n"); + } + printf("\n"); + exit(0); + } + } else { + printf("aaaa src0: ne0 = %5d, ne1 = %5d, ne2 = %5d\n", src0->ne[0], src0->ne[1], src0->ne[2]); + printf("aaaa src1: ne0 = %5d, ne1 = %5d, ne2 = %5d\n", src1->ne[0], src1->ne[1], src1->ne[2]); + printf("aaaa dst: ne0 = %5d, ne1 = %5d, ne2 = %5d\n", dst->ne[0], dst->ne[1], dst->ne[2]); + } +#endif } // ggml_compute_forward_scale @@ -4981,13 +6266,15 @@ static void ggml_compute_forward_scale( { ggml_compute_forward_scale_f32(params, src0, src1, dst); } break; + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: case GGML_TYPE_I8: case GGML_TYPE_I16: case GGML_TYPE_I32: case GGML_TYPE_F16: case GGML_TYPE_COUNT: { - assert(false); + GGML_ASSERT(false); } break; } } @@ -5045,6 +6332,60 @@ static void ggml_compute_forward_transpose( // ggml_compute_forward_get_rows +static void ggml_compute_forward_get_rows_q4_0( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + struct ggml_tensor * dst) { + assert(params->ith == 0); + + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + const int nc = src0->ne[0]; + const int nr = ggml_nelements(src1); + + assert( dst->ne[0] == nc); + assert( dst->ne[1] == nr); + assert(src0->nb[0] == GGML_TYPE_SIZE[GGML_TYPE_Q4_0]); + + for (int i = 0; i < nr; ++i) { + const int r = ((int32_t *) src1->data)[i]; + + dequantize_row_q4_0( + (const void *) ((char *) src0->data + r*src0->nb[1]), + (float *) ((char *) dst->data + i*dst->nb[1]), nc); + } +} + +static void ggml_compute_forward_get_rows_q4_1( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + struct ggml_tensor * dst) { + assert(params->ith == 0); + + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + const int nc = src0->ne[0]; + const int nr = ggml_nelements(src1); + + assert( dst->ne[0] == nc); + assert( dst->ne[1] == nr); + assert(src0->nb[0] == GGML_TYPE_SIZE[GGML_TYPE_Q4_1]); + + for (int i = 0; i < nr; ++i) { + const int r = ((int32_t *) src1->data)[i]; + + dequantize_row_q4_1( + (const void *) ((char *) src0->data + r*src0->nb[1]), + (float *) ((char *) dst->data + i*dst->nb[1]), nc); + } +} + static void ggml_compute_forward_get_rows_f16( const struct ggml_compute_params * params, const struct ggml_tensor * src0, @@ -5106,6 +6447,14 @@ static void ggml_compute_forward_get_rows( const struct ggml_tensor * src1, struct ggml_tensor * dst) { switch (src0->type) { + case GGML_TYPE_Q4_0: + { + ggml_compute_forward_get_rows_q4_0(params, src0, src1, dst); + } break; + case GGML_TYPE_Q4_1: + { + ggml_compute_forward_get_rows_q4_1(params, src0, src1, dst); + } break; case GGML_TYPE_F16: { ggml_compute_forward_get_rows_f16(params, src0, src1, dst); @@ -5119,9 +6468,27 @@ static void ggml_compute_forward_get_rows( case GGML_TYPE_I32: case GGML_TYPE_COUNT: { - assert(false); + GGML_ASSERT(false); } break; } + + //static bool first = true; + //printf("ne0 = %d, ne1 = %d, ne2 = %d\n", dst->ne[0], dst->ne[1], dst->ne[2]); + //if (first) { + // first = false; + //} else { + // for (int k = 0; k < dst->ne[1]; ++k) { + // for (int j = 0; j < dst->ne[0]/16; ++j) { + // for (int i = 0; i < 16; ++i) { + // printf("%8.4f ", ((float *) dst->data)[k*dst->ne[0] + j*16 + i]); + // } + // printf("\n"); + // } + // printf("\n"); + // } + // printf("\n"); + // exit(0); + //} } // ggml_compute_forward_diag_mask_inf @@ -5172,13 +6539,15 @@ static void ggml_compute_forward_diag_mask_inf( { ggml_compute_forward_diag_mask_inf_f32(params, src0, src1, dst); } break; + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: case GGML_TYPE_I8: case GGML_TYPE_I16: case GGML_TYPE_I32: case GGML_TYPE_F16: case GGML_TYPE_COUNT: { - assert(false); + GGML_ASSERT(false); } break; } } @@ -5217,6 +6586,7 @@ static void ggml_compute_forward_soft_max_f32( #ifndef NDEBUG for (int i = 0; i < nc; ++i) { + //printf("p[%d] = %f\n", i, p[i]); assert(!isnan(p[i])); } #endif @@ -5263,13 +6633,15 @@ static void ggml_compute_forward_soft_max( { ggml_compute_forward_soft_max_f32(params, src0, dst); } break; + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: case GGML_TYPE_I8: case GGML_TYPE_I16: case GGML_TYPE_I32: case GGML_TYPE_F16: case GGML_TYPE_COUNT: { - assert(false); + GGML_ASSERT(false); } break; } } @@ -5333,23 +6705,84 @@ static void ggml_compute_forward_rope_f32( } } +static void ggml_compute_forward_rope_f16( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + struct ggml_tensor * dst) { + assert(params->ith == 0); + assert(src1->type == GGML_TYPE_I32); + assert(ggml_nelements(src1) == 3); + + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + const int n_past = ((int32_t *) src1->data)[0]; + const int n_dims = ((int32_t *) src1->data)[1]; + const int mode = ((int32_t *) src1->data)[2]; + + //const int ne0 = src0->ne[0]; + const int ne1 = src0->ne[1]; + const int ne2 = src0->ne[2]; + const int ne3 = src0->ne[3]; + + const int nb0 = src0->nb[0]; + const int nb1 = src0->nb[1]; + const int nb2 = src0->nb[2]; + const int nb3 = src0->nb[3]; + + //printf("ne0: %d, ne1: %d, ne2: %d, ne3: %d\n", ne0, ne1, ne2, ne3); + //printf("n_past = %d, ne2 = %d\n", n_past, ne2); + + assert(nb0 == sizeof(ggml_fp16_t)); + + for (int i3 = 0; i3 < ne3; i3++) { + for (int i2 = (mode == 0 ? 0 : n_past); i2 < ne2; i2++) { + const int p = (mode == 0 ? n_past + i2 : i2); + for (int i1 = 0; i1 < ne1; i1++) { + for (int i0 = 0; i0 < n_dims; i0 += 2) { + const double theta = pow(10000.0, ((double)-i0)/n_dims); + + const double cos_theta = cos(p*theta); + const double sin_theta = sin(p*theta); + + const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + double x0 = ggml_fp16_to_fp32(src[0]); + double x1 = ggml_fp16_to_fp32(src[1]); + + dst_data[0] = ggml_fp32_to_fp16(x0*cos_theta - x1*sin_theta); + dst_data[1] = ggml_fp32_to_fp16(x0*sin_theta + x1*cos_theta); + } + } + } + } +} + static void ggml_compute_forward_rope( const struct ggml_compute_params * params, const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst) { switch (src0->type) { + case GGML_TYPE_F16: + { + ggml_compute_forward_rope_f16(params, src0, src1, dst); + } break; case GGML_TYPE_F32: { ggml_compute_forward_rope_f32(params, src0, src1, dst); } break; + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: case GGML_TYPE_I8: case GGML_TYPE_I16: case GGML_TYPE_I32: - case GGML_TYPE_F16: case GGML_TYPE_COUNT: { - assert(false); + GGML_ASSERT(false); } break; } } @@ -5610,6 +7043,8 @@ static void ggml_compute_forward_conv_1d_1s( { ggml_compute_forward_conv_1d_1s_f32(params, src0, src1, dst); } break; + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: case GGML_TYPE_I8: case GGML_TYPE_I16: case GGML_TYPE_I32: @@ -5876,6 +7311,8 @@ static void ggml_compute_forward_conv_1d_2s( { ggml_compute_forward_conv_1d_2s_f32(params, src0, src1, dst); } break; + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: case GGML_TYPE_I8: case GGML_TYPE_I16: case GGML_TYPE_I32: @@ -6359,12 +7796,14 @@ static void ggml_compute_forward_flash_attn( { ggml_compute_forward_flash_attn_f32(params, q, k, v, masked, dst); } break; + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: case GGML_TYPE_I8: case GGML_TYPE_I16: case GGML_TYPE_I32: case GGML_TYPE_COUNT: { - assert(false); + GGML_ASSERT(false); } break; } } @@ -6568,12 +8007,14 @@ static void ggml_compute_forward_flash_ff( { GGML_ASSERT(false); // TODO } break; + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: case GGML_TYPE_I8: case GGML_TYPE_I16: case GGML_TYPE_I32: case GGML_TYPE_COUNT: { - assert(false); + GGML_ASSERT(false); } break; } } @@ -6581,7 +8022,7 @@ static void ggml_compute_forward_flash_ff( ///////////////////////////////// static void ggml_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor) { - assert(params); + GGML_ASSERT(params); switch (tensor->op) { case GGML_OP_DUP: @@ -6829,7 +8270,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor } break; case GGML_OP_MEAN: { - assert(false); // TODO: implement + GGML_ASSERT(false); // TODO: implement } break; case GGML_OP_REPEAT: { @@ -6884,17 +8325,17 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor } break; case GGML_OP_GELU: { - assert(false); // TODO: not implemented + GGML_ASSERT(false); // TODO: not implemented } break; case GGML_OP_NORM: { - assert(false); // TODO: not implemented + GGML_ASSERT(false); // TODO: not implemented } break; case GGML_OP_MUL_MAT: { if (src0->grad) { // TODO: this requires outer product - ggml_out_prod(ctx, src1, tensor->grad); - assert(false); + GGML_ASSERT(false); } if (src1->grad) { src1->grad = @@ -7010,12 +8451,12 @@ static void ggml_visit_parents(struct ggml_cgraph * cgraph, struct ggml_tensor * if (node->op == GGML_OP_NONE && node->grad == NULL) { // reached a leaf node, not part of the gradient graph (e.g. a constant) - assert(cgraph->n_leafs < GGML_MAX_NODES); + GGML_ASSERT(cgraph->n_leafs < GGML_MAX_NODES); cgraph->leafs[cgraph->n_leafs] = node; cgraph->n_leafs++; } else { - assert(cgraph->n_nodes < GGML_MAX_NODES); + GGML_ASSERT(cgraph->n_nodes < GGML_MAX_NODES); cgraph->nodes[cgraph->n_nodes] = node; cgraph->grads[cgraph->n_nodes] = node->grad; @@ -7039,7 +8480,7 @@ static void ggml_build_forward_impl(struct ggml_cgraph * cgraph, struct ggml_ten if (n_new > 0) { // the last added node should always be starting point - assert(cgraph->nodes[cgraph->n_nodes - 1] == tensor); + GGML_ASSERT(cgraph->nodes[cgraph->n_nodes - 1] == tensor); } } @@ -7070,7 +8511,7 @@ struct ggml_cgraph ggml_build_forward(struct ggml_tensor * tensor) { struct ggml_cgraph ggml_build_backward(struct ggml_context * ctx, struct ggml_cgraph * gf, bool keep) { struct ggml_cgraph result = *gf; - assert(gf->n_nodes > 0); + GGML_ASSERT(gf->n_nodes > 0); // if we are keeping the gradient graph, we have to detach the gradient nodes from the original graph if (keep) { @@ -7269,7 +8710,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph) }; int rc = ggml_thread_create(&workers[j].thrd, NULL, ggml_graph_compute_thread, &workers[j]); - assert(rc == 0); + GGML_ASSERT(rc == 0); UNUSED(rc); } } @@ -7331,6 +8772,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph) // TODO: better way to determine if the matrix is transposed if (node->src0->nb[1] < node->src0->nb[0]) { cur = ggml_nbytes(node)*node->n_tasks; // TODO: this can become (n_tasks-1) + // TODO: overestimated by factor of x2 for FP16 } else { if (node->src0->type == GGML_TYPE_F16 && node->src1->type == GGML_TYPE_F32) { @@ -7338,19 +8780,43 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph) if (ggml_compute_forward_mul_mat_use_blas(node->src0, node->src1, node)) { node->n_tasks = 1; // TODO: this actually is doing nothing // the threads are still spinning - cur = sizeof(float)*(node->src0->ne[0]*node->src0->ne[1]); + cur = GGML_TYPE_SIZE[GGML_TYPE_F32]*(node->src0->ne[0]*node->src0->ne[1]); //printf("src0: ne0 = %d, ne1 = %d, ne = %d\n", node->src0->ne[0], node->src0->ne[1], node->src0->ne[0]*node->src0->ne[1]); //printf("src1: ne0 = %d, ne1 = %d, ne = %d\n", node->src1->ne[0], node->src1->ne[1], node->src1->ne[0]*node->src1->ne[1]); //printf("cur = %zu\n", cur); } else { - cur = sizeof(ggml_fp16_t)*ggml_nelements(node->src1); + cur = GGML_TYPE_SIZE[GGML_TYPE_F16]*ggml_nelements(node->src1); } #else - cur = sizeof(ggml_fp16_t)*ggml_nelements(node->src1); + cur = GGML_TYPE_SIZE[GGML_TYPE_F16]*ggml_nelements(node->src1); #endif } else if (node->src0->type == GGML_TYPE_F32 && node->src1->type == GGML_TYPE_F32) { cur = 0; + } else if (node->src0->type == GGML_TYPE_Q4_0 && + node->src1->type == GGML_TYPE_F32) { +#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) + if (ggml_compute_forward_mul_mat_use_blas(node->src0, node->src1, node)) { + node->n_tasks = 1; + cur = GGML_TYPE_SIZE[GGML_TYPE_F32]*(node->src0->ne[0]*node->src0->ne[1]); + } else { + cur = (GGML_TYPE_SIZE[GGML_TYPE_Q4_0]*ggml_nelements(node->src1))/GGML_BLCK_SIZE[GGML_TYPE_Q4_0]; + } +#else + cur = (GGML_TYPE_SIZE[GGML_TYPE_Q4_0]*ggml_nelements(node->src1))/GGML_BLCK_SIZE[GGML_TYPE_Q4_0]; +#endif + } else if (node->src0->type == GGML_TYPE_Q4_1 && + node->src1->type == GGML_TYPE_F32) { +#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) + if (ggml_compute_forward_mul_mat_use_blas(node->src0, node->src1, node)) { + node->n_tasks = 1; + cur = GGML_TYPE_SIZE[GGML_TYPE_F32]*(node->src0->ne[0]*node->src0->ne[1]); + } else { + cur = (GGML_TYPE_SIZE[GGML_TYPE_Q4_1]*ggml_nelements(node->src1))/GGML_BLCK_SIZE[GGML_TYPE_Q4_1]; + } +#else + cur = (GGML_TYPE_SIZE[GGML_TYPE_Q4_1]*ggml_nelements(node->src1))/GGML_BLCK_SIZE[GGML_TYPE_Q4_1]; +#endif } else { GGML_ASSERT(false); } @@ -7454,13 +8920,13 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph) } break; case GGML_OP_COUNT: { - assert(false); + GGML_ASSERT(false); } break; } } if (cgraph->work != NULL && work_size > cgraph->work_size) { - assert(false); // TODO: better handling + GGML_ASSERT(false); // TODO: better handling } if (work_size > 0 && cgraph->work == NULL) { @@ -7626,7 +9092,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph) for (int j = 0; j < n_threads - 1; j++) { int rc = ggml_thread_join(workers[j].thrd, NULL); - assert(rc == 0); + GGML_ASSERT(rc == 0); UNUSED(rc); } @@ -7733,7 +9199,7 @@ void ggml_graph_dump_dot(const struct ggml_cgraph * gb, const struct ggml_cgraph char color[16]; FILE * fp = fopen(filename, "w"); - assert(fp); + GGML_ASSERT(fp); fprintf(fp, "digraph G {\n"); fprintf(fp, " newrank = true;\n"); @@ -7891,7 +9357,7 @@ static enum ggml_opt_result ggml_opt_adam( struct ggml_tensor * f, struct ggml_cgraph * gf, struct ggml_cgraph * gb) { - assert(ggml_is_scalar(f)); + GGML_ASSERT(ggml_is_scalar(f)); gf->n_threads = params.n_threads; gb->n_threads = params.n_threads; @@ -7905,7 +9371,7 @@ static enum ggml_opt_result ggml_opt_adam( if (gf->nodes[i]->is_param) { GGML_PRINT_DEBUG("found param %d: grad->op = %d\n", np, gf->nodes[i]->grad->op); - assert(np < GGML_MAX_PARAMS); + GGML_ASSERT(np < GGML_MAX_PARAMS); ps[np++] = gf->nodes[i]; nx += ggml_nelements(gf->nodes[i]); @@ -8205,7 +9671,7 @@ static enum ggml_opt_result ggml_opt_lbfgs( if (gf->nodes[i]->is_param) { GGML_PRINT_DEBUG("found param %d: grad->op = %d\n", np, gf->nodes[i]->grad->op); - assert(np < GGML_MAX_PARAMS); + GGML_ASSERT(np < GGML_MAX_PARAMS); ps[np++] = gf->nodes[i]; nx += ggml_nelements(gf->nodes[i]); diff --git a/ggml.h b/ggml.h index 18f317b..cdd8553 100644 --- a/ggml.h +++ b/ggml.h @@ -198,6 +198,8 @@ struct ggml_object; struct ggml_context; enum ggml_type { + GGML_TYPE_Q4_0, + GGML_TYPE_Q4_1, GGML_TYPE_I8, GGML_TYPE_I16, GGML_TYPE_I32, @@ -326,7 +328,10 @@ void ggml_print_objects(const struct ggml_context * ctx); int ggml_nelements(const struct ggml_tensor * tensor); size_t ggml_nbytes (const struct ggml_tensor * tensor); -size_t ggml_type_size (enum ggml_type type); +int ggml_blck_size (enum ggml_type type); +size_t ggml_type_size (enum ggml_type type); // size in bytes for all elements in a block +float ggml_type_sizef(enum ggml_type type); // ggml_type_size()/ggml_blck_size() as float + size_t ggml_element_size(const struct ggml_tensor * tensor); struct ggml_context * ggml_init(struct ggml_init_params params); diff --git a/whisper.cpp b/whisper.cpp index 3a21581..027175e 100644 --- a/whisper.cpp +++ b/whisper.cpp @@ -252,12 +252,34 @@ static const std::map MEM_REQ_SCRATCH3 = { { MODEL_LARGE, 9ull*MB }, }; -static const std::map MEM_REQ_MODEL = { - { MODEL_TINY, 74ull*MB }, - { MODEL_BASE, 142ull*MB }, - { MODEL_SMALL, 466ull*MB }, - { MODEL_MEDIUM, 1464ull*MB }, - { MODEL_LARGE, 2952ull*MB }, +static const std::map> MEM_REQ_MODEL = { + { GGML_TYPE_F16, + { + { MODEL_TINY, 74ull*MB }, + { MODEL_BASE, 142ull*MB }, + { MODEL_SMALL, 466ull*MB }, + { MODEL_MEDIUM, 1464ull*MB }, + { MODEL_LARGE, 2952ull*MB }, + }, + }, + { GGML_TYPE_Q4_0, + { + { MODEL_TINY, 26ull*MB }, + { MODEL_BASE, 50ull*MB }, + { MODEL_SMALL, 154ull*MB }, + { MODEL_MEDIUM, 470ull*MB }, + { MODEL_LARGE, 940ull*MB }, + }, + }, + { GGML_TYPE_Q4_1, + { + { MODEL_TINY, 31ull*MB }, + { MODEL_BASE, 57ull*MB }, + { MODEL_SMALL, 181ull*MB }, + { MODEL_MEDIUM, 559ull*MB }, + { MODEL_LARGE, 1122ull*MB }, + }, + }, }; static const std::map MEM_REQ_KV_SELF = { @@ -681,7 +703,7 @@ static bool kv_cache_reinit(struct whisper_kv_cache & cache) { const ggml_type wtype = cache.k->type; WHISPER_ASSERT(wtype == cache.v->type); - WHISPER_ASSERT(cache.buf.size() >= 2*n_elements*ggml_type_size(wtype)); + WHISPER_ASSERT(cache.buf.size() >= 2*n_elements*ggml_type_sizef(wtype)); struct ggml_init_params params; params.mem_size = cache.buf.size(); @@ -776,12 +798,25 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con model.type = e_model::MODEL_LARGE; } - // for the big tensors, we have the option to store the data in 16-bit floats + // for the big tensors, we have the option to store the data in 16-bit floats or quantized // in order to save memory and also to speed up the computation - wctx.wtype = model.hparams.f16 ? GGML_TYPE_F16 : GGML_TYPE_F32; + wctx.wtype = GGML_TYPE_COUNT; + switch (model.hparams.f16) { + case 0: wctx.wtype = GGML_TYPE_F32; break; + case 1: wctx.wtype = GGML_TYPE_F16; break; + case 2: wctx.wtype = GGML_TYPE_Q4_0; break; + case 3: wctx.wtype = GGML_TYPE_Q4_1; break; + default: + { + fprintf(stderr, "%s: invalid model (bad f16 value %d)\n", __func__, model.hparams.f16); + return false; + } + } const size_t scale = model.hparams.f16 ? 1 : 2; + static const char * ftype_str[] = { "f32", "f16", "q4_0", "q4_1", }; + fprintf(stderr, "%s: n_vocab = %d\n", __func__, hparams.n_vocab); fprintf(stderr, "%s: n_audio_ctx = %d\n", __func__, hparams.n_audio_ctx); fprintf(stderr, "%s: n_audio_state = %d\n", __func__, hparams.n_audio_state); @@ -792,7 +827,7 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con fprintf(stderr, "%s: n_text_head = %d\n", __func__, hparams.n_text_head); fprintf(stderr, "%s: n_text_layer = %d\n", __func__, hparams.n_text_layer); fprintf(stderr, "%s: n_mels = %d\n", __func__, hparams.n_mels); - fprintf(stderr, "%s: f16 = %d\n", __func__, hparams.f16); + fprintf(stderr, "%s: ftype = %s\n", __func__, ftype_str[model.hparams.f16]); fprintf(stderr, "%s: type = %d\n", __func__, model.type); // print memory requirements @@ -803,7 +838,7 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con MEM_REQ_SCRATCH1.at (model.type) + MEM_REQ_SCRATCH2.at (model.type) + MEM_REQ_SCRATCH3.at (model.type) + - scale*MEM_REQ_MODEL.at (model.type) + + scale*MEM_REQ_MODEL.at(wctx.wtype).at(model.type) + scale*MEM_REQ_KV_CROSS.at(model.type) + scale*std::max(MEM_REQ_ENCODE.at(model.type), MEM_REQ_DECODE.at(model.type)); @@ -819,9 +854,9 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con // always have at least one decoder wctx.model.buf = new std::vector(); - wctx.model.buf->resize(scale*MEM_REQ_MODEL.at(model.type)); + wctx.model.buf->resize(scale*MEM_REQ_MODEL.at(wctx.wtype).at(model.type)); - if (!kv_cache_init(model.hparams, scale*MEM_REQ_KV_SELF.at(model.type), wctx.decoders[0].kv_self, wctx.wtype, model.hparams.n_text_ctx)) { + if (!kv_cache_init(model.hparams, scale*MEM_REQ_KV_SELF.at(model.type), wctx.decoders[0].kv_self, GGML_TYPE_F16, model.hparams.n_text_ctx)) { fprintf(stderr, "%s: kv_cache_init() failed for self-attention cache\n", __func__); return false; } @@ -831,7 +866,7 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con fprintf(stderr, "%s: kv self size = %7.2f MB\n", __func__, memory_size/1024.0/1024.0); } - if (!kv_cache_init(model.hparams, scale*MEM_REQ_KV_CROSS.at(model.type), wctx.kv_cross, wctx.wtype, model.hparams.n_audio_ctx)) { + if (!kv_cache_init(model.hparams, scale*MEM_REQ_KV_CROSS.at(model.type), wctx.kv_cross, GGML_TYPE_F16, model.hparams.n_audio_ctx)) { fprintf(stderr, "%s: kv_cache_init() failed for cross-attention cache\n", __func__); return false; } @@ -963,92 +998,92 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con // encoder { - ctx_size += n_audio_ctx*n_audio_state*ggml_type_size(GGML_TYPE_F32); // e_pe; + ctx_size += n_audio_ctx*n_audio_state*ggml_type_sizef(GGML_TYPE_F32); // e_pe; - ctx_size += 3*n_mels*n_audio_state*ggml_type_size(wtype); // e_conv_1_w - ctx_size += n_audio_state*ggml_type_size(GGML_TYPE_F32); // e_conv_1_b + ctx_size += 3*n_mels*n_audio_state*ggml_type_sizef(GGML_TYPE_F16); // e_conv_1_w + ctx_size += n_audio_state*ggml_type_sizef(GGML_TYPE_F32); // e_conv_1_b - ctx_size += 3*n_audio_state*n_audio_state*ggml_type_size(wtype); // e_conv_2_w - ctx_size += n_audio_state*ggml_type_size(GGML_TYPE_F32); // e_conv_2_b + ctx_size += 3*n_audio_state*n_audio_state*ggml_type_sizef(GGML_TYPE_F16); // e_conv_2_w + ctx_size += n_audio_state*ggml_type_sizef(GGML_TYPE_F32); // e_conv_2_b - ctx_size += n_audio_state*ggml_type_size(GGML_TYPE_F32); // e_ln_w; - ctx_size += n_audio_state*ggml_type_size(GGML_TYPE_F32); // e_ln_b; + ctx_size += n_audio_state*ggml_type_sizef(GGML_TYPE_F32); // e_ln_w; + ctx_size += n_audio_state*ggml_type_sizef(GGML_TYPE_F32); // e_ln_b; } // decoder { - ctx_size += n_text_ctx*n_text_state*ggml_type_size(GGML_TYPE_F32); // d_pe; + ctx_size += n_text_ctx*n_text_state*ggml_type_sizef(GGML_TYPE_F32); // d_pe; - ctx_size += n_vocab*n_text_state*ggml_type_size(wtype); // d_te; + ctx_size += n_vocab*n_text_state*ggml_type_sizef(wtype); // d_te; - ctx_size += n_text_state*ggml_type_size(GGML_TYPE_F32); // d_ln_w; - ctx_size += n_text_state*ggml_type_size(GGML_TYPE_F32); // d_ln_b; + ctx_size += n_text_state*ggml_type_sizef(GGML_TYPE_F32); // d_ln_w; + ctx_size += n_text_state*ggml_type_sizef(GGML_TYPE_F32); // d_ln_b; } // encoder layers { - ctx_size += n_audio_layer*(n_audio_state*ggml_type_size(GGML_TYPE_F32)); // mlp_ln_w - ctx_size += n_audio_layer*(n_audio_state*ggml_type_size(GGML_TYPE_F32)); // mlp_ln_b + ctx_size += n_audio_layer*(n_audio_state*ggml_type_sizef(GGML_TYPE_F32)); // mlp_ln_w + ctx_size += n_audio_layer*(n_audio_state*ggml_type_sizef(GGML_TYPE_F32)); // mlp_ln_b - ctx_size += n_audio_layer*(4*n_audio_state*n_audio_state*ggml_type_size(wtype)); // mlp_0_w - ctx_size += n_audio_layer*( 4*n_audio_state*ggml_type_size(GGML_TYPE_F32)); // mlp_0_b + ctx_size += n_audio_layer*(4*n_audio_state*n_audio_state*ggml_type_sizef(wtype)); // mlp_0_w + ctx_size += n_audio_layer*( 4*n_audio_state*ggml_type_sizef(GGML_TYPE_F32)); // mlp_0_b - ctx_size += n_audio_layer*(4*n_audio_state*n_audio_state*ggml_type_size(wtype)); // mlp_1_w - ctx_size += n_audio_layer*( n_audio_state*ggml_type_size(GGML_TYPE_F32)); // mlp_1_b + ctx_size += n_audio_layer*(4*n_audio_state*n_audio_state*ggml_type_sizef(wtype)); // mlp_1_w + ctx_size += n_audio_layer*( n_audio_state*ggml_type_sizef(GGML_TYPE_F32)); // mlp_1_b - ctx_size += n_audio_layer*(n_audio_state*ggml_type_size(GGML_TYPE_F32)); // attn_ln_0_w - ctx_size += n_audio_layer*(n_audio_state*ggml_type_size(GGML_TYPE_F32)); // attn_ln_0_b + ctx_size += n_audio_layer*(n_audio_state*ggml_type_sizef(GGML_TYPE_F32)); // attn_ln_0_w + ctx_size += n_audio_layer*(n_audio_state*ggml_type_sizef(GGML_TYPE_F32)); // attn_ln_0_b - ctx_size += n_audio_layer*(n_audio_state*n_audio_state*ggml_type_size(wtype)); // attn_q_w - ctx_size += n_audio_layer*( n_audio_state*ggml_type_size(GGML_TYPE_F32)); // attn_q_b + ctx_size += n_audio_layer*(n_audio_state*n_audio_state*ggml_type_sizef(wtype)); // attn_q_w + ctx_size += n_audio_layer*( n_audio_state*ggml_type_sizef(GGML_TYPE_F32)); // attn_q_b - ctx_size += n_audio_layer*(n_audio_state*n_audio_state*ggml_type_size(wtype)); // attn_k_w + ctx_size += n_audio_layer*(n_audio_state*n_audio_state*ggml_type_sizef(wtype)); // attn_k_w - ctx_size += n_audio_layer*(n_audio_state*n_audio_state*ggml_type_size(wtype)); // attn_v_w - ctx_size += n_audio_layer*( n_audio_state*ggml_type_size(GGML_TYPE_F32)); // attn_v_b + ctx_size += n_audio_layer*(n_audio_state*n_audio_state*ggml_type_sizef(wtype)); // attn_v_w + ctx_size += n_audio_layer*( n_audio_state*ggml_type_sizef(GGML_TYPE_F32)); // attn_v_b - ctx_size += n_audio_layer*(n_audio_state*n_audio_state*ggml_type_size(wtype)); // attn_ln_1_w - ctx_size += n_audio_layer*( n_audio_state*ggml_type_size(GGML_TYPE_F32)); // attn_ln_1_b + ctx_size += n_audio_layer*(n_audio_state*n_audio_state*ggml_type_sizef(wtype)); // attn_ln_1_w + ctx_size += n_audio_layer*( n_audio_state*ggml_type_sizef(GGML_TYPE_F32)); // attn_ln_1_b } // decoder layers { - ctx_size += n_text_layer*(n_text_state*ggml_type_size(GGML_TYPE_F32)); // mlp_ln_w - ctx_size += n_text_layer*(n_text_state*ggml_type_size(GGML_TYPE_F32)); // mlp_ln_b + ctx_size += n_text_layer*(n_text_state*ggml_type_sizef(GGML_TYPE_F32)); // mlp_ln_w + ctx_size += n_text_layer*(n_text_state*ggml_type_sizef(GGML_TYPE_F32)); // mlp_ln_b - ctx_size += n_text_layer*(4*n_text_state*n_text_state*ggml_type_size(wtype)); // mlp_0_w - ctx_size += n_text_layer*( 4*n_text_state*ggml_type_size(GGML_TYPE_F32)); // mlp_0_b + ctx_size += n_text_layer*(4*n_text_state*n_text_state*ggml_type_sizef(wtype)); // mlp_0_w + ctx_size += n_text_layer*( 4*n_text_state*ggml_type_sizef(GGML_TYPE_F32)); // mlp_0_b - ctx_size += n_text_layer*(4*n_text_state*n_text_state*ggml_type_size(wtype)); // mlp_1_w - ctx_size += n_text_layer*( n_text_state*ggml_type_size(GGML_TYPE_F32)); // mlp_1_b + ctx_size += n_text_layer*(4*n_text_state*n_text_state*ggml_type_sizef(wtype)); // mlp_1_w + ctx_size += n_text_layer*( n_text_state*ggml_type_sizef(GGML_TYPE_F32)); // mlp_1_b - ctx_size += n_text_layer*(n_text_state*ggml_type_size(GGML_TYPE_F32)); // attn_ln_0_w - ctx_size += n_text_layer*(n_text_state*ggml_type_size(GGML_TYPE_F32)); // attn_ln_0_b + ctx_size += n_text_layer*(n_text_state*ggml_type_sizef(GGML_TYPE_F32)); // attn_ln_0_w + ctx_size += n_text_layer*(n_text_state*ggml_type_sizef(GGML_TYPE_F32)); // attn_ln_0_b - ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_size(wtype)); // attn_q_w - ctx_size += n_text_layer*( n_text_state*ggml_type_size(GGML_TYPE_F32)); // attn_q_b + ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_sizef(wtype)); // attn_q_w + ctx_size += n_text_layer*( n_text_state*ggml_type_sizef(GGML_TYPE_F32)); // attn_q_b - ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_size(wtype)); // attn_k_w + ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_sizef(wtype)); // attn_k_w - ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_size(wtype)); // attn_v_w - ctx_size += n_text_layer*( n_text_state*ggml_type_size(GGML_TYPE_F32)); // attn_v_b + ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_sizef(wtype)); // attn_v_w + ctx_size += n_text_layer*( n_text_state*ggml_type_sizef(GGML_TYPE_F32)); // attn_v_b - ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_size(wtype)); // attn_ln_1_w - ctx_size += n_text_layer*( n_text_state*ggml_type_size(GGML_TYPE_F32)); // attn_ln_1_b + ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_sizef(wtype)); // attn_ln_1_w + ctx_size += n_text_layer*( n_text_state*ggml_type_sizef(GGML_TYPE_F32)); // attn_ln_1_b // - ctx_size += n_text_layer*(n_text_state*ggml_type_size(GGML_TYPE_F32)); // cross_attn_ln_0_w - ctx_size += n_text_layer*(n_text_state*ggml_type_size(GGML_TYPE_F32)); // cross_attn_ln_0_b + ctx_size += n_text_layer*(n_text_state*ggml_type_sizef(GGML_TYPE_F32)); // cross_attn_ln_0_w + ctx_size += n_text_layer*(n_text_state*ggml_type_sizef(GGML_TYPE_F32)); // cross_attn_ln_0_b - ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_size(wtype)); // cross_attn_q_w - ctx_size += n_text_layer*( n_text_state*ggml_type_size(GGML_TYPE_F32)); // cross_attn_q_b + ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_sizef(wtype)); // cross_attn_q_w + ctx_size += n_text_layer*( n_text_state*ggml_type_sizef(GGML_TYPE_F32)); // cross_attn_q_b - ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_size(wtype)); // cross_attn_k_w + ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_sizef(wtype)); // cross_attn_k_w - ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_size(wtype)); // cross_attn_v_w - ctx_size += n_text_layer*( n_text_state*ggml_type_size(GGML_TYPE_F32)); // cross_attn_v_b + ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_sizef(wtype)); // cross_attn_v_w + ctx_size += n_text_layer*( n_text_state*ggml_type_sizef(GGML_TYPE_F32)); // cross_attn_v_b - ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_size(wtype)); // cross_attn_ln_1_w - ctx_size += n_text_layer*( n_text_state*ggml_type_size(GGML_TYPE_F32)); // cross_attn_ln_1_b + ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_sizef(wtype)); // cross_attn_ln_1_w + ctx_size += n_text_layer*( n_text_state*ggml_type_sizef(GGML_TYPE_F32)); // cross_attn_ln_1_b } ctx_size += (15 + 15*n_audio_layer + 24*n_text_layer)*256; // object overhead @@ -1094,10 +1129,10 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con { model.e_pe = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_audio_state, n_audio_ctx); - model.e_conv_1_w = ggml_new_tensor_3d(ctx, wtype, 3, n_mels, n_audio_state); + model.e_conv_1_w = ggml_new_tensor_3d(ctx, GGML_TYPE_F16, 3, n_mels, n_audio_state); model.e_conv_1_b = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 1, n_audio_state); - model.e_conv_2_w = ggml_new_tensor_3d(ctx, wtype, 3, n_audio_state, n_audio_state); + model.e_conv_2_w = ggml_new_tensor_3d(ctx, GGML_TYPE_F16, 3, n_audio_state, n_audio_state); model.e_conv_2_b = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 1, n_audio_state); model.e_ln_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state); @@ -1313,9 +1348,21 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con return false; } - const size_t bpe = (ftype == 0) ? sizeof(float) : sizeof(ggml_fp16_t); + size_t bpe = 0; + + switch (ftype) { + case 0: bpe = ggml_type_size(GGML_TYPE_F32); break; + case 1: bpe = ggml_type_size(GGML_TYPE_F16); break; + case 2: bpe = ggml_type_size(GGML_TYPE_Q4_0); assert(ne[0] % 64 == 0); break; + case 3: bpe = ggml_type_size(GGML_TYPE_Q4_1); assert(ne[0] % 64 == 0); break; + default: + { + fprintf(stderr, "%s: unknown ftype %d in model file\n", __func__, ftype); + return false; + } + }; - if (nelements*bpe != ggml_nbytes(tensor)) { + if ((nelements*bpe)/ggml_blck_size(tensor->type) != ggml_nbytes(tensor)) { fprintf(stderr, "%s: tensor '%s' has wrong size in model file: got %zu, expected %zu\n", __func__, name.data(), ggml_nbytes(tensor), nelements*bpe); return false; @@ -1513,14 +1560,14 @@ static bool whisper_encode( ggml_permute(ctx0, ggml_cpy(ctx0, Qcur, - ggml_new_tensor_3d(ctx0, wctx.wtype, n_state/n_head, n_head, n_ctx)), + ggml_new_tensor_3d(ctx0, GGML_TYPE_F16, n_state/n_head, n_head, n_ctx)), 0, 2, 1, 3); struct ggml_tensor * K = ggml_permute(ctx0, ggml_cpy(ctx0, Kcur, - ggml_new_tensor_3d(ctx0, wctx.wtype, n_state/n_head, n_head, n_ctx)), + ggml_new_tensor_3d(ctx0, GGML_TYPE_F16, n_state/n_head, n_head, n_ctx)), 0, 2, 1, 3); struct ggml_tensor * V = @@ -1530,7 +1577,7 @@ static bool whisper_encode( Vcur, n_state/n_head, n_head, n_ctx), 1, 2, 0, 3), - ggml_new_tensor_3d(ctx0, wctx.wtype, n_ctx, n_state/n_head, n_head) + ggml_new_tensor_3d(ctx0, GGML_TYPE_F16, n_ctx, n_state/n_head, n_head) ); struct ggml_tensor * KQV = ggml_flash_attn(ctx0, Q, K, V, false); @@ -1546,7 +1593,7 @@ static bool whisper_encode( ggml_permute(ctx0, ggml_cpy(ctx0, Kcur, - ggml_new_tensor_3d(ctx0, wctx.wtype, n_state/n_head, n_head, n_ctx)), + ggml_new_tensor_3d(ctx0, GGML_TYPE_F16, n_state/n_head, n_head, n_ctx)), 0, 2, 1, 3); // K * Q @@ -1564,7 +1611,7 @@ static bool whisper_encode( // ggml_permute(ctx0, // ggml_cpy(ctx0, // Vcur, - // ggml_new_tensor_3d(ctx0, wctx.wtype, n_state/n_head, n_head, n_ctx)), + // ggml_new_tensor_3d(ctx0, GGML_TYPE_F16, n_state/n_head, n_head, n_ctx)), // 1, 2, 0, 3); //struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V_trans, KQ_soft_max); @@ -1576,7 +1623,7 @@ static bool whisper_encode( Vcur, n_state/n_head, n_head, n_ctx), 0, 2, 1, 3), - ggml_new_tensor_3d(ctx0, wctx.wtype, n_state/n_head, n_ctx, n_head) + ggml_new_tensor_3d(ctx0, GGML_TYPE_F16, n_state/n_head, n_ctx, n_head) ); struct ggml_tensor * KQV = ggml_mul_mat(ctx0, ggml_transpose(ctx0, V), KQ_soft_max); @@ -1634,7 +1681,7 @@ static bool whisper_encode( wctx.use_buf(ctx0, 0); cur = ggml_flash_ff(ctx0, - ggml_cpy(ctx0, cur, ggml_new_tensor_2d(ctx0, wctx.wtype, n_state, n_ctx)), + ggml_cpy(ctx0, cur, ggml_new_tensor_2d(ctx0, GGML_TYPE_F16, n_state, n_ctx)), layer.mlp_0_w, layer.mlp_0_b, layer.mlp_1_w, layer.mlp_1_b); #else wctx.use_buf(ctx0, 0);