Skip to content

Commit fe57eb2

Browse files
authored
feat: add interceptor-like functionality to REST transport (googleapis#1142)
Interceptors are a gRPC feature that wraps rpcs in continuation-passing-style pre and post method custom functions. These can be used e.g. for logging, local caching, and tweaking metadata. This PR adds interceptor like functionality to the REST transport in generated GAPICs. The REST transport interceptors differ in a few ways: 1) They are not continuations. For each method there is a slot for a "pre"function, and for each method with a non-empty return there is a slot for a "post" function. 2) There is always an interceptor for each method. The default simply does nothing. 3) Existing gRPC interceptors and the new REST interceptors are not composable or interoperable.
1 parent feb7b4f commit fe57eb2

File tree

9 files changed

+255
-22
lines changed

9 files changed

+255
-22
lines changed

gapic/ads-templates/%namespace/%name/%version/%sub/services/%service/transports/__init__.py.j2

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ from .grpc import {{ service.name }}GrpcTransport
1111
{% endif %}
1212
{% if 'rest' in opts.transport %}
1313
from .rest import {{ service.name }}RestTransport
14+
from .rest import {{ service.name }}RestInterceptor
1415
{% endif %}
1516

1617
# Compile a registry of transports.
@@ -29,6 +30,7 @@ __all__ = (
2930
{% endif %}
3031
{% if 'rest' in opts.transport %}
3132
'{{ service.name }}RestTransport',
33+
'{{ service.name }}RestInterceptor',
3234
{% endif %}
3335
)
3436
{% endblock %}

gapic/ads-templates/%namespace/%name/%version/%sub/services/%service/transports/rest.py.j2

Lines changed: 66 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -49,10 +49,67 @@ DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo(
4949
rest_version=requests_version,
5050
)
5151

52+
53+
class {{ service.name }}RestInterceptor:
54+
"""Interceptor for {{ service.name }}.
55+
56+
Interceptors are used to manipulate requests, request metadata, and responses
57+
in arbitrary ways.
58+
Example use cases include:
59+
* Logging
60+
* Verifying requests according to service or custom semantics
61+
* Stripping extraneous information from responses
62+
63+
These use cases and more can be enabled by injecting an
64+
instance of a custom subclass when constructing the {{ service.name }}RestTransport.
65+
66+
.. code-block:
67+
class MyCustom{{ service.name }}Interceptor({{ service.name }}RestInterceptor):
68+
{% for _, method in service.methods|dictsort if not (method.server_streaming or method.client_streaming) %}
69+
def pre_{{ method.name|snake_case }}(request, metadata):
70+
logging.log(f"Received request: {request}")
71+
return request, metadata
72+
73+
{% if not method.void %}
74+
def post_{{ method.name|snake_case }}(response):
75+
logging.log(f"Received response: {response}")
76+
{% endif %}
77+
78+
{% endfor %}
79+
transport = {{ service.name }}RestTransport(interceptor=MyCustom{{ service.name }}Interceptor())
80+
client = {{ service.client_name }}(transport=transport)
81+
82+
83+
"""
84+
{% for method in service.methods.values()|sort(attribute="name") if not(method.server_streaming or method.client_streaming) %}
85+
def pre_{{ method.name|snake_case }}(self, request: {{method.input.ident}}, metadata: Sequence[Tuple[str, str]]) -> Tuple[{{method.input.ident}}, Sequence[Tuple[str, str]]]:
86+
"""Pre-rpc interceptor for {{ method.name|snake_case }}
87+
88+
Override in a subclass to manipulate the request or metadata
89+
before they are sent to the {{ service.name }} server.
90+
"""
91+
return request, metadata
92+
93+
{% if not method.void %}
94+
def post_{{ method.name|snake_case }}(self, response: {{method.output.ident}}) -> {{method.output.ident}}:
95+
"""Post-rpc interceptor for {{ method.name|snake_case }}
96+
97+
Override in a subclass to manipulate the response
98+
after it is returned by the {{ service.name }} server but before
99+
it is returned to user code.
100+
"""
101+
return response
102+
{% endif %}
103+
104+
{% endfor %}
105+
106+
52107
@dataclasses.dataclass
53108
class {{service.name}}RestStub:
54109
_session: AuthorizedSession
55110
_host: str
111+
_interceptor: {{ service.name }}RestInterceptor
112+
56113

