sync : whisper.cpp

- Add MSVC header
- FP16 GELU
- C interface fixes (no unions)
- Minor CMake updates
pull/12/head
Georgi Gerganov 2 years ago
parent 270829aa9f
commit 7094be1f37
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735

@ -15,7 +15,7 @@ endif()
# options
option(GGML_ALL_WARNINGS "ggml: enable all compiler warnings" ON)
option(GGML_ALL_WARNINGS "ggml: enable all compiler warnings" ON)
option(GGML_ALL_WARNINGS_3RD_PARTY "ggml: enable all compiler warnings in 3rd party libs" OFF)
option(GGML_SANITIZE_THREAD "ggml: enable thread sanitizer" OFF)
@ -25,7 +25,7 @@ option(GGML_SANITIZE_UNDEFINED "ggml: enable undefined sanitizer" OFF)
option(GGML_BUILD_TESTS "ggml: build tests" ${GGML_STANDALONE})
option(GGML_BUILD_EXAMPLES "ggml: build examples" ${GGML_STANDALONE})
option(GGML_PERF "ggml: enable perf timings" ${GGML_PERF})
option(GGML_PERF "ggml: enable perf timings" OFF)
option(GGML_NO_ACCELERATE "ggml: disable Accelerate framework" OFF)
# sanitizers

