Skip to content

Commit 379ab27

Browse files
committed
Cleanup
1 parent 00ef528 commit 379ab27

File tree

9 files changed

+507
-522
lines changed

9 files changed

+507
-522
lines changed

scripts/gen_visitors.py

Lines changed: 73 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import subprocess
12
import sys
23
from pathlib import Path
34

@@ -30,38 +31,33 @@ def name_for(desc: Descriptor) -> str:
3031
return desc.full_name.replace(".", "_")
3132

3233
def emit_loop(
33-
lines: list[str],
3434
field_name: str,
3535
iter_expr: str,
3636
var_name: str,
3737
child_method: str,
38-
) -> None:
38+
) -> str:
3939
# Helper to emit a for-loop over a collection with optional headers guard
4040
if field_name == "headers":
41-
lines.append(" if not self.skip_headers:")
42-
lines.append(f" for {var_name} in {iter_expr}:")
43-
lines.append(
44-
f" await self.visit_{child_method}(f, {var_name})"
45-
)
41+
return f"""\
42+
if not self.skip_headers:
43+
for {var_name} in {iter_expr}:
44+
await self._visit_{child_method}(f, {var_name})"""
4645
else:
47-
lines.append(f" for {var_name} in {iter_expr}:")
48-
lines.append(f" await self.visit_{child_method}(f, {var_name})")
46+
return f"""\
47+
for {var_name} in {iter_expr}:
48+
await self._visit_{child_method}(f, {var_name})"""
4949

50-
def emit_singular(
51-
lines: list[str], field_name: str, access_expr: str, child_method: str
52-
) -> None:
50+
def emit_singular(field_name: str, access_expr: str, child_method: str) -> str:
5351
# Helper to emit a singular field visit with presence check and optional headers guard
5452
if field_name == "headers":
55-
lines.append(" if not self.skip_headers:")
56-
lines.append(f" if o.HasField('{field_name}'):")
57-
lines.append(
58-
f" await self.visit_{child_method}(f, {access_expr})"
59-
)
53+
return f"""\
54+
if not self.skip_headers:
55+
if o.HasField("{field_name}"):
56+
await self._visit_{child_method}(f, {access_expr})"""
6057
else:
61-
lines.append(f" if o.HasField('{field_name}'):")
62-
lines.append(
63-
f" await self.visit_{child_method}(f, {access_expr})"
64-
)
58+
return f"""\
59+
if o.HasField("{field_name}"):
60+
await self._visit_{child_method}(f, {access_expr})"""
6561

