Skip to content

Commit 381a9fd

Browse files
committed
Refactor test
1 parent 74134bd commit 381a9fd

File tree

1 file changed

+56
-52
lines changed

1 file changed

+56
-52
lines changed

tests/test_serialization_context.py

Lines changed: 56 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
WorkflowSerializationContext,
2727
)
2828
from temporalio.worker import Worker
29+
from temporalio.worker._workflow_instance import UnsandboxedWorkflowRunner
2930

3031

3132
@dataclass
@@ -60,14 +61,14 @@ async def passthrough_activity(input: TraceData) -> TraceData:
6061
return input
6162

6263

63-
@workflow.defn(sandboxed=False)
64+
@workflow.defn
6465
class EchoWorkflow:
6566
@workflow.run
6667
async def run(self, data: TraceData) -> TraceData:
6768
return data
6869

6970

70-
@workflow.defn(sandboxed=False) # we want to use isinstance
71+
@workflow.defn
7172
class SerializationContextTestWorkflow:
7273
@workflow.run
7374
async def run(self, data: TraceData) -> TraceData:
@@ -192,6 +193,7 @@ async def test_workflow_payload_conversion_can_be_given_access_to_serialization_
192193
task_queue=task_queue,
193194
workflows=[SerializationContextTestWorkflow, EchoWorkflow],
194195
activities=[passthrough_activity],
196+
workflow_runner=UnsandboxedWorkflowRunner(), # so that we can use isinstance
195197
):
196198
result = await client.execute_workflow(
197199
SerializationContextTestWorkflow.run,
@@ -310,56 +312,6 @@ async def test_workflow_payload_conversion_can_be_given_access_to_serialization_
310312
pprint(result.items)
311313

312314

313-
def assert_trace(trace: list[TraceItem], expected: list[TraceItem]):
314-
history = []
315-
for item, expected_item in zip_longest(trace, expected):
316-
if item is None:
317-
raise AssertionError("Fewer items in trace than expected")
318-
if expected_item is None:
319-
raise AssertionError("More items in trace than expected")
320-
if item != expected_item:
321-
raise AssertionError(
322-
f"Item:\n{pformat(item)}\n\ndoes not match expected:\n\n {pformat(expected_item)}.\n\n History:\n{chr(10).join(history)}"
323-
)
324-
history.append(f"{item.context_type} {item.method}")
325-
326-
327-
def get_caller_location() -> list[str]:
328-
"""Get 3 stack frames starting from the first that's not in test_serialization_context.py or temporalio/converter.py."""
329-
frame = inspect.currentframe()
330-
result = []
331-
found_first = False
332-
333-
# Walk up the stack
334-
while frame and len(result) < 3:
335-
frame = frame.f_back
336-
if not frame:
337-
break
338-
339-
file_path = frame.f_code.co_filename
340-
341-
# Skip frames from test file and converter.py until we find the first one
342-
if not found_first:
343-
if "test_serialization_context.py" in file_path:
344-
continue
345-
if file_path.endswith("temporalio/converter.py"):
346-
continue
347-
found_first = True
348-
349-
# Format and add this frame
350-
line_number = frame.f_lineno
351-
display_path = file_path
352-
if "/sdk-python/" in display_path:
353-
display_path = display_path.split("/sdk-python/")[-1]
354-
result.append(f"{display_path}:{line_number}")
355-
356-
# Pad with "unknown:0" if we didn't get 3 frames
357-
while len(result) < 3:
358-
result.append("unknown:0")
359-
360-
return result
361-
362-
363315
# Signal test
364316

365317

@@ -642,3 +594,55 @@ async def test_query_payload_conversion_can_be_given_access_to_serialization_con
642594

643595
# Cancel the workflow to clean up
644596
await handle.cancel()
597+
598+
599+
# Utilities
600+
601+
602+
def assert_trace(trace: list[TraceItem], expected: list[TraceItem]):
603+
if len(trace) != len(expected):
604+
raise AssertionError(
605+
f"expected {len(expected)} trace items but received {len(trace)}"
606+
)
607+
history = []
608+
for item, expected_item in zip_longest(trace, expected):
609+
if item is None:
610+
raise AssertionError("Fewer items in trace than expected")
611+
if expected_item is None:
612+
raise AssertionError("More items in trace than expected")
613+
if item != expected_item:
614+
raise AssertionError(
615+
f"Item:\n{pformat(item)}\n\ndoes not match expected:\n\n {pformat(expected_item)}.\n\n History:\n{chr(10).join(history)}"
616+
)
617+
history.append(f"{item.context_type} {item.method}")
618+
619+
620+
def get_caller_location() -> list[str]:
621+
"""Get 3 stack frames starting from the first that's not in test_serialization_context.py or temporalio/converter.py."""
622+
frame = inspect.currentframe()
623+
result = []
624+
found_first = False
625+
626+
# Walk up the stack
627+
while frame and len(result) < 3:
628+
frame = frame.f_back
629+
if not frame:
630+
break
631+
632+
file_path = frame.f_code.co_filename
633+
634+
# Skip frames from test file and converter.py until we find the first one
635+
if not found_first:
636+
if "test_serialization_context.py" in file_path:
637+
continue
638+
if file_path.endswith("temporalio/converter.py"):
639+
continue
640+
found_first = True
641+
642+
result.append(f"{file_path}:{frame.f_lineno}")
643+
644+
# Pad with "unknown:0" if we didn't get 3 frames
645+
while len(result) < 3:
646+
result.append("unknown:0")
647+
648+
return result

0 commit comments

Comments
 (0)