|
8 | 8 | from datetime import timedelta |
9 | 9 | from itertools import zip_longest |
10 | 10 | from pprint import pformat, pprint |
11 | | -from typing import Any, Literal, Optional, Type |
| 11 | +from typing import Any, List, Literal, Optional, Sequence, Type |
12 | 12 | from warnings import warn |
13 | 13 |
|
14 | 14 | import pytest |
| 15 | +from pydantic import BaseModel |
15 | 16 |
|
16 | 17 | from temporalio import activity, workflow |
17 | 18 | from temporalio.api.common.v1 import Payload |
| 19 | +from temporalio.api.failure.v1 import Failure |
18 | 20 | from temporalio.client import Client, WorkflowUpdateFailedError |
19 | 21 | from temporalio.common import RetryPolicy |
20 | 22 | from temporalio.converter import ( |
21 | 23 | ActivitySerializationContext, |
22 | 24 | CompositePayloadConverter, |
23 | 25 | DataConverter, |
24 | 26 | DefaultPayloadConverter, |
| 27 | + DefaultFailureConverter, |
25 | 28 | EncodingPayloadConverter, |
26 | 29 | JSONPlainPayloadConverter, |
| 30 | + PayloadCodec, |
| 31 | + PayloadConverter, |
27 | 32 | SerializationContext, |
28 | 33 | WithSerializationContext, |
29 | 34 | WorkflowSerializationContext, |
30 | 35 | ) |
| 36 | +from temporalio.contrib.pydantic import PydanticJSONPlainPayloadConverter |
| 37 | +from temporalio.exceptions import ApplicationError, ActivityError |
31 | 38 | from temporalio.worker import Worker |
32 | 39 | from temporalio.worker._workflow_instance import UnsandboxedWorkflowRunner |
33 | 40 |
|
@@ -968,3 +975,238 @@ def get_caller_location() -> list[str]: |
968 | 975 | result.append("unknown:0") |
969 | 976 |
|
970 | 977 | return result |
| 978 | + |
| 979 | + |
| 980 | + |
| 981 | + |
| 982 | +@activity.defn |
| 983 | +async def failing_activity() -> TraceData: |
| 984 | + raise ApplicationError("test error", TraceData()) |
| 985 | + |
| 986 | + |
| 987 | +@workflow.defn |
| 988 | +class FailureContextWorkflow: |
| 989 | + @workflow.run |
| 990 | + async def run(self) -> TraceData: |
| 991 | + try: |
| 992 | + await workflow.execute_activity( |
| 993 | + failing_activity, |
| 994 | + start_to_close_timeout=timedelta(seconds=10), |
| 995 | + retry_policy=RetryPolicy(maximum_attempts=1), |
| 996 | + ) |
| 997 | + except ActivityError as e: |
| 998 | + if isinstance(e.cause, ApplicationError) and e.cause.details: |
| 999 | + return e.cause.details[0] |
| 1000 | + return TraceData() |
| 1001 | + |
| 1002 | + |
| 1003 | +class ContextFailureConverter(DefaultFailureConverter, WithSerializationContext): |
| 1004 | + def __init__(self): |
| 1005 | + super().__init__(encode_common_attributes=False) |
| 1006 | + self.context: Optional[SerializationContext] = None |
| 1007 | + |
| 1008 | + def with_context(self, context: Optional[SerializationContext]) -> "ContextFailureConverter": |
| 1009 | + converter = ContextFailureConverter() |
| 1010 | + converter.context = context |
| 1011 | + return converter |
| 1012 | + |
| 1013 | + def to_failure( |
| 1014 | + self, exception: BaseException, payload_converter: PayloadConverter, failure: Failure |
| 1015 | + ) -> None: |
| 1016 | + super().to_failure(exception, payload_converter, failure) |
| 1017 | + if isinstance(exception, ApplicationError) and exception.details: |
| 1018 | + for detail in exception.details: |
| 1019 | + if isinstance(detail, TraceData) and self.context: |
| 1020 | + if isinstance(self.context, ActivitySerializationContext): |
| 1021 | + detail.items.append( |
| 1022 | + TraceItem( |
| 1023 | + context_type="activity", |
| 1024 | + in_workflow=False, |
| 1025 | + method="to_payload", |
| 1026 | + context=dataclasses.asdict(self.context), |
| 1027 | + ) |
| 1028 | + ) |
| 1029 | + |
| 1030 | + def from_failure( |
| 1031 | + self, failure: Failure, payload_converter: PayloadConverter |
| 1032 | + ) -> BaseException: |
| 1033 | + # Let the base class create the exception |
| 1034 | + exception = super().from_failure(failure, payload_converter) |
| 1035 | + # The context tracing is already in the payloads that will be decoded with context |
| 1036 | + return exception |
| 1037 | + |
| 1038 | + |
| 1039 | +async def test_failure_conversion_with_context(client: Client): |
| 1040 | + task_queue = str(uuid.uuid4()) |
| 1041 | + test_client = Client( |
| 1042 | + client.service_client, |
| 1043 | + namespace=client.namespace, |
| 1044 | + data_converter=DataConverter( |
| 1045 | + payload_converter_class=SerializationContextTestPayloadConverter, |
| 1046 | + failure_converter_class=ContextFailureConverter, |
| 1047 | + ), |
| 1048 | + ) |
| 1049 | + async with Worker( |
| 1050 | + test_client, |
| 1051 | + task_queue=task_queue, |
| 1052 | + workflows=[FailureContextWorkflow], |
| 1053 | + activities=[failing_activity], |
| 1054 | + workflow_runner=UnsandboxedWorkflowRunner(), |
| 1055 | + ): |
| 1056 | + result = await test_client.execute_workflow( |
| 1057 | + FailureContextWorkflow.run, |
| 1058 | + id=str(uuid.uuid4()), |
| 1059 | + task_queue=task_queue, |
| 1060 | + ) |
| 1061 | + assert any( |
| 1062 | + item.context_type == "activity" and item.method == "to_payload" |
| 1063 | + for item in result.items |
| 1064 | + ) |
| 1065 | + |
| 1066 | + |
| 1067 | +class ContextCodec(PayloadCodec, WithSerializationContext): |
| 1068 | + def __init__(self): |
| 1069 | + self.context: Optional[SerializationContext] = None |
| 1070 | + self.encode_called_with_context = False |
| 1071 | + self.decode_called_with_context = False |
| 1072 | + |
| 1073 | + def with_context(self, context: Optional[SerializationContext]) -> "ContextCodec": |
| 1074 | + codec = ContextCodec() |
| 1075 | + codec.context = context |
| 1076 | + return codec |
| 1077 | + |
| 1078 | + async def encode(self, payloads: Sequence[Payload]) -> List[Payload]: |
| 1079 | + result = [] |
| 1080 | + for p in payloads: |
| 1081 | + new_p = Payload() |
| 1082 | + new_p.CopyFrom(p) |
| 1083 | + if self.context: |
| 1084 | + self.encode_called_with_context = True |
| 1085 | + # Just add a marker that we encoded with context |
| 1086 | + new_p.metadata[b"has_context"] = b"true" |
| 1087 | + result.append(new_p) |
| 1088 | + return result |
| 1089 | + |
| 1090 | + async def decode(self, payloads: Sequence[Payload]) -> List[Payload]: |
| 1091 | + result = [] |
| 1092 | + for p in payloads: |
| 1093 | + new_p = Payload() |
| 1094 | + new_p.CopyFrom(p) |
| 1095 | + if self.context and new_p.metadata.get(b"has_context") == b"true": |
| 1096 | + self.decode_called_with_context = True |
| 1097 | + # Remove the marker |
| 1098 | + del new_p.metadata[b"has_context"] |
| 1099 | + result.append(new_p) |
| 1100 | + return result |
| 1101 | + |
| 1102 | + |
| 1103 | +@workflow.defn |
| 1104 | +class CodecTestWorkflow: |
| 1105 | + @workflow.run |
| 1106 | + async def run(self, data: str) -> str: |
| 1107 | + return data + "_processed" |
| 1108 | + |
| 1109 | + |
| 1110 | +async def test_codec_with_context(client: Client): |
| 1111 | + wf_id = str(uuid.uuid4()) |
| 1112 | + task_queue = str(uuid.uuid4()) |
| 1113 | + test_client = Client( |
| 1114 | + client.service_client, |
| 1115 | + namespace=client.namespace, |
| 1116 | + data_converter=DataConverter(payload_codec=ContextCodec()), |
| 1117 | + ) |
| 1118 | + async with Worker( |
| 1119 | + test_client, |
| 1120 | + task_queue=task_queue, |
| 1121 | + workflows=[CodecTestWorkflow], |
| 1122 | + ): |
| 1123 | + result = await test_client.execute_workflow( |
| 1124 | + CodecTestWorkflow.run, |
| 1125 | + "test", |
| 1126 | + id=wf_id, |
| 1127 | + task_queue=task_queue, |
| 1128 | + ) |
| 1129 | + assert result == "test_processed" |
| 1130 | + |
| 1131 | + |
| 1132 | +class PydanticData(BaseModel): |
| 1133 | + value: str |
| 1134 | + trace: List[str] = [] |
| 1135 | + |
| 1136 | + |
| 1137 | +class ContextPydanticJSONConverter(PydanticJSONPlainPayloadConverter, WithSerializationContext): |
| 1138 | + def __init__(self): |
| 1139 | + super().__init__() |
| 1140 | + self.context: Optional[SerializationContext] = None |
| 1141 | + |
| 1142 | + def with_context(self, context: Optional[SerializationContext]) -> "ContextPydanticJSONConverter": |
| 1143 | + converter = ContextPydanticJSONConverter() |
| 1144 | + converter.context = context |
| 1145 | + return converter |
| 1146 | + |
| 1147 | + def to_payload(self, value: Any) -> Optional[Payload]: |
| 1148 | + if isinstance(value, PydanticData) and self.context: |
| 1149 | + if isinstance(self.context, WorkflowSerializationContext): |
| 1150 | + value.trace.append(f"wf_{self.context.workflow_id}") |
| 1151 | + return super().to_payload(value) |
| 1152 | + |
| 1153 | + |
| 1154 | +class ContextPydanticConverter(CompositePayloadConverter, WithSerializationContext): |
| 1155 | + def __init__(self): |
| 1156 | + self.json_converter = ContextPydanticJSONConverter() |
| 1157 | + super().__init__( |
| 1158 | + *( |
| 1159 | + c |
| 1160 | + if not isinstance(c, JSONPlainPayloadConverter) |
| 1161 | + else self.json_converter |
| 1162 | + for c in DefaultPayloadConverter.default_encoding_payload_converters |
| 1163 | + ) |
| 1164 | + ) |
| 1165 | + self.context: Optional[SerializationContext] = None |
| 1166 | + |
| 1167 | + def with_context(self, context: Optional[SerializationContext]) -> "ContextPydanticConverter": |
| 1168 | + converter = ContextPydanticConverter() |
| 1169 | + converter.context = context |
| 1170 | + # Also set context on all sub-converters |
| 1171 | + converters = [] |
| 1172 | + for c in self.converters.values(): |
| 1173 | + if isinstance(c, WithSerializationContext): |
| 1174 | + converters.append(c.with_context(context)) |
| 1175 | + else: |
| 1176 | + converters.append(c) |
| 1177 | + CompositePayloadConverter.__init__(converter, *converters) |
| 1178 | + return converter |
| 1179 | + |
| 1180 | + |
| 1181 | +@workflow.defn |
| 1182 | +class PydanticContextWorkflow: |
| 1183 | + @workflow.run |
| 1184 | + async def run(self, data: PydanticData) -> PydanticData: |
| 1185 | + data.value += "_processed" |
| 1186 | + return data |
| 1187 | + |
| 1188 | + |
| 1189 | +async def test_pydantic_converter_with_context(client: Client): |
| 1190 | + wf_id = str(uuid.uuid4()) |
| 1191 | + task_queue = str(uuid.uuid4()) |
| 1192 | + |
| 1193 | + test_client = Client( |
| 1194 | + client.service_client, |
| 1195 | + namespace=client.namespace, |
| 1196 | + data_converter=DataConverter( |
| 1197 | + payload_converter_class=ContextPydanticConverter, |
| 1198 | + ), |
| 1199 | + ) |
| 1200 | + async with Worker( |
| 1201 | + test_client, |
| 1202 | + task_queue=task_queue, |
| 1203 | + workflows=[PydanticContextWorkflow], |
| 1204 | + ): |
| 1205 | + result = await test_client.execute_workflow( |
| 1206 | + PydanticContextWorkflow.run, |
| 1207 | + PydanticData(value="test"), |
| 1208 | + id=wf_id, |
| 1209 | + task_queue=task_queue, |
| 1210 | + ) |
| 1211 | + assert result.value == "test_processed" |
| 1212 | + assert f"wf_{wf_id}" in result.trace |
0 commit comments