From d22c07f474bcc6d7894f455f8b4f723b313638a8 Mon Sep 17 00:00:00 2001 From: Andrew Coleman Date: Fri, 11 Oct 2024 20:45:41 +0100 Subject: [PATCH] feat(spark): support for FetchRel offset field (#296) --- .../spark/logical/ToLogicalPlan.scala | 19 +++++++--- .../spark/logical/ToSubstraitRel.scala | 36 +++++++++++++------ .../scala/io/substrait/spark/TPCHPlan.scala | 8 ++++- 3 files changed, 46 insertions(+), 17 deletions(-) diff --git a/spark/src/main/scala/io/substrait/spark/logical/ToLogicalPlan.scala b/spark/src/main/scala/io/substrait/spark/logical/ToLogicalPlan.scala index 24c57bf64..3740babee 100644 --- a/spark/src/main/scala/io/substrait/spark/logical/ToLogicalPlan.scala +++ b/spark/src/main/scala/io/substrait/spark/logical/ToLogicalPlan.scala @@ -161,11 +161,20 @@ class ToLogicalPlan(spark: SparkSession) extends DefaultRelVisitor[LogicalPlan] } override def visit(fetch: relation.Fetch): LogicalPlan = { val child = fetch.getInput.accept(this) - val limit = Literal(fetch.getCount.getAsLong.intValue(), IntegerType) - fetch.getOffset match { - case 1L => GlobalLimit(limitExpr = limit, child = child) - case -1L => LocalLimit(limitExpr = limit, child = child) - case _ => visitFallback(fetch) + val limit = fetch.getCount.getAsLong.intValue() + val offset = fetch.getOffset.intValue() + val toLiteral = (i: Int) => Literal(i, IntegerType) + if (limit >= 0) { + val limitExpr = toLiteral(limit) + if (offset > 0) { + GlobalLimit(limitExpr, + Offset(toLiteral(offset), + LocalLimit(toLiteral(offset + limit), child))) + } else { + GlobalLimit(limitExpr, LocalLimit(limitExpr, child)) + } + } else { + Offset(toLiteral(offset), child) } } override def visit(sort: relation.Sort): LogicalPlan = { diff --git a/spark/src/main/scala/io/substrait/spark/logical/ToSubstraitRel.scala b/spark/src/main/scala/io/substrait/spark/logical/ToSubstraitRel.scala index 01ae14396..46d00f8cd 100644 --- a/spark/src/main/scala/io/substrait/spark/logical/ToSubstraitRel.scala +++ b/spark/src/main/scala/io/substrait/spark/logical/ToSubstraitRel.scala @@ -171,23 +171,37 @@ class ToSubstraitRel extends AbstractLogicalPlanVisitor with Logging { case other => throw new UnsupportedOperationException(s"Unknown type: $other") } - private def fetchBuilder(limit: Long, global: Boolean): relation.ImmutableFetch.Builder = { - val offset = if (global) 1L else -1L - relation.Fetch - .builder() - .count(limit) + private def fetch(child: LogicalPlan, offset: Long, limit: Long = -1): relation.Fetch = { + relation.Fetch.builder() + .input(visit(child)) .offset(offset) + .count(limit) + .build() } + override def visitGlobalLimit(p: GlobalLimit): relation.Rel = { - fetchBuilder(asLong(p.limitExpr), global = true) - .input(visit(p.child)) - .build() + p match { + case OffsetAndLimit((offset, limit, child)) => fetch(child, offset, limit) + case GlobalLimit(IntegerLiteral(globalLimit), LocalLimit(IntegerLiteral(localLimit), child)) + if globalLimit == localLimit => fetch(child, 0, localLimit) + case _ => + throw new UnsupportedOperationException(s"Unable to convert the limit expression: $p") + } } override def visitLocalLimit(p: LocalLimit): relation.Rel = { - fetchBuilder(asLong(p.limitExpr), global = false) - .input(visit(p.child)) - .build() + val localLimit = asLong(p.limitExpr) + p.child match { + case OffsetAndLimit((offset, limit, child)) if localLimit >= limit => + fetch(child, offset, limit) + case GlobalLimit(IntegerLiteral(globalLimit), child) if localLimit >= globalLimit => + fetch(child, 0, globalLimit) + case _ => fetch(p.child, 0, localLimit) + } + } + + override def visitOffset(p: Offset): relation.Rel = { + fetch(p.child, asLong(p.offsetExpr)) } override def visitFilter(p: Filter): relation.Rel = { diff --git a/spark/src/test/scala/io/substrait/spark/TPCHPlan.scala b/spark/src/test/scala/io/substrait/spark/TPCHPlan.scala index 224ac2e8d..76c5b9a6a 100644 --- a/spark/src/test/scala/io/substrait/spark/TPCHPlan.scala +++ b/spark/src/test/scala/io/substrait/spark/TPCHPlan.scala @@ -73,10 +73,16 @@ class TPCHPlan extends TPCHBase with SubstraitPlanTestBase { "order by l_shipdate asc, l_discount desc nulls last") } - ignore("simpleOffsetClause") { // TODO need to implement the 'offset' clause for this to pass + test("simpleOffsetClause") { assertSqlSubstraitRelRoundTrip( "select l_partkey from lineitem where l_shipdate < date '1998-01-01' " + "order by l_shipdate asc, l_discount desc limit 100 offset 1000") + assertSqlSubstraitRelRoundTrip( + "select l_partkey from lineitem where l_shipdate < date '1998-01-01' " + + "order by l_shipdate asc, l_discount desc offset 1000") + assertSqlSubstraitRelRoundTrip( + "select l_partkey from lineitem where l_shipdate < date '1998-01-01' " + + "order by l_shipdate asc, l_discount desc limit 100") } test("simpleTest") {