From b55d8b04e007a112fff090b50db6b7c06c7bb63e Mon Sep 17 00:00:00 2001 From: Andrew Coleman Date: Thu, 3 Oct 2024 23:27:08 +0100 Subject: [PATCH] feat(spark): support UNION ALL in SparkSql (#301) --- .../io/substrait/spark/logical/ToLogicalPlan.scala | 11 +++++++++++ .../io/substrait/spark/logical/ToSubstraitRel.scala | 11 +++++++++++ .../src/test/scala/io/substrait/spark/TPCDSPlan.scala | 10 +++++----- 3 files changed, 27 insertions(+), 5 deletions(-) diff --git a/spark/src/main/scala/io/substrait/spark/logical/ToLogicalPlan.scala b/spark/src/main/scala/io/substrait/spark/logical/ToLogicalPlan.scala index 45b6c2205..24c57bf64 100644 --- a/spark/src/main/scala/io/substrait/spark/logical/ToLogicalPlan.scala +++ b/spark/src/main/scala/io/substrait/spark/logical/ToLogicalPlan.scala @@ -36,6 +36,7 @@ import io.substrait.plan.Plan import io.substrait.relation import io.substrait.relation.Expand.{ConsistentField, SwitchingField} import io.substrait.relation.LocalFiles +import io.substrait.relation.Set.SetOp import org.apache.hadoop.fs.Path import scala.collection.JavaConverters.asScalaBufferConverter @@ -225,6 +226,16 @@ class ToLogicalPlan(spark: SparkSession) extends DefaultRelVisitor[LogicalPlan] } } + override def visit(set: relation.Set): LogicalPlan = { + val children = set.getInputs.asScala.map(_.accept(this)) + withOutput(children.flatMap(_.output)) { + set.getSetOp match { + case SetOp.UNION_ALL => Union(children, byName = false, allowMissingCol = false) + case op => throw new UnsupportedOperationException(s"Operation not currently supported: $op") + } + } + } + override def visit(emptyScan: relation.EmptyScan): LogicalPlan = { LocalRelation(ToSubstraitType.toAttribute(emptyScan.getInitialSchema)) } diff --git a/spark/src/main/scala/io/substrait/spark/logical/ToSubstraitRel.scala b/spark/src/main/scala/io/substrait/spark/logical/ToSubstraitRel.scala index 08a06c2e4..01ae14396 100644 --- a/spark/src/main/scala/io/substrait/spark/logical/ToSubstraitRel.scala +++ b/spark/src/main/scala/io/substrait/spark/logical/ToSubstraitRel.scala @@ -35,6 +35,7 @@ import io.substrait.extension.ExtensionCollector import io.substrait.hint.Hint import io.substrait.plan.{ImmutablePlan, ImmutableRoot, Plan} import io.substrait.relation.RelProtoConverter +import io.substrait.relation.Set.SetOp import io.substrait.relation.files.{FileFormat, ImmutableFileOrFiles} import io.substrait.relation.files.FileOrFiles.PathType @@ -268,6 +269,16 @@ class ToSubstraitRel extends AbstractLogicalPlanVisitor with Logging { relation.Sort.builder.addAllSortFields(fields).input(input).build } + override def visitUnion(union: Union): relation.Rel = { + if (union.byName) { + throw new UnsupportedOperationException("Union by column name is not supported") + } + relation.Set.builder + .inputs(union.children.map(c => visit(c)).asJava) + .setOp(SetOp.UNION_ALL) + .build() + } + private def toExpression(output: Seq[Attribute])(e: Expression): SExpression = { toSubstraitExp(e, output) } diff --git a/spark/src/test/scala/io/substrait/spark/TPCDSPlan.scala b/spark/src/test/scala/io/substrait/spark/TPCDSPlan.scala index f880b25a4..826d7200c 100644 --- a/spark/src/test/scala/io/substrait/spark/TPCDSPlan.scala +++ b/spark/src/test/scala/io/substrait/spark/TPCDSPlan.scala @@ -34,12 +34,12 @@ class TPCDSPlan extends TPCDSBase with SubstraitPlanTestBase { // "q9" failed in spark 3.3 val successfulSQL: Set[String] = Set("q1", "q3", "q4", "q7", "q11", "q13", "q15", "q16", "q18", "q19", - "q22", "q25", "q26", "q28", "q29", - "q30", "q31", "q32", "q37", + "q22", "q23a", "q23b", "q25", "q26", "q28", "q29", + "q30", "q31", "q32", "q33", "q37", "q41", "q42", "q43", "q46", "q48", - "q50", "q52", "q55", "q58", "q59", - "q61", "q62", "q65", "q68", "q69", - "q79", + "q50", "q52", "q54", "q55", "q56", "q58", "q59", + "q60", "q61", "q62", "q65", "q66", "q68", "q69", + "q71", "q76", "q79", "q81", "q82", "q85", "q88", "q90", "q91", "q92", "q93", "q94", "q95", "q96", "q97", "q99")