33import asyncio
44import dataclasses
55import inspect
6- import json
76import uuid
87from dataclasses import dataclass , field
98from datetime import timedelta
@@ -402,106 +401,28 @@ async def test_signal_payload_conversion_can_be_given_access_to_serialization_co
402401# Query test
403402
404403
405- @dataclass
406- class QueryData :
407- query_context : Optional [WorkflowSerializationContext ] = None
408- value : str = ""
409-
410-
411404@workflow .defn
412405class QuerySerializationContextTestWorkflow :
413- def __init__ (self ) -> None :
414- self .state = QueryData (value = "workflow-state" )
415-
416406 @workflow .run
417407 async def run (self ) -> None :
418- # Keep workflow running
419- await workflow .wait_condition (lambda : False )
408+ await asyncio .Event ().wait ()
420409
421410 @workflow .query
422- def my_query (self ) -> QueryData :
423- return self .state
424-
425-
426- class QuerySerializationContextTestEncodingPayloadConverter (
427- EncodingPayloadConverter , WithSerializationContext
428- ):
429- def __init__ (self , context : Optional [SerializationContext ] = None ):
430- self .context = context
431-
432- @property
433- def encoding (self ) -> str :
434- return "test-query-serialization-context"
435-
436- def with_context (
437- self , context : Optional [SerializationContext ]
438- ) -> QuerySerializationContextTestEncodingPayloadConverter :
439- return QuerySerializationContextTestEncodingPayloadConverter (context )
440-
441- trace = [] # Class variable to capture serialization events
442-
443- def to_payload (self , value : Any ) -> Optional [Payload ]:
444- # Only handle QueryData objects
445- if type (value ).__name__ != "QueryData" :
446- return None
447-
448- # Capture the context during serialization
449- if self .context and isinstance (self .context , WorkflowSerializationContext ):
450- self .__class__ .trace .append (
451- {
452- "operation" : "query_result_serialization" ,
453- "context" : self .context ,
454- "value" : value .value ,
455- }
456- )
457-
458- # Serialize as JSON
459- data = {"value" : value .value }
460- return Payload (
461- metadata = {"encoding" : self .encoding .encode ()},
462- data = json .dumps (data ).encode (),
463- )
464-
465- def from_payload (self , payload : Payload , type_hint : Optional [Type ] = None ) -> Any :
466- data = json .loads (payload .data .decode ())
467- return QueryData (
468- query_context = None , # Context is not transmitted, only captured during serialization
469- value = data .get ("value" , "" ),
470- )
471-
472-
473- class QuerySerializationContextTestPayloadConverter (
474- CompositePayloadConverter , WithSerializationContext
475- ):
476- def __init__ (self , context : Optional [SerializationContext ] = None ):
477- # Create converters with context
478- converters = [
479- QuerySerializationContextTestEncodingPayloadConverter (context ),
480- * DefaultPayloadConverter .default_encoding_payload_converters ,
481- ]
482- super ().__init__ (* converters )
483- self .context = context
484-
485- def with_context (
486- self , context : Optional [SerializationContext ]
487- ) -> QuerySerializationContextTestPayloadConverter :
488- return QuerySerializationContextTestPayloadConverter (context )
411+ def my_query (self , input : TraceData ) -> TraceData :
412+ return input
489413
490414
491415async def test_query_payload_conversion_can_be_given_access_to_serialization_context (
492416 client : Client ,
493417):
494418 workflow_id = str (uuid .uuid4 ())
495419 task_queue = str (uuid .uuid4 ())
496- # Clear the trace before starting QuerySerializationContextTestEncodingPayloadConverter.trace = []
497420
498- # Create client with our custom data converter
499421 data_converter = dataclasses .replace (
500422 DataConverter .default ,
501- payload_converter_class = QuerySerializationContextTestPayloadConverter ,
423+ payload_converter_class = SerializationContextTestPayloadConverter ,
502424 )
503425
504- # Create a new client with the custom data converter
505426 config = client .config ()
506427 config ["data_converter" ] = data_converter
507428 custom_client = Client (** config )
@@ -511,34 +432,52 @@ async def test_query_payload_conversion_can_be_given_access_to_serialization_con
511432 task_queue = task_queue ,
512433 workflows = [QuerySerializationContextTestWorkflow ],
513434 activities = [],
435+ workflow_runner = UnsandboxedWorkflowRunner (), # so that we can use isinstance
514436 ):
515- # Start the workflow
516437 handle = await custom_client .start_workflow (
517438 QuerySerializationContextTestWorkflow .run ,
518439 id = workflow_id ,
519440 task_queue = task_queue ,
520441 )
442+ result = await handle .query (
443+ QuerySerializationContextTestWorkflow .my_query , TraceData ()
444+ )
521445
522- # Query the workflow
523- result = await handle .query (QuerySerializationContextTestWorkflow .my_query )
524-
525- # Verify the result value
526- assert result .value == "workflow-state"
527-
528- print (
529- f"DEBUG: trace length = { len (QuerySerializationContextTestEncodingPayloadConverter .trace )} "
446+ workflow_context = dataclasses .asdict (
447+ WorkflowSerializationContext (
448+ namespace = "default" ,
449+ workflow_id = workflow_id ,
450+ )
530451 )
531- assert len (QuerySerializationContextTestEncodingPayloadConverter .trace ) > 0
532- trace_entry = QuerySerializationContextTestEncodingPayloadConverter .trace [- 1 ]
533- assert trace_entry ["operation" ] == "query_result_serialization"
534- assert trace_entry ["context" ] == WorkflowSerializationContext (
535- namespace = "default" ,
536- workflow_id = workflow_id ,
452+ assert_trace (
453+ result .items ,
454+ [
455+ TraceItem (
456+ context_type = "workflow" ,
457+ in_workflow = False ,
458+ method = "to_payload" ,
459+ context = workflow_context , # Outbound query input
460+ ),
461+ TraceItem (
462+ context_type = "workflow" ,
463+ in_workflow = True ,
464+ method = "from_payload" ,
465+ context = workflow_context , # Inbound query input
466+ ),
467+ TraceItem (
468+ context_type = "workflow" ,
469+ in_workflow = True ,
470+ method = "to_payload" ,
471+ context = workflow_context , # Outbound query result
472+ ),
473+ TraceItem (
474+ context_type = "workflow" ,
475+ in_workflow = False ,
476+ method = "from_payload" ,
477+ context = workflow_context , # Inbound query result
478+ ),
479+ ],
537480 )
538- assert trace_entry ["value" ] == "workflow-state"
539-
540- # Cancel the workflow to clean up
541- await handle .cancel ()
542481
543482
544483# Utilities
0 commit comments