From 864076475b9f42d1bd13e936b52765a03104e033 Mon Sep 17 00:00:00 2001 From: Joris Roovers Date: Fri, 31 Mar 2023 02:33:01 +0200 Subject: [PATCH] Support for SqlMagic.polars_dataframe_kwargs (#325) * Support for SqlMagic.polars_dataframe_kwargs Allows for passing of custom keyword arguments to the Polars DataFrame constructor. Fixes #312 * Improvements for SqlMagic.polars_dataframe_kwargs - Updated docs - Fixed failing tests - Passing polars_dataframe_kwargs as kwargs instead of dict * Black formatting fixes --- CHANGELOG.md | 2 ++ doc/api/configuration.md | 28 ++++++++++++++++++++++++++++ src/sql/magic.py | 10 +++++++++- src/sql/run.py | 8 +++++--- src/tests/test_magic.py | 36 ++++++++++++++++++++++++++++++++---- src/tests/test_run.py | 1 + 6 files changed, 77 insertions(+), 8 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index f72ec29c8..3f0c6a823 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -14,6 +14,8 @@ * [Fix] Clearer error when using `--with` with snippets that do not exist (#257) * [Fix] Pytds now automatically compatible * [Doc] SQL keywords autocompletion +* [Feature] Adds `%%config SqlMagic.polars_dataframe_kwargs = {...}` +* [Fix] Jupysql with autopolars crashes when schema cannot be inferred from the first 100 rows ([#312](https://github.com/ploomber/jupysql/issues/312)) ## 0.6.6 (2023-03-16) diff --git a/doc/api/configuration.md b/doc/api/configuration.md index 6e3673325..52f569424 100644 --- a/doc/api/configuration.md +++ b/doc/api/configuration.md @@ -161,6 +161,34 @@ df = %sql SELECT * FROM languages type(df) ``` +## `polars_dataframe_kwargs` + +Default: `{}` + +Polars [DataFrame constructor](https://pola-rs.github.io/polars/py-polars/html/reference/dataframe/index.html) keyword arguments (e.g. infer_schema_length, nan_to_null, schema_overrides, etc) + +```{code-cell} ipython3 +# By default Polars will only look at the first 100 rows to infer schema +# Disable this limit by setting infer_schema_length to None +%config SqlMagic.polars_dataframe_kwargs = { "infer_schema_length": None} + +# Create a table with 101 rows, last row has a string which will cause the +# column type to be inferred as a string (rather than crashing polars) +%sql CREATE TABLE points (x, y); +insert_stmt = "" +for _ in range(100): + insert_stmt += "INSERT INTO points VALUES (1, 2);" +%sql {{insert_stmt}} +%sql INSERT INTO points VALUES (1, "foo"); + + +%sql SELECT * FROM points +``` +To unset: +```{code-cell} ipython3 +%config SqlMagic.polars_dataframe_kwargs = {} +``` + ## `feedback` Default: `True` diff --git a/src/sql/magic.py b/src/sql/magic.py index 76250aa0f..ad1b08c18 100644 --- a/src/sql/magic.py +++ b/src/sql/magic.py @@ -28,7 +28,7 @@ from ploomber_core.dependencies import check_installed from traitlets.config.configurable import Configurable -from traitlets import Bool, Int, Unicode, observe +from traitlets import Bool, Int, Unicode, Dict, observe try: from pandas.core.frame import DataFrame, Series @@ -110,6 +110,14 @@ class SqlMagic(Magics, Configurable): config=True, help="Return Polars DataFrames instead of regular result sets", ) + polars_dataframe_kwargs = Dict( + {}, + config=True, + help=( + "Polars DataFrame constructor keyword arguments" + "(e.g. infer_schema_length, nan_to_null, schema_overrides, etc)" + ), + ) column_local_vars = Bool( False, config=True, help="Return data into local variables from column names" ) diff --git a/src/sql/run.py b/src/sql/run.py index 3906f568d..bc565060b 100644 --- a/src/sql/run.py +++ b/src/sql/run.py @@ -184,11 +184,13 @@ def DataFrame(self, payload): return frame @telemetry.log_call("polars-data-frame") - def PolarsDataFrame(self): + def PolarsDataFrame(self, **polars_dataframe_kwargs): "Returns a Polars DataFrame instance built from the result set." import polars as pl - frame = pl.DataFrame((tuple(row) for row in self), schema=self.keys) + frame = pl.DataFrame( + (tuple(row) for row in self), schema=self.keys, **polars_dataframe_kwargs + ) return frame @telemetry.log_call("pie") @@ -444,7 +446,7 @@ def select_df_type(resultset, config): if config.autopandas: return resultset.DataFrame() elif config.autopolars: - return resultset.PolarsDataFrame() + return resultset.PolarsDataFrame(**config.polars_dataframe_kwargs) else: return resultset # returning only last result, intentionally diff --git a/src/tests/test_magic.py b/src/tests/test_magic.py index 578c1b43a..76a57422b 100644 --- a/src/tests/test_magic.py +++ b/src/tests/test_magic.py @@ -7,6 +7,7 @@ from textwrap import dedent from unittest.mock import patch +import polars as pl import pytest from sqlalchemy import create_engine from IPython.core.error import UsageError @@ -277,20 +278,47 @@ def test_autopolars(ip): ip.run_line_magic("config", "SqlMagic.autopolars = True") dframe = runsql(ip, "SELECT * FROM test;") - import polars as pl - assert type(dframe) == pl.DataFrame assert not dframe.is_empty() assert len(dframe.shape) == 2 assert dframe["name"][0] == "foo" +def test_autopolars_infer_schema_length(ip): + """Test for `SqlMagic.polars_dataframe_kwargs = {"infer_schema_length": None}` + Without this config, polars will raise an exception when it cannot infer the + correct schema from the first 100 rows. + """ + # Create a table with 100 rows with a NULL value and one row with a non-NULL value + ip.run_line_magic("config", "SqlMagic.autopolars = True") + sql = ["CREATE TABLE test_autopolars_infer_schema (n INT, name TEXT)"] + for i in range(100): + sql.append(f"INSERT INTO test_autopolars_infer_schema VALUES ({i}, NULL)") + sql.append("INSERT INTO test_autopolars_infer_schema VALUES (100, 'foo')") + runsql(ip, sql) + + # By default, this dataset should raise a ComputeError + with pytest.raises(pl.exceptions.ComputeError): + runsql(ip, "SELECT * FROM test_autopolars_infer_schema;") + + # To avoid this error, pass the `infer_schema_length` argument to polars.DataFrame + line_magic = 'SqlMagic.polars_dataframe_kwargs = {"infer_schema_length": None}' + ip.run_line_magic("config", line_magic) + dframe = runsql(ip, "SELECT * FROM test_autopolars_infer_schema;") + assert dframe.schema == {"n": pl.Int64, "name": pl.Utf8} + + # Assert that if we unset the dataframe kwargs, the error is raised again + ip.run_line_magic("config", "SqlMagic.polars_dataframe_kwargs = {}") + with pytest.raises(pl.exceptions.ComputeError): + runsql(ip, "SELECT * FROM test_autopolars_infer_schema;") + + runsql(ip, "DROP TABLE test_autopolars_infer_schema") + + def test_mutex_autopolars_autopandas(ip): dframe = runsql(ip, "SELECT * FROM test;") assert type(dframe) == ResultSet - import polars as pl - ip.run_line_magic("config", "SqlMagic.autopolars = True") dframe = runsql(ip, "SELECT * FROM test;") assert type(dframe) == pl.DataFrame diff --git a/src/tests/test_run.py b/src/tests/test_run.py index 168948655..f1eaf5598 100644 --- a/src/tests/test_run.py +++ b/src/tests/test_run.py @@ -32,6 +32,7 @@ class Config: autopolars = None autocommit = True feedback = True + polars_dataframe_kwargs = {} return Config