Skip to content

Commit c7ea90d

Browse files
Pyo3 Bound<'py, T> api (#734)
* remove gil-refs feature from pyo3 * migrate module instantiation to Bound api * migrate utils.rs to Bound api * migrate config.rs to Bound api * migrate context.rs to Bound api * migrate udaf.rs to Bound api * migrate pyarrow_filter_expression to Bound api * migrate dataframe.rs to Bound api * migrade dataset and dataset_exec to Bound api * migrate substrait.rs to Bound api
1 parent 1f49d46 commit c7ea90d

15 files changed

+136
-115
lines changed

Cargo.toml

+2-1
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ substrait = ["dep:datafusion-substrait"]
3636
[dependencies]
3737
tokio = { version = "1.35", features = ["macros", "rt", "rt-multi-thread", "sync"] }
3838
rand = "0.8"
39-
pyo3 = { version = "0.21", features = ["extension-module", "abi3", "abi3-py38", "gil-refs"] }
39+
pyo3 = { version = "0.21", features = ["extension-module", "abi3", "abi3-py38"] }
4040
arrow = { version = "52", feature = ["pyarrow"] }
4141
datafusion = { version = "39.0.0", features = ["pyarrow", "avro", "unicode_expressions"] }
4242
datafusion-common = { version = "39.0.0", features = ["pyarrow"] }
@@ -67,3 +67,4 @@ crate-type = ["cdylib", "rlib"]
6767
[profile.release]
6868
lto = true
6969
codegen-units = 1
70+

src/common.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ pub mod function;
2323
pub mod schema;
2424

