Skip to content

Commit d528223

Browse files
fix: refactor mtls logic to standalone method (googleapis#1123)
* fix: refactor mtls logic to standalone method * chore: update tests * chore: fix unit tests * chore: update unit test * chore: update integration tests * chore: update async code
1 parent 5750c55 commit d528223

File tree

21 files changed

+1042
-266
lines changed

21 files changed

+1042
-266
lines changed

gapic/templates/%namespace/%name_%version/%sub/services/%service/async_client.py.j2

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from collections import OrderedDict
66
import functools
77
import re
8-
from typing import Dict, {% if service.any_server_streaming %}AsyncIterable, Awaitable, {% endif %}{% if service.any_client_streaming %}AsyncIterator, {% endif %}Sequence, Tuple, Type, Union
8+
from typing import Dict, Optional, {% if service.any_server_streaming %}AsyncIterable, Awaitable, {% endif %}{% if service.any_client_streaming %}AsyncIterator, {% endif %}Sequence, Tuple, Type, Union
99
import pkg_resources
1010
{% if service.any_deprecated %}
1111
import warnings
@@ -90,6 +90,40 @@ class {{ service.async_client_name }}:
9090

9191
from_service_account_json = from_service_account_file
9292

93+
@classmethod
94+
def get_mtls_endpoint_and_cert_source(cls, client_options: Optional[ClientOptions] = None):
95+
"""Return the API endpoint and client cert source for mutual TLS.
96+
97+
The client cert source is determined in the following order:
98+
(1) if `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is not "true", the
99+
client cert source is None.
100+
(2) if `client_options.client_cert_source` is provided, use the provided one; if the
101+
default client cert source exists, use the default one; otherwise the client cert
102+
source is None.
103+
104+
The API endpoint is determined in the following order:
105+
(1) if `client_options.api_endpoint` if provided, use the provided one.
106+
(2) if `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is "always", use the
107+
default mTLS endpoint; if the environment variabel is "never", use the default API
108+
endpoint; otherwise if client cert source exists, use the default mTLS endpoint, otherwise
109+
use the default API endpoint.
110+
111+
More details can be found at https://google.aip.dev/auth/4114.
112+
113+
Args:
114+
client_options (google.api_core.client_options.ClientOptions): Custom options for the
115+
client. Only the `api_endpoint` and `client_cert_source` properties may be used
116+
in this method.
117+
118+
Returns:
119+
Tuple[str, Callable[[], Tuple[bytes, bytes]]]: returns the API endpoint and the
120+
client cert source to use.
121+
122+
Raises:
123+
google.auth.exceptions.MutualTLSChannelError: If any errors happen.
124+
"""
125+
return {{ service.client_name }}.get_mtls_endpoint_and_cert_source(client_options) # type: ignore
126+
93127
@property
94128
def transport(self) -> {{ service.name }}Transport:
95129
"""Returns the transport used by the client instance.

gapic/templates/%namespace/%name_%version/%sub/services/%service/client.py.j2

Lines changed: 60 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,65 @@ class {{ service.client_name }}(metaclass={{ service.client_name }}Meta):
200200

201201
{% endfor %}{# common resources #}
202202

203+
@classmethod
204+
def get_mtls_endpoint_and_cert_source(cls, client_options: Optional[client_options_lib.ClientOptions] = None):
205+
"""Return the API endpoint and client cert source for mutual TLS.
206+
207+
The client cert source is determined in the following order:
208+
(1) if `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is not "true", the
209+
client cert source is None.
210+
(2) if `client_options.client_cert_source` is provided, use the provided one; if the
211+
default client cert source exists, use the default one; otherwise the client cert
212+
source is None.
213+
214+
The API endpoint is determined in the following order:
215+
(1) if `client_options.api_endpoint` if provided, use the provided one.
216+
(2) if `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is "always", use the
217+
default mTLS endpoint; if the environment variabel is "never", use the default API
218+
endpoint; otherwise if client cert source exists, use the default mTLS endpoint, otherwise
219+
use the default API endpoint.
220+
221+
More details can be found at https://google.aip.dev/auth/4114.
222+
223+
Args:
224+
client_options (google.api_core.client_options.ClientOptions): Custom options for the
225+
client. Only the `api_endpoint` and `client_cert_source` properties may be used
226+
in this method.
227+
228+
Returns:
229+
Tuple[str, Callable[[], Tuple[bytes, bytes]]]: returns the API endpoint and the
230+
client cert source to use.
231+
232+
Raises:
233+
google.auth.exceptions.MutualTLSChannelError: If any errors happen.
234+
"""
235+
if client_options is None:
236+
client_options = client_options_lib.ClientOptions()
237+
use_client_cert = os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false")
238+
use_mtls_endpoint = os.getenv("GOOGLE_API_USE_MTLS_ENDPOINT", "auto")
239+
if use_client_cert not in ("true", "false"):
240+
raise ValueError("Environment variable `GOOGLE_API_USE_CLIENT_CERTIFICATE` must be either `true` or `false`")
241+
if use_mtls_endpoint not in ("auto", "never", "always"):
242+
raise MutualTLSChannelError("Environment variable `GOOGLE_API_USE_MTLS_ENDPOINT` must be `never`, `auto` or `always`")
243+
244+
# Figure out the client cert source to use.
245+
client_cert_source = None
246+
if use_client_cert == "true":
247+
if client_options.client_cert_source:
248+
client_cert_source = client_options.client_cert_source
249+
elif mtls.has_default_client_cert_source():
250+
client_cert_source = mtls.default_client_cert_source()
251+
252+
# Figure out which api endpoint to use.
253+
if client_options.api_endpoint is not None:
254+
api_endpoint = client_options.api_endpoint
255+
elif use_mtls_endpoint == "always" or (use_mtls_endpoint == "auto" and client_cert_source):
256+
api_endpoint = cls.DEFAULT_MTLS_ENDPOINT
257+
else:
258+
api_endpoint = cls.DEFAULT_ENDPOINT
259+
260+
return api_endpoint, client_cert_source
261+
203262
def __init__(self, *,
204263
credentials: Optional[ga_credentials.Credentials] = None,
205264
transport: Union[str, {{ service.name }}Transport, None] = None,
@@ -248,43 +307,7 @@ class {{ service.client_name }}(metaclass={{ service.client_name }}Meta):
248307
if client_options is None:
249308
client_options = client_options_lib.ClientOptions()
250309

251-
# Create SSL credentials for mutual TLS if needed.
252-
if os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false") not in ("true", "false"):
253-
raise ValueError("Environment variable `GOOGLE_API_USE_CLIENT_CERTIFICATE` must be either `true` or `false`")
254-
use_client_cert = os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false") == "true"
255-
256-
client_cert_source_func = None
257-
is_mtls = False
258-
if use_client_cert:
259-
if client_options.client_cert_source:
260-
is_mtls = True
261-
client_cert_source_func = client_options.client_cert_source
262-
else:
263-
is_mtls = mtls.has_default_client_cert_source()
264-
if is_mtls:
265-
client_cert_source_func = mtls.default_client_cert_source()
266-
else:
267-
client_cert_source_func = None
268-
269-
# Figure out which api endpoint to use.
270-
if client_options.api_endpoint is not None:
271-
api_endpoint = client_options.api_endpoint
272-
else:
273-
use_mtls_env = os.getenv("GOOGLE_API_USE_MTLS_ENDPOINT", "auto")
274-
if use_mtls_env == "never":
275-
api_endpoint = self.DEFAULT_ENDPOINT
276-
elif use_mtls_env == "always":
277-
api_endpoint = self.DEFAULT_MTLS_ENDPOINT
278-
elif use_mtls_env == "auto":
279-
if is_mtls:
280-
api_endpoint = self.DEFAULT_MTLS_ENDPOINT
281-
else:
282-
api_endpoint = self.DEFAULT_ENDPOINT
283-
else:
284-
raise MutualTLSChannelError(
285-
"Unsupported GOOGLE_API_USE_MTLS_ENDPOINT value. Accepted "
286-
"values: never, auto, always"
287-
)
310+
api_endpoint, client_cert_source_func = self.get_mtls_endpoint_and_cert_source(client_options)
288311

289312
# Save or instantiate the transport.
290313
# Ordinarily, we provide the transport, but allowing a custom transport

gapic/templates/tests/unit/gapic/%name_%version/%sub/test_%service.py.j2

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -359,6 +359,65 @@ def test_{{ service.client_name|snake_case }}_mtls_env_auto(client_class, transp
359359
)
360360

361361

362+
@pytest.mark.parametrize("client_class", [
363+
{% if 'grpc' in opts.transport %}
364+
{{ service.client_name }}, {{ service.async_client_name }}
365+
{% elif 'rest' in opts.transport %}
366+
{{ service.client_name }}
367+
{% endif %}
368+
])
369+
@mock.patch.object({{ service.client_name }}, "DEFAULT_ENDPOINT", modify_default_endpoint({{ service.client_name }}))
370+
{% if 'grpc' in opts.transport %}
371+
@mock.patch.object({{ service.async_client_name }}, "DEFAULT_ENDPOINT", modify_default_endpoint({{ service.async_client_name }}))
372+
{% endif %}
373+
def test_{{ service.client_name|snake_case }}_get_mtls_endpoint_and_cert_source(client_class):
374+
mock_client_cert_source = mock.Mock()
375+
376+
# Test the case GOOGLE_API_USE_CLIENT_CERTIFICATE is "true".
377+
with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "true"}):
378+
mock_api_endpoint = "foo"
379+
options = client_options.ClientOptions(client_cert_source=mock_client_cert_source, api_endpoint=mock_api_endpoint)
380+
api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source(options)
381+
assert api_endpoint == mock_api_endpoint
382+
assert cert_source == mock_client_cert_source
383+
384+
# Test the case GOOGLE_API_USE_CLIENT_CERTIFICATE is "false".
385+
with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "false"}):
386+
mock_client_cert_source = mock.Mock()
387+
mock_api_endpoint = "foo"
388+
options = client_options.ClientOptions(client_cert_source=mock_client_cert_source, api_endpoint=mock_api_endpoint)
389+
api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source(options)
390+
assert api_endpoint == mock_api_endpoint
391+
assert cert_source is None
392+
393+
# Test the case GOOGLE_API_USE_MTLS_ENDPOINT is "never".
394+
with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "never"}):
395+
api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source()
396+
assert api_endpoint == client_class.DEFAULT_ENDPOINT
397+
assert cert_source is None
398+
399+
# Test the case GOOGLE_API_USE_MTLS_ENDPOINT is "always".
400+
with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "always"}):
401+
api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source()
402+
assert api_endpoint == client_class.DEFAULT_MTLS_ENDPOINT
403+
assert cert_source is None
404+
405+
# Test the case GOOGLE_API_USE_MTLS_ENDPOINT is "auto" and default cert doesn't exist.
406+
with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "true"}):
407+
with mock.patch('google.auth.transport.mtls.has_default_client_cert_source', return_value=False):
408+
api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source()
409+
assert api_endpoint == client_class.DEFAULT_ENDPOINT
410+
assert cert_source is None
411+
412+
# Test the case GOOGLE_API_USE_MTLS_ENDPOINT is "auto" and default cert exists.
413+
with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "true"}):
414+
with mock.patch('google.auth.transport.mtls.has_default_client_cert_source', return_value=True):
415+
with mock.patch('google.auth.transport.mtls.default_client_cert_source', return_value=mock_client_cert_source):
416+
api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source()
417+
assert api_endpoint == client_class.DEFAULT_MTLS_ENDPOINT
418+
assert cert_source == mock_client_cert_source
419+
420+
362421
@pytest.mark.parametrize("client_class,transport_class,transport_name", [
363422
{% if 'grpc' in opts.transport %}
364423
({{ service.client_name }}, transports.{{ service.grpc_transport_name }}, "grpc"),

tests/integration/goldens/asset/google/cloud/asset_v1/services/asset_service/async_client.py

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from collections import OrderedDict
1717
import functools
1818
import re
19-
from typing import Dict, Sequence, Tuple, Type, Union
19+
from typing import Dict, Optional, Sequence, Tuple, Type, Union
2020
import pkg_resources
2121

2222
from google.api_core.client_options import ClientOptions
@@ -98,6 +98,40 @@ def from_service_account_file(cls, filename: str, *args, **kwargs):
9898

9999
from_service_account_json = from_service_account_file
100100

101+
@classmethod
102+
def get_mtls_endpoint_and_cert_source(cls, client_options: Optional[ClientOptions] = None):
103+
"""Return the API endpoint and client cert source for mutual TLS.
104+
105+
The client cert source is determined in the following order:
106+
(1) if `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is not "true", the
107+
client cert source is None.
108+
(2) if `client_options.client_cert_source` is provided, use the provided one; if the
109+
default client cert source exists, use the default one; otherwise the client cert
110+
source is None.
111+
112+
The API endpoint is determined in the following order:
113+
(1) if `client_options.api_endpoint` if provided, use the provided one.
114+
(2) if `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is "always", use the
115+
default mTLS endpoint; if the environment variabel is "never", use the default API
116+
endpoint; otherwise if client cert source exists, use the default mTLS endpoint, otherwise
117+
use the default API endpoint.
118+
119+
More details can be found at https://google.aip.dev/auth/4114.
120+
121+
Args:
122+
client_options (google.api_core.client_options.ClientOptions): Custom options for the
123+
client. Only the `api_endpoint` and `client_cert_source` properties may be used
124+
in this method.
125+
126+
Returns:
127+
Tuple[str, Callable[[], Tuple[bytes, bytes]]]: returns the API endpoint and the
128+
client cert source to use.
129+
130+
Raises:
131+
google.auth.exceptions.MutualTLSChannelError: If any errors happen.
132+
"""
133+
return AssetServiceClient.get_mtls_endpoint_and_cert_source(client_options) # type: ignore
134+
101135
@property
102136
def transport(self) -> AssetServiceTransport:
103137
"""Returns the transport used by the client instance.

tests/integration/goldens/asset/google/cloud/asset_v1/services/asset_service/client.py

Lines changed: 60 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -240,6 +240,65 @@ def parse_common_location_path(path: str) -> Dict[str,str]:
240240
m = re.match(r"^projects/(?P<project>.+?)/locations/(?P<location>.+?)$", path)
241241
return m.groupdict() if m else {}
242242

243+
@classmethod
244+
def get_mtls_endpoint_and_cert_source(cls, client_options: Optional[client_options_lib.ClientOptions] = None):
245+
"""Return the API endpoint and client cert source for mutual TLS.
246+
247+
The client cert source is determined in the following order:
248+
(1) if `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is not "true", the
249+
client cert source is None.
250+
(2) if `client_options.client_cert_source` is provided, use the provided one; if the
251+
default client cert source exists, use the default one; otherwise the client cert
252+
source is None.
253+
254+
The API endpoint is determined in the following order:
255+
(1) if `client_options.api_endpoint` if provided, use the provided one.
256+
(2) if `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is "always", use the
257+
default mTLS endpoint; if the environment variabel is "never", use the default API
258+
endpoint; otherwise if client cert source exists, use the default mTLS endpoint, otherwise
259+
use the default API endpoint.
260+
261+
More details can be found at https://google.aip.dev/auth/4114.
262+
263+
Args:
264+
client_options (google.api_core.client_options.ClientOptions): Custom options for the
265+
client. Only the `api_endpoint` and `client_cert_source` properties may be used
266+
in this method.
267+
268+
Returns:
269+
Tuple[str, Callable[[], Tuple[bytes, bytes]]]: returns the API endpoint and the
270+
client cert source to use.
271+
272+
Raises:
273+
google.auth.exceptions.MutualTLSChannelError: If any errors happen.
274+
"""
275+
if client_options is None:
276+
client_options = client_options_lib.ClientOptions()
277+
use_client_cert = os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false")
278+
use_mtls_endpoint = os.getenv("GOOGLE_API_USE_MTLS_ENDPOINT", "auto")
279+
if use_client_cert not in ("true", "false"):
280+
raise ValueError("Environment variable `GOOGLE_API_USE_CLIENT_CERTIFICATE` must be either `true` or `false`")
281+
if use_mtls_endpoint not in ("auto", "never", "always"):
282+
raise MutualTLSChannelError("Environment variable `GOOGLE_API_USE_MTLS_ENDPOINT` must be `never`, `auto` or `always`")
283+
284+
# Figure out the client cert source to use.
285+
client_cert_source = None
286+
if use_client_cert == "true":
287+
if client_options.client_cert_source:
288+
client_cert_source = client_options.client_cert_source
289+
elif mtls.has_default_client_cert_source():
290+
client_cert_source = mtls.default_client_cert_source()
291+
292+
# Figure out which api endpoint to use.
293+
if client_options.api_endpoint is not None:
294+
api_endpoint = client_options.api_endpoint
295+
elif use_mtls_endpoint == "always" or (use_mtls_endpoint == "auto" and client_cert_source):
296+
api_endpoint = cls.DEFAULT_MTLS_ENDPOINT
297+
else:
298+
api_endpoint = cls.DEFAULT_ENDPOINT
299+
300+
return api_endpoint, client_cert_source
301+
243302
def __init__(self, *,
244303
credentials: Optional[ga_credentials.Credentials] = None,
245304
transport: Union[str, AssetServiceTransport, None] = None,
@@ -288,43 +347,7 @@ def __init__(self, *,
288347
if client_options is None:
289348
client_options = client_options_lib.ClientOptions()
290349

291-
# Create SSL credentials for mutual TLS if needed.
292-
if os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false") not in ("true", "false"):
293-
raise ValueError("Environment variable `GOOGLE_API_USE_CLIENT_CERTIFICATE` must be either `true` or `false`")
294-
use_client_cert = os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false") == "true"
295-
296-
client_cert_source_func = None
297-
is_mtls = False
298-
if use_client_cert:
299-
if client_options.client_cert_source:
300-
is_mtls = True
301-
client_cert_source_func = client_options.client_cert_source
302-
else:
303-
is_mtls = mtls.has_default_client_cert_source()
304-
if is_mtls:
305-
client_cert_source_func = mtls.default_client_cert_source()
306-
else:
307-
client_cert_source_func = None
308-
309-
# Figure out which api endpoint to use.
310-
if client_options.api_endpoint is not None:
311-
api_endpoint = client_options.api_endpoint
312-
else:
313-
use_mtls_env = os.getenv("GOOGLE_API_USE_MTLS_ENDPOINT", "auto")
314-
if use_mtls_env == "never":
315-
api_endpoint = self.DEFAULT_ENDPOINT
316-
elif use_mtls_env == "always":
317-
api_endpoint = self.DEFAULT_MTLS_ENDPOINT
318-
elif use_mtls_env == "auto":
319-
if is_mtls:
320-
api_endpoint = self.DEFAULT_MTLS_ENDPOINT
321-
else:
322-
api_endpoint = self.DEFAULT_ENDPOINT
323-
else:
324-
raise MutualTLSChannelError(
325-
"Unsupported GOOGLE_API_USE_MTLS_ENDPOINT value. Accepted "
326-
"values: never, auto, always"
327-
)
350+
api_endpoint, client_cert_source_func = self.get_mtls_endpoint_and_cert_source(client_options)
328351

329352
# Save or instantiate the transport.
330353
# Ordinarily, we provide the transport, but allowing a custom transport

0 commit comments

Comments
 (0)