Skip to content

Commit be891ad

Browse files
maryannxuedongjoon-hyun
authored andcommitted
[SPARK-39551][SQL][3.2] Add AQE invalid plan check
### What changes were proposed in this pull request? This is a backport of #36953 This PR adds a check for invalid plans in AQE replanning process. The check will throw exceptions when it detects an invalid plan, causing AQE to void the current replanning result and keep using the latest valid plan. ### Why are the changes needed? AQE logical optimization rules can lead to invalid physical plans and cause runtime exceptions as certain physical plan nodes are not compatible with others. E.g., `BroadcastExchangeExec` can only work as a direct child of broadcast join nodes, but it could appear under other incompatible physical plan nodes because of empty relation propagation. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Added UT. Closes #37108 from dongjoon-hyun/SPARK-39551. Authored-by: Maryann Xue <maryann.xue@gmail.com> Signed-off-by: Dongjoon Hyun <dongjoon@apache.org>
1 parent 1c0bd4c commit be891ad

File tree

4 files changed

+163
-32
lines changed

4 files changed

+163
-32
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala

+41-31
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,7 @@ case class AdaptiveSparkPlanExec(
107107
// `EnsureRequirements` to not optimize out the user-specified repartition-by-col to work
108108
// around this case.
109109
EnsureRequirements(optimizeOutRepartition = requiredDistribution.isDefined),
110+
ValidateSparkPlan,
110111
RemoveRedundantSorts,
111112
DisableUnnecessaryBucketedScan
112113
) ++ context.session.sessionState.queryStagePrepRules
@@ -295,16 +296,19 @@ case class AdaptiveSparkPlanExec(
295296
// plans are updated, we can clear the query stage list because at this point the two plans
296297
// are semantically and physically in sync again.
297298
val logicalPlan = replaceWithQueryStagesInLogicalPlan(currentLogicalPlan, stagesToReplace)
298-
val (newPhysicalPlan, newLogicalPlan) = reOptimize(logicalPlan)
299-
val origCost = costEvaluator.evaluateCost(currentPhysicalPlan)
300-
val newCost = costEvaluator.evaluateCost(newPhysicalPlan)
301-
if (newCost < origCost ||
299+
val afterReOptimize = reOptimize(logicalPlan)
300+
if (afterReOptimize.isDefined) {
301+
val (newPhysicalPlan, newLogicalPlan) = afterReOptimize.get
302+
val origCost = costEvaluator.evaluateCost(currentPhysicalPlan)
303+
val newCost = costEvaluator.evaluateCost(newPhysicalPlan)
304+
if (newCost < origCost ||
302305
(newCost == origCost && currentPhysicalPlan != newPhysicalPlan)) {
303-
logOnLevel(s"Plan changed from $currentPhysicalPlan to $newPhysicalPlan")
304-
cleanUpTempTags(newPhysicalPlan)
305-
currentPhysicalPlan = newPhysicalPlan
306-
currentLogicalPlan = newLogicalPlan
307-
stagesToReplace = Seq.empty[QueryStageExec]
306+
logOnLevel(s"Plan changed from $currentPhysicalPlan to $newPhysicalPlan")
307+
cleanUpTempTags(newPhysicalPlan)
308+
currentPhysicalPlan = newPhysicalPlan
309+
currentLogicalPlan = newLogicalPlan
310+
stagesToReplace = Seq.empty[QueryStageExec]
311+
}
308312
}
309313
// Now that some stages have finished, we can try creating new stages.
310314
result = createQueryStages(currentPhysicalPlan)
@@ -637,29 +641,35 @@ case class AdaptiveSparkPlanExec(
637641
/**
638642
* Re-optimize and run physical planning on the current logical plan based on the latest stats.
639643
*/
640-
private def reOptimize(logicalPlan: LogicalPlan): (SparkPlan, LogicalPlan) = {
641-
logicalPlan.invalidateStatsCache()
642-
val optimized = optimizer.execute(logicalPlan)
643-
val sparkPlan = context.session.sessionState.planner.plan(ReturnAnswer(optimized)).next()
644-
val newPlan = applyPhysicalRules(
645-
sparkPlan,
646-
preprocessingRules ++ queryStagePreparationRules,
647-
Some((planChangeLogger, "AQE Replanning")))
648-
649-
// When both enabling AQE and DPP, `PlanAdaptiveDynamicPruningFilters` rule will
650-
// add the `BroadcastExchangeExec` node manually in the DPP subquery,
651-
// not through `EnsureRequirements` rule. Therefore, when the DPP subquery is complicated
652-
// and need to be re-optimized, AQE also need to manually insert the `BroadcastExchangeExec`
653-
// node to prevent the loss of the `BroadcastExchangeExec` node in DPP subquery.
654-
// Here, we also need to avoid to insert the `BroadcastExchangeExec` node when the newPlan
655-
// is already the `BroadcastExchangeExec` plan after apply the `LogicalQueryStageStrategy` rule.
656-
val finalPlan = currentPhysicalPlan match {
657-
case b: BroadcastExchangeLike
658-
if (!newPlan.isInstanceOf[BroadcastExchangeLike]) => b.withNewChildren(Seq(newPlan))
659-
case _ => newPlan
660-
}
644+
private def reOptimize(logicalPlan: LogicalPlan): Option[(SparkPlan, LogicalPlan)] = {
645+
try {
646+
logicalPlan.invalidateStatsCache()
647+
val optimized = optimizer.execute(logicalPlan)
648+
val sparkPlan = context.session.sessionState.planner.plan(ReturnAnswer(optimized)).next()
649+
val newPlan = applyPhysicalRules(
650+
sparkPlan,
651+
preprocessingRules ++ queryStagePreparationRules,
652+
Some((planChangeLogger, "AQE Replanning")))
653+
654+
// When both enabling AQE and DPP, `PlanAdaptiveDynamicPruningFilters` rule will
655+
// add the `BroadcastExchangeExec` node manually in the DPP subquery,
656+
// not through `EnsureRequirements` rule. Therefore, when the DPP subquery is complicated
657+
// and need to be re-optimized, AQE also need to manually insert the `BroadcastExchangeExec`
658+
// node to prevent the loss of the `BroadcastExchangeExec` node in DPP subquery.
659+
// Here, we also need to avoid to insert the `BroadcastExchangeExec` node when the newPlan is
660+
// already the `BroadcastExchangeExec` plan after apply the `LogicalQueryStageStrategy` rule.
661+
val finalPlan = currentPhysicalPlan match {
662+
case b: BroadcastExchangeLike
663+
if (!newPlan.isInstanceOf[BroadcastExchangeLike]) => b.withNewChildren(Seq(newPlan))
664+
case _ => newPlan
665+
}
661666

662-
(finalPlan, optimized)
667+
Some((finalPlan, optimized))
668+
} catch {
669+
case e: InvalidAQEPlanException[_] =>
670+
logOnLevel(s"Re-optimize - ${e.getMessage()}:\n${e.plan}")
671+
None
672+
}
663673
}
664674

665675
/**
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.execution.adaptive
19+
20+
import org.apache.spark.sql.catalyst.plans.QueryPlan
21+
22+
/**
23+
* Exception thrown when an invalid query plan is detected in AQE replanning,
24+
* in which case AQE will stop the current replanning process and keep using the latest valid plan.
25+
*
26+
* @param message The reason why the plan is considered invalid.
27+
* @param plan The invalid plan/sub-plan.
28+
*/
29+
case class InvalidAQEPlanException[QueryType <: QueryPlan[_]](message: String, plan: QueryType)
30+
extends Exception(message)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.execution.adaptive
19+
20+
import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight}
21+
import org.apache.spark.sql.catalyst.rules.Rule
22+
import org.apache.spark.sql.execution.SparkPlan
23+
import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, BroadcastNestedLoopJoinExec}
24+
25+
/**
26+
* Detects invalid physical plans generated by AQE replanning and throws `InvalidAQEPlanException`
27+
* if such plans are detected. This rule should be called after EnsureRequirements where all
28+
* necessary Exchange nodes are added.
29+
*/
30+
object ValidateSparkPlan extends Rule[SparkPlan] {
31+
32+
def apply(plan: SparkPlan): SparkPlan = {
33+
validate(plan)
34+
plan
35+
}
36+
37+
/**
38+
* Validate that the plan satisfies the following condition:
39+
* - BroadcastQueryStage only appears as the immediate child and the build side of a broadcast
40+
* hash join or broadcast nested loop join.
41+
*/
42+
private def validate(plan: SparkPlan): Unit = plan match {
43+
case b: BroadcastHashJoinExec =>
44+
val (buildPlan, probePlan) = b.buildSide match {
45+
case BuildLeft => (b.left, b.right)
46+
case BuildRight => (b.right, b.left)
47+
}
48+
if (!buildPlan.isInstanceOf[BroadcastQueryStageExec]) {
49+
validate(buildPlan)
50+
}
51+
validate(probePlan)
52+
case b: BroadcastNestedLoopJoinExec =>
53+
val (buildPlan, probePlan) = b.buildSide match {
54+
case BuildLeft => (b.left, b.right)
55+
case BuildRight => (b.right, b.left)
56+
}
57+
if (!buildPlan.isInstanceOf[BroadcastQueryStageExec]) {
58+
validate(buildPlan)
59+
}
60+
validate(probePlan)
61+
case q: BroadcastQueryStageExec => errorOnInvalidBroadcastQueryStage(q)
62+
case _ => plan.children.foreach(validate)
63+
}
64+
65+
private def errorOnInvalidBroadcastQueryStage(plan: SparkPlan): Unit = {
66+
throw InvalidAQEPlanException("Invalid broadcast query stage", plan)
67+
}
68+
}

sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala

+24-1
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ import org.apache.spark.sql.execution.command.DataWritingCommandExec
3232
import org.apache.spark.sql.execution.datasources.noop.NoopDataSource
3333
import org.apache.spark.sql.execution.datasources.v2.V2TableWriteExec
3434
import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ENSURE_REQUIREMENTS, Exchange, REPARTITION_BY_COL, REPARTITION_BY_NUM, ReusedExchangeExec, ShuffleExchangeExec, ShuffleExchangeLike, ShuffleOrigin}
35-
import org.apache.spark.sql.execution.joins.{BaseJoinExec, BroadcastHashJoinExec, ShuffledHashJoinExec, ShuffledJoin, SortMergeJoinExec}
35+
import org.apache.spark.sql.execution.joins.{BaseJoinExec, BroadcastHashJoinExec, BroadcastNestedLoopJoinExec, ShuffledHashJoinExec, ShuffledJoin, SortMergeJoinExec}
3636
import org.apache.spark.sql.execution.metric.SQLShuffleReadMetricsReporter
3737
import org.apache.spark.sql.execution.ui.SparkListenerSQLAdaptiveExecutionUpdate
3838
import org.apache.spark.sql.functions._
@@ -102,6 +102,12 @@ class AdaptiveQueryExecSuite
102102
}
103103
}
104104