57114
class {{service.name}}RestTransport({{service.name}}Transport):
58115
"""REST backend transport for {{ service.name }}.
@@ -80,6 +137,7 @@ class {{service.name}}RestTransport({{service.name}}Transport):
80137
client_info: gapic_v1.client_info.ClientInfo=DEFAULT_CLIENT_INFO,
81138
always_use_jwt_access: Optional[bool]=False,
82139
url_scheme: str='https',
140+
interceptor: Optional[{{ service.name }}RestInterceptor] = None,
83141
) -> None:
84142
"""Instantiate the transport.
85143

@@ -130,6 +188,7 @@ class {{service.name}}RestTransport({{service.name}}Transport):
130188
{% endif %}
131189
if client_cert_source_for_mtls:
132190
self._session.configure_mtls_channel(client_cert_source_for_mtls)
191+
self._interceptor = interceptor or {{ service.name }}RestInterceptor()
133192
self._prep_wrapped_messages(client_info)
134193

135194
{% if service.has_lro %}
@@ -233,7 +292,7 @@ class {{service.name}}RestTransport({{service.name}}Transport):
233292
},
234293
{% endfor %}{# rule in method.http_options #}
235294
]
236-
295+
request, metadata = self._interceptor.pre_{{ method.name|snake_case }}(request, metadata)
237296
request_kwargs = {{method.input.ident}}.to_dict(request)
238297
transcoded_request = path_template.transcode(
239298
http_options, **request_kwargs)
@@ -288,16 +347,16 @@ class {{service.name}}RestTransport({{service.name}}Transport):
288347
{% if not method.void %}
289348
# Return the response
290349
{% if method.lro %}
291-
return_op = operations_pb2.Operation()
292-
json_format.Parse(response.content, return_op, ignore_unknown_fields=True)
293-
return return_op
350+
resp = operations_pb2.Operation()
351+
json_format.Parse(response.content, resp, ignore_unknown_fields=True)
294352
{% else %}
295-
return {{method.output.ident}}.from_json(
353+
resp = {{method.output.ident}}.from_json(
296354
response.content,
297355
ignore_unknown_fields=True
298356
)
299-
300357
{% endif %}{# method.lro #}
358+
resp = self._interceptor.post_{{ method.name|snake_case }}(resp)
359+
return resp
301360
{% endif %}{# method.void #}
302361
{% else %}{# method.http_options and not (method.server_streaming or method.client_streaming) #}
303362
{% if not method.http_options %}
@@ -323,7 +382,7 @@ class {{service.name}}RestTransport({{service.name}}Transport):
323382
{{method.output.ident}}]:
324383
stub = self._STUBS.get("{{method.name | snake_case}}")
325384
if not stub:
326-
stub = self._STUBS["{{method.name | snake_case}}"] = self._{{method.name}}(self._session, self._host)
385+
stub = self._STUBS["{{method.name | snake_case}}"] = self._{{method.name}}(self._session, self._host, self._interceptor)
327386

328387
return stub
329388

gapic/ads-templates/tests/unit/gapic/%name_%version/%sub/test_%service.py.j2

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ from google.api_core import grpc_helpers
3535
from google.api_core import path_template
3636
{% if service.has_lro %}
3737
from google.api_core import future
38+
from google.api_core import operation
3839
from google.api_core import operations_v1
3940
from google.longrunning import operations_pb2
4041
{% if "rest" in opts.transport %}
@@ -1113,6 +1114,55 @@ def test_{{ method_name }}_rest_unset_required_fields():
11131114

11141115
{% endif %}{# required_fields #}
11151116

1117+
{% if not (method.server_streaming or method.client_streaming) %}
1118+
@pytest.mark.parametrize("null_interceptor", [True, False])
1119+
def test_{{ method_name }}_rest_interceptors(null_interceptor):
1120+
transport = transports.{{ service.name }}RestTransport(
1121+
credentials=ga_credentials.AnonymousCredentials(),
1122+
interceptor=None if null_interceptor else transports.{{ service.name}}RestInterceptor(),
1123+
)
1124+
client = {{ service.client_name }}(transport=transport)
1125+
with mock.patch.object(type(client.transport._session), "request") as req, \
1126+
mock.patch.object(path_template, "transcode") as transcode, \
1127+
{% if method.lro %}
1128+
mock.patch.object(operation.Operation, "_set_result_from_operation"), \
1129+
{% endif %}
1130+
{% if not method.void %}
1131+
mock.patch.object(transports.{{ service.name }}RestInterceptor, "post_{{method.name|snake_case}}") as post, \
1132+
{% endif %}
1133+
mock.patch.object(transports.{{ service.name }}RestInterceptor, "pre_{{ method.name|snake_case }}") as pre:
1134+
pre.assert_not_called()
1135+
{% if not method.void %}
1136+
post.assert_not_called()
1137+
{% endif %}
1138+
1139+
transcode.return_value = {"method": "post", "uri": "my_uri", "body": None, "query_params": {},}
1140+
1141+
req.return_value = Response()
1142+
req.return_value.status_code = 200
1143+
req.return_value.request = PreparedRequest()
1144+
{% if not method.void %}
1145+
req.return_value._content = {% if method.output.ident.package == method.ident.package %}{{ method.output.ident }}.to_json({{ method.output.ident }}()){% else %}json_format.MessageToJson({{ method.output.ident }}()){% endif %}
1146+
{% endif %}
1147+
1148+
request = {{ method.input.ident }}()
1149+
metadata =[
1150+
("key", "val"),
1151+
("cephalopod", "squid"),
1152+
]
1153+
pre.return_value = request, metadata
1154+
{% if not method.void %}
1155+
post.return_value = {{ method.output.ident }}
1156+
{% endif %}
1157+
1158+
client.{{ method_name }}(request, metadata=[("key", "val"), ("cephalopod", "squid"),])
1159+
1160+
pre.assert_called_once()
1161+
{% if not method.void %}
1162+
post.assert_called_once()
1163+
{% endif %}
1164+
{% endif %}{# streaming #}
1165+
11161166

11171167
def test_{{ method_name }}_rest_bad_request(transport: str = 'rest', request_type={{ method.input.ident }}):
11181168
client = {{ service.client_name }}(

gapic/templates/%namespace/%name_%version/%sub/services/%service/transports/__init__.py.j2

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ from .grpc_asyncio import {{ service.name }}GrpcAsyncIOTransport
1212
{% endif %}
1313
{% if 'rest' in opts.transport %}
1414
from .rest import {{ service.name }}RestTransport
15+
from .rest import {{ service.name }}RestInterceptor
1516
{% endif %}
1617

1718

@@ -34,6 +35,7 @@ __all__ = (
3435
{% endif %}
3536
{% if 'rest' in opts.transport %}
3637
'{{ service.name }}RestTransport',
38+
'{{ service.name }}RestInterceptor',
3739
{% endif %}
3840
)
3941
{% endblock %}

gapic/templates/%namespace/%name_%version/%sub/services/%service/transports/rest.py.j2

Lines changed: 66 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -49,10 +49,67 @@ DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo(
4949
rest_version=requests_version,
5050
)
5151

52+
53+
class {{ service.name }}RestInterceptor:
54+
"""Interceptor for {{ service.name }}.
55+
56+
Interceptors are used to manipulate requests, request metadata, and responses
57+
in arbitrary ways.
58+
Example use cases include:
59+
* Logging
60+
* Verifying requests according to service or custom semantics
61+
* Stripping extraneous information from responses
62+
63+
These use cases and more can be enabled by injecting an
64+
instance of a custom subclass when constructing the {{ service.name }}RestTransport.
65+
66+
.. code-block:
67+
class MyCustom{{ service.name }}Interceptor({{ service.name }}RestInterceptor):
68+
{% for _, method in service.methods|dictsort if not (method.server_streaming or method.client_streaming) %}
69+
def pre_{{ method.name|snake_case }}(request, metadata):
70+
logging.log(f"Received request: {request}")
71+
return request, metadata
72+
73+
{% if not method.void %}
74+
def post_{{ method.name|snake_case }}(response):
75+
logging.log(f"Received response: {response}")
76+
{% endif %}
77+
78+
{% endfor %}
79+
transport = {{ service.name }}RestTransport(interceptor=MyCustom{{ service.name }}Interceptor())
80+
client = {{ service.client_name }}(transport=transport)
81+
82+
83+
"""
84+
{% for method in service.methods.values()|sort(attribute="name") if not (method.server_streaming or method.client_streaming) %}
85+
def pre_{{ method.name|snake_case }}(self, request: {{method.input.ident}}, metadata: Sequence[Tuple[str, str]]) -> Tuple[{{method.input.ident}}, Sequence[Tuple[str, str]]]:
86+
"""Pre-rpc interceptor for {{ method.name|snake_case }}
87+
88+
Override in a subclass to manipulate the request or metadata
89+
before they are sent to the {{ service.name }} server.
90+
"""
91+
return request, metadata
92+
93+
{% if not method.void %}
94+
def post_{{ method.name|snake_case }}(self, response: {{method.output.ident}}) -> {{method.output.ident}}:
95+
"""Post-rpc interceptor for {{ method.name|snake_case }}
96+
97+
Override in a subclass to manipulate the response
98+
after it is returned by the {{ service.name }} server but before
99+
it is returned to user code.
100+
"""
101+
return response
102+
{% endif %}
103+
104+
{% endfor %}
105+
106+
52107
@dataclasses.dataclass
53108
class {{service.name}}RestStub:
54109
_session: AuthorizedSession
55110
_host: str
111+
_interceptor: {{ service.name }}RestInterceptor
112+
56113

57114
class {{service.name}}RestTransport({{service.name}}Transport):
58115
"""REST backend transport for {{ service.name }}.
@@ -80,6 +137,7 @@ class {{service.name}}RestTransport({{service.name}}Transport):
80137
client_info: gapic_v1.client_info.ClientInfo=DEFAULT_CLIENT_INFO,
81138
always_use_jwt_access: Optional[bool]=False,
82139
url_scheme: str='https',
140+
interceptor: Optional[{{ service.name }}RestInterceptor] = None,
83141
) -> None:
84142
"""Instantiate the transport.
85143

@@ -130,6 +188,7 @@ class {{service.name}}RestTransport({{service.name}}Transport):
130188
{% endif %}
131189
if client_cert_source_for_mtls:
132190
self._session.configure_mtls_channel(client_cert_source_for_mtls)
191+
self._interceptor = interceptor or {{ service.name }}RestInterceptor()
133192
self._prep_wrapped_messages(client_info)
134193

135194
{% if service.has_lro %}
@@ -233,7 +292,7 @@ class {{service.name}}RestTransport({{service.name}}Transport):
233292
},
234293
{% endfor %}{# rule in method.http_options #}
235294
]
236-
295+
request, metadata = self._interceptor.pre_{{ method.name|snake_case }}(request, metadata)
237296
request_kwargs = {{method.input.ident}}.to_dict(request)
238297
transcoded_request = path_template.transcode(
239298
http_options, **request_kwargs)
@@ -288,16 +347,16 @@ class {{service.name}}RestTransport({{service.name}}Transport):
288347
{% if not method.void %}
289348
# Return the response
290349
{% if method.lro %}
291-
return_op = operations_pb2.Operation()
292-
json_format.Parse(response.content, return_op, ignore_unknown_fields=True)
293-
return return_op
350+
resp = operations_pb2.Operation()
351+
json_format.Parse(response.content, resp, ignore_unknown_fields=True)
294352
{% else %}
295-
return {{method.output.ident}}.from_json(
353+
resp = {{method.output.ident}}.from_json(
296354
response.content,
297355
ignore_unknown_fields=True
298356
)
299-
300357
{% endif %}{# method.lro #}
358+
resp = self._interceptor.post_{{ method.name|snake_case }}(resp)
359+
return resp
301360
{% endif %}{# method.void #}
302361
{% else %}{# method.http_options and not (method.server_streaming or method.client_streaming) #}
303362
{% if not method.http_options %}
@@ -323,7 +382,7 @@ class {{service.name}}RestTransport({{service.name}}Transport):
323382
{{method.output.ident}}]:
324383
stub = self._STUBS.get("{{method.name | snake_case}}")
325384
if not stub:
326-
stub = self._STUBS["{{method.name | snake_case}}"] = self._{{method.name}}(self._session, self._host)
385+
stub = self._STUBS["{{method.name | snake_case}}"] = self._{{method.name}}(self._session, self._host, self._interceptor)
327386

328387
return stub
329388

0 commit comments

Comments
 (0)