Skip to content

Commit 70d0546

Browse files
mdesmetabrisan
authored andcommitted
Add support for client tags in Python API
Co-authored-by: abrisan <[email protected]>
1 parent ee8f1a5 commit 70d0546

File tree

6 files changed

+111
-2
lines changed

6 files changed

+111
-2
lines changed

tests/integration/test_dbapi_integration.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
import pytest
1717
import pytz
18+
import requests
1819

1920
import trino
2021
from tests.integration.conftest import trino_version
@@ -867,3 +868,46 @@ def test_info_uri(trino_connection):
867868
cur.fetchall()
868869
assert cur.info_uri is not None
869870
assert cur._query.query_id in cur.info_uri
871+
872+
873+
def test_client_tags_single_tag(run_trino):
874+
client_tags = ["foo"]
875+
query_client_tags = retrieve_client_tags_from_query(run_trino, client_tags)
876+
assert query_client_tags == client_tags
877+
878+
879+
def test_client_tags_multiple_tags(run_trino):
880+
client_tags = ["foo", "bar"]
881+
query_client_tags = retrieve_client_tags_from_query(run_trino, client_tags)
882+
assert query_client_tags == client_tags
883+
884+
885+
def test_client_tags_special_characters(run_trino):
886+
client_tags = ["foo %20", "bar=test"]
887+
query_client_tags = retrieve_client_tags_from_query(run_trino, client_tags)
888+
assert query_client_tags == client_tags
889+
890+
891+
def retrieve_client_tags_from_query(run_trino, client_tags):
892+
_, host, port = run_trino
893+
894+
trino_connection = trino.dbapi.Connection(
895+
host=host,
896+
port=port,
897+
user="test",
898+
client_tags=client_tags,
899+
)
900+
901+
cur = trino_connection.cursor()
902+
cur.execute('SELECT 1')
903+
cur.fetchall()
904+
905+
api_url = "http://" + trino_connection.host + ":" + str(trino_connection.port)
906+
query_info = requests.post(api_url + "/ui/login", data={
907+
"username": "admin",
908+
"password": "",
909+
"redirectPath": api_url + '/ui/api/query/' + cur._query.query_id
910+
}).json()
911+
912+
query_client_tags = query_info['session']['clientTags']
913+
return query_client_tags

tests/unit/test_client.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,49 @@ def test_request_invalid_http_headers():
172172
assert str(value_error.value).startswith("cannot override reserved HTTP header")
173173

174174

175+
def test_request_client_tags_headers(mock_get_and_post):
176+
get, post = mock_get_and_post
177+
178+
req = TrinoRequest(
179+
host="coordinator",
180+
port=8080,
181+
user="test_user",
182+
client_tags=["tag1", "tag2"]
183+
)
184+
185+
def assert_headers(headers):
186+
assert headers[constants.HEADER_CLIENT_TAGS] == "tag1,tag2"
187+
188+
req.post("URL")
189+
_, post_kwargs = post.call_args
190+
assert_headers(post_kwargs["headers"])
191+
192+
req.get("URL")
193+
_, get_kwargs = get.call_args
194+
assert_headers(get_kwargs["headers"])
195+
196+
197+
def test_request_client_tags_headers_no_client_tags(mock_get_and_post):
198+
get, post = mock_get_and_post
199+
200+
req = TrinoRequest(
201+
host="coordinator",
202+
port=8080,
203+
user="test_user"
204+
)
205+
206+
def assert_headers(headers):
207+
assert constants.HEADER_CLIENT_TAGS not in headers
208+
209+
req.post("URL")
210+
_, post_kwargs = post.call_args
211+
assert_headers(post_kwargs["headers"])
212+
213+
req.get("URL")
214+
_, get_kwargs = get.call_args
215+
assert_headers(get_kwargs["headers"])
216+
217+
175218
def test_enabling_https_automatically_when_using_port_443(mock_get_and_post):
176219
_, post = mock_get_and_post
177220

