|
3 | 3 | import asyncio |
4 | 4 | import dataclasses |
5 | 5 | import inspect |
| 6 | +import json |
6 | 7 | import uuid |
7 | 8 | from dataclasses import dataclass, field |
8 | 9 | from datetime import timedelta |
@@ -318,7 +319,7 @@ def assert_trace(trace: list[TraceItem], expected: list[TraceItem]): |
318 | 319 | raise AssertionError("More items in trace than expected") |
319 | 320 | if item != expected_item: |
320 | 321 | raise AssertionError( |
321 | | - f"Item:\n{pformat(item)}\n\ndoes not match expected:\n\n {pformat(expected_item)}.\n\n History:\n{'\n'.join(history)}" |
| 322 | + f"Item:\n{pformat(item)}\n\ndoes not match expected:\n\n {pformat(expected_item)}.\n\n History:\n{chr(10).join(history)}" |
322 | 323 | ) |
323 | 324 | history.append(f"{item.context_type} {item.method}") |
324 | 325 |
|
@@ -357,3 +358,145 @@ def get_caller_location() -> list[str]: |
357 | 358 | result.append("unknown:0") |
358 | 359 |
|
359 | 360 | return result |
| 361 | + |
| 362 | + |
| 363 | +# Signal test |
| 364 | + |
| 365 | + |
| 366 | +@dataclass |
| 367 | +class SignalData: |
| 368 | + signal_context: Optional[WorkflowSerializationContext] = None |
| 369 | + value: str = "" |
| 370 | + |
| 371 | + |
| 372 | +@workflow.defn |
| 373 | +class SignalSerializationContextTestWorkflow: |
| 374 | + def __init__(self) -> None: |
| 375 | + self.signal_received = None |
| 376 | + |
| 377 | + @workflow.run |
| 378 | + async def run(self) -> SignalData: |
| 379 | + await workflow.wait_condition(lambda: self.signal_received is not None) |
| 380 | + assert self.signal_received is not None |
| 381 | + return self.signal_received |
| 382 | + |
| 383 | + @workflow.signal |
| 384 | + async def my_signal(self, data: SignalData) -> None: |
| 385 | + self.signal_received = data |
| 386 | + |
| 387 | + |
| 388 | +class SignalSerializationContextTestEncodingPayloadConverter( |
| 389 | + EncodingPayloadConverter, WithSerializationContext |
| 390 | +): |
| 391 | + def __init__(self, context: Optional[SerializationContext] = None): |
| 392 | + self.context = context |
| 393 | + |
| 394 | + @property |
| 395 | + def encoding(self) -> str: |
| 396 | + return "test-signal-serialization-context" |
| 397 | + |
| 398 | + def with_context( |
| 399 | + self, context: Optional[SerializationContext] |
| 400 | + ) -> SignalSerializationContextTestEncodingPayloadConverter: |
| 401 | + return SignalSerializationContextTestEncodingPayloadConverter(context) |
| 402 | + |
| 403 | + def to_payload(self, value: Any) -> Optional[Payload]: |
| 404 | + # Only handle SignalData objects |
| 405 | + if not isinstance(value, SignalData): |
| 406 | + return None |
| 407 | + |
| 408 | + # Inject the context if it's a workflow context |
| 409 | + if isinstance(self.context, WorkflowSerializationContext): |
| 410 | + value.signal_context = self.context |
| 411 | + |
| 412 | + # Serialize as JSON |
| 413 | + data = { |
| 414 | + "signal_context": ( |
| 415 | + { |
| 416 | + "namespace": value.signal_context.namespace, |
| 417 | + "workflow_id": value.signal_context.workflow_id, |
| 418 | + } |
| 419 | + if value.signal_context |
| 420 | + else None |
| 421 | + ), |
| 422 | + "value": value.value, |
| 423 | + } |
| 424 | + return Payload( |
| 425 | + metadata={"encoding": self.encoding.encode()}, |
| 426 | + data=json.dumps(data).encode(), |
| 427 | + ) |
| 428 | + |
| 429 | + def from_payload(self, payload: Payload, type_hint: Optional[Type] = None) -> Any: |
| 430 | + data = json.loads(payload.data.decode()) |
| 431 | + ctx_data = data.get("signal_context") |
| 432 | + return SignalData( |
| 433 | + signal_context=( |
| 434 | + WorkflowSerializationContext(**ctx_data) if ctx_data else None |
| 435 | + ), |
| 436 | + value=data.get("value", ""), |
| 437 | + ) |
| 438 | + |
| 439 | + |
| 440 | +class SignalSerializationContextTestPayloadConverter( |
| 441 | + CompositePayloadConverter, WithSerializationContext |
| 442 | +): |
| 443 | + def __init__(self, context: Optional[SerializationContext] = None): |
| 444 | + # Create converters with context |
| 445 | + converters = [ |
| 446 | + SignalSerializationContextTestEncodingPayloadConverter(context), |
| 447 | + *DefaultPayloadConverter.default_encoding_payload_converters, |
| 448 | + ] |
| 449 | + super().__init__(*converters) |
| 450 | + self.context = context |
| 451 | + |
| 452 | + def with_context( |
| 453 | + self, context: Optional[SerializationContext] |
| 454 | + ) -> SignalSerializationContextTestPayloadConverter: |
| 455 | + return SignalSerializationContextTestPayloadConverter(context) |
| 456 | + |
| 457 | + |
| 458 | +async def test_signal_payload_conversion_can_be_given_access_to_serialization_context( |
| 459 | + client: Client, |
| 460 | +): |
| 461 | + workflow_id = str(uuid.uuid4()) |
| 462 | + task_queue = str(uuid.uuid4()) |
| 463 | + |
| 464 | + # Create client with our custom data converter |
| 465 | + data_converter = dataclasses.replace( |
| 466 | + DataConverter.default, |
| 467 | + payload_converter_class=SignalSerializationContextTestPayloadConverter, |
| 468 | + ) |
| 469 | + |
| 470 | + # Create a new client with the custom data converter |
| 471 | + config = client.config() |
| 472 | + config["data_converter"] = data_converter |
| 473 | + custom_client = Client(**config) |
| 474 | + |
| 475 | + async with Worker( |
| 476 | + custom_client, |
| 477 | + task_queue=task_queue, |
| 478 | + workflows=[SignalSerializationContextTestWorkflow], |
| 479 | + activities=[], |
| 480 | + ): |
| 481 | + # Start the workflow |
| 482 | + handle = await custom_client.start_workflow( |
| 483 | + SignalSerializationContextTestWorkflow.run, |
| 484 | + id=workflow_id, |
| 485 | + task_queue=task_queue, |
| 486 | + ) |
| 487 | + |
| 488 | + # Send a signal |
| 489 | + await handle.signal( |
| 490 | + SignalSerializationContextTestWorkflow.my_signal, |
| 491 | + SignalData(value="test-signal"), |
| 492 | + ) |
| 493 | + |
| 494 | + # Get the result |
| 495 | + result = await handle.result() |
| 496 | + |
| 497 | + # Verify the signal context was injected |
| 498 | + assert result.signal_context == WorkflowSerializationContext( |
| 499 | + namespace="default", |
| 500 | + workflow_id=workflow_id, |
| 501 | + ) |
| 502 | + assert result.value == "test-signal" |
0 commit comments