Skip to content

Commit 00e7e37

Browse files
committed
Test max_concurrent
1 parent a6e7852 commit 00e7e37

File tree

1 file changed

+79
-0
lines changed

1 file changed

+79
-0
lines changed

tests/nexus/test_worker.py

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
from __future__ import annotations
2+
3+
import asyncio
4+
import uuid
5+
from datetime import timedelta
6+
7+
import nexusrpc.handler
8+
import pytest
9+
10+
from temporalio import workflow
11+
from temporalio.client import Client
12+
from tests.helpers import new_worker
13+
from tests.helpers.nexus import create_nexus_endpoint, make_nexus_endpoint_name
14+
15+
16+
@workflow.defn
17+
class NexusCallerWorkflow:
18+
"""Workflow that calls a Nexus operation."""
19+
20+
@workflow.run
21+
async def run(self, id: int) -> None:
22+
nexus_client = workflow.create_nexus_client(
23+
endpoint=make_nexus_endpoint_name(workflow.info().task_queue),
24+
service="MaxConcurrentTestService",
25+
)
26+
27+
await nexus_client.execute_operation(
28+
"op",
29+
id,
30+
schedule_to_close_timeout=timedelta(seconds=60),
31+
)
32+
33+
34+
@pytest.mark.parametrize(
35+
["num_nexus_operations", "max_concurrent_nexus_tasks", "expected_num_executed"],
36+
[(1, 1, 1), (2, 1, 1), (43, 42, 42), (43, 44, 43)],
37+
)
38+
async def test_max_concurrent_nexus_tasks(
39+
client: Client,
40+
max_concurrent_nexus_tasks: int,
41+
num_nexus_operations: int,
42+
expected_num_executed: int,
43+
):
44+
ids = []
45+
event = asyncio.Event()
46+
47+
@nexusrpc.handler.service_handler
48+
class MaxConcurrentTestService:
49+
@nexusrpc.handler.sync_operation
50+
async def op(
51+
self, _ctx: nexusrpc.handler.StartOperationContext, id: int
52+
) -> None:
53+
ids.append(id)
54+
await event.wait()
55+
56+
async with new_worker(
57+
client,
58+
NexusCallerWorkflow,
59+
nexus_service_handlers=[MaxConcurrentTestService()],
60+
max_concurrent_nexus_tasks=max_concurrent_nexus_tasks,
61+
) as worker:
62+
await create_nexus_endpoint(worker.task_queue, client)
63+
64+
coros = [
65+
client.execute_workflow(
66+
NexusCallerWorkflow.run,
67+
i,
68+
id=str(uuid.uuid4()),
69+
task_queue=worker.task_queue,
70+
)
71+
for i in range(num_nexus_operations)
72+
]
73+
try:
74+
await asyncio.wait_for(asyncio.gather(*coros), timeout=3)
75+
except asyncio.TimeoutError:
76+
pass
77+
event.set()
78+
assert len(set(ids)) == len(ids)
79+
assert len(ids) == expected_num_executed

0 commit comments

Comments
 (0)