tests/unit/test_dbapi.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -222,3 +222,15 @@ def run(self) -> None:
222222
thread.join()
223223

224224
assert len(_get_token_requests(challenge_id)) == 1
225+
226+
227+
@patch("trino.dbapi.trino.client")
228+
def test_tags_are_set_when_specified(mock_client):
229+
# WHEN
230+
client_tags = ["TAG1", "TAG2"]
231+
with connect("sample_trino_cluster:443", client_tags=client_tags) as conn:
232+
conn.cursor().execute("SOME FAKE QUERY")
233+
234+
# THEN
235+
_, passed_client_tags = mock_client.TrinoRequest.call_args
236+
assert passed_client_tags["client_tags"] == client_tags

trino/client.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ def __init__(
7777
headers=None,
7878
transaction_id=None,
7979
extra_credential=None,
80+
client_tags=None
8081
):
8182
self.catalog = catalog
8283
self.schema = schema
@@ -88,6 +89,7 @@ def __init__(
8889
self._headers = headers or {}
8990
self.transaction_id = transaction_id
9091
self.extra_credential = extra_credential
92+
self.client_tags = client_tags
9193

9294
@property
9395
def properties(self):
@@ -226,7 +228,8 @@ def __init__(
226228
max_attempts: int = MAX_ATTEMPTS,
227229
request_timeout: Union[float, Tuple[float, float]] = constants.DEFAULT_REQUEST_TIMEOUT,
228230
handle_retry=exceptions.RetryWithExponentialBackoff(),
229-
verify: bool = True
231+
verify: bool = True,
232+
client_tags: Optional[List[str]] = None
230233
) -> None:
231234
self._client_session = ClientSession(
232235
catalog,
@@ -237,6 +240,7 @@ def __init__(
237240
http_headers,
238241
transaction_id,
239242
extra_credential,
243+
client_tags
240244
)
241245

242246
self._host = host
@@ -286,6 +290,8 @@ def http_headers(self) -> Dict[str, str]:
286290
headers[constants.HEADER_SCHEMA] = self._client_session.schema
287291
headers[constants.HEADER_SOURCE] = self._client_session.source
288292
headers[constants.HEADER_USER] = self._client_session.user
293+
if self._client_session.client_tags is not None and len(self._client_session.client_tags) > 0:
294+
headers[constants.HEADER_CLIENT_TAGS] = ",".join(self._client_session.client_tags)
289295

290296
headers[constants.HEADER_SESSION] = ",".join(
291297
# ``name`` must not contain ``=``

trino/constants.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
HEADER_SOURCE = "X-Trino-Source"
3333
HEADER_USER = "X-Trino-User"
3434
HEADER_CLIENT_INFO = "X-Trino-Client-Info"
35+
HEADER_CLIENT_TAGS = "X-Trino-Client-Tags"
3536
HEADER_EXTRA_CREDENTIAL = "X-Trino-Extra-Credential"
3637

3738
HEADER_SESSION = "X-Trino-Session"

trino/dbapi.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,8 @@ def __init__(
108108
request_timeout=constants.DEFAULT_REQUEST_TIMEOUT,
109109
isolation_level=IsolationLevel.AUTOCOMMIT,
110110
verify=True,
111-
http_session=None
111+
http_session=None,
112+
client_tags=None
112113
):
113114
self.host = host
114115
self.port = port
@@ -130,6 +131,7 @@ def __init__(
130131
self.redirect_handler = redirect_handler
131132
self.max_attempts = max_attempts
132133
self.request_timeout = request_timeout
134+
self.client_tags = client_tags
133135

134136
self._isolation_level = isolation_level
135137
self._request = None
@@ -194,6 +196,7 @@ def _create_request(self):
194196
self.redirect_handler,
195197
self.max_attempts,
196198
self.request_timeout,
199+
client_tags=self.client_tags
197200
)
198201

199202
def cursor(self, experimental_python_types=False):

0 commit comments

Comments
 (0)