Skip to content

Commit a9c71aa

Browse files
authored
Support multiple nexus callers attaching to same workflow (#1051)
1 parent 607641b commit a9c71aa

File tree

4 files changed

+192
-54
lines changed

4 files changed

+192
-54
lines changed

temporalio/client.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@
5757
import temporalio.converter
5858
import temporalio.exceptions
5959
import temporalio.nexus
60+
import temporalio.nexus._operation_context
6061
import temporalio.runtime
6162
import temporalio.service
6263
import temporalio.workflow
@@ -5877,6 +5878,12 @@ async def _build_start_workflow_execution_request(
58775878
)
58785879
# Links are duplicated on request for compatibility with older server versions.
58795880
req.links.extend(links)
5881+
5882+
if temporalio.nexus._operation_context._in_nexus_backing_workflow_start_context():
5883+
req.on_conflict_options.attach_request_id = True
5884+
req.on_conflict_options.attach_completion_callbacks = True
5885+
req.on_conflict_options.attach_links = True
5886+
58805887
return req
58815888

58825889
async def _build_signal_with_start_workflow_execution_request(
@@ -5932,6 +5939,7 @@ async def _populate_start_workflow_execution_request(
59325939
"temporalio.api.enums.v1.WorkflowIdConflictPolicy.ValueType",
59335940
int(input.id_conflict_policy),
59345941
)
5942+
59355943
if input.retry_policy is not None:
59365944
input.retry_policy.apply_to_proto(req.retry_policy)
59375945
req.cron_schedule = input.cron_schedule

temporalio/nexus/_operation_context.py

Lines changed: 60 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,15 @@
33
import dataclasses
44
import logging
55
from collections.abc import Awaitable, Mapping, MutableMapping, Sequence
6+
from contextlib import contextmanager
67
from contextvars import ContextVar
78
from dataclasses import dataclass
89
from datetime import timedelta
910
from typing import (
1011
TYPE_CHECKING,
1112
Any,
1213
Callable,
14+
Generator,
1315
Optional,
1416
Union,
1517
overload,
@@ -47,6 +49,14 @@
4749
ContextVar("temporal-cancel-operation-context")
4850
)
4951

52+
# A Nexus start handler might start zero or more workflows as usual using a Temporal client. In
53+
# addition, it may start one "nexus-backing" workflow, using
54+
# WorkflowRunOperationContext.start_workflow. This context is active while the latter is being done.
55+
# It is thus a narrower context than _temporal_start_operation_context.
56+
_temporal_nexus_backing_workflow_start_context: ContextVar[bool] = ContextVar(
57+
"temporal-nexus-backing-workflow-start-context"
58+
)
59+
5060

5161
@dataclass(frozen=True)
5262
class Info:
@@ -96,6 +106,19 @@ def _try_temporal_context() -> (
96106
return start_ctx or cancel_ctx
97107

98108

109+
@contextmanager
110+
def _nexus_backing_workflow_start_context() -> Generator[None, None, None]:
111+
token = _temporal_nexus_backing_workflow_start_context.set(True)
112+
try:
113+
yield
114+
finally:
115+
_temporal_nexus_backing_workflow_start_context.reset(token)
116+
117+
118+
def _in_nexus_backing_workflow_start_context() -> bool:
119+
return _temporal_nexus_backing_workflow_start_context.get(False)
120+
121+
99122
@dataclass
100123
class _TemporalStartOperationContext:
101124
"""Context for a Nexus start operation being handled by a Temporal Nexus Worker."""
@@ -396,56 +419,46 @@ async def start_workflow(
396419
Nexus caller is itself a workflow, this means that the workflow in the caller
397420
namespace web UI will contain links to the started workflow, and vice versa.
398421
"""
399-
# TODO(nexus-preview): When sdk-python supports on_conflict_options, Typescript does this:
400-
# if (workflowOptions.workflowIdConflictPolicy === 'USE_EXISTING') {
401-
# internalOptions.onConflictOptions = {
402-
# attachLinks: true,
403-
# attachCompletionCallbacks: true,
404-
# attachRequestId: true,
405-
# };
406-
# }
407-
if (
408-
id_conflict_policy
409-
== temporalio.common.WorkflowIDConflictPolicy.USE_EXISTING
410-
):
411-
raise RuntimeError(
412-
"WorkflowIDConflictPolicy.USE_EXISTING is not yet supported when starting a workflow "
413-
"that backs a Nexus operation (Python SDK Nexus support is at Pre-release stage)."
414-
)
415-
416422
# We must pass nexus_completion_callbacks, workflow_event_links, and request_id,
417423
# but these are deliberately not exposed in overloads, hence the type-check
418424
# violation.
419-
wf_handle = await self._temporal_context.client.start_workflow( # type: ignore
420-
workflow=workflow,
421-
arg=arg,
422-
args=args,
423-
id=id,
424-
task_queue=task_queue or self._temporal_context.info().task_queue,
425-
result_type=result_type,
426-
execution_timeout=execution_timeout,
427-
run_timeout=run_timeout,
428-
task_timeout=task_timeout,
429-
id_reuse_policy=id_reuse_policy,
430-
id_conflict_policy=id_conflict_policy,
431-
retry_policy=retry_policy,
432-
cron_schedule=cron_schedule,
433-
memo=memo,
434-
search_attributes=search_attributes,
435-
static_summary=static_summary,
436-
static_details=static_details,
437-
start_delay=start_delay,
438-
start_signal=start_signal,
439-
start_signal_args=start_signal_args,
440-
rpc_metadata=rpc_metadata,
441-
rpc_timeout=rpc_timeout,
442-
request_eager_start=request_eager_start,
443-
priority=priority,
444-
versioning_override=versioning_override,
445-
callbacks=self._temporal_context._get_callbacks(),
446-
workflow_event_links=self._temporal_context._get_workflow_event_links(),
447-
request_id=self._temporal_context.nexus_context.request_id,
448-
)
425+
426+
# Here we are starting a "nexus-backing" workflow. That means that the StartWorkflow request
427+
# contains nexus-specific data such as a completion callback (used by the handler server
428+
# namespace to deliver the result to the caller namespace when the workflow reaches a
429+
# terminal state) and inbound links to the caller workflow (attached to history events of
430+
# the workflow started in the handler namespace, and displayed in the UI).
431+
with _nexus_backing_workflow_start_context():
432+
wf_handle = await self._temporal_context.client.start_workflow( # type: ignore
433+
workflow=workflow,
434+
arg=arg,
435+
args=args,
436+
id=id,
437+
task_queue=task_queue or self._temporal_context.info().task_queue,
438+
result_type=result_type,
439+
execution_timeout=execution_timeout,
440+
run_timeout=run_timeout,
441+
task_timeout=task_timeout,
442+
id_reuse_policy=id_reuse_policy,
443+
id_conflict_policy=id_conflict_policy,
444+
retry_policy=retry_policy,
445+
cron_schedule=cron_schedule,
446+
memo=memo,
447+
search_attributes=search_attributes,
448+
static_summary=static_summary,
449+
static_details=static_details,
450+
start_delay=start_delay,
451+
start_signal=start_signal,
452+
start_signal_args=start_signal_args,
453+
rpc_metadata=rpc_metadata,
454+
rpc_timeout=rpc_timeout,
455+
request_eager_start=request_eager_start,
456+
priority=priority,
457+
versioning_override=versioning_override,
458+
callbacks=self._temporal_context._get_callbacks(),
459+
workflow_event_links=self._temporal_context._get_workflow_event_links(),
460+
request_id=self._temporal_context.nexus_context.request_id,
461+
)
449462

450463
self._temporal_context._add_outbound_links(wf_handle)
451464

temporalio/worker/_workflow_instance.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3029,13 +3029,6 @@ def operation_token(self) -> Optional[str]:
30293029
def __await__(self) -> Generator[Any, Any, OutputT]:
30303030
return self._task.__await__()
30313031

3032-
def __repr__(self) -> str:
3033-
return (
3034-
f"{self._start_fut} "
3035-
f"{self._result_fut} "
3036-
f"Task[{self._task._state}] fut_waiter = {self._task._fut_waiter}) ({self._task._must_cancel})" # type: ignore
3037-
)
3038-
30393032
def cancel(self) -> bool:
30403033
return self._task.cancel()
30413034

Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
from __future__ import annotations
2+
3+
import asyncio
4+
import uuid
5+
from dataclasses import dataclass
6+
from typing import Optional
7+
8+
import pytest
9+
from nexusrpc.handler import service_handler
10+
11+
from temporalio import nexus, workflow
12+
from temporalio.client import Client
13+
from temporalio.common import WorkflowIDConflictPolicy
14+
from temporalio.testing import WorkflowEnvironment
15+
from temporalio.worker import Worker
16+
from tests.helpers.nexus import create_nexus_endpoint, make_nexus_endpoint_name
17+
18+
19+
@dataclass
20+
class OpInput:
21+
workflow_id: str
22+
conflict_policy: WorkflowIDConflictPolicy
23+
24+
25+
@workflow.defn
26+
class HandlerWorkflow:
27+
def __init__(self) -> None:
28+
self.result: Optional[str] = None
29+
30+
@workflow.run
31+
async def run(self) -> str:
32+
await workflow.wait_condition(lambda: self.result is not None)
33+
assert self.result
34+
return self.result
35+
36+
@workflow.signal
37+
def complete(self, result: str) -> None:
38+
self.result = result
39+
40+
41+
@service_handler
42+
class NexusService:
43+
@nexus.workflow_run_operation
44+
async def workflow_backed_operation(
45+
self, ctx: nexus.WorkflowRunOperationContext, input: OpInput
46+
) -> nexus.WorkflowHandle[str]:
47+
return await ctx.start_workflow(
48+
HandlerWorkflow.run,
49+
id=input.workflow_id,
50+
id_conflict_policy=input.conflict_policy,
51+
)
52+
53+
54+
@dataclass
55+
class CallerWorkflowInput:
56+
workflow_id: str
57+
task_queue: str
58+
num_operations: int
59+
60+
61+
@workflow.defn
62+
class CallerWorkflow:
63+
def __init__(self) -> None:
64+
self._nexus_operations_have_started = asyncio.Event()
65+
66+
@workflow.run
67+
async def run(self, input: CallerWorkflowInput) -> list[str]:
68+
nexus_client = workflow.create_nexus_client(
69+
service=NexusService, endpoint=make_nexus_endpoint_name(input.task_queue)
70+
)
71+
72+
op_input = OpInput(
73+
workflow_id=input.workflow_id,
74+
conflict_policy=WorkflowIDConflictPolicy.USE_EXISTING,
75+
)
76+
77+
handles = []
78+
for _ in range(input.num_operations):
79+
handles.append(
80+
await nexus_client.start_operation(
81+
NexusService.workflow_backed_operation, op_input
82+
)
83+
)
84+
self._nexus_operations_have_started.set()
85+
return await asyncio.gather(*handles)
86+
87+
@workflow.update
88+
async def nexus_operations_have_started(self) -> None:
89+
await self._nexus_operations_have_started.wait()
90+
91+
92+
async def test_multiple_operation_invocations_can_connect_to_same_handler_workflow(
93+
client: Client, env: WorkflowEnvironment
94+
):
95+
if env.supports_time_skipping:
96+
pytest.skip("Nexus tests don't work with time-skipping server")
97+
98+
task_queue = str(uuid.uuid4())
99+
workflow_id = str(uuid.uuid4())
100+
101+
async with Worker(
102+
client,
103+
nexus_service_handlers=[NexusService()],
104+
workflows=[CallerWorkflow, HandlerWorkflow],
105+
task_queue=task_queue,
106+
):
107+
await create_nexus_endpoint(task_queue, client)
108+
caller_handle = await client.start_workflow(
109+
CallerWorkflow.run,
110+
args=[
111+
CallerWorkflowInput(
112+
workflow_id=workflow_id,
113+
task_queue=task_queue,
114+
num_operations=5,
115+
)
116+
],
117+
id=str(uuid.uuid4()),
118+
task_queue=task_queue,
119+
)
120+
await caller_handle.execute_update(CallerWorkflow.nexus_operations_have_started)
121+
await client.get_workflow_handle(workflow_id).signal(
122+
HandlerWorkflow.complete, "test-result"
123+
)
124+
assert await caller_handle.result() == ["test-result"] * 5

0 commit comments

Comments
 (0)