Skip to content

Commit 2e1467f

Browse files
committed
[SPARK-39840][SQL][PYTHON] Factor PythonArrowInput out as a symmetry to PythonArrowOutput
### What changes were proposed in this pull request? This PR factors the Arrow input code path out as `PythonArrowInput` as symmetry to `PythonArrowOutput`. The current hierarchy is not affected: ``` └── BasePythonRunner ├── ArrowPythonRunner with PythonArrowOutput with PythonArrowInput ├── CoGroupedArrowPythonRunner with PythonArrowOutput ├── PythonRunner └── PythonUDFRunner ``` In addition, this PR also factors out `handleMetadataAfterExec` and `handleMetadataBeforeExec` which contains the logic to send and receive the metadata such as runtime configurations specific to Arrow in/out. ### Why are the changes needed? 40485f4 factored `PythonArrowOutput` out. It's better to factor `PythonArrowInput` out too to be consistent ### Does this PR introduce _any_ user-facing change? No, this is refactoring. ### How was this patch tested? Existing test cases should cover. Closes #37253 from HyukjinKwon/pyarrow-output-trait. Authored-by: Hyukjin Kwon <gurwls223@apache.org> Signed-off-by: Hyukjin Kwon <gurwls223@apache.org>
1 parent 1b6cdf1 commit 2e1467f

File tree

4 files changed

+123
-81
lines changed

4 files changed

+123
-81
lines changed

