Skip to content

Commit 8701df9

Browse files
committed
Update query test
1 parent b05abf9 commit 8701df9

File tree

2 files changed

+42
-102
lines changed

2 files changed

+42
-102
lines changed

temporalio/worker/_workflow_instance.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -702,6 +702,7 @@ async def run_query() -> None:
702702
namespace=self._info.namespace,
703703
workflow_id=self._info.workflow_id,
704704
)
705+
# TODO: why do we deserialize query input in workflow but not signal?
705706
args = self._process_handler_args(
706707
job.query_type,
707708
job.arguments,

tests/test_serialization_context.py

Lines changed: 41 additions & 102 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
import asyncio
44
import dataclasses
55
import inspect
6-
import json
76
import uuid
87
from dataclasses import dataclass, field
98
from datetime import timedelta
@@ -402,106 +401,28 @@ async def test_signal_payload_conversion_can_be_given_access_to_serialization_co
402401
# Query test
403402

404403

405-
@dataclass
406-
class QueryData:
407-
query_context: Optional[WorkflowSerializationContext] = None
408-
value: str = ""
409-
410-
411404
@workflow.defn
412405
class QuerySerializationContextTestWorkflow:
413-
def __init__(self) -> None:
414-
self.state = QueryData(value="workflow-state")
415-
416406
@workflow.run
417407
async def run(self) -> None:
418-
# Keep workflow running
419-
await workflow.wait_condition(lambda: False)
408+
await asyncio.Event().wait()
420409

421410
@workflow.query
422-
def my_query(self) -> QueryData:
423-
return self.state
424-
425-
426-
class QuerySerializationContextTestEncodingPayloadConverter(
427-
EncodingPayloadConverter, WithSerializationContext
428-
):
429-
def __init__(self, context: Optional[SerializationContext] = None):
430-
self.context = context
431-
432-
@property
433-
def encoding(self) -> str:
434-
return "test-query-serialization-context"
435-
436-
def with_context(
437-
self, context: Optional[SerializationContext]
438-
) -> QuerySerializationContextTestEncodingPayloadConverter:
439-
return QuerySerializationContextTestEncodingPayloadConverter(context)
440-
441-
trace = [] # Class variable to capture serialization events
442-
443-
def to_payload(self, value: Any) -> Optional[Payload]:
444-
# Only handle QueryData objects
445-
if type(value).__name__ != "QueryData":
446-
return None
447-
448-
# Capture the context during serialization
449-
if self.context and isinstance(self.context, WorkflowSerializationContext):
450-
self.__class__.trace.append(
451-
{
452-
"operation": "query_result_serialization",
453-
"context": self.context,
454-
"value": value.value,
455-
}
456-
)
457-
458-
# Serialize as JSON
459-
data = {"value": value.value}
460-
return Payload(
461-
metadata={"encoding": self.encoding.encode()},
462-
data=json.dumps(data).encode(),
463-
)
464-
465-
def from_payload(self, payload: Payload, type_hint: Optional[Type] = None) -> Any:
466-
data = json.loads(payload.data.decode())
467-
return QueryData(
468-
query_context=None, # Context is not transmitted, only captured during serialization
469-
value=data.get("value", ""),
470-
)
471-
472-
473-
class QuerySerializationContextTestPayloadConverter(
474-
CompositePayloadConverter, WithSerializationContext
475-
):
476-
def __init__(self, context: Optional[SerializationContext] = None):
477-
# Create converters with context
478-
converters = [
479-
QuerySerializationContextTestEncodingPayloadConverter(context),
480-
*DefaultPayloadConverter.default_encoding_payload_converters,
481-
]
482-
super().__init__(*converters)
483-
self.context = context
484-
485-
def with_context(
486-
self, context: Optional[SerializationContext]
487-
) -> QuerySerializationContextTestPayloadConverter:
488-
return QuerySerializationContextTestPayloadConverter(context)
411+
def my_query(self, input: TraceData) -> TraceData:
412+
return input
489413

490414

491415
async def test_query_payload_conversion_can_be_given_access_to_serialization_context(
492416
client: Client,
493417
):
494418
workflow_id = str(uuid.uuid4())
495419
task_queue = str(uuid.uuid4())
496-
# Clear the trace before starting QuerySerializationContextTestEncodingPayloadConverter.trace = []
497420

498-
# Create client with our custom data converter
499421
data_converter = dataclasses.replace(
500422
DataConverter.default,
501-
payload_converter_class=QuerySerializationContextTestPayloadConverter,
423+
payload_converter_class=SerializationContextTestPayloadConverter,
502424
)
503425

504-
# Create a new client with the custom data converter
505426
config = client.config()
506427
config["data_converter"] = data_converter
507428
custom_client = Client(**config)
@@ -511,34 +432,52 @@ async def test_query_payload_conversion_can_be_given_access_to_serialization_con
511432
task_queue=task_queue,
512433
workflows=[QuerySerializationContextTestWorkflow],
513434
activities=[],
435+
workflow_runner=UnsandboxedWorkflowRunner(), # so that we can use isinstance
514436
):
515-
# Start the workflow
516437
handle = await custom_client.start_workflow(
517438
QuerySerializationContextTestWorkflow.run,
518439
id=workflow_id,
519440
task_queue=task_queue,
520441
)
442+
result = await handle.query(
443+
QuerySerializationContextTestWorkflow.my_query, TraceData()
444+
)
521445

522-
# Query the workflow
523-
result = await handle.query(QuerySerializationContextTestWorkflow.my_query)
524-
525-
# Verify the result value
526-
assert result.value == "workflow-state"
527-
528-
print(
529-
f"DEBUG: trace length = {len(QuerySerializationContextTestEncodingPayloadConverter.trace)}"
446+
workflow_context = dataclasses.asdict(
447+
WorkflowSerializationContext(
448+
namespace="default",
449+
workflow_id=workflow_id,
450+
)
530451
)
531-
assert len(QuerySerializationContextTestEncodingPayloadConverter.trace) > 0
532-
trace_entry = QuerySerializationContextTestEncodingPayloadConverter.trace[-1]
533-
assert trace_entry["operation"] == "query_result_serialization"
534-
assert trace_entry["context"] == WorkflowSerializationContext(
535-
namespace="default",
536-
workflow_id=workflow_id,
452+
assert_trace(
453+
result.items,
454+
[
455+
TraceItem(
456+
context_type="workflow",
457+
in_workflow=False,
458+
method="to_payload",
459+
context=workflow_context, # Outbound query input
460+
),
461+
TraceItem(
462+
context_type="workflow",
463+
in_workflow=True,
464+
method="from_payload",
465+
context=workflow_context, # Inbound query input
466+
),
467+
TraceItem(
468+
context_type="workflow",
469+
in_workflow=True,
470+
method="to_payload",
471+
context=workflow_context, # Outbound query result
472+
),
473+
TraceItem(
474+
context_type="workflow",
475+
in_workflow=False,
476+
method="from_payload",
477+
context=workflow_context, # Inbound query result
478+
),
479+
],
537480
)
538-
assert trace_entry["value"] == "workflow-state"
539-
540-
# Cancel the workflow to clean up
541-
await handle.cancel()
542481

543482

544483
# Utilities

0 commit comments

Comments
 (0)