|
32 | 32 | MultithreadedTokenServer,
|
33 | 33 | PostStatementCallback,
|
34 | 34 | RedirectHandler,
|
| 35 | + RedirectHandlerWithException, |
35 | 36 | _get_token_requests,
|
36 | 37 | _post_statement_requests,
|
37 | 38 | )
|
@@ -384,6 +385,48 @@ def test_oauth2_authentication_flow(attempts, sample_post_response_data):
|
384 | 385 | assert len(_get_token_requests(challenge_id)) == attempts
|
385 | 386 |
|
386 | 387 |
|
| 388 | +@httprettified |
| 389 | +def test_oauth2_refresh_token_flow(sample_post_response_data): |
| 390 | + token = str(uuid.uuid4()) |
| 391 | + challenge_id = str(uuid.uuid4()) |
| 392 | + |
| 393 | + token_server = f"{TOKEN_RESOURCE}/{challenge_id}" |
| 394 | + |
| 395 | + post_statement_callback = PostStatementCallback(None, token_server, [token], sample_post_response_data) |
| 396 | + |
| 397 | + # bind post statement |
| 398 | + httpretty.register_uri( |
| 399 | + method=httpretty.POST, |
| 400 | + uri=f"{SERVER_ADDRESS}{constants.URL_STATEMENT_PATH}", |
| 401 | + body=post_statement_callback) |
| 402 | + |
| 403 | + # bind get token |
| 404 | + get_token_callback = GetTokenCallback(token_server, token) |
| 405 | + httpretty.register_uri( |
| 406 | + method=httpretty.GET, |
| 407 | + uri=token_server, |
| 408 | + body=get_token_callback) |
| 409 | + |
| 410 | + redirect_handler = RedirectHandlerWithException( |
| 411 | + trino.exceptions.TrinoAuthError( |
| 412 | + "Do not use redirect handler when there is no redirect_uri in the response")) |
| 413 | + |
| 414 | + request = TrinoRequest( |
| 415 | + host="coordinator", |
| 416 | + port=constants.DEFAULT_TLS_PORT, |
| 417 | + client_session=ClientSession( |
| 418 | + user="test", |
| 419 | + ), |
| 420 | + http_scheme=constants.HTTPS, |
| 421 | + auth=trino.auth.OAuth2Authentication(redirect_auth_url_handler=redirect_handler)) |
| 422 | + |
| 423 | + response = request.post("select 1") |
| 424 | + |
| 425 | + assert response.request.headers['Authorization'] == f"Bearer {token}" |
| 426 | + assert get_token_callback.attempts == 0 |
| 427 | + assert len(_post_statement_requests()) == 2 |
| 428 | + |
| 429 | + |
387 | 430 | @pytest.mark.parametrize("attempts", [6, 10])
|
388 | 431 | @httprettified
|
389 | 432 | def test_oauth2_exceed_max_attempts(attempts, sample_post_response_data):
|
@@ -430,10 +473,9 @@ def test_oauth2_exceed_max_attempts(attempts, sample_post_response_data):
|
430 | 473 |
|
431 | 474 | @pytest.mark.parametrize("header,error", [
|
432 | 475 | ("", "Error: header WWW-Authenticate not available in the response."),
|
433 |
| - ('Bearer"', 'Error: header info didn\'t have x_redirect_server'), |
| 476 | + ('Bearer"', 'Error: header info didn\'t have x_token_server'), |
434 | 477 | ('x_redirect_server="redirect_server", x_token_server="token_server"', 'Error: header info didn\'t match x_redirect_server="redirect_server", x_token_server="token_server"'), # noqa: E501
|
435 | 478 | ('Bearer x_redirect_server="redirect_server"', 'Error: header info didn\'t have x_token_server'),
|
436 |
| - ('Bearer x_token_server="token_server"', 'Error: header info didn\'t have x_redirect_server'), |
437 | 479 | ])
|
438 | 480 | @httprettified
|
439 | 481 | def test_oauth2_authentication_missing_headers(header, error):
|
|
0 commit comments