Skip to content

Commit c2f5be7

Browse files
committed
diarization : some unsuccessful experiments with audio embd clustering
1 parent f254e78 commit c2f5be7

File tree

3 files changed

+162
-12
lines changed

3 files changed

+162
-12
lines changed

examples/main/main.cpp

+2
Original file line numberDiff line numberDiff line change
@@ -618,6 +618,8 @@ int main(int argc, char ** argv) {
618618
fprintf(stderr, "%s: failed to process audio\n", argv[0]);
619619
return 10;
620620
}
621+
622+
whisper_full_cluster_segments(ctx);
621623
}
622624

623625
// output stuff

whisper.cpp

+156-12
Original file line numberDiff line numberDiff line change
@@ -603,6 +603,8 @@ struct whisper_context {
603603
// [EXPERIMENTAL] speed-up techniques
604604
int32_t exp_n_audio_ctx; // 0 - use default
605605

606+
std::vector<float> audio_embd;
607+
606608
void use_buf(struct ggml_context * ctx, int i) {
607609
#if defined(WHISPER_USE_SCRATCH)
608610
size_t last_size = 0;
@@ -1707,18 +1709,34 @@ static bool whisper_encode(
17071709
}
17081710

17091711
// cur
1710-
//{
1711-
// printf("ne0 = %d\n", cur->ne[0]);
1712-
// printf("ne1 = %d\n", cur->ne[1]);
1713-
// for (int i = 0; i < 10; ++i) {
1714-
// printf("%8.4f ", ((float *)(cur->data))[i]);
1715-
// }
1716-
// printf("... ");
1717-
// for (int i = cur->ne[0] - 10; i < cur->ne[0]; ++i) {
1718-
// printf("%8.4f ", ((float *)(cur->data))[i]);
1719-
// }
1720-
// printf("\n");
1721-
//}
1712+
{
1713+
//printf("ne0 = %d\n", cur->ne[0]);
1714+
//printf("ne1 = %d\n", cur->ne[1]);
1715+
//for (int i = 0; i < 10; ++i) {
1716+
// printf("%8.4f ", ((float *)(cur->data))[i]);
1717+
//}
1718+
//printf("... ");
1719+
//for (int i = cur->ne[0] - 10; i < cur->ne[0]; ++i) {
1720+
// printf("%8.4f ", ((float *)(cur->data))[i]);
1721+
//}
1722+
//printf("\n");
1723+
}
1724+
1725+
{
1726+
const int i0 = std::min(mel_offset, mel_inp.n_len);
1727+
const int i1 = std::min(mel_offset + 2*n_ctx, mel_inp.n_len);
1728+
1729+
printf("i0 = %d, i1 = %d, (i1 - i0) = %d, embd size = %d\n", i0, i1, i1 - i0, cur->ne[0]);
1730+
1731+
wctx.audio_embd.clear();
1732+
wctx.audio_embd.resize(cur->ne[0], 0.0f);
1733+
for (int j = 0; j < cur->ne[0]; ++j) {
1734+
for (int i = i0; i < i1; ++i) {
1735+
wctx.audio_embd[j] += ((float *)(cur->data))[(i - i0)*cur->ne[0] + j];
1736+
}
1737+
wctx.audio_embd[j] /= (i1 - i0);
1738+
}
1739+
}
17221740

17231741
// pre-compute cross-attention memory
17241742
{
@@ -4806,3 +4824,129 @@ static void whisper_exp_compute_token_level_timestamps(
48064824
// }
48074825
//}
48084826
}
4827+
4828+
//
4829+
// diarization stuff
4830+
//
4831+
4832+
void whisper_full_cluster_segments(struct whisper_context * ctx) {
4833+
const int n_segments = ctx->result_all.size();
4834+
printf("%s: clustering %d segments\n", __func__, n_segments);
4835+
4836+
const auto mel_len_save = ctx->mel.n_len;
4837+
printf("%s: mel_len_save = %d\n", __func__, mel_len_save);
4838+
4839+
std::vector<std::vector<float>> features(n_segments);
4840+
4841+
for (int i = 0; i < n_segments; ++i) {
4842+
const auto & segment_i = ctx->result_all[i];
4843+
printf("%s: segment %d: t0 = %d, t1 = %d, text = %s\n", __func__, i, (int) segment_i.t0, (int) segment_i.t1, segment_i.text.c_str());
4844+
4845+
ctx->mel.n_len = segment_i.t1;
4846+
whisper_encode(ctx, segment_i.t0, 4);
4847+
4848+
features[i] = ctx->audio_embd;
4849+
}
4850+
4851+
const int n_features = features[0].size();
4852+
4853+
// fuzzy c-means clustering
4854+
const int n_clusters = 4;
4855+
4856+
std::vector<std::vector<float>> centroids(n_clusters, std::vector<float>(n_features, 0.0));
4857+
std::vector<std::vector<float>> membership(n_segments, std::vector<float>(n_clusters, 0.0));
4858+
4859+
// initialize the centroids
4860+
for (int i = 0; i < n_clusters; ++i) {
4861+
for (int j = 0; j < n_features; ++j) {
4862+
centroids[i][j] = features[i][j];
4863+
}
4864+
}
4865+
4866+
// initialize the membership
4867+
for (int i = 0; i < n_segments; ++i) {
4868+
membership[i][i % n_clusters] = 1.0;
4869+
}
4870+
4871+
// iterate
4872+
for (int i = 0; i < 100; ++i) {
4873+
// update the centroids
4874+
for (int j = 0; j < n_clusters; ++j) {
4875+
for (int k = 0; k < n_features; ++k) {
4876+
centroids[j][k] = 0.0;
4877+
}
4878+
}
4879+
4880+
for (int j = 0; j < n_segments; ++j) {
4881+
for (int k = 0; k < n_clusters; ++k) {
4882+
for (int l = 0; l < n_features; ++l) {
4883+
centroids[k][l] += membership[j][k]*features[j][l];
4884+
}
4885+
}
4886+
}
4887+
4888+
for (int j = 0; j < n_clusters; ++j) {
4889+
float sum = 0.0;
4890+
for (int k = 0; k < n_segments; ++k) {
4891+
sum += membership[k][j];
4892+
}
4893+
4894+
for (int k = 0; k < n_features; ++k) {
4895+
centroids[j][k] /= sum;
4896+
}
4897+
}
4898+
4899+
// update the membership
4900+
for (int j = 0; j < n_segments; ++j) {
4901+
for (int k = 0; k < n_clusters; ++k) {
4902+
float sum = 0.0;
4903+
for (int l = 0; l < n_clusters; ++l) {
4904+
//sum += std::pow(whisper_distance(features[j], centroids[k])/whisper_distance(features[j], centroids[l]), 2.0/(2.0 - 1.0));
4905+
4906+
// use the euclidean distance
4907+
double d0 = 0.0;
4908+
for (int m = 0; m < n_features; ++m) {
4909+
d0 += std::pow(features[j][m] - centroids[k][m], 2.0);
4910+
}
4911+
d0 = std::sqrt(d0);
4912+
4913+
double d1 = 0.0;
4914+
for (int m = 0; m < n_features; ++m) {
4915+
d1 += std::pow(features[j][m] - centroids[l][m], 2.0);
4916+
}
4917+
d1 = std::sqrt(d1);
4918+
if (d1 == 0.0) {
4919+
sum += 1.0;
4920+
} else {
4921+
sum += std::pow(d0/d1, 2.0/(2.0 - 1.0));
4922+
}
4923+
}
4924+
4925+
membership[j][k] = 1.0/sum;
4926+
}
4927+
}
4928+
4929+
// print the membership
4930+
for (int i = 0; i < n_segments; ++i) {
4931+
printf("%s: membership %d: ", __func__, i);
4932+
for (int j = 0; j < n_clusters; ++j) {
4933+
printf("%f ", membership[i][j]);
4934+
}
4935+
printf(" '%s'\n", ctx->result_all[i].text.c_str());
4936+
}
4937+
printf("----------------\n");
4938+
}
4939+
4940+
// print the centroids
4941+
//for (int i = 0; i < n_clusters; ++i) {
4942+
// printf("%s: centroid %d: ", __func__, i);
4943+
// for (int j = 0; j < n_features; ++j) {
4944+
// printf("%f ", centroids[i][j]);
4945+
// }
4946+
// printf("\n");
4947+
//}
4948+
4949+
// restore the mel length
4950+
ctx->mel.n_len = mel_len_save;
4951+
}
4952+

whisper.h

+4
Original file line numberDiff line numberDiff line change
@@ -372,6 +372,10 @@ extern "C" {
372372
WHISPER_API int whisper_bench_memcpy(int n_threads);
373373
WHISPER_API int whisper_bench_ggml_mul_mat(int n_threads);
374374

375+
// Temporary experimental API
376+
377+
WHISPER_API void whisper_full_cluster_segments(struct whisper_context * ctx);
378+
375379
#ifdef __cplusplus
376380
}
377381
#endif

0 commit comments

Comments
 (0)