Skip to content

Commit 634c5ec

Browse files
authored
💥 Add generic payload visitor for WorkflowActivation[Completion] (#1075)
* Reflection based visitor * Code generation work * Linting * Cleanup * Cleanup * Fix generator method name * Update gen_visitors.py * Allow payload mutation * Optimize codec visitor * Add warning about encode_failure * Update to use elif on oneof fields * Generate visitors during proto generation * Use os path join in script * PR comments
1 parent 706c89b commit 634c5ec

File tree

7 files changed

+1030
-239
lines changed

7 files changed

+1030
-239
lines changed

scripts/gen_payload_visitor.py

Lines changed: 299 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,299 @@
1+
import subprocess
2+
import sys
3+
from pathlib import Path
4+
from typing import Optional, Tuple
5+
6+
from google.protobuf.descriptor import Descriptor, FieldDescriptor
7+
8+
from temporalio.api.common.v1.message_pb2 import Payload, Payloads, SearchAttributes
9+
from temporalio.bridge.proto.workflow_activation.workflow_activation_pb2 import (
10+
WorkflowActivation,
11+
)
12+
from temporalio.bridge.proto.workflow_completion.workflow_completion_pb2 import (
13+
WorkflowActivationCompletion,
14+
)
15+
16+
base_dir = Path(__file__).parent.parent
17+
18+
19+
def name_for(desc: Descriptor) -> str:
20+
# Use fully-qualified name to avoid collisions; replace dots with underscores
21+
return desc.full_name.replace(".", "_")
22+
23+
24+
def emit_loop(
25+
field_name: str,
26+
iter_expr: str,
27+
child_method: str,
28+
) -> str:
29+
# Helper to emit a for-loop over a collection with optional headers guard
30+
if field_name == "headers":
31+
return f"""\
32+
if not self.skip_headers:
33+
for v in {iter_expr}:
34+
await self._visit_{child_method}(fs, v)"""
35+
else:
36+
return f"""\
37+
for v in {iter_expr}:
38+
await self._visit_{child_method}(fs, v)"""
39+
40+
41+
def emit_singular(
42+
field_name: str, access_expr: str, child_method: str, presence_word: Optional[str]
43+
) -> str:
44+
# Helper to emit a singular field visit with presence check and optional headers guard
45+
if presence_word:
46+
if field_name == "headers":
47+
return f"""\
48+
if not self.skip_headers:
49+
{presence_word} o.HasField("{field_name}"):
50+
await self._visit_{child_method}(fs, {access_expr})"""
51+
else:
52+
return f"""\
53+
{presence_word} o.HasField("{field_name}"):
54+
await self._visit_{child_method}(fs, {access_expr})"""
55+
else:
56+
if field_name == "headers":
57+
return f"""\
58+
if not self.skip_headers:
59+
await self._visit_{child_method}(fs, {access_expr})"""
60+
else:
61+
return f"""\
62+
await self._visit_{child_method}(fs, {access_expr})"""
63+
64+
65+
class VisitorGenerator:
66+
def generate(self, roots: list[Descriptor]) -> str:
67+
"""
68+
Generate Python source code that, given a function f(Payload) -> Payload,
69+
applies it to every Payload contained within a WorkflowActivation tree.
70+
71+
The generated code defines async visitor functions for each reachable
72+
protobuf message type starting from WorkflowActivation, including support
73+
for repeated fields and map entries, and a convenience entrypoint
74+
function `visit`.
75+
"""
76+
77+
for r in roots:
78+
self.walk(r)
79+
80+
header = """
81+
# This file is generated by gen_payload_visitor.py. Changes should be made there.
82+
import abc
83+
from typing import Any, MutableSequence
84+
85+
from temporalio.api.common.v1.message_pb2 import Payload
86+
87+
class VisitorFunctions(abc.ABC):
88+
\"\"\"Set of functions which can be called by the visitor.
89+
Allows handling payloads as a sequence.
90+
\"\"\"
91+
@abc.abstractmethod
92+
async def visit_payload(self, payload: Payload) -> None:
93+
\"\"\"Called when encountering a single payload.\"\"\"
94+
raise NotImplementedError()
95+
96+
@abc.abstractmethod
97+
async def visit_payloads(self, payloads: MutableSequence[Payload]) -> None:
98+
\"\"\"Called when encountering multiple payloads together.\"\"\"
99+
raise NotImplementedError()
100+
101+
class PayloadVisitor:
102+
\"\"\"A visitor for payloads.
103+
Applies a function to every payload in a tree of messages.
104+
\"\"\"
105+
def __init__(
106+
self, *, skip_search_attributes: bool = False, skip_headers: bool = False
107+
):
108+
\"\"\"Creates a new payload visitor.\"\"\"
109+
self.skip_search_attributes = skip_search_attributes
110+
self.skip_headers = skip_headers
111+
112+
async def visit(
113+
self, fs: VisitorFunctions, root: Any
114+
) -> None:
115+
\"\"\"Visits the given root message with the given function.\"\"\"
116+
method_name = "_visit_" + root.DESCRIPTOR.full_name.replace(".", "_")
117+
method = getattr(self, method_name, None)
118+
if method is not None:
119+
await method(fs, root)
120+
else:
121+
raise ValueError(f"Unknown root message type: {root.DESCRIPTOR.full_name}")
122+
123+
"""
124+
125+
return header + "\n".join(self.methods)
126+
127+
def __init__(self):
128+
# Track which message descriptors have visitor methods generated
129+
self.generated: dict[str, bool] = {
130+
Payload.DESCRIPTOR.full_name: True,
131+
Payloads.DESCRIPTOR.full_name: True,
132+
}
133+
self.in_progress: set[str] = set()
134+
self.methods: list[str] = [
135+
"""\
136+
async def _visit_temporal_api_common_v1_Payload(self, fs, o):
137+
await fs.visit_payload(o)
138+
""",
139+
"""\
140+
async def _visit_temporal_api_common_v1_Payloads(self, fs, o):
141+
await fs.visit_payloads(o.payloads)
142+
""",
143+
"""\
144+
async def _visit_payload_container(self, fs, o):
145+
await fs.visit_payloads(o)
146+
""",
147+
]
148+
149+
def check_repeated(self, child_desc, field, iter_expr) -> Optional[str]:
150+
# Special case for repeated payloads, handle them directly
151+
if child_desc.full_name == Payload.DESCRIPTOR.full_name:
152+
return emit_singular(field.name, iter_expr, "payload_container", None)
153+
else:
154+
child_needed = self.walk(child_desc)
155+
if child_needed:
156+
return emit_loop(
157+
field.name,
158+
iter_expr,
159+
name_for(child_desc),
160+
)
161+
else:
162+
return None
163+
164+
def walk(self, desc: Descriptor) -> bool:
165+
key = desc.full_name
166+
if key in self.generated:
167+
return self.generated[key]
168+
if key in self.in_progress:
169+
# Break cycles; if another path proves this node needed, we'll revisit
170+
return False
171+
172+
has_payload = False
173+
self.in_progress.add(key)
174+
lines: list[str] = [f" async def _visit_{name_for(desc)}(self, fs, o):"]
175+
# If this is the SearchAttributes message, allow skipping
176+
if desc.full_name == SearchAttributes.DESCRIPTOR.full_name:
177+
lines.append(" if self.skip_search_attributes:")
178+
lines.append(" return")
179+
180+
# Group fields by oneof to generate if/elif chains
181+
oneof_fields: dict[int, list[FieldDescriptor]] = {}
182+
regular_fields: list[FieldDescriptor] = []
183+
184+
for field in desc.fields:
185+
if field.type != FieldDescriptor.TYPE_MESSAGE:
186+
continue
187+
188+
# Skip synthetic oneofs (proto3 optional fields)
189+
if field.containing_oneof is not None:
190+
oneof_idx = field.containing_oneof.index
191+
if oneof_idx not in oneof_fields:
192+
oneof_fields[oneof_idx] = []
193+
oneof_fields[oneof_idx].append(field)
194+
else:
195+
regular_fields.append(field)
196+
197+
# Process regular fields first
198+
for field in regular_fields:
199+
# Repeated fields (including maps which are represented as repeated messages)
200+
if field.label == FieldDescriptor.LABEL_REPEATED:
201+
if (
202+
field.message_type is not None
203+
and field.message_type.GetOptions().map_entry
204+
):
205+
val_fd = field.message_type.fields_by_name.get("value")
206+
if (
207+
val_fd is not None
208+
and val_fd.type == FieldDescriptor.TYPE_MESSAGE
209+
):
210+
child_desc = val_fd.message_type
211+
child_needed = self.walk(child_desc)
212+
if child_needed:
213+
has_payload = True
214+
lines.append(
215+
emit_loop(
216+
field.name,
217+
f"o.{field.name}.values()",
218+
name_for(child_desc),
219+
)
220+
)
221+
222+
key_fd = field.message_type.fields_by_name.get("key")
223+
if (
224+
key_fd is not None
225+
and key_fd.type == FieldDescriptor.TYPE_MESSAGE
226+
):
227+
child_desc = key_fd.message_type
228+
child_needed = self.walk(child_desc)
229+
if child_needed:
230+
has_payload = True
231+
lines.append(
232+
emit_loop(
233+
field.name,
234+
f"o.{field.name}.keys()",
235+
name_for(child_desc),
236+
)
237+
)
238+
else:
239+
child = self.check_repeated(
240+
field.message_type, field, f"o.{field.name}"
241+
)
242+
if child is not None:
243+
has_payload = True
244+
lines.append(child)
245+
else:
246+
child_desc = field.message_type
247+
child_has_payload = self.walk(child_desc)
248+
has_payload |= child_has_payload
249+
if child_has_payload:
250+
lines.append(
251+
emit_singular(
252+
field.name, f"o.{field.name}", name_for(child_desc), "if"
253+
)
254+
)
255+
256+
# Process oneof fields as if/elif chains
257+
for oneof_idx, fields in oneof_fields.items():
258+
oneof_lines = []
259+
first = True
260+
for field in fields:
261+
child_desc = field.message_type
262+
child_has_payload = self.walk(child_desc)
263+
has_payload |= child_has_payload
264+
if child_has_payload:
265+
if_word = "if" if first else "elif"
266+
first = False
267+
line = emit_singular(
268+
field.name, f"o.{field.name}", name_for(child_desc), if_word
269+
)
270+
oneof_lines.append(line)
271+
if oneof_lines:
272+
lines.extend(oneof_lines)
273+
274+
self.generated[key] = has_payload
275+
self.in_progress.discard(key)
276+
if has_payload:
277+
self.methods.append("\n".join(lines) + "\n")
278+
return has_payload
279+
280+
281+
def write_generated_visitors_into_visitor_generated_py() -> None:
282+
"""Write the generated visitor code into _visitor.py."""
283+
out_path = base_dir / "temporalio" / "bridge" / "_visitor.py"
284+
285+
# Build root descriptors: WorkflowActivation, WorkflowActivationCompletion,
286+
# and all messages from selected API modules
287+
roots: list[Descriptor] = [
288+
WorkflowActivation.DESCRIPTOR,
289+
WorkflowActivationCompletion.DESCRIPTOR,
290+
]
291+
292+
code = VisitorGenerator().generate(roots)
293+
out_path.write_text(code)
294+
295+
296+
if __name__ == "__main__":
297+
print("Generating temporalio/bridge/_visitor.py...", file=sys.stderr)
298+
write_generated_visitors_into_visitor_generated_py()
299+
subprocess.run(["uv", "run", "ruff", "format", "temporalio/bridge/_visitor.py"])

scripts/gen_protos.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import tempfile
77
from functools import partial
88
from pathlib import Path
9-
from typing import List, Mapping, Optional
9+
from typing import List, Mapping
1010

1111
base_dir = Path(__file__).parent.parent
1212
proto_dir = (

scripts/gen_protos_docker.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,14 @@
33

44
# Build the Docker image and capture its ID
55
result = subprocess.run(
6-
["docker", "build", "-q", "-f", "scripts/_proto/Dockerfile", "."],
6+
[
7+
"docker",
8+
"build",
9+
"-q",
10+
"-f",
11+
os.path.join("scripts", "_proto", "Dockerfile"),
12+
".",
13+
],
714
capture_output=True,
815
text=True,
916
check=True,
@@ -16,11 +23,16 @@
1623
"run",
1724
"--rm",
1825
"-v",
19-
f"{os.getcwd()}/temporalio/api:/api_new",
26+
os.path.join(os.getcwd(), "temporalio", "api") + ":/api_new",
2027
"-v",
21-
f"{os.getcwd()}/temporalio/bridge/proto:/bridge_new",
28+
os.path.join(os.getcwd(), "temporalio", "bridge", "proto") + ":/bridge_new",
2229
image_id,
2330
],
2431
check=True,
2532
)
2633
subprocess.run(["uv", "run", "poe", "format"], check=True)
34+
35+
subprocess.run(
36+
["uv", "run", os.path.join(os.getcwd(), "scripts", "gen_payload_visitor.py")],
37+
check=True,
38+
)

0 commit comments

Comments
 (0)