2525
/// Initializes the `common` module to match the pattern of `datafusion-common` https://docs.rs/datafusion-common/18.0.0/datafusion_common/index.html
26-
pub(crate) fn init_module(m: &PyModule) -> PyResult<()> {
26+
pub(crate) fn init_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
2727
m.add_class::<df_schema::PyDFSchema>()?;
2828
m.add_class::<data_type::PyDataType>()?;
2929
m.add_class::<data_type::DataTypeMap>()?;

src/config.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ impl PyConfig {
6565

6666
/// Get all configuration options
6767
pub fn get_all(&mut self, py: Python) -> PyResult<PyObject> {
68-
let dict = PyDict::new(py);
68+
let dict = PyDict::new_bound(py);
6969
let options = self.config.to_owned();
7070
for entry in options.entries() {
7171
dict.set_item(entry.key, entry.value.clone().into_py(py))?;

src/context.rs

+14-9
Original file line numberDiff line numberDiff line change
@@ -291,11 +291,11 @@ impl PySessionContext {
291291
pub fn register_object_store(
292292
&mut self,
293293
scheme: &str,
294-
store: &PyAny,
294+
store: &Bound<'_, PyAny>,
295295
host: Option<&str>,
296296
) -> PyResult<()> {
297297
let res: Result<(Arc<dyn ObjectStore>, String), PyErr> =
298-
match StorageContexts::extract(store) {
298+
match StorageContexts::extract_bound(store) {
299299
Ok(store) => match store {
300300
StorageContexts::AmazonS3(s3) => Ok((s3.inner, s3.bucket_name)),
301301
StorageContexts::GoogleCloudStorage(gcs) => Ok((gcs.inner, gcs.bucket_name)),
@@ -443,8 +443,8 @@ impl PySessionContext {
443443
) -> PyResult<PyDataFrame> {
444444
Python::with_gil(|py| {
445445
// Instantiate pyarrow Table object & convert to Arrow Table
446-
let table_class = py.import("pyarrow")?.getattr("Table")?;
447-
let args = PyTuple::new(py, &[data]);
446+
let table_class = py.import_bound("pyarrow")?.getattr("Table")?;
447+
let args = PyTuple::new_bound(py, &[data]);
448448
let table = table_class.call_method1("from_pylist", args)?.into();
449449

450450
// Convert Arrow Table to datafusion DataFrame
@@ -463,8 +463,8 @@ impl PySessionContext {
463463
) -> PyResult<PyDataFrame> {
464464
Python::with_gil(|py| {
465465
// Instantiate pyarrow Table object & convert to Arrow Table
466-
let table_class = py.import("pyarrow")?.getattr("Table")?;
467-
let args = PyTuple::new(py, &[data]);
466+
let table_class = py.import_bound("pyarrow")?.getattr("Table")?;
467+
let args = PyTuple::new_bound(py, &[data]);
468468
let table = table_class.call_method1("from_pydict", args)?.into();
469469

470470
// Convert Arrow Table to datafusion DataFrame
@@ -507,8 +507,8 @@ impl PySessionContext {
507507
) -> PyResult<PyDataFrame> {
508508
Python::with_gil(|py| {
509509
// Instantiate pyarrow Table object & convert to Arrow Table
510-
let table_class = py.import("pyarrow")?.getattr("Table")?;
511-
let args = PyTuple::new(py, &[data]);
510+
let table_class = py.import_bound("pyarrow")?.getattr("Table")?;
511+
let args = PyTuple::new_bound(py, &[data]);
512512
let table = table_class.call_method1("from_pandas", args)?.into();
513513

514514
// Convert Arrow Table to datafusion DataFrame
@@ -710,7 +710,12 @@ impl PySessionContext {
710710
}
711711

712712
// Registers a PyArrow.Dataset
713-
pub fn register_dataset(&self, name: &str, dataset: &PyAny, py: Python) -> PyResult<()> {
713+
pub fn register_dataset(
714+
&self,
715+
name: &str,
716+
dataset: &Bound<'_, PyAny>,
717+
py: Python,
718+
) -> PyResult<()> {
714719
let table: Arc<dyn TableProvider> = Arc::new(Dataset::new(dataset, py)?);
715720

716721
self.ctx

src/dataframe.rs

+41-26
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ use datafusion::prelude::*;
2828
use datafusion_common::UnnestOptions;
2929
use pyo3::exceptions::{PyTypeError, PyValueError};
3030
use pyo3::prelude::*;
31+
use pyo3::pybacked::PyBackedStr;
3132
use pyo3::types::PyTuple;
3233
use tokio::task::JoinHandle;
3334

@@ -56,23 +57,25 @@ impl PyDataFrame {
5657

5758
#[pymethods]
5859
impl PyDataFrame {
59-
fn __getitem__(&self, key: PyObject) -> PyResult<Self> {
60-
Python::with_gil(|py| {
61-
if let Ok(key) = key.extract::<&str>(py) {
62-
self.select_columns(vec![key])
63-
} else if let Ok(tuple) = key.extract::<&PyTuple>(py) {
64-
let keys = tuple
65-
.iter()
66-
.map(|item| item.extract::<&str>())
67-
.collect::<PyResult<Vec<&str>>>()?;
68-
self.select_columns(keys)
69-
} else if let Ok(keys) = key.extract::<Vec<&str>>(py) {
70-
self.select_columns(keys)
71-
} else {
72-
let message = "DataFrame can only be indexed by string index or indices";
73-
Err(PyTypeError::new_err(message))
74-
}
75-
})
60+
/// Enable selection for `df[col]`, `df[col1, col2, col3]`, and `df[[col1, col2, col3]]`
61+
fn __getitem__(&self, key: Bound<'_, PyAny>) -> PyResult<Self> {
62+
if let Ok(key) = key.extract::<PyBackedStr>() {
63+
// df[col]
64+
self.select_columns(vec![key])
65+
} else if let Ok(tuple) = key.extract::<&PyTuple>() {
66+
// df[col1, col2, col3]
67+
let keys = tuple
68+
.iter()
69+
.map(|item| item.extract::<PyBackedStr>())
70+
.collect::<PyResult<Vec<PyBackedStr>>>()?;
71+
self.select_columns(keys)
72+
} else if let Ok(keys) = key.extract::<Vec<PyBackedStr>>() {
73+
// df[[col1, col2, col3]]
74+
self.select_columns(keys)
75+
} else {
76+
let message = "DataFrame can only be indexed by string index or indices";
77+
Err(PyTypeError::new_err(message))
78+
}
7679
}
7780

7881
fn __repr__(&self, py: Python) -> PyResult<String> {
@@ -98,7 +101,8 @@ impl PyDataFrame {
98101
}
99102

100103
#[pyo3(signature = (*args))]
101-
fn select_columns(&self, args: Vec<&str>) -> PyResult<Self> {
104+
fn select_columns(&self, args: Vec<PyBackedStr>) -> PyResult<Self> {
105+
let args = args.iter().map(|s| s.as_ref()).collect::<Vec<&str>>();
102106
let df = self.df.as_ref().clone().select_columns(&args)?;
103107
Ok(Self::new(df))
104108
}
@@ -194,7 +198,7 @@ impl PyDataFrame {
194198
fn join(
195199
&self,
196200
right: PyDataFrame,
197-
join_keys: (Vec<&str>, Vec<&str>),
201+
join_keys: (Vec<PyBackedStr>, Vec<PyBackedStr>),
198202
how: &str,
199203
) -> PyResult<Self> {
200204
let join_type = match how {
@@ -212,11 +216,22 @@ impl PyDataFrame {
212216
}
213217
};
214218

219+
let left_keys = join_keys
220+
.0
221+
.iter()
222+
.map(|s| s.as_ref())
223+
.collect::<Vec<&str>>();
224+
let right_keys = join_keys
225+
.1
226+
.iter()
227+
.map(|s| s.as_ref())
228+
.collect::<Vec<&str>>();
229+
215230
let df = self.df.as_ref().clone().join(
216231
right.df.as_ref().clone(),
217232
join_type,
218-
&join_keys.0,
219-
&join_keys.1,
233+
&left_keys,
234+
&right_keys,
220235
None,
221236
)?;
222237
Ok(Self::new(df))
@@ -414,8 +429,8 @@ impl PyDataFrame {
414429

415430
Python::with_gil(|py| {
416431
// Instantiate pyarrow Table object and use its from_batches method
417-
let table_class = py.import("pyarrow")?.getattr("Table")?;
418-
let args = PyTuple::new(py, &[batches, schema]);
432+
let table_class = py.import_bound("pyarrow")?.getattr("Table")?;
433+
let args = PyTuple::new_bound(py, &[batches, schema]);
419434
let table: PyObject = table_class.call_method1("from_batches", args)?.into();
420435
Ok(table)
421436
})
@@ -489,8 +504,8 @@ impl PyDataFrame {
489504
let table = self.to_arrow_table(py)?;
490505

491506
Python::with_gil(|py| {
492-
let dataframe = py.import("polars")?.getattr("DataFrame")?;
493-
let args = PyTuple::new(py, &[table]);
507+
let dataframe = py.import_bound("polars")?.getattr("DataFrame")?;
508+
let args = PyTuple::new_bound(py, &[table]);
494509
let result: PyObject = dataframe.call1(args)?.into();
495510
Ok(result)
496511
})
@@ -514,7 +529,7 @@ fn print_dataframe(py: Python, df: DataFrame) -> PyResult<()> {
514529

515530
// Import the Python 'builtins' module to access the print function
516531
// Note that println! does not print to the Python debug console and is not visible in notebooks for instance
517-
let print = py.import("builtins")?.getattr("print")?;
532+
let print = py.import_bound("builtins")?.getattr("print")?;
518533
print.call1((result,))?;
519534
Ok(())
520535
}

src/dataset.rs

+7-6
Original file line numberDiff line numberDiff line change
@@ -46,13 +46,14 @@ pub(crate) struct Dataset {
4646

4747
impl Dataset {
4848
// Creates a Python PyArrow.Dataset
49-
pub fn new(dataset: &PyAny, py: Python) -> PyResult<Self> {
49+
pub fn new(dataset: &Bound<'_, PyAny>, py: Python) -> PyResult<Self> {
5050
// Ensure that we were passed an instance of pyarrow.dataset.Dataset
51-
let ds = PyModule::import(py, "pyarrow.dataset")?;
52-
let ds_type: &PyType = ds.getattr("Dataset")?.downcast()?;
51+
let ds = PyModule::import_bound(py, "pyarrow.dataset")?;
52+
let ds_attr = ds.getattr("Dataset")?;
53+
let ds_type = ds_attr.downcast::<PyType>()?;
5354
if dataset.is_instance(ds_type)? {
5455
Ok(Dataset {
55-
dataset: dataset.into(),
56+
dataset: dataset.clone().unbind(),
5657
})
5758
} else {
5859
Err(PyValueError::new_err(
@@ -73,7 +74,7 @@ impl TableProvider for Dataset {
7374
/// Get a reference to the schema for this table
7475
fn schema(&self) -> SchemaRef {
7576
Python::with_gil(|py| {
76-
let dataset = self.dataset.as_ref(py);
77+
let dataset = self.dataset.bind(py);
7778
// This can panic but since we checked that self.dataset is a pyarrow.dataset.Dataset it should never
7879
Arc::new(
7980
dataset
@@ -108,7 +109,7 @@ impl TableProvider for Dataset {
108109
) -> DFResult<Arc<dyn ExecutionPlan>> {
109110
Python::with_gil(|py| {
110111
let plan: Arc<dyn ExecutionPlan> = Arc::new(
111-
DatasetExec::new(py, self.dataset.as_ref(py), projection.cloned(), filters)
112+
DatasetExec::new(py, self.dataset.bind(py), projection.cloned(), filters)
112113
.map_err(|err| DataFusionError::External(Box::new(err)))?,
113114
);
114115
Ok(plan)

src/dataset_exec.rs

+17-19
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ impl Iterator for PyArrowBatchesAdapter {
5353

5454
fn next(&mut self) -> Option<Self::Item> {
5555
Python::with_gil(|py| {
56-
let mut batches: &PyIterator = self.batches.as_ref(py);
56+
let mut batches = self.batches.clone().into_bound(py);
5757
Some(
5858
batches
5959
.next()?
@@ -79,7 +79,7 @@ pub(crate) struct DatasetExec {
7979
impl DatasetExec {
8080
pub fn new(
8181
py: Python,
82-
dataset: &PyAny,
82+
dataset: &Bound<'_, PyAny>,
8383
projection: Option<Vec<usize>>,
8484
filters: &[Expr],
8585
) -> Result<Self, DataFusionError> {
@@ -103,15 +103,15 @@ impl DatasetExec {
103103
})
104104
.transpose()?;
105105

106-
let kwargs = PyDict::new(py);
106+
let kwargs = PyDict::new_bound(py);
107107

108108
kwargs.set_item("columns", columns.clone())?;
109109
kwargs.set_item(
110110
"filter",
111111
filter_expr.as_ref().map(|expr| expr.clone_ref(py)),
112112
)?;
113113

114-
let scanner = dataset.call_method("scanner", (), Some(kwargs))?;
114+
let scanner = dataset.call_method("scanner", (), Some(&kwargs))?;
115115

116116
let schema = Arc::new(
117117
scanner
@@ -120,19 +120,17 @@ impl DatasetExec {
120120
.0,
121121
);
122122

123-
let builtins = Python::import(py, "builtins")?;
123+
let builtins = Python::import_bound(py, "builtins")?;
124124
let pylist = builtins.getattr("list")?;
125125

126126
// Get the fragments or partitions of the dataset
127-
let fragments_iterator: &PyAny = dataset.call_method1(
127+
let fragments_iterator: Bound<'_, PyAny> = dataset.call_method1(
128128
"get_fragments",
129129
(filter_expr.as_ref().map(|expr| expr.clone_ref(py)),),
130130
)?;
131131

132-
let fragments: &PyList = pylist
133-
.call1((fragments_iterator,))?
134-
.downcast()
135-
.map_err(PyErr::from)?;
132+
let fragments_iter = pylist.call1((fragments_iterator,))?;
133+
let fragments = fragments_iter.downcast::<PyList>().map_err(PyErr::from)?;
136134

137135
let projected_statistics = Statistics::new_unknown(&schema);
138136
let plan_properties = datafusion::physical_plan::PlanProperties::new(
@@ -142,9 +140,9 @@ impl DatasetExec {
142140
);
143141

144142
Ok(DatasetExec {
145-
dataset: dataset.into(),
143+
dataset: dataset.clone().unbind(),
146144
schema,
147-
fragments: fragments.into(),
145+
fragments: fragments.clone().unbind(),
148146
columns,
149147
filter_expr,
150148
projected_statistics,
@@ -183,8 +181,8 @@ impl ExecutionPlan for DatasetExec {
183181
) -> DFResult<SendableRecordBatchStream> {
184182
let batch_size = context.session_config().batch_size();
185183
Python::with_gil(|py| {
186-
let dataset = self.dataset.as_ref(py);
187-
let fragments = self.fragments.as_ref(py);
184+
let dataset = self.dataset.bind(py);
185+
let fragments = self.fragments.bind(py);
188186
let fragment = fragments
189187
.get_item(partition)
190188
.map_err(|err| InnerDataFusionError::External(Box::new(err)))?;
@@ -193,7 +191,7 @@ impl ExecutionPlan for DatasetExec {
193191
let dataset_schema = dataset
194192
.getattr("schema")
195193
.map_err(|err| InnerDataFusionError::External(Box::new(err)))?;
196-
let kwargs = PyDict::new(py);
194+
let kwargs = PyDict::new_bound(py);
197195
kwargs
198196
.set_item("columns", self.columns.clone())
199197
.map_err(|err| InnerDataFusionError::External(Box::new(err)))?;
@@ -207,15 +205,15 @@ impl ExecutionPlan for DatasetExec {
207205
.set_item("batch_size", batch_size)
208206
.map_err(|err| InnerDataFusionError::External(Box::new(err)))?;
209207
let scanner = fragment
210-
.call_method("scanner", (dataset_schema,), Some(kwargs))
208+
.call_method("scanner", (dataset_schema,), Some(&kwargs))
211209
.map_err(|err| InnerDataFusionError::External(Box::new(err)))?;
212210
let schema: SchemaRef = Arc::new(
213211
scanner
214212
.getattr("projected_schema")
215213
.and_then(|schema| Ok(schema.extract::<PyArrowType<_>>()?.0))
216214
.map_err(|err| InnerDataFusionError::External(Box::new(err)))?,
217215
);
218-
let record_batches: &PyIterator = scanner
216+
let record_batches: Bound<'_, PyIterator> = scanner
219217
.call_method0("to_batches")
220218
.map_err(|err| InnerDataFusionError::External(Box::new(err)))?
221219
.iter()
@@ -264,7 +262,7 @@ impl ExecutionPlanProperties for DatasetExec {
264262
impl DisplayAs for DatasetExec {
265263
fn fmt_as(&self, t: DisplayFormatType, f: &mut std::fmt::Formatter) -> std::fmt::Result {
266264
Python::with_gil(|py| {
267-
let number_of_fragments = self.fragments.as_ref(py).len();
265+
let number_of_fragments = self.fragments.bind(py).len();
268266
match t {
269267
DisplayFormatType::Default | DisplayFormatType::Verbose => {
270268
let projected_columns: Vec<String> = self
@@ -274,7 +272,7 @@ impl DisplayAs for DatasetExec {
274272
.map(|x| x.name().to_owned())
275273
.collect();
276274
if let Some(filter_expr) = &self.filter_expr {
277-
let filter_expr = filter_expr.as_ref(py).str().or(Err(std::fmt::Error))?;
275+
let filter_expr = filter_expr.bind(py).str().or(Err(std::fmt::Error))?;
278276
write!(
279277
f,
280278
"DatasetExec: number_of_fragments={}, filter_expr={}, projection=[{}]",

src/expr.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -553,7 +553,7 @@ impl PyExpr {
553553
}
554554

555555
/// Initializes the `expr` module to match the pattern of `datafusion-expr` https://docs.rs/datafusion-expr/latest/datafusion_expr/
556-
pub(crate) fn init_module(m: &PyModule) -> PyResult<()> {
556+
pub(crate) fn init_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
557557
m.add_class::<PyExpr>()?;
558558
m.add_class::<PyColumn>()?;
559559
m.add_class::<PyLiteral>()?;

0 commit comments

Comments
 (0)