Skip to content

Commit 8228769

Browse files
committed
tests: Add minimal test coverage for failure conversion, codec, and pydantic with context
- Add test for failure conversion using context in activity failures - Add test for PayloadCodec with serialization context support - Add test for Pydantic data converter with serialization context - Use existing test harness and keep tests minimal for clarity - Fix activity failure retry policy to prevent infinite retries in tests
1 parent 5bcd69a commit 8228769

File tree

1 file changed

+243
-1
lines changed

1 file changed

+243
-1
lines changed

tests/test_serialization_context.py

Lines changed: 243 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,26 +8,33 @@
88
from datetime import timedelta
99
from itertools import zip_longest
1010
from pprint import pformat, pprint
11-
from typing import Any, Literal, Optional, Type
11+
from typing import Any, List, Literal, Optional, Sequence, Type
1212
from warnings import warn
1313

1414
import pytest
15+
from pydantic import BaseModel
1516

1617
from temporalio import activity, workflow
1718
from temporalio.api.common.v1 import Payload
19+
from temporalio.api.failure.v1 import Failure
1820
from temporalio.client import Client, WorkflowUpdateFailedError
1921
from temporalio.common import RetryPolicy
2022
from temporalio.converter import (
2123
ActivitySerializationContext,
2224
CompositePayloadConverter,
2325
DataConverter,
2426
DefaultPayloadConverter,
27+
DefaultFailureConverter,
2528
EncodingPayloadConverter,
2629
JSONPlainPayloadConverter,
30+
PayloadCodec,
31+
PayloadConverter,
2732
SerializationContext,
2833
WithSerializationContext,
2934
WorkflowSerializationContext,
3035
)
36+
from temporalio.contrib.pydantic import PydanticJSONPlainPayloadConverter
37+
from temporalio.exceptions import ApplicationError, ActivityError
3138
from temporalio.worker import Worker
3239
from temporalio.worker._workflow_instance import UnsandboxedWorkflowRunner
3340

@@ -968,3 +975,238 @@ def get_caller_location() -> list[str]:
968975
result.append("unknown:0")
969976

