From 95f4fc70caf0bd492ce18dc48b76216c1b891bde Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Wed, 26 Oct 2022 18:48:54 +0300 Subject: [PATCH] Try to add AVX 512-bit support --- Makefile | 6 +- ggml.c | 175 +++++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 180 insertions(+), 1 deletion(-) diff --git a/Makefile b/Makefile index 1b9d4b6..d120f1f 100644 --- a/Makefile +++ b/Makefile @@ -42,7 +42,11 @@ endif # TODO: probably these flags need to be tweaked on some architectures # feel free to update the Makefile for your architecture and send a pull request or issue ifeq ($(UNAME_M),x86_64) - CFLAGS += -mavx -mavx2 -mfma -mf16c + # AVX 512 + CFLAGS += -mavx512f -mavx512dq -mfma -mf16c + + # AVX 256 + #CFLAGS += -mavx -mavx2 -mfma -mf16c endif ifeq ($(UNAME_M),amd64) CFLAGS += -mavx -mavx2 -mfma -mf16c diff --git a/ggml.c b/ggml.c index 3a36802..6f80820 100644 --- a/ggml.c +++ b/ggml.c @@ -320,6 +320,45 @@ inline static void ggml_vec_dot_f32(const int n, float * restrict s, const float for (int i = n16; i < n; ++i) { sumf += x[i]*y[i]; } +#elif defined(__AVX512F__) + const int n64 = (n & ~63); + + __m512 sum0 = _mm512_setzero_ps(); + __m512 sum1 = _mm512_setzero_ps(); + __m512 sum2 = _mm512_setzero_ps(); + __m512 sum3 = _mm512_setzero_ps(); + + __m512 x0, x1, x2, x3; + __m512 y0, y1, y2, y3; + + for (int i = 0; i < n64; i += 64) { + x0 = _mm512_loadu_ps(x + i + 0); + x1 = _mm512_loadu_ps(x + i + 16); + x2 = _mm512_loadu_ps(x + i + 32); + x3 = _mm512_loadu_ps(x + i + 48); + + y0 = _mm512_loadu_ps(y + i + 0); + y1 = _mm512_loadu_ps(y + i + 16); + y2 = _mm512_loadu_ps(y + i + 32); + y3 = _mm512_loadu_ps(y + i + 48); + + sum0 = _mm512_fmadd_ps(x0, y0, sum0); + sum1 = _mm512_fmadd_ps(x1, y1, sum1); + sum2 = _mm512_fmadd_ps(x2, y2, sum2); + sum3 = _mm512_fmadd_ps(x3, y3, sum3); + } + + sum0 = _mm512_add_ps(sum0, sum1); + sum2 = _mm512_add_ps(sum2, sum3); + sum0 = _mm512_add_ps(sum0, sum2); + + sumf = sum0[0] + sum0[1] + sum0[2] + sum0[3] + sum0[4] + sum0[5] + sum0[6] + sum0[7] + + sum0[8] + sum0[9] + sum0[10] + sum0[11] + sum0[12] + sum0[13] + sum0[14] + sum0[15]; + + // leftovers + for (int i = n64; i < n; ++i) { + sumf += x[i]*y[i]; + } #elif defined(__AVX2__) // AVX 256-bit const int n32 = (n & ~31); @@ -517,6 +556,61 @@ inline static void ggml_vec_dot_f16(const int n, float * restrict s, ggml_fp16_t for (int i = n32; i < n; ++i) { sumf += ggml_fp16_to_fp32(x[i])*ggml_fp16_to_fp32(y[i]); } +#elif defined(__AVX512F__) + // AVX 512-bit + const int n64 = (n & ~63); + + __m512 sum0 = _mm512_setzero_ps(); + __m512 sum1 = _mm512_setzero_ps(); + __m512 sum2 = _mm512_setzero_ps(); + __m512 sum3 = _mm512_setzero_ps(); + + __m512 x0, x1, x2, x3; + __m512 y0, y1, y2, y3; + + __m256 t0lo, t0hi, t1lo, t1hi, t2lo, t2hi, t3lo, t3hi; + + for (int i = 0; i < n64; i += 64) { + // TODO: is this the best way to do this? + t0lo = _mm256_loadu_ps((const float*)(x + i + 0 )); + t0hi = _mm256_loadu_ps((const float*)(x + i + 8 )); + t1lo = _mm256_loadu_ps((const float*)(x + i + 16)); + t1hi = _mm256_loadu_ps((const float*)(x + i + 24)); + t2lo = _mm256_loadu_ps((const float*)(x + i + 32)); + t2hi = _mm256_loadu_ps((const float*)(x + i + 40)); + t3lo = _mm256_loadu_ps((const float*)(x + i + 48)); + t3hi = _mm256_loadu_ps((const float*)(x + i + 56)); + + x0 = _mm512_insertf32x8(_mm512_castps256_ps512(t0lo), t0hi, 1); + x1 = _mm512_insertf32x8(_mm512_castps256_ps512(t1lo), t1hi, 1); + x2 = _mm512_insertf32x8(_mm512_castps256_ps512(t2lo), t2hi, 1); + x3 = _mm512_insertf32x8(_mm512_castps256_ps512(t3lo), t3hi, 1); + + t0lo = _mm256_loadu_ps((const float*)(y + i + 0 )); + t0hi = _mm256_loadu_ps((const float*)(y + i + 8 )); + t1lo = _mm256_loadu_ps((const float*)(y + i + 16)); + t1hi = _mm256_loadu_ps((const float*)(y + i + 24)); + t2lo = _mm256_loadu_ps((const float*)(y + i + 32)); + t2hi = _mm256_loadu_ps((const float*)(y + i + 40)); + t3lo = _mm256_loadu_ps((const float*)(y + i + 48)); + t3hi = _mm256_loadu_ps((const float*)(y + i + 56)); + + y0 = _mm512_insertf32x8(_mm512_castps256_ps512(t0lo), t0hi, 1); + y1 = _mm512_insertf32x8(_mm512_castps256_ps512(t1lo), t1hi, 1); + y2 = _mm512_insertf32x8(_mm512_castps256_ps512(t2lo), t2hi, 1); + y3 = _mm512_insertf32x8(_mm512_castps256_ps512(t3lo), t3hi, 1); + + sum0 = _mm512_fmadd_ps(x0, y0, sum0); + sum1 = _mm512_fmadd_ps(x1, y1, sum1); + sum2 = _mm512_fmadd_ps(x2, y2, sum2); + sum3 = _mm512_fmadd_ps(x3, y3, sum3); + } + + const __m512 sum01 = _mm512_add_ps(sum0, sum1); + const __m512 sum23 = _mm512_add_ps(sum2, sum3); + const __m512 sum0123 = _mm512_add_ps(sum01, sum23); + + sumf = sum0123[0] + sum0123[1] + sum0123[2] + sum0123[3] + sum0123[4] + sum0123[5] + sum0123[6] + sum0123[7]; #elif defined(__AVX2__) // AVX 256-bit const int n32 = (n & ~31); @@ -654,6 +748,41 @@ inline static void ggml_vec_mad_f32(const int n, float * restrict y, const float for (int i = n16; i < n; ++i) { y[i] += x[i]*v; } +#elif defined(__AVX512F__) + // AVX512 512-bit + const int n64 = (n & ~63); + + const __m512 v4 = _mm512_set1_ps(v); + + __m512 x0, x1, x2, x3; + __m512 y0, y1, y2, y3; + + for (int i = 0; i < n64; i += 64) { + x0 = _mm512_loadu_ps(x + i + 0); + x1 = _mm512_loadu_ps(x + i + 16); + x2 = _mm512_loadu_ps(x + i + 32); + x3 = _mm512_loadu_ps(x + i + 48); + + y0 = _mm512_loadu_ps(y + i + 0); + y1 = _mm512_loadu_ps(y + i + 16); + y2 = _mm512_loadu_ps(y + i + 32); + y3 = _mm512_loadu_ps(y + i + 48); + + y0 = _mm512_fmadd_ps(x0, v4, y0); + y1 = _mm512_fmadd_ps(x1, v4, y1); + y2 = _mm512_fmadd_ps(x2, v4, y2); + y3 = _mm512_fmadd_ps(x3, v4, y3); + + _mm512_storeu_ps(y + i + 0, y0); + _mm512_storeu_ps(y + i + 16, y1); + _mm512_storeu_ps(y + i + 32, y2); + _mm512_storeu_ps(y + i + 48, y3); + } + + // leftovers + for (int i = n64; i < n; ++i) { + y[i] += x[i]*v; + } #elif defined(__AVX2__) // AVX 256-bit const int n32 = (n & ~31); @@ -815,6 +944,52 @@ inline static void ggml_vec_mad_f16(const int n, ggml_fp16_t * restrict y, ggml_ GGML_ASSERT(false); y[i] = ggml_fp32_to_fp16(ggml_fp16_to_fp32(y[i]) + ggml_fp16_to_fp32(x[i])*v); } +#elif defined(__AVX512F__) + // AVX 512-bit + const int n64 = (n & ~63); + + const __m512 v16 = _mm512_set1_ps(v); + + __m512 x0, x1, x2, x3; + __m512 y0, y1, y2, y3; + + __m256 t0lo, t0hi, t1lo, t1hi, t2lo, t2hi, t3lo, t3hi; + + for (int i = 0; i < n64; i += 64) { + t0lo = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(y + i + 0 ))); + t0hi = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(y + i + 8 ))); + t1lo = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(y + i + 16))); + t1hi = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(y + i + 24))); + t2lo = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(y + i + 32))); + t2hi = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(y + i + 40))); + t3lo = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(y + i + 48))); + t3hi = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(y + i + 56))); + + y0 = _mm512_insertf32x8(_mm512_castps256_ps512(t0lo), t0hi, 1); + y1 = _mm512_insertf32x8(_mm512_castps256_ps512(t1lo), t1hi, 1); + y2 = _mm512_insertf32x8(_mm512_castps256_ps512(t2lo), t2hi, 1); + y3 = _mm512_insertf32x8(_mm512_castps256_ps512(t3lo), t3hi, 1); + + t0lo = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(x + i + 0 ))); + t0hi = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(x + i + 8 ))); + t1lo = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(x + i + 16))); + t1hi = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(x + i + 24))); + t2lo = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(x + i + 32))); + t2hi = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(x + i + 40))); + t3lo = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(x + i + 48))); + t3hi = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(x + i + 56))); + + y0 = _mm512_fmadd_ps(x0, v16, y0); + y1 = _mm512_fmadd_ps(x1, v16, y1); + y2 = _mm512_fmadd_ps(x2, v16, y2); + y3 = _mm512_fmadd_ps(x3, v16, y3); + } + + // leftovers + for (int i = n64; i < n; ++i) { + GGML_ASSERT(false); + y[i] = ggml_fp32_to_fp16(ggml_fp16_to_fp32(y[i]) + ggml_fp16_to_fp32(x[i])*v); + } #elif defined(__AVX2__) // AVX 256-bit const int n32 = (n & ~31);