Skip to content

Commit 3af3c19

Browse files
committed
Test async activity completion
1 parent 553efca commit 3af3c19

File tree

1 file changed

+107
-3
lines changed

1 file changed

+107
-3
lines changed

tests/test_serialization_context.py

Lines changed: 107 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from itertools import zip_longest
1010
from pprint import pformat, pprint
1111
from typing import Any, Literal, Optional, Type
12+
from warnings import warn
1213

1314
from temporalio import activity, workflow
1415
from temporalio.api.common.v1 import Payload
@@ -314,6 +315,111 @@ async def test_workflow_payload_conversion(
314315
pprint(result.items)
315316

316317

318+
async_activity_started = asyncio.Event()
319+
320+
321+
# Async activity completion test
322+
@activity.defn
323+
async def async_activity() -> TraceData:
324+
async_activity_started.set()
325+
activity.raise_complete_async()
326+
327+
328+
@workflow.defn
329+
class AsyncActivityCompletionSerializationContextTestWorkflow:
330+
@workflow.run
331+
async def run(self) -> TraceData:
332+
return await workflow.execute_activity(
333+
async_activity,
334+
start_to_close_timeout=timedelta(seconds=10),
335+
activity_id="async-activity-id",
336+
)
337+
338+
339+
async def test_async_activity_completion_payload_conversion(
340+
client: Client,
341+
):
342+
workflow_id = str(uuid.uuid4())
343+
task_queue = str(uuid.uuid4())
344+
345+
config = client.config()
346+
config["data_converter"] = data_converter
347+
client = Client(**config)
348+
349+
async with Worker(
350+
client,
351+
task_queue=task_queue,
352+
workflows=[AsyncActivityCompletionSerializationContextTestWorkflow],
353+
activities=[async_activity],
354+
workflow_runner=UnsandboxedWorkflowRunner(), # so that we can use isinstance
355+
):
356+
wf_handle = await client.start_workflow(
357+
AsyncActivityCompletionSerializationContextTestWorkflow.run,
358+
id=workflow_id,
359+
task_queue=task_queue,
360+
)
361+
activity_handle = client.get_async_activity_handle(
362+
workflow_id=workflow_id,
363+
run_id=wf_handle.first_execution_run_id,
364+
activity_id="async-activity-id",
365+
)
366+
await async_activity_started.wait()
367+
await activity_handle.complete(TraceData())
368+
result = await wf_handle.result()
369+
370+
workflow_context = dataclasses.asdict(
371+
WorkflowSerializationContext(
372+
namespace="default",
373+
workflow_id=workflow_id,
374+
)
375+
)
376+
activity_context = dataclasses.asdict(
377+
ActivitySerializationContext(
378+
namespace="default",
379+
workflow_id=workflow_id,
380+
workflow_type="AsyncActivityCompletionSerializationContextTestWorkflow",
381+
activity_type="async_activity",
382+
activity_task_queue=task_queue,
383+
is_local=False,
384+
)
385+
)
386+
assert_trace(
387+
result.items,
388+
[
389+
TraceItem(
390+
context_type="activity",
391+
in_workflow=False,
392+
method="to_payload",
393+
context=activity_context, # Outbound activity input
394+
),
395+
TraceItem(
396+
context_type="activity",
397+
in_workflow=False,
398+
method="from_payload",
399+
context=activity_context, # Inbound activity input
400+
),
401+
TraceItem(
402+
context_type="activity",
403+
in_workflow=False,
404+
method="to_payload",
405+
context=activity_context, # Outbound activity result
406+
),
407+
TraceItem(
408+
context_type="activity",
409+
in_workflow=False,
410+
method="from_payload",
411+
context=activity_context, # Inbound activity result
412+
),
413+
TraceItem(
414+
context_type="workflow",
415+
in_workflow=True,
416+
method="to_payload",
417+
context=workflow_context, # Inbound activity result
418+
),
419+
],
420+
)
421+
422+
317423
# Signal test
318424

319425

@@ -567,9 +673,7 @@ async def test_update_payload_conversion(
567673

568674
def assert_trace(trace: list[TraceItem], expected: list[TraceItem]):
569675
if len(trace) != len(expected):
570-
raise AssertionError(
571-
f"expected {len(expected)} trace items but received {len(trace)}"
572-
)
676+
warn(f"expected {len(expected)} trace items but received {len(trace)}")
573677
history = []
574678
for item, expected_item in zip_longest(trace, expected):
575679
if item is None:

0 commit comments

Comments
 (0)