Skip to content

Commit aa8aa9c

Browse files
authored
Python wrapper classes for all user interfaces (#750)
* Expose missing functions to python * Initial commit for creating wrapper classes and functions for all user facing python features * Remove extra level of python path that is no longer required * Move import to only happen for type checking for hints * Comment out classes from __all__ in the top level that are not currently exposed. * Add license comments * Add missing import * Functions now only has one level of depth * Applying google docstring formatting * Addressing PR request to add google formatted docstrings * Small docstring for ruff * Linting * Add docstring format checking to pre-commit stage * Set explicit return types on UDFs * Add options of passing either a path or a string * Switch to google docstring style * Update unit tests to include registering via path or string * Add py.typed file * Resolve deprecation warnings in unit tests * Add path to unit test * Expose an option in write_csv to include header and add unit test * Update write_parquet unit test to include paths or strings * Add unit test for write_json * Add unit test for substrait serialization to a file * Add unit tests for runtime config * Setting return type to typing_extensions.Self per PR recommendation * Correcting __next__ to not return None since it will raise an exception instead. * Add optiona parameter of decimal places to round and add unit test * Improve docstrings * Set default to None instead of empty dict * User request to allow passing multiple arguments to filter() * Enhance Expr comparison operators to accept any python value and attempt to convert it to a literal * Expose overlay and add unit test * Allow select() to take either str for column names or a full expr * Update comments on regexp and add unit tests * Remove TODO markings no longer applicable * Update udf documentation * Docstring formatting * Updating docstring formatting * Updating docstring formatting * Updating docstring formatting * Updating docstring formatting * Updating docstring formatting * Cleaning up docstring line lengths * Add pre-commit check of docstring line length * Do not emit doc entry for __init__ of some classes * Correct errors on code blocks generating in sphinx * Resolve conflict with * Add license info to py.typed * Clean up some docstring too long errors in CI * Correct ruff complain in unit tests * Temporarily install google test to get clippy to pass * Adding gmock to build step due to upstream error * Add type_extensions to conda meta file * Small comment suggestions from PR
1 parent faa5a3f commit aa8aa9c

40 files changed

+4441
-288
lines changed

.github/workflows/build.yml

+4
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,10 @@ jobs:
8989
name: python-wheel-license
9090
path: .
9191

92+
# To remove once https://github.com/MaterializeInc/rust-protobuf-native/issues/20 is resolved
93+
- name: Install gtest
94+
uses: MarkusJx/googletest-installer@v1.1
95+
9296
- name: Install Protoc
9397
uses: arduino/setup-protoc@v1
9498
with:

.github/workflows/test.yaml

+4
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,10 @@ jobs:
5555
version: '3.20.2'
5656
repo-token: ${{ secrets.GITHUB_TOKEN }}
5757

58+
# To remove once https://github.com/MaterializeInc/rust-protobuf-native/issues/20 is resolved
59+
- name: Install gtest
60+
uses: MarkusJx/googletest-installer@v1.1
61+
5862
- name: Setup Python
5963
uses: actions/setup-python@v5
6064
with:

benchmarks/db-benchmark/join-datafusion.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,8 @@ def ans_shape(batches):
7474
ctx = df.SessionContext()
7575
print(ctx)
7676

77-
# TODO we should be applying projections to these table reads to crete relations of different sizes
77+
# TODO we should be applying projections to these table reads to create relations
78+
# of different sizes
7879

7980
x_data = pacsv.read_csv(
8081
src_jn_x, convert_options=pacsv.ConvertOptions(auto_dict_encode=True)

conda/recipes/meta.yaml

+1
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ requirements:
5151
run:
5252
- python
5353
- pyarrow >=11.0.0
54+
- typing_extensions
5455

5556
test:
5657
imports:

docs/source/api/functions.rst

+1-1
Original file line numberDiff line numberDiff line change
@@ -24,4 +24,4 @@ Functions
2424
.. autosummary::
2525
:toctree: ../generated/
2626

27-
functions.functions
27+
functions

docs/source/conf.py

+21
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
# specific language governing permissions and limitations
1616
# under the License.
1717

18+
"""Documenation generation."""
19+
1820
# Configuration file for the Sphinx documentation builder.
1921
#
2022
# This file only contains a selection of the most common options. For a full
@@ -78,6 +80,25 @@
7880

7981
autosummary_generate = True
8082

83+
84+
def autodoc_skip_member(app, what, name, obj, skip, options):
85+
exclude_functions = "__init__"
86+
exclude_classes = ("Expr", "DataFrame")
87+
88+
class_name = ""
89+
if hasattr(obj, "__qualname__"):
90+
if obj.__qualname__ is not None:
91+
class_name = obj.__qualname__.split(".")[0]
92+
93+
should_exclude = name in exclude_functions and class_name in exclude_classes
94+
95+
return True if should_exclude else None
96+
97+
98+
def setup(app):
99+
app.connect("autodoc-skip-member", autodoc_skip_member)
100+
101+
81102
# -- Options for HTML output -------------------------------------------------
82103

83104
# The theme to use for HTML and HTML Help pages. See the documentation for

examples/substrait.py

+5-10
Original file line numberDiff line numberDiff line change
@@ -18,16 +18,13 @@
1818
from datafusion import SessionContext
1919
from datafusion import substrait as ss
2020

21-
2221
# Create a DataFusion context
2322
ctx = SessionContext()
2423

2524
# Register table with context
2625
ctx.register_csv("aggregate_test_data", "./testing/data/csv/aggregate_test_100.csv")
2726

28-
substrait_plan = ss.substrait.serde.serialize_to_plan(
29-
"SELECT * FROM aggregate_test_data", ctx
30-
)
27+
substrait_plan = ss.Serde.serialize_to_plan("SELECT * FROM aggregate_test_data", ctx)
3128
# type(substrait_plan) -> <class 'datafusion.substrait.plan'>
3229

3330
# Encode it to bytes
@@ -38,17 +35,15 @@
3835
# Alternative serialization approaches
3936
# type(substrait_bytes) -> <class 'bytes'>, at this point the bytes can be distributed to file, network, etc safely
4037
# where they could subsequently be deserialized on the receiving end.
41-
substrait_bytes = ss.substrait.serde.serialize_bytes(
42-
"SELECT * FROM aggregate_test_data", ctx
43-
)
38+
substrait_bytes = ss.Serde.serialize_bytes("SELECT * FROM aggregate_test_data", ctx)
4439

4540
# Imagine here bytes would be read from network, file, etc ... for example brevity this is omitted and variable is simply reused
4641
# type(substrait_plan) -> <class 'datafusion.substrait.plan'>
47-
substrait_plan = ss.substrait.serde.deserialize_bytes(substrait_bytes)
42+
substrait_plan = ss.Serde.deserialize_bytes(substrait_bytes)
4843

4944
# type(df_logical_plan) -> <class 'substrait.LogicalPlan'>
50-
df_logical_plan = ss.substrait.consumer.from_substrait_plan(ctx, substrait_plan)
45+
df_logical_plan = ss.Consumer.from_substrait_plan(ctx, substrait_plan)
5146

5247
# Back to Substrait Plan just for demonstration purposes
5348
# type(substrait_plan) -> <class 'datafusion.substrait.plan'>
54-
substrait_plan = ss.substrait.producer.to_substrait_plan(df_logical_plan)
49+
substrait_plan = ss.Producer.to_substrait_plan(df_logical_plan)

examples/tpch/_tests.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -96,8 +96,9 @@ def test_tpch_query_vs_answer_file(query_code: str, answer_file: str):
9696
module = import_module(query_code)
9797
df = module.df
9898

99-
# Treat q17 as a special case. The answer file does not match the spec. Running at
100-
# scale factor 1, we have manually verified this result does match the expected value.
99+
# Treat q17 as a special case. The answer file does not match the spec.
100+
# Running at scale factor 1, we have manually verified this result does
101+
# match the expected value.
101102
if answer_file == "q17":
102103
return check_q17(df)
103104

pyproject.toml

+18
Original file line numberDiff line numberDiff line change
@@ -64,3 +64,21 @@ exclude = [".github/**", "ci/**", ".asf.yaml"]
6464
# Require Cargo.lock is up to date
6565
locked = true
6666
features = ["substrait"]
67+
68+
# Enable docstring linting using the google style guide
69+
[tool.ruff.lint]
70+
select = ["E4", "E7", "E9", "F", "D", "W"]
71+
72+
[tool.ruff.lint.pydocstyle]
73+
convention = "google"
74+
75+
[tool.ruff.lint.pycodestyle]
76+
max-doc-length = 88
77+
78+
# Disable docstring checking for these directories
79+
[tool.ruff.lint.per-file-ignores]
80+
"python/datafusion/tests/*" = ["D"]
81+
"examples/*" = ["D", "W505"]
82+
"dev/*" = ["D"]
83+
"benchmarks/*" = ["D", "F"]
84+
"docs/*" = ["D"]

python/datafusion/__init__.py

+19-151
Original file line numberDiff line numberDiff line change
@@ -15,206 +15,74 @@
1515
# specific language governing permissions and limitations
1616
# under the License.
1717

18-
from abc import ABCMeta, abstractmethod
19-
from typing import List
18+
"""DataFusion python package.
19+
20+
This is a Python library that binds to Apache Arrow in-memory query engine DataFusion.
21+
See https://datafusion.apache.org/python for more information.
22+
"""
2023

2124
try:
2225
import importlib.metadata as importlib_metadata
2326
except ImportError:
2427
import importlib_metadata
2528

26-
import pyarrow as pa
27-
28-
from ._internal import (
29-
AggregateUDF,
30-
Config,
31-
DataFrame,
29+
from .context import (
3230
SessionContext,
3331
SessionConfig,
3432
RuntimeConfig,
35-
ScalarUDF,
3633
SQLOptions,
3734
)
3835

36+
# The following imports are okay to remain as opaque to the user.
37+
from ._internal import Config
38+
39+
from .udf import ScalarUDF, AggregateUDF, Accumulator
40+
3941
from .common import (
4042
DFSchema,
4143
)
4244

45+
from .dataframe import DataFrame
46+
4347
from .expr import (
44-
Alias,
45-
Analyze,
4648
Expr,
47-
Filter,
48-
Limit,
49-
Like,
50-
ILike,
51-
Projection,
52-
SimilarTo,
53-
ScalarVariable,
54-
Sort,
55-
TableScan,
56-
Not,
57-
IsNotNull,
58-
IsTrue,
59-
IsFalse,
60-
IsUnknown,
61-
IsNotTrue,
62-
IsNotFalse,
63-
IsNotUnknown,
64-
Negative,
65-
InList,
66-
Exists,
67-
Subquery,
68-
InSubquery,
69-
ScalarSubquery,
70-
GroupingSet,
71-
Placeholder,
72-
Case,
73-
Cast,
74-
TryCast,
75-
Between,
76-
Explain,
77-
CreateMemoryTable,
78-
SubqueryAlias,
79-
Extension,
80-
CreateView,
81-
Distinct,
82-
DropTable,
83-
Repartition,
84-
Partitioning,
85-
Window,
8649
WindowFrame,
8750
)
8851

8952
__version__ = importlib_metadata.version(__name__)
9053

9154
__all__ = [
55+
"Accumulator",
9256
"Config",
9357
"DataFrame",
9458
"SessionContext",
9559
"SessionConfig",
9660
"SQLOptions",
9761
"RuntimeConfig",
9862
"Expr",
99-
"AggregateUDF",
10063
"ScalarUDF",
101-
"Window",
10264
"WindowFrame",
10365
"column",
10466
"literal",
105-
"TableScan",
106-
"Projection",
10767
"DFSchema",
108-
"DFField",
109-
"Analyze",
110-
"Sort",
111-
"Limit",
112-
"Filter",
113-
"Like",
114-
"ILike",
115-
"SimilarTo",
116-
"ScalarVariable",
117-
"Alias",
118-
"Not",
119-
"IsNotNull",
120-
"IsTrue",
121-
"IsFalse",
122-
"IsUnknown",
123-
"IsNotTrue",
124-
"IsNotFalse",
125-
"IsNotUnknown",
126-
"Negative",
127-
"ScalarFunction",
128-
"BuiltinScalarFunction",
129-
"InList",
130-
"Exists",
131-
"Subquery",
132-
"InSubquery",
133-
"ScalarSubquery",
134-
"GroupingSet",
135-
"Placeholder",
136-
"Case",
137-
"Cast",
138-
"TryCast",
139-
"Between",
140-
"Explain",
141-
"SubqueryAlias",
142-
"Extension",
143-
"CreateMemoryTable",
144-
"CreateView",
145-
"Distinct",
146-
"DropTable",
147-
"Repartition",
148-
"Partitioning",
14968
]
15069

15170

152-
class Accumulator(metaclass=ABCMeta):
153-
@abstractmethod
154-
def state(self) -> List[pa.Scalar]:
155-
pass
156-
157-
@abstractmethod
158-
def update(self, values: pa.Array) -> None:
159-
pass
160-
161-
@abstractmethod
162-
def merge(self, states: pa.Array) -> None:
163-
pass
164-
165-
@abstractmethod
166-
def evaluate(self) -> pa.Scalar:
167-
pass
168-
169-
170-
def column(value):
71+
def column(value: str):
72+
"""Create a column expression."""
17173
return Expr.column(value)
17274

17375

17476
col = column
17577

17678

17779
def literal(value):
178-
if not isinstance(value, pa.Scalar):
179-
value = pa.scalar(value)
80+
"""Create a literal expression."""
18081
return Expr.literal(value)
18182

18283

18384
lit = literal
18485

86+
udf = ScalarUDF.udf
18587

186-
def udf(func, input_types, return_type, volatility, name=None):
187-
"""
188-
Create a new User Defined Function
189-
"""
190-
if not callable(func):
191-
raise TypeError("`func` argument must be callable")
192-
if name is None:
193-
name = func.__qualname__.lower()
194-
return ScalarUDF(
195-
name=name,
196-
func=func,
197-
input_types=input_types,
198-
return_type=return_type,
199-
volatility=volatility,
200-
)
201-
202-
203-
def udaf(accum, input_type, return_type, state_type, volatility, name=None):
204-
"""
205-
Create a new User Defined Aggregate Function
206-
"""
207-
if not issubclass(accum, Accumulator):
208-
raise TypeError("`accum` must implement the abstract base class Accumulator")
209-
if name is None:
210-
name = accum.__qualname__.lower()
211-
if isinstance(input_type, pa.lib.DataType):
212-
input_type = [input_type]
213-
return AggregateUDF(
214-
name=name,
215-
accumulator=accum,
216-
input_type=input_type,
217-
return_type=return_type,
218-
state_type=state_type,
219-
volatility=volatility,
220-
)
88+
udaf = AggregateUDF.udaf

0 commit comments

Comments
 (0)