Skip to content

Commit 189effe

Browse files
committed
Add nexus to Custom SlotSupplier test
1 parent a083d2f commit 189effe

File tree

1 file changed

+48
-4
lines changed

1 file changed

+48
-4
lines changed

tests/worker/test_worker.py

Lines changed: 48 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
CustomSlotSupplier,
3636
FixedSizeSlotSupplier,
3737
LocalActivitySlotInfo,
38+
NexusSlotInfo,
3839
PollerBehaviorAutoscaling,
3940
ResourceBasedSlotConfig,
4041
ResourceBasedSlotSupplier,
@@ -56,6 +57,7 @@
5657
new_worker,
5758
worker_versioning_enabled,
5859
)
60+
from tests.helpers.nexus import create_nexus_endpoint, make_nexus_endpoint_name
5961

6062
# Passing through because Python 3.9 has an import bug at
6163
# https://github.com/python/cpython/issues/91351
@@ -404,6 +406,44 @@ async def test_warns_when_workers_too_low(client: Client, env: WorkflowEnvironme
404406
pass
405407

406408

409+
@nexusrpc.handler.service_handler
410+
class SayHelloService:
411+
@nexusrpc.handler.sync_operation
412+
async def say_hello(
413+
self, _ctx: nexusrpc.handler.StartOperationContext, name: str
414+
) -> str:
415+
return f"Hello, {name}!"
416+
417+
418+
@workflow.defn
419+
class CustomSlotSupplierWorkflow:
420+
def __init__(self) -> None:
421+
self._last_signal = "<none>"
422+
423+
@workflow.run
424+
async def run(self) -> None:
425+
await workflow.wait_condition(lambda: self._last_signal == "finish")
426+
await workflow.execute_activity(
427+
say_hello,
428+
"hi",
429+
versioning_intent=VersioningIntent.DEFAULT,
430+
start_to_close_timeout=timedelta(seconds=5),
431+
)
432+
nexus_client = workflow.create_nexus_client(
433+
endpoint=make_nexus_endpoint_name(workflow.info().task_queue),
434+
service=SayHelloService,
435+
)
436+
await nexus_client.execute_operation(
437+
SayHelloService.say_hello,
438+
"hi",
439+
)
440+
441+
@workflow.signal
442+
def my_signal(self, value: str) -> None:
443+
self._last_signal = value
444+
workflow.logger.info(f"Signal: {value}")
445+
446+
407447
async def test_custom_slot_supplier(client: Client, env: WorkflowEnvironment):
408448
class MyPermit(SlotPermit):
409449
def __init__(self, pnum: int):
@@ -443,6 +483,8 @@ def mark_slot_used(self, ctx: SlotMarkUsedContext) -> None:
443483
self.seen_used_slot_kinds.add("a")
444484
elif isinstance(ctx.slot_info, LocalActivitySlotInfo):
445485
self.seen_used_slot_kinds.add("la")
486+
elif isinstance(ctx.slot_info, NexusSlotInfo):
487+
self.seen_used_slot_kinds.add("nx")
446488
self.used += 1
447489

448490
def release_slot(self, ctx: SlotReleaseContext) -> None:
@@ -476,17 +518,19 @@ def reserve_asserts(self, ctx: SlotReserveContext) -> None:
476518
)
477519
async with new_worker(
478520
client,
479-
WaitOnSignalWorkflow,
521+
CustomSlotSupplierWorkflow,
480522
activities=[say_hello],
523+
nexus_service_handlers=[SayHelloService()],
481524
tuner=tuner,
482525
identity="myworker",
483526
) as w:
527+
await create_nexus_endpoint(w.task_queue, client)
484528
wf1 = await client.start_workflow(
485-
WaitOnSignalWorkflow.run,
529+
CustomSlotSupplierWorkflow.run,
486530
id=f"custom-slot-supplier-{uuid.uuid4()}",
487531
task_queue=w.task_queue,
488532
)
489-
await wf1.signal(WaitOnSignalWorkflow.my_signal, "finish")
533+
await wf1.signal(CustomSlotSupplierWorkflow.my_signal, "finish")
490534
await wf1.result()
491535

492536
# We can't use reserve number directly because there is a technically possible race
@@ -495,7 +539,7 @@ def reserve_asserts(self, ctx: SlotReserveContext) -> None:
495539
# that the permits passed to release line up.
496540
assert ss.highest_seen_reserve_on_release == ss.releases
497541
# Two workflow tasks, one activity
498-
assert ss.used == 3
542+
assert ss.used == 4
499543
assert ss.seen_sticky_kinds == {True, False}
500544
assert ss.seen_slot_kinds == {"workflow", "activity", "local-activity"}
501545
assert ss.seen_used_slot_kinds == {"wf", "a"}

0 commit comments

Comments
 (0)