From c065defc22cbf0c2e0c9204efe4875b38a9cb240 Mon Sep 17 00:00:00 2001 From: Alvaro Bartolome <36760800+alvarobartt@users.noreply.github.com> Date: Mon, 14 Apr 2025 09:01:08 +0200 Subject: [PATCH 01/11] Fix spacing in `OrtBackend` --- backends/ort/src/lib.rs | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/backends/ort/src/lib.rs b/backends/ort/src/lib.rs index a0cc19b7..3737d384 100644 --- a/backends/ort/src/lib.rs +++ b/backends/ort/src/lib.rs @@ -167,8 +167,8 @@ impl Backend for OrtBackend { // Run model let outputs = self.session.run(inputs).e()?; - // Get last_hidden_state ndarray + // Get last_hidden_state ndarray let outputs = outputs .get("last_hidden_state") .or(outputs.get("token_embeddings")) @@ -362,6 +362,7 @@ impl Backend for OrtBackend { // Run model let outputs = self.session.run(inputs).e()?; + // Get last_hidden_state ndarray let outputs = outputs["logits"] .try_extract_tensor::() From 9342ba314b9e19907276aba775b5c03db91a003f Mon Sep 17 00:00:00 2001 From: Alvaro Bartolome <36760800+alvarobartt@users.noreply.github.com> Date: Mon, 14 Apr 2025 09:01:38 +0200 Subject: [PATCH 02/11] Use default `Pool::Cls` for `{Jina,JinaCode}BertModel` if classifier --- backends/candle/src/models/jina.rs | 4 +--- backends/candle/src/models/jina_code.rs | 4 +--- 2 files changed, 2 insertions(+), 6 deletions(-) diff --git a/backends/candle/src/models/jina.rs b/backends/candle/src/models/jina.rs index ecee8bfe..51fbba36 100644 --- a/backends/candle/src/models/jina.rs +++ b/backends/candle/src/models/jina.rs @@ -367,9 +367,7 @@ impl JinaBertModel { }; let pool = match model_type { - ModelType::Classifier => { - candle::bail!("`classifier` model type is not supported for Jina") - } + ModelType::Classifier => Pool::Cls, ModelType::Embedding(pool) => { if pool == Pool::Splade { candle::bail!("`splade` is not supported for Jina") diff --git a/backends/candle/src/models/jina_code.rs b/backends/candle/src/models/jina_code.rs index 5f13fe08..fcd75c3b 100644 --- a/backends/candle/src/models/jina_code.rs +++ b/backends/candle/src/models/jina_code.rs @@ -356,9 +356,7 @@ impl JinaCodeBertModel { }; let pool = match model_type { - ModelType::Classifier => { - candle::bail!("`classifier` model type is not supported for JinaCode") - } + ModelType::Classifier => Pool::Cls, ModelType::Embedding(pool) => { if pool == Pool::Splade { candle::bail!("`splade` is not supported for JinaCode") From b43fa49d7d6a8c11ef5883d023ed20f1f660bfb6 Mon Sep 17 00:00:00 2001 From: Alvaro Bartolome <36760800+alvarobartt@users.noreply.github.com> Date: Mon, 14 Apr 2025 09:22:59 +0200 Subject: [PATCH 03/11] Add missing backtick --- backends/candle/src/models/mod.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/backends/candle/src/models/mod.rs b/backends/candle/src/models/mod.rs index a5127da6..74db7d84 100644 --- a/backends/candle/src/models/mod.rs +++ b/backends/candle/src/models/mod.rs @@ -92,6 +92,6 @@ pub(crate) trait Model { } fn predict(&self, _batch: Batch) -> Result { - candle::bail!("`predict is not implemented for this model"); + candle::bail!("`predict` is not implemented for this model"); } } From 6ce15593b03454b3981353066de9ffea1a8ff2b1 Mon Sep 17 00:00:00 2001 From: Alvaro Bartolome <36760800+alvarobartt@users.noreply.github.com> Date: Mon, 14 Apr 2025 10:42:43 +0200 Subject: [PATCH 04/11] Improve `dtype` check readability in `OrtBackend` --- backends/ort/src/lib.rs | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/backends/ort/src/lib.rs b/backends/ort/src/lib.rs index 3737d384..4e6cea44 100644 --- a/backends/ort/src/lib.rs +++ b/backends/ort/src/lib.rs @@ -21,8 +21,7 @@ impl OrtBackend { model_type: ModelType, ) -> Result { // Check dtype - if dtype == "float32" { - } else { + if dtype != "float32" { return Err(BackendError::Start(format!( "DType {dtype} is not supported" ))); From ea24c2cf4591e44aabed4c95f8f6f86a0bcb989a Mon Sep 17 00:00:00 2001 From: Alvaro Bartolome <36760800+alvarobartt@users.noreply.github.com> Date: Mon, 14 Apr 2025 11:52:11 +0200 Subject: [PATCH 05/11] Revert default `Pool::Cls` for `JinaCodeBertModel` --- backends/candle/src/models/jina_code.rs | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/backends/candle/src/models/jina_code.rs b/backends/candle/src/models/jina_code.rs index fcd75c3b..8dadea00 100644 --- a/backends/candle/src/models/jina_code.rs +++ b/backends/candle/src/models/jina_code.rs @@ -356,7 +356,9 @@ impl JinaCodeBertModel { }; let pool = match model_type { - ModelType::Classifier => Pool::Cls, + ModelType::Classifier => { + candle::bail!("`classifier` model type is not supported for JinaCode") + } ModelType::Embedding(pool) => { if pool == Pool::Splade { candle::bail!("`splade` is not supported for JinaCode") @@ -654,6 +656,7 @@ impl Model for JinaCodeBertModel { fn is_padded(&self) -> bool { true } + fn embed(&self, batch: Batch) -> Result<(Option, Option)> { self.forward(batch) } From 1a902e2af23dd4a42da4d54e34270d8969f743aa Mon Sep 17 00:00:00 2001 From: Alvaro Bartolome <36760800+alvarobartt@users.noreply.github.com> Date: Mon, 14 Apr 2025 11:52:30 +0200 Subject: [PATCH 06/11] Add `JinaBertClassificationHead` --- backends/candle/src/models/jina.rs | 83 ++++++++++++++++++++++++++++-- 1 file changed, 80 insertions(+), 3 deletions(-) diff --git a/backends/candle/src/models/jina.rs b/backends/candle/src/models/jina.rs index 51fbba36..b694befb 100644 --- a/backends/candle/src/models/jina.rs +++ b/backends/candle/src/models/jina.rs @@ -339,11 +339,69 @@ impl JinaBertEncoder { } } +pub trait ClassificationHead { + fn forward(&self, hidden_states: &Tensor) -> Result; +} + +pub struct JinaBertClassificationHead { + pooler: Option, + output: Linear, + span: tracing::Span, +} + +impl JinaBertClassificationHead { + pub(crate) fn load(vb: VarBuilder, config: &BertConfig) -> Result { + let n_classes = match &config.id2label { + None => candle::bail!("`id2label` must be set for classifier models"), + Some(id2label) => id2label.len(), + }; + + let pooler = if let Ok(pooler_weight) = vb + .pp("bert.pooler.dense") + .get((config.hidden_size, config.hidden_size), "weight") + { + let pooler_bias = vb.pp("bert.pooler.dense").get(config.hidden_size, "bias")?; + Some(Linear::new(pooler_weight, Some(pooler_bias), None)) + } else { + None + }; + + let output_weight = vb + .pp("classifier") + .get((n_classes, config.hidden_size), "weight")?; + let output_bias = vb.pp("classifier").get(n_classes, "bias")?; + let output = Linear::new(output_weight, Some(output_bias), None); + + Ok(Self { + pooler, + output, + span: tracing::span!(tracing::Level::TRACE, "classifier"), + }) + } +} + +impl ClassificationHead for JinaBertClassificationHead { + fn forward(&self, hidden_states: &Tensor) -> Result { + let _enter = self.span.enter(); + + let mut hidden_states = hidden_states.unsqueeze(1)?; + if let Some(pooler) = self.pooler.as_ref() { + hidden_states = pooler.forward(&hidden_states)?; + hidden_states = hidden_states.tanh()?; + } + + let hidden_states = self.output.forward(&hidden_states)?; + let hidden_states = hidden_states.squeeze(1)?; + Ok(hidden_states) + } +} + pub struct JinaBertModel { embeddings: JinaEmbeddings, encoder: JinaBertEncoder, pool: Pool, alibi: Option, + classifier: Option>, num_attention_heads: usize, @@ -366,8 +424,13 @@ impl JinaBertModel { _ => candle::bail!("not supported"), }; - let pool = match model_type { - ModelType::Classifier => Pool::Cls, + let (pool, classifier) = match model_type { + ModelType::Classifier => { + let pool = Pool::Cls; + let classifier: Box = + Box::new(JinaBertClassificationHead::load(vb.clone(), config)?); + (pool, Some(classifier)) + } ModelType::Embedding(pool) => { if pool == Pool::Splade { candle::bail!("`splade` is not supported for Jina") @@ -375,7 +438,7 @@ impl JinaBertModel { if pool == Pool::LastToken { candle::bail!("`last_token` is not supported for Jina"); } - pool + (pool, None) } }; @@ -401,6 +464,7 @@ impl JinaBertModel { encoder, pool, alibi, + classifier, num_attention_heads: config.num_attention_heads, device: vb.device().clone(), dtype: vb.dtype(), @@ -665,7 +729,20 @@ impl Model for JinaBertModel { fn is_padded(&self) -> bool { true } + fn embed(&self, batch: Batch) -> Result<(Option, Option)> { self.forward(batch) } + + fn predict(&self, batch: Batch) -> Result { + match &self.classifier { + None => candle::bail!("`predict` is not implemented for this model"), + Some(classifier) => { + let (pooled_embeddings, _raw_embeddings) = self.forward(batch)?; + let pooled_embeddings = + pooled_embeddings.expect("pooled_embeddings is empty. This is a bug."); + classifier.forward(&pooled_embeddings) + } + } + } } From 121c2b5ff56069666bfb4e617c50e08dbee2b53c Mon Sep 17 00:00:00 2001 From: Alvaro Bartolome <36760800+alvarobartt@users.noreply.github.com> Date: Mon, 14 Apr 2025 12:01:41 +0200 Subject: [PATCH 07/11] Add `classifier` to `FlashJinaBertModel` --- backends/candle/src/models/flash_jina.rs | 28 ++++++++++++++++++++---- 1 file changed, 24 insertions(+), 4 deletions(-) diff --git a/backends/candle/src/models/flash_jina.rs b/backends/candle/src/models/flash_jina.rs index 947fac6c..94878c5a 100644 --- a/backends/candle/src/models/flash_jina.rs +++ b/backends/candle/src/models/flash_jina.rs @@ -2,7 +2,7 @@ use crate::alibi::alibi_head_slopes; use crate::flash_attn::flash_attn_varlen; use crate::layers::{HiddenAct, LayerNorm, Linear}; use crate::models::bert::PositionEmbeddingType; -use crate::models::jina::JinaEmbeddings; +use crate::models::jina::{ClassificationHead, JinaBertClassificationHead, JinaEmbeddings}; use crate::models::{BertConfig, Model}; use candle::{DType, Device, IndexOp, Result, Tensor}; use candle_nn::VarBuilder; @@ -227,6 +227,8 @@ pub struct FlashJinaBertModel { embeddings: JinaEmbeddings, encoder: JinaBertEncoder, pool: Pool, + classifier: Option>, + pub device: Device, span: tracing::Span, @@ -255,15 +257,19 @@ impl FlashJinaBertModel { candle::bail!("FlashJinaBertModel requires DType::F16") } - let pool = match model_type { + let (pool, classifier) = match model_type { ModelType::Classifier => { - candle::bail!("`classifier` model type is not supported for Jina") + let pool = Pool::Cls; + + let classifier: Box = + Box::new(JinaBertClassificationHead::load(vb.clone(), config)?); + (pool, Some(classifier)) } ModelType::Embedding(pool) => { if pool == Pool::Splade { candle::bail!("`splade` is not supported for Jina") } - pool + (pool, None) } }; @@ -288,6 +294,7 @@ impl FlashJinaBertModel { embeddings, encoder, pool, + classifier, device: vb.device().clone(), span: tracing::span!(tracing::Level::TRACE, "model"), }) @@ -433,7 +440,20 @@ impl Model for FlashJinaBertModel { fn is_padded(&self) -> bool { false } + fn embed(&self, batch: Batch) -> Result<(Option, Option)> { self.forward(batch) } + + fn predict(&self, batch: Batch) -> Result { + match &self.classifier { + None => candle::bail!("`predict` is not implemented for this model"), + Some(classifier) => { + let (pooled_embeddings, _raw_embeddings) = self.forward(batch)?; + let pooled_embeddings = + pooled_embeddings.expect("pooled_embeddings is empty. This is a bug."); + classifier.forward(&pooled_embeddings) + } + } + } } From 0e212a2156aa017fe1488597e977daa0f3ce82e9 Mon Sep 17 00:00:00 2001 From: Alvaro Bartolome <36760800+alvarobartt@users.noreply.github.com> Date: Tue, 15 Apr 2025 17:37:02 +0200 Subject: [PATCH 08/11] Add `test_jina_rerank` --- .../test_jina__jinabert_reranker_single.snap | 5 ++++ backends/candle/tests/test_jina.rs | 30 +++++++++++++++++-- 2 files changed, 33 insertions(+), 2 deletions(-) create mode 100644 backends/candle/tests/snapshots/test_jina__jinabert_reranker_single.snap diff --git a/backends/candle/tests/snapshots/test_jina__jinabert_reranker_single.snap b/backends/candle/tests/snapshots/test_jina__jinabert_reranker_single.snap new file mode 100644 index 00000000..98f2245e --- /dev/null +++ b/backends/candle/tests/snapshots/test_jina__jinabert_reranker_single.snap @@ -0,0 +1,5 @@ +--- +source: backends/candle/tests/test_jina.rs +expression: predictions +--- +- - -0.9144413 diff --git a/backends/candle/tests/test_jina.rs b/backends/candle/tests/test_jina.rs index ae162368..901de1d0 100644 --- a/backends/candle/tests/test_jina.rs +++ b/backends/candle/tests/test_jina.rs @@ -1,8 +1,8 @@ mod common; -use crate::common::{sort_embeddings, SnapshotEmbeddings}; +use crate::common::{sort_embeddings, SnapshotEmbeddings, SnapshotScores}; use anyhow::Result; -use common::{batch, cosine_matcher, download_artifacts, load_tokenizer}; +use common::{batch, cosine_matcher, download_artifacts, load_tokenizer, relative_matcher}; use text_embeddings_backend_candle::CandleBackend; use text_embeddings_backend_core::{Backend, ModelType, Pool}; @@ -48,3 +48,29 @@ fn test_jina_small() -> Result<()> { Ok(()) } + +#[test] +#[serial_test::serial] +fn test_jina_rerank() -> Result<()> { + let model_root = download_artifacts("jinaai/jina-reranker-v1-turbo-en", Some("refs/pr/13"))?; + let tokenizer = load_tokenizer(&model_root)?; + + let backend = CandleBackend::new(&model_root, "float32".to_string(), ModelType::Classifier)?; + + let input_single = batch( + vec![tokenizer.encode("What is Deep Learning?", true).unwrap()], + [0].to_vec(), + vec![], + ); + + let predictions: Vec> = backend + .predict(input_single)? + .into_iter() + .map(|(_, v)| v) + .collect(); + + let predictions = SnapshotScores::from(predictions); + insta::assert_yaml_snapshot!("jinabert_reranker_single", predictions, &relative_matcher()); + + Ok(()) +} From d41ef6e5c67ac6cc0086af6fca9d9edd04ac39ce Mon Sep 17 00:00:00 2001 From: Alvaro Bartolome <36760800+alvarobartt@users.noreply.github.com> Date: Tue, 15 Apr 2025 17:40:54 +0200 Subject: [PATCH 09/11] Use `jinaai/jina-reranker-v1-tiny-en` instead It's smaller and just created the PR to add support for TEI --- .../tests/snapshots/test_jina__jinabert_reranker_single.snap | 2 +- backends/candle/tests/test_jina.rs | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/backends/candle/tests/snapshots/test_jina__jinabert_reranker_single.snap b/backends/candle/tests/snapshots/test_jina__jinabert_reranker_single.snap index 98f2245e..b84fbc28 100644 --- a/backends/candle/tests/snapshots/test_jina__jinabert_reranker_single.snap +++ b/backends/candle/tests/snapshots/test_jina__jinabert_reranker_single.snap @@ -2,4 +2,4 @@ source: backends/candle/tests/test_jina.rs expression: predictions --- -- - -0.9144413 +- - -0.6045344 diff --git a/backends/candle/tests/test_jina.rs b/backends/candle/tests/test_jina.rs index 901de1d0..9ea15b50 100644 --- a/backends/candle/tests/test_jina.rs +++ b/backends/candle/tests/test_jina.rs @@ -52,7 +52,7 @@ fn test_jina_small() -> Result<()> { #[test] #[serial_test::serial] fn test_jina_rerank() -> Result<()> { - let model_root = download_artifacts("jinaai/jina-reranker-v1-turbo-en", Some("refs/pr/13"))?; + let model_root = download_artifacts("jinaai/jina-reranker-v1-tiny-en", Some("refs/pr/11"))?; let tokenizer = load_tokenizer(&model_root)?; let backend = CandleBackend::new(&model_root, "float32".to_string(), ModelType::Classifier)?; From ca27d1337688206b7b367078485ae0b5205d585c Mon Sep 17 00:00:00 2001 From: Alvaro Bartolome <36760800+alvarobartt@users.noreply.github.com> Date: Wed, 16 Apr 2025 09:45:38 +0200 Subject: [PATCH 10/11] Bump `sccache-action` to 0.0.9 --- .github/workflows/test.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index 92620c25..e2b6db13 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -31,7 +31,7 @@ jobs: uses: actions/checkout@v4 - uses: actions-rust-lang/setup-rust-toolchain@v1 - name: Run sccache-cache - uses: mozilla-actions/sccache-action@v0.0.7 + uses: mozilla-actions/sccache-action@v0.0.9 - name: Compile project env: SCCACHE_GHA_ENABLED: "true" From 5b1b336b5c4a3b5b81146072037b48c7f13409a9 Mon Sep 17 00:00:00 2001 From: Alvaro Bartolome <36760800+alvarobartt@users.noreply.github.com> Date: Wed, 16 Apr 2025 10:08:54 +0200 Subject: [PATCH 11/11] Revert "Bump `sccache-action` to 0.0.9" This reverts commit ca27d1337688206b7b367078485ae0b5205d585c. --- .github/workflows/test.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index e2b6db13..92620c25 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -31,7 +31,7 @@ jobs: uses: actions/checkout@v4 - uses: actions-rust-lang/setup-rust-toolchain@v1 - name: Run sccache-cache - uses: mozilla-actions/sccache-action@v0.0.9 + uses: mozilla-actions/sccache-action@v0.0.7 - name: Compile project env: SCCACHE_GHA_ENABLED: "true"