Skip to content

Commit

Permalink
[SPARK-39840][SQL][PYTHON] Factor PythonArrowInput out as a symmetry …
Browse files Browse the repository at this point in the history
…to PythonArrowOutput

### What changes were proposed in this pull request?

This PR factors the Arrow input code path out as `PythonArrowInput` as symmetry to `PythonArrowOutput`. The current hierarchy is not affected:

```
    └── BasePythonRunner
        ├── ArrowPythonRunner with PythonArrowOutput with PythonArrowInput
        ├── CoGroupedArrowPythonRunner with PythonArrowOutput
        ├── PythonRunner
        └── PythonUDFRunner
```

In addition, this PR also factors out `handleMetadataAfterExec` and `handleMetadataBeforeExec` which contains the logic to send and receive the metadata such as runtime configurations specific to Arrow in/out.

### Why are the changes needed?

40485f4 factored `PythonArrowOutput` out. It's better to factor `PythonArrowInput` out too to be consistent

### Does this PR introduce _any_ user-facing change?

No, this is refactoring.

### How was this patch tested?

Existing test cases should cover.

Closes #37253 from HyukjinKwon/pyarrow-output-trait.

Authored-by: Hyukjin Kwon <gurwls223@apache.org>
Signed-off-by: Hyukjin Kwon <gurwls223@apache.org>
  • Loading branch information
