Skip to content

Add Interruptible Query Execution in Jupyter via KeyboardInterrupt Support #1141

New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Merged
merged 56 commits into from
Jun 16, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
56 commits
Select commit Hold shift + click to select a range
60ff7f2
fix: enhance error handling in async wait_for_future function
kosiew Jun 3, 2025
b168afb
feat: implement async execution for execution plans in PySessionContext
kosiew Jun 3, 2025
cca1cf8
fix: improve error message for execution failures in PySessionContext
kosiew Jun 3, 2025
81e4204
fix: enhance error handling and improve execution plan retrieval in P…
kosiew Jun 3, 2025
cfc3c40
fix: ensure 'static lifetime for futures in wait_for_future function
kosiew Jun 3, 2025
1ee861a
fix: handle potential errors when caching DataFrame and retrieving ex…
kosiew Jun 3, 2025
569dc26
fix: flatten batches in PyDataFrame to ensure proper schema conversion
kosiew Jun 3, 2025
45c9fa3
fix: correct error handling in batch processing for schema conversion
kosiew Jun 3, 2025
5336cf4
fix: flatten nested structure in PyDataFrame to ensure proper RecordB…
kosiew Jun 3, 2025
fb814cc
fix: improve error handling in PyDataFrame stream execution
kosiew Jun 3, 2025
a24e280
fix: add utility to get Tokio Runtime with time enabled and update wa…
kosiew Jun 3, 2025
6f0ef29
fix: store result of converting RecordBatches to PyArrow for debugging
kosiew Jun 3, 2025
ec2abf1
fix: handle error from wait_for_future in PyDataFrame collect method
kosiew Jun 3, 2025
aa71328
fix: propagate error from wait_for_future in collect_partitioned method
kosiew Jun 3, 2025
d9bfed5
fix: enable IO in Tokio runtime with time support
kosiew Jun 3, 2025
52a5efe
main register_listing_table
kosiew Jun 3, 2025
b778911
Revert "main register_listing_table"
kosiew Jun 3, 2025
f10652c
fix: propagate error correctly from wait_for_future in PySessionConte…
kosiew Jun 3, 2025
97f86dc
fix: simplify error handling in PySessionContext by unwrapping wait_f…
kosiew Jun 3, 2025
3f17c9c
test: add interruption handling test for long-running queries in Data…
kosiew Jun 3, 2025
dafa3a5
test: move test_collect_interrupted to test_dataframe.py
kosiew Jun 3, 2025
973e690
fix: add const for interval in wait_for_future utility
kosiew Jun 3, 2025
ca2d892
fix: use get_tokio_runtime instead of the custom get_runtime
kosiew Jun 4, 2025
ada6faa
Revert "fix: use get_tokio_runtime instead of the custom get_runtime"
kosiew Jun 4, 2025
f14b51f
fix: use get_tokio_runtime instead of the custom get_runtime
kosiew Jun 4, 2025
b8ce3e4
.
kosiew Jun 4, 2025
5073070
Revert "."
kosiew Jun 4, 2025
63a1ab7
fix: improve query interruption handling in test_collect_interrupted
kosiew Jun 4, 2025
c6e5205
fix: ensure proper handling of query interruption in test_collect_int…
kosiew Jun 4, 2025
d0f2d37
fix: improve error handling in database table retrieval
kosiew Jun 4, 2025
faabf6d
refactor: add helper for async move
kosiew Jun 4, 2025
010b4c6
Revert "refactor: add helper for async move"
kosiew Jun 4, 2025
88a15ee
move py_err_to_datafusion_err to errors.rs
kosiew Jun 4, 2025
f6a4ea4
add create_csv_read_options
kosiew Jun 4, 2025
14f9dd3
fix
kosiew Jun 4, 2025
40b4345
create_csv_read_options -> PyDataFusionResult
kosiew Jun 4, 2025
bcc4f81
revert to before create_csv_read_options
kosiew Jun 4, 2025
8b0e2e1
refactor: simplify file compression type parsing in PySessionContext
kosiew Jun 4, 2025
b1e67df
fix: parse_compression_type once only
kosiew Jun 4, 2025
393e7ca
add create_ndjson_read_options
kosiew Jun 4, 2025
268a855
refactor comment for clarity in wait_for_future function
kosiew Jun 4, 2025
bf91cf8
refactor wait_for_future to avoid spawn
kosiew Jun 4, 2025
c3d808c
remove unused py_err_to_datafusion_err function
kosiew Jun 4, 2025
ccb4a5c
add comment to clarify error handling in next method of PyRecordBatch…
kosiew Jun 4, 2025
ea3673b
handle error from wait_for_future in PySubstraitSerializer
kosiew Jun 4, 2025
8c50e7f
clarify comment on future pinning in wait_for_future function
kosiew Jun 4, 2025
aead2a0
refactor wait_for_future to use Duration for signal check interval
kosiew Jun 4, 2025
84f31bb
handle error from wait_for_future in count method of PyDataFrame
kosiew Jun 4, 2025
e2f1e1c
fix ruff errors
kosiew Jun 4, 2025
d5cf16d
fix clippy errors
kosiew Jun 4, 2025
48e4f7c
remove unused get_and_enter_tokio_runtime function and simplify wait_…
kosiew Jun 4, 2025
00f8041
Refactor async handling in PySessionContext and PyDataFrame
kosiew Jun 5, 2025
0179619
Organize imports in utils.rs for improved readability
kosiew Jun 5, 2025
bc6f5a6
map_err instead of panic
kosiew Jun 13, 2025
42ab687
Fix error handling in async stream execution for PySessionContext and…
kosiew Jun 14, 2025
79ebe1e
Merge branch 'main' into interrupt-1136
kosiew Jun 16, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
121 changes: 121 additions & 0 deletions python/tests/test_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,12 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import ctypes
import datetime
import os
import re
import threading
import time
from typing import Any

