Skip to content

Commit 6f9fd18

Browse files
committed
Allow payload mutation
1 parent 681b54a commit 6f9fd18

File tree

4 files changed

+399
-333
lines changed

4 files changed

+399
-333
lines changed

scripts/gen_visitors.py

Lines changed: 135 additions & 98 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import subprocess
22
import sys
33
from pathlib import Path
4+
from typing import Optional, Tuple
45

56
from google.protobuf.descriptor import Descriptor, FieldDescriptor
67

@@ -15,75 +16,98 @@
1516
base_dir = Path(__file__).parent.parent
1617

1718

18-
def gen_workflow_activation_payload_visitor_code() -> str:
19-
"""
20-
Generate Python source code that, given a function f(Payload) -> Payload,
21-
applies it to every Payload contained within a WorkflowActivation tree.
22-
23-
The generated code defines async visitor functions for each reachable
24-
protobuf message type starting from WorkflowActivation, including support
25-
for repeated fields and map entries, and a convenience entrypoint
26-
function `visit`.
27-
"""
28-
29-
def name_for(desc: Descriptor) -> str:
30-
# Use fully-qualified name to avoid collisions; replace dots with underscores
31-
return desc.full_name.replace(".", "_")
32-
33-
def emit_loop(
34-
field_name: str,
35-
iter_expr: str,
36-
var_name: str,
37-
child_method: str,
38-
) -> str:
39-
# Helper to emit a for-loop over a collection with optional headers guard
19+
def name_for(desc: Descriptor) -> str:
20+
# Use fully-qualified name to avoid collisions; replace dots with underscores
21+
return desc.full_name.replace(".", "_")
22+
23+
24+
def emit_loop(
25+
field_name: str,
26+
iter_expr: str,
27+
child_method: str,
28+
) -> str:
29+
# Helper to emit a for-loop over a collection with optional headers guard
30+
if field_name == "headers":
31+
return f"""\
32+
if not self.skip_headers:
33+
for v in {iter_expr}:
34+
await self._visit_{child_method}(fs, v)"""
35+
else:
36+
return f"""\
37+
for v in {iter_expr}:
38+
await self._visit_{child_method}(fs, v)"""
39+
40+
41+
def emit_singular(
42+
field_name: str, access_expr: str, child_method: str, check_presence: bool
43+
) -> str:
44+
# Helper to emit a singular field visit with presence check and optional headers guard
45+
if check_presence:
4046
if field_name == "headers":
4147
return f"""\
4248
if not self.skip_headers:
43-
for {var_name} in {iter_expr}:
44-
await self._visit_{child_method}(f, {var_name})"""
49+
if o.HasField("{field_name}"):
50+
await self._visit_{child_method}(fs, {access_expr})"""
4551
else:
4652
return f"""\
47-
for {var_name} in {iter_expr}:
48-
await self._visit_{child_method}(f, {var_name})"""
49-
50-
def emit_singular(field_name: str, access_expr: str, child_method: str) -> str:
51-
# Helper to emit a singular field visit with presence check and optional headers guard
53+
if o.HasField("{field_name}"):
54+
await self._visit_{child_method}(fs, {access_expr})"""
55+
else:
5256
if field_name == "headers":
5357
return f"""\
5458
if not self.skip_headers:
55-
if o.HasField("{field_name}"):
56-
await self._visit_{child_method}(f, {access_expr})"""
59+
await self._visit_{child_method}(fs, {access_expr})"""
5760
else:
5861
return f"""\
59-
if o.HasField("{field_name}"):
60-
await self._visit_{child_method}(f, {access_expr})"""
61-
62-
# Track which message descriptors have visitor methods generated
63-
generated: dict[str, bool] = {}
64-
in_progress: set[str] = set()
65-
methods: list[str] = []
62+
await self._visit_{child_method}(fs, {access_expr})"""
63+
64+
65+
class VisitorGenerator:
66+
def __init__(self):
67+
# Track which message descriptors have visitor methods generated
68+
self.generated: dict[str, bool] = {
69+
Payload.DESCRIPTOR.full_name: True,
70+
Payloads.DESCRIPTOR.full_name: True,
71+
}
72+
self.in_progress: set[str] = set()
73+
self.methods: list[str] = [
74+
""" async def _visit_temporal_api_common_v1_Payload(self, fs, o):
75+
await fs.visit_payload(o)
76+
""",
77+
""" async def _visit_temporal_api_common_v1_Payloads(self, fs, o):
78+
await fs.visit_payloads(o.payloads)
79+
""",
80+
""" async def _visit_payload_container(self, fs, o):
81+
await fs.visit_payloads(o)
82+
""",
83+
]
84+
85+
def check_repeated(self, child_desc, field, iter_expr) -> Optional[str]:
86+
# Special case for repeated payloads, handle them directly
87+
if child_desc.full_name == Payload.DESCRIPTOR.full_name:
88+
return emit_singular(field.name, iter_expr, "payload_container", False)
89+
else:
90+
child_needed = self.walk(child_desc)
91+
if child_needed:
92+
return emit_loop(
93+
field.name,
94+
iter_expr,
95+
name_for(child_desc),
96+
)
97+
else:
98+
return None
6699

