1313
1414base_dir = Path (__file__ ).parent .parent
1515
16+
1617def gen_workflow_activation_payload_visitor_code () -> str :
1718 """
1819 Generate Python source code that, given a function f(Payload) -> Payload,
@@ -23,29 +24,44 @@ def gen_workflow_activation_payload_visitor_code() -> str:
2324 for repeated fields and map entries, and a convenience entrypoint
2425 function `visit_workflow_activation_payloads`.
2526 """
27+
2628 def name_for (desc : Descriptor ) -> str :
2729 # Use fully-qualified name to avoid collisions; replace dots with underscores
28- return desc .full_name .replace ('.' , '_' )
29-
30- def emit_loop (lines : list [str ], field_name : str , iter_expr : str , var_name : str , child_method : str ) -> None :
30+ return desc .full_name .replace ("." , "_" )
31+
32+ def emit_loop (
33+ lines : list [str ],
34+ field_name : str ,
35+ iter_expr : str ,
36+ var_name : str ,
37+ child_method : str ,
38+ ) -> None :
3139 # Helper to emit a for-loop over a collection with optional headers guard
3240 if field_name == "headers" :
3341 lines .append (" if not self.skip_headers:" )
3442 lines .append (f" for { var_name } in { iter_expr } :" )
35- lines .append (f" await self.visit_{ child_method } (f, { var_name } )" )
43+ lines .append (
44+ f" await self.visit_{ child_method } (f, { var_name } )"
45+ )
3646 else :
3747 lines .append (f" for { var_name } in { iter_expr } :" )
3848 lines .append (f" await self.visit_{ child_method } (f, { var_name } )" )
3949
40- def emit_singular (lines : list [str ], field_name : str , access_expr : str , child_method : str ) -> None :
50+ def emit_singular (
51+ lines : list [str ], field_name : str , access_expr : str , child_method : str
52+ ) -> None :
4153 # Helper to emit a singular field visit with presence check and optional headers guard
4254 if field_name == "headers" :
4355 lines .append (" if not self.skip_headers:" )
4456 lines .append (f" if o.HasField('{ field_name } '):" )
45- lines .append (f" await self.visit_{ child_method } (f, { access_expr } )" )
57+ lines .append (
58+ f" await self.visit_{ child_method } (f, { access_expr } )"
59+ )
4660 else :
4761 lines .append (f" if o.HasField('{ field_name } '):" )
48- lines .append (f" await self.visit_{ child_method } (f, { access_expr } )" )
62+ lines .append (
63+ f" await self.visit_{ child_method } (f, { access_expr } )"
64+ )
4965
5066 # Track which message descriptors have visitor methods generated
5167 generated : dict [str , bool ] = {}
@@ -83,36 +99,65 @@ def walk(desc: Descriptor) -> bool:
8399
84100 # Repeated fields (including maps which are represented as repeated messages)
85101 if field .label == FieldDescriptor .LABEL_REPEATED :
86- if field .message_type is not None and field .message_type .GetOptions ().map_entry :
102+ if (
103+ field .message_type is not None
104+ and field .message_type .GetOptions ().map_entry
105+ ):
87106 entry_desc = field .message_type
88107 key_fd = entry_desc .fields_by_name .get ("key" )
89108 val_fd = entry_desc .fields_by_name .get ("value" )
90109
91- if val_fd is not None and val_fd .type == FieldDescriptor .TYPE_MESSAGE :
110+ if (
111+ val_fd is not None
112+ and val_fd .type == FieldDescriptor .TYPE_MESSAGE
113+ ):
92114 child_desc = val_fd .message_type
93115 child_needed = walk (child_desc )
94116 needed |= child_needed
95117 if child_needed :
96- emit_loop (lines , field .name , f"o.{ field .name } .values()" , "v" , name_for (child_desc ))
97-
98- if key_fd is not None and key_fd .type == FieldDescriptor .TYPE_MESSAGE :
118+ emit_loop (
119+ lines ,
120+ field .name ,
121+ f"o.{ field .name } .values()" ,
122+ "v" ,
123+ name_for (child_desc ),
124+ )
125+
126+ if (
127+ key_fd is not None
128+ and key_fd .type == FieldDescriptor .TYPE_MESSAGE
129+ ):
99130 key_desc = key_fd .message_type
100131 child_needed = walk (key_desc )
101132 needed |= child_needed
102133 if child_needed :
103- emit_loop (lines , field .name , f"o.{ field .name } .keys()" , "k" , name_for (key_desc ))
134+ emit_loop (
135+ lines ,
136+ field .name ,
137+ f"o.{ field .name } .keys()" ,
138+ "k" ,
139+ name_for (key_desc ),
140+ )
104141 else :
105142 child_desc = field .message_type
106143 child_needed = walk (child_desc )
107144 needed |= child_needed
108145 if child_needed :
109- emit_loop (lines , field .name , f"o.{ field .name } " , "v" , name_for (child_desc ))
146+ emit_loop (
147+ lines ,
148+ field .name ,
149+ f"o.{ field .name } " ,
150+ "v" ,
151+ name_for (child_desc ),
152+ )
110153 else :
111154 child_desc = field .message_type
112155 child_needed = walk (child_desc )
113156 needed |= child_needed
114157 if child_needed :
115- emit_singular (lines , field .name , f"o.{ field .name } " , name_for (child_desc ))
158+ emit_singular (
159+ lines , field .name , f"o.{ field .name } " , name_for (child_desc )
160+ )
116161
117162 generated [key ] = needed
118163 in_progress .discard (key )
@@ -134,7 +179,7 @@ def walk(desc: Descriptor) -> bool:
134179 walk (r )
135180
136181 header = (
137- "from typing import Awaitable, Callable, Any \n \n "
182+ "from typing import Any, Awaitable, Callable \n \n "
138183 "from temporalio.api.common.v1.message_pb2 import Payload\n \n \n "
139184 "class PayloadVisitor:\n "
140185 " def __init__(self, *, skip_search_attributes: bool = False, skip_headers: bool = False):\n "
@@ -156,7 +201,7 @@ def write_generated_visitors_into_visitor_generated_py() -> None:
156201 code = gen_workflow_activation_payload_visitor_code ()
157202 out_path .write_text (code )
158203
204+
159205if __name__ == "__main__" :
160206 print ("Generating temporalio/bridge/visitor_generated.py..." , file = sys .stderr )
161207 write_generated_visitors_into_visitor_generated_py ()
162-
0 commit comments