|
1 | 1 | import inspect |
2 | 2 | import os |
3 | 3 | import re |
| 4 | +from datetime import timedelta |
| 5 | +from importlib import import_module |
4 | 6 | from typing import Any, Callable, Dict, Mapping, Tuple, Type |
5 | 7 |
|
6 | 8 | import google.protobuf.empty_pb2 |
7 | 9 | import google.protobuf.message |
| 10 | +import google.protobuf.symbol_database |
8 | 11 | import grpc |
9 | 12 | import pytest |
| 13 | +from google.protobuf.descriptor import MethodDescriptor |
10 | 14 |
|
11 | 15 | import temporalio |
12 | 16 | import temporalio.api.cloud.cloudservice.v1 |
|
19 | 23 | from temporalio.testing import WorkflowEnvironment |
20 | 24 |
|
21 | 25 |
|
| 26 | +def _camel_to_snake(name: str) -> str: |
| 27 | + return re.sub(r"(?<!^)(?=[A-Z])", "_", name).lower() |
| 28 | + |
| 29 | + |
22 | 30 | def test_all_grpc_calls_present(client: Client): |
23 | 31 | def assert_all_calls_present( |
24 | 32 | service: Any, |
@@ -111,7 +119,7 @@ def unary_unary(self, method, request_serializer, response_deserializer): |
111 | 119 | getattr(self.package, name + "Response"), |
112 | 120 | ) |
113 | 121 | # Camel to snake case |
114 | | - name = re.sub(r"(?<!^)(?=[A-Z])", "_", name).lower() |
| 122 | + name = _camel_to_snake(name) |
115 | 123 | self.calls[name] = req_resp |
116 | 124 |
|
117 | 125 |
|
@@ -157,3 +165,74 @@ async def test_grpc_status(client: Client, env: WorkflowEnvironment): |
157 | 165 | assert err.value.grpc_status.details[0].Is( |
158 | 166 | temporalio.api.errordetails.v1.NamespaceNotFoundFailure.DESCRIPTOR |
159 | 167 | ) |
| 168 | + |
| 169 | + |
| 170 | +async def test_rpc_execution_not_unknown(client: Client): |
| 171 | + """ |
| 172 | + Execute each rpc method and expect a failure, but ensure the failure is not that the rpc method is unknown |
| 173 | + """ |
| 174 | + sym_db = google.protobuf.symbol_database.Default() |
| 175 | + service_client = client.service_client |
| 176 | + |
| 177 | + async def test_method( |
| 178 | + target_service_name: str, method_descriptor: MethodDescriptor |
| 179 | + ): |
| 180 | + if method_descriptor.client_streaming or method_descriptor.server_streaming: |
| 181 | + # skip streaming calls |
| 182 | + return |
| 183 | + |
| 184 | + method_name = _camel_to_snake(method_descriptor.name) |
| 185 | + |
| 186 | + # get request type and instantiate an empty request |
| 187 | + request_type = sym_db.GetSymbol(method_descriptor.input_type.full_name) |
| 188 | + request = request_type() |
| 189 | + |
| 190 | + # get the appropriate temporal service from the service_client |
| 191 | + target_service = getattr(service_client, target_service_name) |
| 192 | + |
| 193 | + # execute rpc and ensure that any exception that occurs is not the |
| 194 | + # "Unknown RPC call" error which indicates the python and rust rpc components |
| 195 | + # should be regenerated |
| 196 | + rpc_call = getattr(target_service, method_name) |
| 197 | + try: |
| 198 | + await rpc_call(request, timeout=timedelta(milliseconds=1)) |
| 199 | + except Exception as err: |
| 200 | + assert ( |
| 201 | + "Unknown RPC call" not in str(err) |
| 202 | + ), f"Unexpected unknown-RPC error for {target_service_name}.{method_name}: {err}" |
| 203 | + |
| 204 | + async def test_service( |
| 205 | + *, proto_module: str, proto_service: str, target_service_name: str |
| 206 | + ): |
| 207 | + # load the module and test each method of the specified service |
| 208 | + module = import_module(proto_module) |
| 209 | + service_descriptor = module.DESCRIPTOR.services_by_name[proto_service] |
| 210 | + |
| 211 | + for method_descriptor in service_descriptor.methods: |
| 212 | + await test_method(target_service_name, method_descriptor) |
| 213 | + |
| 214 | + await test_service( |
| 215 | + proto_module="temporalio.api.workflowservice.v1.service_pb2", |
| 216 | + proto_service="WorkflowService", |
| 217 | + target_service_name="workflow_service", |
| 218 | + ) |
| 219 | + await test_service( |
| 220 | + proto_module="temporalio.api.operatorservice.v1.service_pb2", |
| 221 | + proto_service="OperatorService", |
| 222 | + target_service_name="operator_service", |
| 223 | + ) |
| 224 | + await test_service( |
| 225 | + proto_module="temporalio.api.cloud.cloudservice.v1.service_pb2", |
| 226 | + proto_service="CloudService", |
| 227 | + target_service_name="cloud_service", |
| 228 | + ) |
| 229 | + await test_service( |
| 230 | + proto_module="temporalio.api.testservice.v1.service_pb2", |
| 231 | + proto_service="TestService", |
| 232 | + target_service_name="test_service", |
| 233 | + ) |
| 234 | + await test_service( |
| 235 | + proto_module="temporalio.bridge.proto.health.v1.health_pb2", |
| 236 | + proto_service="Health", |
| 237 | + target_service_name="health_service", |
| 238 | + ) |
0 commit comments