67-
def walk(desc: Descriptor) -> bool:
100+
def walk(self, desc: Descriptor) -> bool:
68101
key = desc.full_name
69-
if key in generated:
70-
return generated[key]
71-
if key in in_progress:
102+
if key in self.generated:
103+
return self.generated[key]
104+
if key in self.in_progress:
72105
# Break cycles; if another path proves this node needed, we'll revisit
73106
return False
74107

75-
if desc.full_name == Payload.DESCRIPTOR.full_name:
76-
generated[key] = True
77-
methods.append(
78-
""" async def _visit_temporal_api_common_v1_Payload(self, f, o):
79-
o.CopyFrom(await f(o))
80-
"""
81-
)
82-
return True
83-
84108
needed = False
85-
in_progress.add(key)
86-
lines: list[str] = [f" async def _visit_{name_for(desc)}(self, f, o):"]
109+
self.in_progress.add(key)
110+
lines: list[str] = [f" async def _visit_{name_for(desc)}(self, fs, o):"]
87111
# If this is the SearchAttributes message, allow skipping
88112
if desc.full_name == SearchAttributes.DESCRIPTOR.full_name:
89113
lines.append(" if self.skip_search_attributes:")
@@ -99,91 +123,96 @@ def walk(desc: Descriptor) -> bool:
99123
field.message_type is not None
100124
and field.message_type.GetOptions().map_entry
101125
):
102-
entry_desc = field.message_type
103-
key_fd = entry_desc.fields_by_name.get("key")
104-
val_fd = entry_desc.fields_by_name.get("value")
105-
126+
val_fd = field.message_type.fields_by_name.get("value")
106127
if (
107128
val_fd is not None
108129
and val_fd.type == FieldDescriptor.TYPE_MESSAGE
109130
):
110131
child_desc = val_fd.message_type
111-
child_needed = walk(child_desc)
112-
needed |= child_needed
132+
child_needed = self.walk(child_desc)
113133
if child_needed:
134+
needed = True
114135
lines.append(
115136
emit_loop(
116137
field.name,
117138
f"o.{field.name}.values()",
118-
"v",
119139
name_for(child_desc),
120140
)
121141
)
122142

143+
key_fd = field.message_type.fields_by_name.get("key")
123144
if (
124145
key_fd is not None
125146
and key_fd.type == FieldDescriptor.TYPE_MESSAGE
126147
):
127-
key_desc = key_fd.message_type
128-
child_needed = walk(key_desc)
129-
needed |= child_needed
148+
child_desc = key_fd.message_type
149+
child_needed = self.walk(child_desc)
130150
if child_needed:
151+
needed = True
131152
lines.append(
132153
emit_loop(
133154
field.name,
134155
f"o.{field.name}.keys()",
135-
"k",
136-
name_for(key_desc),
156+
name_for(child_desc),
137157
)
138158
)
139159
else:
140-
child_desc = field.message_type
141-
child_needed = walk(child_desc)
142-
needed |= child_needed
143-
if child_needed:
144-
lines.append(
145-
emit_loop(
146-
field.name,
147-
f"o.{field.name}",
148-
"v",
149-
name_for(child_desc),
150-
)
151-
)
160+
child = self.check_repeated(
161+
field.message_type, field, f"o.{field.name}"
162+
)
163+
if child is not None:
164+
needed = True
165+
lines.append(child)
152166
else:
153167
child_desc = field.message_type
154-
child_needed = walk(child_desc)
168+
child_needed = self.walk(child_desc)
155169
needed |= child_needed
156170
if child_needed:
157171
lines.append(
158172
emit_singular(
159-
field.name, f"o.{field.name}", name_for(child_desc)
173+
field.name, f"o.{field.name}", name_for(child_desc), True
160174
)
161175
)
162176

