Skip to content

Commit 12b16a3

Browse files
committed
Add failing test for query serialization context
1 parent 48fb29d commit 12b16a3

File tree

1 file changed

+142
-0
lines changed

1 file changed

+142
-0
lines changed

tests/test_serialization_context.py

Lines changed: 142 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -500,3 +500,145 @@ async def test_signal_payload_conversion_can_be_given_access_to_serialization_co
500500
workflow_id=workflow_id,
501501
)
502502
assert result.value == "test-signal"
503+
504+
505+
# Query test
506+
507+
508+
@dataclass
509+
class QueryData:
510+
query_context: Optional[WorkflowSerializationContext] = None
511+
value: str = ""
512+
513+
514+
@workflow.defn
515+
class QuerySerializationContextTestWorkflow:
516+
def __init__(self) -> None:
517+
self.state = QueryData(value="workflow-state")
518+
519+
@workflow.run
520+
async def run(self) -> None:
521+
# Keep workflow running
522+
await workflow.wait_condition(lambda: False)
523+
524+
@workflow.query
525+
def my_query(self) -> QueryData:
526+
return self.state
527+
528+
529+
class QuerySerializationContextTestEncodingPayloadConverter(
530+
EncodingPayloadConverter, WithSerializationContext
531+
):
532+
def __init__(self, context: Optional[SerializationContext] = None):
533+
self.context = context
534+
535+
@property
536+
def encoding(self) -> str:
537+
return "test-query-serialization-context"
538+
539+
def with_context(
540+
self, context: Optional[SerializationContext]
541+
) -> QuerySerializationContextTestEncodingPayloadConverter:
542+
return QuerySerializationContextTestEncodingPayloadConverter(context)
543+
544+
trace = [] # Class variable to capture serialization events
545+
546+
def to_payload(self, value: Any) -> Optional[Payload]:
547+
# Only handle QueryData objects
548+
if type(value).__name__ != "QueryData":
549+
return None
550+
551+
# Capture the context during serialization
552+
if self.context and isinstance(self.context, WorkflowSerializationContext):
553+
self.__class__.trace.append(
554+
{
555+
"operation": "query_result_serialization",
556+
"context": self.context,
557+
"value": value.value,
558+
}
559+
)
560+
561+
# Serialize as JSON
562+
data = {"value": value.value}
563+
return Payload(
564+
metadata={"encoding": self.encoding.encode()},
565+
data=json.dumps(data).encode(),
566+
)
567+
568+
def from_payload(self, payload: Payload, type_hint: Optional[Type] = None) -> Any:
569+
data = json.loads(payload.data.decode())
570+
return QueryData(
571+
query_context=None, # Context is not transmitted, only captured during serialization
572+
value=data.get("value", ""),
573+
)
574+
575+
576+
class QuerySerializationContextTestPayloadConverter(
577+
CompositePayloadConverter, WithSerializationContext
578+
):
579+
def __init__(self, context: Optional[SerializationContext] = None):
580+
# Create converters with context
581+
converters = [
582+
QuerySerializationContextTestEncodingPayloadConverter(context),
583+
*DefaultPayloadConverter.default_encoding_payload_converters,
584+
]
585+
super().__init__(*converters)
586+
self.context = context
587+
588+
def with_context(
589+
self, context: Optional[SerializationContext]
590+
) -> QuerySerializationContextTestPayloadConverter:
591+
return QuerySerializationContextTestPayloadConverter(context)
592+
593+
594+
async def test_query_payload_conversion_can_be_given_access_to_serialization_context(
595+
client: Client,
596+
):
597+
workflow_id = str(uuid.uuid4())
598+
task_queue = str(uuid.uuid4())
599+
# Clear the trace before starting QuerySerializationContextTestEncodingPayloadConverter.trace = []
600+
601+
# Create client with our custom data converter
602+
data_converter = dataclasses.replace(
603+
DataConverter.default,
604+
payload_converter_class=QuerySerializationContextTestPayloadConverter,
605+
)
606+
607+
# Create a new client with the custom data converter
608+
config = client.config()
609+
config["data_converter"] = data_converter
610+
custom_client = Client(**config)
611+
612+
async with Worker(
613+
custom_client,
614+
task_queue=task_queue,
615+
workflows=[QuerySerializationContextTestWorkflow],
616+
activities=[],
617+
):
618+
# Start the workflow
619+
handle = await custom_client.start_workflow(
620+
QuerySerializationContextTestWorkflow.run,
621+
id=workflow_id,
622+
task_queue=task_queue,
623+
)
624+
625+
# Query the workflow
626+
result = await handle.query(QuerySerializationContextTestWorkflow.my_query)
627+
628+
# Verify the result value
629+
assert result.value == "workflow-state"
630+
631+
print(
632+
f"DEBUG: trace length = {len(QuerySerializationContextTestEncodingPayloadConverter.trace)}"
633+
)
634+
assert len(QuerySerializationContextTestEncodingPayloadConverter.trace) > 0
635+
trace_entry = QuerySerializationContextTestEncodingPayloadConverter.trace[-1]
636+
assert trace_entry["operation"] == "query_result_serialization"
637+
assert trace_entry["context"] == WorkflowSerializationContext(
638+
namespace="default",
639+
workflow_id=workflow_id,
640+
)
641+
assert trace_entry["value"] == "workflow-state"
642+
643+
# Cancel the workflow to clean up
644+
await handle.cancel()

0 commit comments

Comments
 (0)