import pyarrow as pa
Expand Down Expand Up @@ -2060,3 +2063,121 @@ def test_fill_null_all_null_column(ctx):
# Check that all nulls were filled
result = filled_df.collect()[0]
assert result.column(1).to_pylist() == ["filled", "filled", "filled"]


def test_collect_interrupted():
"""Test that a long-running query can be interrupted with Ctrl-C.

This test simulates a Ctrl-C keyboard interrupt by raising a KeyboardInterrupt
exception in the main thread during a long-running query execution.
"""
# Create a context and a DataFrame with a query that will run for a while
ctx = SessionContext()

# Create a recursive computation that will run for some time
batches = []
for i in range(10):
batch = pa.RecordBatch.from_arrays(
[
pa.array(list(range(i * 1000, (i + 1) * 1000))),
pa.array([f"value_{j}" for j in range(i * 1000, (i + 1) * 1000)]),
],
names=["a", "b"],
)
batches.append(batch)

# Register tables
ctx.register_record_batches("t1", [batches])
ctx.register_record_batches("t2", [batches])

# Create a large join operation that will take time to process
df = ctx.sql("""
WITH t1_expanded AS (
SELECT
a,
b,
CAST(a AS DOUBLE) / 1.5 AS c,
CAST(a AS DOUBLE) * CAST(a AS DOUBLE) AS d
FROM t1
CROSS JOIN (SELECT 1 AS dummy FROM t1 LIMIT 5)
),
t2_expanded AS (
SELECT
a,
b,
CAST(a AS DOUBLE) * 2.5 AS e,
CAST(a AS DOUBLE) * CAST(a AS DOUBLE) * CAST(a AS DOUBLE) AS f
FROM t2
CROSS JOIN (SELECT 1 AS dummy FROM t2 LIMIT 5)
)
SELECT
t1.a, t1.b, t1.c, t1.d,
t2.a AS a2, t2.b AS b2, t2.e, t2.f
FROM t1_expanded t1
JOIN t2_expanded t2 ON t1.a % 100 = t2.a % 100
WHERE t1.a > 100 AND t2.a > 100
""")

# Flag to track if the query was interrupted
interrupted = False
interrupt_error = None
main_thread = threading.main_thread()

# Shared flag to indicate query execution has started
query_started = threading.Event()
max_wait_time = 5.0 # Maximum wait time in seconds

# This function will be run in a separate thread and will raise
# KeyboardInterrupt in the main thread
def trigger_interrupt():
"""Poll for query start, then raise KeyboardInterrupt in the main thread"""
# Poll for query to start with small sleep intervals
start_time = time.time()
while not query_started.is_set():
time.sleep(0.1) # Small sleep between checks
if time.time() - start_time > max_wait_time:
msg = f"Query did not start within {max_wait_time} seconds"
raise RuntimeError(msg)