105+
def findTopLevelBroadcastNestedLoopJoin(plan: SparkPlan): Seq[BaseJoinExec] = {
106+
collect(plan) {
107+
case j: BroadcastNestedLoopJoinExec => j
108+
}
109+
}
110+
105111
private def findTopLevelSortMergeJoin(plan: SparkPlan): Seq[SortMergeJoinExec] = {
106112
collect(plan) {
107113
case j: SortMergeJoinExec => j
@@ -2085,6 +2091,23 @@ class AdaptiveQueryExecSuite
20852091
assert(bhj.length == 1)
20862092
}
20872093
}
2094+
2095+
test("SPARK-39551: Invalid plan check - invalid broadcast query stage") {
2096+
withSQLConf(
2097+
SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true") {
2098+
val (_, adaptivePlan) = runAdaptiveAndVerifyResult(
2099+
"""
2100+
|SELECT /*+ BROADCAST(t3) */ t3.b, count(t3.a) FROM testData2 t1
2101+
|INNER JOIN testData2 t2
2102+
|ON t1.b = t2.b AND t1.a = 0
2103+
|RIGHT OUTER JOIN testData2 t3
2104+
|ON t1.a > t3.a
2105+
|GROUP BY t3.b
2106+
""".stripMargin
2107+
)
2108+
assert(findTopLevelBroadcastNestedLoopJoin(adaptivePlan).size == 1)
2109+
}
2110+
}
20882111
}
20892112

20902113
/**

0 commit comments

Comments
 (0)