Add AVX,AVX2 support for ggml_vec_scale_f32

pull/288/head
katsu560 1 year ago committed by Georgi Gerganov
parent 1eb81f863f
commit 419b8a6402

@ -1118,7 +1118,45 @@ inline static void ggml_vec_mad_f16(const int n, ggml_fp16_t * restrict y, ggml_
#endif
}
inline static void ggml_vec_scale_f32(const int n, float * y, const float v) { for (int i = 0; i < n; ++i) y[i] *= v; }
//inline static void ggml_vec_scale_f32(const int n, float * y, const float v) { for (int i = 0; i < n; ++i) y[i] *= v; }
inline static void ggml_vec_scale_f32(const int n, float * y, const float v) {
#if defined(__AVX__) || defined(__AVX2__)
// AVX 256-bit
const int n32 = (n & ~31);
const __m256 v4 = _mm256_set1_ps(v);
__m256 y0, y1, y2, y3;
for (int i = 0; i < n32; i += 32) {
y0 = _mm256_loadu_ps(y + i + 0);
y1 = _mm256_loadu_ps(y + i + 8);
y2 = _mm256_loadu_ps(y + i + 16);
y3 = _mm256_loadu_ps(y + i + 24);
y0 = _mm256_mul_ps(y0, v4);
y1 = _mm256_mul_ps(y1, v4);
y2 = _mm256_mul_ps(y2, v4);
y3 = _mm256_mul_ps(y3, v4);
_mm256_storeu_ps(y + i + 0, y0);
_mm256_storeu_ps(y + i + 8, y1);
_mm256_storeu_ps(y + i + 16, y2);
_mm256_storeu_ps(y + i + 24, y3);
}
// leftovers
for (int i = n32; i < n; ++i) {
y[i] *= v;
}
#else
// scalar
for (int i = 0; i < n; ++i) {
y[i] *= v;
}
#endif
}
inline static void ggml_vec_norm_f32 (const int n, float * s, const float * x) { ggml_vec_dot_f32(n, s, x, x); *s = sqrt(*s); }
inline static void ggml_vec_sqr_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = x[i]*x[i]; }
inline static void ggml_vec_sqrt_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = sqrt(x[i]); }

Loading…
Cancel
Save