@@ -603,6 +603,8 @@ struct whisper_context {
603
603
// [EXPERIMENTAL] speed-up techniques
604
604
int32_t exp_n_audio_ctx; // 0 - use default
605
605
606
+ std::vector<float > audio_embd;
607
+
606
608
void use_buf (struct ggml_context * ctx, int i) {
607
609
#if defined(WHISPER_USE_SCRATCH)
608
610
size_t last_size = 0 ;
@@ -1707,18 +1709,34 @@ static bool whisper_encode(
1707
1709
}
1708
1710
1709
1711
// cur
1710
- // {
1711
- // printf("ne0 = %d\n", cur->ne[0]);
1712
- // printf("ne1 = %d\n", cur->ne[1]);
1713
- // for (int i = 0; i < 10; ++i) {
1714
- // printf("%8.4f ", ((float *)(cur->data))[i]);
1715
- // }
1716
- // printf("... ");
1717
- // for (int i = cur->ne[0] - 10; i < cur->ne[0]; ++i) {
1718
- // printf("%8.4f ", ((float *)(cur->data))[i]);
1719
- // }
1720
- // printf("\n");
1721
- // }
1712
+ {
1713
+ // printf("ne0 = %d\n", cur->ne[0]);
1714
+ // printf("ne1 = %d\n", cur->ne[1]);
1715
+ // for (int i = 0; i < 10; ++i) {
1716
+ // printf("%8.4f ", ((float *)(cur->data))[i]);
1717
+ // }
1718
+ // printf("... ");
1719
+ // for (int i = cur->ne[0] - 10; i < cur->ne[0]; ++i) {
1720
+ // printf("%8.4f ", ((float *)(cur->data))[i]);
1721
+ // }
1722
+ // printf("\n");
1723
+ }
1724
+
1725
+ {
1726
+ const int i0 = std::min (mel_offset, mel_inp.n_len );
1727
+ const int i1 = std::min (mel_offset + 2 *n_ctx, mel_inp.n_len );
1728
+
1729
+ printf (" i0 = %d, i1 = %d, (i1 - i0) = %d, embd size = %d\n " , i0, i1, i1 - i0, cur->ne [0 ]);
1730
+
1731
+ wctx.audio_embd .clear ();
1732
+ wctx.audio_embd .resize (cur->ne [0 ], 0 .0f );
1733
+ for (int j = 0 ; j < cur->ne [0 ]; ++j) {
1734
+ for (int i = i0; i < i1; ++i) {
1735
+ wctx.audio_embd [j] += ((float *)(cur->data ))[(i - i0)*cur->ne [0 ] + j];
1736
+ }
1737
+ wctx.audio_embd [j] /= (i1 - i0);
1738
+ }
1739
+ }
1722
1740
1723
1741
// pre-compute cross-attention memory
1724
1742
{
@@ -4806,3 +4824,129 @@ static void whisper_exp_compute_token_level_timestamps(
4806
4824
// }
4807
4825
// }
4808
4826
}
4827
+
4828
+ //
4829
+ // diarization stuff
4830
+ //
4831
+
4832
+ void whisper_full_cluster_segments (struct whisper_context * ctx) {
4833
+ const int n_segments = ctx->result_all .size ();
4834
+ printf (" %s: clustering %d segments\n " , __func__, n_segments);
4835
+
4836
+ const auto mel_len_save = ctx->mel .n_len ;
4837
+ printf (" %s: mel_len_save = %d\n " , __func__, mel_len_save);
4838
+
4839
+ std::vector<std::vector<float >> features (n_segments);
4840
+
4841
+ for (int i = 0 ; i < n_segments; ++i) {
4842
+ const auto & segment_i = ctx->result_all [i];
4843
+ printf (" %s: segment %d: t0 = %d, t1 = %d, text = %s\n " , __func__, i, (int ) segment_i.t0 , (int ) segment_i.t1 , segment_i.text .c_str ());
4844
+
4845
+ ctx->mel .n_len = segment_i.t1 ;
4846
+ whisper_encode (ctx, segment_i.t0 , 4 );
4847
+
4848
+ features[i] = ctx->audio_embd ;
4849
+ }
4850
+
4851
+ const int n_features = features[0 ].size ();
4852
+
4853
+ // fuzzy c-means clustering
4854
+ const int n_clusters = 4 ;
4855
+
4856
+ std::vector<std::vector<float >> centroids (n_clusters, std::vector<float >(n_features, 0.0 ));
4857
+ std::vector<std::vector<float >> membership (n_segments, std::vector<float >(n_clusters, 0.0 ));
4858
+
4859
+ // initialize the centroids
4860
+ for (int i = 0 ; i < n_clusters; ++i) {
4861
+ for (int j = 0 ; j < n_features; ++j) {
4862
+ centroids[i][j] = features[i][j];
4863
+ }
4864
+ }
4865
+
4866
+ // initialize the membership
4867
+ for (int i = 0 ; i < n_segments; ++i) {
4868
+ membership[i][i % n_clusters] = 1.0 ;
4869
+ }
4870
+
4871
+ // iterate
4872
+ for (int i = 0 ; i < 100 ; ++i) {
4873
+ // update the centroids
4874
+ for (int j = 0 ; j < n_clusters; ++j) {
4875
+ for (int k = 0 ; k < n_features; ++k) {
4876
+ centroids[j][k] = 0.0 ;
4877
+ }
4878
+ }
4879
+
4880
+ for (int j = 0 ; j < n_segments; ++j) {
4881
+ for (int k = 0 ; k < n_clusters; ++k) {
4882
+ for (int l = 0 ; l < n_features; ++l) {
4883
+ centroids[k][l] += membership[j][k]*features[j][l];
4884
+ }
4885
+ }
4886
+ }
4887
+
4888
+ for (int j = 0 ; j < n_clusters; ++j) {
4889
+ float sum = 0.0 ;
4890
+ for (int k = 0 ; k < n_segments; ++k) {
4891
+ sum += membership[k][j];
4892
+ }
4893
+
4894
+ for (int k = 0 ; k < n_features; ++k) {
4895
+ centroids[j][k] /= sum;
4896
+ }
4897
+ }
4898
+
4899
+ // update the membership
4900
+ for (int j = 0 ; j < n_segments; ++j) {
4901
+ for (int k = 0 ; k < n_clusters; ++k) {
4902
+ float sum = 0.0 ;
4903
+ for (int l = 0 ; l < n_clusters; ++l) {
4904
+ // sum += std::pow(whisper_distance(features[j], centroids[k])/whisper_distance(features[j], centroids[l]), 2.0/(2.0 - 1.0));
4905
+
4906
+ // use the euclidean distance
4907
+ double d0 = 0.0 ;
4908
+ for (int m = 0 ; m < n_features; ++m) {
4909
+ d0 += std::pow (features[j][m] - centroids[k][m], 2.0 );
4910
+ }
4911
+ d0 = std::sqrt (d0);
4912
+
4913
+ double d1 = 0.0 ;
4914
+ for (int m = 0 ; m < n_features; ++m) {
4915
+ d1 += std::pow (features[j][m] - centroids[l][m], 2.0 );
4916
+ }
4917
+ d1 = std::sqrt (d1);
4918
+ if (d1 == 0.0 ) {
4919
+ sum += 1.0 ;
4920
+ } else {
4921
+ sum += std::pow (d0/d1, 2.0 /(2.0 - 1.0 ));
4922
+ }
4923
+ }
4924
+
4925
+ membership[j][k] = 1.0 /sum;
4926
+ }
4927
+ }
4928
+
4929
+ // print the membership
4930
+ for (int i = 0 ; i < n_segments; ++i) {
4931
+ printf (" %s: membership %d: " , __func__, i);
4932
+ for (int j = 0 ; j < n_clusters; ++j) {
4933
+ printf (" %f " , membership[i][j]);
4934
+ }
4935
+ printf (" '%s'\n " , ctx->result_all [i].text .c_str ());
4936
+ }
4937
+ printf (" ----------------\n " );
4938
+ }
4939
+
4940
+ // print the centroids
4941
+ // for (int i = 0; i < n_clusters; ++i) {
4942
+ // printf("%s: centroid %d: ", __func__, i);
4943
+ // for (int j = 0; j < n_features; ++j) {
4944
+ // printf("%f ", centroids[i][j]);
4945
+ // }
4946
+ // printf("\n");
4947
+ // }
4948
+
4949
+ // restore the mel length
4950
+ ctx->mel .n_len = mel_len_save;
4951
+ }
4952
+
0 commit comments