Skip to content

Commit 295dd57

Browse files
committed
[SPARK-40235][CORE] Use interruptible lock instead of synchronized in Executor.updateDependencies()
### What changes were proposed in this pull request? This patch modifies the synchronization in `Executor.updateDependencies()` in order to allow tasks to be interrupted while they are blocked and waiting on other tasks to finish downloading dependencies. This synchronization was added years ago in mesos/spark@7b9e96c in order to prevent concurrently-launching tasks from performing concurrent dependency updates. If one task is downloading dependencies, all other newly-launched tasks will block until the original dependency download is complete. Let's say that a Spark task launches, becomes blocked on a `updateDependencies()` call, then is cancelled while it is blocked. Although Spark will send a `Thread.interrupt()` to the canceled task, the task will continue waiting because threads blocked on a `synchronized` won't throw an InterruptedException in response to the interrupt. As a result, the blocked thread will continue to wait until the other thread exits the synchronized block.  This PR aims to fix this problem by replacing the `synchronized` with a `ReentrantLock`, which has a `lockInterruptibly` method. ### Why are the changes needed? In a real-world scenario, we hit a case where a task was canceled right after being launched while another task was blocked in a slow library download. The slow library download took so long that the TaskReaper killed the executor because the canceled task could not exit in a timely fashion. This patch's fix prevents this issue. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? New unit test case. Closes #37681 from JoshRosen/SPARK-40235-update-dependencies-lock. Authored-by: Josh Rosen <joshrosen@databricks.com> Signed-off-by: Josh Rosen <joshrosen@databricks.com>
1 parent c95ed82 commit 295dd57

File tree

2 files changed

+72
-3
lines changed

2 files changed

+72
-3
lines changed

core/src/main/scala/org/apache/spark/executor/Executor.scala

+19-3
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ import java.nio.ByteBuffer
2525
import java.util.{Locale, Properties}
2626
import java.util.concurrent._
2727
import java.util.concurrent.atomic.AtomicBoolean
28+
import java.util.concurrent.locks.ReentrantLock
2829
import javax.annotation.concurrent.GuardedBy
2930
import javax.ws.rs.core.UriBuilder
3031

@@ -85,6 +86,11 @@ private[spark] class Executor(
8586

8687
private[executor] val conf = env.conf
8788

89+
// SPARK-40235: updateDependencies() uses a ReentrantLock instead of the `synchronized` keyword
90+
// so that tasks can exit quickly if they are interrupted while waiting on another task to
91+
// finish downloading dependencies.
92+
private val updateDependenciesLock = new ReentrantLock()
93+
8894
// No ip or host:port - just hostname
8995
Utils.checkHost(executorHostname)
9096
// must not have port specified.
@@ -978,13 +984,19 @@ private[spark] class Executor(
978984
/**
979985
* Download any missing dependencies if we receive a new set of files and JARs from the
980986
* SparkContext. Also adds any new JARs we fetched to the class loader.
987+
* Visible for testing.
981988
*/
982-
private def updateDependencies(
989+
private[executor] def updateDependencies(
983990
newFiles: Map[String, Long],
984991
newJars: Map[String, Long],
985-
newArchives: Map[String, Long]): Unit = {
992+
newArchives: Map[String, Long],
993+
testStartLatch: Option[CountDownLatch] = None,
994+
testEndLatch: Option[CountDownLatch] = None): Unit = {
986995
lazy val hadoopConf = SparkHadoopUtil.get.newConfiguration(conf)
987-
synchronized {
996+
updateDependenciesLock.lockInterruptibly()
997+
try {
998+
// For testing, so we can simulate a slow file download:
999+
testStartLatch.foreach(_.countDown())
9881000
// Fetch missing dependencies
9891001
for ((name, timestamp) <- newFiles if currentFiles.getOrElse(name, -1L) < timestamp) {
9901002
logInfo(s"Fetching $name with timestamp $timestamp")
@@ -1027,6 +1039,10 @@ private[spark] class Executor(
10271039
}
10281040
}
10291041
}
1042+
// For testing, so we can simulate a slow file download:
1043+
testEndLatch.foreach(_.await())
1044+
} finally {
1045+
updateDependenciesLock.unlock()
10301046
}
10311047
}
10321048

core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala

+53
Original file line numberDiff line numberDiff line change
@@ -514,6 +514,59 @@ class ExecutorSuite extends SparkFunSuite
514514
}
515515
}
516516

517+
test("SPARK-40235: updateDependencies is interruptible when waiting on lock") {
518+
val conf = new SparkConf
519+
val serializer = new JavaSerializer(conf)
520+
val env = createMockEnv(conf, serializer)
521+
withExecutor("id", "localhost", env) { executor =>
522+
val startLatch = new CountDownLatch(1)
523+
val endLatch = new CountDownLatch(1)
524+
525+
// Start a thread to simulate a task that begins executing updateDependencies()
526+
// and takes a long time to finish because file download is slow:
527+
val slowLibraryDownloadThread = new Thread(() => {
528+
executor.updateDependencies(
529+
Map.empty,
530+
Map.empty,
531+
Map.empty,
532+
Some(startLatch),
533+
Some(endLatch))
534+
})
535+
slowLibraryDownloadThread.start()
536+
537+
// Wait for that thread to acquire the lock:
538+
startLatch.await()
539+
540+
// Start a second thread to simulate a task that blocks on the other task's
541+
// dependency update:
542+
val blockedLibraryDownloadThread = new Thread(() => {
543+
executor.updateDependencies(
544+
Map.empty,
545+
Map.empty,
546+
Map.empty)
547+
})
548+
blockedLibraryDownloadThread.start()
549+
eventually(timeout(10.seconds), interval(100.millis)) {
550+
val threadState = blockedLibraryDownloadThread.getState
551+
assert(Set(Thread.State.BLOCKED, Thread.State.WAITING).contains(threadState))
552+
}
553+
554+
// Interrupt the blocked thread:
555+
blockedLibraryDownloadThread.interrupt()
556+
557+
// The thread should exit:
558+
eventually(timeout(10.seconds), interval(100.millis)) {
559+
assert(blockedLibraryDownloadThread.getState == Thread.State.TERMINATED)
560+
}
561+
562+
// Allow the first thread to finish and exit:
563+
endLatch.countDown()
564+
eventually(timeout(10.seconds), interval(100.millis)) {
565+
assert(slowLibraryDownloadThread.getState == Thread.State.TERMINATED)
566+
}
567+
}
568+
}
569+
517570
private def createMockEnv(conf: SparkConf, serializer: JavaSerializer): SparkEnv = {
518571
val mockEnv = mock[SparkEnv]
519572
val mockRpcEnv = mock[RpcEnv]

0 commit comments

Comments
 (0)