From e27029d278d4d7dba9fa9591646847f1096b36a4 Mon Sep 17 00:00:00 2001 From: Johan Lasperas Date: Tue, 20 Jun 2023 15:43:46 +0200 Subject: [PATCH] Data skipping and column pruning in merge --- .../apache/spark/sql/delta/DeltaTable.scala | 80 ++++++ .../sql/delta/commands/MergeIntoCommand.scala | 186 +++++-------- .../delta/commands/MergeIntoCommandBase.scala | 154 ++++++++++- .../spark/sql/delta/DeltaTestUtils.scala | 11 - .../spark/sql/delta/MergeIntoSuiteBase.scala | 247 ++++++++++++++++-- .../sql/delta/test/ScanReportHelper.scala | 3 +- 6 files changed, 514 insertions(+), 167 deletions(-) diff --git a/spark/src/main/scala/org/apache/spark/sql/delta/DeltaTable.scala b/spark/src/main/scala/org/apache/spark/sql/delta/DeltaTable.scala index 01cfd0ea9e5..8cb90e94b3c 100644 --- a/spark/src/main/scala/org/apache/spark/sql/delta/DeltaTable.scala +++ b/spark/src/main/scala/org/apache/spark/sql/delta/DeltaTable.scala @@ -29,8 +29,10 @@ import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis.{NoSuchTableException, UnresolvedTable} import org.apache.spark.sql.catalyst.catalog.{CatalogTable, SessionCatalog} import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke import org.apache.spark.sql.catalyst.planning.NodeWithOnlyDeterministicProjectAndFilter import org.apache.spark.sql.catalyst.plans.logical.{Filter, LeafNode, LogicalPlan, Project} +import org.apache.spark.sql.catalyst.util.CharVarcharCodegenUtils import org.apache.spark.sql.connector.expressions.{FieldReference, IdentityTransform} import org.apache.spark.sql.execution.datasources.{FileFormat, FileIndex, HadoopFsRelation, LogicalRelation} import org.apache.spark.sql.internal.SQLConf @@ -325,6 +327,84 @@ object DeltaTableUtils extends PredicateHelper } } + /** + * Replace the file index in a logical plan and return the updated plan. + * It's a common pattern that, in Delta commands, we use data skipping to determine a subset of + * files that can be affected by the command, so we replace the whole-table file index in the + * original logical plan with a new index of potentially affected files, while everything else in + * the original plan, e.g., resolved references, remain unchanged. + * + * Many Delta meta-queries involve nondeterminstic functions, which interfere with automatic + * column pruning, so columns can be manually pruned from the new scan. Note that partition + * columns can never be dropped even if they're not referenced in the rest of the query. + * + * @param spark the spark session to use + * @param target the logical plan in which we replace the file index + * @param fileIndex the new file index + * @param columnsToDrop columns to drop from the scan + * @param newOutput If specified, new logical output to replace the current LogicalRelation. + * Used for schema evolution to produce the new schema-evolved types from + * old files, because `target` will have the old types. + */ + def replaceFileIndex( + spark: SparkSession, + target: LogicalPlan, + fileIndex: FileIndex, + columnsToDrop: Seq[String], + newOutput: Option[Seq[AttributeReference]]): LogicalPlan = { + val resolver = spark.sessionState.analyzer.resolver + + var actualNewOutput = newOutput + var hasChar = false + var newTarget = target transformDown { + case l @ LogicalRelation(hfsr: HadoopFsRelation, _, _, _) => + val finalOutput = actualNewOutput.getOrElse(l.output).filterNot { col => + columnsToDrop.exists(resolver(_, col.name)) + } + + // If the output columns were changed e.g. by schema evolution, we need to update + // the relation to expose all the columns that are expected after schema evolution. + val newDataSchema = StructType(finalOutput.map(attr => + StructField(attr.name, attr.dataType, attr.nullable, attr.metadata))) + val newBaseRelation = hfsr.copy( + location = fileIndex, dataSchema = newDataSchema)( + hfsr.sparkSession) + l.copy(relation = newBaseRelation, output = finalOutput) + + case p @ Project(projectList, _) => + def hasCharPadding(e: Expression): Boolean = e.exists { + case s: StaticInvoke => s.staticObject == classOf[CharVarcharCodegenUtils] && + s.functionName == "readSidePadding" + case _ => false + } + val charColMapping = AttributeMap(projectList.collect { + case a: Alias if hasCharPadding(a.child) && a.references.size == 1 => + hasChar = true + val tableCol = a.references.head.asInstanceOf[AttributeReference] + a.toAttribute -> tableCol + }) + actualNewOutput = newOutput.map(_.map { attr => + charColMapping.get(attr).map { tableCol => + attr.withExprId(tableCol.exprId) + }.getOrElse(attr) + }) + p + } + + if (hasChar) { + newTarget = newTarget.transformUp { + case p @ Project(projectList, child) => + val newProjectList = projectList.filter { e => + // Spark does char type read-side padding via an additional Project over the scan node, + // and we need to apply column pruning for the Project as well, otherwise the Project + // will contain missing attributes. + e.references.subsetOf(child.outputSet) + } + p.copy(projectList = newProjectList) + } + } + newTarget + } /** * Update FileFormat for a plan and return the updated plan diff --git a/spark/src/main/scala/org/apache/spark/sql/delta/commands/MergeIntoCommand.scala b/spark/src/main/scala/org/apache/spark/sql/delta/commands/MergeIntoCommand.scala index 3d219dd5467..8f5bb5172e8 100644 --- a/spark/src/main/scala/org/apache/spark/sql/delta/commands/MergeIntoCommand.scala +++ b/spark/src/main/scala/org/apache/spark/sql/delta/commands/MergeIntoCommand.scala @@ -25,17 +25,14 @@ import org.apache.spark.sql.delta.files._ import org.apache.spark.sql.delta.schema.{ImplicitMetadataOperation, SchemaUtils} import org.apache.spark.sql.delta.sources.DeltaSQLConf import org.apache.spark.sql.delta.util.{AnalysisHelper, SetAccumulator} -import com.fasterxml.jackson.databind.annotation.JsonDeserialize import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, RowEncoder} -import org.apache.spark.sql.catalyst.expressions.{Alias, And, Attribute, AttributeReference, BasePredicate, Expression, Literal, NamedExpression, PredicateHelper, UnsafeProjection} +import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap -import org.apache.spark.sql.execution.datasources.LogicalRelation import org.apache.spark.sql.functions._ import org.apache.spark.sql.types.{DataTypes, LongType, StructType} @@ -75,7 +72,6 @@ case class MergeIntoCommand( notMatchedClauses: Seq[DeltaMergeIntoNotMatchedClause], notMatchedBySourceClauses: Seq[DeltaMergeIntoNotMatchedBySourceClause], migratedSchema: Option[StructType]) extends MergeIntoCommandBase - with PredicateHelper with AnalysisHelper with ImplicitMetadataOperation { @@ -92,26 +88,9 @@ case class MergeIntoCommand( AttributeReference("num_deleted_rows", LongType)(), AttributeReference("num_inserted_rows", LongType)()) - /** - * Map to get target output attributes by name. - * The case sensitivity of the map is set accordingly to Spark configuration. - */ - @transient private lazy val targetOutputAttributesMap: Map[String, Attribute] = { - val attrMap: Map[String, Attribute] = target - .outputSet.view - .map(attr => attr.name -> attr).toMap - if (conf.caseSensitiveAnalysis) { - attrMap - } else { - CaseInsensitiveMap(attrMap) - } - } - /** Whether this merge statement has only a single insert (NOT MATCHED) clause. */ private def isSingleInsertOnly: Boolean = matchedClauses.isEmpty && notMatchedBySourceClauses.isEmpty && notMatchedClauses.length == 1 - /** Whether this merge statement has no insert (NOT MATCHED) clause. */ - private def hasNoInserts: Boolean = notMatchedClauses.isEmpty // We over-count numTargetRowsDeleted when there are multiple matches; // this is the amount of the overcount, so we can subtract it to get a correct final metric. @@ -217,29 +196,30 @@ case class MergeIntoCommand( status = "MERGE operation - scanning files for matches", sqlMetricName = "scanTimeMs") { + val columnComparator = spark.sessionState.analyzer.resolver + // Accumulator to collect all the distinct touched files val touchedFilesAccum = new SetAccumulator[String]() spark.sparkContext.register(touchedFilesAccum, TOUCHED_FILES_ACCUM_NAME) // UDFs to records touched files names and add them to the accumulator - val recordTouchedFileName = DeltaUDF.intFromString { fileName => - touchedFilesAccum.add(fileName) - 1 - }.asNondeterministic() + val recordTouchedFileName = + DeltaUDF.intFromStringBoolean { (fileName, shouldRecord) => { + if (shouldRecord) { + touchedFilesAccum.add(fileName) + } + 1 + }}.asNondeterministic() // Prune non-matching files if we don't need to collect them for NOT MATCHED BY SOURCE clauses. val dataSkippedFiles = if (notMatchedBySourceClauses.isEmpty) { - val targetOnlyPredicates = - splitConjunctivePredicates(condition).filter(_.references.subsetOf(target.outputSet)) - deltaTxn.filterFiles(targetOnlyPredicates) + deltaTxn.filterFiles(getTargetOnlyPredicates(spark)) } else { deltaTxn.filterFiles() } val incrSourceRowCountExpr = incrementMetricAndReturnBool("numSourceRows", valueToReturn = true) - val sourceDF = getSourceDF() - .filter(new Column(incrSourceRowCountExpr)) // Join the source and target table using the merge condition to find touched files. An inner // join collects all candidate files for MATCHED clauses, a right outer join also includes @@ -249,14 +229,49 @@ case class MergeIntoCommand( // target row is modified by multiple user or not // - the target file name the row is from to later identify the files touched by matched rows val joinType = if (notMatchedBySourceClauses.isEmpty) "inner" else "right_outer" - val targetDF = buildTargetPlanWithFiles(spark, deltaTxn, dataSkippedFiles) + + // When they are only MATCHED clauses, we prune after the join the files that have no rows that + // satisfy any of the clause conditions. + val matchedPredicate = + if (isMatchedOnly) { + matchedClauses + .map(clause => clause.condition.getOrElse(Literal(true))) + .reduce((a, b) => Or(a, b)) + } else Literal(true) + + // Compute the columns needed for the inner join. + val targetColsNeeded = { + condition.references.map(_.name) ++ deltaTxn.snapshot.metadata.partitionColumns ++ + matchedPredicate.references.map(_.name) + } + + val columnsToDrop = deltaTxn.snapshot.metadata.schema.map(_.name) + .filterNot { field => + targetColsNeeded.exists { name => columnComparator(name, field) } + } + + // We can't use filter() directly on the expression because that will prevent + // column pruning. We don't need the SOURCE_ROW_PRESENT_COL so we immediately drop it. + val sourceDF = getSourceDF() + .withColumn(SOURCE_ROW_PRESENT_COL, new Column(incrSourceRowCountExpr)) + .filter(SOURCE_ROW_PRESENT_COL) + .drop(SOURCE_ROW_PRESENT_COL) + val targetPlan = + buildTargetPlanWithFiles( + spark, + deltaTxn, + dataSkippedFiles, + columnsToDrop) + val targetDF = Dataset.ofRows(spark, targetPlan) .withColumn(ROW_ID_COL, monotonically_increasing_id()) .withColumn(FILE_NAME_COL, input_file_name()) + val joinToFindTouchedFiles = sourceDF.join(targetDF, new Column(condition), joinType) // Process the matches from the inner join to record touched files and find multiple matches val collectTouchedFiles = joinToFindTouchedFiles - .select(col(ROW_ID_COL), recordTouchedFileName(col(FILE_NAME_COL)).as("one")) + .select(col(ROW_ID_COL), + recordTouchedFileName(col(FILE_NAME_COL), new Column(matchedPredicate)).as("one")) // Calculate frequency of matches per source row val matchedRowCounts = collectTouchedFiles.groupBy(ROW_ID_COL).agg(sum("one").as("count")) @@ -368,8 +383,12 @@ case class MergeIntoCommand( val dataSkippedFiles = deltaTxn.filterFiles(targetOnlyPredicates) // target DataFrame - val targetDF = buildTargetPlanWithFiles(spark, deltaTxn, dataSkippedFiles) - + val targetPlan = buildTargetPlanWithFiles( + spark, + deltaTxn, + dataSkippedFiles, + columnsToDrop = Nil) + val targetDF = Dataset.ofRows(spark, targetPlan) val insertDf = sourceDF.join(targetDF, new Column(condition), "leftanti") .select(outputCols: _*) .filter(new Column(incrInsertedCountExpr)) @@ -417,14 +436,16 @@ case class MergeIntoCommand( deltaTxn: OptimisticTransaction, filesToRewrite: Seq[AddFile]) : Seq[FileAction] = recordMergeOperation( - extraOpType = "writeAllChanges", + extraOpType = + if (shouldOptimizeMatchedOnlyMerge(spark)) "writeAllUpdatesAndDeletes" + else "writeAllChanges", status = s"MERGE operation - Rewriting ${filesToRewrite.size} files", sqlMetricName = "rewriteTimeMs") { import org.apache.spark.sql.catalyst.expressions.Literal.{TrueLiteral, FalseLiteral} val cdcEnabled = DeltaConfigs.CHANGE_DATA_FEED.fromMetaData(deltaTxn.metadata) - var targetOutputCols = getTargetOutputCols(deltaTxn) + var targetOutputCols = getTargetOutputCols(deltaTxn, makeNullable = true) var outputRowSchema = deltaTxn.metadata.schema // When we have duplicate matches (only allowed when the whenMatchedCondition is a delete with @@ -451,9 +472,13 @@ case class MergeIntoCommand( // Generate a new target dataframe that has same output attributes exprIds as the target plan. // This allows us to apply the existing resolved update/insert expressions. - val baseTargetDF = buildTargetPlanWithFiles(spark, deltaTxn, filesToRewrite) - val joinType = if (hasNoInserts && - spark.conf.get(DeltaSQLConf.MERGE_MATCHED_ONLY_ENABLED)) { + val targetPlan = buildTargetPlanWithFiles( + spark, + deltaTxn, + filesToRewrite, + columnsToDrop = Nil) + val baseTargetDF = Dataset.ofRows(spark, targetPlan) + val joinType = if (shouldOptimizeMatchedOnlyMerge(spark)) { "rightOuter" } else { "fullOuter" @@ -725,87 +750,6 @@ case class MergeIntoCommand( newFiles } - - /** - * Build a new logical plan using the given `files` that has the same output columns (exprIds) - * as the `target` logical plan, so that existing update/insert expressions can be applied - * on this new plan. - */ - private def buildTargetPlanWithFiles( - spark: SparkSession, - deltaTxn: OptimisticTransaction, - files: Seq[AddFile]): DataFrame = { - val targetOutputCols = getTargetOutputCols(deltaTxn) - val targetOutputColsMap = { - val colsMap: Map[String, NamedExpression] = targetOutputCols.view - .map(col => col.name -> col).toMap - if (conf.caseSensitiveAnalysis) { - colsMap - } else { - CaseInsensitiveMap(colsMap) - } - } - - val plan = { - // We have to do surgery to use the attributes from `targetOutputCols` to scan the table. - // In cases of schema evolution, they may not be the same type as the original attributes. - val original = - deltaTxn.deltaLog.createDataFrame(deltaTxn.snapshot, files).queryExecution.analyzed - val transformed = original.transform { - case LogicalRelation(base, output, catalogTbl, isStreaming) => - LogicalRelation( - base, - // We can ignore the new columns which aren't yet AttributeReferences. - targetOutputCols.collect { case a: AttributeReference => a }, - catalogTbl, - isStreaming) - } - - // In case of schema evolution & column mapping, we would also need to rebuild the file format - // because under column mapping, the reference schema within DeltaParquetFileFormat - // that is used to populate metadata needs to be updated - if (deltaTxn.metadata.columnMappingMode != NoMapping) { - val updatedFileFormat = deltaTxn.deltaLog.fileFormat(deltaTxn.protocol, deltaTxn.metadata) - DeltaTableUtils.replaceFileFormat(transformed, updatedFileFormat) - } else { - transformed - } - } - - // For each plan output column, find the corresponding target output column (by name) and - // create an alias - val aliases = plan.output.map { - case newAttrib: AttributeReference => - val existingTargetAttrib = targetOutputColsMap.get(newAttrib.name) - .getOrElse { - throw DeltaErrors.failedFindAttributeInOutputColumns( - newAttrib.name, targetOutputCols.mkString(",")) - }.asInstanceOf[AttributeReference] - - if (existingTargetAttrib.exprId == newAttrib.exprId) { - // It's not valid to alias an expression to its own exprId (this is considered a - // non-unique exprId by the analyzer), so we just use the attribute directly. - newAttrib - } else { - Alias(newAttrib, existingTargetAttrib.name)(exprId = existingTargetAttrib.exprId) - } - } - - Dataset.ofRows(spark, Project(aliases, plan)) - } - - private def getTargetOutputCols(txn: OptimisticTransaction): Seq[NamedExpression] = { - txn.metadata.schema.map { col => - targetOutputAttributesMap - .get(col.name) - .map { a => - AttributeReference(col.name, col.dataType, col.nullable)(a.exprId) - } - .getOrElse(Alias(Literal(null), col.name)() - ) - } - } - /** * Repartitions the output DataFrame by the partition columns if table is partitioned * and `merge.repartitionBeforeWrite.enabled` is set to true. diff --git a/spark/src/main/scala/org/apache/spark/sql/delta/commands/MergeIntoCommandBase.scala b/spark/src/main/scala/org/apache/spark/sql/delta/commands/MergeIntoCommandBase.scala index 016e856ce34..4e53eae1938 100644 --- a/spark/src/main/scala/org/apache/spark/sql/delta/commands/MergeIntoCommandBase.scala +++ b/spark/src/main/scala/org/apache/spark/sql/delta/commands/MergeIntoCommandBase.scala @@ -19,18 +19,18 @@ package org.apache.spark.sql.delta.commands import java.util.concurrent.TimeUnit import org.apache.spark.sql.delta.metric.IncrementMetric -import org.apache.spark.sql.delta.{DeltaErrors, DeltaLog, DeltaOperations, OptimisticTransaction} -import org.apache.spark.sql.delta.actions.Action -import org.apache.spark.sql.delta.actions.FileAction +import org.apache.spark.sql.delta._ +import org.apache.spark.sql.delta.actions.{Action, AddFile, FileAction} import org.apache.spark.sql.delta.commands.merge.{MergeIntoMaterializeSource, MergeIntoMaterializeSourceReason, MergeStats} -import org.apache.spark.sql.delta.files.TahoeFileIndex +import org.apache.spark.sql.delta.files.{TahoeBatchFileIndex, TahoeFileIndex} import org.apache.spark.sql.delta.metering.DeltaLogging import org.apache.spark.sql.delta.sources.DeltaSQLConf import org.apache.spark.SparkContext import org.apache.spark.sql.SparkSession -import org.apache.spark.sql.catalyst.expressions.{Expression, Literal} -import org.apache.spark.sql.catalyst.plans.logical.{DeltaMergeIntoMatchedClause, DeltaMergeIntoNotMatchedBySourceClause, DeltaMergeIntoNotMatchedClause, LogicalPlan} +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap import org.apache.spark.sql.execution.command.LeafRunnableCommand import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} import org.apache.spark.sql.types.StructType @@ -38,6 +38,7 @@ import org.apache.spark.sql.types.StructType abstract class MergeIntoCommandBase extends LeafRunnableCommand with DeltaCommand with DeltaLogging + with PredicateHelper with MergeIntoMaterializeSource { @transient val source: LogicalPlan @@ -53,6 +54,24 @@ abstract class MergeIntoCommandBase extends LeafRunnableCommand @transient protected lazy val sc: SparkContext = SparkContext.getOrCreate() @transient protected lazy val targetDeltaLog: DeltaLog = targetFileIndex.deltaLog + /** + * Map to get target output attributes by name. + * The case sensitivity of the map is set accordingly to Spark configuration. + */ + @transient private lazy val targetOutputAttributesMap: Map[String, Attribute] = { + val attrMap: Map[String, Attribute] = target + .outputSet.view + .map(attr => attr.name -> attr).toMap + if (conf.caseSensitiveAnalysis) { + attrMap + } else { + CaseInsensitiveMap(attrMap) + } + } + + /** Whether this merge statement has only MATCHED clauses. */ + protected def isMatchedOnly: Boolean = notMatchedClauses.isEmpty && matchedClauses.nonEmpty && + notMatchedBySourceClauses.isEmpty import SQLMetrics._ override lazy val metrics: Map[String, SQLMetric] = baseMetrics @@ -147,9 +166,132 @@ abstract class MergeIntoCommandBase extends LeafRunnableCommand materializeSourceReason = Some(materializeSourceReason.toString), materializeSourceAttempts = Some(attempt)) } + + protected def shouldOptimizeMatchedOnlyMerge(spark: SparkSession): Boolean = { + isMatchedOnly && spark.conf.get(DeltaSQLConf.MERGE_MATCHED_ONLY_ENABLED) + } + /** + * Build a new logical plan to read the given `files` instead of the whole target table. + * The plan returned has the same output columns (exprIds) as the `target` logical plan, so that + * existing update/insert expressions can be applied on this new plan. Unneeded non-partition + * columns may be dropped. + */ + protected def buildTargetPlanWithFiles( + spark: SparkSession, + deltaTxn: OptimisticTransaction, + files: Seq[AddFile], + columnsToDrop: Seq[String]): LogicalPlan = { + // Action type "batch" is a historical artifact; the original implementation used it. + val fileIndex = new TahoeBatchFileIndex( + spark, + actionType = "batch", + files, + deltaTxn.deltaLog, + targetFileIndex.path, + deltaTxn.snapshot) + + buildTargetPlanWithIndex( + spark, + deltaTxn, + fileIndex, + columnsToDrop + ) + } + + /** + * Build a new logical plan to read the target table using the given `fileIndex`. + * The plan returned has the same output columns (exprIds) as the `target` logical plan, so that + * existing update/insert expressions can be applied on this new plan. Unneeded non-partition + * columns may be dropped. + */ + protected def buildTargetPlanWithIndex( + spark: SparkSession, + deltaTxn: OptimisticTransaction, + fileIndex: TahoeFileIndex, + columnsToDrop: Seq[String]): LogicalPlan = { + + val targetOutputCols = getTargetOutputCols(deltaTxn) + + val plan = { + + // In case of schema evolution & column mapping, we need to rebuild the file format + // because under column mapping, the reference schema within DeltaParquetFileFormat + // that is used to populate metadata needs to be updated. + // + // WARNING: We must do this before replacing the file index, or we risk invalidating the + // metadata column expression ids that replaceFileIndex might inject into the plan. + val planWithReplacedFileFormat = if (deltaTxn.metadata.columnMappingMode != NoMapping) { + val updatedFileFormat = deltaTxn.deltaLog.fileFormat(deltaTxn.protocol, deltaTxn.metadata) + DeltaTableUtils.replaceFileFormat(target, updatedFileFormat) + } else { + target + } + + // We have to do surgery to use the attributes from `targetOutputCols` to scan the table. + // In cases of schema evolution, they may not be the same type as the original attributes. + // We can ignore the new columns which aren't yet AttributeReferences. + val newReadCols = targetOutputCols.collect { case a: AttributeReference => a } + DeltaTableUtils.replaceFileIndex( + spark, + planWithReplacedFileFormat, + fileIndex, + columnsToDrop, + newOutput = Some(newReadCols)) + } + + // Add back the null expression aliases for columns that are new to the target schema + // and don't exist in the input snapshot. + // These have been added in `getTargetOutputCols` but have been removed in `newReadCols` above + // and are thus not in `plan.output`. + val newColumnsWithNulls = targetOutputCols.filter(_.isInstanceOf[Alias]) + Project(plan.output ++ newColumnsWithNulls, plan) + } + + /** + * Get the expression references for the output columns of the target table relative to + * the transaction. Due to schema evolution, there are two kinds of expressions here: + * * References to columns in the target dataframe. Note that these references may have a + * different data type than they originally did due to schema evolution, but the exprId + * will be the same. These references will be marked as nullable if `makeNullable` is set + * to true. + * * Literal nulls, for new columns which are being added to the target table as part of + * this transaction, since new columns will have a value of null for all existing rows. + */ + protected def getTargetOutputCols( + txn: OptimisticTransaction, makeNullable: Boolean = false): Seq[NamedExpression] = { + txn.metadata.schema.map { col => + targetOutputAttributesMap + .get(col.name) + .map { a => + AttributeReference(col.name, col.dataType, makeNullable || col.nullable)(a.exprId) + } + .getOrElse(Alias(Literal(null), col.name)()) + } + } + /** Expressions to increment SQL metrics */ protected def incrementMetricAndReturnBool(name: String, valueToReturn: Boolean): Expression = IncrementMetric(Literal(valueToReturn), metrics(name)) + + protected def getTargetOnlyPredicates(spark: SparkSession): Seq[Expression] = { + val targetOnlyPredicatesOnCondition = + splitConjunctivePredicates(condition).filter(_.references.subsetOf(target.outputSet)) + + if (!isMatchedOnly) { + targetOnlyPredicatesOnCondition + } else { + val targetOnlyMatchedPredicate = matchedClauses + .map(clause => clause.condition.getOrElse(Literal(true))) + .map { condition => + splitConjunctivePredicates(condition) + .filter(_.references.subsetOf(target.outputSet)) + .reduceOption(And) + .getOrElse(Literal(true)) + } + .reduceOption(Or) + targetOnlyPredicatesOnCondition ++ targetOnlyMatchedPredicate + } + } /** * Execute the given `thunk` and return its result while recording the time taken to do it * and setting additional local properties for better UI visibility. diff --git a/spark/src/test/scala/org/apache/spark/sql/delta/DeltaTestUtils.scala b/spark/src/test/scala/org/apache/spark/sql/delta/DeltaTestUtils.scala index cfb0c1505cd..df0c485519d 100644 --- a/spark/src/test/scala/org/apache/spark/sql/delta/DeltaTestUtils.scala +++ b/spark/src/test/scala/org/apache/spark/sql/delta/DeltaTestUtils.scala @@ -300,17 +300,6 @@ trait DeltaTestUtilsForTempViews } } - def testOssOnlyWithTempView(testName: String)(testFun: Boolean => Any): Unit = { - Seq(true, false).foreach { isSQLTempView => - val tempViewUsed = if (isSQLTempView) "SQL TempView" else "Dataset TempView" - test(s"$testName - $tempViewUsed") { - withTempView("v") { - testFun(isSQLTempView) - } - } - } - } - def testQuietlyWithTempView(testName: String)(testFun: Boolean => Any): Unit = { Seq(true, false).foreach { isSQLTempView => val tempViewUsed = if (isSQLTempView) "SQL TempView" else "Dataset TempView" diff --git a/spark/src/test/scala/org/apache/spark/sql/delta/MergeIntoSuiteBase.scala b/spark/src/test/scala/org/apache/spark/sql/delta/MergeIntoSuiteBase.scala index be9df05e410..ec44e5cba5b 100644 --- a/spark/src/test/scala/org/apache/spark/sql/delta/MergeIntoSuiteBase.scala +++ b/spark/src/test/scala/org/apache/spark/sql/delta/MergeIntoSuiteBase.scala @@ -22,13 +22,17 @@ import java.util.Locale import scala.language.implicitConversions -import com.databricks.spark.util.{Log4jUsageLogger, UsageRecord} +import com.databricks.spark.util.{Log4jUsageLogger, MetricDefinitions, UsageRecord} import org.apache.spark.sql.delta.DeltaTestUtils.BOOLEAN_DOMAIN +import org.apache.spark.sql.delta.commands.merge.MergeStats import org.apache.spark.sql.delta.sources.DeltaSQLConf import org.apache.spark.sql.delta.test.DeltaTestImplicits._ +import org.apache.spark.sql.delta.test.ScanReportHelper +import org.apache.spark.sql.delta.util.JsonUtils +import org.apache.hadoop.fs.Path import org.scalatest.BeforeAndAfterEach -import org.apache.spark.sql.{AnalysisException, DataFrame, QueryTest, Row} +import org.apache.spark.sql.{functions, AnalysisException, DataFrame, QueryTest, Row} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeArrayData} import org.apache.spark.sql.catalyst.util.FailFastMode @@ -43,6 +47,7 @@ abstract class MergeIntoSuiteBase extends QueryTest with SharedSparkSession with BeforeAndAfterEach with SQLTestUtils + with ScanReportHelper with DeltaTestUtilsForTempViews with MergeHelpers { @@ -2454,6 +2459,210 @@ abstract class MergeIntoSuiteBase ) ) + test("data skipping - target-only condition") { + withKeyValueData( + source = (1, 10) :: Nil, + target = (1, 1) :: (2, 2) :: Nil, + isKeyPartitioned = true) { case (sourceName, targetName) => + + val report = getScanReport { + executeMerge( + target = s"$targetName t", + source = s"$sourceName s", + condition = "s.key = t.key AND t.key <= 1", + update = "t.key = s.key, t.value = s.value", + insert = "(key, value) VALUES (s.key, s.value)") + }.head + + checkAnswer(sql(getDeltaFileStmt(tempPath)), + Row(1, 10) :: // Updated + Row(2, 2) :: // File should be skipped + Nil) + + assert(report.size("scanned").bytesCompressed != report.size("total").bytesCompressed) + } + } + + test("insert only merge - target data skipping") { + val tblName = "merge_target" + withTable(tblName) { + spark.range(10).withColumn("part", 'id % 5).withColumn("value", 'id + 'id) + .write.format("delta").partitionBy("part").mode("append").saveAsTable(tblName) + + val source = "source" + withTable(source) { + spark.range(20).withColumn("part", functions.lit(1)).withColumn("value", 'id + 'id) + .write.format("delta").saveAsTable(source) + + val scans = getScanReport { + withSQLConf(DeltaSQLConf.MERGE_INSERT_ONLY_ENABLED.key -> "true") { + executeMerge( + s"$tblName t", + s"$source s", + "s.id = t.id AND t.part = 1", + insert(condition = "s.id % 5 = s.part", values = "*")) + } + } + checkAnswer( + spark.table(tblName).where("part = 1"), + Row(1, 1, 2) :: Row(6, 1, 12) :: Row(11, 1, 22) :: Row(16, 1, 32) :: Nil + ) + + assert(scans.length === 2, "We should scan the source and target " + + "data once in an insert only optimization") + + // check if the source and target tables are scanned just once + val sourceRoot = DeltaTableUtils.findDeltaTableRoot( + spark, new Path(spark.table(source).inputFiles.head)).get.toString + val targetRoot = DeltaTableUtils.findDeltaTableRoot( + spark, new Path(spark.table(tblName).inputFiles.head)).get.toString + assert(scans.map(_.path).toSet == Set(sourceRoot, targetRoot)) + + // check scanned files + val targetScans = scans.find(_.path == targetRoot) + val deltaLog = DeltaLog.forTable(spark, targetScans.get.path) + val numTargetFiles = deltaLog.snapshot.numOfFiles + assert(targetScans.get.metrics("numFiles") < numTargetFiles) + // check scanned sizes + val scanSizes = targetScans.head.size + assert(scanSizes("total").bytesCompressed.get > scanSizes("scanned").bytesCompressed.get, + "Should have partition pruned target table") + } + } + } + + /** + * Test whether data skipping on matched predicates of a merge command is performed. + * @param name The name of the test case. + * @param source The source for merge. + * @param target The target for merge. + * @param dataSkippingOnTargetOnly The boolean variable indicates whether + * when matched clauses are on target fields only. + * Data Skipping should be performed before inner join if + * this variable is true. + * @param isMatchedOnly The boolean variable indicates whether the merge command only + * contains when matched clauses. + * @param mergeClauses Merge Clauses. + */ + protected def testMergeDataSkippingOnMatchPredicates( + name: String)( + source: Seq[(Int, Int)], + target: Seq[(Int, Int)], + dataSkippingOnTargetOnly: Boolean, + isMatchedOnly: Boolean, + mergeClauses: MergeClause*)( + result: Seq[(Int, Int)]): Unit = { + test(s"data skipping with matched predicates - $name") { + withKeyValueData(source, target) { case (sourceName, targetName) => + val stats = performMergeAndCollectStatsForDataSkippingOnMatchPredicates( + sourceName, + targetName, + result, + mergeClauses) + // Data skipping on match predicates should only be performed when it's a + // matched only merge. + if (isMatchedOnly) { + // The number of files removed/added should be 0 because of the additional predicates. + assert(stats.targetFilesRemoved == 0) + assert(stats.targetFilesAdded == 0) + // Verify that the additional predicates on data skipping + // before inner join filters file out for match predicates only + // on target. + if (dataSkippingOnTargetOnly) { + assert(stats.targetBeforeSkipping.files.get > stats.targetAfterSkipping.files.get) + } + } else { + assert(stats.targetFilesRemoved > 0) + // If there is no insert clause and the flag is enabled, data skipping should be + // performed on targetOnly predicates. + // However, with insert clauses, it's expected that no additional data skipping + // is performed on matched clauses. + assert(stats.targetBeforeSkipping.files.get == stats.targetAfterSkipping.files.get) + assert(stats.targetRowsUpdated == 0) + } + } + } + } + + protected def performMergeAndCollectStatsForDataSkippingOnMatchPredicates( + sourceName: String, + targetName: String, + result: Seq[(Int, Int)], + mergeClauses: Seq[MergeClause]): MergeStats = { + var events: Seq[UsageRecord] = Seq.empty + // Perform merge on merge condition with matched clauses. + events = Log4jUsageLogger.track { + executeMerge(s"$targetName t", s"$sourceName s", "s.key = t.key", mergeClauses: _*) + } + val deltaPath = if (targetName.startsWith("delta.`")) { + targetName.stripPrefix("delta.`").stripSuffix("`") + } else targetName + + checkAnswer( + readDeltaTable(deltaPath), + result.map { case (k, v) => Row(k, v) }) + + // Verify merge stats from usage events + val mergeStats = events.filter { e => + e.metric == MetricDefinitions.EVENT_TAHOE.name && + e.tags.get("opType").contains("delta.dml.merge.stats") + } + + assert(mergeStats.size == 1) + + JsonUtils.fromJson[MergeStats](mergeStats.head.blob) + } + + testMergeDataSkippingOnMatchPredicates("match conditions on target fields only")( + source = Seq((1, 100), (3, 300), (5, 500)), + target = Seq((1, 10), (2, 20), (3, 30)), + dataSkippingOnTargetOnly = true, + isMatchedOnly = true, + update(condition = "t.key == 10", set = "*"), + update(condition = "t.value == 100", set = "*"))( + result = Seq((1, 10), (2, 20), (3, 30)) + ) + + testMergeDataSkippingOnMatchPredicates("match conditions on source fields only")( + source = Seq((1, 100), (3, 300), (5, 500)), + target = Seq((1, 10), (2, 20), (3, 30)), + dataSkippingOnTargetOnly = false, + isMatchedOnly = true, + update(condition = "s.key == 10", set = "*"), + update(condition = "s.value == 10", set = "*"))( + result = Seq((1, 10), (2, 20), (3, 30)) + ) + + testMergeDataSkippingOnMatchPredicates("match on source and target fields")( + source = Seq((1, 100), (3, 300), (5, 500)), + target = Seq((1, 10), (2, 20), (3, 30)), + dataSkippingOnTargetOnly = false, + isMatchedOnly = true, + update(condition = "s.key == 10", set = "*"), + update(condition = "s.value == 10", set = "*"), + delete(condition = "t.key == 4"))( + result = Seq((1, 10), (2, 20), (3, 30)) + ) + + testMergeDataSkippingOnMatchPredicates("with insert clause")( + source = Seq((1, 100), (3, 300), (5, 500)), + target = Seq((1, 10), (2, 20), (3, 30)), + dataSkippingOnTargetOnly = false, + isMatchedOnly = false, + update(condition = "t.key == 10", set = "*"), + insert(condition = null, values = "(key, value) VALUES (s.key, s.value)"))( + result = Seq((1, 10), (2, 20), (3, 30), (5, 500)) + ) + + testMergeDataSkippingOnMatchPredicates("when matched and conjunction")( + source = Seq((1, 100), (3, 300), (5, 500)), + target = Seq((1, 10), (2, 20), (3, 30)), + dataSkippingOnTargetOnly = true, + isMatchedOnly = true, + update(condition = "t.key == 1 AND t.value == 5", set = "*"))( + result = Seq((1, 10), (2, 20), (3, 30))) + + /** * Parse the input JSON data into a dataframe, one row per input element. * Throws an exception on malformed inputs or records that don't comply with the provided schema. @@ -5202,34 +5411,16 @@ abstract class MergeIntoSuiteBase "The schema of your Delta table has changed in an incompatible way" ) - - private def testComplexTempViewOnMerge(name: String)(text: String, expectedResult: Seq[Row]) = { - testOssOnlyWithTempView(s"test merge on temp view - $name") { isSQLTempView => - withTable("tab") { - withTempView("src") { - Seq((0, 3), (1, 2)).toDF("key", "value").write.format("delta").saveAsTable("tab") - createTempViewFromSelect(text, isSQLTempView) - sql("CREATE TEMP VIEW src AS SELECT * FROM VALUES (1, 2), (3, 4) AS t(a, b)") - executeMerge( - target = "v", - source = "src", - condition = "src.a = v.key AND src.b = v.value", - update = "v.value = src.b + 1", - insert = "(v.key, v.value) VALUES (src.a, src.b)") - checkAnswer(spark.table("v"), expectedResult) - } - } - } - } - - testComplexTempViewOnMerge("nontrivial projection")( - "SELECT value as key, key as value FROM tab", - Seq(Row(3, 0), Row(3, 1), Row(4, 3)) + testInvalidTempViews("nontrivial projection")( + text = "SELECT value as key, key as value FROM tab", + expectedErrorMsgForSQLTempView = "Attribute(s) with the same name appear", + expectedErrorMsgForDataSetTempView = "Attribute(s) with the same name appear" ) - testComplexTempViewOnMerge("view with too many internal aliases")( - "SELECT * FROM (SELECT * FROM tab AS t1) AS t2", - Seq(Row(0, 3), Row(1, 3), Row(3, 4)) + testInvalidTempViews("view with too many internal aliases")( + text = "SELECT * FROM (SELECT * FROM tab AS t1) AS t2", + expectedErrorMsgForSQLTempView = "Attribute(s) with the same name appear", + expectedErrorMsgForDataSetTempView = null ) test("UDT Data Types - simple and nested") { diff --git a/spark/src/test/scala/org/apache/spark/sql/delta/test/ScanReportHelper.scala b/spark/src/test/scala/org/apache/spark/sql/delta/test/ScanReportHelper.scala index cc8f5163c7f..0a9f607a03a 100644 --- a/spark/src/test/scala/org/apache/spark/sql/delta/test/ScanReportHelper.scala +++ b/spark/src/test/scala/org/apache/spark/sql/delta/test/ScanReportHelper.scala @@ -110,7 +110,8 @@ trait ScanReportHelper extends SharedSparkSession with AdaptiveSparkPlanHelper { unusedFilters = Nil, size = Map( "total" -> DataSize( - bytesCompressed = Some(deltaTable.deltaLog.unsafeVolatileSnapshot.sizeInBytes)) + bytesCompressed = Some(deltaTable.deltaLog.unsafeVolatileSnapshot.sizeInBytes)), + "scanned" -> DataSize(bytesCompressed = Some(deltaTable.sizeInBytes)) ), metrics = scanExec.metrics.mapValues(_.value).toMap, versionScanned = None,