Skip to content

Commit eb46cab

Browse files
committed
Test local activities
1 parent 96f76d8 commit eb46cab

File tree

1 file changed

+272
-16
lines changed

1 file changed

+272
-16
lines changed

tests/test_serialization_context.py

Lines changed: 272 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -411,7 +411,125 @@ async def test_heartbeat_details_payload_conversion(client: Client):
411411
), "Heartbeat details should be decoded with activity context"
412412

413413

414+
# Local activity test
415+
416+
417+
@activity.defn
418+
async def local_activity(input: TraceData) -> TraceData:
419+
return input
420+
421+
422+
@workflow.defn
423+
class LocalActivityWorkflow:
424+
@workflow.run
425+
async def run(self, data: TraceData) -> TraceData:
426+
return await workflow.execute_local_activity(
427+
local_activity,
428+
data,
429+
start_to_close_timeout=timedelta(seconds=10),
430+
)
431+
432+
433+
async def test_local_activity_payload_conversion(client: Client):
434+
workflow_id = str(uuid.uuid4())
435+
task_queue = str(uuid.uuid4())
436+
437+
config = client.config()
438+
config["data_converter"] = dataclasses.replace(
439+
DataConverter.default,
440+
payload_converter_class=SerializationContextCompositePayloadConverter,
441+
)
442+
client = Client(**config)
443+
444+
async with Worker(
445+
client,
446+
task_queue=task_queue,
447+
workflows=[LocalActivityWorkflow],
448+
activities=[local_activity],
449+
workflow_runner=UnsandboxedWorkflowRunner(), # so that we can use isinstance
450+
):
451+
result = await client.execute_workflow(
452+
LocalActivityWorkflow.run,
453+
TraceData(),
454+
id=workflow_id,
455+
task_queue=task_queue,
456+
)
457+
458+
workflow_context = dataclasses.asdict(
459+
WorkflowSerializationContext(
460+
namespace="default",
461+
workflow_id=workflow_id,
462+
)
463+
)
464+
local_activity_context = dataclasses.asdict(
465+
ActivitySerializationContext(
466+
namespace="default",
467+
workflow_id=workflow_id,
468+
workflow_type=LocalActivityWorkflow.__name__,
469+
activity_type=local_activity.__name__,
470+
activity_task_queue=task_queue,
471+
is_local=True,
472+
)
473+
)
474+
475+
assert (
476+
result.items
477+
== [
478+
TraceItem(
479+
context_type="workflow",
480+
in_workflow=False,
481+
method="to_payload",
482+
context=workflow_context, # Outbound workflow input
483+
),
484+
TraceItem(
485+
context_type="workflow",
486+
in_workflow=False,
487+
method="from_payload",
488+
context=workflow_context, # Inbound workflow input
489+
),
490+
TraceItem(
491+
context_type="activity",
492+
in_workflow=True,
493+
method="to_payload",
494+
context=local_activity_context, # Outbound local activity input (is_local=True)
495+
),
496+
TraceItem(
497+
context_type="activity",
498+
in_workflow=False,
499+
method="from_payload",
500+
context=local_activity_context, # Inbound local activity input (is_local=True)
501+
),
502+
TraceItem(
503+
context_type="activity",
504+
in_workflow=False,
505+
method="to_payload",
506+
context=local_activity_context, # Outbound local activity result (is_local=True)
507+
),
508+
TraceItem(
509+
context_type="activity",
510+
in_workflow=False,
511+
method="from_payload",
512+
context=local_activity_context, # Inbound local activity result (is_local=True)
513+
),
514+
TraceItem(
515+
context_type="workflow",
516+
in_workflow=True,
517+
method="to_payload",
518+
context=workflow_context, # Outbound workflow result
519+
),
520+
TraceItem(
521+
context_type="workflow",
522+
in_workflow=False,
523+
method="from_payload",
524+
context=workflow_context, # Inbound workflow result
525+
),
526+
]
527+
)
528+
529+
414530
# Async activity completion test
531+
532+
415533
@activity.defn
416534
async def async_activity() -> TraceData:
417535
# Signal that activity has started via heartbeat
@@ -1107,28 +1225,48 @@ def with_context(
11071225

11081226
async def encode(self, payloads: Sequence[Payload]) -> List[Payload]:
11091227
assert self.context
1110-
assert isinstance(self.context, WorkflowSerializationContext)
1111-
test_traces[self.context.workflow_id].append(
1112-
TraceItem(
1113-
context_type="workflow",
1114-
context=dataclasses.asdict(self.context),
1115-
method="encode",
1116-
in_workflow=workflow.in_workflow(),
1228+
if isinstance(self.context, ActivitySerializationContext):
1229+
test_traces[self.context.workflow_id].append(
1230+
TraceItem(
1231+
context_type="activity",
1232+
context=dataclasses.asdict(self.context),
1233+
method="encode",
1234+
in_workflow=workflow.in_workflow(),
1235+
)
1236+
)
1237+
else:
1238+
assert isinstance(self.context, WorkflowSerializationContext)
1239+
test_traces[self.context.workflow_id].append(
1240+
TraceItem(
1241+
context_type="workflow",
1242+
context=dataclasses.asdict(self.context),
1243+
method="encode",
1244+
in_workflow=workflow.in_workflow(),
1245+
)
11171246
)
1118-
)
11191247
return list(payloads)
11201248

11211249
async def decode(self, payloads: Sequence[Payload]) -> List[Payload]:
11221250
assert self.context
1123-
assert isinstance(self.context, WorkflowSerializationContext)
1124-
test_traces[self.context.workflow_id].append(
1125-
TraceItem(
1126-
context_type="workflow",
1127-
context=dataclasses.asdict(self.context),
1128-
method="decode",
1129-
in_workflow=workflow.in_workflow(),
1251+
if isinstance(self.context, ActivitySerializationContext):
1252+
test_traces[self.context.workflow_id].append(
1253+
TraceItem(
1254+
context_type="activity",
1255+
context=dataclasses.asdict(self.context),
1256+
method="decode",
1257+
in_workflow=workflow.in_workflow(),
1258+
)
1259+
)
1260+
else:
1261+
assert isinstance(self.context, WorkflowSerializationContext)
1262+
test_traces[self.context.workflow_id].append(
1263+
TraceItem(
1264+
context_type="workflow",
1265+
context=dataclasses.asdict(self.context),
1266+
method="decode",
1267+
in_workflow=workflow.in_workflow(),
1268+
)
11301269
)
1131-
)
11321270
return list(payloads)
11331271

11341272

@@ -1194,6 +1332,124 @@ async def test_codec_with_context(client: Client):
11941332
del test_traces[workflow_id]
11951333

11961334

1335+
@activity.defn
1336+
async def codec_test_local_activity(data: str) -> str:
1337+
return data
1338+
1339+
1340+
@workflow.defn
1341+
class LocalActivityCodecTestWorkflow:
1342+
@workflow.run
1343+
async def run(self, data: str) -> str:
1344+
return await workflow.execute_local_activity(
1345+
codec_test_local_activity,
1346+
data,
1347+
start_to_close_timeout=timedelta(seconds=10),
1348+
)
1349+
1350+
1351+
async def test_local_activity_codec_with_context(client: Client):
1352+
"""Test that codec gets correct context with is_local=True for local activities."""
1353+
workflow_id = str(uuid.uuid4())
1354+
task_queue = str(uuid.uuid4())
1355+
1356+
client_config = client.config()
1357+
client_config["data_converter"] = dataclasses.replace(
1358+
DataConverter.default, payload_codec=PayloadCodecWithContext()
1359+
)
1360+
client = Client(**client_config)
1361+
async with Worker(
1362+
client,
1363+
task_queue=task_queue,
1364+
workflows=[LocalActivityCodecTestWorkflow],
1365+
activities=[codec_test_local_activity],
1366+
):
1367+
await client.execute_workflow(
1368+
LocalActivityCodecTestWorkflow.run,
1369+
"data",
1370+
id=workflow_id,
1371+
task_queue=task_queue,
1372+
)
1373+
1374+
workflow_context = dataclasses.asdict(
1375+
WorkflowSerializationContext(
1376+
namespace=client.namespace,
1377+
workflow_id=workflow_id,
1378+
)
1379+
)
1380+
local_activity_context = dataclasses.asdict(
1381+
ActivitySerializationContext(
1382+
namespace=client.namespace,
1383+
workflow_id=workflow_id,
1384+
workflow_type=LocalActivityCodecTestWorkflow.__name__,
1385+
activity_type=codec_test_local_activity.__name__,
1386+
activity_task_queue=task_queue,
1387+
is_local=True, # Should be True for local activities
1388+
)
1389+
)
1390+
1391+
# Note: Local activities have partial activity context support through codec
1392+
# The input encode uses workflow context, but the decode uses activity context
1393+
# The result encode uses activity context, but the decode uses workflow context
1394+
assert test_traces[workflow_id] == [
1395+
# Workflow input
1396+
TraceItem(
1397+
context_type="workflow",
1398+
context=workflow_context,
1399+
method="encode",
1400+
in_workflow=False,
1401+
),
1402+
TraceItem(
1403+
context_type="workflow",
1404+
context=workflow_context,
1405+
method="decode",
1406+
in_workflow=False,
1407+
),
1408+
# Local activity input - encode uses workflow context
1409+
TraceItem(
1410+
context_type="workflow",
1411+
context=workflow_context,
1412+
method="encode",
1413+
in_workflow=False,
1414+
),
1415+
# Local activity input - decode uses activity context with is_local=True
1416+
TraceItem(
1417+
context_type="activity",
1418+
context=local_activity_context,
1419+
method="decode",
1420+
in_workflow=False,
1421+
),
1422+
# Local activity result - encode uses activity context with is_local=True
1423+
TraceItem(
1424+
context_type="activity",
1425+
context=local_activity_context,
1426+
method="encode",
1427+
in_workflow=False,
1428+
),
1429+
# Local activity result - decode uses workflow context
1430+
TraceItem(
1431+
context_type="workflow",
1432+
context=workflow_context,
1433+
method="decode",
1434+
in_workflow=False,
1435+
),
1436+
# Workflow result
1437+
TraceItem(
1438+
context_type="workflow",
1439+
context=workflow_context,
1440+
method="encode",
1441+
in_workflow=False,
1442+
),
1443+
TraceItem(
1444+
context_type="workflow",
1445+
context=workflow_context,
1446+
method="decode",
1447+
in_workflow=False,
1448+
),
1449+
]
1450+
del test_traces[workflow_id]
1451+
1452+
11971453
# Pydantic
11981454

11991455

0 commit comments

Comments
 (0)