Skip to content

Commit 1a9a230

Browse files
committed
PR comments
1 parent f8398ee commit 1a9a230

File tree

3 files changed

+100
-88
lines changed

3 files changed

+100
-88
lines changed
Lines changed: 91 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -39,18 +39,18 @@ def emit_loop(
3939

4040

4141
def emit_singular(
42-
field_name: str, access_expr: str, child_method: str, check_presence: bool
42+
field_name: str, access_expr: str, child_method: str, presence_word: Optional[str]
4343
) -> str:
4444
# Helper to emit a singular field visit with presence check and optional headers guard
45-
if check_presence:
45+
if presence_word:
4646
if field_name == "headers":
4747
return f"""\
4848
if not self.skip_headers:
49-
if o.HasField("{field_name}"):
49+
{presence_word} o.HasField("{field_name}"):
5050
await self._visit_{child_method}(fs, {access_expr})"""
5151
else:
5252
return f"""\
53-
if o.HasField("{field_name}"):
53+
{presence_word} o.HasField("{field_name}"):
5454
await self._visit_{child_method}(fs, {access_expr})"""
5555
else:
5656
if field_name == "headers":
@@ -63,6 +63,67 @@ def emit_singular(
6363

6464

6565
class VisitorGenerator:
66+
def generate(self, roots: list[Descriptor]) -> str:
67+
"""
68+
Generate Python source code that, given a function f(Payload) -> Payload,
69+
applies it to every Payload contained within a WorkflowActivation tree.
70+
71+
The generated code defines async visitor functions for each reachable
72+
protobuf message type starting from WorkflowActivation, including support
73+
for repeated fields and map entries, and a convenience entrypoint
74+
function `visit`.
75+
"""
76+
77+
for r in roots:
78+
self.walk(r)
79+
80+
header = """
81+
# This file is generated by gen_payload_visitor.py. Changes should be made there.
82+
import abc
83+
from typing import Any, MutableSequence
84+
85+
from temporalio.api.common.v1.message_pb2 import Payload
86+
87+
class VisitorFunctions(abc.ABC):
88+
\"\"\"Set of functions which can be called by the visitor.
89+
Allows handling payloads as a sequence.
90+
\"\"\"
91+
@abc.abstractmethod
92+
async def visit_payload(self, payload: Payload) -> None:
93+
\"\"\"Called when encountering a single payload.\"\"\"
94+
raise NotImplementedError()
95+
96+
@abc.abstractmethod
97+
async def visit_payloads(self, payloads: MutableSequence[Payload]) -> None:
98+
\"\"\"Called when encountering multiple payloads together.\"\"\"
99+
raise NotImplementedError()
100+
101+
class PayloadVisitor:
102+
\"\"\"A visitor for payloads.
103+
Applies a function to every payload in a tree of messages.
104+
\"\"\"
105+
def __init__(
106+
self, *, skip_search_attributes: bool = False, skip_headers: bool = False
107+
):
108+
\"\"\"Creates a new payload visitor.\"\"\"
109+
self.skip_search_attributes = skip_search_attributes
110+
self.skip_headers = skip_headers
111+
112+
async def visit(
113+
self, fs: VisitorFunctions, root: Any
114+
) -> None:
115+
\"\"\"Visits the given root message with the given function.\"\"\"
116+
method_name = "_visit_" + root.DESCRIPTOR.full_name.replace(".", "_")
117+
method = getattr(self, method_name, None)
118+
if method is not None:
119+
await method(fs, root)
120+
else:
121+
raise ValueError(f"Unknown root message type: {root.DESCRIPTOR.full_name}")
122+
123+
"""
124+
125+
return header + "\n".join(self.methods)
126+
66127
def __init__(self):
67128
# Track which message descriptors have visitor methods generated
68129
self.generated: dict[str, bool] = {
@@ -71,21 +132,24 @@ def __init__(self):
71132
}
72133
self.in_progress: set[str] = set()
73134
self.methods: list[str] = [
74-
""" async def _visit_temporal_api_common_v1_Payload(self, fs, o):
75-
await fs.visit_payload(o)
135+
"""\
136+
async def _visit_temporal_api_common_v1_Payload(self, fs, o):
137+
await fs.visit_payload(o)
76138
""",
77-
""" async def _visit_temporal_api_common_v1_Payloads(self, fs, o):
78-
await fs.visit_payloads(o.payloads)
139+
"""\
140+
async def _visit_temporal_api_common_v1_Payloads(self, fs, o):
141+
await fs.visit_payloads(o.payloads)
79142
""",
80-
""" async def _visit_payload_container(self, fs, o):
81-
await fs.visit_payloads(o)
143+
"""\
144+
async def _visit_payload_container(self, fs, o):
145+
await fs.visit_payloads(o)
82146
""",
83147
]
84148

85149
def check_repeated(self, child_desc, field, iter_expr) -> Optional[str]:
86150
# Special case for repeated payloads, handle them directly
87151
if child_desc.full_name == Payload.DESCRIPTOR.full_name:
88-
return emit_singular(field.name, iter_expr, "payload_container", False)
152+
return emit_singular(field.name, iter_expr, "payload_container", None)
89153
else:
90154
child_needed = self.walk(child_desc)
91155
if child_needed:
@@ -105,7 +169,7 @@ def walk(self, desc: Descriptor) -> bool:
105169
# Break cycles; if another path proves this node needed, we'll revisit
106170
return False
107171

108-
needed = False
172+
has_payload = False
109173
self.in_progress.add(key)
110174
lines: list[str] = [f" async def _visit_{name_for(desc)}(self, fs, o):"]
111175
# If this is the SearchAttributes message, allow skipping
@@ -146,7 +210,7 @@ def walk(self, desc: Descriptor) -> bool:
146210
child_desc = val_fd.message_type
147211
child_needed = self.walk(child_desc)
148212
if child_needed:
149-
needed = True
213+
has_payload = True
150214
lines.append(
151215
emit_loop(
152216
field.name,
@@ -163,7 +227,7 @@ def walk(self, desc: Descriptor) -> bool:
163227
child_desc = key_fd.message_type
164228
child_needed = self.walk(child_desc)
165229
if child_needed:
166-
needed = True
230+
has_payload = True
167231
lines.append(
168232
emit_loop(
169233
field.name,
@@ -176,16 +240,16 @@ def walk(self, desc: Descriptor) -> bool:
176240
field.message_type, field, f"o.{field.name}"
177241
)
178242
if child is not None:
179-
needed = True
243+
has_payload = True
180244
lines.append(child)
181245
else:
182246
child_desc = field.message_type
183-
child_needed = self.walk(child_desc)
184-
needed |= child_needed
185-
if child_needed:
247+
child_has_payload = self.walk(child_desc)
248+
has_payload |= child_has_payload
249+
if child_has_payload:
186250
lines.append(
187251
emit_singular(
188-
field.name, f"o.{field.name}", name_for(child_desc), True
252+
field.name, f"o.{field.name}", name_for(child_desc), "if"
189253
)
190254
)
191255

@@ -195,81 +259,23 @@ def walk(self, desc: Descriptor) -> bool:
195259
first = True
196260
for field in fields:
197261
child_desc = field.message_type
198-
child_needed = self.walk(child_desc)
199-
needed |= child_needed
200-
if child_needed:
262+
child_has_payload = self.walk(child_desc)
263+
has_payload |= child_has_payload
264+
if child_has_payload:
201265
if_word = "if" if first else "elif"
202266
first = False
203267
line = emit_singular(
204-
field.name, f"o.{field.name}", name_for(child_desc), True
205-
).replace(" if", f" {if_word}", 1)
268+
field.name, f"o.{field.name}", name_for(child_desc), if_word
269+
)
206270
oneof_lines.append(line)
207271
if oneof_lines:
208272
lines.extend(oneof_lines)
209273

210-
self.generated[key] = needed
274+
self.generated[key] = has_payload
211275
self.in_progress.discard(key)
212-
if needed:
276+
if has_payload:
213277
self.methods.append("\n".join(lines) + "\n")
214-
return needed
215-
216-
def generate(self, roots: list[Descriptor]) -> str:
217-
"""
218-
Generate Python source code that, given a function f(Payload) -> Payload,
219-
applies it to every Payload contained within a WorkflowActivation tree.
220-
221-
The generated code defines async visitor functions for each reachable
222-
protobuf message type starting from WorkflowActivation, including support
223-
for repeated fields and map entries, and a convenience entrypoint
224-
function `visit`.
225-
"""
226-
227-
# We avoid importing google.api deps in service protos; expand by walking from
228-
# WorkflowActivationCompletion root which references many command messages.
229-
for r in roots:
230-
self.walk(r)
231-
232-
header = """
233-
import abc
234-
from typing import Any, MutableSequence
235-
236-
from temporalio.api.common.v1.message_pb2 import Payload
237-
238-
class VisitorFunctions(abc.ABC):
239-
\"\"\"Set of functions which can be called by the visitor. Allows handling payloads as a sequence.\"\"\"
240-
@abc.abstractmethod
241-
async def visit_payload(self, payload: Payload) -> None:
242-
\"\"\"Called when encountering a single payload.\"\"\"
243-
raise NotImplementedError()
244-
245-
@abc.abstractmethod
246-
async def visit_payloads(self, payloads: MutableSequence[Payload]) -> None:
247-
\"\"\"Called when encountering multiple payloads together.\"\"\"
248-
raise NotImplementedError()
249-
250-
class PayloadVisitor:
251-
\"\"\"A visitor for payloads. Applies a function to every payload in a tree of messages.\"\"\"
252-
def __init__(
253-
self, *, skip_search_attributes: bool = False, skip_headers: bool = False
254-
):
255-
\"\"\"Creates a new payload visitor.\"\"\"
256-
self.skip_search_attributes = skip_search_attributes
257-
self.skip_headers = skip_headers
258-
259-
async def visit(
260-
self, fs: VisitorFunctions, root: Any
261-
) -> None:
262-
\"\"\"Visits the given root message with the given function.\"\"\"
263-
method_name = "_visit_" + root.DESCRIPTOR.full_name.replace(".", "_")
264-
method = getattr(self, method_name, None)
265-
if method is not None:
266-
await method(fs, root)
267-
else:
268-
raise ValueError(f"Unknown root message type: {root.DESCRIPTOR.full_name}")
269-
270-
"""
271-
272-
return header + "\n".join(self.methods)
278+
return has_payload
273279

274280

275281
def write_generated_visitors_into_visitor_generated_py() -> None:

scripts/gen_protos_docker.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,5 +33,6 @@
3333
subprocess.run(["uv", "run", "poe", "format"], check=True)
3434

3535
subprocess.run(
36-
["uv", "run", os.path.join(os.getcwd(), "scripts", "gen_visitors.py")], check=True
36+
["uv", "run", os.path.join(os.getcwd(), "scripts", "gen_payload_visitor.py")],
37+
check=True,
3738
)

temporalio/bridge/_visitor.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,14 @@
1+
# This file is generated by gen_payload_visitor.py. Changes should be made there.
12
import abc
23
from typing import Any, MutableSequence
34

45
from temporalio.api.common.v1.message_pb2 import Payload
56

67

78
class VisitorFunctions(abc.ABC):
8-
"""Set of functions which can be called by the visitor. Allows handling payloads as a sequence."""
9+
"""Set of functions which can be called by the visitor.
10+
Allows handling payloads as a sequence.
11+
"""
912

1013
@abc.abstractmethod
1114
async def visit_payload(self, payload: Payload) -> None:
@@ -19,7 +22,9 @@ async def visit_payloads(self, payloads: MutableSequence[Payload]) -> None:
1922

2023

2124
class PayloadVisitor:
22-
"""A visitor for payloads. Applies a function to every payload in a tree of messages."""
25+
"""A visitor for payloads.
26+
Applies a function to every payload in a tree of messages.
27+
"""
2328

2429
def __init__(
2530
self, *, skip_search_attributes: bool = False, skip_headers: bool = False

0 commit comments

Comments
 (0)