Try to add AVX 512-bit support

pull/95/head
Georgi Gerganov 2 years ago
parent 058a27b2e5
commit 95f4fc70ca
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735

@ -42,7 +42,11 @@ endif
# TODO: probably these flags need to be tweaked on some architectures
# feel free to update the Makefile for your architecture and send a pull request or issue
ifeq ($(UNAME_M),x86_64)
CFLAGS += -mavx -mavx2 -mfma -mf16c
# AVX 512
CFLAGS += -mavx512f -mavx512dq -mfma -mf16c
# AVX 256
#CFLAGS += -mavx -mavx2 -mfma -mf16c
endif
ifeq ($(UNAME_M),amd64)
CFLAGS += -mavx -mavx2 -mfma -mf16c

175
ggml.c

@ -320,6 +320,45 @@ inline static void ggml_vec_dot_f32(const int n, float * restrict s, const float
for (int i = n16; i < n; ++i) {
sumf += x[i]*y[i];
}
#elif defined(__AVX512F__)
const int n64 = (n & ~63);
__m512 sum0 = _mm512_setzero_ps();
__m512 sum1 = _mm512_setzero_ps();
__m512 sum2 = _mm512_setzero_ps();
__m512 sum3 = _mm512_setzero_ps();
__m512 x0, x1, x2, x3;
__m512 y0, y1, y2, y3;
for (int i = 0; i < n64; i += 64) {
x0 = _mm512_loadu_ps(x + i + 0);
x1 = _mm512_loadu_ps(x + i + 16);
x2 = _mm512_loadu_ps(x + i + 32);
x3 = _mm512_loadu_ps(x + i + 48);
y0 = _mm512_loadu_ps(y + i + 0);
y1 = _mm512_loadu_ps(y + i + 16);
y2 = _mm512_loadu_ps(y + i + 32);
y3 = _mm512_loadu_ps(y + i + 48);
sum0 = _mm512_fmadd_ps(x0, y0, sum0);
sum1 = _mm512_fmadd_ps(x1, y1, sum1);
sum2 = _mm512_fmadd_ps(x2, y2, sum2);
sum3 = _mm512_fmadd_ps(x3, y3, sum3);
}
sum0 = _mm512_add_ps(sum0, sum1);
sum2 = _mm512_add_ps(sum2, sum3);
sum0 = _mm512_add_ps(sum0, sum2);
sumf = sum0[0] + sum0[1] + sum0[2] + sum0[3] + sum0[4] + sum0[5] + sum0[6] + sum0[7] +
sum0[8] + sum0[9] + sum0[10] + sum0[11] + sum0[12] + sum0[13] + sum0[14] + sum0[15];
// leftovers
for (int i = n64; i < n; ++i) {
sumf += x[i]*y[i];
}
#elif defined(__AVX2__)
// AVX 256-bit
const int n32 = (n & ~31);
@ -517,6 +556,61 @@ inline static void ggml_vec_dot_f16(const int n, float * restrict s, ggml_fp16_t
for (int i = n32; i < n; ++i) {
sumf += ggml_fp16_to_fp32(x[i])*ggml_fp16_to_fp32(y[i]);
}
#elif defined(__AVX512F__)
// AVX 512-bit
const int n64 = (n & ~63);
__m512 sum0 = _mm512_setzero_ps();
__m512 sum1 = _mm512_setzero_ps();
__m512 sum2 = _mm512_setzero_ps();
__m512 sum3 = _mm512_setzero_ps();
__m512 x0, x1, x2, x3;
__m512 y0, y1, y2, y3;
__m256 t0lo, t0hi, t1lo, t1hi, t2lo, t2hi, t3lo, t3hi;
for (int i = 0; i < n64; i += 64) {
// TODO: is this the best way to do this?
t0lo = _mm256_loadu_ps((const float*)(x + i + 0 ));
t0hi = _mm256_loadu_ps((const float*)(x + i + 8 ));
t1lo = _mm256_loadu_ps((const float*)(x + i + 16));
t1hi = _mm256_loadu_ps((const float*)(x + i + 24));
t2lo = _mm256_loadu_ps((const float*)(x + i + 32));
t2hi = _mm256_loadu_ps((const float*)(x + i + 40));
t3lo = _mm256_loadu_ps((const float*)(x + i + 48));
t3hi = _mm256_loadu_ps((const float*)(x + i + 56));
x0 = _mm512_insertf32x8(_mm512_castps256_ps512(t0lo), t0hi, 1);
x1 = _mm512_insertf32x8(_mm512_castps256_ps512(t1lo), t1hi, 1);
x2 = _mm512_insertf32x8(_mm512_castps256_ps512(t2lo), t2hi, 1);
x3 = _mm512_insertf32x8(_mm512_castps256_ps512(t3lo), t3hi, 1);
t0lo = _mm256_loadu_ps((const float*)(y + i + 0 ));
t0hi = _mm256_loadu_ps((const float*)(y + i + 8 ));
t1lo = _mm256_loadu_ps((const float*)(y + i + 16));
t1hi = _mm256_loadu_ps((const float*)(y + i + 24));
t2lo = _mm256_loadu_ps((const float*)(y + i + 32));
t2hi = _mm256_loadu_ps((const float*)(y + i + 40));
t3lo = _mm256_loadu_ps((const float*)(y + i + 48));
t3hi = _mm256_loadu_ps((const float*)(y + i + 56));
y0 = _mm512_insertf32x8(_mm512_castps256_ps512(t0lo), t0hi, 1);
y1 = _mm512_insertf32x8(_mm512_castps256_ps512(t1lo), t1hi, 1);
y2 = _mm512_insertf32x8(_mm512_castps256_ps512(t2lo), t2hi, 1);
y3 = _mm512_insertf32x8(_mm512_castps256_ps512(t3lo), t3hi, 1);
sum0 = _mm512_fmadd_ps(x0, y0, sum0);
sum1 = _mm512_fmadd_ps(x1, y1, sum1);
sum2 = _mm512_fmadd_ps(x2, y2, sum2);
sum3 = _mm512_fmadd_ps(x3, y3, sum3);
}
const __m512 sum01 = _mm512_add_ps(sum0, sum1);
const __m512 sum23 = _mm512_add_ps(sum2, sum3);
const __m512 sum0123 = _mm512_add_ps(sum01, sum23);
sumf = sum0123[0] + sum0123[1] + sum0123[2] + sum0123[3] + sum0123[4] + sum0123[5] + sum0123[6] + sum0123[7];
#elif defined(__AVX2__)
// AVX 256-bit
const int n32 = (n & ~31);
@ -654,6 +748,41 @@ inline static void ggml_vec_mad_f32(const int n, float * restrict y, const float
for (int i = n16; i < n; ++i) {
y[i] += x[i]*v;
}
#elif defined(__AVX512F__)
// AVX512 512-bit
const int n64 = (n & ~63);
const __m512 v4 = _mm512_set1_ps(v);
__m512 x0, x1, x2, x3;
__m512 y0, y1, y2, y3;
for (int i = 0; i < n64; i += 64) {
x0 = _mm512_loadu_ps(x + i + 0);
x1 = _mm512_loadu_ps(x + i + 16);
x2 = _mm512_loadu_ps(x + i + 32);
x3 = _mm512_loadu_ps(x + i + 48);
y0 = _mm512_loadu_ps(y + i + 0);
y1 = _mm512_loadu_ps(y + i + 16);
y2 = _mm512_loadu_ps(y + i + 32);
y3 = _mm512_loadu_ps(y + i + 48);
y0 = _mm512_fmadd_ps(x0, v4, y0);
y1 = _mm512_fmadd_ps(x1, v4, y1);
y2 = _mm512_fmadd_ps(x2, v4, y2);
y3 = _mm512_fmadd_ps(x3, v4, y3);
_mm512_storeu_ps(y + i + 0, y0);
_mm512_storeu_ps(y + i + 16, y1);
_mm512_storeu_ps(y + i + 32, y2);
_mm512_storeu_ps(y + i + 48, y3);
}
// leftovers
for (int i = n64; i < n; ++i) {
y[i] += x[i]*v;
}
#elif defined(__AVX2__)
// AVX 256-bit
const int n32 = (n & ~31);
@ -815,6 +944,52 @@ inline static void ggml_vec_mad_f16(const int n, ggml_fp16_t * restrict y, ggml_
GGML_ASSERT(false);
y[i] = ggml_fp32_to_fp16(ggml_fp16_to_fp32(y[i]) + ggml_fp16_to_fp32(x[i])*v);
}
#elif defined(__AVX512F__)
// AVX 512-bit
const int n64 = (n & ~63);
const __m512 v16 = _mm512_set1_ps(v);
__m512 x0, x1, x2, x3;
__m512 y0, y1, y2, y3;
__m256 t0lo, t0hi, t1lo, t1hi, t2lo, t2hi, t3lo, t3hi;
for (int i = 0; i < n64; i += 64) {
t0lo = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(y + i + 0 )));
t0hi = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(y + i + 8 )));
t1lo = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(y + i + 16)));
t1hi = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(y + i + 24)));
t2lo = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(y + i + 32)));
t2hi = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(y + i + 40)));
t3lo = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(y + i + 48)));
t3hi = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(y + i + 56)));
y0 = _mm512_insertf32x8(_mm512_castps256_ps512(t0lo), t0hi, 1);
y1 = _mm512_insertf32x8(_mm512_castps256_ps512(t1lo), t1hi, 1);
y2 = _mm512_insertf32x8(_mm512_castps256_ps512(t2lo), t2hi, 1);
y3 = _mm512_insertf32x8(_mm512_castps256_ps512(t3lo), t3hi, 1);
t0lo = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(x + i + 0 )));
t0hi = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(x + i + 8 )));
t1lo = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(x + i + 16)));
t1hi = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(x + i + 24)));
t2lo = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(x + i + 32)));
t2hi = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(x + i + 40)));
t3lo = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(x + i + 48)));
t3hi = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(x + i + 56)));
y0 = _mm512_fmadd_ps(x0, v16, y0);
y1 = _mm512_fmadd_ps(x1, v16, y1);
y2 = _mm512_fmadd_ps(x2, v16, y2);
y3 = _mm512_fmadd_ps(x3, v16, y3);
}
// leftovers
for (int i = n64; i < n; ++i) {
GGML_ASSERT(false);
y[i] = ggml_fp32_to_fp16(ggml_fp16_to_fp32(y[i]) + ggml_fp16_to_fp32(x[i])*v);
}
#elif defined(__AVX2__)
// AVX 256-bit
const int n32 = (n & ~31);

Loading…
Cancel
Save