Skip to content

Commit 560556b

Browse files
hovaescohashhar
authored andcommitted
Shard long JWT token on Windows
1 parent d4bf4b8 commit 560556b

File tree

5 files changed

+120
-7
lines changed

5 files changed

+120
-7
lines changed

setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
"pre-commit",
4747
"black",
4848
"isort",
49+
"keyring"
4950
]
5051

5152
setup(

tests/unit/test_client.py

Lines changed: 81 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,12 @@
1616
import uuid
1717
from contextlib import nullcontext as does_not_raise
1818
from typing import Any, Dict, Optional
19-
from unittest import mock
19+
from unittest import TestCase, mock
2020
from urllib.parse import urlparse
2121

2222
import gssapi
2323
import httpretty
24+
import keyring
2425
import pytest
2526
import requests
2627
from httpretty import httprettified
@@ -42,7 +43,12 @@
4243
_post_statement_requests,
4344
)
4445
from trino import __version__, constants
45-
from trino.auth import GSSAPIAuthentication, KerberosAuthentication, _OAuth2TokenBearer
46+
from trino.auth import (
47+
GSSAPIAuthentication,
48+
KerberosAuthentication,
49+
_OAuth2KeyRingTokenCache,
50+
_OAuth2TokenBearer,
51+
)
4652
from trino.client import (
4753
ClientSession,
4854
TrinoQuery,
@@ -1343,3 +1349,76 @@ def test_request_with_invalid_timezone(mock_get_and_post):
13431349
),
13441350
)
13451351
assert str(zinfo_error.value).startswith("'No time zone found with key")
1352+
1353+
1354+
class TestShardedPassword(TestCase):
1355+
def test_store_short_password(self):
1356+
# set the keyring to mock class
1357+
keyring.set_keyring(MockKeyring())
1358+
1359+
host = "trino.com"
1360+
short_password = "x" * 10
1361+
1362+
cache = _OAuth2KeyRingTokenCache()
1363+
cache.store_token_to_cache(host, short_password)
1364+
1365+
retrieved_password = cache.get_token_from_cache(host)
1366+
self.assertEqual(short_password, retrieved_password)
1367+
1368+
def test_store_long_password(self):
1369+
# set the keyring to mock class
1370+
keyring.set_keyring(MockKeyring())
1371+
1372+
host = "trino.com"
1373+
long_password = "x" * 3000
1374+
1375+
cache = _OAuth2KeyRingTokenCache()
1376+
cache.store_token_to_cache(host, long_password)
1377+
1378+
retrieved_password = cache.get_token_from_cache(host)
1379+
self.assertEqual(long_password, retrieved_password)
1380+
1381+
1382+
class MockKeyring(keyring.backend.KeyringBackend):
1383+
def __init__(self):
1384+
self.file_location = self._generate_test_root_dir()
1385+
1386+
@staticmethod
1387+
def _generate_test_root_dir():
1388+
import tempfile
1389+
1390+
return tempfile.mkdtemp(prefix="trino-python-client-unit-test-")
1391+
1392+
def file_path(self, servicename, username):
1393+
from os.path import join
1394+
1395+
file_location = self.file_location
1396+
file_name = f"{servicename}_{username}.txt"
1397+
return join(file_location, file_name)
1398+
1399+
def set_password(self, servicename, username, password):
1400+
file_path = self.file_path(servicename, username)
1401+
1402+
with open(file_path, "w") as file:
1403+
file.write(password)
1404+
1405+
def get_password(self, servicename, username):
1406+
import os
1407+
1408+
file_path = self.file_path(servicename, username)
1409+
if not os.path.exists(file_path):
1410+
return None
1411+
1412+
with open(file_path, "r") as file:
1413+
password = file.read()
1414+
1415+
return password
1416+
1417+
def delete_password(self, servicename, username):
1418+
import os
1419+
1420+
file_path = self.file_path(servicename, username)
1421+
if not os.path.exists(file_path):
1422+
return None
1423+
1424+
os.remove(file_path)

tests/unit/test_dbapi.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ def test_token_retrieved_once_per_auth_instance(sample_post_response_data, sampl
119119
conn2.cursor().execute("SELECT 2")
120120
conn2.cursor().execute("SELECT 3")
121121

122-
assert len(_get_token_requests(challenge_id)) == 2
122+
assert len(_get_token_requests(challenge_id)) == 1
123123

124124

125125
@httprettified

trino/auth.py

Lines changed: 36 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525

2626
import trino.logging
2727
from trino.client import exceptions
28-
from trino.constants import HEADER_USER
28+
from trino.constants import HEADER_USER, MAX_NT_PASSWORD_SIZE
2929

3030
logger = trino.logging.get_logger(__name__)
3131

@@ -347,17 +347,49 @@ def is_keyring_available(self) -> bool:
347347
and not isinstance(self._keyring.get_keyring(), self._keyring.backends.fail.Keyring)
348348

349349
def get_token_from_cache(self, key: Optional[str]) -> Optional[str]:
350+
password = self._keyring.get_password(key, "token")
351+
350352
try:
351-
return self._keyring.get_password(key, "token")
353+
password_as_dict = json.loads(str(password))
354+
if password_as_dict.get("sharded_password"):
355+
# if password was stored shared, reconstruct it
356+
shard_count = int(password_as_dict.get("shard_count"))
357+
358+
password = ""
359+
for i in range(shard_count):
360+
password += str(self._keyring.get_password(key, f"token__{i}"))
361+
352362
except self._keyring.errors.NoKeyringError as e:
353363
raise trino.exceptions.NotSupportedError("Although keyring module is installed no backend has been "
354364
"detected, check https://pypi.org/project/keyring/ for more "
355365
"information.") from e
366+
except ValueError:
367+
pass
368+
369+
return password
356370

357371
def store_token_to_cache(self, key: Optional[str], token: str) -> None:
372+
# keyring is installed, so we can store the token for reuse within multiple threads
358373
try:
359-
# keyring is installed, so we can store the token for reuse within multiple threads
360-
self._keyring.set_password(key, "token", token)
374+
# if not Windows or "small" password, stick to the default
375+
if os.name != "nt" or len(token) < MAX_NT_PASSWORD_SIZE:
376+
self._keyring.set_password(key, "token", token)
377+
else:
378+
logger.debug(f"password is {len(token)} characters, sharding it.")
379+
380+
password_shards = [
381+
token[i: i + MAX_NT_PASSWORD_SIZE] for i in range(0, len(token), MAX_NT_PASSWORD_SIZE)
382+
]
383+
shard_info = {
384+
"sharded_password": True,
385+
"shard_count": len(password_shards),
386+
}
387+
388+
# store the "shard info" as the "base" password
389+
self._keyring.set_password(key, "token", json.dumps(shard_info))
390+
# then store all shards with the shard number as postfix
391+
for i, s in enumerate(password_shards):
392+
self._keyring.set_password(key, f"token__{i}", s)
361393
except self._keyring.errors.NoKeyringError as e:
362394
raise trino.exceptions.NotSupportedError("Although keyring module is installed no backend has been "
363395
"detected, check https://pypi.org/project/keyring/ for more "

trino/constants.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
DEFAULT_AUTH: Optional[Any] = None
2121
DEFAULT_MAX_ATTEMPTS = 3
2222
DEFAULT_REQUEST_TIMEOUT: float = 30.0
23+
MAX_NT_PASSWORD_SIZE: int = 1280
2324

2425
HTTP = "http"
2526
HTTPS = "https"

0 commit comments

Comments
 (0)