|
16 | 16 | from temporalio import activity, workflow |
17 | 17 | from temporalio.api.common.v1 import Payload |
18 | 18 | from temporalio.client import Client, WorkflowUpdateFailedError |
| 19 | +from temporalio.common import RetryPolicy |
19 | 20 | from temporalio.converter import ( |
20 | 21 | ActivitySerializationContext, |
21 | 22 | CompositePayloadConverter, |
@@ -104,7 +105,6 @@ def with_context( |
104 | 105 | return converter |
105 | 106 |
|
106 | 107 | def to_payload(self, value: Any) -> Optional[Payload]: |
107 | | - print(f"🌈 to_payload({isinstance(value, TraceData)}): {value}") |
108 | 108 | if not isinstance(value, TraceData): |
109 | 109 | return None |
110 | 110 | if not self.context: |
@@ -137,8 +137,8 @@ def to_payload(self, value: Any) -> Optional[Payload]: |
137 | 137 | return payload |
138 | 138 |
|
139 | 139 | def from_payload(self, payload: Payload, type_hint: Optional[Type] = None) -> Any: |
140 | | - value = JSONPlainPayloadConverter().from_payload(payload, type_hint) |
141 | | - print(f"🌈 from_payload({isinstance(value, TraceData)}): {value}") |
| 140 | + # Always deserialize as TraceData since that's what this converter handles |
| 141 | + value = JSONPlainPayloadConverter().from_payload(payload, TraceData) |
142 | 142 | assert isinstance(value, TraceData) |
143 | 143 | if not self.context: |
144 | 144 | raise Exception("Context is None") |
@@ -320,6 +320,94 @@ async def test_workflow_payload_conversion( |
320 | 320 | async_activity_started = asyncio.Event() |
321 | 321 |
|
322 | 322 |
|
| 323 | +# Activity with heartbeat details test |
| 324 | +@activity.defn |
| 325 | +async def activity_with_heartbeat_details() -> TraceData: |
| 326 | + """Activity that checks heartbeat details are decoded with proper context.""" |
| 327 | + info = activity.info() |
| 328 | + |
| 329 | + # If we have heartbeat details, it means we're resuming from a previous attempt |
| 330 | + if info.heartbeat_details: |
| 331 | + # The heartbeat details should be a TraceData that was decoded with activity context |
| 332 | + assert len(info.heartbeat_details) == 1 |
| 333 | + heartbeat_data = info.heartbeat_details[0] |
| 334 | + assert isinstance(heartbeat_data, TraceData) |
| 335 | + # Return the heartbeat data which should contain the decode trace |
| 336 | + return heartbeat_data |
| 337 | + |
| 338 | + # First attempt - heartbeat and then fail |
| 339 | + data = TraceData() |
| 340 | + activity.heartbeat(data) |
| 341 | + # Wait a bit to ensure heartbeat is recorded |
| 342 | + await asyncio.sleep(0.1) |
| 343 | + # Fail to trigger retry with heartbeat details |
| 344 | + raise Exception("Intentional failure to test heartbeat details") |
| 345 | + |
| 346 | + |
| 347 | +@workflow.defn |
| 348 | +class HeartbeatDetailsSerializationContextTestWorkflow: |
| 349 | + @workflow.run |
| 350 | + async def run(self) -> TraceData: |
| 351 | + return await workflow.execute_activity( |
| 352 | + activity_with_heartbeat_details, |
| 353 | + start_to_close_timeout=timedelta(seconds=10), |
| 354 | + retry_policy=RetryPolicy( |
| 355 | + initial_interval=timedelta(milliseconds=100), |
| 356 | + maximum_attempts=2, |
| 357 | + ), |
| 358 | + ) |
| 359 | + |
| 360 | + |
| 361 | +async def test_heartbeat_details_payload_conversion(client: Client): |
| 362 | + """Test that heartbeat details are decoded with activity context.""" |
| 363 | + workflow_id = str(uuid.uuid4()) |
| 364 | + task_queue = str(uuid.uuid4()) |
| 365 | + |
| 366 | + config = client.config() |
| 367 | + config["data_converter"] = data_converter |
| 368 | + client = Client(**config) |
| 369 | + |
| 370 | + async with Worker( |
| 371 | + client, |
| 372 | + task_queue=task_queue, |
| 373 | + workflows=[HeartbeatDetailsSerializationContextTestWorkflow], |
| 374 | + activities=[activity_with_heartbeat_details], |
| 375 | + workflow_runner=UnsandboxedWorkflowRunner(), # so that we can use isinstance |
| 376 | + ): |
| 377 | + result = await client.execute_workflow( |
| 378 | + HeartbeatDetailsSerializationContextTestWorkflow.run, |
| 379 | + id=workflow_id, |
| 380 | + task_queue=task_queue, |
| 381 | + ) |
| 382 | + |
| 383 | + activity_context = dataclasses.asdict( |
| 384 | + ActivitySerializationContext( |
| 385 | + namespace="default", |
| 386 | + workflow_id=workflow_id, |
| 387 | + workflow_type="HeartbeatDetailsSerializationContextTestWorkflow", |
| 388 | + activity_type="activity_with_heartbeat_details", |
| 389 | + activity_task_queue=task_queue, |
| 390 | + is_local=False, |
| 391 | + ) |
| 392 | + ) |
| 393 | + |
| 394 | + # The result should contain the heartbeat data that was decoded with activity context |
| 395 | + # We expect to see the from_payload trace item for the heartbeat details |
| 396 | + # This test will FAIL until the bug is fixed |
| 397 | + found_heartbeat_decode = False |
| 398 | + for item in result.items: |
| 399 | + if ( |
| 400 | + item.context_type == "activity" |
| 401 | + and item.method == "from_payload" |
| 402 | + and item.in_workflow == False |
| 403 | + and item.context == activity_context |
| 404 | + ): |
| 405 | + found_heartbeat_decode = True |
| 406 | + break |
| 407 | + |
| 408 | + assert found_heartbeat_decode, "Heartbeat details should be decoded with activity context" |
| 409 | + |
| 410 | + |
323 | 411 | # Async activity completion test |
324 | 412 | @activity.defn |
325 | 413 | async def async_activity() -> TraceData: |
|
0 commit comments