core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -84,9 +84,9 @@ private object BasePythonRunner {
8484
* functions (from bottom to top).
8585
*/
8686
private[spark] abstract class BasePythonRunner[IN, OUT](
87-
funcs: Seq[ChainedPythonFunctions],
88-
evalType: Int,
89-
argOffsets: Array[Array[Int]])
87+
protected val funcs: Seq[ChainedPythonFunctions],
88+
protected val evalType: Int,
89+
protected val argOffsets: Array[Array[Int]])
9090
extends Logging {
9191

9292
require(funcs.length == argOffsets.length, "argOffsets should have the same length as funcs")

sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala

Lines changed: 4 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -17,21 +17,11 @@
1717

1818
package org.apache.spark.sql.execution.python
1919

20-
import java.io._
21-
import java.net._
22-
23-
import org.apache.arrow.vector.VectorSchemaRoot
24-
import org.apache.arrow.vector.ipc.ArrowStreamWriter
25-
26-
import org.apache.spark._
2720
import org.apache.spark.api.python._
2821
import org.apache.spark.sql.catalyst.InternalRow
29-
import org.apache.spark.sql.execution.arrow.ArrowWriter
3022
import org.apache.spark.sql.internal.SQLConf
3123
import org.apache.spark.sql.types._
32-
import org.apache.spark.sql.util.ArrowUtils
3324
import org.apache.spark.sql.vectorized.ColumnarBatch
34-
import org.apache.spark.util.Utils
3525

3626
/**
3727
* Similar to `PythonUDFRunner`, but exchange data with Python worker via Arrow stream.
@@ -40,10 +30,11 @@ class ArrowPythonRunner(
4030
funcs: Seq[ChainedPythonFunctions],
4131
evalType: Int,
4232
argOffsets: Array[Array[Int]],
43-
schema: StructType,
44-
timeZoneId: String,
45-
conf: Map[String, String])
33+
protected override val schema: StructType,
34+
protected override val timeZoneId: String,
35+
protected override val workerConf: Map[String, String])
4636
extends BasePythonRunner[Iterator[InternalRow], ColumnarBatch](funcs, evalType, argOffsets)
37+
with PythonArrowInput
4738
with PythonArrowOutput {
4839

4940
override val simplifiedTraceback: Boolean = SQLConf.get.pysparkSimplifiedTraceback
@@ -53,69 +44,4 @@ class ArrowPythonRunner(
5344
bufferSize >= 4,
5445
"Pandas execution requires more than 4 bytes. Please set higher buffer. " +
5546
s"Please change '${SQLConf.PANDAS_UDF_BUFFER_SIZE.key}'.")
56-
57-
protected override def newWriterThread(
58-
env: SparkEnv,
59-
worker: Socket,
60-
inputIterator: Iterator[Iterator[InternalRow]],
61-
partitionIndex: Int,
62-
context: TaskContext): WriterThread = {
63-
new WriterThread(env, worker, inputIterator, partitionIndex, context) {
64-
65-
protected override def writeCommand(dataOut: DataOutputStream): Unit = {
66-
67-
// Write config for the worker as a number of key -> value pairs of strings
68-
dataOut.writeInt(conf.size)
69-
for ((k, v) <- conf) {
70-
PythonRDD.writeUTF(k, dataOut)
71-
PythonRDD.writeUTF(v, dataOut)
72-
}
73-
74-
PythonUDFRunner.writeUDFs(dataOut, funcs, argOffsets)
75-
}
76-
77-
protected override def writeIteratorToStream(dataOut: DataOutputStream): Unit = {
78-
val arrowSchema = ArrowUtils.toArrowSchema(schema, timeZoneId)
79-
val allocator = ArrowUtils.rootAllocator.newChildAllocator(
80-
s"stdout writer for $pythonExec", 0, Long.MaxValue)
81-
val root = VectorSchemaRoot.create(arrowSchema, allocator)
82-
83-
Utils.tryWithSafeFinally {
84-
val arrowWriter = ArrowWriter.create(root)
85-
val writer = new ArrowStreamWriter(root, null, dataOut)
86-
writer.start()
87-
88-
while (inputIterator.hasNext) {
89-
val nextBatch = inputIterator.next()
90-
91-
while (nextBatch.hasNext) {
92-
arrowWriter.write(nextBatch.next())
93-
}
94-
95-
arrowWriter.finish()
96-
writer.writeBatch()
97-
arrowWriter.reset()
98-
}
99-
// end writes footer to the output stream and doesn't clean any resources.
100-
// It could throw exception if the output stream is closed, so it should be
101-
// in the try block.
102-
writer.end()
103-
} {
104-
// If we close root and allocator in TaskCompletionListener, there could be a race
105-
// condition where the writer thread keeps writing to the VectorSchemaRoot while
106-
// it's being closed by the TaskCompletion listener.
107-
// Closing root and allocator here is cleaner because root and allocator is owned
108-
// by the writer thread and is only visible to the writer thread.
109-
//
110-
// If the writer thread is interrupted by TaskCompletionListener, it should either
111-
// (1) in the try block, in which case it will get an InterruptedException when
112-
// performing io, and goes into the finally block or (2) in the finally block,
113-
// in which case it will ignore the interruption and close the resources.
114-
root.close()
115-
allocator.close()
116-
}
117-
}
118-
}
119-
}
120-
12147
}
Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
package org.apache.spark.sql.execution.python
18+
19+
import java.io.DataOutputStream
20+
import java.net.Socket
21+
22+
import org.apache.arrow.vector.VectorSchemaRoot
23+
import org.apache.arrow.vector.ipc.ArrowStreamWriter
24+
25+
import org.apache.spark.{SparkEnv, TaskContext}
26+
import org.apache.spark.api.python.{BasePythonRunner, PythonRDD}
27+
import org.apache.spark.sql.catalyst.InternalRow
28+
import org.apache.spark.sql.execution.arrow.ArrowWriter
29+
import org.apache.spark.sql.types.StructType
30+
import org.apache.spark.sql.util.ArrowUtils
31+
import org.apache.spark.util.Utils
32+
33+
/**
34+
* A trait that can be mixed-in with [[BasePythonRunner]]. It implements the logic from
35+
* JVM (an iterator of internal rows) to Python (Arrow).
36+
*/
37+
private[python] trait PythonArrowInput { self: BasePythonRunner[Iterator[InternalRow], _] =>
38+
protected val workerConf: Map[String, String]
39+
40+
protected val schema: StructType
41+
42+
protected val timeZoneId: String
43+
44+
protected def handleMetadataBeforeExec(stream: DataOutputStream): Unit = {
45+
// Write config for the worker as a number of key -> value pairs of strings
46+
stream.writeInt(workerConf.size)
47+
for ((k, v) <- workerConf) {
48+
PythonRDD.writeUTF(k, stream)
49+
PythonRDD.writeUTF(v, stream)
50+
}
51+
}
52+
53+
protected override def newWriterThread(
54+
env: SparkEnv,
55+
worker: Socket,
56+
inputIterator: Iterator[Iterator[InternalRow]],
57+
partitionIndex: Int,
58+
context: TaskContext): WriterThread = {
59+
new WriterThread(env, worker, inputIterator, partitionIndex, context) {
60+
61+
protected override def writeCommand(dataOut: DataOutputStream): Unit = {
62+
handleMetadataBeforeExec(dataOut)
63+
PythonUDFRunner.writeUDFs(dataOut, funcs, argOffsets)
64+
}
65+
66+
protected override def writeIteratorToStream(dataOut: DataOutputStream): Unit = {
67+
val arrowSchema = ArrowUtils.toArrowSchema(schema, timeZoneId)
68+
val allocator = ArrowUtils.rootAllocator.newChildAllocator(
69+
s"stdout writer for $pythonExec", 0, Long.MaxValue)
70+
val root = VectorSchemaRoot.create(arrowSchema, allocator)
71+
72+
Utils.tryWithSafeFinally {
73+
val arrowWriter = ArrowWriter.create(root)
74+
val writer = new ArrowStreamWriter(root, null, dataOut)
75+
writer.start()
76+
77+
while (inputIterator.hasNext) {
78+
val nextBatch = inputIterator.next()
79+
80+
while (nextBatch.hasNext) {
81+
arrowWriter.write(nextBatch.next())
82+
}
83+
84+
arrowWriter.finish()
85+
writer.writeBatch()
86+
arrowWriter.reset()
87+
}
88+
// end writes footer to the output stream and doesn't clean any resources.
89+
// It could throw exception if the output stream is closed, so it should be
90+
// in the try block.
91+
writer.end()
92+
} {
93+
// If we close root and allocator in TaskCompletionListener, there could be a race
94+
// condition where the writer thread keeps writing to the VectorSchemaRoot while
95+
// it's being closed by the TaskCompletion listener.
96+
// Closing root and allocator here is cleaner because root and allocator is owned
97+
// by the writer thread and is only visible to the writer thread.
98+
//
99+
// If the writer thread is interrupted by TaskCompletionListener, it should either
100+
// (1) in the try block, in which case it will get an InterruptedException when
101+
// performing io, and goes into the finally block or (2) in the finally block,
102+
// in which case it will ignore the interruption and close the resources.
103+
root.close()
104+
allocator.close()
105+
}
106+
}
107+
}
108+
}
109+
}

sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowOutput.scala

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,8 @@ import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch, Column
3737
*/
3838
private[python] trait PythonArrowOutput { self: BasePythonRunner[_, ColumnarBatch] =>
3939

40+
protected def handleMetadataAfterExec(stream: DataInputStream): Unit = { }
41+
4042
protected def newReaderIterator(
4143
stream: DataInputStream,
4244
writerThread: WriterThread,
@@ -67,6 +69,11 @@ private[python] trait PythonArrowOutput { self: BasePythonRunner[_, ColumnarBatc
6769

6870
private var batchLoaded = true
6971

72+
protected override def handleEndOfDataSection(): Unit = {
73+
handleMetadataAfterExec(stream)
74+
super.handleEndOfDataSection()
75+
}
76+
7077
protected override def read(): ColumnarBatch = {
7178
if (writerThread.exception.isDefined) {
7279
throw writerThread.exception.get

0 commit comments

Comments
 (0)