Skip to content

Commit

Permalink
perf: calculate the max scores for posting lists then have a tighter …
Browse files Browse the repository at this point in the history
…upper bound (#2763)
  • Loading branch information
BubbleCal authored Aug 22, 2024
1 parent f7cc676 commit a7d2f6a
Show file tree
Hide file tree
Showing 3 changed files with 79 additions and 26 deletions.
11 changes: 11 additions & 0 deletions rust/lance-index/src/scalar/inverted/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ use futures::TryStreamExt;
use itertools::Itertools;
use lance_arrow::{iter_str_array, RecordBatchExt};
use lance_core::{Error, Result, ROW_ID};
use rayon::prelude::{IntoParallelRefIterator, ParallelIterator};
use snafu::{location, Location};

use super::index::*;
Expand Down Expand Up @@ -68,7 +69,17 @@ impl InvertedIndexBuilder {
token_set_writer.write_record_batch(token_set_batch).await?;
token_set_writer.finish().await?;

// calculate the max BM25 score for each posting list
let max_scores = self
.invert_list
.inverted_list
.par_iter()
.map(|list| list.calculate_max_score(&self.docs))
.collect::<Vec<_>>();
let max_scores = serde_json::to_string(&max_scores)?;
let invert_list_batch = self.invert_list.to_batch()?;
let invert_list_batch =
invert_list_batch.add_metadata("max_scores".to_owned(), max_scores)?;
let mut invert_list_writer = dest_store
.new_index_file(INVERT_LIST_FILE, invert_list_batch.schema())
.await?;
Expand Down
39 changes: 37 additions & 2 deletions rust/lance-index/src/scalar/inverted/index.rs
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ impl InvertedIndex {
.try_collect::<Vec<_>>()
.await?;

let mut wand = Wand::new(postings.into_iter());
let mut wand = Wand::new(self.docs.len(), postings.into_iter());
wand.search(limit, wand_factor, |doc, freq| {
let doc_norm =
K1 * (1.0 - B + B * self.docs.num_tokens(doc) as f32 / self.docs.average_length());
Expand Down Expand Up @@ -349,6 +349,7 @@ impl TokenSet {
struct InvertedListReader {
reader: Arc<dyn IndexReader>,
offsets: Vec<usize>,
max_scores: Option<Vec<f32>>,

// cache
posting_cache: Cache<u32, PostingList>,
Expand Down Expand Up @@ -379,13 +380,19 @@ impl InvertedListReader {
.ok_or_else(|| Error::io("offsets not found".to_string(), location!()))?;
let offsets: Vec<usize> = serde_json::from_str(offsets)?;

let max_scores = match reader.schema().metadata.get("max_scores") {
Some(max_scores) => serde_json::from_str(max_scores)?,
None => None,
};

let cache = Cache::builder()
.max_capacity(*CACHE_SIZE as u64)
.weigher(|_, posting: &PostingList| posting.deep_size_of() as u32)
.build();
Ok(Self {
reader,
offsets,
max_scores,
posting_cache: cache,
})
}
Expand Down Expand Up @@ -413,6 +420,9 @@ impl InvertedListReader {
Result::Ok(PostingList::new(
row_ids.values().clone(),
frequencies.values().clone(),
self.max_scores
.as_ref()
.map(|max_scores| max_scores[token_id]),
))
})
.await
Expand All @@ -424,6 +434,7 @@ impl InvertedListReader {
pub struct PostingList {
pub row_ids: ScalarBuffer<u64>,
pub frequencies: ScalarBuffer<f32>,
pub max_score: Option<f32>,
}

impl DeepSizeOf for PostingList {
Expand All @@ -434,10 +445,15 @@ impl DeepSizeOf for PostingList {
}

impl PostingList {
pub fn new(row_ids: ScalarBuffer<u64>, frequencies: ScalarBuffer<f32>) -> Self {
pub fn new(
row_ids: ScalarBuffer<u64>,
frequencies: ScalarBuffer<f32>,
max_score: Option<f32>,
) -> Self {
Self {
row_ids,
frequencies,
max_score,
}
}

Expand All @@ -453,6 +469,10 @@ impl PostingList {
(self.row_ids[i], self.frequencies[i])
}

pub fn max_score(&self) -> Option<f32> {
self.max_score
}

pub fn row_id(&self, i: usize) -> u64 {
self.row_ids[i]
}
Expand Down Expand Up @@ -480,6 +500,21 @@ impl PostingListBuilder {
self.len() == 0
}

pub fn calculate_max_score(&self, docs: &DocSet) -> f32 {
// score(q, D) = IDF(q) * (f(q, D) * (k1 + 1)) / (f(q, D) + k1 * (1 - b + b * |D| / avgdl))
let num_docs = docs.len();
let avgdl = docs.average_length();
let mut max_score = 0.0;
for (&row_id, &freq) in self.row_ids.iter().zip(self.frequencies.iter()) {
let doc_norm = K1 * (1.0 - B + B * docs.num_tokens(row_id) as f32 / avgdl);
let score = idf(self.len(), num_docs) * (K1 + 1.0) * freq / (freq + doc_norm);
if score > max_score {
max_score = score;
}
}
max_score
}

pub fn to_batch(&self) -> Result<RecordBatch> {
let indices = (0..self.row_ids.len())
.sorted_unstable_by_key(|&i| self.row_ids[i])
Expand Down
55 changes: 31 additions & 24 deletions rust/lance-index/src/scalar/inverted/wand.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,10 @@ impl PostingIterator {
num_doc: usize,
mask: Arc<RowIdMask>,
) -> Self {
let approximate_upper_bound = idf(list.len(), num_doc) * (K1 + 1.0);
let approximate_upper_bound = match list.max_score() {
Some(max_score) => max_score,
None => idf(list.len(), num_doc) * (K1 + 1.0),
};
Self {
token_id,
list,
Expand Down Expand Up @@ -96,15 +99,17 @@ impl PostingIterator {
pub struct Wand {
threshold: f32, // multiple of factor and the minimum score of the top-k documents
cur_doc: Option<u64>,
num_docs: usize,
postings: Vec<PostingIterator>,
candidates: BinaryHeap<Reverse<OrderedDoc>>,
}

impl Wand {
pub(crate) fn new(postings: impl Iterator<Item = PostingIterator>) -> Self {
pub(crate) fn new(num_docs: usize, postings: impl Iterator<Item = PostingIterator>) -> Self {
Self {
threshold: 0.0,
cur_doc: None,
num_docs,
postings: postings.collect(),
candidates: BinaryHeap::new(),
}
Expand Down Expand Up @@ -152,8 +157,8 @@ impl Wand {
break;
}
debug_assert!(cur_doc == doc);

score += posting.approximate_upper_bound() * scorer(doc, freq);
let idf = idf(posting.list.len(), self.num_docs);
score += idf * (K1 + 1.0) * scorer(doc, freq);
}
score
}
Expand All @@ -169,16 +174,8 @@ impl Wand {

let cur_doc = self.cur_doc.unwrap_or(0);
if self.cur_doc.is_some() && doc <= cur_doc {
// the pivot doc id is less than the current doc id,
// that means this doc id has been processed before, so skip it
self.move_terms(cur_doc + 1);
} else if self
.postings
.first()
.and_then(|posting| posting.doc().map(|(d, _)| d))
.expect("the postings can't be empty")
== doc
{
self.move_term(cur_doc + 1);
} else if self.postings[0].doc().unwrap().0 == doc {
// all the posting iterators have reached this doc id,
// so that means the sum of upper bound of all terms is not less than the threshold,
// this document is a candidate
Expand All @@ -187,7 +184,7 @@ impl Wand {
} else {
// some posting iterators haven't reached this doc id,
// so move such terms to the doc id
self.move_terms(doc);
self.move_term(doc);
}
}
Ok(None)
Expand All @@ -210,15 +207,9 @@ impl Wand {
// pick the term that has the maximum upper bound and the current doc id is less than the given doc id
// so that we can move the posting iterator to the next doc id that is possible to be candidate
#[instrument(level = "debug", skip_all)]
fn move_terms<'b>(&mut self, least_id: u64) {
for posting in self.postings.iter_mut() {
match posting.doc() {
Some((d, _)) if d < least_id => {
posting.next(least_id);
}
_ => break,
}
}
fn move_term(&mut self, least_id: u64) {
let picked = self.pick_term(least_id);
self.postings[picked].next(least_id);

self.postings.sort_unstable();
while let Some(last) = self.postings.last() {
Expand All @@ -229,4 +220,20 @@ impl Wand {
}
}
}

fn pick_term(&self, least_id: u64) -> usize {
let mut least_length = usize::MAX;
let mut pick_index = 0;
for (i, posting) in self.postings.iter().enumerate() {
let (doc, _) = posting.doc().unwrap();
if doc >= least_id {
break;
}
if posting.list.len() < least_length {
least_length = posting.list.len();
pick_index = i;
}
}
pick_index
}
}

0 comments on commit a7d2f6a

Please # to comment.