Skip to content

Add support for JinaAI Re-rankers #582

New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Open
wants to merge 13 commits into
base: main
Choose a base branch
from
28 changes: 24 additions & 4 deletions backends/candle/src/models/flash_jina.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use crate::alibi::alibi_head_slopes;
use crate::flash_attn::flash_attn_varlen;
use crate::layers::{HiddenAct, LayerNorm, Linear};
use crate::models::bert::PositionEmbeddingType;
use crate::models::jina::JinaEmbeddings;
use crate::models::jina::{ClassificationHead, JinaBertClassificationHead, JinaEmbeddings};
use crate::models::{BertConfig, Model};
use candle::{DType, Device, IndexOp, Result, Tensor};
use candle_nn::VarBuilder;
Expand Down Expand Up @@ -227,6 +227,8 @@ pub struct FlashJinaBertModel {
embeddings: JinaEmbeddings,
encoder: JinaBertEncoder,
pool: Pool,
classifier: Option<Box<dyn ClassificationHead + Send>>,

pub device: Device,

span: tracing::Span,
Expand Down Expand Up @@ -255,15 +257,19 @@ impl FlashJinaBertModel {
candle::bail!("FlashJinaBertModel requires DType::F16")
}

let pool = match model_type {
let (pool, classifier) = match model_type {
ModelType::Classifier => {
candle::bail!("`classifier` model type is not supported for Jina")
let pool = Pool::Cls;

let classifier: Box<dyn ClassificationHead + Send> =
Box::new(JinaBertClassificationHead::load(vb.clone(), config)?);
(pool, Some(classifier))
}
ModelType::Embedding(pool) => {
if pool == Pool::Splade {
candle::bail!("`splade` is not supported for Jina")
}
pool
(pool, None)
}
};

Expand All @@ -288,6 +294,7 @@ impl FlashJinaBertModel {
embeddings,
encoder,
pool,
classifier,
device: vb.device().clone(),
span: tracing::span!(tracing::Level::TRACE, "model"),
})
Expand Down Expand Up @@ -433,7 +440,20 @@ impl Model for FlashJinaBertModel {
fn is_padded(&self) -> bool {
false
}

fn embed(&self, batch: Batch) -> Result<(Option<Tensor>, Option<Tensor>)> {
self.forward(batch)
}

fn predict(&self, batch: Batch) -> Result<Tensor> {
match &self.classifier {
None => candle::bail!("`predict` is not implemented for this model"),
Some(classifier) => {
let (pooled_embeddings, _raw_embeddings) = self.forward(batch)?;
let pooled_embeddings =
pooled_embeddings.expect("pooled_embeddings is empty. This is a bug.");
classifier.forward(&pooled_embeddings)
}
}
}
}
81 changes: 78 additions & 3 deletions backends/candle/src/models/jina.rs
Original file line number Diff line number Diff line change
Expand Up @@ -339,11 +339,69 @@ impl JinaBertEncoder {
}
}

pub trait ClassificationHead {
fn forward(&self, hidden_states: &Tensor) -> Result<Tensor>;
}

pub struct JinaBertClassificationHead {
pooler: Option<Linear>,
output: Linear,
span: tracing::Span,
}

impl JinaBertClassificationHead {
pub(crate) fn load(vb: VarBuilder, config: &BertConfig) -> Result<Self> {
let n_classes = match &config.id2label {
None => candle::bail!("`id2label` must be set for classifier models"),
Some(id2label) => id2label.len(),
};

let pooler = if let Ok(pooler_weight) = vb
.pp("bert.pooler.dense")
.get((config.hidden_size, config.hidden_size), "weight")
{
let pooler_bias = vb.pp("bert.pooler.dense").get(config.hidden_size, "bias")?;
Some(Linear::new(pooler_weight, Some(pooler_bias), None))
} else {
None
};

let output_weight = vb
.pp("classifier")
.get((n_classes, config.hidden_size), "weight")?;
let output_bias = vb.pp("classifier").get(n_classes, "bias")?;
let output = Linear::new(output_weight, Some(output_bias), None);

Ok(Self {
pooler,
output,
span: tracing::span!(tracing::Level::TRACE, "classifier"),
})
}
}

impl ClassificationHead for JinaBertClassificationHead {
fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();

let mut hidden_states = hidden_states.unsqueeze(1)?;
if let Some(pooler) = self.pooler.as_ref() {
hidden_states = pooler.forward(&hidden_states)?;
hidden_states = hidden_states.tanh()?;
}

let hidden_states = self.output.forward(&hidden_states)?;
let hidden_states = hidden_states.squeeze(1)?;
Ok(hidden_states)
}
}

