diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala index fe5b016d75c37..e6c8be1397e1d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala @@ -107,6 +107,7 @@ case class AdaptiveSparkPlanExec( // `EnsureRequirements` to not optimize out the user-specified repartition-by-col to work // around this case. EnsureRequirements(optimizeOutRepartition = requiredDistribution.isDefined), + ValidateSparkPlan, RemoveRedundantSorts, DisableUnnecessaryBucketedScan ) ++ context.session.sessionState.queryStagePrepRules @@ -295,16 +296,19 @@ case class AdaptiveSparkPlanExec( // plans are updated, we can clear the query stage list because at this point the two plans // are semantically and physically in sync again. val logicalPlan = replaceWithQueryStagesInLogicalPlan(currentLogicalPlan, stagesToReplace) - val (newPhysicalPlan, newLogicalPlan) = reOptimize(logicalPlan) - val origCost = costEvaluator.evaluateCost(currentPhysicalPlan) - val newCost = costEvaluator.evaluateCost(newPhysicalPlan) - if (newCost < origCost || + val afterReOptimize = reOptimize(logicalPlan) + if (afterReOptimize.isDefined) { + val (newPhysicalPlan, newLogicalPlan) = afterReOptimize.get + val origCost = costEvaluator.evaluateCost(currentPhysicalPlan) + val newCost = costEvaluator.evaluateCost(newPhysicalPlan) + if (newCost < origCost || (newCost == origCost && currentPhysicalPlan != newPhysicalPlan)) { - logOnLevel(s"Plan changed from $currentPhysicalPlan to $newPhysicalPlan") - cleanUpTempTags(newPhysicalPlan) - currentPhysicalPlan = newPhysicalPlan - currentLogicalPlan = newLogicalPlan - stagesToReplace = Seq.empty[QueryStageExec] + logOnLevel(s"Plan changed from $currentPhysicalPlan to $newPhysicalPlan") + cleanUpTempTags(newPhysicalPlan) + currentPhysicalPlan = newPhysicalPlan + currentLogicalPlan = newLogicalPlan + stagesToReplace = Seq.empty[QueryStageExec] + } } // Now that some stages have finished, we can try creating new stages. result = createQueryStages(currentPhysicalPlan) @@ -637,29 +641,35 @@ case class AdaptiveSparkPlanExec( /** * Re-optimize and run physical planning on the current logical plan based on the latest stats. */ - private def reOptimize(logicalPlan: LogicalPlan): (SparkPlan, LogicalPlan) = { - logicalPlan.invalidateStatsCache() - val optimized = optimizer.execute(logicalPlan) - val sparkPlan = context.session.sessionState.planner.plan(ReturnAnswer(optimized)).next() - val newPlan = applyPhysicalRules( - sparkPlan, - preprocessingRules ++ queryStagePreparationRules, - Some((planChangeLogger, "AQE Replanning"))) - - // When both enabling AQE and DPP, `PlanAdaptiveDynamicPruningFilters` rule will - // add the `BroadcastExchangeExec` node manually in the DPP subquery, - // not through `EnsureRequirements` rule. Therefore, when the DPP subquery is complicated - // and need to be re-optimized, AQE also need to manually insert the `BroadcastExchangeExec` - // node to prevent the loss of the `BroadcastExchangeExec` node in DPP subquery. - // Here, we also need to avoid to insert the `BroadcastExchangeExec` node when the newPlan - // is already the `BroadcastExchangeExec` plan after apply the `LogicalQueryStageStrategy` rule. - val finalPlan = currentPhysicalPlan match { - case b: BroadcastExchangeLike - if (!newPlan.isInstanceOf[BroadcastExchangeLike]) => b.withNewChildren(Seq(newPlan)) - case _ => newPlan - } + private def reOptimize(logicalPlan: LogicalPlan): Option[(SparkPlan, LogicalPlan)] = { + try { + logicalPlan.invalidateStatsCache() + val optimized = optimizer.execute(logicalPlan) + val sparkPlan = context.session.sessionState.planner.plan(ReturnAnswer(optimized)).next() + val newPlan = applyPhysicalRules( + sparkPlan, + preprocessingRules ++ queryStagePreparationRules, + Some((planChangeLogger, "AQE Replanning"))) + + // When both enabling AQE and DPP, `PlanAdaptiveDynamicPruningFilters` rule will + // add the `BroadcastExchangeExec` node manually in the DPP subquery, + // not through `EnsureRequirements` rule. Therefore, when the DPP subquery is complicated + // and need to be re-optimized, AQE also need to manually insert the `BroadcastExchangeExec` + // node to prevent the loss of the `BroadcastExchangeExec` node in DPP subquery. + // Here, we also need to avoid to insert the `BroadcastExchangeExec` node when the newPlan is + // already the `BroadcastExchangeExec` plan after apply the `LogicalQueryStageStrategy` rule. + val finalPlan = currentPhysicalPlan match { + case b: BroadcastExchangeLike + if (!newPlan.isInstanceOf[BroadcastExchangeLike]) => b.withNewChildren(Seq(newPlan)) + case _ => newPlan + } - (finalPlan, optimized) + Some((finalPlan, optimized)) + } catch { + case e: InvalidAQEPlanException[_] => + logOnLevel(s"Re-optimize - ${e.getMessage()}:\n${e.plan}") + None + } } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/InvalidAQEPlanException.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/InvalidAQEPlanException.scala new file mode 100644 index 0000000000000..71f6db8b2b9cb --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/InvalidAQEPlanException.scala @@ -0,0 +1,30 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.adaptive + +import org.apache.spark.sql.catalyst.plans.QueryPlan + +/** + * Exception thrown when an invalid query plan is detected in AQE replanning, + * in which case AQE will stop the current replanning process and keep using the latest valid plan. + * + * @param message The reason why the plan is considered invalid. + * @param plan The invalid plan/sub-plan. + */ +case class InvalidAQEPlanException[QueryType <: QueryPlan[_]](message: String, plan: QueryType) + extends Exception(message) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/ValidateSparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/ValidateSparkPlan.scala new file mode 100644 index 0000000000000..0fdc50e2acc8d --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/ValidateSparkPlan.scala @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.adaptive + +import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight} +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, BroadcastNestedLoopJoinExec} + +/** + * Detects invalid physical plans generated by AQE replanning and throws `InvalidAQEPlanException` + * if such plans are detected. This rule should be called after EnsureRequirements where all + * necessary Exchange nodes are added. + */ +object ValidateSparkPlan extends Rule[SparkPlan] { + + def apply(plan: SparkPlan): SparkPlan = { + validate(plan) + plan + } + + /** + * Validate that the plan satisfies the following condition: + * - BroadcastQueryStage only appears as the immediate child and the build side of a broadcast + * hash join or broadcast nested loop join. + */ + private def validate(plan: SparkPlan): Unit = plan match { + case b: BroadcastHashJoinExec => + val (buildPlan, probePlan) = b.buildSide match { + case BuildLeft => (b.left, b.right) + case BuildRight => (b.right, b.left) + } + if (!buildPlan.isInstanceOf[BroadcastQueryStageExec]) { + validate(buildPlan) + } + validate(probePlan) + case b: BroadcastNestedLoopJoinExec => + val (buildPlan, probePlan) = b.buildSide match { + case BuildLeft => (b.left, b.right) + case BuildRight => (b.right, b.left) + } + if (!buildPlan.isInstanceOf[BroadcastQueryStageExec]) { + validate(buildPlan) + } + validate(probePlan) + case q: BroadcastQueryStageExec => errorOnInvalidBroadcastQueryStage(q) + case _ => plan.children.foreach(validate) + } + + private def errorOnInvalidBroadcastQueryStage(plan: SparkPlan): Unit = { + throw InvalidAQEPlanException("Invalid broadcast query stage", plan) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala index 7ae162ca8ad49..2635c55dedf7a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala @@ -32,7 +32,7 @@ import org.apache.spark.sql.execution.command.DataWritingCommandExec import org.apache.spark.sql.execution.datasources.noop.NoopDataSource import org.apache.spark.sql.execution.datasources.v2.V2TableWriteExec import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ENSURE_REQUIREMENTS, Exchange, REPARTITION_BY_COL, REPARTITION_BY_NUM, ReusedExchangeExec, ShuffleExchangeExec, ShuffleExchangeLike, ShuffleOrigin} -import org.apache.spark.sql.execution.joins.{BaseJoinExec, BroadcastHashJoinExec, ShuffledHashJoinExec, ShuffledJoin, SortMergeJoinExec} +import org.apache.spark.sql.execution.joins.{BaseJoinExec, BroadcastHashJoinExec, BroadcastNestedLoopJoinExec, ShuffledHashJoinExec, ShuffledJoin, SortMergeJoinExec} import org.apache.spark.sql.execution.metric.SQLShuffleReadMetricsReporter import org.apache.spark.sql.execution.ui.SparkListenerSQLAdaptiveExecutionUpdate import org.apache.spark.sql.functions._ @@ -102,6 +102,12 @@ class AdaptiveQueryExecSuite } } + def findTopLevelBroadcastNestedLoopJoin(plan: SparkPlan): Seq[BaseJoinExec] = { + collect(plan) { + case j: BroadcastNestedLoopJoinExec => j + } + } + private def findTopLevelSortMergeJoin(plan: SparkPlan): Seq[SortMergeJoinExec] = { collect(plan) { case j: SortMergeJoinExec => j @@ -2085,6 +2091,23 @@ class AdaptiveQueryExecSuite assert(bhj.length == 1) } } + + test("SPARK-39551: Invalid plan check - invalid broadcast query stage") { + withSQLConf( + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true") { + val (_, adaptivePlan) = runAdaptiveAndVerifyResult( + """ + |SELECT /*+ BROADCAST(t3) */ t3.b, count(t3.a) FROM testData2 t1 + |INNER JOIN testData2 t2 + |ON t1.b = t2.b AND t1.a = 0 + |RIGHT OUTER JOIN testData2 t3 + |ON t1.a > t3.a + |GROUP BY t3.b + """.stripMargin + ) + assert(findTopLevelBroadcastNestedLoopJoin(adaptivePlan).size == 1) + } + } } /**