diff --git a/backends/candle/Cargo.toml b/backends/candle/Cargo.toml index f26d213a..60931fb6 100644 --- a/backends/candle/Cargo.toml +++ b/backends/candle/Cargo.toml @@ -25,13 +25,13 @@ thiserror = { workspace = true } serde = { workspace = true } serde_json = { workspace = true } memmap2 = "^0.9" +tokenizers = { workspace = true } [dev-dependencies] insta = { git = "https://github.com/OlivierDehaene/insta", rev = "f4f98c0410b91fb5a28b10df98e4422955be9c2c", features = ["yaml"] } is_close = "0.1.3" hf-hub = "0.3.2" anyhow = { workspace = true } -tokenizers = { workspace = true } serial_test = "2.0.0" [build-dependencies] diff --git a/backends/candle/src/lib.rs b/backends/candle/src/lib.rs index acd93750..d6bdfd26 100644 --- a/backends/candle/src/lib.rs +++ b/backends/candle/src/lib.rs @@ -24,11 +24,12 @@ use candle::{DType, Device}; use candle_nn::VarBuilder; use nohash_hasher::BuildNoHashHasher; use serde::Deserialize; -use std::collections::HashMap; use std::path::PathBuf; +use std::{cmp::max, collections::HashMap}; use text_embeddings_backend_core::{ Backend, BackendError, Batch, Embedding, Embeddings, ModelType, Predictions, }; +use tokenizers::Encoding; /// This enum is needed to be able to differentiate between jina models that also use /// the `bert` model type and valid Bert models. @@ -465,3 +466,48 @@ impl WrapErr for Result { self.map_err(|e| BackendError::Inference(e.to_string())) } } + +pub fn batch(encodings: Vec, pooled_indices: Vec, raw_indices: Vec) -> Batch { + let mut input_ids = Vec::new(); + let mut token_type_ids = Vec::new(); + let mut position_ids = Vec::new(); + let mut cumulative_seq_lengths = Vec::with_capacity(encodings.len() + 1); + cumulative_seq_lengths.push(0); + + let mut max_length = 0; + let mut cumulative_length = 0; + + for encoding in encodings.iter() { + let encoding_length = encoding.len() as u32; + input_ids.extend(encoding.get_ids().to_vec()); + token_type_ids.extend(encoding.get_type_ids().to_vec()); + position_ids.extend(0..encoding_length); + cumulative_length += encoding_length; + cumulative_seq_lengths.push(cumulative_length); + max_length = max(max_length, encoding_length); + } + + Batch { + input_ids, + token_type_ids, + position_ids, + cumulative_seq_lengths, + max_length, + pooled_indices, + raw_indices, + } +} + +pub fn sort_embeddings(embeddings: Embeddings) -> (Vec>, Vec>) { + let mut pooled_embeddings = Vec::new(); + let mut raw_embeddings = Vec::new(); + + for (_, embedding) in embeddings { + match embedding { + Embedding::Pooled(e) => pooled_embeddings.push(e), + Embedding::All(e) => raw_embeddings.extend(e), + } + } + + (pooled_embeddings, raw_embeddings) +} diff --git a/backends/candle/tests/common.rs b/backends/candle/tests/common.rs index 0c069a47..e848212d 100644 --- a/backends/candle/tests/common.rs +++ b/backends/candle/tests/common.rs @@ -89,20 +89,6 @@ impl From>> for SnapshotEmbeddings { } } -pub fn sort_embeddings(embeddings: Embeddings) -> (Vec>, Vec>) { - let mut pooled_embeddings = Vec::new(); - let mut raw_embeddings = Vec::new(); - - for (_, embedding) in embeddings { - match embedding { - Embedding::Pooled(e) => pooled_embeddings.push(e), - Embedding::All(e) => raw_embeddings.extend(e), - } - } - - (pooled_embeddings, raw_embeddings) -} - pub fn download_artifacts( model_id: &'static str, revision: Option<&'static str>, @@ -232,34 +218,3 @@ pub fn load_tokenizer(model_root: &Path) -> Result { tokenizer.with_padding(None); Ok(tokenizer) } - -pub fn batch(encodings: Vec, pooled_indices: Vec, raw_indices: Vec) -> Batch { - let mut input_ids = Vec::new(); - let mut token_type_ids = Vec::new(); - let mut position_ids = Vec::new(); - let mut cumulative_seq_lengths = Vec::with_capacity(encodings.len() + 1); - cumulative_seq_lengths.push(0); - - let mut max_length = 0; - let mut cumulative_length = 0; - - for encoding in encodings.iter() { - let encoding_length = encoding.len() as u32; - input_ids.extend(encoding.get_ids().to_vec()); - token_type_ids.extend(encoding.get_type_ids().to_vec()); - position_ids.extend(0..encoding_length); - cumulative_length += encoding_length; - cumulative_seq_lengths.push(cumulative_length); - max_length = max(max_length, encoding_length); - } - - Batch { - input_ids, - token_type_ids, - position_ids, - cumulative_seq_lengths, - max_length, - pooled_indices, - raw_indices, - } -} diff --git a/backends/candle/tests/test_bert.rs b/backends/candle/tests/test_bert.rs index 1bd5017f..0598476b 100644 --- a/backends/candle/tests/test_bert.rs +++ b/backends/candle/tests/test_bert.rs @@ -2,8 +2,8 @@ mod common; use crate::common::{sort_embeddings, SnapshotEmbeddings, SnapshotScores}; use anyhow::Result; -use common::{batch, cosine_matcher, download_artifacts, load_tokenizer, relative_matcher}; -use text_embeddings_backend_candle::CandleBackend; +use common::{cosine_matcher, download_artifacts, load_tokenizer, relative_matcher}; +use text_embeddings_backend_candle::{batch, CandleBackend}; use text_embeddings_backend_core::{Backend, ModelType, Pool}; #[test] diff --git a/backends/candle/tests/test_flash_bert.rs b/backends/candle/tests/test_flash_bert.rs index ea150e7f..5c8cda74 100644 --- a/backends/candle/tests/test_flash_bert.rs +++ b/backends/candle/tests/test_flash_bert.rs @@ -2,10 +2,10 @@ mod common; -use crate::common::{sort_embeddings, SnapshotEmbeddings, SnapshotScores}; +use crate::common::{SnapshotEmbeddings, SnapshotScores}; use anyhow::Result; -use common::{batch, cosine_matcher, download_artifacts, load_tokenizer, relative_matcher}; -use text_embeddings_backend_candle::CandleBackend; +use common::{cosine_matcher, download_artifacts, load_tokenizer, relative_matcher}; +use text_embeddings_backend_candle::{batch, sort_embeddings, CandleBackend}; use text_embeddings_backend_core::{Backend, ModelType, Pool}; #[test] diff --git a/backends/candle/tests/test_flash_gte.rs b/backends/candle/tests/test_flash_gte.rs index 20b06b2f..4fce7de6 100644 --- a/backends/candle/tests/test_flash_gte.rs +++ b/backends/candle/tests/test_flash_gte.rs @@ -1,10 +1,10 @@ #![allow(dead_code, unused_imports)] mod common; -use crate::common::{sort_embeddings, SnapshotEmbeddings}; +use crate::common::SnapshotEmbeddings; use anyhow::Result; -use common::{batch, cosine_matcher, download_artifacts, load_tokenizer}; -use text_embeddings_backend_candle::CandleBackend; +use common::{cosine_matcher, download_artifacts, load_tokenizer}; +use text_embeddings_backend_candle::{batch, sort_embeddings, CandleBackend}; use text_embeddings_backend_core::{Backend, ModelType, Pool}; #[test] diff --git a/backends/candle/tests/test_flash_jina.rs b/backends/candle/tests/test_flash_jina.rs index 255b82a2..fbd0352c 100644 --- a/backends/candle/tests/test_flash_jina.rs +++ b/backends/candle/tests/test_flash_jina.rs @@ -1,10 +1,10 @@ #![allow(dead_code, unused_imports)] mod common; -use crate::common::{sort_embeddings, SnapshotEmbeddings}; +use crate::common::SnapshotEmbeddings; use anyhow::Result; -use common::{batch, cosine_matcher, download_artifacts, load_tokenizer}; -use text_embeddings_backend_candle::CandleBackend; +use common::{cosine_matcher, download_artifacts, load_tokenizer}; +use text_embeddings_backend_candle::{batch, sort_embeddings, CandleBackend}; use text_embeddings_backend_core::{Backend, ModelType, Pool}; #[test] diff --git a/backends/candle/tests/test_flash_jina_code.rs b/backends/candle/tests/test_flash_jina_code.rs index d84848dc..71ec6bac 100644 --- a/backends/candle/tests/test_flash_jina_code.rs +++ b/backends/candle/tests/test_flash_jina_code.rs @@ -1,10 +1,10 @@ #![allow(dead_code, unused_imports)] mod common; -use crate::common::{sort_embeddings, SnapshotEmbeddings}; +use crate::common::SnapshotEmbeddings; use anyhow::Result; -use common::{batch, cosine_matcher, download_artifacts, load_tokenizer}; -use text_embeddings_backend_candle::CandleBackend; +use common::{cosine_matcher, download_artifacts, load_tokenizer}; +use text_embeddings_backend_candle::{batch, sort_embeddings, CandleBackend}; use text_embeddings_backend_core::{Backend, ModelType, Pool}; #[test] diff --git a/backends/candle/tests/test_flash_mistral.rs b/backends/candle/tests/test_flash_mistral.rs index 71749c8b..a472db59 100644 --- a/backends/candle/tests/test_flash_mistral.rs +++ b/backends/candle/tests/test_flash_mistral.rs @@ -1,10 +1,10 @@ #![allow(dead_code, unused_imports)] mod common; -use crate::common::{sort_embeddings, SnapshotEmbeddings}; +use crate::common::SnapshotEmbeddings; use anyhow::Result; -use common::{batch, cosine_matcher, download_artifacts, load_tokenizer}; -use text_embeddings_backend_candle::CandleBackend; +use common::{cosine_matcher, download_artifacts, load_tokenizer}; +use text_embeddings_backend_candle::{batch, sort_embeddings, CandleBackend}; use text_embeddings_backend_core::{Backend, ModelType, Pool}; #[test] diff --git a/backends/candle/tests/test_flash_nomic.rs b/backends/candle/tests/test_flash_nomic.rs index 263bbe43..29f7314f 100644 --- a/backends/candle/tests/test_flash_nomic.rs +++ b/backends/candle/tests/test_flash_nomic.rs @@ -1,10 +1,10 @@ #![allow(dead_code, unused_imports)] mod common; -use crate::common::{sort_embeddings, SnapshotEmbeddings}; +use crate::common::SnapshotEmbeddings; use anyhow::Result; -use common::{batch, cosine_matcher, download_artifacts, load_tokenizer}; -use text_embeddings_backend_candle::CandleBackend; +use common::{cosine_matcher, download_artifacts, load_tokenizer}; +use text_embeddings_backend_candle::{batch, sort_embeddings, CandleBackend}; use text_embeddings_backend_core::{Backend, ModelType, Pool}; #[test] diff --git a/backends/candle/tests/test_flash_qwen2.rs b/backends/candle/tests/test_flash_qwen2.rs index 38e45553..35ae9aea 100644 --- a/backends/candle/tests/test_flash_qwen2.rs +++ b/backends/candle/tests/test_flash_qwen2.rs @@ -2,10 +2,10 @@ mod common; -use crate::common::{sort_embeddings, SnapshotEmbeddings}; +use crate::common::SnapshotEmbeddings; use anyhow::Result; -use common::{batch, cosine_matcher, download_artifacts, load_tokenizer}; -use text_embeddings_backend_candle::CandleBackend; +use common::{cosine_matcher, download_artifacts, load_tokenizer}; +use text_embeddings_backend_candle::{batch, sort_embeddings, CandleBackend}; use text_embeddings_backend_core::{Backend, ModelType, Pool}; use tokenizers::processors::sequence::Sequence; use tokenizers::processors::template::TemplateProcessing; diff --git a/backends/candle/tests/test_jina.rs b/backends/candle/tests/test_jina.rs index 4aa30d03..f516df77 100644 --- a/backends/candle/tests/test_jina.rs +++ b/backends/candle/tests/test_jina.rs @@ -1,9 +1,9 @@ mod common; -use crate::common::{sort_embeddings, SnapshotEmbeddings}; +use crate::common::SnapshotEmbeddings; use anyhow::Result; -use common::{batch, cosine_matcher, download_artifacts, load_tokenizer}; -use text_embeddings_backend_candle::CandleBackend; +use common::{cosine_matcher, download_artifacts, load_tokenizer}; +use text_embeddings_backend_candle::{batch, sort_embeddings, CandleBackend}; use text_embeddings_backend_core::{Backend, ModelType, Pool}; #[test] diff --git a/backends/candle/tests/test_jina_code.rs b/backends/candle/tests/test_jina_code.rs index 6c3b3f20..edbd6c7a 100644 --- a/backends/candle/tests/test_jina_code.rs +++ b/backends/candle/tests/test_jina_code.rs @@ -1,9 +1,9 @@ mod common; -use crate::common::{sort_embeddings, SnapshotEmbeddings}; +use crate::common::SnapshotEmbeddings; use anyhow::Result; -use common::{batch, cosine_matcher, download_artifacts, load_tokenizer}; -use text_embeddings_backend_candle::CandleBackend; +use common::{cosine_matcher, download_artifacts, load_tokenizer}; +use text_embeddings_backend_candle::{batch, sort_embeddings, CandleBackend}; use text_embeddings_backend_core::{Backend, ModelType, Pool}; #[test] diff --git a/backends/candle/tests/test_nomic.rs b/backends/candle/tests/test_nomic.rs index ce0a4559..a0e3e072 100644 --- a/backends/candle/tests/test_nomic.rs +++ b/backends/candle/tests/test_nomic.rs @@ -1,9 +1,9 @@ mod common; -use crate::common::{sort_embeddings, SnapshotEmbeddings}; +use crate::common::SnapshotEmbeddings; use anyhow::Result; -use common::{batch, cosine_matcher, download_artifacts, load_tokenizer}; -use text_embeddings_backend_candle::CandleBackend; +use common::{cosine_matcher, download_artifacts, load_tokenizer}; +use text_embeddings_backend_candle::{batch, sort_embeddings, CandleBackend}; use text_embeddings_backend_core::{Backend, ModelType, Pool}; #[test] diff --git a/backends/src/dtype.rs b/backends/src/dtype.rs index d2c896ce..01c0b60f 100644 --- a/backends/src/dtype.rs +++ b/backends/src/dtype.rs @@ -33,6 +33,7 @@ impl fmt::Display for DType { DType::Float32 => write!(f, "float32"), // #[cfg(feature = "candle")] // DType::Q6K => write!(f, "q6k"), + _ => unreachable!() } } }