Skip to content

Commit f87f5b3

Browse files
committed
Test async activity completion
1 parent b5741df commit f87f5b3

File tree

1 file changed

+94
-3
lines changed

1 file changed

+94
-3
lines changed

tests/test_serialization_context.py

Lines changed: 94 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from itertools import zip_longest
1010
from pprint import pformat, pprint
1111
from typing import Any, Literal, Optional, Type
12+
from warnings import warn
1213

1314
import pytest
1415

@@ -316,6 +317,98 @@ async def test_workflow_payload_conversion(
316317
pprint(result.items)
317318

318319

320+
async_activity_started = asyncio.Event()
321+
322+
323+
# Async activity completion test
324+
@activity.defn
325+
async def async_activity() -> TraceData:
326+
async_activity_started.set()
327+
activity.raise_complete_async()
328+
329+
330+
@workflow.defn
331+
class AsyncActivityCompletionSerializationContextTestWorkflow:
332+
@workflow.run
333+
async def run(self) -> TraceData:
334+
return await workflow.execute_activity(
335+
async_activity,
336+
start_to_close_timeout=timedelta(seconds=10),
337+
activity_id="async-activity-id",
338+
)
339+
340+
341+
async def test_async_activity_completion_payload_conversion(
342+
client: Client,
343+
):
344+
workflow_id = str(uuid.uuid4())
345+
task_queue = str(uuid.uuid4())
346+
347+
config = client.config()
348+
config["data_converter"] = data_converter
349+
client = Client(**config)
350+
351+
async with Worker(
352+
client,
353+
task_queue=task_queue,
354+
workflows=[AsyncActivityCompletionSerializationContextTestWorkflow],
355+
activities=[async_activity],
356+
workflow_runner=UnsandboxedWorkflowRunner(), # so that we can use isinstance
357+
):
358+
wf_handle = await client.start_workflow(
359+
AsyncActivityCompletionSerializationContextTestWorkflow.run,
360+
id=workflow_id,
361+
task_queue=task_queue,
362+
)
363+
activity_handle = client.get_async_activity_handle(
364+
workflow_id=workflow_id,
365+
run_id=wf_handle.first_execution_run_id,
366+
activity_id="async-activity-id",
367+
)
368+
await async_activity_started.wait()
369+
data = TraceData()
370+
await activity_handle.heartbeat(data)
371+
await activity_handle.complete(data)
372+
result = await wf_handle.result()
373+
374+
# project down since activity completion by a client does not have access to most activity
375+
# context fields
376+
def project(trace_item: TraceItem) -> tuple[str, bool, str]:
377+
return (
378+
trace_item.context_type,
379+
trace_item.in_workflow,
380+
trace_item.method,
381+
)
382+
383+
assert [project(item) for item in result.items] == [
384+
(
385+
"activity",
386+
False,
387+
"to_payload", # Outbound activity input
388+
),
389+
(
390+
"activity",
391+
False,
392+
"to_payload", # Outbound activity heartbeat data
393+
),
394+
(
395+
"activity",
396+
False,
397+
"from_payload", # Inbound activity result
398+
),
399+
(
400+
"workflow",
401+
True,
402+
"to_payload", # Outbound workflow result
403+
),
404+
(
405+
"workflow",
406+
False,
407+
"from_payload", # Inbound workflow result
408+
),
409+
]
410+
411+
319412
# Signal test
320413

321414

@@ -596,9 +689,7 @@ async def test_update_payload_conversion(
596689

597690
def assert_trace(trace: list[TraceItem], expected: list[TraceItem]):
598691
if len(trace) != len(expected):
599-
raise AssertionError(
600-
f"expected {len(expected)} trace items but received {len(trace)}"
601-
)
692+
warn(f"expected {len(expected)} trace items but received {len(trace)}")
602693
history = []
603694
for item, expected_item in zip_longest(trace, expected):
604695
if item is None:

0 commit comments

Comments
 (0)