Added whisper state + default state on the whisper_context

pull/523/head
Sandro Hanea 1 year ago
parent 373043cabe
commit ca6efc47d0

File diff suppressed because it is too large Load Diff

@ -66,6 +66,7 @@ extern "C" {
//
struct whisper_context;
struct whisper_state;
typedef int whisper_token;
@ -101,11 +102,16 @@ extern "C" {
WHISPER_API struct whisper_context * whisper_init_from_buffer(void * buffer, size_t buffer_size);
WHISPER_API struct whisper_context * whisper_init(struct whisper_model_loader * loader);
WHISPER_API struct whisper_state * whisper_init_state(struct whisper_context * ctx);
// Frees all memory allocated by the model.
WHISPER_API void whisper_free(struct whisper_context * ctx);
// Frees all memory allocated by the state.
WHISPER_API void whisper_free_state(struct whisper_state * state);
// Convert RAW PCM audio to log mel spectrogram.
// The resulting spectrogram is stored inside the provided whisper context.
// The resulting spectrogram is stored inside the default state of the provided whisper context.
// Returns 0 on success
WHISPER_API int whisper_pcm_to_mel(
struct whisper_context * ctx,
@ -113,17 +119,36 @@ extern "C" {
int n_samples,
int n_threads);
// Convert RAW PCM audio to log mel spectrogram.
// The resulting spectrogram is stored inside the provided state.
// Returns 0 on success
WHISPER_API int whisper_pcm_to_mel_with_state(
struct whisper_context * ctx,
struct whisper_state * state,
const float * samples,
int n_samples,
int n_threads);
// Convert RAW PCM audio to log mel spectrogram but applies a Phase Vocoder to speed up the audio x2.
// The resulting spectrogram is stored inside the provided whisper context.
// The resulting spectrogram is stored inside the default state of the provided whisper context.
// Returns 0 on success
WHISPER_API int whisper_pcm_to_mel_phase_vocoder(
struct whisper_context* ctx,
const float* samples,
int n_samples,
int n_threads);
// This can be used to set a custom log mel spectrogram inside the provided whisper context.
struct whisper_context * ctx,
const float * samples,
int n_samples,
int n_threads);
// Convert RAW PCM audio to log mel spectrogram but applies a Phase Vocoder to speed up the audio x2.
// The resulting spectrogram is stored inside the provided state.
// Returns 0 on success
WHISPER_API int whisper_pcm_to_mel_phase_vocoder_with_state(
struct whisper_context * ctx,
struct whisper_state * state,
const float * samples,
int n_samples,
int n_threads);
// This can be used to set a custom log mel spectrogram inside the default state of the provided whisper context.
// Use this instead of whisper_pcm_to_mel() if you want to provide your own log mel spectrogram.
// n_mel must be 80
// Returns 0 on success
@ -132,8 +157,19 @@ extern "C" {
const float * data,
int n_len,
int n_mel);
// This can be used to set a custom log mel spectrogram inside the provided state.
// Use this instead of whisper_pcm_to_mel() if you want to provide your own log mel spectrogram.
// n_mel must be 80
// Returns 0 on success
WHISPER_API int whisper_set_mel_with_state(
struct whisper_context * ctx,
struct whisper_state * state,
const float * data,
int n_len,
int n_mel);
// Run the Whisper encoder on the log mel spectrogram stored inside the provided whisper context.
// Run the Whisper encoder on the log mel spectrogram stored inside the default state in the provided whisper context.
// Make sure to call whisper_pcm_to_mel() or whisper_set_mel() first.
// offset can be used to specify the offset of the first frame in the spectrogram.
// Returns 0 on success
@ -141,6 +177,16 @@ extern "C" {
struct whisper_context * ctx,
int offset,
int n_threads);
// Run the Whisper encoder on the log mel spectrogram stored inside the provided whisper state.
// Make sure to call whisper_pcm_to_mel() or whisper_set_mel() first.
// offset can be used to specify the offset of the first frame in the spectrogram.
// Returns 0 on success
WHISPER_API int whisper_encode_with_state(
struct whisper_context * ctx,
struct whisper_state * state,
int offset,
int n_threads);
// Run the Whisper decoder to obtain the logits and probabilities for the next token.
// Make sure to call whisper_encode() first.
@ -155,6 +201,21 @@ extern "C" {
int n_past,
int n_threads);
// Run the Whisper decoder to obtain the logits and probabilities for the next token.
// Make sure to call whisper_encode() first.
// tokens + n_tokens is the provided context for the decoder.
// n_past is the number of tokens to use from previous decoder calls.
// Returns 0 on success
// TODO: add support for multiple decoders
WHISPER_API int whisper_decode_with_state(
struct whisper_context * ctx,
struct whisper_state * state,
const whisper_token * tokens,
int n_tokens,
int n_past,
int n_threads);
// Convert the provided text into tokens.
// The tokens pointer must be large enough to hold the resulting tokens.
// Returns the number of tokens on success, no more than n_max_tokens
@ -189,6 +250,19 @@ extern "C" {
int offset_ms,
int n_threads,
float * lang_probs);
// Use mel data at offset_ms to try and auto-detect the spoken language
// Make sure to call whisper_pcm_to_mel_with_state() or whisper_set_mel_with_state() first
// Returns the top language id or negative on failure
// If not null, fills the lang_probs array with the probabilities of all languages
// The array must be whispe_lang_max_id() + 1 in size
// ref: https://github.com/openai/whisper/blob/main/whisper/decoding.py#L18-L69
WHISPER_API int whisper_lang_auto_detect_with_state(
struct whisper_context * ctx,
struct whisper_state * state,
int offset_ms,
int n_threads,
float * lang_probs);
WHISPER_API int whisper_n_len (struct whisper_context * ctx); // mel length
WHISPER_API int whisper_n_vocab (struct whisper_context * ctx);
@ -196,12 +270,18 @@ extern "C" {
WHISPER_API int whisper_n_audio_ctx (struct whisper_context * ctx);
WHISPER_API int whisper_is_multilingual(struct whisper_context * ctx);
// Token logits obtained from the last call to whisper_decode()
// Token logits obtained from the last call to whisper_decode() for the default state
// The logits for the last token are stored in the last row
// Rows: n_tokens
// Cols: n_vocab
WHISPER_API float * whisper_get_logits(struct whisper_context * ctx);
// Token logits obtained from the last call to whisper_decode() for the given state
// The logits for the last token are stored in the last row
// Rows: n_tokens
// Cols: n_vocab
WHISPER_API float* whisper_get_logits_from_state(struct whisper_state * state);
// Token Id -> String. Uses the vocabulary in the provided context
WHISPER_API const char * whisper_token_to_str(struct whisper_context * ctx, whisper_token token);
@ -218,7 +298,7 @@ extern "C" {
WHISPER_API whisper_token whisper_token_translate (void);
WHISPER_API whisper_token whisper_token_transcribe(void);
// Performance information
// Performance information from the default state.
WHISPER_API void whisper_print_timings(struct whisper_context * ctx);
WHISPER_API void whisper_reset_timings(struct whisper_context * ctx);
@ -236,18 +316,19 @@ extern "C" {
// Text segment callback
// Called on every newly generated text segment
// Use the whisper_full_...() functions to obtain the text segments
typedef void (*whisper_new_segment_callback)(struct whisper_context * ctx, int n_new, void * user_data);
typedef void (*whisper_new_segment_callback)(struct whisper_context * ctx, struct whisper_state * state, int n_new, void * user_data);
// Encoder begin callback
// If not NULL, called before the encoder starts
// If it returns false, the computation is aborted
typedef bool (*whisper_encoder_begin_callback)(struct whisper_context * ctx, void * user_data);
typedef bool (*whisper_encoder_begin_callback)(struct whisper_context * ctx, struct whisper_state * state, void * user_data);
// Logits filter callback
// Can be used to modify the logits before sampling
// If not NULL, called after applying temperature to logits
typedef void (*whisper_logits_filter_callback)(
struct whisper_context * ctx,
struct whisper_state * state,
const whisper_token_data * tokens,
int n_tokens,
float * logits,
@ -333,7 +414,8 @@ extern "C" {
WHISPER_API struct whisper_full_params whisper_full_default_params(enum whisper_sampling_strategy strategy);
// Run the entire model: PCM -> log mel spectrogram -> encoder -> decoder -> text
// Run the entire model: PCM -> log mel spectrogram -> encoder -> decoder -> text on the default state from the context
// Not Thread safe for same context
// Uses the specified decoding strategy to obtain the text.
WHISPER_API int whisper_full(
struct whisper_context * ctx,
@ -341,7 +423,19 @@ extern "C" {
const float * samples,
int n_samples);
// Split the input audio in chunks and process each chunk separately using whisper_full()
// Run the entire model: PCM -> log mel spectrogram -> encoder -> decoder -> text on the given state
// Thread safe for same context, as long as different state is provided.
// Uses the specified decoding strategy to obtain the text.
WHISPER_API int whisper_full_with_state(
struct whisper_context * ctx,
struct whisper_state * state,
struct whisper_full_params params,
const float * samples,
int n_samples);
// Split the input audio in chunks and process each chunk separately using whisper_full_with_state()
// Result is stored on the default state from the context
// Not Thread safe if executed in parallel on the same context.
// It seems this approach can offer some speedup in some cases.
// However, the transcription accuracy can be worse at the beginning and end of each chunk.
WHISPER_API int whisper_full_parallel(
@ -351,34 +445,62 @@ extern "C" {
int n_samples,
int n_processors);
// Number of generated text segments.
// Number of generated text segments in default state.
// A segment can be a few words, a sentence, or even a paragraph.
WHISPER_API int whisper_full_n_segments(struct whisper_context * ctx);
// Language id associated with the current context
// Number of generated text segments from given state.
// A segment can be a few words, a sentence, or even a paragraph.
WHISPER_API int whisper_full_n_segments_from_state(struct whisper_state * state);
// Language id associated with the current context's default state
WHISPER_API int whisper_full_lang_id(struct whisper_context * ctx);
// Get the start and end time of the specified segment.
// Language id associated with the current context's default state
WHISPER_API int whisper_full_lang_id_from_state(struct whisper_state * state);
// Get the start and end time of the specified segment from the default state
WHISPER_API int64_t whisper_full_get_segment_t0(struct whisper_context * ctx, int i_segment);
WHISPER_API int64_t whisper_full_get_segment_t1(struct whisper_context * ctx, int i_segment);
// Get the text of the specified segment.
// Get the start and end time of the specified segment from the given state
WHISPER_API int64_t whisper_full_get_segment_t0_from_state(struct whisper_state * state, int i_segment);
WHISPER_API int64_t whisper_full_get_segment_t1_from_state(struct whisper_state * state, int i_segment);
// Get the text of the specified segment in the default state
WHISPER_API const char * whisper_full_get_segment_text(struct whisper_context * ctx, int i_segment);
// Get number of tokens in the specified segment.
// Get the text of the specified segment in the given state
WHISPER_API const char * whisper_full_get_segment_text_from_state(struct whisper_state * state, int i_segment);
// Get number of tokens in the specified segment from the default state
WHISPER_API int whisper_full_n_tokens(struct whisper_context * ctx, int i_segment);
// Get the token text of the specified token in the specified segment.
// Get number of tokens in the specified segment from the given state
WHISPER_API int whisper_full_n_tokens_from_state(struct whisper_state * state, int i_segment);
// Get the token text of the specified token in the specified segment from the default state
WHISPER_API const char * whisper_full_get_token_text(struct whisper_context * ctx, int i_segment, int i_token);
WHISPER_API whisper_token whisper_full_get_token_id (struct whisper_context * ctx, int i_segment, int i_token);
// Get token data for the specified token in the specified segment.
// Get the token text of the specified token index in the specified segment from the given state
WHISPER_API const char * whisper_full_get_token_text_from_state(struct whisper_context * ctx, struct whisper_state * state, int i_segment, int i_token);
WHISPER_API whisper_token whisper_full_get_token_id_from_state(struct whisper_state * state, int i_segment, int i_token);
// Get token data for the specified token in the specified segment from the given state
// This contains probabilities, timestamps, etc.
WHISPER_API whisper_token_data whisper_full_get_token_data_from_state(struct whisper_state * state, int i_segment, int i_token);
// Get token data for the specified token in the specified segment from the default state
// This contains probabilities, timestamps, etc.
WHISPER_API whisper_token_data whisper_full_get_token_data(struct whisper_context * ctx, int i_segment, int i_token);
// Get the probability of the specified token in the specified segment.
// Get the probability of the specified token in the specified segment from the default state.
WHISPER_API float whisper_full_get_token_p(struct whisper_context * ctx, int i_segment, int i_token);
// Get the probability of the specified token in the specified segment from the given state.
WHISPER_API float whisper_full_get_token_p_from_state(struct whisper_context * ctx, struct whisper_state * state, int i_segment, int i_token);
////////////////////////////////////////////////////////////////////////////
// Temporary helpers needed for exposing ggml interface

Loading…
Cancel
Save