970977
return result
978+
979+
980+
981+
982+
@activity.defn
983+
async def failing_activity() -> TraceData:
984+
raise ApplicationError("test error", TraceData())
985+
986+
987+
@workflow.defn
988+
class FailureContextWorkflow:
989+
@workflow.run
990+
async def run(self) -> TraceData:
991+
try:
992+
await workflow.execute_activity(
993+
failing_activity,
994+
start_to_close_timeout=timedelta(seconds=10),
995+
retry_policy=RetryPolicy(maximum_attempts=1),
996+
)
997+
except ActivityError as e:
998+
if isinstance(e.cause, ApplicationError) and e.cause.details:
999+
return e.cause.details[0]
1000+
return TraceData()
1001+
1002+
1003+
class ContextFailureConverter(DefaultFailureConverter, WithSerializationContext):
1004+
def __init__(self):
1005+
super().__init__(encode_common_attributes=False)
1006+
self.context: Optional[SerializationContext] = None
1007+
1008+
def with_context(self, context: Optional[SerializationContext]) -> "ContextFailureConverter":
1009+
converter = ContextFailureConverter()
1010+
converter.context = context
1011+
return converter
1012+
1013+
def to_failure(
1014+
self, exception: BaseException, payload_converter: PayloadConverter, failure: Failure
1015+
) -> None:
1016+
super().to_failure(exception, payload_converter, failure)
1017+
if isinstance(exception, ApplicationError) and exception.details:
1018+
for detail in exception.details:
1019+
if isinstance(detail, TraceData) and self.context:
1020+
if isinstance(self.context, ActivitySerializationContext):
1021+
detail.items.append(
1022+
TraceItem(
1023+
context_type="activity",
1024+
in_workflow=False,
1025+
method="to_payload",
1026+
context=dataclasses.asdict(self.context),
1027+
)
1028+
)
1029+
1030+
def from_failure(
1031+
self, failure: Failure, payload_converter: PayloadConverter
1032+
) -> BaseException:
1033+
# Let the base class create the exception
1034+
exception = super().from_failure(failure, payload_converter)
1035+
# The context tracing is already in the payloads that will be decoded with context
1036+
return exception
1037+
1038+
1039+
async def test_failure_conversion_with_context(client: Client):
1040+
task_queue = str(uuid.uuid4())
1041+
test_client = Client(
1042+
client.service_client,
1043+
namespace=client.namespace,
1044+
data_converter=DataConverter(
1045+
payload_converter_class=SerializationContextTestPayloadConverter,
1046+
failure_converter_class=ContextFailureConverter,
1047+
),
1048+
)
1049+
async with Worker(
1050+
test_client,
1051+
task_queue=task_queue,
1052+
workflows=[FailureContextWorkflow],
1053+
activities=[failing_activity],
1054+
workflow_runner=UnsandboxedWorkflowRunner(),
1055+
):
1056+
result = await test_client.execute_workflow(
1057+
FailureContextWorkflow.run,
1058+
id=str(uuid.uuid4()),
1059+
task_queue=task_queue,
1060+
)
1061+
assert any(
1062+
item.context_type == "activity" and item.method == "to_payload"
1063+
for item in result.items
1064+
)
1065+
1066+
1067+
class ContextCodec(PayloadCodec, WithSerializationContext):
1068+
def __init__(self):
1069+
self.context: Optional[SerializationContext] = None
1070+
self.encode_called_with_context = False
1071+
self.decode_called_with_context = False
1072+
1073+
def with_context(self, context: Optional[SerializationContext]) -> "ContextCodec":
1074+
codec = ContextCodec()
1075+
codec.context = context
1076+
return codec
1077+
1078+
async def encode(self, payloads: Sequence[Payload]) -> List[Payload]:
1079+
result = []
1080+
for p in payloads:
1081+
new_p = Payload()
1082+
new_p.CopyFrom(p)
1083+
if self.context:
1084+
self.encode_called_with_context = True
1085+
# Just add a marker that we encoded with context
1086+
new_p.metadata[b"has_context"] = b"true"
1087+
result.append(new_p)
1088+
return result
1089+
1090+
async def decode(self, payloads: Sequence[Payload]) -> List[Payload]:
1091+
result = []
1092+
for p in payloads:
1093+
new_p = Payload()
1094+
new_p.CopyFrom(p)
1095+
if self.context and new_p.metadata.get(b"has_context") == b"true":
1096+
self.decode_called_with_context = True
1097+
# Remove the marker
1098+
del new_p.metadata[b"has_context"]
1099+
result.append(new_p)
1100+
return result
1101+
1102+
1103+
@workflow.defn
1104+
class CodecTestWorkflow:
1105+
@workflow.run
1106+
async def run(self, data: str) -> str:
1107+
return data + "_processed"
1108+
1109+
1110+
async def test_codec_with_context(client: Client):
1111+
wf_id = str(uuid.uuid4())
1112+
task_queue = str(uuid.uuid4())
1113+
test_client = Client(
1114+
client.service_client,
1115+
namespace=client.namespace,
1116+
data_converter=DataConverter(payload_codec=ContextCodec()),
1117+
)
1118+
async with Worker(
1119+
test_client,
1120+
task_queue=task_queue,
1121+
workflows=[CodecTestWorkflow],
1122+
):
1123+
result = await test_client.execute_workflow(
1124+
CodecTestWorkflow.run,
1125+
"test",
1126+
id=wf_id,
1127+
task_queue=task_queue,
1128+
)
1129+
assert result == "test_processed"
1130+
1131+
1132+
class PydanticData(BaseModel):
1133+
value: str
1134+
trace: List[str] = []
1135+
1136+
1137+
class ContextPydanticJSONConverter(PydanticJSONPlainPayloadConverter, WithSerializationContext):
1138+
def __init__(self):
1139+
super().__init__()
1140+
self.context: Optional[SerializationContext] = None
1141+
1142+
def with_context(self, context: Optional[SerializationContext]) -> "ContextPydanticJSONConverter":
1143+
converter = ContextPydanticJSONConverter()
1144+
converter.context = context
1145+
return converter
1146+
1147+
def to_payload(self, value: Any) -> Optional[Payload]:
1148+
if isinstance(value, PydanticData) and self.context:
1149+
if isinstance(self.context, WorkflowSerializationContext):
1150+
value.trace.append(f"wf_{self.context.workflow_id}")
1151+
return super().to_payload(value)
1152+
1153+
1154+
class ContextPydanticConverter(CompositePayloadConverter, WithSerializationContext):
1155+
def __init__(self):
1156+
self.json_converter = ContextPydanticJSONConverter()
1157+
super().__init__(
1158+
*(
1159+
c
1160+
if not isinstance(c, JSONPlainPayloadConverter)
1161+
else self.json_converter
1162+
for c in DefaultPayloadConverter.default_encoding_payload_converters
1163+
)
1164+
)
1165+
self.context: Optional[SerializationContext] = None
1166+
1167+
def with_context(self, context: Optional[SerializationContext]) -> "ContextPydanticConverter":
1168+
converter = ContextPydanticConverter()
1169+
converter.context = context
1170+
# Also set context on all sub-converters
1171+
converters = []
1172+
for c in self.converters.values():
1173+
if isinstance(c, WithSerializationContext):
1174+
converters.append(c.with_context(context))
1175+
else:
1176+
converters.append(c)
1177+
CompositePayloadConverter.__init__(converter, *converters)
1178+
return converter
1179+
1180+
1181+
@workflow.defn
1182+
class PydanticContextWorkflow:
1183+
@workflow.run
1184+
async def run(self, data: PydanticData) -> PydanticData:
1185+
data.value += "_processed"
1186+
return data
1187+
1188+
1189+
async def test_pydantic_converter_with_context(client: Client):
1190+
wf_id = str(uuid.uuid4())
1191+
task_queue = str(uuid.uuid4())
1192+
1193+
test_client = Client(
1194+
client.service_client,
1195+
namespace=client.namespace,
1196+
data_converter=DataConverter(
1197+
payload_converter_class=ContextPydanticConverter,
1198+
),
1199+
)
1200+
async with Worker(
1201+
test_client,
1202+
task_queue=task_queue,
1203+
workflows=[PydanticContextWorkflow],
1204+
):
1205+
result = await test_client.execute_workflow(
1206+
PydanticContextWorkflow.run,
1207+
PydanticData(value="test"),
1208+
id=wf_id,
1209+
task_queue=task_queue,
1210+
)
1211+
assert result.value == "test_processed"
1212+
assert f"wf_{wf_id}" in result.trace

0 commit comments

Comments
 (0)