whisper : add WHISPER_DEBUG + diagnostic prints + new main args

pull/291/head
Georgi Gerganov 3 years ago
parent c6a8a47039
commit c301a7942b
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735

@ -212,17 +212,7 @@ make large
## Limitations ## Limitations
- Inference only - Inference only
- No GPU support - No GPU support (yet)
- Very basic greedy sampling scheme - always pick up the token with highest probability.
This should be similar to the [GreedyDecoder](https://github.com/openai/whisper/blob/main/whisper/decoding.py#L249-L274)
from the original python implementation, so in order to make a fair comparison between the 2 implementations, make sure
to run the python code with the following parameters:
```
whisper --best_of None --beam_size None ...
```
In the future, `whisper.cpp` will support more sampling strategies.
## Another example ## Another example

@ -59,8 +59,12 @@ struct whisper_params {
int32_t duration_ms = 0; int32_t duration_ms = 0;
int32_t max_context = -1; int32_t max_context = -1;
int32_t max_len = 0; int32_t max_len = 0;
int32_t best_of = 5;
int32_t beam_size = -1;
float word_thold = 0.01f; float word_thold = 0.01f;
float entropy_thold = 2.4f;
float logprob_thold = -1.0f;
bool speed_up = false; bool speed_up = false;
bool translate = false; bool translate = false;
@ -104,7 +108,11 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
else if (arg == "-d" || arg == "--duration") { params.duration_ms = std::stoi(argv[++i]); } else if (arg == "-d" || arg == "--duration") { params.duration_ms = std::stoi(argv[++i]); }
else if (arg == "-mc" || arg == "--max-context") { params.max_context = std::stoi(argv[++i]); } else if (arg == "-mc" || arg == "--max-context") { params.max_context = std::stoi(argv[++i]); }
else if (arg == "-ml" || arg == "--max-len") { params.max_len = std::stoi(argv[++i]); } else if (arg == "-ml" || arg == "--max-len") { params.max_len = std::stoi(argv[++i]); }
else if (arg == "-bo" || arg == "--best-of") { params.best_of = std::stoi(argv[++i]); }
else if (arg == "-bs" || arg == "--beam-size") { params.beam_size = std::stoi(argv[++i]); }
else if (arg == "-wt" || arg == "--word-thold") { params.word_thold = std::stof(argv[++i]); } else if (arg == "-wt" || arg == "--word-thold") { params.word_thold = std::stof(argv[++i]); }
else if (arg == "-et" || arg == "--entropy-thold") { params.entropy_thold = std::stof(argv[++i]); }
else if (arg == "-lpt" || arg == "--logprob-thold") { params.logprob_thold = std::stof(argv[++i]); }
else if (arg == "-su" || arg == "--speed-up") { params.speed_up = true; } else if (arg == "-su" || arg == "--speed-up") { params.speed_up = true; }
else if (arg == "-tr" || arg == "--translate") { params.translate = true; } else if (arg == "-tr" || arg == "--translate") { params.translate = true; }
else if (arg == "-di" || arg == "--diarize") { params.diarize = true; } else if (arg == "-di" || arg == "--diarize") { params.diarize = true; }
@ -144,7 +152,11 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
fprintf(stderr, " -d N, --duration N [%-7d] duration of audio to process in milliseconds\n", params.duration_ms); fprintf(stderr, " -d N, --duration N [%-7d] duration of audio to process in milliseconds\n", params.duration_ms);
fprintf(stderr, " -mc N, --max-context N [%-7d] maximum number of text context tokens to store\n", params.max_context); fprintf(stderr, " -mc N, --max-context N [%-7d] maximum number of text context tokens to store\n", params.max_context);
fprintf(stderr, " -ml N, --max-len N [%-7d] maximum segment length in characters\n", params.max_len); fprintf(stderr, " -ml N, --max-len N [%-7d] maximum segment length in characters\n", params.max_len);
fprintf(stderr, " -bo N, --best-of N [%-7d] number of best candidates to keep\n", params.best_of);
fprintf(stderr, " -bs N, --beam-size N [%-7d] beam size for beam search\n", params.beam_size);
fprintf(stderr, " -wt N, --word-thold N [%-7.2f] word timestamp probability threshold\n", params.word_thold); fprintf(stderr, " -wt N, --word-thold N [%-7.2f] word timestamp probability threshold\n", params.word_thold);
fprintf(stderr, " -et N, --entropy-thold N [%-7.2f] entropy threshold for decoder fail\n", params.entropy_thold);
fprintf(stderr, " -lpt N, --logprob-thold N [%-7.2f] log probability threshold for decoder fail\n", params.logprob_thold);
fprintf(stderr, " -su, --speed-up [%-7s] speed up audio by x2 (reduced accuracy)\n", params.speed_up ? "true" : "false"); fprintf(stderr, " -su, --speed-up [%-7s] speed up audio by x2 (reduced accuracy)\n", params.speed_up ? "true" : "false");
fprintf(stderr, " -tr, --translate [%-7s] translate from source language to english\n", params.translate ? "true" : "false"); fprintf(stderr, " -tr, --translate [%-7s] translate from source language to english\n", params.translate ? "true" : "false");
fprintf(stderr, " -di, --diarize [%-7s] stereo audio diarization\n", params.diarize ? "true" : "false"); fprintf(stderr, " -di, --diarize [%-7s] stereo audio diarization\n", params.diarize ? "true" : "false");
@ -331,20 +343,19 @@ bool output_csv(struct whisper_context * ctx, const char * fname) {
const int n_segments = whisper_full_n_segments(ctx); const int n_segments = whisper_full_n_segments(ctx);
for (int i = 0; i < n_segments; ++i) { for (int i = 0; i < n_segments; ++i) {
const char * text = whisper_full_get_segment_text(ctx, i); const char * text = whisper_full_get_segment_text(ctx, i);
if (text[0] == ' ') if (text[0] == ' ') {
text = text + sizeof(char); //whisper_full_get_segment_text() returns a string with leading space, point to the next character. text = text + sizeof(char); //whisper_full_get_segment_text() returns a string with leading space, point to the next character.
}
const int64_t t0 = whisper_full_get_segment_t0(ctx, i); const int64_t t0 = whisper_full_get_segment_t0(ctx, i);
const int64_t t1 = whisper_full_get_segment_t1(ctx, i); const int64_t t1 = whisper_full_get_segment_t1(ctx, i);
//need to multiply times returned from whisper_full_get_segment_t{0,1}() by 10 to get milliseconds. //need to multiply times returned from whisper_full_get_segment_t{0,1}() by 10 to get milliseconds.
fout << 10 * t0 << ", " fout << 10 * t0 << ", " << 10 * t1 << ", \"" << text << "\"\n";
<< 10 * t1 << ", \""
<< text << "\"\n";
} }
return true; return true;
} }
// karaoke video generation // karaoke video generation
// outputs a bash script that uses ffmpeg to generate a video with the subtitles // outputs a bash script that uses ffmpeg to generate a video with the subtitles
// TODO: font parameter adjustments // TODO: font parameter adjustments
@ -620,6 +631,8 @@ int main(int argc, char ** argv) {
{ {
whisper_full_params wparams = whisper_full_default_params(WHISPER_SAMPLING_GREEDY); whisper_full_params wparams = whisper_full_default_params(WHISPER_SAMPLING_GREEDY);
wparams.strategy = params.beam_size > 1 ? WHISPER_SAMPLING_BEAM_SEARCH : WHISPER_SAMPLING_GREEDY;
wparams.print_realtime = false; wparams.print_realtime = false;
wparams.print_progress = params.print_progress; wparams.print_progress = params.print_progress;
wparams.print_timestamps = !params.no_timestamps; wparams.print_timestamps = !params.no_timestamps;
@ -633,6 +646,8 @@ int main(int argc, char ** argv) {
wparams.token_timestamps = params.output_wts || params.max_len > 0; wparams.token_timestamps = params.output_wts || params.max_len > 0;
wparams.thold_pt = params.word_thold; wparams.thold_pt = params.word_thold;
wparams.entropy_threshold = params.entropy_thold;
wparams.logprob_threshold = params.logprob_thold;
wparams.max_len = params.output_wts && params.max_len == 0 ? 60 : params.max_len; wparams.max_len = params.output_wts && params.max_len == 0 ? 60 : params.max_len;
wparams.speed_up = params.speed_up; wparams.speed_up = params.speed_up;

@ -25,6 +25,16 @@
} \ } \
} while (0) } while (0)
#define WHISPER_DEBUG
#if defined(WHISPER_DEBUG)
#define WHISPER_PRINT_DEBUG(...) \
do { \
fprintf(stderr, __VA_ARGS__); \
} while (0)
#else
#define WHISPER_PRINT_DEBUG(...)
#endif
#define WHISPER_USE_FLASH_ATTN #define WHISPER_USE_FLASH_ATTN
//#define WHISPER_USE_FLASH_FF //#define WHISPER_USE_FLASH_FF
#define WHISPER_MAX_DECODERS 16 #define WHISPER_MAX_DECODERS 16
@ -1640,7 +1650,7 @@ static bool whisper_decode(
const int N = n_tokens; const int N = n_tokens;
const int M = wctx.exp_n_audio_ctx > 0 ? wctx.exp_n_audio_ctx : hparams.n_audio_ctx; const int M = wctx.exp_n_audio_ctx > 0 ? wctx.exp_n_audio_ctx : hparams.n_audio_ctx;
//fprintf(stderr, "n_past = %d, N = %d, M = %d, n_ctx = %d\n", n_past, N, M, n_ctx); WHISPER_PRINT_DEBUG("%s: n_past = %d, N = %d, M = %d, n_ctx = %d\n", __func__, n_past, N, M, n_ctx);
struct ggml_init_params params; struct ggml_init_params params;
params.mem_size = wctx.buf_compute.size(); params.mem_size = wctx.buf_compute.size();
@ -3228,7 +3238,7 @@ int whisper_full(
return -4; return -4;
} }
fprintf(stderr, "%s: initialized self-attention kv cache, decoder %d\n", __func__, j); WHISPER_PRINT_DEBUG("%s: initialized self-attention kv cache, decoder %d\n", __func__, j);
decoder.sequence.tokens.reserve(ctx->decoders[0].sequence.tokens.capacity()); decoder.sequence.tokens.reserve(ctx->decoders[0].sequence.tokens.capacity());
@ -3338,12 +3348,12 @@ int whisper_full(
} break; } break;
}; };
fprintf(stderr, "\n%s: decoding with %d decoders, temperature = %.2f\n", __func__, n_decoders_cur, t_cur); WHISPER_PRINT_DEBUG("\n%s: decoding with %d decoders, temperature = %.2f\n", __func__, n_decoders_cur, t_cur);
if (t_cur > 0.5) { if (t_cur > 0.5) {
prompt_past.clear(); prompt_past.clear();
fprintf(stderr, "%s: clearing prompt_past\n", __func__); WHISPER_PRINT_DEBUG("%s: clearing prompt_past\n", __func__);
} }
// TAGS: WHISPER_DECODER_INIT // TAGS: WHISPER_DECODER_INIT
@ -3465,11 +3475,13 @@ int whisper_full(
has_ts = true; has_ts = true;
} }
//{ #ifdef WHISPER_DEBUG
// const auto tt = token.pt > 0.10 ? ctx->vocab.id_to_token.at(token.tid) : "[?]"; {
// printf("%s: id = %3d, decoder = %d, token = %6d, p = %6.3f, ts = %10s, %6.3f, result_len = %4d '%s'\n", const auto tt = token.pt > 0.10 ? ctx->vocab.id_to_token.at(token.tid) : "[?]";
// __func__, i, j, token.id, token.p, tt.c_str(), token.pt, result_len, ctx->vocab.id_to_token.at(token.id).c_str()); WHISPER_PRINT_DEBUG("%s: id = %3d, decoder = %d, token = %6d, p = %6.3f, ts = %10s, %6.3f, result_len = %4d '%s'\n",
//} __func__, i, j, token.id, token.p, tt.c_str(), token.pt, result_len, ctx->vocab.id_to_token.at(token.id).c_str());
}
#endif
// end of segment // end of segment
if (token.id == whisper_token_eot(ctx) || // end of text token if (token.id == whisper_token_eot(ctx) || // end of text token
@ -3541,7 +3553,7 @@ int whisper_full(
decoder.tokens_tmp.resize(1); decoder.tokens_tmp.resize(1);
decoder.tokens_tmp[0] = decoder.sequence.tokens.back().id; decoder.tokens_tmp[0] = decoder.sequence.tokens.back().id;
//fprintf(stderr, "%s: decoder %d: token %d, n_past %d, seek_delta %d\n", __func__, j, decoder.tokens_tmp[0], decoder.n_past, decoder.seek_delta); //WHISPER_PRINT_DEBUG("%s: decoder %d: token %d, n_past %d, seek_delta %d\n", __func__, j, decoder.tokens_tmp[0], decoder.n_past, decoder.seek_delta);
if (!whisper_decode(*ctx, decoder, decoder.tokens_tmp.data(), decoder.tokens_tmp.size(), decoder.n_past, params.n_threads)) { if (!whisper_decode(*ctx, decoder, decoder.tokens_tmp.data(), decoder.tokens_tmp.size(), decoder.n_past, params.n_threads)) {
fprintf(stderr, "%s: failed to decode\n", __func__); fprintf(stderr, "%s: failed to decode\n", __func__);
@ -3574,11 +3586,11 @@ int whisper_full(
decoder.sequence.tokens.resize(decoder.sequence.result_len); decoder.sequence.tokens.resize(decoder.sequence.result_len);
whisper_sequence_score(params, decoder.sequence); whisper_sequence_score(params, decoder.sequence);
fprintf(stderr, "%s: decoder %2d: score = %8.5f, result_len = %3d, avg_logprobs = %8.5f, entropy = %8.5f\n", WHISPER_PRINT_DEBUG("%s: decoder %2d: score = %8.5f, result_len = %3d, avg_logprobs = %8.5f, entropy = %8.5f\n",
__func__, j, decoder.sequence.score, decoder.sequence.result_len, decoder.sequence.avg_logprobs, decoder.sequence.entropy); __func__, j, decoder.sequence.score, decoder.sequence.result_len, decoder.sequence.avg_logprobs, decoder.sequence.entropy);
if (decoder.sequence.result_len > 8 && decoder.sequence.entropy < params.entropy_threshold) { if (decoder.sequence.result_len > 8 && decoder.sequence.entropy < params.entropy_threshold) {
fprintf(stderr, "%s: decoder %2d: failed due to entropy %8.5f < %8.5f\n", WHISPER_PRINT_DEBUG("%s: decoder %2d: failed due to entropy %8.5f < %8.5f\n",
__func__, j, decoder.sequence.entropy, params.entropy_threshold); __func__, j, decoder.sequence.entropy, params.entropy_threshold);
decoder.failed = true; decoder.failed = true;
@ -3606,15 +3618,15 @@ int whisper_full(
} }
if (success) { if (success) {
fprintf(stderr, "%s: best decoder = %d\n", __func__, best_decoder_id); WHISPER_PRINT_DEBUG("%s: best decoder = %d\n", __func__, best_decoder_id);
//for (auto & token : ctx->decoders[best_decoder_id].sequence.tokens) { //for (auto & token : ctx->decoders[best_decoder_id].sequence.tokens) {
// fprintf(stderr, "%s: token = %d, p = %6.3f, pt = %6.3f, ts = %s, str = %s\n", __func__, token.id, token.p, token.pt, ctx->vocab.id_to_token.at(token.tid).c_str(), ctx->vocab.id_to_token.at(token.id).c_str()); // WHISPER_PRINT_DEBUG("%s: token = %d, p = %6.3f, pt = %6.3f, ts = %s, str = %s\n", __func__, token.id, token.p, token.pt, ctx->vocab.id_to_token.at(token.tid).c_str(), ctx->vocab.id_to_token.at(token.id).c_str());
//} //}
break; break;
} }
fprintf(stderr, "\n%s: failed to decode with temperature = %.2f\n", __func__, t_cur); WHISPER_PRINT_DEBUG("\n%s: failed to decode with temperature = %.2f\n", __func__, t_cur);
} }
{ {
@ -3734,7 +3746,7 @@ int whisper_full(
// update audio window // update audio window
seek += seek_delta; seek += seek_delta;
fprintf(stderr, "seek = %d, seek_delta = %d\n", seek, seek_delta); WHISPER_PRINT_DEBUG("seek = %d, seek_delta = %d\n", seek, seek_delta);
} }
} }

@ -281,7 +281,7 @@ extern "C" {
float temperature_increment; float temperature_increment;
float entropy_threshold; // analog tho OpenAI's compression_ratio_threshold float entropy_threshold; // analog tho OpenAI's compression_ratio_threshold
float logprob_threshold; float logprob_threshold;
float no_speech_threshold; float no_speech_threshold; // TODO: not implemented
struct { struct {
int best_of; int best_of;

Loading…
Cancel
Save