# Check if thread ID is available
thread_id = main_thread.ident
if thread_id is None:
msg = "Cannot get main thread ID"
raise RuntimeError(msg)

# Use ctypes to raise exception in main thread
exception = ctypes.py_object(KeyboardInterrupt)
res = ctypes.pythonapi.PyThreadState_SetAsyncExc(
ctypes.c_long(thread_id), exception
)
if res != 1:
# If res is 0, the thread ID was invalid
# If res > 1, we modified multiple threads
ctypes.pythonapi.PyThreadState_SetAsyncExc(
ctypes.c_long(thread_id), ctypes.py_object(0)
)
msg = "Failed to raise KeyboardInterrupt in main thread"
raise RuntimeError(msg)

# Start a thread to trigger the interrupt
interrupt_thread = threading.Thread(target=trigger_interrupt)
# we mark as daemon so the test process can exit even if this thread doesn't finish
interrupt_thread.daemon = True
interrupt_thread.start()

# Execute the query and expect it to be interrupted
try:
# Signal that we're about to start the query
query_started.set()
df.collect()
except KeyboardInterrupt:
interrupted = True
except Exception as e:
interrupt_error = e

# Assert that the query was interrupted properly
if not interrupted:
pytest.fail(f"Query was not interrupted; got error: {interrupt_error}")

# Make sure the interrupt thread has finished
interrupt_thread.join(timeout=1.0)
2 changes: 1 addition & 1 deletion src/catalog.rs
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ impl PyDatabase {
}

