|
19 | 19 |
|
20 | 20 | package io.temporal.internal.worker;
|
21 | 21 |
|
| 22 | +import com.google.common.annotations.VisibleForTesting; |
| 23 | +import java.util.Date; |
22 | 24 | import java.util.HashMap;
|
23 |
| -import java.util.concurrent.locks.Lock; |
| 25 | +import java.util.Map; |
| 26 | +import java.util.concurrent.TimeUnit; |
| 27 | +import java.util.concurrent.locks.Condition; |
24 | 28 | import java.util.concurrent.locks.ReentrantLock;
|
25 | 29 |
|
26 | 30 | final class WorkflowRunLockManager {
|
27 |
| - |
28 |
| - private static class CountableLock { |
29 |
| - private final Lock lock = new ReentrantLock(); |
30 |
| - private int count = 1; |
31 |
| - |
32 |
| - void incrementCount() { |
33 |
| - count++; |
34 |
| - } |
35 |
| - |
36 |
| - void decrementCount() { |
37 |
| - count--; |
38 |
| - } |
39 |
| - |
40 |
| - int getCount() { |
41 |
| - return count; |
42 |
| - } |
43 |
| - |
44 |
| - Lock getLock() { |
45 |
| - return lock; |
46 |
| - } |
47 |
| - } |
48 |
| - |
49 |
| - private final Lock mapLock = new ReentrantLock(); |
50 |
| - private final HashMap<String, CountableLock> perRunLock = new HashMap<>(); |
| 31 | + // This is a single lock for the whole worker. |
| 32 | + // It's ok, because it's acquired for a very short time. |
| 33 | + // If this ever becomes a bottleneck, consider: |
| 34 | + // - rework this lock to use com.google.common.util.concurrent.Striped and |
| 35 | + // - make `perRunLock` a ConcurrentHashMap to increase parallelism of this code |
| 36 | + private final ReentrantLock mapLock = new ReentrantLock(); |
| 37 | + private final Map<String, LockData> perRunLock = new HashMap<>(); |
51 | 38 |
|
52 | 39 | /**
|
53 |
| - * This method returns a lock that can be used to serialize workflow task processing for a |
54 |
| - * particular workflow run. This is used to make sure that query tasks and real workflow tasks are |
55 |
| - * serialized when sticky is on. |
56 |
| - * |
57 |
| - * @param runId |
58 |
| - * @return a lock to be used during workflow task processing |
| 40 | + * @param runId to take a lock for |
| 41 | + * @param time the maximum time to wait for the lock |
| 42 | + * @param unit the time unit of the time argument |
| 43 | + * @return true if the lock is taken |
59 | 44 | */
|
60 |
| - Lock getLockForLocking(String runId) { |
| 45 | + boolean tryLock(String runId, long time, TimeUnit unit) throws InterruptedException { |
61 | 46 | mapLock.lock();
|
62 |
| - |
63 | 47 | try {
|
64 |
| - CountableLock cl = perRunLock.get(runId); |
65 |
| - if (cl == null) { |
66 |
| - cl = new CountableLock(); |
67 |
| - perRunLock.put(runId, cl); |
68 |
| - } else { |
69 |
| - cl.incrementCount(); |
| 48 | + Date deadline = new Date(System.currentTimeMillis() + unit.toMillis(time)); |
| 49 | + boolean stillWaiting = true; |
| 50 | + LockData lockData = tryPersistLockByCurrentThreadLocked(runId); |
| 51 | + while (Thread.currentThread() != lockData.thread) { |
| 52 | + if (!stillWaiting) { |
| 53 | + return false; |
| 54 | + } |
| 55 | + stillWaiting = lockData.condition.awaitUntil(deadline); |
| 56 | + lockData = tryPersistLockByCurrentThreadLocked(runId); |
70 | 57 | }
|
71 | 58 |
|
72 |
| - return cl.getLock(); |
| 59 | + lockData.count++; |
| 60 | + return true; |
73 | 61 | } finally {
|
74 | 62 | mapLock.unlock();
|
75 | 63 | }
|
76 | 64 | }
|
77 | 65 |
|
| 66 | + private LockData tryPersistLockByCurrentThreadLocked(String runId) { |
| 67 | + return perRunLock.computeIfAbsent( |
| 68 | + runId, id -> new LockData(Thread.currentThread(), mapLock.newCondition())); |
| 69 | + } |
| 70 | + |
78 | 71 | void unlock(String runId) {
|
79 | 72 | mapLock.lock();
|
80 |
| - |
81 | 73 | try {
|
82 |
| - CountableLock cl = perRunLock.get(runId); |
83 |
| - if (cl == null) { |
84 |
| - throw new RuntimeException("lock for run " + runId + " does not exist."); |
| 74 | + LockData lockData = perRunLock.get(runId); |
| 75 | + if (lockData == null) { |
| 76 | + throw new IllegalStateException("Lock for " + runId + " is not taken"); |
85 | 77 | }
|
86 |
| - |
87 |
| - cl.decrementCount(); |
88 |
| - if (cl.getCount() == 0) { |
89 |
| - perRunLock.remove(runId); |
| 78 | + if (Thread.currentThread() == lockData.thread) { |
| 79 | + lockData.count--; |
| 80 | + if (lockData.count == 0) { |
| 81 | + perRunLock.remove(runId); |
| 82 | + // it's important to signal all threads, |
| 83 | + // otherwise n-1 of them will stuck waiting on a condition that is not in the map already |
| 84 | + lockData.condition.signalAll(); |
| 85 | + } |
| 86 | + } else { |
| 87 | + throw new IllegalStateException( |
| 88 | + "Lock for " |
| 89 | + + runId |
| 90 | + + " is not acquired by the current thread " |
| 91 | + + Thread.currentThread().getName()); |
90 | 92 | }
|
91 |
| - |
92 |
| - cl.getLock().unlock(); |
93 | 93 | } finally {
|
94 | 94 | mapLock.unlock();
|
95 | 95 | }
|
96 | 96 | }
|
97 | 97 |
|
| 98 | + @VisibleForTesting |
98 | 99 | int totalLocks() {
|
99 | 100 | mapLock.lock();
|
100 |
| - |
101 | 101 | try {
|
102 | 102 | return perRunLock.size();
|
103 | 103 | } finally {
|
104 | 104 | mapLock.unlock();
|
105 | 105 | }
|
106 | 106 | }
|
| 107 | + |
| 108 | + private static class LockData { |
| 109 | + final Thread thread; |
| 110 | + final Condition condition; |
| 111 | + // to make lock reentrant |
| 112 | + int count = 0; |
| 113 | + |
| 114 | + public LockData(Thread thread, Condition condition) { |
| 115 | + this.thread = thread; |
| 116 | + this.condition = condition; |
| 117 | + } |
| 118 | + } |
107 | 119 | }
|
0 commit comments