|
22 | 22 |
|
23 | 23 | from requests import PreparedRequest, Request, Response, Session
|
24 | 24 | from requests.auth import AuthBase, extract_cookies_to_jar
|
25 |
| -from requests.utils import parse_dict_header |
26 | 25 |
|
27 | 26 | import trino.logging
|
28 | 27 | from trino.client import exceptions
|
@@ -421,10 +420,13 @@ def _attempt_oauth(self, response: Response, **kwargs: Any) -> None:
|
421 | 420 | if not _OAuth2TokenBearer._BEARER_PREFIX.search(auth_info):
|
422 | 421 | raise exceptions.TrinoAuthError(f"Error: header info didn't match {auth_info}")
|
423 | 422 |
|
424 |
| - auth_info_headers = parse_dict_header( |
425 |
| - _OAuth2TokenBearer._BEARER_PREFIX.sub("", auth_info, count=1)) # type: ignore |
| 423 | + # Example www-authenticate header value: |
| 424 | + # 'Basic realm="Trino", Bearer realm="Trino", token_type="JWT", |
| 425 | + # Bearer x_redirect_server="https://trino.com/oauth2/token/uuid4", |
| 426 | + # x_token_server="https://trino.com/oauth2/token/uuid4"' |
| 427 | + auth_info_headers = self._parse_authenticate_header(auth_info) |
426 | 428 |
|
427 |
| - auth_server = auth_info_headers.get('x_redirect_server') |
| 429 | + auth_server = auth_info_headers.get('bearer x_redirect_server', auth_info_headers.get('x_redirect_server')) |
428 | 430 | token_server = auth_info_headers.get('x_token_server')
|
429 | 431 | if token_server is None:
|
430 | 432 | raise exceptions.TrinoAuthError("Error: header info didn't have x_token_server")
|
@@ -510,6 +512,21 @@ def _construct_cache_key(host: Optional[str], user: Optional[str]) -> Optional[s
|
510 | 512 | else:
|
511 | 513 | return f"{host}@{user}"
|
512 | 514 |
|
| 515 | + @staticmethod |
| 516 | + def _parse_authenticate_header(header: str) -> Dict[str, str]: |
| 517 | + split_challenge = header.split(" ", 1) |
| 518 | + trimmed_challenge = split_challenge[1] if len(split_challenge) > 1 else "" |
| 519 | + auth_info_headers = {} |
| 520 | + |
| 521 | + for item in trimmed_challenge.split(","): |
| 522 | + comps = item.split("=") |
| 523 | + if len(comps) == 2: |
| 524 | + key = comps[0].strip(' "') |
| 525 | + value = comps[1].strip(' "') |
| 526 | + if key: |
| 527 | + auth_info_headers[key.lower()] = value |
| 528 | + return auth_info_headers |
| 529 | + |
513 | 530 |
|
514 | 531 | class OAuth2Authentication(Authentication):
|
515 | 532 | def __init__(self, redirect_auth_url_handler: CompositeRedirectHandler = CompositeRedirectHandler([
|
|
0 commit comments