Skip to content

Commit 9c0cc84

Browse files
committed
Fix caching in WorkflowLocal/WorkflowThreadLocal (temporalio#1876)
Reverted caching changes made to WorkflowLocal/WorkflowThreadLocal, which broke backwards compatibility and accidentally shared values between Workflows/Threads. Re-implemented caching as an optional feature, and deprecated the factory methods that created non-caching instances.
1 parent 717ee05 commit 9c0cc84

File tree

7 files changed

+221
-36
lines changed

7 files changed

+221
-36
lines changed

temporal-sdk/src/main/java/io/temporal/internal/context/ContextThreadLocal.java

+1-1
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
public class ContextThreadLocal {
3333

3434
private static final WorkflowThreadLocal<List<ContextPropagator>> contextPropagators =
35-
WorkflowThreadLocal.withInitial(
35+
WorkflowThreadLocal.withCachedInitial(
3636
new Supplier<List<ContextPropagator>>() {
3737
@Override
3838
public List<ContextPropagator> get() {

temporal-sdk/src/main/java/io/temporal/internal/sync/RunnerLocalInternal.java

+15-13
Original file line numberDiff line numberDiff line change
@@ -24,24 +24,26 @@
2424
import java.util.function.Supplier;
2525

2626
public final class RunnerLocalInternal<T> {
27-
private T supplierResult = null;
28-
private boolean supplierCalled = false;
29-
30-
Optional<T> invokeSupplier(Supplier<? extends T> supplier) {
31-
if (!supplierCalled) {
32-
T result = supplier.get();
33-
supplierCalled = true;
34-
supplierResult = result;
35-
return Optional.ofNullable(result);
36-
} else {
37-
return Optional.ofNullable(supplierResult);
38-
}
27+
28+
private final boolean useCaching;
29+
30+
public RunnerLocalInternal() {
31+
this.useCaching = false;
32+
}
33+
34+
public RunnerLocalInternal(boolean useCaching) {
35+
this.useCaching = useCaching;
3936
}
4037

4138
public T get(Supplier<? extends T> supplier) {
4239
Optional<Optional<T>> result =
4340
DeterministicRunnerImpl.currentThreadInternal().getRunner().getRunnerLocal(this);
44-
return result.orElseGet(() -> invokeSupplier(supplier)).orElse(null);
41+
T out = result.orElseGet(() -> Optional.ofNullable(supplier.get())).orElse(null);
42+
if (!result.isPresent() && useCaching) {
43+
// This is the first time we've tried fetching this, and caching is enabled. Store it.
44+
set(out);
45+
}
46+
return out;
4547
}
4648

4749
public void set(T value) {

temporal-sdk/src/main/java/io/temporal/internal/sync/WorkflowThreadLocalInternal.java

+14-13
Original file line numberDiff line numberDiff line change
@@ -25,24 +25,25 @@
2525

2626
public final class WorkflowThreadLocalInternal<T> {
2727

28-
private T supplierResult = null;
29-
private boolean supplierCalled = false;
30-
31-
Optional<T> invokeSupplier(Supplier<? extends T> supplier) {
32-
if (!supplierCalled) {
33-
T result = supplier.get();
34-
supplierCalled = true;
35-
supplierResult = result;
36-
return Optional.ofNullable(result);
37-
} else {
38-
return Optional.ofNullable(supplierResult);
39-
}
28+
private final boolean useCaching;
29+
30+
public WorkflowThreadLocalInternal() {
31+
this(false);
32+
}
33+
34+
public WorkflowThreadLocalInternal(boolean useCaching) {
35+
this.useCaching = useCaching;
4036
}
4137

4238
public T get(Supplier<? extends T> supplier) {
4339
Optional<Optional<T>> result =
4440
DeterministicRunnerImpl.currentThreadInternal().getThreadLocal(this);
45-
return result.orElseGet(() -> invokeSupplier(supplier)).orElse(null);
41+
T out = result.orElseGet(() -> Optional.ofNullable(supplier.get())).orElse(null);
42+
if (!result.isPresent() && useCaching) {
43+
// This is the first time we've tried fetching this, and caching is enabled. Store it.
44+
set(out);
45+
}
46+
return out;
4647
}
4748

4849
public void set(T value) {

temporal-sdk/src/main/java/io/temporal/workflow/WorkflowLocal.java

+36-4
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
* <pre>{@code
3434
* public class Workflow {
3535
*
36-
* private static final WorkflowLocal<Boolean> signaled = WorkflowLocal.withInitial(() -> false);
36+
* private static final WorkflowLocal<Boolean> signaled = WorkflowLocal.withCachedInitial(() -> false);
3737
*
3838
* public static boolean isSignaled() {
3939
* return signaled.get();
@@ -49,19 +49,51 @@
4949
*/
5050
public final class WorkflowLocal<T> {
5151

52-
private final RunnerLocalInternal<T> impl = new RunnerLocalInternal<>();
52+
private final RunnerLocalInternal<T> impl;
5353
private final Supplier<? extends T> supplier;
5454

55-
private WorkflowLocal(Supplier<? extends T> supplier) {
55+
private WorkflowLocal(Supplier<? extends T> supplier, boolean useCaching) {
5656
this.supplier = Objects.requireNonNull(supplier);
57+
this.impl = new RunnerLocalInternal<>(useCaching);
5758
}
5859

5960
public WorkflowLocal() {
6061
this.supplier = () -> null;
62+
this.impl = new RunnerLocalInternal<>(false);
6163
}
6264

65+
/**
66+
* Create an instance that returns the value returned by the given {@code Supplier} when {@link
67+
* #set(S)} has not yet been called in the Workflow. Note that the value returned by the {@code
68+
* Supplier} is not stored in the {@code WorkflowLocal} implicitly; repeatedly calling {@link
69+
* #get()} will always re-execute the {@code Supplier} until you call {@link #set(S)} for the
70+
* first time. If you want the value returned by the {@code Supplier} to be stored in the {@code
71+
* WorkflowLocal}, use {@link #withCachedInitial(Supplier)} instead.
72+
*
73+
* @param supplier Callback that will be executed whenever {@link #get()} is called, until {@link
74+
* #set(S)} is called for the first time.
75+
* @return A {@code WorkflowLocal} instance.
76+
* @param <S> The type stored in the {@code WorkflowLocal}.
77+
* @deprecated Because the non-caching behavior of this API is typically not desirable, it's
78+
* recommend to use {@link #withCachedInitial(Supplier)} instead.
79+
*/
80+
@Deprecated
6381
public static <S> WorkflowLocal<S> withInitial(Supplier<? extends S> supplier) {
64-
return new WorkflowLocal<>(supplier);
82+
return new WorkflowLocal<>(supplier, false);
83+
}
84+
85+
/**
86+
* Create an instance that returns the value returned by the given {@code Supplier} when {@link
87+
* #set(S)} has not yet been called in the Workflow, and then stores the returned value inside the
88+
* {@code WorkflowLocal}.
89+
*
90+
* @param supplier Callback that will be executed when {@link #get()} is called for the first
91+
* time, if {@link #set(S)} has not already been called.
92+
* @return A {@code WorkflowLocal} instance.
93+
* @param <S> The type stored in the {@code WorkflowLocal}.
94+
*/
95+
public static <S> WorkflowLocal<S> withCachedInitial(Supplier<? extends S> supplier) {
96+
return new WorkflowLocal<>(supplier, true);
6597
}
6698

6799
public T get() {

temporal-sdk/src/main/java/io/temporal/workflow/WorkflowThreadLocal.java

+36-3
Original file line numberDiff line numberDiff line change
@@ -27,19 +27,52 @@
2727
/** {@link ThreadLocal} analog for workflow code. */
2828
public final class WorkflowThreadLocal<T> {
2929

30-
private final WorkflowThreadLocalInternal<T> impl = new WorkflowThreadLocalInternal<>();
30+
private final WorkflowThreadLocalInternal<T> impl;
3131
private final Supplier<? extends T> supplier;
3232

33-
private WorkflowThreadLocal(Supplier<? extends T> supplier) {
33+
private WorkflowThreadLocal(Supplier<? extends T> supplier, boolean useCaching) {
3434
this.supplier = Objects.requireNonNull(supplier);
35+
this.impl = new WorkflowThreadLocalInternal<>(useCaching);
3536
}
3637

3738
public WorkflowThreadLocal() {
3839
this.supplier = () -> null;
40+
this.impl = new WorkflowThreadLocalInternal<>(false);
3941
}
4042

43+
/**
44+
* Create an instance that returns the value returned by the given {@code Supplier} when {@link
45+
* #set(S)} has not yet been called in the thread. Note that the value returned by the {@code
46+
* Supplier} is not stored in the {@code WorkflowThreadLocal} implicitly; repeatedly calling
47+
* {@link #get()} will always re-execute the {@code Supplier} until you call {@link #set(S)} for
48+
* the first time. This differs from the behavior of {@code ThreadLocal}. If you want the value
49+
* returned by the {@code Supplier} to be stored in the {@code WorkflowThreadLocal}, which matches
50+
* the behavior of {@code ThreadLocal}, use {@link #withCachedInitial(Supplier)} instead.
51+
*
52+
* @param supplier Callback that will be executed whenever {@link #get()} is called, until {@link
53+
* #set(S)} is called for the first time.
54+
* @return A {@code WorkflowThreadLocal} instance.
55+
* @param <S> The type stored in the {@code WorkflowThreadLocal}.
56+
* @deprecated Because the non-caching behavior of this API is typically not desirable, it's
57+
* recommend to use {@link #withCachedInitial(Supplier)} instead.
58+
*/
59+
@Deprecated
4160
public static <S> WorkflowThreadLocal<S> withInitial(Supplier<? extends S> supplier) {
42-
return new WorkflowThreadLocal<>(supplier);
61+
return new WorkflowThreadLocal<>(supplier, false);
62+
}
63+
64+
/**
65+
* Create an instance that returns the value returned by the given {@code Supplier} when {@link
66+
* #set(S)} has not yet been called in the Workflow, and then stores the returned value inside the
67+
* {@code WorkflowThreadLocal}.
68+
*
69+
* @param supplier Callback that will be executed when {@link #get()} is called for the first
70+
* time, if {@link #set(S)} has not already been called.
71+
* @return A {@code WorkflowThreadLocal} instance.
72+
* @param <S> The type stored in the {@code WorkflowThreadLocal}.
73+
*/
74+
public static <S> WorkflowThreadLocal<S> withCachedInitial(Supplier<? extends S> supplier) {
75+
return new WorkflowThreadLocal<>(supplier, true);
4376
}
4477

4578
public T get() {

temporal-sdk/src/test/java/io/temporal/internal/sync/DeterministicRunnerTest.java

+22-2
Original file line numberDiff line numberDiff line change
@@ -945,14 +945,14 @@ private static Supplier<String> getStringSupplier(AtomicInteger supplierCalls) {
945945
}
946946

947947
@Test
948-
public void testSupplierCalledOnce() {
948+
public void testSupplierCalledOnceWithCaching() {
949949
AtomicInteger supplierCalls = new AtomicInteger();
950950
DeterministicRunnerImpl d =
951951
new DeterministicRunnerImpl(
952952
threadPool::submit,
953953
DummySyncWorkflowContext.newDummySyncWorkflowContext(),
954954
() -> {
955-
RunnerLocalInternal<String> runnerLocalInternal = new RunnerLocalInternal<>();
955+
RunnerLocalInternal<String> runnerLocalInternal = new RunnerLocalInternal<>(true);
956956
runnerLocalInternal.get(getStringSupplier(supplierCalls));
957957
runnerLocalInternal.get(getStringSupplier(supplierCalls));
958958
runnerLocalInternal.get(getStringSupplier(supplierCalls));
@@ -963,4 +963,24 @@ public void testSupplierCalledOnce() {
963963
});
964964
d.runUntilAllBlocked(DeterministicRunner.DEFAULT_DEADLOCK_DETECTION_TIMEOUT_MS);
965965
}
966+
967+
@Test
968+
public void testSupplierCalledMultipleWithoutCaching() {
969+
AtomicInteger supplierCalls = new AtomicInteger();
970+
DeterministicRunnerImpl d =
971+
new DeterministicRunnerImpl(
972+
threadPool::submit,
973+
DummySyncWorkflowContext.newDummySyncWorkflowContext(),
974+
() -> {
975+
RunnerLocalInternal<String> runnerLocalInternal = new RunnerLocalInternal<>(false);
976+
runnerLocalInternal.get(getStringSupplier(supplierCalls));
977+
runnerLocalInternal.get(getStringSupplier(supplierCalls));
978+
runnerLocalInternal.get(getStringSupplier(supplierCalls));
979+
assertEquals(
980+
"supplier default value",
981+
runnerLocalInternal.get(getStringSupplier(supplierCalls)));
982+
assertEquals(4, supplierCalls.get());
983+
});
984+
d.runUntilAllBlocked(DeterministicRunner.DEFAULT_DEADLOCK_DETECTION_TIMEOUT_MS);
985+
}
966986
}

temporal-sdk/src/test/java/io/temporal/workflow/WorkflowLocalsTest.java

+97
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,13 @@
2121
package io.temporal.workflow;
2222

2323
import static org.junit.Assert.assertEquals;
24+
import static org.junit.Assert.assertNotSame;
2425
import static org.junit.Assert.assertNull;
26+
import static org.junit.Assert.assertSame;
2527

2628
import io.temporal.testing.internal.SDKTestWorkflowRule;
2729
import io.temporal.workflow.shared.TestWorkflows.TestWorkflow1;
30+
import io.temporal.workflow.shared.TestWorkflows.TestWorkflowReturnString;
2831
import java.time.Duration;
2932
import java.util.concurrent.atomic.AtomicInteger;
3033
import org.junit.Assert;
@@ -47,9 +50,11 @@ public void testWorkflowLocals() {
4750

4851
public static class TestWorkflowLocals implements TestWorkflow1 {
4952

53+
@SuppressWarnings("deprecation")
5054
private final WorkflowThreadLocal<Integer> threadLocal =
5155
WorkflowThreadLocal.withInitial(() -> 2);
5256

57+
@SuppressWarnings("deprecation")
5358
private final WorkflowLocal<Integer> workflowLocal = WorkflowLocal.withInitial(() -> 5);
5459

5560
@Override
@@ -84,12 +89,15 @@ public static class TestWorkflowLocalsSupplierReuse implements TestWorkflow1 {
8489
private final AtomicInteger localCalls = new AtomicInteger(0);
8590
private final AtomicInteger threadLocalCalls = new AtomicInteger(0);
8691

92+
@SuppressWarnings("deprecation")
8793
private final WorkflowThreadLocal<Integer> workflowThreadLocal =
8894
WorkflowThreadLocal.withInitial(
8995
() -> {
9096
threadLocalCalls.addAndGet(1);
9197
return null;
9298
});
99+
100+
@SuppressWarnings("deprecation")
93101
private final WorkflowLocal<Integer> workflowLocal =
94102
WorkflowLocal.withInitial(
95103
() -> {
@@ -131,4 +139,93 @@ public void testWorkflowLocalsSupplierReuse() {
131139
String result = workflowStub.execute(testWorkflowRule.getTaskQueue());
132140
Assert.assertEquals("ok", result);
133141
}
142+
143+
@SuppressWarnings("deprecation")
144+
static final WorkflowThreadLocal<AtomicInteger> threadLocal =
145+
WorkflowThreadLocal.withInitial(() -> new AtomicInteger(2));
146+
147+
@SuppressWarnings("deprecation")
148+
static final WorkflowLocal<AtomicInteger> workflowLocal =
149+
WorkflowLocal.withInitial(() -> new AtomicInteger(5));
150+
151+
static final WorkflowThreadLocal<AtomicInteger> threadLocalCached =
152+
WorkflowThreadLocal.withCachedInitial(() -> new AtomicInteger(2));
153+
154+
static final WorkflowLocal<AtomicInteger> workflowLocalCached =
155+
WorkflowLocal.withCachedInitial(() -> new AtomicInteger(5));
156+
157+
public static class TestInit implements TestWorkflowReturnString {
158+
159+
@Override
160+
public String execute() {
161+
assertEquals(2, threadLocal.get().getAndSet(3));
162+
assertEquals(5, workflowLocal.get().getAndSet(6));
163+
assertEquals(2, threadLocalCached.get().getAndSet(3));
164+
assertEquals(5, workflowLocalCached.get().getAndSet(6));
165+
String out = Workflow.newChildWorkflowStub(TestWorkflow1.class).execute("ign");
166+
assertEquals("ok", out);
167+
return "result="
168+
+ threadLocal.get().get()
169+
+ ", "
170+
+ workflowLocal.get().get()
171+
+ ", "
172+
+ threadLocalCached.get().get()
173+
+ ", "
174+
+ workflowLocalCached.get().get();
175+
}
176+
}
177+
178+
public static class TestChildInit implements TestWorkflow1 {
179+
180+
@Override
181+
public String execute(String arg1) {
182+
assertEquals(2, threadLocal.get().getAndSet(8));
183+
assertEquals(5, workflowLocal.get().getAndSet(0));
184+
return "ok";
185+
}
186+
}
187+
188+
@Rule
189+
public SDKTestWorkflowRule testWorkflowRuleInitialValueNotShared =
190+
SDKTestWorkflowRule.newBuilder()
191+
.setWorkflowTypes(TestInit.class, TestChildInit.class)
192+
.build();
193+
194+
@Test
195+
public void testWorkflowInitialNotShared() {
196+
TestWorkflowReturnString workflowStub =
197+
testWorkflowRuleInitialValueNotShared.newWorkflowStubTimeoutOptions(
198+
TestWorkflowReturnString.class);
199+
String result = workflowStub.execute();
200+
Assert.assertEquals("result=2, 5, 3, 6", result);
201+
}
202+
203+
public static class TestCaching implements TestWorkflow1 {
204+
205+
@Override
206+
public String execute(String arg1) {
207+
assertNotSame(threadLocal.get(), threadLocal.get());
208+
assertNotSame(workflowLocal.get(), workflowLocal.get());
209+
threadLocal.set(threadLocal.get());
210+
workflowLocal.set(workflowLocal.get());
211+
assertSame(threadLocal.get(), threadLocal.get());
212+
assertSame(workflowLocal.get(), workflowLocal.get());
213+
214+
assertSame(threadLocalCached.get(), threadLocalCached.get());
215+
assertSame(workflowLocalCached.get(), workflowLocalCached.get());
216+
return "ok";
217+
}
218+
}
219+
220+
@Rule
221+
public SDKTestWorkflowRule testWorkflowRuleCaching =
222+
SDKTestWorkflowRule.newBuilder().setWorkflowTypes(TestCaching.class).build();
223+
224+
@Test
225+
public void testWorkflowLocalCaching() {
226+
TestWorkflow1 workflowStub =
227+
testWorkflowRuleCaching.newWorkflowStubTimeoutOptions(TestWorkflow1.class);
228+
String out = workflowStub.execute("ign");
229+
assertEquals("ok", out);
230+
}
134231
}

0 commit comments

Comments
 (0)