Skip to content

Commit d8b3b68

Browse files
hovaescohashhar
authored andcommitted
Cache OAuth access token per host and user pair
1 parent e90ba3a commit d8b3b68

File tree

2 files changed

+30
-8
lines changed

2 files changed

+30
-8
lines changed

README.md

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -222,7 +222,10 @@ the [OAuth2 authentication type](https://trino.io/docs/current/security/oauth2.h
222222

223223
A callback to handle the redirect url can be provided via param `redirect_auth_url_handler` of the `trino.auth.OAuth2Authentication` class. By default, it will try to launch a web browser (`trino.auth.WebBrowserRedirectHandler`) to go through the authentication flow and output the redirect url to stdout (`trino.auth.ConsoleRedirectHandler`). Multiple redirect handlers are combined using the `trino.auth.CompositeRedirectHandler` class.
224224

225-
The OAuth2 token will be cached either per `trino.auth.OAuth2Authentication` instance or, when keyring is installed, it will be cached within a secure backend (MacOS keychain, Windows credential locker, etc) under a key including host of the Trino connection. Keyring can be installed using `pip install 'trino[external-authentication-token-cache]'`.
225+
The OAuth2 token will be cached either per `trino.auth.OAuth2Authentication` instance and username or, when keyring is installed, it will be cached within a secure backend (MacOS keychain, Windows credential locker, etc) under a key including host of the Trino connection. Keyring can be installed using `pip install 'trino[external-authentication-token-cache]'`.
226+
227+
> [!WARNING]
228+
> If username is not specified then the OAuth2 token cache is shared and stored per host.
226229

227230
- DBAPI
228231

trino/auth.py

Lines changed: 26 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
import re
1818
import threading
1919
import webbrowser
20-
from typing import Any, Callable, Dict, List, Optional, Tuple
20+
from typing import Any, Callable, Dict, List, Mapping, Optional, Tuple
2121
from urllib.parse import urlparse
2222

2323
from requests import PreparedRequest, Request, Response, Session
@@ -26,6 +26,7 @@
2626

2727
import trino.logging
2828
from trino.client import exceptions
29+
from trino.constants import HEADER_USER
2930

3031
logger = trino.logging.get_logger(__name__)
3132

@@ -218,7 +219,8 @@ def store_token_to_cache(self, key: Optional[str], token: str) -> None:
218219

219220
class _OAuth2TokenInMemoryCache(_OAuth2TokenCache):
220221
"""
221-
In-memory token cache implementation. The token is stored per host, so multiple clients can share the same cache.
222+
Multiple clients can share the same cache only if each connection explicitly specifies
223+
a user otherwise the first cached token will be used to authenticate all other users.
222224
"""
223225

224226
def __init__(self) -> None:
@@ -233,7 +235,7 @@ def store_token_to_cache(self, key: Optional[str], token: str) -> None:
233235

234236
class _OAuth2KeyRingTokenCache(_OAuth2TokenCache):
235237
"""
236-
Keyring Token Cache implementation
238+
Keyring token cache implementation
237239
"""
238240

239241
def __init__(self) -> None:
@@ -268,7 +270,7 @@ def store_token_to_cache(self, key: Optional[str], token: str) -> None:
268270

269271
class _OAuth2TokenBearer(AuthBase):
270272
"""
271-
Custom implementation of Trino Oauth2 based authorization to get the token
273+
Custom implementation of Trino OAuth2 based authentication to get the token
272274
"""
273275
MAX_OAUTH_ATTEMPTS = 5
274276
_BEARER_PREFIX = re.compile(r"bearer", flags=re.IGNORECASE)
@@ -283,7 +285,9 @@ def __init__(self, redirect_auth_url_handler: Callable[[str], None]):
283285

284286
def __call__(self, r: PreparedRequest) -> PreparedRequest:
285287
host = self._determine_host(r.url)
286-
token = self._get_token_from_cache(host)
288+
user = self._determine_user(r.headers)
289+
key = self._construct_cache_key(host, user)
290+
token = self._get_token_from_cache(key)
287291

288292
if token is not None:
289293
r.headers['Authorization'] = "Bearer " + token
@@ -341,15 +345,19 @@ def _attempt_oauth(self, response: Response, **kwargs: Any) -> None:
341345

342346
request = response.request
343347
host = self._determine_host(request.url)
344-
self._store_token_to_cache(host, token)
348+
user = self._determine_user(request.headers)
349+
key = self._construct_cache_key(host, user)
350+
self._store_token_to_cache(key, token)
345351

346352
def _retry_request(self, response: Response, **kwargs: Any) -> Optional[Response]:
347353
request = response.request.copy()
348354
extract_cookies_to_jar(request._cookies, response.request, response.raw) # type: ignore
349355
request.prepare_cookies(request._cookies) # type: ignore
350356

351357
host = self._determine_host(response.request.url)
352-
token = self._get_token_from_cache(host)
358+
user = self._determine_user(request.headers)
359+
key = self._construct_cache_key(host, user)
360+
token = self._get_token_from_cache(key)
353361
if token is not None:
354362
request.headers['Authorization'] = "Bearer " + token
355363
retry_response = response.connection.send(request, **kwargs) # type: ignore
@@ -394,6 +402,17 @@ def _store_token_to_cache(self, key: Optional[str], token: str) -> None:
394402
def _determine_host(url: Optional[str]) -> Any:
395403
return urlparse(url).hostname
396404

405+
@staticmethod
406+
def _determine_user(headers: Mapping[Any, Any]) -> Optional[Any]:
407+
return headers.get(HEADER_USER)
408+
409+
@staticmethod
410+
def _construct_cache_key(host: Optional[str], user: Optional[str]) -> Optional[str]:
411+
if user is None:
412+
return host
413+
else:
414+
return f"{host}@{user}"
415+
397416

398417
class OAuth2Authentication(Authentication):
399418
def __init__(self, redirect_auth_url_handler: CompositeRedirectHandler = CompositeRedirectHandler([

0 commit comments

Comments
 (0)