Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Fix ANSI mode test failures in url_test.py [databricks] #11194

Merged
merged 1 commit into from
Jul 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions integration_tests/src/main/python/data_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -1164,6 +1164,7 @@ def gen_scalars_for_sql(data_gen, count, seed=None, force_no_nulls=False):
nested_gens_sample = array_gens_sample + struct_gens_sample_with_decimal128 + map_gens_sample + decimal_128_map_gens

ansi_enabled_conf = {'spark.sql.ansi.enabled': 'true'}
ansi_disabled_conf = {'spark.sql.ansi.enabled': 'false'}
legacy_interval_enabled_conf = {'spark.sql.legacy.interval.enabled': 'true'}

def copy_and_update(conf, *more_confs):
Expand Down
36 changes: 30 additions & 6 deletions integration_tests/src/main/python/url_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,9 @@
@pytest.mark.parametrize('part', supported_parts, ids=idfn)
def test_parse_url_supported(data_gen, part):
assert_gpu_and_cpu_are_equal_collect(
lambda spark: unary_op_df(spark, data_gen).selectExpr("a", "parse_url(a, '" + part + "')"))
lambda spark: unary_op_df(spark, data_gen).selectExpr("a", "parse_url(a, '" + part + "')"),
ansi_disabled_conf # ANSI mode failures are tested in test_parse_url_query_ansi_mode.
)

@allow_non_gpu('ProjectExec', 'ParseUrl')
@pytest.mark.parametrize('part', unsupported_parts, ids=idfn)
Expand All @@ -168,16 +170,36 @@ def test_parse_url_query_with_key():
url_gen = StringGen(url_pattern_with_key)
assert_gpu_and_cpu_are_equal_collect(
lambda spark: unary_op_df(spark, url_gen)
.selectExpr("a", "parse_url(a, 'QUERY', 'abc')", "parse_url(a, 'QUERY', 'a')")
)
.selectExpr("a", "parse_url(a, 'QUERY', 'abc')", "parse_url(a, 'QUERY', 'a')"),
ansi_disabled_conf # ANSI mode failures are tested in test_parse_url_query_ansi_mode.
)


@pytest.mark.xfail(reason="https://github.com/NVIDIA/spark-rapids/issues/11193")
def test_parse_url_query_ansi_mode():
"""
This tests parse_url()'s behaviour when ANSI mode is enabled.
Specifically, the query is expected to fail with an error, if parse_url() fails
in ANSI mode.
This test currently xfails because "fail on error" is not currently supported
for parse_url().
"""
url_gen = StringGen(url_pattern_with_key)
assert_gpu_and_cpu_are_equal_collect(
lambda spark: unary_op_df(spark, url_gen)
.selectExpr("a", "parse_url(a, 'QUERY', 'abc')", "parse_url(a, 'QUERY', 'a')"),
conf = ansi_enabled_conf
)


def test_parse_url_query_with_key_column():
url_gen = StringGen(url_pattern_with_key)
key_gen = StringGen('[a-d]{1,3}')
assert_gpu_and_cpu_are_equal_collect(
lambda spark: two_col_df(spark, url_gen, key_gen)
.selectExpr("a", "parse_url(a, 'QUERY', b)")
)
.selectExpr("a", "parse_url(a, 'QUERY', b)"),
ansi_disabled_conf # ANSI mode failures are tested in test_parse_url_query_ansi_mode.
)

@pytest.mark.parametrize('key', ['a?c', '*'], ids=idfn)
@allow_non_gpu('ProjectExec', 'ParseUrl')
Expand All @@ -191,7 +213,9 @@ def test_parse_url_query_with_key_regex_fallback(key):
@pytest.mark.parametrize('part', supported_parts, ids=idfn)
def test_parse_url_with_key(part):
assert_gpu_and_cpu_are_equal_collect(
lambda spark: unary_op_df(spark, url_gen).selectExpr("parse_url(a, '" + part + "', 'key')"))
lambda spark: unary_op_df(spark, url_gen).selectExpr("parse_url(a, '" + part + "', 'key')"),
ansi_disabled_conf # ANSI mode failures are tested in test_parse_url_query_ansi_mode.
)

@allow_non_gpu('ProjectExec', 'ParseUrl')
@pytest.mark.parametrize('part', unsupported_parts, ids=idfn)
Expand Down