Skip to content

Commit e402c7e

Browse files
committed
Revert "Delete test"
This reverts commit 16bf494.
1 parent 1361e49 commit e402c7e

File tree

2 files changed

+112
-1
lines changed

2 files changed

+112
-1
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ lint-types = [
8282
{ cmd = "uv run mypy --namespace-packages --check-untyped-defs ."},
8383
]
8484
run-bench = "uv run python scripts/run_bench.py"
85-
test = "uv run pytest"
85+
test = "uv run pytest tests/nexus/test_worker.py"
8686

8787

8888
[tool.pytest.ini_options]

tests/nexus/test_worker.py

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
from __future__ import annotations
2+
3+
import asyncio
4+
import uuid
5+
from datetime import timedelta
6+
from typing import Any
7+
8+
import nexusrpc.handler
9+
import pytest
10+
11+
from temporalio import workflow
12+
from temporalio.testing import WorkflowEnvironment
13+
from tests.helpers import new_worker
14+
from tests.helpers.nexus import create_nexus_endpoint, make_nexus_endpoint_name
15+
16+
17+
@workflow.defn
18+
class NexusCallerWorkflow:
19+
"""Workflow that calls a Nexus operation."""
20+
21+
@workflow.run
22+
async def run(self, n: int) -> None:
23+
nexus_client = workflow.create_nexus_client(
24+
endpoint=make_nexus_endpoint_name(workflow.info().task_queue),
25+
service="MaxConcurrentTestService",
26+
)
27+
28+
coros: list[Any] = [
29+
nexus_client.execute_operation(
30+
"op",
31+
i,
32+
schedule_to_close_timeout=timedelta(seconds=60),
33+
)
34+
for i in range(n)
35+
]
36+
await asyncio.gather(*coros)
37+
38+
39+
@pytest.mark.parametrize(
40+
["num_nexus_operations", "max_concurrent_nexus_tasks"],
41+
[
42+
(1, 1),
43+
(3, 3),
44+
(4, 3),
45+
],
46+
)
47+
async def test_max_concurrent_nexus_tasks(
48+
env: WorkflowEnvironment,
49+
max_concurrent_nexus_tasks: int,
50+
num_nexus_operations: int,
51+
):
52+
if env.supports_time_skipping:
53+
pytest.skip("Nexus tests don't work with Javas test server")
54+
55+
barrier = Barrier(num_nexus_operations)
56+
57+
@nexusrpc.handler.service_handler
58+
class MaxConcurrentTestService:
59+
@nexusrpc.handler.sync_operation
60+
async def op(
61+
self, _ctx: nexusrpc.handler.StartOperationContext, id: int
62+
) -> None:
63+
await barrier.wait()
64+
65+
async with new_worker(
66+
env.client,
67+
NexusCallerWorkflow,
68+
nexus_service_handlers=[MaxConcurrentTestService()],
69+
max_concurrent_nexus_tasks=max_concurrent_nexus_tasks,
70+
) as worker:
71+
await create_nexus_endpoint(worker.task_queue, env.client)
72+
73+
execute_operations_concurrently = env.client.execute_workflow(
74+
NexusCallerWorkflow.run,
75+
num_nexus_operations,
76+
id=str(uuid.uuid4()),
77+
task_queue=worker.task_queue,
78+
)
79+
if num_nexus_operations <= max_concurrent_nexus_tasks:
80+
await execute_operations_concurrently
81+
else:
82+
try:
83+
await asyncio.wait_for(execute_operations_concurrently, timeout=10)
84+
except TimeoutError:
85+
pass
86+
else:
87+
pytest.fail(
88+
f"Expected timeout: "
89+
f"max_concurrent_nexus_tasks={max_concurrent_nexus_tasks}, "
90+
f"num_nexus_operations={num_nexus_operations}"
91+
)
92+
93+
94+
# Minimal implementation of asyncio.Barrier for Python 3.9+ compatibility
95+
96+
97+
class Barrier:
98+
def __init__(self, parties: int):
99+
"""Create a barrier for 'parties' tasks."""
100+
if parties < 1:
101+
raise ValueError("parties must be > 0")
102+
self._parties = parties
103+
self._count = 0
104+
self._event = asyncio.Event()
105+
106+
async def wait(self) -> None:
107+
"""Wait for all parties to reach the barrier."""
108+
self._count += 1
109+
if self._count == self._parties:
110+
self._event.set()
111+
await self._event.wait()

0 commit comments

Comments
 (0)