diff --git a/tests/unit/test_client.py b/tests/unit/test_client.py index 82e19a0f..a522cde0 100644 --- a/tests/unit/test_client.py +++ b/tests/unit/test_client.py @@ -543,13 +543,14 @@ def test_oauth2_authentication_missing_headers(header, error): 'Bearer x_redirect_server="{redirect_server}", x_token_server="{token_server}", Basic realm="Trino"', 'Basic realm="Trino", Bearer realm="Trino", token_type="JWT", Bearer x_redirect_server="{redirect_server}", ' 'x_token_server="{token_server}"' + 'Bearer x_redirect_server="{redirect_server}",x_token_server="{token_server}",additional_challenge', ]) @httprettified def test_oauth2_header_parsing(header, sample_post_response_data): token = str(uuid.uuid4()) challenge_id = str(uuid.uuid4()) - redirect_server = f"{REDIRECT_RESOURCE}/{challenge_id}" + redirect_server = f"{REDIRECT_RESOURCE}/{challenge_id}?role=test" token_server = f"{TOKEN_RESOURCE}/{challenge_id}" # noinspection PyUnusedLocal diff --git a/trino/auth.py b/trino/auth.py index 8a24ecd7..3cad6c50 100644 --- a/trino/auth.py +++ b/trino/auth.py @@ -546,17 +546,17 @@ def _construct_cache_key(host: Optional[str], user: Optional[str]) -> Optional[s @staticmethod def _parse_authenticate_header(header: str) -> Dict[str, str]: - split_challenge = header.split(" ", 1) - trimmed_challenge = split_challenge[1] if len(split_challenge) > 1 else "" + logger.debug(f"Authentication header: {header}") + components = header.split(",") auth_info_headers = {} - for item in trimmed_challenge.split(","): - comps = item.split("=") - if len(comps) == 2: - key = comps[0].strip(' "') - value = comps[1].strip(' "') - if key: - auth_info_headers[key.lower()] = value + for component in components: + component = component.strip() + if "=" in component: + key, value = component.split("=", 1) + if value[0] == '"' and value[-1] == '"': + value = value[1:-1] + auth_info_headers[key.lower()] = value return auth_info_headers