Skip to content

Commit 3235a40

Browse files
committed
Add failing test for signal serialization context
1 parent 5a9eb2b commit 3235a40

File tree

1 file changed

+144
-1
lines changed

1 file changed

+144
-1
lines changed

tests/test_serialization_context.py

Lines changed: 144 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import asyncio
44
import dataclasses
55
import inspect
6+
import json
67
import uuid
78
from dataclasses import dataclass, field
89
from datetime import timedelta
@@ -318,7 +319,7 @@ def assert_trace(trace: list[TraceItem], expected: list[TraceItem]):
318319
raise AssertionError("More items in trace than expected")
319320
if item != expected_item:
320321
raise AssertionError(
321-
f"Item:\n{pformat(item)}\n\ndoes not match expected:\n\n {pformat(expected_item)}.\n\n History:\n{'\n'.join(history)}"
322+
f"Item:\n{pformat(item)}\n\ndoes not match expected:\n\n {pformat(expected_item)}.\n\n History:\n{chr(10).join(history)}"
322323
)
323324
history.append(f"{item.context_type} {item.method}")
324325

@@ -357,3 +358,145 @@ def get_caller_location() -> list[str]:
357358
result.append("unknown:0")
358359

359360
return result
361+
362+
363+
# Signal test
364+
365+
366+
@dataclass
367+
class SignalData:
368+
signal_context: Optional[WorkflowSerializationContext] = None
369+
value: str = ""
370+
371+
372+
@workflow.defn
373+
class SignalSerializationContextTestWorkflow:
374+
def __init__(self) -> None:
375+
self.signal_received = None
376+
377+
@workflow.run
378+
async def run(self) -> SignalData:
379+
await workflow.wait_condition(lambda: self.signal_received is not None)
380+
assert self.signal_received is not None
381+
return self.signal_received
382+
383+
@workflow.signal
384+
async def my_signal(self, data: SignalData) -> None:
385+
self.signal_received = data
386+
387+
388+
class SignalSerializationContextTestEncodingPayloadConverter(
389+
EncodingPayloadConverter, WithSerializationContext
390+
):
391+
def __init__(self, context: Optional[SerializationContext] = None):
392+
self.context = context
393+
394+
@property
395+
def encoding(self) -> str:
396+
return "test-signal-serialization-context"
397+
398+
def with_context(
399+
self, context: Optional[SerializationContext]
400+
) -> SignalSerializationContextTestEncodingPayloadConverter:
401+
return SignalSerializationContextTestEncodingPayloadConverter(context)
402+
403+
def to_payload(self, value: Any) -> Optional[Payload]:
404+
# Only handle SignalData objects
405+
if not isinstance(value, SignalData):
406+
return None
407+
408+
# Inject the context if it's a workflow context
409+
if isinstance(self.context, WorkflowSerializationContext):
410+
value.signal_context = self.context
411+
412+
# Serialize as JSON
413+
data = {
414+
"signal_context": (
415+
{
416+
"namespace": value.signal_context.namespace,
417+
"workflow_id": value.signal_context.workflow_id,
418+
}
419+
if value.signal_context
420+
else None
421+
),
422+
"value": value.value,
423+
}
424+
return Payload(
425+
metadata={"encoding": self.encoding.encode()},
426+
data=json.dumps(data).encode(),
427+
)
428+
429+
def from_payload(self, payload: Payload, type_hint: Optional[Type] = None) -> Any:
430+
data = json.loads(payload.data.decode())
431+
ctx_data = data.get("signal_context")
432+
return SignalData(
433+
signal_context=(
434+
WorkflowSerializationContext(**ctx_data) if ctx_data else None
435+
),
436+
value=data.get("value", ""),
437+
)
438+
439+
440+
class SignalSerializationContextTestPayloadConverter(
441+
CompositePayloadConverter, WithSerializationContext
442+
):
443+
def __init__(self, context: Optional[SerializationContext] = None):
444+
# Create converters with context
445+
converters = [
446+
SignalSerializationContextTestEncodingPayloadConverter(context),
447+
*DefaultPayloadConverter.default_encoding_payload_converters,
448+
]
449+
super().__init__(*converters)
450+
self.context = context
451+
452+
def with_context(
453+
self, context: Optional[SerializationContext]
454+
) -> SignalSerializationContextTestPayloadConverter:
455+
return SignalSerializationContextTestPayloadConverter(context)
456+
457+
458+
async def test_signal_payload_conversion_can_be_given_access_to_serialization_context(
459+
client: Client,
460+
):
461+
workflow_id = str(uuid.uuid4())
462+
task_queue = str(uuid.uuid4())
463+
464+
# Create client with our custom data converter
465+
data_converter = dataclasses.replace(
466+
DataConverter.default,
467+
payload_converter_class=SignalSerializationContextTestPayloadConverter,
468+
)
469+
470+
# Create a new client with the custom data converter
471+
config = client.config()
472+
config["data_converter"] = data_converter
473+
custom_client = Client(**config)
474+
475+
async with Worker(
476+
custom_client,
477+
task_queue=task_queue,
478+
workflows=[SignalSerializationContextTestWorkflow],
479+
activities=[],
480+
):
481+
# Start the workflow
482+
handle = await custom_client.start_workflow(
483+
SignalSerializationContextTestWorkflow.run,
484+
id=workflow_id,
485+
task_queue=task_queue,
486+
)
487+
488+
# Send a signal
489+
await handle.signal(
490+
SignalSerializationContextTestWorkflow.my_signal,
491+
SignalData(value="test-signal"),
492+
)
493+
494+
# Get the result
495+
result = await handle.result()
496+
497+
# Verify the signal context was injected
498+
assert result.signal_context == WorkflowSerializationContext(
499+
namespace="default",
500+
workflow_id=workflow_id,
501+
)
502+
assert result.value == "test-signal"

0 commit comments

Comments
 (0)