Skip to content

Commit 4c442cf

Browse files
committed
Test max_concurrent
1 parent a6e7852 commit 4c442cf

File tree

1 file changed

+81
-0
lines changed

1 file changed

+81
-0
lines changed

tests/nexus/test_worker.py

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
"""Test Nexus worker configuration, specifically max_concurrent_nexus_tasks."""
2+
3+
from __future__ import annotations
4+
5+
import asyncio
6+
import uuid
7+
from datetime import timedelta
8+
9+
import nexusrpc.handler
10+
import pytest
11+
12+
from temporalio import workflow
13+
from temporalio.client import Client
14+
from tests.helpers import new_worker
15+
from tests.helpers.nexus import create_nexus_endpoint, make_nexus_endpoint_name
16+
17+
18+
@workflow.defn(sandboxed=False)
19+
class NexusCallerWorkflow:
20+
"""Workflow that calls a Nexus operation."""
21+
22+
@workflow.run
23+
async def run(self, id: int) -> None:
24+
nexus_client = workflow.create_nexus_client(
25+
endpoint=make_nexus_endpoint_name(workflow.info().task_queue),
26+
service="MaxConcurrentTestService",
27+
)
28+
29+
await nexus_client.execute_operation(
30+
"op",
31+
id,
32+
schedule_to_close_timeout=timedelta(seconds=60),
33+
)
34+
35+
36+
@pytest.mark.parametrize(
37+
["num_nexus_operations", "max_concurrent_nexus_tasks", "expected_num_executed"],
38+
[(1, 1, 1), (2, 1, 1), (43, 42, 42), (43, 44, 43)],
39+
)
40+
async def test_max_concurrent_nexus_tasks(
41+
client: Client,
42+
max_concurrent_nexus_tasks: int,
43+
num_nexus_operations: int,
44+
expected_num_executed: int,
45+
):
46+
ids = []
47+
event = asyncio.Event()
48+
49+
@nexusrpc.handler.service_handler
50+
class MaxConcurrentTestService:
51+
@nexusrpc.handler.sync_operation
52+
async def op(
53+
self, _ctx: nexusrpc.handler.StartOperationContext, id: int
54+
) -> None:
55+
ids.append(id)
56+
await event.wait()
57+
58+
async with new_worker(
59+
client,
60+
NexusCallerWorkflow,
61+
nexus_service_handlers=[MaxConcurrentTestService()],
62+
max_concurrent_nexus_tasks=max_concurrent_nexus_tasks,
63+
) as worker:
64+
await create_nexus_endpoint(worker.task_queue, client)
65+
66+
coros = [
67+
client.execute_workflow(
68+
NexusCallerWorkflow.run,
69+
i,
70+
id=str(uuid.uuid4()),
71+
task_queue=worker.task_queue,
72+
)
73+
for i in range(num_nexus_operations)
74+
]
75+
try:
76+
await asyncio.wait_for(asyncio.gather(*coros), timeout=3)
77+
except asyncio.TimeoutError:
78+
pass
79+
event.set()
80+
assert len(set(ids)) == len(ids)
81+
assert len(ids) == expected_num_executed

0 commit comments

Comments
 (0)