Skip to content

Commit 643bb17

Browse files
committed
diarization : some unsuccessful experiments with audio embd clustering
1 parent 0bfe728 commit 643bb17

File tree

3 files changed

+155
-12
lines changed

3 files changed

+155
-12
lines changed

examples/main/main.cpp

+2
Original file line numberDiff line numberDiff line change
@@ -552,6 +552,8 @@ int main(int argc, char ** argv) {
552552
fprintf(stderr, "%s: failed to process audio\n", argv[0]);
553553
return 8;
554554
}
555+
556+
whisper_full_cluster_segments(ctx);
555557
}
556558

557559
// output stuff

whisper.cpp

+151-12
Original file line numberDiff line numberDiff line change
@@ -424,6 +424,8 @@ struct whisper_context {
424424
int64_t t_last;
425425
whisper_token tid_last;
426426
std::vector<float> energy; // PCM signal energy
427+
428+
std::vector<float> audio_embd;
427429
};
428430

429431
// load the model from a ggml file
@@ -1383,18 +1385,34 @@ static bool whisper_encode(
13831385
}
13841386

13851387
// cur
1386-
//{
1387-
// printf("ne0 = %d\n", cur->ne[0]);
1388-
// printf("ne1 = %d\n", cur->ne[1]);
1389-
// for (int i = 0; i < 10; ++i) {
1390-
// printf("%8.4f ", ((float *)(cur->data))[i]);
1391-
// }
1392-
// printf("... ");
1393-
// for (int i = cur->ne[0] - 10; i < cur->ne[0]; ++i) {
1394-
// printf("%8.4f ", ((float *)(cur->data))[i]);
1395-
// }
1396-
// printf("\n");
1397-
//}
1388+
{
1389+
//printf("ne0 = %d\n", cur->ne[0]);
1390+
//printf("ne1 = %d\n", cur->ne[1]);
1391+
//for (int i = 0; i < 10; ++i) {
1392+
// printf("%8.4f ", ((float *)(cur->data))[i]);
1393+
//}
1394+
//printf("... ");
1395+
//for (int i = cur->ne[0] - 10; i < cur->ne[0]; ++i) {
1396+
// printf("%8.4f ", ((float *)(cur->data))[i]);
1397+
//}
1398+
//printf("\n");
1399+
}
1400+
1401+
{
1402+
const int i0 = std::min(mel_offset, mel_inp.n_len);
1403+
const int i1 = std::min(mel_offset + 2*n_ctx, mel_inp.n_len);
1404+
1405+
printf("i0 = %d, i1 = %d, (i1 - i0) = %d, embd size = %d\n", i0, i1, i1 - i0, cur->ne[0]);
1406+
1407+
wctx.audio_embd.clear();
1408+
wctx.audio_embd.resize(cur->ne[0], 0.0f);
1409+
for (int j = 0; j < cur->ne[0]; ++j) {
1410+
for (int i = i0; i < i1; ++i) {
1411+
wctx.audio_embd[j] += ((float *)(cur->data))[(i - i0)*cur->ne[0] + j];
1412+
}
1413+
wctx.audio_embd[j] /= (i1 - i0);
1414+
}
1415+
}
13981416

