Skip to content

Commit b05abf9

Browse files
committed
Update signal test
1 parent 75f558c commit b05abf9

File tree

1 file changed

+43
-98
lines changed

1 file changed

+43
-98
lines changed

tests/test_serialization_context.py

Lines changed: 43 additions & 98 deletions
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,9 @@ def with_context(
102102
return converter
103103

104104
def to_payload(self, value: Any) -> Optional[Payload]:
105-
assert isinstance(value, TraceData)
105+
print(f"🌈 to_payload({isinstance(value, TraceData)}): {value}")
106+
if not isinstance(value, TraceData):
107+
return None
106108
if not self.context:
107109
raise Exception("Context is None")
108110
if isinstance(self.context, WorkflowSerializationContext):
@@ -134,6 +136,7 @@ def to_payload(self, value: Any) -> Optional[Payload]:
134136

135137
def from_payload(self, payload: Payload, type_hint: Optional[Type] = None) -> Any:
136138
value = JSONPlainPayloadConverter().from_payload(payload, type_hint)
139+
print(f"🌈 from_payload({isinstance(value, TraceData)}): {value}")
137140
assert isinstance(value, TraceData)
138141
if not self.context:
139142
raise Exception("Context is None")
@@ -315,111 +318,28 @@ async def test_workflow_payload_conversion_can_be_given_access_to_serialization_
315318
# Signal test
316319

317320

318-
@dataclass
319-
class SignalData:
320-
signal_context: Optional[WorkflowSerializationContext] = None
321-
value: str = ""
322-
323-
324-
@workflow.defn
321+
@workflow.defn(sandboxed=False) # so that we can use isinstance
325322
class SignalSerializationContextTestWorkflow:
326323
def __init__(self) -> None:
327324
self.signal_received = None
328325

329326
@workflow.run
330-
async def run(self) -> SignalData:
327+
async def run(self) -> TraceData:
331328
await workflow.wait_condition(lambda: self.signal_received is not None)
332329
assert self.signal_received is not None
333330
return self.signal_received
334331

335332
@workflow.signal
336-
async def my_signal(self, data: SignalData) -> None:
333+
async def my_signal(self, data: TraceData) -> None:
337334
self.signal_received = data
338335

339336

340-
class SignalSerializationContextTestEncodingPayloadConverter(
341-
EncodingPayloadConverter, WithSerializationContext
342-
):
343-
def __init__(self, context: Optional[SerializationContext] = None):
344-
self.context = context
345-
346-
@property
347-
def encoding(self) -> str:
348-
return "test-signal-serialization-context"
349-
350-
def with_context(
351-
self, context: Optional[SerializationContext]
352-
) -> SignalSerializationContextTestEncodingPayloadConverter:
353-
return SignalSerializationContextTestEncodingPayloadConverter(context)
354-
355-
def to_payload(self, value: Any) -> Optional[Payload]:
356-
# Only handle SignalData objects
357-
if not isinstance(value, SignalData):
358-
return None
359-
360-
# Inject the context if it's a workflow context
361-
if isinstance(self.context, WorkflowSerializationContext):
362-
value.signal_context = self.context
363-
364-
# Serialize as JSON
365-
data = {
366-
"signal_context": (
367-
{
368-
"namespace": value.signal_context.namespace,
369-
"workflow_id": value.signal_context.workflow_id,
370-
}
371-
if value.signal_context
372-
else None
373-
),
374-
"value": value.value,
375-
}
376-
return Payload(
377-
metadata={"encoding": self.encoding.encode()},
378-
data=json.dumps(data).encode(),
379-
)
380-
381-
def from_payload(self, payload: Payload, type_hint: Optional[Type] = None) -> Any:
382-
data = json.loads(payload.data.decode())
383-
ctx_data = data.get("signal_context")
384-
return SignalData(
385-
signal_context=(
386-
WorkflowSerializationContext(**ctx_data) if ctx_data else None
387-
),
388-
value=data.get("value", ""),
389-
)
390-
391-
392-
class SignalSerializationContextTestPayloadConverter(
393-
CompositePayloadConverter, WithSerializationContext
394-
):
395-
def __init__(self, context: Optional[SerializationContext] = None):
396-
# Create converters with context
397-
converters = [
398-
SignalSerializationContextTestEncodingPayloadConverter(context),
399-
*DefaultPayloadConverter.default_encoding_payload_converters,
400-
]
401-
super().__init__(*converters)
402-
self.context = context
403-
404-
def with_context(
405-
self, context: Optional[SerializationContext]
406-
) -> SignalSerializationContextTestPayloadConverter:
407-
return SignalSerializationContextTestPayloadConverter(context)
408-
409-
410337
async def test_signal_payload_conversion_can_be_given_access_to_serialization_context(
411338
client: Client,
412339
):
413340
workflow_id = str(uuid.uuid4())
414341
task_queue = str(uuid.uuid4())
415342

416-
# Create client with our custom data converter
417-
data_converter = dataclasses.replace(
418-
DataConverter.default,
419-
payload_converter_class=SignalSerializationContextTestPayloadConverter,
420-
)
421-
422-
# Create a new client with the custom data converter
423343
config = client.config()
424344
config["data_converter"] = data_converter
425345
custom_client = Client(**config)
@@ -429,29 +349,54 @@ async def test_signal_payload_conversion_can_be_given_access_to_serialization_co
429349
task_queue=task_queue,
430350
workflows=[SignalSerializationContextTestWorkflow],
431351
activities=[],
352+
workflow_runner=UnsandboxedWorkflowRunner(), # so that we can use isinstance
432353
):
433-
# Start the workflow
434354
handle = await custom_client.start_workflow(
435355
SignalSerializationContextTestWorkflow.run,
436356
id=workflow_id,
437357
task_queue=task_queue,
438358
)
439-
440-
# Send a signal
441359
await handle.signal(
442360
SignalSerializationContextTestWorkflow.my_signal,
443-
SignalData(value="test-signal"),
361+
TraceData(),
444362
)
445-
446-
# Get the result
447363
result = await handle.result()
448364

449-
# Verify the signal context was injected
450-
assert result.signal_context == WorkflowSerializationContext(
451-
namespace="default",
452-
workflow_id=workflow_id,
365+
workflow_context = dataclasses.asdict(
366+
WorkflowSerializationContext(
367+
namespace="default",
368+
workflow_id=workflow_id,
369+
)
370+
)
371+
assert_trace(
372+
result.items,
373+
[
374+
TraceItem(
375+
context_type="workflow",
376+
in_workflow=False,
377+
method="to_payload",
378+
context=workflow_context, # Outbound signal input
379+
),
380+
TraceItem(
381+
context_type="workflow",
382+
in_workflow=False,
383+
method="from_payload",
384+
context=workflow_context, # Inbound signal input
385+
),
386+
TraceItem(
387+
context_type="workflow",
388+
in_workflow=True,
389+
method="to_payload",
390+
context=workflow_context, # Outbound workflow result
391+
),
392+
TraceItem(
393+
context_type="workflow",
394+
in_workflow=False,
395+
method="from_payload",
396+
context=workflow_context, # Inbound workflow result
397+
),
398+
],
453399
)
454-
assert result.value == "test-signal"
455400

456401

457402
# Query test

0 commit comments

Comments
 (0)