6662
# Track which message descriptors have visitor methods generated
6763
generated: dict[str, bool] = {}
@@ -79,15 +75,15 @@ def walk(desc: Descriptor) -> bool:
7975
if desc.full_name == Payload.DESCRIPTOR.full_name:
8076
generated[key] = True
8177
methods.append(
82-
""" async def visit_temporal_api_common_v1_Payload(self, f, o):
78+
""" async def _visit_temporal_api_common_v1_Payload(self, f, o):
8379
o.CopyFrom(await f(o))
8480
"""
8581
)
8682
return True
8783

8884
needed = False
8985
in_progress.add(key)
90-
lines: list[str] = [f" async def visit_{name_for(desc)}(self, f, o):"]
86+
lines: list[str] = [f" async def _visit_{name_for(desc)}(self, f, o):"]
9187
# If this is the SearchAttributes message, allow skipping
9288
if desc.full_name == SearchAttributes.DESCRIPTOR.full_name:
9389
lines.append(" if self.skip_search_attributes:")
@@ -115,12 +111,13 @@ def walk(desc: Descriptor) -> bool:
115111
child_needed = walk(child_desc)
116112
needed |= child_needed
117113
if child_needed:
118-
emit_loop(
119-
lines,
120-
field.name,
121-
f"o.{field.name}.values()",
122-
"v",
123-
name_for(child_desc),
114+
lines.append(
115+
emit_loop(
116+
field.name,
117+
f"o.{field.name}.values()",
118+
"v",
119+
name_for(child_desc),
120+
)
124121
)
125122

126123
if (
@@ -131,32 +128,36 @@ def walk(desc: Descriptor) -> bool:
131128
child_needed = walk(key_desc)
132129
needed |= child_needed
133130
if child_needed:
134-
emit_loop(
135-
lines,
136-
field.name,
137-
f"o.{field.name}.keys()",
138-
"k",
139-
name_for(key_desc),
131+
lines.append(
132+
emit_loop(
133+
field.name,
134+
f"o.{field.name}.keys()",
135+
"k",
136+
name_for(key_desc),
137+
)
140138
)
141139
else:
142140
child_desc = field.message_type
143141
child_needed = walk(child_desc)
144142
needed |= child_needed
145143
if child_needed:
146-
emit_loop(
147-
lines,
148-
field.name,
149-
f"o.{field.name}",
150-
"v",
151-
name_for(child_desc),
144+
lines.append(
145+
emit_loop(
146+
field.name,
147+
f"o.{field.name}",
148+
"v",
149+
name_for(child_desc),
150+
)
152151
)
153152
else:
154153
child_desc = field.message_type
155154
child_needed = walk(child_desc)
156155
needed |= child_needed
157156
if child_needed:
158-
emit_singular(
159-
lines, field.name, f"o.{field.name}", name_for(child_desc)
157+
lines.append(
158+
emit_singular(
159+
field.name, f"o.{field.name}", name_for(child_desc)
160+
)
160161
)
161162

162163
generated[key] = needed
@@ -178,30 +179,43 @@ def walk(desc: Descriptor) -> bool:
178179
for r in roots:
179180
walk(r)
180181

181-
header = (
182-
"from typing import Any, Awaitable, Callable\n\n"
183-
"from temporalio.api.common.v1.message_pb2 import Payload\n\n\n"
184-
"class PayloadVisitor:\n"
185-
" def __init__(self, *, skip_search_attributes: bool = False, skip_headers: bool = False):\n"
186-
" self.skip_search_attributes = skip_search_attributes\n"
187-
" self.skip_headers = skip_headers\n\n"
188-
" async def visit(self, f: Callable[[Payload], Awaitable[Payload]], root: Any) -> None:\n"
189-
" method_name = 'visit_' + root.DESCRIPTOR.full_name.replace('.', '_')\n"
190-
" method = getattr(self, method_name, None)\n"
191-
" if method is not None:\n"
192-
" await method(f, root)\n\n"
193-
)
182+
header = """
183+
from typing import Any, Awaitable, Callable
184+
185+
from temporalio.api.common.v1.message_pb2 import Payload
186+
187+
188+
class PayloadVisitor:
189+
\"\"\"A visitor for payloads. Applies a function to every payload in a tree of messages.\"\"\"
190+
def __init__(
191+
self, *, skip_search_attributes: bool = False, skip_headers: bool = False
192+
):
193+
\"\"\"Creates a new payload visitor.\"\"\"
194+
self.skip_search_attributes = skip_search_attributes
195+
self.skip_headers = skip_headers
196+
197+
async def visit(
198+
self, f: Callable[[Payload], Awaitable[Payload]], root: Any
199+
) -> None:
200+
\"\"\"Visits the given root message with the given function.\"\"\"
201+
method_name = "visit_" + root.DESCRIPTOR.full_name.replace(".", "_")
202+
method = getattr(self, method_name, None)
203+
if method is not None:
204+
await method(f, root)
205+
206+
"""
194207

195208
return header + "\n".join(methods)
196209

197210

198211
def write_generated_visitors_into_visitor_generated_py() -> None:
199-
"""Write the generated visitor code into visitor_generated.py."""
200-
out_path = base_dir / "temporalio" / "bridge" / "visitor_generated.py"
212+
"""Write the generated visitor code into _visitor.py."""
213+
out_path = base_dir / "temporalio" / "bridge" / "_visitor.py"
201214
code = gen_workflow_activation_payload_visitor_code()
202215
out_path.write_text(code)
203216

204217

205218
if __name__ == "__main__":
206-
print("Generating temporalio/bridge/visitor_generated.py...", file=sys.stderr)
219+
print("Generating temporalio/bridge/_visitor.py...", file=sys.stderr)
207220
write_generated_visitors_into_visitor_generated_py()
221+
subprocess.run(["uv", "run", "ruff", "format", "temporalio/bridge/_visitor.py"])

0 commit comments

Comments
 (0)