Skip to content

Commit 7ffa586

Browse files
committed
Code generation work
1 parent 2407569 commit 7ffa586

File tree

12 files changed

+732
-75
lines changed

12 files changed

+732
-75
lines changed

scripts/_proto/Dockerfile

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ COPY ./ ./
1010
RUN mkdir -p ./temporalio/api
1111
RUN uv add "protobuf<4"
1212
RUN uv sync --all-extras
13+
RUN poe build-develop
1314
RUN poe gen-protos
1415

1516
CMD cp -r ./temporalio/api/* /api_new && cp -r ./temporalio/bridge/proto/* /bridge_new

scripts/gen_protos.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import tempfile
77
from functools import partial
88
from pathlib import Path
9-
from typing import List, Mapping, Optional
9+
from typing import List, Mapping
1010

1111
base_dir = Path(__file__).parent.parent
1212
proto_dir = (
@@ -201,7 +201,6 @@ def generate_protos(output_dir: Path):
201201
/ v,
202202
)
203203

204-
205204
if __name__ == "__main__":
206205
check_proto_toolchain_versions()
207206
print("Generating protos...", file=sys.stderr)

scripts/gen_visitors.py

Lines changed: 162 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,162 @@
1+
import sys
2+
from pathlib import Path
3+
4+
from google.protobuf.descriptor import Descriptor, FieldDescriptor
5+
6+
from temporalio.api.common.v1.message_pb2 import Payload, Payloads, SearchAttributes
7+
from temporalio.bridge.proto.workflow_activation.workflow_activation_pb2 import (
8+
WorkflowActivation,
9+
)
10+
from temporalio.bridge.proto.workflow_completion.workflow_completion_pb2 import (
11+
WorkflowActivationCompletion,
12+
)
13+
14+
base_dir = Path(__file__).parent.parent
15+
16+
def gen_workflow_activation_payload_visitor_code() -> str:
17+
"""
18+
Generate Python source code that, given a function f(Payload) -> Payload,
19+
applies it to every Payload contained within a WorkflowActivation tree.
20+
21+
The generated code defines async visitor functions for each reachable
22+
protobuf message type starting from WorkflowActivation, including support
23+
for repeated fields and map entries, and a convenience entrypoint
24+
function `visit_workflow_activation_payloads`.
25+
"""
26+
def name_for(desc: Descriptor) -> str:
27+
# Use fully-qualified name to avoid collisions; replace dots with underscores
28+
return desc.full_name.replace('.', '_')
29+
30+
def emit_loop(lines: list[str], field_name: str, iter_expr: str, var_name: str, child_method: str) -> None:
31+
# Helper to emit a for-loop over a collection with optional headers guard
32+
if field_name == "headers":
33+
lines.append(" if not self.skip_headers:")
34+
lines.append(f" for {var_name} in {iter_expr}:")
35+
lines.append(f" await self.visit_{child_method}(f, {var_name})")
36+
else:
37+
lines.append(f" for {var_name} in {iter_expr}:")
38+
lines.append(f" await self.visit_{child_method}(f, {var_name})")
39+
40+
def emit_singular(lines: list[str], field_name: str, access_expr: str, child_method: str) -> None:
41+
# Helper to emit a singular field visit with presence check and optional headers guard
42+
if field_name == "headers":
43+
lines.append(" if not self.skip_headers:")
44+
lines.append(f" if o.HasField('{field_name}'):")
45+
lines.append(f" await self.visit_{child_method}(f, {access_expr})")
46+
else:
47+
lines.append(f" if o.HasField('{field_name}'):")
48+
lines.append(f" await self.visit_{child_method}(f, {access_expr})")
49+
50+
# Track which message descriptors have visitor methods generated
51+
generated: dict[str, bool] = {}
52+
in_progress: set[str] = set()
53+
methods: list[str] = []
54+
55+
def walk(desc: Descriptor) -> bool:
56+
key = desc.full_name
57+
if key in generated:
58+
return generated[key]
59+
if key in in_progress:
60+
# Break cycles; if another path proves this node needed, we'll revisit
61+
return False
62+
63+
if desc.full_name == Payload.DESCRIPTOR.full_name:
64+
generated[key] = True
65+
methods.append(
66+
""" async def visit_temporal_api_common_v1_Payload(self, f, o):
67+
o.CopyFrom(await f(o))
68+
"""
69+
)
70+
return True
71+
72+
needed = False
73+
in_progress.add(key)
74+
lines: list[str] = [f" async def visit_{name_for(desc)}(self, f, o):"]
75+
# If this is the SearchAttributes message, allow skipping
76+
if desc.full_name == SearchAttributes.DESCRIPTOR.full_name:
77+
lines.append(" if self.skip_search_attributes:")
78+
lines.append(" return")
79+
80+
for field in desc.fields:
81+
if field.type != FieldDescriptor.TYPE_MESSAGE:
82+
continue
83+
84+
# Repeated fields (including maps which are represented as repeated messages)
85+
if field.label == FieldDescriptor.LABEL_REPEATED:
86+
if field.message_type is not None and field.message_type.GetOptions().map_entry:
87+
entry_desc = field.message_type
88+
key_fd = entry_desc.fields_by_name.get("key")
89+
val_fd = entry_desc.fields_by_name.get("value")
90+
91+
if val_fd is not None and val_fd.type == FieldDescriptor.TYPE_MESSAGE:
92+
child_desc = val_fd.message_type
93+
child_needed = walk(child_desc)
94+
needed |= child_needed
95+
if child_needed:
96+
emit_loop(lines, field.name, f"o.{field.name}.values()", "v", name_for(child_desc))
97+
98+
if key_fd is not None and key_fd.type == FieldDescriptor.TYPE_MESSAGE:
99+
key_desc = key_fd.message_type
100+
child_needed = walk(key_desc)
101+
needed |= child_needed
102+
if child_needed:
103+
emit_loop(lines, field.name, f"o.{field.name}.keys()", "k", name_for(key_desc))
104+
else:
105+
child_desc = field.message_type
106+
child_needed = walk(child_desc)
107+
needed |= child_needed
108+
if child_needed:
109+
emit_loop(lines, field.name, f"o.{field.name}", "v", name_for(child_desc))
110+
else:
111+
child_desc = field.message_type
112+
child_needed = walk(child_desc)
113+
needed |= child_needed
114+
if child_needed:
115+
emit_singular(lines, field.name, f"o.{field.name}", name_for(child_desc))
116+
117+
generated[key] = needed
118+
in_progress.discard(key)
119+
if needed:
120+
methods.append("\n".join(lines) + "\n")
121+
return needed
122+
123+
# Build root descriptors: WorkflowActivation, WorkflowActivationCompletion,
124+
# and all messages from selected API modules
125+
roots: list[Descriptor] = [
126+
WorkflowActivation.DESCRIPTOR,
127+
WorkflowActivationCompletion.DESCRIPTOR,
128+
]
129+
130+
# We avoid importing google.api deps in service protos; expand by walking from
131+
# WorkflowActivationCompletion root which references many command messages.
132+
133+
for r in roots:
134+
walk(r)
135+
136+
header = (
137+
"from typing import Awaitable, Callable, Any\n\n"
138+
"from temporalio.api.common.v1.message_pb2 import Payload\n\n\n"
139+
"class PayloadVisitor:\n"
140+
" def __init__(self, *, skip_search_attributes: bool = False, skip_headers: bool = False):\n"
141+
" self.skip_search_attributes = skip_search_attributes\n"
142+
" self.skip_headers = skip_headers\n\n"
143+
" async def visit(self, f: Callable[[Payload], Awaitable[Payload]], root: Any) -> None:\n"
144+
" method_name = 'visit_' + root.DESCRIPTOR.full_name.replace('.', '_')\n"
145+
" method = getattr(self, method_name, None)\n"
146+
" if method is not None:\n"
147+
" await method(f, root)\n\n"
148+
)
149+
150+
return header + "\n".join(methods)
151+
152+
153+
def write_generated_visitors_into_visitor_generated_py() -> None:
154+
"""Write the generated visitor code into visitor_generated.py."""
155+
out_path = base_dir / "temporalio" / "bridge" / "visitor_generated.py"
156+
code = gen_workflow_activation_payload_visitor_code()
157+
out_path.write_text(code)
158+
159+
if __name__ == "__main__":
160+
print("Generating temporalio/bridge/visitor_generated.py...", file=sys.stderr)
161+
write_generated_visitors_into_visitor_generated_py()
162+

temporalio/bridge/visitor.py

Lines changed: 53 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -1,53 +1,58 @@
1-
from typing import Awaitable, Callable, Any
2-
3-
from collections.abc import Mapping as AbcMapping, Sequence as AbcSequence
1+
from collections.abc import Mapping as AbcMapping
2+
from collections.abc import Sequence as AbcSequence
3+
from typing import Any, Awaitable, Callable
44

55
from google.protobuf.descriptor import FieldDescriptor
66
from google.protobuf.message import Message
77

8-
from temporalio.api.common.v1.message_pb2 import Payload
9-
10-
11-
async def visit_payloads(
12-
f: Callable[[Payload], Awaitable[Payload]], root: Any
13-
) -> None:
14-
print("Visiting object: ", type(root))
15-
if isinstance(root, Payload):
16-
print("Applying to payload: ", root)
17-
root.CopyFrom(await f(root))
18-
print("Applied to payload: ", root)
19-
elif isinstance(root, AbcMapping):
20-
for k, v in root.items():
21-
await visit_payloads(f, k)
22-
await visit_payloads(f, v)
23-
elif isinstance(root, AbcSequence) and not isinstance(
24-
root, (bytes, bytearray, str)
25-
):
26-
for o in root:
27-
await visit_payloads(f, o)
28-
elif isinstance(root, Message):
29-
await visit_message(f, root)
30-
31-
32-
async def visit_message(
33-
f: Callable[[Payload], Awaitable[Payload]], root: Message
34-
) -> None:
35-
print("Visiting Message: ", type(root))
36-
for field in root.DESCRIPTOR.fields:
37-
print("Evaluating Field: ", field.name)
38-
39-
# Repeated fields (including maps which are represented as repeated messages)
40-
if field.label == FieldDescriptor.LABEL_REPEATED:
41-
value = getattr(root, field.name)
42-
if field.message_type is not None and field.message_type.GetOptions().map_entry:
43-
for k, v in value.items():
44-
await visit_payloads(f, k)
45-
await visit_payloads(f, v)
46-
else:
47-
for item in value:
48-
await visit_payloads(f, item)
49-
else:
50-
# Only descend into singular message fields if present
51-
if field.type == FieldDescriptor.TYPE_MESSAGE and root.HasField(field.name):
8+
from temporalio.api.common.v1.message_pb2 import Payload, SearchAttributes
9+
10+
11+
class PayloadVisitor:
12+
def __init__(self, *, skip_search_attributes: bool = False, skip_headers: bool = False):
13+
self.skip_search_attributes = skip_search_attributes
14+
self.skip_headers = skip_headers
15+
16+
async def visit_payloads(
17+
self, f: Callable[[Payload], Awaitable[Payload]], root: Any
18+
) -> None:
19+
if self.skip_search_attributes and isinstance(root, SearchAttributes):
20+
return
21+
22+
if isinstance(root, Payload):
23+
root.CopyFrom(await f(root))
24+
elif isinstance(root, AbcMapping):
25+
for k, v in root.items():
26+
await self.visit_payloads(f, k)
27+
await self.visit_payloads(f, v)
28+
elif isinstance(root, AbcSequence) and not isinstance(
29+
root, (bytes, bytearray, str)
30+
):
31+
for o in root:
32+
await self.visit_payloads(f, o)
33+
elif isinstance(root, Message):
34+
await self.visit_message(f, root,)
35+
36+
37+
async def visit_message(
38+
self, f: Callable[[Payload], Awaitable[Payload]], root: Message
39+
) -> None:
40+
for field in root.DESCRIPTOR.fields:
41+
if self.skip_headers and field.name == "headers":
42+
continue
43+
44+
# Repeated fields (including maps which are represented as repeated messages)
45+
if field.label == FieldDescriptor.LABEL_REPEATED:
5246
value = getattr(root, field.name)
53-
await visit_payloads(f, value)
47+
if field.message_type is not None and field.message_type.GetOptions().map_entry:
48+
for k, v in value.items():
49+
await self.visit_payloads(f, k)
50+
await self.visit_payloads(f, v)
51+
else:
52+
for item in value:
53+
await self.visit_payloads(f, item)
54+
else:
55+
# Only descend into singular message fields if present
56+
if field.type == FieldDescriptor.TYPE_MESSAGE and root.HasField(field.name):
57+
value = getattr(root, field.name)
58+
await self.visit_payloads(f, value)

0 commit comments

Comments
 (0)