diff --git a/examples/main/main.cpp b/examples/main/main.cpp index b8366b7..105082c 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -618,6 +618,8 @@ int main(int argc, char ** argv) { fprintf(stderr, "%s: failed to process audio\n", argv[0]); return 10; } + + whisper_full_cluster_segments(ctx); } // output stuff diff --git a/whisper.cpp b/whisper.cpp index 04cbc36..4c208b9 100644 --- a/whisper.cpp +++ b/whisper.cpp @@ -603,6 +603,8 @@ struct whisper_context { // [EXPERIMENTAL] speed-up techniques int32_t exp_n_audio_ctx; // 0 - use default + std::vector audio_embd; + void use_buf(struct ggml_context * ctx, int i) { #if defined(WHISPER_USE_SCRATCH) size_t last_size = 0; @@ -1707,18 +1709,34 @@ static bool whisper_encode( } // cur - //{ - // printf("ne0 = %d\n", cur->ne[0]); - // printf("ne1 = %d\n", cur->ne[1]); - // for (int i = 0; i < 10; ++i) { - // printf("%8.4f ", ((float *)(cur->data))[i]); - // } - // printf("... "); - // for (int i = cur->ne[0] - 10; i < cur->ne[0]; ++i) { - // printf("%8.4f ", ((float *)(cur->data))[i]); - // } - // printf("\n"); - //} + { + //printf("ne0 = %d\n", cur->ne[0]); + //printf("ne1 = %d\n", cur->ne[1]); + //for (int i = 0; i < 10; ++i) { + // printf("%8.4f ", ((float *)(cur->data))[i]); + //} + //printf("... "); + //for (int i = cur->ne[0] - 10; i < cur->ne[0]; ++i) { + // printf("%8.4f ", ((float *)(cur->data))[i]); + //} + //printf("\n"); + } + + { + const int i0 = std::min(mel_offset, mel_inp.n_len); + const int i1 = std::min(mel_offset + 2*n_ctx, mel_inp.n_len); + + printf("i0 = %d, i1 = %d, (i1 - i0) = %d, embd size = %d\n", i0, i1, i1 - i0, cur->ne[0]); + + wctx.audio_embd.clear(); + wctx.audio_embd.resize(cur->ne[0], 0.0f); + for (int j = 0; j < cur->ne[0]; ++j) { + for (int i = i0; i < i1; ++i) { + wctx.audio_embd[j] += ((float *)(cur->data))[(i - i0)*cur->ne[0] + j]; + } + wctx.audio_embd[j] /= (i1 - i0); + } + } // pre-compute cross-attention memory { @@ -4806,3 +4824,129 @@ static void whisper_exp_compute_token_level_timestamps( // } //} } + +// +// diarization stuff +// + +void whisper_full_cluster_segments(struct whisper_context * ctx) { + const int n_segments = ctx->result_all.size(); + printf("%s: clustering %d segments\n", __func__, n_segments); + + const auto mel_len_save = ctx->mel.n_len; + printf("%s: mel_len_save = %d\n", __func__, mel_len_save); + + std::vector> features(n_segments); + + for (int i = 0; i < n_segments; ++i) { + const auto & segment_i = ctx->result_all[i]; + printf("%s: segment %d: t0 = %d, t1 = %d, text = %s\n", __func__, i, (int) segment_i.t0, (int) segment_i.t1, segment_i.text.c_str()); + + ctx->mel.n_len = segment_i.t1; + whisper_encode(ctx, segment_i.t0, 4); + + features[i] = ctx->audio_embd; + } + + const int n_features = features[0].size(); + + // fuzzy c-means clustering + const int n_clusters = 4; + + std::vector> centroids(n_clusters, std::vector(n_features, 0.0)); + std::vector> membership(n_segments, std::vector(n_clusters, 0.0)); + + // initialize the centroids + for (int i = 0; i < n_clusters; ++i) { + for (int j = 0; j < n_features; ++j) { + centroids[i][j] = features[i][j]; + } + } + + // initialize the membership + for (int i = 0; i < n_segments; ++i) { + membership[i][i % n_clusters] = 1.0; + } + + // iterate + for (int i = 0; i < 100; ++i) { + // update the centroids + for (int j = 0; j < n_clusters; ++j) { + for (int k = 0; k < n_features; ++k) { + centroids[j][k] = 0.0; + } + } + + for (int j = 0; j < n_segments; ++j) { + for (int k = 0; k < n_clusters; ++k) { + for (int l = 0; l < n_features; ++l) { + centroids[k][l] += membership[j][k]*features[j][l]; + } + } + } + + for (int j = 0; j < n_clusters; ++j) { + float sum = 0.0; + for (int k = 0; k < n_segments; ++k) { + sum += membership[k][j]; + } + + for (int k = 0; k < n_features; ++k) { + centroids[j][k] /= sum; + } + } + + // update the membership + for (int j = 0; j < n_segments; ++j) { + for (int k = 0; k < n_clusters; ++k) { + float sum = 0.0; + for (int l = 0; l < n_clusters; ++l) { + //sum += std::pow(whisper_distance(features[j], centroids[k])/whisper_distance(features[j], centroids[l]), 2.0/(2.0 - 1.0)); + + // use the euclidean distance + double d0 = 0.0; + for (int m = 0; m < n_features; ++m) { + d0 += std::pow(features[j][m] - centroids[k][m], 2.0); + } + d0 = std::sqrt(d0); + + double d1 = 0.0; + for (int m = 0; m < n_features; ++m) { + d1 += std::pow(features[j][m] - centroids[l][m], 2.0); + } + d1 = std::sqrt(d1); + if (d1 == 0.0) { + sum += 1.0; + } else { + sum += std::pow(d0/d1, 2.0/(2.0 - 1.0)); + } + } + + membership[j][k] = 1.0/sum; + } + } + + // print the membership + for (int i = 0; i < n_segments; ++i) { + printf("%s: membership %d: ", __func__, i); + for (int j = 0; j < n_clusters; ++j) { + printf("%f ", membership[i][j]); + } + printf(" '%s'\n", ctx->result_all[i].text.c_str()); + } + printf("----------------\n"); + } + + // print the centroids + //for (int i = 0; i < n_clusters; ++i) { + // printf("%s: centroid %d: ", __func__, i); + // for (int j = 0; j < n_features; ++j) { + // printf("%f ", centroids[i][j]); + // } + // printf("\n"); + //} + + // restore the mel length + ctx->mel.n_len = mel_len_save; +} + diff --git a/whisper.h b/whisper.h index 7eece79..9e40e70 100644 --- a/whisper.h +++ b/whisper.h @@ -372,6 +372,10 @@ extern "C" { WHISPER_API int whisper_bench_memcpy(int n_threads); WHISPER_API int whisper_bench_ggml_mul_mat(int n_threads); + // Temporary experimental API + + WHISPER_API void whisper_full_cluster_segments(struct whisper_context * ctx); + #ifdef __cplusplus } #endif