@@ -19,7 +19,8 @@ from google.protobuf import json_format
19
19
{% endif %}
20
20
from requests import __version__ as requests_version
21
21
import dataclasses
22
- from typing import Callable, Dict, Optional, Sequence, Tuple, Union
22
+ import re
23
+ from typing import Callable, Dict, List, Optional, Sequence, Tuple, Union
23
24
import warnings
24
25
25
26
try:
@@ -65,7 +66,7 @@ class {{ service.name }}RestInterceptor:
65
66
66
67
.. code-block:
67
68
class MyCustom{{ service.name }}Interceptor({{ service.name }}RestInterceptor):
68
- {% for _ , method in service .methods |dictsort if not (method .server_streaming or method .client_streaming ) %}
69
+ {% for _ , method in service .methods |dictsort if not (method .server_streaming or method .client_streaming ) %}
69
70
def pre_{{ method.name|snake_case }}(request, metadata):
70
71
logging.log(f"Received request: {request}")
71
72
return request, metadata
@@ -81,7 +82,7 @@ class {{ service.name }}RestInterceptor:
81
82
82
83
83
84
"""
84
- {% for method in service .methods .values ()|sort (attribute ="name" ) if not (method .server_streaming or method .client_streaming ) %}
85
+ {% for method in service .methods .values ()|sort (attribute ="name" ) if not (method .server_streaming or method .client_streaming ) %}
85
86
def pre_{{ method.name|snake_case }}(self, request: {{method.input.ident}}, metadata: Sequence[Tuple[str, str]]) -> Tuple[{{method.input.ident}}, Sequence[Tuple[str, str]]]:
86
87
"""Pre-rpc interceptor for {{ method.name|snake_case }}
87
88
@@ -175,6 +176,14 @@ class {{service.name}}RestTransport({{service.name}}Transport):
175
176
# TODO(yon-mg): resolve other ctor params i.e. scopes, quota, etc.
176
177
# TODO: When custom host (api_endpoint) is set, `scopes` must *also* be set on the
177
178
# credentials object
179
+ maybe_url_match = re.match("^(?P<scheme >http(?:s)?://)?(?P<host >.*)$", host)
180
+ if maybe_url_match is None:
181
+ raise ValueError(f"Unexpected hostname structure: {host}") # pragma: NO COVER
182
+
183
+ url_match_items = maybe_url_match.groupdict()
184
+
185
+ host = f"{url_scheme}://{host}" if not url_match_items["scheme"] else host
186
+
178
187
super().__init__(
179
188
host=host,
180
189
credentials=credentials,
@@ -184,7 +193,7 @@ class {{service.name}}RestTransport({{service.name}}Transport):
184
193
self._session = AuthorizedSession(
185
194
self._credentials, default_host=self.DEFAULT_HOST)
186
195
{% if service .has_lro %}
187
- self._operations_client = None
196
+ self._operations_client: Optional[operations_v1.AbstractOperationsClient] = None
188
197
{% endif %}
189
198
if client_cert_source_for_mtls:
190
199
self._session.configure_mtls_channel(client_cert_source_for_mtls)
@@ -202,7 +211,7 @@ class {{service.name}}RestTransport({{service.name}}Transport):
202
211
"""
203
212
# Only create a new client if we do not already have one.
204
213
if self._operations_client is None:
205
- http_options = {
214
+ http_options: Dict[str, List[Dict[str, str]]] = {
206
215
{% for selector , rules in api .http_options .items () %}
207
216
{% if selector .startswith ('google.longrunning.Operations' ) %}
208
217
'{{ selector }}': [
@@ -238,9 +247,10 @@ class {{service.name}}RestTransport({{service.name}}Transport):
238
247
def __hash__(self):
239
248
return hash("{{method.name}}")
240
249
250
+
241
251
{% if not (method .server_streaming or method .client_streaming ) %}
242
252
{% if method .input .required_fields %}
243
- __REQUIRED_FIELDS_DEFAULT_VALUES = {
253
+ __REQUIRED_FIELDS_DEFAULT_VALUES: Dict[str, str] = {
244
254
{% for req_field in method .input .required_fields if req_field .is_primitive and req_field .name in method .query_params %}
245
255
"{{ req_field.name | camel_case }}" : {% if req_field .field_pb .type == 9 %} "{{req_field.field_pb.default_value }}"{% else %} {{ req_field.type.python_type(req_field.field_pb.default_value or 0) }}{% endif %} ,{# default is str #}
246
256
{% endfor %}
@@ -258,7 +268,7 @@ class {{service.name}}RestTransport({{service.name}}Transport):
258
268
retry: OptionalRetry=gapic_v1.method.DEFAULT,
259
269
timeout: float=None,
260
270
metadata: Sequence[Tuple[str, str]]=(),
261
- ) -> {{method.output.ident}}:
271
+ ){% if not method . void %} -> {{method.output.ident}} {% endif % } :
262
272
{% if method .http_options and not (method .server_streaming or method .client_streaming ) %}
263
273
r"""Call the {{- ' ' -}}
264
274
{{ (method.name|snake_case).replace('_',' ')|wrap(
@@ -282,7 +292,7 @@ class {{service.name}}RestTransport({{service.name}}Transport):
282
292
{% endif %}
283
293
"""
284
294
285
- http_options = [
295
+ http_options: List[Dict[str, str]] = [
286
296
{% - for rule in method .http_options %} {
287
297
'method': '{{ rule.method }}',
288
298
'uri': '{{ rule.uri }}',
@@ -330,8 +340,7 @@ class {{service.name}}RestTransport({{service.name}}Transport):
330
340
headers = dict(metadata)
331
341
headers['Content-Type'] = 'application/json'
332
342
response = getattr(self._session, method)(
333
- # Replace with proper schema configuration (http/https) logic
334
- "https://{host}{uri}".format(host=self._host, uri=uri),
343
+ "{host}{uri}".format(host=self._host, uri=uri),
335
344
timeout=timeout,
336
345
headers=headers,
337
346
params=rest_helpers.flatten_query_params(query_params),
@@ -344,6 +353,7 @@ class {{service.name}}RestTransport({{service.name}}Transport):
344
353
# subclass.
345
354
if response.status_code >= 400:
346
355
raise core_exceptions.from_http_response(response)
356
+
347
357
{% if not method .void %}
348
358
# Return the response
349
359
{% if method .lro %}
@@ -357,6 +367,7 @@ class {{service.name}}RestTransport({{service.name}}Transport):
357
367
{% endif %} {# method.lro #}
358
368
resp = self._interceptor.post_{{ method.name|snake_case }}(resp)
359
369
return resp
370
+
360
371
{% endif %} {# method.void #}
361
372
{% else %} {# method.http_options and not (method.server_streaming or method.client_streaming) #}
362
373
{% if not method .http_options %}
@@ -384,7 +395,9 @@ class {{service.name}}RestTransport({{service.name}}Transport):
384
395
if not stub:
385
396
stub = self._STUBS["{{method.name | snake_case}}"] = self._{{method.name}}(self._session, self._host, self._interceptor)
386
397
387
- return stub
398
+ # The return type is fine, but mypy isn't sophisticated enough to determine what's going on here.
399
+ # In C++ this would require a dynamic_cast
400
+ return stub # type: ignore
388
401
389
402
{% endfor %}
390
403
0 commit comments