From 0608cfdf486908ae00c6e2e16a36be41c0596b56 Mon Sep 17 00:00:00 2001 From: Will Jones Date: Thu, 20 Jun 2024 20:34:22 -0700 Subject: [PATCH] feat: support stable row ids in Dataset::take_rows() (#2447) Part of https://github.com/lancedb/lance/issues/2307 * `Dataset::take_rows()` is now taking `row_ids`, which may now be stable row ids. These are translated into row addresses internally, and then use the existing logic. * `Fragment::take_rows()` now optionally returns row address column, if asked, instead of row id. --- rust/lance-table/src/utils/stream.rs | 40 +++-- rust/lance/src/dataset/fragment.rs | 27 +-- rust/lance/src/dataset/take.rs | 255 ++++++++++++++------------- 3 files changed, 170 insertions(+), 152 deletions(-) diff --git a/rust/lance-table/src/utils/stream.rs b/rust/lance-table/src/utils/stream.rs index cd624dcf4c..9edea6341a 100644 --- a/rust/lance-table/src/utils/stream.rs +++ b/rust/lance-table/src/utils/stream.rs @@ -196,30 +196,32 @@ pub fn apply_row_id_and_deletes( batch.num_columns() > 0 || config.with_row_id || config.with_row_addr || has_deletions ); - let should_fetch_row_id = config.with_row_id || has_deletions; + // If row id sequence is None, then row id IS row address. + let should_fetch_row_addr = config.with_row_addr + || (config.with_row_id && config.row_id_sequence.is_none()) + || has_deletions; let num_rows = batch.num_rows() as u32; - let row_addrs = - if config.with_row_addr || (should_fetch_row_id && config.row_id_sequence.is_none()) { - let ids_in_batch = config - .params - .slice(batch_offset as usize, num_rows as usize) - .unwrap() - .to_offsets() - .unwrap(); - let row_addrs: UInt64Array = ids_in_batch - .values() - .iter() - .map(|row_id| u64::from(RowAddress::new_from_parts(fragment_id, *row_id))) - .collect(); + let row_addrs = if should_fetch_row_addr { + let ids_in_batch = config + .params + .slice(batch_offset as usize, num_rows as usize) + .unwrap() + .to_offsets() + .unwrap(); + let row_addrs: UInt64Array = ids_in_batch + .values() + .iter() + .map(|row_id| u64::from(RowAddress::new_from_parts(fragment_id, *row_id))) + .collect(); - Some(Arc::new(row_addrs)) - } else { - None - }; + Some(Arc::new(row_addrs)) + } else { + None + }; - let row_ids = if should_fetch_row_id { + let row_ids = if config.with_row_id { if let Some(row_id_sequence) = &config.row_id_sequence { let row_ids = row_id_sequence .slice(batch_offset as usize, num_rows as usize) diff --git a/rust/lance/src/dataset/fragment.rs b/rust/lance/src/dataset/fragment.rs index acf4662521..a3930a42e9 100644 --- a/rust/lance/src/dataset/fragment.rs +++ b/rust/lance/src/dataset/fragment.rs @@ -944,30 +944,31 @@ impl FileFragment { } } - /// Take rows based on internal local row ids + /// Take rows based on internal local row offsets /// - /// If the row ids are out-of-bounds, this will return an error. But if the - /// row id is marked deleted, it will be ignored. Thus, the number of rows - /// returned may be less than the number of row ids provided. + /// If the row offsets are out-of-bounds, this will return an error. But if the + /// row offset is marked deleted, it will be ignored. Thus, the number of rows + /// returned may be less than the number of row offsets provided. /// - /// To recover the original row ids from the returned RecordBatch, set the - /// `with_row_id` parameter to true. This will add a column named `_row_id` + /// To recover the original row addresses from the returned RecordBatch, set the + /// `with_row_address` parameter to true. This will add a column named `_rowaddr` /// to the RecordBatch at the end. pub(crate) async fn take_rows( &self, - row_ids: &[u32], + row_offsets: &[u32], projection: &Schema, - with_row_id: bool, + with_row_address: bool, ) -> Result { // TODO: support taking row addresses - let reader = self.open(projection, with_row_id, false).await?; + let reader = self.open(projection, false, with_row_address).await?; - if row_ids.len() > 1 && Self::row_ids_contiguous(row_ids) { - let range = (row_ids[0] as usize)..(row_ids[row_ids.len() - 1] as usize + 1); + if row_offsets.len() > 1 && Self::row_ids_contiguous(row_offsets) { + let range = + (row_offsets[0] as usize)..(row_offsets[row_offsets.len() - 1] as usize + 1); reader.legacy_read_range_as_batch(range).await } else { // FIXME, change this method to streams - reader.take_as_batch(row_ids).await + reader.take_as_batch(row_offsets).await } } @@ -2008,7 +2009,7 @@ mod tests { &Int32Array::from(vec![121, 125, 128]) ); assert_eq!( - batch.column_by_name(ROW_ID).unwrap().as_ref(), + batch.column_by_name(ROW_ADDR).unwrap().as_ref(), &UInt64Array::from(vec![(3 << 32) + 1, (3 << 32) + 5, (3 << 32) + 8]) ); } diff --git a/rust/lance/src/dataset/take.rs b/rust/lance/src/dataset/take.rs index 22cd2cf97d..bd0e23c251 100644 --- a/rust/lance/src/dataset/take.rs +++ b/rust/lance/src/dataset/take.rs @@ -1,8 +1,10 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright The Lance Authors +use std::borrow::Cow; use std::{collections::BTreeMap, ops::Range, pin::Pin, sync::Arc}; +use crate::dataset::rowids::get_row_id_index; use crate::{Error, Result}; use arrow::{array::as_struct_array, compute::concat_batches, datatypes::UInt64Type}; use arrow_array::cast::AsArray; @@ -12,7 +14,9 @@ use arrow_select::interleave::interleave; use datafusion::error::DataFusionError; use datafusion::physical_plan::stream::RecordBatchStreamAdapter; use futures::{Future, Stream, StreamExt, TryStreamExt}; -use lance_core::{datatypes::Schema, ROW_ID}; +use lance_core::datatypes::Schema; +use lance_core::utils::address::RowAddress; +use lance_core::ROW_ADDR; use snafu::{location, Location}; use super::{fragment::FileFragment, scanner::DatasetRecordBatchStream, Dataset}; @@ -137,8 +141,20 @@ pub async fn take_rows( return Ok(RecordBatch::new_empty(Arc::new(projection.into()))); } + let row_addrs = if dataset.manifest.uses_move_stable_row_ids() { + // Need to map the row ids to addresses + let index = get_row_id_index(dataset).await?; + let addresses = row_ids + .iter() + .filter_map(|id| index.get(*id).map(|address| address.into())) + .collect::>(); + Cow::Owned(addresses) + } else { + Cow::Borrowed(row_ids) + }; + let projection = Arc::new(projection.clone()); - let row_id_meta = check_row_ids(row_ids); + let row_addr_stats = check_row_addrs(&row_addrs); // This method is mostly to annotate the send bound to avoid the // higher-order lifetime error. @@ -146,44 +162,44 @@ pub async fn take_rows( #[allow(clippy::manual_async_fn)] fn do_take( fragment: FileFragment, - row_ids: Vec, + row_offsets: Vec, projection: Arc, - with_row_id: bool, + with_row_addresses: bool, ) -> impl Future> + Send { async move { fragment - .take_rows(&row_ids, projection.as_ref(), with_row_id) + .take_rows(&row_offsets, projection.as_ref(), with_row_addresses) .await } } - if row_id_meta.contiguous { + if row_addr_stats.contiguous { // Fastest path: Can use `read_range` directly - let start = row_ids.first().expect("empty range passed to take_rows"); + let start = row_addrs.first().expect("empty range passed to take_rows"); let fragment_id = (start >> 32) as usize; let range_start = *start as u32 as usize; - let range_end = *row_ids.last().expect("empty range passed to take_rows") as u32 as usize; + let range_end = *row_addrs.last().expect("empty range passed to take_rows") as u32 as usize; let range = range_start..(range_end + 1); let fragment = dataset.get_fragment(fragment_id).ok_or_else(|| { Error::invalid_input( - format!("row_id belongs to non-existant fragment: {start}"), + format!("_rowaddr belongs to non-existent fragment: {start}"), location!(), ) })?; let reader = fragment.open(projection.as_ref(), false, false).await?; reader.legacy_read_range_as_batch(range).await - } else if row_id_meta.sorted { + } else if row_addr_stats.sorted { // Don't need to re-arrange data, just concatenate let mut batches: Vec<_> = Vec::new(); - let mut current_fragment = row_ids[0] >> 32; + let mut current_fragment = row_addrs[0] >> 32; let mut current_start = 0; - let mut row_ids_iter = row_ids.iter().enumerate(); + let mut row_addr_iter = row_addrs.iter().enumerate(); 'outer: loop { let (fragment_id, range) = loop { - if let Some((i, row_id)) = row_ids_iter.next() { + if let Some((i, row_id)) = row_addr_iter.next() { let fragment_id = row_id >> 32; if fragment_id != current_fragment { let next = (current_fragment, current_start..i); @@ -191,9 +207,9 @@ pub async fn take_rows( current_start = i; break next; } - } else if current_start != row_ids.len() { - let next = (current_fragment, current_start..row_ids.len()); - current_start = row_ids.len(); + } else if current_start != row_addrs.len() { + let next = (current_fragment, current_start..row_addrs.len()); + current_start = row_addrs.len(); break next; } else { break 'outer; @@ -203,15 +219,15 @@ pub async fn take_rows( let fragment = dataset.get_fragment(fragment_id as usize).ok_or_else(|| { Error::invalid_input( format!( - "row_id {} belongs to non-existant fragment: {}", - row_ids[range.start], fragment_id + "_rowaddr {} belongs to non-existent fragment: {}", + row_addrs[range.start], fragment_id ), location!(), ) })?; - let row_ids: Vec = row_ids[range].iter().map(|x| *x as u32).collect(); + let row_offsets: Vec = row_addrs[range].iter().map(|x| *x as u32).collect(); - let batch_fut = do_take(fragment, row_ids, projection.clone(), false); + let batch_fut = do_take(fragment, row_offsets, projection.clone(), false); batches.push(batch_fut); } let batches: Vec = futures::stream::iter(batches) @@ -220,25 +236,26 @@ pub async fn take_rows( .await?; Ok(concat_batches(&batches[0].schema(), &batches)?) } else { - let projection_with_row_id = Schema::merge( + let projection_with_row_addr = Schema::merge( projection.as_ref(), &ArrowSchema::new(vec![ArrowField::new( - ROW_ID, + ROW_ADDR, arrow::datatypes::DataType::UInt64, false, )]), )?; - let schema_with_row_id = Arc::new(ArrowSchema::from(&projection_with_row_id)); + let schema_with_row_addr = Arc::new(ArrowSchema::from(&projection_with_row_addr)); // Slow case: need to re-map data into expected order - let mut sorted_row_ids = Vec::from(row_ids); - sorted_row_ids.sort(); + let mut sorted_row_addrs = Vec::from(row_addrs.clone()); + sorted_row_addrs.sort(); // Group ROW Ids by the fragment - let mut row_ids_per_fragment: BTreeMap> = BTreeMap::new(); - sorted_row_ids.iter().for_each(|row_id| { - let fragment_id = row_id >> 32; - let offset = (row_id - (fragment_id << 32)) as u32; - row_ids_per_fragment + let mut row_addrs_per_fragment: BTreeMap> = BTreeMap::new(); + sorted_row_addrs.iter().for_each(|row_addr| { + let row_addr = RowAddress::new_from_id(*row_addr); + let fragment_id = row_addr.fragment_id(); + let offset = row_addr.row_id(); + row_addrs_per_fragment .entry(fragment_id) .and_modify(|v| v.push(offset)) .or_insert_with(|| vec![offset]); @@ -246,8 +263,8 @@ pub async fn take_rows( let fragments = dataset.get_fragments(); let fragment_and_indices = fragments.into_iter().filter_map(|f| { - let local_row_ids = row_ids_per_fragment.remove(&(f.id() as u64))?; - Some((f, local_row_ids)) + let row_offset = row_addrs_per_fragment.remove(&(f.id() as u32))?; + Some((f, row_offset)) }); let mut batches = futures::stream::iter(fragment_and_indices) @@ -257,7 +274,7 @@ pub async fn take_rows( .await?; let one_batch = if batches.len() > 1 { - concat_batches(&schema_with_row_id, &batches)? + concat_batches(&schema_with_row_addr, &batches)? } else { batches.pop().unwrap() }; @@ -266,19 +283,19 @@ pub async fn take_rows( // get the results with row ids so that we can re-order the results // to match the requested order. - let returned_row_ids = one_batch - .column_by_name(ROW_ID) + let returned_row_addr = one_batch + .column_by_name(ROW_ADDR) .ok_or_else(|| Error::Internal { - message: "ROW_ID column not found".into(), + message: "_rowaddr column not found".into(), location: location!(), })? .as_primitive::() .values(); - let remapping_index: UInt64Array = row_ids + let remapping_index: UInt64Array = row_addrs .iter() .filter_map(|o| { - returned_row_ids + returned_row_addr .iter() .position(|id| id == o) .map(|pos| pos as u64) @@ -287,7 +304,7 @@ pub async fn take_rows( debug_assert_eq!(remapping_index.len(), one_batch.num_rows()); - // Remove the row id column. + // Remove the rowaddr column. let keep_indices = (0..one_batch.num_columns() - 1).collect::>(); let one_batch = one_batch.project(&keep_indices)?; let struct_arr: StructArray = one_batch.into(); @@ -329,17 +346,17 @@ pub fn take_scan( ))) } -struct RowIdMeta { +struct RowAddressStats { sorted: bool, contiguous: bool, } -fn check_row_ids(row_ids: &[u64]) -> RowIdMeta { +fn check_row_addrs(row_ids: &[u64]) -> RowAddressStats { let mut sorted = true; let mut contiguous = true; if row_ids.is_empty() { - return RowIdMeta { sorted, contiguous }; + return RowAddressStats { sorted, contiguous }; } let mut last_id = row_ids[0]; @@ -353,13 +370,14 @@ fn check_row_ids(row_ids: &[u64]) -> RowIdMeta { last_id = *id; } - RowIdMeta { sorted, contiguous } + RowAddressStats { sorted, contiguous } } #[cfg(test)] mod test { use arrow_array::{Int32Array, RecordBatchIterator, StringArray}; use arrow_schema::DataType; + use pretty_assertions::assert_eq; use rstest::rstest; use crate::dataset::{scanner::test_dataset::TestVectorDataset, WriteParams}; @@ -371,44 +389,44 @@ mod test { t } - #[rstest] - #[tokio::test] - async fn test_take(#[values(false, true)] use_legacy_format: bool) { - let test_dir = tempfile::tempdir().unwrap(); - + fn test_batch(i_range: Range) -> RecordBatch { let schema = Arc::new(ArrowSchema::new(vec![ ArrowField::new("i", DataType::Int32, false), ArrowField::new("s", DataType::Utf8, false), ])); - let batches: Vec = (0..20) - .map(|i| { - RecordBatch::try_new( - schema.clone(), - vec![ - Arc::new(Int32Array::from_iter_values(i * 20..(i + 1) * 20)), - Arc::new(StringArray::from_iter_values( - (i * 20..(i + 1) * 20).map(|i| format!("str-{i}")), - )), - ], - ) - .unwrap() - }) - .collect(); - let test_uri = test_dir.path().to_str().unwrap(); + RecordBatch::try_new( + schema, + vec![ + Arc::new(Int32Array::from_iter_values(i_range.clone())), + Arc::new(StringArray::from_iter_values( + i_range.clone().map(|i| format!("str-{}", i)), + )), + ], + ) + .unwrap() + } + + #[rstest] + #[tokio::test] + async fn test_take( + #[values(false, true)] use_legacy_format: bool, + #[values(false, true)] enable_move_stable_row_ids: bool, + ) { + let data = test_batch(0..400); let write_params = WriteParams { max_rows_per_file: 40, max_rows_per_group: 10, use_legacy_format, + enable_move_stable_row_ids, ..Default::default() }; - let batches = RecordBatchIterator::new(batches.into_iter().map(Ok), schema.clone()); - Dataset::write(batches, test_uri, Some(write_params)) + let batches = RecordBatchIterator::new([Ok(data.clone())], data.schema()); + let dataset = Dataset::write(batches, "memory://", Some(write_params)) .await .unwrap(); - let dataset = Dataset::open(test_uri).await.unwrap(); assert_eq!(dataset.count_rows(None).await.unwrap(), 400); - let projection = Schema::try_from(schema.as_ref()).unwrap(); + let projection = Schema::try_from(data.schema().as_ref()).unwrap(); let values = dataset .take( &[ @@ -426,7 +444,7 @@ mod test { .unwrap(); assert_eq!( RecordBatch::try_new( - schema.clone(), + data.schema(), vec![ Arc::new(Int32Array::from_iter_values([ 200, 199, 39, 40, 199, 40, 125 @@ -451,7 +469,7 @@ mod test { let ds = test_ds.dataset; // take the last row of first fragment - // this triggeres the contiguous branch + // this triggers the contiguous branch let indices = &[(1 << 32) - 1]; let fut = require_send(ds.take_rows(indices, ds.schema())); let err = fut.await.unwrap_err(); @@ -461,7 +479,7 @@ mod test { err.to_string() ); - // this triggeres the sorted branch, but not continguous + // this triggers the sorted branch, but not contiguous let indices = &[(1 << 32) - 3, (1 << 32) - 1]; let err = ds.take_rows(indices, ds.schema()).await.unwrap_err(); assert!( @@ -471,7 +489,7 @@ mod test { err.to_string() ); - // this triggeres the catch all branch + // this triggers the catch all branch let indices = &[(1 << 32) - 1, (1 << 32) - 3]; let err = ds.take_rows(indices, ds.schema()).await.unwrap_err(); assert!( @@ -485,40 +503,20 @@ mod test { #[rstest] #[tokio::test] async fn test_take_rows(#[values(false, true)] use_legacy_format: bool) { - let test_dir = tempfile::tempdir().unwrap(); - - let schema = Arc::new(ArrowSchema::new(vec![ - ArrowField::new("i", DataType::Int32, false), - ArrowField::new("s", DataType::Utf8, false), - ])); - let batches: Vec = (0..20) - .map(|i| { - RecordBatch::try_new( - schema.clone(), - vec![ - Arc::new(Int32Array::from_iter_values(i * 20..(i + 1) * 20)), - Arc::new(StringArray::from_iter_values( - (i * 20..(i + 1) * 20).map(|i| format!("str-{i}")), - )), - ], - ) - .unwrap() - }) - .collect(); - let test_uri = test_dir.path().to_str().unwrap(); + let data = test_batch(0..400); let write_params = WriteParams { max_rows_per_file: 40, max_rows_per_group: 10, use_legacy_format, ..Default::default() }; - let batches = RecordBatchIterator::new(batches.into_iter().map(Ok), schema.clone()); - let mut dataset = Dataset::write(batches, test_uri, Some(write_params)) + let batches = RecordBatchIterator::new([Ok(data.clone())], data.schema()); + let mut dataset = Dataset::write(batches, "memory://", Some(write_params)) .await .unwrap(); assert_eq!(dataset.count_rows(None).await.unwrap(), 400); - let projection = Schema::try_from(schema.as_ref()).unwrap(); + let projection = Schema::try_from(data.schema().as_ref()).unwrap(); let indices = &[ 5_u64 << 32, // 200 (4_u64 << 32) + 39, // 199 @@ -529,7 +527,7 @@ mod test { let values = dataset.take_rows(indices, &projection).await.unwrap(); assert_eq!( RecordBatch::try_new( - schema.clone(), + data.schema(), vec![ Arc::new(Int32Array::from_iter_values([200, 199, 39, 40, 100])), Arc::new(StringArray::from_iter_values( @@ -547,7 +545,7 @@ mod test { let values = dataset.take_rows(indices, &projection).await.unwrap(); assert_eq!( RecordBatch::try_new( - schema.clone(), + data.schema(), vec![ Arc::new(Int32Array::from_iter_values([200, 39, 40])), Arc::new(StringArray::from_iter_values( @@ -561,44 +559,25 @@ mod test { // Take an empty selection. let values = dataset.take_rows(&[], &projection).await.unwrap(); - assert_eq!(RecordBatch::new_empty(schema.clone()), values); + assert_eq!(RecordBatch::new_empty(data.schema()), values); } #[rstest] #[tokio::test] async fn take_scan_dataset(#[values(false, true)] use_legacy_format: bool) { use arrow::datatypes::Int32Type; - use arrow_array::Float32Array; - - let schema = Arc::new(ArrowSchema::new(vec![ - ArrowField::new("i", DataType::Int32, false), - ArrowField::new("x", DataType::Float32, false), - ])); - let batch = RecordBatch::try_new( - schema.clone(), - vec![ - Arc::new(Int32Array::from(vec![1, 2, 3, 4])), - Arc::new(Float32Array::from(vec![1.0, 2.0, 3.0, 4.0])), - ], - ) - .unwrap(); - - let test_dir = tempfile::tempdir().unwrap(); - let test_uri = test_dir.path().to_str().unwrap(); + let data = test_batch(1..5); let write_params = WriteParams { max_rows_per_group: 2, use_legacy_format, ..Default::default() }; - - let batches = RecordBatchIterator::new(vec![batch].into_iter().map(Ok), schema.clone()); - Dataset::write(batches, test_uri, Some(write_params.clone())) + let batches = RecordBatchIterator::new([Ok(data.clone())], data.schema()); + let dataset = Dataset::write(batches, "memory://", Some(write_params)) .await .unwrap(); - let dataset = Dataset::open(test_uri).await.unwrap(); - let projection = Arc::new(dataset.schema().project(&["i"]).unwrap()); let ranges = [0_u64..3, 1..4, 0..1]; let range_stream = futures::stream::iter(ranges).map(Ok).boxed(); @@ -625,4 +604,40 @@ mod test { &[1], ); } + + #[rstest] + #[tokio::test] + async fn test_take_rows_with_row_ids(#[values(false, true)] use_legacy_format: bool) { + let data = test_batch(0..8); + let write_params = WriteParams { + max_rows_per_group: 2, + use_legacy_format, + enable_move_stable_row_ids: true, + ..Default::default() + }; + let batches = RecordBatchIterator::new([Ok(data.clone())], data.schema()); + let mut dataset = Dataset::write(batches, "memory://", Some(write_params)) + .await + .unwrap(); + + dataset.delete("i in (1, 2, 3, 7)").await.unwrap(); + + let indices = &[0, 4, 6, 5]; + let result = dataset.take_rows(indices, dataset.schema()).await.unwrap(); + assert_eq!( + RecordBatch::try_new( + data.schema(), + vec![ + Arc::new(Int32Array::from_iter_values( + indices.iter().map(|x| *x as i32) + )), + Arc::new(StringArray::from_iter_values( + indices.iter().map(|v| format!("str-{v}")) + )), + ], + ) + .unwrap(), + result + ); + } }