diff --git a/core/src/main/scala/org/apache/spark/sql/delta/commands/MergeIntoCommand.scala b/core/src/main/scala/org/apache/spark/sql/delta/commands/MergeIntoCommand.scala index b5ea5337faa..3807dd6e917 100644 --- a/core/src/main/scala/org/apache/spark/sql/delta/commands/MergeIntoCommand.scala +++ b/core/src/main/scala/org/apache/spark/sql/delta/commands/MergeIntoCommand.scala @@ -23,6 +23,7 @@ import scala.collection.mutable import org.apache.spark.sql.delta._ import org.apache.spark.sql.delta.actions.{AddCDCFile, AddFile, FileAction} +import org.apache.spark.sql.delta.commands.merge.MergeIntoMaterializeSource import org.apache.spark.sql.delta.files._ import org.apache.spark.sql.delta.schema.{ImplicitMetadataOperation, SchemaUtils} import org.apache.spark.sql.delta.sources.DeltaSQLConf @@ -121,7 +122,12 @@ case class MergeStats( targetRowsCopied: Long, targetRowsUpdated: Long, targetRowsInserted: Long, - targetRowsDeleted: Long + targetRowsDeleted: Long, + + // MergeMaterializeSource stats + materializeSourceReason: Option[String] = None, + @JsonDeserialize(contentAs = classOf[java.lang.Long]) + materializeSourceAttempts: Option[Long] = None ) object MergeStats { @@ -220,7 +226,11 @@ case class MergeIntoCommand( matchedClauses: Seq[DeltaMergeIntoMatchedClause], notMatchedClauses: Seq[DeltaMergeIntoNotMatchedClause], migratedSchema: Option[StructType]) extends LeafRunnableCommand - with DeltaCommand with PredicateHelper with AnalysisHelper with ImplicitMetadataOperation { + with DeltaCommand + with PredicateHelper + with AnalysisHelper + with ImplicitMetadataOperation + with MergeIntoMaterializeSource { import MergeIntoCommand._ @@ -313,7 +323,18 @@ case class MergeIntoCommand( ) } } + val (materializeSource, _) = shouldMaterializeSource(spark, source, isSingleInsertOnly) + if (!materializeSource) { + runMerge(spark) + } else { + // If it is determined that source should be materialized, wrap the execution with retries, + // in case the data of the materialized source is lost. + runWithMaterializedSourceLostRetries( + spark, targetFileIndex.deltaLog, metrics, runMerge) + } + } + protected def runMerge(spark: SparkSession): Seq[Row] = { recordDeltaOperation(targetDeltaLog, "delta.dml.merge") { val startTime = System.nanoTime() targetDeltaLog.withNewTransaction { deltaTxn => @@ -329,6 +350,16 @@ case class MergeIntoCommand( isOverwriteMode = false, rearrangeOnly = false) } + // If materialized, prepare the DF reading the materialize source + // Otherwise, prepare a regular DF from source plan. + val materializeSourceReason = prepareSourceDFAndReturnMaterializeReason( + spark, + source, + condition, + matchedClauses, + notMatchedClauses, + isSingleInsertOnly) + val deltaActions = { if (isSingleInsertOnly && spark.conf.get(DeltaSQLConf.MERGE_INSERT_ONLY_ENABLED)) { writeInsertsOnlyWhenNoMatchedClauses(spark, deltaTxn) @@ -363,9 +394,13 @@ case class MergeIntoCommand( notMatchedClauses.map(DeltaOperations.MergePredicate(_)))) // Record metrics - val stats = MergeStats.fromMergeSQLMetrics( + var stats = MergeStats.fromMergeSQLMetrics( metrics, condition, matchedClauses, notMatchedClauses, deltaTxn.metadata.partitionColumns.nonEmpty) + stats = stats.copy( + materializeSourceReason = Some(materializeSourceReason.toString), + materializeSourceAttempts = Some(attempt)) + recordDeltaEvent(targetFileIndex.deltaLog, "delta.dml.merge.stats", data = stats) } @@ -407,7 +442,7 @@ case class MergeIntoCommand( // UDF to increment metrics val incrSourceRowCountExpr = makeMetricUpdateUDF("numSourceRows") - val sourceDF = Dataset.ofRows(spark, source) + val sourceDF = getSourceDF() .filter(new Column(incrSourceRowCountExpr)) // Apply inner join to between source and target using the merge condition to find matches @@ -520,7 +555,7 @@ case class MergeIntoCommand( } // source DataFrame - val sourceDF = Dataset.ofRows(spark, source) + val sourceDF = getSourceDF() .filter(new Column(incrSourceRowCountExpr)) .filter(new Column(notMatchedClauses.head.condition.getOrElse(Literal.TrueLiteral))) @@ -641,7 +676,7 @@ case class MergeIntoCommand( // We add row IDs to the targetDF if we have a delete-when-matched clause with duplicate // matches and CDC is enabled, and additionally add row IDs to the source if we also have an // insert clause. See above at isDeleteWithDuplicateMatchesAndCdc definition for more details. - var sourceDF = Dataset.ofRows(spark, source) + var sourceDF = getSourceDF() .withColumn(SOURCE_ROW_PRESENT_COL, new Column(incrSourceRowCountExpr)) var targetDF = Dataset.ofRows(spark, newTarget) .withColumn(TARGET_ROW_PRESENT_COL, lit(true)) diff --git a/core/src/main/scala/org/apache/spark/sql/delta/commands/merge/MergeIntoMaterializeSource.scala b/core/src/main/scala/org/apache/spark/sql/delta/commands/merge/MergeIntoMaterializeSource.scala new file mode 100644 index 00000000000..fafcac617f8 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/sql/delta/commands/merge/MergeIntoMaterializeSource.scala @@ -0,0 +1,443 @@ +/* + * Copyright (2021) The Delta Lake Project Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.delta.commands.merge + +import java.util.UUID + +import scala.util.control.NonFatal + +import org.apache.spark.sql.delta.{DeltaLog, DeltaTable} +import org.apache.spark.sql.delta.files.TahoeFileIndex +import org.apache.spark.sql.delta.metering.DeltaLogging +import org.apache.spark.sql.delta.sources.DeltaSQLConf + +import org.apache.spark.SparkException +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.{Dataset, Row, SparkSession} +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{AttributeSet, Expression, ExpressionSet, Literal} +import org.apache.spark.sql.catalyst.optimizer.{EliminateResolvedHint, JoinSelectionHelper} +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.execution.{LogicalRDD, SQLExecution} +import org.apache.spark.sql.execution.datasources.LogicalRelation +import org.apache.spark.sql.execution.metric.SQLMetric +import org.apache.spark.storage.StorageLevel + +/** + * Trait with logic and utilities used for materializing a snapshot of MERGE source + * in case we can't guarantee deterministic repeated reads from it. + * + * We materialize source if it is not safe to assume that it's deterministic + * (override with MERGE_SOURCE_MATERIALIZATION). + * Otherwise, if source changes between the phases of the MERGE, it can produce wrong results. + * We use local checkpointing for the materialization, which saves the source as a + * materialized RDD[InternalRow] on the executor local disks. + * + * 1st concern is that if an executor is lost, this data can be lost. + * When Spark executor decomissioning API is used, it should attempt to move this + * materialized data safely out before removing the executor. + * + * 2nd concern is that if an executor is lost for another reason (e.g. spot kill), we will + * still lose that data. To mitigate that, we implement a retry loop. + * The whole Merge operation needs to be restarted from the beginning in this case. + * When we retry, we increase the replication level of the materialized data from 1 to 2. + * (override with MERGE_SOURCE_MATERIALIZATION_RDD_STORAGE_LEVEL_RETRY). + * If it still fails after the maximum number of attempts (MERGE_MATERIALIZE_SOURCE_MAX_ATTEMPTS), + * we record the failure for tracking purposes. + * + * 3rd concern is that executors run out of disk space with the extra materialization. + * We record such failures for tracking purpuses. + */ +trait MergeIntoMaterializeSource extends DeltaLogging { + import MergeIntoMaterializeSource._ + + /** + * Prepared Dataframe with source data. + * If needed, it is materialized, @see prepareSourceDFAndReturnMaterializeReason + */ + private var sourceDF: Option[Dataset[Row]] = None + + /** + * If the source was materialized, reference to the checkpointed RDD. + */ + protected var materializedSourceRDD: Option[RDD[InternalRow]] = None + + /** + * Track which attempt or retry it is in runWithMaterializedSourceAndRetries + */ + protected var attempt: Int = 0 + + /** + * Run the Merge with retries in case it detects an RDD block lost error of the + * materialized source RDD. + * It will also record out of disk error, if such happen - possibly because of increased disk + * pressure from the materialized source RDD. + */ + protected def runWithMaterializedSourceLostRetries( + spark: SparkSession, + deltaLog: DeltaLog, + metrics: Map[String, SQLMetric], + runMergeFunc: SparkSession => Seq[Row]): Seq[Row] = { + var doRetry = false + var runResult: Seq[Row] = null + attempt = 1 + do { + doRetry = false + metrics.values.foreach(_.reset()) + try { + runResult = runMergeFunc(spark) + } catch { + case NonFatal(ex) => + val isLastAttempt = + (attempt == spark.conf.get(DeltaSQLConf.MERGE_MATERIALIZE_SOURCE_MAX_ATTEMPTS)) + if (!handleExceptionAndReturnTrueIfMergeShouldRetry(ex, isLastAttempt, deltaLog)) { + logInfo(s"Fatal error in MERGE with materialized source in attempt $attempt.") + throw ex + } else { + logInfo(s"Retrying MERGE with materialized source. Attempt $attempt failed.") + doRetry = true + attempt += 1 + } + } finally { + // Remove source from RDD cache (noop if wasn't cached) + materializedSourceRDD.foreach { rdd => + rdd.unpersist() + } + materializedSourceRDD = None + sourceDF = null + } + } while (doRetry) + + runResult + } + + /** + * Handle exception that was thrown from runMerge(). + * Search for errors to log, or that can be handled by retry. + * It may need to descend into ex.getCause() to find the errors, as Spark may have wrapped them. + * @param isLastAttempt indicates that it's the last allowed attempt and there shall be no retry. + * @return true if the exception is handled and merge should retry + * false if the caller should rethrow the error + */ + private def handleExceptionAndReturnTrueIfMergeShouldRetry( + ex: Throwable, isLastAttempt: Boolean, deltaLog: DeltaLog): Boolean = ex match { + // If Merge failed because the materialized source lost blocks from the + // locally checkpointed RDD, we want to retry the whole operation. + // If a checkpointed RDD block is lost, it throws + // SparkCoreErrors.checkpointRDDBlockIdNotFoundError from LocalCheckpointRDD.compute. + case s: SparkException + if !materializedSourceRDD.isEmpty && + s.getMessage.matches( + mergeMaterializedSourceRddBlockLostErrorRegex(materializedSourceRDD.get.id)) => + log.warn("Materialized Merge source RDD block lost. Merge needs to be restarted. " + + s"This was attempt number $attempt.") + if (!isLastAttempt) { + true // retry + } else { + // Record situations where we lost RDD materialized source blocks, despite retries. + recordDeltaEvent( + deltaLog, + MergeIntoMaterializeSourceError.OP_TYPE, + data = MergeIntoMaterializeSourceError( + errorType = MergeIntoMaterializeSourceErrorType.RDD_BLOCK_LOST.toString, + attempt = attempt, + materializedSourceRDDStorageLevel = + materializedSourceRDD.get.getStorageLevel.toString + ) + ) + false + } + + // Record if we ran out of executor disk space. + case s: SparkException + if s.getMessage.contains("java.io.IOException: No space left on device") => + // Record situations where we ran out of disk space, possibly because of the space took + // by the materialized RDD. + recordDeltaEvent( + deltaLog, + MergeIntoMaterializeSourceError.OP_TYPE, + data = MergeIntoMaterializeSourceError( + errorType = MergeIntoMaterializeSourceErrorType.OUT_OF_DISK.toString, + attempt = attempt, + materializedSourceRDDStorageLevel = + materializedSourceRDD.get.getStorageLevel.toString + ) + ) + false + + // Descend into ex.getCause. + // The errors that we are looking for above might have been wrapped inside another exception. + case NonFatal(ex) if ex.getCause() != null => + handleExceptionAndReturnTrueIfMergeShouldRetry(ex.getCause(), isLastAttempt, deltaLog) + + // Descended to the bottom of the causes without finding a retryable error + case _ => false + } + + /** + * @return pair of boolean whether source should be materialized + * and the source materialization reason + */ + protected def shouldMaterializeSource( + spark: SparkSession, source: LogicalPlan, isInsertOnly: Boolean + ): (Boolean, MergeIntoMaterializeSourceReason.MergeIntoMaterializeSourceReason) = { + val materializeType = spark.conf.get(DeltaSQLConf.MERGE_MATERIALIZE_SOURCE) + materializeType match { + case DeltaSQLConf.MergeMaterializeSource.ALL => + (true, MergeIntoMaterializeSourceReason.MATERIALIZE_ALL) + case DeltaSQLConf.MergeMaterializeSource.NONE => + (false, MergeIntoMaterializeSourceReason.NOT_MATERIALIZED_NONE) + case DeltaSQLConf.MergeMaterializeSource.AUTO => + if (isInsertOnly && spark.conf.get(DeltaSQLConf.MERGE_INSERT_ONLY_ENABLED)) { + (false, MergeIntoMaterializeSourceReason.NOT_MATERIALIZED_AUTO_INSERT_ONLY) + } else if (!sourceContainsOnlyDeltaScans(source)) { + (true, MergeIntoMaterializeSourceReason.NON_DETERMINISTIC_SOURCE_NON_DELTA) + } else if (!isDeterministic(source)) { + (true, MergeIntoMaterializeSourceReason.NON_DETERMINISTIC_SOURCE_OPERATORS) + } else { + (false, MergeIntoMaterializeSourceReason.NOT_MATERIALIZED_AUTO) + } + case _ => + // If the config is invalidly set, also materialize. + (true, MergeIntoMaterializeSourceReason.INVALID_CONFIG) + } + } + /** + * If source needs to be materialized, prepare the materialized dataframe in sourceDF + * Otherwise, prepare regular dataframe. + * @return the source materialization reason + */ + protected def prepareSourceDFAndReturnMaterializeReason( + spark: SparkSession, + source: LogicalPlan, + condition: Expression, + matchedClauses: Seq[DeltaMergeIntoMatchedClause], + notMatchedClauses: Seq[DeltaMergeIntoNotMatchedClause], + isInsertOnly: Boolean): MergeIntoMaterializeSourceReason.MergeIntoMaterializeSourceReason = { + val (materialize, materializeReason) = + shouldMaterializeSource(spark, source, isInsertOnly) + if (!materialize) { + // Does not materialize, simply return the dataframe from source plan + sourceDF = Some(Dataset.ofRows(spark, source)) + return materializeReason + } + + val referencedSourceColumns = + getReferencedSourceColumns(source, condition, matchedClauses, notMatchedClauses) + // When we materialize the source, we want to make sure that columns got pruned before caching. + val sourceWithSelectedColumns = Project(referencedSourceColumns, source) + val baseSourcePlanDF = Dataset.ofRows(spark, sourceWithSelectedColumns) + + // Caches the source in RDD cache using localCheckpopoint, which cuts away the RDD lineage, + // which shall ensure that the source cannot be recomputed and thus become inconsistent. + val checkpointedSourcePlanDF = baseSourcePlanDF + // eager = false makes it be executed and materialized first time it's used. + // Doing it lazily inside the query lets it interleave this work better with other work. + // On the other hand, it makes it impossible to measure the time it took in a metric. + .localCheckpoint(eager = false) + + // We have to reach through the crust and into the plan of the checkpointed DF + // to get the RDD that was actually checkpointed, to be able to unpersist it later... + var checkpointedPlan = checkpointedSourcePlanDF.queryExecution.analyzed + val rdd = checkpointedPlan.asInstanceOf[LogicalRDD].rdd + materializedSourceRDD = Some(rdd) + rdd.setName("mergeMaterializedSource") + + // We should still keep the hints from the input plan. + checkpointedPlan = addHintsToPlan(source, checkpointedPlan) + + // FIXME(SPARK-39834): Can be removed once Delta adopts Spark 3.4 and constraints are propagated + // Add filters to the logical plan so the optimizer can pick up the constraints even though + // they are lost when materializing. + checkpointedPlan = addFiltersForConstraintsToPlan( + sourceWithSelectedColumns.constraints, checkpointedPlan) + + sourceDF = Some(Dataset.ofRows(spark, checkpointedPlan)) + + // FIXME(SPARK-39834): This can be removed once Delta adopts Spark 3.4 as the statistics + // will be materialized + // and the optimal join will be picked during planning + sourceDF = Some(addBroadcastHintToDF(sourceWithSelectedColumns, sourceDF.get)) + + // Sets appropriate StorageLevel + val storageLevel = StorageLevel.fromString( + if (attempt == 1) { + spark.conf.get(DeltaSQLConf.MERGE_MATERIALIZE_SOURCE_RDD_STORAGE_LEVEL) + } else { + // If it failed the first time, potentially use a different storage level on retry. + spark.conf.get(DeltaSQLConf.MERGE_MATERIALIZE_SOURCE_RDD_STORAGE_LEVEL_RETRY) + } + ) + rdd.persist(storageLevel) + + logDebug(s"Materializing MERGE with pruned columns $referencedSourceColumns. ") + logDebug(s"Materialized MERGE source plan:\n${sourceDF.get.queryExecution}") + materializeReason + } + + protected def getSourceDF(): Dataset[Row] = { + if (sourceDF.isEmpty) { + throw new IllegalStateException( + "sourceDF was not initialized! Call prepareSourceDFAndReturnMaterializeReason before.") + } + sourceDF.get + } + + private def addHintsToPlan(sourcePlan: LogicalPlan, plan: LogicalPlan): LogicalPlan = { + val hints = EliminateResolvedHint.extractHintsFromPlan(sourcePlan)._2 + // This follows similar code in CacheManager from https://github.com/apache/spark/pull/24580 + if (hints.nonEmpty) { + // The returned hint list is in top-down order, we should create the hint nodes from + // right to left. + val planWithHints = + hints.foldRight[LogicalPlan](plan) { case (hint, p) => + ResolvedHint(p, hint) + } + planWithHints + } else { + plan + } + } + + private def addFiltersForConstraintsToPlan( + constraints: ExpressionSet, + plan: LogicalPlan): LogicalPlan = { + if (constraints.nonEmpty) { + val planWithConstraints = + constraints.foldRight[LogicalPlan](plan) {(expr, updatedPlan) => + Filter(expr, updatedPlan) + } + planWithConstraints + } else { + plan + } + } + + private def addBroadcastHintToDF(sourcePlan: LogicalPlan, df: Dataset[Row]): Dataset[Row] = { + val joinSelectionHelper = new Object with JoinSelectionHelper + if (joinSelectionHelper.canBroadcastBySize(sourcePlan, sourcePlan.conf)) { + df.hint("broadcast") + } else { + df + } + } + + /** + * Return columns from the source plan that are used in the MERGE + */ + private def getReferencedSourceColumns( + source: LogicalPlan, + condition: Expression, + matchedClauses: Seq[DeltaMergeIntoMatchedClause], + notMatchedClauses: Seq[DeltaMergeIntoNotMatchedClause]) = { + val conditionCols = condition.references + val matchedCondCols = matchedClauses.flatMap { clause => + clause.condition.getOrElse(Literal(true)).flatMap(_.references) + } + val notMatchedCondCols = notMatchedClauses.flatMap { clause => + clause.condition.getOrElse(Literal(true)).flatMap(_.references) + } + val matchedActionsCols = matchedClauses.flatMap { clause => + clause.resolvedActions.flatMap(_.expr.references) + } + val notMatchedActionsCols = notMatchedClauses.flatMap { clause => + clause.resolvedActions.flatMap(_.expr.references) + } + val allCols = AttributeSet(conditionCols ++ matchedCondCols ++ notMatchedCondCols ++ + matchedActionsCols ++ notMatchedActionsCols) + + source.output.filter(allCols.contains(_)) + } + + private def sourceContainsOnlyDeltaScans(source: LogicalPlan): Boolean = { + !source.exists { + case l: LogicalRelation => + l match { + case DeltaTable(_) => false + case _ => true + } + case _: LeafNode => true // Any other LeafNode is a non Delta scan. + case _ => false + } + } + + /** + * `true` if `source` has a safe level of determinism. + * This is a conservative approximation of `source` being a truly deterministic query. + */ + private def isDeterministic(plan: LogicalPlan): Boolean = plan match { + // This is very restrictive, allowing only deterministic filters and projections directly + // on top of a Delta Table. + case Project(projectList, child) if projectList.forall(_.deterministic) => + isDeterministic(child) + case Filter(cond, child) if cond.deterministic => isDeterministic(child) + case Union(children, _, _) => children.forall(isDeterministic) + case SubqueryAlias(_, child) => isDeterministic(child) + case DeltaTable(_) => true + case _ => false + } +} + +object MergeIntoMaterializeSource { + // This depends on SparkCoreErrors.checkpointRDDBlockIdNotFoundError msg + def mergeMaterializedSourceRddBlockLostErrorRegex(rddId: Int): String = + s"(?s).*Checkpoint block rdd_${rddId}_[0-9]+ not found!.*" +} + +/** + * Enumeration with possible reasons that source may be materialized in a MERGE command. + */ +object MergeIntoMaterializeSourceReason extends Enumeration { + type MergeIntoMaterializeSourceReason = Value + // It was determined to not materialize on auto config. + val NOT_MATERIALIZED_AUTO = Value("notMaterializedAuto") + // Config was set to never materialize source. + val NOT_MATERIALIZED_NONE = Value("notMaterializedNone") + // Insert only merge is single pass, no need for materialization + val NOT_MATERIALIZED_AUTO_INSERT_ONLY = Value("notMaterializedAutoInsertOnly") + // Config was set to always materialize source. + val MATERIALIZE_ALL = Value("materializeAll") + // The source query is considered non-deterministic, because it contains a non-delta scan. + val NON_DETERMINISTIC_SOURCE_NON_DELTA = Value("materializeNonDeterministicSourceNonDelta") + // The source query is considered non-deterministic, because it contains non-deterministic + // operators. + val NON_DETERMINISTIC_SOURCE_OPERATORS = Value("materializeNonDeterministicSourceOperators") + // Materialize when the configuration is invalid + val INVALID_CONFIG = Value("invalidConfigurationFailsafe") + // Catch-all case. + val UNKNOWN = Value("unknown") +} + +/** + * Structure with data for "delta.dml.merge.materializeSourceError" event. + * Note: We log only errors that we want to track (out of disk or lost RDD blocks). + */ +case class MergeIntoMaterializeSourceError( + errorType: String, + attempt: Int, + materializedSourceRDDStorageLevel: String +) + +object MergeIntoMaterializeSourceError { + val OP_TYPE = "delta.dml.merge.materializeSourceError" +} + +object MergeIntoMaterializeSourceErrorType extends Enumeration { + type MergeIntoMaterializeSourceError = Value + val RDD_BLOCK_LOST = Value("materializeSourceRDDBlockLostRetriesFailure") + val OUT_OF_DISK = Value("materializeSourceOutOfDiskFailure") +} diff --git a/core/src/main/scala/org/apache/spark/sql/delta/sources/DeltaSQLConf.scala b/core/src/main/scala/org/apache/spark/sql/delta/sources/DeltaSQLConf.scala index 76001a277b2..4b15c011165 100644 --- a/core/src/main/scala/org/apache/spark/sql/delta/sources/DeltaSQLConf.scala +++ b/core/src/main/scala/org/apache/spark/sql/delta/sources/DeltaSQLConf.scala @@ -17,10 +17,12 @@ package org.apache.spark.sql.delta.sources // scalastyle:off import.ordering.noEmptyLine +import java.util.Locale import java.util.concurrent.TimeUnit import org.apache.spark.internal.config.ConfigBuilder import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.storage.StorageLevel /** * [[SQLConf]] entries for Delta features. @@ -431,6 +433,68 @@ trait DeltaSQLConfBase { .booleanConf .createWithDefault(false) + final object MergeMaterializeSource { + // See value explanations in the doc below. + final val NONE = "none" + final val ALL = "all" + final val AUTO = "auto" + + final val list = Set(NONE, ALL, AUTO) + } + + val MERGE_MATERIALIZE_SOURCE = + buildConf("merge.materializeSource") + .internal() + .doc("When to materializes source plan during MERGE execution. " + + "The value 'none' means source will never be materialized. " + + "The value 'all' means source will always be materialized. " + + "The value 'auto' means sources will not be materialized when they are certain to be " + + "deterministic." + ) + .stringConf + .transform(_.toLowerCase(Locale.ROOT)) + .checkValues(MergeMaterializeSource.list) + .createWithDefault(MergeMaterializeSource.AUTO) + + val MERGE_MATERIALIZE_SOURCE_RDD_STORAGE_LEVEL = + buildConf("merge.materializeSource.rddStorageLevel") + .internal() + .doc("What StorageLevel to use to persist the source RDD. Note: will always use disk.") + .stringConf + .transform(_.toUpperCase(Locale.ROOT)) + .checkValue( v => + try { + StorageLevel.fromString(v).isInstanceOf[StorageLevel] + } catch { + case _: IllegalArgumentException => true + }, + """"spark.databricks.delta.merge.materializeSource.rddStorageLevel" """ + + "must be a valid StorageLevel") + .createWithDefault("DISK_ONLY") + + val MERGE_MATERIALIZE_SOURCE_RDD_STORAGE_LEVEL_RETRY = + buildConf("merge.materializeSource.rddStorageLevelRetry") + .internal() + .doc("What StorageLevel to use to persist the source RDD when MERGE is retried. " + + "Note: will always use disk.") + .stringConf + .transform(_.toUpperCase(Locale.ROOT)) + .checkValue( v => + try { + StorageLevel.fromString(v).isInstanceOf[StorageLevel] + } catch { + case _: IllegalArgumentException => true + }, + """"spark.databricks.delta.merge.materializeSource.rddStorageLevelRetry" """ + + "must be a valid StorageLevel") + .createWithDefault("DISK_ONLY_2") + + val MERGE_MATERIALIZE_SOURCE_MAX_ATTEMPTS = + buildStaticConf("merge.materializeSource.maxAttempts") + .doc("How many times to try MERGE with in case of lost RDD materialized source data") + .intConf + .createWithDefault(4) + val DELTA_LAST_COMMIT_VERSION_IN_SESSION = buildConf("lastCommitVersionInSession") .doc("The version of the last commit made in the SparkSession for any table.") diff --git a/core/src/test/scala/org/apache/spark/sql/delta/DeltaTestUtils.scala b/core/src/test/scala/org/apache/spark/sql/delta/DeltaTestUtils.scala index 2443bfc5685..8b85e1f2117 100644 --- a/core/src/test/scala/org/apache/spark/sql/delta/DeltaTestUtils.scala +++ b/core/src/test/scala/org/apache/spark/sql/delta/DeltaTestUtils.scala @@ -18,15 +18,15 @@ package org.apache.spark.sql.delta import java.util.concurrent.atomic.AtomicInteger -import scala.collection.mutable.ArrayBuffer - +import org.apache.spark.sql.delta.DeltaTestUtils.Plans import org.apache.spark.sql.delta.test.DeltaSQLCommandTest import org.apache.spark.SparkContext import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart} import org.apache.spark.sql.{AnalysisException, SparkSession} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.execution.{QueryExecution, SparkPlan} +import org.apache.spark.sql.execution.{FileSourceScanExec, QueryExecution, RDDScanExec, SparkPlan, WholeStageCodegenExec} +import org.apache.spark.sql.execution.aggregate.HashAggregateExec import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.util.QueryExecutionListener @@ -34,20 +34,18 @@ trait DeltaTestUtilsBase { final val BOOLEAN_DOMAIN: Seq[Boolean] = Seq(true, false) - class LogicalPlanCapturingListener(optimized: Boolean) extends QueryExecutionListener { - val plans = new ArrayBuffer[LogicalPlan] - override def onSuccess(funcName: String, qe: QueryExecution, durationNs: Long): Unit = { - if (optimized) plans.append(qe.optimizedPlan) else plans.append(qe.analyzed) - } + class PlanCapturingListener() extends QueryExecutionListener { - override def onFailure( - funcName: String, qe: QueryExecution, error: Exception): Unit = {} - } + private[this] var capturedPlans = List.empty[Plans] + + def plans: Seq[Plans] = capturedPlans.reverse - class PhysicalPlanCapturingListener() extends QueryExecutionListener { - val plans = new ArrayBuffer[SparkPlan] override def onSuccess(funcName: String, qe: QueryExecution, durationNs: Long): Unit = { - plans.append(qe.sparkPlan) + capturedPlans ::= Plans( + qe.analyzed, + qe.optimizedPlan, + qe.sparkPlan, + qe.executedPlan) } override def onFailure( @@ -60,15 +58,17 @@ trait DeltaTestUtilsBase { def withLogicalPlansCaptured[T]( spark: SparkSession, optimizedPlan: Boolean)( - thunk: => Unit): ArrayBuffer[LogicalPlan] = { - val planCapturingListener = new LogicalPlanCapturingListener(optimizedPlan) + thunk: => Unit): Seq[LogicalPlan] = { + val planCapturingListener = new PlanCapturingListener spark.sparkContext.listenerBus.waitUntilEmpty(15000) spark.listenerManager.register(planCapturingListener) try { thunk spark.sparkContext.listenerBus.waitUntilEmpty(15000) - planCapturingListener.plans + planCapturingListener.plans.map { plans => + if (optimizedPlan) plans.optimized else plans.analyzed + } } finally { spark.listenerManager.unregister(planCapturingListener) } @@ -79,8 +79,28 @@ trait DeltaTestUtilsBase { */ def withPhysicalPlansCaptured[T]( spark: SparkSession)( - thunk: => Unit): ArrayBuffer[SparkPlan] = { - val planCapturingListener = new PhysicalPlanCapturingListener() + thunk: => Unit): Seq[SparkPlan] = { + val planCapturingListener = new PlanCapturingListener + + spark.sparkContext.listenerBus.waitUntilEmpty(15000) + spark.listenerManager.register(planCapturingListener) + try { + thunk + spark.sparkContext.listenerBus.waitUntilEmpty(15000) + planCapturingListener.plans.map(_.sparkPlan) + } finally { + spark.listenerManager.unregister(planCapturingListener) + } + } + + /** + * Run a thunk with logical and physical plans for all queries captured and passed + * into a provided buffer. + */ + def withAllPlansCaptured[T]( + spark: SparkSession)( + thunk: => Unit): Seq[Plans] = { + val planCapturingListener = new PlanCapturingListener spark.sparkContext.listenerBus.waitUntilEmpty(15000) spark.listenerManager.register(planCapturingListener) @@ -114,9 +134,50 @@ trait DeltaTestUtilsBase { } jobCount.get() } + + protected def getfindTouchedFilesJobPlans(plans: Seq[Plans]): SparkPlan = { + // The expected plan for touched file computation is of the format below. + // The data column should be pruned from both leaves. + // HashAggregate(output=[count#3463L]) + // +- HashAggregate(output=[count#3466L]) + // +- Project + // +- Filter (isnotnull(count#3454L) AND (count#3454L > 1)) + // +- HashAggregate(output=[count#3454L]) + // +- HashAggregate(output=[_row_id_#3418L, sum#3468L]) + // +- Project [_row_id_#3418L, UDF(_file_name_#3422) AS one#3448] + // +- BroadcastHashJoin [id#3342L], [id#3412L], Inner, BuildLeft + // :- Project [id#3342L] + // : +- Filter isnotnull(id#3342L) + // : +- FileScan parquet [id#3342L,part#3343L] + // +- Filter isnotnull(id#3412L) + // +- Project [...] + // +- Project [...] + // +- FileScan parquet [id#3412L,part#3413L] + // Note: It can be RDDScanExec instead of FileScan if the source was materialized. + // We pick the first plan starting from FileScan and ending in HashAggregate as a + // stable heuristic for the one we want. + plans.map(_.executedPlan) + .filter { + case WholeStageCodegenExec(hash: HashAggregateExec) => + hash.collectLeaves().size == 2 && + hash.collectLeaves() + .forall { s => + s.isInstanceOf[FileSourceScanExec] || + s.isInstanceOf[RDDScanExec] + } + case _ => false + }.head + } +} + +object DeltaTestUtils extends DeltaTestUtilsBase { + case class Plans( + analyzed: LogicalPlan, + optimized: LogicalPlan, + sparkPlan: SparkPlan, + executedPlan: SparkPlan) } -object DeltaTestUtils extends DeltaTestUtilsBase trait DeltaTestUtilsForTempViews extends SharedSparkSession { diff --git a/core/src/test/scala/org/apache/spark/sql/delta/MergeIntoMaterializeSourceSuite.scala b/core/src/test/scala/org/apache/spark/sql/delta/MergeIntoMaterializeSourceSuite.scala new file mode 100644 index 00000000000..e1b6d69b5d4 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/sql/delta/MergeIntoMaterializeSourceSuite.scala @@ -0,0 +1,303 @@ +/* + * Copyright (2021) The Delta Lake Project Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.delta + +import scala.collection.mutable +import scala.util.control.NonFatal + +import org.apache.spark.sql.delta.DeltaTestUtils._ +import org.apache.spark.sql.delta.commands.MergeStats +import org.apache.spark.sql.delta.commands.merge.{MergeIntoMaterializeSourceError, MergeIntoMaterializeSourceErrorType, MergeIntoMaterializeSourceReason} +import org.apache.spark.sql.delta.commands.merge.MergeIntoMaterializeSource.mergeMaterializedSourceRddBlockLostErrorRegex +import org.apache.spark.sql.delta.sources.DeltaSQLConf +import org.apache.spark.sql.delta.test.DeltaSQLCommandTest +import org.apache.spark.sql.delta.util.JsonUtils +import org.scalactic.source.Position +import org.scalatest.Tag + +import org.apache.spark.{SparkConf, SparkException} +import org.apache.spark.sql.{DataFrame, QueryTest, Row} +import org.apache.spark.sql.catalyst.expressions.{AttributeReference, EqualTo, Expression, Literal} +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.execution.{FilterExec, LogicalRDD, RDDScanExec, SQLExecution} +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.test.{SharedSparkSession, SQLTestUtils} +import org.apache.spark.sql.types._ +import org.apache.spark.storage.StorageLevel +import org.apache.spark.util.Utils + +trait MergeIntoMaterializeSourceTests + extends QueryTest + with SharedSparkSession + with DeltaSQLCommandTest + with SQLTestUtils + with DeltaTestUtilsBase + { + + import testImplicits._ + + override def beforeAll(): Unit = { + super.beforeAll() + // trigger source materialization in all tests + spark.conf.set(DeltaSQLConf.MERGE_MATERIALIZE_SOURCE.key, "all") + } + + + // Test error message that we check if blocks of materialized source RDD were evicted. + test("missing RDD blocks error message") { + val checkpointedDf = sql("select * from range(10)") + .localCheckpoint(eager = false) + val rdd = checkpointedDf.queryExecution.analyzed.asInstanceOf[LogicalRDD].rdd + checkpointedDf.collect() // trigger lazy materialization + rdd.unpersist() + val ex = intercept[Exception] { + checkpointedDf.collect() + } + assert(ex.isInstanceOf[SparkException], ex) + assert( + ex.getMessage().matches(mergeMaterializedSourceRddBlockLostErrorRegex(rdd.id)), + s"RDD id ${rdd.id}: Message: ${ex.getMessage}") + } + + + def getHints(df: => DataFrame): Seq[(Seq[ResolvedHint], JoinHint)] = { + val plans = withAllPlansCaptured(spark) { + df + } + var plansWithMaterializedSource = 0 + val hints = plans.flatMap { p => + val materializedSourceExists = p.analyzed.exists { + case l: LogicalRDD if l.rdd.name == "mergeMaterializedSource" => true + case _ => false + } + if (materializedSourceExists) { + // If it is a plan with materialized source, there should be exactly one join + // of target and source. We collect resolved hints from analyzed plans, and the hint + // applied to the join from optimized plan. + plansWithMaterializedSource += 1 + val hints = p.analyzed.collect { + case h: ResolvedHint => h + } + val joinHints = p.optimized.collect { + case j: Join => j.hint + } + assert(joinHints.length == 1, s"Got $joinHints") + val joinHint = joinHints.head + + // Only preserve join strategy hints, because we are testing with these. + // Other hints may be added by MERGE internally, e.g. hints to force DFP/DPP, that + // we don't want to be considering here. + val retHints = hints + .filter(_.hints.strategy.nonEmpty) + def retJoinHintInfo(hintInfo: Option[HintInfo]): Option[HintInfo] = hintInfo match { + case Some(h) if h.strategy.nonEmpty => Some(HintInfo(strategy = h.strategy)) + case _ => None + } + val retJoinHint = joinHint.copy( + leftHint = retJoinHintInfo(joinHint.leftHint), + rightHint = retJoinHintInfo(joinHint.rightHint) + ) + + Some((retHints, retJoinHint)) + } else { + None + } + } + assert(plansWithMaterializedSource == 2, + s"2 plans should have materialized source, but got: $plans") + hints + } + + test("materialize source preserves dataframe hints") { + withTable("A", "B", "T") { + sql("select id, id as v from range(50000)").write.format("delta").saveAsTable("T") + sql("select id, id+2 as v from range(10000)").write.format("csv").saveAsTable("A") + sql("select id, id*2 as v from range(1000)").write.format("csv").saveAsTable("B") + + // Manually added broadcast hint will mess up the expected hints hence disable it + withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { + // Simple BROADCAST hint + val hSimple = getHints( + sql("MERGE INTO T USING (SELECT /*+ BROADCAST */ * FROM A) s ON T.id = s.id" + + " WHEN MATCHED THEN UPDATE SET * WHEN NOT MATCHED THEN INSERT *") + ) + hSimple.foreach { case (hints, joinHint) => + assert(hints.length == 1) + assert(hints.head.hints == HintInfo(strategy = Some(BROADCAST))) + assert(joinHint == JoinHint(Some(HintInfo(strategy = Some(BROADCAST))), None)) + } + + // Simple MERGE hint + val hSimpleMerge = getHints( + sql("MERGE INTO T USING (SELECT /*+ MERGE */ * FROM A) s ON T.id = s.id" + + " WHEN MATCHED THEN UPDATE SET * WHEN NOT MATCHED THEN INSERT *") + ) + hSimpleMerge.foreach { case (hints, joinHint) => + assert(hints.length == 1) + assert(hints.head.hints == HintInfo(strategy = Some(SHUFFLE_MERGE))) + assert(joinHint == JoinHint(Some(HintInfo(strategy = Some(SHUFFLE_MERGE))), None)) + } + + // Aliased hint + val hAliased = getHints( + sql("MERGE INTO T USING " + + "(SELECT /*+ BROADCAST(FOO) */ * FROM (SELECT * FROM A) FOO) s ON T.id = s.id" + + " WHEN MATCHED THEN UPDATE SET * WHEN NOT MATCHED THEN INSERT *") + ) + hAliased.foreach { case (hints, joinHint) => + assert(hints.length == 1) + assert(hints.head.hints == HintInfo(strategy = Some(BROADCAST))) + assert(joinHint == JoinHint(Some(HintInfo(strategy = Some(BROADCAST))), None)) + } + + // Aliased hint - hint propagation does not work from under an alias + // (remove if this ever gets implemented in the hint framework) + val hAliasedInner = getHints( + sql("MERGE INTO T USING " + + "(SELECT /*+ BROADCAST(A) */ * FROM (SELECT * FROM A) FOO) s ON T.id = s.id" + + " WHEN MATCHED THEN UPDATE SET * WHEN NOT MATCHED THEN INSERT *") + ) + hAliasedInner.foreach { case (hints, joinHint) => + assert(hints.length == 0) + assert(joinHint == JoinHint(None, None)) + } + + // This hint applies to the join inside the source, not to the source as a whole + val hJoinInner = getHints( + sql("MERGE INTO T USING " + + "(SELECT /*+ BROADCAST(A) */ A.* FROM A JOIN B WHERE A.id = B.id) s ON T.id = s.id" + + " WHEN MATCHED THEN UPDATE SET * WHEN NOT MATCHED THEN INSERT *") + ) + hJoinInner.foreach { case (hints, joinHint) => + assert(hints.length == 0) + assert(joinHint == JoinHint(None, None)) + } + + // Two hints - top one takes effect + val hTwo = getHints( + sql("MERGE INTO T USING (SELECT /*+ BROADCAST, MERGE */ * FROM A) s ON T.id = s.id" + + " WHEN MATCHED THEN UPDATE SET * WHEN NOT MATCHED THEN INSERT *") + ) + hTwo.foreach { case (hints, joinHint) => + assert(hints.length == 2) + assert(hints(0).hints == HintInfo(strategy = Some(BROADCAST))) + assert(hints(1).hints == HintInfo(strategy = Some(SHUFFLE_MERGE))) + // top one takes effect + assert(joinHint == JoinHint(Some(HintInfo(strategy = Some(BROADCAST))), None)) + } + } + } + } + + // FIXME: Tests can be removed once Delta adopts Spark 3.4 as constraints and statistics are + // automatically propagated when materializing + // The following test should fail as soon as statistics are correctly propagated, and acts as a + // reminder to remove the manually added filter and broadcast hint once Spark 3.4 is adopted + test("Source in materialized merge has missing stats") { + // AQE has to be disabled as we might not find the Join in the adaptive plan + withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "false") { + withTable("A", "T") { + sql("select id, id as v from range(50)").write.format("delta").saveAsTable("T") + sql("select id, id+2 as v from range(10)").write.format("csv").saveAsTable("A") + val plans = DeltaTestUtils.withAllPlansCaptured(spark) { + sql("MERGE INTO T USING A as s ON T.id = s.id" + + " WHEN MATCHED THEN UPDATE SET * WHEN NOT MATCHED THEN INSERT *") + } + plans.map(_.optimized).foreach { p => + p.foreach { + case j: Join => + // The source is very small, the only way we'd be above the broadcast join threshold + // is if we lost statistics on the size of the source. + val sourceStats = j.left.stats.sizeInBytes + val broadcastJoinThreshold = spark.conf.get(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD) + assert(sourceStats >= broadcastJoinThreshold) + case _ => + } + } + } + } + } + + test("Filter gets added if there is a constraint") { + // AQE has to be disabled as we might not find the filter in the adaptive plan + withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "false") { + withTable("A", "T") { + spark.range(50).toDF("tgtid").write.format("delta").saveAsTable("T") + spark.range(50).toDF("srcid").write.format("delta").saveAsTable("A") + + val plans = DeltaTestUtils.withAllPlansCaptured(spark) { + sql("MERGE INTO T USING (SELECT * FROM A WHERE srcid = 10) as s ON T.tgtid = s.srcid" + + " WHEN MATCHED THEN UPDATE SET tgtid = s.srcid" + + " WHEN NOT MATCHED THEN INSERT (tgtid) values (s.srcid)") + } + // Check whether the executed plan contains a filter that filters by tgtId that could be + // used to infer constraints lost during materialization + val hastgtIdCondition = (condition: Expression) => { + condition.find { + case EqualTo(AttributeReference("tgtid", _, _, _), Literal(10, _)) => true + case _ => false + }.isDefined + } + val touchedFilesPlan = getfindTouchedFilesJobPlans(plans) + val filter = touchedFilesPlan.find { + case f: FilterExec => hastgtIdCondition(f.condition) + case _ => false + } + assert(filter.isDefined, + s"Didn't find Filter on tgtid=10 in touched files plan:\n$touchedFilesPlan") + } + } + } + + test("Broadcast hint gets added when there is a small source table") { + withTable("A", "T") { + sql("select id, id as v from range(50000)").write.format("delta").saveAsTable("T") + sql("select id, id+2 as v from range(10000)").write.format("csv").saveAsTable("A") + val hints = getHints( + sql("MERGE INTO T USING A as s ON T.id = s.id" + + " WHEN MATCHED THEN UPDATE SET * WHEN NOT MATCHED THEN INSERT *") + ) + hints.foreach { case (hints, joinHint) => + assert(hints.length == 1) + assert(hints.head.hints == HintInfo(strategy = Some(BROADCAST))) + assert(joinHint == JoinHint(Some(HintInfo(strategy = Some(BROADCAST))), None)) + } + } + } + + test("Broadcast hint does not get added when there is a large table") { + withTable("A", "T") { + sql("select id, id as v from range(50000)").write.format("delta").saveAsTable("T") + sql("select id, id+2 as v from range(10000)").write.format("csv").saveAsTable("A") + withSQLConf((SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key, "1KB")) { + val hints = getHints( + sql("MERGE INTO T USING A as s ON T.id = s.id" + + " WHEN MATCHED THEN UPDATE SET * WHEN NOT MATCHED THEN INSERT *") + ) + hints.foreach { case (hints, joinHint) => + assert(hints.length == 0) + assert(joinHint == JoinHint(None, None)) + } + } + } + } +} + +// MERGE + materialize +class MergeIntoMaterializeSourceSuite extends MergeIntoMaterializeSourceTests + diff --git a/core/src/test/scala/org/apache/spark/sql/delta/MergeIntoSQLSuite.scala b/core/src/test/scala/org/apache/spark/sql/delta/MergeIntoSQLSuite.scala index 1d27e3ebd3e..4ee4caca3d1 100644 --- a/core/src/test/scala/org/apache/spark/sql/delta/MergeIntoSQLSuite.scala +++ b/core/src/test/scala/org/apache/spark/sql/delta/MergeIntoSQLSuite.scala @@ -241,6 +241,8 @@ class MergeIntoSQLSuite extends MergeIntoSuiteBase with DeltaSQLCommandTest test("detect nondeterministic source - flag on") { withSQLConf( + // materializing source would fix determinism + DeltaSQLConf.MERGE_MATERIALIZE_SOURCE.key -> DeltaSQLConf.MergeMaterializeSource.NONE, DeltaSQLConf.MERGE_FAIL_IF_SOURCE_CHANGED.key -> "true" ) { val e = intercept[UnsupportedOperationException]( @@ -252,12 +254,23 @@ class MergeIntoSQLSuite extends MergeIntoSuiteBase with DeltaSQLCommandTest test("detect nondeterministic source - flag off") { withSQLConf( + // materializing source would fix determinism + DeltaSQLConf.MERGE_MATERIALIZE_SOURCE.key -> DeltaSQLConf.MergeMaterializeSource.NONE, DeltaSQLConf.MERGE_FAIL_IF_SOURCE_CHANGED.key -> "false" ) { testNondeterministicOrder } } + test("detect nondeterministic source - flag on, materialized") { + withSQLConf( + // materializing source fixes determinism, so the source is no longer nondeterministic + DeltaSQLConf.MERGE_MATERIALIZE_SOURCE.key -> DeltaSQLConf.MergeMaterializeSource.ALL, + DeltaSQLConf.MERGE_FAIL_IF_SOURCE_CHANGED.key -> "true" + ) { + testNondeterministicOrder + } + } test("merge into a dataset temp views with star") { withTempView("v") {