Skip to content

Commit d2cfb4c

Browse files
committed
Add generation of python services. Update a couple tests to avoid relying on previous behavior but still contain the same assertions
1 parent 8616046 commit d2cfb4c

File tree

6 files changed

+3394
-1018
lines changed

6 files changed

+3394
-1018
lines changed

scripts/gen_bridge_client.py

Lines changed: 149 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from functools import partial
12
import re
23
from string import Template
34

@@ -14,14 +15,128 @@
1415
import temporalio.bridge.proto.health.v1.health_pb2 as health_service
1516

1617

17-
def generate_client_impl(
18+
def generate_python_services(
19+
file_descriptors: list[FileDescriptor],
20+
output_file: str = "temporalio/bridge/services_generated.py",
21+
):
22+
print("generating python services")
23+
24+
services_template = Template("""# Generated file. DO NOT EDIT
25+
26+
from __future__ import annotations
27+
28+
from datetime import timedelta
29+
from typing import Mapping, Optional, Union, TYPE_CHECKING
30+
import google.protobuf.empty_pb2
31+
32+
$service_imports
33+
34+
35+
if TYPE_CHECKING:
36+
from temporalio.service import ServiceClient
37+
38+
$service_defns
39+
""")
40+
41+
def service_name(s):
42+
return f"import {sanitize_proto_name(s.full_name)[:-len(s.name)-1]}"
43+
44+
service_imports = [
45+
service_name(service_descriptor)
46+
for file_descriptor in file_descriptors
47+
for service_descriptor in file_descriptor.services_by_name.values()
48+
]
49+
50+
service_defns = [
51+
generate_python_service(service_descriptor)
52+
for file_descriptor in file_descriptors
53+
for service_descriptor in file_descriptor.services_by_name.values()
54+
]
55+
56+
with open(output_file, "w") as f:
57+
f.write(
58+
services_template.substitute(
59+
service_imports="\n".join(service_imports),
60+
service_defns="\n".join(service_defns),
61+
)
62+
)
63+
64+
print(f"successfully generated client at {output_file}")
65+
66+
67+
def generate_python_service(service_descriptor: ServiceDescriptor) -> str:
68+
service_template = Template("""
69+
class $service_name:
70+
def __init__(self, client: ServiceClient):
71+
self.client = client
72+
self.service = "$rpc_service_name"
73+
$method_calls
74+
""")
75+
76+
sanitized_service_name: str = service_descriptor.name
77+
# The health service doesn't end in "Service" in the proto definition
78+
# this check ensures that the proto descriptor name will match the format in core
79+
if not sanitized_service_name.endswith("Service"):
80+
sanitized_service_name += "Service"
81+
82+
# remove "Service" and lowercase
83+
rpc_name = sanitized_service_name[:-7].lower()
84+
85+
# remove any streaming methods b/c we don't support them at the moment
86+
methods = [
87+
method
88+
for method in service_descriptor.methods
89+
if not method.client_streaming and not method.server_streaming
90+
]
91+
92+
method_calls = [
93+
generate_python_method_call(method)
94+
for method in sorted(methods, key=lambda m: m.name)
95+
]
96+
97+
return service_template.substitute(
98+
service_name=sanitized_service_name,
99+
rpc_service_name=pascal_to_snake(rpc_name),
100+
method_calls="\n".join(method_calls),
101+
)
102+
103+
104+
def generate_python_method_call(method_descriptor: MethodDescriptor) -> str:
105+
method_template = Template("""
106+
async def $method_name(
107+
self,
108+
req: $request_type,
109+
retry: bool = False,
110+
metadata: Mapping[str, Union[str, bytes]] = {},
111+
timeout: Optional[timedelta] = None,
112+
) -> $response_type:
113+
print("sup from $method_name")
114+
return await self.client._rpc_call(
115+
rpc="$method_name",
116+
req=req,
117+
service=self.service,
118+
resp_type=$response_type,
119+
retry=retry,
120+
metadata=metadata,
121+
timeout=timeout,
122+
)
123+
""")
124+
125+
return method_template.substitute(
126+
method_name=pascal_to_snake(method_descriptor.name),
127+
request_type=sanitize_proto_name(method_descriptor.input_type.full_name),
128+
response_type=sanitize_proto_name(method_descriptor.output_type.full_name),
129+
)
130+
131+
132+
def generate_rust_client_impl(
18133
file_descriptors: list[FileDescriptor],
19134
output_file: str = "temporalio/bridge/src/client_rpc_generated.rs",
20135
):
21136
print("generating bridge rpc calls")
22137

23138
service_calls = [
24-
generate_service_call(service_descriptor)
139+
generate_rust_service_call(service_descriptor)
25140
for file_descriptor in file_descriptors
26141
for service_descriptor in file_descriptor.services_by_name.values()
27142
]
@@ -47,7 +162,7 @@ def generate_client_impl(
47162
print(f"successfully generated client at {output_file}")
48163

49164

50-
def generate_service_call(service_descriptor: ServiceDescriptor) -> str:
165+
def generate_rust_service_call(service_descriptor: ServiceDescriptor) -> str:
51166
print(f"generating rpc call wrapper for {service_descriptor.full_name}")
52167

53168
call_template = Template("""
@@ -86,7 +201,7 @@ def generate_service_call(service_descriptor: ServiceDescriptor) -> str:
86201
]
87202

88203
match_arms = [
89-
generate_match_arm(sanitized_service_name, method)
204+
generate_rust_match_arm(sanitized_service_name, method)
90205
for method in sorted(methods, key=lambda m: m.name)
91206
]
92207

@@ -97,7 +212,7 @@ def generate_service_call(service_descriptor: ServiceDescriptor) -> str:
97212
)
98213

99214

100-
def generate_match_arm(trait_name: str, method: MethodDescriptor) -> str:
215+
def generate_rust_match_arm(trait_name: str, method: MethodDescriptor) -> str:
101216
match_template = Template("""\
102217
"$method_name" => {
103218
rpc_call!(retry_client, call, $trait_name, $method_name)
@@ -112,8 +227,36 @@ def pascal_to_snake(input: str) -> str:
112227
return re.sub(r"([a-z0-9])([A-Z])", r"\1_\2", input).lower()
113228

114229

230+
sanitize_import_fixes = [
231+
partial(re.compile(r"temporal\.api\.").sub, r"temporalio.api."),
232+
partial(
233+
re.compile(r"temporal\.grpc.health\.").sub, r"temporalio.bridge.proto.health."
234+
),
235+
partial(
236+
re.compile(r"google\.protobuf\.Empty").sub, r"google.protobuf.empty_pb2.Empty"
237+
),
238+
]
239+
240+
241+
def sanitize_proto_name(input: str) -> str:
242+
content = input
243+
for fix in sanitize_import_fixes:
244+
content = fix(content)
245+
return content
246+
247+
115248
if __name__ == "__main__":
116-
generate_client_impl(
249+
generate_rust_client_impl(
250+
[
251+
workflow_service.DESCRIPTOR,
252+
operator_service.DESCRIPTOR,
253+
cloud_service.DESCRIPTOR,
254+
test_service.DESCRIPTOR,
255+
health_service.DESCRIPTOR,
256+
]
257+
)
258+
259+
generate_python_services(
117260
[
118261
workflow_service.DESCRIPTOR,
119262
operator_service.DESCRIPTOR,

0 commit comments

Comments
 (0)