Skip to content

Commit

Permalink
fix(duckdb): thread udf parameters through
Browse files Browse the repository at this point in the history
  • Loading branch information
cpcloud committed Jan 21, 2025
1 parent 9ef82fa commit 876590e
Show file tree
Hide file tree
Showing 4 changed files with 69 additions and 48 deletions.
3 changes: 1 addition & 2 deletions ibis/backends/tests/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,13 +124,12 @@
from psycopg2.errors import ProgrammingError as PsycoPg2ProgrammingError
from psycopg2.errors import SyntaxError as PsycoPg2SyntaxError
from psycopg2.errors import UndefinedObject as PsycoPg2UndefinedObject
from psycopg2.errors import UniqueViolation as PsycoPg2UniqueViolation
except ImportError:
PsycoPg2SyntaxError = PsycoPg2IndeterminateDatatype = (
PsycoPg2InvalidTextRepresentation
) = PsycoPg2DivisionByZero = PsycoPg2InternalError = PsycoPg2ProgrammingError = (
PsycoPg2OperationalError
) = PsycoPg2UndefinedObject = PsycoPg2ArraySubscriptError = PsycoPg2UniqueViolation = None
) = PsycoPg2UndefinedObject = PsycoPg2ArraySubscriptError = None

try:
from psycopg.errors import ArraySubscriptError as PsycoPgArraySubscriptError
Expand Down
106 changes: 61 additions & 45 deletions ibis/backends/tests/test_impure.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,27 +2,20 @@

import sys

import pandas.testing as tm
import pytest

import ibis
import ibis.common.exceptions as com
from ibis import _
from ibis.backends.tests.errors import (
PsycoPg2InternalError,
Py4JJavaError,
PyDruidProgrammingError,
)
from ibis.backends.tests.errors import Py4JJavaError

tm = pytest.importorskip("pandas.testing")

pytestmark = pytest.mark.xdist_group("impure")

no_randoms = [
pytest.mark.notimpl(
["dask", "pandas", "polars"], raises=com.OperationNotDefinedError
),
pytest.mark.notimpl("druid", raises=PyDruidProgrammingError),
pytest.mark.notyet(
"risingwave",
raises=PsycoPg2InternalError,
reason="function random() does not exist",
["polars", "druid", "risingwave"], raises=com.OperationNotDefinedError
),
]

