diff --git a/ggml.c b/ggml.c index 31ae2b0..232ae28 100644 --- a/ggml.c +++ b/ggml.c @@ -316,192 +316,426 @@ int64_t ggml_cycles_per_ms(void) { static const size_t CACHE_LINE_SIZE_F32 = CACHE_LINE_SIZE/sizeof(float); // -// fundamental operations +// simd mappings // -inline static void ggml_vec_set_i8(const int n, int8_t * x, const int8_t v) { for (int i = 0; i < n; ++i) x[i] = v; } - -inline static void ggml_vec_set_i16(const int n, int16_t * x, const int16_t v) { for (int i = 0; i < n; ++i) x[i] = v; } - -inline static void ggml_vec_set_i32(const int n, int32_t * x, const int32_t v) { for (int i = 0; i < n; ++i) x[i] = v; } - -inline static void ggml_vec_set_f16(const int n, ggml_fp16_t * x, const int32_t v) { for (int i = 0; i < n; ++i) x[i] = v; } - -inline static void ggml_vec_add_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i] + y[i]; } -inline static void ggml_vec_acc_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] += x[i]; } -inline static void ggml_vec_acc1_f32(const int n, float * y, const float v) { for (int i = 0; i < n; ++i) y[i] += v; } -inline static void ggml_vec_sub_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i] - y[i]; } -inline static void ggml_vec_set_f32 (const int n, float * x, const float v) { for (int i = 0; i < n; ++i) x[i] = v; } -inline static void ggml_vec_cpy_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = x[i]; } -inline static void ggml_vec_neg_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = -x[i]; } -inline static void ggml_vec_mul_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i]*y[i]; } -inline static void ggml_vec_div_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i]/y[i]; } +// we define a common set of C macros which map to specific intrinsics based on the current architecture +// we then implement the fundamental computation operations below using only these macros +// adding support for new architectures requires to define the corresponding SIMD macros +// +// GGML_F32_STEP / GGML_F16_STEP +// number of elements to process in a single step +// +// GGML_F32_EPR / GGML_F16_EPR +// number of elements to fit in a single register +// -inline static void ggml_vec_dot_f32(const int n, float * restrict s, const float * restrict x, const float * restrict y) { - ggml_float sumf = 0.0; #if defined(__ARM_NEON) && defined(__ARM_FEATURE_FMA) - // NEON 128-bit - const int n16 = (n & ~15); - - float32x4_t sum0 = vdupq_n_f32(0); - float32x4_t sum1 = vdupq_n_f32(0); - float32x4_t sum2 = vdupq_n_f32(0); - float32x4_t sum3 = vdupq_n_f32(0); - float32x4_t x0, x1, x2, x3; - float32x4_t y0, y1, y2, y3; +#define GGML_SIMD - for (int i = 0; i < n16; i += 16) { - x0 = vld1q_f32(x + i + 0); - x1 = vld1q_f32(x + i + 4); - x2 = vld1q_f32(x + i + 8); - x3 = vld1q_f32(x + i + 12); +// F32 NEON - y0 = vld1q_f32(y + i + 0); - y1 = vld1q_f32(y + i + 4); - y2 = vld1q_f32(y + i + 8); - y3 = vld1q_f32(y + i + 12); - - sum0 = vfmaq_f32(sum0, x0, y0); - sum1 = vfmaq_f32(sum1, x1, y1); - sum2 = vfmaq_f32(sum2, x2, y2); - sum3 = vfmaq_f32(sum3, x3, y3); - } +#define GGML_F32_STEP 16 +#define GGML_F32_EPR 4 - // reduce sum0..sum3 to sum0 - sum0 = vaddq_f32(sum0, sum1); - sum2 = vaddq_f32(sum2, sum3); - sum0 = vaddq_f32(sum0, sum2); - - sumf = vaddvq_f32(sum0); +#define GGML_F32x4 float32x4_t +#define GGML_F32x4_ZERO vdupq_n_f32(0.0f) +#define GGML_F32x4_SET1(x) vdupq_n_f32(x) +#define GGML_F32x4_LOAD vld1q_f32 +#define GGML_F32x4_STORE vst1q_f32 +#define GGML_F32x4_FMA(a, b, c) vfmaq_f32(a, b, c) +#define GGML_F32x4_ADD vaddq_f32 +#define GGML_F32x4_MUL vmulq_f32 +#if defined(__ARM_FEATURE_QRDMX) + #define GGML_F32x4_REDUCE_ONE(x) vaddvq_f32(x) +#else + #define GGML_F32x4_REDUCE_ONE(x) \ + (vgetq_lane_f32(x, 0) + \ + vgetq_lane_f32(x, 1) + \ + vgetq_lane_f32(x, 2) + \ + vgetq_lane_f32(x, 3)) +#endif +#define GGML_F32x4_REDUCE(res, x) \ +{ \ + for (int i = 0; i < GGML_F32_ARR/2; ++i) { \ + x[2*i] = vaddq_f32(x[2*i], x[2*i+1]); \ + } \ + for (int i = 0; i < GGML_F32_ARR/4; ++i) { \ + x[4*i] = vaddq_f32(x[4*i], x[4*i+2]); \ + } \ + for (int i = 0; i < GGML_F32_ARR/8; ++i) { \ + x[8*i] = vaddq_f32(x[8*i], x[8*i+4]); \ + } \ + res = GGML_F32x4_REDUCE_ONE(x[0]); \ +} + +#define GGML_F32_VEC GGML_F32x4 +#define GGML_F32_VEC_ZERO GGML_F32x4_ZERO +#define GGML_F32_VEC_SET1 GGML_F32x4_SET1 +#define GGML_F32_VEC_LOAD GGML_F32x4_LOAD +#define GGML_F32_VEC_STORE GGML_F32x4_STORE +#define GGML_F32_VEC_FMA GGML_F32x4_FMA +#define GGML_F32_VEC_ADD GGML_F32x4_ADD +#define GGML_F32_VEC_MUL GGML_F32x4_MUL +#define GGML_F32_VEC_REDUCE GGML_F32x4_REDUCE + +// F16 NEON - // leftovers - for (int i = n16; i < n; ++i) { - sumf += x[i]*y[i]; - } -#elif defined(__AVX2__) - // AVX 256-bit - const int n32 = (n & ~31); +#if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) + #define GGML_F16_STEP 32 + #define GGML_F16_EPR 8 + + #define GGML_F16x8 float16x8_t + #define GGML_F16x8_ZERO vdupq_n_f16(0.0f) + #define GGML_F16x8_SET1(x) vdupq_n_f16(x) + #define GGML_F16x8_LOAD vld1q_f16 + #define GGML_F16x8_STORE vst1q_f16 + #define GGML_F16x8_FMA(a, b, c) vfmaq_f16(a, b, c) + #define GGML_F16x8_ADD vaddq_f16 + #define GGML_F16x8_MUL vmulq_f16 + #define GGML_F16x8_REDUCE(res, x) \ + { \ + for (int i = 0; i < GGML_F16_ARR/2; ++i) { \ + x[2*i] = vaddq_f16(x[2*i], x[2*i+1]); \ + } \ + for (int i = 0; i < GGML_F16_ARR/4; ++i) { \ + x[4*i] = vaddq_f16(x[4*i], x[4*i+2]); \ + } \ + for (int i = 0; i < GGML_F16_ARR/8; ++i) { \ + x[8*i] = vaddq_f16(x[8*i], x[8*i+4]); \ + } \ + const float32x4_t t0 = vcvt_f32_f16(vget_low_f16 (x[0])); \ + const float32x4_t t1 = vcvt_f32_f16(vget_high_f16(x[0])); \ + res = vaddvq_f32(vaddq_f32(t0, t1)); \ + } + + #define GGML_F16_VEC GGML_F16x8 + #define GGML_F16_VEC_ZERO GGML_F16x8_ZERO + #define GGML_F16_VEC_SET1 GGML_F16x8_SET1 + #define GGML_F16_VEC_LOAD GGML_F16x8_LOAD + #define GGML_F16_VEC_STORE GGML_F16x8_STORE + #define GGML_F16_VEC_FMA GGML_F16x8_FMA + #define GGML_F16_VEC_ADD GGML_F16x8_ADD + #define GGML_F16_VEC_MUL GGML_F16x8_MUL + #define GGML_F16_VEC_REDUCE GGML_F16x8_REDUCE +#else + // if FP16 vector arithmetic is not supported, we use FP32 instead + // and take advantage of the vcvt_ functions to convert to/from FP16 + + #define GGML_F16_STEP 16 + #define GGML_F16_EPR 4 + + #define GGML_F32Cx4 float32x4_t + #define GGML_F32Cx4_ZERO vdupq_n_f32(0.0f) + #define GGML_F32Cx4_SET1(x) vdupq_n_f32(x) + #define GGML_F32Cx4_LOAD(x) vcvt_f32_f16(vld1_f16(x)) + #define GGML_F32Cx4_STORE(x, y) vst1_f16(x, vcvt_f16_f32(y)) + #define GGML_F32Cx4_FMA(a, b, c) vfmaq_f32(a, b, c) + #define GGML_F32Cx4_ADD vaddq_f32 + #define GGML_F32Cx4_MUL vmulq_f32 + #define GGML_F32Cx4_REDUCE GGML_F32x4_REDUCE + + #define GGML_F16_VEC GGML_F32Cx4 + #define GGML_F16_VEC_ZERO GGML_F32Cx4_ZERO + #define GGML_F16_VEC_SET1 GGML_F32Cx4_SET1 + #define GGML_F16_VEC_LOAD GGML_F32Cx4_LOAD + #define GGML_F16_VEC_STORE GGML_F32Cx4_STORE + #define GGML_F16_VEC_FMA GGML_F32Cx4_FMA + #define GGML_F16_VEC_ADD GGML_F32Cx4_ADD + #define GGML_F16_VEC_MUL GGML_F32Cx4_MUL + #define GGML_F16_VEC_REDUCE GGML_F32Cx4_REDUCE +#endif - __m256 sum0 = _mm256_setzero_ps(); - __m256 sum1 = _mm256_setzero_ps(); - __m256 sum2 = _mm256_setzero_ps(); - __m256 sum3 = _mm256_setzero_ps(); +#elif defined(__AVX__) - __m256 x0, x1, x2, x3; - __m256 y0, y1, y2, y3; +#define GGML_SIMD - for (int i = 0; i < n32; i += 32) { - x0 = _mm256_loadu_ps(x + i + 0); - x1 = _mm256_loadu_ps(x + i + 8); - x2 = _mm256_loadu_ps(x + i + 16); - x3 = _mm256_loadu_ps(x + i + 24); +// F32 AVX - y0 = _mm256_loadu_ps(y + i + 0); - y1 = _mm256_loadu_ps(y + i + 8); - y2 = _mm256_loadu_ps(y + i + 16); - y3 = _mm256_loadu_ps(y + i + 24); +#define GGML_F32_STEP 32 +#define GGML_F32_EPR 8 - sum0 = _mm256_fmadd_ps(x0, y0, sum0); - sum1 = _mm256_fmadd_ps(x1, y1, sum1); - sum2 = _mm256_fmadd_ps(x2, y2, sum2); - sum3 = _mm256_fmadd_ps(x3, y3, sum3); - } +#define GGML_F32x8 __m256 +#define GGML_F32x8_ZERO _mm256_setzero_ps() +#define GGML_F32x8_SET1(x) _mm256_set1_ps(x) +#define GGML_F32x8_LOAD _mm256_loadu_ps +#define GGML_F32x8_STORE _mm256_storeu_ps +#if defined(__FMA__) + #define GGML_F32x8_FMA(a, b, c) _mm256_fmadd_ps(b, c, a) +#else + #define GGML_F32x8_FMA(a, b, c) _mm256_add_ps(_mm256_mul_ps(b, c), a) +#endif +#define GGML_F32x8_ADD _mm256_add_ps +#define GGML_F32x8_MUL _mm256_mul_ps +#define GGML_F32x8_REDUCE(res, x) \ +{ \ + for (int i = 0; i < GGML_F32_ARR/2; ++i) { \ + x[2*i] = _mm256_add_ps(x[2*i], x[2*i+1]); \ + } \ + for (int i = 0; i < GGML_F32_ARR/4; ++i) { \ + x[4*i] = _mm256_add_ps(x[4*i], x[4*i+2]); \ + } \ + for (int i = 0; i < GGML_F32_ARR/8; ++i) { \ + x[8*i] = _mm256_add_ps(x[8*i], x[8*i+4]); \ + } \ + const __m128 t0 = _mm_add_ps(_mm256_castps256_ps128(x[0]), \ + _mm256_extractf128_ps(x[0], 1)); \ + const __m128 t1 = _mm_hadd_ps(t0, t0); \ + res = _mm_cvtss_f32(_mm_hadd_ps(t1, t1)); \ +} +// TODO: is this optimal ? + +#define GGML_F32_VEC GGML_F32x8 +#define GGML_F32_VEC_ZERO GGML_F32x8_ZERO +#define GGML_F32_VEC_SET1 GGML_F32x8_SET1 +#define GGML_F32_VEC_LOAD GGML_F32x8_LOAD +#define GGML_F32_VEC_STORE GGML_F32x8_STORE +#define GGML_F32_VEC_FMA GGML_F32x8_FMA +#define GGML_F32_VEC_ADD GGML_F32x8_ADD +#define GGML_F32_VEC_MUL GGML_F32x8_MUL +#define GGML_F32_VEC_REDUCE GGML_F32x8_REDUCE + +// F16 AVX + +#define GGML_F16_STEP 32 +#define GGML_F16_EPR 8 + +// F16 arithmetic is not supported by AVX, so we use F32 instead +// we take advantage of the _mm256_cvt intrinsics to convert F16 <-> F32 + +#define GGML_F32Cx8 __m256 +#define GGML_F32Cx8_ZERO _mm256_setzero_ps() +#define GGML_F32Cx8_SET1(x) _mm256_set1_ps(x) +#define GGML_F32Cx8_LOAD(x) _mm256_cvtph_ps(_mm_loadu_si128((__m128i *)(x))) +#define GGML_F32Cx8_STORE(x, y) _mm_storeu_si128((__m128i *)(x), _mm256_cvtps_ph(y, 0)) +#define GGML_F32Cx8_FMA GGML_F32x8_FMA +#define GGML_F32Cx8_ADD _mm256_add_ps +#define GGML_F32Cx8_MUL _mm256_mul_ps +#define GGML_F32Cx8_REDUCE GGML_F32x8_REDUCE + +#define GGML_F16_VEC GGML_F32Cx8 +#define GGML_F16_VEC_ZERO GGML_F32Cx8_ZERO +#define GGML_F16_VEC_SET1 GGML_F32Cx8_SET1 +#define GGML_F16_VEC_LOAD GGML_F32Cx8_LOAD +#define GGML_F16_VEC_STORE GGML_F32Cx8_STORE +#define GGML_F16_VEC_FMA GGML_F32Cx8_FMA +#define GGML_F16_VEC_ADD GGML_F32Cx8_ADD +#define GGML_F16_VEC_MUL GGML_F32Cx8_MUL +#define GGML_F16_VEC_REDUCE GGML_F32Cx8_REDUCE - sum0 = _mm256_add_ps(sum0, sum1); - sum2 = _mm256_add_ps(sum2, sum3); - sum0 = _mm256_add_ps(sum0, sum2); +#elif defined(__POWER9_VECTOR__) - const __m128 r4 = _mm_add_ps(_mm256_castps256_ps128(sum0), _mm256_extractf128_ps(sum0, 1)); - const __m128 r2 = _mm_add_ps(r4, _mm_movehl_ps(r4, r4)); - const __m128 r1 = _mm_add_ss(r2, _mm_movehdup_ps(r2)); +// TODO: uncomment this when it works +//#define GGML_SIMD + +// F32 POWER9 + +#define GGML_F32_STEP 32 +#define GGML_F32_EPR 8 + +// TODO: not tested !! +#define GGML_F32x4 __vector float +#define GGML_F32x4_ZERO (__vector float){0.0f, 0.0f, 0.0f, 0.0f} +#define GGML_F32x4_SET1(x) (__vector float){x, x, x, x} +#define GGML_F32x4_LOAD vec_vsx_ld +#define GGML_F32x4_STORE vec_vsx_st +#define GGML_F32x4_FMA(a, b, c) vec_madd(b, c, a) +#define GGML_F32x4_ADD vec_add +#define GGML_F32x4_MUL vec_mul +#define GGML_F32x4_REDUCE(res, x) \ +{ \ + for (int i = 0; i < GGML_F32_ARR/2; ++i) { \ + x[2*i] = vec_add(x[2*i], x[2*i+1]); \ + } \ + for (int i = 0; i < GGML_F32_ARR/4; ++i) { \ + x[4*i] = vec_add(x[4*i], x[4*i+2]); \ + } \ + for (int i = 0; i < GGML_F32_ARR/8; ++i) { \ + x[8*i] = vec_add(x[8*i], x[8*i+4]); \ + } \ + res = vec_extract(x[0], 0) + \ + vec_extract(x[0], 1) + \ + vec_extract(x[0], 2) + \ + vec_extract(x[0], 3); \ +} + +#define GGML_F32_VEC GGML_F32x4 +#define GGML_F32_VEC_ZERO GGML_F32x4_ZERO +#define GGML_F32_VEC_SET1 GGML_F32x4_SET1 +#define GGML_F32_VEC_LOAD GGML_F32x4_LOAD +#define GGML_F32_VEC_STORE GGML_F32x4_STORE +#define GGML_F32_VEC_FMA GGML_F32x4_FMA +#define GGML_F32_VEC_ADD GGML_F32x4_ADD +#define GGML_F32_VEC_MUL GGML_F32x4_MUL +#define GGML_F32_VEC_REDUCE GGML_F32x4_REDUCE + +// F16 POWER9 +// TODO: implement here +// ... - sumf = _mm_cvtss_f32(r1); +#elif defined(__wasm_simd128__) - // leftovers - for (int i = n32; i < n; ++i) { - sumf += x[i]*y[i]; - } -#elif defined(__AVX__) - // AVX 256-bit - const int n32 = (n & ~31); +#define GGML_SIMD + +// F32 WASM + +#define GGML_F32_STEP 16 +#define GGML_F32_EPR 4 + +#define GGML_F32x4 v128_t +#define GGML_F32x4_ZERO wasm_f32x4_splat(0.0f) +#define GGML_F32x4_SET1(x) wasm_f32x4_splat(x) +#define GGML_F32x4_LOAD wasm_v128_load +#define GGML_F32x4_STORE wasm_v128_store +#define GGML_F32x4_FMA(a, b, c) wasm_f32x4_add(wasm_f32x4_mul(b, c), a) +#define GGML_F32x4_ADD wasm_f32x4_add +#define GGML_F32x4_MUL wasm_f32x4_mul +#define GGML_F32x4_REDUCE(res, x) \ +{ \ + for (int i = 0; i < GGML_F32_ARR/2; ++i) { \ + x[2*i] = wasm_f32x4_add(x[2*i], x[2*i+1]); \ + } \ + for (int i = 0; i < GGML_F32_ARR/4; ++i) { \ + x[4*i] = wasm_f32x4_add(x[4*i], x[4*i+2]); \ + } \ + for (int i = 0; i < GGML_F32_ARR/8; ++i) { \ + x[8*i] = wasm_f32x4_add(x[8*i], x[8*i+4]); \ + } \ + res = wasm_f32x4_extract_lane(x[0], 0) + \ + wasm_f32x4_extract_lane(x[0], 1) + \ + wasm_f32x4_extract_lane(x[0], 2) + \ + wasm_f32x4_extract_lane(x[0], 3); \ +} + +#define GGML_F32_VEC GGML_F32x4 +#define GGML_F32_VEC_ZERO GGML_F32x4_ZERO +#define GGML_F32_VEC_SET1 GGML_F32x4_SET1 +#define GGML_F32_VEC_LOAD GGML_F32x4_LOAD +#define GGML_F32_VEC_STORE GGML_F32x4_STORE +#define GGML_F32_VEC_FMA GGML_F32x4_FMA +#define GGML_F32_VEC_ADD GGML_F32x4_ADD +#define GGML_F32_VEC_MUL GGML_F32x4_MUL +#define GGML_F32_VEC_REDUCE GGML_F32x4_REDUCE + +// F16 WASM + +#define GGML_F16_STEP 16 +#define GGML_F16_EPR 4 + +inline static v128_t __wasm_f16x4_load(const ggml_fp16_t * p) { + float tmp[4]; + + tmp[0] = GGML_FP16_TO_FP32(p[0]); + tmp[1] = GGML_FP16_TO_FP32(p[1]); + tmp[2] = GGML_FP16_TO_FP32(p[2]); + tmp[3] = GGML_FP16_TO_FP32(p[3]); + + return wasm_v128_load(tmp); +} + +inline static void __wasm_f16x4_store(ggml_fp16_t * p, v128_t x) { + float tmp[4]; + + wasm_v128_store(tmp, x); + + p[0] = GGML_FP32_TO_FP16(tmp[0]); + p[1] = GGML_FP32_TO_FP16(tmp[1]); + p[2] = GGML_FP32_TO_FP16(tmp[2]); + p[3] = GGML_FP32_TO_FP16(tmp[3]); +} + +#define GGML_F16x4 v128_t +#define GGML_F16x4_ZERO wasm_f32x4_splat(0.0f) +#define GGML_F16x4_SET1(x) wasm_f32x4_splat(x) +#define GGML_F16x4_LOAD(x) __wasm_f16x4_load(x) +#define GGML_F16x4_STORE(x, y) __wasm_f16x4_store(x, y) +#define GGML_F16x4_FMA GGML_F32x4_FMA +#define GGML_F16x4_ADD wasm_f32x4_add +#define GGML_F16x4_MUL wasm_f32x4_mul +#define GGML_F16x4_REDUCE(res, x) \ +{ \ + for (int i = 0; i < GGML_F16_ARR/2; ++i) { \ + x[2*i] = wasm_f32x4_add(x[2*i], x[2*i+1]); \ + } \ + for (int i = 0; i < GGML_F16_ARR/4; ++i) { \ + x[4*i] = wasm_f32x4_add(x[4*i], x[4*i+2]); \ + } \ + for (int i = 0; i < GGML_F16_ARR/8; ++i) { \ + x[8*i] = wasm_f32x4_add(x[8*i], x[8*i+4]); \ + } \ + res = wasm_f32x4_extract_lane(x[0], 0) + \ + wasm_f32x4_extract_lane(x[0], 1) + \ + wasm_f32x4_extract_lane(x[0], 2) + \ + wasm_f32x4_extract_lane(x[0], 3); \ +} + +#define GGML_F16_VEC GGML_F16x4 +#define GGML_F16_VEC_ZERO GGML_F16x4_ZERO +#define GGML_F16_VEC_SET1 GGML_F16x4_SET1 +#define GGML_F16_VEC_LOAD GGML_F16x4_LOAD +#define GGML_F16_VEC_STORE GGML_F16x4_STORE +#define GGML_F16_VEC_FMA GGML_F16x4_FMA +#define GGML_F16_VEC_ADD GGML_F16x4_ADD +#define GGML_F16_VEC_MUL GGML_F16x4_MUL +#define GGML_F16_VEC_REDUCE GGML_F16x4_REDUCE - __m256 sum0 = _mm256_setzero_ps(); - __m256 sum1 = _mm256_setzero_ps(); - __m256 sum2 = _mm256_setzero_ps(); - __m256 sum3 = _mm256_setzero_ps(); +#endif - __m256 x0, x1, x2, x3; - __m256 y0, y1, y2, y3; +// GGML_F32_ARR / GGML_F16_ARR +// number of registers to use per step +#ifdef GGML_SIMD +#define GGML_F32_ARR (GGML_F32_STEP/GGML_F32_EPR) +#define GGML_F16_ARR (GGML_F16_STEP/GGML_F16_EPR) +#endif - for (int i = 0; i < n32; i += 32) { - x0 = _mm256_loadu_ps(x + i + 0); - x1 = _mm256_loadu_ps(x + i + 8); - x2 = _mm256_loadu_ps(x + i + 16); - x3 = _mm256_loadu_ps(x + i + 24); +// +// fundamental operations +// - y0 = _mm256_loadu_ps(y + i + 0); - y1 = _mm256_loadu_ps(y + i + 8); - y2 = _mm256_loadu_ps(y + i + 16); - y3 = _mm256_loadu_ps(y + i + 24); +inline static void ggml_vec_set_i8(const int n, int8_t * x, const int8_t v) { for (int i = 0; i < n; ++i) x[i] = v; } - sum0 = _mm256_add_ps(_mm256_mul_ps(x0, y0), sum0); - sum1 = _mm256_add_ps(_mm256_mul_ps(x1, y1), sum1); - sum2 = _mm256_add_ps(_mm256_mul_ps(x2, y2), sum2); - sum3 = _mm256_add_ps(_mm256_mul_ps(x3, y3), sum3); - } +inline static void ggml_vec_set_i16(const int n, int16_t * x, const int16_t v) { for (int i = 0; i < n; ++i) x[i] = v; } - sum0 = _mm256_add_ps(sum0, sum1); - sum2 = _mm256_add_ps(sum2, sum3); - sum0 = _mm256_add_ps(sum0, sum2); +inline static void ggml_vec_set_i32(const int n, int32_t * x, const int32_t v) { for (int i = 0; i < n; ++i) x[i] = v; } - const __m128 r4 = _mm_add_ps(_mm256_castps256_ps128(sum0), _mm256_extractf128_ps(sum0, 1)); - const __m128 r2 = _mm_add_ps(r4, _mm_movehl_ps(r4, r4)); - const __m128 r1 = _mm_add_ss(r2, _mm_movehdup_ps(r2)); +inline static void ggml_vec_set_f16(const int n, ggml_fp16_t * x, const int32_t v) { for (int i = 0; i < n; ++i) x[i] = v; } - sumf = _mm_cvtss_f32(r1); +inline static void ggml_vec_add_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i] + y[i]; } +inline static void ggml_vec_acc_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] += x[i]; } +inline static void ggml_vec_acc1_f32(const int n, float * y, const float v) { for (int i = 0; i < n; ++i) y[i] += v; } +inline static void ggml_vec_sub_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i] - y[i]; } +inline static void ggml_vec_set_f32 (const int n, float * x, const float v) { for (int i = 0; i < n; ++i) x[i] = v; } +inline static void ggml_vec_cpy_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = x[i]; } +inline static void ggml_vec_neg_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = -x[i]; } +inline static void ggml_vec_mul_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i]*y[i]; } +inline static void ggml_vec_div_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i]/y[i]; } - // leftovers - for (int i = n32; i < n; ++i) { - sumf += x[i]*y[i]; - } -#elif defined(__wasm_simd128__) - // WASM 128-bit - const int n16 = (n & ~15); +inline static void ggml_vec_dot_f32(const int n, float * restrict s, const float * restrict x, const float * restrict y) { + ggml_float sumf = 0.0; - v128_t sum0 = wasm_f32x4_splat(0); - v128_t sum1 = wasm_f32x4_splat(0); - v128_t sum2 = wasm_f32x4_splat(0); - v128_t sum3 = wasm_f32x4_splat(0); +#ifdef GGML_SIMD + const int np = (n & ~(GGML_F32_STEP - 1)); - v128_t x0, x1, x2, x3; - v128_t y0, y1, y2, y3; + GGML_F32_VEC sum[GGML_F32_ARR] = { GGML_F32_VEC_ZERO }; - for (int i = 0; i < n16; i += 16) { - x0 = wasm_v128_load(x + i + 0); - x1 = wasm_v128_load(x + i + 4); - x2 = wasm_v128_load(x + i + 8); - x3 = wasm_v128_load(x + i + 12); + GGML_F32_VEC ax[GGML_F32_ARR]; + GGML_F32_VEC ay[GGML_F32_ARR]; - y0 = wasm_v128_load(y + i + 0); - y1 = wasm_v128_load(y + i + 4); - y2 = wasm_v128_load(y + i + 8); - y3 = wasm_v128_load(y + i + 12); + for (int i = 0; i < np; i += GGML_F32_STEP) { + for (int j = 0; j < GGML_F32_ARR; j++) { + ax[j] = GGML_F32_VEC_LOAD(x + i + j*GGML_F32_EPR); + ay[j] = GGML_F32_VEC_LOAD(y + i + j*GGML_F32_EPR); - sum0 = wasm_f32x4_add(sum0, wasm_f32x4_mul(x0, y0)); - sum1 = wasm_f32x4_add(sum1, wasm_f32x4_mul(x1, y1)); - sum2 = wasm_f32x4_add(sum2, wasm_f32x4_mul(x2, y2)); - sum3 = wasm_f32x4_add(sum3, wasm_f32x4_mul(x3, y3)); + sum[j] = GGML_F32_VEC_FMA(sum[j], ax[j], ay[j]); + } } - sum0 = wasm_f32x4_add(sum0, sum1); - sum2 = wasm_f32x4_add(sum2, sum3); - sum0 = wasm_f32x4_add(sum0, sum2); - - sumf = wasm_f32x4_extract_lane(sum0, 0) + wasm_f32x4_extract_lane(sum0, 1) + wasm_f32x4_extract_lane(sum0, 2) + wasm_f32x4_extract_lane(sum0, 3); + // reduce sum0..sum3 to sum0 + GGML_F32_VEC_REDUCE(sumf, sum); // leftovers - for (int i = n16; i < n; ++i) { + for (int i = np; i < n; ++i) { sumf += x[i]*y[i]; } #else @@ -516,194 +750,34 @@ inline static void ggml_vec_dot_f32(const int n, float * restrict s, const float inline static void ggml_vec_dot_f16(const int n, float * restrict s, ggml_fp16_t * restrict x, ggml_fp16_t * restrict y) { ggml_float sumf = 0.0; -#if defined(__ARM_NEON) && defined(__ARM_FEATURE_FMA) - const int n32 = (n & ~31); -#if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) - float16x8_t sum0 = vdupq_n_f16(0); - float16x8_t sum1 = vdupq_n_f16(0); - float16x8_t sum2 = vdupq_n_f16(0); - float16x8_t sum3 = vdupq_n_f16(0); +#if defined(GGML_SIMD) + const int np = (n & ~(GGML_F16_STEP - 1)); - float16x8_t x0, x1, x2, x3; - float16x8_t y0, y1, y2, y3; + GGML_F16_VEC sum[GGML_F16_ARR] = { GGML_F16_VEC_ZERO }; - for (int i = 0; i < n32; i += 32) { - x0 = vld1q_f16(x + i + 0 ); - x1 = vld1q_f16(x + i + 8 ); - x2 = vld1q_f16(x + i + 16); - x3 = vld1q_f16(x + i + 24); + GGML_F16_VEC ax[GGML_F16_ARR]; + GGML_F16_VEC ay[GGML_F16_ARR]; - y0 = vld1q_f16(y + i + 0 ); - y1 = vld1q_f16(y + i + 8 ); - y2 = vld1q_f16(y + i + 16); - y3 = vld1q_f16(y + i + 24); + for (int i = 0; i < np; i += GGML_F16_STEP) { + for (int j = 0; j < GGML_F16_ARR; j++) { + ax[j] = GGML_F16_VEC_LOAD(x + i + j*GGML_F16_EPR); + ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR); - sum0 = vfmaq_f16(sum0, x0, y0); - sum1 = vfmaq_f16(sum1, x1, y1); - sum2 = vfmaq_f16(sum2, x2, y2); - sum3 = vfmaq_f16(sum3, x3, y3); + sum[j] = GGML_F16_VEC_FMA(sum[j], ax[j], ay[j]); + } } // reduce sum0..sum3 to sum0 - sum0 = vaddq_f16(sum0, sum1); - sum2 = vaddq_f16(sum2, sum3); - sum0 = vaddq_f16(sum0, sum2); - - // load sum0 into 2 float32x4_t - float32x4_t sum0f32 = vcvt_f32_f16(vget_low_f16(sum0)); - float32x4_t sum1f32 = vcvt_f32_f16(vget_high_f16(sum0)); - - // reduce sum0f32 and sum1f32 to sumf - sum0f32 = vaddq_f32(sum0f32, sum1f32); - sumf = vaddvq_f32(sum0f32); -#else - float32x4_t sum0 = vdupq_n_f32(0); - float32x4_t sum1 = vdupq_n_f32(0); - float32x4_t sum2 = vdupq_n_f32(0); - float32x4_t sum3 = vdupq_n_f32(0); - float32x4_t sum4 = vdupq_n_f32(0); - float32x4_t sum5 = vdupq_n_f32(0); - float32x4_t sum6 = vdupq_n_f32(0); - float32x4_t sum7 = vdupq_n_f32(0); - - float32x4_t x0, x1, x2, x3, x4, x5, x6, x7; - float32x4_t y0, y1, y2, y3, y4, y5, y6, y7; - - for (int i = 0; i < n32; i += 32) { - x0 = vcvt_f32_f16(vld1_f16(x + i + 0 )); - x1 = vcvt_f32_f16(vld1_f16(x + i + 4 )); - x2 = vcvt_f32_f16(vld1_f16(x + i + 8 )); - x3 = vcvt_f32_f16(vld1_f16(x + i + 12)); - x4 = vcvt_f32_f16(vld1_f16(x + i + 16)); - x5 = vcvt_f32_f16(vld1_f16(x + i + 20)); - x6 = vcvt_f32_f16(vld1_f16(x + i + 24)); - x7 = vcvt_f32_f16(vld1_f16(x + i + 28)); - - y0 = vcvt_f32_f16(vld1_f16(y + i + 0 )); - y1 = vcvt_f32_f16(vld1_f16(y + i + 4 )); - y2 = vcvt_f32_f16(vld1_f16(y + i + 8 )); - y3 = vcvt_f32_f16(vld1_f16(y + i + 12)); - y4 = vcvt_f32_f16(vld1_f16(y + i + 16)); - y5 = vcvt_f32_f16(vld1_f16(y + i + 20)); - y6 = vcvt_f32_f16(vld1_f16(y + i + 24)); - y7 = vcvt_f32_f16(vld1_f16(y + i + 28)); - - sum0 = vfmaq_f32(sum0, x0, y0); - sum1 = vfmaq_f32(sum1, x1, y1); - sum2 = vfmaq_f32(sum2, x2, y2); - sum3 = vfmaq_f32(sum3, x3, y3); - sum4 = vfmaq_f32(sum4, x4, y4); - sum5 = vfmaq_f32(sum5, x5, y5); - sum6 = vfmaq_f32(sum6, x6, y6); - sum7 = vfmaq_f32(sum7, x7, y7); - } - - // reduce sum0..sum7 to sum0 - sum0 = vaddq_f32(sum0, sum1); - sum2 = vaddq_f32(sum2, sum3); - sum4 = vaddq_f32(sum4, sum5); - sum6 = vaddq_f32(sum6, sum7); - sum0 = vaddq_f32(sum0, sum2); - sum4 = vaddq_f32(sum4, sum6); - sum0 = vaddq_f32(sum0, sum4); - - sumf = vaddvq_f32(sum0); -#endif + GGML_F16_VEC_REDUCE(sumf, sum); // leftovers - for (int i = n32; i < n; ++i) { - sumf += GGML_FP16_TO_FP32(x[i])*GGML_FP16_TO_FP32(y[i]); - } -#elif defined(__AVX2__) - // AVX 256-bit - const int n32 = (n & ~31); - - __m256 sum0 = _mm256_setzero_ps(); - __m256 sum1 = _mm256_setzero_ps(); - __m256 sum2 = _mm256_setzero_ps(); - __m256 sum3 = _mm256_setzero_ps(); - - __m256 x0, x1, x2, x3; - __m256 y0, y1, y2, y3; - - for (int i = 0; i < n32; i += 32) { - x0 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(x + i + 0 ))); - x1 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(x + i + 8 ))); - x2 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(x + i + 16))); - x3 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(x + i + 24))); - - y0 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(y + i + 0 ))); - y1 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(y + i + 8 ))); - y2 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(y + i + 16))); - y3 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(y + i + 24))); - - sum0 = _mm256_fmadd_ps(x0, y0, sum0); - sum1 = _mm256_fmadd_ps(x1, y1, sum1); - sum2 = _mm256_fmadd_ps(x2, y2, sum2); - sum3 = _mm256_fmadd_ps(x3, y3, sum3); - } - - const __m256 sum01 = _mm256_add_ps(sum0, sum1); - const __m256 sum23 = _mm256_add_ps(sum2, sum3); - const __m256 sum0123 = _mm256_add_ps(sum01, sum23); - - const __m128 r4 = _mm_add_ps(_mm256_castps256_ps128(sum0123), _mm256_extractf128_ps(sum0123, 1)); - const __m128 r2 = _mm_add_ps(r4, _mm_movehl_ps(r4, r4)); - const __m128 r1 = _mm_add_ss(r2, _mm_movehdup_ps(r2)); - - sumf = _mm_cvtss_f32(r1); - - // leftovers - for (int i = n32; i < n; ++i) { - //GGML_ASSERT(false); - sumf += GGML_FP16_TO_FP32(x[i])*GGML_FP16_TO_FP32(y[i]); - } -#elif defined(__AVX__) - // AVX 256-bit - const int n32 = (n & ~31); - - __m256 sum0 = _mm256_setzero_ps(); - __m256 sum1 = _mm256_setzero_ps(); - __m256 sum2 = _mm256_setzero_ps(); - __m256 sum3 = _mm256_setzero_ps(); - - __m256 x0, x1, x2, x3; - __m256 y0, y1, y2, y3; - - for (int i = 0; i < n32; i += 32) { - x0 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(x + i + 0 ))); - x1 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(x + i + 8 ))); - x2 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(x + i + 16))); - x3 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(x + i + 24))); - - y0 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(y + i + 0 ))); - y1 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(y + i + 8 ))); - y2 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(y + i + 16))); - y3 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(y + i + 24))); - - sum0 = _mm256_add_ps(_mm256_mul_ps(x0, y0), sum0); - sum1 = _mm256_add_ps(_mm256_mul_ps(x1, y1), sum1); - sum2 = _mm256_add_ps(_mm256_mul_ps(x2, y2), sum2); - sum3 = _mm256_add_ps(_mm256_mul_ps(x3, y3), sum3); - } - - const __m256 sum01 = _mm256_add_ps(sum0, sum1); - const __m256 sum23 = _mm256_add_ps(sum2, sum3); - const __m256 sum0123 = _mm256_add_ps(sum01, sum23); - - const __m128 r4 = _mm_add_ps(_mm256_castps256_ps128(sum0123), _mm256_extractf128_ps(sum0123, 1)); - const __m128 r2 = _mm_add_ps(r4, _mm_movehl_ps(r4, r4)); - const __m128 r1 = _mm_add_ss(r2, _mm_movehdup_ps(r2)); - - sumf = _mm_cvtss_f32(r1); - - // leftovers - for (int i = n32; i < n; ++i) { - //GGML_ASSERT(false); + for (int i = np; i < n; ++i) { sumf += GGML_FP16_TO_FP32(x[i])*GGML_FP16_TO_FP32(y[i]); } #elif defined(__POWER9_VECTOR__) + // TODO: this is temporary because I cannot fit it in the GGML_SIMD pattern like all other architectures without + // being able to test it. hoping someone with access to a POWER9 machine can help out here. const int n32 = (n & ~31); vector float sum0 = vec_splats (0.0f); @@ -754,54 +828,6 @@ inline static void ggml_vec_dot_f16(const int n, float * restrict s, ggml_fp16_t for (int i = n32; i < n; ++i) { sumf += GGML_FP16_TO_FP32(x[i])*GGML_FP16_TO_FP32(y[i]); } -#elif defined(__wasm_simd128__) - // WASM 128-bit - const int n16 = (n & ~15); - - v128_t sum0 = wasm_f32x4_splat(0.0f); - v128_t sum1 = wasm_f32x4_splat(0.0f); - v128_t sum2 = wasm_f32x4_splat(0.0f); - v128_t sum3 = wasm_f32x4_splat(0.0f); - - v128_t x0, x1, x2, x3; - v128_t y0, y1, y2, y3; - - float tx[16]; - float ty[16]; - - for (int i = 0; i < n16; i += 16) { - for (int k = 0; k < 16; ++k) { - tx[k] = GGML_FP16_TO_FP32(x[i + k]); - ty[k] = GGML_FP16_TO_FP32(y[i + k]); - } - - x0 = wasm_v128_load(tx + 0); - x1 = wasm_v128_load(tx + 4); - x2 = wasm_v128_load(tx + 8); - x3 = wasm_v128_load(tx + 12); - - y0 = wasm_v128_load(ty + 0); - y1 = wasm_v128_load(ty + 4); - y2 = wasm_v128_load(ty + 8); - y3 = wasm_v128_load(ty + 12); - - sum0 = wasm_f32x4_add(sum0, wasm_f32x4_mul(x0, y0)); - sum1 = wasm_f32x4_add(sum1, wasm_f32x4_mul(x1, y1)); - sum2 = wasm_f32x4_add(sum2, wasm_f32x4_mul(x2, y2)); - sum3 = wasm_f32x4_add(sum3, wasm_f32x4_mul(x3, y3)); - } - - sum0 = wasm_f32x4_add(sum0, sum1); - sum2 = wasm_f32x4_add(sum2, sum3); - sum0 = wasm_f32x4_add(sum0, sum2); - - sumf = wasm_f32x4_extract_lane(sum0, 0) + wasm_f32x4_extract_lane(sum0, 1) + wasm_f32x4_extract_lane(sum0, 2) + wasm_f32x4_extract_lane(sum0, 3); - - // leftovers - for (int i = n16; i < n; ++i) { - //GGML_ASSERT(false); - sumf += GGML_FP16_TO_FP32(x[i])*GGML_FP16_TO_FP32(y[i]); - } #else for (int i = 0; i < n; ++i) { sumf += GGML_FP16_TO_FP32(x[i])*GGML_FP16_TO_FP32(y[i]); @@ -812,144 +838,26 @@ inline static void ggml_vec_dot_f16(const int n, float * restrict s, ggml_fp16_t } inline static void ggml_vec_mad_f32(const int n, float * restrict y, const float * restrict x, const float v) { -#if defined(__ARM_NEON) && defined(__ARM_FEATURE_FMA) - // NEON 128-bit - const int n16 = (n & ~15); - - const float32x4_t v4 = vdupq_n_f32(v); - - float32x4_t x0, x1, x2, x3; - float32x4_t y0, y1, y2, y3; - - for (int i = 0; i < n16; i += 16) { - x0 = vld1q_f32(x + i + 0); - x1 = vld1q_f32(x + i + 4); - x2 = vld1q_f32(x + i + 8); - x3 = vld1q_f32(x + i + 12); - - y0 = vld1q_f32(y + i + 0); - y1 = vld1q_f32(y + i + 4); - y2 = vld1q_f32(y + i + 8); - y3 = vld1q_f32(y + i + 12); - - y0 = vfmaq_f32(y0, x0, v4); - y1 = vfmaq_f32(y1, x1, v4); - y2 = vfmaq_f32(y2, x2, v4); - y3 = vfmaq_f32(y3, x3, v4); - - vst1q_f32(y + i + 0, y0); - vst1q_f32(y + i + 4, y1); - vst1q_f32(y + i + 8, y2); - vst1q_f32(y + i + 12, y3); - } - - // leftovers - for (int i = n16; i < n; ++i) { - y[i] += x[i]*v; - } -#elif defined(__AVX2__) - // AVX 256-bit - const int n32 = (n & ~31); - - const __m256 v4 = _mm256_set1_ps(v); +#if defined(GGML_SIMD) + const int np = (n & ~(GGML_F32_STEP - 1)); - __m256 x0, x1, x2, x3; - __m256 y0, y1, y2, y3; - - for (int i = 0; i < n32; i += 32) { - x0 = _mm256_loadu_ps(x + i + 0); - x1 = _mm256_loadu_ps(x + i + 8); - x2 = _mm256_loadu_ps(x + i + 16); - x3 = _mm256_loadu_ps(x + i + 24); + GGML_F32_VEC vx = GGML_F32_VEC_SET1(v); - y0 = _mm256_loadu_ps(y + i + 0); - y1 = _mm256_loadu_ps(y + i + 8); - y2 = _mm256_loadu_ps(y + i + 16); - y3 = _mm256_loadu_ps(y + i + 24); + GGML_F32_VEC ax[GGML_F32_ARR]; + GGML_F32_VEC ay[GGML_F32_ARR]; - y0 = _mm256_fmadd_ps(x0, v4, y0); - y1 = _mm256_fmadd_ps(x1, v4, y1); - y2 = _mm256_fmadd_ps(x2, v4, y2); - y3 = _mm256_fmadd_ps(x3, v4, y3); + for (int i = 0; i < np; i += GGML_F32_STEP) { + for (int j = 0; j < GGML_F32_ARR; j++) { + ax[j] = GGML_F32_VEC_LOAD(x + i + j*GGML_F32_EPR); + ay[j] = GGML_F32_VEC_LOAD(y + i + j*GGML_F32_EPR); + ay[j] = GGML_F32_VEC_FMA(ay[j], ax[j], vx); - _mm256_storeu_ps(y + i + 0, y0); - _mm256_storeu_ps(y + i + 8, y1); - _mm256_storeu_ps(y + i + 16, y2); - _mm256_storeu_ps(y + i + 24, y3); - } - - // leftovers - for (int i = n32; i < n; ++i) { - y[i] += x[i]*v; - } -#elif defined(__AVX__) - // AVX 256-bit - const int n32 = (n & ~31); - - const __m256 v4 = _mm256_set1_ps(v); - - __m256 x0, x1, x2, x3; - __m256 y0, y1, y2, y3; - - for (int i = 0; i < n32; i += 32) { - x0 = _mm256_loadu_ps(x + i + 0); - x1 = _mm256_loadu_ps(x + i + 8); - x2 = _mm256_loadu_ps(x + i + 16); - x3 = _mm256_loadu_ps(x + i + 24); - - y0 = _mm256_loadu_ps(y + i + 0); - y1 = _mm256_loadu_ps(y + i + 8); - y2 = _mm256_loadu_ps(y + i + 16); - y3 = _mm256_loadu_ps(y + i + 24); - - y0 = _mm256_add_ps(_mm256_mul_ps(x0, v4), y0); - y1 = _mm256_add_ps(_mm256_mul_ps(x1, v4), y1); - y2 = _mm256_add_ps(_mm256_mul_ps(x2, v4), y2); - y3 = _mm256_add_ps(_mm256_mul_ps(x3, v4), y3); - - _mm256_storeu_ps(y + i + 0, y0); - _mm256_storeu_ps(y + i + 8, y1); - _mm256_storeu_ps(y + i + 16, y2); - _mm256_storeu_ps(y + i + 24, y3); - } - - // leftovers - for (int i = n32; i < n; ++i) { - y[i] += x[i]*v; - } -#elif defined(__wasm_simd128__) - // WASM SIMD 128-bit - const int n16 = (n & ~15); - - const v128_t v4 = wasm_f32x4_splat(v); - - v128_t x0, x1, x2, x3; - v128_t y0, y1, y2, y3; - - for (int i = 0; i < n16; i += 16) { - x0 = wasm_v128_load(x + i + 0); - x1 = wasm_v128_load(x + i + 4); - x2 = wasm_v128_load(x + i + 8); - x3 = wasm_v128_load(x + i + 12); - - y0 = wasm_v128_load(y + i + 0); - y1 = wasm_v128_load(y + i + 4); - y2 = wasm_v128_load(y + i + 8); - y3 = wasm_v128_load(y + i + 12); - - y0 = wasm_f32x4_add(y0, wasm_f32x4_mul(x0, v4)); - y1 = wasm_f32x4_add(y1, wasm_f32x4_mul(x1, v4)); - y2 = wasm_f32x4_add(y2, wasm_f32x4_mul(x2, v4)); - y3 = wasm_f32x4_add(y3, wasm_f32x4_mul(x3, v4)); - - wasm_v128_store(y + i + 0, y0); - wasm_v128_store(y + i + 4, y1); - wasm_v128_store(y + i + 8, y2); - wasm_v128_store(y + i + 12, y3); + GGML_F32_VEC_STORE(y + i + j*GGML_F32_EPR, ay[j]); + } } // leftovers - for (int i = n16; i < n; ++i) { + for (int i = np; i < n; ++i) { y[i] += x[i]*v; } #else @@ -961,263 +869,86 @@ inline static void ggml_vec_mad_f32(const int n, float * restrict y, const float } inline static void ggml_vec_mad_f16(const int n, ggml_fp16_t * restrict y, ggml_fp16_t * restrict x, const float v) { -#if defined(__ARM_NEON) && defined(__ARM_FEATURE_FMA) - // NEON 128-bit - const int n32 = (n & ~31); - -#if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) - const float16x8_t v8 = vdupq_n_f16(v); +#if defined(GGML_SIMD) + const int np = (n & ~(GGML_F16_STEP - 1)); - float16x8_t x0, x1, x2, x3; - float16x8_t y0, y1, y2, y3; + GGML_F16_VEC vx = GGML_F16_VEC_SET1(v); - for (int i = 0; i < n32; i += 32) { - y0 = vld1q_f16(y + i + 0 ); - y1 = vld1q_f16(y + i + 8 ); - y2 = vld1q_f16(y + i + 16); - y3 = vld1q_f16(y + i + 24); - - x0 = vld1q_f16(x + i + 0 ); - x1 = vld1q_f16(x + i + 8 ); - x2 = vld1q_f16(x + i + 16); - x3 = vld1q_f16(x + i + 24); - - y0 = vfmaq_f16(y0, x0, v8); - y1 = vfmaq_f16(y1, x1, v8); - y2 = vfmaq_f16(y2, x2, v8); - y3 = vfmaq_f16(y3, x3, v8); - - vst1q_f16(y + i + 0 , y0); - vst1q_f16(y + i + 8 , y1); - vst1q_f16(y + i + 16, y2); - vst1q_f16(y + i + 24, y3); - } -#else - const float32x4_t v40 = vdupq_n_f32(v); - const float32x4_t v41 = vdupq_n_f32(v); + GGML_F16_VEC ax[GGML_F16_ARR]; + GGML_F16_VEC ay[GGML_F16_ARR]; - float32x4_t x0, x1, x2, x3, x4, x5, x6, x7; - float32x4_t y0, y1, y2, y3, y4, y5, y6, y7; + for (int i = 0; i < np; i += GGML_F16_STEP) { + for (int j = 0; j < GGML_F16_ARR; j++) { + ax[j] = GGML_F16_VEC_LOAD(x + i + j*GGML_F16_EPR); + ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR); + ay[j] = GGML_F16_VEC_FMA(ay[j], ax[j], vx); - for (int i = 0; i < n32; i += 32) { - y0 = vcvt_f32_f16(vld1_f16(y + i + 0 )); - y1 = vcvt_f32_f16(vld1_f16(y + i + 4 )); - y2 = vcvt_f32_f16(vld1_f16(y + i + 8 )); - y3 = vcvt_f32_f16(vld1_f16(y + i + 12)); - y4 = vcvt_f32_f16(vld1_f16(y + i + 16)); - y5 = vcvt_f32_f16(vld1_f16(y + i + 20)); - y6 = vcvt_f32_f16(vld1_f16(y + i + 24)); - y7 = vcvt_f32_f16(vld1_f16(y + i + 28)); - - x0 = vcvt_f32_f16(vld1_f16(x + i + 0 )); - x1 = vcvt_f32_f16(vld1_f16(x + i + 4 )); - x2 = vcvt_f32_f16(vld1_f16(x + i + 8 )); - x3 = vcvt_f32_f16(vld1_f16(x + i + 12)); - x4 = vcvt_f32_f16(vld1_f16(x + i + 16)); - x5 = vcvt_f32_f16(vld1_f16(x + i + 20)); - x6 = vcvt_f32_f16(vld1_f16(x + i + 24)); - x7 = vcvt_f32_f16(vld1_f16(x + i + 28)); - - y0 = vfmaq_f32(y0, x0, v40); - y1 = vfmaq_f32(y1, x1, v40); - y2 = vfmaq_f32(y2, x2, v40); - y3 = vfmaq_f32(y3, x3, v40); - y4 = vfmaq_f32(y4, x4, v41); - y5 = vfmaq_f32(y5, x5, v41); - y6 = vfmaq_f32(y6, x6, v41); - y7 = vfmaq_f32(y7, x7, v41); - - vst1_f16(y + i + 0 , vcvt_f16_f32(y0)); - vst1_f16(y + i + 4 , vcvt_f16_f32(y1)); - vst1_f16(y + i + 8 , vcvt_f16_f32(y2)); - vst1_f16(y + i + 12, vcvt_f16_f32(y3)); - vst1_f16(y + i + 16, vcvt_f16_f32(y4)); - vst1_f16(y + i + 20, vcvt_f16_f32(y5)); - vst1_f16(y + i + 24, vcvt_f16_f32(y6)); - vst1_f16(y + i + 28, vcvt_f16_f32(y7)); + GGML_F16_VEC_STORE(y + i + j*GGML_F16_EPR, ay[j]); + } } -#endif // leftovers - for (int i = n32; i < n; ++i) { + for (int i = np; i < n; ++i) { GGML_ASSERT(false); y[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(y[i]) + GGML_FP16_TO_FP32(x[i])*v); } -#elif defined(__AVX2__) - // AVX 256-bit +#elif defined(__POWER9_VECTOR__) + // TODO: this is temporary because I cannot fit it in the GGML_SIMD pattern like all other architectures without + // being able to test it. hoping someone with access to a POWER9 machine can help out here. const int n32 = (n & ~31); - - const __m256 v8 = _mm256_set1_ps(v); - - __m256 x0, x1, x2, x3; - __m256 y0, y1, y2, y3; - for (int i = 0; i < n32; i += 32) { - y0 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(y + i + 0 ))); - y1 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(y + i + 8 ))); - y2 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(y + i + 16))); - y3 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(y + i + 24))); - - x0 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(x + i + 0 ))); - x1 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(x + i + 8 ))); - x2 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(x + i + 16))); - x3 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(x + i + 24))); - - y0 = _mm256_fmadd_ps(x0, v8, y0); - y1 = _mm256_fmadd_ps(x1, v8, y1); - y2 = _mm256_fmadd_ps(x2, v8, y2); - y3 = _mm256_fmadd_ps(x3, v8, y3); - - _mm_storeu_si128((__m128i*)(y + i + 0 ), _mm256_cvtps_ph(y0, 0)); - _mm_storeu_si128((__m128i*)(y + i + 8 ), _mm256_cvtps_ph(y1, 0)); - _mm_storeu_si128((__m128i*)(y + i + 16), _mm256_cvtps_ph(y2, 0)); - _mm_storeu_si128((__m128i*)(y + i + 24), _mm256_cvtps_ph(y3, 0)); - } + // Use vec_xl, not vec_ld, because x is sometimes unaligned! + vector unsigned short x0 = vec_xl(i * 2 + 0, x); + vector unsigned short x1 = vec_xl(i * 2 + 16, x); + vector unsigned short x2 = vec_xl(i * 2 + 32, x); + vector unsigned short x3 = vec_xl(i * 2 + 48, x); - // leftovers - for (int i = n32; i < n; ++i) { - GGML_ASSERT(false); - y[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(y[i]) + GGML_FP16_TO_FP32(x[i])*v); - } -#elif defined(__AVX__) - // AVX 256-bit - const int n32 = (n & ~31); + vector unsigned short y0 = vec_xl(i * 2 + 0, y); + vector unsigned short y1 = vec_xl(i * 2 + 16, y); + vector unsigned short y2 = vec_xl(i * 2 + 32, y); + vector unsigned short y3 = vec_xl(i * 2 + 48, y); - const __m256 v8 = _mm256_set1_ps(v); + vector float v4 = vec_splats(v); - __m256 x0, x1, x2, x3; - __m256 y0, y1, y2, y3; + vector float fx0l = vec_extract_fp32_from_shortl(x0); + vector float fx0h = vec_extract_fp32_from_shorth(x0); + vector float fx1l = vec_extract_fp32_from_shortl(x1); + vector float fx1h = vec_extract_fp32_from_shorth(x1); + vector float fx2l = vec_extract_fp32_from_shortl(x2); + vector float fx2h = vec_extract_fp32_from_shorth(x2); + vector float fx3l = vec_extract_fp32_from_shortl(x3); + vector float fx3h = vec_extract_fp32_from_shorth(x3); - for (int i = 0; i < n32; i += 32) { - y0 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(y + i + 0 ))); - y1 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(y + i + 8 ))); - y2 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(y + i + 16))); - y3 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(y + i + 24))); + vector float fy0l = vec_extract_fp32_from_shortl(y0); + vector float fy0h = vec_extract_fp32_from_shorth(y0); + vector float fy1l = vec_extract_fp32_from_shortl(y1); + vector float fy1h = vec_extract_fp32_from_shorth(y1); + vector float fy2l = vec_extract_fp32_from_shortl(y2); + vector float fy2h = vec_extract_fp32_from_shorth(y2); + vector float fy3l = vec_extract_fp32_from_shortl(y3); + vector float fy3h = vec_extract_fp32_from_shorth(y3); - x0 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(x + i + 0 ))); - x1 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(x + i + 8 ))); - x2 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(x + i + 16))); - x3 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(x + i + 24))); + fy0l = vec_madd(fx0l, v4, fy0l); + fy0h = vec_madd(fx0h, v4, fy0h); + fy1l = vec_madd(fx1l, v4, fy1l); + fy1h = vec_madd(fx1h, v4, fy1h); + fy2l = vec_madd(fx2l, v4, fy2l); + fy2h = vec_madd(fx2h, v4, fy2h); + fy3l = vec_madd(fx3l, v4, fy3l); + fy3h = vec_madd(fx3h, v4, fy3h); - y0 = _mm256_add_ps(_mm256_mul_ps(x0, v8), y0); - y1 = _mm256_add_ps(_mm256_mul_ps(x1, v8), y1); - y2 = _mm256_add_ps(_mm256_mul_ps(x2, v8), y2); - y3 = _mm256_add_ps(_mm256_mul_ps(x3, v8), y3); + y0 = vec_pack_to_short_fp32(fy0h, fy0l); + y1 = vec_pack_to_short_fp32(fy1h, fy1l); + y2 = vec_pack_to_short_fp32(fy2h, fy2l); + y3 = vec_pack_to_short_fp32(fy3h, fy3l); - _mm_storeu_si128((__m128i*)(y + i + 0 ), _mm256_cvtps_ph(y0, 0)); - _mm_storeu_si128((__m128i*)(y + i + 8 ), _mm256_cvtps_ph(y1, 0)); - _mm_storeu_si128((__m128i*)(y + i + 16), _mm256_cvtps_ph(y2, 0)); - _mm_storeu_si128((__m128i*)(y + i + 24), _mm256_cvtps_ph(y3, 0)); + vec_xst(y0, i * 2 + 0, y); + vec_xst(y1, i * 2 + 16, y); + vec_xst(y2, i * 2 + 32, y); + vec_xst(y3, i * 2 + 48, y); } - // leftovers for (int i = n32; i < n; ++i) { - GGML_ASSERT(false); - y[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(y[i]) + GGML_FP16_TO_FP32(x[i])*v); - } -#elif defined(__POWER9_VECTOR__) - const int n32 = (n & ~31); - for (int i = 0; i < n32; i += 32) { - // Use vec_xl, not vec_ld, because x is sometimes unaligned! - vector unsigned short x0 = vec_xl(i * 2 + 0, x); - vector unsigned short x1 = vec_xl(i * 2 + 16, x); - vector unsigned short x2 = vec_xl(i * 2 + 32, x); - vector unsigned short x3 = vec_xl(i * 2 + 48, x); - - vector unsigned short y0 = vec_xl(i * 2 + 0, y); - vector unsigned short y1 = vec_xl(i * 2 + 16, y); - vector unsigned short y2 = vec_xl(i * 2 + 32, y); - vector unsigned short y3 = vec_xl(i * 2 + 48, y); - - vector float v4 = vec_splats(v); - - vector float fx0l = vec_extract_fp32_from_shortl(x0); - vector float fx0h = vec_extract_fp32_from_shorth(x0); - vector float fx1l = vec_extract_fp32_from_shortl(x1); - vector float fx1h = vec_extract_fp32_from_shorth(x1); - vector float fx2l = vec_extract_fp32_from_shortl(x2); - vector float fx2h = vec_extract_fp32_from_shorth(x2); - vector float fx3l = vec_extract_fp32_from_shortl(x3); - vector float fx3h = vec_extract_fp32_from_shorth(x3); - - vector float fy0l = vec_extract_fp32_from_shortl(y0); - vector float fy0h = vec_extract_fp32_from_shorth(y0); - vector float fy1l = vec_extract_fp32_from_shortl(y1); - vector float fy1h = vec_extract_fp32_from_shorth(y1); - vector float fy2l = vec_extract_fp32_from_shortl(y2); - vector float fy2h = vec_extract_fp32_from_shorth(y2); - vector float fy3l = vec_extract_fp32_from_shortl(y3); - vector float fy3h = vec_extract_fp32_from_shorth(y3); - - fy0l = vec_madd(fx0l, v4, fy0l); - fy0h = vec_madd(fx0h, v4, fy0h); - fy1l = vec_madd(fx1l, v4, fy1l); - fy1h = vec_madd(fx1h, v4, fy1h); - fy2l = vec_madd(fx2l, v4, fy2l); - fy2h = vec_madd(fx2h, v4, fy2h); - fy3l = vec_madd(fx3l, v4, fy3l); - fy3h = vec_madd(fx3h, v4, fy3h); - - y0 = vec_pack_to_short_fp32(fy0h, fy0l); - y1 = vec_pack_to_short_fp32(fy1h, fy1l); - y2 = vec_pack_to_short_fp32(fy2h, fy2l); - y3 = vec_pack_to_short_fp32(fy3h, fy3l); - - vec_xst(y0, i * 2 + 0, y); - vec_xst(y1, i * 2 + 16, y); - vec_xst(y2, i * 2 + 32, y); - vec_xst(y3, i * 2 + 48, y); - } - - for (int i = n32; i < n; ++i) { - y[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(y[i]) + GGML_FP16_TO_FP32(x[i])*v); - } -#elif defined(__wasm_simd128__) - // WASM SIMD 128-bit - const int n16 = (n & ~15); - - const v128_t v4 = wasm_f32x4_splat(v); - - v128_t x0, x1, x2, x3; - v128_t y0, y1, y2, y3; - - float tx[16]; - float ty[16]; - - for (int i = 0; i < n16; i += 16) { - for (int k = 0; k < 16; ++k) { - tx[k] = GGML_FP16_TO_FP32(x[i + k]); - ty[k] = GGML_FP16_TO_FP32(y[i + k]); - } - - x0 = wasm_v128_load(tx + 0); - x1 = wasm_v128_load(tx + 4); - x2 = wasm_v128_load(tx + 8); - x3 = wasm_v128_load(tx + 12); - - y0 = wasm_v128_load(ty + 0); - y1 = wasm_v128_load(ty + 4); - y2 = wasm_v128_load(ty + 8); - y3 = wasm_v128_load(ty + 12); - - y0 = wasm_f32x4_add(y0, wasm_f32x4_mul(x0, v4)); - y1 = wasm_f32x4_add(y1, wasm_f32x4_mul(x1, v4)); - y2 = wasm_f32x4_add(y2, wasm_f32x4_mul(x2, v4)); - y3 = wasm_f32x4_add(y3, wasm_f32x4_mul(x3, v4)); - - wasm_v128_store(ty + 0, y0); - wasm_v128_store(ty + 4, y1); - wasm_v128_store(ty + 8, y2); - wasm_v128_store(ty + 12, y3); - - for (int k = 0; k < 16; ++k) { - y[i + k] = GGML_FP32_TO_FP16(ty[k]); - } - } - - // leftovers - for (int i = n16; i < n; ++i) { - GGML_ASSERT(false); y[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(y[i]) + GGML_FP16_TO_FP32(x[i])*v); } #else @@ -1229,33 +960,24 @@ inline static void ggml_vec_mad_f16(const int n, ggml_fp16_t * restrict y, ggml_ //inline static void ggml_vec_scale_f32(const int n, float * y, const float v) { for (int i = 0; i < n; ++i) y[i] *= v; } inline static void ggml_vec_scale_f32(const int n, float * y, const float v) { -#if defined(__AVX__) || defined(__AVX2__) - // AVX 256-bit - const int n32 = (n & ~31); +#if defined(GGML_SIMD) + const int np = (n & ~(GGML_F32_STEP - 1)); - const __m256 v4 = _mm256_set1_ps(v); + GGML_F32_VEC vx = GGML_F32_VEC_SET1(v); - __m256 y0, y1, y2, y3; - - for (int i = 0; i < n32; i += 32) { - y0 = _mm256_loadu_ps(y + i + 0); - y1 = _mm256_loadu_ps(y + i + 8); - y2 = _mm256_loadu_ps(y + i + 16); - y3 = _mm256_loadu_ps(y + i + 24); + GGML_F32_VEC ay[GGML_F32_ARR]; - y0 = _mm256_mul_ps(y0, v4); - y1 = _mm256_mul_ps(y1, v4); - y2 = _mm256_mul_ps(y2, v4); - y3 = _mm256_mul_ps(y3, v4); + for (int i = 0; i < np; i += GGML_F32_STEP) { + for (int j = 0; j < GGML_F32_ARR; j++) { + ay[j] = GGML_F32_VEC_LOAD(y + i + j*GGML_F32_EPR); + ay[j] = GGML_F32_VEC_MUL(ay[j], vx); - _mm256_storeu_ps(y + i + 0, y0); - _mm256_storeu_ps(y + i + 8, y1); - _mm256_storeu_ps(y + i + 16, y2); - _mm256_storeu_ps(y + i + 24, y3); + GGML_F32_VEC_STORE(y + i + j*GGML_F32_EPR, ay[j]); + } } // leftovers - for (int i = n32; i < n; ++i) { + for (int i = np; i < n; ++i) { y[i] *= v; } #else @@ -8533,6 +8255,14 @@ int ggml_cpu_has_avx512(void) { #endif } +int ggml_cpu_has_fma(void) { +#if defined(__FMA__) + return 1; +#else + return 0; +#endif +} + int ggml_cpu_has_neon(void) { #if defined(__ARM_NEON) return 1; diff --git a/ggml.h b/ggml.h index 4bb6118..a217d2d 100644 --- a/ggml.h +++ b/ggml.h @@ -724,6 +724,7 @@ enum ggml_opt_result ggml_opt( int ggml_cpu_has_avx(void); int ggml_cpu_has_avx2(void); int ggml_cpu_has_avx512(void); +int ggml_cpu_has_fma(void); int ggml_cpu_has_neon(void); int ggml_cpu_has_arm_fma(void); int ggml_cpu_has_f16c(void); diff --git a/whisper.cpp b/whisper.cpp index f43bb4a..0776076 100644 --- a/whisper.cpp +++ b/whisper.cpp @@ -2555,6 +2555,7 @@ const char * whisper_print_system_info(void) { s += "AVX = " + std::to_string(ggml_cpu_has_avx()) + " | "; s += "AVX2 = " + std::to_string(ggml_cpu_has_avx2()) + " | "; s += "AVX512 = " + std::to_string(ggml_cpu_has_avx512()) + " | "; + s += "FMA = " + std::to_string(ggml_cpu_has_fma()) + " | "; s += "NEON = " + std::to_string(ggml_cpu_has_neon()) + " | "; s += "ARM_FMA = " + std::to_string(ggml_cpu_has_arm_fma()) + " | "; s += "F16C = " + std::to_string(ggml_cpu_has_f16c()) + " | ";