@@ -424,6 +424,8 @@ struct whisper_context {
424
424
int64_t t_last;
425
425
whisper_token tid_last;
426
426
std::vector<float > energy; // PCM signal energy
427
+
428
+ std::vector<float > audio_embd;
427
429
};
428
430
429
431
// load the model from a ggml file
@@ -1383,18 +1385,34 @@ static bool whisper_encode(
1383
1385
}
1384
1386
1385
1387
// cur
1386
- // {
1387
- // printf("ne0 = %d\n", cur->ne[0]);
1388
- // printf("ne1 = %d\n", cur->ne[1]);
1389
- // for (int i = 0; i < 10; ++i) {
1390
- // printf("%8.4f ", ((float *)(cur->data))[i]);
1391
- // }
1392
- // printf("... ");
1393
- // for (int i = cur->ne[0] - 10; i < cur->ne[0]; ++i) {
1394
- // printf("%8.4f ", ((float *)(cur->data))[i]);
1395
- // }
1396
- // printf("\n");
1397
- // }
1388
+ {
1389
+ // printf("ne0 = %d\n", cur->ne[0]);
1390
+ // printf("ne1 = %d\n", cur->ne[1]);
1391
+ // for (int i = 0; i < 10; ++i) {
1392
+ // printf("%8.4f ", ((float *)(cur->data))[i]);
1393
+ // }
1394
+ // printf("... ");
1395
+ // for (int i = cur->ne[0] - 10; i < cur->ne[0]; ++i) {
1396
+ // printf("%8.4f ", ((float *)(cur->data))[i]);
1397
+ // }
1398
+ // printf("\n");
1399
+ }
1400
+
1401
+ {
1402
+ const int i0 = std::min (mel_offset, mel_inp.n_len );
1403
+ const int i1 = std::min (mel_offset + 2 *n_ctx, mel_inp.n_len );
1404
+
1405
+ printf (" i0 = %d, i1 = %d, (i1 - i0) = %d, embd size = %d\n " , i0, i1, i1 - i0, cur->ne [0 ]);
1406
+
1407
+ wctx.audio_embd .clear ();
1408
+ wctx.audio_embd .resize (cur->ne [0 ], 0 .0f );
1409
+ for (int j = 0 ; j < cur->ne [0 ]; ++j) {
1410
+ for (int i = i0; i < i1; ++i) {
1411
+ wctx.audio_embd [j] += ((float *)(cur->data ))[(i - i0)*cur->ne [0 ] + j];
1412
+ }
1413
+ wctx.audio_embd [j] /= (i1 - i0);
1414
+ }
1415
+ }
1398
1416
1399
1417
// pre-compute cross-attention memory
1400
1418
{
@@ -2936,6 +2954,127 @@ float whisper_full_get_token_p(struct whisper_context * ctx, int i_segment, int
2936
2954
return ctx->result_all [i_segment].tokens [i_token].p ;
2937
2955
}
2938
2956
2957
+ void whisper_full_cluster_segments (struct whisper_context * ctx) {
2958
+ const int n_segments = ctx->result_all .size ();
2959
+ printf (" %s: clustering %d segments\n " , __func__, n_segments);
2960
+
2961
+ const auto mel_len_save = ctx->mel .n_len ;
2962
+ printf (" %s: mel_len_save = %d\n " , __func__, mel_len_save);
2963
+
2964
+ std::vector<std::vector<float >> features (n_segments);
2965
+
2966
+ for (int i = 0 ; i < n_segments; ++i) {
2967
+ const auto & segment_i = ctx->result_all [i];
2968
+ printf (" %s: segment %d: t0 = %d, t1 = %d, text = %s\n " , __func__, i, (int ) segment_i.t0 , (int ) segment_i.t1 , segment_i.text .c_str ());
2969
+
2970
+ ctx->mel .n_len = segment_i.t1 ;
2971
+ whisper_encode (ctx, segment_i.t0 , 4 );
2972
+
2973
+ features[i] = ctx->audio_embd ;
2974
+ }
2975
+
2976
+ const int n_features = features[0 ].size ();
2977
+
2978
+ // fuzzy c-means clustering
2979
+ const int n_clusters = 4 ;
2980
+
2981
+ std::vector<std::vector<float >> centroids (n_clusters, std::vector<float >(n_features, 0.0 ));
2982
+ std::vector<std::vector<float >> membership (n_segments, std::vector<float >(n_clusters, 0.0 ));
2983
+
2984
+ // initialize the centroids
2985
+ for (int i = 0 ; i < n_clusters; ++i) {
2986
+ for (int j = 0 ; j < n_features; ++j) {
2987
+ centroids[i][j] = features[i][j];
2988
+ }
2989
+ }
2990
+
2991
+ // initialize the membership
2992
+ for (int i = 0 ; i < n_segments; ++i) {
2993
+ membership[i][i % n_clusters] = 1.0 ;
2994
+ }
2995
+
2996
+ // iterate
2997
+ for (int i = 0 ; i < 100 ; ++i) {
2998
+ // update the centroids
2999
+ for (int j = 0 ; j < n_clusters; ++j) {
3000
+ for (int k = 0 ; k < n_features; ++k) {
3001
+ centroids[j][k] = 0.0 ;
3002
+ }
3003
+ }
3004
+
3005
+ for (int j = 0 ; j < n_segments; ++j) {
3006
+ for (int k = 0 ; k < n_clusters; ++k) {
3007
+ for (int l = 0 ; l < n_features; ++l) {
3008
+ centroids[k][l] += membership[j][k]*features[j][l];
3009
+ }
3010
+ }
3011
+ }
3012
+
3013
+ for (int j = 0 ; j < n_clusters; ++j) {
3014
+ float sum = 0.0 ;
3015
+ for (int k = 0 ; k < n_segments; ++k) {
3016
+ sum += membership[k][j];
3017
+ }
3018
+
3019
+ for (int k = 0 ; k < n_features; ++k) {
3020
+ centroids[j][k] /= sum;
3021
+ }
3022
+ }
3023
+
3024
+ // update the membership
3025
+ for (int j = 0 ; j < n_segments; ++j) {
3026
+ for (int k = 0 ; k < n_clusters; ++k) {
3027
+ float sum = 0.0 ;
3028
+ for (int l = 0 ; l < n_clusters; ++l) {
3029
+ // sum += std::pow(whisper_distance(features[j], centroids[k])/whisper_distance(features[j], centroids[l]), 2.0/(2.0 - 1.0));
3030
+
3031
+ // use the euclidean distance
3032
+ double d0 = 0.0 ;
3033
+ for (int m = 0 ; m < n_features; ++m) {
3034
+ d0 += std::pow (features[j][m] - centroids[k][m], 2.0 );
3035
+ }
3036
+ d0 = std::sqrt (d0);
3037
+
3038
+ double d1 = 0.0 ;
3039
+ for (int m = 0 ; m < n_features; ++m) {
3040
+ d1 += std::pow (features[j][m] - centroids[l][m], 2.0 );
3041
+ }
3042
+ d1 = std::sqrt (d1);
3043
+ if (d1 == 0.0 ) {
3044
+ sum += 1.0 ;
3045
+ } else {
3046
+ sum += std::pow (d0/d1, 2.0 /(2.0 - 1.0 ));
3047
+ }
3048
+ }
3049
+
3050
+ membership[j][k] = 1.0 /sum;
3051
+ }
3052
+ }
3053
+
3054
+ // print the membership
3055
+ for (int i = 0 ; i < n_segments; ++i) {
3056
+ printf (" %s: membership %d: " , __func__, i);
3057
+ for (int j = 0 ; j < n_clusters; ++j) {
3058
+ printf (" %f " , membership[i][j]);
3059
+ }
3060
+ printf (" '%s'\n " , ctx->result_all [i].text .c_str ());
3061
+ }
3062
+ printf (" ----------------\n " );
3063
+ }
3064
+
3065
+ // print the centroids
3066
+ // for (int i = 0; i < n_clusters; ++i) {
3067
+ // printf("%s: centroid %d: ", __func__, i);
3068
+ // for (int j = 0; j < n_features; ++j) {
3069
+ // printf("%f ", centroids[i][j]);
3070
+ // }
3071
+ // printf("\n");
3072
+ // }
3073
+
3074
+ // restore the mel length
3075
+ ctx->mel .n_len = mel_len_save;
3076
+ }
3077
+
2939
3078
const char * whisper_print_system_info () {
2940
3079
static std::string s;
2941
3080
0 commit comments