11from __future__ import annotations
22
33import dataclasses
4+ import inspect
45import uuid
56from dataclasses import dataclass , field
67from typing import Any , Literal , Optional , Type
@@ -27,6 +28,18 @@ class TraceItem:
2728 context_type : Literal ["workflow" , "activity" ]
2829 method : Literal ["to_payload" , "from_payload" ]
2930 context : WorkflowSerializationContext | ActivitySerializationContext
31+ in_workflow : bool
32+ caller_location : list [str ] = field (default_factory = list )
33+
34+ def __eq__ (self , other : object ) -> bool :
35+ if not isinstance (other , TraceItem ):
36+ return False
37+ return (
38+ self .context_type == other .context_type
39+ and self .method == other .method
40+ and self .context == other .context
41+ and self .in_workflow == other .in_workflow
42+ )
3043
3144
3245@dataclass
@@ -35,12 +48,48 @@ class TraceData:
3548
3649
3750@workflow .defn (sandboxed = False ) # we want to use isinstance
38- class PassThroughWorkflow :
51+ class SerializationContextTestWorkflow :
3952 @workflow .run
4053 async def run (self , input : TraceData ) -> TraceData :
4154 return input
4255
4356
57+ def get_caller_location () -> list [str ]:
58+ """Get 3 stack frames starting from the first that's not in test_serialization_context.py or temporalio/converter.py."""
59+ frame = inspect .currentframe ()
60+ result = []
61+ found_first = False
62+
63+ # Walk up the stack
64+ while frame and len (result ) < 3 :
65+ frame = frame .f_back
66+ if not frame :
67+ break
68+
69+ file_path = frame .f_code .co_filename
70+
71+ # Skip frames from test file and converter.py until we find the first one
72+ if not found_first :
73+ if "test_serialization_context.py" in file_path :
74+ continue
75+ if file_path .endswith ("temporalio/converter.py" ):
76+ continue
77+ found_first = True
78+
79+ # Format and add this frame
80+ line_number = frame .f_lineno
81+ display_path = file_path
82+ if "/sdk-python/" in display_path :
83+ display_path = display_path .split ("/sdk-python/" )[- 1 ]
84+ result .append (f"{ display_path } :{ line_number } " )
85+
86+ # Pad with "unknown:0" if we didn't get 3 frames
87+ while len (result ) < 3 :
88+ result .append ("unknown:0" )
89+
90+ return result
91+
92+
4493class SerializationContextTestEncodingPayloadConverter (
4594 EncodingPayloadConverter , WithSerializationContext
4695):
@@ -69,7 +118,11 @@ def to_payload(self, value: Any) -> Optional[Payload]:
69118 assert isinstance (self .context , WorkflowSerializationContext )
70119 value .items .append (
71120 TraceItem (
72- context_type = "workflow" , method = "to_payload" , context = self .context
121+ context_type = "workflow" ,
122+ in_workflow = workflow .in_workflow (),
123+ method = "to_payload" ,
124+ context = self .context ,
125+ caller_location = get_caller_location (),
73126 )
74127 )
75128 payload = JSONPlainPayloadConverter ().to_payload (value )
@@ -86,7 +139,11 @@ def from_payload(self, payload: Payload, type_hint: Optional[Type] = None) -> An
86139 assert isinstance (self .context , WorkflowSerializationContext )
87140 value .items .append (
88141 TraceItem (
89- context_type = "workflow" , method = "from_payload" , context = self .context
142+ context_type = "workflow" ,
143+ in_workflow = workflow .in_workflow (),
144+ method = "from_payload" ,
145+ context = self .context ,
146+ caller_location = get_caller_location (),
90147 )
91148 )
92149 return value
@@ -122,11 +179,11 @@ async def test_workflow_payload_conversion_can_be_given_access_to_serialization_
122179 async with Worker (
123180 client ,
124181 task_queue = task_queue ,
125- workflows = [PassThroughWorkflow ],
182+ workflows = [SerializationContextTestWorkflow ],
126183 activities = [],
127184 ):
128185 result = await client .execute_workflow (
129- PassThroughWorkflow .run ,
186+ SerializationContextTestWorkflow .run ,
130187 TraceData (),
131188 id = workflow_id ,
132189 task_queue = task_queue ,
@@ -136,5 +193,29 @@ async def test_workflow_payload_conversion_can_be_given_access_to_serialization_
136193 namespace = "default" ,
137194 workflow_id = workflow_id ,
138195 )
139- for item in result .items :
140- print (item )
196+ assert result .items == [
197+ TraceItem (
198+ context_type = "workflow" ,
199+ in_workflow = False ,
200+ method = "to_payload" ,
201+ context = workflow_context ,
202+ ),
203+ TraceItem (
204+ context_type = "workflow" ,
205+ in_workflow = False ,
206+ method = "from_payload" ,
207+ context = workflow_context ,
208+ ),
209+ TraceItem (
210+ context_type = "workflow" ,
211+ in_workflow = True ,
212+ method = "to_payload" ,
213+ context = workflow_context ,
214+ ),
215+ TraceItem (
216+ context_type = "workflow" ,
217+ in_workflow = False ,
218+ method = "from_payload" ,
219+ context = workflow_context ,
220+ ),
221+ ]
0 commit comments