Skip to content

Commit 0afe14c

Browse files
committed
Reimplemented WorkflowRunLockManager to fix design flaw with an unsafe unlock
Issue #1146
1 parent 145c0ea commit 0afe14c

File tree

3 files changed

+187
-148
lines changed

3 files changed

+187
-148
lines changed

temporal-sdk/src/main/java/io/temporal/internal/worker/WorkflowRunLockManager.java

+63-51
Original file line numberDiff line numberDiff line change
@@ -19,89 +19,101 @@
1919

2020
package io.temporal.internal.worker;
2121

22+
import com.google.common.annotations.VisibleForTesting;
23+
import java.util.Date;
2224
import java.util.HashMap;
23-
import java.util.concurrent.locks.Lock;
25+
import java.util.Map;
26+
import java.util.concurrent.TimeUnit;
27+
import java.util.concurrent.locks.Condition;
2428
import java.util.concurrent.locks.ReentrantLock;
2529

2630
final class WorkflowRunLockManager {
27-
28-
private static class CountableLock {
29-
private final Lock lock = new ReentrantLock();
30-
private int count = 1;
31-
32-
void incrementCount() {
33-
count++;
34-
}
35-
36-
void decrementCount() {
37-
count--;
38-
}
39-
40-
int getCount() {
41-
return count;
42-
}
43-
44-
Lock getLock() {
45-
return lock;
46-
}
47-
}
48-
49-
private final Lock mapLock = new ReentrantLock();
50-
private final HashMap<String, CountableLock> perRunLock = new HashMap<>();
31+
// This is a single lock for the whole worker.
32+
// It's ok, because it's acquired for a very short time.
33+
// If this ever becomes a bottleneck, consider:
34+
// - rework this lock to use com.google.common.util.concurrent.Striped and
35+
// - make `perRunLock` a ConcurrentHashMap to increase parallelism of this code
36+
private final ReentrantLock mapLock = new ReentrantLock();
37+
private final Map<String, LockData> perRunLock = new HashMap<>();
5138

5239
/**
53-
* This method returns a lock that can be used to serialize workflow task processing for a
54-
* particular workflow run. This is used to make sure that query tasks and real workflow tasks are
55-
* serialized when sticky is on.
56-
*
57-
* @param runId
58-
* @return a lock to be used during workflow task processing
40+
* @param runId to take a lock for
41+
* @param time the maximum time to wait for the lock
42+
* @param unit the time unit of the time argument
43+
* @return true if the lock is taken
5944
*/
60-
Lock getLockForLocking(String runId) {
45+
boolean tryLock(String runId, long time, TimeUnit unit) throws InterruptedException {
6146
mapLock.lock();
62-
6347
try {
64-
CountableLock cl = perRunLock.get(runId);
65-
if (cl == null) {
66-
cl = new CountableLock();
67-
perRunLock.put(runId, cl);
68-
} else {
69-
cl.incrementCount();
48+
Date deadline = new Date(System.currentTimeMillis() + unit.toMillis(time));
49+
boolean stillWaiting = true;
50+
LockData lockData = tryPersistLockByCurrentThreadLocked(runId);
51+
while (Thread.currentThread() != lockData.thread) {
52+
if (!stillWaiting) {
53+
return false;
54+
}
55+
stillWaiting = lockData.condition.awaitUntil(deadline);
56+
lockData = tryPersistLockByCurrentThreadLocked(runId);
7057
}
7158

72-
return cl.getLock();
59+
lockData.count++;
60+
return true;
7361
} finally {
7462
mapLock.unlock();
7563
}
7664
}
7765

66+
private LockData tryPersistLockByCurrentThreadLocked(String runId) {
67+
return perRunLock.computeIfAbsent(
68+
runId, id -> new LockData(Thread.currentThread(), mapLock.newCondition()));
69+
}
70+
7871
void unlock(String runId) {
7972
mapLock.lock();
80-
8173
try {
82-
CountableLock cl = perRunLock.get(runId);
83-
if (cl == null) {
84-
throw new RuntimeException("lock for run " + runId + " does not exist.");
74+
LockData lockData = perRunLock.get(runId);
75+
if (lockData == null) {
76+
throw new IllegalStateException("Lock for " + runId + " is not taken");
8577
}
86-
87-
cl.decrementCount();
88-
if (cl.getCount() == 0) {
89-
perRunLock.remove(runId);
78+
if (Thread.currentThread() == lockData.thread) {
79+
lockData.count--;
80+
if (lockData.count == 0) {
81+
perRunLock.remove(runId);
82+
// it's important to signal all threads,
83+
// otherwise n-1 of them will stuck waiting on a condition that is not in the map already
84+
lockData.condition.signalAll();
85+
}
86+
} else {
87+
throw new IllegalStateException(
88+
"Lock for "
89+
+ runId
90+
+ " is not acquired by the current thread "
91+
+ Thread.currentThread().getName());
9092
}
91-
92-
cl.getLock().unlock();
9393
} finally {
9494
mapLock.unlock();
9595
}
9696
}
9797

98+
@VisibleForTesting
9899
int totalLocks() {
99100
mapLock.lock();
100-
101101
try {
102102
return perRunLock.size();
103103
} finally {
104104
mapLock.unlock();
105105
}
106106
}
107+
108+
private static class LockData {
109+
final Thread thread;
110+
final Condition condition;
111+
// to make lock reentrant
112+
int count = 0;
113+
114+
public LockData(Thread thread, Condition condition) {
115+
this.thread = thread;
116+
this.condition = condition;
117+
}
118+
}
107119
}

0 commit comments

Comments
 (0)