From 7cf08b96f6a467c7c78a5827bb477063d35e2512 Mon Sep 17 00:00:00 2001 From: Jason Lowe Date: Tue, 26 Apr 2022 10:29:18 -0500 Subject: [PATCH 01/36] GPU accelerated reads for Apache Iceberg Signed-off-by: Jason Lowe --- NOTICE | 30 + pom.xml | 30 + sql-plugin/pom.xml | 12 + .../shims/ShimSupportsRuntimeFiltering.java | 30 + .../shims/ShimSupportsRuntimeFiltering.java | 26 + .../spark/rapids/CloseableIterator.java | 23 + .../rapids/iceberg/data/GpuDeleteFilter.java | 129 +++ .../spark/rapids/iceberg/orc/GpuORC.java | 119 +++ .../iceberg/parquet/ApplyNameMapping.java | 165 ++++ .../rapids/iceberg/parquet/GpuParquet.java | 166 ++++ .../iceberg/parquet/GpuParquetReader.java | 319 +++++++ .../iceberg/parquet/ParquetConversions.java | 116 +++ .../ParquetDictionaryRowGroupFilter.java | 454 ++++++++++ .../rapids/iceberg/parquet/ParquetIO.java | 95 ++ .../parquet/ParquetMetricsRowGroupFilter.java | 565 ++++++++++++ .../iceberg/parquet/ParquetSchemaUtil.java | 211 +++++ .../iceberg/parquet/ParquetTypeVisitor.java | 263 ++++++ .../rapids/iceberg/parquet/ParquetUtil.java | 331 +++++++ .../rapids/iceberg/parquet/PruneColumns.java | 172 ++++ .../parquet/TypeWithSchemaVisitor.java | 216 +++++ .../rapids/iceberg/spark/Spark3Util.java | 821 ++++++++++++++++++ .../rapids/iceberg/spark/SparkConfParser.java | 202 +++++ .../rapids/iceberg/spark/SparkFilters.java | 270 ++++++ .../rapids/iceberg/spark/SparkReadConf.java | 234 +++++ .../iceberg/spark/SparkReadOptions.java | 73 ++ .../iceberg/spark/SparkSQLProperties.java | 43 + .../rapids/iceberg/spark/SparkSchemaUtil.java | 365 ++++++++ .../rapids/iceberg/spark/SparkTypeToType.java | 162 ++++ .../iceberg/spark/SparkTypeVisitor.java | 82 ++ .../spark/rapids/iceberg/spark/SparkUtil.java | 221 +++++ .../rapids/iceberg/spark/TypeToSparkType.java | 124 +++ .../iceberg/spark/source/BaseDataReader.java | 202 +++++ .../spark/source/GpuBatchDataReader.java | 151 ++++ .../spark/source/GpuIcebergReader.java | 170 ++++ .../spark/source/GpuSparkBatchQueryScan.java | 289 ++++++ .../iceberg/spark/source/GpuSparkScan.java | 289 ++++++ .../iceberg/spark/source/SparkBatch.java | 124 +++ .../rapids/iceberg/spark/source/Stats.java | 41 + .../com/nvidia/spark/rapids/GpuCSVScan.scala | 3 + .../nvidia/spark/rapids/GpuParquetScan.scala | 41 +- .../nvidia/spark/rapids/GpuParquetUtils.scala | 125 +++ .../com/nvidia/spark/rapids/RapidsConf.scala | 2 +- .../spark/sql/rapids/ExternalSource.scala | 71 +- .../sql/rapids/execution/TrampolineUtil.scala | 4 +- 44 files changed, 7519 insertions(+), 62 deletions(-) create mode 100644 sql-plugin/src/main/311until320-all/java/com/nvidia/spark/rapids/shims/ShimSupportsRuntimeFiltering.java create mode 100644 sql-plugin/src/main/320+/java/com/nvidia/spark/rapids/shims/ShimSupportsRuntimeFiltering.java create mode 100644 sql-plugin/src/main/java/com/nvidia/spark/rapids/CloseableIterator.java create mode 100644 sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/data/GpuDeleteFilter.java create mode 100644 sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/orc/GpuORC.java create mode 100644 sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/parquet/ApplyNameMapping.java create mode 100644 sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/parquet/GpuParquet.java create mode 100644 sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/parquet/GpuParquetReader.java create mode 100644 sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/parquet/ParquetConversions.java create mode 100644 sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/parquet/ParquetDictionaryRowGroupFilter.java create mode 100644 sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/parquet/ParquetIO.java create mode 100644 sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/parquet/ParquetMetricsRowGroupFilter.java create mode 100644 sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/parquet/ParquetSchemaUtil.java create mode 100644 sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/parquet/ParquetTypeVisitor.java create mode 100644 sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/parquet/ParquetUtil.java create mode 100644 sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/parquet/PruneColumns.java create mode 100644 sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/parquet/TypeWithSchemaVisitor.java create mode 100644 sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/spark/Spark3Util.java create mode 100644 sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/spark/SparkConfParser.java create mode 100644 sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/spark/SparkFilters.java create mode 100644 sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/spark/SparkReadConf.java create mode 100644 sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/spark/SparkReadOptions.java create mode 100644 sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/spark/SparkSQLProperties.java create mode 100644 sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/spark/SparkSchemaUtil.java create mode 100644 sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/spark/SparkTypeToType.java create mode 100644 sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/spark/SparkTypeVisitor.java create mode 100644 sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/spark/SparkUtil.java create mode 100644 sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/spark/TypeToSparkType.java create mode 100644 sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/spark/source/BaseDataReader.java create mode 100644 sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/spark/source/GpuBatchDataReader.java create mode 100644 sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/spark/source/GpuIcebergReader.java create mode 100644 sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/spark/source/GpuSparkBatchQueryScan.java create mode 100644 sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/spark/source/GpuSparkScan.java create mode 100644 sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/spark/source/SparkBatch.java create mode 100644 sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/spark/source/Stats.java create mode 100644 sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuParquetUtils.scala diff --git a/NOTICE b/NOTICE index 5086dc62912..16b2b943861 100644 --- a/NOTICE +++ b/NOTICE @@ -1,6 +1,8 @@ RAPIDS plugin for Apache Spark Copyright (c) 2019-2022, NVIDIA CORPORATION +-------------------------------------------------------------------------------- + // ------------------------------------------------------------------ // NOTICE file corresponding to the section 4d of The Apache License, // Version 2.0, in this case for @@ -12,6 +14,34 @@ Copyright 2014 and onwards The Apache Software Foundation This product includes software developed at The Apache Software Foundation (http://www.apache.org/). +-------------------------------------------------------------------------------- + +Apache Iceberg +Copyright 2017-2022 The Apache Software Foundation + +This product includes software developed at +The Apache Software Foundation (http://www.apache.org/). + +-------------------------------------------------------------------------------- + +This project includes code from Kite, developed at Cloudera, Inc. with +the following copyright notice: + +| Copyright 2013 Cloudera Inc. +| +| Licensed under the Apache License, Version 2.0 (the "License"); +| you may not use this file except in compliance with the License. +| You may obtain a copy of the License at +| +| http://www.apache.org/licenses/LICENSE-2.0 +| +| Unless required by applicable law or agreed to in writing, software +| distributed under the License is distributed on an "AS IS" BASIS, +| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +| See the License for the specific language governing permissions and +| limitations under the License. + +-------------------------------------------------------------------------------- This product bundles various third-party components under other open source licenses. diff --git a/pom.xml b/pom.xml index ca853f8051e..a4047380618 100644 --- a/pom.xml +++ b/pom.xml @@ -115,6 +115,7 @@ ${project.basedir}/src/main/311-nondb/scala ${project.basedir}/src/main/311+-nondb/scala + ${project.basedir}/src/main/311until320-all/java ${project.basedir}/src/main/311until320-all/scala ${project.basedir}/src/main/311until320-noncdh/scala ${project.basedir}/src/main/311until320-nondb/scala @@ -180,6 +181,7 @@ ${project.basedir}/src/main/312db/scala ${project.basedir}/src/main/311+-db/scala + ${project.basedir}/src/main/311until320-all/java ${project.basedir}/src/main/311until320-all/scala ${project.basedir}/src/main/311until320-noncdh/scala ${project.basedir}/src/main/311until330-all/scala @@ -230,6 +232,7 @@ ${project.basedir}/src/main/311+-nondb/scala ${project.basedir}/src/main/312-nondb/scala ${project.basedir}/src/main/311+-nondb/scala + ${project.basedir}/src/main/311until320-all/java ${project.basedir}/src/main/311until320-all/scala ${project.basedir}/src/main/311until320-noncdh/scala ${project.basedir}/src/main/311until320-nondb/scala @@ -282,6 +285,7 @@ ${project.basedir}/src/main/313/scala ${project.basedir}/src/main/311+-nondb/scala + ${project.basedir}/src/main/311until320-all/java ${project.basedir}/src/main/311until320-all/scala ${project.basedir}/src/main/311until320-noncdh/scala ${project.basedir}/src/main/311until320-nondb/scala @@ -334,6 +338,7 @@ ${project.basedir}/src/main/314/scala ${project.basedir}/src/main/311+-nondb/scala + ${project.basedir}/src/main/311until320-all/java ${project.basedir}/src/main/311until320-all/scala ${project.basedir}/src/main/311until320-noncdh/scala ${project.basedir}/src/main/311until320-nondb/scala @@ -399,6 +404,7 @@ ${project.basedir}/src/main/311until330-all/scala ${project.basedir}/src/main/311until330-nondb/scala ${project.basedir}/src/main/320/scala + ${project.basedir}/src/main/320+/java ${project.basedir}/src/main/320+/scala ${project.basedir}/src/main/320+-nondb/scala ${project.basedir}/src/main/320+-noncdh/scala @@ -463,6 +469,7 @@ ${project.basedir}/src/main/311+-nondb/scala ${project.basedir}/src/main/311until330-all/scala ${project.basedir}/src/main/311until330-nondb/scala + ${project.basedir}/src/main/320+/java ${project.basedir}/src/main/320+/scala ${project.basedir}/src/main/320+-nondb/scala ${project.basedir}/src/main/320+-noncdh/scala @@ -529,6 +536,7 @@ ${project.basedir}/src/main/311+-nondb/scala ${project.basedir}/src/main/311until330-all/scala ${project.basedir}/src/main/311until330-nondb/scala + ${project.basedir}/src/main/320+/java ${project.basedir}/src/main/320+/scala ${project.basedir}/src/main/320+-nondb/scala ${project.basedir}/src/main/320until330-all/scala @@ -599,6 +607,7 @@ ${project.basedir}/src/main/311+-nondb/scala ${project.basedir}/src/main/311until330-all/scala ${project.basedir}/src/main/311until330-nondb/scala + ${project.basedir}/src/main/320+/java ${project.basedir}/src/main/320+/scala ${project.basedir}/src/main/320+-nondb/scala ${project.basedir}/src/main/320+-noncdh/scala @@ -677,6 +686,7 @@ ${project.basedir}/src/main/321db/scala ${project.basedir}/src/main/311until330-all/scala ${project.basedir}/src/main/311+-db/scala + ${project.basedir}/src/main/320+/java ${project.basedir}/src/main/320+/scala ${project.basedir}/src/main/320+-noncdh/scala ${project.basedir}/src/main/320until330-all/scala @@ -728,6 +738,7 @@ ${project.basedir}/src/main/330/scala ${project.basedir}/src/main/311+-nondb/scala + ${project.basedir}/src/main/320+/java ${project.basedir}/src/main/320+/scala ${project.basedir}/src/main/320+-nondb/scala ${project.basedir}/src/main/320+-noncdh/scala @@ -859,6 +870,7 @@ 1.7.30 1.11.0 3.3.1 + 0.13.1 org/scala-lang/scala-library/${scala.version}/scala-library-${scala.version}.jar ${spark.version.classifier} 3.1.0 @@ -890,6 +902,24 @@ ${scala.version} provided + + org.apache.iceberg + iceberg-api + ${iceberg.version} + provided + + + org.apache.iceberg + iceberg-bundled-guava + ${iceberg.version} + provided + + + org.apache.iceberg + iceberg-core + ${iceberg.version} + provided + org.apache.spark spark-annotation_${scala.binary.version} diff --git a/sql-plugin/pom.xml b/sql-plugin/pom.xml index 212410795cf..ddbc84e8682 100644 --- a/sql-plugin/pom.xml +++ b/sql-plugin/pom.xml @@ -56,6 +56,18 @@ com.google.flatbuffers flatbuffers-java + + org.apache.iceberg + iceberg-api + + + org.apache.iceberg + iceberg-bundled-guava + + + org.apache.iceberg + iceberg-core + org.apache.spark spark-avro_${scala.binary.version} diff --git a/sql-plugin/src/main/311until320-all/java/com/nvidia/spark/rapids/shims/ShimSupportsRuntimeFiltering.java b/sql-plugin/src/main/311until320-all/java/com/nvidia/spark/rapids/shims/ShimSupportsRuntimeFiltering.java new file mode 100644 index 00000000000..3c378f19d2d --- /dev/null +++ b/sql-plugin/src/main/311until320-all/java/com/nvidia/spark/rapids/shims/ShimSupportsRuntimeFiltering.java @@ -0,0 +1,30 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark.rapids.shims; + +import org.apache.spark.sql.connector.expressions.NamedReference; +import org.apache.spark.sql.sources.Filter; + +/** + * Shim interface for Apache Spark's SupportsRuntimeFiltering interface + * which was added in Spark 3.2.0. + */ +public interface ShimSupportsRuntimeFiltering { + NamedReference[] filterAttributes(); + + void filter(Filter[] filters); +} diff --git a/sql-plugin/src/main/320+/java/com/nvidia/spark/rapids/shims/ShimSupportsRuntimeFiltering.java b/sql-plugin/src/main/320+/java/com/nvidia/spark/rapids/shims/ShimSupportsRuntimeFiltering.java new file mode 100644 index 00000000000..1da2fd872ac --- /dev/null +++ b/sql-plugin/src/main/320+/java/com/nvidia/spark/rapids/shims/ShimSupportsRuntimeFiltering.java @@ -0,0 +1,26 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark.rapids.shims; + +import org.apache.spark.sql.connector.read.SupportsRuntimeFiltering; + +/** + * Shim interface for Apache Spark's SupportsRuntimeFiltering interface + * which was added in Spark 3.2.0. + */ +public interface ShimSupportsRuntimeFiltering extends SupportsRuntimeFiltering { +} diff --git a/sql-plugin/src/main/java/com/nvidia/spark/rapids/CloseableIterator.java b/sql-plugin/src/main/java/com/nvidia/spark/rapids/CloseableIterator.java new file mode 100644 index 00000000000..0864f7887d6 --- /dev/null +++ b/sql-plugin/src/main/java/com/nvidia/spark/rapids/CloseableIterator.java @@ -0,0 +1,23 @@ +/* + * Copyright (c) 2019-2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark.rapids; + +import java.io.Closeable; +import java.util.Iterator; + +public interface CloseableIterator extends Iterator, Closeable { +} diff --git a/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/data/GpuDeleteFilter.java b/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/data/GpuDeleteFilter.java new file mode 100644 index 00000000000..6573367d4aa --- /dev/null +++ b/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/data/GpuDeleteFilter.java @@ -0,0 +1,129 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark.rapids.iceberg.data; + +import java.util.List; +import java.util.Set; + +import org.apache.iceberg.DeleteFile; +import org.apache.iceberg.MetadataColumns; +import org.apache.iceberg.Schema; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.relocated.com.google.common.collect.Sets; +import org.apache.iceberg.types.TypeUtil; +import org.apache.iceberg.types.Types; + +/** + * GPU version of Apache Iceberg's DeleteFilter, operating on a ColumnarBatch and performing + * the row filtering on the GPU. + */ +public class GpuDeleteFilter { + private static final Schema POS_DELETE_SCHEMA = new Schema( + MetadataColumns.DELETE_FILE_PATH, + MetadataColumns.DELETE_FILE_POS); + + private final String filePath; + private final List posDeletes; + private final List eqDeletes; + private final Schema requiredSchema; + + public GpuDeleteFilter(String filePath, List deletes, + Schema tableSchema, Schema requestedSchema) { + this.filePath = filePath; + + ImmutableList.Builder posDeleteBuilder = ImmutableList.builder(); + ImmutableList.Builder eqDeleteBuilder = ImmutableList.builder(); + for (DeleteFile delete : deletes) { + switch (delete.content()) { + case POSITION_DELETES: + posDeleteBuilder.add(delete); + break; + case EQUALITY_DELETES: + eqDeleteBuilder.add(delete); + break; + default: + throw new UnsupportedOperationException("Unknown delete file content: " + delete.content()); + } + } + + this.posDeletes = posDeleteBuilder.build(); + this.eqDeletes = eqDeleteBuilder.build(); + this.requiredSchema = fileProjection(tableSchema, requestedSchema, posDeletes, eqDeletes); + } + + public Schema requiredSchema() { + return requiredSchema; + } + + public boolean hasPosDeletes() { + return !posDeletes.isEmpty(); + } + + public boolean hasEqDeletes() { + return !eqDeletes.isEmpty(); + } + + private static Schema fileProjection(Schema tableSchema, Schema requestedSchema, + List posDeletes, List eqDeletes) { + if (posDeletes.isEmpty() && eqDeletes.isEmpty()) { + return requestedSchema; + } + + Set requiredIds = Sets.newLinkedHashSet(); + if (!posDeletes.isEmpty()) { + requiredIds.add(MetadataColumns.ROW_POSITION.fieldId()); + } + + for (DeleteFile eqDelete : eqDeletes) { + requiredIds.addAll(eqDelete.equalityFieldIds()); + } + + requiredIds.add(MetadataColumns.IS_DELETED.fieldId()); + + Set missingIds = Sets.newLinkedHashSet( + Sets.difference(requiredIds, TypeUtil.getProjectedIds(requestedSchema))); + + if (missingIds.isEmpty()) { + return requestedSchema; + } + + // TODO: support adding nested columns. this will currently fail when finding nested columns to add + List columns = Lists.newArrayList(requestedSchema.columns()); + for (int fieldId : missingIds) { + if (fieldId == MetadataColumns.ROW_POSITION.fieldId() || fieldId == MetadataColumns.IS_DELETED.fieldId()) { + continue; // add _pos and _deleted at the end + } + + Types.NestedField field = tableSchema.asStruct().field(fieldId); + Preconditions.checkArgument(field != null, "Cannot find required field for ID %s", fieldId); + + columns.add(field); + } + + if (missingIds.contains(MetadataColumns.ROW_POSITION.fieldId())) { + columns.add(MetadataColumns.ROW_POSITION); + } + + if (missingIds.contains(MetadataColumns.IS_DELETED.fieldId())) { + columns.add(MetadataColumns.IS_DELETED); + } + + return new Schema(columns); + } +} diff --git a/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/orc/GpuORC.java b/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/orc/GpuORC.java new file mode 100644 index 00000000000..315e2431feb --- /dev/null +++ b/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/orc/GpuORC.java @@ -0,0 +1,119 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark.rapids.iceberg.orc; + +import java.util.Map; + +import org.apache.iceberg.Schema; +import org.apache.iceberg.expressions.Expression; +import org.apache.iceberg.hadoop.HadoopInputFile; +import org.apache.iceberg.io.CloseableIterable; +import org.apache.iceberg.io.InputFile; +import org.apache.iceberg.mapping.NameMapping; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; +import org.apache.hadoop.conf.Configuration; +import org.apache.orc.OrcConf; + +/** GPU version of Apache Iceberg's ORC class */ +public class GpuORC { + private GpuORC() { + } + + public static ReadBuilder read(InputFile file) { + return new ReadBuilder(file); + } + + public static class ReadBuilder { + private final InputFile file; + private final Configuration conf; + private Schema projectSchema = null; + private Schema readerExpectedSchema = null; + private Map idToConstant = null; + private Long start = null; + private Long length = null; + private Expression filter = null; + private boolean caseSensitive = true; + private NameMapping nameMapping = null; + + private ReadBuilder(InputFile file) { + Preconditions.checkNotNull(file, "Input file cannot be null"); + this.file = file; + if (file instanceof HadoopInputFile) { + this.conf = new Configuration(((HadoopInputFile) file).getConf()); + } else { + this.conf = new Configuration(); + } + + // We need to turn positional schema evolution off since we use column name based schema evolution for projection + this.conf.setBoolean(OrcConf.FORCE_POSITIONAL_EVOLUTION.getHiveConfName(), false); + } + + /** + * Restricts the read to the given range: [start, start + length). + * + * @param newStart the start position for this read + * @param newLength the length of the range this read should scan + * @return this builder for method chaining + */ + public ReadBuilder split(long newStart, long newLength) { + this.start = newStart; + this.length = newLength; + return this; + } + + public ReadBuilder project(Schema newSchema) { + this.projectSchema = newSchema; + return this; + } + + public ReadBuilder readerExpectedSchema(Schema newSchema) { + this.readerExpectedSchema = newSchema; + return this; + } + + public ReadBuilder constants(Map constants) { + this.idToConstant = constants; + return this; + } + + public ReadBuilder caseSensitive(boolean newCaseSensitive) { + OrcConf.IS_SCHEMA_EVOLUTION_CASE_SENSITIVE.setBoolean(this.conf, newCaseSensitive); + this.caseSensitive = newCaseSensitive; + return this; + } + + public ReadBuilder config(String property, String value) { + conf.set(property, value); + return this; + } + + public ReadBuilder filter(Expression newFilter) { + this.filter = newFilter; + return this; + } + + public ReadBuilder withNameMapping(NameMapping newNameMapping) { + this.nameMapping = newNameMapping; + return this; + } + + public CloseableIterable build() { + Preconditions.checkNotNull(projectSchema, "Schema is required"); + throw new UnsupportedOperationException(); + } + } +} diff --git a/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/parquet/ApplyNameMapping.java b/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/parquet/ApplyNameMapping.java new file mode 100644 index 00000000000..d71118e6708 --- /dev/null +++ b/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/parquet/ApplyNameMapping.java @@ -0,0 +1,165 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark.rapids.iceberg.parquet; + +import java.util.Deque; +import java.util.List; +import java.util.Objects; +import java.util.stream.Collectors; +import org.apache.iceberg.mapping.MappedField; +import org.apache.iceberg.mapping.NameMapping; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.parquet.schema.GroupType; +import org.apache.parquet.schema.MessageType; +import org.apache.parquet.schema.OriginalType; +import org.apache.parquet.schema.PrimitiveType; +import org.apache.parquet.schema.Type; +import org.apache.parquet.schema.Types; + +public class ApplyNameMapping extends ParquetTypeVisitor { + private static final String LIST_ELEMENT_NAME = "element"; + private static final String MAP_KEY_NAME = "key"; + private static final String MAP_VALUE_NAME = "value"; + private final NameMapping nameMapping; + private final Deque fieldNames = Lists.newLinkedList(); + + ApplyNameMapping(NameMapping nameMapping) { + this.nameMapping = nameMapping; + } + + @Override + public Type message(MessageType message, List fields) { + Types.MessageTypeBuilder builder = Types.buildMessage(); + fields.stream().filter(Objects::nonNull).forEach(builder::addField); + + return builder.named(message.getName()); + } + + @Override + public Type struct(GroupType struct, List types) { + MappedField field = nameMapping.find(currentPath()); + List actualTypes = types.stream().filter(Objects::nonNull).collect(Collectors.toList()); + Type structType = struct.withNewFields(actualTypes); + + return field == null ? structType : structType.withId(field.id()); + } + + @Override + public Type list(GroupType list, Type elementType) { + Preconditions.checkArgument(elementType != null, + "List type must have element field"); + + Type listElement = ParquetSchemaUtil.determineListElementType(list); + MappedField field = nameMapping.find(currentPath()); + + Types.GroupBuilder listBuilder = Types.buildGroup(list.getRepetition()) + // Spark 3.1 uses Parquet 1.10 which does not have LogicalTypeAnnotation +// .as(LogicalTypeAnnotation.listType()); + .as(OriginalType.LIST); + if (listElement.isRepetition(Type.Repetition.REPEATED)) { + listBuilder.addFields(elementType); + } else { + listBuilder.repeatedGroup().addFields(elementType).named(list.getFieldName(0)); + } + Type listType = listBuilder.named(list.getName()); + + return field == null ? listType : listType.withId(field.id()); + } + + @Override + public Type map(GroupType map, Type keyType, Type valueType) { + Preconditions.checkArgument(keyType != null && valueType != null, + "Map type must have both key field and value field"); + + MappedField field = nameMapping.find(currentPath()); + Type mapType = Types.buildGroup(map.getRepetition()) + // Spark 3.1 uses Parquet 1.10 which does not have LogicalTypeAnnotation +// .as(LogicalTypeAnnotation.mapType()) + .as(OriginalType.MAP) + .repeatedGroup().addFields(keyType, valueType).named(map.getFieldName(0)) + .named(map.getName()); + + return field == null ? mapType : mapType.withId(field.id()); + } + + @Override + public Type primitive(PrimitiveType primitive) { + MappedField field = nameMapping.find(currentPath()); + return field == null ? primitive : primitive.withId(field.id()); + } + + @Override + public void beforeField(Type type) { + fieldNames.push(type.getName()); + } + + @Override + public void afterField(Type type) { + fieldNames.pop(); + } + + @Override + public void beforeElementField(Type element) { + // normalize the name to "element" so that the mapping will match structures with alternative names + fieldNames.push(LIST_ELEMENT_NAME); + } + + @Override + public void beforeKeyField(Type key) { + // normalize the name to "key" so that the mapping will match structures with alternative names + fieldNames.push(MAP_KEY_NAME); + } + + @Override + public void beforeValueField(Type key) { + // normalize the name to "value" so that the mapping will match structures with alternative names + fieldNames.push(MAP_VALUE_NAME); + } + + @Override + public void beforeRepeatedElement(Type element) { + // do not add the repeated element's name + } + + @Override + public void afterRepeatedElement(Type element) { + // do not remove the repeated element's name + } + + @Override + public void beforeRepeatedKeyValue(Type keyValue) { + // do not add the repeated element's name + } + + @Override + public void afterRepeatedKeyValue(Type keyValue) { + // do not remove the repeated element's name + } + + @Override + protected String[] currentPath() { + return Lists.newArrayList(fieldNames.descendingIterator()).toArray(new String[0]); + } + + @Override + protected String[] path(String name) { + List list = Lists.newArrayList(fieldNames.descendingIterator()); + list.add(name); + return list.toArray(new String[0]); + } +} diff --git a/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/parquet/GpuParquet.java b/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/parquet/GpuParquet.java new file mode 100644 index 00000000000..61fe0a9e317 --- /dev/null +++ b/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/parquet/GpuParquet.java @@ -0,0 +1,166 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark.rapids.iceberg.parquet; + +import java.util.Collection; +import java.util.Map; + +import com.nvidia.spark.rapids.GpuMetric; +import com.nvidia.spark.rapids.iceberg.data.GpuDeleteFilter; +import org.apache.hadoop.conf.Configuration; +import org.apache.iceberg.Schema; +import org.apache.iceberg.expressions.Expression; +import org.apache.iceberg.hadoop.HadoopInputFile; +import org.apache.iceberg.io.CloseableIterable; +import org.apache.iceberg.io.InputFile; +import org.apache.iceberg.mapping.NameMapping; +import org.apache.iceberg.relocated.com.google.common.collect.Sets; +import org.apache.parquet.HadoopReadOptions; +import org.apache.parquet.ParquetReadOptions; + +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.execution.datasources.PartitionedFile; +import org.apache.spark.sql.vectorized.ColumnarBatch; + +/** GPU version of Apache Iceberg's Parquet class */ +public class GpuParquet { + private static final Collection READ_PROPERTIES_TO_REMOVE = Sets.newHashSet( + "parquet.read.filter", "parquet.private.read.filter.predicate", "parquet.read.support.class"); + + private GpuParquet() { + } + + public static ReadBuilder read(InputFile file) { + return new ReadBuilder(file); + } + + public static class ReadBuilder { + private final InputFile file; + private Long start = null; + private Long length = null; + private Schema projectSchema = null; + private Map idToConstant = null; + private GpuDeleteFilter deleteFilter = null; + private Expression filter = null; + private boolean caseSensitive = true; + private NameMapping nameMapping = null; + private Configuration conf = null; + private int maxBatchSizeRows = 0; + private long maxBatchSizeBytes = Integer.MAX_VALUE; + private String debugDumpPrefix = null; + private scala.collection.immutable.Map metrics = null; + + private ReadBuilder(InputFile file) { + this.file = file; + } + + /** + * Restricts the read to the given range: [start, start + length). + * + * @param newStart the start position for this read + * @param newLength the length of the range this read should scan + * @return this builder for method chaining + */ + public ReadBuilder split(long newStart, long newLength) { + this.start = newStart; + this.length = newLength; + return this; + } + + public ReadBuilder project(Schema newSchema) { + this.projectSchema = newSchema; + return this; + } + + public ReadBuilder caseSensitive(boolean newCaseSensitive) { + this.caseSensitive = newCaseSensitive; + return this; + } + + public ReadBuilder constants(Map constantsMap) { + this.idToConstant = constantsMap; + return this; + } + + public ReadBuilder deleteFilter(GpuDeleteFilter deleteFilter) { + this.deleteFilter = deleteFilter; + return this; + } + + public ReadBuilder filter(Expression newFilter) { + this.filter = newFilter; + return this; + } + + public ReadBuilder withNameMapping(NameMapping newNameMapping) { + this.nameMapping = newNameMapping; + return this; + } + + public ReadBuilder withConfiguration(Configuration conf) { + this.conf = conf; + return this; + } + + public ReadBuilder withMaxBatchSizeRows(int maxBatchSizeRows) { + this.maxBatchSizeRows = maxBatchSizeRows; + return this; + } + + public ReadBuilder withMaxBatchSizeBytes(long maxBatchSizeBytes) { + this.maxBatchSizeBytes = maxBatchSizeBytes; + return this; + } + + public ReadBuilder withDebugDumpPrefix(String dumpPrefix) { + this.debugDumpPrefix = dumpPrefix; + return this; + } + + public ReadBuilder withMetrics(scala.collection.immutable.Map metrics) { + this.metrics = metrics; + return this; + } + + public CloseableIterable build() { + ParquetReadOptions.Builder optionsBuilder; + if (file instanceof HadoopInputFile) { + // remove read properties already set that may conflict with this read + Configuration conf = new Configuration(((HadoopInputFile) file).getConf()); + for (String property : READ_PROPERTIES_TO_REMOVE) { + conf.unset(property); + } + optionsBuilder = HadoopReadOptions.builder(conf); + } else { + //optionsBuilder = ParquetReadOptions.builder(); + throw new UnsupportedOperationException("Only Hadoop files are supported for now"); + } + + if (start != null) { + optionsBuilder.withRange(start, start + length); + } + + ParquetReadOptions options = optionsBuilder.build(); + + PartitionedFile partFile = new PartitionedFile(InternalRow.empty(), file.location(), + start, length, null); + return new GpuParquetReader(file, projectSchema, options, nameMapping, filter, caseSensitive, + idToConstant, deleteFilter, partFile, conf, maxBatchSizeRows, maxBatchSizeBytes, + debugDumpPrefix, metrics); + } + } +} diff --git a/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/parquet/GpuParquetReader.java b/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/parquet/GpuParquetReader.java new file mode 100644 index 00000000000..f6b789a2de2 --- /dev/null +++ b/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/parquet/GpuParquetReader.java @@ -0,0 +1,319 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark.rapids.iceberg.parquet; + +import java.io.IOException; +import java.io.UncheckedIOException; +import java.util.List; +import java.util.Map; +import java.util.Objects; + +import scala.collection.JavaConverters; + +import com.nvidia.spark.rapids.GpuMetric; +import com.nvidia.spark.rapids.ParquetPartitionReader; +import com.nvidia.spark.rapids.iceberg.data.GpuDeleteFilter; +import com.nvidia.spark.rapids.iceberg.spark.SparkSchemaUtil; +import com.nvidia.spark.rapids.iceberg.spark.source.GpuIcebergReader; +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.fs.Path; +import org.apache.iceberg.MetadataColumns; +import org.apache.iceberg.Schema; +import org.apache.iceberg.expressions.Expression; +import org.apache.iceberg.io.CloseableGroup; +import org.apache.iceberg.io.CloseableIterable; +import org.apache.iceberg.io.InputFile; +import org.apache.iceberg.mapping.NameMapping; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.apache.iceberg.types.TypeUtil; +import org.apache.iceberg.types.Types; +import org.apache.parquet.ParquetReadOptions; +import org.apache.parquet.hadoop.ParquetFileReader; +import org.apache.parquet.hadoop.metadata.BlockMetaData; +import org.apache.parquet.schema.GroupType; +import org.apache.parquet.schema.MessageType; +import org.apache.parquet.schema.PrimitiveType; +import org.apache.parquet.schema.Type; +import org.apache.parquet.schema.Types.MessageTypeBuilder; + +import org.apache.spark.sql.execution.datasources.PartitionedFile; +import org.apache.spark.sql.types.StructType; +import org.apache.spark.sql.vectorized.ColumnarBatch; + +/** GPU version of Apache Iceberg's ParquetReader class */ +public class GpuParquetReader extends CloseableGroup implements CloseableIterable { + private final InputFile input; + private final Schema expectedSchema; + private final ParquetReadOptions options; + private final Expression filter; + private final boolean caseSensitive; + private final NameMapping nameMapping; + private final Map idToConstant; + private final GpuDeleteFilter deleteFilter; + private final PartitionedFile partFile; + private final Configuration conf; + private final int maxBatchSizeRows; + private final long maxBatchSizeBytes; + private final String debugDumpPrefix; + private final scala.collection.immutable.Map metrics; + + public GpuParquetReader( + InputFile input, Schema expectedSchema, ParquetReadOptions options, + NameMapping nameMapping, Expression filter, boolean caseSensitive, + Map idToConstant, GpuDeleteFilter deleteFilter, + PartitionedFile partFile, Configuration conf, int maxBatchSizeRows, + long maxBatchSizeBytes, String debugDumpPrefix, + scala.collection.immutable.Map metrics) { + this.input = input; + this.expectedSchema = expectedSchema; + this.options = options; + this.nameMapping = nameMapping; + this.filter = filter; + this.caseSensitive = caseSensitive; + this.idToConstant = idToConstant; + this.deleteFilter = deleteFilter; + this.partFile = partFile; + this.conf = conf; + this.maxBatchSizeRows = maxBatchSizeRows; + this.maxBatchSizeBytes = maxBatchSizeBytes; + this.debugDumpPrefix = debugDumpPrefix; + this.metrics = metrics; + } + + @Override + public org.apache.iceberg.io.CloseableIterator iterator() { + try (ParquetFileReader reader = newReader(input, options)) { + MessageType fileSchema = reader.getFileMetaData().getSchema(); + + MessageType typeWithIds; +// MessageType projection; + if (ParquetSchemaUtil.hasIds(fileSchema)) { + typeWithIds = fileSchema; +// projection = ParquetSchemaUtil.pruneColumns(fileSchema, expectedSchema); + } else if (nameMapping != null) { + typeWithIds = ParquetSchemaUtil.applyNameMapping(fileSchema, nameMapping); +// projection = ParquetSchemaUtil.pruneColumns(typeWithIds, expectedSchema); + } else { + typeWithIds = ParquetSchemaUtil.addFallbackIds(fileSchema); +// projection = ParquetSchemaUtil.pruneColumnsFallback(fileSchema, expectedSchema); + } + + List rowGroups = reader.getRowGroups(); + List filteredRowGroups = Lists.newArrayListWithCapacity(rowGroups.size()); + +// boolean[] startRowPositions[i] = new boolean[rowGroups.size()]; +// +// // Fetch all row groups starting positions to compute the row offsets of the filtered row groups +// Map offsetToStartPos = generateOffsetToStartPos(expectedSchema); + if (expectedSchema.findField(MetadataColumns.ROW_POSITION.fieldId()) != null) { + throw new UnsupportedOperationException("row position meta column not implemented"); + } + + ParquetMetricsRowGroupFilter statsFilter = null; + ParquetDictionaryRowGroupFilter dictFilter = null; + if (filter != null) { + statsFilter = new ParquetMetricsRowGroupFilter(expectedSchema, filter, caseSensitive); + dictFilter = new ParquetDictionaryRowGroupFilter(expectedSchema, filter, caseSensitive); + } + + for (BlockMetaData rowGroup : rowGroups) { +// startRowPositions[i] = offsetToStartPos == null ? 0 : offsetToStartPos.get(rowGroup.getStartingPos()); + boolean shouldRead = filter == null || ( + statsFilter.shouldRead(typeWithIds, rowGroup) && + dictFilter.shouldRead(typeWithIds, rowGroup, reader.getDictionaryReader(rowGroup))); + if (shouldRead) { + filteredRowGroups.add(rowGroup); + } + } + + StructType sparkSchema = SparkSchemaUtil.convertWithoutConstants(expectedSchema, idToConstant); + MessageType fileReadSchema = buildFileReadSchema(fileSchema); + + // reuse Parquet scan code to read the raw data from the file + ParquetPartitionReader partReader = new ParquetPartitionReader(conf, partFile, + new Path(input.location()), JavaConverters.collectionAsScalaIterable(filteredRowGroups), + fileReadSchema, caseSensitive, sparkSchema, debugDumpPrefix, + maxBatchSizeRows, maxBatchSizeBytes, metrics, true, true, true); + + return new GpuIcebergReader(expectedSchema, partReader, deleteFilter, idToConstant); + } catch (IOException e) { + throw new UncheckedIOException("Failed to create/close reader for file: " + input, e); + } + } + +// private Map generateOffsetToStartPos(Schema schema) { +// if (schema.findField(MetadataColumns.ROW_POSITION.fieldId()) == null) { +// return null; +// } +// +// try (ParquetFileReader fileReader = newReader(input, ParquetReadOptions.builder().build())) { +// Map offsetToStartPos = Maps.newHashMap(); +// +// long curRowCount = 0; +// for (int i = 0; i < fileReader.getRowGroups().size(); i += 1) { +// BlockMetaData meta = fileReader.getRowGroups().get(i); +// offsetToStartPos.put(meta.getStartingPos(), curRowCount); +// curRowCount += meta.getRowCount(); +// } +// +// return offsetToStartPos; +// +// } catch (IOException e) { +// throw new UncheckedIOException("Failed to create/close reader for file: " + input, e); +// } +// } + + private static ParquetFileReader newReader(InputFile file, ParquetReadOptions options) { + try { + return ParquetFileReader.open(ParquetIO.file(file), options); + } catch (IOException e) { + throw new UncheckedIOException("Failed to open Parquet file: " + file.location(), e); + } + } + + // Filter out any unreferenced and metadata columns and reorder the columns + // to match the expected schema. + private MessageType buildFileReadSchema(MessageType fileSchema) { + if (ParquetSchemaUtil.hasIds(fileSchema)) { + return (MessageType) + TypeWithSchemaVisitor.visit(expectedSchema.asStruct(), fileSchema, + new ReorderColumns(fileSchema, idToConstant)); + } else { + return (MessageType) + TypeWithSchemaVisitor.visit(expectedSchema.asStruct(), fileSchema, + new ReorderColumnsFallback(fileSchema, idToConstant)); + } + } + + private static class ReorderColumns extends TypeWithSchemaVisitor { + private final MessageType fileSchema; + private final Map idToConstant; + + public ReorderColumns(MessageType fileSchema, Map idToConstant) { + this.fileSchema = fileSchema; + this.idToConstant = idToConstant; + } + + @Override + public Type message(Types.StructType expected, MessageType message, List fields) { + MessageTypeBuilder builder = org.apache.parquet.schema.Types.buildMessage(); + List newFields = filterAndReorder(expected, fields); + // TODO: Avoid re-creating type if nothing changed + for (Type type : newFields) { + builder.addField(type); + } + return builder.named(message.getName()); + } + + @Override + public Type struct(Types.StructType expected, GroupType struct, List fields) { + // TODO: Avoid re-creating type if nothing changed + List newFields = filterAndReorder(expected, fields); + return struct.withNewFields(newFields); + } + + @Override + public Type list(Types.ListType expectedList, GroupType list, Type element) { + boolean hasConstant = expectedList.fields().stream() + .anyMatch(f -> idToConstant.containsKey(f.fieldId())); + if (hasConstant) { + throw new UnsupportedOperationException("constant column in list"); + } + Type originalElement = list.getFields().get(0); + if (Objects.equals(element, originalElement)) { + return list; + } else if (originalElement.isRepetition(Type.Repetition.REPEATED)) { + return list.withNewFields(element); + } + return list.withNewFields(list.getType(0).asGroupType().withNewFields(element)); + } + + @Override + public Type map(Types.MapType expectedMap, GroupType map, Type key, Type value) { + boolean hasConstant = expectedMap.fields().stream() + .anyMatch(f -> idToConstant.containsKey(f.fieldId())); + if (hasConstant) { + throw new UnsupportedOperationException("constant column in map"); + } + GroupType repeated = map.getFields().get(0).asGroupType(); + Type originalKey = repeated.getType(0); + Type originalValue = repeated.getType(0); + if (Objects.equals(key, originalKey) && Objects.equals(value, originalValue)) { + return map; + } + return map.withNewFields(repeated.withNewFields(key, value)); + } + + @Override + public Type primitive(org.apache.iceberg.types.Type.PrimitiveType expected, PrimitiveType primitive) { + return primitive; + } + + /** Returns true if a column with the specified ID should be ignored when loading the file data */ + private boolean shouldIgnoreFileColumn(int id) { + return idToConstant.containsKey(id) || + id == MetadataColumns.ROW_POSITION.fieldId() && + id == MetadataColumns.IS_DELETED.fieldId(); + } + + private List filterAndReorder(Types.StructType expected, List fields) { + // match the expected struct's order + Map typesById = Maps.newHashMap(); + for (Type fieldType : fields) { + if (fieldType.getId() != null) { + int id = fieldType.getId().intValue(); + typesById.put(id, fieldType); + } + } + + List expectedFields = expected != null ? + expected.fields() : ImmutableList.of(); + List reorderedFields = Lists.newArrayListWithCapacity(expectedFields.size()); + for (Types.NestedField field : expectedFields) { + int id = field.fieldId(); + if (!shouldIgnoreFileColumn(id)) { + Type newField = typesById.get(id); + if (newField != null) { + reorderedFields.add(newField); + } + } + } + + return reorderedFields; + } + } + + private static class ReorderColumnsFallback extends ReorderColumns { + public ReorderColumnsFallback(MessageType fileSchema, Map idToConstant) { + super(fileSchema, idToConstant); + } + + @Override + public Type message(Types.StructType expected, MessageType message, List fields) { + // the top level matches by ID, but the remaining IDs are missing + return super.struct(expected, message, fields); + } + + @Override + public Type struct(Types.StructType ignored, GroupType struct, List fields) { + // the expected struct is ignored because nested fields are never found when the IDs are missing + return struct; + } + } +} diff --git a/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/parquet/ParquetConversions.java b/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/parquet/ParquetConversions.java new file mode 100644 index 00000000000..32126f415f8 --- /dev/null +++ b/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/parquet/ParquetConversions.java @@ -0,0 +1,116 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark.rapids.iceberg.parquet; + +import java.math.BigDecimal; +import java.math.BigInteger; +import java.nio.ByteBuffer; +import java.nio.charset.StandardCharsets; +import java.util.UUID; +import java.util.function.Function; + +import org.apache.iceberg.expressions.Literal; +import org.apache.iceberg.types.Type; +import org.apache.parquet.io.api.Binary; +import org.apache.parquet.schema.PrimitiveType; + +public class ParquetConversions { + private ParquetConversions() { + } + + @SuppressWarnings("unchecked") + static Literal fromParquetPrimitive(Type type, PrimitiveType parquetType, Object value) { + switch (type.typeId()) { + case BOOLEAN: + return (Literal) Literal.of((Boolean) value); + case INTEGER: + case DATE: + return (Literal) Literal.of((Integer) value); + case LONG: + case TIME: + case TIMESTAMP: + return (Literal) Literal.of((Long) value); + case FLOAT: + return (Literal) Literal.of((Float) value); + case DOUBLE: + return (Literal) Literal.of((Double) value); + case STRING: + Function stringConversion = converterFromParquet(parquetType); + return (Literal) Literal.of((CharSequence) stringConversion.apply(value)); + case UUID: + Function uuidConversion = converterFromParquet(parquetType); + return (Literal) Literal.of((UUID) uuidConversion.apply(value)); + case FIXED: + case BINARY: + Function binaryConversion = converterFromParquet(parquetType); + return (Literal) Literal.of((ByteBuffer) binaryConversion.apply(value)); + case DECIMAL: + Function decimalConversion = converterFromParquet(parquetType); + return (Literal) Literal.of((BigDecimal) decimalConversion.apply(value)); + default: + throw new IllegalArgumentException("Unsupported primitive type: " + type); + } + } + + static Function converterFromParquet(PrimitiveType parquetType, Type icebergType) { + Function fromParquet = converterFromParquet(parquetType); + if (icebergType != null) { + if (icebergType.typeId() == Type.TypeID.LONG && + parquetType.getPrimitiveTypeName() == PrimitiveType.PrimitiveTypeName.INT32) { + return value -> ((Integer) fromParquet.apply(value)).longValue(); + } else if (icebergType.typeId() == Type.TypeID.DOUBLE && + parquetType.getPrimitiveTypeName() == PrimitiveType.PrimitiveTypeName.FLOAT) { + return value -> ((Float) fromParquet.apply(value)).doubleValue(); + } + } + + return fromParquet; + } + + static Function converterFromParquet(PrimitiveType type) { + if (type.getOriginalType() != null) { + switch (type.getOriginalType()) { + case UTF8: + // decode to CharSequence to avoid copying into a new String + return binary -> StandardCharsets.UTF_8.decode(((Binary) binary).toByteBuffer()); + case DECIMAL: + int scale = type.getDecimalMetadata().getScale(); + switch (type.getPrimitiveTypeName()) { + case INT32: + case INT64: + return num -> BigDecimal.valueOf(((Number) num).longValue(), scale); + case FIXED_LEN_BYTE_ARRAY: + case BINARY: + return bin -> new BigDecimal(new BigInteger(((Binary) bin).getBytes()), scale); + default: + throw new IllegalArgumentException( + "Unsupported primitive type for decimal: " + type.getPrimitiveTypeName()); + } + default: + } + } + + switch (type.getPrimitiveTypeName()) { + case FIXED_LEN_BYTE_ARRAY: + case BINARY: + return binary -> ByteBuffer.wrap(((Binary) binary).getBytes()); + default: + } + + return obj -> obj; + } +} diff --git a/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/parquet/ParquetDictionaryRowGroupFilter.java b/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/parquet/ParquetDictionaryRowGroupFilter.java new file mode 100644 index 00000000000..3f0940fc0da --- /dev/null +++ b/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/parquet/ParquetDictionaryRowGroupFilter.java @@ -0,0 +1,454 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark.rapids.iceberg.parquet; + +import java.io.IOException; +import java.util.Comparator; +import java.util.Map; +import java.util.Set; +import java.util.function.Function; + +import org.apache.iceberg.Schema; +import org.apache.iceberg.exceptions.RuntimeIOException; +import org.apache.iceberg.expressions.Binder; +import org.apache.iceberg.expressions.BoundReference; +import org.apache.iceberg.expressions.Expression; +import org.apache.iceberg.expressions.ExpressionVisitors; +import org.apache.iceberg.expressions.ExpressionVisitors.BoundExpressionVisitor; +import org.apache.iceberg.expressions.Expressions; +import org.apache.iceberg.expressions.Literal; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.apache.iceberg.relocated.com.google.common.collect.Sets; +import org.apache.iceberg.types.Comparators; +import org.apache.iceberg.types.Type; +import org.apache.iceberg.types.Types.StructType; +import org.apache.iceberg.util.NaNUtil; +import org.apache.parquet.column.ColumnDescriptor; +import org.apache.parquet.column.Dictionary; +import org.apache.parquet.column.page.DictionaryPage; +import org.apache.parquet.column.page.DictionaryPageReadStore; +import org.apache.parquet.hadoop.metadata.BlockMetaData; +import org.apache.parquet.hadoop.metadata.ColumnChunkMetaData; +import org.apache.parquet.schema.MessageType; +import org.apache.parquet.schema.PrimitiveType; + +public class ParquetDictionaryRowGroupFilter { + private final Schema schema; + private final Expression expr; + + public ParquetDictionaryRowGroupFilter(Schema schema, Expression unbound) { + this(schema, unbound, true); + } + + public ParquetDictionaryRowGroupFilter(Schema schema, Expression unbound, boolean caseSensitive) { + this.schema = schema; + StructType struct = schema.asStruct(); + this.expr = Binder.bind(struct, Expressions.rewriteNot(unbound), caseSensitive); + } + + /** + * Test whether the dictionaries for a row group may contain records that match the expression. + * + * @param fileSchema schema for the Parquet file + * @param dictionaries a dictionary page read store + * @return false if the file cannot contain rows that match the expression, true otherwise. + */ + public boolean shouldRead(MessageType fileSchema, BlockMetaData rowGroup, + DictionaryPageReadStore dictionaries) { + return new EvalVisitor().eval(fileSchema, rowGroup, dictionaries); + } + + private static final boolean ROWS_MIGHT_MATCH = true; + private static final boolean ROWS_CANNOT_MATCH = false; + + private class EvalVisitor extends BoundExpressionVisitor { + private DictionaryPageReadStore dictionaries = null; + private Map> dictCache = null; + private Map isFallback = null; + private Map mayContainNulls = null; + private Map cols = null; + private Map> conversions = null; + + private boolean eval(MessageType fileSchema, BlockMetaData rowGroup, + DictionaryPageReadStore dictionaryReadStore) { + this.dictionaries = dictionaryReadStore; + this.dictCache = Maps.newHashMap(); + this.isFallback = Maps.newHashMap(); + this.mayContainNulls = Maps.newHashMap(); + this.cols = Maps.newHashMap(); + this.conversions = Maps.newHashMap(); + + for (ColumnDescriptor desc : fileSchema.getColumns()) { + PrimitiveType colType = fileSchema.getType(desc.getPath()).asPrimitiveType(); + if (colType.getId() != null) { + int id = colType.getId().intValue(); + Type icebergType = schema.findType(id); + cols.put(id, desc); + conversions.put(id, ParquetConversions.converterFromParquet(colType, icebergType)); + } + } + + for (ColumnChunkMetaData meta : rowGroup.getColumns()) { + PrimitiveType colType = fileSchema.getType(meta.getPath().toArray()).asPrimitiveType(); + if (colType.getId() != null) { + int id = colType.getId().intValue(); + isFallback.put(id, ParquetUtil.hasNonDictionaryPages(meta)); + mayContainNulls.put(id, mayContainNull(meta)); + } + } + + return ExpressionVisitors.visitEvaluator(expr, this); + } + + @Override + public Boolean alwaysTrue() { + return ROWS_MIGHT_MATCH; // all rows match + } + + @Override + public Boolean alwaysFalse() { + return ROWS_CANNOT_MATCH; // all rows fail + } + + @Override + public Boolean not(Boolean result) { + return !result; + } + + @Override + public Boolean and(Boolean leftResult, Boolean rightResult) { + return leftResult && rightResult; + } + + @Override + public Boolean or(Boolean leftResult, Boolean rightResult) { + return leftResult || rightResult; + } + + @Override + public Boolean isNull(BoundReference ref) { + // dictionaries only contain non-nulls and cannot eliminate based on isNull or NotNull + return ROWS_MIGHT_MATCH; + } + + @Override + public Boolean notNull(BoundReference ref) { + // dictionaries only contain non-nulls and cannot eliminate based on isNull or NotNull + return ROWS_MIGHT_MATCH; + } + + @Override + public Boolean isNaN(BoundReference ref) { + int id = ref.fieldId(); + + Boolean hasNonDictPage = isFallback.get(id); + if (hasNonDictPage == null || hasNonDictPage) { + return ROWS_MIGHT_MATCH; + } + + Set dictionary = dict(id, comparatorForNaNPredicate(ref)); + return dictionary.stream().anyMatch(NaNUtil::isNaN) ? ROWS_MIGHT_MATCH : ROWS_CANNOT_MATCH; + } + + @Override + public Boolean notNaN(BoundReference ref) { + int id = ref.fieldId(); + + Boolean hasNonDictPage = isFallback.get(id); + if (hasNonDictPage == null || hasNonDictPage) { + return ROWS_MIGHT_MATCH; + } + + Set dictionary = dict(id, comparatorForNaNPredicate(ref)); + return dictionary.stream().allMatch(NaNUtil::isNaN) ? ROWS_CANNOT_MATCH : ROWS_MIGHT_MATCH; + } + + private Comparator comparatorForNaNPredicate(BoundReference ref) { + // Construct the same comparator as in ComparableLiteral.comparator, ignoring null value order as dictionary + // cannot contain null values. + // No need to check type: incompatible types will be handled during expression binding. + return Comparators.forType(ref.type().asPrimitiveType()); + } + + @Override + public Boolean lt(BoundReference ref, Literal lit) { + int id = ref.fieldId(); + + Boolean hasNonDictPage = isFallback.get(id); + if (hasNonDictPage == null || hasNonDictPage) { + return ROWS_MIGHT_MATCH; + } + + Set dictionary = dict(id, lit.comparator()); + + // if any item in the dictionary matches the predicate, then at least one row does + for (T item : dictionary) { + int cmp = lit.comparator().compare(item, lit.value()); + if (cmp < 0) { + return ROWS_MIGHT_MATCH; + } + } + + return ROWS_CANNOT_MATCH; + } + + @Override + public Boolean ltEq(BoundReference ref, Literal lit) { + int id = ref.fieldId(); + + Boolean hasNonDictPage = isFallback.get(id); + if (hasNonDictPage == null || hasNonDictPage) { + return ROWS_MIGHT_MATCH; + } + + Set dictionary = dict(id, lit.comparator()); + + // if any item in the dictionary matches the predicate, then at least one row does + for (T item : dictionary) { + int cmp = lit.comparator().compare(item, lit.value()); + if (cmp <= 0) { + return ROWS_MIGHT_MATCH; + } + } + + return ROWS_CANNOT_MATCH; + } + + @Override + public Boolean gt(BoundReference ref, Literal lit) { + int id = ref.fieldId(); + + Boolean hasNonDictPage = isFallback.get(id); + if (hasNonDictPage == null || hasNonDictPage) { + return ROWS_MIGHT_MATCH; + } + + Set dictionary = dict(id, lit.comparator()); + + // if any item in the dictionary matches the predicate, then at least one row does + for (T item : dictionary) { + int cmp = lit.comparator().compare(item, lit.value()); + if (cmp > 0) { + return ROWS_MIGHT_MATCH; + } + } + + return ROWS_CANNOT_MATCH; + } + + @Override + public Boolean gtEq(BoundReference ref, Literal lit) { + int id = ref.fieldId(); + + Boolean hasNonDictPage = isFallback.get(id); + if (hasNonDictPage == null || hasNonDictPage) { + return ROWS_MIGHT_MATCH; + } + + Set dictionary = dict(id, lit.comparator()); + + // if any item in the dictionary matches the predicate, then at least one row does + for (T item : dictionary) { + int cmp = lit.comparator().compare(item, lit.value()); + if (cmp >= 0) { + return ROWS_MIGHT_MATCH; + } + } + + return ROWS_CANNOT_MATCH; + } + + @Override + public Boolean eq(BoundReference ref, Literal lit) { + int id = ref.fieldId(); + + Boolean hasNonDictPage = isFallback.get(id); + if (hasNonDictPage == null || hasNonDictPage) { + return ROWS_MIGHT_MATCH; + } + + Set dictionary = dict(id, lit.comparator()); + + return dictionary.contains(lit.value()) ? ROWS_MIGHT_MATCH : ROWS_CANNOT_MATCH; + } + + @Override + public Boolean notEq(BoundReference ref, Literal lit) { + int id = ref.fieldId(); + + Boolean hasNonDictPage = isFallback.get(id); + if (hasNonDictPage == null || hasNonDictPage) { + return ROWS_MIGHT_MATCH; + } + + Set dictionary = dict(id, lit.comparator()); + if (dictionary.size() > 1 || mayContainNulls.get(id)) { + return ROWS_MIGHT_MATCH; + } + + return dictionary.contains(lit.value()) ? ROWS_CANNOT_MATCH : ROWS_MIGHT_MATCH; + } + + @Override + public Boolean in(BoundReference ref, Set literalSet) { + int id = ref.fieldId(); + + Boolean hasNonDictPage = isFallback.get(id); + if (hasNonDictPage == null || hasNonDictPage) { + return ROWS_MIGHT_MATCH; + } + + Set dictionary = dict(id, ref.comparator()); + + // we need to find out the smaller set to iterate through + Set smallerSet; + Set biggerSet; + + if (literalSet.size() < dictionary.size()) { + smallerSet = literalSet; + biggerSet = dictionary; + } else { + smallerSet = dictionary; + biggerSet = literalSet; + } + + for (T e : smallerSet) { + if (biggerSet.contains(e)) { + // value sets intersect so rows match + return ROWS_MIGHT_MATCH; + } + } + + // value sets are disjoint so rows don't match + return ROWS_CANNOT_MATCH; + } + + @Override + public Boolean notIn(BoundReference ref, Set literalSet) { + int id = ref.fieldId(); + + Boolean hasNonDictPage = isFallback.get(id); + if (hasNonDictPage == null || hasNonDictPage) { + return ROWS_MIGHT_MATCH; + } + + Set dictionary = dict(id, ref.comparator()); + if (dictionary.size() > literalSet.size() || mayContainNulls.get(id)) { + return ROWS_MIGHT_MATCH; + } + + // ROWS_CANNOT_MATCH if no values in the dictionary that are not also in the set (the difference is empty) + return Sets.difference(dictionary, literalSet).isEmpty() ? ROWS_CANNOT_MATCH : ROWS_MIGHT_MATCH; + } + + @Override + public Boolean startsWith(BoundReference ref, Literal lit) { + int id = ref.fieldId(); + + Boolean hasNonDictPage = isFallback.get(id); + if (hasNonDictPage == null || hasNonDictPage) { + return ROWS_MIGHT_MATCH; + } + + Set dictionary = dict(id, lit.comparator()); + for (T item : dictionary) { + if (item.toString().startsWith(lit.value().toString())) { + return ROWS_MIGHT_MATCH; + } + } + + return ROWS_CANNOT_MATCH; + } + + @Override + public Boolean notStartsWith(BoundReference ref, Literal lit) { + int id = ref.fieldId(); + + Boolean hasNonDictPage = isFallback.get(id); + if (hasNonDictPage == null || hasNonDictPage) { + return ROWS_MIGHT_MATCH; + } + + Set dictionary = dict(id, lit.comparator()); + for (T item : dictionary) { + if (!item.toString().startsWith(lit.value().toString())) { + return ROWS_MIGHT_MATCH; + } + } + + return ROWS_CANNOT_MATCH; + } + + @SuppressWarnings("unchecked") + private Set dict(int id, Comparator comparator) { + Preconditions.checkNotNull(dictionaries, "Dictionary is required"); + + Set cached = dictCache.get(id); + if (cached != null) { + return (Set) cached; + } + + ColumnDescriptor col = cols.get(id); + DictionaryPage page = dictionaries.readDictionaryPage(col); + // may not be dictionary-encoded + if (page == null) { + throw new IllegalStateException("Failed to read required dictionary page for id: " + id); + } + + Function conversion = conversions.get(id); + + Dictionary dict; + try { + dict = page.getEncoding().initDictionary(col, page); + } catch (IOException e) { + throw new RuntimeIOException("Failed to create reader for dictionary page"); + } + + Set dictSet = Sets.newTreeSet(comparator); + + for (int i = 0; i <= dict.getMaxId(); i++) { + switch (col.getPrimitiveType().getPrimitiveTypeName()) { + case FIXED_LEN_BYTE_ARRAY: dictSet.add((T) conversion.apply(dict.decodeToBinary(i))); + break; + case BINARY: dictSet.add((T) conversion.apply(dict.decodeToBinary(i))); + break; + case INT32: dictSet.add((T) conversion.apply(dict.decodeToInt(i))); + break; + case INT64: dictSet.add((T) conversion.apply(dict.decodeToLong(i))); + break; + case FLOAT: dictSet.add((T) conversion.apply(dict.decodeToFloat(i))); + break; + case DOUBLE: dictSet.add((T) conversion.apply(dict.decodeToDouble(i))); + break; + default: + throw new IllegalArgumentException( + "Cannot decode dictionary of type: " + col.getPrimitiveType().getPrimitiveTypeName()); + } + } + + dictCache.put(id, dictSet); + + return dictSet; + } + } + + private static boolean mayContainNull(ColumnChunkMetaData meta) { + return meta.getStatistics() == null || meta.getStatistics().getNumNulls() != 0; + } + +} diff --git a/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/parquet/ParquetIO.java b/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/parquet/ParquetIO.java new file mode 100644 index 00000000000..4c6a91eb3dc --- /dev/null +++ b/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/parquet/ParquetIO.java @@ -0,0 +1,95 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark.rapids.iceberg.parquet; + +import java.io.IOException; +import java.io.InputStream; +import java.io.UncheckedIOException; + +import org.apache.hadoop.fs.FSDataInputStream; +import org.apache.iceberg.hadoop.HadoopInputFile; +import org.apache.iceberg.io.DelegatingInputStream; +import org.apache.parquet.hadoop.util.HadoopStreams; +import org.apache.parquet.io.DelegatingSeekableInputStream; +import org.apache.parquet.io.InputFile; +import org.apache.parquet.io.SeekableInputStream; + +public class ParquetIO { + private ParquetIO() { + } + + static InputFile file(org.apache.iceberg.io.InputFile file) { + // TODO: use reflection to avoid depending on classes from iceberg-hadoop + // TODO: use reflection to avoid depending on classes from hadoop + if (file instanceof HadoopInputFile) { + HadoopInputFile hfile = (HadoopInputFile) file; + try { + return org.apache.parquet.hadoop.util.HadoopInputFile.fromStatus(hfile.getStat(), hfile.getConf()); + } catch (IOException e) { + throw new UncheckedIOException("Failed to create Parquet input file for " + file, e); + } + } + return new ParquetInputFile(file); + } + + static SeekableInputStream stream(org.apache.iceberg.io.SeekableInputStream stream) { + if (stream instanceof DelegatingInputStream) { + InputStream wrapped = ((DelegatingInputStream) stream).getDelegate(); + if (wrapped instanceof FSDataInputStream) { + return HadoopStreams.wrap((FSDataInputStream) wrapped); + } + } + return new ParquetInputStreamAdapter(stream); + } + + private static class ParquetInputStreamAdapter extends DelegatingSeekableInputStream { + private final org.apache.iceberg.io.SeekableInputStream delegate; + + private ParquetInputStreamAdapter(org.apache.iceberg.io.SeekableInputStream delegate) { + super(delegate); + this.delegate = delegate; + } + + @Override + public long getPos() throws IOException { + return delegate.getPos(); + } + + @Override + public void seek(long newPos) throws IOException { + delegate.seek(newPos); + } + } + + private static class ParquetInputFile implements InputFile { + private final org.apache.iceberg.io.InputFile file; + + private ParquetInputFile(org.apache.iceberg.io.InputFile file) { + this.file = file; + } + + @Override + public long getLength() throws IOException { + return file.getLength(); + } + + @Override + public SeekableInputStream newStream() throws IOException { + return stream(file.newStream()); + } + } +} diff --git a/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/parquet/ParquetMetricsRowGroupFilter.java b/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/parquet/ParquetMetricsRowGroupFilter.java new file mode 100644 index 00000000000..bd60184f0c2 --- /dev/null +++ b/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/parquet/ParquetMetricsRowGroupFilter.java @@ -0,0 +1,565 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark.rapids.iceberg.parquet; + +import java.nio.ByteBuffer; +import java.util.Collection; +import java.util.Comparator; +import java.util.Map; +import java.util.Set; +import java.util.function.Function; +import java.util.stream.Collectors; + +import org.apache.iceberg.Schema; +import org.apache.iceberg.expressions.Binder; +import org.apache.iceberg.expressions.BoundReference; +import org.apache.iceberg.expressions.Expression; +import org.apache.iceberg.expressions.ExpressionVisitors; +import org.apache.iceberg.expressions.ExpressionVisitors.BoundExpressionVisitor; +import org.apache.iceberg.expressions.Expressions; +import org.apache.iceberg.expressions.Literal; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.apache.iceberg.types.Comparators; +import org.apache.iceberg.types.Type; +import org.apache.iceberg.types.Types.StructType; +import org.apache.iceberg.util.BinaryUtil; +import org.apache.parquet.column.statistics.Statistics; +import org.apache.parquet.hadoop.metadata.BlockMetaData; +import org.apache.parquet.hadoop.metadata.ColumnChunkMetaData; +import org.apache.parquet.io.api.Binary; +import org.apache.parquet.schema.MessageType; +import org.apache.parquet.schema.PrimitiveType; + +public class ParquetMetricsRowGroupFilter { + private static final int IN_PREDICATE_LIMIT = 200; + + private final Schema schema; + private final Expression expr; + + public ParquetMetricsRowGroupFilter(Schema schema, Expression unbound) { + this(schema, unbound, true); + } + + public ParquetMetricsRowGroupFilter(Schema schema, Expression unbound, boolean caseSensitive) { + this.schema = schema; + StructType struct = schema.asStruct(); + this.expr = Binder.bind(struct, Expressions.rewriteNot(unbound), caseSensitive); + } + + /** + * Test whether the file may contain records that match the expression. + * + * @param fileSchema schema for the Parquet file + * @param rowGroup metadata for a row group + * @return false if the file cannot contain rows that match the expression, true otherwise. + */ + public boolean shouldRead(MessageType fileSchema, BlockMetaData rowGroup) { + return new MetricsEvalVisitor().eval(fileSchema, rowGroup); + } + + private static final boolean ROWS_MIGHT_MATCH = true; + private static final boolean ROWS_CANNOT_MATCH = false; + + private class MetricsEvalVisitor extends BoundExpressionVisitor { + private Map> stats = null; + private Map valueCounts = null; + private Map> conversions = null; + + private boolean eval(MessageType fileSchema, BlockMetaData rowGroup) { + if (rowGroup.getRowCount() <= 0) { + return ROWS_CANNOT_MATCH; + } + + this.stats = Maps.newHashMap(); + this.valueCounts = Maps.newHashMap(); + this.conversions = Maps.newHashMap(); + for (ColumnChunkMetaData col : rowGroup.getColumns()) { + PrimitiveType colType = fileSchema.getType(col.getPath().toArray()).asPrimitiveType(); + if (colType.getId() != null) { + int id = colType.getId().intValue(); + Type icebergType = schema.findType(id); + stats.put(id, col.getStatistics()); + valueCounts.put(id, col.getValueCount()); + conversions.put(id, ParquetConversions.converterFromParquet(colType, icebergType)); + } + } + + return ExpressionVisitors.visitEvaluator(expr, this); + } + + @Override + public Boolean alwaysTrue() { + return ROWS_MIGHT_MATCH; // all rows match + } + + @Override + public Boolean alwaysFalse() { + return ROWS_CANNOT_MATCH; // all rows fail + } + + @Override + public Boolean not(Boolean result) { + return !result; + } + + @Override + public Boolean and(Boolean leftResult, Boolean rightResult) { + return leftResult && rightResult; + } + + @Override + public Boolean or(Boolean leftResult, Boolean rightResult) { + return leftResult || rightResult; + } + + @Override + public Boolean isNull(BoundReference ref) { + // no need to check whether the field is required because binding evaluates that case + // if the column has no null values, the expression cannot match + int id = ref.fieldId(); + + Long valueCount = valueCounts.get(id); + if (valueCount == null) { + // the column is not present and is all nulls + return ROWS_MIGHT_MATCH; + } + + Statistics colStats = stats.get(id); + if (colStats != null && !colStats.isEmpty() && colStats.getNumNulls() == 0) { + // there are stats and no values are null => all values are non-null + return ROWS_CANNOT_MATCH; + } + + return ROWS_MIGHT_MATCH; + } + + @Override + public Boolean notNull(BoundReference ref) { + // no need to check whether the field is required because binding evaluates that case + // if the column has no non-null values, the expression cannot match + int id = ref.fieldId(); + + // When filtering nested types notNull() is implicit filter passed even though complex + // filters aren't pushed down in Parquet. Leave all nested column type filters to be + // evaluated post scan. + if (schema.findType(id) instanceof Type.NestedType) { + return ROWS_MIGHT_MATCH; + } + + Long valueCount = valueCounts.get(id); + if (valueCount == null) { + // the column is not present and is all nulls + return ROWS_CANNOT_MATCH; + } + + Statistics colStats = stats.get(id); + if (colStats != null && valueCount - colStats.getNumNulls() == 0) { + // (num nulls == value count) => all values are null => no non-null values + return ROWS_CANNOT_MATCH; + } + + return ROWS_MIGHT_MATCH; + } + + @Override + public Boolean isNaN(BoundReference ref) { + int id = ref.fieldId(); + + Long valueCount = valueCounts.get(id); + if (valueCount == null) { + // the column is not present and is all nulls + return ROWS_CANNOT_MATCH; + } + + Statistics colStats = stats.get(id); + if (colStats != null && valueCount - colStats.getNumNulls() == 0) { + // (num nulls == value count) => all values are null => no nan values + return ROWS_CANNOT_MATCH; + } + + return ROWS_MIGHT_MATCH; + } + + @Override + public Boolean notNaN(BoundReference ref) { + return ROWS_MIGHT_MATCH; + } + + @Override + public Boolean lt(BoundReference ref, Literal lit) { + int id = ref.fieldId(); + + Long valueCount = valueCounts.get(id); + if (valueCount == null) { + // the column is not present and is all nulls + return ROWS_CANNOT_MATCH; + } + + Statistics colStats = stats.get(id); + if (colStats != null && !colStats.isEmpty()) { + if (hasNonNullButNoMinMax(colStats, valueCount)) { + return ROWS_MIGHT_MATCH; + } + + if (!colStats.hasNonNullValue()) { + return ROWS_CANNOT_MATCH; + } + + T lower = min(colStats, id); + int cmp = lit.comparator().compare(lower, lit.value()); + if (cmp >= 0) { + return ROWS_CANNOT_MATCH; + } + } + + return ROWS_MIGHT_MATCH; + } + + @Override + public Boolean ltEq(BoundReference ref, Literal lit) { + int id = ref.fieldId(); + + Long valueCount = valueCounts.get(id); + if (valueCount == null) { + // the column is not present and is all nulls + return ROWS_CANNOT_MATCH; + } + + Statistics colStats = stats.get(id); + if (colStats != null && !colStats.isEmpty()) { + if (hasNonNullButNoMinMax(colStats, valueCount)) { + return ROWS_MIGHT_MATCH; + } + + if (!colStats.hasNonNullValue()) { + return ROWS_CANNOT_MATCH; + } + + T lower = min(colStats, id); + int cmp = lit.comparator().compare(lower, lit.value()); + if (cmp > 0) { + return ROWS_CANNOT_MATCH; + } + } + + return ROWS_MIGHT_MATCH; + } + + @Override + public Boolean gt(BoundReference ref, Literal lit) { + int id = ref.fieldId(); + + Long valueCount = valueCounts.get(id); + if (valueCount == null) { + // the column is not present and is all nulls + return ROWS_CANNOT_MATCH; + } + + Statistics colStats = stats.get(id); + if (colStats != null && !colStats.isEmpty()) { + if (hasNonNullButNoMinMax(colStats, valueCount)) { + return ROWS_MIGHT_MATCH; + } + + if (!colStats.hasNonNullValue()) { + return ROWS_CANNOT_MATCH; + } + + T upper = max(colStats, id); + int cmp = lit.comparator().compare(upper, lit.value()); + if (cmp <= 0) { + return ROWS_CANNOT_MATCH; + } + } + + return ROWS_MIGHT_MATCH; + } + + @Override + public Boolean gtEq(BoundReference ref, Literal lit) { + int id = ref.fieldId(); + + Long valueCount = valueCounts.get(id); + if (valueCount == null) { + // the column is not present and is all nulls + return ROWS_CANNOT_MATCH; + } + + Statistics colStats = stats.get(id); + if (colStats != null && !colStats.isEmpty()) { + if (hasNonNullButNoMinMax(colStats, valueCount)) { + return ROWS_MIGHT_MATCH; + } + + if (!colStats.hasNonNullValue()) { + return ROWS_CANNOT_MATCH; + } + + T upper = max(colStats, id); + int cmp = lit.comparator().compare(upper, lit.value()); + if (cmp < 0) { + return ROWS_CANNOT_MATCH; + } + } + + return ROWS_MIGHT_MATCH; + } + + @Override + public Boolean eq(BoundReference ref, Literal lit) { + int id = ref.fieldId(); + + // When filtering nested types notNull() is implicit filter passed even though complex + // filters aren't pushed down in Parquet. Leave all nested column type filters to be + // evaluated post scan. + if (schema.findType(id) instanceof Type.NestedType) { + return ROWS_MIGHT_MATCH; + } + + Long valueCount = valueCounts.get(id); + if (valueCount == null) { + // the column is not present and is all nulls + return ROWS_CANNOT_MATCH; + } + + Statistics colStats = stats.get(id); + if (colStats != null && !colStats.isEmpty()) { + if (hasNonNullButNoMinMax(colStats, valueCount)) { + return ROWS_MIGHT_MATCH; + } + + if (!colStats.hasNonNullValue()) { + return ROWS_CANNOT_MATCH; + } + + T lower = min(colStats, id); + int cmp = lit.comparator().compare(lower, lit.value()); + if (cmp > 0) { + return ROWS_CANNOT_MATCH; + } + + T upper = max(colStats, id); + cmp = lit.comparator().compare(upper, lit.value()); + if (cmp < 0) { + return ROWS_CANNOT_MATCH; + } + } + + return ROWS_MIGHT_MATCH; + } + + @Override + public Boolean notEq(BoundReference ref, Literal lit) { + // because the bounds are not necessarily a min or max value, this cannot be answered using + // them. notEq(col, X) with (X, Y) doesn't guarantee that X is a value in col. + return ROWS_MIGHT_MATCH; + } + + @Override + public Boolean in(BoundReference ref, Set literalSet) { + int id = ref.fieldId(); + + // When filtering nested types notNull() is implicit filter passed even though complex + // filters aren't pushed down in Parquet. Leave all nested column type filters to be + // evaluated post scan. + if (schema.findType(id) instanceof Type.NestedType) { + return ROWS_MIGHT_MATCH; + } + + Long valueCount = valueCounts.get(id); + if (valueCount == null) { + // the column is not present and is all nulls + return ROWS_CANNOT_MATCH; + } + + Statistics colStats = stats.get(id); + if (colStats != null && !colStats.isEmpty()) { + if (hasNonNullButNoMinMax(colStats, valueCount)) { + return ROWS_MIGHT_MATCH; + } + + if (!colStats.hasNonNullValue()) { + return ROWS_CANNOT_MATCH; + } + + Collection literals = literalSet; + + if (literals.size() > IN_PREDICATE_LIMIT) { + // skip evaluating the predicate if the number of values is too big + return ROWS_MIGHT_MATCH; + } + + T lower = min(colStats, id); + literals = literals.stream().filter(v -> ref.comparator().compare(lower, v) <= 0).collect(Collectors.toList()); + if (literals.isEmpty()) { // if all values are less than lower bound, rows cannot match. + return ROWS_CANNOT_MATCH; + } + + T upper = max(colStats, id); + literals = literals.stream().filter(v -> ref.comparator().compare(upper, v) >= 0).collect(Collectors.toList()); + if (literals.isEmpty()) { // if all remaining values are greater than upper bound, rows cannot match. + return ROWS_CANNOT_MATCH; + } + } + + return ROWS_MIGHT_MATCH; + } + + @Override + public Boolean notIn(BoundReference ref, Set literalSet) { + // because the bounds are not necessarily a min or max value, this cannot be answered using + // them. notIn(col, {X, ...}) with (X, Y) doesn't guarantee that X is a value in col. + return ROWS_MIGHT_MATCH; + } + + @Override + @SuppressWarnings("unchecked") + public Boolean startsWith(BoundReference ref, Literal lit) { + int id = ref.fieldId(); + + Long valueCount = valueCounts.get(id); + if (valueCount == null) { + // the column is not present and is all nulls + return ROWS_CANNOT_MATCH; + } + + Statistics colStats = (Statistics) stats.get(id); + if (colStats != null && !colStats.isEmpty()) { + if (hasNonNullButNoMinMax(colStats, valueCount)) { + return ROWS_MIGHT_MATCH; + } + + if (!colStats.hasNonNullValue()) { + return ROWS_CANNOT_MATCH; + } + + ByteBuffer prefixAsBytes = lit.toByteBuffer(); + + Comparator comparator = Comparators.unsignedBytes(); + + Binary lower = colStats.genericGetMin(); + // truncate lower bound so that its length in bytes is not greater than the length of prefix + int lowerLength = Math.min(prefixAsBytes.remaining(), lower.length()); + int lowerCmp = comparator.compare(BinaryUtil.truncateBinary(lower.toByteBuffer(), lowerLength), prefixAsBytes); + if (lowerCmp > 0) { + return ROWS_CANNOT_MATCH; + } + + Binary upper = colStats.genericGetMax(); + // truncate upper bound so that its length in bytes is not greater than the length of prefix + int upperLength = Math.min(prefixAsBytes.remaining(), upper.length()); + int upperCmp = comparator.compare(BinaryUtil.truncateBinary(upper.toByteBuffer(), upperLength), prefixAsBytes); + if (upperCmp < 0) { + return ROWS_CANNOT_MATCH; + } + } + + return ROWS_MIGHT_MATCH; + } + + @Override + @SuppressWarnings("unchecked") + public Boolean notStartsWith(BoundReference ref, Literal lit) { + int id = ref.fieldId(); + Long valueCount = valueCounts.get(id); + + if (valueCount == null) { + // the column is not present and is all nulls + return ROWS_MIGHT_MATCH; + } + + Statistics colStats = (Statistics) stats.get(id); + if (colStats != null && !colStats.isEmpty()) { + if (mayContainNull(colStats)) { + return ROWS_MIGHT_MATCH; + } + + if (hasNonNullButNoMinMax(colStats, valueCount)) { + return ROWS_MIGHT_MATCH; + } + + Binary lower = colStats.genericGetMin(); + Binary upper = colStats.genericGetMax(); + + // notStartsWith will match unless all values must start with the prefix. this happens when the lower and upper + // bounds both start with the prefix. + if (lower != null && upper != null) { + ByteBuffer prefix = lit.toByteBuffer(); + Comparator comparator = Comparators.unsignedBytes(); + + // if lower is shorter than the prefix, it can't start with the prefix + if (lower.length() < prefix.remaining()) { + return ROWS_MIGHT_MATCH; + } + + // truncate lower bound to the prefix and check for equality + int cmp = comparator.compare(BinaryUtil.truncateBinary(lower.toByteBuffer(), prefix.remaining()), prefix); + if (cmp == 0) { + // the lower bound starts with the prefix; check the upper bound + // if upper is shorter than the prefix, it can't start with the prefix + if (upper.length() < prefix.remaining()) { + return ROWS_MIGHT_MATCH; + } + + // truncate upper bound so that its length in bytes is not greater than the length of prefix + cmp = comparator.compare(BinaryUtil.truncateBinary(upper.toByteBuffer(), prefix.remaining()), prefix); + if (cmp == 0) { + // both bounds match the prefix, so all rows must match the prefix and none do not match + return ROWS_CANNOT_MATCH; + } + } + } + } + + return ROWS_MIGHT_MATCH; + } + + @SuppressWarnings("unchecked") + private T min(Statistics statistics, int id) { + return (T) conversions.get(id).apply(statistics.genericGetMin()); + } + + @SuppressWarnings("unchecked") + private T max(Statistics statistics, int id) { + return (T) conversions.get(id).apply(statistics.genericGetMax()); + } + } + + /** + * Checks against older versions of Parquet statistics which may have a null count but undefined min and max + * statistics. Returns true if nonNull values exist in the row group but no further statistics are available. + *

+ * We can't use {@code statistics.hasNonNullValue()} because it is inaccurate with older files and will return + * false if min and max are not set. + *

+ * This is specifically for 1.5.0-CDH Parquet builds and later which contain the different unusual hasNonNull + * behavior. OSS Parquet builds are not effected because PARQUET-251 prohibits the reading of these statistics + * from versions of Parquet earlier than 1.8.0. + * + * @param statistics Statistics to check + * @param valueCount Number of values in the row group + * @return true if nonNull values exist and no other stats can be used + */ + static boolean hasNonNullButNoMinMax(Statistics statistics, long valueCount) { + return statistics.getNumNulls() < valueCount && + (statistics.getMaxBytes() == null || statistics.getMinBytes() == null); + } + + private static boolean mayContainNull(Statistics statistics) { + return !statistics.isNumNullsSet() || statistics.getNumNulls() > 0; + } +} diff --git a/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/parquet/ParquetSchemaUtil.java b/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/parquet/ParquetSchemaUtil.java new file mode 100644 index 00000000000..51a12e55dac --- /dev/null +++ b/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/parquet/ParquetSchemaUtil.java @@ -0,0 +1,211 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark.rapids.iceberg.parquet; + +import java.util.List; + +import org.apache.iceberg.mapping.NameMapping; +import org.apache.parquet.schema.GroupType; +import org.apache.parquet.schema.MessageType; +import org.apache.parquet.schema.PrimitiveType; +import org.apache.parquet.schema.Type; +import org.apache.parquet.schema.Types.MessageTypeBuilder; + +public class ParquetSchemaUtil { + private ParquetSchemaUtil() { + } + +// public static MessageType convert(Schema schema, String name) { +// return new TypeToMessageType().convert(schema, name); +// } +// +// /** +// * Converts a Parquet schema to an Iceberg schema. Fields without IDs are kept and assigned fallback IDs. +// * +// * @param parquetSchema a Parquet schema +// * @return a matching Iceberg schema for the provided Parquet schema +// */ +// public static Schema convert(MessageType parquetSchema) { +// // if the Parquet schema does not contain ids, we assign fallback ids to top-level fields +// // all remaining fields will get ids >= 1000 to avoid pruning columns without ids +// MessageType parquetSchemaWithIds = hasIds(parquetSchema) ? parquetSchema : addFallbackIds(parquetSchema); +// AtomicInteger nextId = new AtomicInteger(1000); +// return convertInternal(parquetSchemaWithIds, name -> nextId.getAndIncrement()); +// } +// +// /** +// * Converts a Parquet schema to an Iceberg schema and prunes fields without IDs. +// * +// * @param parquetSchema a Parquet schema +// * @return a matching Iceberg schema for the provided Parquet schema +// */ +// public static Schema convertAndPrune(MessageType parquetSchema) { +// return convertInternal(parquetSchema, name -> null); +// } +// +// private static Schema convertInternal(MessageType parquetSchema, Function nameToIdFunc) { +// MessageTypeToType converter = new MessageTypeToType(nameToIdFunc); +// return new Schema( +// ParquetTypeVisitor.visit(parquetSchema, converter).asNestedType().fields(), +// converter.getAliases()); +// } +// +// public static MessageType pruneColumns(MessageType fileSchema, Schema expectedSchema) { +// // column order must match the incoming type, so it doesn't matter that the ids are unordered +// Set selectedIds = TypeUtil.getProjectedIds(expectedSchema); +// return (MessageType) ParquetTypeVisitor.visit(fileSchema, new PruneColumns(selectedIds)); +// } +// +// /** +// * Prunes columns from a Parquet file schema that was written without field ids. +// *

+// * Files that were written without field ids are read assuming that schema evolution preserved +// * column order. Deleting columns was not allowed. +// *

+// * The order of columns in the resulting Parquet schema matches the Parquet file. +// * +// * @param fileSchema schema from a Parquet file that does not have field ids. +// * @param expectedSchema expected schema +// * @return a parquet schema pruned using the expected schema +// */ +// public static MessageType pruneColumnsFallback(MessageType fileSchema, Schema expectedSchema) { +// Set selectedIds = Sets.newHashSet(); +// +// for (Types.NestedField field : expectedSchema.columns()) { +// selectedIds.add(field.fieldId()); +// } +// +// MessageTypeBuilder builder = org.apache.parquet.schema.Types.buildMessage(); +// +// int ordinal = 1; +// for (Type type : fileSchema.getFields()) { +// if (selectedIds.contains(ordinal)) { +// builder.addField(type.withId(ordinal)); +// } +// ordinal += 1; +// } +// +// return builder.named(fileSchema.getName()); +// } + + public static boolean hasIds(MessageType fileSchema) { + return ParquetTypeVisitor.visit(fileSchema, new HasIds()); + } + + public static MessageType addFallbackIds(MessageType fileSchema) { + MessageTypeBuilder builder = org.apache.parquet.schema.Types.buildMessage(); + + int ordinal = 1; // ids are assigned starting at 1 + for (Type type : fileSchema.getFields()) { + builder.addField(type.withId(ordinal)); + ordinal += 1; + } + + return builder.named(fileSchema.getName()); + } + + public static MessageType applyNameMapping(MessageType fileSchema, NameMapping nameMapping) { + return (MessageType) ParquetTypeVisitor.visit(fileSchema, new ApplyNameMapping(nameMapping)); + } + + public static class HasIds extends ParquetTypeVisitor { + @Override + public Boolean message(MessageType message, List fields) { + return struct(message, fields); + } + + @Override + public Boolean struct(GroupType struct, List hasIds) { + for (Boolean hasId : hasIds) { + if (hasId) { + return true; + } + } + return struct.getId() != null; + } + + @Override + public Boolean list(GroupType array, Boolean hasId) { + return hasId || array.getId() != null; + } + + @Override + public Boolean map(GroupType map, Boolean keyHasId, Boolean valueHasId) { + return keyHasId || valueHasId || map.getId() != null; + } + + @Override + public Boolean primitive(PrimitiveType primitive) { + return primitive.getId() != null; + } + } + + public static Type determineListElementType(GroupType array) { + Type repeated = array.getFields().get(0); + boolean isOldListElementType = isOldListElementType(array); + + return isOldListElementType ? repeated : repeated.asGroupType().getType(0); + } + + // Parquet LIST backwards-compatibility rules. + // https://github.com/apache/parquet-format/blob/master/LogicalTypes.md#backward-compatibility-rules + static boolean isOldListElementType(GroupType list) { + Type repeatedType = list.getFields().get(0); + String parentName = list.getName(); + + return + // For legacy 2-level list types with primitive element type, e.g.: + // + // // ARRAY (nullable list, non-null elements) + // optional group my_list (LIST) { + // repeated int32 element; + // } + // + repeatedType.isPrimitive() || + // For legacy 2-level list types whose element type is a group type with 2 or more fields, + // e.g.: + // + // // ARRAY> (nullable list, non-null elements) + // optional group my_list (LIST) { + // repeated group element { + // required binary str (UTF8); + // required int32 num; + // }; + // } + // + repeatedType.asGroupType().getFieldCount() > 1 || + // For legacy 2-level list types generated by parquet-avro (Parquet version < 1.6.0), e.g.: + // + // // ARRAY> (nullable list, non-null elements) + // optional group my_list (LIST) { + // repeated group array { + // required binary str (UTF8); + // }; + // } + repeatedType.getName().equals("array") || + // For Parquet data generated by parquet-thrift, e.g.: + // + // // ARRAY> (nullable list, non-null elements) + // optional group my_list (LIST) { + // repeated group my_list_tuple { + // required binary str (UTF8); + // }; + // } + // + repeatedType.getName().equals(parentName + "_tuple"); + } +} diff --git a/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/parquet/ParquetTypeVisitor.java b/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/parquet/ParquetTypeVisitor.java new file mode 100644 index 00000000000..614688fb118 --- /dev/null +++ b/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/parquet/ParquetTypeVisitor.java @@ -0,0 +1,263 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark.rapids.iceberg.parquet; + +import java.util.Deque; +import java.util.List; + +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.parquet.schema.GroupType; +import org.apache.parquet.schema.MessageType; +import org.apache.parquet.schema.OriginalType; +import org.apache.parquet.schema.PrimitiveType; +import org.apache.parquet.schema.Type; + +public class ParquetTypeVisitor { + private final Deque fieldNames = Lists.newLinkedList(); + + public static T visit(Type type, ParquetTypeVisitor visitor) { + if (type instanceof MessageType) { + return visitor.message((MessageType) type, + visitFields(type.asGroupType(), visitor)); + + } else if (type.isPrimitive()) { + return visitor.primitive(type.asPrimitiveType()); + + } else { + // if not a primitive, the typeId must be a group + GroupType group = type.asGroupType(); + OriginalType annotation = group.getOriginalType(); + if (annotation != null) { + switch (annotation) { + case LIST: + return visitList(group, visitor); + + case MAP: + return visitMap(group, visitor); + + default: + } + } + + return visitor.struct(group, visitFields(group, visitor)); + } + } + + private static T visitList(GroupType list, ParquetTypeVisitor visitor) { + Preconditions.checkArgument(!list.isRepetition(Type.Repetition.REPEATED), + "Invalid list: top-level group is repeated: %s", list); + Preconditions.checkArgument(list.getFieldCount() == 1, + "Invalid list: does not contain single repeated field: %s", list); + + Type repeatedElement = list.getFields().get(0); + Preconditions.checkArgument(repeatedElement.isRepetition(Type.Repetition.REPEATED), + "Invalid list: inner group is not repeated"); + + Type listElement = ParquetSchemaUtil.determineListElementType(list); + if (listElement.isRepetition(Type.Repetition.REPEATED)) { + T elementResult = visitListElement(listElement, visitor); + return visitor.list(list, elementResult); + } else { + return visitThreeLevelList(list, repeatedElement, listElement, visitor); + } + } + + private static T visitThreeLevelList( + GroupType list, Type repeated, Type listElement, ParquetTypeVisitor visitor) { + visitor.beforeRepeatedElement(repeated); + try { + T elementResult = visitListElement(listElement, visitor); + return visitor.list(list, elementResult); + } finally { + visitor.afterRepeatedElement(repeated); + } + } + + private static T visitListElement(Type listElement, ParquetTypeVisitor visitor) { + T elementResult = null; + + visitor.beforeElementField(listElement); + try { + elementResult = visit(listElement, visitor); + } finally { + visitor.afterElementField(listElement); + } + + return elementResult; + } + + private static T visitMap(GroupType map, ParquetTypeVisitor visitor) { + Preconditions.checkArgument(!map.isRepetition(Type.Repetition.REPEATED), + "Invalid map: top-level group is repeated: %s", map); + Preconditions.checkArgument(map.getFieldCount() == 1, + "Invalid map: does not contain single repeated field: %s", map); + + GroupType repeatedKeyValue = map.getType(0).asGroupType(); + Preconditions.checkArgument(repeatedKeyValue.isRepetition(Type.Repetition.REPEATED), + "Invalid map: inner group is not repeated"); + Preconditions.checkArgument(repeatedKeyValue.getFieldCount() <= 2, + "Invalid map: repeated group does not have 2 fields"); + + visitor.beforeRepeatedKeyValue(repeatedKeyValue); + try { + T keyResult = null; + T valueResult = null; + switch (repeatedKeyValue.getFieldCount()) { + case 2: + // if there are 2 fields, both key and value are projected + Type keyType = repeatedKeyValue.getType(0); + visitor.beforeKeyField(keyType); + try { + keyResult = visit(keyType, visitor); + } finally { + visitor.afterKeyField(keyType); + } + Type valueType = repeatedKeyValue.getType(1); + visitor.beforeValueField(valueType); + try { + valueResult = visit(valueType, visitor); + } finally { + visitor.afterValueField(valueType); + } + break; + + case 1: + // if there is just one, use the name to determine what it is + Type keyOrValue = repeatedKeyValue.getType(0); + if (keyOrValue.getName().equalsIgnoreCase("key")) { + visitor.beforeKeyField(keyOrValue); + try { + keyResult = visit(keyOrValue, visitor); + } finally { + visitor.afterKeyField(keyOrValue); + } + // value result remains null + } else { + visitor.beforeValueField(keyOrValue); + try { + valueResult = visit(keyOrValue, visitor); + } finally { + visitor.afterValueField(keyOrValue); + } + // key result remains null + } + break; + + default: + // both results will remain null + } + + return visitor.map(map, keyResult, valueResult); + + } finally { + visitor.afterRepeatedKeyValue(repeatedKeyValue); + } + } + + private static List visitFields(GroupType group, ParquetTypeVisitor visitor) { + List results = Lists.newArrayListWithExpectedSize(group.getFieldCount()); + for (Type field : group.getFields()) { + visitor.beforeField(field); + try { + results.add(visit(field, visitor)); + } finally { + visitor.afterField(field); + } + } + + return results; + } + + public T message(MessageType message, List fields) { + return null; + } + + public T struct(GroupType struct, List fields) { + return null; + } + + public T list(GroupType array, T element) { + return null; + } + + public T map(GroupType map, T key, T value) { + return null; + } + + public T primitive(PrimitiveType primitive) { + return null; + } + + public void beforeField(Type type) { + fieldNames.push(type.getName()); + } + + public void afterField(Type type) { + fieldNames.pop(); + } + + public void beforeRepeatedElement(Type element) { + beforeField(element); + } + + public void afterRepeatedElement(Type element) { + afterField(element); + } + + public void beforeElementField(Type element) { + beforeField(element); + } + + public void afterElementField(Type element) { + afterField(element); + } + + public void beforeRepeatedKeyValue(Type keyValue) { + beforeField(keyValue); + } + + public void afterRepeatedKeyValue(Type keyValue) { + afterField(keyValue); + } + + public void beforeKeyField(Type key) { + beforeField(key); + } + + public void afterKeyField(Type key) { + afterField(key); + } + + public void beforeValueField(Type value) { + beforeField(value); + } + + public void afterValueField(Type value) { + afterField(value); + } + + protected String[] currentPath() { + return Lists.newArrayList(fieldNames.descendingIterator()).toArray(new String[0]); + } + + protected String[] path(String name) { + List list = Lists.newArrayList(fieldNames.descendingIterator()); + list.add(name); + return list.toArray(new String[0]); + } +} diff --git a/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/parquet/ParquetUtil.java b/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/parquet/ParquetUtil.java new file mode 100644 index 00000000000..512ce1d5907 --- /dev/null +++ b/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/parquet/ParquetUtil.java @@ -0,0 +1,331 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark.rapids.iceberg.parquet; + +import java.util.Set; + +import org.apache.iceberg.relocated.com.google.common.collect.Sets; +import org.apache.parquet.column.Encoding; +import org.apache.parquet.column.EncodingStats; +import org.apache.parquet.hadoop.metadata.ColumnChunkMetaData; + +public class ParquetUtil { + // not meant to be instantiated + private ParquetUtil() { + } + +// public static Metrics fileMetrics(InputFile file, MetricsConfig metricsConfig) { +// return fileMetrics(file, metricsConfig, null); +// } +// +// public static Metrics fileMetrics(InputFile file, MetricsConfig metricsConfig, NameMapping nameMapping) { +// try (ParquetFileReader reader = ParquetFileReader.open(ParquetIO.file(file))) { +// return footerMetrics(reader.getFooter(), Stream.empty(), metricsConfig, nameMapping); +// } catch (IOException e) { +// throw new RuntimeIOException(e, "Failed to read footer of file: %s", file); +// } +// } +// +// public static Metrics footerMetrics(ParquetMetadata metadata, Stream> fieldMetrics, +// MetricsConfig metricsConfig) { +// return footerMetrics(metadata, fieldMetrics, metricsConfig, null); +// } +// +// @SuppressWarnings("checkstyle:CyclomaticComplexity") +// public static Metrics footerMetrics(ParquetMetadata metadata, Stream> fieldMetrics, +// MetricsConfig metricsConfig, NameMapping nameMapping) { +// Preconditions.checkNotNull(fieldMetrics, "fieldMetrics should not be null"); +// +// long rowCount = 0; +// Map columnSizes = Maps.newHashMap(); +// Map valueCounts = Maps.newHashMap(); +// Map nullValueCounts = Maps.newHashMap(); +// Map> lowerBounds = Maps.newHashMap(); +// Map> upperBounds = Maps.newHashMap(); +// Set missingStats = Sets.newHashSet(); +// +// // ignore metrics for fields we failed to determine reliable IDs +// MessageType parquetTypeWithIds = getParquetTypeWithIds(metadata, nameMapping); +// Schema fileSchema = ParquetSchemaUtil.convertAndPrune(parquetTypeWithIds); +// +// Map> fieldMetricsMap = fieldMetrics.collect( +// Collectors.toMap(FieldMetrics::id, Function.identity())); +// +// List blocks = metadata.getBlocks(); +// for (BlockMetaData block : blocks) { +// rowCount += block.getRowCount(); +// for (ColumnChunkMetaData column : block.getColumns()) { +// +// Integer fieldId = fileSchema.aliasToId(column.getPath().toDotString()); +// if (fieldId == null) { +// // fileSchema may contain a subset of columns present in the file +// // as we prune columns we could not assign ids +// continue; +// } +// +// increment(columnSizes, fieldId, column.getTotalSize()); +// +// MetricsMode metricsMode = MetricsUtil.metricsMode(fileSchema, metricsConfig, fieldId); +// if (metricsMode == MetricsModes.None.get()) { +// continue; +// } +// increment(valueCounts, fieldId, column.getValueCount()); +// +// Statistics stats = column.getStatistics(); +// if (stats == null) { +// missingStats.add(fieldId); +// } else if (!stats.isEmpty()) { +// increment(nullValueCounts, fieldId, stats.getNumNulls()); +// +// // when there are metrics gathered by Iceberg for a column, we should use those instead +// // of the ones from Parquet +// if (metricsMode != MetricsModes.Counts.get() && !fieldMetricsMap.containsKey(fieldId)) { +// Types.NestedField field = fileSchema.findField(fieldId); +// if (field != null && stats.hasNonNullValue() && shouldStoreBounds(column, fileSchema)) { +// Literal min = ParquetConversions.fromParquetPrimitive( +// field.type(), column.getPrimitiveType(), stats.genericGetMin()); +// updateMin(lowerBounds, fieldId, field.type(), min, metricsMode); +// Literal max = ParquetConversions.fromParquetPrimitive( +// field.type(), column.getPrimitiveType(), stats.genericGetMax()); +// updateMax(upperBounds, fieldId, field.type(), max, metricsMode); +// } +// } +// } +// } +// } +// +// // discard accumulated values if any stats were missing +// for (Integer fieldId : missingStats) { +// nullValueCounts.remove(fieldId); +// lowerBounds.remove(fieldId); +// upperBounds.remove(fieldId); +// } +// +// updateFromFieldMetrics(fieldMetricsMap, metricsConfig, fileSchema, lowerBounds, upperBounds); +// +// return new Metrics(rowCount, columnSizes, valueCounts, nullValueCounts, +// MetricsUtil.createNanValueCounts(fieldMetricsMap.values().stream(), metricsConfig, fileSchema), +// toBufferMap(fileSchema, lowerBounds), +// toBufferMap(fileSchema, upperBounds)); +// } +// +// private static void updateFromFieldMetrics( +// Map> idToFieldMetricsMap, MetricsConfig metricsConfig, Schema schema, +// Map> lowerBounds, Map> upperBounds) { +// idToFieldMetricsMap.entrySet().forEach(entry -> { +// int fieldId = entry.getKey(); +// FieldMetrics metrics = entry.getValue(); +// MetricsMode metricsMode = MetricsUtil.metricsMode(schema, metricsConfig, fieldId); +// +// // only check for MetricsModes.None, since we don't truncate float/double values. +// if (metricsMode != MetricsModes.None.get()) { +// if (!metrics.hasBounds()) { +// lowerBounds.remove(fieldId); +// upperBounds.remove(fieldId); +// } else if (metrics.upperBound() instanceof Float) { +// lowerBounds.put(fieldId, Literal.of((Float) metrics.lowerBound())); +// upperBounds.put(fieldId, Literal.of((Float) metrics.upperBound())); +// } else if (metrics.upperBound() instanceof Double) { +// lowerBounds.put(fieldId, Literal.of((Double) metrics.lowerBound())); +// upperBounds.put(fieldId, Literal.of((Double) metrics.upperBound())); +// } else { +// throw new UnsupportedOperationException("Expected only float or double column metrics"); +// } +// } +// }); +// } +// +// private static MessageType getParquetTypeWithIds(ParquetMetadata metadata, NameMapping nameMapping) { +// MessageType type = metadata.getFileMetaData().getSchema(); +// +// if (ParquetSchemaUtil.hasIds(type)) { +// return type; +// } +// +// if (nameMapping != null) { +// return ParquetSchemaUtil.applyNameMapping(type, nameMapping); +// } +// +// return ParquetSchemaUtil.addFallbackIds(type); +// } +// +// /** +// * Returns a list of offsets in ascending order determined by the starting position of the row groups. +// */ +// public static List getSplitOffsets(ParquetMetadata md) { +// List splitOffsets = Lists.newArrayListWithExpectedSize(md.getBlocks().size()); +// for (BlockMetaData blockMetaData : md.getBlocks()) { +// splitOffsets.add(blockMetaData.getStartingPos()); +// } +// Collections.sort(splitOffsets); +// return splitOffsets; +// } +// +// // we allow struct nesting, but not maps or arrays +// private static boolean shouldStoreBounds(ColumnChunkMetaData column, Schema schema) { +// if (column.getPrimitiveType().getPrimitiveTypeName() == PrimitiveType.PrimitiveTypeName.INT96) { +// // stats for INT96 are not reliable +// return false; +// } +// +// ColumnPath columnPath = column.getPath(); +// Iterator pathIterator = columnPath.iterator(); +// Type currentType = schema.asStruct(); +// +// while (pathIterator.hasNext()) { +// if (currentType == null || !currentType.isStructType()) { +// return false; +// } +// String fieldName = pathIterator.next(); +// currentType = currentType.asStructType().fieldType(fieldName); +// } +// +// return currentType != null && currentType.isPrimitiveType(); +// } +// +// private static void increment(Map columns, int fieldId, long amount) { +// if (columns != null) { +// if (columns.containsKey(fieldId)) { +// columns.put(fieldId, columns.get(fieldId) + amount); +// } else { +// columns.put(fieldId, amount); +// } +// } +// } +// +// @SuppressWarnings("unchecked") +// private static void updateMin(Map> lowerBounds, int id, Type type, +// Literal min, MetricsMode metricsMode) { +// Literal currentMin = (Literal) lowerBounds.get(id); +// if (currentMin == null || min.comparator().compare(min.value(), currentMin.value()) < 0) { +// if (metricsMode == MetricsModes.Full.get()) { +// lowerBounds.put(id, min); +// } else { +// MetricsModes.Truncate truncateMode = (MetricsModes.Truncate) metricsMode; +// int truncateLength = truncateMode.length(); +// switch (type.typeId()) { +// case STRING: +// lowerBounds.put(id, UnicodeUtil.truncateStringMin((Literal) min, truncateLength)); +// break; +// case FIXED: +// case BINARY: +// lowerBounds.put(id, BinaryUtil.truncateBinaryMin((Literal) min, truncateLength)); +// break; +// default: +// lowerBounds.put(id, min); +// } +// } +// } +// } +// +// @SuppressWarnings("unchecked") +// private static void updateMax(Map> upperBounds, int id, Type type, +// Literal max, MetricsMode metricsMode) { +// Literal currentMax = (Literal) upperBounds.get(id); +// if (currentMax == null || max.comparator().compare(max.value(), currentMax.value()) > 0) { +// if (metricsMode == MetricsModes.Full.get()) { +// upperBounds.put(id, max); +// } else { +// MetricsModes.Truncate truncateMode = (MetricsModes.Truncate) metricsMode; +// int truncateLength = truncateMode.length(); +// switch (type.typeId()) { +// case STRING: +// Literal truncatedMaxString = UnicodeUtil.truncateStringMax((Literal) max, +// truncateLength); +// if (truncatedMaxString != null) { +// upperBounds.put(id, truncatedMaxString); +// } +// break; +// case FIXED: +// case BINARY: +// Literal truncatedMaxBinary = BinaryUtil.truncateBinaryMax((Literal) max, +// truncateLength); +// if (truncatedMaxBinary != null) { +// upperBounds.put(id, truncatedMaxBinary); +// } +// break; +// default: +// upperBounds.put(id, max); +// } +// } +// } +// } +// +// private static Map toBufferMap(Schema schema, Map> map) { +// Map bufferMap = Maps.newHashMap(); +// for (Map.Entry> entry : map.entrySet()) { +// bufferMap.put(entry.getKey(), +// Conversions.toByteBuffer(schema.findType(entry.getKey()), entry.getValue().value())); +// } +// return bufferMap; +// } + + @SuppressWarnings("deprecation") + public static boolean hasNonDictionaryPages(ColumnChunkMetaData meta) { + EncodingStats stats = meta.getEncodingStats(); + if (stats != null) { + return stats.hasNonDictionaryEncodedPages(); + } + + // without EncodingStats, fall back to testing the encoding list + Set encodings = Sets.newHashSet(meta.getEncodings()); + if (encodings.remove(Encoding.PLAIN_DICTIONARY)) { + // if remove returned true, PLAIN_DICTIONARY was present, which means at + // least one page was dictionary encoded and 1.0 encodings are used + + // RLE and BIT_PACKED are only used for repetition or definition levels + encodings.remove(Encoding.RLE); + encodings.remove(Encoding.BIT_PACKED); + + // when empty, no encodings other than dictionary or rep/def levels + return !encodings.isEmpty(); + } else { + // if PLAIN_DICTIONARY wasn't present, then either the column is not + // dictionary-encoded, or the 2.0 encoding, RLE_DICTIONARY, was used. + // for 2.0, this cannot determine whether a page fell back without + // page encoding stats + return true; + } + } + +// public static Dictionary readDictionary(ColumnDescriptor desc, PageReader pageSource) { +// DictionaryPage dictionaryPage = pageSource.readDictionaryPage(); +// if (dictionaryPage != null) { +// try { +// return dictionaryPage.getEncoding().initDictionary(desc, dictionaryPage); +// } catch (IOException e) { +// throw new ParquetDecodingException("could not decode the dictionary for " + desc, e); +// } +// } +// return null; +// } +// +// public static boolean isIntType(PrimitiveType primitiveType) { +// if (primitiveType.getOriginalType() != null) { +// switch (primitiveType.getOriginalType()) { +// case INT_8: +// case INT_16: +// case INT_32: +// case DATE: +// return true; +// default: +// return false; +// } +// } +// return primitiveType.getPrimitiveTypeName() == PrimitiveType.PrimitiveTypeName.INT32; +// } +} diff --git a/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/parquet/PruneColumns.java b/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/parquet/PruneColumns.java new file mode 100644 index 00000000000..52ecc19ba70 --- /dev/null +++ b/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/parquet/PruneColumns.java @@ -0,0 +1,172 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark.rapids.iceberg.parquet; + +import java.util.Collections; +import java.util.List; +import java.util.Set; +import org.apache.iceberg.relocated.com.google.common.base.Objects; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.parquet.schema.GroupType; +import org.apache.parquet.schema.MessageType; +import org.apache.parquet.schema.OriginalType; +import org.apache.parquet.schema.PrimitiveType; +import org.apache.parquet.schema.Type; +import org.apache.parquet.schema.Types; + +public class PruneColumns extends ParquetTypeVisitor { + private final Set selectedIds; + + PruneColumns(Set selectedIds) { + Preconditions.checkNotNull(selectedIds, "Selected field ids cannot be null"); + this.selectedIds = selectedIds; + } + + @Override + public Type message(MessageType message, List fields) { + Types.MessageTypeBuilder builder = Types.buildMessage(); + + boolean hasChange = false; + int fieldCount = 0; + for (int i = 0; i < fields.size(); i += 1) { + Type originalField = message.getType(i); + Type field = fields.get(i); + Integer fieldId = getId(originalField); + if (fieldId != null && selectedIds.contains(fieldId)) { + if (field != null) { + hasChange = true; + builder.addField(field); + } else { + if (isStruct(originalField)) { + hasChange = true; + builder.addField(originalField.asGroupType().withNewFields(Collections.emptyList())); + } else { + builder.addField(originalField); + } + } + fieldCount += 1; + } else if (field != null) { + hasChange = true; + builder.addField(field); + fieldCount += 1; + } + } + + if (hasChange) { + return builder.named(message.getName()); + } else if (message.getFieldCount() == fieldCount) { + return message; + } + + return builder.named(message.getName()); + } + + @Override + public Type struct(GroupType struct, List fields) { + boolean hasChange = false; + List filteredFields = Lists.newArrayListWithExpectedSize(fields.size()); + for (int i = 0; i < fields.size(); i += 1) { + Type originalField = struct.getType(i); + Type field = fields.get(i); + Integer fieldId = getId(originalField); + if (fieldId != null && selectedIds.contains(fieldId)) { + filteredFields.add(originalField); + } else if (field != null) { + filteredFields.add(originalField); + hasChange = true; + } + } + + if (hasChange) { + return struct.withNewFields(filteredFields); + } else if (struct.getFieldCount() == filteredFields.size()) { + return struct; + } else if (!filteredFields.isEmpty()) { + return struct.withNewFields(filteredFields); + } + + return null; + } + + @Override + public Type list(GroupType list, Type element) { + Type repeated = list.getType(0); + Type originalElement = ParquetSchemaUtil.determineListElementType(list); + Integer elementId = getId(originalElement); + + if (elementId != null && selectedIds.contains(elementId)) { + return list; + } else if (element != null) { + if (!Objects.equal(element, originalElement)) { + if (originalElement.isRepetition(Type.Repetition.REPEATED)) { + return list.withNewFields(element); + } else { + return list.withNewFields(repeated.asGroupType().withNewFields(element)); + } + } + return list; + } + + return null; + } + + @Override + public Type map(GroupType map, Type key, Type value) { + GroupType repeated = map.getType(0).asGroupType(); + Type originalKey = repeated.getType(0); + Type originalValue = repeated.getType(1); + + Integer keyId = getId(originalKey); + Integer valueId = getId(originalValue); + + if ((keyId != null && selectedIds.contains(keyId)) || (valueId != null && selectedIds.contains(valueId))) { + return map; + } else if (value != null) { + if (!Objects.equal(value, originalValue)) { + return map.withNewFields(repeated.withNewFields(originalKey, value)); + } + return map; + } + + return null; + } + + @Override + public Type primitive(PrimitiveType primitive) { + return null; + } + + private Integer getId(Type type) { + return type.getId() == null ? null : type.getId().intValue(); + } + + private boolean isStruct(Type field) { + if (field.isPrimitive()) { + return false; + } else { + GroupType groupType = field.asGroupType(); + // Spark 3.1 uses Parquet 1.10 which does not have LogicalTypeAnnotation +// LogicalTypeAnnotation logicalTypeAnnotation = groupType.getLogicalTypeAnnotation(); +// return !logicalTypeAnnotation.equals(LogicalTypeAnnotation.mapType()) && +// !logicalTypeAnnotation.equals(LogicalTypeAnnotation.listType()); + OriginalType originalType = groupType.getOriginalType(); + return !originalType.equals(OriginalType.MAP) && + !originalType.equals(OriginalType.LIST); + } + } +} diff --git a/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/parquet/TypeWithSchemaVisitor.java b/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/parquet/TypeWithSchemaVisitor.java new file mode 100644 index 00000000000..090b5e712f4 --- /dev/null +++ b/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/parquet/TypeWithSchemaVisitor.java @@ -0,0 +1,216 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark.rapids.iceberg.parquet; + +import java.util.ArrayDeque; +import java.util.List; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.types.Types; +import org.apache.parquet.schema.GroupType; +import org.apache.parquet.schema.MessageType; +import org.apache.parquet.schema.OriginalType; +import org.apache.parquet.schema.PrimitiveType; +import org.apache.parquet.schema.Type; + +/** + * Visitor for traversing a Parquet type with a companion Iceberg type. + * + * @param the Java class returned by the visitor + */ +public class TypeWithSchemaVisitor { + @SuppressWarnings("checkstyle:VisibilityModifier") + protected ArrayDeque fieldNames = new ArrayDeque<>(); + + @SuppressWarnings("checkstyle:CyclomaticComplexity") + public static T visit(org.apache.iceberg.types.Type iType, Type type, TypeWithSchemaVisitor visitor) { + if (type instanceof MessageType) { + Types.StructType struct = iType != null ? iType.asStructType() : null; + return visitor.message(struct, (MessageType) type, + visitFields(struct, type.asGroupType(), visitor)); + + } else if (type.isPrimitive()) { + org.apache.iceberg.types.Type.PrimitiveType iPrimitive = iType != null ? + iType.asPrimitiveType() : null; + return visitor.primitive(iPrimitive, type.asPrimitiveType()); + + } else { + // if not a primitive, the typeId must be a group + GroupType group = type.asGroupType(); + OriginalType annotation = group.getOriginalType(); + if (annotation != null) { + switch (annotation) { + case LIST: + Preconditions.checkArgument(!group.isRepetition(Type.Repetition.REPEATED), + "Invalid list: top-level group is repeated: %s", group); + Preconditions.checkArgument(group.getFieldCount() == 1, + "Invalid list: does not contain single repeated field: %s", group); + + Type repeatedElement = group.getFields().get(0); + Preconditions.checkArgument(repeatedElement.isRepetition(Type.Repetition.REPEATED), + "Invalid list: inner group is not repeated"); + + Type listElement = ParquetSchemaUtil.determineListElementType(group); + Types.ListType list = null; + Types.NestedField element = null; + if (iType != null) { + list = iType.asListType(); + element = list.fields().get(0); + } + + if (listElement.isRepetition(Type.Repetition.REPEATED)) { + return visitTwoLevelList(list, element, group, listElement, visitor); + } else { + return visitThreeLevelList(list, element, group, listElement, visitor); + } + + case MAP: + Preconditions.checkArgument(!group.isRepetition(Type.Repetition.REPEATED), + "Invalid map: top-level group is repeated: %s", group); + Preconditions.checkArgument(group.getFieldCount() == 1, + "Invalid map: does not contain single repeated field: %s", group); + + GroupType repeatedKeyValue = group.getType(0).asGroupType(); + Preconditions.checkArgument(repeatedKeyValue.isRepetition(Type.Repetition.REPEATED), + "Invalid map: inner group is not repeated"); + Preconditions.checkArgument(repeatedKeyValue.getFieldCount() <= 2, + "Invalid map: repeated group does not have 2 fields"); + + Types.MapType map = null; + Types.NestedField keyField = null; + Types.NestedField valueField = null; + if (iType != null) { + map = iType.asMapType(); + keyField = map.fields().get(0); + valueField = map.fields().get(1); + } + + visitor.fieldNames.push(repeatedKeyValue.getName()); + try { + T keyResult = null; + T valueResult = null; + switch (repeatedKeyValue.getFieldCount()) { + case 2: + // if there are 2 fields, both key and value are projected + keyResult = visitField(keyField, repeatedKeyValue.getType(0), visitor); + valueResult = visitField(valueField, repeatedKeyValue.getType(1), visitor); + break; + case 1: + // if there is just one, use the name to determine what it is + Type keyOrValue = repeatedKeyValue.getType(0); + if (keyOrValue.getName().equalsIgnoreCase("key")) { + keyResult = visitField(keyField, keyOrValue, visitor); + // value result remains null + } else { + valueResult = visitField(valueField, keyOrValue, visitor); + // key result remains null + } + break; + default: + // both results will remain null + } + + return visitor.map(map, group, keyResult, valueResult); + + } finally { + visitor.fieldNames.pop(); + } + + default: + } + } + + Types.StructType struct = iType != null ? iType.asStructType() : null; + return visitor.struct(struct, group, visitFields(struct, group, visitor)); + } + } + + private static T visitTwoLevelList( + Types.ListType iListType, Types.NestedField iListElement, GroupType pListType, Type pListElement, + TypeWithSchemaVisitor visitor) { + T elementResult = visitField(iListElement, pListElement, visitor); + return visitor.list(iListType, pListType, elementResult); + } + + private static T visitThreeLevelList( + Types.ListType iListType, Types.NestedField iListElement, GroupType pListType, Type pListElement, + TypeWithSchemaVisitor visitor) { + visitor.fieldNames.push(pListType.getFieldName(0)); + + try { + T elementResult = visitField(iListElement, pListElement, visitor); + + return visitor.list(iListType, pListType, elementResult); + } finally { + visitor.fieldNames.pop(); + } + } + + private static T visitField(Types.NestedField iField, Type field, TypeWithSchemaVisitor visitor) { + visitor.fieldNames.push(field.getName()); + try { + return visit(iField != null ? iField.type() : null, field, visitor); + } finally { + visitor.fieldNames.pop(); + } + } + + private static List visitFields(Types.StructType struct, GroupType group, TypeWithSchemaVisitor visitor) { + List results = Lists.newArrayListWithExpectedSize(group.getFieldCount()); + for (Type field : group.getFields()) { + int id = -1; + if (field.getId() != null) { + id = field.getId().intValue(); + } + Types.NestedField iField = (struct != null && id >= 0) ? struct.field(id) : null; + results.add(visitField(iField, field, visitor)); + } + + return results; + } + + public T message(Types.StructType iStruct, MessageType message, List fields) { + return null; + } + + public T struct(Types.StructType iStruct, GroupType struct, List fields) { + return null; + } + + public T list(Types.ListType iList, GroupType array, T element) { + return null; + } + + public T map(Types.MapType iMap, GroupType map, T key, T value) { + return null; + } + + public T primitive(org.apache.iceberg.types.Type.PrimitiveType iPrimitive, + PrimitiveType primitive) { + return null; + } + + protected String[] currentPath() { + return Lists.newArrayList(fieldNames.descendingIterator()).toArray(new String[0]); + } + + protected String[] path(String name) { + List list = Lists.newArrayList(fieldNames.descendingIterator()); + list.add(name); + return list.toArray(new String[0]); + } +} diff --git a/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/spark/Spark3Util.java b/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/spark/Spark3Util.java new file mode 100644 index 00000000000..ada2586d4e8 --- /dev/null +++ b/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/spark/Spark3Util.java @@ -0,0 +1,821 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark.rapids.iceberg.spark; + +import java.nio.ByteBuffer; +import java.util.List; +import java.util.stream.Collectors; + +import org.apache.iceberg.expressions.BoundPredicate; +import org.apache.iceberg.expressions.ExpressionVisitors; +import org.apache.iceberg.expressions.Literal; +import org.apache.iceberg.expressions.UnboundPredicate; +import org.apache.iceberg.relocated.com.google.common.io.BaseEncoding; +import org.apache.iceberg.util.ByteBuffers; + +import org.apache.spark.sql.connector.expressions.Expressions; +import org.apache.spark.sql.connector.expressions.NamedReference; + +public class Spark3Util { + +// private static final Set RESERVED_PROPERTIES = ImmutableSet.of( +// TableCatalog.PROP_LOCATION, TableCatalog.PROP_PROVIDER); +// private static final Joiner DOT = Joiner.on("."); + + private Spark3Util() { + } + +// public static CaseInsensitiveStringMap setOption(String key, String value, CaseInsensitiveStringMap options) { +// Map newOptions = Maps.newHashMap(); +// newOptions.putAll(options); +// newOptions.put(key, value); +// return new CaseInsensitiveStringMap(newOptions); +// } +// +// public static Map rebuildCreateProperties(Map createProperties) { +// ImmutableMap.Builder tableProperties = ImmutableMap.builder(); +// createProperties.entrySet().stream() +// .filter(entry -> !RESERVED_PROPERTIES.contains(entry.getKey())) +// .forEach(tableProperties::put); +// +// String provider = createProperties.get(TableCatalog.PROP_PROVIDER); +// if ("parquet".equalsIgnoreCase(provider)) { +// tableProperties.put(TableProperties.DEFAULT_FILE_FORMAT, "parquet"); +// } else if ("avro".equalsIgnoreCase(provider)) { +// tableProperties.put(TableProperties.DEFAULT_FILE_FORMAT, "avro"); +// } else if ("orc".equalsIgnoreCase(provider)) { +// tableProperties.put(TableProperties.DEFAULT_FILE_FORMAT, "orc"); +// } else if (provider != null && !"iceberg".equalsIgnoreCase(provider)) { +// throw new IllegalArgumentException("Unsupported format in USING: " + provider); +// } +// +// return tableProperties.build(); +// } +// +// /** +// * Applies a list of Spark table changes to an {@link UpdateProperties} operation. +// * +// * @param pendingUpdate an uncommitted UpdateProperties operation to configure +// * @param changes a list of Spark table changes +// * @return the UpdateProperties operation configured with the changes +// */ +// public static UpdateProperties applyPropertyChanges(UpdateProperties pendingUpdate, List changes) { +// for (TableChange change : changes) { +// if (change instanceof TableChange.SetProperty) { +// TableChange.SetProperty set = (TableChange.SetProperty) change; +// pendingUpdate.set(set.property(), set.value()); +// +// } else if (change instanceof TableChange.RemoveProperty) { +// TableChange.RemoveProperty remove = (TableChange.RemoveProperty) change; +// pendingUpdate.remove(remove.property()); +// +// } else { +// throw new UnsupportedOperationException("Cannot apply unknown table change: " + change); +// } +// } +// +// return pendingUpdate; +// } +// +// /** +// * Applies a list of Spark table changes to an {@link UpdateSchema} operation. +// * +// * @param pendingUpdate an uncommitted UpdateSchema operation to configure +// * @param changes a list of Spark table changes +// * @return the UpdateSchema operation configured with the changes +// */ +// public static UpdateSchema applySchemaChanges(UpdateSchema pendingUpdate, List changes) { +// for (TableChange change : changes) { +// if (change instanceof TableChange.AddColumn) { +// apply(pendingUpdate, (TableChange.AddColumn) change); +// +// } else if (change instanceof TableChange.UpdateColumnType) { +// TableChange.UpdateColumnType update = (TableChange.UpdateColumnType) change; +// Type newType = SparkSchemaUtil.convert(update.newDataType()); +// Preconditions.checkArgument(newType.isPrimitiveType(), +// "Cannot update '%s', not a primitive type: %s", DOT.join(update.fieldNames()), update.newDataType()); +// pendingUpdate.updateColumn(DOT.join(update.fieldNames()), newType.asPrimitiveType()); +// +// } else if (change instanceof TableChange.UpdateColumnComment) { +// TableChange.UpdateColumnComment update = (TableChange.UpdateColumnComment) change; +// pendingUpdate.updateColumnDoc(DOT.join(update.fieldNames()), update.newComment()); +// +// } else if (change instanceof TableChange.RenameColumn) { +// TableChange.RenameColumn rename = (TableChange.RenameColumn) change; +// pendingUpdate.renameColumn(DOT.join(rename.fieldNames()), rename.newName()); +// +// } else if (change instanceof TableChange.DeleteColumn) { +// TableChange.DeleteColumn delete = (TableChange.DeleteColumn) change; +// pendingUpdate.deleteColumn(DOT.join(delete.fieldNames())); +// +// } else if (change instanceof TableChange.UpdateColumnNullability) { +// TableChange.UpdateColumnNullability update = (TableChange.UpdateColumnNullability) change; +// if (update.nullable()) { +// pendingUpdate.makeColumnOptional(DOT.join(update.fieldNames())); +// } else { +// pendingUpdate.requireColumn(DOT.join(update.fieldNames())); +// } +// +// } else if (change instanceof TableChange.UpdateColumnPosition) { +// apply(pendingUpdate, (TableChange.UpdateColumnPosition) change); +// +// } else { +// throw new UnsupportedOperationException("Cannot apply unknown table change: " + change); +// } +// } +// +// return pendingUpdate; +// } +// +// private static void apply(UpdateSchema pendingUpdate, TableChange.UpdateColumnPosition update) { +// Preconditions.checkArgument(update.position() != null, "Invalid position: null"); +// +// if (update.position() instanceof TableChange.After) { +// TableChange.After after = (TableChange.After) update.position(); +// String referenceField = peerName(update.fieldNames(), after.column()); +// pendingUpdate.moveAfter(DOT.join(update.fieldNames()), referenceField); +// +// } else if (update.position() instanceof TableChange.First) { +// pendingUpdate.moveFirst(DOT.join(update.fieldNames())); +// +// } else { +// throw new IllegalArgumentException("Unknown position for reorder: " + update.position()); +// } +// } +// +// private static void apply(UpdateSchema pendingUpdate, TableChange.AddColumn add) { +// Preconditions.checkArgument(add.isNullable(), +// "Incompatible change: cannot add required column: %s", leafName(add.fieldNames())); +// Type type = SparkSchemaUtil.convert(add.dataType()); +// pendingUpdate.addColumn(parentName(add.fieldNames()), leafName(add.fieldNames()), type, add.comment()); +// +// if (add.position() instanceof TableChange.After) { +// TableChange.After after = (TableChange.After) add.position(); +// String referenceField = peerName(add.fieldNames(), after.column()); +// pendingUpdate.moveAfter(DOT.join(add.fieldNames()), referenceField); +// +// } else if (add.position() instanceof TableChange.First) { +// pendingUpdate.moveFirst(DOT.join(add.fieldNames())); +// +// } else { +// Preconditions.checkArgument(add.position() == null, +// "Cannot add '%s' at unknown position: %s", DOT.join(add.fieldNames()), add.position()); +// } +// } +// +// public static org.apache.iceberg.Table toIcebergTable(Table table) { +// Preconditions.checkArgument(table instanceof SparkTable, "Table %s is not an Iceberg table", table); +// SparkTable sparkTable = (SparkTable) table; +// return sparkTable.table(); +// } +// +// /** +// * Converts a PartitionSpec to Spark transforms. +// * +// * @param spec a PartitionSpec +// * @return an array of Transforms +// */ +// public static Transform[] toTransforms(PartitionSpec spec) { +// Map quotedNameById = SparkSchemaUtil.indexQuotedNameById(spec.schema()); +// List transforms = PartitionSpecVisitor.visit(spec, +// new PartitionSpecVisitor() { +// @Override +// public Transform identity(String sourceName, int sourceId) { +// return Expressions.identity(quotedName(sourceId)); +// } +// +// @Override +// public Transform bucket(String sourceName, int sourceId, int numBuckets) { +// return Expressions.bucket(numBuckets, quotedName(sourceId)); +// } +// +// @Override +// public Transform truncate(String sourceName, int sourceId, int width) { +// return Expressions.apply("truncate", Expressions.column(quotedName(sourceId)), Expressions.literal(width)); +// } +// +// @Override +// public Transform year(String sourceName, int sourceId) { +// return Expressions.years(quotedName(sourceId)); +// } +// +// @Override +// public Transform month(String sourceName, int sourceId) { +// return Expressions.months(quotedName(sourceId)); +// } +// +// @Override +// public Transform day(String sourceName, int sourceId) { +// return Expressions.days(quotedName(sourceId)); +// } +// +// @Override +// public Transform hour(String sourceName, int sourceId) { +// return Expressions.hours(quotedName(sourceId)); +// } +// +// @Override +// public Transform alwaysNull(int fieldId, String sourceName, int sourceId) { +// // do nothing for alwaysNull, it doesn't need to be converted to a transform +// return null; +// } +// +// @Override +// public Transform unknown(int fieldId, String sourceName, int sourceId, String transform) { +// return Expressions.apply(transform, Expressions.column(quotedName(sourceId))); +// } +// +// private String quotedName(int id) { +// return quotedNameById.get(id); +// } +// }); +// +// return transforms.stream().filter(Objects::nonNull).toArray(Transform[]::new); +// } + + public static NamedReference toNamedReference(String name) { + return Expressions.column(name); + } + +// public static Term toIcebergTerm(Expression expr) { +// if (expr instanceof Transform) { +// Transform transform = (Transform) expr; +// Preconditions.checkArgument(transform.references().length == 1, +// "Cannot convert transform with more than one column reference: %s", transform); +// String colName = DOT.join(transform.references()[0].fieldNames()); +// switch (transform.name()) { +// case "identity": +// return org.apache.iceberg.expressions.Expressions.ref(colName); +// case "bucket": +// return org.apache.iceberg.expressions.Expressions.bucket(colName, findWidth(transform)); +// case "years": +// return org.apache.iceberg.expressions.Expressions.year(colName); +// case "months": +// return org.apache.iceberg.expressions.Expressions.month(colName); +// case "date": +// case "days": +// return org.apache.iceberg.expressions.Expressions.day(colName); +// case "date_hour": +// case "hours": +// return org.apache.iceberg.expressions.Expressions.hour(colName); +// case "truncate": +// return org.apache.iceberg.expressions.Expressions.truncate(colName, findWidth(transform)); +// default: +// throw new UnsupportedOperationException("Transform is not supported: " + transform); +// } +// +// } else if (expr instanceof NamedReference) { +// NamedReference ref = (NamedReference) expr; +// return org.apache.iceberg.expressions.Expressions.ref(DOT.join(ref.fieldNames())); +// +// } else { +// throw new UnsupportedOperationException("Cannot convert unknown expression: " + expr); +// } +// } +// +// /** +// * Converts Spark transforms into a {@link PartitionSpec}. +// * +// * @param schema the table schema +// * @param partitioning Spark Transforms +// * @return a PartitionSpec +// */ +// public static PartitionSpec toPartitionSpec(Schema schema, Transform[] partitioning) { +// if (partitioning == null || partitioning.length == 0) { +// return PartitionSpec.unpartitioned(); +// } +// +// PartitionSpec.Builder builder = PartitionSpec.builderFor(schema); +// for (Transform transform : partitioning) { +// Preconditions.checkArgument(transform.references().length == 1, +// "Cannot convert transform with more than one column reference: %s", transform); +// String colName = DOT.join(transform.references()[0].fieldNames()); +// switch (transform.name()) { +// case "identity": +// builder.identity(colName); +// break; +// case "bucket": +// builder.bucket(colName, findWidth(transform)); +// break; +// case "years": +// builder.year(colName); +// break; +// case "months": +// builder.month(colName); +// break; +// case "date": +// case "days": +// builder.day(colName); +// break; +// case "date_hour": +// case "hours": +// builder.hour(colName); +// break; +// case "truncate": +// builder.truncate(colName, findWidth(transform)); +// break; +// default: +// throw new UnsupportedOperationException("Transform is not supported: " + transform); +// } +// } +// +// return builder.build(); +// } +// +// @SuppressWarnings("unchecked") +// private static int findWidth(Transform transform) { +// for (Expression expr : transform.arguments()) { +// if (expr instanceof Literal) { +// if (((Literal) expr).dataType() instanceof IntegerType) { +// Literal lit = (Literal) expr; +// Preconditions.checkArgument(lit.value() > 0, +// "Unsupported width for transform: %s", transform.describe()); +// return lit.value(); +// +// } else if (((Literal) expr).dataType() instanceof LongType) { +// Literal lit = (Literal) expr; +// Preconditions.checkArgument(lit.value() > 0 && lit.value() < Integer.MAX_VALUE, +// "Unsupported width for transform: %s", transform.describe()); +// if (lit.value() > Integer.MAX_VALUE) { +// throw new IllegalArgumentException(); +// } +// return lit.value().intValue(); +// } +// } +// } +// +// throw new IllegalArgumentException("Cannot find width for transform: " + transform.describe()); +// } +// +// private static String leafName(String[] fieldNames) { +// Preconditions.checkArgument(fieldNames.length > 0, "Invalid field name: at least one name is required"); +// return fieldNames[fieldNames.length - 1]; +// } +// +// private static String peerName(String[] fieldNames, String fieldName) { +// if (fieldNames.length > 1) { +// String[] peerNames = Arrays.copyOf(fieldNames, fieldNames.length); +// peerNames[fieldNames.length - 1] = fieldName; +// return DOT.join(peerNames); +// } +// return fieldName; +// } +// +// private static String parentName(String[] fieldNames) { +// if (fieldNames.length > 1) { +// return DOT.join(Arrays.copyOfRange(fieldNames, 0, fieldNames.length - 1)); +// } +// return null; +// } + + public static String describe(org.apache.iceberg.expressions.Expression expr) { + return ExpressionVisitors.visit(expr, DescribeExpressionVisitor.INSTANCE); + } + +// public static String describe(Schema schema) { +// return TypeUtil.visit(schema, DescribeSchemaVisitor.INSTANCE); +// } +// +// public static String describe(Type type) { +// return TypeUtil.visit(type, DescribeSchemaVisitor.INSTANCE); +// } +// +// public static String describe(org.apache.iceberg.SortOrder order) { +// return Joiner.on(", ").join(SortOrderVisitor.visit(order, DescribeSortOrderVisitor.INSTANCE)); +// } +// +// public static boolean extensionsEnabled(SparkSession spark) { +// String extensions = spark.conf().get("spark.sql.extensions", ""); +// return extensions.contains("IcebergSparkSessionExtensions"); +// } +// +// public static class DescribeSchemaVisitor extends TypeUtil.SchemaVisitor { +// private static final Joiner COMMA = Joiner.on(','); +// private static final DescribeSchemaVisitor INSTANCE = new DescribeSchemaVisitor(); +// +// private DescribeSchemaVisitor() { +// } +// +// @Override +// public String schema(Schema schema, String structResult) { +// return structResult; +// } +// +// @Override +// public String struct(Types.StructType struct, List fieldResults) { +// return "struct<" + COMMA.join(fieldResults) + ">"; +// } +// +// @Override +// public String field(Types.NestedField field, String fieldResult) { +// return field.name() + ": " + fieldResult + (field.isRequired() ? " not null" : ""); +// } +// +// @Override +// public String list(Types.ListType list, String elementResult) { +// return "list<" + elementResult + ">"; +// } +// +// @Override +// public String map(Types.MapType map, String keyResult, String valueResult) { +// return "map<" + keyResult + ", " + valueResult + ">"; +// } +// +// @Override +// public String primitive(Type.PrimitiveType primitive) { +// switch (primitive.typeId()) { +// case BOOLEAN: +// return "boolean"; +// case INTEGER: +// return "int"; +// case LONG: +// return "bigint"; +// case FLOAT: +// return "float"; +// case DOUBLE: +// return "double"; +// case DATE: +// return "date"; +// case TIME: +// return "time"; +// case TIMESTAMP: +// return "timestamp"; +// case STRING: +// case UUID: +// return "string"; +// case FIXED: +// case BINARY: +// return "binary"; +// case DECIMAL: +// Types.DecimalType decimal = (Types.DecimalType) primitive; +// return "decimal(" + decimal.precision() + "," + decimal.scale() + ")"; +// } +// throw new UnsupportedOperationException("Cannot convert type to SQL: " + primitive); +// } +// } + + private static class DescribeExpressionVisitor extends ExpressionVisitors.ExpressionVisitor { + private static final DescribeExpressionVisitor INSTANCE = new DescribeExpressionVisitor(); + + private DescribeExpressionVisitor() { + } + + @Override + public String alwaysTrue() { + return "true"; + } + + @Override + public String alwaysFalse() { + return "false"; + } + + @Override + public String not(String result) { + return "NOT (" + result + ")"; + } + + @Override + public String and(String leftResult, String rightResult) { + return "(" + leftResult + " AND " + rightResult + ")"; + } + + @Override + public String or(String leftResult, String rightResult) { + return "(" + leftResult + " OR " + rightResult + ")"; + } + + @Override + public String predicate(BoundPredicate pred) { + throw new UnsupportedOperationException("Cannot convert bound predicates to SQL"); + } + + @Override + public String predicate(UnboundPredicate pred) { + switch (pred.op()) { + case IS_NULL: + return pred.ref().name() + " IS NULL"; + case NOT_NULL: + return pred.ref().name() + " IS NOT NULL"; + case IS_NAN: + return "is_nan(" + pred.ref().name() + ")"; + case NOT_NAN: + return "not_nan(" + pred.ref().name() + ")"; + case LT: + return pred.ref().name() + " < " + sqlString(pred.literal()); + case LT_EQ: + return pred.ref().name() + " <= " + sqlString(pred.literal()); + case GT: + return pred.ref().name() + " > " + sqlString(pred.literal()); + case GT_EQ: + return pred.ref().name() + " >= " + sqlString(pred.literal()); + case EQ: + return pred.ref().name() + " = " + sqlString(pred.literal()); + case NOT_EQ: + return pred.ref().name() + " != " + sqlString(pred.literal()); + case STARTS_WITH: + return pred.ref().name() + " LIKE '" + pred.literal() + "%'"; + case NOT_STARTS_WITH: + return pred.ref().name() + " NOT LIKE '" + pred.literal() + "%'"; + case IN: + return pred.ref().name() + " IN (" + sqlString(pred.literals()) + ")"; + case NOT_IN: + return pred.ref().name() + " NOT IN (" + sqlString(pred.literals()) + ")"; + default: + throw new UnsupportedOperationException("Cannot convert predicate to SQL: " + pred); + } + } + + private static String sqlString(List> literals) { + return literals.stream().map(DescribeExpressionVisitor::sqlString).collect(Collectors.joining(", ")); + } + + private static String sqlString(org.apache.iceberg.expressions.Literal lit) { + if (lit.value() instanceof String) { + return "'" + lit.value() + "'"; + } else if (lit.value() instanceof ByteBuffer) { + byte[] bytes = ByteBuffers.toByteArray((ByteBuffer) lit.value()); + return "X'" + BaseEncoding.base16().encode(bytes) + "'"; + } else { + return lit.value().toString(); + } + } + } + +// /** +// * Returns a metadata table as a Dataset based on the given Iceberg table. +// * +// * @param spark SparkSession where the Dataset will be created +// * @param table an Iceberg table +// * @param type the type of metadata table +// * @return a Dataset that will read the metadata table +// */ +// private static Dataset loadMetadataTable(SparkSession spark, org.apache.iceberg.Table table, +// MetadataTableType type) { +// Table metadataTable = new SparkTable(MetadataTableUtils.createMetadataTableInstance(table, type), false); +// return Dataset.ofRows(spark, DataSourceV2Relation.create(metadataTable, Some.empty(), Some.empty())); +// } +// +// /** +// * Returns an Iceberg Table by its name from a Spark V2 Catalog. If cache is enabled in {@link SparkCatalog}, +// * the {@link TableOperations} of the table may be stale, please refresh the table to get the latest one. +// * +// * @param spark SparkSession used for looking up catalog references and tables +// * @param name The multipart identifier of the Iceberg table +// * @return an Iceberg table +// */ +// public static org.apache.iceberg.Table loadIcebergTable(SparkSession spark, String name) +// throws ParseException, NoSuchTableException { +// CatalogAndIdentifier catalogAndIdentifier = catalogAndIdentifier(spark, name); +// +// TableCatalog catalog = asTableCatalog(catalogAndIdentifier.catalog); +// Table sparkTable = catalog.loadTable(catalogAndIdentifier.identifier); +// return toIcebergTable(sparkTable); +// } +// +// public static CatalogAndIdentifier catalogAndIdentifier(SparkSession spark, String name) throws ParseException { +// return catalogAndIdentifier(spark, name, spark.sessionState().catalogManager().currentCatalog()); +// } +// +// public static CatalogAndIdentifier catalogAndIdentifier(SparkSession spark, String name, +// CatalogPlugin defaultCatalog) throws ParseException { +// ParserInterface parser = spark.sessionState().sqlParser(); +// Seq multiPartIdentifier = parser.parseMultipartIdentifier(name).toIndexedSeq(); +// List javaMultiPartIdentifier = JavaConverters.seqAsJavaList(multiPartIdentifier); +// return catalogAndIdentifier(spark, javaMultiPartIdentifier, defaultCatalog); +// } +// +// public static CatalogAndIdentifier catalogAndIdentifier(String description, SparkSession spark, String name) { +// return catalogAndIdentifier(description, spark, name, spark.sessionState().catalogManager().currentCatalog()); +// } +// +// public static CatalogAndIdentifier catalogAndIdentifier(String description, SparkSession spark, +// String name, CatalogPlugin defaultCatalog) { +// try { +// return catalogAndIdentifier(spark, name, defaultCatalog); +// } catch (ParseException e) { +// throw new IllegalArgumentException("Cannot parse " + description + ": " + name, e); +// } +// } +// +// public static CatalogAndIdentifier catalogAndIdentifier(SparkSession spark, List nameParts) { +// return catalogAndIdentifier(spark, nameParts, spark.sessionState().catalogManager().currentCatalog()); +// } +// +// /** +// * A modified version of Spark's LookupCatalog.CatalogAndIdentifier.unapply +// * Attempts to find the catalog and identifier a multipart identifier represents +// * @param spark Spark session to use for resolution +// * @param nameParts Multipart identifier representing a table +// * @param defaultCatalog Catalog to use if none is specified +// * @return The CatalogPlugin and Identifier for the table +// */ +// public static CatalogAndIdentifier catalogAndIdentifier(SparkSession spark, List nameParts, +// CatalogPlugin defaultCatalog) { +// CatalogManager catalogManager = spark.sessionState().catalogManager(); +// +// String[] currentNamespace; +// if (defaultCatalog.equals(catalogManager.currentCatalog())) { +// currentNamespace = catalogManager.currentNamespace(); +// } else { +// currentNamespace = defaultCatalog.defaultNamespace(); +// } +// +// Pair catalogIdentifier = SparkUtil.catalogAndIdentifier(nameParts, +// catalogName -> { +// try { +// return catalogManager.catalog(catalogName); +// } catch (Exception e) { +// return null; +// } +// }, +// Identifier::of, +// defaultCatalog, +// currentNamespace +// ); +// return new CatalogAndIdentifier(catalogIdentifier); +// } +// +// private static TableCatalog asTableCatalog(CatalogPlugin catalog) { +// if (catalog instanceof TableCatalog) { +// return (TableCatalog) catalog; +// } +// +// throw new IllegalArgumentException(String.format( +// "Cannot use catalog %s(%s): not a TableCatalog", catalog.name(), catalog.getClass().getName())); +// } +// +// /** +// * This mimics a class inside of Spark which is private inside of LookupCatalog. +// */ +// public static class CatalogAndIdentifier { +// private final CatalogPlugin catalog; +// private final Identifier identifier; +// +// +// public CatalogAndIdentifier(CatalogPlugin catalog, Identifier identifier) { +// this.catalog = catalog; +// this.identifier = identifier; +// } +// +// public CatalogAndIdentifier(Pair identifier) { +// this.catalog = identifier.first(); +// this.identifier = identifier.second(); +// } +// +// public CatalogPlugin catalog() { +// return catalog; +// } +// +// public Identifier identifier() { +// return identifier; +// } +// } +// +// public static TableIdentifier identifierToTableIdentifier(Identifier identifier) { +// return TableIdentifier.of(Namespace.of(identifier.namespace()), identifier.name()); +// } +// +// /** +// * Use Spark to list all partitions in the table. +// * +// * @param spark a Spark session +// * @param rootPath a table identifier +// * @param format format of the file +// * @param partitionFilter partitionFilter of the file +// * @return all table's partitions +// */ +// public static List getPartitions(SparkSession spark, Path rootPath, String format, +// Map partitionFilter) { +// FileStatusCache fileStatusCache = FileStatusCache.getOrCreate(spark); +// +// InMemoryFileIndex fileIndex = new InMemoryFileIndex( +// spark, +// JavaConverters +// .collectionAsScalaIterableConverter(ImmutableList.of(rootPath)) +// .asScala() +// .toSeq(), +// scala.collection.immutable.Map$.MODULE$.empty(), +// Option.empty(), +// fileStatusCache, +// Option.empty(), +// Option.empty()); +// +// org.apache.spark.sql.execution.datasources.PartitionSpec spec = fileIndex.partitionSpec(); +// StructType schema = spec.partitionColumns(); +// if (schema.isEmpty()) { +// return Lists.newArrayList(); +// } +// +// List filterExpressions = +// SparkUtil.partitionMapToExpression(schema, partitionFilter); +// Seq scalaPartitionFilters = +// JavaConverters.asScalaBufferConverter(filterExpressions).asScala().toIndexedSeq(); +// +// List dataFilters = Lists.newArrayList(); +// Seq scalaDataFilters = +// JavaConverters.asScalaBufferConverter(dataFilters).asScala().toIndexedSeq(); +// +// Seq filteredPartitions = +// fileIndex.listFiles(scalaPartitionFilters, scalaDataFilters).toIndexedSeq(); +// +// return JavaConverters +// .seqAsJavaListConverter(filteredPartitions) +// .asJava() +// .stream() +// .map(partition -> { +// Map values = Maps.newHashMap(); +// JavaConverters.asJavaIterableConverter(schema).asJava().forEach(field -> { +// int fieldIndex = schema.fieldIndex(field.name()); +// Object catalystValue = partition.values().get(fieldIndex, field.dataType()); +// Object value = CatalystTypeConverters.convertToScala(catalystValue, field.dataType()); +// values.put(field.name(), String.valueOf(value)); +// }); +// +// FileStatus fileStatus = +// JavaConverters.seqAsJavaListConverter(partition.files()).asJava().get(0); +// +// return new SparkPartition(values, fileStatus.getPath().getParent().toString(), format); +// }).collect(Collectors.toList()); +// } +// +// public static org.apache.spark.sql.catalyst.TableIdentifier toV1TableIdentifier(Identifier identifier) { +// String[] namespace = identifier.namespace(); +// +// Preconditions.checkArgument(namespace.length <= 1, +// "Cannot convert %s to a Spark v1 identifier, namespace contains more than 1 part", identifier); +// +// String table = identifier.name(); +// Option database = namespace.length == 1 ? Option.apply(namespace[0]) : Option.empty(); +// return org.apache.spark.sql.catalyst.TableIdentifier.apply(table, database); +// } +// +// private static class DescribeSortOrderVisitor implements SortOrderVisitor { +// private static final DescribeSortOrderVisitor INSTANCE = new DescribeSortOrderVisitor(); +// +// private DescribeSortOrderVisitor() { +// } +// +// @Override +// public String field(String sourceName, int sourceId, +// org.apache.iceberg.SortDirection direction, NullOrder nullOrder) { +// return String.format("%s %s %s", sourceName, direction, nullOrder); +// } +// +// @Override +// public String bucket(String sourceName, int sourceId, int numBuckets, +// org.apache.iceberg.SortDirection direction, NullOrder nullOrder) { +// return String.format("bucket(%s, %s) %s %s", numBuckets, sourceName, direction, nullOrder); +// } +// +// @Override +// public String truncate(String sourceName, int sourceId, int width, +// org.apache.iceberg.SortDirection direction, NullOrder nullOrder) { +// return String.format("truncate(%s, %s) %s %s", sourceName, width, direction, nullOrder); +// } +// +// @Override +// public String year(String sourceName, int sourceId, +// org.apache.iceberg.SortDirection direction, NullOrder nullOrder) { +// return String.format("years(%s) %s %s", sourceName, direction, nullOrder); +// } +// +// @Override +// public String month(String sourceName, int sourceId, +// org.apache.iceberg.SortDirection direction, NullOrder nullOrder) { +// return String.format("months(%s) %s %s", sourceName, direction, nullOrder); +// } +// +// @Override +// public String day(String sourceName, int sourceId, +// org.apache.iceberg.SortDirection direction, NullOrder nullOrder) { +// return String.format("days(%s) %s %s", sourceName, direction, nullOrder); +// } +// +// @Override +// public String hour(String sourceName, int sourceId, +// org.apache.iceberg.SortDirection direction, NullOrder nullOrder) { +// return String.format("hours(%s) %s %s", sourceName, direction, nullOrder); +// } +// +// @Override +// public String unknown(String sourceName, int sourceId, String transform, +// org.apache.iceberg.SortDirection direction, NullOrder nullOrder) { +// return String.format("%s(%s) %s %s", transform, sourceName, direction, nullOrder); +// } +// } +} diff --git a/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/spark/SparkConfParser.java b/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/spark/SparkConfParser.java new file mode 100644 index 00000000000..948c56dace4 --- /dev/null +++ b/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/spark/SparkConfParser.java @@ -0,0 +1,202 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark.rapids.iceberg.spark; + +import java.util.List; +import java.util.Locale; +import java.util.Map; +import java.util.function.Function; + +import org.apache.iceberg.Table; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; + +import org.apache.spark.sql.RuntimeConfig; +import org.apache.spark.sql.SparkSession; + +public class SparkConfParser { + + private final Map properties; + private final RuntimeConfig sessionConf; + private final Map options; + + SparkConfParser(SparkSession spark, Table table, Map options) { + this.properties = table.properties(); + this.sessionConf = spark.conf(); + this.options = options; + } + + public BooleanConfParser booleanConf() { + return new BooleanConfParser(); + } + + public IntConfParser intConf() { + return new IntConfParser(); + } + + public LongConfParser longConf() { + return new LongConfParser(); + } + + public StringConfParser stringConf() { + return new StringConfParser(); + } + + class BooleanConfParser extends ConfParser { + private Boolean defaultValue; + + @Override + protected BooleanConfParser self() { + return this; + } + + public BooleanConfParser defaultValue(boolean value) { + this.defaultValue = value; + return self(); + } + + public BooleanConfParser defaultValue(String value) { + this.defaultValue = Boolean.parseBoolean(value); + return self(); + } + + public boolean parse() { + Preconditions.checkArgument(defaultValue != null, "Default value cannot be null"); + return parse(Boolean::parseBoolean, defaultValue); + } + } + + class IntConfParser extends ConfParser { + private Integer defaultValue; + + @Override + protected IntConfParser self() { + return this; + } + + public IntConfParser defaultValue(int value) { + this.defaultValue = value; + return self(); + } + + public int parse() { + Preconditions.checkArgument(defaultValue != null, "Default value cannot be null"); + return parse(Integer::parseInt, defaultValue); + } + + public Integer parseOptional() { + return parse(Integer::parseInt, null); + } + } + + class LongConfParser extends ConfParser { + private Long defaultValue; + + @Override + protected LongConfParser self() { + return this; + } + + public LongConfParser defaultValue(long value) { + this.defaultValue = value; + return self(); + } + + public long parse() { + Preconditions.checkArgument(defaultValue != null, "Default value cannot be null"); + return parse(Long::parseLong, defaultValue); + } + + public Long parseOptional() { + return parse(Long::parseLong, null); + } + } + + class StringConfParser extends ConfParser { + private String defaultValue; + + @Override + protected StringConfParser self() { + return this; + } + + public StringConfParser defaultValue(String value) { + this.defaultValue = value; + return self(); + } + + public String parse() { + Preconditions.checkArgument(defaultValue != null, "Default value cannot be null"); + return parse(Function.identity(), defaultValue); + } + + public String parseOptional() { + return parse(Function.identity(), null); + } + } + + abstract class ConfParser { + private final List optionNames = Lists.newArrayList(); + private String sessionConfName; + private String tablePropertyName; + + protected abstract ThisT self(); + + public ThisT option(String name) { + this.optionNames.add(name); + return self(); + } + + public ThisT sessionConf(String name) { + this.sessionConfName = name; + return self(); + } + + public ThisT tableProperty(String name) { + this.tablePropertyName = name; + return self(); + } + + protected T parse(Function conversion, T defaultValue) { + if (!optionNames.isEmpty()) { + for (String optionName : optionNames) { + // use lower case comparison as DataSourceOptions.asMap() in Spark 2 returns a lower case map + String optionValue = options.get(optionName.toLowerCase(Locale.ROOT)); + if (optionValue != null) { + return conversion.apply(optionValue); + } + } + } + + if (sessionConfName != null) { + String sessionConfValue = sessionConf.get(sessionConfName, null); + if (sessionConfValue != null) { + return conversion.apply(sessionConfValue); + } + } + + if (tablePropertyName != null) { + String propertyValue = properties.get(tablePropertyName); + if (propertyValue != null) { + return conversion.apply(propertyValue); + } + } + + return defaultValue; + } + } +} diff --git a/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/spark/SparkFilters.java b/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/spark/SparkFilters.java new file mode 100644 index 00000000000..46ebbeec163 --- /dev/null +++ b/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/spark/SparkFilters.java @@ -0,0 +1,270 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark.rapids.iceberg.spark; + +import java.sql.Date; +import java.sql.Timestamp; +import java.time.Instant; +import java.time.LocalDate; +import java.util.Map; +import java.util.Objects; +import java.util.regex.Matcher; +import java.util.regex.Pattern; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +import org.apache.iceberg.expressions.Expression; +import org.apache.iceberg.expressions.Expression.Operation; +import org.apache.iceberg.expressions.Expressions; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.iceberg.util.NaNUtil; + +import org.apache.spark.sql.catalyst.util.DateTimeUtils; +import org.apache.spark.sql.sources.AlwaysFalse; +import org.apache.spark.sql.sources.AlwaysFalse$; +import org.apache.spark.sql.sources.AlwaysTrue; +import org.apache.spark.sql.sources.AlwaysTrue$; +import org.apache.spark.sql.sources.And; +import org.apache.spark.sql.sources.EqualNullSafe; +import org.apache.spark.sql.sources.EqualTo; +import org.apache.spark.sql.sources.Filter; +import org.apache.spark.sql.sources.GreaterThan; +import org.apache.spark.sql.sources.GreaterThanOrEqual; +import org.apache.spark.sql.sources.In; +import org.apache.spark.sql.sources.IsNotNull; +import org.apache.spark.sql.sources.IsNull; +import org.apache.spark.sql.sources.LessThan; +import org.apache.spark.sql.sources.LessThanOrEqual; +import org.apache.spark.sql.sources.Not; +import org.apache.spark.sql.sources.Or; +import org.apache.spark.sql.sources.StringStartsWith; + +import static org.apache.iceberg.expressions.Expressions.and; +import static org.apache.iceberg.expressions.Expressions.equal; +import static org.apache.iceberg.expressions.Expressions.greaterThan; +import static org.apache.iceberg.expressions.Expressions.greaterThanOrEqual; +import static org.apache.iceberg.expressions.Expressions.in; +import static org.apache.iceberg.expressions.Expressions.isNaN; +import static org.apache.iceberg.expressions.Expressions.isNull; +import static org.apache.iceberg.expressions.Expressions.lessThan; +import static org.apache.iceberg.expressions.Expressions.lessThanOrEqual; +import static org.apache.iceberg.expressions.Expressions.not; +import static org.apache.iceberg.expressions.Expressions.notIn; +import static org.apache.iceberg.expressions.Expressions.notNull; +import static org.apache.iceberg.expressions.Expressions.or; +import static org.apache.iceberg.expressions.Expressions.startsWith; + +public class SparkFilters { + + private static final Pattern BACKTICKS_PATTERN = Pattern.compile("([`])(.|$)"); + + private SparkFilters() { + } + + private static final Map, Operation> FILTERS = ImmutableMap + ., Expression.Operation>builder() + .put(AlwaysTrue.class, Expression.Operation.TRUE) + .put(AlwaysTrue$.class, Expression.Operation.TRUE) + .put(AlwaysFalse$.class, Expression.Operation.FALSE) + .put(AlwaysFalse.class, Expression.Operation.FALSE) + .put(EqualTo.class, Expression.Operation.EQ) + .put(EqualNullSafe.class, Expression.Operation.EQ) + .put(GreaterThan.class, Expression.Operation.GT) + .put(GreaterThanOrEqual.class, Expression.Operation.GT_EQ) + .put(LessThan.class, Expression.Operation.LT) + .put(LessThanOrEqual.class, Expression.Operation.LT_EQ) + .put(In.class, Expression.Operation.IN) + .put(IsNull.class, Expression.Operation.IS_NULL) + .put(IsNotNull.class, Expression.Operation.NOT_NULL) + .put(And.class, Expression.Operation.AND) + .put(Or.class, Expression.Operation.OR) + .put(Not.class, Expression.Operation.NOT) + .put(StringStartsWith.class, Expression.Operation.STARTS_WITH) + .build(); + + public static Expression convert(Filter[] filters) { + Expression expression = Expressions.alwaysTrue(); + for (Filter filter : filters) { + Expression converted = convert(filter); + Preconditions.checkArgument(converted != null, "Cannot convert filter to Iceberg: %s", filter); + expression = Expressions.and(expression, converted); + } + return expression; + } + + public static Expression convert(Filter filter) { + // avoid using a chain of if instanceof statements by mapping to the expression enum. + Operation op = FILTERS.get(filter.getClass()); + if (op != null) { + switch (op) { + case TRUE: + return Expressions.alwaysTrue(); + + case FALSE: + return Expressions.alwaysFalse(); + + case IS_NULL: + IsNull isNullFilter = (IsNull) filter; + return isNull(unquote(isNullFilter.attribute())); + + case NOT_NULL: + IsNotNull notNullFilter = (IsNotNull) filter; + return notNull(unquote(notNullFilter.attribute())); + + case LT: + LessThan lt = (LessThan) filter; + return lessThan(unquote(lt.attribute()), convertLiteral(lt.value())); + + case LT_EQ: + LessThanOrEqual ltEq = (LessThanOrEqual) filter; + return lessThanOrEqual(unquote(ltEq.attribute()), convertLiteral(ltEq.value())); + + case GT: + GreaterThan gt = (GreaterThan) filter; + return greaterThan(unquote(gt.attribute()), convertLiteral(gt.value())); + + case GT_EQ: + GreaterThanOrEqual gtEq = (GreaterThanOrEqual) filter; + return greaterThanOrEqual(unquote(gtEq.attribute()), convertLiteral(gtEq.value())); + + case EQ: // used for both eq and null-safe-eq + if (filter instanceof EqualTo) { + EqualTo eq = (EqualTo) filter; + // comparison with null in normal equality is always null. this is probably a mistake. + Preconditions.checkNotNull(eq.value(), + "Expression is always false (eq is not null-safe): %s", filter); + return handleEqual(unquote(eq.attribute()), eq.value()); + } else { + EqualNullSafe eq = (EqualNullSafe) filter; + if (eq.value() == null) { + return isNull(unquote(eq.attribute())); + } else { + return handleEqual(unquote(eq.attribute()), eq.value()); + } + } + + case IN: + In inFilter = (In) filter; + return in(unquote(inFilter.attribute()), + Stream.of(inFilter.values()) + .filter(Objects::nonNull) + .map(SparkFilters::convertLiteral) + .collect(Collectors.toList())); + + case NOT: + Not notFilter = (Not) filter; + Filter childFilter = notFilter.child(); + Expression.Operation childOp = FILTERS.get(childFilter.getClass()); + if (childOp == Expression.Operation.IN) { + // infer an extra notNull predicate for Spark NOT IN filters + // as Iceberg expressions don't follow the 3-value SQL boolean logic + // col NOT IN (1, 2) in Spark is equivalent to notNull(col) && notIn(col, 1, 2) in Iceberg + In childInFilter = (In) childFilter; + Expression notIn = notIn(unquote(childInFilter.attribute()), + Stream.of(childInFilter.values()) + .map(SparkFilters::convertLiteral) + .collect(Collectors.toList())); + return and(notNull(childInFilter.attribute()), notIn); + } else if (hasNoInFilter(childFilter)) { + Expression child = convert(childFilter); + if (child != null) { + return not(child); + } + } + return null; + + case AND: { + And andFilter = (And) filter; + Expression left = convert(andFilter.left()); + Expression right = convert(andFilter.right()); + if (left != null && right != null) { + return and(left, right); + } + return null; + } + + case OR: { + Or orFilter = (Or) filter; + Expression left = convert(orFilter.left()); + Expression right = convert(orFilter.right()); + if (left != null && right != null) { + return or(left, right); + } + return null; + } + + case STARTS_WITH: { + StringStartsWith stringStartsWith = (StringStartsWith) filter; + return startsWith(unquote(stringStartsWith.attribute()), stringStartsWith.value()); + } + } + } + + return null; + } + + private static Object convertLiteral(Object value) { + if (value instanceof Timestamp) { + return DateTimeUtils.fromJavaTimestamp((Timestamp) value); + } else if (value instanceof Date) { + return DateTimeUtils.fromJavaDate((Date) value); + } else if (value instanceof Instant) { + return DateTimeUtils.instantToMicros((Instant) value); + } else if (value instanceof LocalDate) { + return DateTimeUtils.localDateToDays((LocalDate) value); + } + return value; + } + + private static Expression handleEqual(String attribute, Object value) { + if (NaNUtil.isNaN(value)) { + return isNaN(attribute); + } else { + return equal(attribute, convertLiteral(value)); + } + } + + private static String unquote(String attributeName) { + Matcher matcher = BACKTICKS_PATTERN.matcher(attributeName); + return matcher.replaceAll("$2"); + } + + private static boolean hasNoInFilter(Filter filter) { + Expression.Operation op = FILTERS.get(filter.getClass()); + + if (op != null) { + switch (op) { + case AND: + And andFilter = (And) filter; + return hasNoInFilter(andFilter.left()) && hasNoInFilter(andFilter.right()); + case OR: + Or orFilter = (Or) filter; + return hasNoInFilter(orFilter.left()) && hasNoInFilter(orFilter.right()); + case NOT: + Not notFilter = (Not) filter; + return hasNoInFilter(notFilter.child()); + case IN: + return false; + default: + return true; + } + } + + return false; + } +} diff --git a/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/spark/SparkReadConf.java b/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/spark/SparkReadConf.java new file mode 100644 index 00000000000..7dd5f2db46a --- /dev/null +++ b/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/spark/SparkReadConf.java @@ -0,0 +1,234 @@ +/* + * Copyright (c) 2021-2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark.rapids.iceberg.spark; + +import java.util.Map; +import java.util.Set; + +import org.apache.commons.lang3.reflect.FieldUtils; +import org.apache.iceberg.Table; +import org.apache.iceberg.TableProperties; +import org.apache.iceberg.hadoop.HadoopInputFile; +import org.apache.iceberg.io.InputFile; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableSet; +import org.apache.iceberg.util.PropertyUtil; + +import org.apache.spark.sql.SparkSession; + +/** + * A class for common Iceberg configs for Spark reads. + *

+ * If a config is set at multiple levels, the following order of precedence is used (top to bottom): + *

    + *
  1. Read options
  2. + *
  3. Session configuration
  4. + *
  5. Table metadata
  6. + *
+ * The most specific value is set in read options and takes precedence over all other configs. + * If no read option is provided, this class checks the session configuration for any overrides. + * If no applicable value is found in the session configuration, this class uses the table metadata. + *

+ * Note this class is NOT meant to be serialized and sent to executors. + */ +public class SparkReadConf { + private static final Set LOCALITY_WHITELIST_FS = ImmutableSet.of("hdfs"); + + private final SparkSession spark; + private final Table table; + private final Map readOptions; + private final SparkConfParser confParser; + + public static SparkReadConf fromReflect(Object obj) throws IllegalAccessException { + SparkSession spark = (SparkSession) FieldUtils.readField(obj, "spark", true); + Table table = (Table) FieldUtils.readField(obj, "table", true); + Map readOptions = (Map) FieldUtils.readField(obj, "readOptions", true); + return new SparkReadConf(spark, table, readOptions); + } + + public SparkReadConf(SparkSession spark, Table table, Map readOptions) { + this.spark = spark; + this.table = table; + this.readOptions = readOptions; + this.confParser = new SparkConfParser(spark, table, readOptions); + } + + public boolean caseSensitive() { + return Boolean.parseBoolean(spark.conf().get("spark.sql.caseSensitive")); + } + + public boolean localityEnabled() { + InputFile file = table.io().newInputFile(table.location()); + + if (file instanceof HadoopInputFile) { + String scheme = ((HadoopInputFile) file).getFileSystem().getScheme(); + boolean defaultValue = LOCALITY_WHITELIST_FS.contains(scheme); + return PropertyUtil.propertyAsBoolean( + readOptions, + SparkReadOptions.LOCALITY, + defaultValue); + } + + return false; + } + + public Long snapshotId() { + return confParser.longConf() + .option(SparkReadOptions.SNAPSHOT_ID) + .parseOptional(); + } + + public Long asOfTimestamp() { + return confParser.longConf() + .option(SparkReadOptions.AS_OF_TIMESTAMP) + .parseOptional(); + } + + public Long startSnapshotId() { + return confParser.longConf() + .option(SparkReadOptions.START_SNAPSHOT_ID) + .parseOptional(); + } + + public Long endSnapshotId() { + return confParser.longConf() + .option(SparkReadOptions.END_SNAPSHOT_ID) + .parseOptional(); + } + + public String fileScanTaskSetId() { + return confParser.stringConf() + .option(SparkReadOptions.FILE_SCAN_TASK_SET_ID) + .parseOptional(); + } + + public boolean streamingSkipDeleteSnapshots() { + return confParser.booleanConf() + .option(SparkReadOptions.STREAMING_SKIP_DELETE_SNAPSHOTS) + .defaultValue(SparkReadOptions.STREAMING_SKIP_DELETE_SNAPSHOTS_DEFAULT) + .parse(); + } + + public boolean streamingSkipOverwriteSnapshots() { + return confParser.booleanConf() + .option(SparkReadOptions.STREAMING_SKIP_OVERWRITE_SNAPSHOTS) + .defaultValue(SparkReadOptions.STREAMING_SKIP_OVERWRITE_SNAPSHOTS_DEFAULT) + .parse(); + } + + public boolean parquetVectorizationEnabled() { + return confParser.booleanConf() + .option(SparkReadOptions.VECTORIZATION_ENABLED) + .sessionConf(SparkSQLProperties.VECTORIZATION_ENABLED) + .tableProperty(TableProperties.PARQUET_VECTORIZATION_ENABLED) + .defaultValue(TableProperties.PARQUET_VECTORIZATION_ENABLED_DEFAULT) + .parse(); + } + + public int parquetBatchSize() { + return confParser.intConf() + .option(SparkReadOptions.VECTORIZATION_BATCH_SIZE) + .tableProperty(TableProperties.PARQUET_BATCH_SIZE) + .defaultValue(TableProperties.PARQUET_BATCH_SIZE_DEFAULT) + .parse(); + } + + public boolean orcVectorizationEnabled() { + return confParser.booleanConf() + .option(SparkReadOptions.VECTORIZATION_ENABLED) + .sessionConf(SparkSQLProperties.VECTORIZATION_ENABLED) + .tableProperty(TableProperties.ORC_VECTORIZATION_ENABLED) + .defaultValue(TableProperties.ORC_VECTORIZATION_ENABLED_DEFAULT) + .parse(); + } + + public int orcBatchSize() { + return confParser.intConf() + .option(SparkReadOptions.VECTORIZATION_BATCH_SIZE) + .tableProperty(TableProperties.ORC_BATCH_SIZE) + .defaultValue(TableProperties.ORC_BATCH_SIZE_DEFAULT) + .parse(); + } + + public Long splitSizeOption() { + return confParser.longConf() + .option(SparkReadOptions.SPLIT_SIZE) + .parseOptional(); + } + + public long splitSize() { + return confParser.longConf() + .option(SparkReadOptions.SPLIT_SIZE) + .tableProperty(TableProperties.SPLIT_SIZE) + .defaultValue(TableProperties.SPLIT_SIZE_DEFAULT) + .parse(); + } + + public Integer splitLookbackOption() { + return confParser.intConf() + .option(SparkReadOptions.LOOKBACK) + .parseOptional(); + } + + public int splitLookback() { + return confParser.intConf() + .option(SparkReadOptions.LOOKBACK) + .tableProperty(TableProperties.SPLIT_LOOKBACK) + .defaultValue(TableProperties.SPLIT_LOOKBACK_DEFAULT) + .parse(); + } + + public Long splitOpenFileCostOption() { + return confParser.longConf() + .option(SparkReadOptions.FILE_OPEN_COST) + .parseOptional(); + } + + public long splitOpenFileCost() { + return confParser.longConf() + .option(SparkReadOptions.FILE_OPEN_COST) + .tableProperty(TableProperties.SPLIT_OPEN_FILE_COST) + .defaultValue(TableProperties.SPLIT_OPEN_FILE_COST_DEFAULT) + .parse(); + } + + /** + * Enables reading a timestamp without time zone as a timestamp with time zone. + *

+ * Generally, this is not safe as a timestamp without time zone is supposed to represent the wall-clock time, + * i.e. no matter the reader/writer timezone 3PM should always be read as 3PM, + * but a timestamp with time zone represents instant semantics, i.e. the timestamp + * is adjusted so that the corresponding time in the reader timezone is displayed. + *

+ * When set to false (default), an exception must be thrown while reading a timestamp without time zone. + * + * @return boolean indicating if reading timestamps without timezone is allowed + */ + public boolean handleTimestampWithoutZone() { + return confParser.booleanConf() + .option(SparkReadOptions.HANDLE_TIMESTAMP_WITHOUT_TIMEZONE) + .sessionConf(SparkSQLProperties.HANDLE_TIMESTAMP_WITHOUT_TIMEZONE) + .defaultValue(SparkSQLProperties.HANDLE_TIMESTAMP_WITHOUT_TIMEZONE_DEFAULT) + .parse(); + } + + public Long streamFromTimestamp() { + return confParser.longConf() + .option(SparkReadOptions.STREAM_FROM_TIMESTAMP) + .defaultValue(Long.MIN_VALUE) + .parse(); + } +} diff --git a/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/spark/SparkReadOptions.java b/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/spark/SparkReadOptions.java new file mode 100644 index 00000000000..92a72e56c76 --- /dev/null +++ b/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/spark/SparkReadOptions.java @@ -0,0 +1,73 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark.rapids.iceberg.spark; + +/** + * Spark DF read options + */ +public class SparkReadOptions { + + private SparkReadOptions() { + } + + // Snapshot ID of the table snapshot to read + public static final String SNAPSHOT_ID = "snapshot-id"; + + // Start snapshot ID used in incremental scans (exclusive) + public static final String START_SNAPSHOT_ID = "start-snapshot-id"; + + // End snapshot ID used in incremental scans (inclusive) + public static final String END_SNAPSHOT_ID = "end-snapshot-id"; + + // A timestamp in milliseconds; the snapshot used will be the snapshot current at this time. + public static final String AS_OF_TIMESTAMP = "as-of-timestamp"; + + // Overrides the table's read.split.target-size and read.split.metadata-target-size + public static final String SPLIT_SIZE = "split-size"; + + // Overrides the table's read.split.planning-lookback + public static final String LOOKBACK = "lookback"; + + // Overrides the table's read.split.open-file-cost + public static final String FILE_OPEN_COST = "file-open-cost"; + + // Overrides the table's read.split.open-file-cost + public static final String VECTORIZATION_ENABLED = "vectorization-enabled"; + + // Overrides the table's read.parquet.vectorization.batch-size + public static final String VECTORIZATION_BATCH_SIZE = "batch-size"; + + // Set ID that is used to fetch file scan tasks + public static final String FILE_SCAN_TASK_SET_ID = "file-scan-task-set-id"; + + // skip snapshots of type delete while reading stream out of iceberg table + public static final String STREAMING_SKIP_DELETE_SNAPSHOTS = "streaming-skip-delete-snapshots"; + public static final boolean STREAMING_SKIP_DELETE_SNAPSHOTS_DEFAULT = false; + + // skip snapshots of type overwrite while reading stream out of iceberg table + public static final String STREAMING_SKIP_OVERWRITE_SNAPSHOTS = "streaming-skip-overwrite-snapshots"; + public static final boolean STREAMING_SKIP_OVERWRITE_SNAPSHOTS_DEFAULT = false; + + // Controls whether to allow reading timestamps without zone info + public static final String HANDLE_TIMESTAMP_WITHOUT_TIMEZONE = "handle-timestamp-without-timezone"; + + // Controls whether to report locality information to Spark while allocating input partitions + public static final String LOCALITY = "locality"; + + // Timestamp in milliseconds; start a stream from the snapshot that occurs after this timestamp + public static final String STREAM_FROM_TIMESTAMP = "stream-from-timestamp"; +} diff --git a/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/spark/SparkSQLProperties.java b/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/spark/SparkSQLProperties.java new file mode 100644 index 00000000000..c24fe950ded --- /dev/null +++ b/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/spark/SparkSQLProperties.java @@ -0,0 +1,43 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark.rapids.iceberg.spark; + +public class SparkSQLProperties { + + private SparkSQLProperties() { + } + + // Controls whether vectorized reads are enabled + public static final String VECTORIZATION_ENABLED = "spark.sql.iceberg.vectorization.enabled"; + + // Controls whether reading/writing timestamps without timezones is allowed + public static final String HANDLE_TIMESTAMP_WITHOUT_TIMEZONE = "spark.sql.iceberg.handle-timestamp-without-timezone"; + public static final boolean HANDLE_TIMESTAMP_WITHOUT_TIMEZONE_DEFAULT = false; + + // Controls whether timestamp types for new tables should be stored with timezone info + public static final String USE_TIMESTAMP_WITHOUT_TIME_ZONE_IN_NEW_TABLES = + "spark.sql.iceberg.use-timestamp-without-timezone-in-new-tables"; + public static final boolean USE_TIMESTAMP_WITHOUT_TIME_ZONE_IN_NEW_TABLES_DEFAULT = false; + + // Controls whether to perform the nullability check during writes + public static final String CHECK_NULLABILITY = "spark.sql.iceberg.check-nullability"; + public static final boolean CHECK_NULLABILITY_DEFAULT = true; + + // Controls whether to check the order of fields during writes + public static final String CHECK_ORDERING = "spark.sql.iceberg.check-ordering"; + public static final boolean CHECK_ORDERING_DEFAULT = true; +} diff --git a/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/spark/SparkSchemaUtil.java b/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/spark/SparkSchemaUtil.java new file mode 100644 index 00000000000..fce3a1ccac6 --- /dev/null +++ b/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/spark/SparkSchemaUtil.java @@ -0,0 +1,365 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark.rapids.iceberg.spark; + +import java.util.List; +import java.util.Map; +import java.util.function.Function; +import java.util.stream.Collectors; + +import org.apache.iceberg.MetadataColumns; +import org.apache.iceberg.Schema; +import org.apache.iceberg.exceptions.ValidationException; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.relocated.com.google.common.math.LongMath; +import org.apache.iceberg.types.Type; +import org.apache.iceberg.types.TypeUtil; +import org.apache.iceberg.types.Types; + +import org.apache.spark.sql.types.DataType; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +import org.apache.spark.sql.types.StructType$; + +/** + * Helper methods for working with Spark/Hive metadata. + */ +public class SparkSchemaUtil { + private SparkSchemaUtil() { + } + +// /** +// * Returns a {@link Schema} for the given table with fresh field ids. +// *

+// * This creates a Schema for an existing table by looking up the table's schema with Spark and +// * converting that schema. Spark/Hive partition columns are included in the schema. +// * +// * @param spark a Spark session +// * @param name a table name and (optional) database +// * @return a Schema for the table, if found +// */ +// public static Schema schemaForTable(SparkSession spark, String name) { +// StructType sparkType = spark.table(name).schema(); +// Type converted = SparkTypeVisitor.visit(sparkType, new SparkTypeToType(sparkType)); +// return new Schema(converted.asNestedType().asStructType().fields()); +// } +// +// /** +// * Returns a {@link PartitionSpec} for the given table. +// *

+// * This creates a partition spec for an existing table by looking up the table's schema and +// * creating a spec with identity partitions for each partition column. +// * +// * @param spark a Spark session +// * @param name a table name and (optional) database +// * @return a PartitionSpec for the table +// * @throws AnalysisException if thrown by the Spark catalog +// */ +// public static PartitionSpec specForTable(SparkSession spark, String name) throws AnalysisException { +// List parts = Lists.newArrayList(Splitter.on('.').limit(2).split(name)); +// String db = parts.size() == 1 ? "default" : parts.get(0); +// String table = parts.get(parts.size() == 1 ? 0 : 1); +// +// PartitionSpec spec = identitySpec( +// schemaForTable(spark, name), +// spark.catalog().listColumns(db, table).collectAsList()); +// return spec == null ? PartitionSpec.unpartitioned() : spec; +// } + + /** + * Convert a {@link Schema} to a {@link DataType Spark type}. + * + * @param schema a Schema + * @return the equivalent Spark type + * @throws IllegalArgumentException if the type cannot be converted to Spark + */ + public static StructType convert(Schema schema) { + return (StructType) TypeUtil.visit(schema, new TypeToSparkType()); + } + + /** + * Convert a {@link Type} to a {@link DataType Spark type}. + * + * @param type a Type + * @return the equivalent Spark type + * @throws IllegalArgumentException if the type cannot be converted to Spark + */ + public static DataType convert(Type type) { + return TypeUtil.visit(type, new TypeToSparkType()); + } + + public static StructType convertWithoutConstants(Schema schema, Map idToConstant) { + return (StructType) TypeUtil.visit(schema, new TypeToSparkType() { + @Override + public DataType struct(Types.StructType struct, List fieldResults) { + List fields = struct.fields(); + + List sparkFields = Lists.newArrayListWithExpectedSize(fieldResults.size()); + for (int i = 0; i < fields.size(); i += 1) { + Types.NestedField field = fields.get(i); + // skip fields that are constants + if (idToConstant.containsKey(field.fieldId())) { + continue; + } + DataType type = fieldResults.get(i); + StructField sparkField = StructField.apply( + field.name(), type, field.isOptional(), Metadata.empty()); + if (field.doc() != null) { + sparkField = sparkField.withComment(field.doc()); + } + sparkFields.add(sparkField); + } + + return StructType$.MODULE$.apply(sparkFields); + } + }); + } + +// /** +// * Convert a Spark {@link StructType struct} to a {@link Schema} with new field ids. +// *

+// * This conversion assigns fresh ids. +// *

+// * Some data types are represented as the same Spark type. These are converted to a default type. +// *

+// * To convert using a reference schema for field ids and ambiguous types, use +// * {@link #convert(Schema, StructType)}. +// * +// * @param sparkType a Spark StructType +// * @return the equivalent Schema +// * @throws IllegalArgumentException if the type cannot be converted +// */ +// public static Schema convert(StructType sparkType) { +// return convert(sparkType, false); +// } +// +// /** +// * Convert a Spark {@link StructType struct} to a {@link Schema} with new field ids. +// *

+// * This conversion assigns fresh ids. +// *

+// * Some data types are represented as the same Spark type. These are converted to a default type. +// *

+// * To convert using a reference schema for field ids and ambiguous types, use +// * {@link #convert(Schema, StructType)}. +// * +// * @param sparkType a Spark StructType +// * @param useTimestampWithoutZone boolean flag indicates that timestamp should be stored without timezone +// * @return the equivalent Schema +// * @throws IllegalArgumentException if the type cannot be converted +// */ +// public static Schema convert(StructType sparkType, boolean useTimestampWithoutZone) { +// Type converted = SparkTypeVisitor.visit(sparkType, new SparkTypeToType(sparkType)); +// Schema schema = new Schema(converted.asNestedType().asStructType().fields()); +// if (useTimestampWithoutZone) { +// schema = SparkFixupTimestampType.fixup(schema); +// } +// return schema; +// } +// +// /** +// * Convert a Spark {@link DataType struct} to a {@link Type} with new field ids. +// *

+// * This conversion assigns fresh ids. +// *

+// * Some data types are represented as the same Spark type. These are converted to a default type. +// *

+// * To convert using a reference schema for field ids and ambiguous types, use +// * {@link #convert(Schema, StructType)}. +// * +// * @param sparkType a Spark DataType +// * @return the equivalent Type +// * @throws IllegalArgumentException if the type cannot be converted +// */ +// public static Type convert(DataType sparkType) { +// return SparkTypeVisitor.visit(sparkType, new SparkTypeToType()); +// } +// +// /** +// * Convert a Spark {@link StructType struct} to a {@link Schema} based on the given schema. +// *

+// * This conversion does not assign new ids; it uses ids from the base schema. +// *

+// * Data types, field order, and nullability will match the spark type. This conversion may return +// * a schema that is not compatible with base schema. +// * +// * @param baseSchema a Schema on which conversion is based +// * @param sparkType a Spark StructType +// * @return the equivalent Schema +// * @throws IllegalArgumentException if the type cannot be converted or there are missing ids +// */ +// public static Schema convert(Schema baseSchema, StructType sparkType) { +// // convert to a type with fresh ids +// Types.StructType struct = SparkTypeVisitor.visit(sparkType, new SparkTypeToType(sparkType)).asStructType(); +// // reassign ids to match the base schema +// Schema schema = TypeUtil.reassignIds(new Schema(struct.fields()), baseSchema); +// // fix types that can't be represented in Spark (UUID and Fixed) +// return SparkFixupTypes.fixup(schema, baseSchema); +// } +// +// /** +// * Convert a Spark {@link StructType struct} to a {@link Schema} based on the given schema. +// *

+// * This conversion will assign new ids for fields that are not found in the base schema. +// *

+// * Data types, field order, and nullability will match the spark type. This conversion may return +// * a schema that is not compatible with base schema. +// * +// * @param baseSchema a Schema on which conversion is based +// * @param sparkType a Spark StructType +// * @return the equivalent Schema +// * @throws IllegalArgumentException if the type cannot be converted or there are missing ids +// */ +// public static Schema convertWithFreshIds(Schema baseSchema, StructType sparkType) { +// // convert to a type with fresh ids +// Types.StructType struct = SparkTypeVisitor.visit(sparkType, new SparkTypeToType(sparkType)).asStructType(); +// // reassign ids to match the base schema +// Schema schema = TypeUtil.reassignOrRefreshIds(new Schema(struct.fields()), baseSchema); +// // fix types that can't be represented in Spark (UUID and Fixed) +// return SparkFixupTypes.fixup(schema, baseSchema); +// } +// +// /** +// * Prune columns from a {@link Schema} using a {@link StructType Spark type} projection. +// *

+// * This requires that the Spark type is a projection of the Schema. Nullability and types must +// * match. +// * +// * @param schema a Schema +// * @param requestedType a projection of the Spark representation of the Schema +// * @return a Schema corresponding to the Spark projection +// * @throws IllegalArgumentException if the Spark type does not match the Schema +// */ +// public static Schema prune(Schema schema, StructType requestedType) { +// return new Schema(TypeUtil.visit(schema, new PruneColumnsWithoutReordering(requestedType, ImmutableSet.of())) +// .asNestedType() +// .asStructType() +// .fields()); +// } +// +// /** +// * Prune columns from a {@link Schema} using a {@link StructType Spark type} projection. +// *

+// * This requires that the Spark type is a projection of the Schema. Nullability and types must +// * match. +// *

+// * The filters list of {@link Expression} is used to ensure that columns referenced by filters +// * are projected. +// * +// * @param schema a Schema +// * @param requestedType a projection of the Spark representation of the Schema +// * @param filters a list of filters +// * @return a Schema corresponding to the Spark projection +// * @throws IllegalArgumentException if the Spark type does not match the Schema +// */ +// public static Schema prune(Schema schema, StructType requestedType, List filters) { +// Set filterRefs = Binder.boundReferences(schema.asStruct(), filters, true); +// return new Schema(TypeUtil.visit(schema, new PruneColumnsWithoutReordering(requestedType, filterRefs)) +// .asNestedType() +// .asStructType() +// .fields()); +// } +// +// /** +// * Prune columns from a {@link Schema} using a {@link StructType Spark type} projection. +// *

+// * This requires that the Spark type is a projection of the Schema. Nullability and types must +// * match. +// *

+// * The filters list of {@link Expression} is used to ensure that columns referenced by filters +// * are projected. +// * +// * @param schema a Schema +// * @param requestedType a projection of the Spark representation of the Schema +// * @param filter a filters +// * @return a Schema corresponding to the Spark projection +// * @throws IllegalArgumentException if the Spark type does not match the Schema +// */ +// public static Schema prune(Schema schema, StructType requestedType, Expression filter, boolean caseSensitive) { +// Set filterRefs = +// Binder.boundReferences(schema.asStruct(), Collections.singletonList(filter), caseSensitive); +// +// return new Schema(TypeUtil.visit(schema, new PruneColumnsWithoutReordering(requestedType, filterRefs)) +// .asNestedType() +// .asStructType() +// .fields()); +// } +// +// private static PartitionSpec identitySpec(Schema schema, Collection columns) { +// List names = Lists.newArrayList(); +// for (Column column : columns) { +// if (column.isPartition()) { +// names.add(column.name()); +// } +// } +// +// return identitySpec(schema, names); +// } +// +// private static PartitionSpec identitySpec(Schema schema, List partitionNames) { +// if (partitionNames == null || partitionNames.isEmpty()) { +// return null; +// } +// +// PartitionSpec.Builder builder = PartitionSpec.builderFor(schema); +// for (String partitionName : partitionNames) { +// builder.identity(partitionName); +// } +// +// return builder.build(); +// } + + /** + * Estimate approximate table size based on Spark schema and total records. + * + * @param tableSchema Spark schema + * @param totalRecords total records in the table + * @return approximate size based on table schema + */ + public static long estimateSize(StructType tableSchema, long totalRecords) { + if (totalRecords == Long.MAX_VALUE) { + return totalRecords; + } + + long result; + try { + result = LongMath.checkedMultiply(tableSchema.defaultSize(), totalRecords); + } catch (ArithmeticException e) { + result = Long.MAX_VALUE; + } + return result; + } + + public static void validateMetadataColumnReferences(Schema tableSchema, Schema readSchema) { + List conflictingColumnNames = readSchema.columns().stream() + .map(Types.NestedField::name) + .filter(name -> MetadataColumns.isMetadataColumn(name) && tableSchema.findField(name) != null) + .collect(Collectors.toList()); + + ValidationException.check( + conflictingColumnNames.isEmpty(), + "Table column names conflict with names reserved for Iceberg metadata columns: %s.\n" + + "Please, use ALTER TABLE statements to rename the conflicting table columns.", + conflictingColumnNames); + } + + public static Map indexQuotedNameById(Schema schema) { + Function quotingFunc = name -> String.format("`%s`", name.replace("`", "``")); + return TypeUtil.indexQuotedNameById(schema.asStruct(), quotingFunc); + } +} diff --git a/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/spark/SparkTypeToType.java b/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/spark/SparkTypeToType.java new file mode 100644 index 00000000000..bc170644184 --- /dev/null +++ b/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/spark/SparkTypeToType.java @@ -0,0 +1,162 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark.rapids.iceberg.spark; + +import java.util.List; + +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.types.Type; +import org.apache.iceberg.types.Types; + +import org.apache.spark.sql.types.ArrayType; +import org.apache.spark.sql.types.BinaryType; +import org.apache.spark.sql.types.BooleanType; +import org.apache.spark.sql.types.ByteType; +import org.apache.spark.sql.types.CharType; +import org.apache.spark.sql.types.DataType; +import org.apache.spark.sql.types.DateType; +import org.apache.spark.sql.types.DecimalType; +import org.apache.spark.sql.types.DoubleType; +import org.apache.spark.sql.types.FloatType; +import org.apache.spark.sql.types.IntegerType; +import org.apache.spark.sql.types.LongType; +import org.apache.spark.sql.types.MapType; +import org.apache.spark.sql.types.ShortType; +import org.apache.spark.sql.types.StringType; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +import org.apache.spark.sql.types.TimestampType; +import org.apache.spark.sql.types.VarcharType; + +public class SparkTypeToType extends SparkTypeVisitor { + private final StructType root; + private int nextId = 0; + + SparkTypeToType() { + this.root = null; + } + + SparkTypeToType(StructType root) { + this.root = root; + // the root struct's fields use the first ids + this.nextId = root.fields().length; + } + + private int getNextId() { + int next = nextId; + nextId += 1; + return next; + } + + @Override + @SuppressWarnings("ReferenceEquality") + public Type struct(StructType struct, List types) { + StructField[] fields = struct.fields(); + List newFields = Lists.newArrayListWithExpectedSize(fields.length); + boolean isRoot = root == struct; + for (int i = 0; i < fields.length; i += 1) { + StructField field = fields[i]; + Type type = types.get(i); + + int id; + if (isRoot) { + // for new conversions, use ordinals for ids in the root struct + id = i; + } else { + id = getNextId(); + } + + String doc = field.getComment().isDefined() ? field.getComment().get() : null; + + if (field.nullable()) { + newFields.add(Types.NestedField.optional(id, field.name(), type, doc)); + } else { + newFields.add(Types.NestedField.required(id, field.name(), type, doc)); + } + } + + return Types.StructType.of(newFields); + } + + @Override + public Type field(StructField field, Type typeResult) { + return typeResult; + } + + @Override + public Type array(ArrayType array, Type elementType) { + if (array.containsNull()) { + return Types.ListType.ofOptional(getNextId(), elementType); + } else { + return Types.ListType.ofRequired(getNextId(), elementType); + } + } + + @Override + public Type map(MapType map, Type keyType, Type valueType) { + if (map.valueContainsNull()) { + return Types.MapType.ofOptional(getNextId(), getNextId(), keyType, valueType); + } else { + return Types.MapType.ofRequired(getNextId(), getNextId(), keyType, valueType); + } + } + + @SuppressWarnings("checkstyle:CyclomaticComplexity") + @Override + public Type atomic(DataType atomic) { + if (atomic instanceof BooleanType) { + return Types.BooleanType.get(); + + } else if ( + atomic instanceof IntegerType || + atomic instanceof ShortType || + atomic instanceof ByteType) { + return Types.IntegerType.get(); + + } else if (atomic instanceof LongType) { + return Types.LongType.get(); + + } else if (atomic instanceof FloatType) { + return Types.FloatType.get(); + + } else if (atomic instanceof DoubleType) { + return Types.DoubleType.get(); + + } else if ( + atomic instanceof StringType || + atomic instanceof CharType || + atomic instanceof VarcharType) { + return Types.StringType.get(); + + } else if (atomic instanceof DateType) { + return Types.DateType.get(); + + } else if (atomic instanceof TimestampType) { + return Types.TimestampType.withZone(); + + } else if (atomic instanceof DecimalType) { + return Types.DecimalType.of( + ((DecimalType) atomic).precision(), + ((DecimalType) atomic).scale()); + } else if (atomic instanceof BinaryType) { + return Types.BinaryType.get(); + } + + throw new UnsupportedOperationException( + "Not a supported type: " + atomic.catalogString()); + } +} diff --git a/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/spark/SparkTypeVisitor.java b/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/spark/SparkTypeVisitor.java new file mode 100644 index 00000000000..2407d9c77dd --- /dev/null +++ b/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/spark/SparkTypeVisitor.java @@ -0,0 +1,82 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark.rapids.iceberg.spark; + +import java.util.List; + +import org.apache.iceberg.relocated.com.google.common.collect.Lists; + +import org.apache.spark.sql.types.ArrayType; +import org.apache.spark.sql.types.DataType; +import org.apache.spark.sql.types.MapType; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +import org.apache.spark.sql.types.UserDefinedType; + +public class SparkTypeVisitor { + static T visit(DataType type, SparkTypeVisitor visitor) { + if (type instanceof StructType) { + StructField[] fields = ((StructType) type).fields(); + List fieldResults = Lists.newArrayListWithExpectedSize(fields.length); + + for (StructField field : fields) { + fieldResults.add(visitor.field( + field, + visit(field.dataType(), visitor))); + } + + return visitor.struct((StructType) type, fieldResults); + + } else if (type instanceof MapType) { + return visitor.map((MapType) type, + visit(((MapType) type).keyType(), visitor), + visit(((MapType) type).valueType(), visitor)); + + } else if (type instanceof ArrayType) { + return visitor.array( + (ArrayType) type, + visit(((ArrayType) type).elementType(), visitor)); + + } else if (type instanceof UserDefinedType) { + throw new UnsupportedOperationException( + "User-defined types are not supported"); + + } else { + return visitor.atomic(type); + } + } + + public T struct(StructType struct, List fieldResults) { + return null; + } + + public T field(StructField field, T typeResult) { + return null; + } + + public T array(ArrayType array, T elementResult) { + return null; + } + + public T map(MapType map, T keyResult, T valueResult) { + return null; + } + + public T atomic(DataType atomic) { + return null; + } +} diff --git a/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/spark/SparkUtil.java b/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/spark/SparkUtil.java new file mode 100644 index 00000000000..04c1b76428b --- /dev/null +++ b/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/spark/SparkUtil.java @@ -0,0 +1,221 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark.rapids.iceberg.spark; + +import org.apache.iceberg.Schema; +import org.apache.iceberg.types.TypeUtil; +import org.apache.iceberg.types.Types; + +public class SparkUtil { + + public static final String TIMESTAMP_WITHOUT_TIMEZONE_ERROR = String.format("Cannot handle timestamp without" + + " timezone fields in Spark. Spark does not natively support this type but if you would like to handle all" + + " timestamps as timestamp with timezone set '%s' to true. This will not change the underlying values stored" + + " but will change their displayed values in Spark. For more information please see" + + " https://docs.databricks.com/spark/latest/dataframes-datasets/dates-timestamps.html#ansi-sql-and" + + "-spark-sql-timestamps", SparkSQLProperties.HANDLE_TIMESTAMP_WITHOUT_TIMEZONE); + +// private static final String SPARK_CATALOG_CONF_PREFIX = "spark.sql.catalog"; +// // Format string used as the prefix for spark configuration keys to override hadoop configuration values +// // for Iceberg tables from a given catalog. These keys can be specified as `spark.sql.catalog.$catalogName.hadoop.*`, +// // similar to using `spark.hadoop.*` to override hadoop configurations globally for a given spark session. +// private static final String SPARK_CATALOG_HADOOP_CONF_OVERRIDE_FMT_STR = SPARK_CATALOG_CONF_PREFIX + ".%s.hadoop."; + + private SparkUtil() { + } + +// public static FileIO serializableFileIO(Table table) { +// if (table.io() instanceof HadoopConfigurable) { +// // we need to use Spark's SerializableConfiguration to avoid issues with Kryo serialization +// ((HadoopConfigurable) table.io()).serializeConfWith(conf -> new SerializableConfiguration(conf)::value); +// } +// +// return table.io(); +// } +// +// /** +// * Check whether the partition transforms in a spec can be used to write data. +// * +// * @param spec a PartitionSpec +// * @throws UnsupportedOperationException if the spec contains unknown partition transforms +// */ +// public static void validatePartitionTransforms(PartitionSpec spec) { +// if (spec.fields().stream().anyMatch(field -> field.transform() instanceof UnknownTransform)) { +// String unsupported = spec.fields().stream() +// .map(PartitionField::transform) +// .filter(transform -> transform instanceof UnknownTransform) +// .map(Transform::toString) +// .collect(Collectors.joining(", ")); +// +// throw new UnsupportedOperationException( +// String.format("Cannot write using unsupported transforms: %s", unsupported)); +// } +// } +// +// /** +// * A modified version of Spark's LookupCatalog.CatalogAndIdentifier.unapply +// * Attempts to find the catalog and identifier a multipart identifier represents +// * @param nameParts Multipart identifier representing a table +// * @return The CatalogPlugin and Identifier for the table +// */ +// public static Pair catalogAndIdentifier(List nameParts, +// Function catalogProvider, +// BiFunction identiferProvider, +// C currentCatalog, +// String[] currentNamespace) { +// Preconditions.checkArgument(!nameParts.isEmpty(), +// "Cannot determine catalog and identifier from empty name"); +// +// int lastElementIndex = nameParts.size() - 1; +// String name = nameParts.get(lastElementIndex); +// +// if (nameParts.size() == 1) { +// // Only a single element, use current catalog and namespace +// return Pair.of(currentCatalog, identiferProvider.apply(currentNamespace, name)); +// } else { +// C catalog = catalogProvider.apply(nameParts.get(0)); +// if (catalog == null) { +// // The first element was not a valid catalog, treat it like part of the namespace +// String[] namespace = nameParts.subList(0, lastElementIndex).toArray(new String[0]); +// return Pair.of(currentCatalog, identiferProvider.apply(namespace, name)); +// } else { +// // Assume the first element is a valid catalog +// String[] namespace = nameParts.subList(1, lastElementIndex).toArray(new String[0]); +// return Pair.of(catalog, identiferProvider.apply(namespace, name)); +// } +// } +// } + + /** + * Responsible for checking if the table schema has a timestamp without timezone column + * @param schema table schema to check if it contains a timestamp without timezone column + * @return boolean indicating if the schema passed in has a timestamp field without a timezone + */ + public static boolean hasTimestampWithoutZone(Schema schema) { + return TypeUtil.find(schema, t -> Types.TimestampType.withoutZone().equals(t)) != null; + } + +// /** +// * Checks whether timestamp types for new tables should be stored with timezone info. +// *

+// * The default value is false and all timestamp fields are stored as {@link Types.TimestampType#withZone()}. +// * If enabled, all timestamp fields in new tables will be stored as {@link Types.TimestampType#withoutZone()}. +// * +// * @param sessionConf a Spark runtime config +// * @return true if timestamp types for new tables should be stored with timezone info +// */ +// public static boolean useTimestampWithoutZoneInNewTables(RuntimeConfig sessionConf) { +// String sessionConfValue = sessionConf.get(SparkSQLProperties.USE_TIMESTAMP_WITHOUT_TIME_ZONE_IN_NEW_TABLES, null); +// if (sessionConfValue != null) { +// return Boolean.parseBoolean(sessionConfValue); +// } +// return SparkSQLProperties.USE_TIMESTAMP_WITHOUT_TIME_ZONE_IN_NEW_TABLES_DEFAULT; +// } +// +// /** +// * Pulls any Catalog specific overrides for the Hadoop conf from the current SparkSession, which can be +// * set via `spark.sql.catalog.$catalogName.hadoop.*` +// * +// * Mirrors the override of hadoop configurations for a given spark session using `spark.hadoop.*`. +// * +// * The SparkCatalog allows for hadoop configurations to be overridden per catalog, by setting +// * them on the SQLConf, where the following will add the property "fs.default.name" with value +// * "hdfs://hanksnamenode:8020" to the catalog's hadoop configuration. +// * SparkSession.builder() +// * .config(s"spark.sql.catalog.$catalogName.hadoop.fs.default.name", "hdfs://hanksnamenode:8020") +// * .getOrCreate() +// * @param spark The current Spark session +// * @param catalogName Name of the catalog to find overrides for. +// * @return the Hadoop Configuration that should be used for this catalog, with catalog specific overrides applied. +// */ +// public static Configuration hadoopConfCatalogOverrides(SparkSession spark, String catalogName) { +// // Find keys for the catalog intended to be hadoop configurations +// final String hadoopConfCatalogPrefix = hadoopConfPrefixForCatalog(catalogName); +// final Configuration conf = spark.sessionState().newHadoopConf(); +// spark.sqlContext().conf().settings().forEach((k, v) -> { +// // These checks are copied from `spark.sessionState().newHadoopConfWithOptions()`, which we +// // avoid using to not have to convert back and forth between scala / java map types. +// if (v != null && k != null && k.startsWith(hadoopConfCatalogPrefix)) { +// conf.set(k.substring(hadoopConfCatalogPrefix.length()), v); +// } +// }); +// return conf; +// } +// +// private static String hadoopConfPrefixForCatalog(String catalogName) { +// return String.format(SPARK_CATALOG_HADOOP_CONF_OVERRIDE_FMT_STR, catalogName); +// } +// +// /** +// * Get a List of Spark filter Expression. +// * +// * @param schema table schema +// * @param filters filters in the format of a Map, where key is one of the table column name, +// * and value is the specific value to be filtered on the column. +// * @return a List of filters in the format of Spark Expression. +// */ +// public static List partitionMapToExpression(StructType schema, +// Map filters) { +// List filterExpressions = Lists.newArrayList(); +// for (Map.Entry entry : filters.entrySet()) { +// try { +// int index = schema.fieldIndex(entry.getKey()); +// DataType dataType = schema.fields()[index].dataType(); +// BoundReference ref = new BoundReference(index, dataType, true); +// switch (dataType.typeName()) { +// case "integer": +// filterExpressions.add(new EqualTo(ref, +// Literal.create(Integer.parseInt(entry.getValue()), DataTypes.IntegerType))); +// break; +// case "string": +// filterExpressions.add(new EqualTo(ref, Literal.create(entry.getValue(), DataTypes.StringType))); +// break; +// case "short": +// filterExpressions.add(new EqualTo(ref, +// Literal.create(Short.parseShort(entry.getValue()), DataTypes.ShortType))); +// break; +// case "long": +// filterExpressions.add(new EqualTo(ref, +// Literal.create(Long.parseLong(entry.getValue()), DataTypes.LongType))); +// break; +// case "float": +// filterExpressions.add(new EqualTo(ref, +// Literal.create(Float.parseFloat(entry.getValue()), DataTypes.FloatType))); +// break; +// case "double": +// filterExpressions.add(new EqualTo(ref, +// Literal.create(Double.parseDouble(entry.getValue()), DataTypes.DoubleType))); +// break; +// case "date": +// filterExpressions.add(new EqualTo(ref, +// Literal.create(new Date(DateTime.parse(entry.getValue()).getMillis()), DataTypes.DateType))); +// break; +// case "timestamp": +// filterExpressions.add(new EqualTo(ref, +// Literal.create(new Timestamp(DateTime.parse(entry.getValue()).getMillis()), DataTypes.TimestampType))); +// break; +// default: +// throw new IllegalStateException("Unexpected data type in partition filters: " + dataType); +// } +// } catch (IllegalArgumentException e) { +// // ignore if filter is not on table columns +// } +// } +// +// return filterExpressions; +// } +} diff --git a/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/spark/TypeToSparkType.java b/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/spark/TypeToSparkType.java new file mode 100644 index 00000000000..862f5a15c6f --- /dev/null +++ b/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/spark/TypeToSparkType.java @@ -0,0 +1,124 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark.rapids.iceberg.spark; + +import java.util.List; + +import org.apache.iceberg.Schema; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.types.Type; +import org.apache.iceberg.types.TypeUtil; +import org.apache.iceberg.types.Types; + +import org.apache.spark.sql.types.ArrayType$; +import org.apache.spark.sql.types.BinaryType$; +import org.apache.spark.sql.types.BooleanType$; +import org.apache.spark.sql.types.DataType; +import org.apache.spark.sql.types.DateType$; +import org.apache.spark.sql.types.DecimalType$; +import org.apache.spark.sql.types.DoubleType$; +import org.apache.spark.sql.types.FloatType$; +import org.apache.spark.sql.types.IntegerType$; +import org.apache.spark.sql.types.LongType$; +import org.apache.spark.sql.types.MapType$; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StringType$; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType$; +import org.apache.spark.sql.types.TimestampType$; + +public class TypeToSparkType extends TypeUtil.SchemaVisitor { + TypeToSparkType() { + } + + @Override + public DataType schema(Schema schema, DataType structType) { + return structType; + } + + @Override + public DataType struct(Types.StructType struct, List fieldResults) { + List fields = struct.fields(); + + List sparkFields = Lists.newArrayListWithExpectedSize(fieldResults.size()); + for (int i = 0; i < fields.size(); i += 1) { + Types.NestedField field = fields.get(i); + DataType type = fieldResults.get(i); + StructField sparkField = StructField.apply( + field.name(), type, field.isOptional(), Metadata.empty()); + if (field.doc() != null) { + sparkField = sparkField.withComment(field.doc()); + } + sparkFields.add(sparkField); + } + + return StructType$.MODULE$.apply(sparkFields); + } + + @Override + public DataType field(Types.NestedField field, DataType fieldResult) { + return fieldResult; + } + + @Override + public DataType list(Types.ListType list, DataType elementResult) { + return ArrayType$.MODULE$.apply(elementResult, list.isElementOptional()); + } + + @Override + public DataType map(Types.MapType map, DataType keyResult, DataType valueResult) { + return MapType$.MODULE$.apply(keyResult, valueResult, map.isValueOptional()); + } + + @Override + public DataType primitive(Type.PrimitiveType primitive) { + switch (primitive.typeId()) { + case BOOLEAN: + return BooleanType$.MODULE$; + case INTEGER: + return IntegerType$.MODULE$; + case LONG: + return LongType$.MODULE$; + case FLOAT: + return FloatType$.MODULE$; + case DOUBLE: + return DoubleType$.MODULE$; + case DATE: + return DateType$.MODULE$; + case TIME: + throw new UnsupportedOperationException( + "Spark does not support time fields"); + case TIMESTAMP: + return TimestampType$.MODULE$; + case STRING: + return StringType$.MODULE$; + case UUID: + // use String + return StringType$.MODULE$; + case FIXED: + return BinaryType$.MODULE$; + case BINARY: + return BinaryType$.MODULE$; + case DECIMAL: + Types.DecimalType decimal = (Types.DecimalType) primitive; + return DecimalType$.MODULE$.apply(decimal.precision(), decimal.scale()); + default: + throw new UnsupportedOperationException( + "Cannot convert unknown type to Spark: " + primitive); + } + } +} diff --git a/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/spark/source/BaseDataReader.java b/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/spark/source/BaseDataReader.java new file mode 100644 index 00000000000..7b638625bf3 --- /dev/null +++ b/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/spark/source/BaseDataReader.java @@ -0,0 +1,202 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark.rapids.iceberg.spark.source; + +import java.io.Closeable; +import java.io.IOException; +import java.math.BigDecimal; +import java.nio.ByteBuffer; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.stream.Stream; + +import org.apache.avro.generic.GenericData; +import org.apache.avro.util.Utf8; +import org.apache.iceberg.CombinedScanTask; +import org.apache.iceberg.FileScanTask; +import org.apache.iceberg.MetadataColumns; +import org.apache.iceberg.Partitioning; +import org.apache.iceberg.Schema; +import org.apache.iceberg.StructLike; +import org.apache.iceberg.Table; +import org.apache.iceberg.encryption.EncryptedFiles; +import org.apache.iceberg.encryption.EncryptedInputFile; +import org.apache.iceberg.io.CloseableIterator; +import org.apache.iceberg.io.InputFile; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.apache.iceberg.types.Type; +import org.apache.iceberg.types.Types.NestedField; +import org.apache.iceberg.types.Types.StructType; +import org.apache.iceberg.util.ByteBuffers; +import org.apache.iceberg.util.PartitionUtil; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import org.apache.spark.rdd.InputFileBlockHolder; +import org.apache.spark.sql.catalyst.expressions.GenericInternalRow; +import org.apache.spark.sql.types.Decimal; +import org.apache.spark.unsafe.types.UTF8String; + +/** + * Base class of Spark readers. + * + * @param is the Java class returned by this reader whose objects contain one or more rows. + */ +abstract class BaseDataReader implements Closeable { + private static final Logger LOG = LoggerFactory.getLogger(BaseDataReader.class); + + private final Table table; + private final Iterator tasks; + private final Map inputFiles; + + private CloseableIterator currentIterator; + private T current = null; + private FileScanTask currentTask = null; + + BaseDataReader(Table table, CombinedScanTask task) { + this.table = table; + this.tasks = task.files().iterator(); + Map keyMetadata = Maps.newHashMap(); + task.files().stream() + .flatMap(fileScanTask -> Stream.concat(Stream.of(fileScanTask.file()), fileScanTask.deletes().stream())) + .forEach(file -> keyMetadata.put(file.path().toString(), file.keyMetadata())); + Stream encrypted = keyMetadata.entrySet().stream() + .map(entry -> EncryptedFiles.encryptedInput(table.io().newInputFile(entry.getKey()), entry.getValue())); + + // decrypt with the batch call to avoid multiple RPCs to a key server, if possible + Iterable decryptedFiles = table.encryption().decrypt(encrypted::iterator); + + Map files = Maps.newHashMapWithExpectedSize(task.files().size()); + decryptedFiles.forEach(decrypted -> files.putIfAbsent(decrypted.location(), decrypted)); + this.inputFiles = ImmutableMap.copyOf(files); + + this.currentIterator = CloseableIterator.empty(); + } + + protected Table table() { + return table; + } + + public boolean next() throws IOException { + try { + while (true) { + if (currentIterator.hasNext()) { + this.current = currentIterator.next(); + return true; + } else if (tasks.hasNext()) { + this.currentIterator.close(); + this.currentTask = tasks.next(); + this.currentIterator = open(currentTask); + } else { + this.currentIterator.close(); + return false; + } + } + } catch (IOException | RuntimeException e) { + if (currentTask != null && !currentTask.isDataTask()) { + LOG.error("Error reading file: {}", getInputFile(currentTask).location(), e); + } + throw e; + } + } + + public T get() { + return current; + } + + abstract CloseableIterator open(FileScanTask task); + + @Override + public void close() throws IOException { + InputFileBlockHolder.unset(); + + // close the current iterator + this.currentIterator.close(); + + // exhaust the task iterator + while (tasks.hasNext()) { + tasks.next(); + } + } + + protected InputFile getInputFile(FileScanTask task) { + Preconditions.checkArgument(!task.isDataTask(), "Invalid task type"); + return inputFiles.get(task.file().path().toString()); + } + + protected InputFile getInputFile(String location) { + return inputFiles.get(location); + } + + protected Map constantsMap(FileScanTask task, Schema readSchema) { + if (readSchema.findField(MetadataColumns.PARTITION_COLUMN_ID) != null) { + StructType partitionType = Partitioning.partitionType(table); + return PartitionUtil.constantsMap(task, partitionType, BaseDataReader::convertConstant); + } else { + return PartitionUtil.constantsMap(task, BaseDataReader::convertConstant); + } + } + + protected static Object convertConstant(Type type, Object value) { + if (value == null) { + return null; + } + + switch (type.typeId()) { + case DECIMAL: + return Decimal.apply((BigDecimal) value); + case STRING: + if (value instanceof Utf8) { + Utf8 utf8 = (Utf8) value; + return UTF8String.fromBytes(utf8.getBytes(), 0, utf8.getByteLength()); + } + return UTF8String.fromString(value.toString()); + case FIXED: + if (value instanceof byte[]) { + return value; + } else if (value instanceof GenericData.Fixed) { + return ((GenericData.Fixed) value).bytes(); + } + return ByteBuffers.toByteArray((ByteBuffer) value); + case BINARY: + return ByteBuffers.toByteArray((ByteBuffer) value); + case STRUCT: + StructType structType = (StructType) type; + + if (structType.fields().isEmpty()) { + return new GenericInternalRow(); + } + + List fields = structType.fields(); + Object[] values = new Object[fields.size()]; + StructLike struct = (StructLike) value; + + for (int index = 0; index < fields.size(); index++) { + NestedField field = fields.get(index); + Type fieldType = field.type(); + values[index] = convertConstant(fieldType, struct.get(index, fieldType.typeId().javaClass())); + } + + return new GenericInternalRow(values); + default: + } + return value; + } +} diff --git a/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/spark/source/GpuBatchDataReader.java b/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/spark/source/GpuBatchDataReader.java new file mode 100644 index 00000000000..3dd8edac776 --- /dev/null +++ b/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/spark/source/GpuBatchDataReader.java @@ -0,0 +1,151 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark.rapids.iceberg.spark.source; + +import java.util.Map; +import java.util.Set; + +import com.nvidia.spark.rapids.GpuMetric; +import com.nvidia.spark.rapids.iceberg.data.GpuDeleteFilter; +import com.nvidia.spark.rapids.iceberg.orc.GpuORC; +import com.nvidia.spark.rapids.iceberg.parquet.GpuParquet; +import org.apache.hadoop.conf.Configuration; +import org.apache.iceberg.CombinedScanTask; +import org.apache.iceberg.DataFile; +import org.apache.iceberg.FileFormat; +import org.apache.iceberg.FileScanTask; +import org.apache.iceberg.MetadataColumns; +import org.apache.iceberg.Schema; +import org.apache.iceberg.Table; +import org.apache.iceberg.TableProperties; +import org.apache.iceberg.io.CloseableIterable; +import org.apache.iceberg.io.CloseableIterator; +import org.apache.iceberg.io.InputFile; +import org.apache.iceberg.mapping.NameMappingParser; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; +import org.apache.iceberg.relocated.com.google.common.collect.Sets; +import org.apache.iceberg.types.TypeUtil; + +import org.apache.spark.rdd.InputFileBlockHolder; +import org.apache.spark.sql.vectorized.ColumnarBatch; + +/** GPU version of Apache Iceberg's BatchDataReader */ +class GpuBatchDataReader extends BaseDataReader { + private final Schema expectedSchema; + private final String nameMapping; + private final boolean caseSensitive; + private final Configuration conf; + private final int maxBatchSizeRows; + private final long maxBatchSizeBytes; + private final String parquetDebugDumpPrefix; + private final scala.collection.immutable.Map metrics; + + GpuBatchDataReader(CombinedScanTask task, Table table, Schema expectedSchema, boolean caseSensitive, + Configuration conf, int maxBatchSizeRows, long maxBatchSizeBytes, + String parquetDebugDumpPrefix, + scala.collection.immutable.Map metrics) { + super(table, task); + this.expectedSchema = expectedSchema; + this.nameMapping = table.properties().get(TableProperties.DEFAULT_NAME_MAPPING); + this.caseSensitive = caseSensitive; + this.conf = conf; + this.maxBatchSizeRows = maxBatchSizeRows; + this.maxBatchSizeBytes = maxBatchSizeBytes; + this.parquetDebugDumpPrefix = parquetDebugDumpPrefix; + this.metrics = metrics; + } + + @Override + CloseableIterator open(FileScanTask task) { + DataFile file = task.file(); + + // update the current file for Spark's filename() function + InputFileBlockHolder.set(file.path().toString(), task.start(), task.length()); + + Map idToConstant = constantsMap(task, expectedSchema); + + CloseableIterable iter; + InputFile location = getInputFile(task); + Preconditions.checkNotNull(location, "Could not find InputFile associated with FileScanTask"); + if (task.file().format() == FileFormat.PARQUET) { + GpuDeleteFilter deleteFilter = deleteFilter(task); + // get required schema for filtering out equality-delete rows in case equality-delete uses columns are + // not selected. + Schema requiredSchema = requiredSchema(deleteFilter); + + GpuParquet.ReadBuilder builder = GpuParquet.read(location) + .project(requiredSchema) + .split(task.start(), task.length()) + .constants(idToConstant) + .deleteFilter(deleteFilter) + .filter(task.residual()) + .caseSensitive(caseSensitive) + .withConfiguration(conf) + .withMaxBatchSizeRows(maxBatchSizeRows) + .withMaxBatchSizeBytes(maxBatchSizeBytes) + .withDebugDumpPrefix(parquetDebugDumpPrefix) + .withMetrics(metrics); + + if (nameMapping != null) { + builder.withNameMapping(NameMappingParser.fromJson(nameMapping)); + } + + iter = builder.build(); + } else if (task.file().format() == FileFormat.ORC) { + Set constantFieldIds = idToConstant.keySet(); + Set metadataFieldIds = MetadataColumns.metadataFieldIds(); + Sets.SetView constantAndMetadataFieldIds = Sets.union(constantFieldIds, metadataFieldIds); + Schema schemaWithoutConstantAndMetadataFields = TypeUtil.selectNot(expectedSchema, constantAndMetadataFieldIds); + GpuORC.ReadBuilder builder = GpuORC.read(location) + .project(schemaWithoutConstantAndMetadataFields) + .split(task.start(), task.length()) + .readerExpectedSchema(expectedSchema) + .constants(idToConstant) + .filter(task.residual()) + .caseSensitive(caseSensitive); + + if (nameMapping != null) { + builder.withNameMapping(NameMappingParser.fromJson(nameMapping)); + } + + iter = builder.build(); + } else { + throw new UnsupportedOperationException( + "Format: " + task.file().format() + " not supported for batched reads"); + } + return iter.iterator(); + } + + private GpuDeleteFilter deleteFilter(FileScanTask task) { + if (task.deletes().isEmpty()) { + return null; + } + return new GpuDeleteFilter( + task.file().path().toString(), + task.deletes(), + table().schema(), + expectedSchema); + } + + private Schema requiredSchema(GpuDeleteFilter deleteFilter) { + if (deleteFilter != null && deleteFilter.hasEqDeletes()) { + return deleteFilter.requiredSchema(); + } else { + return expectedSchema; + } + } +} diff --git a/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/spark/source/GpuIcebergReader.java b/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/spark/source/GpuIcebergReader.java new file mode 100644 index 00000000000..8cc5b7aaafc --- /dev/null +++ b/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/spark/source/GpuIcebergReader.java @@ -0,0 +1,170 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark.rapids.iceberg.spark.source; + +import java.io.IOException; +import java.io.UncheckedIOException; +import java.util.List; +import java.util.Map; +import java.util.NoSuchElementException; + +import ai.rapids.cudf.Scalar; +import com.nvidia.spark.rapids.GpuColumnVector; +import com.nvidia.spark.rapids.GpuScalar; +import com.nvidia.spark.rapids.iceberg.data.GpuDeleteFilter; +import com.nvidia.spark.rapids.iceberg.spark.SparkSchemaUtil; +import org.apache.iceberg.Schema; +import org.apache.iceberg.io.CloseableIterator; +import org.apache.iceberg.types.Type; +import org.apache.iceberg.types.TypeUtil; +import org.apache.iceberg.types.Types; + +import org.apache.spark.sql.connector.read.PartitionReader; +import org.apache.spark.sql.types.DataType; +import org.apache.spark.sql.vectorized.ColumnVector; +import org.apache.spark.sql.vectorized.ColumnarBatch; + +public class GpuIcebergReader implements CloseableIterator { + private final Schema expectedSchema; + private final PartitionReader partReader; + private final GpuDeleteFilter deleteFilter; + private final Map idToConstant; + private boolean needNext = true; + private boolean isBatchPending; + + public GpuIcebergReader(Schema expectedSchema, + PartitionReader partReader, + GpuDeleteFilter deleteFilter, + Map idToConstant) { + this.expectedSchema = expectedSchema; + this.partReader = partReader; + this.deleteFilter = deleteFilter; + this.idToConstant = idToConstant; + } + + @Override + public void close() throws IOException { + partReader.close(); + } + + @Override + public boolean hasNext() { + if (needNext) { + try { + isBatchPending = partReader.next(); + } catch (IOException e) { + throw new UncheckedIOException(e); + } + needNext = false; + } + return isBatchPending; + } + + @Override + public ColumnarBatch next() { + if (!hasNext()) { + throw new NoSuchElementException("No more batches to iterate"); + } + isBatchPending = false; + needNext = true; + try (ColumnarBatch batch = partReader.get()) { + if (deleteFilter != null) { + throw new UnsupportedOperationException("Delete filter is not supported"); + } + return addConstantColumns(batch); + } + } + + private ColumnarBatch addConstantColumns(ColumnarBatch batch) { + ColumnVector[] columns = new ColumnVector[expectedSchema.columns().size()]; + ColumnarBatch result = null; + final ConstantDetector constantDetector = new ConstantDetector(idToConstant); + try { + int inputIdx = 0; + int outputIdx = 0; + for (Types.NestedField field : expectedSchema.columns()) { + // need to check for key presence since associated value could be null + if (idToConstant.containsKey(field.fieldId())) { + DataType type = SparkSchemaUtil.convert(field.type()); + try (Scalar scalar = GpuScalar.from(idToConstant.get(field.fieldId()), type)) { + columns[outputIdx++] = GpuColumnVector.from(scalar, batch.numRows(), type); + } + } else { + if (TypeUtil.visit(field.type(), constantDetector)) { + throw new UnsupportedOperationException("constants not implemented for nested field"); + } + GpuColumnVector gpuColumn = (GpuColumnVector) batch.column(inputIdx++); + columns[outputIdx++] = gpuColumn.incRefCount(); + } + } + if (inputIdx != batch.numCols()) { + throw new IllegalStateException("Did not consume all input batch columns"); + } + result = new ColumnarBatch(columns, batch.numRows()); + } finally { + if (result == null) { + // TODO: Update safeClose to be reusable by Java code + for (ColumnVector c : columns) { + if (c != null) { + c.close(); + } + } + } + } + return result; + } + + private static class ConstantDetector extends TypeUtil.SchemaVisitor { + private final Map idToConstant; + + ConstantDetector(Map idToConstant) { + this.idToConstant = idToConstant; + } + + @Override + public Boolean schema(Schema schema, Boolean structResult) { + return structResult; + } + + @Override + public Boolean struct(Types.StructType struct, List fieldResults) { + return fieldResults.stream().anyMatch(b -> b); + } + + @Override + public Boolean field(Types.NestedField field, Boolean fieldResult) { + return idToConstant.containsKey(field.fieldId()); + } + + @Override + public Boolean list(Types.ListType list, Boolean elementResult) { + return list.fields().stream() + .anyMatch(f -> idToConstant.containsKey(f.fieldId())); + } + + @Override + public Boolean map(Types.MapType map, Boolean keyResult, Boolean valueResult) { + return map.fields().stream() + .anyMatch(f -> idToConstant.containsKey(f.fieldId())); + } + + @Override + public Boolean primitive(Type.PrimitiveType primitive) { + return false; + } + } +} diff --git a/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/spark/source/GpuSparkBatchQueryScan.java b/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/spark/source/GpuSparkBatchQueryScan.java new file mode 100644 index 00000000000..6aa2b368ae1 --- /dev/null +++ b/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/spark/source/GpuSparkBatchQueryScan.java @@ -0,0 +1,289 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark.rapids.iceberg.spark.source; + +import java.io.IOException; +import java.io.UncheckedIOException; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Set; +import java.util.stream.Collectors; + +import com.nvidia.spark.rapids.RapidsConf; +import com.nvidia.spark.rapids.iceberg.spark.Spark3Util; +import com.nvidia.spark.rapids.iceberg.spark.SparkFilters; +import com.nvidia.spark.rapids.iceberg.spark.SparkReadConf; +import com.nvidia.spark.rapids.iceberg.spark.SparkSchemaUtil; +import com.nvidia.spark.rapids.shims.ShimSupportsRuntimeFiltering; +import org.apache.commons.lang3.reflect.FieldUtils; +import org.apache.iceberg.CombinedScanTask; +import org.apache.iceberg.FileScanTask; +import org.apache.iceberg.PartitionField; +import org.apache.iceberg.PartitionSpec; +import org.apache.iceberg.Schema; +import org.apache.iceberg.Snapshot; +import org.apache.iceberg.Table; +import org.apache.iceberg.TableScan; +import org.apache.iceberg.exceptions.ValidationException; +import org.apache.iceberg.expressions.Binder; +import org.apache.iceberg.expressions.Evaluator; +import org.apache.iceberg.expressions.Expression; +import org.apache.iceberg.expressions.Expressions; +import org.apache.iceberg.expressions.Projections; +import org.apache.iceberg.io.CloseableIterable; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.apache.iceberg.relocated.com.google.common.collect.Sets; +import org.apache.iceberg.util.SnapshotUtil; +import org.apache.iceberg.util.TableScanUtil; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.connector.expressions.NamedReference; +import org.apache.spark.sql.connector.read.Scan; +import org.apache.spark.sql.connector.read.Statistics; +import org.apache.spark.sql.sources.Filter; + +/** + * GPU-accelerated Iceberg batch scan. + * This is derived from Apache Iceberg's BatchQueryScan class. + */ +public class GpuSparkBatchQueryScan extends GpuSparkScan implements ShimSupportsRuntimeFiltering { + private static final Logger LOG = LoggerFactory.getLogger(GpuSparkBatchQueryScan.class); + + private final TableScan scan; + private final Long snapshotId; + private final Long startSnapshotId; + private final Long endSnapshotId; + private final Long asOfTimestamp; + private final List runtimeFilterExpressions; + + private Set specIds = null; // lazy cache of scanned spec IDs + private List files = null; // lazy cache of files + private List tasks = null; // lazy cache of tasks + + public static GpuSparkBatchQueryScan fromCpu(Scan cpuInstance, RapidsConf rapidsConf) throws IllegalAccessException { + Table table = (Table) FieldUtils.readField(cpuInstance, "table", true); + TableScan scan = (TableScan) FieldUtils.readField(cpuInstance, "scan", true); + SparkReadConf readConf = SparkReadConf.fromReflect(FieldUtils.readField(cpuInstance, "readConf", true)); + Schema expectedSchema = (Schema) FieldUtils.readField(cpuInstance, "expectedSchema", true); + List filters = (List) FieldUtils.readField(cpuInstance, "filterExpressions", true); + return new GpuSparkBatchQueryScan(SparkSession.active(), table, scan, readConf, expectedSchema, filters, rapidsConf); + } + + GpuSparkBatchQueryScan(SparkSession spark, Table table, TableScan scan, SparkReadConf readConf, + Schema expectedSchema, List filters, RapidsConf rapidsConf) { + + super(spark, table, readConf, expectedSchema, filters, rapidsConf); + + this.scan = scan; + this.snapshotId = readConf.snapshotId(); + this.startSnapshotId = readConf.startSnapshotId(); + this.endSnapshotId = readConf.endSnapshotId(); + this.asOfTimestamp = readConf.asOfTimestamp(); + this.runtimeFilterExpressions = Lists.newArrayList(); + + if (scan == null) { + this.specIds = Collections.emptySet(); + this.files = Collections.emptyList(); + this.tasks = Collections.emptyList(); + } + } + + Long snapshotId() { + return snapshotId; + } + + private Set specIds() { + if (specIds == null) { + Set specIdSet = Sets.newHashSet(); + for (FileScanTask file : files()) { + specIdSet.add(file.spec().specId()); + } + this.specIds = specIdSet; + } + + return specIds; + } + + private List files() { + if (files == null) { + try (CloseableIterable filesIterable = scan.planFiles()) { + this.files = Lists.newArrayList(filesIterable); + } catch (IOException e) { + throw new UncheckedIOException("Failed to close table scan: " + scan, e); + } + } + + return files; + } + + @Override + protected List tasks() { + if (tasks == null) { + CloseableIterable splitFiles = TableScanUtil.splitFiles( + CloseableIterable.withNoopClose(files()), + scan.targetSplitSize()); + CloseableIterable scanTasks = TableScanUtil.planTasks( + splitFiles, scan.targetSplitSize(), + scan.splitLookback(), scan.splitOpenFileCost()); + tasks = Lists.newArrayList(scanTasks); + } + + return tasks; + } + + @Override + public NamedReference[] filterAttributes() { + Set partitionFieldSourceIds = Sets.newHashSet(); + + for (Integer specId : specIds()) { + PartitionSpec spec = table().specs().get(specId); + for (PartitionField field : spec.fields()) { + partitionFieldSourceIds.add(field.sourceId()); + } + } + + Map quotedNameById = SparkSchemaUtil.indexQuotedNameById(expectedSchema()); + + // the optimizer will look for an equality condition with filter attributes in a join + // as the scan has been already planned, filtering can only be done on projected attributes + // that's why only partition source fields that are part of the read schema can be reported + + return partitionFieldSourceIds.stream() + .filter(fieldId -> expectedSchema().findField(fieldId) != null) + .map(fieldId -> Spark3Util.toNamedReference(quotedNameById.get(fieldId))) + .toArray(NamedReference[]::new); + } + + @Override + public void filter(Filter[] filters) { + Expression runtimeFilterExpr = convertRuntimeFilters(filters); + + if (runtimeFilterExpr != Expressions.alwaysTrue()) { + Map evaluatorsBySpecId = Maps.newHashMap(); + + for (Integer specId : specIds()) { + PartitionSpec spec = table().specs().get(specId); + Expression inclusiveExpr = Projections.inclusive(spec, caseSensitive()).project(runtimeFilterExpr); + Evaluator inclusive = new Evaluator(spec.partitionType(), inclusiveExpr); + evaluatorsBySpecId.put(specId, inclusive); + } + + LOG.info("Trying to filter {} files using runtime filter {}", files().size(), runtimeFilterExpr); + + List filteredFiles = files().stream() + .filter(file -> { + Evaluator evaluator = evaluatorsBySpecId.get(file.spec().specId()); + return evaluator.eval(file.file().partition()); + }) + .collect(Collectors.toList()); + + LOG.info("{}/{} files matched runtime filter {}", filteredFiles.size(), files().size(), runtimeFilterExpr); + + // don't invalidate tasks if the runtime filter had no effect to avoid planning splits again + if (filteredFiles.size() < files().size()) { + this.specIds = null; + this.files = filteredFiles; + this.tasks = null; + } + + // save the evaluated filter for equals/hashCode + runtimeFilterExpressions.add(runtimeFilterExpr); + } + } + + // at this moment, Spark can only pass IN filters for a single attribute + // if there are multiple filter attributes, Spark will pass two separate IN filters + private Expression convertRuntimeFilters(Filter[] filters) { + Expression runtimeFilterExpr = Expressions.alwaysTrue(); + + for (Filter filter : filters) { + Expression expr = SparkFilters.convert(filter); + if (expr != null) { + try { + Binder.bind(expectedSchema().asStruct(), expr, caseSensitive()); + runtimeFilterExpr = Expressions.and(runtimeFilterExpr, expr); + } catch (ValidationException e) { + LOG.warn("Failed to bind {} to expected schema, skipping runtime filter", expr, e); + } + } else { + LOG.warn("Unsupported runtime filter {}", filter); + } + } + + return runtimeFilterExpr; + } + + @Override + public Statistics estimateStatistics() { + if (scan == null) { + return estimateStatistics(null); + + } else if (snapshotId != null) { + Snapshot snapshot = table().snapshot(snapshotId); + return estimateStatistics(snapshot); + + } else if (asOfTimestamp != null) { + long snapshotIdAsOfTime = SnapshotUtil.snapshotIdAsOfTime(table(), asOfTimestamp); + Snapshot snapshot = table().snapshot(snapshotIdAsOfTime); + return estimateStatistics(snapshot); + + } else { + Snapshot snapshot = table().currentSnapshot(); + return estimateStatistics(snapshot); + } + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + + if (o == null || getClass() != o.getClass()) { + return false; + } + + GpuSparkBatchQueryScan that = (GpuSparkBatchQueryScan) o; + return table().name().equals(that.table().name()) && + readSchema().equals(that.readSchema()) && // compare Spark schemas to ignore field ids + filterExpressions().toString().equals(that.filterExpressions().toString()) && + runtimeFilterExpressions.toString().equals(that.runtimeFilterExpressions.toString()) && + Objects.equals(snapshotId, that.snapshotId) && + Objects.equals(startSnapshotId, that.startSnapshotId) && + Objects.equals(endSnapshotId, that.endSnapshotId) && + Objects.equals(asOfTimestamp, that.asOfTimestamp); + } + + @Override + public int hashCode() { + return Objects.hash( + table().name(), readSchema(), filterExpressions().toString(), runtimeFilterExpressions.toString(), + snapshotId, startSnapshotId, endSnapshotId, asOfTimestamp); + } + + @Override + public String toString() { + return String.format( + "IcebergScan(table=%s, type=%s, filters=%s, runtimeFilters=%s, caseSensitive=%s)", + table(), expectedSchema().asStruct(), filterExpressions(), runtimeFilterExpressions, caseSensitive()); + } +} diff --git a/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/spark/source/GpuSparkScan.java b/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/spark/source/GpuSparkScan.java new file mode 100644 index 00000000000..b45e1a6ae03 --- /dev/null +++ b/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/spark/source/GpuSparkScan.java @@ -0,0 +1,289 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark.rapids.iceberg.spark.source; + +import java.io.Serializable; +import java.util.Collection; +import java.util.Collections; +import java.util.List; +import java.util.stream.Collectors; + +import com.nvidia.spark.rapids.GpuMetric; +import com.nvidia.spark.rapids.RapidsConf; +import com.nvidia.spark.rapids.ScanWithMetricsWrapper; +import com.nvidia.spark.rapids.iceberg.spark.Spark3Util; +import com.nvidia.spark.rapids.iceberg.spark.SparkReadConf; +import com.nvidia.spark.rapids.iceberg.spark.SparkSchemaUtil; +import com.nvidia.spark.rapids.iceberg.spark.SparkUtil; +import org.apache.hadoop.conf.Configuration; +import org.apache.iceberg.CombinedScanTask; +import org.apache.iceberg.FileScanTask; +import org.apache.iceberg.Schema; +import org.apache.iceberg.SchemaParser; +import org.apache.iceberg.Snapshot; +import org.apache.iceberg.SnapshotSummary; +import org.apache.iceberg.Table; +import org.apache.iceberg.expressions.Expression; +import org.apache.iceberg.hadoop.HadoopInputFile; +import org.apache.iceberg.hadoop.Util; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; +import org.apache.iceberg.util.PropertyUtil; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.broadcast.Broadcast; +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.connector.read.Batch; +import org.apache.spark.sql.connector.read.InputPartition; +import org.apache.spark.sql.connector.read.PartitionReader; +import org.apache.spark.sql.connector.read.PartitionReaderFactory; +import org.apache.spark.sql.connector.read.Scan; +import org.apache.spark.sql.connector.read.Statistics; +import org.apache.spark.sql.connector.read.SupportsReportStatistics; +import org.apache.spark.sql.connector.read.streaming.MicroBatchStream; +import org.apache.spark.sql.types.StructType; +import org.apache.spark.sql.vectorized.ColumnarBatch; +import org.apache.spark.util.SerializableConfiguration; + +/** + * GPU-accelerated Iceberg Scan. + * This is derived from Apache Iceberg's SparkScan class. + */ +abstract class GpuSparkScan extends ScanWithMetricsWrapper + implements Scan, SupportsReportStatistics { + private static final Logger LOG = LoggerFactory.getLogger(GpuSparkScan.class); + + private final JavaSparkContext sparkContext; + private final Table table; + private final SparkReadConf readConf; + private final boolean caseSensitive; + private final Schema expectedSchema; + private final List filterExpressions; + private final boolean readTimestampWithoutZone; + private final RapidsConf rapidsConf; + + // lazy variables + private StructType readSchema = null; + + GpuSparkScan(SparkSession spark, Table table, SparkReadConf readConf, + Schema expectedSchema, List filters, + RapidsConf rapidsConf) { + + SparkSchemaUtil.validateMetadataColumnReferences(table.schema(), expectedSchema); + + this.sparkContext = JavaSparkContext.fromSparkContext(spark.sparkContext()); + this.table = table; + this.readConf = readConf; + this.caseSensitive = readConf.caseSensitive(); + this.expectedSchema = expectedSchema; + this.filterExpressions = filters != null ? filters : Collections.emptyList(); + this.readTimestampWithoutZone = readConf.handleTimestampWithoutZone(); + this.rapidsConf = rapidsConf; + } + + protected Table table() { + return table; + } + + protected boolean caseSensitive() { + return caseSensitive; + } + + protected Schema expectedSchema() { + return expectedSchema; + } + + protected List filterExpressions() { + return filterExpressions; + } + + protected abstract List tasks(); + + @Override + public Batch toBatch() { + return new SparkBatch(sparkContext, table, readConf, tasks(), expectedSchema, + rapidsConf, metrics()); + } + + @Override + public MicroBatchStream toMicroBatchStream(String checkpointLocation) { + throw new IllegalStateException("Unexpected micro batch stream read"); + } + + @Override + public StructType readSchema() { + if (readSchema == null) { + Preconditions.checkArgument(readTimestampWithoutZone || !SparkUtil.hasTimestampWithoutZone(expectedSchema), + SparkUtil.TIMESTAMP_WITHOUT_TIMEZONE_ERROR); + this.readSchema = SparkSchemaUtil.convert(expectedSchema); + } + return readSchema; + } + + @Override + public Statistics estimateStatistics() { + return estimateStatistics(table.currentSnapshot()); + } + + protected Statistics estimateStatistics(Snapshot snapshot) { + // its a fresh table, no data + if (snapshot == null) { + return new Stats(0L, 0L); + } + + // estimate stats using snapshot summary only for partitioned tables (metadata tables are unpartitioned) + if (!table.spec().isUnpartitioned() && filterExpressions.isEmpty()) { + LOG.debug("using table metadata to estimate table statistics"); + long totalRecords = PropertyUtil.propertyAsLong(snapshot.summary(), + SnapshotSummary.TOTAL_RECORDS_PROP, Long.MAX_VALUE); + return new Stats( + SparkSchemaUtil.estimateSize(readSchema(), totalRecords), + totalRecords); + } + + long numRows = 0L; + + for (CombinedScanTask task : tasks()) { + for (FileScanTask file : task.files()) { + // TODO: if possible, take deletes also into consideration. + double fractionOfFileScanned = ((double) file.length()) / file.file().fileSizeInBytes(); + numRows += (fractionOfFileScanned * file.file().recordCount()); + } + } + + long sizeInBytes = SparkSchemaUtil.estimateSize(readSchema(), numRows); + return new Stats(sizeInBytes, numRows); + } + + @Override + public String description() { + String filters = filterExpressions.stream().map(Spark3Util::describe).collect(Collectors.joining(", ")); + return String.format("%s [filters=%s]", table, filters); + } + + static class ReaderFactory implements PartitionReaderFactory { + @Override + public PartitionReader createReader(InputPartition partition) { + throw new IllegalStateException("non-columnar read"); + } + + @Override + public PartitionReader createColumnarReader(InputPartition partition) { + if (partition instanceof ReadTask) { + return new BatchReader((ReadTask) partition); + } else { + throw new UnsupportedOperationException("Incorrect input partition type: " + partition); + } + } + + @Override + public boolean supportColumnarReads(InputPartition partition) { + return true; + } + } + + private static class BatchReader extends GpuBatchDataReader implements PartitionReader { + BatchReader(ReadTask task) { + super(task.task, task.table(), task.expectedSchema(), task.isCaseSensitive(), + task.getConfiguration(), task.getMaxBatchSizeRows(), task.getMaxBatchSizeBytes(), + task.getParquetDebugDumpPrefix(), task.getMetrics()); + } + } + + static class ReadTask implements InputPartition, Serializable { + private final CombinedScanTask task; + private final Broadcast tableBroadcast; + private final String expectedSchemaString; + private final boolean caseSensitive; + + private final Broadcast confBroadcast; + private final int maxBatchSizeRows; + private final long maxBatchSizeBytes; + private final String parquetDebugDumpPrefix; + private final scala.collection.immutable.Map metrics; + + private transient Schema expectedSchema = null; + private transient String[] preferredLocations = null; + + ReadTask(CombinedScanTask task, Broadcast
tableBroadcast, String expectedSchemaString, + boolean caseSensitive, boolean localityPreferred, RapidsConf rapidsConf, + Broadcast confBroadcast, + scala.collection.immutable.Map metrics) { + this.task = task; + this.tableBroadcast = tableBroadcast; + this.expectedSchemaString = expectedSchemaString; + this.caseSensitive = caseSensitive; + if (localityPreferred) { + Table table = tableBroadcast.value(); + this.preferredLocations = Util.blockLocations(table.io(), task); + } else { + this.preferredLocations = HadoopInputFile.NO_LOCATION_PREFERENCE; + } + this.confBroadcast = confBroadcast; + this.maxBatchSizeRows = rapidsConf.maxReadBatchSizeRows(); + this.maxBatchSizeBytes = rapidsConf.maxReadBatchSizeBytes(); + this.parquetDebugDumpPrefix = rapidsConf.parquetDebugDumpPrefix(); + this.metrics = metrics; + } + + @Override + public String[] preferredLocations() { + return preferredLocations; + } + + public Collection files() { + return task.files(); + } + + public Table table() { + return tableBroadcast.value(); + } + + public boolean isCaseSensitive() { + return caseSensitive; + } + + public Configuration getConfiguration() { + return confBroadcast.value().value(); + } + + public int getMaxBatchSizeRows() { + return maxBatchSizeRows; + } + + public long getMaxBatchSizeBytes() { + return maxBatchSizeBytes; + } + + public String getParquetDebugDumpPrefix() { + return parquetDebugDumpPrefix; + } + + public scala.collection.immutable.Map getMetrics() { + return metrics; + } + + private Schema expectedSchema() { + if (expectedSchema == null) { + this.expectedSchema = SchemaParser.fromJson(expectedSchemaString); + } + return expectedSchema; + } + } +} diff --git a/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/spark/source/SparkBatch.java b/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/spark/source/SparkBatch.java new file mode 100644 index 00000000000..42f4265d22c --- /dev/null +++ b/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/spark/source/SparkBatch.java @@ -0,0 +1,124 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark.rapids.iceberg.spark.source; + +import java.util.List; + +import com.nvidia.spark.rapids.GpuMetric; +import com.nvidia.spark.rapids.RapidsConf; +import com.nvidia.spark.rapids.iceberg.spark.SparkReadConf; +import org.apache.iceberg.CombinedScanTask; +import org.apache.iceberg.Schema; +import org.apache.iceberg.SchemaParser; +import org.apache.iceberg.SerializableTable; +import org.apache.iceberg.Table; +import org.apache.iceberg.util.Tasks; +import org.apache.iceberg.util.ThreadPools; + +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.broadcast.Broadcast; +import org.apache.spark.sql.connector.read.Batch; +import org.apache.spark.sql.connector.read.InputPartition; +import org.apache.spark.sql.connector.read.PartitionReaderFactory; +import org.apache.spark.util.SerializableConfiguration; + +public class SparkBatch implements Batch { + + private final JavaSparkContext sparkContext; + private final Table table; + private final SparkReadConf readConf; + private final List tasks; + private final Schema expectedSchema; + private final boolean caseSensitive; + private final boolean localityEnabled; + private final RapidsConf rapidsConf; + private final scala.collection.immutable.Map metrics; + + SparkBatch(JavaSparkContext sparkContext, Table table, SparkReadConf readConf, + List tasks, Schema expectedSchema, + RapidsConf rapidsConf, + scala.collection.immutable.Map metrics) { + this.sparkContext = sparkContext; + this.table = table; + this.readConf = readConf; + this.tasks = tasks; + this.expectedSchema = expectedSchema; + this.caseSensitive = readConf.caseSensitive(); + this.localityEnabled = readConf.localityEnabled(); + this.rapidsConf = rapidsConf; + this.metrics = metrics; + } + + @Override + public InputPartition[] planInputPartitions() { + // broadcast the table metadata as input partitions will be sent to executors + Broadcast
tableBroadcast = sparkContext.broadcast(SerializableTable.copyOf(table)); + Broadcast confBroadcast = sparkContext.broadcast( + new SerializableConfiguration(sparkContext.hadoopConfiguration())); + String expectedSchemaString = SchemaParser.toJson(expectedSchema); + + InputPartition[] readTasks = new InputPartition[tasks.size()]; + + Tasks.range(readTasks.length) + .stopOnFailure() + .executeWith(localityEnabled ? ThreadPools.getWorkerPool() : null) + .run(index -> readTasks[index] = new GpuSparkScan.ReadTask( + tasks.get(index), tableBroadcast, expectedSchemaString, + caseSensitive, localityEnabled, rapidsConf, confBroadcast, + metrics)); + + return readTasks; + } + + @Override + public PartitionReaderFactory createReaderFactory() { + return new GpuSparkScan.ReaderFactory(); + } + +// private int batchSize() { +// if (parquetOnly() && parquetBatchReadsEnabled()) { +// return readConf.parquetBatchSize(); +// } else if (orcOnly() && orcBatchReadsEnabled()) { +// return readConf.orcBatchSize(); +// } else { +// return 0; +// } +// } +// +// private boolean parquetOnly() { +// return tasks.stream().allMatch(task -> !task.isDataTask() && onlyFileFormat(task, FileFormat.PARQUET)); +// } +// +// private boolean parquetBatchReadsEnabled() { +// return readConf.parquetVectorizationEnabled() && // vectorization enabled +// expectedSchema.columns().size() > 0 && // at least one column is projected +// expectedSchema.columns().stream().allMatch(c -> c.type().isPrimitiveType()); // only primitives +// } +// +// private boolean orcOnly() { +// return tasks.stream().allMatch(task -> !task.isDataTask() && onlyFileFormat(task, FileFormat.ORC)); +// } +// +// private boolean orcBatchReadsEnabled() { +// return readConf.orcVectorizationEnabled() && // vectorization enabled +// tasks.stream().noneMatch(TableScanUtil::hasDeletes); // no delete files +// } +// +// private boolean onlyFileFormat(CombinedScanTask task, FileFormat fileFormat) { +// return task.files().stream().allMatch(fileScanTask -> fileScanTask.file().format().equals(fileFormat)); +// } +} diff --git a/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/spark/source/Stats.java b/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/spark/source/Stats.java new file mode 100644 index 00000000000..e5ec0567c05 --- /dev/null +++ b/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/spark/source/Stats.java @@ -0,0 +1,41 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark.rapids.iceberg.spark.source; + +import java.util.OptionalLong; + +import org.apache.spark.sql.connector.read.Statistics; + +public class Stats implements Statistics { + private final OptionalLong sizeInBytes; + private final OptionalLong numRows; + + Stats(long sizeInBytes, long numRows) { + this.sizeInBytes = OptionalLong.of(sizeInBytes); + this.numRows = OptionalLong.of(numRows); + } + + @Override + public OptionalLong sizeInBytes() { + return sizeInBytes; + } + + @Override + public OptionalLong numRows() { + return numRows; + } +} diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuCSVScan.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuCSVScan.scala index e34f2b898e0..59d342ac7c7 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuCSVScan.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuCSVScan.scala @@ -48,6 +48,9 @@ trait ScanWithMetrics { var metrics : Map[String, GpuMetric] = Map.empty } +// Allows use of ScanWithMetrics from Java code +class ScanWithMetricsWrapper extends ScanWithMetrics + object GpuCSVScan { def tagSupport(scanMeta: ScanMeta[CSVScan]) : Unit = { val scan = scanMeta.wrapped diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuParquetScan.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuParquetScan.scala index c7ab100ecc9..30bc6380a2b 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuParquetScan.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuParquetScan.scala @@ -692,9 +692,7 @@ private case class GpuParquetFileFilterHandler(@transient sqlConf: SQLConf) exte val clippedSchema = GpuParquetPartitionReaderFactoryBase.filterClippedSchema(clippedSchemaTmp, fileSchema, isCaseSensitive) - val columnPaths = clippedSchema.getPaths.asScala.map(x => ColumnPath.get(x: _*)) - val clipped = - ParquetPartitionReader.clipBlocks(columnPaths, blocks.asScala, isCaseSensitive) + val clipped = GpuParquetUtils.clipBlocksToSchema(clippedSchema, blocks, isCaseSensitive) (clipped, clippedSchema) } @@ -1193,7 +1191,7 @@ trait ParquetPartitionReaderBase extends Logging with Arm with ScanWithMetrics currentCopyEnd += column.getTotalSize totalBytesToCopy += column.getTotalSize } - outputBlocks += ParquetPartitionReader.newParquetBlock(block.getRowCount, outputColumns) + outputBlocks += GpuParquetUtils.newBlockMeta(block.getRowCount, outputColumns) } if (currentCopyEnd != currentCopyStart) { @@ -2009,7 +2007,7 @@ class ParquetPartitionReader( override val conf: Configuration, split: PartitionedFile, filePath: Path, - clippedBlocks: Seq[BlockMetaData], + clippedBlocks: Iterable[BlockMetaData], clippedParquetSchema: MessageType, override val isSchemaCaseSensitive: Boolean, override val readDataSchema: StructType, @@ -2146,38 +2144,5 @@ object ParquetPartitionReader { block } - - /** - * Trim block metadata to contain only the column chunks that occur in the specified columns. - * The column chunks that are returned are preserved verbatim - * (i.e.: file offsets remain unchanged). - * - * @param columnPaths the paths of columns to preserve - * @param blocks the block metadata from the original Parquet file - * @param isCaseSensitive indicate if it is case sensitive - * @return the updated block metadata with undesired column chunks removed - */ - @scala.annotation.nowarn( - "msg=method getPath in class ColumnChunkMetaData is deprecated" - ) - private[spark] def clipBlocks(columnPaths: Seq[ColumnPath], - blocks: Seq[BlockMetaData], isCaseSensitive: Boolean): Seq[BlockMetaData] = { - val pathSet = if (isCaseSensitive) { - columnPaths.map(cp => cp.toDotString).toSet - } else { - columnPaths.map(cp => cp.toDotString.toLowerCase(Locale.ROOT)).toSet - } - blocks.map(oldBlock => { - //noinspection ScalaDeprecation - val newColumns = if (isCaseSensitive) { - oldBlock.getColumns.asScala.filter(c => - pathSet.contains(c.getPath.toDotString)) - } else { - oldBlock.getColumns.asScala.filter(c => - pathSet.contains(c.getPath.toDotString.toLowerCase(Locale.ROOT))) - } - ParquetPartitionReader.newParquetBlock(oldBlock.getRowCount, newColumns) - }) - } } diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuParquetUtils.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuParquetUtils.scala new file mode 100644 index 00000000000..ddbfa9a1564 --- /dev/null +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuParquetUtils.scala @@ -0,0 +1,125 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark.rapids + +import java.util.Locale + +import scala.annotation.tailrec +import scala.collection.JavaConverters._ +import scala.collection.mutable.ArrayBuffer + +import org.apache.parquet.hadoop.metadata.{BlockMetaData, ColumnChunkMetaData, ColumnPath} +import org.apache.parquet.schema.MessageType + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.types.StructType + +object GpuParquetUtils extends Logging { + /** + * Trim block metadata to contain only the column chunks that occur in the specified schema. + * The column chunks that are returned are preserved verbatim + * (i.e.: file offsets remain unchanged). + * + * @param readSchema the schema to preserve + * @param blocks the block metadata from the original Parquet file + * @param isCaseSensitive indicate if it is case sensitive + * @return the updated block metadata with undesired column chunks removed + */ + @scala.annotation.nowarn( + "msg=method getPath in class ColumnChunkMetaData is deprecated" + ) + def clipBlocksToSchema( + readSchema: MessageType, + blocks: java.util.List[BlockMetaData], + isCaseSensitive: Boolean): Seq[BlockMetaData] = { + val columnPaths = readSchema.getPaths.asScala.map(x => ColumnPath.get(x: _*)) + val pathSet = if (isCaseSensitive) { + columnPaths.map(cp => cp.toDotString).toSet + } else { + columnPaths.map(cp => cp.toDotString.toLowerCase(Locale.ROOT)).toSet + } + blocks.asScala.map { oldBlock => + //noinspection ScalaDeprecation + val newColumns = if (isCaseSensitive) { + oldBlock.getColumns.asScala.filter(c => pathSet.contains(c.getPath.toDotString)) + } else { + oldBlock.getColumns.asScala.filter(c => + pathSet.contains(c.getPath.toDotString.toLowerCase(Locale.ROOT))) + } + newBlockMeta(oldBlock.getRowCount, newColumns) + } + } + + /** + * Build a new BlockMetaData + * + * @param rowCount the number of rows in this block + * @param columns the new column chunks to reference in the new BlockMetaData + * @return the new BlockMetaData + */ + def newBlockMeta( + rowCount: Long, + columns: Seq[ColumnChunkMetaData]): BlockMetaData = { + val block = new BlockMetaData + block.setRowCount(rowCount) + + var totalSize: Long = 0 + columns.foreach { column => + block.addColumn(column) + totalSize += column.getTotalUncompressedSize + } + block.setTotalByteSize(totalSize) + + block + } + + def getBlocksInBatch( + blockIter: BufferedIterator[BlockMetaData], + readSchema: StructType, + maxBatchSizeRows: Int, + maxBatchSizeBytes: Long): Seq[BlockMetaData] = { + val currentChunk = new ArrayBuffer[BlockMetaData] + var numRows: Long = 0 + var numBytes: Long = 0 + var numParquetBytes: Long = 0 + + @tailrec + def readNextBatch(): Unit = { + if (blockIter.hasNext) { + val peekedRowGroup = blockIter.head + if (peekedRowGroup.getRowCount > Integer.MAX_VALUE) { + throw new UnsupportedOperationException("Too many rows in split") + } + if (numRows == 0 || numRows + peekedRowGroup.getRowCount <= maxBatchSizeRows) { + val estimatedBytes = GpuBatchUtils.estimateGpuMemory(readSchema, + peekedRowGroup.getRowCount) + if (numBytes == 0 || numBytes + estimatedBytes <= maxBatchSizeBytes) { + currentChunk += blockIter.next() + numRows += currentChunk.last.getRowCount + numParquetBytes += currentChunk.last.getTotalByteSize + numBytes += estimatedBytes + readNextBatch() + } + } + } + } + readNextBatch() + logDebug(s"Loaded $numRows rows from Parquet. Parquet bytes read: $numParquetBytes. " + + s"Estimated GPU bytes: $numBytes") + currentChunk + } +} diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsConf.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsConf.scala index 2d86f4ed315..66beee4627c 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsConf.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsConf.scala @@ -224,7 +224,7 @@ class TypedConfBuilder[T]( } } - def createWithDefault(value: T): ConfEntry[T] = { + def createWithDefault(value: T): ConfEntryWithDefault[T] = { val ret = new ConfEntryWithDefault[T](parent.key, converter, parent.doc, parent.isInternal, value) parent.register(ret) diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/ExternalSource.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/ExternalSource.scala index cfd02b7424c..157823f5ed3 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/ExternalSource.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/ExternalSource.scala @@ -16,9 +16,11 @@ package org.apache.spark.sql.rapids +import scala.reflect.ClassTag import scala.util.{Failure, Success, Try} import com.nvidia.spark.rapids._ +import com.nvidia.spark.rapids.iceberg.spark.source.GpuSparkBatchQueryScan import org.apache.spark.broadcast.Broadcast import org.apache.spark.sql.avro.{AvroFileFormat, AvroOptions} @@ -31,7 +33,16 @@ import org.apache.spark.util.{SerializableConfiguration, Utils} object ExternalSource { - lazy val hasSparkAvroJar = { + private lazy val icebergBatchQueryScanClass: Option[Class[_ <: Scan]] = { + val className = "org.apache.iceberg.spark.source.SparkBatchQueryScan" + val loader = Utils.getContextOrSparkClassLoader + Try(loader.loadClass(className)) match { + case Failure(_) => None + case Success(clz) => Some(clz.asSubclass(classOf[Scan])) + } + } + + lazy val hasSparkAvroJar: Boolean = { val loader = Utils.getContextOrSparkClassLoader /** spark-avro is an optional package for Spark, so the RAPIDS Accelerator @@ -137,27 +148,43 @@ object ExternalSource { } def getScans: Map[Class[_ <: Scan], ScanRule[_ <: Scan]] = { + var scans: Seq[ScanRule[_ <: Scan]] = icebergBatchQueryScanClass.map { clz => + Seq(new ScanRule[Scan]( + (a, conf, p, r) => new ScanMeta[Scan](a, conf, p, r) { + override def tagSelfForGpu(): Unit = { + // table could be a mix of Parquet and ORC, so check for ability to load for both + val readSchema = a.readSchema() + FileFormatChecks.tag(this, readSchema, ParquetFormatType, ReadFileOp) + FileFormatChecks.tag(this, readSchema, OrcFormatType, ReadFileOp) + } + + override def convertToGpu(): Scan = GpuSparkBatchQueryScan.fromCpu(a, conf) + }, + "Iceberg scan", + ClassTag(clz))) + }.getOrElse(Seq.empty) + if (hasSparkAvroJar) { - Seq( - GpuOverrides.scan[AvroScan]( - "Avro parsing", - (a, conf, p, r) => new ScanMeta[AvroScan](a, conf, p, r) { - override def tagSelfForGpu(): Unit = GpuAvroScan.tagSupport(this) - - override def convertToGpu(): Scan = - GpuAvroScan(a.sparkSession, - a.fileIndex, - a.dataSchema, - a.readDataSchema, - a.readPartitionSchema, - a.options, - a.pushedFilters, - conf, - a.partitionFilters, - a.dataFilters) - }) - ).map(r => (r.getClassFor.asSubclass(classOf[Scan]), r)).toMap - } else Map.empty + scans = scans :+ GpuOverrides.scan[AvroScan]( + "Avro parsing", + (a, conf, p, r) => new ScanMeta[AvroScan](a, conf, p, r) { + override def tagSelfForGpu(): Unit = GpuAvroScan.tagSupport(this) + + override def convertToGpu(): Scan = + GpuAvroScan(a.sparkSession, + a.fileIndex, + a.dataSchema, + a.readDataSchema, + a.readPartitionSchema, + a.options, + a.pushedFilters, + conf, + a.partitionFilters, + a.dataFilters) + }) + } + + scans.map(r => (r.getClassFor.asSubclass(classOf[Scan]), r)).toMap } /** If the scan is supported as an external source */ @@ -177,7 +204,7 @@ object ExternalSource { def copyScanWithInputFileTrue(scan: Scan): Scan = { if (hasSparkAvroJar) { scan match { - case avroScan: GpuAvroScan => avroScan.copy(queryUsesInputFile=true) + case avroScan: GpuAvroScan => avroScan.copy(queryUsesInputFile = true) case _ => throw new RuntimeException(s"Unsupported scan type: ${scan.getClass.getSimpleName}") } diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/TrampolineUtil.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/TrampolineUtil.scala index a5ad414d487..89f8a1a7eb7 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/TrampolineUtil.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/TrampolineUtil.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2021, NVIDIA CORPORATION. + * Copyright (c) 2019-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -151,4 +151,6 @@ object TrampolineUtil { /** Remove the task context for the current thread */ def unsetTaskContext(): Unit = TaskContext.unset() + + def getContextOrSparkClassLoader: ClassLoader = Utils.getContextOrSparkClassLoader } From 2f2435f15e2ee70182f878c668c889a14467e3a6 Mon Sep 17 00:00:00 2001 From: Jason Lowe Date: Thu, 19 May 2022 16:33:32 -0500 Subject: [PATCH 02/36] Clip Parquet block data to read schema --- .../spark/rapids/iceberg/parquet/GpuParquetReader.java | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/parquet/GpuParquetReader.java b/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/parquet/GpuParquetReader.java index f6b789a2de2..d7783a32b94 100644 --- a/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/parquet/GpuParquetReader.java +++ b/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/parquet/GpuParquetReader.java @@ -23,8 +23,10 @@ import java.util.Objects; import scala.collection.JavaConverters; +import scala.collection.Seq; import com.nvidia.spark.rapids.GpuMetric; +import com.nvidia.spark.rapids.GpuParquetUtils; import com.nvidia.spark.rapids.ParquetPartitionReader; import com.nvidia.spark.rapids.iceberg.data.GpuDeleteFilter; import com.nvidia.spark.rapids.iceberg.spark.SparkSchemaUtil; @@ -144,12 +146,14 @@ public org.apache.iceberg.io.CloseableIterator iterator() { StructType sparkSchema = SparkSchemaUtil.convertWithoutConstants(expectedSchema, idToConstant); MessageType fileReadSchema = buildFileReadSchema(fileSchema); + Seq clippedBlocks = GpuParquetUtils.clipBlocksToSchema( + fileReadSchema, filteredRowGroups, caseSensitive); // reuse Parquet scan code to read the raw data from the file ParquetPartitionReader partReader = new ParquetPartitionReader(conf, partFile, - new Path(input.location()), JavaConverters.collectionAsScalaIterable(filteredRowGroups), - fileReadSchema, caseSensitive, sparkSchema, debugDumpPrefix, - maxBatchSizeRows, maxBatchSizeBytes, metrics, true, true, true); + new Path(input.location()), clippedBlocks, fileReadSchema, caseSensitive, + sparkSchema, debugDumpPrefix, maxBatchSizeRows, maxBatchSizeBytes, metrics, + true, true, true); return new GpuIcebergReader(expectedSchema, partReader, deleteFilter, idToConstant); } catch (IOException e) { From 873443ae05c9a953785824bcc072af4ca96a0d86 Mon Sep 17 00:00:00 2001 From: Jason Lowe Date: Thu, 19 May 2022 16:33:51 -0500 Subject: [PATCH 03/36] Add configs to disable Iceberg --- .../com/nvidia/spark/rapids/GpuOverrides.scala | 11 ++++++++++- .../com/nvidia/spark/rapids/RapidsConf.scala | 14 ++++++++++++++ .../com/nvidia/spark/rapids/TypeChecks.scala | 1 + .../spark/sql/rapids/ExternalSource.scala | 17 +++++++++++++---- 4 files changed, 38 insertions(+), 5 deletions(-) diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala index 77574e8dd95..5ae15e96dcb 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala @@ -435,6 +435,9 @@ object JsonFormatType extends FileFormatType { object AvroFormatType extends FileFormatType { override def toString = "Avro" } +object IcebergFormatType extends FileFormatType { + override def toString = "Iceberg" +} sealed trait FileFormatOp object ReadFileOp extends FileFormatOp { @@ -839,7 +842,13 @@ object GpuOverrides extends Logging { TypeSig.FLOAT + TypeSig.DOUBLE + TypeSig.STRING, cudfWrite = TypeSig.none, sparkSig = (TypeSig.cpuAtomics + TypeSig.STRUCT + TypeSig.ARRAY + TypeSig.MAP + - TypeSig.UDT).nested()))) + TypeSig.UDT).nested())), + (IcebergFormatType, FileFormatChecks( + cudfRead = (TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 + TypeSig.STRUCT + + TypeSig.ARRAY + TypeSig.MAP + GpuTypeShims.additionalParquetSupportedTypes).nested(), + cudfWrite = TypeSig.none, + sparkSig = (TypeSig.cpuAtomics + TypeSig.STRUCT + TypeSig.ARRAY + TypeSig.MAP + + TypeSig.UDT + GpuTypeShims.additionalParquetSupportedTypes).nested()))) val commonExpressions: Map[Class[_ <: Expression], ExprRule[_ <: Expression]] = Seq( expr[Literal]( diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsConf.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsConf.scala index 66beee4627c..898fb36928c 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsConf.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsConf.scala @@ -985,6 +985,16 @@ object RapidsConf { .integerConf .createWithDefault(20) + val ENABLE_ICEBERG = conf("spark.rapids.sql.format.iceberg.enabled") + .doc("When set to false disables all Iceberg acceleration") + .booleanConf + .createWithDefault(true) + + val ENABLE_ICEBERG_READ = conf("spark.rapids.sql.format.iceberg.enabled") + .doc("When set to false disables Iceberg input acceleration") + .booleanConf + .createWithDefault(true) + val ENABLE_RANGE_WINDOW_BYTES = conf("spark.rapids.sql.window.range.byte.enabled") .doc("When the order-by column of a range based window is byte type and " + "the range boundary calculated for a value has overflow, CPU and GPU will get " + @@ -1778,6 +1788,10 @@ class RapidsConf(conf: Map[String, String]) extends Logging { lazy val avroMultiThreadReadNumThreads: Int = get(AVRO_MULTITHREAD_READ_NUM_THREADS) + lazy val isIcebergEnabled: Boolean = get(ENABLE_ICEBERG) + + lazy val isIcebergReadEnabled: Boolean = get(ENABLE_ICEBERG_READ) + lazy val shuffleManagerEnabled: Boolean = get(SHUFFLE_MANAGER_ENABLED) lazy val shuffleTransportEnabled: Boolean = get(SHUFFLE_TRANSPORT_ENABLE) diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/TypeChecks.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/TypeChecks.scala index eacace1241f..90f286fe195 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/TypeChecks.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/TypeChecks.scala @@ -2188,6 +2188,7 @@ object SupportedOpsForTools { case "orc" => conf.isOrcEnabled && conf.isOrcReadEnabled case "json" => conf.isJsonEnabled && conf.isJsonReadEnabled case "avro" => conf.isAvroEnabled && conf.isAvroReadEnabled + case "iceberg" => conf.isIcebergEnabled && conf.isIcebergReadEnabled case _ => throw new IllegalArgumentException("Format is unknown we need to add it here!") } diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/ExternalSource.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/ExternalSource.scala index 157823f5ed3..c100ef2e19a 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/ExternalSource.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/ExternalSource.scala @@ -152,10 +152,19 @@ object ExternalSource { Seq(new ScanRule[Scan]( (a, conf, p, r) => new ScanMeta[Scan](a, conf, p, r) { override def tagSelfForGpu(): Unit = { - // table could be a mix of Parquet and ORC, so check for ability to load for both - val readSchema = a.readSchema() - FileFormatChecks.tag(this, readSchema, ParquetFormatType, ReadFileOp) - FileFormatChecks.tag(this, readSchema, OrcFormatType, ReadFileOp) + // TODO: Should this be tied to Parquet/ORC formats as well since underlying files + // could be that format? + if (!conf.isIcebergEnabled) { + willNotWorkOnGpu("Iceberg input and output has been disabled. To enable set " + + s"${RapidsConf.ENABLE_ICEBERG} to true") + } + + if (!conf.isIcebergReadEnabled) { + willNotWorkOnGpu("Iceberg input has been disabled. To enable set " + + s"${RapidsConf.ENABLE_ICEBERG_READ} to true") + } + + FileFormatChecks.tag(this, a.readSchema(), IcebergFormatType, ReadFileOp) } override def convertToGpu(): Scan = GpuSparkBatchQueryScan.fromCpu(a, conf) From 49e278799f98441da3fa88904d0f810aea4329bb Mon Sep 17 00:00:00 2001 From: Jason Lowe Date: Fri, 20 May 2022 14:57:23 -0500 Subject: [PATCH 04/36] DPP filtering working, still not getting reuse --- .../spark/rapids/shims/Spark320PlusShims.scala | 12 +++++++++--- .../nvidia/spark/rapids/shims/GpuBatchScanExec.scala | 2 +- 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/sql-plugin/src/main/320+/scala/com/nvidia/spark/rapids/shims/Spark320PlusShims.scala b/sql-plugin/src/main/320+/scala/com/nvidia/spark/rapids/shims/Spark320PlusShims.scala index 91389a31cb2..f5f93a5954a 100644 --- a/sql-plugin/src/main/320+/scala/com/nvidia/spark/rapids/shims/Spark320PlusShims.scala +++ b/sql-plugin/src/main/320+/scala/com/nvidia/spark/rapids/shims/Spark320PlusShims.scala @@ -519,13 +519,19 @@ trait Spark320PlusShims extends SparkShims with RebaseShims with Logging { } } + override val childExprs: Seq[BaseExprMeta[_]] = { + // We want to leave the runtime filters as CPU expressions, so leave them out of the expressions + p.output.map(GpuOverrides.wrapExpr(_, conf, Some(this))) + } + override val childScans: scala.Seq[ScanMeta[_]] = Seq(GpuOverrides.wrapScan(p.scan, conf, Some(this))) override def tagPlanForGpu(): Unit = { - if (!p.runtimeFilters.isEmpty) { - willNotWorkOnGpu("runtime filtering (DPP) on datasource V2 is not supported") - } + // TODO: Implement support for runtimeFilters for all supported scans +// if (!p.runtimeFilters.isEmpty) { +// willNotWorkOnGpu("runtime filtering (DPP) on datasource V2 is not supported") +// } } override def convertToCpu(): SparkPlan = { diff --git a/sql-plugin/src/main/320until330-all/scala/com/nvidia/spark/rapids/shims/GpuBatchScanExec.scala b/sql-plugin/src/main/320until330-all/scala/com/nvidia/spark/rapids/shims/GpuBatchScanExec.scala index 5049c3b090b..a78975a3504 100644 --- a/sql-plugin/src/main/320until330-all/scala/com/nvidia/spark/rapids/shims/GpuBatchScanExec.scala +++ b/sql-plugin/src/main/320until330-all/scala/com/nvidia/spark/rapids/shims/GpuBatchScanExec.scala @@ -93,7 +93,7 @@ case class GpuBatchScanExec( // return an empty RDD with 1 partition if dynamic filtering removed the only split sparkContext.parallelize(Array.empty[InternalRow], 1) } else { - new GpuDataSourceRDD(sparkContext, partitions, readerFactory) + new GpuDataSourceRDD(sparkContext, filteredPartitions, readerFactory) } } From 69f6df16e9fa1a884563afa2b7d76c4945d5d7a7 Mon Sep 17 00:00:00 2001 From: Jason Lowe Date: Tue, 24 May 2022 11:17:16 -0500 Subject: [PATCH 05/36] Fix lack of exchange reuse --- .../iceberg/spark/source/GpuSparkScan.java | 2 +- .../iceberg/spark/source/SparkBatch.java | 26 ++++++++++++++++++- 2 files changed, 26 insertions(+), 2 deletions(-) diff --git a/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/spark/source/GpuSparkScan.java b/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/spark/source/GpuSparkScan.java index b45e1a6ae03..1ac371b2b6d 100644 --- a/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/spark/source/GpuSparkScan.java +++ b/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/spark/source/GpuSparkScan.java @@ -118,7 +118,7 @@ protected List filterExpressions() { @Override public Batch toBatch() { return new SparkBatch(sparkContext, table, readConf, tasks(), expectedSchema, - rapidsConf, metrics()); + rapidsConf, metrics(), this); } @Override diff --git a/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/spark/source/SparkBatch.java b/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/spark/source/SparkBatch.java index 42f4265d22c..5c5a1a7d26e 100644 --- a/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/spark/source/SparkBatch.java +++ b/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/spark/source/SparkBatch.java @@ -17,6 +17,7 @@ package com.nvidia.spark.rapids.iceberg.spark.source; import java.util.List; +import java.util.Objects; import com.nvidia.spark.rapids.GpuMetric; import com.nvidia.spark.rapids.RapidsConf; @@ -47,11 +48,13 @@ public class SparkBatch implements Batch { private final boolean localityEnabled; private final RapidsConf rapidsConf; private final scala.collection.immutable.Map metrics; + private final GpuSparkScan parentScan; SparkBatch(JavaSparkContext sparkContext, Table table, SparkReadConf readConf, List tasks, Schema expectedSchema, RapidsConf rapidsConf, - scala.collection.immutable.Map metrics) { + scala.collection.immutable.Map metrics, + GpuSparkScan parentScan) { this.sparkContext = sparkContext; this.table = table; this.readConf = readConf; @@ -61,6 +64,7 @@ public class SparkBatch implements Batch { this.localityEnabled = readConf.localityEnabled(); this.rapidsConf = rapidsConf; this.metrics = metrics; + this.parentScan = parentScan; } @Override @@ -121,4 +125,24 @@ public PartitionReaderFactory createReaderFactory() { // private boolean onlyFileFormat(CombinedScanTask task, FileFormat fileFormat) { // return task.files().stream().allMatch(fileScanTask -> fileScanTask.file().format().equals(fileFormat)); // } + + // TODO: See if latest Iceberg code has the same issues with lacking equals/hashCode on batch + // causing exchange to not be reused + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + SparkBatch that = (SparkBatch) o; + // Emulating Apache Iceberg old SparkScan behavior where the scan was the batch + // to fix exchange reuse with DPP. + return this.parentScan.equals(that.parentScan); + } + + @Override + public int hashCode() { + // Emulating Apache Iceberg old SparkScan behavior where the scan was the batch + // to fix exchange reuse with DPP. + return Objects.hash(parentScan); + } } From a953bbd8498e8e28bbcd493f81f7890e07dd2cfc Mon Sep 17 00:00:00 2001 From: Jason Lowe Date: Thu, 26 May 2022 14:51:22 -0500 Subject: [PATCH 06/36] Add support for Iceberg on Spark 3.1 --- .../rapids/iceberg/spark/SparkReadConf.java | 2 +- .../spark/source/GpuSparkBatchQueryScan.java | 63 ++++++++++++++++++- 2 files changed, 63 insertions(+), 2 deletions(-) diff --git a/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/spark/SparkReadConf.java b/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/spark/SparkReadConf.java index 7dd5f2db46a..abe45a91076 100644 --- a/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/spark/SparkReadConf.java +++ b/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/spark/SparkReadConf.java @@ -53,7 +53,7 @@ public class SparkReadConf { private final SparkConfParser confParser; public static SparkReadConf fromReflect(Object obj) throws IllegalAccessException { - SparkSession spark = (SparkSession) FieldUtils.readField(obj, "spark", true); + SparkSession spark = SparkSession.active(); Table table = (Table) FieldUtils.readField(obj, "table", true); Map readOptions = (Map) FieldUtils.readField(obj, "readOptions", true); return new SparkReadConf(spark, table, readOptions); diff --git a/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/spark/source/GpuSparkBatchQueryScan.java b/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/spark/source/GpuSparkBatchQueryScan.java index 6aa2b368ae1..4974007ef0b 100644 --- a/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/spark/source/GpuSparkBatchQueryScan.java +++ b/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/spark/source/GpuSparkBatchQueryScan.java @@ -39,6 +39,7 @@ import org.apache.iceberg.Schema; import org.apache.iceberg.Snapshot; import org.apache.iceberg.Table; +import org.apache.iceberg.TableProperties; import org.apache.iceberg.TableScan; import org.apache.iceberg.exceptions.ValidationException; import org.apache.iceberg.expressions.Binder; @@ -81,13 +82,73 @@ public class GpuSparkBatchQueryScan extends GpuSparkScan implements ShimSupports public static GpuSparkBatchQueryScan fromCpu(Scan cpuInstance, RapidsConf rapidsConf) throws IllegalAccessException { Table table = (Table) FieldUtils.readField(cpuInstance, "table", true); - TableScan scan = (TableScan) FieldUtils.readField(cpuInstance, "scan", true); SparkReadConf readConf = SparkReadConf.fromReflect(FieldUtils.readField(cpuInstance, "readConf", true)); Schema expectedSchema = (Schema) FieldUtils.readField(cpuInstance, "expectedSchema", true); List filters = (List) FieldUtils.readField(cpuInstance, "filterExpressions", true); + TableScan scan; + try { + scan = (TableScan) FieldUtils.readField(cpuInstance, "scan", true); + } catch (IllegalArgumentException ignored) { + // No TableScan instance, so try to build one now + scan = buildScan(cpuInstance, table, readConf, expectedSchema, filters); + } return new GpuSparkBatchQueryScan(SparkSession.active(), table, scan, readConf, expectedSchema, filters, rapidsConf); } + // Try to build an Iceberg TableScan when one was not found in the CPU instance + private static TableScan buildScan(Scan cpuInstance, + Table table, + SparkReadConf readConf, + Schema expectedSchema, + List filterExpressions) throws IllegalAccessException { + Long snapshotId = (Long) FieldUtils.readField(cpuInstance, "snapshotId", true); + Long startSnapshotId = (Long) FieldUtils.readField(cpuInstance, "startSnapshotId", true); + Long endSnapshotId = (Long) FieldUtils.readField(cpuInstance, "endSnapshotId", true); + Long asOfTimestamp = (Long) FieldUtils.readField(cpuInstance, "asOfTimestamp", true); + Long splitSize = (Long) FieldUtils.readField(cpuInstance, "splitSize", true); + Integer splitLookback = (Integer) FieldUtils.readField(cpuInstance, "splitLookback", true); + Long splitOpenFileCost = (Long) FieldUtils.readField(cpuInstance, "splitOpenFileCost", true); + + TableScan scan = table + .newScan() + .caseSensitive(readConf.caseSensitive()) + .project(expectedSchema); + + if (snapshotId != null) { + scan = scan.useSnapshot(snapshotId); + } + + if (asOfTimestamp != null) { + scan = scan.asOfTime(asOfTimestamp); + } + + if (startSnapshotId != null) { + if (endSnapshotId != null) { + scan = scan.appendsBetween(startSnapshotId, endSnapshotId); + } else { + scan = scan.appendsAfter(startSnapshotId); + } + } + + if (splitSize != null) { + scan = scan.option(TableProperties.SPLIT_SIZE, splitSize.toString()); + } + + if (splitLookback != null) { + scan = scan.option(TableProperties.SPLIT_LOOKBACK, splitLookback.toString()); + } + + if (splitOpenFileCost != null) { + scan = scan.option(TableProperties.SPLIT_OPEN_FILE_COST, splitOpenFileCost.toString()); + } + + for (Expression filter : filterExpressions) { + scan = scan.filter(filter); + } + + return scan; + } + GpuSparkBatchQueryScan(SparkSession spark, Table table, TableScan scan, SparkReadConf readConf, Schema expectedSchema, List filters, RapidsConf rapidsConf) { From c81cb86b5a91975e45485afbf1cfdd1b3ba8fcc3 Mon Sep 17 00:00:00 2001 From: Jason Lowe Date: Thu, 26 May 2022 14:52:40 -0500 Subject: [PATCH 07/36] Use metrics from parent SparkScan --- .../spark/rapids/iceberg/spark/source/GpuSparkScan.java | 2 +- .../spark/rapids/iceberg/spark/source/SparkBatch.java | 6 +----- 2 files changed, 2 insertions(+), 6 deletions(-) diff --git a/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/spark/source/GpuSparkScan.java b/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/spark/source/GpuSparkScan.java index 1ac371b2b6d..d990ab98cf9 100644 --- a/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/spark/source/GpuSparkScan.java +++ b/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/spark/source/GpuSparkScan.java @@ -118,7 +118,7 @@ protected List filterExpressions() { @Override public Batch toBatch() { return new SparkBatch(sparkContext, table, readConf, tasks(), expectedSchema, - rapidsConf, metrics(), this); + rapidsConf, this); } @Override diff --git a/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/spark/source/SparkBatch.java b/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/spark/source/SparkBatch.java index 5c5a1a7d26e..40712e3ec7c 100644 --- a/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/spark/source/SparkBatch.java +++ b/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/spark/source/SparkBatch.java @@ -19,7 +19,6 @@ import java.util.List; import java.util.Objects; -import com.nvidia.spark.rapids.GpuMetric; import com.nvidia.spark.rapids.RapidsConf; import com.nvidia.spark.rapids.iceberg.spark.SparkReadConf; import org.apache.iceberg.CombinedScanTask; @@ -47,13 +46,11 @@ public class SparkBatch implements Batch { private final boolean caseSensitive; private final boolean localityEnabled; private final RapidsConf rapidsConf; - private final scala.collection.immutable.Map metrics; private final GpuSparkScan parentScan; SparkBatch(JavaSparkContext sparkContext, Table table, SparkReadConf readConf, List tasks, Schema expectedSchema, RapidsConf rapidsConf, - scala.collection.immutable.Map metrics, GpuSparkScan parentScan) { this.sparkContext = sparkContext; this.table = table; @@ -63,7 +60,6 @@ public class SparkBatch implements Batch { this.caseSensitive = readConf.caseSensitive(); this.localityEnabled = readConf.localityEnabled(); this.rapidsConf = rapidsConf; - this.metrics = metrics; this.parentScan = parentScan; } @@ -83,7 +79,7 @@ public InputPartition[] planInputPartitions() { .run(index -> readTasks[index] = new GpuSparkScan.ReadTask( tasks.get(index), tableBroadcast, expectedSchemaString, caseSensitive, localityEnabled, rapidsConf, confBroadcast, - metrics)); + parentScan.metrics())); return readTasks; } From d18ac4f884111097d5126904f5444c25b78d71c8 Mon Sep 17 00:00:00 2001 From: Jason Lowe Date: Mon, 13 Jun 2022 15:32:42 -0500 Subject: [PATCH 08/36] Fix metrics --- .../iceberg/parquet/GpuParquetReader.java | 5 ++--- .../iceberg/spark/source/GpuSparkScan.java | 21 +++++++++---------- .../iceberg/spark/source/SparkBatch.java | 5 ++--- 3 files changed, 14 insertions(+), 17 deletions(-) diff --git a/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/parquet/GpuParquetReader.java b/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/parquet/GpuParquetReader.java index d7783a32b94..4c21b18b6c0 100644 --- a/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/parquet/GpuParquetReader.java +++ b/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/parquet/GpuParquetReader.java @@ -22,7 +22,6 @@ import java.util.Map; import java.util.Objects; -import scala.collection.JavaConverters; import scala.collection.Seq; import com.nvidia.spark.rapids.GpuMetric; @@ -43,7 +42,6 @@ import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; import org.apache.iceberg.relocated.com.google.common.collect.Lists; import org.apache.iceberg.relocated.com.google.common.collect.Maps; -import org.apache.iceberg.types.TypeUtil; import org.apache.iceberg.types.Types; import org.apache.parquet.ParquetReadOptions; import org.apache.parquet.hadoop.ParquetFileReader; @@ -100,6 +98,7 @@ public GpuParquetReader( @Override public org.apache.iceberg.io.CloseableIterator iterator() { + scala.collection.immutable.Map localMetrics = metrics; try (ParquetFileReader reader = newReader(input, options)) { MessageType fileSchema = reader.getFileMetaData().getSchema(); @@ -152,7 +151,7 @@ public org.apache.iceberg.io.CloseableIterator iterator() { // reuse Parquet scan code to read the raw data from the file ParquetPartitionReader partReader = new ParquetPartitionReader(conf, partFile, new Path(input.location()), clippedBlocks, fileReadSchema, caseSensitive, - sparkSchema, debugDumpPrefix, maxBatchSizeRows, maxBatchSizeBytes, metrics, + sparkSchema, debugDumpPrefix, maxBatchSizeRows, maxBatchSizeBytes, localMetrics, true, true, true); return new GpuIcebergReader(expectedSchema, partReader, deleteFilter, idToConstant); diff --git a/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/spark/source/GpuSparkScan.java b/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/spark/source/GpuSparkScan.java index d990ab98cf9..51083876607 100644 --- a/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/spark/source/GpuSparkScan.java +++ b/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/spark/source/GpuSparkScan.java @@ -178,6 +178,12 @@ public String description() { } static class ReaderFactory implements PartitionReaderFactory { + private scala.collection.immutable.Map metrics; + + public ReaderFactory(scala.collection.immutable.Map metrics) { + this.metrics = metrics; + } + @Override public PartitionReader createReader(InputPartition partition) { throw new IllegalStateException("non-columnar read"); @@ -186,7 +192,7 @@ public PartitionReader createReader(InputPartition partition) { @Override public PartitionReader createColumnarReader(InputPartition partition) { if (partition instanceof ReadTask) { - return new BatchReader((ReadTask) partition); + return new BatchReader((ReadTask) partition, metrics); } else { throw new UnsupportedOperationException("Incorrect input partition type: " + partition); } @@ -199,10 +205,10 @@ public boolean supportColumnarReads(InputPartition partition) { } private static class BatchReader extends GpuBatchDataReader implements PartitionReader { - BatchReader(ReadTask task) { + BatchReader(ReadTask task, scala.collection.immutable.Map metrics) { super(task.task, task.table(), task.expectedSchema(), task.isCaseSensitive(), task.getConfiguration(), task.getMaxBatchSizeRows(), task.getMaxBatchSizeBytes(), - task.getParquetDebugDumpPrefix(), task.getMetrics()); + task.getParquetDebugDumpPrefix(), metrics); } } @@ -216,15 +222,13 @@ static class ReadTask implements InputPartition, Serializable { private final int maxBatchSizeRows; private final long maxBatchSizeBytes; private final String parquetDebugDumpPrefix; - private final scala.collection.immutable.Map metrics; private transient Schema expectedSchema = null; private transient String[] preferredLocations = null; ReadTask(CombinedScanTask task, Broadcast
tableBroadcast, String expectedSchemaString, boolean caseSensitive, boolean localityPreferred, RapidsConf rapidsConf, - Broadcast confBroadcast, - scala.collection.immutable.Map metrics) { + Broadcast confBroadcast) { this.task = task; this.tableBroadcast = tableBroadcast; this.expectedSchemaString = expectedSchemaString; @@ -239,7 +243,6 @@ static class ReadTask implements InputPartition, Serializable { this.maxBatchSizeRows = rapidsConf.maxReadBatchSizeRows(); this.maxBatchSizeBytes = rapidsConf.maxReadBatchSizeBytes(); this.parquetDebugDumpPrefix = rapidsConf.parquetDebugDumpPrefix(); - this.metrics = metrics; } @Override @@ -275,10 +278,6 @@ public String getParquetDebugDumpPrefix() { return parquetDebugDumpPrefix; } - public scala.collection.immutable.Map getMetrics() { - return metrics; - } - private Schema expectedSchema() { if (expectedSchema == null) { this.expectedSchema = SchemaParser.fromJson(expectedSchemaString); diff --git a/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/spark/source/SparkBatch.java b/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/spark/source/SparkBatch.java index 40712e3ec7c..c6c8383442e 100644 --- a/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/spark/source/SparkBatch.java +++ b/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/spark/source/SparkBatch.java @@ -78,15 +78,14 @@ public InputPartition[] planInputPartitions() { .executeWith(localityEnabled ? ThreadPools.getWorkerPool() : null) .run(index -> readTasks[index] = new GpuSparkScan.ReadTask( tasks.get(index), tableBroadcast, expectedSchemaString, - caseSensitive, localityEnabled, rapidsConf, confBroadcast, - parentScan.metrics())); + caseSensitive, localityEnabled, rapidsConf, confBroadcast)); return readTasks; } @Override public PartitionReaderFactory createReaderFactory() { - return new GpuSparkScan.ReaderFactory(); + return new GpuSparkScan.ReaderFactory(parentScan.metrics()); } // private int batchSize() { From 9f6321288850871fbeb181ce49aab3b916917707 Mon Sep 17 00:00:00 2001 From: Jason Lowe Date: Tue, 14 Jun 2022 16:18:49 -0500 Subject: [PATCH 09/36] Fix DPP test --- integration_tests/src/main/python/iceberg_test.py | 5 +++-- .../com/nvidia/spark/rapids/shims/Spark320PlusShims.scala | 7 +++---- .../com/nvidia/spark/rapids/shims/GpuBatchScanExec.scala | 3 +++ .../com/nvidia/spark/rapids/shims/GpuBatchScanExec.scala | 3 +++ .../main/scala/com/nvidia/spark/rapids/RapidsMeta.scala | 2 ++ .../scala/org/apache/spark/sql/rapids/ExternalSource.scala | 2 ++ 6 files changed, 16 insertions(+), 6 deletions(-) diff --git a/integration_tests/src/main/python/iceberg_test.py b/integration_tests/src/main/python/iceberg_test.py index c507b7e5d66..2ab82977049 100644 --- a/integration_tests/src/main/python/iceberg_test.py +++ b/integration_tests/src/main/python/iceberg_test.py @@ -28,9 +28,10 @@ def setup_iceberg_table(spark): spark.sql("INSERT INTO {} VALUES (1, 'a'), (2, 'b'), (3, 'c')".format(table)) with_cpu_session(setup_iceberg_table) assert_gpu_and_cpu_are_equal_collect( - lambda spark : spark.sql("SELECT COUNT(DISTINCT id) from {}".format(table))) + lambda spark : spark.sql("SELECT COUNT(DISTINCT id) from {}".format(table)), + conf={"spark.rapids.sql.format.iceberg.enabled": "false"} + ) -@allow_non_gpu('BatchScanExec') @iceberg @ignore_order(local=True) @pytest.mark.skipif(is_before_spark_320() or is_databricks_runtime(), diff --git a/sql-plugin/src/main/320+/scala/com/nvidia/spark/rapids/shims/Spark320PlusShims.scala b/sql-plugin/src/main/320+/scala/com/nvidia/spark/rapids/shims/Spark320PlusShims.scala index b38b98ff8d3..53be6740faf 100644 --- a/sql-plugin/src/main/320+/scala/com/nvidia/spark/rapids/shims/Spark320PlusShims.scala +++ b/sql-plugin/src/main/320+/scala/com/nvidia/spark/rapids/shims/Spark320PlusShims.scala @@ -469,10 +469,9 @@ trait Spark320PlusShims extends SparkShims with RebaseShims with Logging { Seq(GpuOverrides.wrapScan(p.scan, conf, Some(this))) override def tagPlanForGpu(): Unit = { - // TODO: Implement support for runtimeFilters for all supported scans -// if (!p.runtimeFilters.isEmpty) { -// willNotWorkOnGpu("runtime filtering (DPP) on datasource V2 is not supported") -// } + if (!p.runtimeFilters.isEmpty && !childScans.head.supportsRuntimeFilters) { + willNotWorkOnGpu("runtime filtering (DPP) is not supported") + } } override def convertToCpu(): SparkPlan = { diff --git a/sql-plugin/src/main/320until330-all/scala/com/nvidia/spark/rapids/shims/GpuBatchScanExec.scala b/sql-plugin/src/main/320until330-all/scala/com/nvidia/spark/rapids/shims/GpuBatchScanExec.scala index db3a610a314..6046b24318f 100644 --- a/sql-plugin/src/main/320until330-all/scala/com/nvidia/spark/rapids/shims/GpuBatchScanExec.scala +++ b/sql-plugin/src/main/320until330-all/scala/com/nvidia/spark/rapids/shims/GpuBatchScanExec.scala @@ -37,6 +37,9 @@ case class GpuBatchScanExec( extends DataSourceV2ScanExecBase with GpuBatchScanExecMetrics { @transient lazy val batch: Batch = scan.toBatch + // All expressions are filter expressions used on the CPU. + override def gpuExpressions: Seq[Expression] = Nil + // TODO: unify the equal/hashCode implementation for all data source v2 query plans. override def equals(other: Any): Boolean = other match { case other: GpuBatchScanExec => diff --git a/sql-plugin/src/main/330+/scala/com/nvidia/spark/rapids/shims/GpuBatchScanExec.scala b/sql-plugin/src/main/330+/scala/com/nvidia/spark/rapids/shims/GpuBatchScanExec.scala index ac143331880..a80e4b90392 100644 --- a/sql-plugin/src/main/330+/scala/com/nvidia/spark/rapids/shims/GpuBatchScanExec.scala +++ b/sql-plugin/src/main/330+/scala/com/nvidia/spark/rapids/shims/GpuBatchScanExec.scala @@ -39,6 +39,9 @@ case class GpuBatchScanExec( extends DataSourceV2ScanExecBase with GpuBatchScanExecMetrics { @transient lazy val batch: Batch = scan.toBatch + // All expressions are filter expressions used on the CPU. + override def gpuExpressions: Seq[Expression] = Nil + // TODO: unify the equal/hashCode implementation for all data source v2 query plans. override def equals(other: Any): Boolean = other match { case other: BatchScanExec => diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsMeta.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsMeta.scala index 7b23d6a5b78..4deac9f4ba6 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsMeta.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsMeta.scala @@ -488,6 +488,8 @@ abstract class ScanMeta[INPUT <: Scan](scan: INPUT, override val childDataWriteCmds: Seq[DataWritingCommandMeta[_]] = Seq.empty override def tagSelfForGpu(): Unit = {} + + def supportsRuntimeFilters: Boolean = false } /** diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/ExternalSource.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/ExternalSource.scala index 8ec333771e5..7a775036ddf 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/ExternalSource.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/ExternalSource.scala @@ -143,6 +143,8 @@ object ExternalSource extends Logging { var scans: Seq[ScanRule[_ <: Scan]] = icebergBatchQueryScanClass.map { clz => Seq(new ScanRule[Scan]( (a, conf, p, r) => new ScanMeta[Scan](a, conf, p, r) { + override def supportsRuntimeFilters: Boolean = true + override def tagSelfForGpu(): Unit = { // TODO: Should this be tied to Parquet/ORC formats as well since underlying files // could be that format? From a2086bf2f8c437e48ca3cafbac282d8255995bb8 Mon Sep 17 00:00:00 2001 From: Jason Lowe Date: Tue, 14 Jun 2022 16:27:46 -0500 Subject: [PATCH 10/36] Update NOTICE-binary --- NOTICE-binary | 30 +++++++++++++++++++++++++++++- 1 file changed, 29 insertions(+), 1 deletion(-) diff --git a/NOTICE-binary b/NOTICE-binary index 02899872b48..57c3552cc11 100644 --- a/NOTICE-binary +++ b/NOTICE-binary @@ -12,7 +12,35 @@ Copyright 2014 and onwards The Apache Software Foundation This product includes software developed at The Apache Software Foundation (http://www.apache.org/). ---------------------------------------------------------------------- +-------------------------------------------------------------------------------- + +Apache Iceberg +Copyright 2017-2022 The Apache Software Foundation + +This product includes software developed at +The Apache Software Foundation (http://www.apache.org/). + +-------------------------------------------------------------------------------- + +This project includes code from Kite, developed at Cloudera, Inc. with +the following copyright notice: + +| Copyright 2013 Cloudera Inc. +| +| Licensed under the Apache License, Version 2.0 (the "License"); +| you may not use this file except in compliance with the License. +| You may obtain a copy of the License at +| +| http://www.apache.org/licenses/LICENSE-2.0 +| +| Unless required by applicable law or agreed to in writing, software +| distributed under the License is distributed on an "AS IS" BASIS, +| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +| See the License for the specific language governing permissions and +| limitations under the License. + +-------------------------------------------------------------------------------- + UCF Consortium - Unified Communication X (UCX) Copyright (c) 2014-2015 UT-Battelle, LLC. All rights reserved. From e006f18495fc850aafb7fe330836cf5d7846e9b7 Mon Sep 17 00:00:00 2001 From: Jason Lowe Date: Fri, 24 Jun 2022 13:42:05 -0500 Subject: [PATCH 11/36] Remove unused code --- .../spark/rapids/CloseableIterator.java | 23 - .../iceberg/parquet/GpuParquetReader.java | 46 +- .../iceberg/parquet/ParquetSchemaUtil.java | 73 -- .../rapids/iceberg/parquet/ParquetUtil.java | 273 ------- .../rapids/iceberg/spark/Spark3Util.java | 688 ------------------ .../rapids/iceberg/spark/SparkSchemaUtil.java | 232 ------ .../spark/rapids/iceberg/spark/SparkUtil.java | 178 ----- .../iceberg/spark/source/SparkBatch.java | 38 - .../nvidia/spark/rapids/GpuParquetUtils.scala | 39 - 9 files changed, 6 insertions(+), 1584 deletions(-) delete mode 100644 sql-plugin/src/main/java/com/nvidia/spark/rapids/CloseableIterator.java diff --git a/sql-plugin/src/main/java/com/nvidia/spark/rapids/CloseableIterator.java b/sql-plugin/src/main/java/com/nvidia/spark/rapids/CloseableIterator.java deleted file mode 100644 index 0864f7887d6..00000000000 --- a/sql-plugin/src/main/java/com/nvidia/spark/rapids/CloseableIterator.java +++ /dev/null @@ -1,23 +0,0 @@ -/* - * Copyright (c) 2019-2022, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.nvidia.spark.rapids; - -import java.io.Closeable; -import java.util.Iterator; - -public interface CloseableIterator extends Iterator, Closeable { -} diff --git a/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/parquet/GpuParquetReader.java b/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/parquet/GpuParquetReader.java index 4c21b18b6c0..239b5887312 100644 --- a/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/parquet/GpuParquetReader.java +++ b/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/parquet/GpuParquetReader.java @@ -98,30 +98,21 @@ public GpuParquetReader( @Override public org.apache.iceberg.io.CloseableIterator iterator() { - scala.collection.immutable.Map localMetrics = metrics; try (ParquetFileReader reader = newReader(input, options)) { MessageType fileSchema = reader.getFileMetaData().getSchema(); MessageType typeWithIds; -// MessageType projection; if (ParquetSchemaUtil.hasIds(fileSchema)) { typeWithIds = fileSchema; -// projection = ParquetSchemaUtil.pruneColumns(fileSchema, expectedSchema); } else if (nameMapping != null) { typeWithIds = ParquetSchemaUtil.applyNameMapping(fileSchema, nameMapping); -// projection = ParquetSchemaUtil.pruneColumns(typeWithIds, expectedSchema); } else { typeWithIds = ParquetSchemaUtil.addFallbackIds(fileSchema); -// projection = ParquetSchemaUtil.pruneColumnsFallback(fileSchema, expectedSchema); } List rowGroups = reader.getRowGroups(); List filteredRowGroups = Lists.newArrayListWithCapacity(rowGroups.size()); -// boolean[] startRowPositions[i] = new boolean[rowGroups.size()]; -// -// // Fetch all row groups starting positions to compute the row offsets of the filtered row groups -// Map offsetToStartPos = generateOffsetToStartPos(expectedSchema); if (expectedSchema.findField(MetadataColumns.ROW_POSITION.fieldId()) != null) { throw new UnsupportedOperationException("row position meta column not implemented"); } @@ -134,7 +125,6 @@ public org.apache.iceberg.io.CloseableIterator iterator() { } for (BlockMetaData rowGroup : rowGroups) { -// startRowPositions[i] = offsetToStartPos == null ? 0 : offsetToStartPos.get(rowGroup.getStartingPos()); boolean shouldRead = filter == null || ( statsFilter.shouldRead(typeWithIds, rowGroup) && dictFilter.shouldRead(typeWithIds, rowGroup, reader.getDictionaryReader(rowGroup))); @@ -151,7 +141,7 @@ public org.apache.iceberg.io.CloseableIterator iterator() { // reuse Parquet scan code to read the raw data from the file ParquetPartitionReader partReader = new ParquetPartitionReader(conf, partFile, new Path(input.location()), clippedBlocks, fileReadSchema, caseSensitive, - sparkSchema, debugDumpPrefix, maxBatchSizeRows, maxBatchSizeBytes, localMetrics, + sparkSchema, debugDumpPrefix, maxBatchSizeRows, maxBatchSizeBytes, metrics, true, true, true); return new GpuIcebergReader(expectedSchema, partReader, deleteFilter, idToConstant); @@ -160,28 +150,6 @@ public org.apache.iceberg.io.CloseableIterator iterator() { } } -// private Map generateOffsetToStartPos(Schema schema) { -// if (schema.findField(MetadataColumns.ROW_POSITION.fieldId()) == null) { -// return null; -// } -// -// try (ParquetFileReader fileReader = newReader(input, ParquetReadOptions.builder().build())) { -// Map offsetToStartPos = Maps.newHashMap(); -// -// long curRowCount = 0; -// for (int i = 0; i < fileReader.getRowGroups().size(); i += 1) { -// BlockMetaData meta = fileReader.getRowGroups().get(i); -// offsetToStartPos.put(meta.getStartingPos(), curRowCount); -// curRowCount += meta.getRowCount(); -// } -// -// return offsetToStartPos; -// -// } catch (IOException e) { -// throw new UncheckedIOException("Failed to create/close reader for file: " + input, e); -// } -// } - private static ParquetFileReader newReader(InputFile file, ParquetReadOptions options) { try { return ParquetFileReader.open(ParquetIO.file(file), options); @@ -196,20 +164,18 @@ private MessageType buildFileReadSchema(MessageType fileSchema) { if (ParquetSchemaUtil.hasIds(fileSchema)) { return (MessageType) TypeWithSchemaVisitor.visit(expectedSchema.asStruct(), fileSchema, - new ReorderColumns(fileSchema, idToConstant)); + new ReorderColumns(idToConstant)); } else { return (MessageType) TypeWithSchemaVisitor.visit(expectedSchema.asStruct(), fileSchema, - new ReorderColumnsFallback(fileSchema, idToConstant)); + new ReorderColumnsFallback(idToConstant)); } } private static class ReorderColumns extends TypeWithSchemaVisitor { - private final MessageType fileSchema; private final Map idToConstant; - public ReorderColumns(MessageType fileSchema, Map idToConstant) { - this.fileSchema = fileSchema; + public ReorderColumns(Map idToConstant) { this.idToConstant = idToConstant; } @@ -303,8 +269,8 @@ private List filterAndReorder(Types.StructType expected, List fields } private static class ReorderColumnsFallback extends ReorderColumns { - public ReorderColumnsFallback(MessageType fileSchema, Map idToConstant) { - super(fileSchema, idToConstant); + public ReorderColumnsFallback(Map idToConstant) { + super(idToConstant); } @Override diff --git a/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/parquet/ParquetSchemaUtil.java b/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/parquet/ParquetSchemaUtil.java index 51a12e55dac..e89301a67e4 100644 --- a/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/parquet/ParquetSchemaUtil.java +++ b/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/parquet/ParquetSchemaUtil.java @@ -29,79 +29,6 @@ public class ParquetSchemaUtil { private ParquetSchemaUtil() { } -// public static MessageType convert(Schema schema, String name) { -// return new TypeToMessageType().convert(schema, name); -// } -// -// /** -// * Converts a Parquet schema to an Iceberg schema. Fields without IDs are kept and assigned fallback IDs. -// * -// * @param parquetSchema a Parquet schema -// * @return a matching Iceberg schema for the provided Parquet schema -// */ -// public static Schema convert(MessageType parquetSchema) { -// // if the Parquet schema does not contain ids, we assign fallback ids to top-level fields -// // all remaining fields will get ids >= 1000 to avoid pruning columns without ids -// MessageType parquetSchemaWithIds = hasIds(parquetSchema) ? parquetSchema : addFallbackIds(parquetSchema); -// AtomicInteger nextId = new AtomicInteger(1000); -// return convertInternal(parquetSchemaWithIds, name -> nextId.getAndIncrement()); -// } -// -// /** -// * Converts a Parquet schema to an Iceberg schema and prunes fields without IDs. -// * -// * @param parquetSchema a Parquet schema -// * @return a matching Iceberg schema for the provided Parquet schema -// */ -// public static Schema convertAndPrune(MessageType parquetSchema) { -// return convertInternal(parquetSchema, name -> null); -// } -// -// private static Schema convertInternal(MessageType parquetSchema, Function nameToIdFunc) { -// MessageTypeToType converter = new MessageTypeToType(nameToIdFunc); -// return new Schema( -// ParquetTypeVisitor.visit(parquetSchema, converter).asNestedType().fields(), -// converter.getAliases()); -// } -// -// public static MessageType pruneColumns(MessageType fileSchema, Schema expectedSchema) { -// // column order must match the incoming type, so it doesn't matter that the ids are unordered -// Set selectedIds = TypeUtil.getProjectedIds(expectedSchema); -// return (MessageType) ParquetTypeVisitor.visit(fileSchema, new PruneColumns(selectedIds)); -// } -// -// /** -// * Prunes columns from a Parquet file schema that was written without field ids. -// *

-// * Files that were written without field ids are read assuming that schema evolution preserved -// * column order. Deleting columns was not allowed. -// *

-// * The order of columns in the resulting Parquet schema matches the Parquet file. -// * -// * @param fileSchema schema from a Parquet file that does not have field ids. -// * @param expectedSchema expected schema -// * @return a parquet schema pruned using the expected schema -// */ -// public static MessageType pruneColumnsFallback(MessageType fileSchema, Schema expectedSchema) { -// Set selectedIds = Sets.newHashSet(); -// -// for (Types.NestedField field : expectedSchema.columns()) { -// selectedIds.add(field.fieldId()); -// } -// -// MessageTypeBuilder builder = org.apache.parquet.schema.Types.buildMessage(); -// -// int ordinal = 1; -// for (Type type : fileSchema.getFields()) { -// if (selectedIds.contains(ordinal)) { -// builder.addField(type.withId(ordinal)); -// } -// ordinal += 1; -// } -// -// return builder.named(fileSchema.getName()); -// } - public static boolean hasIds(MessageType fileSchema) { return ParquetTypeVisitor.visit(fileSchema, new HasIds()); } diff --git a/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/parquet/ParquetUtil.java b/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/parquet/ParquetUtil.java index 512ce1d5907..d060748f965 100644 --- a/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/parquet/ParquetUtil.java +++ b/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/parquet/ParquetUtil.java @@ -28,252 +28,6 @@ public class ParquetUtil { private ParquetUtil() { } -// public static Metrics fileMetrics(InputFile file, MetricsConfig metricsConfig) { -// return fileMetrics(file, metricsConfig, null); -// } -// -// public static Metrics fileMetrics(InputFile file, MetricsConfig metricsConfig, NameMapping nameMapping) { -// try (ParquetFileReader reader = ParquetFileReader.open(ParquetIO.file(file))) { -// return footerMetrics(reader.getFooter(), Stream.empty(), metricsConfig, nameMapping); -// } catch (IOException e) { -// throw new RuntimeIOException(e, "Failed to read footer of file: %s", file); -// } -// } -// -// public static Metrics footerMetrics(ParquetMetadata metadata, Stream> fieldMetrics, -// MetricsConfig metricsConfig) { -// return footerMetrics(metadata, fieldMetrics, metricsConfig, null); -// } -// -// @SuppressWarnings("checkstyle:CyclomaticComplexity") -// public static Metrics footerMetrics(ParquetMetadata metadata, Stream> fieldMetrics, -// MetricsConfig metricsConfig, NameMapping nameMapping) { -// Preconditions.checkNotNull(fieldMetrics, "fieldMetrics should not be null"); -// -// long rowCount = 0; -// Map columnSizes = Maps.newHashMap(); -// Map valueCounts = Maps.newHashMap(); -// Map nullValueCounts = Maps.newHashMap(); -// Map> lowerBounds = Maps.newHashMap(); -// Map> upperBounds = Maps.newHashMap(); -// Set missingStats = Sets.newHashSet(); -// -// // ignore metrics for fields we failed to determine reliable IDs -// MessageType parquetTypeWithIds = getParquetTypeWithIds(metadata, nameMapping); -// Schema fileSchema = ParquetSchemaUtil.convertAndPrune(parquetTypeWithIds); -// -// Map> fieldMetricsMap = fieldMetrics.collect( -// Collectors.toMap(FieldMetrics::id, Function.identity())); -// -// List blocks = metadata.getBlocks(); -// for (BlockMetaData block : blocks) { -// rowCount += block.getRowCount(); -// for (ColumnChunkMetaData column : block.getColumns()) { -// -// Integer fieldId = fileSchema.aliasToId(column.getPath().toDotString()); -// if (fieldId == null) { -// // fileSchema may contain a subset of columns present in the file -// // as we prune columns we could not assign ids -// continue; -// } -// -// increment(columnSizes, fieldId, column.getTotalSize()); -// -// MetricsMode metricsMode = MetricsUtil.metricsMode(fileSchema, metricsConfig, fieldId); -// if (metricsMode == MetricsModes.None.get()) { -// continue; -// } -// increment(valueCounts, fieldId, column.getValueCount()); -// -// Statistics stats = column.getStatistics(); -// if (stats == null) { -// missingStats.add(fieldId); -// } else if (!stats.isEmpty()) { -// increment(nullValueCounts, fieldId, stats.getNumNulls()); -// -// // when there are metrics gathered by Iceberg for a column, we should use those instead -// // of the ones from Parquet -// if (metricsMode != MetricsModes.Counts.get() && !fieldMetricsMap.containsKey(fieldId)) { -// Types.NestedField field = fileSchema.findField(fieldId); -// if (field != null && stats.hasNonNullValue() && shouldStoreBounds(column, fileSchema)) { -// Literal min = ParquetConversions.fromParquetPrimitive( -// field.type(), column.getPrimitiveType(), stats.genericGetMin()); -// updateMin(lowerBounds, fieldId, field.type(), min, metricsMode); -// Literal max = ParquetConversions.fromParquetPrimitive( -// field.type(), column.getPrimitiveType(), stats.genericGetMax()); -// updateMax(upperBounds, fieldId, field.type(), max, metricsMode); -// } -// } -// } -// } -// } -// -// // discard accumulated values if any stats were missing -// for (Integer fieldId : missingStats) { -// nullValueCounts.remove(fieldId); -// lowerBounds.remove(fieldId); -// upperBounds.remove(fieldId); -// } -// -// updateFromFieldMetrics(fieldMetricsMap, metricsConfig, fileSchema, lowerBounds, upperBounds); -// -// return new Metrics(rowCount, columnSizes, valueCounts, nullValueCounts, -// MetricsUtil.createNanValueCounts(fieldMetricsMap.values().stream(), metricsConfig, fileSchema), -// toBufferMap(fileSchema, lowerBounds), -// toBufferMap(fileSchema, upperBounds)); -// } -// -// private static void updateFromFieldMetrics( -// Map> idToFieldMetricsMap, MetricsConfig metricsConfig, Schema schema, -// Map> lowerBounds, Map> upperBounds) { -// idToFieldMetricsMap.entrySet().forEach(entry -> { -// int fieldId = entry.getKey(); -// FieldMetrics metrics = entry.getValue(); -// MetricsMode metricsMode = MetricsUtil.metricsMode(schema, metricsConfig, fieldId); -// -// // only check for MetricsModes.None, since we don't truncate float/double values. -// if (metricsMode != MetricsModes.None.get()) { -// if (!metrics.hasBounds()) { -// lowerBounds.remove(fieldId); -// upperBounds.remove(fieldId); -// } else if (metrics.upperBound() instanceof Float) { -// lowerBounds.put(fieldId, Literal.of((Float) metrics.lowerBound())); -// upperBounds.put(fieldId, Literal.of((Float) metrics.upperBound())); -// } else if (metrics.upperBound() instanceof Double) { -// lowerBounds.put(fieldId, Literal.of((Double) metrics.lowerBound())); -// upperBounds.put(fieldId, Literal.of((Double) metrics.upperBound())); -// } else { -// throw new UnsupportedOperationException("Expected only float or double column metrics"); -// } -// } -// }); -// } -// -// private static MessageType getParquetTypeWithIds(ParquetMetadata metadata, NameMapping nameMapping) { -// MessageType type = metadata.getFileMetaData().getSchema(); -// -// if (ParquetSchemaUtil.hasIds(type)) { -// return type; -// } -// -// if (nameMapping != null) { -// return ParquetSchemaUtil.applyNameMapping(type, nameMapping); -// } -// -// return ParquetSchemaUtil.addFallbackIds(type); -// } -// -// /** -// * Returns a list of offsets in ascending order determined by the starting position of the row groups. -// */ -// public static List getSplitOffsets(ParquetMetadata md) { -// List splitOffsets = Lists.newArrayListWithExpectedSize(md.getBlocks().size()); -// for (BlockMetaData blockMetaData : md.getBlocks()) { -// splitOffsets.add(blockMetaData.getStartingPos()); -// } -// Collections.sort(splitOffsets); -// return splitOffsets; -// } -// -// // we allow struct nesting, but not maps or arrays -// private static boolean shouldStoreBounds(ColumnChunkMetaData column, Schema schema) { -// if (column.getPrimitiveType().getPrimitiveTypeName() == PrimitiveType.PrimitiveTypeName.INT96) { -// // stats for INT96 are not reliable -// return false; -// } -// -// ColumnPath columnPath = column.getPath(); -// Iterator pathIterator = columnPath.iterator(); -// Type currentType = schema.asStruct(); -// -// while (pathIterator.hasNext()) { -// if (currentType == null || !currentType.isStructType()) { -// return false; -// } -// String fieldName = pathIterator.next(); -// currentType = currentType.asStructType().fieldType(fieldName); -// } -// -// return currentType != null && currentType.isPrimitiveType(); -// } -// -// private static void increment(Map columns, int fieldId, long amount) { -// if (columns != null) { -// if (columns.containsKey(fieldId)) { -// columns.put(fieldId, columns.get(fieldId) + amount); -// } else { -// columns.put(fieldId, amount); -// } -// } -// } -// -// @SuppressWarnings("unchecked") -// private static void updateMin(Map> lowerBounds, int id, Type type, -// Literal min, MetricsMode metricsMode) { -// Literal currentMin = (Literal) lowerBounds.get(id); -// if (currentMin == null || min.comparator().compare(min.value(), currentMin.value()) < 0) { -// if (metricsMode == MetricsModes.Full.get()) { -// lowerBounds.put(id, min); -// } else { -// MetricsModes.Truncate truncateMode = (MetricsModes.Truncate) metricsMode; -// int truncateLength = truncateMode.length(); -// switch (type.typeId()) { -// case STRING: -// lowerBounds.put(id, UnicodeUtil.truncateStringMin((Literal) min, truncateLength)); -// break; -// case FIXED: -// case BINARY: -// lowerBounds.put(id, BinaryUtil.truncateBinaryMin((Literal) min, truncateLength)); -// break; -// default: -// lowerBounds.put(id, min); -// } -// } -// } -// } -// -// @SuppressWarnings("unchecked") -// private static void updateMax(Map> upperBounds, int id, Type type, -// Literal max, MetricsMode metricsMode) { -// Literal currentMax = (Literal) upperBounds.get(id); -// if (currentMax == null || max.comparator().compare(max.value(), currentMax.value()) > 0) { -// if (metricsMode == MetricsModes.Full.get()) { -// upperBounds.put(id, max); -// } else { -// MetricsModes.Truncate truncateMode = (MetricsModes.Truncate) metricsMode; -// int truncateLength = truncateMode.length(); -// switch (type.typeId()) { -// case STRING: -// Literal truncatedMaxString = UnicodeUtil.truncateStringMax((Literal) max, -// truncateLength); -// if (truncatedMaxString != null) { -// upperBounds.put(id, truncatedMaxString); -// } -// break; -// case FIXED: -// case BINARY: -// Literal truncatedMaxBinary = BinaryUtil.truncateBinaryMax((Literal) max, -// truncateLength); -// if (truncatedMaxBinary != null) { -// upperBounds.put(id, truncatedMaxBinary); -// } -// break; -// default: -// upperBounds.put(id, max); -// } -// } -// } -// } -// -// private static Map toBufferMap(Schema schema, Map> map) { -// Map bufferMap = Maps.newHashMap(); -// for (Map.Entry> entry : map.entrySet()) { -// bufferMap.put(entry.getKey(), -// Conversions.toByteBuffer(schema.findType(entry.getKey()), entry.getValue().value())); -// } -// return bufferMap; -// } - @SuppressWarnings("deprecation") public static boolean hasNonDictionaryPages(ColumnChunkMetaData meta) { EncodingStats stats = meta.getEncodingStats(); @@ -301,31 +55,4 @@ public static boolean hasNonDictionaryPages(ColumnChunkMetaData meta) { return true; } } - -// public static Dictionary readDictionary(ColumnDescriptor desc, PageReader pageSource) { -// DictionaryPage dictionaryPage = pageSource.readDictionaryPage(); -// if (dictionaryPage != null) { -// try { -// return dictionaryPage.getEncoding().initDictionary(desc, dictionaryPage); -// } catch (IOException e) { -// throw new ParquetDecodingException("could not decode the dictionary for " + desc, e); -// } -// } -// return null; -// } -// -// public static boolean isIntType(PrimitiveType primitiveType) { -// if (primitiveType.getOriginalType() != null) { -// switch (primitiveType.getOriginalType()) { -// case INT_8: -// case INT_16: -// case INT_32: -// case DATE: -// return true; -// default: -// return false; -// } -// } -// return primitiveType.getPrimitiveTypeName() == PrimitiveType.PrimitiveTypeName.INT32; -// } } diff --git a/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/spark/Spark3Util.java b/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/spark/Spark3Util.java index ada2586d4e8..472b2ccfc7b 100644 --- a/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/spark/Spark3Util.java +++ b/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/spark/Spark3Util.java @@ -32,442 +32,17 @@ public class Spark3Util { -// private static final Set RESERVED_PROPERTIES = ImmutableSet.of( -// TableCatalog.PROP_LOCATION, TableCatalog.PROP_PROVIDER); -// private static final Joiner DOT = Joiner.on("."); - private Spark3Util() { } -// public static CaseInsensitiveStringMap setOption(String key, String value, CaseInsensitiveStringMap options) { -// Map newOptions = Maps.newHashMap(); -// newOptions.putAll(options); -// newOptions.put(key, value); -// return new CaseInsensitiveStringMap(newOptions); -// } -// -// public static Map rebuildCreateProperties(Map createProperties) { -// ImmutableMap.Builder tableProperties = ImmutableMap.builder(); -// createProperties.entrySet().stream() -// .filter(entry -> !RESERVED_PROPERTIES.contains(entry.getKey())) -// .forEach(tableProperties::put); -// -// String provider = createProperties.get(TableCatalog.PROP_PROVIDER); -// if ("parquet".equalsIgnoreCase(provider)) { -// tableProperties.put(TableProperties.DEFAULT_FILE_FORMAT, "parquet"); -// } else if ("avro".equalsIgnoreCase(provider)) { -// tableProperties.put(TableProperties.DEFAULT_FILE_FORMAT, "avro"); -// } else if ("orc".equalsIgnoreCase(provider)) { -// tableProperties.put(TableProperties.DEFAULT_FILE_FORMAT, "orc"); -// } else if (provider != null && !"iceberg".equalsIgnoreCase(provider)) { -// throw new IllegalArgumentException("Unsupported format in USING: " + provider); -// } -// -// return tableProperties.build(); -// } -// -// /** -// * Applies a list of Spark table changes to an {@link UpdateProperties} operation. -// * -// * @param pendingUpdate an uncommitted UpdateProperties operation to configure -// * @param changes a list of Spark table changes -// * @return the UpdateProperties operation configured with the changes -// */ -// public static UpdateProperties applyPropertyChanges(UpdateProperties pendingUpdate, List changes) { -// for (TableChange change : changes) { -// if (change instanceof TableChange.SetProperty) { -// TableChange.SetProperty set = (TableChange.SetProperty) change; -// pendingUpdate.set(set.property(), set.value()); -// -// } else if (change instanceof TableChange.RemoveProperty) { -// TableChange.RemoveProperty remove = (TableChange.RemoveProperty) change; -// pendingUpdate.remove(remove.property()); -// -// } else { -// throw new UnsupportedOperationException("Cannot apply unknown table change: " + change); -// } -// } -// -// return pendingUpdate; -// } -// -// /** -// * Applies a list of Spark table changes to an {@link UpdateSchema} operation. -// * -// * @param pendingUpdate an uncommitted UpdateSchema operation to configure -// * @param changes a list of Spark table changes -// * @return the UpdateSchema operation configured with the changes -// */ -// public static UpdateSchema applySchemaChanges(UpdateSchema pendingUpdate, List changes) { -// for (TableChange change : changes) { -// if (change instanceof TableChange.AddColumn) { -// apply(pendingUpdate, (TableChange.AddColumn) change); -// -// } else if (change instanceof TableChange.UpdateColumnType) { -// TableChange.UpdateColumnType update = (TableChange.UpdateColumnType) change; -// Type newType = SparkSchemaUtil.convert(update.newDataType()); -// Preconditions.checkArgument(newType.isPrimitiveType(), -// "Cannot update '%s', not a primitive type: %s", DOT.join(update.fieldNames()), update.newDataType()); -// pendingUpdate.updateColumn(DOT.join(update.fieldNames()), newType.asPrimitiveType()); -// -// } else if (change instanceof TableChange.UpdateColumnComment) { -// TableChange.UpdateColumnComment update = (TableChange.UpdateColumnComment) change; -// pendingUpdate.updateColumnDoc(DOT.join(update.fieldNames()), update.newComment()); -// -// } else if (change instanceof TableChange.RenameColumn) { -// TableChange.RenameColumn rename = (TableChange.RenameColumn) change; -// pendingUpdate.renameColumn(DOT.join(rename.fieldNames()), rename.newName()); -// -// } else if (change instanceof TableChange.DeleteColumn) { -// TableChange.DeleteColumn delete = (TableChange.DeleteColumn) change; -// pendingUpdate.deleteColumn(DOT.join(delete.fieldNames())); -// -// } else if (change instanceof TableChange.UpdateColumnNullability) { -// TableChange.UpdateColumnNullability update = (TableChange.UpdateColumnNullability) change; -// if (update.nullable()) { -// pendingUpdate.makeColumnOptional(DOT.join(update.fieldNames())); -// } else { -// pendingUpdate.requireColumn(DOT.join(update.fieldNames())); -// } -// -// } else if (change instanceof TableChange.UpdateColumnPosition) { -// apply(pendingUpdate, (TableChange.UpdateColumnPosition) change); -// -// } else { -// throw new UnsupportedOperationException("Cannot apply unknown table change: " + change); -// } -// } -// -// return pendingUpdate; -// } -// -// private static void apply(UpdateSchema pendingUpdate, TableChange.UpdateColumnPosition update) { -// Preconditions.checkArgument(update.position() != null, "Invalid position: null"); -// -// if (update.position() instanceof TableChange.After) { -// TableChange.After after = (TableChange.After) update.position(); -// String referenceField = peerName(update.fieldNames(), after.column()); -// pendingUpdate.moveAfter(DOT.join(update.fieldNames()), referenceField); -// -// } else if (update.position() instanceof TableChange.First) { -// pendingUpdate.moveFirst(DOT.join(update.fieldNames())); -// -// } else { -// throw new IllegalArgumentException("Unknown position for reorder: " + update.position()); -// } -// } -// -// private static void apply(UpdateSchema pendingUpdate, TableChange.AddColumn add) { -// Preconditions.checkArgument(add.isNullable(), -// "Incompatible change: cannot add required column: %s", leafName(add.fieldNames())); -// Type type = SparkSchemaUtil.convert(add.dataType()); -// pendingUpdate.addColumn(parentName(add.fieldNames()), leafName(add.fieldNames()), type, add.comment()); -// -// if (add.position() instanceof TableChange.After) { -// TableChange.After after = (TableChange.After) add.position(); -// String referenceField = peerName(add.fieldNames(), after.column()); -// pendingUpdate.moveAfter(DOT.join(add.fieldNames()), referenceField); -// -// } else if (add.position() instanceof TableChange.First) { -// pendingUpdate.moveFirst(DOT.join(add.fieldNames())); -// -// } else { -// Preconditions.checkArgument(add.position() == null, -// "Cannot add '%s' at unknown position: %s", DOT.join(add.fieldNames()), add.position()); -// } -// } -// -// public static org.apache.iceberg.Table toIcebergTable(Table table) { -// Preconditions.checkArgument(table instanceof SparkTable, "Table %s is not an Iceberg table", table); -// SparkTable sparkTable = (SparkTable) table; -// return sparkTable.table(); -// } -// -// /** -// * Converts a PartitionSpec to Spark transforms. -// * -// * @param spec a PartitionSpec -// * @return an array of Transforms -// */ -// public static Transform[] toTransforms(PartitionSpec spec) { -// Map quotedNameById = SparkSchemaUtil.indexQuotedNameById(spec.schema()); -// List transforms = PartitionSpecVisitor.visit(spec, -// new PartitionSpecVisitor() { -// @Override -// public Transform identity(String sourceName, int sourceId) { -// return Expressions.identity(quotedName(sourceId)); -// } -// -// @Override -// public Transform bucket(String sourceName, int sourceId, int numBuckets) { -// return Expressions.bucket(numBuckets, quotedName(sourceId)); -// } -// -// @Override -// public Transform truncate(String sourceName, int sourceId, int width) { -// return Expressions.apply("truncate", Expressions.column(quotedName(sourceId)), Expressions.literal(width)); -// } -// -// @Override -// public Transform year(String sourceName, int sourceId) { -// return Expressions.years(quotedName(sourceId)); -// } -// -// @Override -// public Transform month(String sourceName, int sourceId) { -// return Expressions.months(quotedName(sourceId)); -// } -// -// @Override -// public Transform day(String sourceName, int sourceId) { -// return Expressions.days(quotedName(sourceId)); -// } -// -// @Override -// public Transform hour(String sourceName, int sourceId) { -// return Expressions.hours(quotedName(sourceId)); -// } -// -// @Override -// public Transform alwaysNull(int fieldId, String sourceName, int sourceId) { -// // do nothing for alwaysNull, it doesn't need to be converted to a transform -// return null; -// } -// -// @Override -// public Transform unknown(int fieldId, String sourceName, int sourceId, String transform) { -// return Expressions.apply(transform, Expressions.column(quotedName(sourceId))); -// } -// -// private String quotedName(int id) { -// return quotedNameById.get(id); -// } -// }); -// -// return transforms.stream().filter(Objects::nonNull).toArray(Transform[]::new); -// } - public static NamedReference toNamedReference(String name) { return Expressions.column(name); } -// public static Term toIcebergTerm(Expression expr) { -// if (expr instanceof Transform) { -// Transform transform = (Transform) expr; -// Preconditions.checkArgument(transform.references().length == 1, -// "Cannot convert transform with more than one column reference: %s", transform); -// String colName = DOT.join(transform.references()[0].fieldNames()); -// switch (transform.name()) { -// case "identity": -// return org.apache.iceberg.expressions.Expressions.ref(colName); -// case "bucket": -// return org.apache.iceberg.expressions.Expressions.bucket(colName, findWidth(transform)); -// case "years": -// return org.apache.iceberg.expressions.Expressions.year(colName); -// case "months": -// return org.apache.iceberg.expressions.Expressions.month(colName); -// case "date": -// case "days": -// return org.apache.iceberg.expressions.Expressions.day(colName); -// case "date_hour": -// case "hours": -// return org.apache.iceberg.expressions.Expressions.hour(colName); -// case "truncate": -// return org.apache.iceberg.expressions.Expressions.truncate(colName, findWidth(transform)); -// default: -// throw new UnsupportedOperationException("Transform is not supported: " + transform); -// } -// -// } else if (expr instanceof NamedReference) { -// NamedReference ref = (NamedReference) expr; -// return org.apache.iceberg.expressions.Expressions.ref(DOT.join(ref.fieldNames())); -// -// } else { -// throw new UnsupportedOperationException("Cannot convert unknown expression: " + expr); -// } -// } -// -// /** -// * Converts Spark transforms into a {@link PartitionSpec}. -// * -// * @param schema the table schema -// * @param partitioning Spark Transforms -// * @return a PartitionSpec -// */ -// public static PartitionSpec toPartitionSpec(Schema schema, Transform[] partitioning) { -// if (partitioning == null || partitioning.length == 0) { -// return PartitionSpec.unpartitioned(); -// } -// -// PartitionSpec.Builder builder = PartitionSpec.builderFor(schema); -// for (Transform transform : partitioning) { -// Preconditions.checkArgument(transform.references().length == 1, -// "Cannot convert transform with more than one column reference: %s", transform); -// String colName = DOT.join(transform.references()[0].fieldNames()); -// switch (transform.name()) { -// case "identity": -// builder.identity(colName); -// break; -// case "bucket": -// builder.bucket(colName, findWidth(transform)); -// break; -// case "years": -// builder.year(colName); -// break; -// case "months": -// builder.month(colName); -// break; -// case "date": -// case "days": -// builder.day(colName); -// break; -// case "date_hour": -// case "hours": -// builder.hour(colName); -// break; -// case "truncate": -// builder.truncate(colName, findWidth(transform)); -// break; -// default: -// throw new UnsupportedOperationException("Transform is not supported: " + transform); -// } -// } -// -// return builder.build(); -// } -// -// @SuppressWarnings("unchecked") -// private static int findWidth(Transform transform) { -// for (Expression expr : transform.arguments()) { -// if (expr instanceof Literal) { -// if (((Literal) expr).dataType() instanceof IntegerType) { -// Literal lit = (Literal) expr; -// Preconditions.checkArgument(lit.value() > 0, -// "Unsupported width for transform: %s", transform.describe()); -// return lit.value(); -// -// } else if (((Literal) expr).dataType() instanceof LongType) { -// Literal lit = (Literal) expr; -// Preconditions.checkArgument(lit.value() > 0 && lit.value() < Integer.MAX_VALUE, -// "Unsupported width for transform: %s", transform.describe()); -// if (lit.value() > Integer.MAX_VALUE) { -// throw new IllegalArgumentException(); -// } -// return lit.value().intValue(); -// } -// } -// } -// -// throw new IllegalArgumentException("Cannot find width for transform: " + transform.describe()); -// } -// -// private static String leafName(String[] fieldNames) { -// Preconditions.checkArgument(fieldNames.length > 0, "Invalid field name: at least one name is required"); -// return fieldNames[fieldNames.length - 1]; -// } -// -// private static String peerName(String[] fieldNames, String fieldName) { -// if (fieldNames.length > 1) { -// String[] peerNames = Arrays.copyOf(fieldNames, fieldNames.length); -// peerNames[fieldNames.length - 1] = fieldName; -// return DOT.join(peerNames); -// } -// return fieldName; -// } -// -// private static String parentName(String[] fieldNames) { -// if (fieldNames.length > 1) { -// return DOT.join(Arrays.copyOfRange(fieldNames, 0, fieldNames.length - 1)); -// } -// return null; -// } - public static String describe(org.apache.iceberg.expressions.Expression expr) { return ExpressionVisitors.visit(expr, DescribeExpressionVisitor.INSTANCE); } -// public static String describe(Schema schema) { -// return TypeUtil.visit(schema, DescribeSchemaVisitor.INSTANCE); -// } -// -// public static String describe(Type type) { -// return TypeUtil.visit(type, DescribeSchemaVisitor.INSTANCE); -// } -// -// public static String describe(org.apache.iceberg.SortOrder order) { -// return Joiner.on(", ").join(SortOrderVisitor.visit(order, DescribeSortOrderVisitor.INSTANCE)); -// } -// -// public static boolean extensionsEnabled(SparkSession spark) { -// String extensions = spark.conf().get("spark.sql.extensions", ""); -// return extensions.contains("IcebergSparkSessionExtensions"); -// } -// -// public static class DescribeSchemaVisitor extends TypeUtil.SchemaVisitor { -// private static final Joiner COMMA = Joiner.on(','); -// private static final DescribeSchemaVisitor INSTANCE = new DescribeSchemaVisitor(); -// -// private DescribeSchemaVisitor() { -// } -// -// @Override -// public String schema(Schema schema, String structResult) { -// return structResult; -// } -// -// @Override -// public String struct(Types.StructType struct, List fieldResults) { -// return "struct<" + COMMA.join(fieldResults) + ">"; -// } -// -// @Override -// public String field(Types.NestedField field, String fieldResult) { -// return field.name() + ": " + fieldResult + (field.isRequired() ? " not null" : ""); -// } -// -// @Override -// public String list(Types.ListType list, String elementResult) { -// return "list<" + elementResult + ">"; -// } -// -// @Override -// public String map(Types.MapType map, String keyResult, String valueResult) { -// return "map<" + keyResult + ", " + valueResult + ">"; -// } -// -// @Override -// public String primitive(Type.PrimitiveType primitive) { -// switch (primitive.typeId()) { -// case BOOLEAN: -// return "boolean"; -// case INTEGER: -// return "int"; -// case LONG: -// return "bigint"; -// case FLOAT: -// return "float"; -// case DOUBLE: -// return "double"; -// case DATE: -// return "date"; -// case TIME: -// return "time"; -// case TIMESTAMP: -// return "timestamp"; -// case STRING: -// case UUID: -// return "string"; -// case FIXED: -// case BINARY: -// return "binary"; -// case DECIMAL: -// Types.DecimalType decimal = (Types.DecimalType) primitive; -// return "decimal(" + decimal.precision() + "," + decimal.scale() + ")"; -// } -// throw new UnsupportedOperationException("Cannot convert type to SQL: " + primitive); -// } -// } - private static class DescribeExpressionVisitor extends ExpressionVisitors.ExpressionVisitor { private static final DescribeExpressionVisitor INSTANCE = new DescribeExpressionVisitor(); @@ -555,267 +130,4 @@ private static String sqlString(org.apache.iceberg.expressions.Literal lit) { } } } - -// /** -// * Returns a metadata table as a Dataset based on the given Iceberg table. -// * -// * @param spark SparkSession where the Dataset will be created -// * @param table an Iceberg table -// * @param type the type of metadata table -// * @return a Dataset that will read the metadata table -// */ -// private static Dataset loadMetadataTable(SparkSession spark, org.apache.iceberg.Table table, -// MetadataTableType type) { -// Table metadataTable = new SparkTable(MetadataTableUtils.createMetadataTableInstance(table, type), false); -// return Dataset.ofRows(spark, DataSourceV2Relation.create(metadataTable, Some.empty(), Some.empty())); -// } -// -// /** -// * Returns an Iceberg Table by its name from a Spark V2 Catalog. If cache is enabled in {@link SparkCatalog}, -// * the {@link TableOperations} of the table may be stale, please refresh the table to get the latest one. -// * -// * @param spark SparkSession used for looking up catalog references and tables -// * @param name The multipart identifier of the Iceberg table -// * @return an Iceberg table -// */ -// public static org.apache.iceberg.Table loadIcebergTable(SparkSession spark, String name) -// throws ParseException, NoSuchTableException { -// CatalogAndIdentifier catalogAndIdentifier = catalogAndIdentifier(spark, name); -// -// TableCatalog catalog = asTableCatalog(catalogAndIdentifier.catalog); -// Table sparkTable = catalog.loadTable(catalogAndIdentifier.identifier); -// return toIcebergTable(sparkTable); -// } -// -// public static CatalogAndIdentifier catalogAndIdentifier(SparkSession spark, String name) throws ParseException { -// return catalogAndIdentifier(spark, name, spark.sessionState().catalogManager().currentCatalog()); -// } -// -// public static CatalogAndIdentifier catalogAndIdentifier(SparkSession spark, String name, -// CatalogPlugin defaultCatalog) throws ParseException { -// ParserInterface parser = spark.sessionState().sqlParser(); -// Seq multiPartIdentifier = parser.parseMultipartIdentifier(name).toIndexedSeq(); -// List javaMultiPartIdentifier = JavaConverters.seqAsJavaList(multiPartIdentifier); -// return catalogAndIdentifier(spark, javaMultiPartIdentifier, defaultCatalog); -// } -// -// public static CatalogAndIdentifier catalogAndIdentifier(String description, SparkSession spark, String name) { -// return catalogAndIdentifier(description, spark, name, spark.sessionState().catalogManager().currentCatalog()); -// } -// -// public static CatalogAndIdentifier catalogAndIdentifier(String description, SparkSession spark, -// String name, CatalogPlugin defaultCatalog) { -// try { -// return catalogAndIdentifier(spark, name, defaultCatalog); -// } catch (ParseException e) { -// throw new IllegalArgumentException("Cannot parse " + description + ": " + name, e); -// } -// } -// -// public static CatalogAndIdentifier catalogAndIdentifier(SparkSession spark, List nameParts) { -// return catalogAndIdentifier(spark, nameParts, spark.sessionState().catalogManager().currentCatalog()); -// } -// -// /** -// * A modified version of Spark's LookupCatalog.CatalogAndIdentifier.unapply -// * Attempts to find the catalog and identifier a multipart identifier represents -// * @param spark Spark session to use for resolution -// * @param nameParts Multipart identifier representing a table -// * @param defaultCatalog Catalog to use if none is specified -// * @return The CatalogPlugin and Identifier for the table -// */ -// public static CatalogAndIdentifier catalogAndIdentifier(SparkSession spark, List nameParts, -// CatalogPlugin defaultCatalog) { -// CatalogManager catalogManager = spark.sessionState().catalogManager(); -// -// String[] currentNamespace; -// if (defaultCatalog.equals(catalogManager.currentCatalog())) { -// currentNamespace = catalogManager.currentNamespace(); -// } else { -// currentNamespace = defaultCatalog.defaultNamespace(); -// } -// -// Pair catalogIdentifier = SparkUtil.catalogAndIdentifier(nameParts, -// catalogName -> { -// try { -// return catalogManager.catalog(catalogName); -// } catch (Exception e) { -// return null; -// } -// }, -// Identifier::of, -// defaultCatalog, -// currentNamespace -// ); -// return new CatalogAndIdentifier(catalogIdentifier); -// } -// -// private static TableCatalog asTableCatalog(CatalogPlugin catalog) { -// if (catalog instanceof TableCatalog) { -// return (TableCatalog) catalog; -// } -// -// throw new IllegalArgumentException(String.format( -// "Cannot use catalog %s(%s): not a TableCatalog", catalog.name(), catalog.getClass().getName())); -// } -// -// /** -// * This mimics a class inside of Spark which is private inside of LookupCatalog. -// */ -// public static class CatalogAndIdentifier { -// private final CatalogPlugin catalog; -// private final Identifier identifier; -// -// -// public CatalogAndIdentifier(CatalogPlugin catalog, Identifier identifier) { -// this.catalog = catalog; -// this.identifier = identifier; -// } -// -// public CatalogAndIdentifier(Pair identifier) { -// this.catalog = identifier.first(); -// this.identifier = identifier.second(); -// } -// -// public CatalogPlugin catalog() { -// return catalog; -// } -// -// public Identifier identifier() { -// return identifier; -// } -// } -// -// public static TableIdentifier identifierToTableIdentifier(Identifier identifier) { -// return TableIdentifier.of(Namespace.of(identifier.namespace()), identifier.name()); -// } -// -// /** -// * Use Spark to list all partitions in the table. -// * -// * @param spark a Spark session -// * @param rootPath a table identifier -// * @param format format of the file -// * @param partitionFilter partitionFilter of the file -// * @return all table's partitions -// */ -// public static List getPartitions(SparkSession spark, Path rootPath, String format, -// Map partitionFilter) { -// FileStatusCache fileStatusCache = FileStatusCache.getOrCreate(spark); -// -// InMemoryFileIndex fileIndex = new InMemoryFileIndex( -// spark, -// JavaConverters -// .collectionAsScalaIterableConverter(ImmutableList.of(rootPath)) -// .asScala() -// .toSeq(), -// scala.collection.immutable.Map$.MODULE$.empty(), -// Option.empty(), -// fileStatusCache, -// Option.empty(), -// Option.empty()); -// -// org.apache.spark.sql.execution.datasources.PartitionSpec spec = fileIndex.partitionSpec(); -// StructType schema = spec.partitionColumns(); -// if (schema.isEmpty()) { -// return Lists.newArrayList(); -// } -// -// List filterExpressions = -// SparkUtil.partitionMapToExpression(schema, partitionFilter); -// Seq scalaPartitionFilters = -// JavaConverters.asScalaBufferConverter(filterExpressions).asScala().toIndexedSeq(); -// -// List dataFilters = Lists.newArrayList(); -// Seq scalaDataFilters = -// JavaConverters.asScalaBufferConverter(dataFilters).asScala().toIndexedSeq(); -// -// Seq filteredPartitions = -// fileIndex.listFiles(scalaPartitionFilters, scalaDataFilters).toIndexedSeq(); -// -// return JavaConverters -// .seqAsJavaListConverter(filteredPartitions) -// .asJava() -// .stream() -// .map(partition -> { -// Map values = Maps.newHashMap(); -// JavaConverters.asJavaIterableConverter(schema).asJava().forEach(field -> { -// int fieldIndex = schema.fieldIndex(field.name()); -// Object catalystValue = partition.values().get(fieldIndex, field.dataType()); -// Object value = CatalystTypeConverters.convertToScala(catalystValue, field.dataType()); -// values.put(field.name(), String.valueOf(value)); -// }); -// -// FileStatus fileStatus = -// JavaConverters.seqAsJavaListConverter(partition.files()).asJava().get(0); -// -// return new SparkPartition(values, fileStatus.getPath().getParent().toString(), format); -// }).collect(Collectors.toList()); -// } -// -// public static org.apache.spark.sql.catalyst.TableIdentifier toV1TableIdentifier(Identifier identifier) { -// String[] namespace = identifier.namespace(); -// -// Preconditions.checkArgument(namespace.length <= 1, -// "Cannot convert %s to a Spark v1 identifier, namespace contains more than 1 part", identifier); -// -// String table = identifier.name(); -// Option database = namespace.length == 1 ? Option.apply(namespace[0]) : Option.empty(); -// return org.apache.spark.sql.catalyst.TableIdentifier.apply(table, database); -// } -// -// private static class DescribeSortOrderVisitor implements SortOrderVisitor { -// private static final DescribeSortOrderVisitor INSTANCE = new DescribeSortOrderVisitor(); -// -// private DescribeSortOrderVisitor() { -// } -// -// @Override -// public String field(String sourceName, int sourceId, -// org.apache.iceberg.SortDirection direction, NullOrder nullOrder) { -// return String.format("%s %s %s", sourceName, direction, nullOrder); -// } -// -// @Override -// public String bucket(String sourceName, int sourceId, int numBuckets, -// org.apache.iceberg.SortDirection direction, NullOrder nullOrder) { -// return String.format("bucket(%s, %s) %s %s", numBuckets, sourceName, direction, nullOrder); -// } -// -// @Override -// public String truncate(String sourceName, int sourceId, int width, -// org.apache.iceberg.SortDirection direction, NullOrder nullOrder) { -// return String.format("truncate(%s, %s) %s %s", sourceName, width, direction, nullOrder); -// } -// -// @Override -// public String year(String sourceName, int sourceId, -// org.apache.iceberg.SortDirection direction, NullOrder nullOrder) { -// return String.format("years(%s) %s %s", sourceName, direction, nullOrder); -// } -// -// @Override -// public String month(String sourceName, int sourceId, -// org.apache.iceberg.SortDirection direction, NullOrder nullOrder) { -// return String.format("months(%s) %s %s", sourceName, direction, nullOrder); -// } -// -// @Override -// public String day(String sourceName, int sourceId, -// org.apache.iceberg.SortDirection direction, NullOrder nullOrder) { -// return String.format("days(%s) %s %s", sourceName, direction, nullOrder); -// } -// -// @Override -// public String hour(String sourceName, int sourceId, -// org.apache.iceberg.SortDirection direction, NullOrder nullOrder) { -// return String.format("hours(%s) %s %s", sourceName, direction, nullOrder); -// } -// -// @Override -// public String unknown(String sourceName, int sourceId, String transform, -// org.apache.iceberg.SortDirection direction, NullOrder nullOrder) { -// return String.format("%s(%s) %s %s", transform, sourceName, direction, nullOrder); -// } -// } } diff --git a/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/spark/SparkSchemaUtil.java b/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/spark/SparkSchemaUtil.java index fce3a1ccac6..123de4276a5 100644 --- a/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/spark/SparkSchemaUtil.java +++ b/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/spark/SparkSchemaUtil.java @@ -43,44 +43,6 @@ public class SparkSchemaUtil { private SparkSchemaUtil() { } -// /** -// * Returns a {@link Schema} for the given table with fresh field ids. -// *

-// * This creates a Schema for an existing table by looking up the table's schema with Spark and -// * converting that schema. Spark/Hive partition columns are included in the schema. -// * -// * @param spark a Spark session -// * @param name a table name and (optional) database -// * @return a Schema for the table, if found -// */ -// public static Schema schemaForTable(SparkSession spark, String name) { -// StructType sparkType = spark.table(name).schema(); -// Type converted = SparkTypeVisitor.visit(sparkType, new SparkTypeToType(sparkType)); -// return new Schema(converted.asNestedType().asStructType().fields()); -// } -// -// /** -// * Returns a {@link PartitionSpec} for the given table. -// *

-// * This creates a partition spec for an existing table by looking up the table's schema and -// * creating a spec with identity partitions for each partition column. -// * -// * @param spark a Spark session -// * @param name a table name and (optional) database -// * @return a PartitionSpec for the table -// * @throws AnalysisException if thrown by the Spark catalog -// */ -// public static PartitionSpec specForTable(SparkSession spark, String name) throws AnalysisException { -// List parts = Lists.newArrayList(Splitter.on('.').limit(2).split(name)); -// String db = parts.size() == 1 ? "default" : parts.get(0); -// String table = parts.get(parts.size() == 1 ? 0 : 1); -// -// PartitionSpec spec = identitySpec( -// schemaForTable(spark, name), -// spark.catalog().listColumns(db, table).collectAsList()); -// return spec == null ? PartitionSpec.unpartitioned() : spec; -// } - /** * Convert a {@link Schema} to a {@link DataType Spark type}. * @@ -130,200 +92,6 @@ public DataType struct(Types.StructType struct, List fieldResults) { }); } -// /** -// * Convert a Spark {@link StructType struct} to a {@link Schema} with new field ids. -// *

-// * This conversion assigns fresh ids. -// *

-// * Some data types are represented as the same Spark type. These are converted to a default type. -// *

-// * To convert using a reference schema for field ids and ambiguous types, use -// * {@link #convert(Schema, StructType)}. -// * -// * @param sparkType a Spark StructType -// * @return the equivalent Schema -// * @throws IllegalArgumentException if the type cannot be converted -// */ -// public static Schema convert(StructType sparkType) { -// return convert(sparkType, false); -// } -// -// /** -// * Convert a Spark {@link StructType struct} to a {@link Schema} with new field ids. -// *

-// * This conversion assigns fresh ids. -// *

-// * Some data types are represented as the same Spark type. These are converted to a default type. -// *

-// * To convert using a reference schema for field ids and ambiguous types, use -// * {@link #convert(Schema, StructType)}. -// * -// * @param sparkType a Spark StructType -// * @param useTimestampWithoutZone boolean flag indicates that timestamp should be stored without timezone -// * @return the equivalent Schema -// * @throws IllegalArgumentException if the type cannot be converted -// */ -// public static Schema convert(StructType sparkType, boolean useTimestampWithoutZone) { -// Type converted = SparkTypeVisitor.visit(sparkType, new SparkTypeToType(sparkType)); -// Schema schema = new Schema(converted.asNestedType().asStructType().fields()); -// if (useTimestampWithoutZone) { -// schema = SparkFixupTimestampType.fixup(schema); -// } -// return schema; -// } -// -// /** -// * Convert a Spark {@link DataType struct} to a {@link Type} with new field ids. -// *

-// * This conversion assigns fresh ids. -// *

-// * Some data types are represented as the same Spark type. These are converted to a default type. -// *

-// * To convert using a reference schema for field ids and ambiguous types, use -// * {@link #convert(Schema, StructType)}. -// * -// * @param sparkType a Spark DataType -// * @return the equivalent Type -// * @throws IllegalArgumentException if the type cannot be converted -// */ -// public static Type convert(DataType sparkType) { -// return SparkTypeVisitor.visit(sparkType, new SparkTypeToType()); -// } -// -// /** -// * Convert a Spark {@link StructType struct} to a {@link Schema} based on the given schema. -// *

-// * This conversion does not assign new ids; it uses ids from the base schema. -// *

-// * Data types, field order, and nullability will match the spark type. This conversion may return -// * a schema that is not compatible with base schema. -// * -// * @param baseSchema a Schema on which conversion is based -// * @param sparkType a Spark StructType -// * @return the equivalent Schema -// * @throws IllegalArgumentException if the type cannot be converted or there are missing ids -// */ -// public static Schema convert(Schema baseSchema, StructType sparkType) { -// // convert to a type with fresh ids -// Types.StructType struct = SparkTypeVisitor.visit(sparkType, new SparkTypeToType(sparkType)).asStructType(); -// // reassign ids to match the base schema -// Schema schema = TypeUtil.reassignIds(new Schema(struct.fields()), baseSchema); -// // fix types that can't be represented in Spark (UUID and Fixed) -// return SparkFixupTypes.fixup(schema, baseSchema); -// } -// -// /** -// * Convert a Spark {@link StructType struct} to a {@link Schema} based on the given schema. -// *

-// * This conversion will assign new ids for fields that are not found in the base schema. -// *

-// * Data types, field order, and nullability will match the spark type. This conversion may return -// * a schema that is not compatible with base schema. -// * -// * @param baseSchema a Schema on which conversion is based -// * @param sparkType a Spark StructType -// * @return the equivalent Schema -// * @throws IllegalArgumentException if the type cannot be converted or there are missing ids -// */ -// public static Schema convertWithFreshIds(Schema baseSchema, StructType sparkType) { -// // convert to a type with fresh ids -// Types.StructType struct = SparkTypeVisitor.visit(sparkType, new SparkTypeToType(sparkType)).asStructType(); -// // reassign ids to match the base schema -// Schema schema = TypeUtil.reassignOrRefreshIds(new Schema(struct.fields()), baseSchema); -// // fix types that can't be represented in Spark (UUID and Fixed) -// return SparkFixupTypes.fixup(schema, baseSchema); -// } -// -// /** -// * Prune columns from a {@link Schema} using a {@link StructType Spark type} projection. -// *

-// * This requires that the Spark type is a projection of the Schema. Nullability and types must -// * match. -// * -// * @param schema a Schema -// * @param requestedType a projection of the Spark representation of the Schema -// * @return a Schema corresponding to the Spark projection -// * @throws IllegalArgumentException if the Spark type does not match the Schema -// */ -// public static Schema prune(Schema schema, StructType requestedType) { -// return new Schema(TypeUtil.visit(schema, new PruneColumnsWithoutReordering(requestedType, ImmutableSet.of())) -// .asNestedType() -// .asStructType() -// .fields()); -// } -// -// /** -// * Prune columns from a {@link Schema} using a {@link StructType Spark type} projection. -// *

-// * This requires that the Spark type is a projection of the Schema. Nullability and types must -// * match. -// *

-// * The filters list of {@link Expression} is used to ensure that columns referenced by filters -// * are projected. -// * -// * @param schema a Schema -// * @param requestedType a projection of the Spark representation of the Schema -// * @param filters a list of filters -// * @return a Schema corresponding to the Spark projection -// * @throws IllegalArgumentException if the Spark type does not match the Schema -// */ -// public static Schema prune(Schema schema, StructType requestedType, List filters) { -// Set filterRefs = Binder.boundReferences(schema.asStruct(), filters, true); -// return new Schema(TypeUtil.visit(schema, new PruneColumnsWithoutReordering(requestedType, filterRefs)) -// .asNestedType() -// .asStructType() -// .fields()); -// } -// -// /** -// * Prune columns from a {@link Schema} using a {@link StructType Spark type} projection. -// *

-// * This requires that the Spark type is a projection of the Schema. Nullability and types must -// * match. -// *

-// * The filters list of {@link Expression} is used to ensure that columns referenced by filters -// * are projected. -// * -// * @param schema a Schema -// * @param requestedType a projection of the Spark representation of the Schema -// * @param filter a filters -// * @return a Schema corresponding to the Spark projection -// * @throws IllegalArgumentException if the Spark type does not match the Schema -// */ -// public static Schema prune(Schema schema, StructType requestedType, Expression filter, boolean caseSensitive) { -// Set filterRefs = -// Binder.boundReferences(schema.asStruct(), Collections.singletonList(filter), caseSensitive); -// -// return new Schema(TypeUtil.visit(schema, new PruneColumnsWithoutReordering(requestedType, filterRefs)) -// .asNestedType() -// .asStructType() -// .fields()); -// } -// -// private static PartitionSpec identitySpec(Schema schema, Collection columns) { -// List names = Lists.newArrayList(); -// for (Column column : columns) { -// if (column.isPartition()) { -// names.add(column.name()); -// } -// } -// -// return identitySpec(schema, names); -// } -// -// private static PartitionSpec identitySpec(Schema schema, List partitionNames) { -// if (partitionNames == null || partitionNames.isEmpty()) { -// return null; -// } -// -// PartitionSpec.Builder builder = PartitionSpec.builderFor(schema); -// for (String partitionName : partitionNames) { -// builder.identity(partitionName); -// } -// -// return builder.build(); -// } - /** * Estimate approximate table size based on Spark schema and total records. * diff --git a/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/spark/SparkUtil.java b/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/spark/SparkUtil.java index 04c1b76428b..1a558d3b955 100644 --- a/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/spark/SparkUtil.java +++ b/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/spark/SparkUtil.java @@ -29,77 +29,9 @@ public class SparkUtil { " https://docs.databricks.com/spark/latest/dataframes-datasets/dates-timestamps.html#ansi-sql-and" + "-spark-sql-timestamps", SparkSQLProperties.HANDLE_TIMESTAMP_WITHOUT_TIMEZONE); -// private static final String SPARK_CATALOG_CONF_PREFIX = "spark.sql.catalog"; -// // Format string used as the prefix for spark configuration keys to override hadoop configuration values -// // for Iceberg tables from a given catalog. These keys can be specified as `spark.sql.catalog.$catalogName.hadoop.*`, -// // similar to using `spark.hadoop.*` to override hadoop configurations globally for a given spark session. -// private static final String SPARK_CATALOG_HADOOP_CONF_OVERRIDE_FMT_STR = SPARK_CATALOG_CONF_PREFIX + ".%s.hadoop."; - private SparkUtil() { } -// public static FileIO serializableFileIO(Table table) { -// if (table.io() instanceof HadoopConfigurable) { -// // we need to use Spark's SerializableConfiguration to avoid issues with Kryo serialization -// ((HadoopConfigurable) table.io()).serializeConfWith(conf -> new SerializableConfiguration(conf)::value); -// } -// -// return table.io(); -// } -// -// /** -// * Check whether the partition transforms in a spec can be used to write data. -// * -// * @param spec a PartitionSpec -// * @throws UnsupportedOperationException if the spec contains unknown partition transforms -// */ -// public static void validatePartitionTransforms(PartitionSpec spec) { -// if (spec.fields().stream().anyMatch(field -> field.transform() instanceof UnknownTransform)) { -// String unsupported = spec.fields().stream() -// .map(PartitionField::transform) -// .filter(transform -> transform instanceof UnknownTransform) -// .map(Transform::toString) -// .collect(Collectors.joining(", ")); -// -// throw new UnsupportedOperationException( -// String.format("Cannot write using unsupported transforms: %s", unsupported)); -// } -// } -// -// /** -// * A modified version of Spark's LookupCatalog.CatalogAndIdentifier.unapply -// * Attempts to find the catalog and identifier a multipart identifier represents -// * @param nameParts Multipart identifier representing a table -// * @return The CatalogPlugin and Identifier for the table -// */ -// public static Pair catalogAndIdentifier(List nameParts, -// Function catalogProvider, -// BiFunction identiferProvider, -// C currentCatalog, -// String[] currentNamespace) { -// Preconditions.checkArgument(!nameParts.isEmpty(), -// "Cannot determine catalog and identifier from empty name"); -// -// int lastElementIndex = nameParts.size() - 1; -// String name = nameParts.get(lastElementIndex); -// -// if (nameParts.size() == 1) { -// // Only a single element, use current catalog and namespace -// return Pair.of(currentCatalog, identiferProvider.apply(currentNamespace, name)); -// } else { -// C catalog = catalogProvider.apply(nameParts.get(0)); -// if (catalog == null) { -// // The first element was not a valid catalog, treat it like part of the namespace -// String[] namespace = nameParts.subList(0, lastElementIndex).toArray(new String[0]); -// return Pair.of(currentCatalog, identiferProvider.apply(namespace, name)); -// } else { -// // Assume the first element is a valid catalog -// String[] namespace = nameParts.subList(1, lastElementIndex).toArray(new String[0]); -// return Pair.of(catalog, identiferProvider.apply(namespace, name)); -// } -// } -// } - /** * Responsible for checking if the table schema has a timestamp without timezone column * @param schema table schema to check if it contains a timestamp without timezone column @@ -108,114 +40,4 @@ private SparkUtil() { public static boolean hasTimestampWithoutZone(Schema schema) { return TypeUtil.find(schema, t -> Types.TimestampType.withoutZone().equals(t)) != null; } - -// /** -// * Checks whether timestamp types for new tables should be stored with timezone info. -// *

-// * The default value is false and all timestamp fields are stored as {@link Types.TimestampType#withZone()}. -// * If enabled, all timestamp fields in new tables will be stored as {@link Types.TimestampType#withoutZone()}. -// * -// * @param sessionConf a Spark runtime config -// * @return true if timestamp types for new tables should be stored with timezone info -// */ -// public static boolean useTimestampWithoutZoneInNewTables(RuntimeConfig sessionConf) { -// String sessionConfValue = sessionConf.get(SparkSQLProperties.USE_TIMESTAMP_WITHOUT_TIME_ZONE_IN_NEW_TABLES, null); -// if (sessionConfValue != null) { -// return Boolean.parseBoolean(sessionConfValue); -// } -// return SparkSQLProperties.USE_TIMESTAMP_WITHOUT_TIME_ZONE_IN_NEW_TABLES_DEFAULT; -// } -// -// /** -// * Pulls any Catalog specific overrides for the Hadoop conf from the current SparkSession, which can be -// * set via `spark.sql.catalog.$catalogName.hadoop.*` -// * -// * Mirrors the override of hadoop configurations for a given spark session using `spark.hadoop.*`. -// * -// * The SparkCatalog allows for hadoop configurations to be overridden per catalog, by setting -// * them on the SQLConf, where the following will add the property "fs.default.name" with value -// * "hdfs://hanksnamenode:8020" to the catalog's hadoop configuration. -// * SparkSession.builder() -// * .config(s"spark.sql.catalog.$catalogName.hadoop.fs.default.name", "hdfs://hanksnamenode:8020") -// * .getOrCreate() -// * @param spark The current Spark session -// * @param catalogName Name of the catalog to find overrides for. -// * @return the Hadoop Configuration that should be used for this catalog, with catalog specific overrides applied. -// */ -// public static Configuration hadoopConfCatalogOverrides(SparkSession spark, String catalogName) { -// // Find keys for the catalog intended to be hadoop configurations -// final String hadoopConfCatalogPrefix = hadoopConfPrefixForCatalog(catalogName); -// final Configuration conf = spark.sessionState().newHadoopConf(); -// spark.sqlContext().conf().settings().forEach((k, v) -> { -// // These checks are copied from `spark.sessionState().newHadoopConfWithOptions()`, which we -// // avoid using to not have to convert back and forth between scala / java map types. -// if (v != null && k != null && k.startsWith(hadoopConfCatalogPrefix)) { -// conf.set(k.substring(hadoopConfCatalogPrefix.length()), v); -// } -// }); -// return conf; -// } -// -// private static String hadoopConfPrefixForCatalog(String catalogName) { -// return String.format(SPARK_CATALOG_HADOOP_CONF_OVERRIDE_FMT_STR, catalogName); -// } -// -// /** -// * Get a List of Spark filter Expression. -// * -// * @param schema table schema -// * @param filters filters in the format of a Map, where key is one of the table column name, -// * and value is the specific value to be filtered on the column. -// * @return a List of filters in the format of Spark Expression. -// */ -// public static List partitionMapToExpression(StructType schema, -// Map filters) { -// List filterExpressions = Lists.newArrayList(); -// for (Map.Entry entry : filters.entrySet()) { -// try { -// int index = schema.fieldIndex(entry.getKey()); -// DataType dataType = schema.fields()[index].dataType(); -// BoundReference ref = new BoundReference(index, dataType, true); -// switch (dataType.typeName()) { -// case "integer": -// filterExpressions.add(new EqualTo(ref, -// Literal.create(Integer.parseInt(entry.getValue()), DataTypes.IntegerType))); -// break; -// case "string": -// filterExpressions.add(new EqualTo(ref, Literal.create(entry.getValue(), DataTypes.StringType))); -// break; -// case "short": -// filterExpressions.add(new EqualTo(ref, -// Literal.create(Short.parseShort(entry.getValue()), DataTypes.ShortType))); -// break; -// case "long": -// filterExpressions.add(new EqualTo(ref, -// Literal.create(Long.parseLong(entry.getValue()), DataTypes.LongType))); -// break; -// case "float": -// filterExpressions.add(new EqualTo(ref, -// Literal.create(Float.parseFloat(entry.getValue()), DataTypes.FloatType))); -// break; -// case "double": -// filterExpressions.add(new EqualTo(ref, -// Literal.create(Double.parseDouble(entry.getValue()), DataTypes.DoubleType))); -// break; -// case "date": -// filterExpressions.add(new EqualTo(ref, -// Literal.create(new Date(DateTime.parse(entry.getValue()).getMillis()), DataTypes.DateType))); -// break; -// case "timestamp": -// filterExpressions.add(new EqualTo(ref, -// Literal.create(new Timestamp(DateTime.parse(entry.getValue()).getMillis()), DataTypes.TimestampType))); -// break; -// default: -// throw new IllegalStateException("Unexpected data type in partition filters: " + dataType); -// } -// } catch (IllegalArgumentException e) { -// // ignore if filter is not on table columns -// } -// } -// -// return filterExpressions; -// } } diff --git a/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/spark/source/SparkBatch.java b/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/spark/source/SparkBatch.java index c6c8383442e..ec39e370c2c 100644 --- a/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/spark/source/SparkBatch.java +++ b/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/spark/source/SparkBatch.java @@ -40,7 +40,6 @@ public class SparkBatch implements Batch { private final JavaSparkContext sparkContext; private final Table table; - private final SparkReadConf readConf; private final List tasks; private final Schema expectedSchema; private final boolean caseSensitive; @@ -54,7 +53,6 @@ public class SparkBatch implements Batch { GpuSparkScan parentScan) { this.sparkContext = sparkContext; this.table = table; - this.readConf = readConf; this.tasks = tasks; this.expectedSchema = expectedSchema; this.caseSensitive = readConf.caseSensitive(); @@ -88,42 +86,6 @@ public PartitionReaderFactory createReaderFactory() { return new GpuSparkScan.ReaderFactory(parentScan.metrics()); } -// private int batchSize() { -// if (parquetOnly() && parquetBatchReadsEnabled()) { -// return readConf.parquetBatchSize(); -// } else if (orcOnly() && orcBatchReadsEnabled()) { -// return readConf.orcBatchSize(); -// } else { -// return 0; -// } -// } -// -// private boolean parquetOnly() { -// return tasks.stream().allMatch(task -> !task.isDataTask() && onlyFileFormat(task, FileFormat.PARQUET)); -// } -// -// private boolean parquetBatchReadsEnabled() { -// return readConf.parquetVectorizationEnabled() && // vectorization enabled -// expectedSchema.columns().size() > 0 && // at least one column is projected -// expectedSchema.columns().stream().allMatch(c -> c.type().isPrimitiveType()); // only primitives -// } -// -// private boolean orcOnly() { -// return tasks.stream().allMatch(task -> !task.isDataTask() && onlyFileFormat(task, FileFormat.ORC)); -// } -// -// private boolean orcBatchReadsEnabled() { -// return readConf.orcVectorizationEnabled() && // vectorization enabled -// tasks.stream().noneMatch(TableScanUtil::hasDeletes); // no delete files -// } -// -// private boolean onlyFileFormat(CombinedScanTask task, FileFormat fileFormat) { -// return task.files().stream().allMatch(fileScanTask -> fileScanTask.file().format().equals(fileFormat)); -// } - - // TODO: See if latest Iceberg code has the same issues with lacking equals/hashCode on batch - // causing exchange to not be reused - @Override public boolean equals(Object o) { if (this == o) return true; diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuParquetUtils.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuParquetUtils.scala index ddbfa9a1564..8808255dcab 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuParquetUtils.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuParquetUtils.scala @@ -18,15 +18,12 @@ package com.nvidia.spark.rapids import java.util.Locale -import scala.annotation.tailrec import scala.collection.JavaConverters._ -import scala.collection.mutable.ArrayBuffer import org.apache.parquet.hadoop.metadata.{BlockMetaData, ColumnChunkMetaData, ColumnPath} import org.apache.parquet.schema.MessageType import org.apache.spark.internal.Logging -import org.apache.spark.sql.types.StructType object GpuParquetUtils extends Logging { /** @@ -86,40 +83,4 @@ object GpuParquetUtils extends Logging { block } - - def getBlocksInBatch( - blockIter: BufferedIterator[BlockMetaData], - readSchema: StructType, - maxBatchSizeRows: Int, - maxBatchSizeBytes: Long): Seq[BlockMetaData] = { - val currentChunk = new ArrayBuffer[BlockMetaData] - var numRows: Long = 0 - var numBytes: Long = 0 - var numParquetBytes: Long = 0 - - @tailrec - def readNextBatch(): Unit = { - if (blockIter.hasNext) { - val peekedRowGroup = blockIter.head - if (peekedRowGroup.getRowCount > Integer.MAX_VALUE) { - throw new UnsupportedOperationException("Too many rows in split") - } - if (numRows == 0 || numRows + peekedRowGroup.getRowCount <= maxBatchSizeRows) { - val estimatedBytes = GpuBatchUtils.estimateGpuMemory(readSchema, - peekedRowGroup.getRowCount) - if (numBytes == 0 || numBytes + estimatedBytes <= maxBatchSizeBytes) { - currentChunk += blockIter.next() - numRows += currentChunk.last.getRowCount - numParquetBytes += currentChunk.last.getTotalByteSize - numBytes += estimatedBytes - readNextBatch() - } - } - } - } - readNextBatch() - logDebug(s"Loaded $numRows rows from Parquet. Parquet bytes read: $numParquetBytes. " + - s"Estimated GPU bytes: $numBytes") - currentChunk - } } From 352eb927c31c609c1da070f82afc5d84ebddc56d Mon Sep 17 00:00:00 2001 From: Jason Lowe Date: Mon, 27 Jun 2022 10:42:50 -0500 Subject: [PATCH 12/36] Add Iceberg support to ExternalSource --- .../rapids/iceberg/IcebergProvider.scala | 32 +++++++++++ .../rapids/iceberg/IcebergProviderImpl.scala | 55 +++++++++++++++++++ .../iceberg/parquet/GpuParquetReader.java | 2 +- .../nvidia/spark/rapids/AvroProvider.scala | 4 ++ .../spark/sql/rapids/AvroProviderImpl.scala | 9 +++ .../spark/sql/rapids/ExternalSource.scala | 44 ++++++++++----- 6 files changed, 130 insertions(+), 16 deletions(-) create mode 100644 sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/IcebergProvider.scala create mode 100644 sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/IcebergProviderImpl.scala diff --git a/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/IcebergProvider.scala b/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/IcebergProvider.scala new file mode 100644 index 00000000000..ced6c6ab58b --- /dev/null +++ b/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/IcebergProvider.scala @@ -0,0 +1,32 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark.rapids.iceberg + +import com.nvidia.spark.rapids.ScanRule + +import org.apache.spark.sql.connector.read.Scan + +/** Interfaces to avoid accessing the optional Apache Iceberg jars directly in common code. */ +trait IcebergProvider { + def isSupportedScan(scan: Scan): Boolean + + def getScans: Map[Class[_ <: Scan], ScanRule[_ <: Scan]] +} + +object IcebergProvider { + val cpuScanClassName: String = "org.apache.iceberg.spark.source.SparkBatchQueryScan" +} diff --git a/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/IcebergProviderImpl.scala b/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/IcebergProviderImpl.scala new file mode 100644 index 00000000000..4b503b1e726 --- /dev/null +++ b/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/IcebergProviderImpl.scala @@ -0,0 +1,55 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark.rapids.iceberg + +import scala.reflect.ClassTag + +import com.nvidia.spark.rapids.{FileFormatChecks, IcebergFormatType, RapidsConf, ReadFileOp, ScanMeta, ScanRule, ShimLoader} +import com.nvidia.spark.rapids.iceberg.spark.source.GpuSparkBatchQueryScan + +import org.apache.spark.sql.connector.read.Scan + +class IcebergProviderImpl extends IcebergProvider { + override def isSupportedScan(scan: Scan): Boolean = scan.isInstanceOf[GpuSparkBatchQueryScan] + + override def getScans: Map[Class[_ <: Scan], ScanRule[_ <: Scan]] = { + val cpuIcebergScanClass = ShimLoader.loadClass(IcebergProvider.cpuScanClassName) + Seq(new ScanRule[Scan]( + (a, conf, p, r) => new ScanMeta[Scan](a, conf, p, r) { + override def supportsRuntimeFilters: Boolean = true + + override def tagSelfForGpu(): Unit = { + if (!conf.isIcebergEnabled) { + willNotWorkOnGpu("Iceberg input and output has been disabled. To enable set " + + s"${RapidsConf.ENABLE_ICEBERG} to true") + } + + if (!conf.isIcebergReadEnabled) { + willNotWorkOnGpu("Iceberg input has been disabled. To enable set " + + s"${RapidsConf.ENABLE_ICEBERG_READ} to true") + } + + FileFormatChecks.tag(this, a.readSchema(), IcebergFormatType, ReadFileOp) + } + + override def convertToGpu(): Scan = GpuSparkBatchQueryScan.fromCpu(a, conf) + }, + "Iceberg scan", + ClassTag(cpuIcebergScanClass)) + ).map(r => (r.getClassFor.asSubclass(classOf[Scan]), r)).toMap + } +} diff --git a/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/parquet/GpuParquetReader.java b/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/parquet/GpuParquetReader.java index 239b5887312..a5eb31c7c46 100644 --- a/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/parquet/GpuParquetReader.java +++ b/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/parquet/GpuParquetReader.java @@ -142,7 +142,7 @@ public org.apache.iceberg.io.CloseableIterator iterator() { ParquetPartitionReader partReader = new ParquetPartitionReader(conf, partFile, new Path(input.location()), clippedBlocks, fileReadSchema, caseSensitive, sparkSchema, debugDumpPrefix, maxBatchSizeRows, maxBatchSizeBytes, metrics, - true, true, true); + true, true, true, false); return new GpuIcebergReader(expectedSchema, partReader, deleteFilter, idToConstant); } catch (IOException e) { diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/AvroProvider.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/AvroProvider.scala index a7934050d75..03bf3362062 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/AvroProvider.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/AvroProvider.scala @@ -49,4 +49,8 @@ trait AvroProvider { fileScan: GpuFileSourceScanExec): PartitionReaderFactory def getScans: Map[Class[_ <: Scan], ScanRule[_ <: Scan]] + + def isSupportedScan(scan: Scan): Boolean + + def copyScanWithInputFileTrue(scan: Scan): Scan } diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/AvroProviderImpl.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/AvroProviderImpl.scala index 12e513df1be..0771b4aa614 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/AvroProviderImpl.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/AvroProviderImpl.scala @@ -113,4 +113,13 @@ class AvroProviderImpl extends AvroProvider { }) ).map(r => (r.getClassFor.asSubclass(classOf[Scan]), r)).toMap } + + def isSupportedScan(scan: Scan): Boolean = scan.isInstanceOf[GpuAvroScan] + + def copyScanWithInputFileTrue(scan: Scan): Scan = scan match { + case avroScan: GpuAvroScan => + avroScan.copy(queryUsesInputFile=true) + case _ => + throw new RuntimeException(s"Unsupported scan type: ${scan.getClass.getSimpleName}") + } } diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/ExternalSource.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/ExternalSource.scala index 1b92f0ef029..167296456a6 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/ExternalSource.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/ExternalSource.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.rapids import scala.util.Try import com.nvidia.spark.rapids._ +import com.nvidia.spark.rapids.iceberg.IcebergProvider import org.apache.spark.broadcast.Broadcast import org.apache.spark.internal.Logging @@ -49,7 +50,15 @@ object ExternalSource extends Logging { } } - lazy val avroProvider = ShimLoader.newAvroProvider + lazy val avroProvider = ShimLoader.newAvroProvider() + + private lazy val hasIcebergJar = { + Utils.classIsLoadable(IcebergProvider.cpuScanClassName) && + Try(ShimLoader.loadClass(IcebergProvider.cpuScanClassName)).isSuccess + } + + private lazy val icebergProvider = ShimLoader.newInstanceOf[IcebergProvider]( + "com.nvidia.spark.rapids.iceberg.IcebergProviderImpl") /** If the file format is supported as an external source */ def isSupportedFormat(format: FileFormat): Boolean = { @@ -100,19 +109,25 @@ object ExternalSource extends Logging { } def getScans: Map[Class[_ <: Scan], ScanRule[_ <: Scan]] = { + var scans: Map[Class[_ <: Scan], ScanRule[_ <: Scan]] = Map.empty if (hasSparkAvroJar) { - avroProvider.getScans - } else Map.empty + scans = scans ++ avroProvider.getScans + } + if (hasIcebergJar) { + scans = scans ++ icebergProvider.getScans + } + scans } /** If the scan is supported as an external source */ def isSupportedScan(scan: Scan): Boolean = { - if (hasSparkAvroJar) { - scan match { - case _: GpuAvroScan => true - case _ => false - } - } else false + if (hasSparkAvroJar && avroProvider.isSupportedScan(scan)) { + true + } else if (hasIcebergJar && icebergProvider.isSupportedScan(scan)) { + true + } else { + false + } } /** @@ -120,12 +135,11 @@ object ExternalSource extends Logging { * Better to check if the scan is supported first by calling 'isSupportedScan'. */ def copyScanWithInputFileTrue(scan: Scan): Scan = { - if (hasSparkAvroJar) { - scan match { - case avroScan: GpuAvroScan => avroScan.copy(queryUsesInputFile=true) - case _ => - throw new RuntimeException(s"Unsupported scan type: ${scan.getClass.getSimpleName}") - } + if (hasSparkAvroJar && avroProvider.isSupportedScan(scan)) { + avroProvider.copyScanWithInputFileTrue(scan) + } else if (hasIcebergJar && icebergProvider.isSupportedScan(scan)) { + // Iceberg does not yet support a coalescing reader, so nothing to change + scan } else { throw new RuntimeException(s"Unsupported scan type: ${scan.getClass.getSimpleName}") } From 0a91f4fe8c2a2b58a3cc166162614b8b026dabd4 Mon Sep 17 00:00:00 2001 From: Jason Lowe Date: Mon, 27 Jun 2022 11:03:14 -0500 Subject: [PATCH 13/36] Fix missing bytes read metric from stages reading from Iceberg --- .../nvidia/spark/rapids/iceberg/parquet/GpuParquetReader.java | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/parquet/GpuParquetReader.java b/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/parquet/GpuParquetReader.java index a5eb31c7c46..84da4b72fd1 100644 --- a/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/parquet/GpuParquetReader.java +++ b/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/parquet/GpuParquetReader.java @@ -27,6 +27,7 @@ import com.nvidia.spark.rapids.GpuMetric; import com.nvidia.spark.rapids.GpuParquetUtils; import com.nvidia.spark.rapids.ParquetPartitionReader; +import com.nvidia.spark.rapids.PartitionReaderWithBytesRead; import com.nvidia.spark.rapids.iceberg.data.GpuDeleteFilter; import com.nvidia.spark.rapids.iceberg.spark.SparkSchemaUtil; import com.nvidia.spark.rapids.iceberg.spark.source.GpuIcebergReader; @@ -139,10 +140,11 @@ public org.apache.iceberg.io.CloseableIterator iterator() { fileReadSchema, filteredRowGroups, caseSensitive); // reuse Parquet scan code to read the raw data from the file - ParquetPartitionReader partReader = new ParquetPartitionReader(conf, partFile, + ParquetPartitionReader parquetPartReader = new ParquetPartitionReader(conf, partFile, new Path(input.location()), clippedBlocks, fileReadSchema, caseSensitive, sparkSchema, debugDumpPrefix, maxBatchSizeRows, maxBatchSizeBytes, metrics, true, true, true, false); + PartitionReaderWithBytesRead partReader = new PartitionReaderWithBytesRead(parquetPartReader); return new GpuIcebergReader(expectedSchema, partReader, deleteFilter, idToConstant); } catch (IOException e) { From eaa3c0999ee2f7a7b906af8a7c9e6afe8d5405c5 Mon Sep 17 00:00:00 2001 From: Jason Lowe Date: Mon, 27 Jun 2022 15:07:47 -0500 Subject: [PATCH 14/36] Fix handling of list columns, add round trip Parquet read test --- .../src/main/python/iceberg_test.py | 29 +++++++++++++++++++ .../iceberg/parquet/GpuParquetReader.java | 7 ++--- 2 files changed, 32 insertions(+), 4 deletions(-) diff --git a/integration_tests/src/main/python/iceberg_test.py b/integration_tests/src/main/python/iceberg_test.py index 2ab82977049..787fc0e51ed 100644 --- a/integration_tests/src/main/python/iceberg_test.py +++ b/integration_tests/src/main/python/iceberg_test.py @@ -19,6 +19,22 @@ from marks import allow_non_gpu, iceberg, ignore_order from spark_session import is_before_spark_320, is_databricks_runtime, with_cpu_session +iceberg_map_gens = [MapGen(f(nullable=False), f()) for f in [ + BooleanGen, ByteGen, ShortGen, IntegerGen, LongGen, FloatGen, DoubleGen, DateGen, TimestampGen ]] + \ + [simple_string_to_string_map_gen, + MapGen(StringGen(pattern='key_[0-9]', nullable=False), ArrayGen(string_gen), max_length=10), + MapGen(RepeatSeqGen(IntegerGen(nullable=False), 10), long_gen, max_length=10), + MapGen(StringGen(pattern='key_[0-9]', nullable=False), simple_string_to_string_map_gen)] + +iceberg_gens_list = [ + [byte_gen, short_gen, int_gen, long_gen, float_gen, double_gen, + string_gen, boolean_gen, date_gen, timestamp_gen, + ArrayGen(byte_gen), ArrayGen(long_gen), ArrayGen(string_gen), ArrayGen(date_gen), + ArrayGen(timestamp_gen), ArrayGen(decimal_gen_64bit), ArrayGen(ArrayGen(byte_gen)), + StructGen([['child0', ArrayGen(byte_gen)], ['child1', byte_gen], ['child2', float_gen], ['child3', decimal_gen_64bit]]), + ArrayGen(StructGen([['child0', string_gen], ['child1', double_gen], ['child2', int_gen]])) + ] + iceberg_map_gens + decimal_gens ] + @allow_non_gpu('BatchScanExec') @iceberg def test_iceberg_fallback_not_unsafe_row(spark_tmp_table_factory): @@ -48,3 +64,16 @@ def setup_iceberg_table(spark): lambda spark : spark.sql("SELECT * from {} as X JOIN {} as Y ON X.a = Y.a WHERE Y.a > 0".format(table, table)), conf={"spark.sql.adaptive.enabled": "true", "spark.sql.optimizer.dynamicPartitionPruning.enabled": "true"}) + +@iceberg +@pytest.mark.parametrize('iceberg_gens', iceberg_gens_list, ids=idfn) +def test_iceberg_parquet_read_round_trip(spark_tmp_table_factory, iceberg_gens): + gen_list = [('_c' + str(i), gen) for i, gen in enumerate(iceberg_gens)] + table_name = spark_tmp_table_factory.get() + def setup_iceberg_table(spark): + df = gen_df(spark, gen_list) + df.createOrReplaceTempView("df") + spark.sql("CREATE TABLE {} USING ICEBERG AS SELECT * FROM df".format(table_name)) + with_cpu_session(setup_iceberg_table) + assert_gpu_and_cpu_are_equal_collect( + lambda spark : spark.sql("SELECT * FROM {}".format(table_name))) diff --git a/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/parquet/GpuParquetReader.java b/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/parquet/GpuParquetReader.java index 84da4b72fd1..0e30e30c94d 100644 --- a/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/parquet/GpuParquetReader.java +++ b/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/parquet/GpuParquetReader.java @@ -206,13 +206,12 @@ public Type list(Types.ListType expectedList, GroupType list, Type element) { if (hasConstant) { throw new UnsupportedOperationException("constant column in list"); } - Type originalElement = list.getFields().get(0); + GroupType repeated = list.getType(0).asGroupType(); + Type originalElement = repeated.getType(0); if (Objects.equals(element, originalElement)) { return list; - } else if (originalElement.isRepetition(Type.Repetition.REPEATED)) { - return list.withNewFields(element); } - return list.withNewFields(list.getType(0).asGroupType().withNewFields(element)); + return list.withNewFields(repeated.withNewFields(element)); } @Override From 6c7b0e0057dbe8336b4104cf0aab5da0bd76bee8 Mon Sep 17 00:00:00 2001 From: Jason Lowe Date: Mon, 27 Jun 2022 15:23:04 -0500 Subject: [PATCH 15/36] Fix Iceberg read enable config --- .../src/main/python/iceberg_test.py | 17 ++++++++++++++++- .../com/nvidia/spark/rapids/RapidsConf.scala | 2 +- 2 files changed, 17 insertions(+), 2 deletions(-) diff --git a/integration_tests/src/main/python/iceberg_test.py b/integration_tests/src/main/python/iceberg_test.py index 787fc0e51ed..80f97f9b3ba 100644 --- a/integration_tests/src/main/python/iceberg_test.py +++ b/integration_tests/src/main/python/iceberg_test.py @@ -14,7 +14,7 @@ import pytest -from asserts import assert_gpu_and_cpu_are_equal_collect +from asserts import assert_gpu_and_cpu_are_equal_collect, assert_gpu_fallback_collect from data_gen import * from marks import allow_non_gpu, iceberg, ignore_order from spark_session import is_before_spark_320, is_databricks_runtime, with_cpu_session @@ -77,3 +77,18 @@ def setup_iceberg_table(spark): with_cpu_session(setup_iceberg_table) assert_gpu_and_cpu_are_equal_collect( lambda spark : spark.sql("SELECT * FROM {}".format(table_name))) + +@iceberg +@allow_non_gpu("BatchScanExec") +@pytest.mark.parametrize("disable_conf", ["spark.rapids.sql.format.iceberg.enabled", + "spark.rapids.sql.format.iceberg.read.enabled"]) +def test_iceberg_read_fallback(spark_tmp_table_factory, disable_conf): + table = spark_tmp_table_factory.get() + def setup_iceberg_table(spark): + spark.sql("CREATE TABLE {} (id BIGINT, data STRING) USING ICEBERG".format(table)) + spark.sql("INSERT INTO {} VALUES (1, 'a'), (2, 'b'), (3, 'c')".format(table)) + with_cpu_session(setup_iceberg_table) + assert_gpu_fallback_collect( + lambda spark : spark.sql("SELECT * from {}".format(table)), + "BatchScanExec", + conf = {disable_conf : "false"}) diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsConf.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsConf.scala index 422779187d0..d1012a3d28a 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsConf.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsConf.scala @@ -1016,7 +1016,7 @@ object RapidsConf { .booleanConf .createWithDefault(true) - val ENABLE_ICEBERG_READ = conf("spark.rapids.sql.format.iceberg.enabled") + val ENABLE_ICEBERG_READ = conf("spark.rapids.sql.format.iceberg.read.enabled") .doc("When set to false disables Iceberg input acceleration") .booleanConf .createWithDefault(true) From 41bdc8628dcbf9a318561318c47dd00ad014df2c Mon Sep 17 00:00:00 2001 From: Jason Lowe Date: Tue, 28 Jun 2022 09:00:54 -0500 Subject: [PATCH 16/36] Add more Iceberg tests --- .../src/main/python/iceberg_test.py | 74 +++++++++++++++---- 1 file changed, 61 insertions(+), 13 deletions(-) diff --git a/integration_tests/src/main/python/iceberg_test.py b/integration_tests/src/main/python/iceberg_test.py index 80f97f9b3ba..2b15994c6b9 100644 --- a/integration_tests/src/main/python/iceberg_test.py +++ b/integration_tests/src/main/python/iceberg_test.py @@ -14,10 +14,10 @@ import pytest -from asserts import assert_gpu_and_cpu_are_equal_collect, assert_gpu_fallback_collect +from asserts import assert_gpu_and_cpu_are_equal_collect, assert_gpu_fallback_collect, assert_py4j_exception from data_gen import * from marks import allow_non_gpu, iceberg, ignore_order -from spark_session import is_before_spark_320, is_databricks_runtime, with_cpu_session +from spark_session import is_before_spark_320, is_databricks_runtime, with_cpu_session, with_gpu_session iceberg_map_gens = [MapGen(f(nullable=False), f()) for f in [ BooleanGen, ByteGen, ShortGen, IntegerGen, LongGen, FloatGen, DoubleGen, DateGen, TimestampGen ]] + \ @@ -35,7 +35,7 @@ ArrayGen(StructGen([['child0', string_gen], ['child1', double_gen], ['child2', int_gen]])) ] + iceberg_map_gens + decimal_gens ] -@allow_non_gpu('BatchScanExec') +@allow_non_gpu("BatchScanExec") @iceberg def test_iceberg_fallback_not_unsafe_row(spark_tmp_table_factory): table = spark_tmp_table_factory.get() @@ -54,11 +54,12 @@ def setup_iceberg_table(spark): reason="AQE+DPP not supported until Spark 3.2.0+ and AQE+DPP not supported on Databricks") def test_iceberg_aqe_dpp(spark_tmp_table_factory): table = spark_tmp_table_factory.get() + tmpview = spark_tmp_table_factory.get() def setup_iceberg_table(spark): df = two_col_df(spark, int_gen, int_gen) - df.createOrReplaceTempView("df") + df.createOrReplaceTempView(tmpview) spark.sql("CREATE TABLE {} (a INT, b INT) USING ICEBERG PARTITIONED BY (a)".format(table)) - spark.sql("INSERT INTO {} SELECT * FROM df".format(table)) + spark.sql("INSERT INTO {} SELECT * FROM {}".format(table, tmpview)) with_cpu_session(setup_iceberg_table) assert_gpu_and_cpu_are_equal_collect( lambda spark : spark.sql("SELECT * from {} as X JOIN {} as Y ON X.a = Y.a WHERE Y.a > 0".format(table, table)), @@ -66,17 +67,36 @@ def setup_iceberg_table(spark): "spark.sql.optimizer.dynamicPartitionPruning.enabled": "true"}) @iceberg -@pytest.mark.parametrize('iceberg_gens', iceberg_gens_list, ids=idfn) -def test_iceberg_parquet_read_round_trip(spark_tmp_table_factory, iceberg_gens): - gen_list = [('_c' + str(i), gen) for i, gen in enumerate(iceberg_gens)] - table_name = spark_tmp_table_factory.get() +@pytest.mark.parametrize("data_gens", iceberg_gens_list, ids=idfn) +def test_iceberg_parquet_read_round_trip(spark_tmp_table_factory, data_gens): + gen_list = [('_c' + str(i), gen) for i, gen in enumerate(data_gens)] + table = spark_tmp_table_factory.get() + tmpview = spark_tmp_table_factory.get() def setup_iceberg_table(spark): df = gen_df(spark, gen_list) - df.createOrReplaceTempView("df") - spark.sql("CREATE TABLE {} USING ICEBERG AS SELECT * FROM df".format(table_name)) + df.createOrReplaceTempView(tmpview) + spark.sql("CREATE TABLE {} USING ICEBERG AS SELECT * FROM {}".format(table, tmpview)) with_cpu_session(setup_iceberg_table) assert_gpu_and_cpu_are_equal_collect( - lambda spark : spark.sql("SELECT * FROM {}".format(table_name))) + lambda spark : spark.sql("SELECT * FROM {}".format(table))) + +@iceberg +@pytest.mark.parametrize("data_gens", [[long_gen]], ids=idfn) +@pytest.mark.parametrize("iceberg_format", ["orc", "avro"], ids=idfn) +def test_iceberg_unsupported_formats(spark_tmp_table_factory, data_gens, iceberg_format): + gen_list = [('_c' + str(i), gen) for i, gen in enumerate(data_gens)] + table = spark_tmp_table_factory.get() + tmpview = spark_tmp_table_factory.get() + def setup_iceberg_table(spark): + df = gen_df(spark, gen_list) + df.createOrReplaceTempView(tmpview) + spark.sql("CREATE TABLE {} USING ICEBERG ".format(table) + \ + "TBLPROPERTIES('write.format.default' = '{}') ".format(iceberg_format) + \ + "AS SELECT * FROM {}".format(tmpview)) + with_cpu_session(setup_iceberg_table) + assert_py4j_exception( + lambda : with_gpu_session(lambda spark : spark.sql("SELECT * FROM {}".format(table)).collect()), + "UnsupportedOperationException") @iceberg @allow_non_gpu("BatchScanExec") @@ -89,6 +109,34 @@ def setup_iceberg_table(spark): spark.sql("INSERT INTO {} VALUES (1, 'a'), (2, 'b'), (3, 'c')".format(table)) with_cpu_session(setup_iceberg_table) assert_gpu_fallback_collect( - lambda spark : spark.sql("SELECT * from {}".format(table)), + lambda spark : spark.sql("SELECT * FROM {}".format(table)), "BatchScanExec", conf = {disable_conf : "false"}) + +@iceberg +# Compression codec to test and whether the codec is supported by cudf +# Note that compression codecs brotli and lzo need extra jars +# https://githbub.com/NVIDIA/spark-rapids/issues/143 +@pytest.mark.parametrize("codec_info", [ + ("uncompressed", None), + ("snappy", None), + ("gzip", None), + ("lz4", "Unsupported compression type"), + ("zstd", "Zstandard compression is experimental")]) +def test_iceberg_read_parquet_compression_codec(spark_tmp_table_factory, codec_info): + codec, error_msg = codec_info + table = spark_tmp_table_factory.get() + tmpview = spark_tmp_table_factory.get() + def setup_iceberg_table(spark): + df = binary_op_df(spark, long_gen) + df.createOrReplaceTempView(tmpview) + spark.sql("CREATE TABLE {} (id BIGINT, data BIGINT) USING ICEBERG ".format(table) + \ + "TBLPROPERTIES('write.parquet.compression-codec' = '{}')".format(codec)) + spark.sql("INSERT INTO {} SELECT * FROM {}".format(table, tmpview)) + with_cpu_session(setup_iceberg_table) + query = "SELECT * FROM {}".format(table) + if error_msg: + assert_py4j_exception( + lambda : with_gpu_session(lambda spark : spark.sql(query).collect()), error_msg) + else: + assert_gpu_and_cpu_are_equal_collect(lambda spark : spark.sql(query)) From 8f2361d6e05b33eb6bccf5e073d0a01472ae984b Mon Sep 17 00:00:00 2001 From: Jason Lowe Date: Tue, 28 Jun 2022 09:31:32 -0500 Subject: [PATCH 17/36] Remove unused code --- .../spark/rapids/iceberg/orc/GpuORC.java | 119 ------------ .../iceberg/parquet/ParquetConversions.java | 36 ---- .../rapids/iceberg/parquet/PruneColumns.java | 172 ------------------ .../rapids/iceberg/spark/SparkTypeToType.java | 162 ----------------- .../spark/source/GpuBatchDataReader.java | 23 --- .../iceberg/spark/source/GpuSparkScan.java | 2 +- 6 files changed, 1 insertion(+), 513 deletions(-) delete mode 100644 sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/orc/GpuORC.java delete mode 100644 sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/parquet/PruneColumns.java delete mode 100644 sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/spark/SparkTypeToType.java diff --git a/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/orc/GpuORC.java b/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/orc/GpuORC.java deleted file mode 100644 index 315e2431feb..00000000000 --- a/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/orc/GpuORC.java +++ /dev/null @@ -1,119 +0,0 @@ -/* - * Copyright (c) 2022, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.nvidia.spark.rapids.iceberg.orc; - -import java.util.Map; - -import org.apache.iceberg.Schema; -import org.apache.iceberg.expressions.Expression; -import org.apache.iceberg.hadoop.HadoopInputFile; -import org.apache.iceberg.io.CloseableIterable; -import org.apache.iceberg.io.InputFile; -import org.apache.iceberg.mapping.NameMapping; -import org.apache.iceberg.relocated.com.google.common.base.Preconditions; -import org.apache.hadoop.conf.Configuration; -import org.apache.orc.OrcConf; - -/** GPU version of Apache Iceberg's ORC class */ -public class GpuORC { - private GpuORC() { - } - - public static ReadBuilder read(InputFile file) { - return new ReadBuilder(file); - } - - public static class ReadBuilder { - private final InputFile file; - private final Configuration conf; - private Schema projectSchema = null; - private Schema readerExpectedSchema = null; - private Map idToConstant = null; - private Long start = null; - private Long length = null; - private Expression filter = null; - private boolean caseSensitive = true; - private NameMapping nameMapping = null; - - private ReadBuilder(InputFile file) { - Preconditions.checkNotNull(file, "Input file cannot be null"); - this.file = file; - if (file instanceof HadoopInputFile) { - this.conf = new Configuration(((HadoopInputFile) file).getConf()); - } else { - this.conf = new Configuration(); - } - - // We need to turn positional schema evolution off since we use column name based schema evolution for projection - this.conf.setBoolean(OrcConf.FORCE_POSITIONAL_EVOLUTION.getHiveConfName(), false); - } - - /** - * Restricts the read to the given range: [start, start + length). - * - * @param newStart the start position for this read - * @param newLength the length of the range this read should scan - * @return this builder for method chaining - */ - public ReadBuilder split(long newStart, long newLength) { - this.start = newStart; - this.length = newLength; - return this; - } - - public ReadBuilder project(Schema newSchema) { - this.projectSchema = newSchema; - return this; - } - - public ReadBuilder readerExpectedSchema(Schema newSchema) { - this.readerExpectedSchema = newSchema; - return this; - } - - public ReadBuilder constants(Map constants) { - this.idToConstant = constants; - return this; - } - - public ReadBuilder caseSensitive(boolean newCaseSensitive) { - OrcConf.IS_SCHEMA_EVOLUTION_CASE_SENSITIVE.setBoolean(this.conf, newCaseSensitive); - this.caseSensitive = newCaseSensitive; - return this; - } - - public ReadBuilder config(String property, String value) { - conf.set(property, value); - return this; - } - - public ReadBuilder filter(Expression newFilter) { - this.filter = newFilter; - return this; - } - - public ReadBuilder withNameMapping(NameMapping newNameMapping) { - this.nameMapping = newNameMapping; - return this; - } - - public CloseableIterable build() { - Preconditions.checkNotNull(projectSchema, "Schema is required"); - throw new UnsupportedOperationException(); - } - } -} diff --git a/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/parquet/ParquetConversions.java b/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/parquet/ParquetConversions.java index 32126f415f8..9dd5d24ed3a 100644 --- a/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/parquet/ParquetConversions.java +++ b/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/parquet/ParquetConversions.java @@ -20,10 +20,8 @@ import java.math.BigInteger; import java.nio.ByteBuffer; import java.nio.charset.StandardCharsets; -import java.util.UUID; import java.util.function.Function; -import org.apache.iceberg.expressions.Literal; import org.apache.iceberg.types.Type; import org.apache.parquet.io.api.Binary; import org.apache.parquet.schema.PrimitiveType; @@ -32,40 +30,6 @@ public class ParquetConversions { private ParquetConversions() { } - @SuppressWarnings("unchecked") - static Literal fromParquetPrimitive(Type type, PrimitiveType parquetType, Object value) { - switch (type.typeId()) { - case BOOLEAN: - return (Literal) Literal.of((Boolean) value); - case INTEGER: - case DATE: - return (Literal) Literal.of((Integer) value); - case LONG: - case TIME: - case TIMESTAMP: - return (Literal) Literal.of((Long) value); - case FLOAT: - return (Literal) Literal.of((Float) value); - case DOUBLE: - return (Literal) Literal.of((Double) value); - case STRING: - Function stringConversion = converterFromParquet(parquetType); - return (Literal) Literal.of((CharSequence) stringConversion.apply(value)); - case UUID: - Function uuidConversion = converterFromParquet(parquetType); - return (Literal) Literal.of((UUID) uuidConversion.apply(value)); - case FIXED: - case BINARY: - Function binaryConversion = converterFromParquet(parquetType); - return (Literal) Literal.of((ByteBuffer) binaryConversion.apply(value)); - case DECIMAL: - Function decimalConversion = converterFromParquet(parquetType); - return (Literal) Literal.of((BigDecimal) decimalConversion.apply(value)); - default: - throw new IllegalArgumentException("Unsupported primitive type: " + type); - } - } - static Function converterFromParquet(PrimitiveType parquetType, Type icebergType) { Function fromParquet = converterFromParquet(parquetType); if (icebergType != null) { diff --git a/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/parquet/PruneColumns.java b/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/parquet/PruneColumns.java deleted file mode 100644 index 52ecc19ba70..00000000000 --- a/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/parquet/PruneColumns.java +++ /dev/null @@ -1,172 +0,0 @@ -/* - * Copyright (c) 2022, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.nvidia.spark.rapids.iceberg.parquet; - -import java.util.Collections; -import java.util.List; -import java.util.Set; -import org.apache.iceberg.relocated.com.google.common.base.Objects; -import org.apache.iceberg.relocated.com.google.common.base.Preconditions; -import org.apache.iceberg.relocated.com.google.common.collect.Lists; -import org.apache.parquet.schema.GroupType; -import org.apache.parquet.schema.MessageType; -import org.apache.parquet.schema.OriginalType; -import org.apache.parquet.schema.PrimitiveType; -import org.apache.parquet.schema.Type; -import org.apache.parquet.schema.Types; - -public class PruneColumns extends ParquetTypeVisitor { - private final Set selectedIds; - - PruneColumns(Set selectedIds) { - Preconditions.checkNotNull(selectedIds, "Selected field ids cannot be null"); - this.selectedIds = selectedIds; - } - - @Override - public Type message(MessageType message, List fields) { - Types.MessageTypeBuilder builder = Types.buildMessage(); - - boolean hasChange = false; - int fieldCount = 0; - for (int i = 0; i < fields.size(); i += 1) { - Type originalField = message.getType(i); - Type field = fields.get(i); - Integer fieldId = getId(originalField); - if (fieldId != null && selectedIds.contains(fieldId)) { - if (field != null) { - hasChange = true; - builder.addField(field); - } else { - if (isStruct(originalField)) { - hasChange = true; - builder.addField(originalField.asGroupType().withNewFields(Collections.emptyList())); - } else { - builder.addField(originalField); - } - } - fieldCount += 1; - } else if (field != null) { - hasChange = true; - builder.addField(field); - fieldCount += 1; - } - } - - if (hasChange) { - return builder.named(message.getName()); - } else if (message.getFieldCount() == fieldCount) { - return message; - } - - return builder.named(message.getName()); - } - - @Override - public Type struct(GroupType struct, List fields) { - boolean hasChange = false; - List filteredFields = Lists.newArrayListWithExpectedSize(fields.size()); - for (int i = 0; i < fields.size(); i += 1) { - Type originalField = struct.getType(i); - Type field = fields.get(i); - Integer fieldId = getId(originalField); - if (fieldId != null && selectedIds.contains(fieldId)) { - filteredFields.add(originalField); - } else if (field != null) { - filteredFields.add(originalField); - hasChange = true; - } - } - - if (hasChange) { - return struct.withNewFields(filteredFields); - } else if (struct.getFieldCount() == filteredFields.size()) { - return struct; - } else if (!filteredFields.isEmpty()) { - return struct.withNewFields(filteredFields); - } - - return null; - } - - @Override - public Type list(GroupType list, Type element) { - Type repeated = list.getType(0); - Type originalElement = ParquetSchemaUtil.determineListElementType(list); - Integer elementId = getId(originalElement); - - if (elementId != null && selectedIds.contains(elementId)) { - return list; - } else if (element != null) { - if (!Objects.equal(element, originalElement)) { - if (originalElement.isRepetition(Type.Repetition.REPEATED)) { - return list.withNewFields(element); - } else { - return list.withNewFields(repeated.asGroupType().withNewFields(element)); - } - } - return list; - } - - return null; - } - - @Override - public Type map(GroupType map, Type key, Type value) { - GroupType repeated = map.getType(0).asGroupType(); - Type originalKey = repeated.getType(0); - Type originalValue = repeated.getType(1); - - Integer keyId = getId(originalKey); - Integer valueId = getId(originalValue); - - if ((keyId != null && selectedIds.contains(keyId)) || (valueId != null && selectedIds.contains(valueId))) { - return map; - } else if (value != null) { - if (!Objects.equal(value, originalValue)) { - return map.withNewFields(repeated.withNewFields(originalKey, value)); - } - return map; - } - - return null; - } - - @Override - public Type primitive(PrimitiveType primitive) { - return null; - } - - private Integer getId(Type type) { - return type.getId() == null ? null : type.getId().intValue(); - } - - private boolean isStruct(Type field) { - if (field.isPrimitive()) { - return false; - } else { - GroupType groupType = field.asGroupType(); - // Spark 3.1 uses Parquet 1.10 which does not have LogicalTypeAnnotation -// LogicalTypeAnnotation logicalTypeAnnotation = groupType.getLogicalTypeAnnotation(); -// return !logicalTypeAnnotation.equals(LogicalTypeAnnotation.mapType()) && -// !logicalTypeAnnotation.equals(LogicalTypeAnnotation.listType()); - OriginalType originalType = groupType.getOriginalType(); - return !originalType.equals(OriginalType.MAP) && - !originalType.equals(OriginalType.LIST); - } - } -} diff --git a/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/spark/SparkTypeToType.java b/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/spark/SparkTypeToType.java deleted file mode 100644 index bc170644184..00000000000 --- a/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/spark/SparkTypeToType.java +++ /dev/null @@ -1,162 +0,0 @@ -/* - * Copyright (c) 2022, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.nvidia.spark.rapids.iceberg.spark; - -import java.util.List; - -import org.apache.iceberg.relocated.com.google.common.collect.Lists; -import org.apache.iceberg.types.Type; -import org.apache.iceberg.types.Types; - -import org.apache.spark.sql.types.ArrayType; -import org.apache.spark.sql.types.BinaryType; -import org.apache.spark.sql.types.BooleanType; -import org.apache.spark.sql.types.ByteType; -import org.apache.spark.sql.types.CharType; -import org.apache.spark.sql.types.DataType; -import org.apache.spark.sql.types.DateType; -import org.apache.spark.sql.types.DecimalType; -import org.apache.spark.sql.types.DoubleType; -import org.apache.spark.sql.types.FloatType; -import org.apache.spark.sql.types.IntegerType; -import org.apache.spark.sql.types.LongType; -import org.apache.spark.sql.types.MapType; -import org.apache.spark.sql.types.ShortType; -import org.apache.spark.sql.types.StringType; -import org.apache.spark.sql.types.StructField; -import org.apache.spark.sql.types.StructType; -import org.apache.spark.sql.types.TimestampType; -import org.apache.spark.sql.types.VarcharType; - -public class SparkTypeToType extends SparkTypeVisitor { - private final StructType root; - private int nextId = 0; - - SparkTypeToType() { - this.root = null; - } - - SparkTypeToType(StructType root) { - this.root = root; - // the root struct's fields use the first ids - this.nextId = root.fields().length; - } - - private int getNextId() { - int next = nextId; - nextId += 1; - return next; - } - - @Override - @SuppressWarnings("ReferenceEquality") - public Type struct(StructType struct, List types) { - StructField[] fields = struct.fields(); - List newFields = Lists.newArrayListWithExpectedSize(fields.length); - boolean isRoot = root == struct; - for (int i = 0; i < fields.length; i += 1) { - StructField field = fields[i]; - Type type = types.get(i); - - int id; - if (isRoot) { - // for new conversions, use ordinals for ids in the root struct - id = i; - } else { - id = getNextId(); - } - - String doc = field.getComment().isDefined() ? field.getComment().get() : null; - - if (field.nullable()) { - newFields.add(Types.NestedField.optional(id, field.name(), type, doc)); - } else { - newFields.add(Types.NestedField.required(id, field.name(), type, doc)); - } - } - - return Types.StructType.of(newFields); - } - - @Override - public Type field(StructField field, Type typeResult) { - return typeResult; - } - - @Override - public Type array(ArrayType array, Type elementType) { - if (array.containsNull()) { - return Types.ListType.ofOptional(getNextId(), elementType); - } else { - return Types.ListType.ofRequired(getNextId(), elementType); - } - } - - @Override - public Type map(MapType map, Type keyType, Type valueType) { - if (map.valueContainsNull()) { - return Types.MapType.ofOptional(getNextId(), getNextId(), keyType, valueType); - } else { - return Types.MapType.ofRequired(getNextId(), getNextId(), keyType, valueType); - } - } - - @SuppressWarnings("checkstyle:CyclomaticComplexity") - @Override - public Type atomic(DataType atomic) { - if (atomic instanceof BooleanType) { - return Types.BooleanType.get(); - - } else if ( - atomic instanceof IntegerType || - atomic instanceof ShortType || - atomic instanceof ByteType) { - return Types.IntegerType.get(); - - } else if (atomic instanceof LongType) { - return Types.LongType.get(); - - } else if (atomic instanceof FloatType) { - return Types.FloatType.get(); - - } else if (atomic instanceof DoubleType) { - return Types.DoubleType.get(); - - } else if ( - atomic instanceof StringType || - atomic instanceof CharType || - atomic instanceof VarcharType) { - return Types.StringType.get(); - - } else if (atomic instanceof DateType) { - return Types.DateType.get(); - - } else if (atomic instanceof TimestampType) { - return Types.TimestampType.withZone(); - - } else if (atomic instanceof DecimalType) { - return Types.DecimalType.of( - ((DecimalType) atomic).precision(), - ((DecimalType) atomic).scale()); - } else if (atomic instanceof BinaryType) { - return Types.BinaryType.get(); - } - - throw new UnsupportedOperationException( - "Not a supported type: " + atomic.catalogString()); - } -} diff --git a/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/spark/source/GpuBatchDataReader.java b/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/spark/source/GpuBatchDataReader.java index 3dd8edac776..0ba8696a13f 100644 --- a/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/spark/source/GpuBatchDataReader.java +++ b/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/spark/source/GpuBatchDataReader.java @@ -17,18 +17,15 @@ package com.nvidia.spark.rapids.iceberg.spark.source; import java.util.Map; -import java.util.Set; import com.nvidia.spark.rapids.GpuMetric; import com.nvidia.spark.rapids.iceberg.data.GpuDeleteFilter; -import com.nvidia.spark.rapids.iceberg.orc.GpuORC; import com.nvidia.spark.rapids.iceberg.parquet.GpuParquet; import org.apache.hadoop.conf.Configuration; import org.apache.iceberg.CombinedScanTask; import org.apache.iceberg.DataFile; import org.apache.iceberg.FileFormat; import org.apache.iceberg.FileScanTask; -import org.apache.iceberg.MetadataColumns; import org.apache.iceberg.Schema; import org.apache.iceberg.Table; import org.apache.iceberg.TableProperties; @@ -37,8 +34,6 @@ import org.apache.iceberg.io.InputFile; import org.apache.iceberg.mapping.NameMappingParser; import org.apache.iceberg.relocated.com.google.common.base.Preconditions; -import org.apache.iceberg.relocated.com.google.common.collect.Sets; -import org.apache.iceberg.types.TypeUtil; import org.apache.spark.rdd.InputFileBlockHolder; import org.apache.spark.sql.vectorized.ColumnarBatch; @@ -104,24 +99,6 @@ CloseableIterator open(FileScanTask task) { builder.withNameMapping(NameMappingParser.fromJson(nameMapping)); } - iter = builder.build(); - } else if (task.file().format() == FileFormat.ORC) { - Set constantFieldIds = idToConstant.keySet(); - Set metadataFieldIds = MetadataColumns.metadataFieldIds(); - Sets.SetView constantAndMetadataFieldIds = Sets.union(constantFieldIds, metadataFieldIds); - Schema schemaWithoutConstantAndMetadataFields = TypeUtil.selectNot(expectedSchema, constantAndMetadataFieldIds); - GpuORC.ReadBuilder builder = GpuORC.read(location) - .project(schemaWithoutConstantAndMetadataFields) - .split(task.start(), task.length()) - .readerExpectedSchema(expectedSchema) - .constants(idToConstant) - .filter(task.residual()) - .caseSensitive(caseSensitive); - - if (nameMapping != null) { - builder.withNameMapping(NameMappingParser.fromJson(nameMapping)); - } - iter = builder.build(); } else { throw new UnsupportedOperationException( diff --git a/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/spark/source/GpuSparkScan.java b/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/spark/source/GpuSparkScan.java index 51083876607..df33fa4e165 100644 --- a/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/spark/source/GpuSparkScan.java +++ b/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/spark/source/GpuSparkScan.java @@ -178,7 +178,7 @@ public String description() { } static class ReaderFactory implements PartitionReaderFactory { - private scala.collection.immutable.Map metrics; + private final scala.collection.immutable.Map metrics; public ReaderFactory(scala.collection.immutable.Map metrics) { this.metrics = metrics; From 4964085874990f9094d080b0068761f5fc6d9dd0 Mon Sep 17 00:00:00 2001 From: Jason Lowe Date: Tue, 28 Jun 2022 13:59:30 -0500 Subject: [PATCH 18/36] More Iceberg tests --- .../src/main/python/iceberg_test.py | 96 ++++++++++++++++++- 1 file changed, 93 insertions(+), 3 deletions(-) diff --git a/integration_tests/src/main/python/iceberg_test.py b/integration_tests/src/main/python/iceberg_test.py index 2b15994c6b9..6c1f5d89ade 100644 --- a/integration_tests/src/main/python/iceberg_test.py +++ b/integration_tests/src/main/python/iceberg_test.py @@ -101,7 +101,7 @@ def setup_iceberg_table(spark): @iceberg @allow_non_gpu("BatchScanExec") @pytest.mark.parametrize("disable_conf", ["spark.rapids.sql.format.iceberg.enabled", - "spark.rapids.sql.format.iceberg.read.enabled"]) + "spark.rapids.sql.format.iceberg.read.enabled"], ids=idfn) def test_iceberg_read_fallback(spark_tmp_table_factory, disable_conf): table = spark_tmp_table_factory.get() def setup_iceberg_table(spark): @@ -121,8 +121,10 @@ def setup_iceberg_table(spark): ("uncompressed", None), ("snappy", None), ("gzip", None), - ("lz4", "Unsupported compression type"), - ("zstd", "Zstandard compression is experimental")]) + pytest.param(("lz4", "Unsupported compression type"), + marks=pytest.mark.skipif(is_before_spark_320(), + reason="Hadoop with Spark 3.1.x does not support lz4 by default")), + ("zstd", "Zstandard compression is experimental")], ids=idfn) def test_iceberg_read_parquet_compression_codec(spark_tmp_table_factory, codec_info): codec, error_msg = codec_info table = spark_tmp_table_factory.get() @@ -140,3 +142,91 @@ def setup_iceberg_table(spark): lambda : with_gpu_session(lambda spark : spark.sql(query).collect()), error_msg) else: assert_gpu_and_cpu_are_equal_collect(lambda spark : spark.sql(query)) + +@iceberg +@pytest.mark.parametrize("key_gen", [int_gen, long_gen, string_gen, boolean_gen, date_gen, timestamp_gen, decimal_gen_64bit], ids=idfn) +def test_iceberg_read_partition_key(spark_tmp_table_factory, key_gen): + table = spark_tmp_table_factory.get() + tmpview = spark_tmp_table_factory.get() + def setup_iceberg_table(spark): + df = two_col_df(spark, key_gen, long_gen).orderBy("a") + df.createOrReplaceTempView(tmpview) + spark.sql("CREATE TABLE {} USING ICEBERG PARTITIONED BY (a) ".format(table) + \ + "AS SELECT * FROM {}".format(tmpview)) + with_cpu_session(setup_iceberg_table) + assert_gpu_and_cpu_are_equal_collect( + lambda spark : spark.sql("SELECT a FROM {}".format(table))) + +@iceberg +def test_iceberg_input_meta(spark_tmp_table_factory): + table = spark_tmp_table_factory.get() + tmpview = spark_tmp_table_factory.get() + def setup_iceberg_table(spark): + df = binary_op_df(spark, long_gen).orderBy("a") + df.createOrReplaceTempView(tmpview) + spark.sql("CREATE TABLE {} USING ICEBERG PARTITIONED BY (a) ".format(table) + \ + "AS SELECT * FROM {}".format(tmpview)) + with_cpu_session(setup_iceberg_table) + assert_gpu_and_cpu_are_equal_collect( + lambda spark : spark.sql( + "SELECT a, input_file_name(), input_file_block_start(), input_file_block_length() " + \ + "FROM {}".format(table))) + +@iceberg +def test_iceberg_disorder_read_schema(spark_tmp_table_factory): + table = spark_tmp_table_factory.get() + tmpview = spark_tmp_table_factory.get() + def setup_iceberg_table(spark): + df = three_col_df(spark, long_gen, string_gen, float_gen) + df.createOrReplaceTempView(tmpview) + spark.sql("CREATE TABLE {} USING ICEBERG ".format(table) + \ + "AS SELECT * FROM {}".format(tmpview)) + with_cpu_session(setup_iceberg_table) + assert_gpu_and_cpu_are_equal_collect( + lambda spark : spark.sql("SELECT b,c,a FROM {}".format(table))) + +@iceberg +def test_iceberg_read_appended_table(spark_tmp_table_factory): + table = spark_tmp_table_factory.get() + tmpview = spark_tmp_table_factory.get() + def setup_iceberg_table(spark): + df = binary_op_df(spark, long_gen) + df.createOrReplaceTempView(tmpview) + spark.sql("CREATE TABLE {} USING ICEBERG ".format(table) + \ + "AS SELECT * FROM {}".format(tmpview)) + df = binary_op_df(spark, long_gen, seed=1) + df.createOrReplaceTempView(tmpview) + spark.sql("INSERT INTO {} ".format(table) + \ + "SELECT * FROM {}".format(tmpview)) + with_cpu_session(setup_iceberg_table) + assert_gpu_and_cpu_are_equal_collect(lambda spark : spark.sql("SELECT * FROM {}".format(table))) + +@iceberg +def test_iceberg_read_history(spark_tmp_table_factory): + table = spark_tmp_table_factory.get() + tmpview = spark_tmp_table_factory.get() + def setup_iceberg_table(spark): + df = binary_op_df(spark, long_gen) + df.createOrReplaceTempView(tmpview) + spark.sql("CREATE TABLE {} USING ICEBERG ".format(table) + \ + "AS SELECT * FROM {}".format(tmpview)) + df = binary_op_df(spark, long_gen, seed=1) + df.createOrReplaceTempView(tmpview) + spark.sql("INSERT INTO {} ".format(table) + \ + "SELECT * FROM {}".format(tmpview)) + with_cpu_session(setup_iceberg_table) + assert_gpu_and_cpu_are_equal_collect( + # SQL does not have syntax to read history table + lambda spark : spark.read.format("iceberg").load("default.{}.history".format(table))) + +# test appended data +# test column removed and more data appended +# test column added and more data appended +# test column type changed and more data appended +# test reordering struct fields and appending more data +# test column names swapped +# test time-travel with snapshot IDs and timestamps +# iceberg metadata queries (metadata table select, etc.) +# test reading data between two snapshot IDs +# https://iceberg.apache.org/docs/latest/spark-queries/ +# test v2 deletes (position and equality) From e06865cf197c2385fa44dbd52622b315f39fe480 Mon Sep 17 00:00:00 2001 From: Jason Lowe Date: Tue, 28 Jun 2022 17:07:17 -0500 Subject: [PATCH 19/36] Fix Iceberg metadata queries --- integration_tests/src/main/python/iceberg_test.py | 15 +++++++++------ .../rapids/iceberg/IcebergProviderImpl.scala | 4 ++++ .../spark/source/GpuSparkBatchQueryScan.java | 14 ++++++++++++++ 3 files changed, 27 insertions(+), 6 deletions(-) diff --git a/integration_tests/src/main/python/iceberg_test.py b/integration_tests/src/main/python/iceberg_test.py index 6c1f5d89ade..ccb1beed050 100644 --- a/integration_tests/src/main/python/iceberg_test.py +++ b/integration_tests/src/main/python/iceberg_test.py @@ -202,7 +202,9 @@ def setup_iceberg_table(spark): assert_gpu_and_cpu_are_equal_collect(lambda spark : spark.sql("SELECT * FROM {}".format(table))) @iceberg -def test_iceberg_read_history(spark_tmp_table_factory): +# Some metadata files have types that are not supported on the GPU yet (e.g.: BinaryType) +@allow_non_gpu("BatchScanExec", "ProjectExec") +def test_iceberg_read_metadata_fallback(spark_tmp_table_factory): table = spark_tmp_table_factory.get() tmpview = spark_tmp_table_factory.get() def setup_iceberg_table(spark): @@ -215,18 +217,19 @@ def setup_iceberg_table(spark): spark.sql("INSERT INTO {} ".format(table) + \ "SELECT * FROM {}".format(tmpview)) with_cpu_session(setup_iceberg_table) - assert_gpu_and_cpu_are_equal_collect( - # SQL does not have syntax to read history table - lambda spark : spark.read.format("iceberg").load("default.{}.history".format(table))) + for subtable in ["all_data_files", "all_manifests", "files", "history", + "manifests", "partitions", "snapshots"]: + # SQL does not have syntax to read table metadata + assert_gpu_fallback_collect( + lambda spark : spark.read.format("iceberg").load("default.{}.{}".format(table, subtable)), + "BatchScanExec") -# test appended data # test column removed and more data appended # test column added and more data appended # test column type changed and more data appended # test reordering struct fields and appending more data # test column names swapped # test time-travel with snapshot IDs and timestamps -# iceberg metadata queries (metadata table select, etc.) # test reading data between two snapshot IDs # https://iceberg.apache.org/docs/latest/spark-queries/ # test v2 deletes (position and equality) diff --git a/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/IcebergProviderImpl.scala b/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/IcebergProviderImpl.scala index 4b503b1e726..edde8111a94 100644 --- a/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/IcebergProviderImpl.scala +++ b/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/IcebergProviderImpl.scala @@ -44,6 +44,10 @@ class IcebergProviderImpl extends IcebergProvider { } FileFormatChecks.tag(this, a.readSchema(), IcebergFormatType, ReadFileOp) + + if (GpuSparkBatchQueryScan.isMetadataScan(a)) { + willNotWorkOnGpu("scan is a metadata scan") + } } override def convertToGpu(): Scan = GpuSparkBatchQueryScan.fromCpu(a, conf) diff --git a/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/spark/source/GpuSparkBatchQueryScan.java b/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/spark/source/GpuSparkBatchQueryScan.java index 4974007ef0b..eca3d7ed0fc 100644 --- a/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/spark/source/GpuSparkBatchQueryScan.java +++ b/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/spark/source/GpuSparkBatchQueryScan.java @@ -19,6 +19,7 @@ import java.io.IOException; import java.io.UncheckedIOException; import java.util.Collections; +import java.util.Iterator; import java.util.List; import java.util.Map; import java.util.Objects; @@ -80,6 +81,19 @@ public class GpuSparkBatchQueryScan extends GpuSparkScan implements ShimSupports private List files = null; // lazy cache of files private List tasks = null; // lazy cache of tasks + // Check for file scan tasks that are reported as data tasks. + // Null/empty tasks are assumed to be for scans not best performed by the GPU. + @SuppressWarnings("unchecked") + public static boolean isMetadataScan(Scan cpuInstance) throws IllegalAccessException { + List tasks = (List) FieldUtils.readField(cpuInstance, "tasks", true); + if (tasks == null || tasks.isEmpty()) { + return true; + } + Iterator taskIter = tasks.get(0).files().iterator(); + return !taskIter.hasNext() || taskIter.next().isDataTask(); + } + + @SuppressWarnings("unchecked") public static GpuSparkBatchQueryScan fromCpu(Scan cpuInstance, RapidsConf rapidsConf) throws IllegalAccessException { Table table = (Table) FieldUtils.readField(cpuInstance, "table", true); SparkReadConf readConf = SparkReadConf.fromReflect(FieldUtils.readField(cpuInstance, "readConf", true)); From f9590474a6c494ef90f71099a9eb4ddc7c5660e3 Mon Sep 17 00:00:00 2001 From: Jason Lowe Date: Wed, 29 Jun 2022 14:58:05 -0500 Subject: [PATCH 20/36] Fix reads of Iceberg tables with renamed columns --- .../src/main/python/iceberg_test.py | 120 +++++++++++++++++- .../iceberg/parquet/GpuParquetReader.java | 73 ++++++++++- .../rapids/iceberg/spark/SparkSchemaUtil.java | 31 ----- 3 files changed, 183 insertions(+), 41 deletions(-) diff --git a/integration_tests/src/main/python/iceberg_test.py b/integration_tests/src/main/python/iceberg_test.py index ccb1beed050..983f9b87549 100644 --- a/integration_tests/src/main/python/iceberg_test.py +++ b/integration_tests/src/main/python/iceberg_test.py @@ -37,6 +37,7 @@ @allow_non_gpu("BatchScanExec") @iceberg +@ignore_order(local=True) # Iceberg plans with a thread pool and is not deterministic in file ordering def test_iceberg_fallback_not_unsafe_row(spark_tmp_table_factory): table = spark_tmp_table_factory.get() def setup_iceberg_table(spark): @@ -67,6 +68,7 @@ def setup_iceberg_table(spark): "spark.sql.optimizer.dynamicPartitionPruning.enabled": "true"}) @iceberg +@ignore_order(local=True) # Iceberg plans with a thread pool and is not deterministic in file ordering @pytest.mark.parametrize("data_gens", iceberg_gens_list, ids=idfn) def test_iceberg_parquet_read_round_trip(spark_tmp_table_factory, data_gens): gen_list = [('_c' + str(i), gen) for i, gen in enumerate(data_gens)] @@ -100,6 +102,7 @@ def setup_iceberg_table(spark): @iceberg @allow_non_gpu("BatchScanExec") +@ignore_order(local=True) # Iceberg plans with a thread pool and is not deterministic in file ordering @pytest.mark.parametrize("disable_conf", ["spark.rapids.sql.format.iceberg.enabled", "spark.rapids.sql.format.iceberg.read.enabled"], ids=idfn) def test_iceberg_read_fallback(spark_tmp_table_factory, disable_conf): @@ -114,6 +117,7 @@ def setup_iceberg_table(spark): conf = {disable_conf : "false"}) @iceberg +@ignore_order(local=True) # Iceberg plans with a thread pool and is not deterministic in file ordering # Compression codec to test and whether the codec is supported by cudf # Note that compression codecs brotli and lzo need extra jars # https://githbub.com/NVIDIA/spark-rapids/issues/143 @@ -144,6 +148,7 @@ def setup_iceberg_table(spark): assert_gpu_and_cpu_are_equal_collect(lambda spark : spark.sql(query)) @iceberg +@ignore_order(local=True) # Iceberg plans with a thread pool and is not deterministic in file ordering @pytest.mark.parametrize("key_gen", [int_gen, long_gen, string_gen, boolean_gen, date_gen, timestamp_gen, decimal_gen_64bit], ids=idfn) def test_iceberg_read_partition_key(spark_tmp_table_factory, key_gen): table = spark_tmp_table_factory.get() @@ -158,6 +163,7 @@ def setup_iceberg_table(spark): lambda spark : spark.sql("SELECT a FROM {}".format(table))) @iceberg +@ignore_order(local=True) # Iceberg plans with a thread pool and is not deterministic in file ordering def test_iceberg_input_meta(spark_tmp_table_factory): table = spark_tmp_table_factory.get() tmpview = spark_tmp_table_factory.get() @@ -173,6 +179,7 @@ def setup_iceberg_table(spark): "FROM {}".format(table))) @iceberg +@ignore_order(local=True) # Iceberg plans with a thread pool and is not deterministic in file ordering def test_iceberg_disorder_read_schema(spark_tmp_table_factory): table = spark_tmp_table_factory.get() tmpview = spark_tmp_table_factory.get() @@ -186,6 +193,7 @@ def setup_iceberg_table(spark): lambda spark : spark.sql("SELECT b,c,a FROM {}".format(table))) @iceberg +@ignore_order(local=True) # Iceberg plans with a thread pool and is not deterministic in file ordering def test_iceberg_read_appended_table(spark_tmp_table_factory): table = spark_tmp_table_factory.get() tmpview = spark_tmp_table_factory.get() @@ -204,6 +212,7 @@ def setup_iceberg_table(spark): @iceberg # Some metadata files have types that are not supported on the GPU yet (e.g.: BinaryType) @allow_non_gpu("BatchScanExec", "ProjectExec") +@ignore_order(local=True) # Iceberg plans with a thread pool and is not deterministic in file ordering def test_iceberg_read_metadata_fallback(spark_tmp_table_factory): table = spark_tmp_table_factory.get() tmpview = spark_tmp_table_factory.get() @@ -224,12 +233,111 @@ def setup_iceberg_table(spark): lambda spark : spark.read.format("iceberg").load("default.{}.{}".format(table, subtable)), "BatchScanExec") +@iceberg +@ignore_order(local=True) # Iceberg plans with a thread pool and is not deterministic in file ordering +def test_iceberg_read_timetravel(spark_tmp_table_factory): + table = spark_tmp_table_factory.get() + tmpview = spark_tmp_table_factory.get() + def setup_snapshots(spark): + df = binary_op_df(spark, long_gen) + df.createOrReplaceTempView(tmpview) + spark.sql("CREATE TABLE {} USING ICEBERG ".format(table) + \ + "AS SELECT * FROM {}".format(tmpview)) + df = binary_op_df(spark, long_gen, seed=1) + df.createOrReplaceTempView(tmpview) + spark.sql("INSERT INTO {} ".format(table) + \ + "SELECT * FROM {}".format(tmpview)) + return spark.sql("SELECT snapshot_id FROM default.{}.snapshots ".format(table) + \ + "ORDER BY committed_at").head()[0] + first_snapshot_id = with_cpu_session(setup_snapshots) + assert_gpu_and_cpu_are_equal_collect( + lambda spark : spark.read.option("snapshot-id", first_snapshot_id) \ + .format("iceberg").load("default.{}".format(table))) + +@iceberg +@ignore_order(local=True) # Iceberg plans with a thread pool and is not deterministic in file ordering +def test_iceberg_incremental_read(spark_tmp_table_factory): + table = spark_tmp_table_factory.get() + tmpview = spark_tmp_table_factory.get() + def setup_snapshots(spark): + df = binary_op_df(spark, long_gen) + df.createOrReplaceTempView(tmpview) + spark.sql("CREATE TABLE {} USING ICEBERG ".format(table) + \ + "AS SELECT * FROM {}".format(tmpview)) + df = binary_op_df(spark, long_gen, seed=1) + df.createOrReplaceTempView(tmpview) + spark.sql("INSERT INTO {} ".format(table) + \ + "SELECT * FROM {}".format(tmpview)) + df = binary_op_df(spark, long_gen, seed=2) + df.createOrReplaceTempView(tmpview) + spark.sql("INSERT INTO {} ".format(table) + \ + "SELECT * FROM {}".format(tmpview)) + return spark.sql("SELECT snapshot_id FROM default.{}.snapshots ".format(table) + \ + "ORDER BY committed_at").collect() + snapshots = with_cpu_session(setup_snapshots) + start_snapshot, end_snapshot = [ row[0] for row in snapshots[:2] ] + assert_gpu_and_cpu_are_equal_collect( + lambda spark : spark.read \ + .option("start-snapshot-id", start_snapshot) \ + .option("end-snapshot-id", end_snapshot) \ + .format("iceberg").load("default.{}".format(table))) + +@iceberg +@ignore_order(local=True) # Iceberg plans with a thread pool and is not deterministic in file ordering +def test_iceberg_reorder_columns(spark_tmp_table_factory): + table = spark_tmp_table_factory.get() + tmpview = spark_tmp_table_factory.get() + def setup_iceberg_table(spark): + df = binary_op_df(spark, long_gen) + df.createOrReplaceTempView(tmpview) + spark.sql("CREATE TABLE {} USING ICEBERG ".format(table) + \ + "AS SELECT * FROM {}".format(tmpview)) + spark.sql("ALTER TABLE {} ALTER COLUMN b FIRST".format(table)) + df = binary_op_df(spark, long_gen, seed=1) + df.createOrReplaceTempView(tmpview) + spark.sql("INSERT INTO {} ".format(table) + \ + "SELECT * FROM {}".format(tmpview)) + with_cpu_session(setup_iceberg_table) + assert_gpu_and_cpu_are_equal_collect(lambda spark : spark.sql("SELECT * FROM {}".format(table))) + +@iceberg +@ignore_order(local=True) # Iceberg plans with a thread pool and is not deterministic in file ordering +def test_iceberg_rename_column(spark_tmp_table_factory): + table = spark_tmp_table_factory.get() + tmpview = spark_tmp_table_factory.get() + def setup_iceberg_table(spark): + df = binary_op_df(spark, long_gen) + df.createOrReplaceTempView(tmpview) + spark.sql("CREATE TABLE {} USING ICEBERG ".format(table) + \ + "AS SELECT * FROM {}".format(tmpview)) + spark.sql("ALTER TABLE {} RENAME COLUMN a TO c".format(table)) + df = binary_op_df(spark, long_gen, seed=1) + df.createOrReplaceTempView(tmpview) + spark.sql("INSERT INTO {} ".format(table) + \ + "SELECT * FROM {}".format(tmpview)) + with_cpu_session(setup_iceberg_table) + assert_gpu_and_cpu_are_equal_collect(lambda spark : spark.sql("SELECT * FROM {}".format(table))) + +@iceberg +@ignore_order(local=True) # Iceberg plans with a thread pool and is not deterministic in file ordering +def test_iceberg_column_names_swapped(spark_tmp_table_factory): + table = spark_tmp_table_factory.get() + tmpview = spark_tmp_table_factory.get() + def setup_iceberg_table(spark): + df = binary_op_df(spark, long_gen) + df.createOrReplaceTempView(tmpview) + spark.sql("CREATE TABLE {} USING ICEBERG ".format(table) + \ + "AS SELECT * FROM {}".format(tmpview)) + spark.sql("ALTER TABLE {} RENAME COLUMN a TO c".format(table)) + spark.sql("ALTER TABLE {} RENAME COLUMN b TO a".format(table)) + spark.sql("ALTER TABLE {} RENAME COLUMN c TO b".format(table)) + df = binary_op_df(spark, long_gen, seed=1) + df.createOrReplaceTempView(tmpview) + spark.sql("INSERT INTO {} ".format(table) + \ + "SELECT * FROM {}".format(tmpview)) + with_cpu_session(setup_iceberg_table) + assert_gpu_and_cpu_are_equal_collect(lambda spark : spark.sql("SELECT * FROM {}".format(table))) + # test column removed and more data appended # test column added and more data appended -# test column type changed and more data appended -# test reordering struct fields and appending more data -# test column names swapped -# test time-travel with snapshot IDs and timestamps -# test reading data between two snapshot IDs -# https://iceberg.apache.org/docs/latest/spark-queries/ # test v2 deletes (position and equality) diff --git a/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/parquet/GpuParquetReader.java b/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/parquet/GpuParquetReader.java index 0e30e30c94d..9863e2f104d 100644 --- a/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/parquet/GpuParquetReader.java +++ b/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/parquet/GpuParquetReader.java @@ -40,6 +40,7 @@ import org.apache.iceberg.io.CloseableIterable; import org.apache.iceberg.io.InputFile; import org.apache.iceberg.mapping.NameMapping; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; import org.apache.iceberg.relocated.com.google.common.collect.Lists; import org.apache.iceberg.relocated.com.google.common.collect.Maps; @@ -54,6 +55,11 @@ import org.apache.parquet.schema.Types.MessageTypeBuilder; import org.apache.spark.sql.execution.datasources.PartitionedFile; +import org.apache.spark.sql.types.ArrayType; +import org.apache.spark.sql.types.DataType; +import org.apache.spark.sql.types.MapType; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; import org.apache.spark.sql.types.StructType; import org.apache.spark.sql.vectorized.ColumnarBatch; @@ -134,15 +140,16 @@ public org.apache.iceberg.io.CloseableIterator iterator() { } } - StructType sparkSchema = SparkSchemaUtil.convertWithoutConstants(expectedSchema, idToConstant); MessageType fileReadSchema = buildFileReadSchema(fileSchema); Seq clippedBlocks = GpuParquetUtils.clipBlocksToSchema( fileReadSchema, filteredRowGroups, caseSensitive); + StructType partReaderSparkSchema = (StructType) TypeWithSchemaVisitor.visit( + expectedSchema.asStruct(), fileReadSchema, new SparkSchemaConverter()); // reuse Parquet scan code to read the raw data from the file ParquetPartitionReader parquetPartReader = new ParquetPartitionReader(conf, partFile, new Path(input.location()), clippedBlocks, fileReadSchema, caseSensitive, - sparkSchema, debugDumpPrefix, maxBatchSizeRows, maxBatchSizeBytes, metrics, + partReaderSparkSchema, debugDumpPrefix, maxBatchSizeRows, maxBatchSizeBytes, metrics, true, true, true, false); PartitionReaderWithBytesRead partReader = new PartitionReaderWithBytesRead(parquetPartReader); @@ -174,6 +181,66 @@ private MessageType buildFileReadSchema(MessageType fileSchema) { } } + /** Generate the Spark schema corresponding to a Parquet schema and expected Iceberg schema */ + private static class SparkSchemaConverter extends TypeWithSchemaVisitor { + @Override + public DataType message(Types.StructType iStruct, MessageType message, List fields) { + return struct(iStruct, message, fields); + } + + @Override + public DataType struct(Types.StructType iStruct, GroupType struct, List fieldTypes) { + List parquetFields = struct.getFields(); + List fields = Lists.newArrayListWithExpectedSize(fieldTypes.size()); + + for (int i = 0; i < parquetFields.size(); i += 1) { + Type parquetField = parquetFields.get(i); + + Preconditions.checkArgument( + !parquetField.isRepetition(Type.Repetition.REPEATED), + "Fields cannot have repetition REPEATED: %s", parquetField); + + boolean isNullable = parquetField.isRepetition(Type.Repetition.OPTIONAL); + StructField field = new StructField(parquetField.getName(), fieldTypes.get(i), + isNullable, Metadata.empty()); + fields.add(field); + } + + return new StructType(fields.toArray(new StructField[0])); + } + + @Override + public DataType list(Types.ListType iList, GroupType array, DataType elementType) { + GroupType repeated = array.getType(0).asGroupType(); + Type element = repeated.getType(0); + + Preconditions.checkArgument( + !element.isRepetition(Type.Repetition.REPEATED), + "Elements cannot have repetition REPEATED: %s", element); + + boolean isNullable = element.isRepetition(Type.Repetition.OPTIONAL); + return new ArrayType(elementType, isNullable); + } + + @Override + public DataType map(Types.MapType iMap, GroupType map, DataType keyType, DataType valueType) { + GroupType keyValue = map.getType(0).asGroupType(); + Type value = keyValue.getType(1); + + Preconditions.checkArgument( + !value.isRepetition(Type.Repetition.REPEATED), + "Values cannot have repetition REPEATED: %s", value); + + boolean isValueNullable = value.isRepetition(Type.Repetition.OPTIONAL); + return new MapType(keyType, valueType, isValueNullable); + } + + @Override + public DataType primitive(org.apache.iceberg.types.Type.PrimitiveType iPrimitive, PrimitiveType primitiveType) { + return SparkSchemaUtil.convert(iPrimitive); + } + } + private static class ReorderColumns extends TypeWithSchemaVisitor { private final Map idToConstant; @@ -185,7 +252,6 @@ public ReorderColumns(Map idToConstant) { public Type message(Types.StructType expected, MessageType message, List fields) { MessageTypeBuilder builder = org.apache.parquet.schema.Types.buildMessage(); List newFields = filterAndReorder(expected, fields); - // TODO: Avoid re-creating type if nothing changed for (Type type : newFields) { builder.addField(type); } @@ -194,7 +260,6 @@ public Type message(Types.StructType expected, MessageType message, List f @Override public Type struct(Types.StructType expected, GroupType struct, List fields) { - // TODO: Avoid re-creating type if nothing changed List newFields = filterAndReorder(expected, fields); return struct.withNewFields(newFields); } diff --git a/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/spark/SparkSchemaUtil.java b/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/spark/SparkSchemaUtil.java index 123de4276a5..5e04c82d6d1 100644 --- a/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/spark/SparkSchemaUtil.java +++ b/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/spark/SparkSchemaUtil.java @@ -24,17 +24,13 @@ import org.apache.iceberg.MetadataColumns; import org.apache.iceberg.Schema; import org.apache.iceberg.exceptions.ValidationException; -import org.apache.iceberg.relocated.com.google.common.collect.Lists; import org.apache.iceberg.relocated.com.google.common.math.LongMath; import org.apache.iceberg.types.Type; import org.apache.iceberg.types.TypeUtil; import org.apache.iceberg.types.Types; import org.apache.spark.sql.types.DataType; -import org.apache.spark.sql.types.Metadata; -import org.apache.spark.sql.types.StructField; import org.apache.spark.sql.types.StructType; -import org.apache.spark.sql.types.StructType$; /** * Helper methods for working with Spark/Hive metadata. @@ -65,33 +61,6 @@ public static DataType convert(Type type) { return TypeUtil.visit(type, new TypeToSparkType()); } - public static StructType convertWithoutConstants(Schema schema, Map idToConstant) { - return (StructType) TypeUtil.visit(schema, new TypeToSparkType() { - @Override - public DataType struct(Types.StructType struct, List fieldResults) { - List fields = struct.fields(); - - List sparkFields = Lists.newArrayListWithExpectedSize(fieldResults.size()); - for (int i = 0; i < fields.size(); i += 1) { - Types.NestedField field = fields.get(i); - // skip fields that are constants - if (idToConstant.containsKey(field.fieldId())) { - continue; - } - DataType type = fieldResults.get(i); - StructField sparkField = StructField.apply( - field.name(), type, field.isOptional(), Metadata.empty()); - if (field.doc() != null) { - sparkField = sparkField.withComment(field.doc()); - } - sparkFields.add(sparkField); - } - - return StructType$.MODULE$.apply(sparkFields); - } - }); - } - /** * Estimate approximate table size based on Spark schema and total records. * From 168042753b45273ced0edc1804a369859a2a23bf Mon Sep 17 00:00:00 2001 From: Jason Lowe Date: Wed, 29 Jun 2022 16:39:47 -0500 Subject: [PATCH 21/36] Fix Iceberg reads for missing columns --- .../src/main/python/iceberg_test.py | 18 +++++++++ .../iceberg/parquet/GpuParquetReader.java | 37 ++++++++++++------- 2 files changed, 42 insertions(+), 13 deletions(-) diff --git a/integration_tests/src/main/python/iceberg_test.py b/integration_tests/src/main/python/iceberg_test.py index 983f9b87549..b4162b317a5 100644 --- a/integration_tests/src/main/python/iceberg_test.py +++ b/integration_tests/src/main/python/iceberg_test.py @@ -338,6 +338,24 @@ def setup_iceberg_table(spark): with_cpu_session(setup_iceberg_table) assert_gpu_and_cpu_are_equal_collect(lambda spark : spark.sql("SELECT * FROM {}".format(table))) +@iceberg +@ignore_order(local=True) +def test_iceberg_add_column(spark_tmp_table_factory): + table = spark_tmp_table_factory.get() + tmpview = spark_tmp_table_factory.get() + def setup_iceberg_table(spark): + df = binary_op_df(spark, long_gen) + df.createOrReplaceTempView(tmpview) + spark.sql("CREATE TABLE {} USING ICEBERG ".format(table) + \ + "AS SELECT * FROM {}".format(tmpview)) + spark.sql("ALTER TABLE {} ADD COLUMNS (c DOUBLE)".format(table)) + df = three_col_df(spark, long_gen, long_gen, double_gen) + df.createOrReplaceTempView(tmpview) + spark.sql("INSERT INTO {} ".format(table) + \ + "SELECT * FROM {}".format(tmpview)) + with_cpu_session(setup_iceberg_table) + assert_gpu_and_cpu_are_equal_collect(lambda spark : spark.sql("SELECT * FROM {}".format(table))) + # test column removed and more data appended # test column added and more data appended # test v2 deletes (position and equality) diff --git a/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/parquet/GpuParquetReader.java b/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/parquet/GpuParquetReader.java index 9863e2f104d..41550eab820 100644 --- a/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/parquet/GpuParquetReader.java +++ b/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/parquet/GpuParquetReader.java @@ -21,6 +21,7 @@ import java.util.List; import java.util.Map; import java.util.Objects; +import java.util.Set; import scala.collection.Seq; @@ -44,6 +45,7 @@ import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; import org.apache.iceberg.relocated.com.google.common.collect.Lists; import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.apache.iceberg.relocated.com.google.common.collect.Sets; import org.apache.iceberg.types.Types; import org.apache.parquet.ParquetReadOptions; import org.apache.parquet.hadoop.ParquetFileReader; @@ -140,7 +142,11 @@ public org.apache.iceberg.io.CloseableIterator iterator() { } } - MessageType fileReadSchema = buildFileReadSchema(fileSchema); + ReorderColumns reorder = ParquetSchemaUtil.hasIds(fileSchema) ? new ReorderColumns(idToConstant) + : new ReorderColumnsFallback(idToConstant); + MessageType fileReadSchema = (MessageType) TypeWithSchemaVisitor.visit( + expectedSchema.asStruct(), fileSchema, reorder); + Seq clippedBlocks = GpuParquetUtils.clipBlocksToSchema( fileReadSchema, filteredRowGroups, caseSensitive); StructType partReaderSparkSchema = (StructType) TypeWithSchemaVisitor.visit( @@ -153,7 +159,8 @@ public org.apache.iceberg.io.CloseableIterator iterator() { true, true, true, false); PartitionReaderWithBytesRead partReader = new PartitionReaderWithBytesRead(parquetPartReader); - return new GpuIcebergReader(expectedSchema, partReader, deleteFilter, idToConstant); + Map updatedConstants = addNullsForMissingFields(idToConstant, reorder.getMissingFields()); + return new GpuIcebergReader(expectedSchema, partReader, deleteFilter, updatedConstants); } catch (IOException e) { throw new UncheckedIOException("Failed to create/close reader for file: " + input, e); } @@ -167,18 +174,15 @@ private static ParquetFileReader newReader(InputFile file, ParquetReadOptions op } } - // Filter out any unreferenced and metadata columns and reorder the columns - // to match the expected schema. - private MessageType buildFileReadSchema(MessageType fileSchema) { - if (ParquetSchemaUtil.hasIds(fileSchema)) { - return (MessageType) - TypeWithSchemaVisitor.visit(expectedSchema.asStruct(), fileSchema, - new ReorderColumns(idToConstant)); - } else { - return (MessageType) - TypeWithSchemaVisitor.visit(expectedSchema.asStruct(), fileSchema, - new ReorderColumnsFallback(idToConstant)); + private Map addNullsForMissingFields(Map idToConstant, Set missingFields) { + if (missingFields.isEmpty()) { + return idToConstant; + } + Map updated = Maps.newHashMap(idToConstant); + for (Integer field : missingFields) { + updated.put(field, null); } + return updated; } /** Generate the Spark schema corresponding to a Parquet schema and expected Iceberg schema */ @@ -243,11 +247,16 @@ public DataType primitive(org.apache.iceberg.types.Type.PrimitiveType iPrimitive private static class ReorderColumns extends TypeWithSchemaVisitor { private final Map idToConstant; + private final Set missingFields = Sets.newHashSet(); public ReorderColumns(Map idToConstant) { this.idToConstant = idToConstant; } + public Set getMissingFields() { + return missingFields; + } + @Override public Type message(Types.StructType expected, MessageType message, List fields) { MessageTypeBuilder builder = org.apache.parquet.schema.Types.buildMessage(); @@ -326,6 +335,8 @@ private List filterAndReorder(Types.StructType expected, List fields Type newField = typesById.get(id); if (newField != null) { reorderedFields.add(newField); + } else { + missingFields.add(id); } } } From 793f6adb2667d58c013ebf60502aa2fb6c02a7ce Mon Sep 17 00:00:00 2001 From: Jason Lowe Date: Thu, 30 Jun 2022 11:30:23 -0500 Subject: [PATCH 22/36] Add Iceberg partition update and delete tests --- .../src/main/python/iceberg_test.py | 88 ++++++++++++++++++- 1 file changed, 84 insertions(+), 4 deletions(-) diff --git a/integration_tests/src/main/python/iceberg_test.py b/integration_tests/src/main/python/iceberg_test.py index b4162b317a5..4308ab2e5ad 100644 --- a/integration_tests/src/main/python/iceberg_test.py +++ b/integration_tests/src/main/python/iceberg_test.py @@ -339,7 +339,7 @@ def setup_iceberg_table(spark): assert_gpu_and_cpu_are_equal_collect(lambda spark : spark.sql("SELECT * FROM {}".format(table))) @iceberg -@ignore_order(local=True) +@ignore_order(local=True) # Iceberg plans with a thread pool and is not deterministic in file ordering def test_iceberg_add_column(spark_tmp_table_factory): table = spark_tmp_table_factory.get() tmpview = spark_tmp_table_factory.get() @@ -356,6 +356,86 @@ def setup_iceberg_table(spark): with_cpu_session(setup_iceberg_table) assert_gpu_and_cpu_are_equal_collect(lambda spark : spark.sql("SELECT * FROM {}".format(table))) -# test column removed and more data appended -# test column added and more data appended -# test v2 deletes (position and equality) +@iceberg +@ignore_order(local=True) # Iceberg plans with a thread pool and is not deterministic in file ordering +def test_iceberg_remove_column(spark_tmp_table_factory): + table = spark_tmp_table_factory.get() + tmpview = spark_tmp_table_factory.get() + def setup_iceberg_table(spark): + df = binary_op_df(spark, long_gen) + df.createOrReplaceTempView(tmpview) + spark.sql("CREATE TABLE {} USING ICEBERG ".format(table) + \ + "AS SELECT * FROM {}".format(tmpview)) + spark.sql("ALTER TABLE {} DROP COLUMN a".format(table)) + df = unary_op_df(spark, long_gen) + df.createOrReplaceTempView(tmpview) + spark.sql("INSERT INTO {} ".format(table) + \ + "SELECT * FROM {}".format(tmpview)) + with_cpu_session(setup_iceberg_table) + assert_gpu_and_cpu_are_equal_collect(lambda spark : spark.sql("SELECT * FROM {}".format(table))) + +@iceberg +@ignore_order(local=True) # Iceberg plans with a thread pool and is not deterministic in file ordering +def test_iceberg_add_partition_field(spark_tmp_table_factory): + table = spark_tmp_table_factory.get() + tmpview = spark_tmp_table_factory.get() + def setup_iceberg_table(spark): + df = binary_op_df(spark, int_gen) + df.createOrReplaceTempView(tmpview) + spark.sql("CREATE TABLE {} USING ICEBERG ".format(table) + \ + "AS SELECT * FROM {}".format(tmpview)) + spark.sql("ALTER TABLE {} ADD PARTITION FIELD b".format(table)) + df = binary_op_df(spark, int_gen) + df.createOrReplaceTempView(tmpview) + spark.sql("INSERT INTO {} ".format(table) + \ + "SELECT * FROM {}".format(tmpview)) + with_cpu_session(setup_iceberg_table) + assert_gpu_and_cpu_are_equal_collect(lambda spark : spark.sql("SELECT * FROM {}".format(table))) + +@iceberg +@ignore_order(local=True) # Iceberg plans with a thread pool and is not deterministic in file ordering +def test_iceberg_drop_partition_field(spark_tmp_table_factory): + table = spark_tmp_table_factory.get() + tmpview = spark_tmp_table_factory.get() + def setup_iceberg_table(spark): + df = binary_op_df(spark, int_gen) + df.createOrReplaceTempView(tmpview) + spark.sql("CREATE TABLE {} (a INT, b INT) USING ICEBERG PARTITIONED BY (b)".format(table)) + spark.sql("INSERT INTO {} SELECT * FROM {}".format(table, tmpview)) + spark.sql("ALTER TABLE {} DROP PARTITION FIELD b".format(table)) + df = binary_op_df(spark, int_gen) + df.createOrReplaceTempView(tmpview) + spark.sql("INSERT INTO {} ".format(table) + \ + "SELECT * FROM {}".format(tmpview)) + with_cpu_session(setup_iceberg_table) + assert_gpu_and_cpu_are_equal_collect(lambda spark : spark.sql("SELECT * FROM {}".format(table))) + +@iceberg +@ignore_order(local=True) # Iceberg plans with a thread pool and is not deterministic in file ordering +def test_iceberg_v1_delete(spark_tmp_table_factory): + table = spark_tmp_table_factory.get() + tmpview = spark_tmp_table_factory.get() + def setup_iceberg_table(spark): + df = binary_op_df(spark, long_gen) + df.createOrReplaceTempView(tmpview) + spark.sql("CREATE TABLE {} USING ICEBERG ".format(table) + \ + "AS SELECT * FROM {}".format(tmpview)) + spark.sql("DELETE FROM {} WHERE a < 0".format(table)) + with_cpu_session(setup_iceberg_table) + assert_gpu_and_cpu_are_equal_collect(lambda spark : spark.sql("SELECT * FROM {}".format(table))) + +@iceberg +def test_iceberg_v2_delete_unsupported(spark_tmp_table_factory): + table = spark_tmp_table_factory.get() + tmpview = spark_tmp_table_factory.get() + def setup_iceberg_table(spark): + df = binary_op_df(spark, long_gen) + df.createOrReplaceTempView(tmpview) + spark.sql("CREATE TABLE {} USING ICEBERG ".format(table) + \ + "TBLPROPERTIES('format-version' = 2, 'write.delete.mode' = 'merge-on-read') " + \ + "AS SELECT * FROM {}".format(tmpview)) + spark.sql("DELETE FROM {} WHERE a < 0".format(table)) + with_cpu_session(setup_iceberg_table) + assert_py4j_exception( + lambda : with_gpu_session(lambda spark : spark.sql("SELECT * FROM {}".format(table)).collect()), + "UnsupportedOperationException: Delete filter is not supported") From 94ca5d6a2097c17079830e5ec44adc9ea93f7626 Mon Sep 17 00:00:00 2001 From: Jason Lowe Date: Thu, 30 Jun 2022 15:49:43 -0500 Subject: [PATCH 23/36] Fix Iceberg upcasting during reads --- .../src/main/python/iceberg_test.py | 20 ++++++++++++ .../iceberg/parquet/GpuParquetReader.java | 26 ++++++++++++++- .../spark/source/GpuIcebergReader.java | 32 ++++++++++++++++++- 3 files changed, 76 insertions(+), 2 deletions(-) diff --git a/integration_tests/src/main/python/iceberg_test.py b/integration_tests/src/main/python/iceberg_test.py index 4308ab2e5ad..acbd83ff4d3 100644 --- a/integration_tests/src/main/python/iceberg_test.py +++ b/integration_tests/src/main/python/iceberg_test.py @@ -338,6 +338,26 @@ def setup_iceberg_table(spark): with_cpu_session(setup_iceberg_table) assert_gpu_and_cpu_are_equal_collect(lambda spark : spark.sql("SELECT * FROM {}".format(table))) +@iceberg +@ignore_order(local=True) # Iceberg plans with a thread pool and is not deterministic in file ordering +def test_iceberg_alter_column_type(spark_tmp_table_factory): + table = spark_tmp_table_factory.get() + tmpview = spark_tmp_table_factory.get() + def setup_iceberg_table(spark): + df = three_col_df(spark, int_gen, float_gen, DecimalGen(precision=7, scale=3)) + df.createOrReplaceTempView(tmpview) + spark.sql("CREATE TABLE {} USING ICEBERG ".format(table) + \ + "AS SELECT * FROM {}".format(tmpview)) + spark.sql("ALTER TABLE {} ALTER COLUMN a TYPE BIGINT".format(table)) + spark.sql("ALTER TABLE {} ALTER COLUMN b TYPE DOUBLE".format(table)) + spark.sql("ALTER TABLE {} ALTER COLUMN c TYPE DECIMAL(17, 3)".format(table)) + df = three_col_df(spark, long_gen, double_gen, DecimalGen(precision=17, scale=3)) + df.createOrReplaceTempView(tmpview) + spark.sql("INSERT INTO {} ".format(table) + \ + "SELECT * FROM {}".format(tmpview)) + with_cpu_session(setup_iceberg_table) + assert_gpu_and_cpu_are_equal_collect(lambda spark : spark.sql("SELECT * FROM {}".format(table))) + @iceberg @ignore_order(local=True) # Iceberg plans with a thread pool and is not deterministic in file ordering def test_iceberg_add_column(spark_tmp_table_factory): diff --git a/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/parquet/GpuParquetReader.java b/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/parquet/GpuParquetReader.java index 41550eab820..02db5c4c9c9 100644 --- a/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/parquet/GpuParquetReader.java +++ b/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/parquet/GpuParquetReader.java @@ -50,6 +50,7 @@ import org.apache.parquet.ParquetReadOptions; import org.apache.parquet.hadoop.ParquetFileReader; import org.apache.parquet.hadoop.metadata.BlockMetaData; +import org.apache.parquet.schema.DecimalMetadata; import org.apache.parquet.schema.GroupType; import org.apache.parquet.schema.MessageType; import org.apache.parquet.schema.PrimitiveType; @@ -59,6 +60,12 @@ import org.apache.spark.sql.execution.datasources.PartitionedFile; import org.apache.spark.sql.types.ArrayType; import org.apache.spark.sql.types.DataType; +import org.apache.spark.sql.types.DecimalType$; +import org.apache.spark.sql.types.DoubleType$; +import org.apache.spark.sql.types.FloatType$; +import org.apache.spark.sql.types.IntegerType$; +import org.apache.spark.sql.types.LongType; +import org.apache.spark.sql.types.LongType$; import org.apache.spark.sql.types.MapType; import org.apache.spark.sql.types.Metadata; import org.apache.spark.sql.types.StructField; @@ -241,7 +248,24 @@ public DataType map(Types.MapType iMap, GroupType map, DataType keyType, DataTyp @Override public DataType primitive(org.apache.iceberg.types.Type.PrimitiveType iPrimitive, PrimitiveType primitiveType) { - return SparkSchemaUtil.convert(iPrimitive); + // If up-casts are needed, load as the pre-cast Spark type, and this will be up-cast in GpuIcebergReader. + switch (iPrimitive.typeId()) { + case LONG: + if (primitiveType.getPrimitiveTypeName().equals(PrimitiveType.PrimitiveTypeName.INT32)) { + return IntegerType$.MODULE$; + } + return LongType$.MODULE$; + case DOUBLE: + if (primitiveType.getPrimitiveTypeName().equals(PrimitiveType.PrimitiveTypeName.FLOAT)) { + return FloatType$.MODULE$; + } + return DoubleType$.MODULE$; + case DECIMAL: + DecimalMetadata metadata = primitiveType.getDecimalMetadata(); + return DecimalType$.MODULE$.apply(metadata.getPrecision(), metadata.getScale()); + default: + return SparkSchemaUtil.convert(iPrimitive); + } } } diff --git a/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/spark/source/GpuIcebergReader.java b/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/spark/source/GpuIcebergReader.java index 8cc5b7aaafc..5a2fca10461 100644 --- a/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/spark/source/GpuIcebergReader.java +++ b/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/spark/source/GpuIcebergReader.java @@ -23,12 +23,14 @@ import java.util.NoSuchElementException; import ai.rapids.cudf.Scalar; +import com.nvidia.spark.rapids.GpuCast; import com.nvidia.spark.rapids.GpuColumnVector; import com.nvidia.spark.rapids.GpuScalar; import com.nvidia.spark.rapids.iceberg.data.GpuDeleteFilter; import com.nvidia.spark.rapids.iceberg.spark.SparkSchemaUtil; import org.apache.iceberg.Schema; import org.apache.iceberg.io.CloseableIterator; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; import org.apache.iceberg.types.Type; import org.apache.iceberg.types.TypeUtil; import org.apache.iceberg.types.Types; @@ -85,7 +87,8 @@ public ColumnarBatch next() { if (deleteFilter != null) { throw new UnsupportedOperationException("Delete filter is not supported"); } - return addConstantColumns(batch); + ColumnarBatch updatedBatch = addConstantColumns(batch); + return addUpcastsIfNeeded(updatedBatch); } } @@ -128,6 +131,33 @@ private ColumnarBatch addConstantColumns(ColumnarBatch batch) { return result; } + private ColumnarBatch addUpcastsIfNeeded(ColumnarBatch batch) { + GpuColumnVector[] columns = null; + try { + List expectedColumnTypes = expectedSchema.columns(); + Preconditions.checkState(expectedColumnTypes.size() == batch.numCols(), + "Expected to load " + expectedColumnTypes.size() + " columns, found " + batch.numCols()); + columns = GpuColumnVector.extractColumns(batch); + for (int i = 0; i < batch.numCols(); i++) { + DataType expectedSparkType = SparkSchemaUtil.convert(expectedColumnTypes.get(i).type()); + GpuColumnVector oldColumn = columns[i]; + columns[i] = GpuColumnVector.from( + GpuCast.doCast(oldColumn.getBase(), oldColumn.dataType(), expectedSparkType, false, false, false), + expectedSparkType); + } + ColumnarBatch newBatch = new ColumnarBatch(columns, batch.numRows()); + columns = null; + return newBatch; + } finally { + batch.close(); + if (columns != null) { + for (ColumnVector c : columns) { + c.close(); + } + } + } + } + private static class ConstantDetector extends TypeUtil.SchemaVisitor { private final Map idToConstant; From 4a8e296388be977a352a74b0575ef2fdded8f814 Mon Sep 17 00:00:00 2001 From: Jason Lowe Date: Thu, 30 Jun 2022 16:15:10 -0500 Subject: [PATCH 24/36] Suppress some warnings --- .../iceberg/parquet/ParquetDictionaryRowGroupFilter.java | 5 +++-- .../rapids/iceberg/parquet/ParquetMetricsRowGroupFilter.java | 5 +++-- .../com/nvidia/spark/rapids/iceberg/spark/SparkReadConf.java | 1 + 3 files changed, 7 insertions(+), 4 deletions(-) diff --git a/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/parquet/ParquetDictionaryRowGroupFilter.java b/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/parquet/ParquetDictionaryRowGroupFilter.java index 3f0940fc0da..8daefd2732a 100644 --- a/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/parquet/ParquetDictionaryRowGroupFilter.java +++ b/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/parquet/ParquetDictionaryRowGroupFilter.java @@ -17,13 +17,13 @@ package com.nvidia.spark.rapids.iceberg.parquet; import java.io.IOException; +import java.io.UncheckedIOException; import java.util.Comparator; import java.util.Map; import java.util.Set; import java.util.function.Function; import org.apache.iceberg.Schema; -import org.apache.iceberg.exceptions.RuntimeIOException; import org.apache.iceberg.expressions.Binder; import org.apache.iceberg.expressions.BoundReference; import org.apache.iceberg.expressions.Expression; @@ -104,6 +104,7 @@ private boolean eval(MessageType fileSchema, BlockMetaData rowGroup, } for (ColumnChunkMetaData meta : rowGroup.getColumns()) { + @SuppressWarnings("deprecation") PrimitiveType colType = fileSchema.getType(meta.getPath().toArray()).asPrimitiveType(); if (colType.getId() != null) { int id = colType.getId().intValue(); @@ -416,7 +417,7 @@ private Set dict(int id, Comparator comparator) { try { dict = page.getEncoding().initDictionary(col, page); } catch (IOException e) { - throw new RuntimeIOException("Failed to create reader for dictionary page"); + throw new UncheckedIOException(new IOException("Failed to create reader for dictionary page")); } Set dictSet = Sets.newTreeSet(comparator); diff --git a/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/parquet/ParquetMetricsRowGroupFilter.java b/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/parquet/ParquetMetricsRowGroupFilter.java index bd60184f0c2..4590385ac7f 100644 --- a/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/parquet/ParquetMetricsRowGroupFilter.java +++ b/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/parquet/ParquetMetricsRowGroupFilter.java @@ -88,6 +88,7 @@ private boolean eval(MessageType fileSchema, BlockMetaData rowGroup) { this.valueCounts = Maps.newHashMap(); this.conversions = Maps.newHashMap(); for (ColumnChunkMetaData col : rowGroup.getColumns()) { + @SuppressWarnings("deprecation") PrimitiveType colType = fileSchema.getType(col.getPath().toArray()).asPrimitiveType(); if (colType.getId() != null) { int id = colType.getId().intValue(); @@ -554,12 +555,12 @@ private T max(Statistics statistics, int id) { * @param valueCount Number of values in the row group * @return true if nonNull values exist and no other stats can be used */ - static boolean hasNonNullButNoMinMax(Statistics statistics, long valueCount) { + static boolean hasNonNullButNoMinMax(Statistics statistics, long valueCount) { return statistics.getNumNulls() < valueCount && (statistics.getMaxBytes() == null || statistics.getMinBytes() == null); } - private static boolean mayContainNull(Statistics statistics) { + private static boolean mayContainNull(Statistics statistics) { return !statistics.isNumNullsSet() || statistics.getNumNulls() > 0; } } diff --git a/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/spark/SparkReadConf.java b/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/spark/SparkReadConf.java index abe45a91076..bcfabeb276a 100644 --- a/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/spark/SparkReadConf.java +++ b/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/spark/SparkReadConf.java @@ -52,6 +52,7 @@ public class SparkReadConf { private final Map readOptions; private final SparkConfParser confParser; + @SuppressWarnings("unchecked") public static SparkReadConf fromReflect(Object obj) throws IllegalAccessException { SparkSession spark = SparkSession.active(); Table table = (Table) FieldUtils.readField(obj, "table", true); From 83e0bb82919449280e57097ae6e553247de7e2cf Mon Sep 17 00:00:00 2001 From: Jason Lowe Date: Fri, 1 Jul 2022 10:14:52 -0500 Subject: [PATCH 25/36] Update to Iceberg 0.13.2 --- pom.xml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pom.xml b/pom.xml index ab26e37a790..4f49e8019d3 100644 --- a/pom.xml +++ b/pom.xml @@ -1033,7 +1033,7 @@ 1.7.30 1.11.0 3.3.1 - 0.13.1 + 0.13.2 org/scala-lang/scala-library/${scala.version}/scala-library-${scala.version}.jar ${spark.version.classifier} 3.1.0 From d65c2a46fe97b89fa77c82bdcbda933a70f41259 Mon Sep 17 00:00:00 2001 From: Jason Lowe Date: Fri, 1 Jul 2022 11:11:09 -0500 Subject: [PATCH 26/36] Skip tests not supported on Spark 3.1.x --- integration_tests/src/main/python/iceberg_test.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/integration_tests/src/main/python/iceberg_test.py b/integration_tests/src/main/python/iceberg_test.py index acbd83ff4d3..3264e4d98b8 100644 --- a/integration_tests/src/main/python/iceberg_test.py +++ b/integration_tests/src/main/python/iceberg_test.py @@ -235,6 +235,7 @@ def setup_iceberg_table(spark): @iceberg @ignore_order(local=True) # Iceberg plans with a thread pool and is not deterministic in file ordering +@pytest.mark.skipif(is_before_spark_320(), reason="Spark 3.1.x has a catalog bug precluding scope prefix in table names") def test_iceberg_read_timetravel(spark_tmp_table_factory): table = spark_tmp_table_factory.get() tmpview = spark_tmp_table_factory.get() @@ -256,6 +257,7 @@ def setup_snapshots(spark): @iceberg @ignore_order(local=True) # Iceberg plans with a thread pool and is not deterministic in file ordering +@pytest.mark.skipif(is_before_spark_320(), reason="Spark 3.1.x has a catalog bug precluding scope prefix in table names") def test_iceberg_incremental_read(spark_tmp_table_factory): table = spark_tmp_table_factory.get() tmpview = spark_tmp_table_factory.get() @@ -408,7 +410,7 @@ def setup_iceberg_table(spark): df = binary_op_df(spark, int_gen) df.createOrReplaceTempView(tmpview) spark.sql("INSERT INTO {} ".format(table) + \ - "SELECT * FROM {}".format(tmpview)) + "SELECT * FROM {} ORDER BY b".format(tmpview)) with_cpu_session(setup_iceberg_table) assert_gpu_and_cpu_are_equal_collect(lambda spark : spark.sql("SELECT * FROM {}".format(table))) @@ -421,7 +423,7 @@ def setup_iceberg_table(spark): df = binary_op_df(spark, int_gen) df.createOrReplaceTempView(tmpview) spark.sql("CREATE TABLE {} (a INT, b INT) USING ICEBERG PARTITIONED BY (b)".format(table)) - spark.sql("INSERT INTO {} SELECT * FROM {}".format(table, tmpview)) + spark.sql("INSERT INTO {} SELECT * FROM {} ORDER BY b".format(table, tmpview)) spark.sql("ALTER TABLE {} DROP PARTITION FIELD b".format(table)) df = binary_op_df(spark, int_gen) df.createOrReplaceTempView(tmpview) @@ -445,6 +447,7 @@ def setup_iceberg_table(spark): assert_gpu_and_cpu_are_equal_collect(lambda spark : spark.sql("SELECT * FROM {}".format(table))) @iceberg +@pytest.mark.skipif(is_before_spark_320(), reason="merge-on-read not supported on Spark 3.1.x") def test_iceberg_v2_delete_unsupported(spark_tmp_table_factory): table = spark_tmp_table_factory.get() tmpview = spark_tmp_table_factory.get() From 97b281d88d466e7d8443e8790027970fa4957eb5 Mon Sep 17 00:00:00 2001 From: Jason Lowe Date: Fri, 1 Jul 2022 11:19:27 -0500 Subject: [PATCH 27/36] Add docs detailing Iceberg support --- .../iceberg-support.md | 62 ++++++++++++++ docs/configs.md | 2 + docs/supported_ops.md | 46 +++++++++++ jenkins/databricks/test.sh | 2 +- jenkins/spark-tests.sh | 2 +- .../rapids/shims/Spark320PlusShims.scala | 2 +- .../iceberg/parquet/ApplyNameMapping.java | 1 + .../iceberg/parquet/ParquetConversions.java | 1 + .../ParquetDictionaryRowGroupFilter.java | 1 + .../rapids/iceberg/parquet/ParquetIO.java | 1 + .../parquet/ParquetMetricsRowGroupFilter.java | 1 + .../iceberg/parquet/ParquetSchemaUtil.java | 1 + .../iceberg/parquet/ParquetTypeVisitor.java | 1 + .../rapids/iceberg/parquet/ParquetUtil.java | 1 + .../parquet/TypeWithSchemaVisitor.java | 1 + .../rapids/iceberg/spark/Spark3Util.java | 1 + .../rapids/iceberg/spark/SparkConfParser.java | 1 + .../rapids/iceberg/spark/SparkFilters.java | 1 + .../rapids/iceberg/spark/SparkReadConf.java | 1 + .../iceberg/spark/SparkReadOptions.java | 3 +- .../iceberg/spark/SparkSQLProperties.java | 1 + .../rapids/iceberg/spark/SparkSchemaUtil.java | 1 + .../iceberg/spark/SparkTypeVisitor.java | 82 ------------------- .../spark/rapids/iceberg/spark/SparkUtil.java | 1 + .../rapids/iceberg/spark/TypeToSparkType.java | 1 + .../iceberg/spark/source/BaseDataReader.java | 1 + .../spark/source/GpuBatchDataReader.java | 2 +- .../spark/source/GpuIcebergReader.java | 4 + .../iceberg/spark/source/SparkBatch.java | 1 + .../rapids/iceberg/spark/source/Stats.java | 1 + .../com/nvidia/spark/rapids/TypeChecks.scala | 3 + .../main/resources/supportedDataSource.csv | 1 + 32 files changed, 144 insertions(+), 87 deletions(-) create mode 100644 docs/additional-functionality/iceberg-support.md delete mode 100644 sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/spark/SparkTypeVisitor.java diff --git a/docs/additional-functionality/iceberg-support.md b/docs/additional-functionality/iceberg-support.md new file mode 100644 index 00000000000..9c4910e624c --- /dev/null +++ b/docs/additional-functionality/iceberg-support.md @@ -0,0 +1,62 @@ +--- +layout: page +title: Apache Iceberg Support +parent: Additional Functionality +nav_order: 7 +--- + +# Apache Iceberg Support + +The RAPIDS Accelerator for Apache Spark provides limited support for Apache Iceberg tables. +This document details the Apache Iceberg features that are supported. + +## Apache Iceberg Versions + +The RAPIDS Accelerator supports Apache Iceberg 0.13.x. Earlier versions of Apache Iceberg are +not supported. + +## Reading Tables + +### Metadata Queries + +Reads of Apache Iceberg metadata, i.e.: the `history`, `snapshots`, and other metadata tables +associated with a table, will not be GPU-accelerated. The CPU will continue to process these +metadata-level queries. + +### Row-level Delete and Update Support + +Apache Iceberg supports row-level deletions and updates. Tables that are using a configuration of +`write.delete.mode=merge-on-read` are not supported. + +### Schema Evolution + +Columns that are added and removed at the top level of the table schema are supported. Columns +that are added or removed within struct columns are not supported. + +### Data Formats + +Apache Iceberg can store data in various formats. Each section below details the levels of support +for each of the underlying data formats. + +#### Parquet + +Data stored in Parquet is supported with the same limitations for loading data from raw Parquet +files. See the [Input/Output](../supported_ops.md#inputoutput) documentation for details. The +following compression codecs applied to the Parquet data are supported: +- gzip (Apache Iceberg default) +- snappy +- uncompressed +- zstd + +#### ORC + +The RAPIDS Accelerator does not support Apache Iceberg tables using the ORC data format. + +#### Avro + +The RAPIDS Accelerator does not support Apache Iceberg tables using the ORC data format. + +## Writing Tables + +The RAPIDS Accelerator for Apache Spark does not accelerate Apache Iceberg writes. Writes +to Iceberg tables will be processed by the CPU. diff --git a/docs/configs.md b/docs/configs.md index 66164fc8d7e..cd2f042c15a 100644 --- a/docs/configs.md +++ b/docs/configs.md @@ -81,6 +81,8 @@ Name | Description | Default Value spark.rapids.sql.format.avro.reader.type|Sets the Avro reader type. We support different types that are optimized for different environments. The original Spark style reader can be selected by setting this to PERFILE which individually reads and copies files to the GPU. Loading many small files individually has high overhead, and using either COALESCING or MULTITHREADED is recommended instead. The COALESCING reader is good when using a local file system where the executors are on the same nodes or close to the nodes the data is being read on. This reader coalesces all the files assigned to a task into a single host buffer before sending it down to the GPU. It copies blocks from a single file into a host buffer in separate threads in parallel, see spark.rapids.sql.multiThreadedRead.numThreads. MULTITHREADED is good for cloud environments where you are reading from a blobstore that is totally separate and likely has a higher I/O read cost. Many times the cloud environments also get better throughput when you have multiple readers in parallel. This reader uses multiple threads to read each file in parallel and each file is sent to the GPU separately. This allows the CPU to keep reading while GPU is also doing work. See spark.rapids.sql.multiThreadedRead.numThreads and spark.rapids.sql.format.avro.multiThreadedRead.maxNumFilesParallel to control the number of threads and amount of memory used. By default this is set to AUTO so we select the reader we think is best. This will either be the COALESCING or the MULTITHREADED based on whether we think the file is in the cloud. See spark.rapids.cloudSchemes.|AUTO spark.rapids.sql.format.csv.enabled|When set to false disables all csv input and output acceleration. (only input is currently supported anyways)|true spark.rapids.sql.format.csv.read.enabled|When set to false disables csv input acceleration|true +spark.rapids.sql.format.iceberg.enabled|When set to false disables all Iceberg acceleration|true +spark.rapids.sql.format.iceberg.read.enabled|When set to false disables Iceberg input acceleration|true spark.rapids.sql.format.json.enabled|When set to true enables all json input and output acceleration. (only input is currently supported anyways)|false spark.rapids.sql.format.json.read.enabled|When set to true enables json input acceleration|false spark.rapids.sql.format.orc.enabled|When set to false disables all orc input and output acceleration|true diff --git a/docs/supported_ops.md b/docs/supported_ops.md index 66946d040ab..c50da5407d9 100644 --- a/docs/supported_ops.md +++ b/docs/supported_ops.md @@ -18218,6 +18218,49 @@ dates or timestamps, or for a lack of type coercion support.

+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + @@ -18347,3 +18390,6 @@ dates or timestamps, or for a lack of type coercion support.
IcebergReadSSSSSSSSPS
UTC is only supported TZ for TIMESTAMP
SS NS PS
UTC is only supported TZ for child TIMESTAMP;
unsupported child types BINARY, UDT
PS
UTC is only supported TZ for child TIMESTAMP;
unsupported child types BINARY, UDT
PS
UTC is only supported TZ for child TIMESTAMP;
unsupported child types BINARY, UDT
NS
WriteNSNSNSNSNSNSNSNSNSNSNS NS NSNSNSNS
JSON Read SNS
+### Apache Iceberg Support +Support for Apache Iceberg has additional limitations. See the +[Apache Iceberg Support](additional-functionality/iceberg-support.md) document. diff --git a/jenkins/databricks/test.sh b/jenkins/databricks/test.sh index 0ff4381005c..5b8c4a7f8df 100755 --- a/jenkins/databricks/test.sh +++ b/jenkins/databricks/test.sh @@ -76,7 +76,7 @@ TEST_MODE=${TEST_MODE:-'IT_ONLY'} TEST_TYPE="nightly" PCBS_CONF="com.nvidia.spark.ParquetCachedBatchSerializer" -ICEBERG_VERSION=0.13.1 +ICEBERG_VERSION=0.13.2 ICEBERG_SPARK_VER=$(echo $BASE_SPARK_VER | cut -d. -f1,2) ICEBERG_CONFS="--packages org.apache.iceberg:iceberg-spark-runtime-${ICEBERG_SPARK_VER}_2.12:${ICEBERG_VERSION} \ --conf spark.sql.extensions=org.apache.iceberg.spark.extensions.IcebergSparkSessionExtensions \ diff --git a/jenkins/spark-tests.sh b/jenkins/spark-tests.sh index 64877f8f949..20a21d4214e 100755 --- a/jenkins/spark-tests.sh +++ b/jenkins/spark-tests.sh @@ -166,7 +166,7 @@ export TARGET_DIR="$SCRIPT_PATH/target" mkdir -p $TARGET_DIR run_iceberg_tests() { - ICEBERG_VERSION="0.13.1" + ICEBERG_VERSION="0.13.2" # get the major/minor version of Spark ICEBERG_SPARK_VER=$(echo $SPARK_VER | cut -d. -f1,2) diff --git a/sql-plugin/src/main/320+/scala/com/nvidia/spark/rapids/shims/Spark320PlusShims.scala b/sql-plugin/src/main/320+/scala/com/nvidia/spark/rapids/shims/Spark320PlusShims.scala index 8ad3b64d6b3..60f92353ce1 100644 --- a/sql-plugin/src/main/320+/scala/com/nvidia/spark/rapids/shims/Spark320PlusShims.scala +++ b/sql-plugin/src/main/320+/scala/com/nvidia/spark/rapids/shims/Spark320PlusShims.scala @@ -460,7 +460,7 @@ trait Spark320PlusShims extends SparkShims with RebaseShims with Logging { } override val childExprs: Seq[BaseExprMeta[_]] = { - // We want to leave the runtime filters as CPU expressions, so leave them out of the expressions + // We want to leave the runtime filters as CPU expressions p.output.map(GpuOverrides.wrapExpr(_, conf, Some(this))) } diff --git a/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/parquet/ApplyNameMapping.java b/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/parquet/ApplyNameMapping.java index d71118e6708..55d22bfbd96 100644 --- a/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/parquet/ApplyNameMapping.java +++ b/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/parquet/ApplyNameMapping.java @@ -31,6 +31,7 @@ import org.apache.parquet.schema.Type; import org.apache.parquet.schema.Types; +/** Derived from Apache Iceberg's ApplyNameMapping Parquet support class. */ public class ApplyNameMapping extends ParquetTypeVisitor { private static final String LIST_ELEMENT_NAME = "element"; private static final String MAP_KEY_NAME = "key"; diff --git a/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/parquet/ParquetConversions.java b/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/parquet/ParquetConversions.java index 9dd5d24ed3a..0d81521cc47 100644 --- a/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/parquet/ParquetConversions.java +++ b/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/parquet/ParquetConversions.java @@ -26,6 +26,7 @@ import org.apache.parquet.io.api.Binary; import org.apache.parquet.schema.PrimitiveType; +/** Derived from Apache Iceberg's ParquetConversions class. */ public class ParquetConversions { private ParquetConversions() { } diff --git a/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/parquet/ParquetDictionaryRowGroupFilter.java b/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/parquet/ParquetDictionaryRowGroupFilter.java index 8daefd2732a..cba801addde 100644 --- a/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/parquet/ParquetDictionaryRowGroupFilter.java +++ b/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/parquet/ParquetDictionaryRowGroupFilter.java @@ -47,6 +47,7 @@ import org.apache.parquet.schema.MessageType; import org.apache.parquet.schema.PrimitiveType; +/** Derived from Apache Iceberg's ParquetDictionaryRowGroupFilter class. */ public class ParquetDictionaryRowGroupFilter { private final Schema schema; private final Expression expr; diff --git a/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/parquet/ParquetIO.java b/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/parquet/ParquetIO.java index 4c6a91eb3dc..595c0794703 100644 --- a/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/parquet/ParquetIO.java +++ b/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/parquet/ParquetIO.java @@ -28,6 +28,7 @@ import org.apache.parquet.io.InputFile; import org.apache.parquet.io.SeekableInputStream; +/** Derived from Apache Iceberg's ParquetIO class. */ public class ParquetIO { private ParquetIO() { } diff --git a/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/parquet/ParquetMetricsRowGroupFilter.java b/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/parquet/ParquetMetricsRowGroupFilter.java index 4590385ac7f..39698eef738 100644 --- a/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/parquet/ParquetMetricsRowGroupFilter.java +++ b/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/parquet/ParquetMetricsRowGroupFilter.java @@ -44,6 +44,7 @@ import org.apache.parquet.schema.MessageType; import org.apache.parquet.schema.PrimitiveType; +/** Derived from Apache Iceberg's ParquetMetricsRowGroupFilter class. */ public class ParquetMetricsRowGroupFilter { private static final int IN_PREDICATE_LIMIT = 200; diff --git a/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/parquet/ParquetSchemaUtil.java b/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/parquet/ParquetSchemaUtil.java index e89301a67e4..71f0385e100 100644 --- a/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/parquet/ParquetSchemaUtil.java +++ b/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/parquet/ParquetSchemaUtil.java @@ -25,6 +25,7 @@ import org.apache.parquet.schema.Type; import org.apache.parquet.schema.Types.MessageTypeBuilder; +/** Derived from Apache Iceberg's ParquetSchemaUtil class. */ public class ParquetSchemaUtil { private ParquetSchemaUtil() { } diff --git a/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/parquet/ParquetTypeVisitor.java b/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/parquet/ParquetTypeVisitor.java index 614688fb118..b668d412ae5 100644 --- a/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/parquet/ParquetTypeVisitor.java +++ b/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/parquet/ParquetTypeVisitor.java @@ -27,6 +27,7 @@ import org.apache.parquet.schema.PrimitiveType; import org.apache.parquet.schema.Type; +/** Derived from Apache Iceberg's ParquetTypeVisitor class. */ public class ParquetTypeVisitor { private final Deque fieldNames = Lists.newLinkedList(); diff --git a/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/parquet/ParquetUtil.java b/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/parquet/ParquetUtil.java index d060748f965..786a6079c97 100644 --- a/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/parquet/ParquetUtil.java +++ b/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/parquet/ParquetUtil.java @@ -23,6 +23,7 @@ import org.apache.parquet.column.EncodingStats; import org.apache.parquet.hadoop.metadata.ColumnChunkMetaData; +/** Derived from Apache Iceberg's ParquetUtil class. */ public class ParquetUtil { // not meant to be instantiated private ParquetUtil() { diff --git a/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/parquet/TypeWithSchemaVisitor.java b/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/parquet/TypeWithSchemaVisitor.java index 090b5e712f4..b2962a94f61 100644 --- a/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/parquet/TypeWithSchemaVisitor.java +++ b/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/parquet/TypeWithSchemaVisitor.java @@ -29,6 +29,7 @@ /** * Visitor for traversing a Parquet type with a companion Iceberg type. + * Derived from Apache Iceberg's TypeWithSchemaVisitor class. * * @param the Java class returned by the visitor */ diff --git a/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/spark/Spark3Util.java b/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/spark/Spark3Util.java index 472b2ccfc7b..6a3e2ce48cd 100644 --- a/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/spark/Spark3Util.java +++ b/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/spark/Spark3Util.java @@ -30,6 +30,7 @@ import org.apache.spark.sql.connector.expressions.Expressions; import org.apache.spark.sql.connector.expressions.NamedReference; +/** Derived from Apache Iceberg's Spark3Util class. */ public class Spark3Util { private Spark3Util() { diff --git a/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/spark/SparkConfParser.java b/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/spark/SparkConfParser.java index 948c56dace4..a640c2810a3 100644 --- a/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/spark/SparkConfParser.java +++ b/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/spark/SparkConfParser.java @@ -28,6 +28,7 @@ import org.apache.spark.sql.RuntimeConfig; import org.apache.spark.sql.SparkSession; +/** Derived from Apache Iceberg's SparkConfParser class. */ public class SparkConfParser { private final Map properties; diff --git a/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/spark/SparkFilters.java b/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/spark/SparkFilters.java index 46ebbeec163..a141e6e6e59 100644 --- a/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/spark/SparkFilters.java +++ b/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/spark/SparkFilters.java @@ -69,6 +69,7 @@ import static org.apache.iceberg.expressions.Expressions.or; import static org.apache.iceberg.expressions.Expressions.startsWith; +/** Derived from Apache Iceberg's SparkFilters class. */ public class SparkFilters { private static final Pattern BACKTICKS_PATTERN = Pattern.compile("([`])(.|$)"); diff --git a/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/spark/SparkReadConf.java b/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/spark/SparkReadConf.java index bcfabeb276a..f80a7be7c7f 100644 --- a/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/spark/SparkReadConf.java +++ b/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/spark/SparkReadConf.java @@ -31,6 +31,7 @@ /** * A class for common Iceberg configs for Spark reads. + * Derived from Apache Iceberg's SparkReadConf class. *

* If a config is set at multiple levels, the following order of precedence is used (top to bottom): *

    diff --git a/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/spark/SparkReadOptions.java b/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/spark/SparkReadOptions.java index 92a72e56c76..5034f57894e 100644 --- a/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/spark/SparkReadOptions.java +++ b/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/spark/SparkReadOptions.java @@ -17,7 +17,8 @@ package com.nvidia.spark.rapids.iceberg.spark; /** - * Spark DF read options + * Spark DF read options. + * Derived from Apache Iceberg's SparkReadOptions class. */ public class SparkReadOptions { diff --git a/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/spark/SparkSQLProperties.java b/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/spark/SparkSQLProperties.java index c24fe950ded..866ff7ff0c5 100644 --- a/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/spark/SparkSQLProperties.java +++ b/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/spark/SparkSQLProperties.java @@ -16,6 +16,7 @@ package com.nvidia.spark.rapids.iceberg.spark; +/** Derived from Apache Iceberg's SparkSQLProperties class. */ public class SparkSQLProperties { private SparkSQLProperties() { diff --git a/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/spark/SparkSchemaUtil.java b/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/spark/SparkSchemaUtil.java index 5e04c82d6d1..284e589acb3 100644 --- a/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/spark/SparkSchemaUtil.java +++ b/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/spark/SparkSchemaUtil.java @@ -34,6 +34,7 @@ /** * Helper methods for working with Spark/Hive metadata. + * Derived from Apache Iceberg's SparkSchemaUtil class. */ public class SparkSchemaUtil { private SparkSchemaUtil() { diff --git a/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/spark/SparkTypeVisitor.java b/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/spark/SparkTypeVisitor.java deleted file mode 100644 index 2407d9c77dd..00000000000 --- a/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/spark/SparkTypeVisitor.java +++ /dev/null @@ -1,82 +0,0 @@ -/* - * Copyright (c) 2022, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.nvidia.spark.rapids.iceberg.spark; - -import java.util.List; - -import org.apache.iceberg.relocated.com.google.common.collect.Lists; - -import org.apache.spark.sql.types.ArrayType; -import org.apache.spark.sql.types.DataType; -import org.apache.spark.sql.types.MapType; -import org.apache.spark.sql.types.StructField; -import org.apache.spark.sql.types.StructType; -import org.apache.spark.sql.types.UserDefinedType; - -public class SparkTypeVisitor { - static T visit(DataType type, SparkTypeVisitor visitor) { - if (type instanceof StructType) { - StructField[] fields = ((StructType) type).fields(); - List fieldResults = Lists.newArrayListWithExpectedSize(fields.length); - - for (StructField field : fields) { - fieldResults.add(visitor.field( - field, - visit(field.dataType(), visitor))); - } - - return visitor.struct((StructType) type, fieldResults); - - } else if (type instanceof MapType) { - return visitor.map((MapType) type, - visit(((MapType) type).keyType(), visitor), - visit(((MapType) type).valueType(), visitor)); - - } else if (type instanceof ArrayType) { - return visitor.array( - (ArrayType) type, - visit(((ArrayType) type).elementType(), visitor)); - - } else if (type instanceof UserDefinedType) { - throw new UnsupportedOperationException( - "User-defined types are not supported"); - - } else { - return visitor.atomic(type); - } - } - - public T struct(StructType struct, List fieldResults) { - return null; - } - - public T field(StructField field, T typeResult) { - return null; - } - - public T array(ArrayType array, T elementResult) { - return null; - } - - public T map(MapType map, T keyResult, T valueResult) { - return null; - } - - public T atomic(DataType atomic) { - return null; - } -} diff --git a/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/spark/SparkUtil.java b/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/spark/SparkUtil.java index 1a558d3b955..7a2a3089b9b 100644 --- a/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/spark/SparkUtil.java +++ b/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/spark/SparkUtil.java @@ -20,6 +20,7 @@ import org.apache.iceberg.types.TypeUtil; import org.apache.iceberg.types.Types; +/** Derived from Apache Iceberg's SparkUtil class. */ public class SparkUtil { public static final String TIMESTAMP_WITHOUT_TIMEZONE_ERROR = String.format("Cannot handle timestamp without" + diff --git a/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/spark/TypeToSparkType.java b/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/spark/TypeToSparkType.java index 862f5a15c6f..652021d9342 100644 --- a/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/spark/TypeToSparkType.java +++ b/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/spark/TypeToSparkType.java @@ -41,6 +41,7 @@ import org.apache.spark.sql.types.StructType$; import org.apache.spark.sql.types.TimestampType$; +/** Derived from Apache Iceberg's TypeToSparkType class. */ public class TypeToSparkType extends TypeUtil.SchemaVisitor { TypeToSparkType() { } diff --git a/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/spark/source/BaseDataReader.java b/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/spark/source/BaseDataReader.java index 7b638625bf3..c4639a0625d 100644 --- a/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/spark/source/BaseDataReader.java +++ b/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/spark/source/BaseDataReader.java @@ -56,6 +56,7 @@ /** * Base class of Spark readers. + * Derived from Apache Spark's BaseDataReader class. * * @param is the Java class returned by this reader whose objects contain one or more rows. */ diff --git a/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/spark/source/GpuBatchDataReader.java b/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/spark/source/GpuBatchDataReader.java index 0ba8696a13f..bc014228d41 100644 --- a/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/spark/source/GpuBatchDataReader.java +++ b/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/spark/source/GpuBatchDataReader.java @@ -38,7 +38,7 @@ import org.apache.spark.rdd.InputFileBlockHolder; import org.apache.spark.sql.vectorized.ColumnarBatch; -/** GPU version of Apache Iceberg's BatchDataReader */ +/** GPU version of Apache Iceberg's BatchDataReader. */ class GpuBatchDataReader extends BaseDataReader { private final Schema expectedSchema; private final String nameMapping; diff --git a/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/spark/source/GpuIcebergReader.java b/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/spark/source/GpuIcebergReader.java index 5a2fca10461..228a03cf483 100644 --- a/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/spark/source/GpuIcebergReader.java +++ b/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/spark/source/GpuIcebergReader.java @@ -40,6 +40,10 @@ import org.apache.spark.sql.vectorized.ColumnVector; import org.apache.spark.sql.vectorized.ColumnarBatch; +/** + * Takes a partition reader output and adds any constant columns and deletion filters + * that need to be applied after the data is loaded from the raw data files. + */ public class GpuIcebergReader implements CloseableIterator { private final Schema expectedSchema; private final PartitionReader partReader; diff --git a/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/spark/source/SparkBatch.java b/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/spark/source/SparkBatch.java index ec39e370c2c..bc00688ca31 100644 --- a/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/spark/source/SparkBatch.java +++ b/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/spark/source/SparkBatch.java @@ -36,6 +36,7 @@ import org.apache.spark.sql.connector.read.PartitionReaderFactory; import org.apache.spark.util.SerializableConfiguration; +/** Derived from Apache Iceberg's SparkBatch class. */ public class SparkBatch implements Batch { private final JavaSparkContext sparkContext; diff --git a/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/spark/source/Stats.java b/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/spark/source/Stats.java index e5ec0567c05..bb01fa0ae34 100644 --- a/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/spark/source/Stats.java +++ b/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/spark/source/Stats.java @@ -20,6 +20,7 @@ import org.apache.spark.sql.connector.read.Statistics; +/** Derived from Apache Iceberg's Stats class. */ public class Stats implements Statistics { private final OptionalLong sizeInBytes; private final OptionalLong numRows; diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/TypeChecks.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/TypeChecks.scala index 536005f74da..cc77f8b8960 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/TypeChecks.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/TypeChecks.scala @@ -2145,6 +2145,9 @@ object SupportedOpsDocs { totalCount += 2 } println("") + println("### Apache Iceberg Support") + println("Support for Apache Iceberg has additional limitations. See the") + println("[Apache Iceberg Support](additional-functionality/iceberg-support.md) document.") // scalastyle:on line.size.limit } diff --git a/tools/src/main/resources/supportedDataSource.csv b/tools/src/main/resources/supportedDataSource.csv index 45128d0d069..71cb5819e54 100644 --- a/tools/src/main/resources/supportedDataSource.csv +++ b/tools/src/main/resources/supportedDataSource.csv @@ -1,6 +1,7 @@ Format,Direction,BOOLEAN,BYTE,SHORT,INT,LONG,FLOAT,DOUBLE,DATE,TIMESTAMP,STRING,DECIMAL,NULL,BINARY,CALENDAR,ARRAY,MAP,STRUCT,UDT Avro,read,CO,CO,CO,CO,CO,CO,CO,CO,CO,CO,CO,CO,CO,CO,CO,CO,CO,CO CSV,read,S,S,S,S,S,S,S,S,PS,S,S,NA,NS,NA,NA,NA,NA,NA +Iceberg,read,S,S,S,S,S,S,S,S,PS,S,S,NA,NS,NA,PS,PS,PS,NS JSON,read,CO,CO,CO,CO,CO,CO,CO,CO,CO,CO,CO,CO,CO,CO,CO,CO,CO,CO ORC,read,S,S,S,S,S,S,S,S,PS,S,S,NA,NS,NA,PS,PS,PS,NS ORC,write,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA From 1b3e594836dda5556e00b21d04b1043c51e30b64 Mon Sep 17 00:00:00 2001 From: Jason Lowe Date: Fri, 1 Jul 2022 15:24:44 -0500 Subject: [PATCH 28/36] Fix Iceberg doc reference --- docs/supported_ops.md | 1 + .../src/main/scala/com/nvidia/spark/rapids/TypeChecks.scala | 1 + 2 files changed, 2 insertions(+) diff --git a/docs/supported_ops.md b/docs/supported_ops.md index 124ac3b840c..6e8f5594019 100644 --- a/docs/supported_ops.md +++ b/docs/supported_ops.md @@ -18390,6 +18390,7 @@ dates or timestamps, or for a lack of type coercion support. NS + ### Apache Iceberg Support Support for Apache Iceberg has additional limitations. See the [Apache Iceberg Support](additional-functionality/iceberg-support.md) document. diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/TypeChecks.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/TypeChecks.scala index cc77f8b8960..ac5b2c90a02 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/TypeChecks.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/TypeChecks.scala @@ -2145,6 +2145,7 @@ object SupportedOpsDocs { totalCount += 2 } println("") + println() println("### Apache Iceberg Support") println("Support for Apache Iceberg has additional limitations. See the") println("[Apache Iceberg Support](additional-functionality/iceberg-support.md) document.") From 06433abdd85679a189e7f4be654d8619a81c91d8 Mon Sep 17 00:00:00 2001 From: Jason Lowe Date: Mon, 11 Jul 2022 11:07:23 -0500 Subject: [PATCH 29/36] Fix spark330 build --- .../rapids/iceberg/parquet/GpuParquet.java | 5 ++-- .../iceberg/parquet/GpuParquetReader.java | 1 - .../spark/rapids/PartitionedFileUtils.scala | 30 +++++++++++++++++++ 3 files changed, 33 insertions(+), 3 deletions(-) create mode 100644 sql-plugin/src/main/scala/com/nvidia/spark/rapids/PartitionedFileUtils.scala diff --git a/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/parquet/GpuParquet.java b/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/parquet/GpuParquet.java index 61fe0a9e317..1ba555a37e0 100644 --- a/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/parquet/GpuParquet.java +++ b/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/parquet/GpuParquet.java @@ -20,6 +20,7 @@ import java.util.Map; import com.nvidia.spark.rapids.GpuMetric; +import com.nvidia.spark.rapids.PartitionedFileUtils; import com.nvidia.spark.rapids.iceberg.data.GpuDeleteFilter; import org.apache.hadoop.conf.Configuration; import org.apache.iceberg.Schema; @@ -156,8 +157,8 @@ public CloseableIterable build() { ParquetReadOptions options = optionsBuilder.build(); - PartitionedFile partFile = new PartitionedFile(InternalRow.empty(), file.location(), - start, length, null); + PartitionedFile partFile = PartitionedFileUtils.newPartitionedFile( + InternalRow.empty(), file.location(), start, length); return new GpuParquetReader(file, projectSchema, options, nameMapping, filter, caseSensitive, idToConstant, deleteFilter, partFile, conf, maxBatchSizeRows, maxBatchSizeBytes, debugDumpPrefix, metrics); diff --git a/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/parquet/GpuParquetReader.java b/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/parquet/GpuParquetReader.java index 02db5c4c9c9..1ebdc6c4e8c 100644 --- a/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/parquet/GpuParquetReader.java +++ b/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/parquet/GpuParquetReader.java @@ -64,7 +64,6 @@ import org.apache.spark.sql.types.DoubleType$; import org.apache.spark.sql.types.FloatType$; import org.apache.spark.sql.types.IntegerType$; -import org.apache.spark.sql.types.LongType; import org.apache.spark.sql.types.LongType$; import org.apache.spark.sql.types.MapType; import org.apache.spark.sql.types.Metadata; diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/PartitionedFileUtils.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/PartitionedFileUtils.scala new file mode 100644 index 00000000000..ad7d58d211d --- /dev/null +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/PartitionedFileUtils.scala @@ -0,0 +1,30 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark.rapids + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.execution.datasources.PartitionedFile + +object PartitionedFileUtils { + // Wrapper for case class constructor so Java code can access + // the default values across Spark versions. + def newPartitionedFile( + partitionValues: InternalRow, + filePath: String, + start: Long, + length: Long): PartitionedFile = PartitionedFile(partitionValues, filePath, start, length) +} From ed5e5a5e974aa7e83129eb2a401c709eecdce3df Mon Sep 17 00:00:00 2001 From: Jason Lowe Date: Mon, 11 Jul 2022 11:08:46 -0500 Subject: [PATCH 30/36] Fix paste error in Iceberg support docs --- docs/additional-functionality/iceberg-support.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/additional-functionality/iceberg-support.md b/docs/additional-functionality/iceberg-support.md index 9c4910e624c..f32a55be59f 100644 --- a/docs/additional-functionality/iceberg-support.md +++ b/docs/additional-functionality/iceberg-support.md @@ -54,7 +54,7 @@ The RAPIDS Accelerator does not support Apache Iceberg tables using the ORC data #### Avro -The RAPIDS Accelerator does not support Apache Iceberg tables using the ORC data format. +The RAPIDS Accelerator does not support Apache Iceberg tables using the Avro data format. ## Writing Tables From e7fe43a3769f9afd0b9561a30b76f6cbca9715b8 Mon Sep 17 00:00:00 2001 From: Jason Lowe Date: Thu, 14 Jul 2022 14:43:12 -0500 Subject: [PATCH 31/36] Add new 320+/java directory for recently added shims --- pom.xml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pom.xml b/pom.xml index 4f49e8019d3..72e7eba1d53 100644 --- a/pom.xml +++ b/pom.xml @@ -831,6 +831,7 @@ ${project.basedir}/src/main/331/scala ${project.basedir}/src/main/311+-nondb/scala ${project.basedir}/src/main/311until340-all/scala + ${project.basedir}/src/main/320+/java ${project.basedir}/src/main/320+/scala ${project.basedir}/src/main/320+-nondb/scala ${project.basedir}/src/main/320+-noncdh/scala @@ -895,6 +896,7 @@ ${project.basedir}/src/main/340/scala ${project.basedir}/src/main/311+-nondb/scala + ${project.basedir}/src/main/320+/java ${project.basedir}/src/main/320+/scala ${project.basedir}/src/main/320+-nondb/scala ${project.basedir}/src/main/320+-noncdh/scala From 495d3a57320661408f9d3f4a443fd661c4022339 Mon Sep 17 00:00:00 2001 From: Jason Lowe Date: Mon, 18 Jul 2022 11:06:14 -0500 Subject: [PATCH 32/36] Add protections for errors during conversion of CPU scan --- .../spark/rapids/iceberg/IcebergProviderImpl.scala | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/IcebergProviderImpl.scala b/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/IcebergProviderImpl.scala index edde8111a94..9d440885e91 100644 --- a/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/IcebergProviderImpl.scala +++ b/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/IcebergProviderImpl.scala @@ -17,6 +17,7 @@ package com.nvidia.spark.rapids.iceberg import scala.reflect.ClassTag +import scala.util.{Failure, Try} import com.nvidia.spark.rapids.{FileFormatChecks, IcebergFormatType, RapidsConf, ReadFileOp, ScanMeta, ScanRule, ShimLoader} import com.nvidia.spark.rapids.iceberg.spark.source.GpuSparkBatchQueryScan @@ -30,6 +31,10 @@ class IcebergProviderImpl extends IcebergProvider { val cpuIcebergScanClass = ShimLoader.loadClass(IcebergProvider.cpuScanClassName) Seq(new ScanRule[Scan]( (a, conf, p, r) => new ScanMeta[Scan](a, conf, p, r) { + private lazy val convertedScan: Try[GpuSparkBatchQueryScan] = Try { + GpuSparkBatchQueryScan.fromCpu(a, conf) + } + override def supportsRuntimeFilters: Boolean = true override def tagSelfForGpu(): Unit = { @@ -48,9 +53,14 @@ class IcebergProviderImpl extends IcebergProvider { if (GpuSparkBatchQueryScan.isMetadataScan(a)) { willNotWorkOnGpu("scan is a metadata scan") } + + convertedScan match { + case Failure(e) => willNotWorkOnGpu(s"conversion to GPU scan failed: ${e.getMessage}") + case _ => + } } - override def convertToGpu(): Scan = GpuSparkBatchQueryScan.fromCpu(a, conf) + override def convertToGpu(): Scan = convertedScan.get }, "Iceberg scan", ClassTag(cpuIcebergScanClass)) From bd539c51e77d3d82e23b92a2feb5b276040bdd66 Mon Sep 17 00:00:00 2001 From: Jason Lowe Date: Mon, 18 Jul 2022 11:13:36 -0500 Subject: [PATCH 33/36] Update test for zstd being fully supported in libcudf --- integration_tests/src/main/python/iceberg_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/integration_tests/src/main/python/iceberg_test.py b/integration_tests/src/main/python/iceberg_test.py index 3264e4d98b8..fb015546ad9 100644 --- a/integration_tests/src/main/python/iceberg_test.py +++ b/integration_tests/src/main/python/iceberg_test.py @@ -128,7 +128,7 @@ def setup_iceberg_table(spark): pytest.param(("lz4", "Unsupported compression type"), marks=pytest.mark.skipif(is_before_spark_320(), reason="Hadoop with Spark 3.1.x does not support lz4 by default")), - ("zstd", "Zstandard compression is experimental")], ids=idfn) + ("zstd", None)], ids=idfn) def test_iceberg_read_parquet_compression_codec(spark_tmp_table_factory, codec_info): codec, error_msg = codec_info table = spark_tmp_table_factory.get() From 98daf5bc3ab6801c4adbd71105c187551f7b2515 Mon Sep 17 00:00:00 2001 From: Jason Lowe Date: Mon, 18 Jul 2022 15:13:48 -0500 Subject: [PATCH 34/36] Address review comments --- .../nvidia/spark/rapids/shims/Spark320PlusShims.scala | 2 +- .../nvidia/spark/rapids/iceberg/IcebergProvider.scala | 5 ++++- .../spark/rapids/iceberg/parquet/GpuParquet.java | 9 +++++++-- .../rapids/iceberg/parquet/GpuParquetReader.java | 11 ++++++++--- .../spark/rapids/iceberg/spark/SparkReadConf.java | 7 ------- .../spark/rapids/iceberg/spark/SparkReadOptions.java | 4 ---- .../iceberg/spark/source/GpuSparkBatchQueryScan.java | 3 ++- .../org/apache/spark/sql/rapids/ExternalSource.scala | 3 +-- 8 files changed, 23 insertions(+), 21 deletions(-) diff --git a/sql-plugin/src/main/320+/scala/com/nvidia/spark/rapids/shims/Spark320PlusShims.scala b/sql-plugin/src/main/320+/scala/com/nvidia/spark/rapids/shims/Spark320PlusShims.scala index 60f92353ce1..5b8803cbea7 100644 --- a/sql-plugin/src/main/320+/scala/com/nvidia/spark/rapids/shims/Spark320PlusShims.scala +++ b/sql-plugin/src/main/320+/scala/com/nvidia/spark/rapids/shims/Spark320PlusShims.scala @@ -469,7 +469,7 @@ trait Spark320PlusShims extends SparkShims with RebaseShims with Logging { override def tagPlanForGpu(): Unit = { if (!p.runtimeFilters.isEmpty && !childScans.head.supportsRuntimeFilters) { - willNotWorkOnGpu("runtime filtering (DPP) is not supported") + willNotWorkOnGpu("runtime filtering (DPP) is not supported for this scan") } } diff --git a/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/IcebergProvider.scala b/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/IcebergProvider.scala index ced6c6ab58b..29219922b58 100644 --- a/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/IcebergProvider.scala +++ b/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/IcebergProvider.scala @@ -16,7 +16,7 @@ package com.nvidia.spark.rapids.iceberg -import com.nvidia.spark.rapids.ScanRule +import com.nvidia.spark.rapids.{ScanRule, ShimLoader} import org.apache.spark.sql.connector.read.Scan @@ -28,5 +28,8 @@ trait IcebergProvider { } object IcebergProvider { + def apply(): IcebergProvider = ShimLoader.newInstanceOf[IcebergProvider]( + "com.nvidia.spark.rapids.iceberg.IcebergProviderImpl") + val cpuScanClassName: String = "org.apache.iceberg.spark.source.SparkBatchQueryScan" } diff --git a/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/parquet/GpuParquet.java b/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/parquet/GpuParquet.java index 1ba555a37e0..cb4963aea3b 100644 --- a/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/parquet/GpuParquet.java +++ b/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/parquet/GpuParquet.java @@ -37,7 +37,12 @@ import org.apache.spark.sql.execution.datasources.PartitionedFile; import org.apache.spark.sql.vectorized.ColumnarBatch; -/** GPU version of Apache Iceberg's Parquet class */ +/** + * GPU version of Apache Iceberg's Parquet class. + * The Iceberg version originally accepted a callback function to create the reader to handle + * vectorized batch vs. row readers, but since the GPU only reads vectorized that abstraction + * has been removed. + */ public class GpuParquet { private static final Collection READ_PROPERTIES_TO_REMOVE = Sets.newHashSet( "parquet.read.filter", "parquet.private.read.filter.predicate", "parquet.read.support.class"); @@ -60,7 +65,7 @@ public static class ReadBuilder { private boolean caseSensitive = true; private NameMapping nameMapping = null; private Configuration conf = null; - private int maxBatchSizeRows = 0; + private int maxBatchSizeRows = Integer.MAX_VALUE; private long maxBatchSizeBytes = Integer.MAX_VALUE; private String debugDumpPrefix = null; private scala.collection.immutable.Map metrics = null; diff --git a/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/parquet/GpuParquetReader.java b/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/parquet/GpuParquetReader.java index 1ebdc6c4e8c..584cb94a447 100644 --- a/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/parquet/GpuParquetReader.java +++ b/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/parquet/GpuParquetReader.java @@ -162,7 +162,11 @@ public org.apache.iceberg.io.CloseableIterator iterator() { ParquetPartitionReader parquetPartReader = new ParquetPartitionReader(conf, partFile, new Path(input.location()), clippedBlocks, fileReadSchema, caseSensitive, partReaderSparkSchema, debugDumpPrefix, maxBatchSizeRows, maxBatchSizeBytes, metrics, - true, true, true, false); + true, // isCorrectedInt96RebaseMode + true, // isCorrectedRebaseMode + true, // hasInt96Timestamps + false // useFieldId + ); PartitionReaderWithBytesRead partReader = new PartitionReaderWithBytesRead(parquetPartReader); Map updatedConstants = addNullsForMissingFields(idToConstant, reorder.getMissingFields()); @@ -180,7 +184,8 @@ private static ParquetFileReader newReader(InputFile file, ParquetReadOptions op } } - private Map addNullsForMissingFields(Map idToConstant, Set missingFields) { + private static Map addNullsForMissingFields(Map idToConstant, + Set missingFields) { if (missingFields.isEmpty()) { return idToConstant; } @@ -203,7 +208,7 @@ public DataType struct(Types.StructType iStruct, GroupType struct, List parquetFields = struct.getFields(); List fields = Lists.newArrayListWithExpectedSize(fieldTypes.size()); - for (int i = 0; i < parquetFields.size(); i += 1) { + for (int i = 0; i < parquetFields.size(); i++) { Type parquetField = parquetFields.get(i); Preconditions.checkArgument( diff --git a/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/spark/SparkReadConf.java b/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/spark/SparkReadConf.java index f80a7be7c7f..04c56c52431 100644 --- a/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/spark/SparkReadConf.java +++ b/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/spark/SparkReadConf.java @@ -124,13 +124,6 @@ public boolean streamingSkipDeleteSnapshots() { .parse(); } - public boolean streamingSkipOverwriteSnapshots() { - return confParser.booleanConf() - .option(SparkReadOptions.STREAMING_SKIP_OVERWRITE_SNAPSHOTS) - .defaultValue(SparkReadOptions.STREAMING_SKIP_OVERWRITE_SNAPSHOTS_DEFAULT) - .parse(); - } - public boolean parquetVectorizationEnabled() { return confParser.booleanConf() .option(SparkReadOptions.VECTORIZATION_ENABLED) diff --git a/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/spark/SparkReadOptions.java b/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/spark/SparkReadOptions.java index 5034f57894e..b425125f291 100644 --- a/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/spark/SparkReadOptions.java +++ b/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/spark/SparkReadOptions.java @@ -59,10 +59,6 @@ private SparkReadOptions() { public static final String STREAMING_SKIP_DELETE_SNAPSHOTS = "streaming-skip-delete-snapshots"; public static final boolean STREAMING_SKIP_DELETE_SNAPSHOTS_DEFAULT = false; - // skip snapshots of type overwrite while reading stream out of iceberg table - public static final String STREAMING_SKIP_OVERWRITE_SNAPSHOTS = "streaming-skip-overwrite-snapshots"; - public static final boolean STREAMING_SKIP_OVERWRITE_SNAPSHOTS_DEFAULT = false; - // Controls whether to allow reading timestamps without zone info public static final String HANDLE_TIMESTAMP_WITHOUT_TIMEZONE = "handle-timestamp-without-timezone"; diff --git a/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/spark/source/GpuSparkBatchQueryScan.java b/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/spark/source/GpuSparkBatchQueryScan.java index eca3d7ed0fc..d1fa44c9469 100644 --- a/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/spark/source/GpuSparkBatchQueryScan.java +++ b/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/spark/source/GpuSparkBatchQueryScan.java @@ -109,7 +109,8 @@ public static GpuSparkBatchQueryScan fromCpu(Scan cpuInstance, RapidsConf rapids return new GpuSparkBatchQueryScan(SparkSession.active(), table, scan, readConf, expectedSchema, filters, rapidsConf); } - // Try to build an Iceberg TableScan when one was not found in the CPU instance + // Try to build an Iceberg TableScan when one was not found in the CPU instance. + // This happens on Spark 3.1 where Iceberg's SparkBatchQueryScan does not have a TableScan. private static TableScan buildScan(Scan cpuInstance, Table table, SparkReadConf readConf, diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/ExternalSource.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/ExternalSource.scala index 167296456a6..43ddbf440e6 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/ExternalSource.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/ExternalSource.scala @@ -57,8 +57,7 @@ object ExternalSource extends Logging { Try(ShimLoader.loadClass(IcebergProvider.cpuScanClassName)).isSuccess } - private lazy val icebergProvider = ShimLoader.newInstanceOf[IcebergProvider]( - "com.nvidia.spark.rapids.iceberg.IcebergProviderImpl") + private lazy val icebergProvider = IcebergProvider() /** If the file format is supported as an external source */ def isSupportedFormat(format: FileFormat): Boolean = { From abccc6b78374dab329717f63b44daeeb5c8ddf32 Mon Sep 17 00:00:00 2001 From: Jason Lowe Date: Thu, 21 Jul 2022 17:21:57 -0500 Subject: [PATCH 35/36] Work around classloader issues in distributed setups --- jenkins/databricks/test.sh | 6 +++++- jenkins/spark-tests.sh | 4 ++++ 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/jenkins/databricks/test.sh b/jenkins/databricks/test.sh index 6523b458ea6..700c39d975b 100755 --- a/jenkins/databricks/test.sh +++ b/jenkins/databricks/test.sh @@ -78,7 +78,11 @@ PCBS_CONF="com.nvidia.spark.ParquetCachedBatchSerializer" ICEBERG_VERSION=0.13.2 ICEBERG_SPARK_VER=$(echo $BASE_SPARK_VER | cut -d. -f1,2) -ICEBERG_CONFS="--packages org.apache.iceberg:iceberg-spark-runtime-${ICEBERG_SPARK_VER}_2.12:${ICEBERG_VERSION} \ +# Classloader config is here to work around classloader issues with +# --packages in distributed setups, should be fixed by +# https://github.com/NVIDIA/spark-rapids/pull/5646 +ICEBERG_CONFS="--conf spark.rapids.force.caller.classloader=false \ + --packages org.apache.iceberg:iceberg-spark-runtime-${ICEBERG_SPARK_VER}_2.12:${ICEBERG_VERSION} \ --conf spark.sql.extensions=org.apache.iceberg.spark.extensions.IcebergSparkSessionExtensions \ --conf spark.sql.catalog.spark_catalog=org.apache.iceberg.spark.SparkSessionCatalog \ --conf spark.sql.catalog.spark_catalog.type=hadoop \ diff --git a/jenkins/spark-tests.sh b/jenkins/spark-tests.sh index 34e6770e961..4f656545f72 100755 --- a/jenkins/spark-tests.sh +++ b/jenkins/spark-tests.sh @@ -179,7 +179,11 @@ run_iceberg_tests() { # Iceberg does not support Spark 3.3+ yet if [[ "$ICEBERG_SPARK_VER" < "3.3" ]]; then + # Classloader config is here to work around classloader issues with + # --packages in distributed setups, should be fixed by + # https://github.com/NVIDIA/spark-rapids/pull/5646 SPARK_SUBMIT_FLAGS="$BASE_SPARK_SUBMIT_ARGS $SEQ_CONF \ + --conf spark.rapids.force.caller.classloader=false \ --packages org.apache.iceberg:iceberg-spark-runtime-${ICEBERG_SPARK_VER}_2.12:${ICEBERG_VERSION} \ --conf spark.sql.extensions=org.apache.iceberg.spark.extensions.IcebergSparkSessionExtensions \ --conf spark.sql.catalog.spark_catalog=org.apache.iceberg.spark.SparkSessionCatalog \ From 5d2fb2718b8d84b1d8683c19f3327757990c6ba1 Mon Sep 17 00:00:00 2001 From: Jason Lowe Date: Thu, 21 Jul 2022 17:52:11 -0500 Subject: [PATCH 36/36] Update to ShimLoader convention --- dist/unshimmed-from-each-spark3xx.txt | 1 + .../java/com/nvidia/spark/rapids/iceberg/IcebergProvider.scala | 3 +-- .../src/main/scala/com/nvidia/spark/rapids/ShimLoader.scala | 3 +++ 3 files changed, 5 insertions(+), 2 deletions(-) diff --git a/dist/unshimmed-from-each-spark3xx.txt b/dist/unshimmed-from-each-spark3xx.txt index 639d486a514..35e7b334b21 100644 --- a/dist/unshimmed-from-each-spark3xx.txt +++ b/dist/unshimmed-from-each-spark3xx.txt @@ -1,5 +1,6 @@ com/nvidia/spark/rapids/*/RapidsShuffleManager* com/nvidia/spark/rapids/AvroProvider.class com/nvidia/spark/rapids/HiveProvider.class +com/nvidia/spark/rapids/IcebergProvider.class org/apache/spark/sql/rapids/shims/*/ProxyRapidsShuffleInternalManager* spark-*-info.properties diff --git a/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/IcebergProvider.scala b/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/IcebergProvider.scala index 29219922b58..cd9cc9666c0 100644 --- a/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/IcebergProvider.scala +++ b/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/IcebergProvider.scala @@ -28,8 +28,7 @@ trait IcebergProvider { } object IcebergProvider { - def apply(): IcebergProvider = ShimLoader.newInstanceOf[IcebergProvider]( - "com.nvidia.spark.rapids.iceberg.IcebergProviderImpl") + def apply(): IcebergProvider = ShimLoader.newIcebergProvider() val cpuScanClassName: String = "org.apache.iceberg.spark.source.SparkBatchQueryScan" } diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/ShimLoader.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/ShimLoader.scala index 6ac86d24a5b..bcdc0b1f997 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/ShimLoader.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/ShimLoader.scala @@ -19,6 +19,7 @@ package com.nvidia.spark.rapids import java.net.URL import com.nvidia.spark.GpuCachedBatchSerializer +import com.nvidia.spark.rapids.iceberg.IcebergProvider import org.apache.commons.lang3.reflect.MethodUtils import scala.annotation.tailrec import scala.collection.JavaConverters._ @@ -462,4 +463,6 @@ object ShimLoader extends Logging { def newAvroProvider(): AvroProvider = ShimLoader.newInstanceOf[AvroProvider]( "org.apache.spark.sql.rapids.AvroProviderImpl") + def newIcebergProvider(): IcebergProvider = ShimLoader.newInstanceOf[IcebergProvider]( + "com.nvidia.spark.rapids.iceberg.IcebergProviderImpl") }