diff --git a/Makefile b/Makefile index d120f1f..40a0187 100644 --- a/Makefile +++ b/Makefile @@ -43,7 +43,7 @@ endif # feel free to update the Makefile for your architecture and send a pull request or issue ifeq ($(UNAME_M),x86_64) # AVX 512 - CFLAGS += -mavx512f -mavx512dq -mfma -mf16c + CFLAGS += -mavx512f -mfma -mf16c # AVX 256 #CFLAGS += -mavx -mavx2 -mfma -mf16c diff --git a/ggml.c b/ggml.c index 6f80820..7d67078 100644 --- a/ggml.c +++ b/ggml.c @@ -568,37 +568,16 @@ inline static void ggml_vec_dot_f16(const int n, float * restrict s, ggml_fp16_t __m512 x0, x1, x2, x3; __m512 y0, y1, y2, y3; - __m256 t0lo, t0hi, t1lo, t1hi, t2lo, t2hi, t3lo, t3hi; - for (int i = 0; i < n64; i += 64) { - // TODO: is this the best way to do this? - t0lo = _mm256_loadu_ps((const float*)(x + i + 0 )); - t0hi = _mm256_loadu_ps((const float*)(x + i + 8 )); - t1lo = _mm256_loadu_ps((const float*)(x + i + 16)); - t1hi = _mm256_loadu_ps((const float*)(x + i + 24)); - t2lo = _mm256_loadu_ps((const float*)(x + i + 32)); - t2hi = _mm256_loadu_ps((const float*)(x + i + 40)); - t3lo = _mm256_loadu_ps((const float*)(x + i + 48)); - t3hi = _mm256_loadu_ps((const float*)(x + i + 56)); - - x0 = _mm512_insertf32x8(_mm512_castps256_ps512(t0lo), t0hi, 1); - x1 = _mm512_insertf32x8(_mm512_castps256_ps512(t1lo), t1hi, 1); - x2 = _mm512_insertf32x8(_mm512_castps256_ps512(t2lo), t2hi, 1); - x3 = _mm512_insertf32x8(_mm512_castps256_ps512(t3lo), t3hi, 1); - - t0lo = _mm256_loadu_ps((const float*)(y + i + 0 )); - t0hi = _mm256_loadu_ps((const float*)(y + i + 8 )); - t1lo = _mm256_loadu_ps((const float*)(y + i + 16)); - t1hi = _mm256_loadu_ps((const float*)(y + i + 24)); - t2lo = _mm256_loadu_ps((const float*)(y + i + 32)); - t2hi = _mm256_loadu_ps((const float*)(y + i + 40)); - t3lo = _mm256_loadu_ps((const float*)(y + i + 48)); - t3hi = _mm256_loadu_ps((const float*)(y + i + 56)); - - y0 = _mm512_insertf32x8(_mm512_castps256_ps512(t0lo), t0hi, 1); - y1 = _mm512_insertf32x8(_mm512_castps256_ps512(t1lo), t1hi, 1); - y2 = _mm512_insertf32x8(_mm512_castps256_ps512(t2lo), t2hi, 1); - y3 = _mm512_insertf32x8(_mm512_castps256_ps512(t3lo), t3hi, 1); + x0 = _mm512_cvtph_ps(_mm256_loadu_si256((__m256i*)(x + i + 0 ))); + x1 = _mm512_cvtph_ps(_mm256_loadu_si256((__m256i*)(x + i + 16))); + x2 = _mm512_cvtph_ps(_mm256_loadu_si256((__m256i*)(x + i + 32))); + x3 = _mm512_cvtph_ps(_mm256_loadu_si256((__m256i*)(x + i + 48))); + + y0 = _mm512_cvtph_ps(_mm256_loadu_si256((__m256i*)(y + i + 0 ))); + y1 = _mm512_cvtph_ps(_mm256_loadu_si256((__m256i*)(y + i + 16))); + y2 = _mm512_cvtph_ps(_mm256_loadu_si256((__m256i*)(y + i + 32))); + y3 = _mm512_cvtph_ps(_mm256_loadu_si256((__m256i*)(y + i + 48))); sum0 = _mm512_fmadd_ps(x0, y0, sum0); sum1 = _mm512_fmadd_ps(x1, y1, sum1); @@ -953,36 +932,26 @@ inline static void ggml_vec_mad_f16(const int n, ggml_fp16_t * restrict y, ggml_ __m512 x0, x1, x2, x3; __m512 y0, y1, y2, y3; - __m256 t0lo, t0hi, t1lo, t1hi, t2lo, t2hi, t3lo, t3hi; - for (int i = 0; i < n64; i += 64) { - t0lo = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(y + i + 0 ))); - t0hi = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(y + i + 8 ))); - t1lo = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(y + i + 16))); - t1hi = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(y + i + 24))); - t2lo = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(y + i + 32))); - t2hi = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(y + i + 40))); - t3lo = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(y + i + 48))); - t3hi = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(y + i + 56))); - - y0 = _mm512_insertf32x8(_mm512_castps256_ps512(t0lo), t0hi, 1); - y1 = _mm512_insertf32x8(_mm512_castps256_ps512(t1lo), t1hi, 1); - y2 = _mm512_insertf32x8(_mm512_castps256_ps512(t2lo), t2hi, 1); - y3 = _mm512_insertf32x8(_mm512_castps256_ps512(t3lo), t3hi, 1); - - t0lo = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(x + i + 0 ))); - t0hi = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(x + i + 8 ))); - t1lo = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(x + i + 16))); - t1hi = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(x + i + 24))); - t2lo = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(x + i + 32))); - t2hi = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(x + i + 40))); - t3lo = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(x + i + 48))); - t3hi = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(x + i + 56))); + x0 = _mm512_cvtph_ps(_mm256_loadu_si256((__m256i*)(x + i + 0 ))); + x1 = _mm512_cvtph_ps(_mm256_loadu_si256((__m256i*)(x + i + 16))); + x2 = _mm512_cvtph_ps(_mm256_loadu_si256((__m256i*)(x + i + 32))); + x3 = _mm512_cvtph_ps(_mm256_loadu_si256((__m256i*)(x + i + 48))); + + y0 = _mm512_loadu_ps(y + i + 0 ); + y1 = _mm512_loadu_ps(y + i + 16); + y2 = _mm512_loadu_ps(y + i + 32); + y3 = _mm512_loadu_ps(y + i + 48); y0 = _mm512_fmadd_ps(x0, v16, y0); y1 = _mm512_fmadd_ps(x1, v16, y1); y2 = _mm512_fmadd_ps(x2, v16, y2); y3 = _mm512_fmadd_ps(x3, v16, y3); + + _mm256_storeu_si256((__m256i*)(y + i + 0 ), _mm512_cvtps_ph(y0, 0)); + _mm256_storeu_si256((__m256i*)(y + i + 16), _mm512_cvtps_ph(y1, 0)); + _mm256_storeu_si256((__m256i*)(y + i + 32), _mm512_cvtps_ph(y2, 0)); + _mm256_storeu_si256((__m256i*)(y + i + 48), _mm512_cvtps_ph(y3, 0)); } // leftovers