Skip to content

Commit accdc29

Browse files
committed
Refactor test utility
1 parent bdeb3d1 commit accdc29

File tree

2 files changed

+36
-22
lines changed

2 files changed

+36
-22
lines changed

tests/helpers/__init__.py

Lines changed: 34 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import time
44
import uuid
55
from contextlib import closing
6+
from dataclasses import dataclass
67
from datetime import datetime, timedelta, timezone
78
from typing import Any, Awaitable, Callable, Optional, Sequence, Type, TypeVar, Union
89

@@ -300,6 +301,14 @@ async def print_history(handle: WorkflowHandle):
300301
i += 1
301302

302303

304+
@dataclass
305+
class InterleavedHistoryEvent:
306+
handle: WorkflowHandle
307+
event: Union[HistoryEvent, str]
308+
number: Optional[int]
309+
time: datetime
310+
311+
303312
async def print_interleaved_histories(
304313
handles: list[WorkflowHandle],
305314
extra_events: Optional[list[tuple[WorkflowHandle, str, datetime]]] = None,
@@ -313,32 +322,34 @@ async def print_interleaved_histories(
313322
314323
where <elapsed_ms> is the number of milliseconds since the first event in any of the workflows.
315324
"""
316-
all_events: list[
317-
tuple[WorkflowHandle, Union[HistoryEvent, str], Optional[int], datetime]
318-
] = []
325+
all_events: list[InterleavedHistoryEvent] = []
319326
workflow_start_times: dict[WorkflowHandle, datetime] = {}
320327

321328
for handle in handles:
322329
event_num = 1
323330
first_event = True
324-
async for event in handle.fetch_history_events():
325-
event_time = event.event_time.ToDatetime()
331+
async for history_event in handle.fetch_history_events():
332+
event_time = history_event.event_time.ToDatetime()
326333
if first_event:
327334
workflow_start_times[handle] = event_time
328335
first_event = False
329-
all_events.append((handle, event, event_num, event_time))
336+
all_events.append(
337+
InterleavedHistoryEvent(handle, history_event, event_num, event_time)
338+
)
330339
event_num += 1
331340

332341
if extra_events:
333342
for handle, event_str, event_time in extra_events:
334343
# Ensure timezone-naive
335344
if event_time.tzinfo is not None:
336345
event_time = event_time.astimezone(timezone.utc).replace(tzinfo=None)
337-
all_events.append((handle, event_str, None, event_time))
346+
all_events.append(
347+
InterleavedHistoryEvent(handle, event_str, None, event_time)
348+
)
338349

339350
zero_time = min(workflow_start_times.values())
340351

341-
all_events.sort(key=lambda item: item[3])
352+
all_events.sort(key=lambda item: item.time)
342353
col_width = 50
343354

344355
def _format_row(items: list[str], truncate: bool = False) -> str:
@@ -350,30 +361,32 @@ def _format_row(items: list[str], truncate: bool = False) -> str:
350361
print("\n" + _format_row(headers, truncate=True))
351362
print("-" * (col_width * len(handles) + len(handles) - 1))
352363

353-
for handle, event, event_num, event_time in all_events:
354-
elapsed_ms = int((event_time - zero_time).total_seconds() * 1000)
364+
for event in all_events:
365+
elapsed_ms = int((event.time - zero_time).total_seconds() * 1000)
355366

356-
if isinstance(event, str):
357-
event_desc = f" *: {elapsed_ms:>4} {event}"
367+
if isinstance(event.event, str):
368+
event_desc = f" *: {elapsed_ms:>4} {event.event}"
358369
summary = None
359370
else:
360-
event_type = EventType.Name(event.event_type).removeprefix("EVENT_TYPE_")
361-
event_desc = f"{event_num:2}: {elapsed_ms:>4} {event_type}"
371+
event_type = EventType.Name(event.event.event_type).removeprefix(
372+
"EVENT_TYPE_"
373+
)
374+
event_desc = f"{event.number:2}: {elapsed_ms:>4} {event_type}"
362375

363376
# Extract summary from user_metadata if present
364377
summary = None
365-
if event.HasField("user_metadata") and event.user_metadata.HasField(
366-
"summary"
367-
):
378+
if event.event.HasField(
379+
"user_metadata"
380+
) and event.event.user_metadata.HasField("summary"):
368381
try:
369382
summary = DataConverter.default.payload_converter.from_payload(
370-
event.user_metadata.summary
383+
event.event.user_metadata.summary
371384
)
372385
except Exception:
373386
pass # Ignore decoding errors
374387

375388
row = [""] * len(handles)
376-
col_idx = handles.index(handle)
389+
col_idx = handles.index(event.handle)
377390
row[col_idx] = event_desc[: col_width - 3]
378391
print(_format_row(row))
379392

@@ -382,8 +395,8 @@ def _format_row(items: list[str], truncate: bool = False) -> str:
382395
summary_row = [""] * len(handles)
383396
# Left-align with event type name (after "<event_num>: <elapsed_ms> ")
384397
# Calculate the padding needed
385-
if event_num is not None:
386-
padding = len(f"{event_num:2}: {elapsed_ms:>4} ")
398+
if event.number is not None:
399+
padding = len(f"{event.number:2}: {elapsed_ms:>4} ")
387400
else:
388401
padding = len(f" *: {elapsed_ms:>4} ")
389402
summary_row[col_idx] = f"{' ' * padding}[{summary}]"[: col_width - 3]

tests/nexus/test_workflow_caller_cancellation_types_when_cancel_handler_fails.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,8 @@ async def start(
8181
async def cancel(
8282
self, ctx: nexusrpc.handler.CancelOperationContext, token: str
8383
) -> None:
84-
handler_wf = nexus.client().get_workflow_handle_for(
84+
client = nexus.client()
85+
handler_wf = client.get_workflow_handle_for(
8586
HandlerWorkflow.run,
8687
workflow_id=nexus.WorkflowHandle[None].from_token(token).workflow_id,
8788
)

0 commit comments

Comments
 (0)