diff --git a/src/ggml.c b/src/ggml.c index 1a138c4..2b59648 100644 --- a/src/ggml.c +++ b/src/ggml.c @@ -367,36 +367,76 @@ void quantize_row_q4_0(const float * restrict x, void * restrict y, int k) { uint8_t pp[QK/2]; +#if __ARM_NEON +#if QK == 32 for (int i = 0; i < nb; i++) { float amax = 0.0f; // absolute max - { - for (int l = 0; l < QK; l++) { - const float v = x[i*QK + l]; - amax = MAX(amax, fabsf(v)); - } + float32x4_t srcv [8]; + float32x4_t asrcv[8]; + float32x4_t amaxv[8]; - const float d = amax / ((1 << 3) - 1); - const float id = d ? 1.0f/d : 0.0f; + for (int l = 0; l < 8; l++) srcv[l] = vld1q_f32(x + i*32 + 4*l); + for (int l = 0; l < 8; l++) asrcv[l] = vabsq_f32(srcv[l]); - pd[i] = d; + for (int l = 0; l < 4; l++) amaxv[2*l] = vmaxq_f32(asrcv[2*l], asrcv[2*l+1]); + for (int l = 0; l < 2; l++) amaxv[4*l] = vmaxq_f32(amaxv[4*l], amaxv[4*l+2]); + for (int l = 0; l < 1; l++) amaxv[8*l] = vmaxq_f32(amaxv[8*l], amaxv[8*l+4]); - for (int l = 0; l < QK; l += 2) { - const float v0 = x[i*QK + l + 0]*id; - const float v1 = x[i*QK + l + 1]*id; + amax = MAX( + MAX(vgetq_lane_f32(amaxv[0], 0), vgetq_lane_f32(amaxv[0], 1)), + MAX(vgetq_lane_f32(amaxv[0], 2), vgetq_lane_f32(amaxv[0], 3))); - const uint8_t vi0 = ((int8_t) (round(v0))) + 8; - const uint8_t vi1 = ((int8_t) (round(v1))) + 8; + const float d = amax / ((1 << 3) - 1); + const float id = d ? 1.0/d : 0.0; - assert(vi0 >= 0 && vi0 < 16); - assert(vi1 >= 0 && vi1 < 16); + pd[i] = d; - pp[l/2] = vi0 | (vi1 << 4); - } + for (int l = 0; l < 8; l++) { + const float32x4_t v = vmulq_n_f32(srcv[l], id); + const float32x4_t vf = vaddq_f32(v, vdupq_n_f32(8.5f)); + const int32x4_t vi = vcvtq_s32_f32(vf); - memcpy(pb + i*QK/2, pp, sizeof(pp)); + pp[2*l + 0] = vgetq_lane_s32(vi, 0) | (vgetq_lane_s32(vi, 1) << 4); + pp[2*l + 1] = vgetq_lane_s32(vi, 2) | (vgetq_lane_s32(vi, 3) << 4); + } + + memcpy(pb + i*16, pp, sizeof(pp)); + } +#else +#error "not implemented for QK" +#endif +#else + // scalar + for (int i = 0; i < nb; i++) { + float amax = 0.0f; // absolute max + + for (int l = 0; l < QK; l++) { + const float v = x[i*QK + l]; + amax = MAX(amax, fabsf(v)); } + + const float d = amax / ((1 << 3) - 1); + const float id = d ? 1.0f/d : 0.0f; + + pd[i] = d; + + for (int l = 0; l < QK; l += 2) { + const float v0 = x[i*QK + l + 0]*id; + const float v1 = x[i*QK + l + 1]*id; + + const uint8_t vi0 = ((int8_t) (round(v0))) + 8; + const uint8_t vi1 = ((int8_t) (round(v1))) + 8; + + assert(vi0 >= 0 && vi0 < 16); + assert(vi1 >= 0 && vi1 < 16); + + pp[l/2] = vi0 | (vi1 << 4); + } + + memcpy(pb + i*QK/2, pp, sizeof(pp)); } +#endif } // method 4 diff --git a/tests/test-mul-mat2.c b/tests/test-mul-mat2.c index f5ebfd5..be7b038 100644 --- a/tests/test-mul-mat2.c +++ b/tests/test-mul-mat2.c @@ -2051,7 +2051,6 @@ void quantize_6_row(const float * restrict src, void * restrict dst, int k) { asrcv[l] = _mm256_and_ps(srcv[l], (__m256) _mm256_set1_epi32(0x7fffffff)); } - for (int l = 0; l < QK8/2; l++) { amaxv[2*l] = _mm256_max_ps(asrcv[2*l], asrcv[2*l+1]); } @@ -2060,8 +2059,6 @@ void quantize_6_row(const float * restrict src, void * restrict dst, int k) { amaxv[4*l] = _mm256_max_ps(amaxv[4*l], amaxv[4*l+2]); } - //amax = MAX(amaxv[0][0], MAX(amaxv[0][1], MAX(amaxv[0][2], MAX(amaxv[0][3], MAX(amaxv[0][4], MAX(amaxv[0][5], MAX(amaxv[0][6], amaxv[0][7]))))))); - const __m256 amaxv0_0 = _mm256_permute2f128_ps(amaxv[0], amaxv[0], 3); const __m256 amaxv0_1 = _mm256_max_ps(amaxv[0], amaxv0_0); const __m256 amaxv0_2 = _mm256_permute_ps(amaxv0_1, 0x4e); @@ -2071,8 +2068,6 @@ void quantize_6_row(const float * restrict src, void * restrict dst, int k) { amax = _mm256_cvtss_f32(amaxv0_5); - //printf("amax = %f\n", amax); - const float d = amax / ((1 << (QB - 1)) - 1); const float id = d ? 1.0/d : 0.0; @@ -2082,10 +2077,6 @@ void quantize_6_row(const float * restrict src, void * restrict dst, int k) { for (int l = 0; l < 4; l++) { __m256 v = _mm256_mul_ps(srcv[l], idv); -#if 0 - v[0] += frand(); v[1] += frand(); v[2] += frand(); v[3] += frand(); - v[4] += frand(); v[5] += frand(); v[6] += frand(); v[7] += frand(); -#endif // convert to int8 __m256i vi = _mm256_cvtps_epi32(v); @@ -2107,9 +2098,6 @@ void quantize_6_row(const float * restrict src, void * restrict dst, int k) { pp[4*l + 2] = vi_4 | (vi_5 << 4); pp[4*l + 3] = vi_6 | (vi_7 << 4); - //printf("vi: %7d %7d %7d %7d %7d %7d %7d %7d\n", vi_0, vi_1, vi_2, vi_3, vi_4, vi_5, vi_6, vi_7); - ////printf("v : %7.3f %7.3f %7.3f %7.3f %7.3f %7.3f %7.3f %7.3f\n", v[0], v[1], v[2], v[3], v[4], v[5], v[6], v[7]); - assert(vi_0 >= 0 && vi_0 < 16); assert(vi_1 >= 0 && vi_1 < 16); assert(vi_2 >= 0 && vi_2 < 16); @@ -2123,9 +2111,38 @@ void quantize_6_row(const float * restrict src, void * restrict dst, int k) { memcpy(pb + i*16, pp, sizeof(pp)); } -#elif defined(__ARM_NEON) && 0 +#elif defined(__ARM_NEON) { - // TODO + float32x4_t srcv [8]; + float32x4_t asrcv[8]; + float32x4_t amaxv[8]; + + for (int l = 0; l < 8; l++) srcv[l] = vld1q_f32(src + i*32 + 4*l); + for (int l = 0; l < 8; l++) asrcv[l] = vabsq_f32(srcv[l]); + + for (int l = 0; l < 4; l++) amaxv[2*l] = vmaxq_f32(asrcv[2*l], asrcv[2*l+1]); + for (int l = 0; l < 2; l++) amaxv[4*l] = vmaxq_f32(amaxv[4*l], amaxv[4*l+2]); + for (int l = 0; l < 1; l++) amaxv[8*l] = vmaxq_f32(amaxv[8*l], amaxv[8*l+4]); + + amax = MAX( + MAX(vgetq_lane_f32(amaxv[0], 0), vgetq_lane_f32(amaxv[0], 1)), + MAX(vgetq_lane_f32(amaxv[0], 2), vgetq_lane_f32(amaxv[0], 3))); + + const float d = amax / ((1 << 3) - 1); + const float id = d ? 1.0/d : 0.0; + + pd[i] = GGML_FP32_TO_GQ(d); + + for (int l = 0; l < 8; l++) { + const float32x4_t v = vmulq_n_f32(srcv[l], id); + const float32x4_t vf = vaddq_f32(v, vdupq_n_f32(8.5f)); + const int32x4_t vi = vcvtq_s32_f32(vf); + + pp[2*l + 0] = vgetq_lane_s32(vi, 0) | (vgetq_lane_s32(vi, 1) << 4); + pp[2*l + 1] = vgetq_lane_s32(vi, 2) | (vgetq_lane_s32(vi, 3) << 4); + } + + memcpy(pb + i*16, pp, sizeof(pp)); } #else { @@ -2149,7 +2166,7 @@ void quantize_6_row(const float * restrict src, void * restrict dst, int k) { memcpy(pb + i*16, pp, sizeof(pp)); } #endif - //printf("min %f max %f\n", min, max); + //printf("amax = %f\n", amax); } }