Skip to content

Commit d725e21

Browse files
committed
Add command seq IDs to async context during payload visitor traversal
1 parent 9021f95 commit d725e21

File tree

2 files changed

+87
-19
lines changed

2 files changed

+87
-19
lines changed

scripts/gen_payload_visitor.py

Lines changed: 46 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import subprocess
22
import sys
33
from pathlib import Path
4-
from typing import Optional, Tuple
4+
from typing import Optional
55

66
from google.protobuf.descriptor import Descriptor, FieldDescriptor
77

@@ -62,6 +62,20 @@ def emit_singular(
6262
await self._visit_{child_method}(fs, {access_expr})"""
6363

6464

65+
def emit_singular_with_seq(
66+
field_name: str, access_expr: str, child_method: str, presence_word: str
67+
) -> str:
68+
# Helper to emit a singular field visit that sets the seq contextvar, with presence check but
69+
# without headers guard since this is used for commands only.
70+
return f"""\
71+
{presence_word} o.HasField("{field_name}"):
72+
token = current_command_seq.set({access_expr}.seq)
73+
try:
74+
await self._visit_{child_method}(fs, {access_expr})
75+
finally:
76+
current_command_seq.reset(token)"""
77+
78+
6579
class VisitorGenerator:
6680
def generate(self, roots: list[Descriptor]) -> str:
6781
"""
@@ -80,10 +94,16 @@ def generate(self, roots: list[Descriptor]) -> str:
8094
header = """
8195
# This file is generated by gen_payload_visitor.py. Changes should be made there.
8296
import abc
97+
import contextvars
8398
from typing import Any, MutableSequence
8499
85100
from temporalio.api.common.v1.message_pb2 import Payload
86101
102+
# Context variable for tracking the current workflow command sequence number
103+
current_command_seq: contextvars.ContextVar[int | None] = contextvars.ContextVar(
104+
"current_command_seq", default=None
105+
)
106+
87107
class VisitorFunctions(abc.ABC):
88108
\"\"\"Set of functions which can be called by the visitor.
89109
Allows handling payloads as a sequence.
@@ -253,6 +273,20 @@ def walk(self, desc: Descriptor) -> bool:
253273
)
254274
)
255275

276+
commands_with_seq = {
277+
"start_timer",
278+
"cancel_timer",
279+
"schedule_activity",
280+
"schedule_local_activity",
281+
"request_cancel_activity",
282+
"request_cancel_local_activity",
283+
"start_child_workflow_execution",
284+
"request_cancel_external_workflow_execution",
285+
"signal_external_workflow_execution",
286+
"cancel_signal_workflow",
287+
"schedule_nexus_operation",
288+
"request_cancel_nexus_operation",
289+
}
256290
# Process oneof fields as if/elif chains
257291
for oneof_idx, fields in oneof_fields.items():
258292
oneof_lines = []
@@ -264,9 +298,17 @@ def walk(self, desc: Descriptor) -> bool:
264298
if child_has_payload:
265299
if_word = "if" if first else "elif"
266300
first = False
267-
line = emit_singular(
268-
field.name, f"o.{field.name}", name_for(child_desc), if_word
269-
)
301+
if (
302+
desc.full_name == "coresdk.workflow_commands.WorkflowCommand"
303+
and field.name in commands_with_seq
304+
):
305+
line = emit_singular_with_seq(
306+
field.name, f"o.{field.name}", name_for(child_desc), if_word
307+
)
308+
else:
309+
line = emit_singular(
310+
field.name, f"o.{field.name}", name_for(child_desc), if_word
311+
)
270312
oneof_lines.append(line)
271313
if oneof_lines:
272314
lines.extend(oneof_lines)

temporalio/bridge/_visitor.py

Lines changed: 41 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,15 @@
11
# This file is generated by gen_payload_visitor.py. Changes should be made there.
22
import abc
3+
import contextvars
34
from typing import Any, MutableSequence
45

56
from temporalio.api.common.v1.message_pb2 import Payload
67

8+
# Context variable for tracking the current workflow command sequence number
9+
current_command_seq: contextvars.ContextVar[int | None] = contextvars.ContextVar(
10+
"current_command_seq", default=None
11+
)
12+
713

814
class VisitorFunctions(abc.ABC):
915
"""Set of functions which can be called by the visitor.
@@ -371,9 +377,13 @@ async def _visit_coresdk_workflow_commands_WorkflowCommand(self, fs, o):
371377
if o.HasField("user_metadata"):
372378
await self._visit_temporal_api_sdk_v1_UserMetadata(fs, o.user_metadata)
373379
if o.HasField("schedule_activity"):
374-
await self._visit_coresdk_workflow_commands_ScheduleActivity(
375-
fs, o.schedule_activity
376-
)
380+
token = current_command_seq.set(o.schedule_activity.seq)
381+
try:
382+
await self._visit_coresdk_workflow_commands_ScheduleActivity(
383+
fs, o.schedule_activity
384+
)
385+
finally:
386+
current_command_seq.reset(token)
377387
elif o.HasField("respond_to_query"):
378388
await self._visit_coresdk_workflow_commands_QueryResult(
379389
fs, o.respond_to_query
@@ -391,17 +401,29 @@ async def _visit_coresdk_workflow_commands_WorkflowCommand(self, fs, o):
391401
fs, o.continue_as_new_workflow_execution
392402
)
393403
elif o.HasField("start_child_workflow_execution"):
394-
await self._visit_coresdk_workflow_commands_StartChildWorkflowExecution(
395-
fs, o.start_child_workflow_execution
396-
)
404+
token = current_command_seq.set(o.start_child_workflow_execution.seq)
405+
try:
406+
await self._visit_coresdk_workflow_commands_StartChildWorkflowExecution(
407+
fs, o.start_child_workflow_execution
408+
)
409+
finally:
410+
current_command_seq.reset(token)
397411
elif o.HasField("signal_external_workflow_execution"):
398-
await self._visit_coresdk_workflow_commands_SignalExternalWorkflowExecution(
399-
fs, o.signal_external_workflow_execution
400-
)
412+
token = current_command_seq.set(o.signal_external_workflow_execution.seq)
413+
try:
414+
await self._visit_coresdk_workflow_commands_SignalExternalWorkflowExecution(
415+
fs, o.signal_external_workflow_execution
416+
)
417+
finally:
418+
current_command_seq.reset(token)
401419
elif o.HasField("schedule_local_activity"):
402-
await self._visit_coresdk_workflow_commands_ScheduleLocalActivity(
403-
fs, o.schedule_local_activity
404-
)
420+
token = current_command_seq.set(o.schedule_local_activity.seq)
421+
try:
422+
await self._visit_coresdk_workflow_commands_ScheduleLocalActivity(
423+
fs, o.schedule_local_activity
424+
)
425+
finally:
426+
current_command_seq.reset(token)
405427
elif o.HasField("upsert_workflow_search_attributes"):
406428
await self._visit_coresdk_workflow_commands_UpsertWorkflowSearchAttributes(
407429
fs, o.upsert_workflow_search_attributes
@@ -415,9 +437,13 @@ async def _visit_coresdk_workflow_commands_WorkflowCommand(self, fs, o):
415437
fs, o.update_response
416438
)
417439
elif o.HasField("schedule_nexus_operation"):
418-
await self._visit_coresdk_workflow_commands_ScheduleNexusOperation(
419-
fs, o.schedule_nexus_operation
420-
)
440+
token = current_command_seq.set(o.schedule_nexus_operation.seq)
441+
try:
442+
await self._visit_coresdk_workflow_commands_ScheduleNexusOperation(
443+
fs, o.schedule_nexus_operation
444+
)
445+
finally:
446+
current_command_seq.reset(token)
421447

422448
async def _visit_coresdk_workflow_completion_Success(self, fs, o):
423449
for v in o.commands:

0 commit comments

Comments
 (0)