Skip to content

Commit 428fe03

Browse files
committed
WIP: failure test
1 parent 1f40869 commit 428fe03

File tree

1 file changed

+151
-38
lines changed

1 file changed

+151
-38
lines changed

tests/test_serialization_context.py

Lines changed: 151 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from datetime import timedelta
99
from itertools import zip_longest
1010
from pprint import pformat, pprint
11-
from typing import Any, List, Literal, Optional, Sequence, Type
11+
from typing import Any, List, Literal, Never, Optional, Sequence, Type
1212
from warnings import warn
1313

1414
import pytest
@@ -17,7 +17,7 @@
1717
from temporalio import activity, workflow
1818
from temporalio.api.common.v1 import Payload
1919
from temporalio.api.failure.v1 import Failure
20-
from temporalio.client import Client, WorkflowUpdateFailedError
20+
from temporalio.client import Client, WorkflowFailureError, WorkflowUpdateFailedError
2121
from temporalio.common import RetryPolicy
2222
from temporalio.contrib.pydantic import PydanticJSONPlainPayloadConverter
2323
from temporalio.converter import (
@@ -42,7 +42,12 @@
4242
@dataclass
4343
class TraceItem:
4444
context_type: Literal["workflow", "activity"]
45-
method: Literal["to_payload", "from_payload"]
45+
method: Literal[
46+
"to_payload",
47+
"from_payload",
48+
"to_failure",
49+
"from_failure",
50+
]
4651
context: dict[str, Any]
4752
in_workflow: bool
4853
caller_location: list[str] = field(default_factory=list)
@@ -943,24 +948,23 @@ async def test_external_workflow_signal_and_cancel_payload_conversion(
943948
# The cancel context would only be used for failure deserialization
944949

945950

951+
# Failure conversion
952+
953+
946954
@activity.defn
947-
async def failing_activity() -> TraceData:
948-
raise ApplicationError("test error", TraceData())
955+
async def failing_activity() -> Never:
956+
raise ApplicationError("test error", dataclasses.asdict(TraceData()))
949957

950958

951959
@workflow.defn
952960
class FailureContextWorkflow:
953961
@workflow.run
954-
async def run(self) -> TraceData:
955-
try:
956-
await workflow.execute_activity(
957-
failing_activity,
958-
start_to_close_timeout=timedelta(seconds=10),
959-
retry_policy=RetryPolicy(maximum_attempts=1),
960-
)
961-
except ActivityError as e:
962-
assert isinstance(e.cause, ApplicationError) and e.cause.details
963-
return e.cause.details[0]
962+
async def run(self) -> Never:
963+
await workflow.execute_activity(
964+
failing_activity,
965+
start_to_close_timeout=timedelta(seconds=10),
966+
retry_policy=RetryPolicy(maximum_attempts=1),
967+
)
964968
raise Exception("Unreachable")
965969

966970

@@ -982,36 +986,80 @@ def to_failure(
982986
payload_converter: PayloadConverter,
983987
failure: Failure,
984988
) -> None:
985-
super().to_failure(exception, payload_converter, failure)
989+
print("🌈 to_failure")
986990
if isinstance(exception, ApplicationError) and exception.details:
987-
for detail in exception.details:
988-
if isinstance(detail, TraceData) and self.context:
989-
if isinstance(self.context, ActivitySerializationContext):
990-
detail.items.append(
991-
TraceItem(
992-
context_type="activity",
993-
in_workflow=False,
994-
method="to_payload",
995-
context=dataclasses.asdict(self.context),
996-
)
991+
if isinstance(
992+
self.context,
993+
(WorkflowSerializationContext, ActivitySerializationContext),
994+
):
995+
context_type = (
996+
"workflow"
997+
if isinstance(self.context, WorkflowSerializationContext)
998+
else "activity"
999+
)
1000+
print(
1001+
f" 🌈 to_failure appending {context_type}: {exception.details}"
1002+
)
1003+
exception.details[0]["items"].append(
1004+
dataclasses.asdict(
1005+
TraceItem(
1006+
context_type=context_type,
1007+
in_workflow=workflow.in_workflow(),
1008+
method="to_failure",
1009+
context=dataclasses.asdict(self.context),
9971010
)
1011+
)
1012+
)
1013+
else:
1014+
raise TypeError(f"self.context is {type(self.context)}")
1015+
1016+
super().to_failure(exception, payload_converter, failure)
9981017

9991018
def from_failure(
10001019
self, failure: Failure, payload_converter: PayloadConverter
10011020
) -> BaseException:
1021+
print("🌈 from_failure")
10021022
# Let the base class create the exception
10031023
exception = super().from_failure(failure, payload_converter)
1004-
# The context tracing is already in the payloads that will be decoded with context
1024+
print(f" 🌈 {exception.__class__}")
1025+
if isinstance(exception, ApplicationError) and exception.details:
1026+
if isinstance(
1027+
self.context,
1028+
(WorkflowSerializationContext, ActivitySerializationContext),
1029+
):
1030+
context_type = (
1031+
"workflow"
1032+
if isinstance(self.context, WorkflowSerializationContext)
1033+
else "activity"
1034+
)
1035+
print(
1036+
f" 🌈 from_failure appending {context_type}: {exception.details}"
1037+
)
1038+
exception.details[0]["items"].append(
1039+
dataclasses.asdict(
1040+
TraceItem(
1041+
context_type=context_type,
1042+
in_workflow=workflow.in_workflow(),
1043+
method="from_failure",
1044+
context=dataclasses.asdict(self.context),
1045+
)
1046+
)
1047+
)
1048+
else:
1049+
raise TypeError(f"self.context is {type(self.context)}")
10051050
return exception
10061051

10071052

10081053
async def test_failure_conversion_with_context(client: Client):
1054+
print()
1055+
workflow_id = str(uuid.uuid4())
10091056
task_queue = str(uuid.uuid4())
1057+
10101058
test_client = Client(
10111059
client.service_client,
10121060
namespace=client.namespace,
1013-
data_converter=DataConverter(
1014-
payload_converter_class=SerializationContextTestPayloadConverter,
1061+
data_converter=dataclasses.replace(
1062+
DataConverter.default,
10151063
failure_converter_class=ContextFailureConverter,
10161064
),
10171065
)
@@ -1022,12 +1070,73 @@ async def test_failure_conversion_with_context(client: Client):
10221070
activities=[failing_activity],
10231071
workflow_runner=UnsandboxedWorkflowRunner(),
10241072
):
1025-
result = await test_client.execute_workflow(
1026-
FailureContextWorkflow.run,
1027-
id=str(uuid.uuid4()),
1028-
task_queue=task_queue,
1029-
)
1030-
pprint(result.items)
1073+
try:
1074+
await test_client.execute_workflow(
1075+
FailureContextWorkflow.run,
1076+
id=workflow_id,
1077+
task_queue=task_queue,
1078+
)
1079+
except Exception as err:
1080+
assert isinstance(err, WorkflowFailureError)
1081+
assert isinstance(err.cause, ActivityError)
1082+
assert isinstance(err.cause.cause, ApplicationError)
1083+
pprint(err.cause.cause.details)
1084+
1085+
workflow_context = dataclasses.asdict(
1086+
WorkflowSerializationContext(
1087+
namespace="default",
1088+
workflow_id=workflow_id,
1089+
)
1090+
)
1091+
activity_context = dataclasses.asdict(
1092+
ActivitySerializationContext(
1093+
namespace="default",
1094+
workflow_id=workflow_id,
1095+
workflow_type="FailureContextWorkflow",
1096+
activity_type="failing_activity",
1097+
activity_task_queue=task_queue,
1098+
is_local=False,
1099+
)
1100+
)
1101+
# 1. Exception raised in activity
1102+
# 2. outbound activity result to_failure(act, activity_ctx) appends and serializes
1103+
# 3. -> server -> WFT -> WF
1104+
# 4. inbound activity result from_failure(wf, activity_ctx, in_wf=False) deserializes and appends
1105+
# 5. outbound wf result to_failure(wf, in_wf=True) appends and serializes
1106+
# 6. inbound wf result from_failure(client, wf_context, in_wf=False)
1107+
if False:
1108+
assert_trace(
1109+
err.cause.cause.details,
1110+
[
1111+
dataclasses.asdict(d)
1112+
for d in [
1113+
TraceItem(
1114+
context_type="activity",
1115+
context=activity_context,
1116+
in_workflow=False,
1117+
method="to_failure", # outbound activity result
1118+
),
1119+
TraceItem(
1120+
context_type="activity",
1121+
context=activity_context,
1122+
in_workflow=False,
1123+
method="from_failure", # inbound activity result
1124+
),
1125+
TraceItem(
1126+
context_type="workflow",
1127+
context=workflow_context,
1128+
in_workflow=True,
1129+
method="to_failure", # outbound workflow result
1130+
),
1131+
TraceItem(
1132+
context_type="workflow",
1133+
context=workflow_context,
1134+
in_workflow=False,
1135+
method="from_failure", # inbound workflow result
1136+
),
1137+
]
1138+
],
1139+
)
10311140

10321141

10331142
class ContextCodec(PayloadCodec, WithSerializationContext):
@@ -1193,12 +1302,16 @@ def assert_trace(trace: list[TraceItem], expected: list[TraceItem]):
11931302
history: list[str] = []
11941303
for item, expected_item in zip_longest(trace, expected):
11951304
if item is None:
1196-
raise AssertionError("Fewer items in trace than expected")
1305+
raise AssertionError(
1306+
f"Fewer items in trace than expected.\n\n History:\n{'\n'.join(history)}"
1307+
)
11971308
if expected_item is None:
1198-
raise AssertionError("More items in trace than expected")
1309+
raise AssertionError(
1310+
f"More items in trace than expected.\n\n History:\n{'\n'.join(history)}"
1311+
)
11991312
if item != expected_item:
12001313
raise AssertionError(
1201-
f"Item:\n{pformat(item)}\n\ndoes not match expected:\n\n {pformat(expected_item)}.\n\n History:\n{chr(10).join(history)}"
1314+
f"Item:\n{pformat(item)}\n\ndoes not match expected:\n\n {pformat(expected_item)}.\n\n History:\n{'\n'.join(history)}"
12021315
)
12031316
history.append(f"{item.context_type} {item.method}")
12041317

0 commit comments

Comments
 (0)