Skip to content

Commit

Permalink
[SPARK-42631][CONNECT] Support custom extensions in Scala client
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?

This PR adds public interfaces for creating `Dataset` and `Column` instances, and for executing commands. These interfaces only allow creating `Plan`s and `Expression`s that contain an `extension` to limit what we need to expose.

### Why are the changes needed?

Required to implement extensions to the Scala Spark Connect client.

### Does this PR introduce _any_ user-facing change?

Yes, adds new public interfaces (see above).

### How was this patch tested?

Added unit tests.

Closes #40234 from tomvanbussel/SPARK-34827.

Authored-by: Tom van Bussel <tom.vanbussel@databricks.com>
Signed-off-by: Herman van Hovell <herman@databricks.com>
(cherry picked from commit a9c5efa)
Signed-off-by: Herman van Hovell <herman@databricks.com>
  • Loading branch information
tomvanbussel authored and hvanhovell committed Mar 2, 2023
1 parent e194260 commit a1d5e89
Show file tree
Hide file tree
Showing 14 changed files with 130 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package org.apache.spark.sql

import scala.collection.JavaConverters._

import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.connect.proto
import org.apache.spark.connect.proto.Expression.SortOrder.NullOrdering
import org.apache.spark.connect.proto.Expression.SortOrder.SortDirection
Expand Down Expand Up @@ -1312,6 +1313,11 @@ private[sql] object Column {
new Column(builder.build())
}

@DeveloperApi
def apply(extension: com.google.protobuf.Any): Column = {
apply(_.setExtension(extension))
}

