Skip to content

Commit

Permalink
Add tryAcquire to GpuSemaphore (NVIDIA#10330)
Browse files Browse the repository at this point in the history
Signed-off-by: Robert (Bobby) Evans <bobby@apache.org>
  • Loading branch information
revans2 authored and sperlingxx committed Feb 2, 2024
1 parent b09ea63 commit 7b9d103
Show file tree
Hide file tree
Showing 2 changed files with 124 additions and 15 deletions.
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2019-2023, NVIDIA CORPORATION.
* Copyright (c) 2019-2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -30,6 +30,24 @@ import org.apache.spark.internal.Logging
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.rapids.GpuTaskMetrics

/**
* The result of trying to acquire a semaphore could be
* `SemaphoreAcquired` or `AcquireFailed`.
*/
sealed trait TryAcquireResult

/**
* The Semaphore was successfully acquired.
*/
case object SemaphoreAcquired extends TryAcquireResult

/**
* To acquire the semaphore this thread would have to block.
* @param numWaitingTasks the number of tasks waiting at the time the request was made.
* Note that this can change very quickly.
*/
case class AcquireFailed(numWaitingTasks: Int) extends TryAcquireResult

object GpuSemaphore {
// DO NOT ACCESS DIRECTLY! Use `getInstance` instead.
@volatile private var instance: GpuSemaphore = _
Expand Down Expand Up @@ -58,6 +76,20 @@ object GpuSemaphore {
instance = new GpuSemaphore()
}

/**
* A thread may try to acquire the semaphore without blocking on it. NOTE: A task completion
* listener will automatically be installed to ensure the semaphore is always released by the
* time the task completes.
*/
def tryAcquire(context: TaskContext): TryAcquireResult = {
if (context != null) {
getInstance.tryAcquire(context)
} else {
// For unit tests that might try with no context
SemaphoreAcquired
}
}

/**
* Tasks must call this when they begin to use the GPU.
* If the task has not already acquired the GPU semaphore then it is acquired,
Expand All @@ -71,14 +103,6 @@ object GpuSemaphore {
}
}

def mayBeAvailable(context: TaskContext): Boolean = {
if (context != null) {
getInstance.anyResourceRemain
} else {
false
}
}

/**
* Tasks must call this when they are finished using the GPU.
*/
Expand Down Expand Up @@ -253,6 +277,28 @@ private final class SemaphoreTaskInfo() extends Logging {
}
}

def tryAcquire(semaphore: Semaphore): Boolean = synchronized {
val t = Thread.currentThread()
if (hasSemaphore) {
activeThreads.add(t)
true
} else {
if (blockedThreads.size() == 0) {
// No other threads for this task are waiting, so we might be able to grab this directly
val ret = semaphore.tryAcquire(numPermits)
if (ret) {
hasSemaphore = true
activeThreads.add(t)
// no need to notify because there are no other threads and we are holding the lock
// to ensure that.
}
ret
} else {
false
}
}
}

def releaseSemaphore(semaphore: Semaphore): Unit = synchronized {
val t = Thread.currentThread()
activeThreads.remove(t)
Expand All @@ -275,7 +321,28 @@ private final class GpuSemaphore() extends Logging {
// Keep track of all tasks that are both active on the GPU and blocked waiting on the GPU
private val tasks = new ConcurrentHashMap[Long, SemaphoreTaskInfo]

private def anyResourceRemain: Boolean = semaphore.availablePermits() > 0
def tryAcquire(context: TaskContext): TryAcquireResult = {
// Make sure that the thread/task is registered before we try and block
TaskRegistryTracker.registerThreadForRetry()
val taskAttemptId = context.taskAttemptId()
val taskInfo = tasks.computeIfAbsent(taskAttemptId, _ => {
onTaskCompletion(context, completeTask)
new SemaphoreTaskInfo()
})
if (taskInfo.tryAcquire(semaphore)) {
GpuDeviceManager.initializeFromTask()
SemaphoreAcquired
} else {
// We need to get the number of tasks that are still waiting
var numWaiting = 0
tasks.values().forEach { ti =>
if (ti.isHoldingSemaphore) {
numWaiting += 1
}
}
AcquireFailed(numWaiting)
}
}

def acquireIfNecessary(context: TaskContext): Unit = {
// Make sure that the thread/task is registered before we try and block
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2019-2023, NVIDIA CORPORATION.
* Copyright (c) 2019-2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -29,19 +29,23 @@ import org.apache.spark.sql.SparkSession

class GpuSemaphoreSuite extends AnyFunSuite
with BeforeAndAfterEach with MockitoSugar with TimeLimits with TimeLimitedTests {
val timeLimit = Span(10, Seconds)
val timeLimit: Span = Span(10, Seconds)

override def beforeEach(): Unit = {
ScalableTaskCompletion.reset()
GpuSemaphore.shutdown()
// semaphore tests depend on a SparkEnv being available
val activeSession = SparkSession.getActiveSession
if (activeSession.isEmpty) {
SparkSession.builder
if (activeSession.isDefined) {
SparkSession.getActiveSession.foreach(_.stop())
SparkSession.clearActiveSession()
}
SparkSession.builder
.appName("semaphoreTests")
.master("local[1]")
// only 1 task at a time so we can verify what blocks and what does not block
.config("spark.rapids.sql.concurrentGpuTasks", "1")
.getOrCreate()
}
}

override def afterEach(): Unit = {
Expand Down Expand Up @@ -79,4 +83,42 @@ class GpuSemaphoreSuite extends AnyFunSuite
GpuSemaphore.acquireIfNecessary(context)
verify(context, times(1)).addTaskCompletionListener[Unit](any())
}

def assertAcquired(result: TryAcquireResult): Unit = result match {
case SemaphoreAcquired => // NOOP
case AcquireFailed(_) =>
fail("The Semaphore was not acquired")
}

def assertNotAcquired(numExpectedWaiting: Int, result: TryAcquireResult): Unit = result match {
case SemaphoreAcquired =>
fail("The Semaphore was acquired when we didn't expect it")
case AcquireFailed(numWaiting) =>
assert(numWaiting == numExpectedWaiting, "The number of waiting tasks didn't match")
}

test("multi tryAcquire") {
GpuDeviceManager.setRmmTaskInitEnabled(false)
val context = mockContext(1)
try {
assertAcquired(GpuSemaphore.tryAcquire(context))
assertAcquired(GpuSemaphore.tryAcquire(context))
} finally {
GpuSemaphore.releaseIfNecessary(context)
}
}

test("tryAcquire non-blocking") {
GpuDeviceManager.setRmmTaskInitEnabled(false)
val context1 = mockContext(1)
val context2 = mockContext(2)
try {
GpuSemaphore.acquireIfNecessary(context1)
assertNotAcquired(1, GpuSemaphore.tryAcquire(context2))
assertNotAcquired(1, GpuSemaphore.tryAcquire(context2))
} finally {
GpuSemaphore.releaseIfNecessary(context1)
GpuSemaphore.releaseIfNecessary(context2)
}
}
}

0 comments on commit 7b9d103

Please # to comment.