From 95f4fc70caf0bd492ce18dc48b76216c1b891bde Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Wed, 26 Oct 2022 18:48:54 +0300 Subject: [PATCH 1/4] Try to add AVX 512-bit support --- Makefile | 6 +- ggml.c | 175 +++++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 180 insertions(+), 1 deletion(-) diff --git a/Makefile b/Makefile index 1b9d4b6..d120f1f 100644 --- a/Makefile +++ b/Makefile @@ -42,7 +42,11 @@ endif # TODO: probably these flags need to be tweaked on some architectures # feel free to update the Makefile for your architecture and send a pull request or issue ifeq ($(UNAME_M),x86_64) - CFLAGS += -mavx -mavx2 -mfma -mf16c + # AVX 512 + CFLAGS += -mavx512f -mavx512dq -mfma -mf16c + + # AVX 256 + #CFLAGS += -mavx -mavx2 -mfma -mf16c endif ifeq ($(UNAME_M),amd64) CFLAGS += -mavx -mavx2 -mfma -mf16c diff --git a/ggml.c b/ggml.c index 3a36802..6f80820 100644 --- a/ggml.c +++ b/ggml.c @@ -320,6 +320,45 @@ inline static void ggml_vec_dot_f32(const int n, float * restrict s, const float for (int i = n16; i < n; ++i) { sumf += x[i]*y[i]; } +#elif defined(__AVX512F__) + const int n64 = (n & ~63); + + __m512 sum0 = _mm512_setzero_ps(); + __m512 sum1 = _mm512_setzero_ps(); + __m512 sum2 = _mm512_setzero_ps(); + __m512 sum3 = _mm512_setzero_ps(); + + __m512 x0, x1, x2, x3; + __m512 y0, y1, y2, y3; + + for (int i = 0; i < n64; i += 64) { + x0 = _mm512_loadu_ps(x + i + 0); + x1 = _mm512_loadu_ps(x + i + 16); + x2 = _mm512_loadu_ps(x + i + 32); + x3 = _mm512_loadu_ps(x + i + 48); + + y0 = _mm512_loadu_ps(y + i + 0); + y1 = _mm512_loadu_ps(y + i + 16); + y2 = _mm512_loadu_ps(y + i + 32); + y3 = _mm512_loadu_ps(y + i + 48); + + sum0 = _mm512_fmadd_ps(x0, y0, sum0); + sum1 = _mm512_fmadd_ps(x1, y1, sum1); + sum2 = _mm512_fmadd_ps(x2, y2, sum2); + sum3 = _mm512_fmadd_ps(x3, y3, sum3); + } + + sum0 = _mm512_add_ps(sum0, sum1); + sum2 = _mm512_add_ps(sum2, sum3); + sum0 = _mm512_add_ps(sum0, sum2); + + sumf = sum0[0] + sum0[1] + sum0[2] + sum0[3] + sum0[4] + sum0[5] + sum0[6] + sum0[7] + + sum0[8] + sum0[9] + sum0[10] + sum0[11] + sum0[12] + sum0[13] + sum0[14] + sum0[15]; + + // leftovers + for (int i = n64; i < n; ++i) { + sumf += x[i]*y[i]; + } #elif defined(__AVX2__) // AVX 256-bit const int n32 = (n & ~31); @@ -517,6 +556,61 @@ inline static void ggml_vec_dot_f16(const int n, float * restrict s, ggml_fp16_t for (int i = n32; i < n; ++i) { sumf += ggml_fp16_to_fp32(x[i])*ggml_fp16_to_fp32(y[i]); } +#elif defined(__AVX512F__) + // AVX 512-bit + const int n64 = (n & ~63); + + __m512 sum0 = _mm512_setzero_ps(); + __m512 sum1 = _mm512_setzero_ps(); + __m512 sum2 = _mm512_setzero_ps(); + __m512 sum3 = _mm512_setzero_ps(); + + __m512 x0, x1, x2, x3; + __m512 y0, y1, y2, y3; + + __m256 t0lo, t0hi, t1lo, t1hi, t2lo, t2hi, t3lo, t3hi; + + for (int i = 0; i < n64; i += 64) { + // TODO: is this the best way to do this? + t0lo = _mm256_loadu_ps((const float*)(x + i + 0 )); + t0hi = _mm256_loadu_ps((const float*)(x + i + 8 )); + t1lo = _mm256_loadu_ps((const float*)(x + i + 16)); + t1hi = _mm256_loadu_ps((const float*)(x + i + 24)); + t2lo = _mm256_loadu_ps((const float*)(x + i + 32)); + t2hi = _mm256_loadu_ps((const float*)(x + i + 40)); + t3lo = _mm256_loadu_ps((const float*)(x + i + 48)); + t3hi = _mm256_loadu_ps((const float*)(x + i + 56)); + + x0 = _mm512_insertf32x8(_mm512_castps256_ps512(t0lo), t0hi, 1); + x1 = _mm512_insertf32x8(_mm512_castps256_ps512(t1lo), t1hi, 1); + x2 = _mm512_insertf32x8(_mm512_castps256_ps512(t2lo), t2hi, 1); + x3 = _mm512_insertf32x8(_mm512_castps256_ps512(t3lo), t3hi, 1); + + t0lo = _mm256_loadu_ps((const float*)(y + i + 0 )); + t0hi = _mm256_loadu_ps((const float*)(y + i + 8 )); + t1lo = _mm256_loadu_ps((const float*)(y + i + 16)); + t1hi = _mm256_loadu_ps((const float*)(y + i + 24)); + t2lo = _mm256_loadu_ps((const float*)(y + i + 32)); + t2hi = _mm256_loadu_ps((const float*)(y + i + 40)); + t3lo = _mm256_loadu_ps((const float*)(y + i + 48)); + t3hi = _mm256_loadu_ps((const float*)(y + i + 56)); + + y0 = _mm512_insertf32x8(_mm512_castps256_ps512(t0lo), t0hi, 1); + y1 = _mm512_insertf32x8(_mm512_castps256_ps512(t1lo), t1hi, 1); + y2 = _mm512_insertf32x8(_mm512_castps256_ps512(t2lo), t2hi, 1); + y3 = _mm512_insertf32x8(_mm512_castps256_ps512(t3lo), t3hi, 1); + + sum0 = _mm512_fmadd_ps(x0, y0, sum0); + sum1 = _mm512_fmadd_ps(x1, y1, sum1); + sum2 = _mm512_fmadd_ps(x2, y2, sum2); + sum3 = _mm512_fmadd_ps(x3, y3, sum3); + } + + const __m512 sum01 = _mm512_add_ps(sum0, sum1); + const __m512 sum23 = _mm512_add_ps(sum2, sum3); + const __m512 sum0123 = _mm512_add_ps(sum01, sum23); + + sumf = sum0123[0] + sum0123[1] + sum0123[2] + sum0123[3] + sum0123[4] + sum0123[5] + sum0123[6] + sum0123[7]; #elif defined(__AVX2__) // AVX 256-bit const int n32 = (n & ~31); @@ -654,6 +748,41 @@ inline static void ggml_vec_mad_f32(const int n, float * restrict y, const float for (int i = n16; i < n; ++i) { y[i] += x[i]*v; } +#elif defined(__AVX512F__) + // AVX512 512-bit + const int n64 = (n & ~63); + + const __m512 v4 = _mm512_set1_ps(v); + + __m512 x0, x1, x2, x3; + __m512 y0, y1, y2, y3; + + for (int i = 0; i < n64; i += 64) { + x0 = _mm512_loadu_ps(x + i + 0); + x1 = _mm512_loadu_ps(x + i + 16); + x2 = _mm512_loadu_ps(x + i + 32); + x3 = _mm512_loadu_ps(x + i + 48); + + y0 = _mm512_loadu_ps(y + i + 0); + y1 = _mm512_loadu_ps(y + i + 16); + y2 = _mm512_loadu_ps(y + i + 32); + y3 = _mm512_loadu_ps(y + i + 48); + + y0 = _mm512_fmadd_ps(x0, v4, y0); + y1 = _mm512_fmadd_ps(x1, v4, y1); + y2 = _mm512_fmadd_ps(x2, v4, y2); + y3 = _mm512_fmadd_ps(x3, v4, y3); + + _mm512_storeu_ps(y + i + 0, y0); + _mm512_storeu_ps(y + i + 16, y1); + _mm512_storeu_ps(y + i + 32, y2); + _mm512_storeu_ps(y + i + 48, y3); + } + + // leftovers + for (int i = n64; i < n; ++i) { + y[i] += x[i]*v; + } #elif defined(__AVX2__) // AVX 256-bit const int n32 = (n & ~31); @@ -815,6 +944,52 @@ inline static void ggml_vec_mad_f16(const int n, ggml_fp16_t * restrict y, ggml_ GGML_ASSERT(false); y[i] = ggml_fp32_to_fp16(ggml_fp16_to_fp32(y[i]) + ggml_fp16_to_fp32(x[i])*v); } +#elif defined(__AVX512F__) + // AVX 512-bit + const int n64 = (n & ~63); + + const __m512 v16 = _mm512_set1_ps(v); + + __m512 x0, x1, x2, x3; + __m512 y0, y1, y2, y3; + + __m256 t0lo, t0hi, t1lo, t1hi, t2lo, t2hi, t3lo, t3hi; + + for (int i = 0; i < n64; i += 64) { + t0lo = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(y + i + 0 ))); + t0hi = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(y + i + 8 ))); + t1lo = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(y + i + 16))); + t1hi = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(y + i + 24))); + t2lo = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(y + i + 32))); + t2hi = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(y + i + 40))); + t3lo = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(y + i + 48))); + t3hi = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(y + i + 56))); + + y0 = _mm512_insertf32x8(_mm512_castps256_ps512(t0lo), t0hi, 1); + y1 = _mm512_insertf32x8(_mm512_castps256_ps512(t1lo), t1hi, 1); + y2 = _mm512_insertf32x8(_mm512_castps256_ps512(t2lo), t2hi, 1); + y3 = _mm512_insertf32x8(_mm512_castps256_ps512(t3lo), t3hi, 1); + + t0lo = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(x + i + 0 ))); + t0hi = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(x + i + 8 ))); + t1lo = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(x + i + 16))); + t1hi = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(x + i + 24))); + t2lo = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(x + i + 32))); + t2hi = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(x + i + 40))); + t3lo = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(x + i + 48))); + t3hi = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(x + i + 56))); + + y0 = _mm512_fmadd_ps(x0, v16, y0); + y1 = _mm512_fmadd_ps(x1, v16, y1); + y2 = _mm512_fmadd_ps(x2, v16, y2); + y3 = _mm512_fmadd_ps(x3, v16, y3); + } + + // leftovers + for (int i = n64; i < n; ++i) { + GGML_ASSERT(false); + y[i] = ggml_fp32_to_fp16(ggml_fp16_to_fp32(y[i]) + ggml_fp16_to_fp32(x[i])*v); + } #elif defined(__AVX2__) // AVX 256-bit const int n32 = (n & ~31); From 7fc52fa7ef95cfb1e794f64602d6061dd3c0a5cb Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 27 Oct 2022 17:45:38 +0300 Subject: [PATCH 2/4] Another shot at AVX-512 support --- Makefile | 2 +- ggml.c | 77 +++++++++++++++++--------------------------------------- 2 files changed, 24 insertions(+), 55 deletions(-) diff --git a/Makefile b/Makefile index d120f1f..40a0187 100644 --- a/Makefile +++ b/Makefile @@ -43,7 +43,7 @@ endif # feel free to update the Makefile for your architecture and send a pull request or issue ifeq ($(UNAME_M),x86_64) # AVX 512 - CFLAGS += -mavx512f -mavx512dq -mfma -mf16c + CFLAGS += -mavx512f -mfma -mf16c # AVX 256 #CFLAGS += -mavx -mavx2 -mfma -mf16c diff --git a/ggml.c b/ggml.c index 6f80820..7d67078 100644 --- a/ggml.c +++ b/ggml.c @@ -568,37 +568,16 @@ inline static void ggml_vec_dot_f16(const int n, float * restrict s, ggml_fp16_t __m512 x0, x1, x2, x3; __m512 y0, y1, y2, y3; - __m256 t0lo, t0hi, t1lo, t1hi, t2lo, t2hi, t3lo, t3hi; - for (int i = 0; i < n64; i += 64) { - // TODO: is this the best way to do this? - t0lo = _mm256_loadu_ps((const float*)(x + i + 0 )); - t0hi = _mm256_loadu_ps((const float*)(x + i + 8 )); - t1lo = _mm256_loadu_ps((const float*)(x + i + 16)); - t1hi = _mm256_loadu_ps((const float*)(x + i + 24)); - t2lo = _mm256_loadu_ps((const float*)(x + i + 32)); - t2hi = _mm256_loadu_ps((const float*)(x + i + 40)); - t3lo = _mm256_loadu_ps((const float*)(x + i + 48)); - t3hi = _mm256_loadu_ps((const float*)(x + i + 56)); - - x0 = _mm512_insertf32x8(_mm512_castps256_ps512(t0lo), t0hi, 1); - x1 = _mm512_insertf32x8(_mm512_castps256_ps512(t1lo), t1hi, 1); - x2 = _mm512_insertf32x8(_mm512_castps256_ps512(t2lo), t2hi, 1); - x3 = _mm512_insertf32x8(_mm512_castps256_ps512(t3lo), t3hi, 1); - - t0lo = _mm256_loadu_ps((const float*)(y + i + 0 )); - t0hi = _mm256_loadu_ps((const float*)(y + i + 8 )); - t1lo = _mm256_loadu_ps((const float*)(y + i + 16)); - t1hi = _mm256_loadu_ps((const float*)(y + i + 24)); - t2lo = _mm256_loadu_ps((const float*)(y + i + 32)); - t2hi = _mm256_loadu_ps((const float*)(y + i + 40)); - t3lo = _mm256_loadu_ps((const float*)(y + i + 48)); - t3hi = _mm256_loadu_ps((const float*)(y + i + 56)); - - y0 = _mm512_insertf32x8(_mm512_castps256_ps512(t0lo), t0hi, 1); - y1 = _mm512_insertf32x8(_mm512_castps256_ps512(t1lo), t1hi, 1); - y2 = _mm512_insertf32x8(_mm512_castps256_ps512(t2lo), t2hi, 1); - y3 = _mm512_insertf32x8(_mm512_castps256_ps512(t3lo), t3hi, 1); + x0 = _mm512_cvtph_ps(_mm256_loadu_si256((__m256i*)(x + i + 0 ))); + x1 = _mm512_cvtph_ps(_mm256_loadu_si256((__m256i*)(x + i + 16))); + x2 = _mm512_cvtph_ps(_mm256_loadu_si256((__m256i*)(x + i + 32))); + x3 = _mm512_cvtph_ps(_mm256_loadu_si256((__m256i*)(x + i + 48))); + + y0 = _mm512_cvtph_ps(_mm256_loadu_si256((__m256i*)(y + i + 0 ))); + y1 = _mm512_cvtph_ps(_mm256_loadu_si256((__m256i*)(y + i + 16))); + y2 = _mm512_cvtph_ps(_mm256_loadu_si256((__m256i*)(y + i + 32))); + y3 = _mm512_cvtph_ps(_mm256_loadu_si256((__m256i*)(y + i + 48))); sum0 = _mm512_fmadd_ps(x0, y0, sum0); sum1 = _mm512_fmadd_ps(x1, y1, sum1); @@ -953,36 +932,26 @@ inline static void ggml_vec_mad_f16(const int n, ggml_fp16_t * restrict y, ggml_ __m512 x0, x1, x2, x3; __m512 y0, y1, y2, y3; - __m256 t0lo, t0hi, t1lo, t1hi, t2lo, t2hi, t3lo, t3hi; - for (int i = 0; i < n64; i += 64) { - t0lo = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(y + i + 0 ))); - t0hi = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(y + i + 8 ))); - t1lo = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(y + i + 16))); - t1hi = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(y + i + 24))); - t2lo = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(y + i + 32))); - t2hi = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(y + i + 40))); - t3lo = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(y + i + 48))); - t3hi = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(y + i + 56))); - - y0 = _mm512_insertf32x8(_mm512_castps256_ps512(t0lo), t0hi, 1); - y1 = _mm512_insertf32x8(_mm512_castps256_ps512(t1lo), t1hi, 1); - y2 = _mm512_insertf32x8(_mm512_castps256_ps512(t2lo), t2hi, 1); - y3 = _mm512_insertf32x8(_mm512_castps256_ps512(t3lo), t3hi, 1); - - t0lo = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(x + i + 0 ))); - t0hi = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(x + i + 8 ))); - t1lo = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(x + i + 16))); - t1hi = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(x + i + 24))); - t2lo = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(x + i + 32))); - t2hi = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(x + i + 40))); - t3lo = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(x + i + 48))); - t3hi = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(x + i + 56))); + x0 = _mm512_cvtph_ps(_mm256_loadu_si256((__m256i*)(x + i + 0 ))); + x1 = _mm512_cvtph_ps(_mm256_loadu_si256((__m256i*)(x + i + 16))); + x2 = _mm512_cvtph_ps(_mm256_loadu_si256((__m256i*)(x + i + 32))); + x3 = _mm512_cvtph_ps(_mm256_loadu_si256((__m256i*)(x + i + 48))); + + y0 = _mm512_loadu_ps(y + i + 0 ); + y1 = _mm512_loadu_ps(y + i + 16); + y2 = _mm512_loadu_ps(y + i + 32); + y3 = _mm512_loadu_ps(y + i + 48); y0 = _mm512_fmadd_ps(x0, v16, y0); y1 = _mm512_fmadd_ps(x1, v16, y1); y2 = _mm512_fmadd_ps(x2, v16, y2); y3 = _mm512_fmadd_ps(x3, v16, y3); + + _mm256_storeu_si256((__m256i*)(y + i + 0 ), _mm512_cvtps_ph(y0, 0)); + _mm256_storeu_si256((__m256i*)(y + i + 16), _mm512_cvtps_ph(y1, 0)); + _mm256_storeu_si256((__m256i*)(y + i + 32), _mm512_cvtps_ph(y2, 0)); + _mm256_storeu_si256((__m256i*)(y + i + 48), _mm512_cvtps_ph(y3, 0)); } // leftovers From c4350356deafcc748b64f6aece9d4de2cf223de5 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sat, 5 Nov 2022 22:56:56 +0200 Subject: [PATCH 3/4] Update ggml.c --- ggml.c | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ggml.c b/ggml.c index 7d67078..056f85d 100644 --- a/ggml.c +++ b/ggml.c @@ -589,7 +589,7 @@ inline static void ggml_vec_dot_f16(const int n, float * restrict s, ggml_fp16_t const __m512 sum23 = _mm512_add_ps(sum2, sum3); const __m512 sum0123 = _mm512_add_ps(sum01, sum23); - sumf = sum0123[0] + sum0123[1] + sum0123[2] + sum0123[3] + sum0123[4] + sum0123[5] + sum0123[6] + sum0123[7]; + sumf = sum0123[0] + sum0123[1] + sum0123[2] + sum0123[3] + sum0123[4] + sum0123[5] + sum0123[6] + sum0123[7] + sum0123[8] + sum0123[9] + sum0123[10] + sum0123[11] + sum0123[12] + sum0123[13] + sum0123[14] + sum0123[15]; #elif defined(__AVX2__) // AVX 256-bit const int n32 = (n & ~31); From b71d45beffa40f5386d826215009f19f728b05b9 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 6 Nov 2022 08:50:57 +0200 Subject: [PATCH 4/4] ggml : fix AVX 512-bit kernels --- ggml.c | 128 ++++++++++++++++++++++++++++++--------------------------- 1 file changed, 67 insertions(+), 61 deletions(-) diff --git a/ggml.c b/ggml.c index 056f85d..154a76d 100644 --- a/ggml.c +++ b/ggml.c @@ -589,7 +589,14 @@ inline static void ggml_vec_dot_f16(const int n, float * restrict s, ggml_fp16_t const __m512 sum23 = _mm512_add_ps(sum2, sum3); const __m512 sum0123 = _mm512_add_ps(sum01, sum23); - sumf = sum0123[0] + sum0123[1] + sum0123[2] + sum0123[3] + sum0123[4] + sum0123[5] + sum0123[6] + sum0123[7] + sum0123[8] + sum0123[9] + sum0123[10] + sum0123[11] + sum0123[12] + sum0123[13] + sum0123[14] + sum0123[15]; + sumf = sum0123[0] + sum0123[1] + sum0123[2] + sum0123[3] + sum0123[4] + sum0123[5] + sum0123[6] + sum0123[7] + + sum0123[8] + sum0123[9] + sum0123[10] + sum0123[11] + sum0123[12] + sum0123[13] + sum0123[14] + sum0123[15]; + + // leftovers + for (int i = n64; i < n; ++i) { + //GGML_ASSERT(false); + sumf += ggml_fp16_to_fp32(x[i])*ggml_fp16_to_fp32(y[i]); + } #elif defined(__AVX2__) // AVX 256-bit const int n32 = (n & ~31); @@ -696,7 +703,7 @@ inline static void ggml_vec_mad_f32(const int n, float * restrict y, const float // NEON 128-bit const int n16 = (n & ~15); - const float32x4_t v4 = vdupq_n_f32(v); + const float32x4_t v0 = vdupq_n_f32(v); float32x4_t x0, x1, x2, x3; float32x4_t y0, y1, y2, y3; @@ -712,14 +719,14 @@ inline static void ggml_vec_mad_f32(const int n, float * restrict y, const float y2 = vld1q_f32(y + i + 8); y3 = vld1q_f32(y + i + 12); - y0 = vfmaq_f32(y0, x0, v4); - y1 = vfmaq_f32(y1, x1, v4); - y2 = vfmaq_f32(y2, x2, v4); - y3 = vfmaq_f32(y3, x3, v4); + y0 = vfmaq_f32(y0, x0, v0); + y1 = vfmaq_f32(y1, x1, v0); + y2 = vfmaq_f32(y2, x2, v0); + y3 = vfmaq_f32(y3, x3, v0); - vst1q_f32(y + i + 0, y0); - vst1q_f32(y + i + 4, y1); - vst1q_f32(y + i + 8, y2); + vst1q_f32(y + i + 0, y0); + vst1q_f32(y + i + 4, y1); + vst1q_f32(y + i + 8, y2); vst1q_f32(y + i + 12, y3); } @@ -731,7 +738,7 @@ inline static void ggml_vec_mad_f32(const int n, float * restrict y, const float // AVX512 512-bit const int n64 = (n & ~63); - const __m512 v4 = _mm512_set1_ps(v); + const __m512 v0 = _mm512_set1_ps(v); __m512 x0, x1, x2, x3; __m512 y0, y1, y2, y3; @@ -747,12 +754,12 @@ inline static void ggml_vec_mad_f32(const int n, float * restrict y, const float y2 = _mm512_loadu_ps(y + i + 32); y3 = _mm512_loadu_ps(y + i + 48); - y0 = _mm512_fmadd_ps(x0, v4, y0); - y1 = _mm512_fmadd_ps(x1, v4, y1); - y2 = _mm512_fmadd_ps(x2, v4, y2); - y3 = _mm512_fmadd_ps(x3, v4, y3); + y0 = _mm512_fmadd_ps(x0, v0, y0); + y1 = _mm512_fmadd_ps(x1, v0, y1); + y2 = _mm512_fmadd_ps(x2, v0, y2); + y3 = _mm512_fmadd_ps(x3, v0, y3); - _mm512_storeu_ps(y + i + 0, y0); + _mm512_storeu_ps(y + i + 0, y0); _mm512_storeu_ps(y + i + 16, y1); _mm512_storeu_ps(y + i + 32, y2); _mm512_storeu_ps(y + i + 48, y3); @@ -766,7 +773,7 @@ inline static void ggml_vec_mad_f32(const int n, float * restrict y, const float // AVX 256-bit const int n32 = (n & ~31); - const __m256 v4 = _mm256_set1_ps(v); + const __m256 v0 = _mm256_set1_ps(v); __m256 x0, x1, x2, x3; __m256 y0, y1, y2, y3; @@ -782,13 +789,13 @@ inline static void ggml_vec_mad_f32(const int n, float * restrict y, const float y2 = _mm256_loadu_ps(y + i + 16); y3 = _mm256_loadu_ps(y + i + 24); - y0 = _mm256_fmadd_ps(x0, v4, y0); - y1 = _mm256_fmadd_ps(x1, v4, y1); - y2 = _mm256_fmadd_ps(x2, v4, y2); - y3 = _mm256_fmadd_ps(x3, v4, y3); + y0 = _mm256_fmadd_ps(x0, v0, y0); + y1 = _mm256_fmadd_ps(x1, v0, y1); + y2 = _mm256_fmadd_ps(x2, v0, y2); + y3 = _mm256_fmadd_ps(x3, v0, y3); - _mm256_storeu_ps(y + i + 0, y0); - _mm256_storeu_ps(y + i + 8, y1); + _mm256_storeu_ps(y + i + 0, y0); + _mm256_storeu_ps(y + i + 8, y1); _mm256_storeu_ps(y + i + 16, y2); _mm256_storeu_ps(y + i + 24, y3); } @@ -801,7 +808,7 @@ inline static void ggml_vec_mad_f32(const int n, float * restrict y, const float // WASM SIMD 128-bit const int n16 = (n & ~15); - const v128_t v4 = wasm_f32x4_splat(v); + const v128_t v0 = wasm_f32x4_splat(v); v128_t x0, x1, x2, x3; v128_t y0, y1, y2, y3; @@ -817,10 +824,10 @@ inline static void ggml_vec_mad_f32(const int n, float * restrict y, const float y2 = wasm_v128_load(y + i + 8); y3 = wasm_v128_load(y + i + 12); - y0 = wasm_f32x4_add(y0, wasm_f32x4_mul(x0, v4)); - y1 = wasm_f32x4_add(y1, wasm_f32x4_mul(x1, v4)); - y2 = wasm_f32x4_add(y2, wasm_f32x4_mul(x2, v4)); - y3 = wasm_f32x4_add(y3, wasm_f32x4_mul(x3, v4)); + y0 = wasm_f32x4_add(y0, wasm_f32x4_mul(x0, v0)); + y1 = wasm_f32x4_add(y1, wasm_f32x4_mul(x1, v0)); + y2 = wasm_f32x4_add(y2, wasm_f32x4_mul(x2, v0)); + y3 = wasm_f32x4_add(y3, wasm_f32x4_mul(x3, v0)); wasm_v128_store(y + i + 0, y0); wasm_v128_store(y + i + 4, y1); @@ -846,7 +853,7 @@ inline static void ggml_vec_mad_f16(const int n, ggml_fp16_t * restrict y, ggml_ const int n32 = (n & ~31); #if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) - const float16x8_t v8 = vdupq_n_f16(v); + const float16x8_t v0 = vdupq_n_f16(v); float16x8_t x0, x1, x2, x3; float16x8_t y0, y1, y2, y3; @@ -862,10 +869,10 @@ inline static void ggml_vec_mad_f16(const int n, ggml_fp16_t * restrict y, ggml_ x2 = vld1q_f16(x + i + 16); x3 = vld1q_f16(x + i + 24); - y0 = vfmaq_f16(y0, x0, v8); - y1 = vfmaq_f16(y1, x1, v8); - y2 = vfmaq_f16(y2, x2, v8); - y3 = vfmaq_f16(y3, x3, v8); + y0 = vfmaq_f16(y0, x0, v0); + y1 = vfmaq_f16(y1, x1, v0); + y2 = vfmaq_f16(y2, x2, v0); + y3 = vfmaq_f16(y3, x3, v0); vst1q_f16(y + i + 0 , y0); vst1q_f16(y + i + 8 , y1); @@ -873,8 +880,7 @@ inline static void ggml_vec_mad_f16(const int n, ggml_fp16_t * restrict y, ggml_ vst1q_f16(y + i + 24, y3); } #else - const float32x4_t v40 = vdupq_n_f32(v); - const float32x4_t v41 = vdupq_n_f32(v); + const float32x4_t v0 = vdupq_n_f32(v); float32x4_t x0, x1, x2, x3, x4, x5, x6, x7; float32x4_t y0, y1, y2, y3, y4, y5, y6, y7; @@ -898,14 +904,14 @@ inline static void ggml_vec_mad_f16(const int n, ggml_fp16_t * restrict y, ggml_ x6 = vcvt_f32_f16(vld1_f16(x + i + 24)); x7 = vcvt_f32_f16(vld1_f16(x + i + 28)); - y0 = vfmaq_f32(y0, x0, v40); - y1 = vfmaq_f32(y1, x1, v40); - y2 = vfmaq_f32(y2, x2, v40); - y3 = vfmaq_f32(y3, x3, v40); - y4 = vfmaq_f32(y4, x4, v41); - y5 = vfmaq_f32(y5, x5, v41); - y6 = vfmaq_f32(y6, x6, v41); - y7 = vfmaq_f32(y7, x7, v41); + y0 = vfmaq_f32(y0, x0, v0); + y1 = vfmaq_f32(y1, x1, v0); + y2 = vfmaq_f32(y2, x2, v0); + y3 = vfmaq_f32(y3, x3, v0); + y4 = vfmaq_f32(y4, x4, v0); + y5 = vfmaq_f32(y5, x5, v0); + y6 = vfmaq_f32(y6, x6, v0); + y7 = vfmaq_f32(y7, x7, v0); vst1_f16(y + i + 0 , vcvt_f16_f32(y0)); vst1_f16(y + i + 4 , vcvt_f16_f32(y1)); @@ -927,7 +933,7 @@ inline static void ggml_vec_mad_f16(const int n, ggml_fp16_t * restrict y, ggml_ // AVX 512-bit const int n64 = (n & ~63); - const __m512 v16 = _mm512_set1_ps(v); + const __m512 v0 = _mm512_set1_ps(v); __m512 x0, x1, x2, x3; __m512 y0, y1, y2, y3; @@ -938,15 +944,15 @@ inline static void ggml_vec_mad_f16(const int n, ggml_fp16_t * restrict y, ggml_ x2 = _mm512_cvtph_ps(_mm256_loadu_si256((__m256i*)(x + i + 32))); x3 = _mm512_cvtph_ps(_mm256_loadu_si256((__m256i*)(x + i + 48))); - y0 = _mm512_loadu_ps(y + i + 0 ); - y1 = _mm512_loadu_ps(y + i + 16); - y2 = _mm512_loadu_ps(y + i + 32); - y3 = _mm512_loadu_ps(y + i + 48); + y0 = _mm512_cvtph_ps(_mm256_loadu_si256((__m256i*)(y + i + 0 ))); + y1 = _mm512_cvtph_ps(_mm256_loadu_si256((__m256i*)(y + i + 16))); + y2 = _mm512_cvtph_ps(_mm256_loadu_si256((__m256i*)(y + i + 32))); + y3 = _mm512_cvtph_ps(_mm256_loadu_si256((__m256i*)(y + i + 48))); - y0 = _mm512_fmadd_ps(x0, v16, y0); - y1 = _mm512_fmadd_ps(x1, v16, y1); - y2 = _mm512_fmadd_ps(x2, v16, y2); - y3 = _mm512_fmadd_ps(x3, v16, y3); + y0 = _mm512_fmadd_ps(x0, v0, y0); + y1 = _mm512_fmadd_ps(x1, v0, y1); + y2 = _mm512_fmadd_ps(x2, v0, y2); + y3 = _mm512_fmadd_ps(x3, v0, y3); _mm256_storeu_si256((__m256i*)(y + i + 0 ), _mm512_cvtps_ph(y0, 0)); _mm256_storeu_si256((__m256i*)(y + i + 16), _mm512_cvtps_ph(y1, 0)); @@ -963,7 +969,7 @@ inline static void ggml_vec_mad_f16(const int n, ggml_fp16_t * restrict y, ggml_ // AVX 256-bit const int n32 = (n & ~31); - const __m256 v8 = _mm256_set1_ps(v); + const __m256 v0 = _mm256_set1_ps(v); __m256 x0, x1, x2, x3; __m256 y0, y1, y2, y3; @@ -979,10 +985,10 @@ inline static void ggml_vec_mad_f16(const int n, ggml_fp16_t * restrict y, ggml_ x2 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(x + i + 16))); x3 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(x + i + 24))); - y0 = _mm256_fmadd_ps(x0, v8, y0); - y1 = _mm256_fmadd_ps(x1, v8, y1); - y2 = _mm256_fmadd_ps(x2, v8, y2); - y3 = _mm256_fmadd_ps(x3, v8, y3); + y0 = _mm256_fmadd_ps(x0, v0, y0); + y1 = _mm256_fmadd_ps(x1, v0, y1); + y2 = _mm256_fmadd_ps(x2, v0, y2); + y3 = _mm256_fmadd_ps(x3, v0, y3); _mm_storeu_si128((__m128i*)(y + i + 0 ), _mm256_cvtps_ph(y0, 0)); _mm_storeu_si128((__m128i*)(y + i + 8 ), _mm256_cvtps_ph(y1, 0)); @@ -999,7 +1005,7 @@ inline static void ggml_vec_mad_f16(const int n, ggml_fp16_t * restrict y, ggml_ // WASM SIMD 128-bit const int n16 = (n & ~15); - const v128_t v4 = wasm_f32x4_splat(v); + const v128_t v0 = wasm_f32x4_splat(v); v128_t x0, x1, x2, x3; v128_t y0, y1, y2, y3; @@ -1023,10 +1029,10 @@ inline static void ggml_vec_mad_f16(const int n, ggml_fp16_t * restrict y, ggml_ y2 = wasm_v128_load(ty + 8); y3 = wasm_v128_load(ty + 12); - y0 = wasm_f32x4_add(y0, wasm_f32x4_mul(x0, v4)); - y1 = wasm_f32x4_add(y1, wasm_f32x4_mul(x1, v4)); - y2 = wasm_f32x4_add(y2, wasm_f32x4_mul(x2, v4)); - y3 = wasm_f32x4_add(y3, wasm_f32x4_mul(x3, v4)); + y0 = wasm_f32x4_add(y0, wasm_f32x4_mul(x0, v0)); + y1 = wasm_f32x4_add(y1, wasm_f32x4_mul(x1, v0)); + y2 = wasm_f32x4_add(y2, wasm_f32x4_mul(x2, v0)); + y3 = wasm_f32x4_add(y3, wasm_f32x4_mul(x3, v0)); wasm_v128_store(ty + 0, y0); wasm_v128_store(ty + 4, y1);