@@ -48,6 +48,8 @@ class TraceItem:
4848 "from_payload" ,
4949 "to_failure" ,
5050 "from_failure" ,
51+ "encode" ,
52+ "decode" ,
5153 ]
5254 context : dict [str , Any ]
5355 in_workflow : bool
@@ -846,10 +848,10 @@ async def test_external_workflow_signal_and_cancel_payload_conversion(
846848 DataConverter .default ,
847849 payload_converter_class = SerializationContextTestPayloadConverter ,
848850 )
849- custom_client = Client (** config )
851+ client = Client (** config )
850852
851853 async with Worker (
852- custom_client ,
854+ client ,
853855 task_queue = task_queue ,
854856 workflows = [
855857 ExternalWorkflowTarget ,
@@ -860,13 +862,13 @@ async def test_external_workflow_signal_and_cancel_payload_conversion(
860862 workflow_runner = UnsandboxedWorkflowRunner (), # so that we can use isinstance
861863 ):
862864 # Test external signal
863- target_handle = await custom_client .start_workflow (
865+ target_handle = await client .start_workflow (
864866 ExternalWorkflowTarget .run ,
865867 id = target_workflow_id ,
866868 task_queue = task_queue ,
867869 )
868870
869- signaler_handle = await custom_client .start_workflow (
871+ signaler_handle = await client .start_workflow (
870872 ExternalWorkflowSignaler .run ,
871873 args = [target_workflow_id , TraceData ()],
872874 id = signaler_workflow_id ,
@@ -953,7 +955,7 @@ async def run(self) -> Never:
953955 raise Exception ("Unreachable" )
954956
955957
956- failure_converter_test_trace : dict [str , list [TraceItem ]] = defaultdict (list )
958+ test_traces : dict [str , list [TraceItem ]] = defaultdict (list )
957959
958960
959961class FailureConverterWithContext (DefaultFailureConverter , WithSerializationContext ):
@@ -981,7 +983,7 @@ def to_failure(
981983 else :
982984 raise TypeError (f"self.context is { type (self .context )} " )
983985
984- failure_converter_test_trace [self .context .workflow_id ].append (
986+ test_traces [self .context .workflow_id ].append (
985987 TraceItem (
986988 context_type = context_type ,
987989 in_workflow = workflow .in_workflow (),
@@ -1002,7 +1004,7 @@ def from_failure(
10021004 else :
10031005 raise TypeError (f"self.context is { type (self .context )} " )
10041006
1005- failure_converter_test_trace [self .context .workflow_id ].append (
1007+ test_traces [self .context .workflow_id ].append (
10061008 TraceItem (
10071009 context_type = context_type ,
10081010 in_workflow = workflow .in_workflow (),
@@ -1022,20 +1024,19 @@ async def test_failure_converter_with_context(client: Client):
10221024 DataConverter .default ,
10231025 failure_converter_class = FailureConverterWithContext ,
10241026 )
1025- test_client = Client (
1026- client .service_client ,
1027- namespace = client .namespace ,
1028- data_converter = data_converter ,
1029- )
1027+ config = client .config ()
1028+ config ["data_converter" ] = data_converter
1029+ client = Client (** config )
1030+
10301031 async with Worker (
1031- test_client ,
1032+ client ,
10321033 task_queue = task_queue ,
10331034 workflows = [FailureConverterTestWorkflow ],
10341035 activities = [failing_activity ],
10351036 workflow_runner = UnsandboxedWorkflowRunner (),
10361037 ):
10371038 try :
1038- await test_client .execute_workflow (
1039+ await client .execute_workflow (
10391040 FailureConverterTestWorkflow .run ,
10401041 id = workflow_id ,
10411042 task_queue = task_queue ,
@@ -1063,7 +1064,7 @@ async def test_failure_converter_with_context(client: Client):
10631064 )
10641065 )
10651066 assert_trace (
1066- failure_converter_test_trace [workflow_id ],
1067+ test_traces [workflow_id ],
10671068 [
10681069 TraceItem (
10691070 context_type = "activity" ,
@@ -1103,7 +1104,7 @@ async def test_failure_converter_with_context(client: Client):
11031104 * 2 # from_failure deserializes the error and error cause
11041105 ),
11051106 )
1106- del failure_converter_test_trace [workflow_id ]
1107+ del test_traces [workflow_id ]
11071108
11081109
11091110class PayloadCodecWithContext (PayloadCodec , WithSerializationContext ):
@@ -1120,26 +1121,30 @@ def with_context(
11201121 return codec
11211122
11221123 async def encode (self , payloads : Sequence [Payload ]) -> List [Payload ]:
1123- result = []
1124- for p in payloads :
1125- new_p = Payload ()
1126- new_p .CopyFrom (p )
1127- if self .context :
1128- self .encode_called_with_context = True
1129- new_p .metadata ["has_context" ] = b"true"
1130- result .append (new_p )
1131- return result
1124+ assert self .context
1125+ assert isinstance (self .context , WorkflowSerializationContext )
1126+ test_traces [self .context .workflow_id ].append (
1127+ TraceItem (
1128+ context_type = "workflow" ,
1129+ context = dataclasses .asdict (self .context ),
1130+ method = "encode" ,
1131+ in_workflow = workflow .in_workflow (),
1132+ )
1133+ )
1134+ return list (payloads )
11321135
11331136 async def decode (self , payloads : Sequence [Payload ]) -> List [Payload ]:
1134- result = []
1135- for p in payloads :
1136- new_p = Payload ()
1137- new_p .CopyFrom (p )
1138- if self .context and new_p .metadata .get ("has_context" ) == b"true" :
1139- self .decode_called_with_context = True
1140- del new_p .metadata ["has_context" ]
1141- result .append (new_p )
1142- return result
1137+ assert self .context
1138+ assert isinstance (self .context , WorkflowSerializationContext )
1139+ test_traces [self .context .workflow_id ].append (
1140+ TraceItem (
1141+ context_type = "workflow" ,
1142+ context = dataclasses .asdict (self .context ),
1143+ method = "decode" ,
1144+ in_workflow = workflow .in_workflow (),
1145+ )
1146+ )
1147+ return list (payloads )
11431148
11441149
11451150@workflow .defn
@@ -1150,26 +1155,61 @@ async def run(self, data: str) -> str:
11501155
11511156
11521157async def test_codec_with_context (client : Client ):
1153- wf_id = str (uuid .uuid4 ())
1158+ workflow_id = str (uuid .uuid4 ())
11541159 task_queue = str (uuid .uuid4 ())
1155- test_client = Client (
1156- client .service_client ,
1157- namespace = client .namespace ,
1158- data_converter = dataclasses .replace (
1159- DataConverter .default , payload_codec = PayloadCodecWithContext ()
1160- ),
1160+
1161+ client_config = client .config ()
1162+ client_config ["data_converter" ] = dataclasses .replace (
1163+ DataConverter .default , payload_codec = PayloadCodecWithContext ()
11611164 )
1165+ client = Client (** client_config )
11621166 async with Worker (
1163- test_client ,
1167+ client ,
11641168 task_queue = task_queue ,
11651169 workflows = [CodecTestWorkflow ],
11661170 ):
1167- await test_client .execute_workflow (
1171+ await client .execute_workflow (
11681172 CodecTestWorkflow .run ,
11691173 "data" ,
1170- id = wf_id ,
1174+ id = workflow_id ,
11711175 task_queue = task_queue ,
11721176 )
1177+ workflow_context = dataclasses .asdict (
1178+ WorkflowSerializationContext (
1179+ namespace = client .namespace ,
1180+ workflow_id = workflow_id ,
1181+ )
1182+ )
1183+ assert_trace (
1184+ test_traces [workflow_id ],
1185+ [
1186+ TraceItem (
1187+ context_type = "workflow" ,
1188+ context = workflow_context ,
1189+ method = "encode" ,
1190+ in_workflow = False ,
1191+ ),
1192+ TraceItem (
1193+ context_type = "workflow" ,
1194+ context = workflow_context ,
1195+ method = "decode" ,
1196+ in_workflow = False ,
1197+ ),
1198+ TraceItem (
1199+ context_type = "workflow" ,
1200+ context = workflow_context ,
1201+ method = "encode" ,
1202+ in_workflow = False ,
1203+ ),
1204+ TraceItem (
1205+ context_type = "workflow" ,
1206+ context = workflow_context ,
1207+ method = "decode" ,
1208+ in_workflow = False ,
1209+ ),
1210+ ],
1211+ )
1212+ del test_traces [workflow_id ]
11731213
11741214
11751215# Pydantic
0 commit comments