11import subprocess
22import sys
33from pathlib import Path
4+ from typing import Optional , Tuple
45
56from google .protobuf .descriptor import Descriptor , FieldDescriptor
67
1516base_dir = Path (__file__ ).parent .parent
1617
1718
18- def gen_workflow_activation_payload_visitor_code () -> str :
19- """
20- Generate Python source code that, given a function f(Payload) -> Payload,
21- applies it to every Payload contained within a WorkflowActivation tree.
22-
23- The generated code defines async visitor functions for each reachable
24- protobuf message type starting from WorkflowActivation, including support
25- for repeated fields and map entries, and a convenience entrypoint
26- function `visit`.
27- """
28-
29- def name_for (desc : Descriptor ) -> str :
30- # Use fully-qualified name to avoid collisions; replace dots with underscores
31- return desc .full_name .replace ("." , "_" )
32-
33- def emit_loop (
34- field_name : str ,
35- iter_expr : str ,
36- var_name : str ,
37- child_method : str ,
38- ) -> str :
39- # Helper to emit a for-loop over a collection with optional headers guard
19+ def name_for (desc : Descriptor ) -> str :
20+ # Use fully-qualified name to avoid collisions; replace dots with underscores
21+ return desc .full_name .replace ("." , "_" )
22+
23+
24+ def emit_loop (
25+ field_name : str ,
26+ iter_expr : str ,
27+ child_method : str ,
28+ ) -> str :
29+ # Helper to emit a for-loop over a collection with optional headers guard
30+ if field_name == "headers" :
31+ return f"""\
32+ if not self.skip_headers:
33+ for v in { iter_expr } :
34+ await self._visit_{ child_method } (fs, v)"""
35+ else :
36+ return f"""\
37+ for v in { iter_expr } :
38+ await self._visit_{ child_method } (fs, v)"""
39+
40+
41+ def emit_singular (
42+ field_name : str , access_expr : str , child_method : str , check_presence : bool
43+ ) -> str :
44+ # Helper to emit a singular field visit with presence check and optional headers guard
45+ if check_presence :
4046 if field_name == "headers" :
4147 return f"""\
4248 if not self.skip_headers:
43- for { var_name } in { iter_expr } :
44- await self._visit_{ child_method } (f , { var_name } )"""
49+ if o.HasField(" { field_name } ") :
50+ await self._visit_{ child_method } (fs , { access_expr } )"""
4551 else :
4652 return f"""\
47- for { var_name } in { iter_expr } :
48- await self._visit_{ child_method } (f, { var_name } )"""
49-
50- def emit_singular (field_name : str , access_expr : str , child_method : str ) -> str :
51- # Helper to emit a singular field visit with presence check and optional headers guard
53+ if o.HasField("{ field_name } "):
54+ await self._visit_{ child_method } (fs, { access_expr } )"""
55+ else :
5256 if field_name == "headers" :
5357 return f"""\
5458 if not self.skip_headers:
55- if o.HasField("{ field_name } "):
56- await self._visit_{ child_method } (f, { access_expr } )"""
59+ await self._visit_{ child_method } (fs, { access_expr } )"""
5760 else :
5861 return f"""\
59- if o.HasField("{ field_name } "):
60- await self._visit_{ child_method } (f, { access_expr } )"""
61-
62- # Track which message descriptors have visitor methods generated
63- generated : dict [str , bool ] = {}
64- in_progress : set [str ] = set ()
65- methods : list [str ] = []
62+ await self._visit_{ child_method } (fs, { access_expr } )"""
63+
64+
65+ class VisitorGenerator :
66+ def __init__ (self ):
67+ # Track which message descriptors have visitor methods generated
68+ self .generated : dict [str , bool ] = {
69+ Payload .DESCRIPTOR .full_name : True ,
70+ Payloads .DESCRIPTOR .full_name : True ,
71+ }
72+ self .in_progress : set [str ] = set ()
73+ self .methods : list [str ] = [
74+ """ async def _visit_temporal_api_common_v1_Payload(self, fs, o):
75+ await fs.visit_payload(o)
76+ """ ,
77+ """ async def _visit_temporal_api_common_v1_Payloads(self, fs, o):
78+ await fs.visit_payloads(o.payloads)
79+ """ ,
80+ """ async def _visit_payload_container(self, fs, o):
81+ await fs.visit_payloads(o)
82+ """ ,
83+ ]
84+
85+ def check_repeated (self , child_desc , field , iter_expr ) -> Optional [str ]:
86+ # Special case for repeated payloads, handle them directly
87+ if child_desc .full_name == Payload .DESCRIPTOR .full_name :
88+ return emit_singular (field .name , iter_expr , "payload_container" , False )
89+ else :
90+ child_needed = self .walk (child_desc )
91+ if child_needed :
92+ return emit_loop (
93+ field .name ,
94+ iter_expr ,
95+ name_for (child_desc ),
96+ )
97+ else :
98+ return None
6699
67- def walk (desc : Descriptor ) -> bool :
100+ def walk (self , desc : Descriptor ) -> bool :
68101 key = desc .full_name
69- if key in generated :
70- return generated [key ]
71- if key in in_progress :
102+ if key in self . generated :
103+ return self . generated [key ]
104+ if key in self . in_progress :
72105 # Break cycles; if another path proves this node needed, we'll revisit
73106 return False
74107
75- if desc .full_name == Payload .DESCRIPTOR .full_name :
76- generated [key ] = True
77- methods .append (
78- """ async def _visit_temporal_api_common_v1_Payload(self, f, o):
79- o.CopyFrom(await f(o))
80- """
81- )
82- return True
83-
84108 needed = False
85- in_progress .add (key )
86- lines : list [str ] = [f" async def _visit_{ name_for (desc )} (self, f , o):" ]
109+ self . in_progress .add (key )
110+ lines : list [str ] = [f" async def _visit_{ name_for (desc )} (self, fs , o):" ]
87111 # If this is the SearchAttributes message, allow skipping
88112 if desc .full_name == SearchAttributes .DESCRIPTOR .full_name :
89113 lines .append (" if self.skip_search_attributes:" )
@@ -99,91 +123,96 @@ def walk(desc: Descriptor) -> bool:
99123 field .message_type is not None
100124 and field .message_type .GetOptions ().map_entry
101125 ):
102- entry_desc = field .message_type
103- key_fd = entry_desc .fields_by_name .get ("key" )
104- val_fd = entry_desc .fields_by_name .get ("value" )
105-
126+ val_fd = field .message_type .fields_by_name .get ("value" )
106127 if (
107128 val_fd is not None
108129 and val_fd .type == FieldDescriptor .TYPE_MESSAGE
109130 ):
110131 child_desc = val_fd .message_type
111- child_needed = walk (child_desc )
112- needed |= child_needed
132+ child_needed = self .walk (child_desc )
113133 if child_needed :
134+ needed = True
114135 lines .append (
115136 emit_loop (
116137 field .name ,
117138 f"o.{ field .name } .values()" ,
118- "v" ,
119139 name_for (child_desc ),
120140 )
121141 )
122142
143+ key_fd = field .message_type .fields_by_name .get ("key" )
123144 if (
124145 key_fd is not None
125146 and key_fd .type == FieldDescriptor .TYPE_MESSAGE
126147 ):
127- key_desc = key_fd .message_type
128- child_needed = walk (key_desc )
129- needed |= child_needed
148+ child_desc = key_fd .message_type
149+ child_needed = self .walk (child_desc )
130150 if child_needed :
151+ needed = True
131152 lines .append (
132153 emit_loop (
133154 field .name ,
134155 f"o.{ field .name } .keys()" ,
135- "k" ,
136- name_for (key_desc ),
156+ name_for (child_desc ),
137157 )
138158 )
139159 else :
140- child_desc = field .message_type
141- child_needed = walk (child_desc )
142- needed |= child_needed
143- if child_needed :
144- lines .append (
145- emit_loop (
146- field .name ,
147- f"o.{ field .name } " ,
148- "v" ,
149- name_for (child_desc ),
150- )
151- )
160+ child = self .check_repeated (
161+ field .message_type , field , f"o.{ field .name } "
162+ )
163+ if child is not None :
164+ needed = True
165+ lines .append (child )
152166 else :
153167 child_desc = field .message_type
154- child_needed = walk (child_desc )
168+ child_needed = self . walk (child_desc )
155169 needed |= child_needed
156170 if child_needed :
157171 lines .append (
158172 emit_singular (
159- field .name , f"o.{ field .name } " , name_for (child_desc )
173+ field .name , f"o.{ field .name } " , name_for (child_desc ), True
160174 )
161175 )
162176
163- generated [key ] = needed
164- in_progress .discard (key )
177+ self . generated [key ] = needed
178+ self . in_progress .discard (key )
165179 if needed :
166- methods .append ("\n " .join (lines ) + "\n " )
180+ self . methods .append ("\n " .join (lines ) + "\n " )
167181 return needed
168182
169- # Build root descriptors: WorkflowActivation, WorkflowActivationCompletion,
170- # and all messages from selected API modules
171- roots : list [Descriptor ] = [
172- WorkflowActivation .DESCRIPTOR ,
173- WorkflowActivationCompletion .DESCRIPTOR ,
174- ]
183+ def generate (self , roots : list [Descriptor ]) -> str :
184+ """
185+ Generate Python source code that, given a function f(Payload) -> Payload,
186+ applies it to every Payload contained within a WorkflowActivation tree.
175187
176- # We avoid importing google.api deps in service protos; expand by walking from
177- # WorkflowActivationCompletion root which references many command messages.
188+ The generated code defines async visitor functions for each reachable
189+ protobuf message type starting from WorkflowActivation, including support
190+ for repeated fields and map entries, and a convenience entrypoint
191+ function `visit`.
192+ """
178193
179- for r in roots :
180- walk (r )
194+ # We avoid importing google.api deps in service protos; expand by walking from
195+ # WorkflowActivationCompletion root which references many command messages.
196+ for r in roots :
197+ self .walk (r )
181198
182- header = """
183- from typing import Any, Awaitable, Callable
199+ header = """
200+ import abc
201+ from typing import Any, MutableSequence
184202
185203from temporalio.api.common.v1.message_pb2 import Payload
186204
205+ class VisitorFunctions(abc.ABC):
206+ \" \" \" Set of functions which can be called by the visitor. Allows handling payloads as a sequence.\" \" \"
207+ @abc.abstractmethod
208+ async def visit_payload(self, payload: Payload) -> None:
209+ \" \" \" Called when encountering a single payload.\" \" \"
210+ raise NotImplementedError()
211+
212+ @abc.abstractmethod
213+ async def visit_payloads(self, payloads: MutableSequence[Payload]) -> None:
214+ \" \" \" Called when encountering multiple payloads together.\" \" \"
215+ raise NotImplementedError()
187216
188217class PayloadVisitor:
189218 \" \" \" A visitor for payloads. Applies a function to every payload in a tree of messages.\" \" \"
@@ -195,25 +224,33 @@ def __init__(
195224 self.skip_headers = skip_headers
196225
197226 async def visit(
198- self, f: Callable[[Payload], Awaitable[Payload]] , root: Any
227+ self, fs: VisitorFunctions , root: Any
199228 ) -> None:
200229 \" \" \" Visits the given root message with the given function.\" \" \"
201230 method_name = "_visit_" + root.DESCRIPTOR.full_name.replace(".", "_")
202231 method = getattr(self, method_name, None)
203232 if method is not None:
204- await method(f , root)
233+ await method(fs , root)
205234 else:
206235 raise ValueError(f"Unknown root message type: {root.DESCRIPTOR.full_name}")
207236
208237"""
209238
210- return header + "\n " .join (methods )
239+ return header + "\n " .join (self . methods )
211240
212241
213242def write_generated_visitors_into_visitor_generated_py () -> None :
214243 """Write the generated visitor code into _visitor.py."""
215244 out_path = base_dir / "temporalio" / "bridge" / "_visitor.py"
216- code = gen_workflow_activation_payload_visitor_code ()
245+
246+ # Build root descriptors: WorkflowActivation, WorkflowActivationCompletion,
247+ # and all messages from selected API modules
248+ roots : list [Descriptor ] = [
249+ WorkflowActivation .DESCRIPTOR ,
250+ WorkflowActivationCompletion .DESCRIPTOR ,
251+ ]
252+
253+ code = VisitorGenerator ().generate (roots )
217254 out_path .write_text (code )
218255
219256
0 commit comments