|
|
@ -407,6 +407,45 @@ void quantize_row_q4_0(const float * restrict x, void * restrict y, int k) {
|
|
|
|
#else
|
|
|
|
#else
|
|
|
|
#error "not implemented for QK"
|
|
|
|
#error "not implemented for QK"
|
|
|
|
#endif
|
|
|
|
#endif
|
|
|
|
|
|
|
|
#elif defined(__wasm_simd128__)
|
|
|
|
|
|
|
|
#if QK == 32
|
|
|
|
|
|
|
|
for (int i = 0; i < nb; i++) {
|
|
|
|
|
|
|
|
float amax = 0.0f; // absolute max
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
v128_t srcv [8];
|
|
|
|
|
|
|
|
v128_t asrcv[8];
|
|
|
|
|
|
|
|
v128_t amaxv[8];
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for (int l = 0; l < 8; l++) srcv[l] = wasm_v128_load(x + i*32 + 4*l);
|
|
|
|
|
|
|
|
for (int l = 0; l < 8; l++) asrcv[l] = wasm_f32x4_abs(srcv[l]);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for (int l = 0; l < 4; l++) amaxv[2*l] = wasm_f32x4_max(asrcv[2*l], asrcv[2*l+1]);
|
|
|
|
|
|
|
|
for (int l = 0; l < 2; l++) amaxv[4*l] = wasm_f32x4_max(amaxv[4*l], amaxv[4*l+2]);
|
|
|
|
|
|
|
|
for (int l = 0; l < 1; l++) amaxv[8*l] = wasm_f32x4_max(amaxv[8*l], amaxv[8*l+4]);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
amax = MAX(
|
|
|
|
|
|
|
|
MAX(wasm_f32x4_extract_lane(amaxv[0], 0), wasm_f32x4_extract_lane(amaxv[0], 1)),
|
|
|
|
|
|
|
|
MAX(wasm_f32x4_extract_lane(amaxv[0], 2), wasm_f32x4_extract_lane(amaxv[0], 3)));
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
const float d = amax / ((1 << 3) - 1);
|
|
|
|
|
|
|
|
const float id = d ? 1.0/d : 0.0;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
pd[i] = d;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for (int l = 0; l < 8; l++) {
|
|
|
|
|
|
|
|
const v128_t v = wasm_f32x4_mul(srcv[l], wasm_f32x4_splat(id));
|
|
|
|
|
|
|
|
const v128_t vf = wasm_f32x4_add(v, wasm_f32x4_splat(8.5f));
|
|
|
|
|
|
|
|
const v128_t vi = wasm_i32x4_trunc_sat_f32x4(vf);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
pp[2*l + 0] = wasm_i32x4_extract_lane(vi, 0) | (wasm_i32x4_extract_lane(vi, 1) << 4);
|
|
|
|
|
|
|
|
pp[2*l + 1] = wasm_i32x4_extract_lane(vi, 2) | (wasm_i32x4_extract_lane(vi, 3) << 4);
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
memcpy(pb + i*16, pp, sizeof(pp));
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
#else
|
|
|
|
|
|
|
|
#error "not implemented for QK"
|
|
|
|
|
|
|
|
#endif
|
|
|
|
#else
|
|
|
|
#else
|
|
|
|
// scalar
|
|
|
|
// scalar
|
|
|
|
for (int i = 0; i < nb; i++) {
|
|
|
|
for (int i = 0; i < nb; i++) {
|
|
|
@ -1216,9 +1255,6 @@ inline static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void
|
|
|
|
const int16x8_t p_0 = vaddq_s16(pl_0, ph_0);
|
|
|
|
const int16x8_t p_0 = vaddq_s16(pl_0, ph_0);
|
|
|
|
const int16x8_t p_1 = vaddq_s16(pl_1, ph_1);
|
|
|
|
const int16x8_t p_1 = vaddq_s16(pl_1, ph_1);
|
|
|
|
|
|
|
|
|
|
|
|
//printf("p_0: %d %d %d %d %d %d %d %d\n", vgetq_lane_s16(p_0, 0), vgetq_lane_s16(p_0, 1), vgetq_lane_s16(p_0, 2), vgetq_lane_s16(p_0, 3), vgetq_lane_s16(p_0, 4), vgetq_lane_s16(p_0, 5), vgetq_lane_s16(p_0, 6), vgetq_lane_s16(p_0, 7));
|
|
|
|
|
|
|
|
//printf("p_1: %d %d %d %d %d %d %d %d\n", vgetq_lane_s16(p_1, 0), vgetq_lane_s16(p_1, 1), vgetq_lane_s16(p_1, 2), vgetq_lane_s16(p_1, 3), vgetq_lane_s16(p_1, 4), vgetq_lane_s16(p_1, 5), vgetq_lane_s16(p_1, 6), vgetq_lane_s16(p_1, 7));
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// scalar
|
|
|
|
// scalar
|
|
|
|
#if defined(__ARM_FEATURE_QRDMX)
|
|
|
|
#if defined(__ARM_FEATURE_QRDMX)
|
|
|
|
sum0 += d0_0*d1_0*vaddvq_s16(p_0);
|
|
|
|
sum0 += d0_0*d1_0*vaddvq_s16(p_0);
|
|
|
@ -1230,35 +1266,93 @@ inline static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
sumf = sum0 + sum1;
|
|
|
|
sumf = sum0 + sum1;
|
|
|
|
|
|
|
|
#else
|
|
|
|
|
|
|
|
#error "not implemented for QK"
|
|
|
|
|
|
|
|
#endif
|
|
|
|
|
|
|
|
#elif defined(__wasm_simd128__)
|
|
|
|
|
|
|
|
#if QK == 32
|
|
|
|
|
|
|
|
// wasm simd
|
|
|
|
|
|
|
|
float sum0 = 0.0f;
|
|
|
|
|
|
|
|
float sum1 = 0.0f;
|
|
|
|
|
|
|
|
|
|
|
|
//printf("sumf SIMD = %f\n", sumf);
|
|
|
|
for (int i = 0; i < nb; i += 2) {
|
|
|
|
|
|
|
|
const float d0_0 = pd0[i + 0];
|
|
|
|
|
|
|
|
const float d0_1 = pd0[i + 1];
|
|
|
|
|
|
|
|
const float d1_0 = pd1[i + 0];
|
|
|
|
|
|
|
|
const float d1_1 = pd1[i + 1];
|
|
|
|
|
|
|
|
|
|
|
|
//// scalar
|
|
|
|
const uint8_t * restrict p0 = pb0 + i*16;
|
|
|
|
//sumf = 0.0f;
|
|
|
|
const uint8_t * restrict p1 = pb1 + i*16;
|
|
|
|
//for (int i = 0; i < nb; i++) {
|
|
|
|
|
|
|
|
// const float d0 = pd0[i];
|
|
|
|
|
|
|
|
// const float d1 = pd1[i];
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// const uint8_t * restrict p0 = pb0 + i*QK/2;
|
|
|
|
const v128_t m4b = wasm_u8x16_splat(0xf);
|
|
|
|
// const uint8_t * restrict p1 = pb1 + i*QK/2;
|
|
|
|
const v128_t s8b = wasm_i8x16_splat(0x8);
|
|
|
|
|
|
|
|
|
|
|
|
// for (int j = 0; j < QK/2; j++) {
|
|
|
|
const v128_t v0_0 = wasm_v128_load(p0);
|
|
|
|
// const uint8_t v0 = p0[j];
|
|
|
|
const v128_t v0_1 = wasm_v128_load(p0 + 16);
|
|
|
|
// const uint8_t v1 = p1[j];
|
|
|
|
const v128_t v1_0 = wasm_v128_load(p1);
|
|
|
|
|
|
|
|
const v128_t v1_1 = wasm_v128_load(p1 + 16);
|
|
|
|
|
|
|
|
|
|
|
|
// const float f0 = d0*((int8_t) (v0 & 0xf) - 8);
|
|
|
|
// 4-bit -> 8-bit
|
|
|
|
// const float f1 = d0*((int8_t) (v0 >> 4) - 8);
|
|
|
|
const v128_t v0_0l = wasm_v128_and(v0_0, m4b);
|
|
|
|
|
|
|
|
const v128_t v1_0l = wasm_v128_and(v1_0, m4b);
|
|
|
|
|
|
|
|
|
|
|
|
// const float f2 = d1*((int8_t) (v1 & 0xf) - 8);
|
|
|
|
const v128_t v0_0h = wasm_u8x16_shr(v0_0, 4);
|
|
|
|
// const float f3 = d1*((int8_t) (v1 >> 4) - 8);
|
|
|
|
const v128_t v1_0h = wasm_u8x16_shr(v1_0, 4);
|
|
|
|
|
|
|
|
|
|
|
|
// sumf += f0*f2 + f1*f3;
|
|
|
|
const v128_t v0_1l = wasm_v128_and(v0_1, m4b);
|
|
|
|
// }
|
|
|
|
const v128_t v1_1l = wasm_v128_and(v1_1, m4b);
|
|
|
|
//}
|
|
|
|
|
|
|
|
//printf("sumf scalar = %f\n", sumf);
|
|
|
|
|
|
|
|
//printf("--------\n");
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
//exit(0);
|
|
|
|
const v128_t v0_1h = wasm_u8x16_shr(v0_1, 4);
|
|
|
|
|
|
|
|
const v128_t v1_1h = wasm_u8x16_shr(v1_1, 4);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// sub 8
|
|
|
|
|
|
|
|
const v128_t v0_0ls = wasm_i8x16_sub(v0_0l, s8b);
|
|
|
|
|
|
|
|
const v128_t v1_0ls = wasm_i8x16_sub(v1_0l, s8b);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
const v128_t v0_0hs = wasm_i8x16_sub(v0_0h, s8b);
|
|
|
|
|
|
|
|
const v128_t v1_0hs = wasm_i8x16_sub(v1_0h, s8b);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
const v128_t v0_1ls = wasm_i8x16_sub(v0_1l, s8b);
|
|
|
|
|
|
|
|
const v128_t v1_1ls = wasm_i8x16_sub(v1_1l, s8b);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
const v128_t v0_1hs = wasm_i8x16_sub(v0_1h, s8b);
|
|
|
|
|
|
|
|
const v128_t v1_1hs = wasm_i8x16_sub(v1_1h, s8b);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// dot product into int16x8_t
|
|
|
|
|
|
|
|
const v128_t pl0l = wasm_i16x8_mul(wasm_i16x8_extend_low_i8x16(v0_0ls), wasm_i16x8_extend_low_i8x16(v1_0ls));
|
|
|
|
|
|
|
|
const v128_t pl0h = wasm_i16x8_mul(wasm_i16x8_extend_high_i8x16(v0_0ls), wasm_i16x8_extend_high_i8x16(v1_0ls));
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
const v128_t ph0l = wasm_i16x8_mul(wasm_i16x8_extend_low_i8x16(v0_0hs), wasm_i16x8_extend_low_i8x16(v1_0hs));
|
|
|
|
|
|
|
|
const v128_t ph0h = wasm_i16x8_mul(wasm_i16x8_extend_high_i8x16(v0_0hs), wasm_i16x8_extend_high_i8x16(v1_0hs));
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
const v128_t pl1l = wasm_i16x8_mul(wasm_i16x8_extend_low_i8x16(v0_1ls), wasm_i16x8_extend_low_i8x16(v1_1ls));
|
|
|
|
|
|
|
|
const v128_t pl1h = wasm_i16x8_mul(wasm_i16x8_extend_high_i8x16(v0_1ls), wasm_i16x8_extend_high_i8x16(v1_1ls));
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
const v128_t ph1l = wasm_i16x8_mul(wasm_i16x8_extend_low_i8x16(v0_1hs), wasm_i16x8_extend_low_i8x16(v1_1hs));
|
|
|
|
|
|
|
|
const v128_t ph1h = wasm_i16x8_mul(wasm_i16x8_extend_high_i8x16(v0_1hs), wasm_i16x8_extend_high_i8x16(v1_1hs));
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
const v128_t pl_0 = wasm_i16x8_add(pl0l, pl0h);
|
|
|
|
|
|
|
|
const v128_t ph_0 = wasm_i16x8_add(ph0l, ph0h);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
const v128_t pl_1 = wasm_i16x8_add(pl1l, pl1h);
|
|
|
|
|
|
|
|
const v128_t ph_1 = wasm_i16x8_add(ph1l, ph1h);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
const v128_t p_0 = wasm_i16x8_add(pl_0, ph_0);
|
|
|
|
|
|
|
|
const v128_t p_1 = wasm_i16x8_add(pl_1, ph_1);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
sum0 += d0_0*d1_0*(
|
|
|
|
|
|
|
|
wasm_i16x8_extract_lane(p_0, 0) + wasm_i16x8_extract_lane(p_0, 1) +
|
|
|
|
|
|
|
|
wasm_i16x8_extract_lane(p_0, 2) + wasm_i16x8_extract_lane(p_0, 3) +
|
|
|
|
|
|
|
|
wasm_i16x8_extract_lane(p_0, 4) + wasm_i16x8_extract_lane(p_0, 5) +
|
|
|
|
|
|
|
|
wasm_i16x8_extract_lane(p_0, 6) + wasm_i16x8_extract_lane(p_0, 7));
|
|
|
|
|
|
|
|
sum1 += d0_1*d1_1*(
|
|
|
|
|
|
|
|
wasm_i16x8_extract_lane(p_1, 0) + wasm_i16x8_extract_lane(p_1, 1) +
|
|
|
|
|
|
|
|
wasm_i16x8_extract_lane(p_1, 2) + wasm_i16x8_extract_lane(p_1, 3) +
|
|
|
|
|
|
|
|
wasm_i16x8_extract_lane(p_1, 4) + wasm_i16x8_extract_lane(p_1, 5) +
|
|
|
|
|
|
|
|
wasm_i16x8_extract_lane(p_1, 6) + wasm_i16x8_extract_lane(p_1, 7));
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
sumf = sum0 + sum1;
|
|
|
|
#else
|
|
|
|
#else
|
|
|
|
#error "not implemented for QK"
|
|
|
|
#error "not implemented for QK"
|
|
|
|
#endif
|
|
|
|
#endif
|
|
|
|