88from datetime import timedelta
99from itertools import zip_longest
1010from pprint import pformat , pprint
11- from typing import Any , List , Literal , Optional , Sequence , Type
11+ from typing import Any , List , Literal , Never , Optional , Sequence , Type
1212from warnings import warn
1313
1414import pytest
1717from temporalio import activity , workflow
1818from temporalio .api .common .v1 import Payload
1919from temporalio .api .failure .v1 import Failure
20- from temporalio .client import Client , WorkflowUpdateFailedError
20+ from temporalio .client import Client , WorkflowFailureError , WorkflowUpdateFailedError
2121from temporalio .common import RetryPolicy
2222from temporalio .contrib .pydantic import PydanticJSONPlainPayloadConverter
2323from temporalio .converter import (
4242@dataclass
4343class TraceItem :
4444 context_type : Literal ["workflow" , "activity" ]
45- method : Literal ["to_payload" , "from_payload" ]
45+ method : Literal [
46+ "to_payload" ,
47+ "from_payload" ,
48+ "to_failure" ,
49+ "from_failure" ,
50+ ]
4651 context : dict [str , Any ]
4752 in_workflow : bool
4853 caller_location : list [str ] = field (default_factory = list )
@@ -943,24 +948,23 @@ async def test_external_workflow_signal_and_cancel_payload_conversion(
943948 # The cancel context would only be used for failure deserialization
944949
945950
951+ # Failure conversion
952+
953+
946954@activity .defn
947- async def failing_activity () -> TraceData :
948- raise ApplicationError ("test error" , TraceData ())
955+ async def failing_activity () -> Never :
956+ raise ApplicationError ("test error" , dataclasses . asdict ( TraceData () ))
949957
950958
951959@workflow .defn
952960class FailureContextWorkflow :
953961 @workflow .run
954- async def run (self ) -> TraceData :
955- try :
956- await workflow .execute_activity (
957- failing_activity ,
958- start_to_close_timeout = timedelta (seconds = 10 ),
959- retry_policy = RetryPolicy (maximum_attempts = 1 ),
960- )
961- except ActivityError as e :
962- assert isinstance (e .cause , ApplicationError ) and e .cause .details
963- return e .cause .details [0 ]
962+ async def run (self ) -> Never :
963+ await workflow .execute_activity (
964+ failing_activity ,
965+ start_to_close_timeout = timedelta (seconds = 10 ),
966+ retry_policy = RetryPolicy (maximum_attempts = 1 ),
967+ )
964968 raise Exception ("Unreachable" )
965969
966970
@@ -982,36 +986,80 @@ def to_failure(
982986 payload_converter : PayloadConverter ,
983987 failure : Failure ,
984988 ) -> None :
985- super (). to_failure ( exception , payload_converter , failure )
989+ print ( "🌈 to_failure" )
986990 if isinstance (exception , ApplicationError ) and exception .details :
987- for detail in exception .details :
988- if isinstance (detail , TraceData ) and self .context :
989- if isinstance (self .context , ActivitySerializationContext ):
990- detail .items .append (
991- TraceItem (
992- context_type = "activity" ,
993- in_workflow = False ,
994- method = "to_payload" ,
995- context = dataclasses .asdict (self .context ),
996- )
991+ if isinstance (
992+ self .context ,
993+ (WorkflowSerializationContext , ActivitySerializationContext ),
994+ ):
995+ context_type = (
996+ "workflow"
997+ if isinstance (self .context , WorkflowSerializationContext )
998+ else "activity"
999+ )
1000+ print (
1001+ f" 🌈 to_failure appending { context_type } : { exception .details } "
1002+ )
1003+ exception .details [0 ]["items" ].append (
1004+ dataclasses .asdict (
1005+ TraceItem (
1006+ context_type = context_type ,
1007+ in_workflow = workflow .in_workflow (),
1008+ method = "to_failure" ,
1009+ context = dataclasses .asdict (self .context ),
9971010 )
1011+ )
1012+ )
1013+ else :
1014+ raise TypeError (f"self.context is { type (self .context )} " )
1015+
1016+ super ().to_failure (exception , payload_converter , failure )
9981017
9991018 def from_failure (
10001019 self , failure : Failure , payload_converter : PayloadConverter
10011020 ) -> BaseException :
1021+ print ("🌈 from_failure" )
10021022 # Let the base class create the exception
10031023 exception = super ().from_failure (failure , payload_converter )
1004- # The context tracing is already in the payloads that will be decoded with context
1024+ print (f" 🌈 { exception .__class__ } " )
1025+ if isinstance (exception , ApplicationError ) and exception .details :
1026+ if isinstance (
1027+ self .context ,
1028+ (WorkflowSerializationContext , ActivitySerializationContext ),
1029+ ):
1030+ context_type = (
1031+ "workflow"
1032+ if isinstance (self .context , WorkflowSerializationContext )
1033+ else "activity"
1034+ )
1035+ print (
1036+ f" 🌈 from_failure appending { context_type } : { exception .details } "
1037+ )
1038+ exception .details [0 ]["items" ].append (
1039+ dataclasses .asdict (
1040+ TraceItem (
1041+ context_type = context_type ,
1042+ in_workflow = workflow .in_workflow (),
1043+ method = "from_failure" ,
1044+ context = dataclasses .asdict (self .context ),
1045+ )
1046+ )
1047+ )
1048+ else :
1049+ raise TypeError (f"self.context is { type (self .context )} " )
10051050 return exception
10061051
10071052
10081053async def test_failure_conversion_with_context (client : Client ):
1054+ print ()
1055+ workflow_id = str (uuid .uuid4 ())
10091056 task_queue = str (uuid .uuid4 ())
1057+
10101058 test_client = Client (
10111059 client .service_client ,
10121060 namespace = client .namespace ,
1013- data_converter = DataConverter (
1014- payload_converter_class = SerializationContextTestPayloadConverter ,
1061+ data_converter = dataclasses . replace (
1062+ DataConverter . default ,
10151063 failure_converter_class = ContextFailureConverter ,
10161064 ),
10171065 )
@@ -1022,12 +1070,73 @@ async def test_failure_conversion_with_context(client: Client):
10221070 activities = [failing_activity ],
10231071 workflow_runner = UnsandboxedWorkflowRunner (),
10241072 ):
1025- result = await test_client .execute_workflow (
1026- FailureContextWorkflow .run ,
1027- id = str (uuid .uuid4 ()),
1028- task_queue = task_queue ,
1029- )
1030- pprint (result .items )
1073+ try :
1074+ await test_client .execute_workflow (
1075+ FailureContextWorkflow .run ,
1076+ id = workflow_id ,
1077+ task_queue = task_queue ,
1078+ )
1079+ except Exception as err :
1080+ assert isinstance (err , WorkflowFailureError )
1081+ assert isinstance (err .cause , ActivityError )
1082+ assert isinstance (err .cause .cause , ApplicationError )
1083+ pprint (err .cause .cause .details )
1084+
1085+ workflow_context = dataclasses .asdict (
1086+ WorkflowSerializationContext (
1087+ namespace = "default" ,
1088+ workflow_id = workflow_id ,
1089+ )
1090+ )
1091+ activity_context = dataclasses .asdict (
1092+ ActivitySerializationContext (
1093+ namespace = "default" ,
1094+ workflow_id = workflow_id ,
1095+ workflow_type = "FailureContextWorkflow" ,
1096+ activity_type = "failing_activity" ,
1097+ activity_task_queue = task_queue ,
1098+ is_local = False ,
1099+ )
1100+ )
1101+ # 1. Exception raised in activity
1102+ # 2. outbound activity result to_failure(act, activity_ctx) appends and serializes
1103+ # 3. -> server -> WFT -> WF
1104+ # 4. inbound activity result from_failure(wf, activity_ctx, in_wf=False) deserializes and appends
1105+ # 5. outbound wf result to_failure(wf, in_wf=True) appends and serializes
1106+ # 6. inbound wf result from_failure(client, wf_context, in_wf=False)
1107+ if False :
1108+ assert_trace (
1109+ err .cause .cause .details ,
1110+ [
1111+ dataclasses .asdict (d )
1112+ for d in [
1113+ TraceItem (
1114+ context_type = "activity" ,
1115+ context = activity_context ,
1116+ in_workflow = False ,
1117+ method = "to_failure" , # outbound activity result
1118+ ),
1119+ TraceItem (
1120+ context_type = "activity" ,
1121+ context = activity_context ,
1122+ in_workflow = False ,
1123+ method = "from_failure" , # inbound activity result
1124+ ),
1125+ TraceItem (
1126+ context_type = "workflow" ,
1127+ context = workflow_context ,
1128+ in_workflow = True ,
1129+ method = "to_failure" , # outbound workflow result
1130+ ),
1131+ TraceItem (
1132+ context_type = "workflow" ,
1133+ context = workflow_context ,
1134+ in_workflow = False ,
1135+ method = "from_failure" , # inbound workflow result
1136+ ),
1137+ ]
1138+ ],
1139+ )
10311140
10321141
10331142class ContextCodec (PayloadCodec , WithSerializationContext ):
@@ -1193,12 +1302,16 @@ def assert_trace(trace: list[TraceItem], expected: list[TraceItem]):
11931302 history : list [str ] = []
11941303 for item , expected_item in zip_longest (trace , expected ):
11951304 if item is None :
1196- raise AssertionError ("Fewer items in trace than expected" )
1305+ raise AssertionError (
1306+ f"Fewer items in trace than expected.\n \n History:\n { '\n ' .join (history )} "
1307+ )
11971308 if expected_item is None :
1198- raise AssertionError ("More items in trace than expected" )
1309+ raise AssertionError (
1310+ f"More items in trace than expected.\n \n History:\n { '\n ' .join (history )} "
1311+ )
11991312 if item != expected_item :
12001313 raise AssertionError (
1201- f"Item:\n { pformat (item )} \n \n does not match expected:\n \n { pformat (expected_item )} .\n \n History:\n { chr ( 10 ) .join (history )} "
1314+ f"Item:\n { pformat (item )} \n \n does not match expected:\n \n { pformat (expected_item )} .\n \n History:\n { ' \n ' .join (history )} "
12021315 )
12031316 history .append (f"{ item .context_type } { item .method } " )
12041317
0 commit comments