Skip to content

Commit 224d590

Browse files
committed
Test passes
1 parent e3a0a3b commit 224d590

File tree

1 file changed

+88
-7
lines changed

1 file changed

+88
-7
lines changed

tests/test_serialization_context.py

Lines changed: 88 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations
22

33
import dataclasses
4+
import inspect
45
import uuid
56
from dataclasses import dataclass, field
67
from typing import Any, Literal, Optional, Type
@@ -27,6 +28,18 @@ class TraceItem:
2728
context_type: Literal["workflow", "activity"]
2829
method: Literal["to_payload", "from_payload"]
2930
context: WorkflowSerializationContext | ActivitySerializationContext
31+
in_workflow: bool
32+
caller_location: list[str] = field(default_factory=list)
33+
34+
def __eq__(self, other: object) -> bool:
35+
if not isinstance(other, TraceItem):
36+
return False
37+
return (
38+
self.context_type == other.context_type
39+
and self.method == other.method
40+
and self.context == other.context
41+
and self.in_workflow == other.in_workflow
42+
)
3043

3144

3245
@dataclass
@@ -35,12 +48,48 @@ class TraceData:
3548

3649

3750
@workflow.defn(sandboxed=False) # we want to use isinstance
38-
class PassThroughWorkflow:
51+
class SerializationContextTestWorkflow:
3952
@workflow.run
4053
async def run(self, input: TraceData) -> TraceData:
4154
return input
4255

4356

57+
def get_caller_location() -> list[str]:
58+
"""Get 3 stack frames starting from the first that's not in test_serialization_context.py or temporalio/converter.py."""
59+
frame = inspect.currentframe()
60+
result = []
61+
found_first = False
62+
63+
# Walk up the stack
64+
while frame and len(result) < 3:
65+
frame = frame.f_back
66+
if not frame:
67+
break
68+
69+
file_path = frame.f_code.co_filename
70+
71+
# Skip frames from test file and converter.py until we find the first one
72+
if not found_first:
73+
if "test_serialization_context.py" in file_path:
74+
continue
75+
if file_path.endswith("temporalio/converter.py"):
76+
continue
77+
found_first = True
78+
79+
# Format and add this frame
80+
line_number = frame.f_lineno
81+
display_path = file_path
82+
if "/sdk-python/" in display_path:
83+
display_path = display_path.split("/sdk-python/")[-1]
84+
result.append(f"{display_path}:{line_number}")
85+
86+
# Pad with "unknown:0" if we didn't get 3 frames
87+
while len(result) < 3:
88+
result.append("unknown:0")
89+
90+
return result
91+
92+
4493
class SerializationContextTestEncodingPayloadConverter(
4594
EncodingPayloadConverter, WithSerializationContext
4695
):
@@ -69,7 +118,11 @@ def to_payload(self, value: Any) -> Optional[Payload]:
69118
assert isinstance(self.context, WorkflowSerializationContext)
70119
value.items.append(
71120
TraceItem(
72-
context_type="workflow", method="to_payload", context=self.context
121+
context_type="workflow",
122+
in_workflow=workflow.in_workflow(),
123+
method="to_payload",
124+
context=self.context,
125+
caller_location=get_caller_location(),
73126
)
74127
)
75128
payload = JSONPlainPayloadConverter().to_payload(value)
@@ -86,7 +139,11 @@ def from_payload(self, payload: Payload, type_hint: Optional[Type] = None) -> An
86139
assert isinstance(self.context, WorkflowSerializationContext)
87140
value.items.append(
88141
TraceItem(
89-
context_type="workflow", method="from_payload", context=self.context
142+
context_type="workflow",
143+
in_workflow=workflow.in_workflow(),
144+
method="from_payload",
145+
context=self.context,
146+
caller_location=get_caller_location(),
90147
)
91148
)
92149
return value
@@ -122,11 +179,11 @@ async def test_workflow_payload_conversion_can_be_given_access_to_serialization_
122179
async with Worker(
123180
client,
124181
task_queue=task_queue,
125-
workflows=[PassThroughWorkflow],
182+
workflows=[SerializationContextTestWorkflow],
126183
activities=[],
127184
):
128185
result = await client.execute_workflow(
129-
PassThroughWorkflow.run,
186+
SerializationContextTestWorkflow.run,
130187
TraceData(),
131188
id=workflow_id,
132189
task_queue=task_queue,
@@ -136,5 +193,29 @@ async def test_workflow_payload_conversion_can_be_given_access_to_serialization_
136193
namespace="default",
137194
workflow_id=workflow_id,
138195
)
139-
for item in result.items:
140-
print(item)
196+
assert result.items == [
197+
TraceItem(
198+
context_type="workflow",
199+
in_workflow=False,
200+
method="to_payload",
201+
context=workflow_context,
202+
),
203+
TraceItem(
204+
context_type="workflow",
205+
in_workflow=False,
206+
method="from_payload",
207+
context=workflow_context,
208+
),
209+
TraceItem(
210+
context_type="workflow",
211+
in_workflow=True,
212+
method="to_payload",
213+
context=workflow_context,
214+
),
215+
TraceItem(
216+
context_type="workflow",
217+
in_workflow=False,
218+
method="from_payload",
219+
context=workflow_context,
220+
),
221+
]

0 commit comments

Comments
 (0)