pub struct JinaBertModel {
embeddings: JinaEmbeddings,
encoder: JinaBertEncoder,
pool: Pool,
alibi: Option<Tensor>,
classifier: Option<Box<dyn ClassificationHead + Send>>,

num_attention_heads: usize,

Expand All @@ -366,9 +424,12 @@ impl JinaBertModel {
_ => candle::bail!("not supported"),
};

let pool = match model_type {
let (pool, classifier) = match model_type {
ModelType::Classifier => {
candle::bail!("`classifier` model type is not supported for Jina")
let pool = Pool::Cls;
let classifier: Box<dyn ClassificationHead + Send> =
Box::new(JinaBertClassificationHead::load(vb.clone(), config)?);
(pool, Some(classifier))
}
ModelType::Embedding(pool) => {
if pool == Pool::Splade {
Expand All @@ -377,7 +438,7 @@ impl JinaBertModel {
if pool == Pool::LastToken {
candle::bail!("`last_token` is not supported for Jina");
}
pool
(pool, None)
}
};

Expand All @@ -403,6 +464,7 @@ impl JinaBertModel {
encoder,
pool,
alibi,
classifier,
num_attention_heads: config.num_attention_heads,
device: vb.device().clone(),
dtype: vb.dtype(),
Expand Down Expand Up @@ -667,7 +729,20 @@ impl Model for JinaBertModel {
fn is_padded(&self) -> bool {
true
}

fn embed(&self, batch: Batch) -> Result<(Option<Tensor>, Option<Tensor>)> {
self.forward(batch)
}

fn predict(&self, batch: Batch) -> Result<Tensor> {
match &self.classifier {
None => candle::bail!("`predict` is not implemented for this model"),
Some(classifier) => {
let (pooled_embeddings, _raw_embeddings) = self.forward(batch)?;
let pooled_embeddings =
pooled_embeddings.expect("pooled_embeddings is empty. This is a bug.");
classifier.forward(&pooled_embeddings)
}
}
}
}
1 change: 1 addition & 0 deletions backends/candle/src/models/jina_code.rs
Original file line number Diff line number Diff line change
Expand Up @@ -656,6 +656,7 @@ impl Model for JinaCodeBertModel {
fn is_padded(&self) -> bool {
true
}

fn embed(&self, batch: Batch) -> Result<(Option<Tensor>, Option<Tensor>)> {
self.forward(batch)
}
Expand Down
2 changes: 1 addition & 1 deletion backends/candle/src/models/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,6 @@ pub(crate) trait Model {
}

fn predict(&self, _batch: Batch) -> Result<Tensor> {
candle::bail!("`predict is not implemented for this model");
candle::bail!("`predict` is not implemented for this model");
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
source: backends/candle/tests/test_jina.rs
expression: predictions
---
- - -0.6045344
30 changes: 28 additions & 2 deletions backends/candle/tests/test_jina.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
mod common;

use crate::common::{sort_embeddings, SnapshotEmbeddings};
use crate::common::{sort_embeddings, SnapshotEmbeddings, SnapshotScores};
use anyhow::Result;
use common::{batch, cosine_matcher, download_artifacts, load_tokenizer};
use common::{batch, cosine_matcher, download_artifacts, load_tokenizer, relative_matcher};
use text_embeddings_backend_candle::CandleBackend;
use text_embeddings_backend_core::{Backend, ModelType, Pool};

Expand Down Expand Up @@ -48,3 +48,29 @@ fn test_jina_small() -> Result<()> {

Ok(())
}

#[test]
#[serial_test::serial]
fn test_jina_rerank() -> Result<()> {
let model_root = download_artifacts("jinaai/jina-reranker-v1-tiny-en", Some("refs/pr/11"))?;
let tokenizer = load_tokenizer(&model_root)?;

let backend = CandleBackend::new(&model_root, "float32".to_string(), ModelType::Classifier)?;

let input_single = batch(
vec![tokenizer.encode("What is Deep Learning?", true).unwrap()],
[0].to_vec(),
vec![],
);

let predictions: Vec<Vec<f32>> = backend
.predict(input_single)?
.into_iter()
.map(|(_, v)| v)
.collect();

let predictions = SnapshotScores::from(predictions);
insta::assert_yaml_snapshot!("jinabert_reranker_single", predictions, &relative_matcher());

Ok(())
}
6 changes: 3 additions & 3 deletions backends/ort/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,7 @@ impl OrtBackend {
model_type: ModelType,
) -> Result<Self, BackendError> {
// Check dtype
if dtype == "float32" {
} else {
if dtype != "float32" {
return Err(BackendError::Start(format!(
"DType {dtype} is not supported"
)));
Expand Down Expand Up @@ -167,8 +166,8 @@ impl Backend for OrtBackend {

// Run model
let outputs = self.session.run(inputs).e()?;
// Get last_hidden_state ndarray

// Get last_hidden_state ndarray
let outputs = outputs
.get("last_hidden_state")
.or(outputs.get("token_embeddings"))
Expand Down Expand Up @@ -362,6 +361,7 @@ impl Backend for OrtBackend {

// Run model
let outputs = self.session.run(inputs).e()?;

// Get last_hidden_state ndarray
let outputs = outputs["logits"]
.try_extract_tensor::<f32>()
Expand Down
Loading