From 71c81667d9688ee132f43673468390f139ca1a59 Mon Sep 17 00:00:00 2001 From: Firestarman Date: Fri, 26 Apr 2024 16:52:15 +0800 Subject: [PATCH] Semaphore optimization in scan Signed-off-by: Firestarman --- .../scala/com/nvidia/spark/rapids/GpuMultiFileReader.scala | 7 ++++++- .../scala/com/nvidia/spark/rapids/GpuParquetScan.scala | 7 +++++-- 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuMultiFileReader.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuMultiFileReader.scala index f64ed1097b0..73e34c194b1 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuMultiFileReader.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuMultiFileReader.scala @@ -639,6 +639,8 @@ abstract class MultiFileCloudPartitionReaderBase( return true } + // Read starts with IO operations, so leaving GPU for a while. + GpuSemaphore.releaseIfNecessary(TaskContext.get()) // Temporary until we get more to read batchIter = EmptyGpuColumnarBatchIterator // if we have batch left from the last file read return it @@ -1031,6 +1033,9 @@ abstract class MultiFileCoalescingPartitionReaderBase( def startNewBufferRetry: Unit = () private def readBatch(): Iterator[ColumnarBatch] = { + val taskContext = TaskContext.get() + // Read begins with IO operations, so leaving GPU for a while. + GpuSemaphore.releaseIfNecessary(taskContext) withResource(new NvtxRange(s"$getFileFormatShortName readBatch", NvtxColor.GREEN)) { _ => val currentChunkMeta = populateCurrentBlockChunk() val batchIter = if (currentChunkMeta.clippedSchema.isEmpty) { @@ -1040,7 +1045,7 @@ abstract class MultiFileCoalescingPartitionReaderBase( } else { val rows = currentChunkMeta.numTotalRows.toInt // Someone is going to process this data, even if it is just a row count - GpuSemaphore.acquireIfNecessary(TaskContext.get()) + GpuSemaphore.acquireIfNecessary(taskContext) val nullColumns = currentChunkMeta.readSchema.safeMap(f => GpuColumnVector.fromNull(rows, f.dataType).asInstanceOf[SparkVector]) val emptyBatch = new ColumnarBatch(nullColumns.toArray, rows) diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuParquetScan.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuParquetScan.scala index 4f140f27bf3..3c089c3f7e2 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuParquetScan.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuParquetScan.scala @@ -2783,6 +2783,9 @@ class ParquetPartitionReader( } private def readBatches(): Iterator[ColumnarBatch] = { + val taskContext = TaskContext.get() + // Read starts with IO operations, so leaving GPU for a while. + GpuSemaphore.releaseIfNecessary(taskContext) withResource(new NvtxRange("Parquet readBatch", NvtxColor.GREEN)) { _ => val currentChunkedBlocks = populateCurrentBlockChunk(blockIterator, maxReadBatchSizeRows, maxReadBatchSizeBytes, readDataSchema) @@ -2793,7 +2796,7 @@ class ParquetPartitionReader( EmptyGpuColumnarBatchIterator } else { // Someone is going to process this data, even if it is just a row count - GpuSemaphore.acquireIfNecessary(TaskContext.get()) + GpuSemaphore.acquireIfNecessary(taskContext) val nullColumns = readDataSchema.safeMap(f => GpuColumnVector.fromNull(numRows, f.dataType).asInstanceOf[SparkVector]) new SingleGpuColumnarBatchIterator(new ColumnarBatch(nullColumns.toArray, numRows)) @@ -2812,7 +2815,7 @@ class ParquetPartitionReader( CachedGpuBatchIterator(EmptyTableReader, colTypes) } else { // about to start using the GPU - GpuSemaphore.acquireIfNecessary(TaskContext.get()) + GpuSemaphore.acquireIfNecessary(taskContext) RmmRapidsRetryIterator.withRetryNoSplit(dataBuffer) { _ => // Inc the ref count because MakeParquetTableProducer will try to close the dataBuffer