Skip to content

Commit

Permalink
Add prototype SQL interface
Browse files Browse the repository at this point in the history
  • Loading branch information
westonpace committed Nov 16, 2024
1 parent afa711c commit faedd7e
Show file tree
Hide file tree
Showing 5 changed files with 96 additions and 18 deletions.
18 changes: 14 additions & 4 deletions python/python/lance/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,21 +7,31 @@


class SqlQueryBuilder:
"""A tool for building SQL queries"""

def __init__(self, query_str: str):
"""Create a new SQL query builder from a query string."""
self.inner = LanceSqlQueryBuilder(query_str)

def with_lance_dataset(self, ds: LanceDataset) -> "SqlQueryBuilder":
self.inner.with_lance_dataset(ds._ds)
def with_dataset(self, alias: str, ds: LanceDataset) -> "SqlQueryBuilder":
"""Adds a dataset to the query's context with a given alias."""
self.inner.with_dataset(alias, ds._ds)
return self

def to_table(self) -> pa.Table:
"""Execute the query and return the result as a table."""
return self.inner.execute().read_all()


def query(query_str: str) -> SqlQueryBuilder:
"""
Create an SQL query builder from a query string.
Note: This is an experimental feature. The API may change in future
versions or be removed entirely. The most stable way to execute SQL
against Lance datasets is to use another tool such as DuckDB. This
tool is primarily intended for simple exploration and prototyping.
Parameters
----------
query_str: str
Expand All @@ -38,8 +48,8 @@ def query(query_str: str) -> SqlQueryBuilder:
>>> import lance
>>>
>>> ds = lance.write_dataset(pa.table("a", [1, 2, 3], "b": [4, 5, 6]))
>>> query = lance.query("SELECT SUM(a) FROM ds WHERE b > 4")
>>> table = query.to_table()
>>> query = lance.query("SELECT SUM(a) FROM d1 WHERE b > 4")
>>> table = query.with_dataset("d1", ds).to_table()
>>> print(table)
# pyarrow.Table
a
Expand Down
27 changes: 27 additions & 0 deletions python/python/tests/test_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,30 @@ def test_no_dataset():
assert lance.sql.query("SELECT 5").to_table() == pa.table(
{"Int64(5)": [5]}, schema=schema
)


def test_aggregation(tmp_path):
ds = lance.write_dataset(pa.table({"a": [1, 2, 3], "b": [4, 5, 6]}), tmp_path)
schema = pa.schema([pa.field("sum(d1.a)", pa.int64(), nullable=True)])
assert lance.sql.query("SELECT SUM(a) FROM d1 WHERE b > 4").with_dataset(
"d1", ds
).to_table() == pa.table({"sum(d1.a)": [5]}, schema=schema)


def test_join(tmp_path):
ds1 = lance.write_dataset(
pa.table({"a": [1, 2, 3], "b": [4, 5, 6]}), tmp_path / "d1"
)
ds2 = lance.write_dataset(
pa.table({"c": [4, 5, 6], "d": ["x", "y", "z"]}), tmp_path / "d2"
)
expected = pa.table({"a": [3, 2, 1], "d": ["z", "y", "x"]})
assert (
lance.sql.query(
"SELECT d1.a, d2.d FROM d1 INNER JOIN d2 ON d1.b = d2.c ORDER BY d1.a DESC"
)
.with_dataset("d1", ds1)
.with_dataset("d2", ds2)
.to_table()
== expected
)
38 changes: 31 additions & 7 deletions python/src/sql.rs
Original file line number Diff line number Diff line change
@@ -1,34 +1,58 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright The Lance Authors

use std::{
collections::HashMap,
sync::{Arc, Mutex},
};

use arrow::pyarrow::PyArrowType;
use arrow_array::RecordBatchReader;
use pyo3::{pyclass, pymethods, PyRef, PyResult};

use lance::datafusion::sql::SqlPlan;
use lance::Dataset as LanceDataset;

use crate::{error::PythonErrorExt, Dataset, LanceReader, RT};

use crate::{error::PythonErrorExt, LanceReader, RT};
struct QueryBuilderState {
datasets: HashMap<String, Arc<LanceDataset>>,
}

#[pyclass]
pub struct SqlQueryBuilder {
pub query: String,
state: Arc<Mutex<QueryBuilderState>>,
}

#[pymethods]
impl SqlQueryBuilder {
#[new]
pub fn new(query: String) -> Self {
Self { query }
Self {
query,
state: Arc::new(Mutex::new(QueryBuilderState {
datasets: HashMap::new(),
})),
}
}

fn with_lance_dataset(self: PyRef<'_, Self>) -> PyRef<'_, Self> {
self
fn with_dataset<'a>(slf: PyRef<'a, Self>, alias: String, dataset: &Dataset) -> PyRef<'a, Self> {
{
let mut state = slf.state.lock().unwrap();
state.datasets.insert(alias, dataset.ds.clone());
}
slf
}

fn execute(self_: PyRef<'_, Self>) -> PyResult<PyArrowType<Box<dyn RecordBatchReader + Send>>> {
let query = SqlPlan::new(self_.query.clone());
fn execute(slf: PyRef<'_, Self>) -> PyResult<PyArrowType<Box<dyn RecordBatchReader + Send>>> {
let context = {
let state = slf.state.lock().unwrap();
state.datasets.clone()
};
let query = SqlPlan::new(slf.query.clone(), context);
let reader = RT
.spawn(Some(self_.py()), async move {
.spawn(Some(slf.py()), async move {
Ok(LanceReader::from_stream(query.execute().await?))
})?
.infer_error()?;
Expand Down
2 changes: 1 addition & 1 deletion rust/lance/src/datafusion/dataframe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ pub struct LanceTableProvider {
}

impl LanceTableProvider {
fn new(dataset: Arc<Dataset>, with_row_id: bool, with_row_addr: bool) -> Self {
pub fn new(dataset: Arc<Dataset>, with_row_id: bool, with_row_addr: bool) -> Self {
let mut full_schema = Schema::from(dataset.schema());
let mut row_id_idx = None;
let mut row_addr_idx = None;
Expand Down
29 changes: 23 additions & 6 deletions rust/lance/src/datafusion/sql.rs
Original file line number Diff line number Diff line change
@@ -1,26 +1,43 @@
use std::collections::HashMap;
use std::{collections::HashMap, sync::Arc};

use lance_datafusion::exec::{get_session_context, LanceExecutionOptions};
use lance_datafusion::exec::{new_session_context, LanceExecutionOptions};

use crate::{dataset::scanner::DatasetRecordBatchStream, Dataset, Result};

use super::dataframe::LanceTableProvider;

/// An SQL query that can be executed
pub struct SqlPlan {
query: String,
context: HashMap<String, Dataset>,
context: HashMap<String, Arc<Dataset>>,
execution_options: LanceExecutionOptions,
}

impl SqlPlan {
pub fn new(query: String) -> Self {
/// Creates a new SQL with a given query string and context
///
/// The context is a mapping of dataset aliases to datasets.
/// This is how the SQL query can reference datasets.
pub fn new(query: String, context: HashMap<String, Arc<Dataset>>) -> Self {
Self {
query,
context: HashMap::new(),
context,
execution_options: LanceExecutionOptions::default(),
}
}

/// Executes the SQL query and returns a stream of record batches
pub async fn execute(&self) -> Result<DatasetRecordBatchStream> {
let session_context = get_session_context(&self.execution_options);
let session_context = new_session_context(&self.execution_options);

for (alias, dataset) in &self.context {
let provider = Arc::new(LanceTableProvider::new(
dataset.clone(),
/*with_row_id= */ true,
/*with_row_addr= */ true,
));
session_context.register_table(alias, provider)?;
}

let df = session_context.sql(&self.query).await?;
let stream = df.execute_stream().await?;
Expand Down

0 comments on commit faedd7e

Please # to comment.