@ -216,7 +216,7 @@ int main(int argc, char ** argv) {
// run the inference
{
whisper_full_params wparams = whisper_full_default_params(WHISPER_DECODE_GREEDY);
whisper_full_params wparams = whisper_full_default_params(WHISPER_SAMPLING_GREEDY);
wparams.print_realtime = true;
wparams.print_progress = false;

@ -2256,51 +2256,63 @@ void whisper_print_timings(struct whisper_context * ctx) {
////////////////////////////////////////////////////////////////////////////
struct whisper_full_params whisper_full_default_params(enum whisper_decode_strategy strategy) {
struct whisper_full_params whisper_full_default_params(enum whisper_sampling_strategy strategy) {
struct whisper_full_params result;
switch (strategy) {
case WHISPER_DECODE_GREEDY:
case WHISPER_SAMPLING_GREEDY:
{
result = {
.strategy = WHISPER_DECODE_GREEDY,
.n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency()),
.offset_ms = 0,
/*.strategy =*/ WHISPER_SAMPLING_GREEDY,
.translate = false,
.no_context = false,
.print_special_tokens = false,
.print_progress = true,
.print_realtime = false,
.print_timestamps = true,
/*.n_threads =*/ std::min(4, (int32_t) std::thread::hardware_concurrency()),
/*.offset_ms =*/ 0,
.language = "en",
/*.translate =*/ false,
/*.no_context =*/ false,
/*.print_special_tokens =*/ false,
/*.print_progress =*/ true,
/*.print_realtime =*/ false,
/*.print_timestamps =*/ true,
.greedy = {
.n_past = 0,
/*.language =*/ "en",
/*.greedy =*/ {
/*.n_past =*/ 0,
},
/*.beam_search =*/ {
/*.n_past =*/ -1,
/*.beam_width =*/ -1,
/*.n_best =*/ -1,
},
};
} break;
case WHISPER_DECODE_BEAM_SEARCH:
case WHISPER_SAMPLING_BEAM_SEARCH:
{
result = {
.strategy = WHISPER_DECODE_GREEDY,
.n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency()),
.offset_ms = 0,
.translate = false,
.no_context = false,
.print_special_tokens = false,
.print_progress = true,
.print_realtime = false,
.print_timestamps = true,
.language = "en",
.beam_search = {
.n_past = 0,
.beam_width = 10,
.n_best = 5,
/*.strategy =*/ WHISPER_SAMPLING_BEAM_SEARCH,
/*.n_threads =*/ std::min(4, (int32_t) std::thread::hardware_concurrency()),
/*.offset_ms =*/ 0,
/*.translate =*/ false,
/*.no_context =*/ false,
/*.print_special_tokens =*/ false,
/*.print_progress =*/ true,
/*.print_realtime =*/ false,
/*.print_timestamps =*/ true,
/*.language =*/ "en",
/*.greedy =*/ {
/*.n_past =*/ -1,
},
/*.beam_search =*/ {
/*.n_past =*/ 0,
/*.beam_width =*/ 10,
/*.n_best =*/ 5,
},
};
} break;
@ -2425,7 +2437,7 @@ int whisper_full(
whisper_token id = 0;
whisper_token tid = whisper_token_beg(ctx);
id = whisper_sample_best(ctx, result_len == 0 || i > 32);
id = whisper_sample_best(ctx, result_len == 0);
if (i > 0) {
tid = whisper_sample_timestamp(ctx);
}
@ -2445,9 +2457,12 @@ int whisper_full(
// end of text token
if (id == whisper_token_eot(ctx)) {
if (result_len == 0) {
// TODO: figure out how to resolve this
fprintf(stderr, "\n%s: failed to generate timestamp token - this should not happen\n\n", __func__);
//result_len = i + 1;
if (seek + seek_delta + 100 >= whisper_n_len(ctx)) {
result_len = i + 1;
} else {
// TODO: figure out how to resolve this
fprintf(stderr, "\n%s: failed to generate timestamp token - this should not happen\n\n", __func__);
}
}
break;
}

@ -31,7 +31,8 @@ extern "C" {
//
// C interface
//
// The following interface is thread-safe as long as the sample whisper_context is not used by multiple threads
// concurrently.
//
// Basic usage:
//
@ -153,14 +154,14 @@ extern "C" {
////////////////////////////////////////////////////////////////////////////
// Available decoding strategies
enum whisper_decode_strategy {
WHISPER_DECODE_GREEDY, // Always select the most probable token
WHISPER_DECODE_BEAM_SEARCH, // TODO: not implemented yet!
// Available sampling strategies
enum whisper_sampling_strategy {
WHISPER_SAMPLING_GREEDY, // Always select the most probable token
WHISPER_SAMPLING_BEAM_SEARCH, // TODO: not implemented yet!
};
struct whisper_full_params {
enum whisper_decode_strategy strategy;
enum whisper_sampling_strategy strategy;
int n_threads;
int offset_ms;
@ -174,20 +175,18 @@ extern "C" {
const char * language;
union {
struct {
int n_past;
} greedy;
struct {
int n_past;
int beam_width;
int n_best;
} beam_search;
};
struct {
int n_past;
} greedy;
struct {
int n_past;
int beam_width;
int n_best;
} beam_search;
};
WHISPER_API struct whisper_full_params whisper_full_default_params(enum whisper_decode_strategy strategy);
WHISPER_API struct whisper_full_params whisper_full_default_params(enum whisper_sampling_strategy strategy);
// Run the entire model: PCM -> log mel spectrogram -> encoder -> decoder -> text
// Uses the specified decoding strategy to obtain the text.

@ -14,7 +14,6 @@
#include <stdint.h>
#include <stdio.h>
#if defined _MSC_VER
#include "msvc_thread_atomic.h"
#else
@ -24,6 +23,7 @@ typedef void* thread_ret_t;
#endif
#define GGML_DEBUG 0
#define GGML_GELU_FP16
#if UINTPTR_MAX == 0xFFFFFFFF
#define GGML_MEM_ALIGN 4
@ -723,20 +723,22 @@ inline static void ggml_vec_gelu_f16(const int n, ggml_fp16_t * y, const ggml_fp
}
}
#ifdef GGML_GELU_FP16
inline static void ggml_vec_gelu_f32(const int n, float * y, const float * x) {
uint16_t t;
for (int i = 0; i < n; ++i) {
ggml_fp16_t fp16 = ggml_fp32_to_fp16(x[i]);
memcpy(&t, &fp16, sizeof(uint16_t));
y[i] = table_gelu_f16[t];
y[i] = ggml_fp16_to_fp32(table_gelu_f16[t]);
}
}
//inline static void ggml_vec_gelu_f32(const int n, float * y, const float * x) {
// for (int i = 0; i < n; ++i) {
// y[i] = ggml_gelu_f32(x[i]);
// }
//}
#else
inline static void ggml_vec_gelu_f32(const int n, float * y, const float * x) {
for (int i = 0; i < n; ++i) {
y[i] = ggml_gelu_f32(x[i]);
}
}
#endif
inline static void ggml_vec_sum_f32 (const int n, float * s, const float * x) { ggml_float sum = 0.0; for (int i = 0; i < n; ++i) sum += x[i]; *s += sum; }
inline static void ggml_vec_norm_inv_f32(const int n, float * s, const float * x) { ggml_vec_norm_f32(n, s, x); *s = 1./(*s); }

@ -0,0 +1,31 @@
#pragma once
#include <Windows.h>
typedef volatile LONG atomic_int;
typedef atomic_int atomic_bool;
static void atomic_store(atomic_int* ptr, LONG val) {
InterlockedExchange(ptr, val);
}
static LONG atomic_load(atomic_int* ptr) {
return InterlockedCompareExchange(ptr, 0, 0);
}
static LONG atomic_fetch_add(atomic_int* ptr, LONG inc) {
return InterlockedExchangeAdd(ptr, inc);
}
static LONG atomic_fetch_sub(atomic_int* ptr, LONG dec) {
return atomic_fetch_add(ptr, -(dec));
}
typedef HANDLE pthread_t;
typedef DWORD thread_ret_t;
static int pthread_create(pthread_t* out, void* unused, thread_ret_t(*func)(void*), void* arg) {
out = CreateThread(NULL, 0, func, arg, 0, NULL);
return out != NULL;
}
static int pthread_join(pthread_t thread, void* unused) {
return (int) WaitForSingleObject(thread, INFINITE);
}
Loading…
Cancel
Save