|
18 | 18 | from .mock_utils import mock_connection |
19 | 19 |
|
20 | 20 | try: # pragma: no cover |
21 | | - from snowflake.connector.auth import Auth, AuthByDefault, AuthByPlugin |
| 21 | + from snowflake.connector.auth import ( |
| 22 | + Auth, |
| 23 | + AuthByDefault, |
| 24 | + AuthByOAuth, |
| 25 | + AuthByOauthCode, |
| 26 | + AuthByOauthCredentials, |
| 27 | + AuthByPlugin, |
| 28 | + ) |
22 | 29 | except ImportError: |
23 | 30 | from snowflake.connector.auth import Auth |
24 | 31 | from snowflake.connector.auth_by_plugin import AuthByPlugin |
25 | 32 | from snowflake.connector.auth_default import AuthByDefault |
| 33 | + from snowflake.connector.auth_oauth import AuthByOAuth |
| 34 | + from snowflake.connector.auth_oauth_code import AuthByOauthCode |
| 35 | + from snowflake.connector.auth_oauth_credentials import AuthByOauthCredentials |
| 36 | + |
| 37 | +from snowflake.connector.errors import DatabaseError |
| 38 | +from snowflake.connector.network import ReauthenticationRequest |
26 | 39 |
|
27 | 40 |
|
28 | 41 | def _init_rest(application, post_requset): |
@@ -354,3 +367,65 @@ def test_auth_by_default_prepare_body_does_not_overwrite_client_environment_fiel |
354 | 367 | for k in req_body_before["data"]["CLIENT_ENVIRONMENT"] |
355 | 368 | ] |
356 | 369 | ) |
| 370 | + |
| 371 | + |
| 372 | +def _mock_oauth_token_expired_rest_response(url, headers, body, **kwargs): |
| 373 | + """Mock rest response for OAuth access token expired error.""" |
| 374 | + from snowflake.connector.network import OAUTH_ACCESS_TOKEN_EXPIRED_GS_CODE |
| 375 | + |
| 376 | + return { |
| 377 | + "success": False, |
| 378 | + "message": "OAuth access token expired", |
| 379 | + "code": OAUTH_ACCESS_TOKEN_EXPIRED_GS_CODE, |
| 380 | + "data": {}, |
| 381 | + } |
| 382 | + |
| 383 | + |
| 384 | +@pytest.mark.skipolddriver |
| 385 | +@pytest.mark.parametrize( |
| 386 | + "auth_instance, expected_exc_type", |
| 387 | + [ |
| 388 | + (AuthByOAuth("test_oauth_token"), DatabaseError), |
| 389 | + ( |
| 390 | + AuthByOauthCode( |
| 391 | + application="testapp", |
| 392 | + client_id="test_client_id", |
| 393 | + client_secret="test_client_secret", |
| 394 | + authentication_url="https://auth.example.com", |
| 395 | + token_request_url="https://token.example.com", |
| 396 | + redirect_uri="http://localhost:8080", |
| 397 | + scope="session:role-any", |
| 398 | + host="testaccount.snowflakecomputing.com", |
| 399 | + ), |
| 400 | + ReauthenticationRequest, |
| 401 | + ), |
| 402 | + ( |
| 403 | + AuthByOauthCredentials( |
| 404 | + application="testapp", |
| 405 | + client_id="test_client_id", |
| 406 | + client_secret="test_client_secret", |
| 407 | + token_request_url="https://token.example.com", |
| 408 | + scope="session:role-any", |
| 409 | + ), |
| 410 | + ReauthenticationRequest, |
| 411 | + ), |
| 412 | + ], |
| 413 | +) |
| 414 | +def test_oauth_token_expired_error_handling(auth_instance, expected_exc_type): |
| 415 | + """Test that OAuth authenticators handle token expiry errors differently. |
| 416 | +
|
| 417 | + - AuthByOAuth should raise DatabaseError (falls through to general error handling) |
| 418 | + - AuthByOauthCode and AuthByOauthCredentials should raise ProgrammingError (via ReauthenticationRequest) |
| 419 | + """ |
| 420 | + |
| 421 | + def mock_errorhandler_always_raise(connection, cursor, error_class, error_value): |
| 422 | + raise error_class(**error_value) |
| 423 | + |
| 424 | + application = "testapplication" |
| 425 | + account = "testaccount" |
| 426 | + user = "testuser" |
| 427 | + rest = _init_rest(application, _mock_oauth_token_expired_rest_response) |
| 428 | + rest._connection.errorhandler = mock_errorhandler_always_raise |
| 429 | + auth = Auth(rest) |
| 430 | + with pytest.raises(expected_exc_type): |
| 431 | + auth.authenticate(auth_instance, account, user) |
0 commit comments