private[sql] def fn(name: String, inputs: Column*): Column = {
fn(name, isDistinct = false, inputs: _*)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import scala.collection.JavaConverters._
import scala.collection.mutable
import scala.util.control.NonFatal

import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.connect.proto
import org.apache.spark.sql.catalyst.encoders.AgnosticEncoder
import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{PrimitiveLongEncoder, StringEncoder, UnboundRowEncoder}
Expand Down Expand Up @@ -120,7 +121,7 @@ import org.apache.spark.util.Utils
*/
class Dataset[T] private[sql] (
val sparkSession: SparkSession,
private[sql] val plan: proto.Plan,
@DeveloperApi val plan: proto.Plan,
val encoder: AgnosticEncoder[T])
extends Serializable {
// Make sure we don't forget to set plan id.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import scala.collection.JavaConverters._

import org.apache.arrow.memory.RootAllocator

import org.apache.spark.annotation.Experimental
import org.apache.spark.annotation.{DeveloperApi, Experimental}
import org.apache.spark.connect.proto
import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.encoders.AgnosticEncoder
Expand Down Expand Up @@ -261,6 +261,18 @@ class SparkSession private[sql] (
new Dataset[T](this, plan, encoder)
}

@DeveloperApi
def newDataFrame(extension: com.google.protobuf.Any): DataFrame = {
newDataset(extension, UnboundRowEncoder)
}

@DeveloperApi
def newDataset[T](
extension: com.google.protobuf.Any,
encoder: AgnosticEncoder[T]): Dataset[T] = {
newDataset(encoder)(_.setExtension(extension))
}

private[sql] def newCommand[T](f: proto.Command.Builder => Unit): proto.Command = {
val builder = proto.Command.newBuilder()
f(builder)
Expand All @@ -287,6 +299,12 @@ class SparkSession private[sql] (
client.execute(plan).asScala.foreach(_ => ())
}

@DeveloperApi
def execute(extension: com.google.protobuf.Any): Unit = {
val command = proto.Command.newBuilder().setExtension(extension).build()
execute(command)
}

/**
* This resets the plan id generator so we can produce plans that are comparable.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -134,4 +134,16 @@ class DatasetSuite extends ConnectFunSuite with BeforeAndAfterEach {
df.groupBy().pivot(Column("c"), Seq(Column("col")))
}
}

test("command extension") {
val extension = proto.ExamplePluginCommand.newBuilder().setCustomField("abc").build()
val command = proto.Command
.newBuilder()
.setExtension(com.google.protobuf.Any.pack(extension))
.build()
val expectedPlan = proto.Plan.newBuilder().setCommand(command).build()
ss.execute(com.google.protobuf.Any.pack(extension))
val actualPlan = service.getAndClearLatestInputPlan()
assert(actualPlan.equals(expectedPlan))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import scala.collection.mutable
import scala.util.{Failure, Success, Try}

import com.google.protobuf.util.JsonFormat
import com.google.protobuf.util.JsonFormat.TypeRegistry
import io.grpc.inprocess.InProcessChannelBuilder
import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach}

Expand Down Expand Up @@ -100,7 +101,14 @@ class PlanGenerationTestSuite
"query-tests",
"test-data")

private val printer = JsonFormat.printer()
private val registry = TypeRegistry
.newBuilder()
.add(proto.ExamplePluginRelation.getDescriptor)
.add(proto.ExamplePluginExpression.getDescriptor)
.add(proto.ExamplePluginCommand.getDescriptor)
.build()

private val printer = JsonFormat.printer().usingTypeRegistry(registry)

private var session: SparkSession = _

Expand Down Expand Up @@ -2007,4 +2015,27 @@ class PlanGenerationTestSuite
fn.min("id").over(Window.orderBy("a").rangeBetween(2L, 3L)),
fn.count(Column("id")).over())
}

/* Extensions */
test("relation extension") {
val input = proto.ExamplePluginRelation
.newBuilder()
.setInput(simple.plan.getRoot)
.build()
session.newDataFrame(com.google.protobuf.Any.pack(input))
}

test("expression extension") {
val extension = proto.ExamplePluginExpression
.newBuilder()
.setChild(
proto.Expression
.newBuilder()
.setUnresolvedAttribute(proto.Expression.UnresolvedAttribute
.newBuilder()
.setUnparsedIdentifier("id")))
.setCustomField("abc")
.build()
simple.select(Column(com.google.protobuf.Any.pack(extension)))
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Project [id#0L AS abc#0L]
+- LocalRelation <empty>, [id#0L, a#0, b#0]
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
LocalRelation <empty>, [id#0L, a#0, b#0]
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
"schema": "struct\u003cid:bigint,a:int,b:double\u003e"
}
},
"cols": [{
"columns": [{
"unresolvedAttribute": {
"unparsedIdentifier": "b"
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
"schema": "struct\u003cid:bigint,a:int,b:double\u003e"
}
},
"cols": [{
"columns": [{
"unresolvedAttribute": {
"unparsedIdentifier": "b"
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
{
"common": {
"planId": "1"
},
"project": {
"input": {
"common": {
"planId": "0"
},
"localRelation": {
"schema": "struct\u003cid:bigint,a:int,b:double\u003e"
}
},
"expressions": [{
"extension": {
"@type": "type.googleapis.com/spark.connect.ExamplePluginExpression",
"child": {
"unresolvedAttribute": {
"unparsedIdentifier": "id"
}
},
"customField": "abc"
}
}]
}
}
Binary file not shown.
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
{
"common": {
"planId": "1"
},
"extension": {
"@type": "type.googleapis.com/spark.connect.ExamplePluginRelation",
"input": {
"common": {
"planId": "0"
},
"localRelation": {
"schema": "struct\u003cid:bigint,a:int,b:double\u003e"
}
}
}
}
Binary file not shown.
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,13 @@ import java.util

import scala.util.{Failure, Success, Try}

import org.apache.spark.SparkFunSuite
import org.apache.spark.{SparkConf, SparkFunSuite}
import org.apache.spark.connect.proto
import org.apache.spark.sql.catalyst.{catalog, QueryPlanningTracker}
import org.apache.spark.sql.catalyst.analysis.{caseSensitiveResolution, Analyzer, FunctionRegistry, Resolver, TableFunctionRegistry}
import org.apache.spark.sql.catalyst.catalog.SessionCatalog
import org.apache.spark.sql.catalyst.optimizer.ReplaceExpressions
import org.apache.spark.sql.connect.config.Connect
import org.apache.spark.sql.connect.planner.SparkConnectPlanner
import org.apache.spark.sql.connector.catalog.{CatalogManager, Identifier, InMemoryCatalog}
import org.apache.spark.sql.connector.expressions.Transform
Expand Down Expand Up @@ -56,6 +57,16 @@ import org.apache.spark.sql.util.CaseInsensitiveStringMap
*/
// scalastyle:on
class ProtoToParsedPlanTestSuite extends SparkFunSuite with SharedSparkSession {
override def sparkConf: SparkConf = {
super.sparkConf
.set(
Connect.CONNECT_EXTENSIONS_RELATION_CLASSES.key,
"org.apache.spark.sql.connect.plugin.ExampleRelationPlugin")
.set(
Connect.CONNECT_EXTENSIONS_EXPRESSION_CLASSES.key,
"org.apache.spark.sql.connect.plugin.ExampleExpressionPlugin")
}

protected val baseResourcePath: Path = {
getWorkspaceFilePath(
"connector",
Expand Down

0 comments on commit a1d5e89

Please # to comment.