diff --git a/python/python/lance/sql.py b/python/python/lance/sql.py index 9fea40ceb4..e826ff1d37 100644 --- a/python/python/lance/sql.py +++ b/python/python/lance/sql.py @@ -7,14 +7,19 @@ class SqlQueryBuilder: + """A tool for building SQL queries""" + def __init__(self, query_str: str): + """Create a new SQL query builder from a query string.""" self.inner = LanceSqlQueryBuilder(query_str) - def with_lance_dataset(self, ds: LanceDataset) -> "SqlQueryBuilder": - self.inner.with_lance_dataset(ds._ds) + def with_dataset(self, alias: str, ds: LanceDataset) -> "SqlQueryBuilder": + """Adds a dataset to the query's context with a given alias.""" + self.inner.with_dataset(alias, ds._ds) return self def to_table(self) -> pa.Table: + """Execute the query and return the result as a table.""" return self.inner.execute().read_all() @@ -22,6 +27,11 @@ def query(query_str: str) -> SqlQueryBuilder: """ Create an SQL query builder from a query string. + Note: This is an experimental feature. The API may change in future + versions or be removed entirely. The most stable way to execute SQL + against Lance datasets is to use another tool such as DuckDB. This + tool is primarily intended for simple exploration and prototyping. + Parameters ---------- query_str: str @@ -38,8 +48,8 @@ def query(query_str: str) -> SqlQueryBuilder: >>> import lance >>> >>> ds = lance.write_dataset(pa.table("a", [1, 2, 3], "b": [4, 5, 6])) - >>> query = lance.query("SELECT SUM(a) FROM ds WHERE b > 4") - >>> table = query.to_table() + >>> query = lance.query("SELECT SUM(a) FROM d1 WHERE b > 4") + >>> table = query.with_dataset("d1", ds).to_table() >>> print(table) # pyarrow.Table a diff --git a/python/python/tests/test_sql.py b/python/python/tests/test_sql.py index cd1552f6cc..7efb07131d 100644 --- a/python/python/tests/test_sql.py +++ b/python/python/tests/test_sql.py @@ -10,3 +10,30 @@ def test_no_dataset(): assert lance.sql.query("SELECT 5").to_table() == pa.table( {"Int64(5)": [5]}, schema=schema ) + + +def test_aggregation(tmp_path): + ds = lance.write_dataset(pa.table({"a": [1, 2, 3], "b": [4, 5, 6]}), tmp_path) + schema = pa.schema([pa.field("sum(d1.a)", pa.int64(), nullable=True)]) + assert lance.sql.query("SELECT SUM(a) FROM d1 WHERE b > 4").with_dataset( + "d1", ds + ).to_table() == pa.table({"sum(d1.a)": [5]}, schema=schema) + + +def test_join(tmp_path): + ds1 = lance.write_dataset( + pa.table({"a": [1, 2, 3], "b": [4, 5, 6]}), tmp_path / "d1" + ) + ds2 = lance.write_dataset( + pa.table({"c": [4, 5, 6], "d": ["x", "y", "z"]}), tmp_path / "d2" + ) + expected = pa.table({"a": [3, 2, 1], "d": ["z", "y", "x"]}) + assert ( + lance.sql.query( + "SELECT d1.a, d2.d FROM d1 INNER JOIN d2 ON d1.b = d2.c ORDER BY d1.a DESC" + ) + .with_dataset("d1", ds1) + .with_dataset("d2", ds2) + .to_table() + == expected + ) diff --git a/python/src/sql.rs b/python/src/sql.rs index cd45a8c97e..0b134ff783 100644 --- a/python/src/sql.rs +++ b/python/src/sql.rs @@ -1,34 +1,58 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright The Lance Authors +use std::{ + collections::HashMap, + sync::{Arc, Mutex}, +}; + use arrow::pyarrow::PyArrowType; use arrow_array::RecordBatchReader; use pyo3::{pyclass, pymethods, PyRef, PyResult}; use lance::datafusion::sql::SqlPlan; +use lance::Dataset as LanceDataset; + +use crate::{error::PythonErrorExt, Dataset, LanceReader, RT}; -use crate::{error::PythonErrorExt, LanceReader, RT}; +struct QueryBuilderState { + datasets: HashMap>, +} #[pyclass] pub struct SqlQueryBuilder { pub query: String, + state: Arc>, } #[pymethods] impl SqlQueryBuilder { #[new] pub fn new(query: String) -> Self { - Self { query } + Self { + query, + state: Arc::new(Mutex::new(QueryBuilderState { + datasets: HashMap::new(), + })), + } } - fn with_lance_dataset(self: PyRef<'_, Self>) -> PyRef<'_, Self> { - self + fn with_dataset<'a>(slf: PyRef<'a, Self>, alias: String, dataset: &Dataset) -> PyRef<'a, Self> { + { + let mut state = slf.state.lock().unwrap(); + state.datasets.insert(alias, dataset.ds.clone()); + } + slf } - fn execute(self_: PyRef<'_, Self>) -> PyResult>> { - let query = SqlPlan::new(self_.query.clone()); + fn execute(slf: PyRef<'_, Self>) -> PyResult>> { + let context = { + let state = slf.state.lock().unwrap(); + state.datasets.clone() + }; + let query = SqlPlan::new(slf.query.clone(), context); let reader = RT - .spawn(Some(self_.py()), async move { + .spawn(Some(slf.py()), async move { Ok(LanceReader::from_stream(query.execute().await?)) })? .infer_error()?; diff --git a/rust/lance/src/datafusion/dataframe.rs b/rust/lance/src/datafusion/dataframe.rs index 2f9a850bac..b6b8ca6b15 100644 --- a/rust/lance/src/datafusion/dataframe.rs +++ b/rust/lance/src/datafusion/dataframe.rs @@ -30,7 +30,7 @@ pub struct LanceTableProvider { } impl LanceTableProvider { - fn new(dataset: Arc, with_row_id: bool, with_row_addr: bool) -> Self { + pub fn new(dataset: Arc, with_row_id: bool, with_row_addr: bool) -> Self { let mut full_schema = Schema::from(dataset.schema()); let mut row_id_idx = None; let mut row_addr_idx = None; diff --git a/rust/lance/src/datafusion/sql.rs b/rust/lance/src/datafusion/sql.rs index cdf574f025..7b3eef88b1 100644 --- a/rust/lance/src/datafusion/sql.rs +++ b/rust/lance/src/datafusion/sql.rs @@ -1,26 +1,43 @@ -use std::collections::HashMap; +use std::{collections::HashMap, sync::Arc}; -use lance_datafusion::exec::{get_session_context, LanceExecutionOptions}; +use lance_datafusion::exec::{new_session_context, LanceExecutionOptions}; use crate::{dataset::scanner::DatasetRecordBatchStream, Dataset, Result}; +use super::dataframe::LanceTableProvider; + +/// An SQL query that can be executed pub struct SqlPlan { query: String, - context: HashMap, + context: HashMap>, execution_options: LanceExecutionOptions, } impl SqlPlan { - pub fn new(query: String) -> Self { + /// Creates a new SQL with a given query string and context + /// + /// The context is a mapping of dataset aliases to datasets. + /// This is how the SQL query can reference datasets. + pub fn new(query: String, context: HashMap>) -> Self { Self { query, - context: HashMap::new(), + context, execution_options: LanceExecutionOptions::default(), } } + /// Executes the SQL query and returns a stream of record batches pub async fn execute(&self) -> Result { - let session_context = get_session_context(&self.execution_options); + let session_context = new_session_context(&self.execution_options); + + for (alias, dataset) in &self.context { + let provider = Arc::new(LanceTableProvider::new( + dataset.clone(), + /*with_row_id= */ true, + /*with_row_addr= */ true, + )); + session_context.register_table(alias, provider)?; + } let df = session_context.sql(&self.query).await?; let stream = df.execute_stream().await?;