Skip to content

Commit 8342c69

Browse files
Ke Zhuebyhr
authored andcommitted
Add parameter extra_credential
Signed-off-by: Ke Zhu <[email protected]>
1 parent 0fb9d8f commit 8342c69

File tree

5 files changed

+116
-0
lines changed

5 files changed

+116
-0
lines changed

README.md

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,24 @@ The transaction is created when the first SQL statement is executed.
125125
exits the *with* context and the queries succeed, otherwise
126126
`trino.dbapi.Connection.rollback()` will be called.
127127

128+
# Extra Credential
129+
130+
Send [`extra credentials`](https://trino.io/docs/current/develop/client-protocol.html#client-request-headers):
131+
132+
```python
133+
import trino
134+
conn = trino.dbapi.connect(
135+
host='localhost',
136+
port=443,
137+
user='the-user',
138+
extra_credential=[('a.username', 'bar'), ('a.password', 'foo')],
139+
)
140+
141+
cur = conn.cursor()
142+
cur.execute('SELECT * FROM system.runtime.nodes')
143+
rows = cur.fetchall()
144+
```
145+
128146
# Development
129147

130148
## Getting Started With Development

tests/unit/test_client.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -713,6 +713,64 @@ def test_trino_connection_error(monkeypatch, error_code, error_type, error_messa
713713
assert error_message in str(error)
714714

715715

716+
def test_extra_credential(mock_get_and_post):
717+
_, post = mock_get_and_post
718+
719+
req = TrinoRequest(
720+
host="coordinator",
721+
port=constants.DEFAULT_TLS_PORT,
722+
user="test",
723+
extra_credential=[("a.username", "foo"), ("b.password", "bar")],
724+
)
725+
726+
req.post("SELECT 1")
727+
_, post_kwargs = post.call_args
728+
headers = post_kwargs["headers"]
729+
assert constants.HEADER_EXTRA_CREDENTIAL in headers
730+
assert headers[constants.HEADER_EXTRA_CREDENTIAL] == "a.username=foo, b.password=bar"
731+
732+
733+
def test_extra_credential_key_with_illegal_chars():
734+
with pytest.raises(ValueError) as e_info:
735+
TrinoRequest(
736+
host="coordinator",
737+
port=constants.DEFAULT_TLS_PORT,
738+
user="test",
739+
extra_credential=[("a=b", "")],
740+
)
741+
742+
assert str(e_info.value) == "whitespace or '=' are disallowed in extra credential 'a=b'"
743+
744+
745+
def test_extra_credential_key_non_ascii():
746+
with pytest.raises(ValueError) as e_info:
747+
TrinoRequest(
748+
host="coordinator",
749+
port=constants.DEFAULT_TLS_PORT,
750+
user="test",
751+
extra_credential=[("的", "")],
752+
)
753+
754+
assert str(e_info.value) == "only ASCII characters are allowed in extra credential '的'"
755+
756+
757+
def test_extra_credential_value_encoding(mock_get_and_post):
758+
_, post = mock_get_and_post
759+
760+
req = TrinoRequest(
761+
host="coordinator",
762+
port=constants.DEFAULT_TLS_PORT,
763+
user="test",
764+
extra_credential=[("foo", "bar 的")],
765+
)
766+
767+
req.post("SELECT 1")
768+
_, post_kwargs = post.call_args
769+
headers = post_kwargs["headers"]
770+
assert constants.HEADER_EXTRA_CREDENTIAL in headers
771+
assert headers[constants.HEADER_EXTRA_CREDENTIAL] == "foo=bar+%E7%9A%84"
772+
773+
716774
class RetryRecorder(object):
717775
def __init__(self, error=None, result=None):
718776
self.__name__ = "RetryRecorder"

trino/client.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535

3636
import copy
3737
import os
38+
import re
3839
from typing import Any, Dict, List, Optional, Tuple, Union
3940
import urllib.parse
4041

@@ -55,6 +56,8 @@
5556
else:
5657
PROXIES = {}
5758

59+
_HEADER_EXTRA_CREDENTIAL_KEY_REGEX = re.compile(r'^\S[^\s=]*$')
60+
5861

5962
class ClientSession(object):
6063
def __init__(
@@ -66,6 +69,7 @@ def __init__(
6669
properties=None,
6770
headers=None,
6871
transaction_id=None,
72+
extra_credential=None,
6973
):
7074
self.catalog = catalog
7175
self.schema = schema
@@ -76,6 +80,7 @@ def __init__(
7680
self._properties = properties
7781
self._headers = headers or {}
7882
self.transaction_id = transaction_id
83+
self.extra_credential = extra_credential
7984

8085
@property
8186
def properties(self):
@@ -153,6 +158,8 @@ class TrinoRequest(object):
153158
:param http_scheme: "http" or "https"
154159
:param auth: class that manages user authentication. ``None`` means no
155160
authentication.
161+
:param extra_credential: extra credentials. as list of ``(key, value)``
162+
tuples.
156163
:max_attempts: maximum number of attempts when sending HTTP requests. An
157164
attempt is an HTTP request. 5 attempts means 4 retries.
158165
:request_timeout: How long (in seconds) to wait for the server to send
@@ -206,6 +213,7 @@ def __init__(
206213
transaction_id: Optional[str] = NO_TRANSACTION,
207214
http_scheme: str = None,
208215
auth: Optional[Any] = constants.DEFAULT_AUTH,
216+
extra_credential: Optional[List[Tuple[str, str]]] = None,
209217
redirect_handler: Any = None,
210218
max_attempts: int = MAX_ATTEMPTS,
211219
request_timeout: Union[float, Tuple[float, float]] = constants.DEFAULT_REQUEST_TIMEOUT,
@@ -220,6 +228,7 @@ def __init__(
220228
session_properties,
221229
http_headers,
222230
transaction_id,
231+
extra_credential,
223232
)
224233

225234
self._host = host
@@ -285,6 +294,19 @@ def http_headers(self) -> Dict[str, str]:
285294
transaction_id = self._client_session.transaction_id
286295
headers[constants.HEADER_TRANSACTION] = transaction_id
287296

297+
if self._client_session.extra_credential is not None and \
298+
len(self._client_session.extra_credential) > 0:
299+
300+
for tup in self._client_session.extra_credential:
301+
self._verify_extra_credential(tup)
302+
303+
# HTTP 1.1 section 4.2 combine multiple extra credentials into a
304+
# comma-separated value
305+
# extra credential value is encoded per spec (application/x-www-form-urlencoded MIME format)
306+
headers[constants.HEADER_EXTRA_CREDENTIAL] = \
307+
", ".join(
308+
[f"{tup[0]}={urllib.parse.quote_plus(tup[1])}" for tup in self._client_session.extra_credential])
309+
288310
return headers
289311

290312
@property
@@ -427,6 +449,20 @@ def process(self, http_response) -> TrinoStatus:
427449
columns=response.get("columns"),
428450
)
429451

452+
def _verify_extra_credential(self, header):
453+
"""
454+
Verifies that key has ASCII only and non-whitespace characters.
455+
"""
456+
key = header[0]
457+
458+
if not _HEADER_EXTRA_CREDENTIAL_KEY_REGEX.match(key):
459+
raise ValueError(f"whitespace or '=' are disallowed in extra credential '{key}'")
460+
461+
try:
462+
key.encode().decode('ascii')
463+
except UnicodeDecodeError:
464+
raise ValueError(f"only ASCII characters are allowed in extra credential '{key}'")
465+
430466

431467
class TrinoResult(object):
432468
"""

trino/constants.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
HEADER_SOURCE = "X-Trino-Source"
3333
HEADER_USER = "X-Trino-User"
3434
HEADER_CLIENT_INFO = "X-Trino-Client-Info"
35+
HEADER_EXTRA_CREDENTIAL = "X-Trino-Extra-Credential"
3536

3637
HEADER_SESSION = "X-Trino-Session"
3738
HEADER_SET_SESSION = "X-Trino-Set-Session"

trino/dbapi.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,7 @@ def __init__(
102102
http_headers=None,
103103
http_scheme=constants.HTTP,
104104
auth=constants.DEFAULT_AUTH,
105+
extra_credential=None,
105106
redirect_handler=None,
106107
max_attempts=constants.DEFAULT_MAX_ATTEMPTS,
107108
request_timeout=constants.DEFAULT_REQUEST_TIMEOUT,
@@ -125,6 +126,7 @@ def __init__(
125126
self.http_headers = http_headers
126127
self.http_scheme = http_scheme
127128
self.auth = auth
129+
self.extra_credential = extra_credential
128130
self.redirect_handler = redirect_handler
129131
self.max_attempts = max_attempts
130132
self.request_timeout = request_timeout
@@ -188,6 +190,7 @@ def _create_request(self):
188190
NO_TRANSACTION,
189191
self.http_scheme,
190192
self.auth,
193+
self.extra_credential,
191194
self.redirect_handler,
192195
self.max_attempts,
193196
self.request_timeout,

0 commit comments

Comments
 (0)