Skip to content

Commit cea73d3

Browse files
Add support for calling Workflow.getInfo from query handler (#2541)
Support Workflow.getInfo from query method body
1 parent 44d9abe commit cea73d3

File tree

8 files changed

+125
-13
lines changed

8 files changed

+125
-13
lines changed

temporal-sdk/src/main/java/io/temporal/internal/sync/DeterministicRunner.java

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import io.temporal.internal.worker.WorkflowExecutorCache;
44
import io.temporal.workflow.CancellationScope;
5+
import java.util.Optional;
56
import javax.annotation.Nonnull;
67
import javax.annotation.Nullable;
78

@@ -90,4 +91,17 @@ static DeterministicRunner newRunner(
9091
/** Creates a new instance of a workflow callback thread. */
9192
@Nonnull
9293
WorkflowThread newCallbackThread(Runnable runnable, @Nullable String name);
94+
95+
/**
96+
* Retrieve data from runner locals. Returns 1. not found (an empty Optional) 2. found but null
97+
* (an Optional of an empty Optional) 3. found and non-null (an Optional of an Optional of a
98+
* value). The type nesting is because Java Optionals cannot understand "Some null" vs "None",
99+
* which is exactly what we need here.
100+
*
101+
* @param key
102+
* @return one of three cases
103+
* @param <T>
104+
*/
105+
@SuppressWarnings("unchecked")
106+
<T> Optional<Optional<T>> getRunnerLocal(RunnerLocalInternal<T> key);
93107
}

temporal-sdk/src/main/java/io/temporal/internal/sync/DeterministicRunnerImpl.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -586,7 +586,7 @@ private boolean areThreadsToBeExecuted() {
586586
* @param <T>
587587
*/
588588
@SuppressWarnings("unchecked")
589-
<T> Optional<Optional<T>> getRunnerLocal(RunnerLocalInternal<T> key) {
589+
public <T> Optional<Optional<T>> getRunnerLocal(RunnerLocalInternal<T> key) {
590590
if (!runnerLocalMap.containsKey(key)) {
591591
return Optional.empty();
592592
}

temporal-sdk/src/main/java/io/temporal/internal/sync/QueryDispatcher.java

Lines changed: 39 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,11 +23,33 @@ class QueryDispatcher {
2323

2424
private DynamicQueryHandler dynamicQueryHandler;
2525
private WorkflowInboundCallsInterceptor inboundCallsInterceptor;
26+
private static final ThreadLocal<SyncWorkflowContext> queryHandlerWorkflowContext =
27+
new ThreadLocal<>();
2628

2729
public QueryDispatcher(DataConverter dataConverterWithWorkflowContext) {
2830
this.dataConverterWithWorkflowContext = dataConverterWithWorkflowContext;
2931
}
3032

33+
/**
34+
* @return True if the current thread is executing a query handler.
35+
*/
36+
public static boolean isQueryHandler() {
37+
SyncWorkflowContext value = queryHandlerWorkflowContext.get();
38+
return value != null;
39+
}
40+
41+
/**
42+
* @return The current workflow context if the current thread is executing a query handler.
43+
* @throws IllegalStateException if not in a query handler.
44+
*/
45+
public static SyncWorkflowContext getWorkflowContext() {
46+
SyncWorkflowContext value = queryHandlerWorkflowContext.get();
47+
if (value == null) {
48+
throw new IllegalStateException("Not in a query handler");
49+
}
50+
return value;
51+
}
52+
3153
public void setInboundCallsInterceptor(WorkflowInboundCallsInterceptor inboundCallsInterceptor) {
3254
this.inboundCallsInterceptor = inboundCallsInterceptor;
3355
}
@@ -51,7 +73,11 @@ public WorkflowInboundCallsInterceptor.QueryOutput handleInterceptedQuery(
5173
return new WorkflowInboundCallsInterceptor.QueryOutput(result);
5274
}
5375

54-
public Optional<Payloads> handleQuery(String queryName, Header header, Optional<Payloads> input) {
76+
public Optional<Payloads> handleQuery(
77+
SyncWorkflowContext replayContext,
78+
String queryName,
79+
Header header,
80+
Optional<Payloads> input) {
5581
WorkflowOutboundCallsInterceptor.RegisterQueryInput handler = queryCallbacks.get(queryName);
5682
Object[] args;
5783
if (queryName.startsWith(TEMPORAL_RESERVED_PREFIX)) {
@@ -69,11 +95,18 @@ public Optional<Payloads> handleQuery(String queryName, Header header, Optional<
6995
dataConverterWithWorkflowContext.fromPayloads(
7096
input, handler.getArgTypes(), handler.getGenericArgTypes());
7197
}
72-
Object result =
73-
inboundCallsInterceptor
74-
.handleQuery(new WorkflowInboundCallsInterceptor.QueryInput(queryName, header, args))
75-
.getResult();
76-
return dataConverterWithWorkflowContext.toPayloads(result);
98+
try {
99+
replayContext.setReadOnly(true);
100+
queryHandlerWorkflowContext.set(replayContext);
101+
Object result =
102+
inboundCallsInterceptor
103+
.handleQuery(new WorkflowInboundCallsInterceptor.QueryInput(queryName, header, args))
104+
.getResult();
105+
return dataConverterWithWorkflowContext.toPayloads(result);
106+
} finally {
107+
replayContext.setReadOnly(false);
108+
queryHandlerWorkflowContext.set(null);
109+
}
77110
}
78111

79112
public void registerQueryHandlers(WorkflowOutboundCallsInterceptor.RegisterQueryInput request) {

temporal-sdk/src/main/java/io/temporal/internal/sync/RunnerLocalInternal.java

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,17 @@ public RunnerLocalInternal(boolean useCaching) {
1616
}
1717

1818
public T get(Supplier<? extends T> supplier) {
19-
Optional<Optional<T>> result =
20-
DeterministicRunnerImpl.currentThreadInternal().getRunner().getRunnerLocal(this);
19+
Optional<Optional<T>> result;
20+
// Query handlers are special in that they are executing in a different context
21+
// than the main workflow execution threads. We need to fetch the runner local from the
22+
// correct context based on whether we are in a query handler or not.
23+
if (QueryDispatcher.isQueryHandler()) {
24+
result = QueryDispatcher.getWorkflowContext().getRunner().getRunnerLocal(this);
25+
} else {
26+
result = DeterministicRunnerImpl.currentThreadInternal().getRunner().getRunnerLocal(this);
27+
}
2128
T out = result.orElseGet(() -> Optional.ofNullable(supplier.get())).orElse(null);
22-
if (!result.isPresent() && useCaching) {
29+
if (!result.isPresent() && useCaching && !QueryDispatcher.isQueryHandler()) {
2330
// This is the first time we've tried fetching this, and caching is enabled. Store it.
2431
set(out);
2532
}

temporal-sdk/src/main/java/io/temporal/internal/sync/SyncWorkflowContext.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -345,7 +345,7 @@ public WorkflowInboundCallsInterceptor.QueryOutput handleInterceptedQuery(
345345
}
346346

347347
public Optional<Payloads> handleQuery(String queryName, Header header, Optional<Payloads> input) {
348-
return queryDispatcher.handleQuery(queryName, header, input);
348+
return queryDispatcher.handleQuery(this, queryName, header, input);
349349
}
350350

351351
public boolean isEveryHandlerFinished() {

temporal-sdk/src/main/java/io/temporal/internal/sync/WorkflowInternal.java

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -843,6 +843,13 @@ static WorkflowOutboundCallsInterceptor getWorkflowOutboundInterceptor() {
843843
}
844844

845845
static SyncWorkflowContext getRootWorkflowContext() {
846+
// If we are in a query handler, we need to get the workflow context from the
847+
// QueryDispatcher, otherwise we get it from the current thread's internal context.
848+
// This is necessary because query handlers run in a different context than the main workflow
849+
// threads.
850+
if (QueryDispatcher.isQueryHandler()) {
851+
return QueryDispatcher.getWorkflowContext();
852+
}
846853
return DeterministicRunnerImpl.currentThreadInternal().getWorkflowContext();
847854
}
848855

temporal-sdk/src/test/java/io/temporal/internal/sync/QueryDispatcherTest.java

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,8 @@ public void testQuerySuccess() {
4949

5050
// Invoke functionality under test, expect no exceptions for an existing query.
5151
Optional<Payloads> queryResult =
52-
dispatcher.handleQuery("QueryB", Header.empty(), Optional.empty());
52+
dispatcher.handleQuery(
53+
mock(SyncWorkflowContext.class), "QueryB", Header.empty(), Optional.empty());
5354
assertTrue(queryResult.isPresent());
5455
}
5556

@@ -61,7 +62,8 @@ public void testQueryDispatcherException() {
6162
assertThrows(
6263
IllegalArgumentException.class,
6364
() -> {
64-
dispatcher.handleQuery("QueryC", Header.empty(), null);
65+
dispatcher.handleQuery(
66+
mock(SyncWorkflowContext.class), "QueryC", Header.empty(), null);
6567
});
6668
assertEquals("Unknown query type: QueryC, knownTypes=[QueryA, QueryB]", exception.getMessage());
6769
}
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
package io.temporal.workflow.queryTests;
2+
3+
import static org.junit.Assert.assertEquals;
4+
5+
import io.temporal.client.WorkflowClient;
6+
import io.temporal.client.WorkflowStub;
7+
import io.temporal.testing.internal.SDKTestWorkflowRule;
8+
import io.temporal.workflow.Workflow;
9+
import io.temporal.workflow.WorkflowInfo;
10+
import io.temporal.workflow.WorkflowLocal;
11+
import io.temporal.workflow.shared.TestWorkflows;
12+
import java.time.Duration;
13+
import org.junit.Rule;
14+
import org.junit.Test;
15+
16+
public class WorkflowInfoAndLocalInQueryTest {
17+
18+
@Rule
19+
public SDKTestWorkflowRule testWorkflowRule =
20+
SDKTestWorkflowRule.newBuilder().setWorkflowTypes(TestWorkflow.class).build();
21+
22+
@Test
23+
public void queryReturnsInfoAndLocal() {
24+
TestWorkflows.TestWorkflowWithQuery workflowStub =
25+
testWorkflowRule.newWorkflowStub(TestWorkflows.TestWorkflowWithQuery.class);
26+
WorkflowClient.start(workflowStub::execute);
27+
28+
assertEquals("attempt=1 local=42", workflowStub.query());
29+
assertEquals("done", WorkflowStub.fromTyped(workflowStub).getResult(String.class));
30+
}
31+
32+
public static class TestWorkflow implements TestWorkflows.TestWorkflowWithQuery {
33+
34+
private final WorkflowLocal<Integer> local = WorkflowLocal.withCachedInitial(() -> 0);
35+
36+
@Override
37+
public String execute() {
38+
local.set(42);
39+
Workflow.sleep(Duration.ofSeconds(1));
40+
return "done";
41+
}
42+
43+
@Override
44+
public String query() {
45+
WorkflowInfo info = Workflow.getInfo();
46+
return "attempt=" + info.getAttempt() + " local=" + local.get();
47+
}
48+
}
49+
}

0 commit comments

Comments
 (0)