1+ from functools import partial
12import re
23from string import Template
34
1415import temporalio .bridge .proto .health .v1 .health_pb2 as health_service
1516
1617
17- def generate_client_impl (
18+ def generate_python_services (
19+ file_descriptors : list [FileDescriptor ],
20+ output_file : str = "temporalio/bridge/services_generated.py" ,
21+ ):
22+ print ("generating python services" )
23+
24+ services_template = Template ("""# Generated file. DO NOT EDIT
25+
26+ from __future__ import annotations
27+
28+ from datetime import timedelta
29+ from typing import Mapping, Optional, Union, TYPE_CHECKING
30+ import google.protobuf.empty_pb2
31+
32+ $service_imports
33+
34+
35+ if TYPE_CHECKING:
36+ from temporalio.service import ServiceClient
37+
38+ $service_defns
39+ """ )
40+
41+ def service_name (s ):
42+ return f"import { sanitize_proto_name (s .full_name )[:- len (s .name )- 1 ]} "
43+
44+ service_imports = [
45+ service_name (service_descriptor )
46+ for file_descriptor in file_descriptors
47+ for service_descriptor in file_descriptor .services_by_name .values ()
48+ ]
49+
50+ service_defns = [
51+ generate_python_service (service_descriptor )
52+ for file_descriptor in file_descriptors
53+ for service_descriptor in file_descriptor .services_by_name .values ()
54+ ]
55+
56+ with open (output_file , "w" ) as f :
57+ f .write (
58+ services_template .substitute (
59+ service_imports = "\n " .join (service_imports ),
60+ service_defns = "\n " .join (service_defns ),
61+ )
62+ )
63+
64+ print (f"successfully generated client at { output_file } " )
65+
66+
67+ def generate_python_service (service_descriptor : ServiceDescriptor ) -> str :
68+ service_template = Template ("""
69+ class $service_name:
70+ def __init__(self, client: ServiceClient):
71+ self.client = client
72+ self.service = "$rpc_service_name"
73+ $method_calls
74+ """ )
75+
76+ sanitized_service_name : str = service_descriptor .name
77+ # The health service doesn't end in "Service" in the proto definition
78+ # this check ensures that the proto descriptor name will match the format in core
79+ if not sanitized_service_name .endswith ("Service" ):
80+ sanitized_service_name += "Service"
81+
82+ # remove "Service" and lowercase
83+ rpc_name = sanitized_service_name [:- 7 ].lower ()
84+
85+ # remove any streaming methods b/c we don't support them at the moment
86+ methods = [
87+ method
88+ for method in service_descriptor .methods
89+ if not method .client_streaming and not method .server_streaming
90+ ]
91+
92+ method_calls = [
93+ generate_python_method_call (method )
94+ for method in sorted (methods , key = lambda m : m .name )
95+ ]
96+
97+ return service_template .substitute (
98+ service_name = sanitized_service_name ,
99+ rpc_service_name = pascal_to_snake (rpc_name ),
100+ method_calls = "\n " .join (method_calls ),
101+ )
102+
103+
104+ def generate_python_method_call (method_descriptor : MethodDescriptor ) -> str :
105+ method_template = Template ("""
106+ async def $method_name(
107+ self,
108+ req: $request_type,
109+ retry: bool = False,
110+ metadata: Mapping[str, Union[str, bytes]] = {},
111+ timeout: Optional[timedelta] = None,
112+ ) -> $response_type:
113+ print("sup from $method_name")
114+ return await self.client._rpc_call(
115+ rpc="$method_name",
116+ req=req,
117+ service=self.service,
118+ resp_type=$response_type,
119+ retry=retry,
120+ metadata=metadata,
121+ timeout=timeout,
122+ )
123+ """ )
124+
125+ return method_template .substitute (
126+ method_name = pascal_to_snake (method_descriptor .name ),
127+ request_type = sanitize_proto_name (method_descriptor .input_type .full_name ),
128+ response_type = sanitize_proto_name (method_descriptor .output_type .full_name ),
129+ )
130+
131+
132+ def generate_rust_client_impl (
18133 file_descriptors : list [FileDescriptor ],
19134 output_file : str = "temporalio/bridge/src/client_rpc_generated.rs" ,
20135):
21136 print ("generating bridge rpc calls" )
22137
23138 service_calls = [
24- generate_service_call (service_descriptor )
139+ generate_rust_service_call (service_descriptor )
25140 for file_descriptor in file_descriptors
26141 for service_descriptor in file_descriptor .services_by_name .values ()
27142 ]
@@ -47,7 +162,7 @@ def generate_client_impl(
47162 print (f"successfully generated client at { output_file } " )
48163
49164
50- def generate_service_call (service_descriptor : ServiceDescriptor ) -> str :
165+ def generate_rust_service_call (service_descriptor : ServiceDescriptor ) -> str :
51166 print (f"generating rpc call wrapper for { service_descriptor .full_name } " )
52167
53168 call_template = Template ("""
@@ -86,7 +201,7 @@ def generate_service_call(service_descriptor: ServiceDescriptor) -> str:
86201 ]
87202
88203 match_arms = [
89- generate_match_arm (sanitized_service_name , method )
204+ generate_rust_match_arm (sanitized_service_name , method )
90205 for method in sorted (methods , key = lambda m : m .name )
91206 ]
92207
@@ -97,7 +212,7 @@ def generate_service_call(service_descriptor: ServiceDescriptor) -> str:
97212 )
98213
99214
100- def generate_match_arm (trait_name : str , method : MethodDescriptor ) -> str :
215+ def generate_rust_match_arm (trait_name : str , method : MethodDescriptor ) -> str :
101216 match_template = Template ("""\
102217 "$method_name" => {
103218 rpc_call!(retry_client, call, $trait_name, $method_name)
@@ -112,8 +227,36 @@ def pascal_to_snake(input: str) -> str:
112227 return re .sub (r"([a-z0-9])([A-Z])" , r"\1_\2" , input ).lower ()
113228
114229
230+ sanitize_import_fixes = [
231+ partial (re .compile (r"temporal\.api\." ).sub , r"temporalio.api." ),
232+ partial (
233+ re .compile (r"temporal\.grpc.health\." ).sub , r"temporalio.bridge.proto.health."
234+ ),
235+ partial (
236+ re .compile (r"google\.protobuf\.Empty" ).sub , r"google.protobuf.empty_pb2.Empty"
237+ ),
238+ ]
239+
240+
241+ def sanitize_proto_name (input : str ) -> str :
242+ content = input
243+ for fix in sanitize_import_fixes :
244+ content = fix (content )
245+ return content
246+
247+
115248if __name__ == "__main__" :
116- generate_client_impl (
249+ generate_rust_client_impl (
250+ [
251+ workflow_service .DESCRIPTOR ,
252+ operator_service .DESCRIPTOR ,
253+ cloud_service .DESCRIPTOR ,
254+ test_service .DESCRIPTOR ,
255+ health_service .DESCRIPTOR ,
256+ ]
257+ )
258+
259+ generate_python_services (
117260 [
118261 workflow_service .DESCRIPTOR ,
119262 operator_service .DESCRIPTOR ,
0 commit comments