@@ -411,7 +411,125 @@ async def test_heartbeat_details_payload_conversion(client: Client):
411411 ), "Heartbeat details should be decoded with activity context"
412412
413413
414+ # Local activity test
415+
416+
417+ @activity .defn
418+ async def local_activity (input : TraceData ) -> TraceData :
419+ return input
420+
421+
422+ @workflow .defn
423+ class LocalActivityWorkflow :
424+ @workflow .run
425+ async def run (self , data : TraceData ) -> TraceData :
426+ return await workflow .execute_local_activity (
427+ local_activity ,
428+ data ,
429+ start_to_close_timeout = timedelta (seconds = 10 ),
430+ )
431+
432+
433+ async def test_local_activity_payload_conversion (client : Client ):
434+ workflow_id = str (uuid .uuid4 ())
435+ task_queue = str (uuid .uuid4 ())
436+
437+ config = client .config ()
438+ config ["data_converter" ] = dataclasses .replace (
439+ DataConverter .default ,
440+ payload_converter_class = SerializationContextCompositePayloadConverter ,
441+ )
442+ client = Client (** config )
443+
444+ async with Worker (
445+ client ,
446+ task_queue = task_queue ,
447+ workflows = [LocalActivityWorkflow ],
448+ activities = [local_activity ],
449+ workflow_runner = UnsandboxedWorkflowRunner (), # so that we can use isinstance
450+ ):
451+ result = await client .execute_workflow (
452+ LocalActivityWorkflow .run ,
453+ TraceData (),
454+ id = workflow_id ,
455+ task_queue = task_queue ,
456+ )
457+
458+ workflow_context = dataclasses .asdict (
459+ WorkflowSerializationContext (
460+ namespace = "default" ,
461+ workflow_id = workflow_id ,
462+ )
463+ )
464+ local_activity_context = dataclasses .asdict (
465+ ActivitySerializationContext (
466+ namespace = "default" ,
467+ workflow_id = workflow_id ,
468+ workflow_type = LocalActivityWorkflow .__name__ ,
469+ activity_type = local_activity .__name__ ,
470+ activity_task_queue = task_queue ,
471+ is_local = True ,
472+ )
473+ )
474+
475+ assert (
476+ result .items
477+ == [
478+ TraceItem (
479+ context_type = "workflow" ,
480+ in_workflow = False ,
481+ method = "to_payload" ,
482+ context = workflow_context , # Outbound workflow input
483+ ),
484+ TraceItem (
485+ context_type = "workflow" ,
486+ in_workflow = False ,
487+ method = "from_payload" ,
488+ context = workflow_context , # Inbound workflow input
489+ ),
490+ TraceItem (
491+ context_type = "activity" ,
492+ in_workflow = True ,
493+ method = "to_payload" ,
494+ context = local_activity_context , # Outbound local activity input (is_local=True)
495+ ),
496+ TraceItem (
497+ context_type = "activity" ,
498+ in_workflow = False ,
499+ method = "from_payload" ,
500+ context = local_activity_context , # Inbound local activity input (is_local=True)
501+ ),
502+ TraceItem (
503+ context_type = "activity" ,
504+ in_workflow = False ,
505+ method = "to_payload" ,
506+ context = local_activity_context , # Outbound local activity result (is_local=True)
507+ ),
508+ TraceItem (
509+ context_type = "activity" ,
510+ in_workflow = False ,
511+ method = "from_payload" ,
512+ context = local_activity_context , # Inbound local activity result (is_local=True)
513+ ),
514+ TraceItem (
515+ context_type = "workflow" ,
516+ in_workflow = True ,
517+ method = "to_payload" ,
518+ context = workflow_context , # Outbound workflow result
519+ ),
520+ TraceItem (
521+ context_type = "workflow" ,
522+ in_workflow = False ,
523+ method = "from_payload" ,
524+ context = workflow_context , # Inbound workflow result
525+ ),
526+ ]
527+ )
528+
529+
414530# Async activity completion test
531+
532+
415533@activity .defn
416534async def async_activity () -> TraceData :
417535 # Signal that activity has started via heartbeat
@@ -1107,28 +1225,48 @@ def with_context(
11071225
11081226 async def encode (self , payloads : Sequence [Payload ]) -> List [Payload ]:
11091227 assert self .context
1110- assert isinstance (self .context , WorkflowSerializationContext )
1111- test_traces [self .context .workflow_id ].append (
1112- TraceItem (
1113- context_type = "workflow" ,
1114- context = dataclasses .asdict (self .context ),
1115- method = "encode" ,
1116- in_workflow = workflow .in_workflow (),
1228+ if isinstance (self .context , ActivitySerializationContext ):
1229+ test_traces [self .context .workflow_id ].append (
1230+ TraceItem (
1231+ context_type = "activity" ,
1232+ context = dataclasses .asdict (self .context ),
1233+ method = "encode" ,
1234+ in_workflow = workflow .in_workflow (),
1235+ )
1236+ )
1237+ else :
1238+ assert isinstance (self .context , WorkflowSerializationContext )
1239+ test_traces [self .context .workflow_id ].append (
1240+ TraceItem (
1241+ context_type = "workflow" ,
1242+ context = dataclasses .asdict (self .context ),
1243+ method = "encode" ,
1244+ in_workflow = workflow .in_workflow (),
1245+ )
11171246 )
1118- )
11191247 return list (payloads )
11201248
11211249 async def decode (self , payloads : Sequence [Payload ]) -> List [Payload ]:
11221250 assert self .context
1123- assert isinstance (self .context , WorkflowSerializationContext )
1124- test_traces [self .context .workflow_id ].append (
1125- TraceItem (
1126- context_type = "workflow" ,
1127- context = dataclasses .asdict (self .context ),
1128- method = "decode" ,
1129- in_workflow = workflow .in_workflow (),
1251+ if isinstance (self .context , ActivitySerializationContext ):
1252+ test_traces [self .context .workflow_id ].append (
1253+ TraceItem (
1254+ context_type = "activity" ,
1255+ context = dataclasses .asdict (self .context ),
1256+ method = "decode" ,
1257+ in_workflow = workflow .in_workflow (),
1258+ )
1259+ )
1260+ else :
1261+ assert isinstance (self .context , WorkflowSerializationContext )
1262+ test_traces [self .context .workflow_id ].append (
1263+ TraceItem (
1264+ context_type = "workflow" ,
1265+ context = dataclasses .asdict (self .context ),
1266+ method = "decode" ,
1267+ in_workflow = workflow .in_workflow (),
1268+ )
11301269 )
1131- )
11321270 return list (payloads )
11331271
11341272
@@ -1194,6 +1332,124 @@ async def test_codec_with_context(client: Client):
11941332 del test_traces [workflow_id ]
11951333
11961334
1335+ @activity .defn
1336+ async def codec_test_local_activity (data : str ) -> str :
1337+ return data
1338+
1339+
1340+ @workflow .defn
1341+ class LocalActivityCodecTestWorkflow :
1342+ @workflow .run
1343+ async def run (self , data : str ) -> str :
1344+ return await workflow .execute_local_activity (
1345+ codec_test_local_activity ,
1346+ data ,
1347+ start_to_close_timeout = timedelta (seconds = 10 ),
1348+ )
1349+
1350+
1351+ async def test_local_activity_codec_with_context (client : Client ):
1352+ """Test that codec gets correct context with is_local=True for local activities."""
1353+ workflow_id = str (uuid .uuid4 ())
1354+ task_queue = str (uuid .uuid4 ())
1355+
1356+ client_config = client .config ()
1357+ client_config ["data_converter" ] = dataclasses .replace (
1358+ DataConverter .default , payload_codec = PayloadCodecWithContext ()
1359+ )
1360+ client = Client (** client_config )
1361+ async with Worker (
1362+ client ,
1363+ task_queue = task_queue ,
1364+ workflows = [LocalActivityCodecTestWorkflow ],
1365+ activities = [codec_test_local_activity ],
1366+ ):
1367+ await client .execute_workflow (
1368+ LocalActivityCodecTestWorkflow .run ,
1369+ "data" ,
1370+ id = workflow_id ,
1371+ task_queue = task_queue ,
1372+ )
1373+
1374+ workflow_context = dataclasses .asdict (
1375+ WorkflowSerializationContext (
1376+ namespace = client .namespace ,
1377+ workflow_id = workflow_id ,
1378+ )
1379+ )
1380+ local_activity_context = dataclasses .asdict (
1381+ ActivitySerializationContext (
1382+ namespace = client .namespace ,
1383+ workflow_id = workflow_id ,
1384+ workflow_type = LocalActivityCodecTestWorkflow .__name__ ,
1385+ activity_type = codec_test_local_activity .__name__ ,
1386+ activity_task_queue = task_queue ,
1387+ is_local = True , # Should be True for local activities
1388+ )
1389+ )
1390+
1391+ # Note: Local activities have partial activity context support through codec
1392+ # The input encode uses workflow context, but the decode uses activity context
1393+ # The result encode uses activity context, but the decode uses workflow context
1394+ assert test_traces [workflow_id ] == [
1395+ # Workflow input
1396+ TraceItem (
1397+ context_type = "workflow" ,
1398+ context = workflow_context ,
1399+ method = "encode" ,
1400+ in_workflow = False ,
1401+ ),
1402+ TraceItem (
1403+ context_type = "workflow" ,
1404+ context = workflow_context ,
1405+ method = "decode" ,
1406+ in_workflow = False ,
1407+ ),
1408+ # Local activity input - encode uses workflow context
1409+ TraceItem (
1410+ context_type = "workflow" ,
1411+ context = workflow_context ,
1412+ method = "encode" ,
1413+ in_workflow = False ,
1414+ ),
1415+ # Local activity input - decode uses activity context with is_local=True
1416+ TraceItem (
1417+ context_type = "activity" ,
1418+ context = local_activity_context ,
1419+ method = "decode" ,
1420+ in_workflow = False ,
1421+ ),
1422+ # Local activity result - encode uses activity context with is_local=True
1423+ TraceItem (
1424+ context_type = "activity" ,
1425+ context = local_activity_context ,
1426+ method = "encode" ,
1427+ in_workflow = False ,
1428+ ),
1429+ # Local activity result - decode uses workflow context
1430+ TraceItem (
1431+ context_type = "workflow" ,
1432+ context = workflow_context ,
1433+ method = "decode" ,
1434+ in_workflow = False ,
1435+ ),
1436+ # Workflow result
1437+ TraceItem (
1438+ context_type = "workflow" ,
1439+ context = workflow_context ,
1440+ method = "encode" ,
1441+ in_workflow = False ,
1442+ ),
1443+ TraceItem (
1444+ context_type = "workflow" ,
1445+ context = workflow_context ,
1446+ method = "decode" ,
1447+ in_workflow = False ,
1448+ ),
1449+ ]
1450+ del test_traces [workflow_id ]
1451+
1452+
11971453# Pydantic
11981454
11991455
0 commit comments