diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala index c37e1e92c8576..07e19f6285d27 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala @@ -37,7 +37,7 @@ import org.apache.spark.sql.catalyst.util.truncatedString import org.apache.spark.sql.execution.adaptive.{AdaptiveExecutionContext, InsertAdaptiveSparkPlan} import org.apache.spark.sql.execution.bucketing.{CoalesceBucketsInJoin, DisableUnnecessaryBucketedScan} import org.apache.spark.sql.execution.dynamicpruning.PlanDynamicPruningFilters -import org.apache.spark.sql.execution.exchange.{EnsureRequirements, ReuseExchange} +import org.apache.spark.sql.execution.exchange.{EnsureRequirements, PruneShuffle, ReuseExchange} import org.apache.spark.sql.execution.streaming.{IncrementalExecution, OffsetSeqMetadata} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.streaming.OutputMode @@ -344,6 +344,7 @@ object QueryExecution { PlanSubqueries, RemoveRedundantProjects, EnsureRequirements, + PruneShuffle(), DisableUnnecessaryBucketedScan, ApplyColumnarRulesAndInsertTransitions(sparkSession.sessionState.columnarRules), CollapseCodegenStages(), diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/PruneShuffle.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/PruneShuffle.scala new file mode 100644 index 0000000000000..06bfcd155a14e --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/PruneShuffle.scala @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.exchange + +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.execution.SparkPlan + +/** + * Removes unnecessary shuffles. A shuffle can be introduced by [[Rule]]s for + * [[SparkPlan]]s, such as [[EnsureRequirements]] and then, its immediate child of + * another shuffle should be unnecessary. + */ +case class PruneShuffle() extends Rule[SparkPlan] { + + override def apply(plan: SparkPlan): SparkPlan = plan.transform { + case op @ ShuffleExchangeExec(_, child: ShuffleExchangeExec, _) => + op.withNewChildren(Seq(pruneShuffle(child))) + case other => other + } + + private def pruneShuffle(plan: SparkPlan): SparkPlan = { + plan match { + case shuffle: ShuffleExchangeExec => + pruneShuffle(shuffle.child) + case other => other + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index 048466b3d8637..5079173f64e6f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanHelper, Disable import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, ObjectHashAggregateExec, SortAggregateExec} import org.apache.spark.sql.execution.columnar.{InMemoryRelation, InMemoryTableScanExec} import org.apache.spark.sql.execution.exchange.{EnsureRequirements, ReusedExchangeExec, ReuseExchange, ShuffleExchangeExec} -import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, SortMergeJoinExec} +import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, ShuffledJoin, SortMergeJoinExec} import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession @@ -1001,6 +1001,45 @@ class PlannerSuite extends SharedSparkSession with AdaptiveSparkPlanHelper { val numPartitions = range.rdd.getNumPartitions assert(numPartitions == 0) } + + test("SPARK-32820: Remove redundant shuffle exchange") { + withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "0") { + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "200") { + val ordered = spark.range(1, 100).repartitionByRange(10, $"id".desc).orderBy($"id") + val orderedPlan = ordered.queryExecution.executedPlan.collectFirst { + case sort: SortExec => sort + }.get + val exchangeInOrdered = orderedPlan.collectFirst { + case shuffle: ShuffleExchangeExec => shuffle + }.get + + val partitioning = exchangeInOrdered.outputPartitioning + assert(partitioning.numPartitions == 200) + assert(partitioning.satisfies(orderedPlan.requiredChildDistribution.head)) + + val left = Seq(1, 2, 3).toDF.repartition(10) + val right = Seq(1, 2, 3).toDF.repartition(30, $"value") + val joined = left.join(right, left("value") + 1 === right("value") + 2) + val joinedPlan = joined.queryExecution.executedPlan.collectFirst { + case shuffledJoin: ShuffledJoin => shuffledJoin + }.get + val leftExchangesInJoined = joinedPlan.children(0).collectFirst { + case shuffle: ShuffleExchangeExec => shuffle + }.get + val rightExchangeInJoined = joinedPlan.children(1).collectFirst { + case shuffle: ShuffleExchangeExec => shuffle + }.get + + val leftPartitioning = leftExchangesInJoined.outputPartitioning + assert(leftPartitioning.numPartitions == 200) + assert(leftPartitioning.satisfies(joinedPlan.requiredChildDistribution(0))) + + val rightPartitioning = rightExchangeInJoined.outputPartitioning + assert(rightPartitioning.numPartitions == 200) + assert(rightPartitioning.satisfies(joinedPlan.requiredChildDistribution(1))) + } + } + } } // Used for unit-testing EnsureRequirements