wip : WASM 128-bit SIMD support

pull/70/head
Georgi Gerganov 2 years ago
parent e905c6f827
commit db460b78ff

@ -123,13 +123,13 @@ else()
if (MSVC) if (MSVC)
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} /arch:AVX2 /D_CRT_SECURE_NO_WARNINGS=1") set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} /arch:AVX2 /D_CRT_SECURE_NO_WARNINGS=1")
else() else()
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -mavx -mavx2 -mfma -mf16c") if (EMSCRIPTEN)
endif()
endif()
if (EMSCRIPTEN)
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -pthread -msimd128") set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -pthread -msimd128")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -pthread") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -pthread")
else()
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -mavx -mavx2 -mfma -mf16c")
endif()
endif()
endif() endif()
# whisper - this is the main library of the project # whisper - this is the main library of the project

@ -21,13 +21,14 @@ if (WHISPER_WASM_SINGLE_FILE)
) )
endif() endif()
#-s TOTAL_MEMORY=536870912 \
set_target_properties(${TARGET} PROPERTIES LINK_FLAGS " \ set_target_properties(${TARGET} PROPERTIES LINK_FLAGS " \
--bind \ --bind \
-s MODULARIZE=1 \ -s MODULARIZE=1 \
-s ASSERTIONS=1 \ -s ASSERTIONS=1 \
-s USE_PTHREADS=1 \ -s USE_PTHREADS=1 \
-s PTHREAD_POOL_SIZE=8 \ -s PTHREAD_POOL_SIZE=9 \
-s TOTAL_MEMORY=536870912 \ -s ALLOW_MEMORY_GROWTH=1 \
-s FORCE_FILESYSTEM=1 \ -s FORCE_FILESYSTEM=1 \
-s EXPORT_NAME=\"'whisper_factory'\" \ -s EXPORT_NAME=\"'whisper_factory'\" \
${EXTRA_FLAGS} \ ${EXTRA_FLAGS} \

@ -4,6 +4,7 @@
#include <emscripten/bind.h> #include <emscripten/bind.h>
#include <vector> #include <vector>
#include <thread>
std::vector<struct whisper_context *> g_contexts(4, nullptr); std::vector<struct whisper_context *> g_contexts(4, nullptr);
@ -47,7 +48,7 @@ EMSCRIPTEN_BINDINGS(whisper) {
params.print_special_tokens = false; params.print_special_tokens = false;
params.translate = false; params.translate = false;
params.language = "en"; params.language = "en";
params.n_threads = 4; params.n_threads = std::min(8, (int) std::thread::hardware_concurrency());
params.offset_ms = 0; params.offset_ms = 0;
std::vector<float> pcmf32; std::vector<float> pcmf32;

File diff suppressed because one or more lines are too long

180
ggml.c

@ -73,7 +73,11 @@ ggml_fp16_t ggml_fp32_to_fp16(float x) {
#else #else
#ifdef __wasm_simd128__
#include <wasm_simd128.h>
#else
#include <immintrin.h> #include <immintrin.h>
#endif
// FP16 <-> FP32 // FP16 <-> FP32
// ref: https://github.com/Maratyszcza/FP16 // ref: https://github.com/Maratyszcza/FP16
@ -288,7 +292,7 @@ inline static void ggml_vec_dot_f32(const int n, float * restrict s, const float
sumf += x[i]*y[i]; sumf += x[i]*y[i];
} }
#elif defined(__AVX2__) #elif defined(__AVX2__)
// AVX 256-bit (unroll 4) // AVX 256-bit
const int n32 = (n & ~31); const int n32 = (n & ~31);
__m256 sum0 = _mm256_setzero_ps(); __m256 sum0 = _mm256_setzero_ps();
@ -330,6 +334,45 @@ inline static void ggml_vec_dot_f32(const int n, float * restrict s, const float
for (int i = n32; i < n; ++i) { for (int i = n32; i < n; ++i) {
sumf += x[i]*y[i]; sumf += x[i]*y[i];
} }
#elif defined(__wasm_simd128__)
// WASM 128-bit
const int n16 = (n & ~15);
v128_t sum0 = wasm_f32x4_splat(0);
v128_t sum1 = wasm_f32x4_splat(0);
v128_t sum2 = wasm_f32x4_splat(0);
v128_t sum3 = wasm_f32x4_splat(0);
v128_t x0, x1, x2, x3;
v128_t y0, y1, y2, y3;
for (int i = 0; i < n16; i += 16) {
x0 = wasm_v128_load(x + i + 0);
x1 = wasm_v128_load(x + i + 4);
x2 = wasm_v128_load(x + i + 8);
x3 = wasm_v128_load(x + i + 12);
y0 = wasm_v128_load(y + i + 0);
y1 = wasm_v128_load(y + i + 4);
y2 = wasm_v128_load(y + i + 8);
y3 = wasm_v128_load(y + i + 12);
sum0 = wasm_f32x4_add(sum0, wasm_f32x4_mul(x0, y0));
sum1 = wasm_f32x4_add(sum1, wasm_f32x4_mul(x1, y1));
sum2 = wasm_f32x4_add(sum2, wasm_f32x4_mul(x2, y2));
sum3 = wasm_f32x4_add(sum3, wasm_f32x4_mul(x3, y3));
}
sum0 = wasm_f32x4_add(sum0, sum1);
sum2 = wasm_f32x4_add(sum2, sum3);
sum0 = wasm_f32x4_add(sum0, sum2);
sumf = wasm_f32x4_extract_lane(sum0, 0) + wasm_f32x4_extract_lane(sum0, 1) + wasm_f32x4_extract_lane(sum0, 2) + wasm_f32x4_extract_lane(sum0, 3);
// leftovers
for (int i = n16; i < n; ++i) {
sumf += x[i]*y[i];
}
#else #else
// scalar // scalar
for (int i = 0; i < n; ++i) { for (int i = 0; i < n; ++i) {
@ -446,7 +489,7 @@ inline static void ggml_vec_dot_f16(const int n, float * restrict s, ggml_fp16_t
sumf += ggml_fp16_to_fp32(x[i])*ggml_fp16_to_fp32(y[i]); sumf += ggml_fp16_to_fp32(x[i])*ggml_fp16_to_fp32(y[i]);
} }
#elif defined(__AVX2__) #elif defined(__AVX2__)
// AVX 256-bit (unroll 4) // AVX 256-bit
const int n32 = (n & ~31); const int n32 = (n & ~31);
__m256 sum0 = _mm256_setzero_ps(); __m256 sum0 = _mm256_setzero_ps();
@ -489,6 +532,54 @@ inline static void ggml_vec_dot_f16(const int n, float * restrict s, ggml_fp16_t
//GGML_ASSERT(false); //GGML_ASSERT(false);
sumf += ggml_fp16_to_fp32(x[i])*ggml_fp16_to_fp32(y[i]); sumf += ggml_fp16_to_fp32(x[i])*ggml_fp16_to_fp32(y[i]);
} }
#elif defined(__wasm_simd128__)
// WASM 128-bit
const int n16 = (n & ~15);
v128_t sum0 = wasm_f32x4_splat(0.0f);
v128_t sum1 = wasm_f32x4_splat(0.0f);
v128_t sum2 = wasm_f32x4_splat(0.0f);
v128_t sum3 = wasm_f32x4_splat(0.0f);
v128_t x0, x1, x2, x3;
v128_t y0, y1, y2, y3;
float tx[16];
float ty[16];
for (int i = 0; i < n16; i += 16) {
for (int k = 0; k < 16; ++k) {
tx[k] = ggml_fp16_to_fp32(x[i + k]);
ty[k] = ggml_fp16_to_fp32(y[i + k]);
}
x0 = wasm_v128_load(tx + 0);
x1 = wasm_v128_load(tx + 4);
x2 = wasm_v128_load(tx + 8);
x3 = wasm_v128_load(tx + 12);
y0 = wasm_v128_load(ty + 0);
y1 = wasm_v128_load(ty + 4);
y2 = wasm_v128_load(ty + 8);
y3 = wasm_v128_load(ty + 12);
sum0 = wasm_f32x4_add(sum0, wasm_f32x4_mul(x0, y0));
sum1 = wasm_f32x4_add(sum1, wasm_f32x4_mul(x1, y1));
sum2 = wasm_f32x4_add(sum2, wasm_f32x4_mul(x2, y2));
sum3 = wasm_f32x4_add(sum3, wasm_f32x4_mul(x3, y3));
}
sum0 = wasm_f32x4_add(sum0, sum1);
sum2 = wasm_f32x4_add(sum2, sum3);
sum0 = wasm_f32x4_add(sum0, sum2);
sumf = wasm_f32x4_extract_lane(sum0, 0) + wasm_f32x4_extract_lane(sum0, 1) + wasm_f32x4_extract_lane(sum0, 2) + wasm_f32x4_extract_lane(sum0, 3);
// leftovers
for (int i = n16; i < n; ++i) {
//GGML_ASSERT(false);
sumf += ggml_fp16_to_fp32(x[i])*ggml_fp16_to_fp32(y[i]);
}
#else #else
for (int i = 0; i < n; ++i) { for (int i = 0; i < n; ++i) {
sumf += ggml_fp16_to_fp32(x[i])*ggml_fp16_to_fp32(y[i]); sumf += ggml_fp16_to_fp32(x[i])*ggml_fp16_to_fp32(y[i]);
@ -535,7 +626,7 @@ inline static void ggml_vec_mad_f32(const int n, float * restrict y, const float
y[i] += x[i]*v; y[i] += x[i]*v;
} }
#elif defined(__AVX2__) #elif defined(__AVX2__)
// AVX 256-bit (unroll 4) // AVX 256-bit
const int n32 = (n & ~31); const int n32 = (n & ~31);
const __m256 v4 = _mm256_set1_ps(v); const __m256 v4 = _mm256_set1_ps(v);
@ -569,6 +660,41 @@ inline static void ggml_vec_mad_f32(const int n, float * restrict y, const float
for (int i = n32; i < n; ++i) { for (int i = n32; i < n; ++i) {
y[i] += x[i]*v; y[i] += x[i]*v;
} }
#elif defined(__wasm_simd128__)
// WASM SIMD 128-bit
const int n16 = (n & ~15);
const v128_t v4 = wasm_f32x4_splat(v);
v128_t x0, x1, x2, x3;
v128_t y0, y1, y2, y3;
for (int i = 0; i < n16; i += 16) {
x0 = wasm_v128_load(x + i + 0);
x1 = wasm_v128_load(x + i + 4);
x2 = wasm_v128_load(x + i + 8);
x3 = wasm_v128_load(x + i + 12);
y0 = wasm_v128_load(y + i + 0);
y1 = wasm_v128_load(y + i + 4);
y2 = wasm_v128_load(y + i + 8);
y3 = wasm_v128_load(y + i + 12);
y0 = wasm_f32x4_add(y0, wasm_f32x4_mul(x0, v4));
y1 = wasm_f32x4_add(y1, wasm_f32x4_mul(x1, v4));
y2 = wasm_f32x4_add(y2, wasm_f32x4_mul(x2, v4));
y3 = wasm_f32x4_add(y3, wasm_f32x4_mul(x3, v4));
wasm_v128_store(y + i + 0, y0);
wasm_v128_store(y + i + 4, y1);
wasm_v128_store(y + i + 8, y2);
wasm_v128_store(y + i + 12, y3);
}
// leftovers
for (int i = n16; i < n; ++i) {
y[i] += x[i]*v;
}
#else #else
// scalar // scalar
for (int i = 0; i < n; ++i) { for (int i = 0; i < n; ++i) {
@ -696,6 +822,54 @@ inline static void ggml_vec_mad_f16(const int n, ggml_fp16_t * restrict y, ggml_
GGML_ASSERT(false); GGML_ASSERT(false);
y[i] = ggml_fp32_to_fp16(ggml_fp16_to_fp32(y[i]) + ggml_fp16_to_fp32(x[i])*v); y[i] = ggml_fp32_to_fp16(ggml_fp16_to_fp32(y[i]) + ggml_fp16_to_fp32(x[i])*v);
} }
#elif defined(__wasm_simd128__)
// WASM SIMD 128-bit
const int n16 = (n & ~15);
const v128_t v4 = wasm_f32x4_splat(v);
v128_t x0, x1, x2, x3;
v128_t y0, y1, y2, y3;
float tx[16];
float ty[16];
for (int i = 0; i < n16; i += 16) {
for (int k = 0; k < 16; ++k) {
tx[k] = ggml_fp16_to_fp32(x[i + k]);
ty[k] = ggml_fp16_to_fp32(y[i + k]);
}
x0 = wasm_v128_load(tx + 0);
x1 = wasm_v128_load(tx + 4);
x2 = wasm_v128_load(tx + 8);
x3 = wasm_v128_load(tx + 12);
y0 = wasm_v128_load(ty + 0);
y1 = wasm_v128_load(ty + 4);
y2 = wasm_v128_load(ty + 8);
y3 = wasm_v128_load(ty + 12);
y0 = wasm_f32x4_add(y0, wasm_f32x4_mul(x0, v4));
y1 = wasm_f32x4_add(y1, wasm_f32x4_mul(x1, v4));
y2 = wasm_f32x4_add(y2, wasm_f32x4_mul(x2, v4));
y3 = wasm_f32x4_add(y3, wasm_f32x4_mul(x3, v4));
wasm_v128_store(ty + 0, y0);
wasm_v128_store(ty + 4, y1);
wasm_v128_store(ty + 8, y2);
wasm_v128_store(ty + 12, y3);
for (int k = 0; k < 16; ++k) {
y[i + k] = ggml_fp32_to_fp16(ty[k]);
}
}
// leftovers
for (int i = n16; i < n; ++i) {
GGML_ASSERT(false);
y[i] = ggml_fp32_to_fp16(ggml_fp16_to_fp32(y[i]) + ggml_fp16_to_fp32(x[i])*v);
}
#else #else
for (int i = 0; i < n; ++i) { for (int i = 0; i < n; ++i) {
y[i] = ggml_fp32_to_fp16(ggml_fp16_to_fp32(y[i]) + ggml_fp16_to_fp32(x[i])*v); y[i] = ggml_fp32_to_fp16(ggml_fp16_to_fp32(y[i]) + ggml_fp16_to_fp32(x[i])*v);

Loading…
Cancel
Save