fn table(&self, name: &str, py: Python) -> PyDataFusionResult<PyTable> {
if let Some(table) = wait_for_future(py, self.database.table(name))? {
if let Some(table) = wait_for_future(py, self.database.table(name))?? {
Ok(PyTable::new(table))
} else {
Err(PyDataFusionError::Common(format!(
Expand Down
52 changes: 31 additions & 21 deletions src/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ use pyo3::prelude::*;
use crate::catalog::{PyCatalog, PyTable};
use crate::dataframe::PyDataFrame;
use crate::dataset::Dataset;
use crate::errors::{py_datafusion_err, PyDataFusionResult};
use crate::errors::{py_datafusion_err, to_datafusion_err, PyDataFusionResult};
use crate::expr::sort_expr::PySortExpr;
use crate::physical_plan::PyExecutionPlan;
use crate::record_batch::PyRecordBatchStream;
Expand Down Expand Up @@ -375,7 +375,7 @@ impl PySessionContext {
None => {
let state = self.ctx.state();
let schema = options.infer_schema(&state, &table_path);
wait_for_future(py, schema)?
wait_for_future(py, schema)??
}
};
let config = ListingTableConfig::new(table_path)
Expand All @@ -400,7 +400,7 @@ impl PySessionContext {
/// Returns a PyDataFrame whose plan corresponds to the SQL statement.
pub fn sql(&mut self, query: &str, py: Python) -> PyDataFusionResult<PyDataFrame> {
let result = self.ctx.sql(query);
let df = wait_for_future(py, result)?;
let df = wait_for_future(py, result)??;
Ok(PyDataFrame::new(df))
}

Expand All @@ -417,7 +417,7 @@ impl PySessionContext {
SQLOptions::new()
};
let result = self.ctx.sql_with_options(query, options);
let df = wait_for_future(py, result)?;
let df = wait_for_future(py, result)??;
Ok(PyDataFrame::new(df))
}

Expand Down Expand Up @@ -451,7 +451,7 @@ impl PySessionContext {

self.ctx.register_table(&*table_name, Arc::new(table))?;

let table = wait_for_future(py, self._table(&table_name))?;
let table = wait_for_future(py, self._table(&table_name))??;

let df = PyDataFrame::new(table);
Ok(df)
Expand Down Expand Up @@ -650,7 +650,7 @@ impl PySessionContext {
.collect();

let result = self.ctx.register_parquet(name, path, options);
wait_for_future(py, result)?;
wait_for_future(py, result)??;
Ok(())
}

Expand Down Expand Up @@ -693,11 +693,11 @@ impl PySessionContext {
if path.is_instance_of::<PyList>() {
let paths = path.extract::<Vec<String>>()?;
let result = self.register_csv_from_multiple_paths(name, paths, options);
wait_for_future(py, result)?;
wait_for_future(py, result)??;
} else {
let path = path.extract::<String>()?;
let result = self.ctx.register_csv(name, &path, options);
wait_for_future(py, result)?;
wait_for_future(py, result)??;
}

Ok(())
Expand Down Expand Up @@ -734,7 +734,7 @@ impl PySessionContext {
options.schema = schema.as_ref().map(|x| &x.0);

let result = self.ctx.register_json(name, path, options);
wait_for_future(py, result)?;
wait_for_future(py, result)??;

Ok(())
}
Expand Down Expand Up @@ -764,7 +764,7 @@ impl PySessionContext {
options.schema = schema.as_ref().map(|x| &x.0);

let result = self.ctx.register_avro(name, path, options);
wait_for_future(py, result)?;
wait_for_future(py, result)??;

Ok(())
}
Expand Down Expand Up @@ -825,9 +825,19 @@ impl PySessionContext {
}

pub fn table(&self, name: &str, py: Python) -> PyResult<PyDataFrame> {
let x = wait_for_future(py, self.ctx.table(name))
let res = wait_for_future(py, self.ctx.table(name))
.map_err(|e| PyKeyError::new_err(e.to_string()))?;
Ok(PyDataFrame::new(x))
match res {
Ok(df) => Ok(PyDataFrame::new(df)),
Err(e) => {
if let datafusion::error::DataFusionError::Plan(msg) = &e {
if msg.contains("No table named") {
return Err(PyKeyError::new_err(msg.to_string()));
}
}
Err(py_datafusion_err(e))
}
}
}

pub fn table_exist(&self, name: &str) -> PyDataFusionResult<bool> {
Expand Down Expand Up @@ -865,10 +875,10 @@ impl PySessionContext {
let df = if let Some(schema) = schema {
options.schema = Some(&schema.0);
let result = self.ctx.read_json(path, options);
wait_for_future(py, result)?
wait_for_future(py, result)??
} else {
let result = self.ctx.read_json(path, options);
wait_for_future(py, result)?
wait_for_future(py, result)??
};
Ok(PyDataFrame::new(df))
}
Expand Down Expand Up @@ -915,12 +925,12 @@ impl PySessionContext {
let paths = path.extract::<Vec<String>>()?;
let paths = paths.iter().map(|p| p as &str).collect::<Vec<&str>>();
let result = self.ctx.read_csv(paths, options);
let df = PyDataFrame::new(wait_for_future(py, result)?);
let df = PyDataFrame::new(wait_for_future(py, result)??);
Ok(df)
} else {
let path = path.extract::<String>()?;
let result = self.ctx.read_csv(path, options);
let df = PyDataFrame::new(wait_for_future(py, result)?);
let df = PyDataFrame::new(wait_for_future(py, result)??);
Ok(df)
}
}
Expand Down Expand Up @@ -958,7 +968,7 @@ impl PySessionContext {
.collect();

let result = self.ctx.read_parquet(path, options);
let df = PyDataFrame::new(wait_for_future(py, result)?);
let df = PyDataFrame::new(wait_for_future(py, result)??);
Ok(df)
}

Expand All @@ -978,10 +988,10 @@ impl PySessionContext {
let df = if let Some(schema) = schema {
options.schema = Some(&schema.0);
let read_future = self.ctx.read_avro(path, options);
wait_for_future(py, read_future)?
wait_for_future(py, read_future)??
} else {
let read_future = self.ctx.read_avro(path, options);
wait_for_future(py, read_future)?
wait_for_future(py, read_future)??
};
Ok(PyDataFrame::new(df))
}
Expand Down Expand Up @@ -1021,8 +1031,8 @@ impl PySessionContext {
let plan = plan.plan.clone();
let fut: JoinHandle<datafusion::common::Result<SendableRecordBatchStream>> =
rt.spawn(async move { plan.execute(part, Arc::new(ctx)) });
let stream = wait_for_future(py, fut).map_err(py_datafusion_err)?;
Ok(PyRecordBatchStream::new(stream?))
let stream = wait_for_future(py, async { fut.await.map_err(to_datafusion_err) })???;
Ok(PyRecordBatchStream::new(stream))
}
}

Expand Down
Loading