Expand All @@ -32,19 +25,16 @@
[
"bigquery",
"clickhouse",
"dask",
"druid",
"exasol",
"impala",
"mssql",
"mysql",
"oracle",
"pandas",
"trino",
"risingwave",
]
),
pytest.mark.notimpl("pyspark", reason="only supports pandas UDFs"),
pytest.mark.notyet(
"flink",
condition=sys.version_info >= (3, 11),
Expand All @@ -55,16 +45,7 @@

no_uuids = [
pytest.mark.notimpl(
[
"druid",
"exasol",
"oracle",
"polars",
"pyspark",
"risingwave",
"pandas",
"dask",
],
["druid", "exasol", "oracle", "polars", "pyspark", "risingwave"],
raises=com.OperationNotDefinedError,
),
pytest.mark.notyet("mssql", reason="Unrelated bug: Incorrect syntax near '('"),
Expand All @@ -82,11 +63,7 @@ def my_random(x: float) -> float:
mark_impures = pytest.mark.parametrize(
"impure",
[
pytest.param(
lambda _: ibis.random(),
marks=no_randoms,
id="random",
),
pytest.param(lambda _: ibis.random(), marks=no_randoms, id="random"),
pytest.param(
lambda _: ibis.uuid().cast(str).contains("a").ifelse(1, 0),
marks=[
Expand All @@ -107,6 +84,7 @@ def my_random(x: float) -> float:
)


# You can work around this by .cache()ing the table.
@pytest.mark.notyet("sqlite", reason="instances are uncorrelated")
@mark_impures
def test_impure_correlated(alltypes, impure):
Expand All @@ -120,14 +98,12 @@ def test_impure_correlated(alltypes, impure):
# t AS (SELECT random() AS common)
# SELECT common as x, common as y FROM t
# Then both x and y should have the same value.
df = (
alltypes.select(common=impure(alltypes))
.select(x=_.common, y=_.common)
.execute()
)
expr = alltypes.select(common=impure(alltypes)).select(x=_.common, y=_.common)
df = expr.execute()
tm.assert_series_equal(df.x, df.y, check_names=False)


# You can work around this by .cache()ing the table.
@pytest.mark.notyet("sqlite", reason="instances are uncorrelated")
@mark_impures
def test_chained_selections(alltypes, impure):
Expand All @@ -153,9 +129,7 @@ def test_chained_selections(alltypes, impure):
lambda _: ibis.random(),
marks=[
*no_randoms,
pytest.mark.notyet(
["impala", "trino"], reason="instances are correlated"
),
pytest.mark.notyet(["impala"], reason="instances are correlated"),
],
id="random",
),
Expand All @@ -164,24 +138,24 @@ def test_chained_selections(alltypes, impure):
lambda _: ibis.uuid().cast(str).contains("a").ifelse(1, 0),
marks=[
*no_uuids,
pytest.mark.notyet(
["mysql", "trino"], reason="instances are correlated"
),
pytest.mark.notyet(["mysql"], reason="instances are correlated"),
],
id="uuid",
),
pytest.param(
lambda table: my_random(table.float_col),
marks=[
*no_udfs,
pytest.mark.notyet("duckdb", reason="instances are correlated"),
# no "impure" argument for pyspark yet
pytest.mark.notimpl("pyspark"),
],
id="udf",
),
],
)


# You can work around this by doing .select().cache().select()
@pytest.mark.notyet(["clickhouse"], reason="instances are correlated")
@impure_params_uncorrelated
def test_impure_uncorrelated_different_id(alltypes, impure):
Expand All @@ -191,15 +165,57 @@ def test_impure_uncorrelated_different_id(alltypes, impure):
# eg if you look at the following SQL:
# select random() as x, random() as y
# Then x and y should be uncorrelated.
df = alltypes.select(x=impure(alltypes), y=impure(alltypes)).execute()
expr = alltypes.select(x=impure(alltypes), y=impure(alltypes))
df = expr.execute()
assert (df.x != df.y).any()


# You can work around this by doing .select().cache().select()
@pytest.mark.notyet(["clickhouse"], reason="instances are correlated")
@impure_params_uncorrelated
def test_impure_uncorrelated_same_id(alltypes, impure):
# Similar to test_impure_uncorrelated_different_id, but the two expressions
# have the same ID. Still, they should be uncorrelated.
common = impure(alltypes)
df = alltypes.select(x=common, y=common).execute()
expr = alltypes.select(x=common, y=common)
df = expr.execute()
assert (df.x != df.y).any()


@pytest.mark.notyet(
[
"duckdb",
"clickhouse",
"datafusion",
"mysql",
"impala",
"mssql",
"trino",
"flink",
"bigquery",
],
raises=AssertionError,
reason="instances are not correlated but ideally they would be",
)
@pytest.mark.notyet(
["sqlite"],
raises=AssertionError,
reason="instances are *sometimes* correlated but ideally they would always be",
strict=False,
)
@pytest.mark.notimpl(
["polars", "risingwave", "druid", "exasol", "oracle", "pyspark"],
raises=com.OperationNotDefinedError,
)
def test_self_join_with_generated_keys(con):
# Even with CTEs in the generated SQL, the backends still
# materialize a new value every time it is referenced.
# This isn't ideal behavior, but there is nothing we can do about it
# on the ibis side. The best you can do is to .cache() the table
# right after you assign the uuid().
# https://github.com/ibis-project/ibis/pull/9014#issuecomment-2399449665
left = ibis.memtable({"idx": list(range(5))}).mutate(key=ibis.uuid())
right = left.filter(left.idx < 3)
expr = left.join(right, "key")
result = con.execute(expr.count())
assert result == 3
6 changes: 5 additions & 1 deletion ibis/expr/decompile.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@
ops.StringContains: "contains",
ops.StringSQLILike: "ilike",
ops.StringSQLLike: "like",
ops.TimestampNow: "now",
}


Expand Down Expand Up @@ -84,6 +83,11 @@ def translate(op, *args, **kwargs):
raise NotImplementedError(op)


@translate.register(ops.TimestampNow)
def now(_):
return "ibis.now()"


@translate.register(ops.Value)
def value(op, *args, **kwargs):
method = _get_method_name(op)
Expand Down
2 changes: 2 additions & 0 deletions ibis/expr/operations/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,13 +193,15 @@ class TimestampNow(Impure):
"""Return the current timestamp."""

dtype = dt.timestamp
shape = ds.scalar


@public
class DateNow(Impure):
"""Return the current date."""

dtype = dt.date
shape = ds.scalar


@public
Expand Down

0 comments on commit 876590e

Please # to comment.