Skip to content

Commit 8dcb594

Browse files
sinkuladishashhar
authored andcommitted
Fix refresh tokens flow
Signed-off-by: sinkuladis <[email protected]>
1 parent bcd4039 commit 8dcb594

File tree

3 files changed

+58
-7
lines changed

3 files changed

+58
-7
lines changed

tests/unit/oauth_test_utils.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,14 @@ def __call__(self, url):
3434
self.redirect_server += url
3535

3636

37+
class RedirectHandlerWithException:
38+
def __init__(self, exception):
39+
self.exception = exception
40+
41+
def __call__(self, url):
42+
raise self.exception
43+
44+
3745
class PostStatementCallback:
3846
def __init__(self, redirect_server, token_server, tokens, sample_post_response_data):
3947
self.redirect_server = redirect_server
@@ -45,6 +53,9 @@ def __call__(self, request, uri, response_headers):
4553
authorization = request.headers.get("Authorization")
4654
if authorization and authorization.replace("Bearer ", "") in self.tokens:
4755
return [200, response_headers, json.dumps(self.sample_post_response_data)]
56+
elif self.redirect_server is None and self.token_server is not None:
57+
return [401, {'Www-Authenticate': f'Bearer x_token_server="{self.token_server}"',
58+
'Basic realm': '"Trino"'}, ""]
4859
return [401, {'Www-Authenticate': f'Bearer x_redirect_server="{self.redirect_server}", '
4960
f'x_token_server="{self.token_server}"',
5061
'Basic realm': '"Trino"'}, ""]

tests/unit/test_client.py

Lines changed: 44 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
MultithreadedTokenServer,
3333
PostStatementCallback,
3434
RedirectHandler,
35+
RedirectHandlerWithException,
3536
_get_token_requests,
3637
_post_statement_requests,
3738
)
@@ -384,6 +385,48 @@ def test_oauth2_authentication_flow(attempts, sample_post_response_data):
384385
assert len(_get_token_requests(challenge_id)) == attempts
385386

386387

388+
@httprettified
389+
def test_oauth2_refresh_token_flow(sample_post_response_data):
390+
token = str(uuid.uuid4())
391+
challenge_id = str(uuid.uuid4())
392+
393+
token_server = f"{TOKEN_RESOURCE}/{challenge_id}"
394+
395+
post_statement_callback = PostStatementCallback(None, token_server, [token], sample_post_response_data)
396+
397+
# bind post statement
398+
httpretty.register_uri(
399+
method=httpretty.POST,
400+
uri=f"{SERVER_ADDRESS}{constants.URL_STATEMENT_PATH}",
401+
body=post_statement_callback)
402+
403+
# bind get token
404+
get_token_callback = GetTokenCallback(token_server, token)
405+
httpretty.register_uri(
406+
method=httpretty.GET,
407+
uri=token_server,
408+
body=get_token_callback)
409+
410+
redirect_handler = RedirectHandlerWithException(
411+
trino.exceptions.TrinoAuthError(
412+
"Do not use redirect handler when there is no redirect_uri in the response"))
413+
414+
request = TrinoRequest(
415+
host="coordinator",
416+
port=constants.DEFAULT_TLS_PORT,
417+
client_session=ClientSession(
418+
user="test",
419+
),
420+
http_scheme=constants.HTTPS,
421+
auth=trino.auth.OAuth2Authentication(redirect_auth_url_handler=redirect_handler))
422+
423+
response = request.post("select 1")
424+
425+
assert response.request.headers['Authorization'] == f"Bearer {token}"
426+
assert get_token_callback.attempts == 0
427+
assert len(_post_statement_requests()) == 2
428+
429+
387430
@pytest.mark.parametrize("attempts", [6, 10])
388431
@httprettified
389432
def test_oauth2_exceed_max_attempts(attempts, sample_post_response_data):
@@ -430,10 +473,9 @@ def test_oauth2_exceed_max_attempts(attempts, sample_post_response_data):
430473

431474
@pytest.mark.parametrize("header,error", [
432475
("", "Error: header WWW-Authenticate not available in the response."),
433-
('Bearer"', 'Error: header info didn\'t have x_redirect_server'),
476+
('Bearer"', 'Error: header info didn\'t have x_token_server'),
434477
('x_redirect_server="redirect_server", x_token_server="token_server"', 'Error: header info didn\'t match x_redirect_server="redirect_server", x_token_server="token_server"'), # noqa: E501
435478
('Bearer x_redirect_server="redirect_server"', 'Error: header info didn\'t have x_token_server'),
436-
('Bearer x_token_server="token_server"', 'Error: header info didn\'t have x_redirect_server'),
437479
])
438480
@httprettified
439481
def test_oauth2_authentication_missing_headers(header, error):

trino/auth.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -322,15 +322,13 @@ def _attempt_oauth(self, response, **kwargs):
322322
auth_info_headers = parse_dict_header(_OAuth2TokenBearer._BEARER_PREFIX.sub("", auth_info, count=1))
323323

324324
auth_server = auth_info_headers.get('x_redirect_server')
325-
if auth_server is None:
326-
raise exceptions.TrinoAuthError("Error: header info didn't have x_redirect_server")
327-
328325
token_server = auth_info_headers.get('x_token_server')
329326
if token_server is None:
330327
raise exceptions.TrinoAuthError("Error: header info didn't have x_token_server")
331328

332-
# tell app that use this url to proceed with the authentication
333-
self._redirect_auth_url(auth_server)
329+
if auth_server is not None:
330+
# tell app that use this url to proceed with the authentication
331+
self._redirect_auth_url(auth_server)
334332

335333
# Consume content and release the original connection
336334
# to allow our new request to reuse the same one.

0 commit comments

Comments
 (0)