whisper : add whisper_n_audio_ctx and check for invalid audio_ctx

closes #344
pull/350/head
Georgi Gerganov 2 years ago
parent 3467230a77
commit d97e6005e9
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735

@ -2497,6 +2497,10 @@ int whisper_n_text_ctx(struct whisper_context * ctx) {
return ctx->model.hparams.n_text_ctx;
}
int whisper_n_audio_ctx(struct whisper_context * ctx) {
return ctx->model.hparams.n_audio_ctx;
}
int whisper_is_multilingual(struct whisper_context * ctx) {
return ctx->vocab.is_multilingual() ? 1 : 0;
}
@ -2822,7 +2826,11 @@ int whisper_full(
std::rotate(prompt_past.begin(), prompt_past.end() - params.prompt_n_tokens, prompt_past.end());
}
// overwrite audio_ctx
// overwrite audio_ctx, max allowed is hparams.n_audio_ctx
if (params.audio_ctx > whisper_n_audio_ctx(ctx)) {
fprintf(stderr, "%s: audio_ctx is larger than the maximum allowed (%d > %d)\n", __func__, params.audio_ctx, whisper_n_audio_ctx(ctx));
return -4;
}
ctx->exp_n_audio_ctx = params.audio_ctx;
// these tokens determine the task that will be performed

@ -177,6 +177,7 @@ extern "C" {
WHISPER_API int whisper_n_len (struct whisper_context * ctx); // mel length
WHISPER_API int whisper_n_vocab (struct whisper_context * ctx);
WHISPER_API int whisper_n_text_ctx (struct whisper_context * ctx);
WHISPER_API int whisper_n_audio_ctx (struct whisper_context * ctx);
WHISPER_API int whisper_is_multilingual(struct whisper_context * ctx);
// The probabilities for the next token

Loading…
Cancel
Save