HyukjinKwon committed Jul 25, 2022
1 parent 1b6cdf1 commit 2e1467f
Show file tree
Hide file tree
Showing 4 changed files with 123 additions and 81 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -84,9 +84,9 @@ private object BasePythonRunner {
* functions (from bottom to top).
*/
private[spark] abstract class BasePythonRunner[IN, OUT](
funcs: Seq[ChainedPythonFunctions],
evalType: Int,
argOffsets: Array[Array[Int]])
protected val funcs: Seq[ChainedPythonFunctions],
protected val evalType: Int,
protected val argOffsets: Array[Array[Int]])
extends Logging {

require(funcs.length == argOffsets.length, "argOffsets should have the same length as funcs")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,21 +17,11 @@

package org.apache.spark.sql.execution.python

import java.io._
import java.net._

import org.apache.arrow.vector.VectorSchemaRoot
import org.apache.arrow.vector.ipc.ArrowStreamWriter

import org.apache.spark._
import org.apache.spark.api.python._
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.execution.arrow.ArrowWriter
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
import org.apache.spark.sql.util.ArrowUtils
import org.apache.spark.sql.vectorized.ColumnarBatch
import org.apache.spark.util.Utils

/**
* Similar to `PythonUDFRunner`, but exchange data with Python worker via Arrow stream.
Expand All @@ -40,10 +30,11 @@ class ArrowPythonRunner(
funcs: Seq[ChainedPythonFunctions],
evalType: Int,
argOffsets: Array[Array[Int]],
schema: StructType,
timeZoneId: String,
conf: Map[String, String])
protected override val schema: StructType,
protected override val timeZoneId: String,
protected override val workerConf: Map[String, String])
extends BasePythonRunner[Iterator[InternalRow], ColumnarBatch](funcs, evalType, argOffsets)
with PythonArrowInput
with PythonArrowOutput {

override val simplifiedTraceback: Boolean = SQLConf.get.pysparkSimplifiedTraceback
Expand All @@ -53,69 +44,4 @@ class ArrowPythonRunner(
bufferSize >= 4,
"Pandas execution requires more than 4 bytes. Please set higher buffer. " +
s"Please change '${SQLConf.PANDAS_UDF_BUFFER_SIZE.key}'.")

protected override def newWriterThread(
env: SparkEnv,
worker: Socket,
inputIterator: Iterator[Iterator[InternalRow]],
partitionIndex: Int,
context: TaskContext): WriterThread = {
new WriterThread(env, worker, inputIterator, partitionIndex, context) {

protected override def writeCommand(dataOut: DataOutputStream): Unit = {

// Write config for the worker as a number of key -> value pairs of strings
dataOut.writeInt(conf.size)
for ((k, v) <- conf) {
PythonRDD.writeUTF(k, dataOut)
PythonRDD.writeUTF(v, dataOut)
}

PythonUDFRunner.writeUDFs(dataOut, funcs, argOffsets)
}

protected override def writeIteratorToStream(dataOut: DataOutputStream): Unit = {
val arrowSchema = ArrowUtils.toArrowSchema(schema, timeZoneId)
val allocator = ArrowUtils.rootAllocator.newChildAllocator(
s"stdout writer for $pythonExec", 0, Long.MaxValue)
val root = VectorSchemaRoot.create(arrowSchema, allocator)

Utils.tryWithSafeFinally {
val arrowWriter = ArrowWriter.create(root)
val writer = new ArrowStreamWriter(root, null, dataOut)
writer.start()

while (inputIterator.hasNext) {
val nextBatch = inputIterator.next()

while (nextBatch.hasNext) {
arrowWriter.write(nextBatch.next())
}

arrowWriter.finish()
writer.writeBatch()
arrowWriter.reset()
}
// end writes footer to the output stream and doesn't clean any resources.
// It could throw exception if the output stream is closed, so it should be
// in the try block.
writer.end()
} {
// If we close root and allocator in TaskCompletionListener, there could be a race
// condition where the writer thread keeps writing to the VectorSchemaRoot while
// it's being closed by the TaskCompletion listener.
// Closing root and allocator here is cleaner because root and allocator is owned
// by the writer thread and is only visible to the writer thread.
//
// If the writer thread is interrupted by TaskCompletionListener, it should either
// (1) in the try block, in which case it will get an InterruptedException when
// performing io, and goes into the finally block or (2) in the finally block,
// in which case it will ignore the interruption and close the resources.
root.close()
allocator.close()
}
}
}
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.execution.python

import java.io.DataOutputStream
import java.net.Socket

import org.apache.arrow.vector.VectorSchemaRoot
import org.apache.arrow.vector.ipc.ArrowStreamWriter

import org.apache.spark.{SparkEnv, TaskContext}
import org.apache.spark.api.python.{BasePythonRunner, PythonRDD}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.execution.arrow.ArrowWriter
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.util.ArrowUtils
import org.apache.spark.util.Utils

/**
* A trait that can be mixed-in with [[BasePythonRunner]]. It implements the logic from
* JVM (an iterator of internal rows) to Python (Arrow).
*/
private[python] trait PythonArrowInput { self: BasePythonRunner[Iterator[InternalRow], _] =>
protected val workerConf: Map[String, String]

protected val schema: StructType

protected val timeZoneId: String

protected def handleMetadataBeforeExec(stream: DataOutputStream): Unit = {
// Write config for the worker as a number of key -> value pairs of strings
stream.writeInt(workerConf.size)
for ((k, v) <- workerConf) {
PythonRDD.writeUTF(k, stream)
PythonRDD.writeUTF(v, stream)
}
}

protected override def newWriterThread(
env: SparkEnv,
worker: Socket,
inputIterator: Iterator[Iterator[InternalRow]],
partitionIndex: Int,
context: TaskContext): WriterThread = {
new WriterThread(env, worker, inputIterator, partitionIndex, context) {

protected override def writeCommand(dataOut: DataOutputStream): Unit = {
handleMetadataBeforeExec(dataOut)
PythonUDFRunner.writeUDFs(dataOut, funcs, argOffsets)
}

protected override def writeIteratorToStream(dataOut: DataOutputStream): Unit = {
val arrowSchema = ArrowUtils.toArrowSchema(schema, timeZoneId)
val allocator = ArrowUtils.rootAllocator.newChildAllocator(
s"stdout writer for $pythonExec", 0, Long.MaxValue)
val root = VectorSchemaRoot.create(arrowSchema, allocator)

Utils.tryWithSafeFinally {
val arrowWriter = ArrowWriter.create(root)
val writer = new ArrowStreamWriter(root, null, dataOut)
writer.start()

while (inputIterator.hasNext) {
val nextBatch = inputIterator.next()

while (nextBatch.hasNext) {
arrowWriter.write(nextBatch.next())
}

arrowWriter.finish()
writer.writeBatch()
arrowWriter.reset()
}
// end writes footer to the output stream and doesn't clean any resources.
// It could throw exception if the output stream is closed, so it should be
// in the try block.
writer.end()
} {
// If we close root and allocator in TaskCompletionListener, there could be a race
// condition where the writer thread keeps writing to the VectorSchemaRoot while
// it's being closed by the TaskCompletion listener.
// Closing root and allocator here is cleaner because root and allocator is owned
// by the writer thread and is only visible to the writer thread.
//
// If the writer thread is interrupted by TaskCompletionListener, it should either
// (1) in the try block, in which case it will get an InterruptedException when
// performing io, and goes into the finally block or (2) in the finally block,
// in which case it will ignore the interruption and close the resources.
root.close()
allocator.close()
}
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch, Column
*/
private[python] trait PythonArrowOutput { self: BasePythonRunner[_, ColumnarBatch] =>

protected def handleMetadataAfterExec(stream: DataInputStream): Unit = { }

protected def newReaderIterator(
stream: DataInputStream,
writerThread: WriterThread,
Expand Down Expand Up @@ -67,6 +69,11 @@ private[python] trait PythonArrowOutput { self: BasePythonRunner[_, ColumnarBatc

private var batchLoaded = true

protected override def handleEndOfDataSection(): Unit = {
handleMetadataAfterExec(stream)
super.handleEndOfDataSection()
}

protected override def read(): ColumnarBatch = {
if (writerThread.exception.isDefined) {
throw writerThread.exception.get
Expand Down

0 comments on commit 2e1467f

Please # to comment.