@@ -500,3 +500,145 @@ async def test_signal_payload_conversion_can_be_given_access_to_serialization_co
500500 workflow_id = workflow_id ,
501501 )
502502 assert result .value == "test-signal"
503+
504+
505+ # Query test
506+
507+
508+ @dataclass
509+ class QueryData :
510+ query_context : Optional [WorkflowSerializationContext ] = None
511+ value : str = ""
512+
513+
514+ @workflow .defn
515+ class QuerySerializationContextTestWorkflow :
516+ def __init__ (self ) -> None :
517+ self .state = QueryData (value = "workflow-state" )
518+
519+ @workflow .run
520+ async def run (self ) -> None :
521+ # Keep workflow running
522+ await workflow .wait_condition (lambda : False )
523+
524+ @workflow .query
525+ def my_query (self ) -> QueryData :
526+ return self .state
527+
528+
529+ class QuerySerializationContextTestEncodingPayloadConverter (
530+ EncodingPayloadConverter , WithSerializationContext
531+ ):
532+ def __init__ (self , context : Optional [SerializationContext ] = None ):
533+ self .context = context
534+
535+ @property
536+ def encoding (self ) -> str :
537+ return "test-query-serialization-context"
538+
539+ def with_context (
540+ self , context : Optional [SerializationContext ]
541+ ) -> QuerySerializationContextTestEncodingPayloadConverter :
542+ return QuerySerializationContextTestEncodingPayloadConverter (context )
543+
544+ trace = [] # Class variable to capture serialization events
545+
546+ def to_payload (self , value : Any ) -> Optional [Payload ]:
547+ # Only handle QueryData objects
548+ if type (value ).__name__ != "QueryData" :
549+ return None
550+
551+ # Capture the context during serialization
552+ if self .context and isinstance (self .context , WorkflowSerializationContext ):
553+ self .__class__ .trace .append (
554+ {
555+ "operation" : "query_result_serialization" ,
556+ "context" : self .context ,
557+ "value" : value .value ,
558+ }
559+ )
560+
561+ # Serialize as JSON
562+ data = {"value" : value .value }
563+ return Payload (
564+ metadata = {"encoding" : self .encoding .encode ()},
565+ data = json .dumps (data ).encode (),
566+ )
567+
568+ def from_payload (self , payload : Payload , type_hint : Optional [Type ] = None ) -> Any :
569+ data = json .loads (payload .data .decode ())
570+ return QueryData (
571+ query_context = None , # Context is not transmitted, only captured during serialization
572+ value = data .get ("value" , "" ),
573+ )
574+
575+
576+ class QuerySerializationContextTestPayloadConverter (
577+ CompositePayloadConverter , WithSerializationContext
578+ ):
579+ def __init__ (self , context : Optional [SerializationContext ] = None ):
580+ # Create converters with context
581+ converters = [
582+ QuerySerializationContextTestEncodingPayloadConverter (context ),
583+ * DefaultPayloadConverter .default_encoding_payload_converters ,
584+ ]
585+ super ().__init__ (* converters )
586+ self .context = context
587+
588+ def with_context (
589+ self , context : Optional [SerializationContext ]
590+ ) -> QuerySerializationContextTestPayloadConverter :
591+ return QuerySerializationContextTestPayloadConverter (context )
592+
593+
594+ async def test_query_payload_conversion_can_be_given_access_to_serialization_context (
595+ client : Client ,
596+ ):
597+ workflow_id = str (uuid .uuid4 ())
598+ task_queue = str (uuid .uuid4 ())
599+ # Clear the trace before starting QuerySerializationContextTestEncodingPayloadConverter.trace = []
600+
601+ # Create client with our custom data converter
602+ data_converter = dataclasses .replace (
603+ DataConverter .default ,
604+ payload_converter_class = QuerySerializationContextTestPayloadConverter ,
605+ )
606+
607+ # Create a new client with the custom data converter
608+ config = client .config ()
609+ config ["data_converter" ] = data_converter
610+ custom_client = Client (** config )
611+
612+ async with Worker (
613+ custom_client ,
614+ task_queue = task_queue ,
615+ workflows = [QuerySerializationContextTestWorkflow ],
616+ activities = [],
617+ ):
618+ # Start the workflow
619+ handle = await custom_client .start_workflow (
620+ QuerySerializationContextTestWorkflow .run ,
621+ id = workflow_id ,
622+ task_queue = task_queue ,
623+ )
624+
625+ # Query the workflow
626+ result = await handle .query (QuerySerializationContextTestWorkflow .my_query )
627+
628+ # Verify the result value
629+ assert result .value == "workflow-state"
630+
631+ print (
632+ f"DEBUG: trace length = { len (QuerySerializationContextTestEncodingPayloadConverter .trace )} "
633+ )
634+ assert len (QuerySerializationContextTestEncodingPayloadConverter .trace ) > 0
635+ trace_entry = QuerySerializationContextTestEncodingPayloadConverter .trace [- 1 ]
636+ assert trace_entry ["operation" ] == "query_result_serialization"
637+ assert trace_entry ["context" ] == WorkflowSerializationContext (
638+ namespace = "default" ,
639+ workflow_id = workflow_id ,
640+ )
641+ assert trace_entry ["value" ] == "workflow-state"
642+
643+ # Cancel the workflow to clean up
644+ await handle .cancel ()
0 commit comments