Skip to content

Commit 1f40869

Browse files
committed
Rearrange
1 parent 71ca54c commit 1f40869

File tree

1 file changed

+82
-67
lines changed

1 file changed

+82
-67
lines changed

tests/test_serialization_context.py

Lines changed: 82 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -184,20 +184,17 @@ def __init__(self):
184184
)
185185

186186

187-
data_converter = dataclasses.replace(
188-
DataConverter.default,
189-
payload_converter_class=SerializationContextTestPayloadConverter,
190-
)
191-
192-
193187
async def test_workflow_payload_conversion(
194188
client: Client,
195189
):
196190
workflow_id = str(uuid.uuid4())
197191
task_queue = str(uuid.uuid4())
198192

199193
config = client.config()
200-
config["data_converter"] = data_converter
194+
config["data_converter"] = dataclasses.replace(
195+
DataConverter.default,
196+
payload_converter_class=SerializationContextTestPayloadConverter,
197+
)
201198
client = Client(**config)
202199

203200
async with Worker(
@@ -368,7 +365,11 @@ async def test_heartbeat_details_payload_conversion(client: Client):
368365
task_queue = str(uuid.uuid4())
369366

370367
config = client.config()
371-
config["data_converter"] = data_converter
368+
config["data_converter"] = dataclasses.replace(
369+
DataConverter.default,
370+
payload_converter_class=SerializationContextTestPayloadConverter,
371+
)
372+
372373
client = Client(**config)
373374

374375
async with Worker(
@@ -440,7 +441,11 @@ async def test_async_activity_completion_payload_conversion(
440441
task_queue = str(uuid.uuid4())
441442

442443
config = client.config()
443-
config["data_converter"] = data_converter
444+
config["data_converter"] = dataclasses.replace(
445+
DataConverter.default,
446+
payload_converter_class=SerializationContextTestPayloadConverter,
447+
)
448+
444449
client = Client(**config)
445450

446451
async with Worker(
@@ -531,7 +536,11 @@ async def test_signal_payload_conversion(
531536
task_queue = str(uuid.uuid4())
532537

533538
config = client.config()
534-
config["data_converter"] = data_converter
539+
config["data_converter"] = dataclasses.replace(
540+
DataConverter.default,
541+
payload_converter_class=SerializationContextTestPayloadConverter,
542+
)
543+
535544
custom_client = Client(**config)
536545

537546
async with Worker(
@@ -615,7 +624,10 @@ async def test_query_payload_conversion(
615624
)
616625

617626
config = client.config()
618-
config["data_converter"] = data_converter
627+
config["data_converter"] = dataclasses.replace(
628+
DataConverter.default,
629+
payload_converter_class=SerializationContextTestPayloadConverter,
630+
)
619631
custom_client = Client(**config)
620632

621633
async with Worker(
@@ -712,7 +724,10 @@ async def test_update_payload_conversion(
712724
)
713725

714726
config = client.config()
715-
config["data_converter"] = data_converter
727+
config["data_converter"] = dataclasses.replace(
728+
DataConverter.default,
729+
payload_converter_class=SerializationContextTestPayloadConverter,
730+
)
716731
custom_client = Client(**config)
717732

718733
async with Worker(
@@ -837,7 +852,10 @@ async def test_external_workflow_signal_and_cancel_payload_conversion(
837852
)
838853

839854
config = client.config()
840-
config["data_converter"] = data_converter
855+
config["data_converter"] = dataclasses.replace(
856+
DataConverter.default,
857+
payload_converter_class=SerializationContextTestPayloadConverter,
858+
)
841859
custom_client = Client(**config)
842860

843861
async with Worker(
@@ -925,56 +943,6 @@ async def test_external_workflow_signal_and_cancel_payload_conversion(
925943
# The cancel context would only be used for failure deserialization
926944

927945

928-
# Utilities
929-
930-
931-
def assert_trace(trace: list[TraceItem], expected: list[TraceItem]):
932-
if len(trace) != len(expected):
933-
warn(f"expected {len(expected)} trace items but received {len(trace)}")
934-
history: list[str] = []
935-
for item, expected_item in zip_longest(trace, expected):
936-
if item is None:
937-
raise AssertionError("Fewer items in trace than expected")
938-
if expected_item is None:
939-
raise AssertionError("More items in trace than expected")
940-
if item != expected_item:
941-
raise AssertionError(
942-
f"Item:\n{pformat(item)}\n\ndoes not match expected:\n\n {pformat(expected_item)}.\n\n History:\n{chr(10).join(history)}"
943-
)
944-
history.append(f"{item.context_type} {item.method}")
945-
946-
947-
def get_caller_location() -> list[str]:
948-
"""Get 3 stack frames starting from the first that's not in test_serialization_context.py or temporalio/converter.py."""
949-
frame = inspect.currentframe()
950-
result: list[str] = []
951-
found_first = False
952-
953-
# Walk up the stack
954-
while frame and len(result) < 3:
955-
frame = frame.f_back
956-
if not frame:
957-
break
958-
959-
file_path = frame.f_code.co_filename
960-
961-
# Skip frames from test file and converter.py until we find the first one
962-
if not found_first:
963-
if "test_serialization_context.py" in file_path:
964-
continue
965-
if file_path.endswith("temporalio/converter.py"):
966-
continue
967-
found_first = True
968-
969-
result.append(f"{file_path}:{frame.f_lineno}")
970-
971-
# Pad with "unknown:0" if we didn't get 3 frames
972-
while len(result) < 3:
973-
result.append("unknown:0")
974-
975-
return result
976-
977-
978946
@activity.defn
979947
async def failing_activity() -> TraceData:
980948
raise ApplicationError("test error", TraceData())
@@ -1059,10 +1027,7 @@ async def test_failure_conversion_with_context(client: Client):
10591027
id=str(uuid.uuid4()),
10601028
task_queue=task_queue,
10611029
)
1062-
assert any(
1063-
item.context_type == "activity" and item.method == "to_payload"
1064-
for item in result.items
1065-
)
1030+
pprint(result.items)
10661031

10671032

10681033
class ContextCodec(PayloadCodec, WithSerializationContext):
@@ -1217,3 +1182,53 @@ async def test_pydantic_converter_with_context(client: Client):
12171182
)
12181183
assert result.value == "test_processed"
12191184
assert f"wf_{wf_id}" in result.trace
1185+
1186+
1187+
# Utilities
1188+
1189+
1190+
def assert_trace(trace: list[TraceItem], expected: list[TraceItem]):
1191+
if len(trace) != len(expected):
1192+
warn(f"expected {len(expected)} trace items but received {len(trace)}")
1193+
history: list[str] = []
1194+
for item, expected_item in zip_longest(trace, expected):
1195+
if item is None:
1196+
raise AssertionError("Fewer items in trace than expected")
1197+
if expected_item is None:
1198+
raise AssertionError("More items in trace than expected")
1199+
if item != expected_item:
1200+
raise AssertionError(
1201+
f"Item:\n{pformat(item)}\n\ndoes not match expected:\n\n {pformat(expected_item)}.\n\n History:\n{chr(10).join(history)}"
1202+
)
1203+
history.append(f"{item.context_type} {item.method}")
1204+
1205+
1206+
def get_caller_location() -> list[str]:
1207+
"""Get 3 stack frames starting from the first that's not in test_serialization_context.py or temporalio/converter.py."""
1208+
frame = inspect.currentframe()
1209+
result: list[str] = []
1210+
found_first = False
1211+
1212+
# Walk up the stack
1213+
while frame and len(result) < 3:
1214+
frame = frame.f_back
1215+
if not frame:
1216+
break
1217+
1218+
file_path = frame.f_code.co_filename
1219+
1220+
# Skip frames from test file and converter.py until we find the first one
1221+
if not found_first:
1222+
if "test_serialization_context.py" in file_path:
1223+
continue
1224+
if file_path.endswith("temporalio/converter.py"):
1225+
continue
1226+
found_first = True
1227+
1228+
result.append(f"{file_path}:{frame.f_lineno}")
1229+
1230+
# Pad with "unknown:0" if we didn't get 3 frames
1231+
while len(result) < 3:
1232+
result.append("unknown:0")
1233+
1234+
return result

0 commit comments

Comments
 (0)