diff --git a/backends/candle/src/models/flash_jina.rs b/backends/candle/src/models/flash_jina.rs index 947fac6c..94878c5a 100644 --- a/backends/candle/src/models/flash_jina.rs +++ b/backends/candle/src/models/flash_jina.rs @@ -2,7 +2,7 @@ use crate::alibi::alibi_head_slopes; use crate::flash_attn::flash_attn_varlen; use crate::layers::{HiddenAct, LayerNorm, Linear}; use crate::models::bert::PositionEmbeddingType; -use crate::models::jina::JinaEmbeddings; +use crate::models::jina::{ClassificationHead, JinaBertClassificationHead, JinaEmbeddings}; use crate::models::{BertConfig, Model}; use candle::{DType, Device, IndexOp, Result, Tensor}; use candle_nn::VarBuilder; @@ -227,6 +227,8 @@ pub struct FlashJinaBertModel { embeddings: JinaEmbeddings, encoder: JinaBertEncoder, pool: Pool, + classifier: Option>, + pub device: Device, span: tracing::Span, @@ -255,15 +257,19 @@ impl FlashJinaBertModel { candle::bail!("FlashJinaBertModel requires DType::F16") } - let pool = match model_type { + let (pool, classifier) = match model_type { ModelType::Classifier => { - candle::bail!("`classifier` model type is not supported for Jina") + let pool = Pool::Cls; + + let classifier: Box = + Box::new(JinaBertClassificationHead::load(vb.clone(), config)?); + (pool, Some(classifier)) } ModelType::Embedding(pool) => { if pool == Pool::Splade { candle::bail!("`splade` is not supported for Jina") } - pool + (pool, None) } }; @@ -288,6 +294,7 @@ impl FlashJinaBertModel { embeddings, encoder, pool, + classifier, device: vb.device().clone(), span: tracing::span!(tracing::Level::TRACE, "model"), }) @@ -433,7 +440,20 @@ impl Model for FlashJinaBertModel { fn is_padded(&self) -> bool { false } + fn embed(&self, batch: Batch) -> Result<(Option, Option)> { self.forward(batch) } + + fn predict(&self, batch: Batch) -> Result { + match &self.classifier { + None => candle::bail!("`predict` is not implemented for this model"), + Some(classifier) => { + let (pooled_embeddings, _raw_embeddings) = self.forward(batch)?; + let pooled_embeddings = + pooled_embeddings.expect("pooled_embeddings is empty. This is a bug."); + classifier.forward(&pooled_embeddings) + } + } + } } diff --git a/backends/candle/src/models/jina.rs b/backends/candle/src/models/jina.rs index ecee8bfe..b694befb 100644 --- a/backends/candle/src/models/jina.rs +++ b/backends/candle/src/models/jina.rs @@ -339,11 +339,69 @@ impl JinaBertEncoder { } } +pub trait ClassificationHead { + fn forward(&self, hidden_states: &Tensor) -> Result; +} + +pub struct JinaBertClassificationHead { + pooler: Option, + output: Linear, + span: tracing::Span, +} + +impl JinaBertClassificationHead { + pub(crate) fn load(vb: VarBuilder, config: &BertConfig) -> Result { + let n_classes = match &config.id2label { + None => candle::bail!("`id2label` must be set for classifier models"), + Some(id2label) => id2label.len(), + }; + + let pooler = if let Ok(pooler_weight) = vb + .pp("bert.pooler.dense") + .get((config.hidden_size, config.hidden_size), "weight") + { + let pooler_bias = vb.pp("bert.pooler.dense").get(config.hidden_size, "bias")?; + Some(Linear::new(pooler_weight, Some(pooler_bias), None)) + } else { + None + }; + + let output_weight = vb + .pp("classifier") + .get((n_classes, config.hidden_size), "weight")?; + let output_bias = vb.pp("classifier").get(n_classes, "bias")?; + let output = Linear::new(output_weight, Some(output_bias), None); + + Ok(Self { + pooler, + output, + span: tracing::span!(tracing::Level::TRACE, "classifier"), + }) + } +} + +impl ClassificationHead for JinaBertClassificationHead { + fn forward(&self, hidden_states: &Tensor) -> Result { + let _enter = self.span.enter(); + + let mut hidden_states = hidden_states.unsqueeze(1)?; + if let Some(pooler) = self.pooler.as_ref() { + hidden_states = pooler.forward(&hidden_states)?; + hidden_states = hidden_states.tanh()?; + } + + let hidden_states = self.output.forward(&hidden_states)?; + let hidden_states = hidden_states.squeeze(1)?; + Ok(hidden_states) + } +} + pub struct JinaBertModel { embeddings: JinaEmbeddings, encoder: JinaBertEncoder, pool: Pool, alibi: Option, + classifier: Option>, num_attention_heads: usize, @@ -366,9 +424,12 @@ impl JinaBertModel { _ => candle::bail!("not supported"), }; - let pool = match model_type { + let (pool, classifier) = match model_type { ModelType::Classifier => { - candle::bail!("`classifier` model type is not supported for Jina") + let pool = Pool::Cls; + let classifier: Box = + Box::new(JinaBertClassificationHead::load(vb.clone(), config)?); + (pool, Some(classifier)) } ModelType::Embedding(pool) => { if pool == Pool::Splade { @@ -377,7 +438,7 @@ impl JinaBertModel { if pool == Pool::LastToken { candle::bail!("`last_token` is not supported for Jina"); } - pool + (pool, None) } }; @@ -403,6 +464,7 @@ impl JinaBertModel { encoder, pool, alibi, + classifier, num_attention_heads: config.num_attention_heads, device: vb.device().clone(), dtype: vb.dtype(), @@ -667,7 +729,20 @@ impl Model for JinaBertModel { fn is_padded(&self) -> bool { true } + fn embed(&self, batch: Batch) -> Result<(Option, Option)> { self.forward(batch) } + + fn predict(&self, batch: Batch) -> Result { + match &self.classifier { + None => candle::bail!("`predict` is not implemented for this model"), + Some(classifier) => { + let (pooled_embeddings, _raw_embeddings) = self.forward(batch)?; + let pooled_embeddings = + pooled_embeddings.expect("pooled_embeddings is empty. This is a bug."); + classifier.forward(&pooled_embeddings) + } + } + } } diff --git a/backends/candle/src/models/jina_code.rs b/backends/candle/src/models/jina_code.rs index 5f13fe08..8dadea00 100644 --- a/backends/candle/src/models/jina_code.rs +++ b/backends/candle/src/models/jina_code.rs @@ -656,6 +656,7 @@ impl Model for JinaCodeBertModel { fn is_padded(&self) -> bool { true } + fn embed(&self, batch: Batch) -> Result<(Option, Option)> { self.forward(batch) } diff --git a/backends/candle/src/models/mod.rs b/backends/candle/src/models/mod.rs index a5127da6..74db7d84 100644 --- a/backends/candle/src/models/mod.rs +++ b/backends/candle/src/models/mod.rs @@ -92,6 +92,6 @@ pub(crate) trait Model { } fn predict(&self, _batch: Batch) -> Result { - candle::bail!("`predict is not implemented for this model"); + candle::bail!("`predict` is not implemented for this model"); } } diff --git a/backends/candle/tests/snapshots/test_jina__jinabert_reranker_single.snap b/backends/candle/tests/snapshots/test_jina__jinabert_reranker_single.snap new file mode 100644 index 00000000..b84fbc28 --- /dev/null +++ b/backends/candle/tests/snapshots/test_jina__jinabert_reranker_single.snap @@ -0,0 +1,5 @@ +--- +source: backends/candle/tests/test_jina.rs +expression: predictions +--- +- - -0.6045344 diff --git a/backends/candle/tests/test_jina.rs b/backends/candle/tests/test_jina.rs index ae162368..9ea15b50 100644 --- a/backends/candle/tests/test_jina.rs +++ b/backends/candle/tests/test_jina.rs @@ -1,8 +1,8 @@ mod common; -use crate::common::{sort_embeddings, SnapshotEmbeddings}; +use crate::common::{sort_embeddings, SnapshotEmbeddings, SnapshotScores}; use anyhow::Result; -use common::{batch, cosine_matcher, download_artifacts, load_tokenizer}; +use common::{batch, cosine_matcher, download_artifacts, load_tokenizer, relative_matcher}; use text_embeddings_backend_candle::CandleBackend; use text_embeddings_backend_core::{Backend, ModelType, Pool}; @@ -48,3 +48,29 @@ fn test_jina_small() -> Result<()> { Ok(()) } + +#[test] +#[serial_test::serial] +fn test_jina_rerank() -> Result<()> { + let model_root = download_artifacts("jinaai/jina-reranker-v1-tiny-en", Some("refs/pr/11"))?; + let tokenizer = load_tokenizer(&model_root)?; + + let backend = CandleBackend::new(&model_root, "float32".to_string(), ModelType::Classifier)?; + + let input_single = batch( + vec![tokenizer.encode("What is Deep Learning?", true).unwrap()], + [0].to_vec(), + vec![], + ); + + let predictions: Vec> = backend + .predict(input_single)? + .into_iter() + .map(|(_, v)| v) + .collect(); + + let predictions = SnapshotScores::from(predictions); + insta::assert_yaml_snapshot!("jinabert_reranker_single", predictions, &relative_matcher()); + + Ok(()) +} diff --git a/backends/ort/src/lib.rs b/backends/ort/src/lib.rs index a0cc19b7..4e6cea44 100644 --- a/backends/ort/src/lib.rs +++ b/backends/ort/src/lib.rs @@ -21,8 +21,7 @@ impl OrtBackend { model_type: ModelType, ) -> Result { // Check dtype - if dtype == "float32" { - } else { + if dtype != "float32" { return Err(BackendError::Start(format!( "DType {dtype} is not supported" ))); @@ -167,8 +166,8 @@ impl Backend for OrtBackend { // Run model let outputs = self.session.run(inputs).e()?; - // Get last_hidden_state ndarray + // Get last_hidden_state ndarray let outputs = outputs .get("last_hidden_state") .or(outputs.get("token_embeddings")) @@ -362,6 +361,7 @@ impl Backend for OrtBackend { // Run model let outputs = self.session.run(inputs).e()?; + // Get last_hidden_state ndarray let outputs = outputs["logits"] .try_extract_tensor::()