Skip to content

Commit 7cb0597

Browse files
committed
Failing test: USE_EXISTING conflict policy
1 parent 607641b commit 7cb0597

File tree

2 files changed

+111
-8
lines changed

2 files changed

+111
-8
lines changed

temporalio/nexus/_operation_context.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -404,14 +404,6 @@ async def start_workflow(
404404
# attachRequestId: true,
405405
# };
406406
# }
407-
if (
408-
id_conflict_policy
409-
== temporalio.common.WorkflowIDConflictPolicy.USE_EXISTING
410-
):
411-
raise RuntimeError(
412-
"WorkflowIDConflictPolicy.USE_EXISTING is not yet supported when starting a workflow "
413-
"that backs a Nexus operation (Python SDK Nexus support is at Pre-release stage)."
414-
)
415407

416408
# We must pass nexus_completion_callbacks, workflow_event_links, and request_id,
417409
# but these are deliberately not exposed in overloads, hence the type-check
Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
from __future__ import annotations
2+
3+
import asyncio
4+
import uuid
5+
from dataclasses import dataclass
6+
from typing import Optional
7+
8+
import pytest
9+
from nexusrpc.handler import service_handler
10+
11+
from temporalio import nexus, workflow
12+
from temporalio.client import Client
13+
from temporalio.common import WorkflowIDConflictPolicy
14+
from temporalio.testing import WorkflowEnvironment
15+
from temporalio.worker import Worker
16+
from tests.helpers.nexus import create_nexus_endpoint, make_nexus_endpoint_name
17+
18+
19+
@dataclass
20+
class OpInput:
21+
workflow_id: str
22+
conflict_policy: WorkflowIDConflictPolicy
23+
24+
25+
@workflow.defn
26+
class HandlerWorkflow:
27+
def __init__(self) -> None:
28+
self.result: Optional[str] = None
29+
30+
@workflow.run
31+
async def run(self) -> str:
32+
await workflow.wait_condition(lambda: self.result is not None)
33+
assert self.result
34+
return self.result
35+
36+
@workflow.signal
37+
def complete(self, result: str) -> None:
38+
self.result = result
39+
40+
41+
@service_handler
42+
class NexusService:
43+
@nexus.workflow_run_operation
44+
async def workflow_backed_operation(
45+
self, ctx: nexus.WorkflowRunOperationContext, input: OpInput
46+
) -> nexus.WorkflowHandle[str]:
47+
return await ctx.start_workflow(
48+
HandlerWorkflow.run,
49+
id=input.workflow_id,
50+
id_conflict_policy=input.conflict_policy,
51+
)
52+
53+
54+
@workflow.defn
55+
class CallerWorkflow:
56+
def __init__(self) -> None:
57+
self._nexus_operations_have_started = asyncio.Event()
58+
59+
@workflow.run
60+
async def run(self, workflow_id: str, task_queue: str) -> list[str]:
61+
nexus_client = workflow.create_nexus_client(
62+
service=NexusService, endpoint=make_nexus_endpoint_name(task_queue)
63+
)
64+
65+
op_input = OpInput(
66+
workflow_id=workflow_id,
67+
conflict_policy=WorkflowIDConflictPolicy.USE_EXISTING,
68+
)
69+
70+
handles = []
71+
for _ in range(5):
72+
handles.append(
73+
await nexus_client.start_operation(
74+
NexusService.workflow_backed_operation, op_input
75+
)
76+
)
77+
self._nexus_operations_have_started.set()
78+
return await asyncio.gather(*handles)
79+
80+
@workflow.update
81+
async def nexus_operations_have_started(self) -> None:
82+
await self._nexus_operations_have_started.wait()
83+
84+
85+
async def test_multiple_operation_invocations_can_connect_to_same_handler_workflow(
86+
client: Client, env: WorkflowEnvironment
87+
):
88+
if env.supports_time_skipping:
89+
pytest.skip("Nexus tests don't work with time-skipping server")
90+
91+
task_queue = str(uuid.uuid4())
92+
workflow_id = str(uuid.uuid4())
93+
94+
async with Worker(
95+
client,
96+
nexus_service_handlers=[NexusService()],
97+
workflows=[CallerWorkflow, HandlerWorkflow],
98+
task_queue=task_queue,
99+
):
100+
await create_nexus_endpoint(task_queue, client)
101+
caller_handle = await client.start_workflow(
102+
CallerWorkflow.run,
103+
args=[workflow_id, task_queue],
104+
id=str(uuid.uuid4()),
105+
task_queue=task_queue,
106+
)
107+
await caller_handle.execute_update(CallerWorkflow.nexus_operations_have_started)
108+
await client.get_workflow_handle(workflow_id).signal(
109+
HandlerWorkflow.complete, "test-result"
110+
)
111+
assert await caller_handle.result() == ["test-result"] * 5

0 commit comments

Comments
 (0)