163-
generated[key] = needed
164-
in_progress.discard(key)
177+
self.generated[key] = needed
178+
self.in_progress.discard(key)
165179
if needed:
166-
methods.append("\n".join(lines) + "\n")
180+
self.methods.append("\n".join(lines) + "\n")
167181
return needed
168182

169-
# Build root descriptors: WorkflowActivation, WorkflowActivationCompletion,
170-
# and all messages from selected API modules
171-
roots: list[Descriptor] = [
172-
WorkflowActivation.DESCRIPTOR,
173-
WorkflowActivationCompletion.DESCRIPTOR,
174-
]
183+
def generate(self, roots: list[Descriptor]) -> str:
184+
"""
185+
Generate Python source code that, given a function f(Payload) -> Payload,
186+
applies it to every Payload contained within a WorkflowActivation tree.
175187
176-
# We avoid importing google.api deps in service protos; expand by walking from
177-
# WorkflowActivationCompletion root which references many command messages.
188+
The generated code defines async visitor functions for each reachable
189+
protobuf message type starting from WorkflowActivation, including support
190+
for repeated fields and map entries, and a convenience entrypoint
191+
function `visit`.
192+
"""
178193

179-
for r in roots:
180-
walk(r)
194+
# We avoid importing google.api deps in service protos; expand by walking from
195+
# WorkflowActivationCompletion root which references many command messages.
196+
for r in roots:
197+
self.walk(r)
181198

182-
header = """
183-
from typing import Any, Awaitable, Callable
199+
header = """
200+
import abc
201+
from typing import Any, MutableSequence
184202
185203
from temporalio.api.common.v1.message_pb2 import Payload
186204
205+
class VisitorFunctions(abc.ABC):
206+
\"\"\"Set of functions which can be called by the visitor. Allows handling payloads as a sequence.\"\"\"
207+
@abc.abstractmethod
208+
async def visit_payload(self, payload: Payload) -> None:
209+
\"\"\"Called when encountering a single payload.\"\"\"
210+
raise NotImplementedError()
211+
212+
@abc.abstractmethod
213+
async def visit_payloads(self, payloads: MutableSequence[Payload]) -> None:
214+
\"\"\"Called when encountering multiple payloads together.\"\"\"
215+
raise NotImplementedError()
187216
188217
class PayloadVisitor:
189218
\"\"\"A visitor for payloads. Applies a function to every payload in a tree of messages.\"\"\"
@@ -195,25 +224,33 @@ def __init__(
195224
self.skip_headers = skip_headers
196225
197226
async def visit(
198-
self, f: Callable[[Payload], Awaitable[Payload]], root: Any
227+
self, fs: VisitorFunctions, root: Any
199228
) -> None:
200229
\"\"\"Visits the given root message with the given function.\"\"\"
201230
method_name = "_visit_" + root.DESCRIPTOR.full_name.replace(".", "_")
202231
method = getattr(self, method_name, None)
203232
if method is not None:
204-
await method(f, root)
233+
await method(fs, root)
205234
else:
206235
raise ValueError(f"Unknown root message type: {root.DESCRIPTOR.full_name}")
207236
208237
"""
209238

210-
return header + "\n".join(methods)
239+
return header + "\n".join(self.methods)
211240

212241

213242
def write_generated_visitors_into_visitor_generated_py() -> None:
214243
"""Write the generated visitor code into _visitor.py."""
215244
out_path = base_dir / "temporalio" / "bridge" / "_visitor.py"
216-
code = gen_workflow_activation_payload_visitor_code()
245+
246+
# Build root descriptors: WorkflowActivation, WorkflowActivationCompletion,
247+
# and all messages from selected API modules
248+
roots: list[Descriptor] = [
249+
WorkflowActivation.DESCRIPTOR,
250+
WorkflowActivationCompletion.DESCRIPTOR,
251+
]
252+
253+
code = VisitorGenerator().generate(roots)
217254
out_path.write_text(code)
218255

219256

0 commit comments

Comments
 (0)