diff --git a/integration_tests/src/main/python/json_tuple_test.py b/integration_tests/src/main/python/json_tuple_test.py index 7de605dda8a..22705585445 100644 --- a/integration_tests/src/main/python/json_tuple_test.py +++ b/integration_tests/src/main/python/json_tuple_test.py @@ -1,4 +1,4 @@ -# Copyright (c) 2023-2024, NVIDIA CORPORATION. +# Copyright (c) 2023-2025, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -26,6 +26,20 @@ def mk_json_str_gen(pattern): r'"email":"[a-z]{1,5}\@[a-z]{3,10}\.com","owner":"[a-z]{3,8}"\}', r'\{"a": "[a-z]{1,3}", "b\$":"[b-z]{1,3}"\}'] +json_int_dict_patterns = [r'\{"1": [1-9]{1,6}, "2": -[1-9]{1,6}, ' \ + r'"3": \{ "[1-9]{1,6}": [1-9]{1,6}, "-[1,10]": [1-9]{1,6}, '\ + r'"-45": -[1-9]{1,6}\}\}'] + +json_whitespace_patterns = [r'\{"\\r\\n":"value\\n!", ' \ + r'"cheddar\rcheese":"\\n[a-z]{0,10}\\r!", ' \ + r'"fried\\nchicken":"[a-z]{0,2}\\n[a-z]{0,10}\\r[a-z]{0,2}!",' \ + r'"fish":"salmon\\r\\ncod"\}'] + +json_eol_garbage_patterns = [r'\{"store":"Albertsons"\}this should not break', + r'\{"1":2\} freedom', + r'\{"email":gmail@outlook.com, "2":-5\}gmail better'] + + @pytest.mark.parametrize('json_str_pattern', json_str_patterns, ids=idfn) def test_json_tuple(json_str_pattern): gen = mk_json_str_gen(json_str_pattern) @@ -35,6 +49,33 @@ def test_json_tuple(json_str_pattern): conf={'spark.sql.parser.escapedStringLiterals': 'true', 'spark.rapids.sql.expression.JsonTuple': 'true'}) +@pytest.mark.parametrize('json_int_dict_pattern', json_int_dict_patterns, ids=idfn) +def test_int_dict_json(json_int_dict_pattern): + gen = mk_json_str_gen(json_int_dict_pattern) + assert_gpu_and_cpu_are_equal_collect( + lambda spark: unary_op_df(spark, gen, length=10).selectExpr( + 'json_tuple(a, "1", "2", "-45", "3.-45")'), + conf={'spark.sql.parser.escapedStringLiterals': 'true', + 'spark.rapids.sql.expression.JsonTuple': 'true'}) + +@pytest.mark.parametrize('json_whitespace_pattern', json_whitespace_patterns, ids=idfn) +def test_whitespace_json(json_whitespace_pattern): + gen = mk_json_str_gen(json_whitespace_pattern) + assert_gpu_and_cpu_are_equal_collect( + lambda spark: unary_op_df(spark, gen, length=10).selectExpr( + 'json_tuple(a, "\\r\\n", "fish", "fried\\nchicken", "cheddar\\rcheese")'), + conf={'spark.sql.parser.escapedStringLiterals': 'true', + 'spark.rapids.sql.expression.JsonTuple': 'true'}) + +@pytest.mark.parametrize('json_eol_garbage_pattern', json_eol_garbage_patterns, ids=idfn) +def test_json_eol_garbage_json(json_eol_garbage_pattern): + gen = mk_json_str_gen(json_eol_garbage_pattern) + assert_gpu_and_cpu_are_equal_collect( + lambda spark: unary_op_df(spark, gen, length=10).selectExpr( + 'json_tuple(a, "store", "email", "1", "cheddar\\rcheese")'), + conf={'spark.sql.parser.escapedStringLiterals': 'true', + 'spark.rapids.sql.expression.JsonTuple': 'true'}) + def test_json_tuple_select_non_generator_col(): gen = StringGen(pattern="{\"Zipcode\":\"abc\",\"ZipCodeType\":\"STANDARD\",\"City\":\"PARC PARQUE\",\"State\":\"PR\"}") assert_gpu_and_cpu_are_equal_sql(