From 036705c58f160b2a6ad3c1a7497de5f5e8eb0951 Mon Sep 17 00:00:00 2001 From: Raza Jafri Date: Fri, 3 Jun 2022 11:34:06 -0700 Subject: [PATCH 1/4] Write Scala test for ORC encrypted write Signed-off-by: Raza Jafri --- .../src/main/python/orc_write_test.py | 27 ------- tests/pom.xml | 76 +++++++++++++++++++ .../spark/rapids/OrcEncryptionSuite.scala | 48 ++++++++++++ 3 files changed, 124 insertions(+), 27 deletions(-) create mode 100644 tests/src/test/320+/scala/com/nvidia/spark/rapids/OrcEncryptionSuite.scala diff --git a/integration_tests/src/main/python/orc_write_test.py b/integration_tests/src/main/python/orc_write_test.py index 2d58d198903..7afe00076a3 100644 --- a/integration_tests/src/main/python/orc_write_test.py +++ b/integration_tests/src/main/python/orc_write_test.py @@ -167,30 +167,3 @@ def create_empty_df(spark, path): lambda spark, path: spark.read.orc(path), data_path, conf={'spark.rapids.sql.format.orc.write.enabled': True}) - -@allow_non_gpu('DataWritingCommandExec') -@pytest.mark.parametrize("path", ["", "kms://http@localhost:9600/kms"]) -@pytest.mark.parametrize("provider", ["", "hadoop"]) -@pytest.mark.parametrize("encrypt", ["", "pii:a"]) -@pytest.mark.parametrize("mask", ["", "sha256:a"]) -@pytest.mark.skipif(is_databricks104_or_later(), reason="The test will fail on Databricks10.4 because `HadoopShimsPre2_3$NullKeyProvider` is loaded") -def test_orc_write_encryption_fallback(spark_tmp_path, spark_tmp_table_factory, path, provider, encrypt, mask): - def write_func(spark, write_path): - writer = unary_op_df(spark, gen).coalesce(1).write - if path != "": - writer.option("hadoop.security.key.provider.path", path) - if provider != "": - writer.option("orc.key.provider", provider) - if encrypt != "": - writer.option("orc.encrypt", encrypt) - if mask != "": - writer.option("orc.mask", mask) - writer.format("orc").mode('overwrite').option("path", write_path).saveAsTable(spark_tmp_table_factory.get()) - if path == "" and provider == "" and encrypt == "" and mask == "": - pytest.skip("Skip this test when none of the encryption confs are set") - gen = IntegerGen() - data_path = spark_tmp_path + '/ORC_DATA' - assert_gpu_fallback_write(write_func, - lambda spark, path: spark.read.orc(path), - data_path, - 'DataWritingCommandExec') diff --git a/tests/pom.xml b/tests/pom.xml index 4026d6865d7..e19039dc61e 100644 --- a/tests/pom.xml +++ b/tests/pom.xml @@ -148,6 +148,25 @@ + + + + org.codehaus.mojo + build-helper-maven-plugin + + + add-321cdh-test-src + add-test-source + + + ${project.basedir}/src/test/320+/scala + + + + + + +