Skip to content

Commit 856d8e9

Browse files
huw0hashhar
authored andcommitted
Support object as value in extra_credential
1 parent 670d5f7 commit 856d8e9

File tree

2 files changed

+37
-1
lines changed

2 files changed

+37
-1
lines changed

tests/unit/test_client.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -867,6 +867,41 @@ def test_extra_credential_value_encoding(mock_get_and_post):
867867
assert headers[constants.HEADER_EXTRA_CREDENTIAL] == "foo=bar+%E7%9A%84"
868868

869869

870+
def test_extra_credential_value_object(mock_get_and_post):
871+
_, post = mock_get_and_post
872+
873+
class TestCredential(object):
874+
value = "initial"
875+
876+
def __str__(self):
877+
return self.value
878+
879+
credential = TestCredential()
880+
881+
req = TrinoRequest(
882+
host="coordinator",
883+
port=constants.DEFAULT_TLS_PORT,
884+
client_session=ClientSession(
885+
user="test",
886+
extra_credential=[("foo", credential)]
887+
)
888+
)
889+
890+
req.post("SELECT 1")
891+
_, post_kwargs = post.call_args
892+
headers = post_kwargs["headers"]
893+
assert constants.HEADER_EXTRA_CREDENTIAL in headers
894+
assert headers[constants.HEADER_EXTRA_CREDENTIAL] == "foo=initial"
895+
896+
# Make a second request, assert that credential has changed
897+
credential.value = "changed"
898+
req.post("SELECT 1")
899+
_, post_kwargs = post.call_args
900+
headers = post_kwargs["headers"]
901+
assert constants.HEADER_EXTRA_CREDENTIAL in headers
902+
assert headers[constants.HEADER_EXTRA_CREDENTIAL] == "foo=changed"
903+
904+
870905
class MockGssapiCredentials:
871906
def __init__(self, name: gssapi.Name, usage: str):
872907
self.name = name

trino/client.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -486,7 +486,8 @@ def http_headers(self) -> Dict[str, str]:
486486
# extra credential value is encoded per spec (application/x-www-form-urlencoded MIME format)
487487
headers[constants.HEADER_EXTRA_CREDENTIAL] = \
488488
", ".join(
489-
[f"{tup[0]}={urllib.parse.quote_plus(tup[1])}" for tup in self._client_session.extra_credential])
489+
[f"{tup[0]}={urllib.parse.quote_plus(str(tup[1]))}"
490+
for tup in self._client_session.extra_credential])
490491

491492
return headers
492493

0 commit comments

Comments
 (0)