Skip to content

Commit ef5981c

Browse files
committed
test codec
1 parent d254c1a commit ef5981c

File tree

1 file changed

+84
-44
lines changed

1 file changed

+84
-44
lines changed

tests/test_serialization_context.py

Lines changed: 84 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,8 @@ class TraceItem:
4848
"from_payload",
4949
"to_failure",
5050
"from_failure",
51+
"encode",
52+
"decode",
5153
]
5254
context: dict[str, Any]
5355
in_workflow: bool
@@ -846,10 +848,10 @@ async def test_external_workflow_signal_and_cancel_payload_conversion(
846848
DataConverter.default,
847849
payload_converter_class=SerializationContextTestPayloadConverter,
848850
)
849-
custom_client = Client(**config)
851+
client = Client(**config)
850852

851853
async with Worker(
852-
custom_client,
854+
client,
853855
task_queue=task_queue,
854856
workflows=[
855857
ExternalWorkflowTarget,
@@ -860,13 +862,13 @@ async def test_external_workflow_signal_and_cancel_payload_conversion(
860862
workflow_runner=UnsandboxedWorkflowRunner(), # so that we can use isinstance
861863
):
862864
# Test external signal
863-
target_handle = await custom_client.start_workflow(
865+
target_handle = await client.start_workflow(
864866
ExternalWorkflowTarget.run,
865867
id=target_workflow_id,
866868
task_queue=task_queue,
867869
)
868870

869-
signaler_handle = await custom_client.start_workflow(
871+
signaler_handle = await client.start_workflow(
870872
ExternalWorkflowSignaler.run,
871873
args=[target_workflow_id, TraceData()],
872874
id=signaler_workflow_id,
@@ -953,7 +955,7 @@ async def run(self) -> Never:
953955
raise Exception("Unreachable")
954956

955957

956-
failure_converter_test_trace: dict[str, list[TraceItem]] = defaultdict(list)
958+
test_traces: dict[str, list[TraceItem]] = defaultdict(list)
957959

958960

959961
class FailureConverterWithContext(DefaultFailureConverter, WithSerializationContext):
@@ -981,7 +983,7 @@ def to_failure(
981983
else:
982984
raise TypeError(f"self.context is {type(self.context)}")
983985

984-
failure_converter_test_trace[self.context.workflow_id].append(
986+
test_traces[self.context.workflow_id].append(
985987
TraceItem(
986988
context_type=context_type,
987989
in_workflow=workflow.in_workflow(),
@@ -1002,7 +1004,7 @@ def from_failure(
10021004
else:
10031005
raise TypeError(f"self.context is {type(self.context)}")
10041006

1005-
failure_converter_test_trace[self.context.workflow_id].append(
1007+
test_traces[self.context.workflow_id].append(
10061008
TraceItem(
10071009
context_type=context_type,
10081010
in_workflow=workflow.in_workflow(),
@@ -1022,20 +1024,19 @@ async def test_failure_converter_with_context(client: Client):
10221024
DataConverter.default,
10231025
failure_converter_class=FailureConverterWithContext,
10241026
)
1025-
test_client = Client(
1026-
client.service_client,
1027-
namespace=client.namespace,
1028-
data_converter=data_converter,
1029-
)
1027+
config = client.config()
1028+
config["data_converter"] = data_converter
1029+
client = Client(**config)
1030+
10301031
async with Worker(
1031-
test_client,
1032+
client,
10321033
task_queue=task_queue,
10331034
workflows=[FailureConverterTestWorkflow],
10341035
activities=[failing_activity],
10351036
workflow_runner=UnsandboxedWorkflowRunner(),
10361037
):
10371038
try:
1038-
await test_client.execute_workflow(
1039+
await client.execute_workflow(
10391040
FailureConverterTestWorkflow.run,
10401041
id=workflow_id,
10411042
task_queue=task_queue,
@@ -1063,7 +1064,7 @@ async def test_failure_converter_with_context(client: Client):
10631064
)
10641065
)
10651066
assert_trace(
1066-
failure_converter_test_trace[workflow_id],
1067+
test_traces[workflow_id],
10671068
[
10681069
TraceItem(
10691070
context_type="activity",
@@ -1103,7 +1104,7 @@ async def test_failure_converter_with_context(client: Client):
11031104
* 2 # from_failure deserializes the error and error cause
11041105
),
11051106
)
1106-
del failure_converter_test_trace[workflow_id]
1107+
del test_traces[workflow_id]
11071108

11081109

11091110
class PayloadCodecWithContext(PayloadCodec, WithSerializationContext):
@@ -1120,26 +1121,30 @@ def with_context(
11201121
return codec
11211122

11221123
async def encode(self, payloads: Sequence[Payload]) -> List[Payload]:
1123-
result = []
1124-
for p in payloads:
1125-
new_p = Payload()
1126-
new_p.CopyFrom(p)
1127-
if self.context:
1128-
self.encode_called_with_context = True
1129-
new_p.metadata["has_context"] = b"true"
1130-
result.append(new_p)
1131-
return result
1124+
assert self.context
1125+
assert isinstance(self.context, WorkflowSerializationContext)
1126+
test_traces[self.context.workflow_id].append(
1127+
TraceItem(
1128+
context_type="workflow",
1129+
context=dataclasses.asdict(self.context),
1130+
method="encode",
1131+
in_workflow=workflow.in_workflow(),
1132+
)
1133+
)
1134+
return list(payloads)
11321135

11331136
async def decode(self, payloads: Sequence[Payload]) -> List[Payload]:
1134-
result = []
1135-
for p in payloads:
1136-
new_p = Payload()
1137-
new_p.CopyFrom(p)
1138-
if self.context and new_p.metadata.get("has_context") == b"true":
1139-
self.decode_called_with_context = True
1140-
del new_p.metadata["has_context"]
1141-
result.append(new_p)
1142-
return result
1137+
assert self.context
1138+
assert isinstance(self.context, WorkflowSerializationContext)
1139+
test_traces[self.context.workflow_id].append(
1140+
TraceItem(
1141+
context_type="workflow",
1142+
context=dataclasses.asdict(self.context),
1143+
method="decode",
1144+
in_workflow=workflow.in_workflow(),
1145+
)
1146+
)
1147+
return list(payloads)
11431148

11441149

11451150
@workflow.defn
@@ -1150,26 +1155,61 @@ async def run(self, data: str) -> str:
11501155

11511156

11521157
async def test_codec_with_context(client: Client):
1153-
wf_id = str(uuid.uuid4())
1158+
workflow_id = str(uuid.uuid4())
11541159
task_queue = str(uuid.uuid4())
1155-
test_client = Client(
1156-
client.service_client,
1157-
namespace=client.namespace,
1158-
data_converter=dataclasses.replace(
1159-
DataConverter.default, payload_codec=PayloadCodecWithContext()
1160-
),
1160+
1161+
client_config = client.config()
1162+
client_config["data_converter"] = dataclasses.replace(
1163+
DataConverter.default, payload_codec=PayloadCodecWithContext()
11611164
)
1165+
client = Client(**client_config)
11621166
async with Worker(
1163-
test_client,
1167+
client,
11641168
task_queue=task_queue,
11651169
workflows=[CodecTestWorkflow],
11661170
):
1167-
await test_client.execute_workflow(
1171+
await client.execute_workflow(
11681172
CodecTestWorkflow.run,
11691173
"data",
1170-
id=wf_id,
1174+
id=workflow_id,
11711175
task_queue=task_queue,
11721176
)
1177+
workflow_context = dataclasses.asdict(
1178+
WorkflowSerializationContext(
1179+
namespace=client.namespace,
1180+
workflow_id=workflow_id,
1181+
)
1182+
)
1183+
assert_trace(
1184+
test_traces[workflow_id],
1185+
[
1186+
TraceItem(
1187+
context_type="workflow",
1188+
context=workflow_context,
1189+
method="encode",
1190+
in_workflow=False,
1191+
),
1192+
TraceItem(
1193+
context_type="workflow",
1194+
context=workflow_context,
1195+
method="decode",
1196+
in_workflow=False,
1197+
),
1198+
TraceItem(
1199+
context_type="workflow",
1200+
context=workflow_context,
1201+
method="encode",
1202+
in_workflow=False,
1203+
),
1204+
TraceItem(
1205+
context_type="workflow",
1206+
context=workflow_context,
1207+
method="decode",
1208+
in_workflow=False,
1209+
),
1210+
],
1211+
)
1212+
del test_traces[workflow_id]
11731213

11741214

11751215
# Pydantic

0 commit comments

Comments
 (0)