Skip to content

Commit a73e9d9

Browse files
authored
Issue 1761 fix supplier reuse (#1779)
For WorkflowLocals and WorkflowThreadLocals, disambiguate between variables that are null and those that have not been set. Invoke the supplier for local values at most once.
1 parent 18120b1 commit a73e9d9

File tree

7 files changed

+185
-10
lines changed

7 files changed

+185
-10
lines changed

temporal-sdk/src/main/java/io/temporal/internal/sync/DeterministicRunnerImpl.java

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -589,12 +589,22 @@ private boolean areThreadsToBeExecuted() {
589589
|| !toExecuteInWorkflowThread.isEmpty();
590590
}
591591

592+
/**
593+
* Retrieve data from runner locals. Returns 1. not found (an empty Optional) 2. found but null
594+
* (an Optional of an empty Optional) 3. found and non-null (an Optional of an Optional of a
595+
* value). The type nesting is because Java Optionals cannot understand "Some null" vs "None",
596+
* which is exactly what we need here.
597+
*
598+
* @param key
599+
* @return one of three cases
600+
* @param <T>
601+
*/
592602
@SuppressWarnings("unchecked")
593-
<T> Optional<T> getRunnerLocal(RunnerLocalInternal<T> key) {
603+
<T> Optional<Optional<T>> getRunnerLocal(RunnerLocalInternal<T> key) {
594604
if (!runnerLocalMap.containsKey(key)) {
595605
return Optional.empty();
596606
}
597-
return Optional.of((T) runnerLocalMap.get(key));
607+
return Optional.of(Optional.ofNullable((T) runnerLocalMap.get(key)));
598608
}
599609

600610
<T> void setRunnerLocal(RunnerLocalInternal<T> key, T value) {

temporal-sdk/src/main/java/io/temporal/internal/sync/RunnerLocalInternal.java

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,24 @@
2424
import java.util.function.Supplier;
2525

2626
public final class RunnerLocalInternal<T> {
27+
private T supplierResult = null;
28+
private boolean supplierCalled = false;
29+
30+
Optional<T> invokeSupplier(Supplier<? extends T> supplier) {
31+
if (!supplierCalled) {
32+
T result = supplier.get();
33+
supplierCalled = true;
34+
supplierResult = result;
35+
return Optional.ofNullable(result);
36+
} else {
37+
return Optional.ofNullable(supplierResult);
38+
}
39+
}
2740

2841
public T get(Supplier<? extends T> supplier) {
29-
Optional<T> result =
42+
Optional<Optional<T>> result =
3043
DeterministicRunnerImpl.currentThreadInternal().getRunner().getRunnerLocal(this);
31-
return result.orElse(supplier.get());
44+
return result.orElseGet(() -> invokeSupplier(supplier)).orElse(null);
3245
}
3346

3447
public void set(T value) {

temporal-sdk/src/main/java/io/temporal/internal/sync/WorkflowThread.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ static void exit() {
120120

121121
<T> void setThreadLocal(WorkflowThreadLocalInternal<T> key, T value);
122122

123-
<T> Optional<T> getThreadLocal(WorkflowThreadLocalInternal<T> key);
123+
<T> Optional<Optional<T>> getThreadLocal(WorkflowThreadLocalInternal<T> key);
124124

125125
WorkflowThreadContext getWorkflowThreadContext();
126126
}

temporal-sdk/src/main/java/io/temporal/internal/sync/WorkflowThreadImpl.java

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -404,13 +404,22 @@ public <T> void setThreadLocal(WorkflowThreadLocalInternal<T> key, T value) {
404404
threadLocalMap.put(key, value);
405405
}
406406

407+
/**
408+
* Retrieve data from thread locals. Returns 1. not found (an empty Optional) 2. found but null
409+
* (an Optional of an empty Optional) 3. found and non-null (an Optional of an Optional of a
410+
* value). The type nesting is because Java Optionals cannot understand "Some null" vs "None",
411+
* which is exactly what we need here.
412+
*
413+
* @param key
414+
* @return one of three cases
415+
* @param <T>
416+
*/
407417
@SuppressWarnings("unchecked")
408-
@Override
409-
public <T> Optional<T> getThreadLocal(WorkflowThreadLocalInternal<T> key) {
418+
public <T> Optional<Optional<T>> getThreadLocal(WorkflowThreadLocalInternal<T> key) {
410419
if (!threadLocalMap.containsKey(key)) {
411420
return Optional.empty();
412421
}
413-
return Optional.of((T) threadLocalMap.get(key));
422+
return Optional.of(Optional.ofNullable((T) threadLocalMap.get(key)));
414423
}
415424

416425
/**

temporal-sdk/src/main/java/io/temporal/internal/sync/WorkflowThreadLocalInternal.java

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,24 @@
2525

2626
public final class WorkflowThreadLocalInternal<T> {
2727

28+
private T supplierResult = null;
29+
private boolean supplierCalled = false;
30+
31+
Optional<T> invokeSupplier(Supplier<? extends T> supplier) {
32+
if (!supplierCalled) {
33+
T result = supplier.get();
34+
supplierCalled = true;
35+
supplierResult = result;
36+
return Optional.ofNullable(result);
37+
} else {
38+
return Optional.ofNullable(supplierResult);
39+
}
40+
}
41+
2842
public T get(Supplier<? extends T> supplier) {
29-
Optional<T> result = DeterministicRunnerImpl.currentThreadInternal().getThreadLocal(this);
30-
return result.orElse(supplier.get());
43+
Optional<Optional<T>> result =
44+
DeterministicRunnerImpl.currentThreadInternal().getThreadLocal(this);
45+
return result.orElseGet(() -> invokeSupplier(supplier)).orElse(null);
3146
}
3247

3348
public void set(T value) {

temporal-sdk/src/test/java/io/temporal/internal/sync/DeterministicRunnerTest.java

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,9 @@
5959
import java.util.concurrent.ThreadPoolExecutor;
6060
import java.util.concurrent.TimeUnit;
6161
import java.util.concurrent.atomic.AtomicBoolean;
62+
import java.util.concurrent.atomic.AtomicInteger;
6263
import java.util.concurrent.atomic.AtomicReference;
64+
import java.util.function.Supplier;
6365
import org.junit.*;
6466

6567
public class DeterministicRunnerTest {
@@ -890,4 +892,75 @@ public void testCloseBlockedUntilDone() throws InterruptedException {
890892
d.close();
891893
assertTrue("Close should return only after the thread finished", threadFinished.get());
892894
}
895+
896+
@Test
897+
public void testGetRunnerLocalAbsent() {
898+
DeterministicRunnerImpl d =
899+
new DeterministicRunnerImpl(
900+
threadPool::submit,
901+
DummySyncWorkflowContext.newDummySyncWorkflowContext(),
902+
() -> {
903+
RunnerLocalInternal<String> runnerLocalInternal = new RunnerLocalInternal<>();
904+
assertEquals(
905+
"supplier default value",
906+
runnerLocalInternal.get(() -> "supplier default value"));
907+
});
908+
d.runUntilAllBlocked(DeterministicRunner.DEFAULT_DEADLOCK_DETECTION_TIMEOUT_MS);
909+
}
910+
911+
@Test
912+
public void testGetRunnerLocalPresentAndNull() {
913+
DeterministicRunnerImpl d =
914+
new DeterministicRunnerImpl(
915+
threadPool::submit,
916+
DummySyncWorkflowContext.newDummySyncWorkflowContext(),
917+
() -> {
918+
RunnerLocalInternal<String> runnerLocalInternal = new RunnerLocalInternal<>();
919+
runnerLocalInternal.set(null);
920+
assertNull(runnerLocalInternal.get(() -> "supplier default value"));
921+
});
922+
d.runUntilAllBlocked(DeterministicRunner.DEFAULT_DEADLOCK_DETECTION_TIMEOUT_MS);
923+
}
924+
925+
@Test
926+
public void testGetRunnerLocalPresentAndNonNull() {
927+
DeterministicRunnerImpl d =
928+
new DeterministicRunnerImpl(
929+
threadPool::submit,
930+
DummySyncWorkflowContext.newDummySyncWorkflowContext(),
931+
() -> {
932+
RunnerLocalInternal<String> runnerLocalInternal = new RunnerLocalInternal<>();
933+
runnerLocalInternal.set("explicitly set value");
934+
assertEquals(
935+
"explicitly set value", runnerLocalInternal.get(() -> "supplier default value"));
936+
});
937+
d.runUntilAllBlocked(DeterministicRunner.DEFAULT_DEADLOCK_DETECTION_TIMEOUT_MS);
938+
}
939+
940+
private static Supplier<String> getStringSupplier(AtomicInteger supplierCalls) {
941+
return () -> {
942+
supplierCalls.addAndGet(1);
943+
return "supplier default value";
944+
};
945+
}
946+
947+
@Test
948+
public void testSupplierCalledOnce() {
949+
AtomicInteger supplierCalls = new AtomicInteger();
950+
DeterministicRunnerImpl d =
951+
new DeterministicRunnerImpl(
952+
threadPool::submit,
953+
DummySyncWorkflowContext.newDummySyncWorkflowContext(),
954+
() -> {
955+
RunnerLocalInternal<String> runnerLocalInternal = new RunnerLocalInternal<>();
956+
runnerLocalInternal.get(getStringSupplier(supplierCalls));
957+
runnerLocalInternal.get(getStringSupplier(supplierCalls));
958+
runnerLocalInternal.get(getStringSupplier(supplierCalls));
959+
assertEquals(
960+
"supplier default value",
961+
runnerLocalInternal.get(getStringSupplier(supplierCalls)));
962+
assertEquals(1, supplierCalls.get());
963+
});
964+
d.runUntilAllBlocked(DeterministicRunner.DEFAULT_DEADLOCK_DETECTION_TIMEOUT_MS);
965+
}
893966
}

temporal-sdk/src/test/java/io/temporal/workflow/WorkflowLocalsTest.java

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,12 @@
2121
package io.temporal.workflow;
2222

2323
import static org.junit.Assert.assertEquals;
24+
import static org.junit.Assert.assertNull;
2425

2526
import io.temporal.testing.internal.SDKTestWorkflowRule;
2627
import io.temporal.workflow.shared.TestWorkflows.TestWorkflow1;
2728
import java.time.Duration;
29+
import java.util.concurrent.atomic.AtomicInteger;
2830
import org.junit.Assert;
2931
import org.junit.Rule;
3032
import org.junit.Test;
@@ -76,4 +78,57 @@ public String execute(String taskQueue) {
7678
return "result=" + threadLocal.get() + ", " + workflowLocal.get();
7779
}
7880
}
81+
82+
public static class TestWorkflowLocalsSupplierReuse implements TestWorkflow1 {
83+
84+
private final AtomicInteger localCalls = new AtomicInteger(0);
85+
private final AtomicInteger threadLocalCalls = new AtomicInteger(0);
86+
87+
private final WorkflowThreadLocal<Integer> workflowThreadLocal =
88+
WorkflowThreadLocal.withInitial(
89+
() -> {
90+
threadLocalCalls.addAndGet(1);
91+
return null;
92+
});
93+
private final WorkflowLocal<Integer> workflowLocal =
94+
WorkflowLocal.withInitial(
95+
() -> {
96+
localCalls.addAndGet(1);
97+
return null;
98+
});
99+
100+
@Override
101+
public String execute(String taskQueue) {
102+
assertNull(workflowThreadLocal.get());
103+
workflowThreadLocal.set(null);
104+
assertNull(workflowThreadLocal.get());
105+
assertNull(workflowThreadLocal.get());
106+
workflowThreadLocal.set(55);
107+
assertEquals((long) workflowThreadLocal.get(), 55);
108+
assertEquals(threadLocalCalls.get(), 1);
109+
110+
assertNull(workflowLocal.get());
111+
workflowLocal.set(null);
112+
assertNull(workflowLocal.get());
113+
assertNull(workflowLocal.get());
114+
workflowLocal.set(58);
115+
assertEquals((long) workflowLocal.get(), 58);
116+
assertEquals(localCalls.get(), 1);
117+
return "ok";
118+
}
119+
}
120+
121+
@Rule
122+
public SDKTestWorkflowRule testWorkflowRuleSupplierReuse =
123+
SDKTestWorkflowRule.newBuilder()
124+
.setWorkflowTypes(TestWorkflowLocalsSupplierReuse.class)
125+
.build();
126+
127+
@Test
128+
public void testWorkflowLocalsSupplierReuse() {
129+
TestWorkflow1 workflowStub =
130+
testWorkflowRuleSupplierReuse.newWorkflowStubTimeoutOptions(TestWorkflow1.class);
131+
String result = workflowStub.execute(testWorkflowRule.getTaskQueue());
132+
Assert.assertEquals("ok", result);
133+
}
79134
}

0 commit comments

Comments
 (0)