|
26 | 26 | WorkflowSerializationContext, |
27 | 27 | ) |
28 | 28 | from temporalio.worker import Worker |
| 29 | +from temporalio.worker._workflow_instance import UnsandboxedWorkflowRunner |
29 | 30 |
|
30 | 31 |
|
31 | 32 | @dataclass |
@@ -60,14 +61,14 @@ async def passthrough_activity(input: TraceData) -> TraceData: |
60 | 61 | return input |
61 | 62 |
|
62 | 63 |
|
63 | | -@workflow.defn(sandboxed=False) |
| 64 | +@workflow.defn |
64 | 65 | class EchoWorkflow: |
65 | 66 | @workflow.run |
66 | 67 | async def run(self, data: TraceData) -> TraceData: |
67 | 68 | return data |
68 | 69 |
|
69 | 70 |
|
70 | | -@workflow.defn(sandboxed=False) # we want to use isinstance |
| 71 | +@workflow.defn |
71 | 72 | class SerializationContextTestWorkflow: |
72 | 73 | @workflow.run |
73 | 74 | async def run(self, data: TraceData) -> TraceData: |
@@ -192,6 +193,7 @@ async def test_workflow_payload_conversion_can_be_given_access_to_serialization_ |
192 | 193 | task_queue=task_queue, |
193 | 194 | workflows=[SerializationContextTestWorkflow, EchoWorkflow], |
194 | 195 | activities=[passthrough_activity], |
| 196 | + workflow_runner=UnsandboxedWorkflowRunner(), # so that we can use isinstance |
195 | 197 | ): |
196 | 198 | result = await client.execute_workflow( |
197 | 199 | SerializationContextTestWorkflow.run, |
@@ -310,56 +312,6 @@ async def test_workflow_payload_conversion_can_be_given_access_to_serialization_ |
310 | 312 | pprint(result.items) |
311 | 313 |
|
312 | 314 |
|
313 | | -def assert_trace(trace: list[TraceItem], expected: list[TraceItem]): |
314 | | - history = [] |
315 | | - for item, expected_item in zip_longest(trace, expected): |
316 | | - if item is None: |
317 | | - raise AssertionError("Fewer items in trace than expected") |
318 | | - if expected_item is None: |
319 | | - raise AssertionError("More items in trace than expected") |
320 | | - if item != expected_item: |
321 | | - raise AssertionError( |
322 | | - f"Item:\n{pformat(item)}\n\ndoes not match expected:\n\n {pformat(expected_item)}.\n\n History:\n{chr(10).join(history)}" |
323 | | - ) |
324 | | - history.append(f"{item.context_type} {item.method}") |
325 | | - |
326 | | - |
327 | | -def get_caller_location() -> list[str]: |
328 | | - """Get 3 stack frames starting from the first that's not in test_serialization_context.py or temporalio/converter.py.""" |
329 | | - frame = inspect.currentframe() |
330 | | - result = [] |
331 | | - found_first = False |
332 | | - |
333 | | - # Walk up the stack |
334 | | - while frame and len(result) < 3: |
335 | | - frame = frame.f_back |
336 | | - if not frame: |
337 | | - break |
338 | | - |
339 | | - file_path = frame.f_code.co_filename |
340 | | - |
341 | | - # Skip frames from test file and converter.py until we find the first one |
342 | | - if not found_first: |
343 | | - if "test_serialization_context.py" in file_path: |
344 | | - continue |
345 | | - if file_path.endswith("temporalio/converter.py"): |
346 | | - continue |
347 | | - found_first = True |
348 | | - |
349 | | - # Format and add this frame |
350 | | - line_number = frame.f_lineno |
351 | | - display_path = file_path |
352 | | - if "/sdk-python/" in display_path: |
353 | | - display_path = display_path.split("/sdk-python/")[-1] |
354 | | - result.append(f"{display_path}:{line_number}") |
355 | | - |
356 | | - # Pad with "unknown:0" if we didn't get 3 frames |
357 | | - while len(result) < 3: |
358 | | - result.append("unknown:0") |
359 | | - |
360 | | - return result |
361 | | - |
362 | | - |
363 | 315 | # Signal test |
364 | 316 |
|
365 | 317 |
|
@@ -642,3 +594,55 @@ async def test_query_payload_conversion_can_be_given_access_to_serialization_con |
642 | 594 |
|
643 | 595 | # Cancel the workflow to clean up |
644 | 596 | await handle.cancel() |
| 597 | + |
| 598 | + |
| 599 | +# Utilities |
| 600 | + |
| 601 | + |
| 602 | +def assert_trace(trace: list[TraceItem], expected: list[TraceItem]): |
| 603 | + if len(trace) != len(expected): |
| 604 | + raise AssertionError( |
| 605 | + f"expected {len(expected)} trace items but received {len(trace)}" |
| 606 | + ) |
| 607 | + history = [] |
| 608 | + for item, expected_item in zip_longest(trace, expected): |
| 609 | + if item is None: |
| 610 | + raise AssertionError("Fewer items in trace than expected") |
| 611 | + if expected_item is None: |
| 612 | + raise AssertionError("More items in trace than expected") |
| 613 | + if item != expected_item: |
| 614 | + raise AssertionError( |
| 615 | + f"Item:\n{pformat(item)}\n\ndoes not match expected:\n\n {pformat(expected_item)}.\n\n History:\n{chr(10).join(history)}" |
| 616 | + ) |
| 617 | + history.append(f"{item.context_type} {item.method}") |
| 618 | + |
| 619 | + |
| 620 | +def get_caller_location() -> list[str]: |
| 621 | + """Get 3 stack frames starting from the first that's not in test_serialization_context.py or temporalio/converter.py.""" |
| 622 | + frame = inspect.currentframe() |
| 623 | + result = [] |
| 624 | + found_first = False |
| 625 | + |
| 626 | + # Walk up the stack |
| 627 | + while frame and len(result) < 3: |
| 628 | + frame = frame.f_back |
| 629 | + if not frame: |
| 630 | + break |
| 631 | + |
| 632 | + file_path = frame.f_code.co_filename |
| 633 | + |
| 634 | + # Skip frames from test file and converter.py until we find the first one |
| 635 | + if not found_first: |
| 636 | + if "test_serialization_context.py" in file_path: |
| 637 | + continue |
| 638 | + if file_path.endswith("temporalio/converter.py"): |
| 639 | + continue |
| 640 | + found_first = True |
| 641 | + |
| 642 | + result.append(f"{file_path}:{frame.f_lineno}") |
| 643 | + |
| 644 | + # Pad with "unknown:0" if we didn't get 3 frames |
| 645 | + while len(result) < 3: |
| 646 | + result.append("unknown:0") |
| 647 | + |
| 648 | + return result |
0 commit comments