13991417
// pre-compute cross-attention memory
14001418
{
@@ -2936,6 +2954,127 @@ float whisper_full_get_token_p(struct whisper_context * ctx, int i_segment, int
29362954
return ctx->result_all[i_segment].tokens[i_token].p;
29372955
}
29382956

2957+
void whisper_full_cluster_segments(struct whisper_context * ctx) {
2958+
const int n_segments = ctx->result_all.size();
2959+
printf("%s: clustering %d segments\n", __func__, n_segments);
2960+
2961+
const auto mel_len_save = ctx->mel.n_len;
2962+
printf("%s: mel_len_save = %d\n", __func__, mel_len_save);
2963+
2964+
std::vector<std::vector<float>> features(n_segments);
2965+
2966+
for (int i = 0; i < n_segments; ++i) {
2967+
const auto & segment_i = ctx->result_all[i];
2968+
printf("%s: segment %d: t0 = %d, t1 = %d, text = %s\n", __func__, i, (int) segment_i.t0, (int) segment_i.t1, segment_i.text.c_str());
2969+
2970+
ctx->mel.n_len = segment_i.t1;
2971+
whisper_encode(ctx, segment_i.t0, 4);
2972+
2973+
features[i] = ctx->audio_embd;
2974+
}
2975+
2976+
const int n_features = features[0].size();
2977+
2978+
// fuzzy c-means clustering
2979+
const int n_clusters = 4;
2980+
2981+
std::vector<std::vector<float>> centroids(n_clusters, std::vector<float>(n_features, 0.0));
2982+
std::vector<std::vector<float>> membership(n_segments, std::vector<float>(n_clusters, 0.0));
2983+
2984+
// initialize the centroids
2985+
for (int i = 0; i < n_clusters; ++i) {
2986+
for (int j = 0; j < n_features; ++j) {
2987+
centroids[i][j] = features[i][j];
2988+
}
2989+
}
2990+
2991+
// initialize the membership
2992+
for (int i = 0; i < n_segments; ++i) {
2993+
membership[i][i % n_clusters] = 1.0;
2994+
}
2995+
2996+
// iterate
2997+
for (int i = 0; i < 100; ++i) {
2998+
// update the centroids
2999+
for (int j = 0; j < n_clusters; ++j) {
3000+
for (int k = 0; k < n_features; ++k) {
3001+
centroids[j][k] = 0.0;
3002+
}
3003+
}
3004+
3005+
for (int j = 0; j < n_segments; ++j) {
3006+
for (int k = 0; k < n_clusters; ++k) {
3007+
for (int l = 0; l < n_features; ++l) {
3008+
centroids[k][l] += membership[j][k]*features[j][l];
3009+
}
3010+
}
3011+
}
3012+
3013+
for (int j = 0; j < n_clusters; ++j) {
3014+
float sum = 0.0;
3015+
for (int k = 0; k < n_segments; ++k) {
3016+
sum += membership[k][j];
3017+
}
3018+
3019+
for (int k = 0; k < n_features; ++k) {
3020+
centroids[j][k] /= sum;
3021+
}
3022+
}
3023+
3024+
// update the membership
3025+
for (int j = 0; j < n_segments; ++j) {
3026+
for (int k = 0; k < n_clusters; ++k) {
3027+
float sum = 0.0;
3028+
for (int l = 0; l < n_clusters; ++l) {
3029+
//sum += std::pow(whisper_distance(features[j], centroids[k])/whisper_distance(features[j], centroids[l]), 2.0/(2.0 - 1.0));
3030+
3031+
// use the euclidean distance
3032+
double d0 = 0.0;
3033+
for (int m = 0; m < n_features; ++m) {
3034+
d0 += std::pow(features[j][m] - centroids[k][m], 2.0);
3035+
}
3036+
d0 = std::sqrt(d0);
3037+
3038+
double d1 = 0.0;
3039+
for (int m = 0; m < n_features; ++m) {
3040+
d1 += std::pow(features[j][m] - centroids[l][m], 2.0);
3041+
}
3042+
d1 = std::sqrt(d1);
3043+
if (d1 == 0.0) {
3044+
sum += 1.0;
3045+
} else {
3046+
sum += std::pow(d0/d1, 2.0/(2.0 - 1.0));
3047+
}
3048+
}
3049+
3050+
membership[j][k] = 1.0/sum;
3051+
}
3052+
}
3053+
3054+
// print the membership
3055+
for (int i = 0; i < n_segments; ++i) {
3056+
printf("%s: membership %d: ", __func__, i);
3057+
for (int j = 0; j < n_clusters; ++j) {
3058+
printf("%f ", membership[i][j]);
3059+
}
3060+
printf(" '%s'\n", ctx->result_all[i].text.c_str());
3061+
}
3062+
printf("----------------\n");
3063+
}
3064+
3065+
// print the centroids
3066+
//for (int i = 0; i < n_clusters; ++i) {
3067+
// printf("%s: centroid %d: ", __func__, i);
3068+
// for (int j = 0; j < n_features; ++j) {
3069+
// printf("%f ", centroids[i][j]);
3070+
// }
3071+
// printf("\n");
3072+
//}
3073+
3074+
// restore the mel length
3075+
ctx->mel.n_len = mel_len_save;
3076+
}
3077+
29393078
const char * whisper_print_system_info() {
29403079
static std::string s;
29413080

whisper.h

+2
Original file line numberDiff line numberDiff line change
@@ -263,6 +263,8 @@ extern "C" {
263263
// Get the probability of the specified token in the specified segment.
264264
WHISPER_API float whisper_full_get_token_p(struct whisper_context * ctx, int i_segment, int i_token);
265265

266+
WHISPER_API void whisper_full_cluster_segments(struct whisper_context * ctx);
267+
266268
// Print system information
267269
WHISPER_API const char * whisper_print_system_info();
268270

0 commit comments

Comments
 (0)