diarization : some unsuccessful experiments with audio embd clustering

pull/130/head
Georgi Gerganov 2 years ago
parent f254e78737
commit c2f5be7c11
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735

@ -618,6 +618,8 @@ int main(int argc, char ** argv) {
fprintf(stderr, "%s: failed to process audio\n", argv[0]); fprintf(stderr, "%s: failed to process audio\n", argv[0]);
return 10; return 10;
} }
whisper_full_cluster_segments(ctx);
} }
// output stuff // output stuff

@ -603,6 +603,8 @@ struct whisper_context {
// [EXPERIMENTAL] speed-up techniques // [EXPERIMENTAL] speed-up techniques
int32_t exp_n_audio_ctx; // 0 - use default int32_t exp_n_audio_ctx; // 0 - use default
std::vector<float> audio_embd;
void use_buf(struct ggml_context * ctx, int i) { void use_buf(struct ggml_context * ctx, int i) {
#if defined(WHISPER_USE_SCRATCH) #if defined(WHISPER_USE_SCRATCH)
size_t last_size = 0; size_t last_size = 0;
@ -1707,7 +1709,7 @@ static bool whisper_encode(
} }
// cur // cur
//{ {
//printf("ne0 = %d\n", cur->ne[0]); //printf("ne0 = %d\n", cur->ne[0]);
//printf("ne1 = %d\n", cur->ne[1]); //printf("ne1 = %d\n", cur->ne[1]);
//for (int i = 0; i < 10; ++i) { //for (int i = 0; i < 10; ++i) {
@ -1718,7 +1720,23 @@ static bool whisper_encode(
// printf("%8.4f ", ((float *)(cur->data))[i]); // printf("%8.4f ", ((float *)(cur->data))[i]);
//} //}
//printf("\n"); //printf("\n");
//} }
{
const int i0 = std::min(mel_offset, mel_inp.n_len);
const int i1 = std::min(mel_offset + 2*n_ctx, mel_inp.n_len);
printf("i0 = %d, i1 = %d, (i1 - i0) = %d, embd size = %d\n", i0, i1, i1 - i0, cur->ne[0]);
wctx.audio_embd.clear();
wctx.audio_embd.resize(cur->ne[0], 0.0f);
for (int j = 0; j < cur->ne[0]; ++j) {
for (int i = i0; i < i1; ++i) {
wctx.audio_embd[j] += ((float *)(cur->data))[(i - i0)*cur->ne[0] + j];
}
wctx.audio_embd[j] /= (i1 - i0);
}
}
// pre-compute cross-attention memory // pre-compute cross-attention memory
{ {
@ -4806,3 +4824,129 @@ static void whisper_exp_compute_token_level_timestamps(
// } // }
//} //}
} }
//
// diarization stuff
//
void whisper_full_cluster_segments(struct whisper_context * ctx) {
const int n_segments = ctx->result_all.size();
printf("%s: clustering %d segments\n", __func__, n_segments);
const auto mel_len_save = ctx->mel.n_len;
printf("%s: mel_len_save = %d\n", __func__, mel_len_save);
std::vector<std::vector<float>> features(n_segments);
for (int i = 0; i < n_segments; ++i) {
const auto & segment_i = ctx->result_all[i];
printf("%s: segment %d: t0 = %d, t1 = %d, text = %s\n", __func__, i, (int) segment_i.t0, (int) segment_i.t1, segment_i.text.c_str());
ctx->mel.n_len = segment_i.t1;
whisper_encode(ctx, segment_i.t0, 4);
features[i] = ctx->audio_embd;
}
const int n_features = features[0].size();
// fuzzy c-means clustering
const int n_clusters = 4;
std::vector<std::vector<float>> centroids(n_clusters, std::vector<float>(n_features, 0.0));
std::vector<std::vector<float>> membership(n_segments, std::vector<float>(n_clusters, 0.0));
// initialize the centroids
for (int i = 0; i < n_clusters; ++i) {
for (int j = 0; j < n_features; ++j) {
centroids[i][j] = features[i][j];
}
}
// initialize the membership
for (int i = 0; i < n_segments; ++i) {
membership[i][i % n_clusters] = 1.0;
}
// iterate
for (int i = 0; i < 100; ++i) {
// update the centroids
for (int j = 0; j < n_clusters; ++j) {
for (int k = 0; k < n_features; ++k) {
centroids[j][k] = 0.0;
}
}
for (int j = 0; j < n_segments; ++j) {
for (int k = 0; k < n_clusters; ++k) {
for (int l = 0; l < n_features; ++l) {
centroids[k][l] += membership[j][k]*features[j][l];
}
}
}
for (int j = 0; j < n_clusters; ++j) {
float sum = 0.0;
for (int k = 0; k < n_segments; ++k) {
sum += membership[k][j];
}
for (int k = 0; k < n_features; ++k) {
centroids[j][k] /= sum;
}
}
// update the membership
for (int j = 0; j < n_segments; ++j) {
for (int k = 0; k < n_clusters; ++k) {
float sum = 0.0;
for (int l = 0; l < n_clusters; ++l) {
//sum += std::pow(whisper_distance(features[j], centroids[k])/whisper_distance(features[j], centroids[l]), 2.0/(2.0 - 1.0));
// use the euclidean distance
double d0 = 0.0;
for (int m = 0; m < n_features; ++m) {
d0 += std::pow(features[j][m] - centroids[k][m], 2.0);
}
d0 = std::sqrt(d0);
double d1 = 0.0;
for (int m = 0; m < n_features; ++m) {
d1 += std::pow(features[j][m] - centroids[l][m], 2.0);
}
d1 = std::sqrt(d1);
if (d1 == 0.0) {
sum += 1.0;
} else {
sum += std::pow(d0/d1, 2.0/(2.0 - 1.0));
}
}
membership[j][k] = 1.0/sum;
}
}
// print the membership
for (int i = 0; i < n_segments; ++i) {
printf("%s: membership %d: ", __func__, i);
for (int j = 0; j < n_clusters; ++j) {
printf("%f ", membership[i][j]);
}
printf(" '%s'\n", ctx->result_all[i].text.c_str());
}
printf("----------------\n");
}
// print the centroids
//for (int i = 0; i < n_clusters; ++i) {
// printf("%s: centroid %d: ", __func__, i);
// for (int j = 0; j < n_features; ++j) {
// printf("%f ", centroids[i][j]);
// }
// printf("\n");
//}
// restore the mel length
ctx->mel.n_len = mel_len_save;
}

@ -372,6 +372,10 @@ extern "C" {
WHISPER_API int whisper_bench_memcpy(int n_threads); WHISPER_API int whisper_bench_memcpy(int n_threads);
WHISPER_API int whisper_bench_ggml_mul_mat(int n_threads); WHISPER_API int whisper_bench_ggml_mul_mat(int n_threads);
// Temporary experimental API
WHISPER_API void whisper_full_cluster_segments(struct whisper_context * ctx);
#ifdef __cplusplus #ifdef __cplusplus
} }
#endif #endif

Loading…
Cancel
Save