From 5146b17c765b8a46e5588cf179dc01d662fa4bdc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Filip=20Paw=C5=82owski?= Date: Sat, 9 Aug 2025 16:17:38 +0200 Subject: [PATCH 1/5] [Async] Apply #2364 to async code --- test/integ/{aio => aio_it}/__init__.py | 0 test/integ/{aio => aio_it}/conftest.py | 0 test/integ/{aio => aio_it}/lambda/__init__.py | 0 test/integ/{aio => aio_it}/lambda/test_basic_query_async.py | 0 test/integ/{aio => aio_it}/pandas/__init__.py | 0 .../{aio => aio_it}/pandas/test_arrow_chunk_iterator_async.py | 0 test/integ/{aio => aio_it}/pandas/test_arrow_pandas_async.py | 0 test/integ/{aio => aio_it}/pandas/test_logging_async.py | 0 test/integ/{aio => aio_it}/sso/__init__.py | 0 test/integ/{aio => aio_it}/sso/test_connection_manual_async.py | 0 test/integ/{aio => aio_it}/sso/test_unit_mfa_cache_async.py | 0 test/integ/{aio => aio_it}/test_arrow_result_async.py | 0 test/integ/{aio => aio_it}/test_async_async.py | 0 test/integ/{aio => aio_it}/test_autocommit_async.py | 0 test/integ/{aio => aio_it}/test_bindings_async.py | 0 test/integ/{aio => aio_it}/test_boolean_async.py | 0 .../{aio => aio_it}/test_client_session_keep_alive_async.py | 0 .../{aio => aio_it}/test_concurrent_create_objects_async.py | 0 test/integ/{aio => aio_it}/test_concurrent_insert_async.py | 0 test/integ/{aio => aio_it}/test_connection_async.py | 2 +- test/integ/{aio => aio_it}/test_converter_async.py | 0 .../{aio => aio_it}/test_converter_more_timestamp_async.py | 0 test/integ/{aio => aio_it}/test_converter_null_async.py | 0 test/integ/{aio => aio_it}/test_cursor_async.py | 0 test/integ/{aio => aio_it}/test_cursor_binding_async.py | 0 test/integ/{aio => aio_it}/test_cursor_context_manager_async.py | 0 test/integ/{aio => aio_it}/test_dataintegrity_async.py | 0 test/integ/{aio => aio_it}/test_daylight_savings_async.py | 0 test/integ/{aio => aio_it}/test_dbapi_async.py | 0 test/integ/{aio => aio_it}/test_decfloat_async.py | 0 .../{aio => aio_it}/test_direct_file_operation_utils_async.py | 0 test/integ/{aio => aio_it}/test_errors_async.py | 0 .../{aio => aio_it}/test_execute_multi_statements_async.py | 0 .../integ/{aio => aio_it}/test_key_pair_authentication_async.py | 0 test/integ/{aio => aio_it}/test_large_put_async.py | 0 test/integ/{aio => aio_it}/test_large_result_set_async.py | 0 test/integ/{aio => aio_it}/test_load_unload_async.py | 0 test/integ/{aio => aio_it}/test_multi_statement_async.py | 0 test/integ/{aio => aio_it}/test_network_async.py | 0 test/integ/{aio => aio_it}/test_numpy_binding_async.py | 0 test/integ/{aio => aio_it}/test_pickle_timestamp_tz_async.py | 0 test/integ/{aio => aio_it}/test_put_get_async.py | 0 test/integ/{aio => aio_it}/test_put_get_compress_enc_async.py | 0 test/integ/{aio => aio_it}/test_put_get_medium_async.py | 0 test/integ/{aio => aio_it}/test_put_get_snow_4525_async.py | 0 test/integ/{aio => aio_it}/test_put_get_user_stage_async.py | 0 test/integ/{aio => aio_it}/test_put_get_with_aws_token_async.py | 0 .../{aio => aio_it}/test_put_get_with_azure_token_async.py | 0 .../{aio => aio_it}/test_put_get_with_gcp_account_async.py | 0 test/integ/{aio => aio_it}/test_put_windows_path_async.py | 0 test/integ/{aio => aio_it}/test_qmark_async.py | 0 test/integ/{aio => aio_it}/test_query_cancelling_async.py | 0 test/integ/{aio => aio_it}/test_results_async.py | 0 test/integ/{aio => aio_it}/test_reuse_cursor_async.py | 0 test/integ/{aio => aio_it}/test_session_parameters_async.py | 0 .../{aio => aio_it}/test_statement_parameter_binding_async.py | 0 test/integ/{aio => aio_it}/test_structured_types_async.py | 0 test/integ/{aio => aio_it}/test_transaction_async.py | 0 58 files changed, 1 insertion(+), 1 deletion(-) rename test/integ/{aio => aio_it}/__init__.py (100%) rename test/integ/{aio => aio_it}/conftest.py (100%) rename test/integ/{aio => aio_it}/lambda/__init__.py (100%) rename test/integ/{aio => aio_it}/lambda/test_basic_query_async.py (100%) rename test/integ/{aio => aio_it}/pandas/__init__.py (100%) rename test/integ/{aio => aio_it}/pandas/test_arrow_chunk_iterator_async.py (100%) rename test/integ/{aio => aio_it}/pandas/test_arrow_pandas_async.py (100%) rename test/integ/{aio => aio_it}/pandas/test_logging_async.py (100%) rename test/integ/{aio => aio_it}/sso/__init__.py (100%) rename test/integ/{aio => aio_it}/sso/test_connection_manual_async.py (100%) rename test/integ/{aio => aio_it}/sso/test_unit_mfa_cache_async.py (100%) rename test/integ/{aio => aio_it}/test_arrow_result_async.py (100%) rename test/integ/{aio => aio_it}/test_async_async.py (100%) rename test/integ/{aio => aio_it}/test_autocommit_async.py (100%) rename test/integ/{aio => aio_it}/test_bindings_async.py (100%) rename test/integ/{aio => aio_it}/test_boolean_async.py (100%) rename test/integ/{aio => aio_it}/test_client_session_keep_alive_async.py (100%) rename test/integ/{aio => aio_it}/test_concurrent_create_objects_async.py (100%) rename test/integ/{aio => aio_it}/test_concurrent_insert_async.py (100%) rename test/integ/{aio => aio_it}/test_connection_async.py (99%) rename test/integ/{aio => aio_it}/test_converter_async.py (100%) rename test/integ/{aio => aio_it}/test_converter_more_timestamp_async.py (100%) rename test/integ/{aio => aio_it}/test_converter_null_async.py (100%) rename test/integ/{aio => aio_it}/test_cursor_async.py (100%) rename test/integ/{aio => aio_it}/test_cursor_binding_async.py (100%) rename test/integ/{aio => aio_it}/test_cursor_context_manager_async.py (100%) rename test/integ/{aio => aio_it}/test_dataintegrity_async.py (100%) rename test/integ/{aio => aio_it}/test_daylight_savings_async.py (100%) rename test/integ/{aio => aio_it}/test_dbapi_async.py (100%) rename test/integ/{aio => aio_it}/test_decfloat_async.py (100%) rename test/integ/{aio => aio_it}/test_direct_file_operation_utils_async.py (100%) rename test/integ/{aio => aio_it}/test_errors_async.py (100%) rename test/integ/{aio => aio_it}/test_execute_multi_statements_async.py (100%) rename test/integ/{aio => aio_it}/test_key_pair_authentication_async.py (100%) rename test/integ/{aio => aio_it}/test_large_put_async.py (100%) rename test/integ/{aio => aio_it}/test_large_result_set_async.py (100%) rename test/integ/{aio => aio_it}/test_load_unload_async.py (100%) rename test/integ/{aio => aio_it}/test_multi_statement_async.py (100%) rename test/integ/{aio => aio_it}/test_network_async.py (100%) rename test/integ/{aio => aio_it}/test_numpy_binding_async.py (100%) rename test/integ/{aio => aio_it}/test_pickle_timestamp_tz_async.py (100%) rename test/integ/{aio => aio_it}/test_put_get_async.py (100%) rename test/integ/{aio => aio_it}/test_put_get_compress_enc_async.py (100%) rename test/integ/{aio => aio_it}/test_put_get_medium_async.py (100%) rename test/integ/{aio => aio_it}/test_put_get_snow_4525_async.py (100%) rename test/integ/{aio => aio_it}/test_put_get_user_stage_async.py (100%) rename test/integ/{aio => aio_it}/test_put_get_with_aws_token_async.py (100%) rename test/integ/{aio => aio_it}/test_put_get_with_azure_token_async.py (100%) rename test/integ/{aio => aio_it}/test_put_get_with_gcp_account_async.py (100%) rename test/integ/{aio => aio_it}/test_put_windows_path_async.py (100%) rename test/integ/{aio => aio_it}/test_qmark_async.py (100%) rename test/integ/{aio => aio_it}/test_query_cancelling_async.py (100%) rename test/integ/{aio => aio_it}/test_results_async.py (100%) rename test/integ/{aio => aio_it}/test_reuse_cursor_async.py (100%) rename test/integ/{aio => aio_it}/test_session_parameters_async.py (100%) rename test/integ/{aio => aio_it}/test_statement_parameter_binding_async.py (100%) rename test/integ/{aio => aio_it}/test_structured_types_async.py (100%) rename test/integ/{aio => aio_it}/test_transaction_async.py (100%) diff --git a/test/integ/aio/__init__.py b/test/integ/aio_it/__init__.py similarity index 100% rename from test/integ/aio/__init__.py rename to test/integ/aio_it/__init__.py diff --git a/test/integ/aio/conftest.py b/test/integ/aio_it/conftest.py similarity index 100% rename from test/integ/aio/conftest.py rename to test/integ/aio_it/conftest.py diff --git a/test/integ/aio/lambda/__init__.py b/test/integ/aio_it/lambda/__init__.py similarity index 100% rename from test/integ/aio/lambda/__init__.py rename to test/integ/aio_it/lambda/__init__.py diff --git a/test/integ/aio/lambda/test_basic_query_async.py b/test/integ/aio_it/lambda/test_basic_query_async.py similarity index 100% rename from test/integ/aio/lambda/test_basic_query_async.py rename to test/integ/aio_it/lambda/test_basic_query_async.py diff --git a/test/integ/aio/pandas/__init__.py b/test/integ/aio_it/pandas/__init__.py similarity index 100% rename from test/integ/aio/pandas/__init__.py rename to test/integ/aio_it/pandas/__init__.py diff --git a/test/integ/aio/pandas/test_arrow_chunk_iterator_async.py b/test/integ/aio_it/pandas/test_arrow_chunk_iterator_async.py similarity index 100% rename from test/integ/aio/pandas/test_arrow_chunk_iterator_async.py rename to test/integ/aio_it/pandas/test_arrow_chunk_iterator_async.py diff --git a/test/integ/aio/pandas/test_arrow_pandas_async.py b/test/integ/aio_it/pandas/test_arrow_pandas_async.py similarity index 100% rename from test/integ/aio/pandas/test_arrow_pandas_async.py rename to test/integ/aio_it/pandas/test_arrow_pandas_async.py diff --git a/test/integ/aio/pandas/test_logging_async.py b/test/integ/aio_it/pandas/test_logging_async.py similarity index 100% rename from test/integ/aio/pandas/test_logging_async.py rename to test/integ/aio_it/pandas/test_logging_async.py diff --git a/test/integ/aio/sso/__init__.py b/test/integ/aio_it/sso/__init__.py similarity index 100% rename from test/integ/aio/sso/__init__.py rename to test/integ/aio_it/sso/__init__.py diff --git a/test/integ/aio/sso/test_connection_manual_async.py b/test/integ/aio_it/sso/test_connection_manual_async.py similarity index 100% rename from test/integ/aio/sso/test_connection_manual_async.py rename to test/integ/aio_it/sso/test_connection_manual_async.py diff --git a/test/integ/aio/sso/test_unit_mfa_cache_async.py b/test/integ/aio_it/sso/test_unit_mfa_cache_async.py similarity index 100% rename from test/integ/aio/sso/test_unit_mfa_cache_async.py rename to test/integ/aio_it/sso/test_unit_mfa_cache_async.py diff --git a/test/integ/aio/test_arrow_result_async.py b/test/integ/aio_it/test_arrow_result_async.py similarity index 100% rename from test/integ/aio/test_arrow_result_async.py rename to test/integ/aio_it/test_arrow_result_async.py diff --git a/test/integ/aio/test_async_async.py b/test/integ/aio_it/test_async_async.py similarity index 100% rename from test/integ/aio/test_async_async.py rename to test/integ/aio_it/test_async_async.py diff --git a/test/integ/aio/test_autocommit_async.py b/test/integ/aio_it/test_autocommit_async.py similarity index 100% rename from test/integ/aio/test_autocommit_async.py rename to test/integ/aio_it/test_autocommit_async.py diff --git a/test/integ/aio/test_bindings_async.py b/test/integ/aio_it/test_bindings_async.py similarity index 100% rename from test/integ/aio/test_bindings_async.py rename to test/integ/aio_it/test_bindings_async.py diff --git a/test/integ/aio/test_boolean_async.py b/test/integ/aio_it/test_boolean_async.py similarity index 100% rename from test/integ/aio/test_boolean_async.py rename to test/integ/aio_it/test_boolean_async.py diff --git a/test/integ/aio/test_client_session_keep_alive_async.py b/test/integ/aio_it/test_client_session_keep_alive_async.py similarity index 100% rename from test/integ/aio/test_client_session_keep_alive_async.py rename to test/integ/aio_it/test_client_session_keep_alive_async.py diff --git a/test/integ/aio/test_concurrent_create_objects_async.py b/test/integ/aio_it/test_concurrent_create_objects_async.py similarity index 100% rename from test/integ/aio/test_concurrent_create_objects_async.py rename to test/integ/aio_it/test_concurrent_create_objects_async.py diff --git a/test/integ/aio/test_concurrent_insert_async.py b/test/integ/aio_it/test_concurrent_insert_async.py similarity index 100% rename from test/integ/aio/test_concurrent_insert_async.py rename to test/integ/aio_it/test_concurrent_insert_async.py diff --git a/test/integ/aio/test_connection_async.py b/test/integ/aio_it/test_connection_async.py similarity index 99% rename from test/integ/aio/test_connection_async.py rename to test/integ/aio_it/test_connection_async.py index e0c771664a..df76fa1df4 100644 --- a/test/integ/aio/test_connection_async.py +++ b/test/integ/aio_it/test_connection_async.py @@ -1674,7 +1674,7 @@ async def test_is_valid(conn_cnx): async def test_no_auth_connection_negative_case(): # AuthNoAuth does not exist in old drivers, so we import at test level to # skip importing it for old driver tests. - from test.integ.aio.conftest import create_connection + from test.integ.aio_it.conftest import create_connection from snowflake.connector.aio.auth._no_auth import AuthNoAuth diff --git a/test/integ/aio/test_converter_async.py b/test/integ/aio_it/test_converter_async.py similarity index 100% rename from test/integ/aio/test_converter_async.py rename to test/integ/aio_it/test_converter_async.py diff --git a/test/integ/aio/test_converter_more_timestamp_async.py b/test/integ/aio_it/test_converter_more_timestamp_async.py similarity index 100% rename from test/integ/aio/test_converter_more_timestamp_async.py rename to test/integ/aio_it/test_converter_more_timestamp_async.py diff --git a/test/integ/aio/test_converter_null_async.py b/test/integ/aio_it/test_converter_null_async.py similarity index 100% rename from test/integ/aio/test_converter_null_async.py rename to test/integ/aio_it/test_converter_null_async.py diff --git a/test/integ/aio/test_cursor_async.py b/test/integ/aio_it/test_cursor_async.py similarity index 100% rename from test/integ/aio/test_cursor_async.py rename to test/integ/aio_it/test_cursor_async.py diff --git a/test/integ/aio/test_cursor_binding_async.py b/test/integ/aio_it/test_cursor_binding_async.py similarity index 100% rename from test/integ/aio/test_cursor_binding_async.py rename to test/integ/aio_it/test_cursor_binding_async.py diff --git a/test/integ/aio/test_cursor_context_manager_async.py b/test/integ/aio_it/test_cursor_context_manager_async.py similarity index 100% rename from test/integ/aio/test_cursor_context_manager_async.py rename to test/integ/aio_it/test_cursor_context_manager_async.py diff --git a/test/integ/aio/test_dataintegrity_async.py b/test/integ/aio_it/test_dataintegrity_async.py similarity index 100% rename from test/integ/aio/test_dataintegrity_async.py rename to test/integ/aio_it/test_dataintegrity_async.py diff --git a/test/integ/aio/test_daylight_savings_async.py b/test/integ/aio_it/test_daylight_savings_async.py similarity index 100% rename from test/integ/aio/test_daylight_savings_async.py rename to test/integ/aio_it/test_daylight_savings_async.py diff --git a/test/integ/aio/test_dbapi_async.py b/test/integ/aio_it/test_dbapi_async.py similarity index 100% rename from test/integ/aio/test_dbapi_async.py rename to test/integ/aio_it/test_dbapi_async.py diff --git a/test/integ/aio/test_decfloat_async.py b/test/integ/aio_it/test_decfloat_async.py similarity index 100% rename from test/integ/aio/test_decfloat_async.py rename to test/integ/aio_it/test_decfloat_async.py diff --git a/test/integ/aio/test_direct_file_operation_utils_async.py b/test/integ/aio_it/test_direct_file_operation_utils_async.py similarity index 100% rename from test/integ/aio/test_direct_file_operation_utils_async.py rename to test/integ/aio_it/test_direct_file_operation_utils_async.py diff --git a/test/integ/aio/test_errors_async.py b/test/integ/aio_it/test_errors_async.py similarity index 100% rename from test/integ/aio/test_errors_async.py rename to test/integ/aio_it/test_errors_async.py diff --git a/test/integ/aio/test_execute_multi_statements_async.py b/test/integ/aio_it/test_execute_multi_statements_async.py similarity index 100% rename from test/integ/aio/test_execute_multi_statements_async.py rename to test/integ/aio_it/test_execute_multi_statements_async.py diff --git a/test/integ/aio/test_key_pair_authentication_async.py b/test/integ/aio_it/test_key_pair_authentication_async.py similarity index 100% rename from test/integ/aio/test_key_pair_authentication_async.py rename to test/integ/aio_it/test_key_pair_authentication_async.py diff --git a/test/integ/aio/test_large_put_async.py b/test/integ/aio_it/test_large_put_async.py similarity index 100% rename from test/integ/aio/test_large_put_async.py rename to test/integ/aio_it/test_large_put_async.py diff --git a/test/integ/aio/test_large_result_set_async.py b/test/integ/aio_it/test_large_result_set_async.py similarity index 100% rename from test/integ/aio/test_large_result_set_async.py rename to test/integ/aio_it/test_large_result_set_async.py diff --git a/test/integ/aio/test_load_unload_async.py b/test/integ/aio_it/test_load_unload_async.py similarity index 100% rename from test/integ/aio/test_load_unload_async.py rename to test/integ/aio_it/test_load_unload_async.py diff --git a/test/integ/aio/test_multi_statement_async.py b/test/integ/aio_it/test_multi_statement_async.py similarity index 100% rename from test/integ/aio/test_multi_statement_async.py rename to test/integ/aio_it/test_multi_statement_async.py diff --git a/test/integ/aio/test_network_async.py b/test/integ/aio_it/test_network_async.py similarity index 100% rename from test/integ/aio/test_network_async.py rename to test/integ/aio_it/test_network_async.py diff --git a/test/integ/aio/test_numpy_binding_async.py b/test/integ/aio_it/test_numpy_binding_async.py similarity index 100% rename from test/integ/aio/test_numpy_binding_async.py rename to test/integ/aio_it/test_numpy_binding_async.py diff --git a/test/integ/aio/test_pickle_timestamp_tz_async.py b/test/integ/aio_it/test_pickle_timestamp_tz_async.py similarity index 100% rename from test/integ/aio/test_pickle_timestamp_tz_async.py rename to test/integ/aio_it/test_pickle_timestamp_tz_async.py diff --git a/test/integ/aio/test_put_get_async.py b/test/integ/aio_it/test_put_get_async.py similarity index 100% rename from test/integ/aio/test_put_get_async.py rename to test/integ/aio_it/test_put_get_async.py diff --git a/test/integ/aio/test_put_get_compress_enc_async.py b/test/integ/aio_it/test_put_get_compress_enc_async.py similarity index 100% rename from test/integ/aio/test_put_get_compress_enc_async.py rename to test/integ/aio_it/test_put_get_compress_enc_async.py diff --git a/test/integ/aio/test_put_get_medium_async.py b/test/integ/aio_it/test_put_get_medium_async.py similarity index 100% rename from test/integ/aio/test_put_get_medium_async.py rename to test/integ/aio_it/test_put_get_medium_async.py diff --git a/test/integ/aio/test_put_get_snow_4525_async.py b/test/integ/aio_it/test_put_get_snow_4525_async.py similarity index 100% rename from test/integ/aio/test_put_get_snow_4525_async.py rename to test/integ/aio_it/test_put_get_snow_4525_async.py diff --git a/test/integ/aio/test_put_get_user_stage_async.py b/test/integ/aio_it/test_put_get_user_stage_async.py similarity index 100% rename from test/integ/aio/test_put_get_user_stage_async.py rename to test/integ/aio_it/test_put_get_user_stage_async.py diff --git a/test/integ/aio/test_put_get_with_aws_token_async.py b/test/integ/aio_it/test_put_get_with_aws_token_async.py similarity index 100% rename from test/integ/aio/test_put_get_with_aws_token_async.py rename to test/integ/aio_it/test_put_get_with_aws_token_async.py diff --git a/test/integ/aio/test_put_get_with_azure_token_async.py b/test/integ/aio_it/test_put_get_with_azure_token_async.py similarity index 100% rename from test/integ/aio/test_put_get_with_azure_token_async.py rename to test/integ/aio_it/test_put_get_with_azure_token_async.py diff --git a/test/integ/aio/test_put_get_with_gcp_account_async.py b/test/integ/aio_it/test_put_get_with_gcp_account_async.py similarity index 100% rename from test/integ/aio/test_put_get_with_gcp_account_async.py rename to test/integ/aio_it/test_put_get_with_gcp_account_async.py diff --git a/test/integ/aio/test_put_windows_path_async.py b/test/integ/aio_it/test_put_windows_path_async.py similarity index 100% rename from test/integ/aio/test_put_windows_path_async.py rename to test/integ/aio_it/test_put_windows_path_async.py diff --git a/test/integ/aio/test_qmark_async.py b/test/integ/aio_it/test_qmark_async.py similarity index 100% rename from test/integ/aio/test_qmark_async.py rename to test/integ/aio_it/test_qmark_async.py diff --git a/test/integ/aio/test_query_cancelling_async.py b/test/integ/aio_it/test_query_cancelling_async.py similarity index 100% rename from test/integ/aio/test_query_cancelling_async.py rename to test/integ/aio_it/test_query_cancelling_async.py diff --git a/test/integ/aio/test_results_async.py b/test/integ/aio_it/test_results_async.py similarity index 100% rename from test/integ/aio/test_results_async.py rename to test/integ/aio_it/test_results_async.py diff --git a/test/integ/aio/test_reuse_cursor_async.py b/test/integ/aio_it/test_reuse_cursor_async.py similarity index 100% rename from test/integ/aio/test_reuse_cursor_async.py rename to test/integ/aio_it/test_reuse_cursor_async.py diff --git a/test/integ/aio/test_session_parameters_async.py b/test/integ/aio_it/test_session_parameters_async.py similarity index 100% rename from test/integ/aio/test_session_parameters_async.py rename to test/integ/aio_it/test_session_parameters_async.py diff --git a/test/integ/aio/test_statement_parameter_binding_async.py b/test/integ/aio_it/test_statement_parameter_binding_async.py similarity index 100% rename from test/integ/aio/test_statement_parameter_binding_async.py rename to test/integ/aio_it/test_statement_parameter_binding_async.py diff --git a/test/integ/aio/test_structured_types_async.py b/test/integ/aio_it/test_structured_types_async.py similarity index 100% rename from test/integ/aio/test_structured_types_async.py rename to test/integ/aio_it/test_structured_types_async.py diff --git a/test/integ/aio/test_transaction_async.py b/test/integ/aio_it/test_transaction_async.py similarity index 100% rename from test/integ/aio/test_transaction_async.py rename to test/integ/aio_it/test_transaction_async.py From 607db2e1facc15389a17d43ceb39736aeee3e1d6 Mon Sep 17 00:00:00 2001 From: Mark Keller Date: Mon, 14 Apr 2025 08:44:04 -0700 Subject: [PATCH 2/5] SNOW-1825495 OAuth flows implementation (#2135) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Michał Hofman Co-authored-by: Piotr Bulawa Co-authored-by: Maxim Mishchenko Co-authored-by: Mikołaj Kubik Co-authored-by: Yijun Xie Co-authored-by: Zexin Yao <103003040+sfc-gh-zyao@users.noreply.github.com> Co-authored-by: Jakub Szczerbiński Co-authored-by: Patryk Cyrek --- .../parameters_aws_auth_tests.json.gpg | Bin 0 -> 934 bytes .../private/rsa_keys/rsa_key.p8.gpg | Bin 0 -> 1401 bytes .../private/rsa_keys/rsa_key_invalid.p8.gpg | Bin 0 -> 1409 bytes DESCRIPTION.md | 1 + Jenkinsfile | 57 +- ci/container/test_authentication.sh | 24 + ci/test_authentication.sh | 27 + src/snowflake/connector/auth/__init__.py | 6 + src/snowflake/connector/auth/_auth.py | 28 +- src/snowflake/connector/auth/_http_server.py | 220 ++++++ src/snowflake/connector/auth/_oauth_base.py | 367 +++++++++ src/snowflake/connector/auth/oauth_code.py | 383 +++++++++ .../connector/auth/oauth_credentials.py | 64 ++ src/snowflake/connector/auth/webbrowser.py | 1 + src/snowflake/connector/connection.py | 176 ++++- src/snowflake/connector/constants.py | 6 +- src/snowflake/connector/errorcode.py | 9 +- src/snowflake/connector/file_lock.py | 72 ++ src/snowflake/connector/network.py | 3 + src/snowflake/connector/token_cache.py | 482 +++++++----- .../connector/vendored/requests/__init__.py | 1 - .../connector/vendored/requests/adapters.py | 1 - .../connector/vendored/requests/exceptions.py | 1 - .../connector/vendored/requests/help.py | 2 +- .../connector/vendored/requests/models.py | 1 - .../connector/vendored/requests/utils.py | 1 - test/auth/__init__.py | 0 test/auth/authorization_parameters.py | 218 ++++++ test/auth/authorization_test_helper.py | 144 ++++ test/auth/test_external_browser.py | 90 +++ test/auth/test_key_pair.py | 39 + test/auth/test_oauth.py | 59 ++ test/auth/test_okta.py | 58 ++ test/auth/test_okta_authorization_code.py | 96 +++ test/auth/test_okta_client_credentials.py | 57 ++ test/auth/test_pat.py | 82 ++ .../auth/test_snowflake_authorization_code.py | 122 +++ ..._snowflake_authorization_code_wildcards.py | 121 +++ test/conftest.py | 4 + .../browser_timeout_authorization_error.json | 15 + .../external_idp_custom_urls.json | 77 ++ .../invalid_scope_error.json | 17 + .../invalid_state_error.json | 17 + .../new_tokens_after_failed_refresh.json | 34 + .../successful_auth_after_failed_refresh.json | 37 + .../authorization_code/successful_flow.json | 77 ++ .../token_request_error.json | 67 ++ .../successful_auth_after_failed_refresh.json | 35 + .../client_credentials/successful_flow.json | 39 + .../token_request_error.json | 29 + .../oauth/refresh_token/refresh_failed.json | 28 + .../refresh_token/refresh_successful.json | 30 + .../generic/snowflake_login_failed.json | 48 ++ .../generic/snowflake_login_successful.json | 64 ++ test/unit/test_auth_callback_server.py | 63 ++ test/unit/test_auth_oauth_auth_code.py | 22 + test/unit/test_connection.py | 6 +- test/unit/test_linux_local_file_cache.py | 197 ++++- test/unit/test_oauth_token.py | 729 ++++++++++++++++++ test/unit/test_wiremock_client.py | 1 + test/wiremock/wiremock_utils.py | 13 +- 61 files changed, 4393 insertions(+), 275 deletions(-) create mode 100644 .github/workflows/parameters/private/parameters_aws_auth_tests.json.gpg create mode 100644 .github/workflows/parameters/private/rsa_keys/rsa_key.p8.gpg create mode 100644 .github/workflows/parameters/private/rsa_keys/rsa_key_invalid.p8.gpg create mode 100755 ci/container/test_authentication.sh create mode 100755 ci/test_authentication.sh create mode 100644 src/snowflake/connector/auth/_http_server.py create mode 100644 src/snowflake/connector/auth/_oauth_base.py create mode 100644 src/snowflake/connector/auth/oauth_code.py create mode 100644 src/snowflake/connector/auth/oauth_credentials.py create mode 100644 src/snowflake/connector/file_lock.py create mode 100644 test/auth/__init__.py create mode 100644 test/auth/authorization_parameters.py create mode 100644 test/auth/authorization_test_helper.py create mode 100644 test/auth/test_external_browser.py create mode 100644 test/auth/test_key_pair.py create mode 100644 test/auth/test_oauth.py create mode 100644 test/auth/test_okta.py create mode 100644 test/auth/test_okta_authorization_code.py create mode 100644 test/auth/test_okta_client_credentials.py create mode 100644 test/auth/test_pat.py create mode 100644 test/auth/test_snowflake_authorization_code.py create mode 100644 test/auth/test_snowflake_authorization_code_wildcards.py create mode 100644 test/data/wiremock/mappings/auth/oauth/authorization_code/browser_timeout_authorization_error.json create mode 100644 test/data/wiremock/mappings/auth/oauth/authorization_code/external_idp_custom_urls.json create mode 100644 test/data/wiremock/mappings/auth/oauth/authorization_code/invalid_scope_error.json create mode 100644 test/data/wiremock/mappings/auth/oauth/authorization_code/invalid_state_error.json create mode 100644 test/data/wiremock/mappings/auth/oauth/authorization_code/new_tokens_after_failed_refresh.json create mode 100644 test/data/wiremock/mappings/auth/oauth/authorization_code/successful_auth_after_failed_refresh.json create mode 100644 test/data/wiremock/mappings/auth/oauth/authorization_code/successful_flow.json create mode 100644 test/data/wiremock/mappings/auth/oauth/authorization_code/token_request_error.json create mode 100644 test/data/wiremock/mappings/auth/oauth/client_credentials/successful_auth_after_failed_refresh.json create mode 100644 test/data/wiremock/mappings/auth/oauth/client_credentials/successful_flow.json create mode 100644 test/data/wiremock/mappings/auth/oauth/client_credentials/token_request_error.json create mode 100644 test/data/wiremock/mappings/auth/oauth/refresh_token/refresh_failed.json create mode 100644 test/data/wiremock/mappings/auth/oauth/refresh_token/refresh_successful.json create mode 100644 test/data/wiremock/mappings/generic/snowflake_login_failed.json create mode 100644 test/data/wiremock/mappings/generic/snowflake_login_successful.json create mode 100644 test/unit/test_auth_callback_server.py create mode 100644 test/unit/test_auth_oauth_auth_code.py create mode 100644 test/unit/test_oauth_token.py diff --git a/.github/workflows/parameters/private/parameters_aws_auth_tests.json.gpg b/.github/workflows/parameters/private/parameters_aws_auth_tests.json.gpg new file mode 100644 index 0000000000000000000000000000000000000000..4cdd2a880eff59ba16c5560d41c38805f3c07e14 GIT binary patch literal 934 zcmV;X16llx4Fm}T2tuxL^wL?b05`{5w_zj!W)cbZRclI)z9Fs#NEqK~#&6 zUz>n~St4sM%+}u3ESS1n9Veq`ukvAb!k-22DtJTfkR)1zPARN5h_J7h{ia{nq`tL- zqix22I+4?zKjZ|we9I?|h(qF$@Jk-E+12@NFwOfYhxn0YHqx&kyUBaKlkbLkSp0+j zI%~IeYT)F0!BfO*r;=5vVt-tw58!`6I(pNb z$C+q=Co*qE1L9M$ik_dpOg85&OQ(@l5v!t}Q2yO`-fRGpE8wZ%9k?HCe8SQjtr1j* zx!voNVEw^|w#jEg?%7bc#Zhqk8Eb8NfKO5{N;ks$@m@uYuB34O^lAJ1Tw~%0rO;c< z@5u_Co>V7+Fhn_^ZhVh6P=EO_muf{Evw0F|s`~82>Kql%xxXiN)ag&*k|#_9M1NQ+p3%fLbi*^nBqdJy4EfI~Mgo9l!6ujS7|wrDz|*4_^icpBQ0lgSEf(FweX(*c2gs1kPT_})I!!-{H&juK z#z(55*9yr#(d60PCE0stht|U4r@i&rXDO#f(J{Ao;Qk9hsjbv0yGzgN&M_}WQOvrn z-T;R;|5{=#eroUQnY^$-SuNo_{&v6CihiBdYz}$Szpy+M=t&bwZ?z5H4vk|CMhLBD zcXg$yrU=LR2vvQbrjS4VD$EAs-&kev0H| z2hEfLlTp(NT1kuE;0j}(zV}&Jat8Ia6}sJ4tGL7Sk7J$(<#%7qU0Vy~wR2 z(64GI$}njH74s$yefa4s9xci&D&O=V8ZJZ9Ia$(+&oNTThk`C|J4Fm}T2sW}O5FyLDc>mJs0h2d8Frh+x^{u3Pa0SMoLHH6?e=*-mjvgH> z2vLe85YdPlUO(Q!9!pSB6A>H&jweedm{s1V#ne)B?*Ew@rwU_b#TnM9;UhFx3To6t z_P6?=sU*gyF;QLA_#Mo$(kPi1(JXYTssT|^+$&x*L$Eevm!oS`|43p5;KmqU%^l%$ zf+r)?+eT*09Xu(o5iEK_joJ$}Dsa`pjGsmZ`mI2lp`mFRNV1sTd0Ci*okFH%AmE-; zd+4|J!QjPk6^BSXKqE0mf|TlZ2zaMlTbpXsC}8o<%fLM^&iVvao*?ml+)$0U$bpAs zLH3TUT={fDCfq?VMddD_#HP=Ix7ckQ9C~_R0%&T!)P+`x2!}dM`3Kl20N3Op6C{~Z zk%{y}8JDX^xOP(EJHGc%(w2EuKHp1HMsLfzY%s9N>vQ-ncyy|Qjq?>IM#&zzy$VRR zr%PD_0L5TZ=Fq+8=!nISW!w4d)Xi~Qk&z;<)M=*IY0|V2k9kxLJ6xnHG=GW{Bg392 zo?6Jztg?5>fX#J^fhUfm$i9b1;ZAon)7!0dn2$1FkNQp#N_7b@%US|Ee+O>>6c8F4 z8a;2=rr9|#bXujzitOi+FSk`y)kNpN4i;$@tv4=b7o%a>uN;EdP^E?n_Q-{wPYw~N z|47S7IEO?Ln4ASaGSR=-{Pc#}G1?1n7$VDvPU~-ssYO&VbeM{lQ6Tp~cqGaD@u7^Q zI=D!g8j<4D=f7>QH~i4K*5V~kiz6_G@_&}0$@Ie*VgJ|q0mU-bVrDgK z_UU*pcX31M|Du$yhr*09q5N?I+Obt52dlq~F=T{M%>_PfJG&tD;yh$2qBNa0W78(P z7EH4YTCLzV3LEsaoI0=+vSlhsLYl05tRpzAMgiTn$G;mh*}Oz(beRb=+xUafENnD=V(wydV*G}ytLhk*&7QzvwqgUs=o-nJZ-IS zyIN22BCd+3IHMThQ_=?Ln=Z32HzU$-b^bf9-<7+E^Oe1i{7Dw>UUYPFfpKM zSJm^uuFIae{#^76r5G*dmsT@A!OKNo+nngu#=P3_}|z{JF#59A&4vatGc8TE>m`%l%m`jfVV^uj=t}2 zJrjj9LIj(iz@|1G{v+ne-Lu2Ma4?|)pY)`d7DVN&Mdu+ns=?BFgXcUpL>bfKjT)e^ z5#Oe^lh5|RCake@4!5fn+xoS3alDz$xCmS=-X%uTwpF6O@}3q#d~ql~M| z{kB*tpL=a+_a;5=UL!Dn=AI8#{FjG}b6lG2D_3jMpT61#hB#NJmASwqAG~!GxUJKz z4_kZU4Uh;r7}q4lRX{b!1gC|6Hy!yLx=(TabS-?4vr6ah6S(+)C!Sa9Ta?{g4vRQ{ zO8cfZ?^Cr!d+yUXE#jo78AFzI7^KK>Q|5XOhg20-DxIyW%1Va#`!v#!;>>o_8w@+y@ HBt6`SeZ8+W literal 0 HcmV?d00001 diff --git a/.github/workflows/parameters/private/rsa_keys/rsa_key_invalid.p8.gpg b/.github/workflows/parameters/private/rsa_keys/rsa_key_invalid.p8.gpg new file mode 100644 index 0000000000000000000000000000000000000000..3d2442a7c8c319f53ac8d6af8b12bcff78a364da GIT binary patch literal 1409 zcmV-{1%CRB4Fm}T2oOE^p^$#mNB`350dCH$DrP24&;&c@7=1}8SgJ`t2Jhbld~Tl6 z0G^#)CvI%Mkv<0lD%2OXL2~H0fW4z>z(2$S(~kytf&DxtVS*9dICkcS+ffG~Z})uY zB`zl9fytg}f2tY#aS$7$`dU!eSG^#6UQrisepTgBF1~k(CjXE)AYQUTzEc--YbK}u zfsLh_J)DoV(9go0v$9Lt*l~ZQa4WUnT`63VpmmGnc8nyJjA-hQ*mB*#+OG1a7ASWB ze9Raxfi?A`6~)p!PMM?gM{9IagV-L({Z#oVi}JuWJi|)kx!>=yOrS9Gq5B-*7@l)) z{1{^`_X~Fv&Doaj))DjJ-L^W*9SS1i*gV(ca; zCE#)|nKO-5uH?~G%8AdGged-2(o6$lg?(fJ4+X01x(E6BVUjyzwz#mG9;gAgohtx= z+&#TCKz~*>R9}QeY_}#N+573zc^{U_Y4gN;>sP>M4jpt_~jH>I~>|f z@HAD^fpJJ@X(hwEc(W@+EO*05}n$URL!)yDLn$0&x z^Y^TAlbtX(ef+HM*}-W=_w*kkq{JQnh7gX=9dA%?Uj;erdH&Br>=oUw^F_@8# zLANF=N8ti7r*?em#%jnoCdOAhX5?oaWg<8t>FGIOrHxIeM@3$j&RnDKL$ZO^m($@tnS=DJfQX8u0o%kqgA|Ez(3Xdev20sf`K49<(%)+s=$Lqn7cLK zcYr+C2QO{oaBUN&oPOqMtoXcn*^>5AfVW*-OjbrWWwCs<;gyMKQ_0~e9`Pd2^zd>R zPX&|_DioP6LIMMh`;dZ*njth+Ku%Gg29O|p4;>}oQFCY>4qILxtFqk;)Bc%ZF3f)+`?qFwQwiKph#32+tK`; z=hK)V8LHpTW2$|D8@l``@kxvrP5`g*-?MM*%hfFnknxU z8cM#TPhCW@difY~FYFCSv86g8s5|^8##L%J7Ja|v+O~qJJ-y}x5Zh}k_BaO;v=2G` z+ny&dA%H9;I6Ct#0(NHo*n(D4L>KHT*TeodHh;rXV#}gki)mI%l-}(jmh@%C=yj&+ zS&Cj=Jp0|>K~z)V(Z4UOueoNTy4huPrFV~!yPZp;MUU5kNFE>N@{W&7(!v#LnSkIr zRv~<=iZTb7zF~)46dwn2zL%K#j?aMpW8ffUfnkv<{;fKOMqmyh5>jTxsWit;I&THD zL{Eh#W~~%SQQ8F3URzklBVsset&u0^faeR4!Z66t3mnorN#pDIQGk~3E%IowD9NC0 Plk!d)KNyEuj^Ur;bOOha literal 0 HcmV?d00001 diff --git a/DESCRIPTION.md b/DESCRIPTION.md index 3f8686eea4..916812e99c 100644 --- a/DESCRIPTION.md +++ b/DESCRIPTION.md @@ -13,6 +13,7 @@ Source code is also available at: https://github.com/snowflakedb/snowflake-conne - Dropped support for Python 3.8. - Basic decimal floating-point type support. - Added handling of PAT provided in `password` field. + - Added experimental support for OAuth authorization code and client credentials flows. - Improved error message for client-side query cancellations due to timeouts. - Added support of GCS regional endpoints. - Added `gcs_use_virtual_endpoints` connection property that forces the usage of the virtual GCS usage. Thanks to this it should be possible to set up private DNS entry for the GCS endpoint. See more: https://cloud.google.com/storage/docs/request-endpoints#xml-api diff --git a/Jenkinsfile b/Jenkinsfile index 699a514970..00374eaf9a 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -35,29 +35,46 @@ timestamps { string(name: 'parent_job', value: env.JOB_NAME), string(name: 'parent_build_number', value: env.BUILD_NUMBER) ] - stage('Test') { - try { - def commit_hash = "main" // default which we want to override - def bptp_tag = "bptp-stable" - def response = authenticatedGithubCall("https://api.github.com/repos/snowflakedb/snowflake/git/ref/tags/${bptp_tag}") - commit_hash = response.object.sha - // Append the bptp-stable commit sha to params - params += [string(name: 'svn_revision', value: commit_hash)] - } catch(Exception e) { - println("Exception computing commit hash from: ${response}") + parallel( + 'Test': { + stage('Test') { + try { + def commit_hash = "main" // default which we want to override + def bptp_tag = "bptp-stable" + def response = authenticatedGithubCall("https://api.github.com/repos/snowflakedb/snowflake/git/ref/tags/${bptp_tag}") + commit_hash = response.object.sha + // Append the bptp-stable commit sha to params + params += [string(name: 'svn_revision', value: commit_hash)] + } catch(Exception e) { + println("Exception computing commit hash from: ${response}") + } + parallel ( + 'Test Python 39': { build job: 'RT-PyConnector39-PC',parameters: params}, + 'Test Python 310': { build job: 'RT-PyConnector310-PC',parameters: params}, + 'Test Python 311': { build job: 'RT-PyConnector311-PC',parameters: params}, + 'Test Python 312': { build job: 'RT-PyConnector312-PC',parameters: params}, + 'Test Python 313': { build job: 'RT-PyConnector313-PC',parameters: params}, + 'Test Python 39 OldDriver': { build job: 'RT-PyConnector39-OldDriver-PC',parameters: params}, + 'Test Python 39 FIPS': { build job: 'RT-FIPS-PyConnector39',parameters: params}, + ) + } + }, + 'Test Authentication': { + stage('Test Authentication') { + withCredentials([ + string(credentialsId: 'a791118f-a1ea-46cd-b876-56da1b9bc71c', variable: 'NEXUS_PASSWORD'), + string(credentialsId: 'sfctest0-parameters-secret', variable: 'PARAMETERS_SECRET') + ]) { + sh '''\ + |#!/bin/bash -e + |$WORKSPACE/ci/test_authentication.sh + '''.stripMargin() } - parallel ( - 'Test Python 39': { build job: 'RT-PyConnector39-PC',parameters: params}, - 'Test Python 310': { build job: 'RT-PyConnector310-PC',parameters: params}, - 'Test Python 311': { build job: 'RT-PyConnector311-PC',parameters: params}, - 'Test Python 312': { build job: 'RT-PyConnector312-PC',parameters: params}, - 'Test Python 313': { build job: 'RT-PyConnector313-PC',parameters: params}, - 'Test Python 39 OldDriver': { build job: 'RT-PyConnector39-OldDriver-PC',parameters: params}, - 'Test Python 39 FIPS': { build job: 'RT-FIPS-PyConnector39',parameters: params}, - ) } } - } + ) + } +} pipeline { diff --git a/ci/container/test_authentication.sh b/ci/container/test_authentication.sh new file mode 100755 index 0000000000..d65c7627eb --- /dev/null +++ b/ci/container/test_authentication.sh @@ -0,0 +1,24 @@ +#!/bin/bash -e + +set -o pipefail + + +export WORKSPACE=${WORKSPACE:-/mnt/workspace} +export SOURCE_ROOT=${SOURCE_ROOT:-/mnt/host} + +MVNW_EXE=$SOURCE_ROOT/mvnw +AUTH_PARAMETER_FILE=./.github/workflows/parameters/private/parameters_aws_auth_tests.json +eval $(jq -r '.authtestparams | to_entries | map("export \(.key)=\(.value|tostring)")|.[]' $AUTH_PARAMETER_FILE) + +export SNOWFLAKE_AUTH_TEST_PRIVATE_KEY_PATH=./.github/workflows/parameters/private/rsa_keys/rsa_key.p8 +export SNOWFLAKE_AUTH_TEST_INVALID_PRIVATE_KEY_PATH=./.github/workflows/parameters/private/rsa_keys/rsa_key_invalid.p8 + +export SF_OCSP_TEST_MODE=true +export SF_ENABLE_EXPERIMENTAL_AUTHENTICATION=true +export RUN_AUTH_TESTS=true +export AUTHENTICATION_TESTS_ENV="docker" +export PYTHONPATH=$SOURCE_ROOT + +python3 -m pip install --break-system-packages -e . + +python3 -m pytest test/auth/* diff --git a/ci/test_authentication.sh b/ci/test_authentication.sh new file mode 100755 index 0000000000..dbf78c83e8 --- /dev/null +++ b/ci/test_authentication.sh @@ -0,0 +1,27 @@ +#!/bin/bash -e + +set -o pipefail + + +export THIS_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )" +export WORKSPACE=${WORKSPACE:-/tmp} + +CI_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )" +if [[ -n "$JENKINS_HOME" ]]; then + ROOT_DIR="$(cd "${CI_DIR}/.." && pwd)" + export WORKSPACE=${WORKSPACE:-/tmp} + echo "Use /sbin/ip" + IP_ADDR=$(/sbin/ip -4 addr show scope global dev eth0 | grep inet | awk '{print $2}' | cut -d / -f 1) + +fi + +gpg --quiet --batch --yes --decrypt --passphrase="$PARAMETERS_SECRET" --output $THIS_DIR/../.github/workflows/parameters/private/parameters_aws_auth_tests.json "$THIS_DIR/../.github/workflows/parameters/private/parameters_aws_auth_tests.json.gpg" +gpg --quiet --batch --yes --decrypt --passphrase="$PARAMETERS_SECRET" --output $THIS_DIR/../.github/workflows/parameters/private/rsa_keys/rsa_key.p8 "$THIS_DIR/../.github/workflows/parameters/private/rsa_keys/rsa_key.p8.gpg" +gpg --quiet --batch --yes --decrypt --passphrase="$PARAMETERS_SECRET" --output $THIS_DIR/../.github/workflows/parameters/private/rsa_keys/rsa_key_invalid.p8 "$THIS_DIR/../.github/workflows/parameters/private/rsa_keys/rsa_key_invalid.p8.gpg" + +docker run \ + -v $(cd $THIS_DIR/.. && pwd):/mnt/host \ + -v $WORKSPACE:/mnt/workspace \ + --rm \ + nexus.int.snowflakecomputing.com:8086/docker/snowdrivers-test-external-browser-python:1 \ + "/mnt/host/ci/container/test_authentication.sh" diff --git a/src/snowflake/connector/auth/__init__.py b/src/snowflake/connector/auth/__init__.py index 0874b35ca7..cb25f7d364 100644 --- a/src/snowflake/connector/auth/__init__.py +++ b/src/snowflake/connector/auth/__init__.py @@ -7,6 +7,8 @@ from .keypair import AuthByKeyPair from .no_auth import AuthNoAuth from .oauth import AuthByOAuth +from .oauth_code import AuthByOauthCode +from .oauth_credentials import AuthByOauthCredentials from .okta import AuthByOkta from .pat import AuthByPAT from .usrpwdmfa import AuthByUsrPwdMfa @@ -18,6 +20,8 @@ AuthByDefault, AuthByKeyPair, AuthByOAuth, + AuthByOauthCode, + AuthByOauthCredentials, AuthByOkta, AuthByUsrPwdMfa, AuthByWebBrowser, @@ -34,6 +38,8 @@ "AuthByKeyPair", "AuthByPAT", "AuthByOAuth", + "AuthByOauthCode", + "AuthByOauthCredentials", "AuthByOkta", "AuthByUsrPwdMfa", "AuthByWebBrowser", diff --git a/src/snowflake/connector/auth/_auth.py b/src/snowflake/connector/auth/_auth.py index cf3b6b6297..527bd5cf9b 100644 --- a/src/snowflake/connector/auth/_auth.py +++ b/src/snowflake/connector/auth/_auth.py @@ -47,6 +47,7 @@ ACCEPT_TYPE_APPLICATION_SNOWFLAKE, CONTENT_TYPE_APPLICATION_JSON, ID_TOKEN_INVALID_LOGIN_REQUEST_GS_CODE, + OAUTH_ACCESS_TOKEN_EXPIRED_GS_CODE, PYTHON_CONNECTOR_USER_AGENT, ReauthenticationRequest, ) @@ -86,7 +87,7 @@ class Auth: def __init__(self, rest) -> None: self._rest = rest - self.token_cache = TokenCache.make() + self._token_cache: TokenCache | None = None @staticmethod def base_auth_data( @@ -350,7 +351,7 @@ def post_request_wrapper(self, url, headers, body) -> None: # clear stored id_token if failed to connect because of id_token # raise an exception for reauth without id_token self._rest.id_token = None - self.delete_temporary_credential( + self._delete_temporary_credential( self._rest._host, user, TokenType.ID_TOKEN ) raise ReauthenticationRequest( @@ -360,6 +361,14 @@ def post_request_wrapper(self, url, headers, body) -> None: sqlstate=SQLSTATE_CONNECTION_WAS_NOT_ESTABLISHED, ) ) + elif errno == OAUTH_ACCESS_TOKEN_EXPIRED_GS_CODE: + raise ReauthenticationRequest( + ProgrammingError( + msg=ret["message"], + errno=int(errno), + sqlstate=SQLSTATE_CONNECTION_WAS_NOT_ESTABLISHED, + ) + ) from . import AuthByKeyPair @@ -374,7 +383,7 @@ def post_request_wrapper(self, url, headers, body) -> None: from . import AuthByUsrPwdMfa if isinstance(auth_instance, AuthByUsrPwdMfa): - self.delete_temporary_credential( + self._delete_temporary_credential( self._rest._host, user, TokenType.MFA_TOKEN ) Error.errorhandler_wrapper( @@ -466,7 +475,7 @@ def _read_temporary_credential( user: str, cred_type: TokenType, ) -> str | None: - return self.token_cache.retrieve(TokenKey(host, user, cred_type)) + return self.get_token_cache().retrieve(TokenKey(host, user, cred_type)) def read_temporary_credentials( self, @@ -500,7 +509,7 @@ def _write_temporary_credential( "no credential is given when try to store temporary credential" ) return - self.token_cache.store(TokenKey(host, user, cred_type), cred) + self.get_token_cache().store(TokenKey(host, user, cred_type), cred) def write_temporary_credentials( self, @@ -524,10 +533,15 @@ def write_temporary_credentials( host, user, TokenType.MFA_TOKEN, response["data"].get("mfaToken") ) - def delete_temporary_credential( + def _delete_temporary_credential( self, host: str, user: str, cred_type: TokenType ) -> None: - self.token_cache.remove(TokenKey(host, user, cred_type)) + self.get_token_cache().remove(TokenKey(host, user, cred_type)) + + def get_token_cache(self) -> TokenCache: + if self._token_cache is None: + self._token_cache = TokenCache.make() + return self._token_cache def get_token_from_private_key( diff --git a/src/snowflake/connector/auth/_http_server.py b/src/snowflake/connector/auth/_http_server.py new file mode 100644 index 0000000000..a11662f25b --- /dev/null +++ b/src/snowflake/connector/auth/_http_server.py @@ -0,0 +1,220 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# + +from __future__ import annotations + +import logging +import os +import select +import socket +import time +import urllib.parse +from collections.abc import Callable +from types import TracebackType + +from typing_extensions import Self + +from ..compat import IS_WINDOWS + +logger = logging.getLogger(__name__) + + +def _use_msg_dont_wait() -> bool: + if os.getenv("SNOWFLAKE_AUTH_SOCKET_MSG_DONTWAIT", "false").lower() != "true": + return False + if IS_WINDOWS: + logger.warning( + "Configuration SNOWFLAKE_AUTH_SOCKET_MSG_DONTWAIT is not available in Windows. Ignoring." + ) + return False + return True + + +def _wrap_socket_recv() -> Callable[[socket.socket, int], bytes]: + dont_wait = _use_msg_dont_wait() + if dont_wait: + # WSL containerized environment sometimes causes socket_client.recv to hang indefinetly + # To avoid this, passing the socket.MSG_DONTWAIT flag which raises BlockingIOError if + # operation would block + logger.debug( + "Will call socket.recv with MSG_DONTWAIT flag due to SNOWFLAKE_AUTH_SOCKET_MSG_DONTWAIT env var" + ) + socket_recv = ( + (lambda sock, buf_size: socket.socket.recv(sock, buf_size, socket.MSG_DONTWAIT)) + if dont_wait + else (lambda sock, buf_size: socket.socket.recv(sock, buf_size)) + ) + + def socket_recv_checked(sock: socket.socket, buf_size: int) -> bytes: + raw = socket_recv(sock, buf_size) + # when running in a containerized environment, socket_client.recv occasionally returns an empty byte array + # an immediate successive call to socket_client.recv gets the actual data + if len(raw) == 0: + raw = socket_recv(sock, buf_size) + return raw + + return socket_recv_checked + + +class AuthHttpServer: + """Simple HTTP server to receive callbacks through for auth purposes.""" + + DEFAULT_MAX_ATTEMPTS = 15 + DEFAULT_TIMEOUT = 30.0 + + PORT_BIND_MAX_ATTEMPTS = 10 + PORT_BIND_TIMEOUT = 20.0 + + def __init__( + self, + uri: str, + buf_size: int = 16384, + ) -> None: + parsed_uri = urllib.parse.urlparse(uri) + self._socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + self.buf_size = buf_size + if os.getenv("SNOWFLAKE_AUTH_SOCKET_REUSE_PORT", "False").lower() == "true": + if IS_WINDOWS: + logger.warning( + "Configuration SNOWFLAKE_AUTH_SOCKET_REUSE_PORT is not available in Windows. Ignoring." + ) + else: + self._socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1) + + port = parsed_uri.port or 0 + for attempt in range(1, self.DEFAULT_MAX_ATTEMPTS + 1): + try: + self._socket.bind( + ( + parsed_uri.hostname, + port, + ) + ) + break + except socket.gaierror as ex: + logger.error( + f"Failed to bind authorization callback server to port {port}: {ex}" + ) + raise + except OSError as ex: + if attempt == self.DEFAULT_MAX_ATTEMPTS: + logger.error( + f"Failed to bind authorization callback server to port {port}: {ex}" + ) + raise + logger.warning( + f"Attempt {attempt}/{self.DEFAULT_MAX_ATTEMPTS}. " + f"Failed to bind authorization callback server to port {port}: {ex}" + ) + time.sleep(self.PORT_BIND_TIMEOUT / self.PORT_BIND_MAX_ATTEMPTS) + try: + self._socket.listen(0) # no backlog + except Exception as ex: + logger.error(f"Failed to start listening for auth callback: {ex}") + self.close() + raise + port = self._socket.getsockname()[1] + self._uri = urllib.parse.ParseResult( + scheme=parsed_uri.scheme, + netloc=parsed_uri.hostname + ":" + str(port), + path=parsed_uri.path, + params=parsed_uri.params, + query=parsed_uri.query, + fragment=parsed_uri.fragment, + ) + + @property + def url(self) -> str: + return self._uri.geturl() + + @property + def port(self) -> int: + return self._uri.port + + @property + def hostname(self) -> str: + return self._uri.hostname + + def _try_poll( + self, attempts: int, attempt_timeout: float | None + ) -> (socket.socket | None, int): + for attempt in range(attempts): + read_sockets = select.select([self._socket], [], [], attempt_timeout)[0] + if read_sockets and read_sockets[0] is not None: + return self._socket.accept()[0], attempt + return None, attempts + + def _try_receive_block( + self, client_socket: socket.socket, attempts: int, attempt_timeout: float | None + ) -> bytes | None: + if attempt_timeout is not None: + client_socket.settimeout(attempt_timeout) + recv = _wrap_socket_recv() + for attempt in range(attempts): + try: + return recv(client_socket, self.buf_size) + except BlockingIOError: + if attempt < attempts - 1: + cooldown = min(attempt_timeout, 0.25) if attempt_timeout else 0.25 + logger.debug( + f"BlockingIOError raised from socket.recv on {1 + attempt}/{attempts} attempt." + f"Waiting for {cooldown} seconds before trying again" + ) + time.sleep(cooldown) + except socket.timeout: + logger.debug( + f"socket.recv timed out on {1 + attempt}/{attempts} attempt." + ) + return None + + def receive_block( + self, + max_attempts: int = None, + timeout: float | int | None = None, + ) -> (list[str] | None, socket.socket | None): + if max_attempts is None: + max_attempts = self.DEFAULT_MAX_ATTEMPTS + if timeout is None: + timeout = self.DEFAULT_TIMEOUT + """Receive a message with a maximum attempt count and a timeout in seconds, blocking.""" + if not self._socket: + raise RuntimeError( + "Operation is not supported, server was already shut down." + ) + attempt_timeout = timeout / max_attempts if timeout else None + client_socket, poll_attempts = self._try_poll(max_attempts, attempt_timeout) + if client_socket is None: + return None, None + raw_block = self._try_receive_block( + client_socket, max_attempts - poll_attempts, attempt_timeout + ) + if raw_block: + return raw_block.decode("utf-8").split("\r\n"), client_socket + try: + client_socket.shutdown(socket.SHUT_RDWR) + except OSError: + pass + client_socket.close() + return None, None + + def close(self) -> None: + """Closes the underlying socket. + After having close() being called the server object cannot be reused. + """ + if self._socket: + self._socket.close() + self._socket = None + + def __enter__(self) -> Self: + """Context manager.""" + return self + + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> None: + """Context manager with disposing underlying networking objects.""" + self.close() diff --git a/src/snowflake/connector/auth/_oauth_base.py b/src/snowflake/connector/auth/_oauth_base.py new file mode 100644 index 0000000000..ec77b22735 --- /dev/null +++ b/src/snowflake/connector/auth/_oauth_base.py @@ -0,0 +1,367 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# + +from __future__ import annotations + +import base64 +import json +import logging +import urllib.parse +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING, Any +from urllib.error import HTTPError, URLError + +from ..errorcode import ER_FAILED_TO_REQUEST, ER_IDP_CONNECTION_ERROR +from ..network import OAUTH_AUTHENTICATOR +from ..secret_detector import SecretDetector +from ..token_cache import TokenCache, TokenKey, TokenType +from ..vendored import urllib3 +from .by_plugin import AuthByPlugin, AuthType + +if TYPE_CHECKING: + from .. import SnowflakeConnection + +logger = logging.getLogger(__name__) + + +class _OAuthTokensMixin: + def __init__( + self, + token_cache: TokenCache | None, + refresh_token_enabled: bool, + idp_host: str, + ) -> None: + self._access_token = None + self._refresh_token_enabled = refresh_token_enabled + if self._refresh_token_enabled: + self._refresh_token = None + self._token_cache = token_cache + if self._token_cache: + logger.debug("token cache is going to be used if needed") + self._idp_host = idp_host + self._access_token_key: TokenKey | None = None + if self._refresh_token_enabled: + self._refresh_token_key: TokenKey | None = None + + def _update_cache_keys(self, user: str) -> None: + if self._token_cache: + self._user = user + + def _get_access_token_cache_key(self) -> TokenKey | None: + return ( + TokenKey(self._user, self._idp_host, TokenType.OAUTH_ACCESS_TOKEN) + if self._token_cache and self._user + else None + ) + + def _get_refresh_token_cache_key(self) -> TokenKey | None: + return ( + TokenKey(self._user, self._idp_host, TokenType.OAUTH_REFRESH_TOKEN) + if self._refresh_token_enabled and self._token_cache and self._user + else None + ) + + def _pop_cached_token(self, key: TokenKey | None) -> str | None: + if self._token_cache is None or key is None: + return None + return self._token_cache.retrieve(key) + + def _pop_cached_access_token(self) -> bool: + """Retrieves OAuth access token from the token cache if enabled""" + self._access_token = self._pop_cached_token(self._get_access_token_cache_key()) + return self._access_token is not None + + def _pop_cached_refresh_token(self) -> bool: + """Retrieves OAuth refresh token from the token cache if enabled""" + if self._refresh_token_enabled: + self._refresh_token = self._pop_cached_token( + self._get_refresh_token_cache_key() + ) + return self._refresh_token is not None + return False + + def _reset_cached_token(self, key: TokenKey | None, token: str | None) -> None: + if self._token_cache is None or key is None: + return + if token: + self._token_cache.store(key, token) + else: + self._token_cache.remove(key) + + def _reset_access_token(self, access_token: str | None = None) -> None: + """Updates OAuth access token both in memory and in the token cache if enabled""" + logger.debug( + "resetting access token to %s", + "*" * len(access_token) if access_token else None, + ) + self._access_token = access_token + self._reset_cached_token(self._get_access_token_cache_key(), self._access_token) + + def _reset_refresh_token(self, refresh_token: str | None = None) -> None: + """Updates OAuth refresh token both in memory and in the token cache if necessary""" + if self._refresh_token_enabled: + logger.debug( + "resetting refresh token to %s", + "*" * len(refresh_token) if refresh_token else None, + ) + self._refresh_token = refresh_token + self._reset_cached_token( + self._get_refresh_token_cache_key(), self._refresh_token + ) + + def _reset_temporary_state(self) -> None: + self._access_token = None + if self._refresh_token_enabled: + self._refresh_token = None + if self._token_cache: + self._user = None + + +class AuthByOAuthBase(AuthByPlugin, _OAuthTokensMixin, ABC): + """A base abstract class for OAuth authenticators""" + + def __init__( + self, + client_id: str, + client_secret: str, + token_request_url: str, + scope: str, + token_cache: TokenCache | None, + refresh_token_enabled: bool, + **kwargs, + ) -> None: + super().__init__(**kwargs) + _OAuthTokensMixin.__init__( + self, + token_cache=token_cache, + refresh_token_enabled=refresh_token_enabled, + idp_host=urllib.parse.urlparse(token_request_url).hostname, + ) + self._client_id = client_id + self._client_secret = client_secret + self._token_request_url = token_request_url + self._scope = scope + if refresh_token_enabled: + logger.debug("oauth refresh token is going to be used if needed") + self._scope += (" " if self._scope else "") + "offline_access" + + @abstractmethod + def _request_tokens( + self, + *, + conn: SnowflakeConnection, + authenticator: str, + service_name: str | None, + account: str, + user: str, + password: str | None, + **kwargs: Any, + ) -> (str | None, str | None): + """Request new access and optionally refresh tokens from IdP. + + This function should implement specific tokens querying flow. + """ + raise NotImplementedError + + @abstractmethod + def _get_oauth_type_id(self) -> str: + """Get OAuth specific authenticator id to be passed to Snowflake. + + This function should return a unique OAuth authenticator id. + """ + raise NotImplementedError + + def reset_secrets(self) -> None: + logger.debug("resetting secrets") + self._reset_temporary_state() + + @property + def type_(self) -> AuthType: + return AuthType.OAUTH + + @property + def assertion_content(self) -> str: + """Returns the token.""" + return self._access_token or "" + + def reauthenticate( + self, + *, + conn: SnowflakeConnection, + **kwargs: Any, + ) -> dict[str, bool]: + self._reset_access_token() + if self._pop_cached_refresh_token(): + logger.debug( + "OAuth refresh token is available, try to use it and get a new access token" + ) + self._do_refresh_token(conn=conn) + conn.authenticate_with_retry(self) + return {"success": True} + + def prepare( + self, + *, + conn: SnowflakeConnection, + authenticator: str, + service_name: str | None, + account: str, + user: str, + **kwargs: Any, + ) -> None: + """Web Browser based Authentication.""" + logger.debug("authenticating with OAuth authorization code flow") + self._update_cache_keys(user=user) + if self._pop_cached_access_token(): + logger.info( + "OAuth access token is already available in cache, no need to authenticate." + ) + return + access_token, refresh_token = self._request_tokens( + conn=conn, + authenticator=authenticator, + service_name=service_name, + account=account, + user=user, + **kwargs, + ) + self._reset_access_token(access_token) + self._reset_refresh_token(refresh_token) + + def update_body(self, body: dict[Any, Any]) -> None: + """Used by Auth to update the request that gets sent to /v1/login-request. + + Args: + body: existing request dictionary + """ + body["data"]["AUTHENTICATOR"] = OAUTH_AUTHENTICATOR + body["data"]["TOKEN"] = self._access_token + body["data"]["OAUTH_TYPE"] = self._get_oauth_type_id() + + def _do_refresh_token(self, conn: SnowflakeConnection) -> None: + """If a refresh token is available exchanges it with a new access token. + Updates self as a side-effect. Needs at lest self._refresh_token and client_id set. + """ + if not self._refresh_token_enabled: + logger.debug("refresh_token feature is disabled") + return + + resp = self._get_refresh_token_response(conn) + if not resp: + logger.info( + "failed to exchange the refresh token on a new OAuth access token" + ) + self._reset_refresh_token() + return + + try: + json_resp = json.loads(resp.data.decode()) + self._reset_access_token(json_resp["access_token"]) + if "refresh_token" in json_resp: + self._reset_refresh_token(json_resp["refresh_token"]) + except ( + json.JSONDecodeError, + KeyError, + ): + logger.error( + "refresh token exchange response did not contain 'access_token'" + ) + logger.debug( + "received the following response body when exchanging refresh token: %s", + SecretDetector.mask_secrets(str(resp.data)), + ) + self._reset_refresh_token() + + def _get_refresh_token_response( + self, conn: SnowflakeConnection + ) -> urllib3.BaseHTTPResponse | None: + fields = { + "grant_type": "refresh_token", + "refresh_token": self._refresh_token, + } + if self._scope: + fields["scope"] = self._scope + try: + return urllib3.PoolManager().request_encode_body( + # TODO: use network pool to gain use of proxy settings and so on + "POST", + self._token_request_url, + encode_multipart=False, + headers=self._create_token_request_headers(), + fields=fields, + ) + except HTTPError as e: + self._handle_failure( + conn=conn, + ret={ + "code": ER_FAILED_TO_REQUEST, + "message": f"Failed to request new OAuth access token with a refresh token," + f" url={e.url}, code={e.code}, reason={e.reason}", + }, + ) + except URLError as e: + self._handle_failure( + conn=conn, + ret={ + "code": ER_FAILED_TO_REQUEST, + "message": f"Failed to request new OAuth access token with a refresh token, reason: {e.reason}", + }, + ) + except Exception: + self._handle_failure( + conn=conn, + ret={ + "code": ER_FAILED_TO_REQUEST, + "message": "Failed to request new OAuth access token with a refresh token by unknown reason", + }, + ) + return None + + def _get_request_token_response( + self, + connection: SnowflakeConnection, + fields: dict[str, str], + ) -> (str | None, str | None): + resp = urllib3.PoolManager().request_encode_body( + # TODO: use network pool to gain use of proxy settings and so on + "POST", + self._token_request_url, + headers=self._create_token_request_headers(), + encode_multipart=False, + fields=fields, + ) + try: + logger.debug("OAuth IdP response received, try to parse it") + json_resp: dict = json.loads(resp.data) + access_token = json_resp["access_token"] + refresh_token = json_resp.get("refresh_token") + return access_token, refresh_token + except ( + json.JSONDecodeError, + KeyError, + ): + logger.error("oauth response invalid, does not contain 'access_token'") + logger.debug( + "received the following response body when requesting oauth token: %s", + SecretDetector.mask_secrets(str(resp.data)), + ) + self._handle_failure( + conn=connection, + ret={ + "code": ER_IDP_CONNECTION_ERROR, + "message": "Invalid HTTP request from web browser. Idp " + "authentication could have failed.", + }, + ) + return None, None + + def _create_token_request_headers(self) -> dict[str, str]: + return { + "Authorization": "Basic " + + base64.b64encode( + f"{self._client_id}:{self._client_secret}".encode() + ).decode(), + "Accept": "application/json", + "Content-Type": "application/x-www-form-urlencoded; charset=UTF-8", + } diff --git a/src/snowflake/connector/auth/oauth_code.py b/src/snowflake/connector/auth/oauth_code.py new file mode 100644 index 0000000000..f93562bc3b --- /dev/null +++ b/src/snowflake/connector/auth/oauth_code.py @@ -0,0 +1,383 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# + +from __future__ import annotations + +import base64 +import hashlib +import json +import logging +import secrets +import socket +import time +import urllib.parse +import webbrowser +from typing import TYPE_CHECKING, Any + +from ..compat import parse_qs, urlparse, urlsplit +from ..constants import OAUTH_TYPE_AUTHORIZATION_CODE +from ..errorcode import ( + ER_OAUTH_CALLBACK_ERROR, + ER_OAUTH_SERVER_TIMEOUT, + ER_OAUTH_STATE_CHANGED, + ER_UNABLE_TO_OPEN_BROWSER, +) +from ..token_cache import TokenCache +from ._http_server import AuthHttpServer +from ._oauth_base import AuthByOAuthBase + +if TYPE_CHECKING: + from .. import SnowflakeConnection + +logger = logging.getLogger(__name__) + +BUF_SIZE = 16384 + + +def _get_query_params( + url: str, +) -> dict[str, list[str]]: + parsed = parse_qs(urlparse(url).query) + return parsed + + +class AuthByOauthCode(AuthByOAuthBase): + """Authenticates user by OAuth code flow.""" + + def __init__( + self, + application: str, + client_id: str, + client_secret: str, + authentication_url: str, + token_request_url: str, + redirect_uri: str, + scope: str, + pkce_enabled: bool = True, + token_cache: TokenCache | None = None, + refresh_token_enabled: bool = False, + external_browser_timeout: int | None = None, + **kwargs, + ) -> None: + super().__init__( + client_id=client_id, + client_secret=client_secret, + token_request_url=token_request_url, + scope=scope, + token_cache=token_cache, + refresh_token_enabled=refresh_token_enabled, + **kwargs, + ) + self._application = application + self._origin: str | None = None + self._authentication_url = authentication_url + self._redirect_uri = redirect_uri + self._state = secrets.token_urlsafe(43) + logger.debug("chose oauth state: %s", "".join("*" for _ in self._state)) + self._protocol = "http" + self._pkce_enabled = pkce_enabled + if pkce_enabled: + logger.debug("oauth pkce is going to be used") + self._verifier: str | None = None + self._external_browser_timeout = external_browser_timeout + + def _get_oauth_type_id(self) -> str: + return OAUTH_TYPE_AUTHORIZATION_CODE + + def _request_tokens( + self, + *, + conn: SnowflakeConnection, + authenticator: str, + service_name: str | None, + account: str, + user: str, + **kwargs: Any, + ) -> (str | None, str | None): + """Web Browser based Authentication.""" + logger.debug("authenticating with OAuth authorization code flow") + with AuthHttpServer(self._redirect_uri) as callback_server: + code = self._do_authorization_request(callback_server, conn) + return self._do_token_request(code, callback_server, conn) + + def _check_post_requested( + self, data: list[str] + ) -> tuple[str, str] | tuple[None, None]: + request_line = None + header_line = None + origin_line = None + for line in data: + if line.startswith("Access-Control-Request-Method:"): + request_line = line + elif line.startswith("Access-Control-Request-Headers:"): + header_line = line + elif line.startswith("Origin:"): + origin_line = line + + if ( + not request_line + or not header_line + or not origin_line + or request_line.split(":")[1].strip() != "POST" + ): + return (None, None) + + return ( + header_line.split(":")[1].strip(), + ":".join(origin_line.split(":")[1:]).strip(), + ) + + def _process_options( + self, data: list[str], socket_client: socket.socket, hostname: str, port: int + ) -> bool: + """Allows JS Ajax access to this endpoint.""" + for line in data: + if line.startswith("OPTIONS "): + break + else: + return False + requested_headers, requested_origin = self._check_post_requested(data) + if requested_headers is None or requested_origin is None: + return False + + if not self._validate_origin(requested_origin, hostname, port): + # validate Origin and fail if not match with the server. + return False + + self._origin = requested_origin + content = [ + "HTTP/1.1 200 OK", + "Date: {}".format( + time.strftime("%a, %d %b %Y %H:%M:%S GMT", time.gmtime()) + ), + "Access-Control-Allow-Methods: POST, GET", + f"Access-Control-Allow-Headers: {requested_headers}", + "Access-Control-Max-Age: 86400", + f"Access-Control-Allow-Origin: {self._origin}", + "", + "", + ] + socket_client.sendall("\r\n".join(content).encode("utf-8")) + return True + + def _validate_origin(self, requested_origin: str, hostname: str, port: int) -> bool: + ret = urlsplit(requested_origin) + netloc = ret.netloc.split(":") + host_got = netloc[0] + port_got = ( + netloc[1] if len(netloc) > 1 else (443 if self._protocol == "https" else 80) + ) + + return ( + ret.scheme == self._protocol and host_got == hostname and port_got == port + ) + + def _send_response(self, data: list[str], socket_client: socket.socket) -> None: + if not self._is_request_get(data): + return # error + + response = [ + "HTTP/1.1 200 OK", + "Content-Type: text/html", + ] + if self._origin: + msg = json.dumps({"consent": self.consent_cache_id_token}) + response.append(f"Access-Control-Allow-Origin: {self._origin}") + response.append("Vary: Accept-Encoding, Origin") + else: + msg = f""" + + +OAuth Response for Snowflake + +Your identity was confirmed and propagated to Snowflake {self._application}. +You can close this window now and go back where you started from. +""" + response.append(f"Content-Length: {len(msg)}") + response.append("") + response.append(msg) + + socket_client.sendall("\r\n".join(response).encode("utf-8")) + + @staticmethod + def _has_code(url: str) -> bool: + return "code" in parse_qs(urlparse(url).query) + + @staticmethod + def _is_request_get(data: list[str]) -> bool: + """Whether an HTTP request is a GET.""" + return any(line.startswith("GET ") for line in data) + + def _construct_authorization_request(self, redirect_uri: str) -> str: + params = { + "response_type": "code", + "client_id": self._client_id, + "redirect_uri": redirect_uri, + "state": self._state, + } + if self._scope: + params["scope"] = self._scope + if self._pkce_enabled: + self._verifier = secrets.token_urlsafe(43) + # calculate challenge and verifier + challenge = ( + base64.urlsafe_b64encode( + hashlib.sha256(self._verifier.encode("utf-8")).digest() + ) + .decode("utf-8") + .rstrip("=") + ) + params["code_challenge"] = challenge + params["code_challenge_method"] = "S256" + url_params = urllib.parse.urlencode(params) + url = f"{self._authentication_url}?{url_params}" + return url + + def _do_authorization_request( + self, + callback_server: AuthHttpServer, + connection: SnowflakeConnection, + ) -> str | None: + authorization_request = self._construct_authorization_request( + callback_server.url + ) + logger.debug("step 1: going to open authorization URL") + print( + "Initiating login request with your identity provider. A " + "browser window should have opened for you to complete the " + "login. If you can't see it, check existing browser windows, " + "or your OS settings. Press CTRL+C to abort and try again..." + ) + code, state = ( + self._receive_authorization_callback(callback_server, connection) + if webbrowser.open(authorization_request) + else self._ask_authorization_callback_from_user( + authorization_request, connection + ) + ) + if not code: + self._handle_failure( + conn=connection, + ret={ + "code": ER_UNABLE_TO_OPEN_BROWSER, + "message": ( + "Unable to open a browser in this environment and " + "OAuth URL contained no authorization code." + ), + }, + ) + return None + if state != self._state: + self._handle_failure( + conn=connection, + ret={ + "code": ER_OAUTH_STATE_CHANGED, + "message": "State changed during OAuth process.", + }, + ) + logger.debug( + "received oauth code: %s and state: %s", + "*" * len(code), + "*" * len(state), + ) + return None + return code + + def _do_token_request( + self, + code: str, + callback_server: AuthHttpServer, + connection: SnowflakeConnection, + ) -> (str | None, str | None): + logger.debug("step 2: received OAuth callback, requesting token") + fields = { + "grant_type": "authorization_code", + "code": code, + "redirect_uri": callback_server.url, + } + if self._pkce_enabled: + assert self._verifier is not None + fields["code_verifier"] = self._verifier + return self._get_request_token_response(connection, fields) + + def _receive_authorization_callback( + self, + http_server: AuthHttpServer, + connection: SnowflakeConnection, + ) -> (str | None, str | None): + logger.debug("trying to receive authorization redirected uri") + data, socket_connection = http_server.receive_block( + timeout=self._external_browser_timeout + ) + if socket_connection is None: + self._handle_failure( + conn=connection, + ret={ + "code": ER_OAUTH_SERVER_TIMEOUT, + "message": "Unable to receive the OAuth message within a given timeout. Please check the redirect URI and try again.", + }, + ) + return None, None + try: + if not self._process_options( + data, socket_connection, http_server.hostname, http_server.port + ): + self._send_response(data, socket_connection) + socket_connection.shutdown(socket.SHUT_RDWR) + except OSError: + pass + finally: + socket_connection.close() + return self._parse_authorization_redirected_request( + data[0].split(maxsplit=2)[1], + connection, + ) + + def _ask_authorization_callback_from_user( + self, + authorization_request: str, + connection: SnowflakeConnection, + ) -> (str | None, str | None): + logger.debug("requesting authorization redirected url from user") + print( + "We were unable to open a browser window for you, " + "please open the URL manually then paste the " + "URL you are redirected to into the terminal:\n" + f"{authorization_request}" + ) + received_redirected_request = input( + "Enter the URL the OAuth flow redirected you to: " + ) + code, state = self._parse_authorization_redirected_request( + received_redirected_request, + connection, + ) + if not code: + self._handle_failure( + conn=connection, + ret={ + "code": ER_UNABLE_TO_OPEN_BROWSER, + "message": ( + "Unable to open a browser in this environment and " + "OAuth URL contained no code" + ), + }, + ) + return code, state + + def _parse_authorization_redirected_request( + self, + url: str, + conn: SnowflakeConnection, + ) -> (str | None, str | None): + parsed = parse_qs(urlparse(url).query) + if "error" in parsed: + self._handle_failure( + conn=conn, + ret={ + "code": ER_OAUTH_CALLBACK_ERROR, + "message": f"Oauth callback returned an {parsed['error'][0]} error{': ' + parsed['error_description'][0] if 'error_description' in parsed else '.'}", + }, + ) + return parsed.get("code", [None])[0], parsed.get("state", [None])[0] diff --git a/src/snowflake/connector/auth/oauth_credentials.py b/src/snowflake/connector/auth/oauth_credentials.py new file mode 100644 index 0000000000..6061ead023 --- /dev/null +++ b/src/snowflake/connector/auth/oauth_credentials.py @@ -0,0 +1,64 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# + +from __future__ import annotations + +import logging +from typing import TYPE_CHECKING, Any + +from ..constants import OAUTH_TYPE_CLIENT_CREDENTIALS +from ..token_cache import TokenCache +from ._oauth_base import AuthByOAuthBase + +if TYPE_CHECKING: + from .. import SnowflakeConnection + +logger = logging.getLogger(__name__) + + +class AuthByOauthCredentials(AuthByOAuthBase): + """Authenticates user by OAuth credentials - a client_id/client_secret pair.""" + + def __init__( + self, + application: str, + client_id: str, + client_secret: str, + token_request_url: str, + scope: str, + token_cache: TokenCache | None = None, + refresh_token_enabled: bool = False, + **kwargs, + ) -> None: + super().__init__( + client_id=client_id, + client_secret=client_secret, + token_request_url=token_request_url, + scope=scope, + token_cache=token_cache, + refresh_token_enabled=refresh_token_enabled, + **kwargs, + ) + self._application = application + self._origin: str | None = None + + def _get_oauth_type_id(self) -> str: + return OAUTH_TYPE_CLIENT_CREDENTIALS + + def _request_tokens( + self, + *, + conn: SnowflakeConnection, + authenticator: str, + service_name: str | None, + account: str, + user: str, + **kwargs: Any, + ) -> (str | None, str | None): + logger.debug("authenticating with OAuth client credentials flow") + fields = { + "grant_type": "client_credentials", + "scope": self._scope, + } + return self._get_request_token_response(conn, fields) diff --git a/src/snowflake/connector/auth/webbrowser.py b/src/snowflake/connector/auth/webbrowser.py index 20b92efb52..f5bddd4fcc 100644 --- a/src/snowflake/connector/auth/webbrowser.py +++ b/src/snowflake/connector/auth/webbrowser.py @@ -112,6 +112,7 @@ def prepare( """Web Browser based Authentication.""" logger.debug("authenticating by Web Browser") + # TODO: switch to the new AuthHttpServer class instead of doing this manually socket_connection = self._socket(socket.AF_INET, socket.SOCK_STREAM) if os.getenv("SNOWFLAKE_AUTH_SOCKET_REUSE_PORT", "False").lower() == "true": diff --git a/src/snowflake/connector/connection.py b/src/snowflake/connector/connection.py index 84e0052a62..9103710f7a 100644 --- a/src/snowflake/connector/connection.py +++ b/src/snowflake/connector/connection.py @@ -2,6 +2,7 @@ from __future__ import annotations import atexit +import collections.abc import logging import os import pathlib @@ -35,6 +36,8 @@ AuthByDefault, AuthByKeyPair, AuthByOAuth, + AuthByOauthCode, + AuthByOauthCredentials, AuthByOkta, AuthByPAT, AuthByPlugin, @@ -52,6 +55,7 @@ from .constants import ( _CONNECTIVITY_ERR_MSG, _DOMAIN_NAME_MAP, + _OAUTH_DEFAULT_SCOPE, ENV_VAR_EXPERIMENTAL_AUTHENTICATION, ENV_VAR_PARTNER, PARAMETER_AUTOCOMMIT, @@ -81,6 +85,7 @@ from .direct_file_operation_utils import FileOperationParser, StreamDownloader from .errorcode import ( ER_CONNECTION_IS_CLOSED, + ER_EXPERIMENTAL_AUTHENTICATION_NOT_SUPPORTED, ER_FAILED_PROCESSING_PYFORMAT, ER_FAILED_PROCESSING_QMARK, ER_FAILED_TO_CONNECT_TO_DB, @@ -88,6 +93,7 @@ ER_INVALID_VALUE, ER_INVALID_WIF_SETTINGS, ER_NO_ACCOUNT_NAME, + ER_NO_CLIENT_ID, ER_NO_NUMPY, ER_NO_PASSWORD, ER_NO_USER, @@ -101,6 +107,8 @@ KEY_PAIR_AUTHENTICATOR, NO_AUTH_AUTHENTICATOR, OAUTH_AUTHENTICATOR, + OAUTH_AUTHORIZATION_CODE, + OAUTH_CLIENT_CREDENTIALS, PROGRAMMATIC_ACCESS_TOKEN, REQUEST_ID, USR_PWD_MFA_AUTHENTICATOR, @@ -166,13 +174,13 @@ def _get_private_bytes_from_file( "user": ("", str), # standard "password": ("", str), # standard "host": ("127.0.0.1", str), # standard - "port": (8080, (int, str)), # standard + "port": (443, (int, str)), # standard "database": (None, (type(None), str)), # standard "proxy_host": (None, (type(None), str)), # snowflake "proxy_port": (None, (type(None), str)), # snowflake "proxy_user": (None, (type(None), str)), # snowflake "proxy_password": (None, (type(None), str)), # snowflake - "protocol": ("http", str), # snowflake + "protocol": ("https", str), # snowflake "warehouse": (None, (type(None), str)), # snowflake "region": (None, (type(None), str)), # snowflake "account": (None, (type(None), str)), # snowflake @@ -185,6 +193,7 @@ def _get_private_bytes_from_file( (type(None), int), ), # network timeout (infinite by default) "socket_timeout": (None, (type(None), int)), + "external_browser_timeout": (120, int), "backoff_policy": (DEFAULT_BACKOFF_POLICY, Callable), "passcode_in_password": (False, bool), # Snowflake MFA "passcode": (None, (type(None), str)), # Snowflake MFA @@ -315,6 +324,37 @@ def _get_private_bytes_from_file( False, bool, ), # use https://{bucket}.storage.googleapis.com instead of https://storage.googleapis.com/{bucket} + "oauth_client_id": ( + None, + (type(None), str), + # SNOW-1825621: OAUTH implementation + ), + "oauth_client_secret": ( + None, + (type(None), str), + # SNOW-1825621: OAUTH implementation + ), + "oauth_authorization_url": ( + "https://{host}:{port}/oauth/authorize", + str, + # SNOW-1825621: OAUTH implementation + ), + "oauth_token_request_url": ( + "https://{host}:{port}/oauth/token-request", + str, + # SNOW-1825621: OAUTH implementation + ), + "oauth_redirect_uri": ("http://127.0.0.1/", str), + "oauth_scope": ( + "", + str, + # SNOW-1825621: OAUTH implementation + ), + "oauth_security_features": ( + ("pkce",), + collections.abc.Iterable, # of strings + # SNOW-1825621: OAUTH PKCE + ), "check_arrow_conversion_error_on_every_column": ( True, bool, @@ -552,8 +592,8 @@ def host(self) -> str: return self._host @property - def port(self) -> int | str: # TODO: shouldn't be a string - return self._port + def port(self) -> int: + return int(self._port) @property def region(self) -> str | None: @@ -806,6 +846,21 @@ def unsafe_file_write(self) -> bool: def unsafe_file_write(self, value: bool) -> None: self._unsafe_file_write = value + class _OAuthSecurityFeatures(NamedTuple): + pkce_enabled: bool + refresh_token_enabled: bool + + @property + def oauth_security_features(self) -> _OAuthSecurityFeatures: + features = self._oauth_security_features + if isinstance(features, str): + features = features.split(" ") + features = [feat.lower() for feat in features] + return self._OAuthSecurityFeatures( + pkce_enabled="pkce" in features, + refresh_token_enabled="refresh_token" in features, + ) + @property def gcs_use_virtual_endpoints(self) -> bool: return self._gcs_use_virtual_endpoints @@ -1134,7 +1189,7 @@ def __open_connection(self): self.auth_class = AuthByWebBrowser( application=self.application, protocol=self._protocol, - host=self.host, + host=self.host, # TODO: delete this? port=self.port, timeout=self.login_timeout, backoff_generator=self._backoff_generator, @@ -1170,6 +1225,56 @@ def __open_connection(self): timeout=self.login_timeout, backoff_generator=self._backoff_generator, ) + elif self._authenticator == OAUTH_AUTHORIZATION_CODE: + self._check_experimental_authentication_flag() + self._check_oauth_required_parameters() + features = self.oauth_security_features + if self._role and (self._oauth_scope == ""): + # if role is known then let's inject it into scope + self._oauth_scope = _OAUTH_DEFAULT_SCOPE.format(role=self._role) + self.auth_class = AuthByOauthCode( + application=self.application, + client_id=self._oauth_client_id, + client_secret=self._oauth_client_secret, + authentication_url=self._oauth_authorization_url.format( + host=self.host, port=self.port + ), + token_request_url=self._oauth_token_request_url.format( + host=self.host, port=self.port + ), + redirect_uri=self._oauth_redirect_uri, + scope=self._oauth_scope, + pkce_enabled=features.pkce_enabled, + token_cache=( + auth.get_token_cache() + if self._client_store_temporary_credential + else None + ), + refresh_token_enabled=features.refresh_token_enabled, + external_browser_timeout=self._external_browser_timeout, + ) + elif self._authenticator == OAUTH_CLIENT_CREDENTIALS: + self._check_experimental_authentication_flag() + self._check_oauth_required_parameters() + features = self.oauth_security_features + if self._role and (self._oauth_scope == ""): + # if role is known then let's inject it into scope + self._oauth_scope = _OAUTH_DEFAULT_SCOPE.format(role=self._role) + self.auth_class = AuthByOauthCredentials( + application=self.application, + client_id=self._oauth_client_id, + client_secret=self._oauth_client_secret, + token_request_url=self._oauth_token_request_url.format( + host=self.host, port=self.port + ), + scope=self._oauth_scope, + token_cache=( + auth.get_token_cache() + if self._client_store_temporary_credential + else None + ), + refresh_token_enabled=features.refresh_token_enabled, + ) elif self._authenticator == USR_PWD_MFA_AUTHENTICATOR: self._session_parameters[PARAMETER_CLIENT_REQUEST_MFA_TOKEN] = ( self._client_request_mfa_token if IS_LINUX else True @@ -1189,16 +1294,7 @@ def __open_connection(self): elif self._authenticator == PROGRAMMATIC_ACCESS_TOKEN: self.auth_class = AuthByPAT(self._token) elif self._authenticator == WORKLOAD_IDENTITY_AUTHENTICATOR: - if ENV_VAR_EXPERIMENTAL_AUTHENTICATION not in os.environ: - Error.errorhandler_wrapper( - self, - None, - ProgrammingError, - { - "msg": f"Please set the '{ENV_VAR_EXPERIMENTAL_AUTHENTICATION}' environment variable to use the '{WORKLOAD_IDENTITY_AUTHENTICATOR}' authenticator.", - "errno": ER_INVALID_WIF_SETTINGS, - }, - ) + self._check_experimental_authentication_flag() # Standardize the provider enum. if self._workload_identity_provider and isinstance( self._workload_identity_provider, str @@ -1311,10 +1407,6 @@ def __config(self, **kwargs): if "account" in kwargs: if "host" not in kwargs: self._host = construct_hostname(kwargs.get("region"), self._account) - if "port" not in kwargs: - self._port = "443" - if "protocol" not in kwargs: - self._protocol = "https" logger.info( f"Connecting to {_DOMAIN_NAME_MAP.get(extract_top_level_domain_from_hostname(self._host), 'GLOBAL')} Snowflake domain" @@ -1393,6 +1485,8 @@ def __config(self, **kwargs): not in ( EXTERNAL_BROWSER_AUTHENTICATOR, OAUTH_AUTHENTICATOR, + OAUTH_AUTHORIZATION_CODE, + OAUTH_CLIENT_CREDENTIALS, KEY_PAIR_AUTHENTICATOR, PROGRAMMATIC_ACCESS_TOKEN, WORKLOAD_IDENTITY_AUTHENTICATOR, @@ -1542,9 +1636,13 @@ def authenticate_with_retry(self, auth_instance) -> None: except ReauthenticationRequest as ex: # cached id_token expiration error, we have cleaned id_token and try to authenticate again logger.debug("ID token expired. Reauthenticating...: %s", ex) - if isinstance(auth_instance, AuthByIdToken): - # Note: SNOW-733835 IDToken auth needs to authenticate through - # SSO if it has expired + if type(auth_instance) in ( + AuthByIdToken, + AuthByOauthCode, + AuthByOauthCredentials, + ): + # IDToken and OAuth auth need to authenticate through + # SSO if its credential has expired self._reauthenticate() else: self._authenticate(auth_instance) @@ -2146,6 +2244,40 @@ def is_valid(self) -> bool: logger.debug("session could not be validated due to exception: %s", e) return False + def _check_experimental_authentication_flag(self) -> None: + if os.getenv(ENV_VAR_EXPERIMENTAL_AUTHENTICATION, "false").lower() != "true": + Error.errorhandler_wrapper( + self, + None, + ProgrammingError, + { + "msg": f"Please set the '{ENV_VAR_EXPERIMENTAL_AUTHENTICATION}' environment variable true to use the '{self._authenticator}' authenticator.", + "errno": ER_EXPERIMENTAL_AUTHENTICATION_NOT_SUPPORTED, + }, + ) + + def _check_oauth_required_parameters(self) -> None: + if self._oauth_client_id is None: + Error.errorhandler_wrapper( + self, + None, + ProgrammingError, + { + "msg": "Oauth code flow requirement 'client_id' is empty", + "errno": ER_NO_CLIENT_ID, + }, + ) + if self._oauth_client_secret is None: + Error.errorhandler_wrapper( + self, + None, + ProgrammingError, + { + "msg": "Oauth code flow requirement 'client_secret' is empty", + "errno": ER_NO_CLIENT_ID, + }, + ) + @staticmethod def _detect_application() -> None | str: if ENV_VAR_PARTNER in os.environ.keys(): diff --git a/src/snowflake/connector/constants.py b/src/snowflake/connector/constants.py index 085ec7a2b3..739fcd3fcc 100644 --- a/src/snowflake/connector/constants.py +++ b/src/snowflake/connector/constants.py @@ -321,7 +321,7 @@ class FileHeader(NamedTuple): PARAMETER_CLIENT_STORE_TEMPORARY_CREDENTIAL = "CLIENT_STORE_TEMPORARY_CREDENTIAL" PARAMETER_CLIENT_REQUEST_MFA_TOKEN = "CLIENT_REQUEST_MFA_TOKEN" PARAMETER_CLIENT_USE_SECURE_STORAGE_FOR_TEMPORARY_CREDENTIAL = ( - "CLIENT_USE_SECURE_STORAGE_FOR_TEMPORARY_CREDENTAIL" + "CLIENT_USE_SECURE_STORAGE_FOR_TEMPORARY_CREDENTIAL" ) PARAMETER_QUERY_CONTEXT_CACHE_SIZE = "QUERY_CONTEXT_CACHE_SIZE" PARAMETER_TIMEZONE = "TIMEZONE" @@ -436,3 +436,7 @@ class IterUnit(Enum): "\nTo further troubleshoot your connection you may reference the following article: " "https://docs.snowflake.com/en/user-guide/client-connectivity-troubleshooting/overview." ) + +_OAUTH_DEFAULT_SCOPE = "session:role:{role}" +OAUTH_TYPE_AUTHORIZATION_CODE = "authorization_code" +OAUTH_TYPE_CLIENT_CREDENTIALS = "client_credentials" diff --git a/src/snowflake/connector/errorcode.py b/src/snowflake/connector/errorcode.py index 1bc9138df2..0a0dbe0a45 100644 --- a/src/snowflake/connector/errorcode.py +++ b/src/snowflake/connector/errorcode.py @@ -27,8 +27,13 @@ ER_JWT_RETRY_EXPIRED = 251010 ER_CONNECTION_TIMEOUT = 251011 ER_RETRYABLE_CODE = 251012 -ER_INVALID_WIF_SETTINGS = 251013 -ER_WIF_CREDENTIALS_NOT_FOUND = 251014 +ER_NO_CLIENT_ID = 251013 +ER_OAUTH_STATE_CHANGED = 251014 +ER_OAUTH_CALLBACK_ERROR = 251015 +ER_OAUTH_SERVER_TIMEOUT = 251016 +ER_INVALID_WIF_SETTINGS = 251017 +ER_WIF_CREDENTIALS_NOT_FOUND = 251018 +ER_EXPERIMENTAL_AUTHENTICATION_NOT_SUPPORTED = 251019 # cursor ER_FAILED_TO_REWRITE_MULTI_ROW_INSERT = 252001 diff --git a/src/snowflake/connector/file_lock.py b/src/snowflake/connector/file_lock.py new file mode 100644 index 0000000000..dd3bc85ab9 --- /dev/null +++ b/src/snowflake/connector/file_lock.py @@ -0,0 +1,72 @@ +from __future__ import annotations + +import logging +import time +from os import stat_result +from pathlib import Path +from time import sleep + +MAX_RETRIES = 5 +INITIAL_BACKOFF_SECONDS = 0.025 +STALE_LOCK_AGE_SECONDS = 1 + + +class FileLockError(Exception): + pass + + +class FileLock: + def __init__(self, path: Path) -> None: + self.path: Path = path + self.locked = False + self.logger = logging.getLogger(__name__) + + def __enter__(self): + statinfo: stat_result | None = None + try: + statinfo = self.path.stat() + except FileNotFoundError: + pass + except OSError as e: + raise FileLockError(f"Failed to stat lock file {self.path} due to {e=}") + + if statinfo and statinfo.st_ctime < time.time() - STALE_LOCK_AGE_SECONDS: + self.logger.debug("Removing stale file lock") + try: + self.path.rmdir() + except FileNotFoundError: + pass + except OSError as e: + raise FileLockError( + f"Failed to remove stale lock file {self.path} due to {e=}" + ) + + backoff_seconds = INITIAL_BACKOFF_SECONDS + for attempt in range(MAX_RETRIES): + self.logger.debug( + f"Trying to acquire file lock after {backoff_seconds} seconds in attempt number {attempt}.", + ) + backoff_seconds = backoff_seconds * 2 + try: + self.path.mkdir(mode=0o700) + self.locked = True + break + except FileExistsError: + sleep(backoff_seconds) + continue + except OSError as e: + raise FileLockError( + f"Failed to acquire lock file {self.path} due to {e=}" + ) + + if not self.locked: + raise FileLockError( + f"Failed to acquire file lock, after {MAX_RETRIES} attempts." + ) + + def __exit__(self, exc_type, exc_val, exc_tbc): + try: + self.path.rmdir() + except FileNotFoundError: + pass + self.locked = False diff --git a/src/snowflake/connector/network.py b/src/snowflake/connector/network.py index adffc4b6b9..acfe14c589 100644 --- a/src/snowflake/connector/network.py +++ b/src/snowflake/connector/network.py @@ -138,6 +138,7 @@ MASTER_TOKEN_INVALD_GS_CODE = "390115" ID_TOKEN_INVALID_LOGIN_REQUEST_GS_CODE = "390195" BAD_REQUEST_GS_CODE = "390400" +OAUTH_ACCESS_TOKEN_EXPIRED_GS_CODE = "390318" # other constants CONTENT_TYPE_APPLICATION_JSON = "application/json" @@ -181,6 +182,8 @@ EXTERNAL_BROWSER_AUTHENTICATOR = "EXTERNALBROWSER" KEY_PAIR_AUTHENTICATOR = "SNOWFLAKE_JWT" OAUTH_AUTHENTICATOR = "OAUTH" +OAUTH_AUTHORIZATION_CODE = "OAUTH_AUTHORIZATION_CODE" +OAUTH_CLIENT_CREDENTIALS = "OAUTH_CLIENT_CREDENTIALS" ID_TOKEN_AUTHENTICATOR = "ID_TOKEN" USR_PWD_MFA_AUTHENTICATOR = "USERNAME_PASSWORD_MFA" PROGRAMMATIC_ACCESS_TOKEN = "PROGRAMMATIC_ACCESS_TOKEN" diff --git a/src/snowflake/connector/token_cache.py b/src/snowflake/connector/token_cache.py index 40a55f9e8b..a5ace1f6a8 100644 --- a/src/snowflake/connector/token_cache.py +++ b/src/snowflake/connector/token_cache.py @@ -1,22 +1,24 @@ from __future__ import annotations import codecs +import hashlib import json import logging -import tempfile -import time +import os +import stat +import sys from abc import ABC, abstractmethod from dataclasses import dataclass from enum import Enum -from os import getenv, makedirs, mkdir, path, remove, removedirs, rmdir -from os.path import expanduser -from threading import Lock +from pathlib import Path +from typing import Any, TypeVar from .compat import IS_LINUX, IS_MACOS, IS_WINDOWS -from .file_util import owner_rw_opener +from .file_lock import FileLock, FileLockError from .options import installed_keyring, keyring -KEYRING_DRIVER_NAME = "SNOWFLAKE-PYTHON-DRIVER" +logger = logging.getLogger(__name__) +T = TypeVar("T") class TokenType(Enum): @@ -26,46 +28,65 @@ class TokenType(Enum): OAUTH_REFRESH_TOKEN = "OAUTH_REFRESH_TOKEN" +class _InvalidTokenKeyError(Exception): + pass + + @dataclass class TokenKey: user: str host: str tokenType: TokenType + def string_key(self) -> str: + if len(self.host) == 0: + raise _InvalidTokenKeyError("Invalid key, host is empty") + if len(self.user) == 0: + raise _InvalidTokenKeyError("Invalid key, user is empty") + return f"{self.host.upper()}:{self.user.upper()}:{self.tokenType.value}" -class TokenCache(ABC): - def build_temporary_credential_name( - self, host: str, user: str, cred_type: TokenType - ) -> str: - return "{host}:{user}:{driver}:{cred}".format( - host=host.upper(), - user=user.upper(), - driver=KEYRING_DRIVER_NAME, - cred=cred_type.value, - ) + def hash_key(self) -> str: + m = hashlib.sha256() + m.update(self.string_key().encode(encoding="utf-8")) + return m.hexdigest() + + +def _warn(warning: str) -> None: + logger.warning(warning) + print("Warning: " + warning, file=sys.stderr) + +class TokenCache(ABC): @staticmethod def make() -> TokenCache: if IS_MACOS or IS_WINDOWS: if not installed_keyring: - logging.getLogger(__name__).debug( + _warn( "Dependency 'keyring' is not installed, cannot cache id token. You might experience " - "multiple authentication pop ups while using ExternalBrowser Authenticator. To avoid " - "this please install keyring module using the following command : pip install " - "snowflake-connector-python[secure-local-storage]" + "multiple authentication pop ups while using ExternalBrowser/OAuth/MFA Authenticator. To avoid " + "this please install keyring module using the following command:\n" + " pip install snowflake-connector-python[secure-local-storage]" ) return NoopTokenCache() return KeyringTokenCache() if IS_LINUX: - return FileTokenCache() + cache = FileTokenCache.make() + if cache: + return cache + else: + _warn( + "Failed to initialize file based token cache. You might experience " + "multiple authentication pop ups while using ExternalBrowser/OAuth/MFA Authenticator." + ) + return NoopTokenCache() @abstractmethod def store(self, key: TokenKey, token: str) -> None: pass @abstractmethod - def retrieve(self, key: TokenKey) -> str: + def retrieve(self, key: TokenKey) -> str | None: pass @abstractmethod @@ -73,196 +94,255 @@ def remove(self, key: TokenKey) -> None: pass +class _FileTokenCacheError(Exception): + pass + + +class _OwnershipError(_FileTokenCacheError): + pass + + +class _PermissionsTooWideError(_FileTokenCacheError): + pass + + +class _CacheDirNotFoundError(_FileTokenCacheError): + pass + + +class _InvalidCacheDirError(_FileTokenCacheError): + pass + + +class _MalformedCacheFileError(_FileTokenCacheError): + pass + + +class _CacheFileReadError(_FileTokenCacheError): + pass + + +class _CacheFileWriteError(_FileTokenCacheError): + pass + + class FileTokenCache(TokenCache): + @staticmethod + def make() -> FileTokenCache | None: + cache_dir = FileTokenCache.find_cache_dir() + if cache_dir is None: + logging.getLogger(__name__).debug( + "Failed to find suitable cache directory for token cache. File based token cache initialization failed." + ) + return None + else: + return FileTokenCache(cache_dir) - def __init__(self): + def __init__(self, cache_dir: Path) -> None: self.logger = logging.getLogger(__name__) - self.CACHE_ROOT_DIR = ( - getenv("SF_TEMPORARY_CREDENTIAL_CACHE_DIR") - or expanduser("~") - or tempfile.gettempdir() - ) - self.CACHE_DIR = path.join(self.CACHE_ROOT_DIR, ".cache", "snowflake") - - if not path.exists(self.CACHE_DIR): - try: - makedirs(self.CACHE_DIR, mode=0o700) - except Exception as ex: - self.logger.debug( - "cannot create a cache directory: [%s], err=[%s]", - self.CACHE_DIR, - ex, - ) - self.CACHE_DIR = None - self.logger.debug("cache directory: %s", self.CACHE_DIR) - - # temporary credential cache - self.TEMPORARY_CREDENTIAL: dict[str, dict[str, str | None]] = {} - - self.TEMPORARY_CREDENTIAL_LOCK = Lock() - - # temporary credential cache file name - self.TEMPORARY_CREDENTIAL_FILE = "temporary_credential.json" - self.TEMPORARY_CREDENTIAL_FILE = ( - path.join(self.CACHE_DIR, self.TEMPORARY_CREDENTIAL_FILE) - if self.CACHE_DIR - else "" - ) - - # temporary credential cache lock directory name - self.TEMPORARY_CREDENTIAL_FILE_LOCK = self.TEMPORARY_CREDENTIAL_FILE + ".lck" - - def flush_temporary_credentials(self) -> None: - """Flush temporary credentials in memory into disk. Need to hold TEMPORARY_CREDENTIAL_LOCK.""" - for _ in range(10): - if self.lock_temporary_credential_file(): - break - time.sleep(1) - else: - self.logger.debug( - "The lock file still persists after the maximum wait time." - "Will ignore it and write temporary credential file: %s", - self.TEMPORARY_CREDENTIAL_FILE, - ) + self.cache_dir: Path = cache_dir + + def store(self, key: TokenKey, token: str) -> None: try: - with open( - self.TEMPORARY_CREDENTIAL_FILE, - "w", - encoding="utf-8", - errors="ignore", - opener=owner_rw_opener, - ) as f: - json.dump(self.TEMPORARY_CREDENTIAL, f) - except Exception as ex: - self.logger.debug( - "Failed to write a credential file: " "file=[%s], err=[%s]", - self.TEMPORARY_CREDENTIAL_FILE, - ex, + FileTokenCache.validate_cache_dir(self.cache_dir) + with FileLock(self.lock_file()): + cache = self._read_cache_file() + cache["tokens"][key.hash_key()] = token + self._write_cache_file(cache) + except _FileTokenCacheError as e: + self.logger.error(f"Failed to store token: {e=}") + except FileLockError as e: + self.logger.error(f"Unable to lock file lock: {e=}") + except _InvalidTokenKeyError as e: + self.logger.error(f"Failed to produce token key {e=}") + + def retrieve(self, key: TokenKey) -> str | None: + try: + FileTokenCache.validate_cache_dir(self.cache_dir) + with FileLock(self.lock_file()): + cache = self._read_cache_file() + token = cache["tokens"].get(key.hash_key(), None) + if isinstance(token, str): + return token + else: + return None + except _FileTokenCacheError as e: + self.logger.error(f"Failed to retrieve token: {e=}") + return None + except FileLockError as e: + self.logger.error(f"Unable to lock file lock: {e=}") + return None + except _InvalidTokenKeyError as e: + self.logger.error(f"Failed to produce token key {e=}") + return None + + def remove(self, key: TokenKey) -> None: + try: + FileTokenCache.validate_cache_dir(self.cache_dir) + with FileLock(self.lock_file()): + cache = self._read_cache_file() + cache["tokens"].pop(key.hash_key(), None) + self._write_cache_file(cache) + except _FileTokenCacheError as e: + self.logger.error(f"Failed to remove token: {e=}") + except FileLockError as e: + self.logger.error(f"Unable to lock file lock: {e=}") + except _InvalidTokenKeyError as e: + self.logger.error(f"Failed to produce token key {e=}") + + def cache_file(self) -> Path: + return self.cache_dir / "credential_cache_v1.json" + + def lock_file(self) -> Path: + return self.cache_dir / "credential_cache_v1.json.lck" + + def _read_cache_file(self) -> dict[str, dict[str, Any]]: + fd = -1 + json_data = {"tokens": {}} + try: + fd = os.open(self.cache_file(), os.O_RDONLY) + self._ensure_permissions(fd, 0o600) + size = os.lseek(fd, 0, os.SEEK_END) + os.lseek(fd, 0, os.SEEK_SET) + data = os.read(fd, size) + json_data = json.loads(codecs.decode(data, "utf-8")) + except FileNotFoundError: + self.logger.debug(f"{self.cache_file()} not found") + except json.decoder.JSONDecodeError as e: + self.logger.warning( + f"Failed to decode json read from cache file {self.cache_file()}: {e.__class__.__name__}" + ) + except UnicodeError as e: + self.logger.warning( + f"Failed to decode utf-8 read from cache file {self.cache_file()}: {e.__class__.__name__}" ) + except OSError as e: + self.logger.warning(f"Failed to read cache file {self.cache_file()}: {e}") finally: - self.unlock_temporary_credential_file() + if fd > 0: + os.close(fd) - def lock_temporary_credential_file(self) -> bool: + if "tokens" not in json_data or not isinstance(json_data["tokens"], dict): + json_data["tokens"] = {} + + return json_data + + def _write_cache_file(self, json_data: dict): + fd = -1 + self.logger.debug(f"Writing cache file {self.cache_file()}") try: - mkdir(self.TEMPORARY_CREDENTIAL_FILE_LOCK) - return True - except OSError: - self.logger.debug( - "Temporary cache file lock already exists. Other " - "process may be updating the temporary " + fd = os.open( + self.cache_file(), os.O_WRONLY | os.O_CREAT | os.O_TRUNC, 0o600 ) - return False + self._ensure_permissions(fd, 0o600) + os.write(fd, codecs.encode(json.dumps(json_data), "utf-8")) + return json_data + except OSError as e: + raise _CacheFileWriteError("Failed to write cache file", e) + finally: + if fd > 0: + os.close(fd) - def unlock_temporary_credential_file(self) -> bool: - try: - rmdir(self.TEMPORARY_CREDENTIAL_FILE_LOCK) - return True - except OSError: - self.logger.debug("Temporary cache file lock no longer exists.") - return False - - def write_temporary_credential_file( - self, host: str, cred_name: str, cred: str - ) -> None: - """Writes temporary credential file when OS is Linux.""" - if not self.CACHE_DIR: - # no cache is enabled - return - with self.TEMPORARY_CREDENTIAL_LOCK: - # update the cache - host_data = self.TEMPORARY_CREDENTIAL.get(host.upper(), {}) - host_data[cred_name.upper()] = cred - self.TEMPORARY_CREDENTIAL[host.upper()] = host_data - self.flush_temporary_credentials() - - def read_temporary_credential_file(self): - """Reads temporary credential file when OS is Linux.""" - if not self.CACHE_DIR: - # no cache is enabled - return - - with self.TEMPORARY_CREDENTIAL_LOCK: - for _ in range(10): - if self.lock_temporary_credential_file(): - break - time.sleep(1) - else: - self.logger.debug( - "The lock file still persists. Will ignore and " - "write the temporary credential file: %s", - self.TEMPORARY_CREDENTIAL_FILE, + @staticmethod + def find_cache_dir() -> Path | None: + def lookup_env_dir(env_var: str, subpath_segments: list[str]) -> Path | None: + env_val = os.getenv(env_var) + if env_val is None: + logger.debug( + f"Environment variable {env_var} not set. Skipping it in cache directory lookup." ) + return None + + directory = Path(env_val) + + if len(subpath_segments) > 0: + if not directory.exists(): + logger.debug( + f"Path {str(directory)} does not exist. Skipping it in cache directory lookup." + ) + return None + + if not directory.is_dir(): + logger.debug( + f"Path {str(directory)} is not a directory. Skipping it in cache directory lookup." + ) + return None + + for subpath in subpath_segments[:-1]: + directory = directory / subpath + directory.mkdir(exist_ok=True, mode=0o755) + + directory = directory / subpath_segments[-1] + directory.mkdir(exist_ok=True, mode=0o700) + try: - with codecs.open( - self.TEMPORARY_CREDENTIAL_FILE, - "r", - encoding="utf-8", - errors="ignore", - ) as f: - self.TEMPORARY_CREDENTIAL = json.load(f) - return self.TEMPORARY_CREDENTIAL - except Exception as ex: - self.logger.debug( - "Failed to read a credential file. The file may not" - "exists: file=[%s], err=[%s]", - self.TEMPORARY_CREDENTIAL_FILE, - ex, + FileTokenCache.validate_cache_dir(directory) + return directory + except _FileTokenCacheError as e: + logger.debug( + f"Cache directory validation failed for {str(directory)} due to error '{e}'. Skipping it in cache directory lookup." ) - finally: - self.unlock_temporary_credential_file() - - def temporary_credential_file_delete_password( - self, host: str, user: str, cred_type: TokenType - ) -> None: - """Remove credential from temporary credential file when OS is Linux.""" - if not self.CACHE_DIR: - # no cache is enabled - return - with self.TEMPORARY_CREDENTIAL_LOCK: - # update the cache - host_data = self.TEMPORARY_CREDENTIAL.get(host.upper(), {}) - host_data.pop( - self.build_temporary_credential_name(host, user, cred_type), None - ) - if not host_data: - self.TEMPORARY_CREDENTIAL.pop(host.upper(), None) - else: - self.TEMPORARY_CREDENTIAL[host.upper()] = host_data - self.flush_temporary_credentials() + return None + + lookup_functions = [ + lambda: lookup_env_dir("SF_TEMPORARY_CREDENTIAL_CACHE_DIR", []), + lambda: lookup_env_dir("XDG_CACHE_HOME", ["snowflake"]), + lambda: lookup_env_dir("HOME", [".cache", "snowflake"]), + ] - def delete_temporary_credential_file(self) -> None: - """Deletes temporary credential file and its lock file.""" + for lf in lookup_functions: + cache_dir = lf() + if cache_dir: + return cache_dir + + return None + + @staticmethod + def validate_cache_dir(cache_dir: Path | None) -> None: try: - remove(self.TEMPORARY_CREDENTIAL_FILE) - except Exception as ex: - self.logger.debug( - "Failed to delete a credential file: " "file=[%s], err=[%s]", - self.TEMPORARY_CREDENTIAL_FILE, - ex, + statinfo = cache_dir.stat() + + if cache_dir is None: + raise _CacheDirNotFoundError("Cache dir was not found") + + if not stat.S_ISDIR(statinfo.st_mode): + raise _InvalidCacheDirError(f"Cache dir {cache_dir} is not a directory") + + permissions = stat.S_IMODE(statinfo.st_mode) + if permissions != 0o700: + raise _PermissionsTooWideError( + f"Cache dir {cache_dir} has incorrect permissions. {permissions:o} != 0700" + ) + + euid = os.geteuid() + if statinfo.st_uid != euid: + raise _OwnershipError( + f"Cache dir {cache_dir} has incorrect owner. {euid} != {statinfo.st_uid}" + ) + + except FileNotFoundError: + raise _CacheDirNotFoundError( + f"Cache dir {cache_dir} was not found. Failed to stat." ) + + def _ensure_permissions(self, fd: int, permissions: int) -> None: try: - removedirs(self.TEMPORARY_CREDENTIAL_FILE_LOCK) - except Exception as ex: - self.logger.debug("Failed to delete credential lock file: err=[%s]", ex) + statinfo = os.fstat(fd) + actual_permissions = stat.S_IMODE(statinfo.st_mode) - def store(self, key: TokenKey, token: str) -> None: - return self.write_temporary_credential_file( - key.host, - self.build_temporary_credential_name(key.host, key.user, key.tokenType), - token, - ) - - def retrieve(self, key: TokenKey) -> str: - self.read_temporary_credential_file() - token = self.TEMPORARY_CREDENTIAL.get(key.host.upper(), {}).get( - self.build_temporary_credential_name(key.host, key.user, key.tokenType) - ) - return token + if actual_permissions != permissions: + raise _PermissionsTooWideError( + f"Cache file {self.cache_file()} has incorrect permissions. {permissions:o} != {actual_permissions:o}" + ) - def remove(self, key: TokenKey) -> None: - return self.temporary_credential_file_delete_password( - key.host, key.user, key.tokenType - ) + euid = os.geteuid() + if statinfo.st_uid != euid: + raise _OwnershipError( + f"Cache file {self.cache_file()} has incorrect owner. {euid} != {statinfo.st_uid}" + ) + + except FileNotFoundError: + pass class KeyringTokenCache(TokenCache): @@ -272,17 +352,19 @@ def __init__(self) -> None: def store(self, key: TokenKey, token: str) -> None: try: keyring.set_password( - self.build_temporary_credential_name(key.host, key.user, key.tokenType), + key.string_key(), key.user.upper(), token, ) + except _InvalidTokenKeyError as e: + self.logger.error(f"Could not retrieve {key.tokenType} from keyring, {e=}") except keyring.errors.KeyringError as ke: self.logger.error("Could not store id_token to keyring, %s", str(ke)) - def retrieve(self, key: TokenKey) -> str: + def retrieve(self, key: TokenKey) -> str | None: try: return keyring.get_password( - self.build_temporary_credential_name(key.host, key.user, key.tokenType), + key.string_key(), key.user.upper(), ) except keyring.errors.KeyringError as ke: @@ -291,13 +373,17 @@ def retrieve(self, key: TokenKey) -> str: key.tokenType.value, str(ke) ) ) + except _InvalidTokenKeyError as e: + self.logger.error(f"Could not retrieve {key.tokenType} from keyring, {e=}") def remove(self, key: TokenKey) -> None: try: keyring.delete_password( - self.build_temporary_credential_name(key.host, key.user, key.tokenType), + key.string_key(), key.user.upper(), ) + except _InvalidTokenKeyError as e: + self.logger.error(f"Could not retrieve {key.tokenType} from keyring, {e=}") except Exception as ex: self.logger.error( "Failed to delete credential in the keyring: err=[%s]", ex diff --git a/src/snowflake/connector/vendored/requests/__init__.py b/src/snowflake/connector/vendored/requests/__init__.py index 03c3f69d31..f3d57da6de 100644 --- a/src/snowflake/connector/vendored/requests/__init__.py +++ b/src/snowflake/connector/vendored/requests/__init__.py @@ -41,7 +41,6 @@ import warnings from .. import urllib3 - from .exceptions import RequestsDependencyWarning try: diff --git a/src/snowflake/connector/vendored/requests/adapters.py b/src/snowflake/connector/vendored/requests/adapters.py index ab92194fb5..0c14ac32fd 100644 --- a/src/snowflake/connector/vendored/requests/adapters.py +++ b/src/snowflake/connector/vendored/requests/adapters.py @@ -25,7 +25,6 @@ from ..urllib3.util import Timeout as TimeoutSauce from ..urllib3.util import parse_url from ..urllib3.util.retry import Retry - from .auth import _basic_auth_str from .compat import basestring, urlparse from .cookies import extract_cookies_to_jar diff --git a/src/snowflake/connector/vendored/requests/exceptions.py b/src/snowflake/connector/vendored/requests/exceptions.py index 5efb9c99e1..2ee5d1cfcd 100644 --- a/src/snowflake/connector/vendored/requests/exceptions.py +++ b/src/snowflake/connector/vendored/requests/exceptions.py @@ -5,7 +5,6 @@ This module contains the set of Requests' exceptions. """ from ..urllib3.exceptions import HTTPError as BaseHTTPError - from .compat import JSONDecodeError as CompatJSONDecodeError diff --git a/src/snowflake/connector/vendored/requests/help.py b/src/snowflake/connector/vendored/requests/help.py index fc3d1daef5..85f091e3b0 100644 --- a/src/snowflake/connector/vendored/requests/help.py +++ b/src/snowflake/connector/vendored/requests/help.py @@ -6,8 +6,8 @@ import sys import idna -from .. import urllib3 +from .. import urllib3 from . import __version__ as requests_version try: diff --git a/src/snowflake/connector/vendored/requests/models.py b/src/snowflake/connector/vendored/requests/models.py index bc73aabc52..e88d2a1904 100644 --- a/src/snowflake/connector/vendored/requests/models.py +++ b/src/snowflake/connector/vendored/requests/models.py @@ -23,7 +23,6 @@ from ..urllib3.fields import RequestField from ..urllib3.filepost import encode_multipart_formdata from ..urllib3.util import parse_url - from ._internal_utils import to_native_string, unicode_is_ascii from .auth import HTTPBasicAuth from .compat import ( diff --git a/src/snowflake/connector/vendored/requests/utils.py b/src/snowflake/connector/vendored/requests/utils.py index 1da5e1c34a..e90f96cc81 100644 --- a/src/snowflake/connector/vendored/requests/utils.py +++ b/src/snowflake/connector/vendored/requests/utils.py @@ -20,7 +20,6 @@ from collections import OrderedDict from ..urllib3.util import make_headers, parse_url - from . import certs from .__version__ import __version__ diff --git a/test/auth/__init__.py b/test/auth/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/test/auth/authorization_parameters.py b/test/auth/authorization_parameters.py new file mode 100644 index 0000000000..fe33ee8ea5 --- /dev/null +++ b/test/auth/authorization_parameters.py @@ -0,0 +1,218 @@ +import os +import sys +from typing import Union + +from cryptography.hazmat.backends import default_backend +from cryptography.hazmat.primitives import serialization + +sys.path.append(os.path.abspath(os.path.dirname(__file__))) + + +def get_oauth_token_parameters() -> dict[str, str]: + return { + "auth_url": _get_env_variable("SNOWFLAKE_AUTH_TEST_OAUTH_URL"), + "oauth_client_id": _get_env_variable("SNOWFLAKE_AUTH_TEST_OAUTH_CLIENT_ID"), + "oauth_client_secret": _get_env_variable( + "SNOWFLAKE_AUTH_TEST_OAUTH_CLIENT_SECRET" + ), + "okta_user": _get_env_variable("SNOWFLAKE_AUTH_TEST_OKTA_USER"), + "okta_pass": _get_env_variable("SNOWFLAKE_AUTH_TEST_OKTA_PASS"), + "role": (_get_env_variable("SNOWFLAKE_AUTH_TEST_ROLE")).lower(), + } + + +def _get_env_variable(name: str, required: bool = True) -> str: + value = os.getenv(name) + if required and value is None: + raise OSError(f"Environment variable {name} is not set") + return value + + +def get_okta_login_credentials() -> dict[str, str]: + return { + "login": _get_env_variable("SNOWFLAKE_AUTH_TEST_OKTA_USER"), + "password": _get_env_variable("SNOWFLAKE_AUTH_TEST_OKTA_PASS"), + } + + +def get_soteria_okta_login_credentials() -> dict[str, str]: + return { + "login": _get_env_variable("SNOWFLAKE_AUTH_TEST_EXTERNAL_OAUTH_OKTA_CLIENT_ID"), + "password": _get_env_variable( + "SNOWFLAKE_AUTH_TEST_EXTERNAL_OAUTH_OKTA_USER_PASSWORD" + ), + } + + +def get_rsa_private_key_for_key_pair( + key_path: str, +) -> serialization.load_pem_private_key: + with open(_get_env_variable(key_path), "rb") as key_file: + private_key = serialization.load_pem_private_key( + key_file.read(), password=None, backend=default_backend() + ) + return private_key + + +def get_pat_setup_command_variables() -> dict[str, Union[str, bool, int]]: + return { + "snowflake_user": _get_env_variable("SNOWFLAKE_AUTH_TEST_SNOWFLAKE_USER"), + "role": _get_env_variable("SNOWFLAKE_AUTH_TEST_INTERNAL_OAUTH_SNOWFLAKE_ROLE"), + } + + +class AuthConnectionParameters: + def __init__(self): + self.basic_config = { + "host": _get_env_variable("SNOWFLAKE_AUTH_TEST_HOST"), + "port": _get_env_variable("SNOWFLAKE_AUTH_TEST_PORT"), + "role": _get_env_variable("SNOWFLAKE_AUTH_TEST_ROLE"), + "account": _get_env_variable("SNOWFLAKE_AUTH_TEST_ACCOUNT"), + "db": _get_env_variable("SNOWFLAKE_AUTH_TEST_DATABASE"), + "schema": _get_env_variable("SNOWFLAKE_AUTH_TEST_SCHEMA"), + "warehouse": _get_env_variable("SNOWFLAKE_AUTH_TEST_WAREHOUSE"), + "CLIENT_STORE_TEMPORARY_CREDENTIAL": False, + } + + def get_base_connection_parameters(self) -> dict[str, Union[str, bool, int]]: + return self.basic_config + + def get_key_pair_connection_parameters(self): + config = self.basic_config.copy() + config["authenticator"] = "KEY_PAIR_AUTHENTICATOR" + config["user"] = _get_env_variable("SNOWFLAKE_AUTH_TEST_BROWSER_USER") + + return config + + def get_external_browser_connection_parameters(self) -> dict[str, str]: + config = self.basic_config.copy() + + config["user"] = _get_env_variable("SNOWFLAKE_AUTH_TEST_BROWSER_USER") + config["authenticator"] = "externalbrowser" + + return config + + def get_store_id_token_connection_parameters(self) -> dict[str, str]: + config = self.get_external_browser_connection_parameters() + + config["CLIENT_STORE_TEMPORARY_CREDENTIAL"] = _get_env_variable( + "SNOWFLAKE_AUTH_TEST_STORE_ID_TOKEN_USER" + ) + + return config + + def get_okta_connection_parameters(self) -> dict[str, str]: + config = self.basic_config.copy() + + config["user"] = _get_env_variable("SNOWFLAKE_AUTH_TEST_BROWSER_USER") + config["password"] = _get_env_variable("SNOWFLAKE_AUTH_TEST_OKTA_PASS") + config["authenticator"] = _get_env_variable("SNOWFLAKE_AUTH_TEST_OAUTH_URL") + + return config + + def get_oauth_connection_parameters(self, token: str) -> dict[str, str]: + config = self.basic_config.copy() + + config["user"] = _get_env_variable("SNOWFLAKE_AUTH_TEST_BROWSER_USER") + config["authenticator"] = "OAUTH" + config["token"] = token + return config + + def get_oauth_external_authorization_code_connection_parameters( + self, + ) -> dict[str, Union[str, bool, int]]: + config = self.basic_config.copy() + + config["authenticator"] = "OAUTH_AUTHORIZATION_CODE" + config["oauth_client_id"] = _get_env_variable( + "SNOWFLAKE_AUTH_TEST_EXTERNAL_OAUTH_OKTA_CLIENT_ID" + ) + config["oauth_client_secret"] = _get_env_variable( + "SNOWFLAKE_AUTH_TEST_EXTERNAL_OAUTH_OKTA_CLIENT_SECRET" + ) + config["oauth_redirect_uri"] = _get_env_variable( + "SNOWFLAKE_AUTH_TEST_EXTERNAL_OAUTH_OKTA_REDIRECT_URI" + ) + config["oauth_authorization_url"] = _get_env_variable( + "SNOWFLAKE_AUTH_TEST_EXTERNAL_OAUTH_OKTA_AUTH_URL" + ) + config["oauth_token_request_url"] = _get_env_variable( + "SNOWFLAKE_AUTH_TEST_EXTERNAL_OAUTH_OKTA_TOKEN" + ) + config["user"] = _get_env_variable("SNOWFLAKE_AUTH_TEST_BROWSER_USER") + + return config + + def get_snowflake_authorization_code_connection_parameters( + self, + ) -> dict[str, Union[str, bool, int]]: + config = self.basic_config.copy() + + config["authenticator"] = "OAUTH_AUTHORIZATION_CODE" + config["oauth_client_id"] = _get_env_variable( + "SNOWFLAKE_AUTH_TEST_INTERNAL_OAUTH_SNOWFLAKE_CLIENT_ID" + ) + config["oauth_client_secret"] = _get_env_variable( + "SNOWFLAKE_AUTH_TEST_INTERNAL_OAUTH_SNOWFLAKE_CLIENT_SECRET" + ) + config["oauth_redirect_uri"] = _get_env_variable( + "SNOWFLAKE_AUTH_TEST_INTERNAL_OAUTH_SNOWFLAKE_REDIRECT_URI" + ) + config["role"] = _get_env_variable( + "SNOWFLAKE_AUTH_TEST_INTERNAL_OAUTH_SNOWFLAKE_ROLE" + ) + config["user"] = _get_env_variable( + "SNOWFLAKE_AUTH_TEST_EXTERNAL_OAUTH_OKTA_CLIENT_ID" + ) + + return config + + def get_snowflake_wildcard_external_authorization_code_connection_parameters( + self, + ) -> dict[str, Union[str, bool, int]]: + config = self.basic_config.copy() + + config["authenticator"] = "OAUTH_AUTHORIZATION_CODE" + config["oauth_client_id"] = _get_env_variable( + "SNOWFLAKE_AUTH_TEST_INTERNAL_OAUTH_SNOWFLAKE_WILDCARDS_CLIENT_ID" + ) + config["oauth_client_secret"] = _get_env_variable( + "SNOWFLAKE_AUTH_TEST_INTERNAL_OAUTH_SNOWFLAKE_WILDCARDS_CLIENT_SECRET" + ) + config["role"] = _get_env_variable( + "SNOWFLAKE_AUTH_TEST_INTERNAL_OAUTH_SNOWFLAKE_ROLE" + ) + config["user"] = _get_env_variable( + "SNOWFLAKE_AUTH_TEST_EXTERNAL_OAUTH_OKTA_CLIENT_ID" + ) + + return config + + def get_oauth_external_client_credential_connection_parameters( + self, + ) -> dict[str, str]: + config = self.basic_config.copy() + + config["authenticator"] = "OAUTH_CLIENT_CREDENTIALS" + config["oauth_client_id"] = _get_env_variable( + "SNOWFLAKE_AUTH_TEST_EXTERNAL_OAUTH_OKTA_CLIENT_ID" + ) + config["oauth_client_secret"] = _get_env_variable( + "SNOWFLAKE_AUTH_TEST_EXTERNAL_OAUTH_OKTA_CLIENT_SECRET" + ) + config["oauth_token_request_url"] = _get_env_variable( + "SNOWFLAKE_AUTH_TEST_EXTERNAL_OAUTH_OKTA_TOKEN" + ) + config["user"] = _get_env_variable( + "SNOWFLAKE_AUTH_TEST_EXTERNAL_OAUTH_OKTA_CLIENT_ID" + ) + + return config + + def get_pat_connection_parameters(self) -> dict[str, str]: + config = self.basic_config.copy() + + config["authenticator"] = "PROGRAMMATIC_ACCESS_TOKEN" + config["user"] = _get_env_variable("SNOWFLAKE_AUTH_TEST_BROWSER_USER") + + return config diff --git a/test/auth/authorization_test_helper.py b/test/auth/authorization_test_helper.py new file mode 100644 index 0000000000..0d3148be0d --- /dev/null +++ b/test/auth/authorization_test_helper.py @@ -0,0 +1,144 @@ +import logging.config +import os +import subprocess +import threading +import webbrowser +from enum import Enum +from typing import Union + +import requests + +import snowflake.connector + +try: + from src.snowflake.connector.vendored.requests.auth import HTTPBasicAuth +except ImportError: + pass + +logger = logging.getLogger(__name__) + +logger.setLevel(logging.INFO) + + +class Scenario(Enum): + SUCCESS = "success" + FAIL = "fail" + TIMEOUT = "timeout" + EXTERNAL_OAUTH_OKTA_SUCCESS = "externalOauthOktaSuccess" + INTERNAL_OAUTH_SNOWFLAKE_SUCCESS = "internalOauthSnowflakeSuccess" + + +def get_access_token_oauth(cfg): + auth_url = cfg["auth_url"] + + data = { + "username": cfg["okta_user"], + "password": cfg["okta_pass"], + "grant_type": "password", + "scope": f"session:role:{cfg['role']}", + } + + headers = {"Content-Type": "application/x-www-form-urlencoded;charset=UTF-8"} + + auth_credentials = HTTPBasicAuth(cfg["oauth_client_id"], cfg["oauth_client_secret"]) + try: + response = requests.post( + url=auth_url, data=data, headers=headers, auth=auth_credentials + ) + response.raise_for_status() + return response.json()["access_token"] + + except requests.exceptions.HTTPError as http_err: + logger.error(f"HTTP error occurred: {http_err}") + raise + + +def clean_browser_processes(): + if os.getenv("AUTHENTICATION_TESTS_ENV") == "docker": + try: + clean_browser_processes_path = "/externalbrowser/cleanBrowserProcesses.js" + process = subprocess.run(["node", clean_browser_processes_path], timeout=15) + logger.debug(f"OUTPUT: {process.stdout}, ERRORS: {process.stderr}") + except Exception as e: + raise RuntimeError(e) + + +class AuthorizationTestHelper: + def __init__(self, configuration: dict): + self.auth_test_env = os.getenv("AUTHENTICATION_TESTS_ENV") + self.configuration = configuration + self.error_msg = "" + + def update_config(self, configuration): + self.configuration = configuration + + def connect_and_provide_credentials( + self, scenario: Scenario, login: str, password: str + ): + try: + connect = threading.Thread(target=self.connect_and_execute_simple_query) + connect.start() + if self.auth_test_env == "docker": + browser = threading.Thread( + target=self._provide_credentials, args=(scenario, login, password) + ) + browser.start() + browser.join() + connect.join() + + except Exception as e: + self.error_msg = e + logger.error(e) + + def get_error_msg(self) -> str: + return str(self.error_msg) + + def connect_and_execute_simple_query(self): + try: + logger.info("Trying to connect to Snowflake") + with snowflake.connector.connect(**self.configuration) as con: + result = con.cursor().execute("select 1;") + logger.debug(result.fetchall()) + logger.info("Successfully connected to Snowflake") + return True + except Exception as e: + self.error_msg = e + logger.error(e) + return False + + def _provide_credentials(self, scenario: Scenario, login: str, password: str): + try: + webbrowser.register("xdg-open", None, webbrowser.GenericBrowser("xdg-open")) + provide_browser_credentials_path = ( + "/externalbrowser/provideBrowserCredentials.js" + ) + process = subprocess.run( + [ + "node", + provide_browser_credentials_path, + scenario.value, + login, + password, + ], + timeout=15, + ) + logger.debug(f"OUTPUT: {process.stdout}, ERRORS: {process.stderr}") + except Exception as e: + self.error_msg = e + raise RuntimeError(e) + + def connect_using_okta_connection_and_execute_custom_command( + self, command: str, return_token: bool = False + ) -> Union[bool, str]: + try: + logger.info("Setup PAT") + with snowflake.connector.connect(**self.configuration) as con: + result = con.cursor().execute(command) + token = result.fetchall()[0][1] + except Exception as e: + self.error_msg = e + logger.error(e) + return False + if return_token: + return token + return False diff --git a/test/auth/test_external_browser.py b/test/auth/test_external_browser.py new file mode 100644 index 0000000000..0658bb2c7c --- /dev/null +++ b/test/auth/test_external_browser.py @@ -0,0 +1,90 @@ +import logging +from test.auth.authorization_parameters import ( + AuthConnectionParameters, + get_okta_login_credentials, +) + +import pytest +from authorization_test_helper import ( + AuthorizationTestHelper, + Scenario, + clean_browser_processes, +) + + +@pytest.fixture(autouse=True) +def setup_and_teardown(): + logging.info("Cleanup before test") + clean_browser_processes() + + yield + + logging.info("Teardown: Performing specific actions after the test") + clean_browser_processes() + + +@pytest.mark.auth +def test_external_browser_successful(): + connection_parameters = ( + AuthConnectionParameters().get_external_browser_connection_parameters() + ) + test_helper = AuthorizationTestHelper(connection_parameters) + browser_login, browser_password = get_okta_login_credentials().values() + test_helper.connect_and_provide_credentials( + Scenario.SUCCESS, browser_login, browser_password + ) + assert test_helper.error_msg == "", "Error message should be empty" + + +@pytest.mark.auth +def test_external_browser_mismatched_user(): + connection_parameters = ( + AuthConnectionParameters().get_external_browser_connection_parameters() + ) + connection_parameters["user"] = "differentUsername" + browser_login, browser_password = get_okta_login_credentials().values() + + test_helper = AuthorizationTestHelper(connection_parameters) + test_helper.connect_and_provide_credentials( + Scenario.SUCCESS, browser_login, browser_password + ) + assert ( + "The user you were trying to authenticate as differs from the user" + in test_helper.get_error_msg() + ) + + +@pytest.mark.auth +@pytest.mark.skip(reason="SNOW-2007651 Adding custom browser timeout") +def test_external_browser_wrong_credentials(): + connection_parameters = ( + AuthConnectionParameters().get_external_browser_connection_parameters() + ) + browser_login, browser_password = "invalidUser", "invalidPassword" + connection_parameters["external_browser_timeout"] = 10 + test_helper = AuthorizationTestHelper(connection_parameters) + test_helper.connect_and_provide_credentials( + Scenario.FAIL, browser_login, browser_password + ) + + assert ( + "Unable to receive the OAuth message within a given timeout" + in test_helper.get_error_msg() + ) + + +@pytest.mark.auth +@pytest.mark.skip(reason="SNOW-2007651 Adding custom browser timeout") +def test_external_browser_timeout(): + connection_parameters = ( + AuthConnectionParameters().get_external_browser_connection_parameters() + ) + test_helper = AuthorizationTestHelper(connection_parameters) + connection_parameters["external_browser_timeout"] = 1 + assert ( + not test_helper.connect_and_execute_simple_query() + ), "Connection should not be established" + assert ( + "Unable to receive the OAuth message within a given timeout" + in test_helper.get_error_msg() + ) diff --git a/test/auth/test_key_pair.py b/test/auth/test_key_pair.py new file mode 100644 index 0000000000..21b46c5738 --- /dev/null +++ b/test/auth/test_key_pair.py @@ -0,0 +1,39 @@ +from test.auth.authorization_parameters import ( + AuthConnectionParameters, + get_rsa_private_key_for_key_pair, +) +from test.auth.authorization_test_helper import AuthorizationTestHelper + +import pytest + + +@pytest.mark.auth +def test_key_pair_successful(): + connection_parameters = ( + AuthConnectionParameters().get_key_pair_connection_parameters() + ) + connection_parameters["private_key"] = get_rsa_private_key_for_key_pair( + "SNOWFLAKE_AUTH_TEST_PRIVATE_KEY_PATH" + ) + + test_helper = AuthorizationTestHelper(connection_parameters) + assert ( + test_helper.connect_and_execute_simple_query() + ), "Failed to connect with Snowflake" + assert test_helper.error_msg == "", "Error message should be empty" + + +@pytest.mark.auth +def test_key_pair_invalid_key(): + connection_parameters = ( + AuthConnectionParameters().get_key_pair_connection_parameters() + ) + connection_parameters["private_key"] = get_rsa_private_key_for_key_pair( + "SNOWFLAKE_AUTH_TEST_INVALID_PRIVATE_KEY_PATH" + ) + + test_helper = AuthorizationTestHelper(connection_parameters) + assert ( + not test_helper.connect_and_execute_simple_query() + ), "Connection to Snowflake should not be established" + assert "JWT token is invalid" in test_helper.get_error_msg() diff --git a/test/auth/test_oauth.py b/test/auth/test_oauth.py new file mode 100644 index 0000000000..de977fc92d --- /dev/null +++ b/test/auth/test_oauth.py @@ -0,0 +1,59 @@ +from test.auth.authorization_parameters import ( + AuthConnectionParameters, + get_oauth_token_parameters, +) +from test.auth.authorization_test_helper import ( + AuthorizationTestHelper, + get_access_token_oauth, +) + +import pytest + + +@pytest.mark.auth +def test_oauth_successful(): + token = get_oauth_token() + connection_parameters = AuthConnectionParameters().get_oauth_connection_parameters( + token + ) + test_helper = AuthorizationTestHelper(connection_parameters) + assert ( + test_helper.connect_and_execute_simple_query() + ), "Failed to connect with OAuth token" + assert test_helper.error_msg == "", "Error message should be empty" + + +@pytest.mark.auth +def test_oauth_mismatched_user(): + token = get_oauth_token() + connection_parameters = AuthConnectionParameters().get_oauth_connection_parameters( + token + ) + connection_parameters["user"] = "differentUsername" + test_helper = AuthorizationTestHelper(connection_parameters) + assert ( + test_helper.connect_and_execute_simple_query() is False + ), "Connection should not be established" + assert ( + "The user you were trying to authenticate as differs from the user" + in test_helper.get_error_msg() + ) + + +@pytest.mark.auth +def test_oauth_invalid_token(): + token = "invalidToken" + connection_parameters = AuthConnectionParameters().get_oauth_connection_parameters( + token + ) + test_helper = AuthorizationTestHelper(connection_parameters) + assert ( + test_helper.connect_and_execute_simple_query() is False + ), "Connection should not be established" + assert "Invalid OAuth access token" in test_helper.get_error_msg() + + +def get_oauth_token(): + oauth_config = get_oauth_token_parameters() + token = get_access_token_oauth(oauth_config) + return token diff --git a/test/auth/test_okta.py b/test/auth/test_okta.py new file mode 100644 index 0000000000..adfffd31df --- /dev/null +++ b/test/auth/test_okta.py @@ -0,0 +1,58 @@ +from test.auth.authorization_parameters import AuthConnectionParameters +from test.auth.authorization_test_helper import AuthorizationTestHelper + +import pytest + + +@pytest.mark.auth +def test_okta_successful(): + connection_parameters = AuthConnectionParameters().get_okta_connection_parameters() + test_helper = AuthorizationTestHelper(connection_parameters) + + assert ( + test_helper.connect_and_execute_simple_query() + ), "Failed to connect with Snowflake" + assert test_helper.error_msg == "", "Error message should be empty" + + +@pytest.mark.auth +def test_okta_with_wrong_okta_username(): + connection_parameters = AuthConnectionParameters().get_okta_connection_parameters() + connection_parameters["user"] = "differentUsername" + + test_helper = AuthorizationTestHelper(connection_parameters) + assert ( + not test_helper.connect_and_execute_simple_query() + ), "Connection to Snowflake should not be established" + assert "Failed to get authentication by OKTA" in test_helper.get_error_msg() + + +@pytest.mark.auth +def test_okta_wrong_url(): + connection_parameters = AuthConnectionParameters().get_okta_connection_parameters() + + connection_parameters["authenticator"] = "https://invalid.okta.com/" + test_helper = AuthorizationTestHelper(connection_parameters) + assert ( + not test_helper.connect_and_execute_simple_query() + ), "Connection to Snowflake should not be established" + assert ( + "The specified authenticator is not accepted by your Snowflake account configuration" + in test_helper.get_error_msg() + ) + + +@pytest.mark.auth +@pytest.mark.skip(reason="SNOW-1852279 implement error handling for invalid URL") +def test_okta_wrong_url_2(): + connection_parameters = AuthConnectionParameters().get_okta_connection_parameters() + + connection_parameters["authenticator"] = "https://invalid.abc.com/" + test_helper = AuthorizationTestHelper(connection_parameters) + assert ( + not test_helper.connect_and_execute_simple_query() + ), "Connection to Snowflake should not be established" + assert ( + "The specified authenticator is not accepted by your Snowflake account configuration" + in test_helper.get_error_msg() + ) diff --git a/test/auth/test_okta_authorization_code.py b/test/auth/test_okta_authorization_code.py new file mode 100644 index 0000000000..db4f16dd34 --- /dev/null +++ b/test/auth/test_okta_authorization_code.py @@ -0,0 +1,96 @@ +import logging +from test.auth.authorization_parameters import ( + AuthConnectionParameters, + get_okta_login_credentials, +) + +import pytest +from authorization_test_helper import ( + AuthorizationTestHelper, + Scenario, + clean_browser_processes, +) + + +@pytest.fixture(autouse=True) +def setup_and_teardown(): + logging.info("Cleanup before test") + clean_browser_processes() + + yield + + logging.info("Teardown: Performing specific actions after the test") + clean_browser_processes() + + +@pytest.mark.auth +def test_okta_authorization_code_successful(): + connection_parameters = ( + AuthConnectionParameters().get_oauth_external_authorization_code_connection_parameters() + ) + test_helper = AuthorizationTestHelper(connection_parameters) + browser_login, browser_password = get_okta_login_credentials().values() + test_helper.connect_and_provide_credentials( + Scenario.SUCCESS, browser_login, browser_password + ) + + assert test_helper.error_msg == "", "Error message should be empty" + + +@pytest.mark.auth +def test_okta_authorization_code_mismatched_user(): + connection_parameters = ( + AuthConnectionParameters().get_oauth_external_authorization_code_connection_parameters() + ) + connection_parameters["user"] = "differentUsername" + browser_login, browser_password = get_okta_login_credentials().values() + + test_helper = AuthorizationTestHelper(connection_parameters) + test_helper.connect_and_provide_credentials( + Scenario.SUCCESS, browser_login, browser_password + ) + + assert ( + "The user you were trying to authenticate as differs from the user" + in test_helper.get_error_msg() + ) + + +@pytest.mark.auth +def test_okta_authorization_code_timeout(): + connection_parameters = ( + AuthConnectionParameters().get_oauth_external_authorization_code_connection_parameters() + ) + test_helper = AuthorizationTestHelper(connection_parameters) + connection_parameters["external_browser_timeout"] = 1 + + assert ( + test_helper.connect_and_execute_simple_query() is False + ), "Connection should not be established" + assert ( + "Unable to receive the OAuth message within a given timeout" + in test_helper.get_error_msg() + ) + + +@pytest.mark.auth +def test_okta_authorization_code_with_token_cache(): + connection_parameters = ( + AuthConnectionParameters().get_oauth_external_authorization_code_connection_parameters() + ) + connection_parameters["client_store_temporary_credential"] = True + connection_parameters["external_browser_timeout"] = 10 + + test_helper = AuthorizationTestHelper(connection_parameters) + browser_login, browser_password = get_okta_login_credentials().values() + + test_helper.connect_and_provide_credentials( + Scenario.SUCCESS, browser_login, browser_password + ) + + clean_browser_processes() + + assert ( + test_helper.connect_and_execute_simple_query() is True + ), "Connection should be established" + assert test_helper.error_msg == "", "Error message should be empty" diff --git a/test/auth/test_okta_client_credentials.py b/test/auth/test_okta_client_credentials.py new file mode 100644 index 0000000000..063e22d786 --- /dev/null +++ b/test/auth/test_okta_client_credentials.py @@ -0,0 +1,57 @@ +import logging +from test.auth.authorization_parameters import AuthConnectionParameters + +import pytest +from authorization_test_helper import AuthorizationTestHelper, clean_browser_processes + + +@pytest.fixture(autouse=True) +def setup_and_teardown(): + logging.info("Cleanup before test") + clean_browser_processes() + + yield + + logging.info("Teardown: Performing specific actions after the test") + clean_browser_processes() + + +@pytest.mark.auth +def test_okta_client_credentials_successful(): + connection_parameters = ( + AuthConnectionParameters().get_oauth_external_client_credential_connection_parameters() + ) + test_helper = AuthorizationTestHelper(connection_parameters) + + test_helper.connect_and_execute_simple_query() + + assert test_helper.error_msg == "", "Error message should be empty" + + +@pytest.mark.auth +def test_okta_client_credentials_mismatched_user(): + connection_parameters = ( + AuthConnectionParameters().get_oauth_external_client_credential_connection_parameters() + ) + connection_parameters["user"] = "differentUsername" + test_helper = AuthorizationTestHelper(connection_parameters) + + test_helper.connect_and_execute_simple_query() + + assert ( + "The user you were trying to authenticate as differs from the user" + in test_helper.get_error_msg() + ) + + +@pytest.mark.auth +def test_okta_client_credentials_unauthorized(): + connection_parameters = ( + AuthConnectionParameters().get_oauth_external_client_credential_connection_parameters() + ) + connection_parameters["oauth_client_id"] = "invalidClientID" + test_helper = AuthorizationTestHelper(connection_parameters) + + test_helper.connect_and_execute_simple_query() + + assert "Invalid HTTP request from web browser" in test_helper.get_error_msg() diff --git a/test/auth/test_pat.py b/test/auth/test_pat.py new file mode 100644 index 0000000000..5db79967f2 --- /dev/null +++ b/test/auth/test_pat.py @@ -0,0 +1,82 @@ +from datetime import datetime +from test.auth.authorization_parameters import ( + AuthConnectionParameters, + get_pat_setup_command_variables, +) +from typing import Union + +import pytest +from authorization_test_helper import AuthorizationTestHelper + + +@pytest.mark.auth +def test_authenticate_with_pat_successful() -> None: + pat_command_variables = get_pat_setup_command_variables() + connection_parameters = AuthConnectionParameters().get_pat_connection_parameters() + test_helper = AuthorizationTestHelper(connection_parameters) + try: + pat_command_variables = get_pat_token(pat_command_variables) + connection_parameters["token"] = pat_command_variables["token"] + test_helper.connect_and_execute_simple_query() + finally: + remove_pat_token(pat_command_variables) + assert test_helper.get_error_msg() == "", "Error message should be empty" + + +@pytest.mark.auth +def test_authenticate_with_pat_mismatched_user() -> None: + pat_command_variables = get_pat_setup_command_variables() + connection_parameters = AuthConnectionParameters().get_pat_connection_parameters() + connection_parameters["user"] = "differentUsername" + test_helper = AuthorizationTestHelper(connection_parameters) + try: + pat_command_variables = get_pat_token(pat_command_variables) + connection_parameters["token"] = pat_command_variables["token"] + test_helper.connect_and_execute_simple_query() + finally: + remove_pat_token(pat_command_variables) + + assert "Programmatic access token is invalid" in test_helper.get_error_msg() + + +@pytest.mark.auth +def test_authenticate_with_pat_invalid_token() -> None: + connection_parameters = AuthConnectionParameters().get_pat_connection_parameters() + connection_parameters["token"] = "invalidToken" + test_helper = AuthorizationTestHelper(connection_parameters) + test_helper.connect_and_execute_simple_query() + assert "Programmatic access token is invalid" in test_helper.get_error_msg() + + +def get_pat_token(pat_command_variables) -> dict[str, Union[str, bool]]: + okta_connection_parameters = ( + AuthConnectionParameters().get_okta_connection_parameters() + ) + + pat_name = "PAT_PYTHON_" + generate_random_suffix() + pat_command_variables["pat_name"] = pat_name + command = ( + f"alter user {pat_command_variables['snowflake_user']} add programmatic access token {pat_name} " + f"ROLE_RESTRICTION = '{pat_command_variables['role']}' DAYS_TO_EXPIRY=1;" + ) + test_helper = AuthorizationTestHelper(okta_connection_parameters) + pat_command_variables["token"] = ( + test_helper.connect_using_okta_connection_and_execute_custom_command( + command, True + ) + ) + return pat_command_variables + + +def remove_pat_token(pat_command_variables: dict[str, Union[str, bool]]) -> None: + okta_connection_parameters = ( + AuthConnectionParameters().get_okta_connection_parameters() + ) + + command = f"alter user {pat_command_variables['snowflake_user']} remove programmatic access token {pat_command_variables['pat_name']};" + test_helper = AuthorizationTestHelper(okta_connection_parameters) + test_helper.connect_using_okta_connection_and_execute_custom_command(command) + + +def generate_random_suffix() -> str: + return datetime.now().strftime("%Y%m%d%H%M%S%f") diff --git a/test/auth/test_snowflake_authorization_code.py b/test/auth/test_snowflake_authorization_code.py new file mode 100644 index 0000000000..9116c9008e --- /dev/null +++ b/test/auth/test_snowflake_authorization_code.py @@ -0,0 +1,122 @@ +import logging +from test.auth.authorization_parameters import ( + AuthConnectionParameters, + get_soteria_okta_login_credentials, +) + +import pytest +from authorization_test_helper import ( + AuthorizationTestHelper, + Scenario, + clean_browser_processes, +) + + +@pytest.fixture(autouse=True) +def setup_and_teardown(): + logging.info("Cleanup before test") + clean_browser_processes() + + yield + + logging.info("Teardown: Performing specific actions after the test") + clean_browser_processes() + + +@pytest.mark.auth +def test_snowflake_authorization_code_successful(): + connection_parameters = ( + AuthConnectionParameters().get_snowflake_authorization_code_connection_parameters() + ) + test_helper = AuthorizationTestHelper(connection_parameters) + browser_login, browser_password = get_soteria_okta_login_credentials().values() + + test_helper.connect_and_provide_credentials( + Scenario.INTERNAL_OAUTH_SNOWFLAKE_SUCCESS, browser_login, browser_password + ) + + assert test_helper.error_msg == "", "Error message should be empty" + + +@pytest.mark.auth +def test_snowflake_authorization_code_mismatched_user(): + connection_parameters = ( + AuthConnectionParameters().get_snowflake_authorization_code_connection_parameters() + ) + connection_parameters["user"] = "differentUsername" + browser_login, browser_password = get_soteria_okta_login_credentials().values() + test_helper = AuthorizationTestHelper(connection_parameters) + + test_helper.connect_and_provide_credentials( + Scenario.INTERNAL_OAUTH_SNOWFLAKE_SUCCESS, browser_login, browser_password + ) + + assert ( + "The user you were trying to authenticate as differs from the user" + in test_helper.get_error_msg() + ) + + +@pytest.mark.auth +def test_snowflake_authorization_code_timeout(): + connection_parameters = ( + AuthConnectionParameters().get_snowflake_authorization_code_connection_parameters() + ) + test_helper = AuthorizationTestHelper(connection_parameters) + connection_parameters["external_browser_timeout"] = 1 + + assert ( + test_helper.connect_and_execute_simple_query() is False + ), "Connection should not be established" + assert ( + "Unable to receive the OAuth message within a given timeout" + in test_helper.get_error_msg() + ) + + +@pytest.mark.auth +def test_snowflake_authorization_code_with_token_cache(): + connection_parameters = ( + AuthConnectionParameters().get_snowflake_authorization_code_connection_parameters() + ) + connection_parameters["external_browser_timeout"] = 15 + connection_parameters["client_store_temporary_credential"] = True + test_helper = AuthorizationTestHelper(connection_parameters) + browser_login, browser_password = get_soteria_okta_login_credentials().values() + + test_helper.connect_and_provide_credentials( + Scenario.INTERNAL_OAUTH_SNOWFLAKE_SUCCESS, browser_login, browser_password + ) + + clean_browser_processes() + + assert ( + test_helper.connect_and_execute_simple_query() is True + ), "Connection should be established" + assert test_helper.get_error_msg() == "", "Error message should be empty" + + +@pytest.mark.auth +def test_snowflake_authorization_code_without_token_cache(): + connection_parameters = ( + AuthConnectionParameters().get_snowflake_authorization_code_connection_parameters() + ) + connection_parameters["client_store_temporary_credential"] = False + connection_parameters["external_browser_timeout"] = 15 + test_helper = AuthorizationTestHelper(connection_parameters) + browser_login, browser_password = get_soteria_okta_login_credentials().values() + + test_helper.connect_and_provide_credentials( + Scenario.INTERNAL_OAUTH_SNOWFLAKE_SUCCESS, browser_login, browser_password + ) + + clean_browser_processes() + + assert ( + test_helper.connect_and_execute_simple_query() is False + ), "Connection should be established" + + assert ( + "Unable to receive the OAuth message within a given timeout" + in test_helper.get_error_msg() + ), "Error message should contain timeout" diff --git a/test/auth/test_snowflake_authorization_code_wildcards.py b/test/auth/test_snowflake_authorization_code_wildcards.py new file mode 100644 index 0000000000..f38db07bdf --- /dev/null +++ b/test/auth/test_snowflake_authorization_code_wildcards.py @@ -0,0 +1,121 @@ +import logging +from test.auth.authorization_parameters import ( + AuthConnectionParameters, + get_soteria_okta_login_credentials, +) + +import pytest +from authorization_test_helper import ( + AuthorizationTestHelper, + Scenario, + clean_browser_processes, +) + + +@pytest.fixture(autouse=True) +def setup_and_teardown(): + logging.info("Cleanup before test") + clean_browser_processes() + + yield + + logging.info("Teardown: Performing specific actions after the test") + clean_browser_processes() + + +@pytest.mark.auth +def test_snowflake_authorization_code_wildcards_successful(): + connection_parameters = ( + AuthConnectionParameters().get_snowflake_wildcard_external_authorization_code_connection_parameters() + ) + test_helper = AuthorizationTestHelper(connection_parameters) + browser_login, browser_password = get_soteria_okta_login_credentials().values() + + test_helper.connect_and_provide_credentials( + Scenario.INTERNAL_OAUTH_SNOWFLAKE_SUCCESS, browser_login, browser_password + ) + + assert test_helper.error_msg == "", "Error message should be empty" + + +@pytest.mark.auth +def test_snowflake_authorization_code_wildcards_mismatched_user(): + connection_parameters = ( + AuthConnectionParameters().get_snowflake_wildcard_external_authorization_code_connection_parameters() + ) + connection_parameters["user"] = "differentUsername" + browser_login, browser_password = get_soteria_okta_login_credentials().values() + test_helper = AuthorizationTestHelper(connection_parameters) + + test_helper.connect_and_provide_credentials( + Scenario.INTERNAL_OAUTH_SNOWFLAKE_SUCCESS, browser_login, browser_password + ) + + assert ( + "The user you were trying to authenticate as differs from the user" + in test_helper.get_error_msg() + ) + + +@pytest.mark.auth +def test_snowflake_authorization_code_wildcards_timeout(): + connection_parameters = ( + AuthConnectionParameters().get_snowflake_wildcard_external_authorization_code_connection_parameters() + ) + test_helper = AuthorizationTestHelper(connection_parameters) + connection_parameters["external_browser_timeout"] = 1 + + assert ( + test_helper.connect_and_execute_simple_query() is False + ), "Connection should not be established" + assert ( + "Unable to receive the OAuth message within a given timeout" + in test_helper.get_error_msg() + ) + + +@pytest.mark.auth +def test_snowflake_authorization_code_wildcards_with_token_cache(): + connection_parameters = ( + AuthConnectionParameters().get_snowflake_wildcard_external_authorization_code_connection_parameters() + ) + connection_parameters["external_browser_timeout"] = 15 + connection_parameters["client_store_temporary_credential"] = True + test_helper = AuthorizationTestHelper(connection_parameters) + browser_login, browser_password = get_soteria_okta_login_credentials().values() + + test_helper.connect_and_provide_credentials( + Scenario.INTERNAL_OAUTH_SNOWFLAKE_SUCCESS, browser_login, browser_password + ) + + clean_browser_processes() + + assert ( + test_helper.connect_and_execute_simple_query() is True + ), "Connection should be established" + assert test_helper.get_error_msg() == "", "Error message should be empty" + + +@pytest.mark.auth +def test_snowflake_authorization_code_wildcards_without_token_cache(): + connection_parameters = ( + AuthConnectionParameters().get_snowflake_wildcard_external_authorization_code_connection_parameters() + ) + connection_parameters["client_store_temporary_credential"] = False + connection_parameters["external_browser_timeout"] = 15 + test_helper = AuthorizationTestHelper(connection_parameters) + browser_login, browser_password = get_soteria_okta_login_credentials().values() + + test_helper.connect_and_provide_credentials( + Scenario.INTERNAL_OAUTH_SNOWFLAKE_SUCCESS, browser_login, browser_password + ) + + clean_browser_processes() + + assert ( + test_helper.connect_and_execute_simple_query() is False + ), "Connection should be established" + assert ( + "Unable to receive the OAuth message within a given timeout" + in test_helper.get_error_msg() + ), "Error message should contain timeout" diff --git a/test/conftest.py b/test/conftest.py index 9f0fcbc7c8..5cdc714216 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -144,3 +144,7 @@ def pytest_runtest_setup(item) -> None: pytest.skip("cannot run this test on public Snowflake deployment") elif INTERNAL_SKIP_TAGS.intersection(test_tags) and not running_on_public_ci(): pytest.skip("cannot run this test on private Snowflake deployment") + + if "auth" in test_tags: + if os.getenv("RUN_AUTH_TESTS") != "true": + pytest.skip("Skipping auth test in current environment") diff --git a/test/data/wiremock/mappings/auth/oauth/authorization_code/browser_timeout_authorization_error.json b/test/data/wiremock/mappings/auth/oauth/authorization_code/browser_timeout_authorization_error.json new file mode 100644 index 0000000000..b14718c2ba --- /dev/null +++ b/test/data/wiremock/mappings/auth/oauth/authorization_code/browser_timeout_authorization_error.json @@ -0,0 +1,15 @@ +{ + "mappings": [ + { + "scenarioName": "Browser Authorization timeout", + "request": { + "urlPathPattern": "/oauth/authorize.*", + "method": "GET" + }, + "response": { + "status": 200, + "fixedDelayMilliseconds": 5000 + } + } + ] +} diff --git a/test/data/wiremock/mappings/auth/oauth/authorization_code/external_idp_custom_urls.json b/test/data/wiremock/mappings/auth/oauth/authorization_code/external_idp_custom_urls.json new file mode 100644 index 0000000000..0cee97115f --- /dev/null +++ b/test/data/wiremock/mappings/auth/oauth/authorization_code/external_idp_custom_urls.json @@ -0,0 +1,77 @@ +{ + "mappings": [ + { + "scenarioName": "Custom urls OAuth authorization code flow", + "requiredScenarioState": "Started", + "newScenarioState": "Authorized", + "request": { + "urlPathPattern": "/authorization", + "method": "GET", + "queryParameters": { + "response_type": { + "equalTo": "code" + }, + "scope": { + "equalTo": "session:role:ANALYST" + }, + "code_challenge_method": { + "equalTo": "S256" + }, + "redirect_uri": { + "equalTo": "http://localhost:8009/snowflake/oauth-redirect" + }, + "code_challenge": { + "matches": ".*" + }, + "state": { + "matches": ".*" + }, + "client_id": { + "equalTo": "123" + } + } + }, + "response": { + "status": 302, + "headers": { + "Location": "http://localhost:8009/snowflake/oauth-redirect?code=123&state=abc123" + } + } + }, + { + "scenarioName": "Custom urls OAuth authorization code flow", + "requiredScenarioState": "Authorized", + "newScenarioState": "Acquired access token", + "request": { + "urlPathPattern": "/tokenrequest.*", + "method": "POST", + "headers": { + "Authorization": { + "contains": "Basic" + }, + "Content-Type": { + "contains": "application/x-www-form-urlencoded; charset=UTF-8" + } + }, + "bodyPatterns": [ + { + "contains": "grant_type=authorization_code&code=123&redirect_uri=http%3A%2F%2Flocalhost%3A8009%2Fsnowflake%2Foauth-redirect&code_verifier=" + } + ] + }, + "response": { + "status": 200, + "jsonBody": { + "access_token": "access-token-123", + "refresh_token": "123", + "token_type": "Bearer", + "username": "user", + "scope": "refresh_token session:role:ANALYST", + "expires_in": 600, + "refresh_token_expires_in": 86399, + "idpInitiated": false + } + } + } + ] +} diff --git a/test/data/wiremock/mappings/auth/oauth/authorization_code/invalid_scope_error.json b/test/data/wiremock/mappings/auth/oauth/authorization_code/invalid_scope_error.json new file mode 100644 index 0000000000..fc495213e1 --- /dev/null +++ b/test/data/wiremock/mappings/auth/oauth/authorization_code/invalid_scope_error.json @@ -0,0 +1,17 @@ +{ + "mappings": [ + { + "scenarioName": "Invalid scope authorization error", + "request": { + "urlPathPattern": "/oauth/authorize.*", + "method": "GET" + }, + "response": { + "status": 302, + "headers": { + "Location": "http://localhost:8009/snowflake/oauth-redirect?error=invalid_scope&error_description=One+or+more+scopes+are+not+configured+for+the+authorization+server+resource." + } + } + } + ] +} diff --git a/test/data/wiremock/mappings/auth/oauth/authorization_code/invalid_state_error.json b/test/data/wiremock/mappings/auth/oauth/authorization_code/invalid_state_error.json new file mode 100644 index 0000000000..23799a655c --- /dev/null +++ b/test/data/wiremock/mappings/auth/oauth/authorization_code/invalid_state_error.json @@ -0,0 +1,17 @@ +{ + "mappings": [ + { + "scenarioName": "Invalid scope authorization error", + "request": { + "urlPathPattern": "/oauth/authorize.*", + "method": "GET" + }, + "response": { + "status": 302, + "headers": { + "Location": "http://localhost:8009/snowflake/oauth-redirect?code=123&state=invalidstate" + } + } + } + ] +} diff --git a/test/data/wiremock/mappings/auth/oauth/authorization_code/new_tokens_after_failed_refresh.json b/test/data/wiremock/mappings/auth/oauth/authorization_code/new_tokens_after_failed_refresh.json new file mode 100644 index 0000000000..e6cfb44085 --- /dev/null +++ b/test/data/wiremock/mappings/auth/oauth/authorization_code/new_tokens_after_failed_refresh.json @@ -0,0 +1,34 @@ +{ + "requiredScenarioState": "Authorized", + "newScenarioState": "Acquired access token", + "request": { + "urlPathPattern": "/oauth/token-request.*", + "method": "POST", + "headers": { + "Authorization": { + "contains": "Basic" + }, + "Content-Type": { + "contains": "application/x-www-form-urlencoded; charset=UTF-8" + } + }, + "bodyPatterns": [ + { + "matches": "^grant_type=authorization_code&code=123&redirect_uri=http%3A%2F%2Flocalhost%3A([0-9]+)%2Fsnowflake%2Foauth-redirect&code_verifier=abc123$" + } + ] + }, + "response": { + "status": 200, + "jsonBody": { + "access_token": "access-token-123", + "refresh_token": "refresh-token-123", + "token_type": "Bearer", + "username": "user", + "scope": "refresh_token session:role:ANALYST", + "expires_in": 600, + "refresh_token_expires_in": 86399, + "idpInitiated": false + } + } +} diff --git a/test/data/wiremock/mappings/auth/oauth/authorization_code/successful_auth_after_failed_refresh.json b/test/data/wiremock/mappings/auth/oauth/authorization_code/successful_auth_after_failed_refresh.json new file mode 100644 index 0000000000..f61d618011 --- /dev/null +++ b/test/data/wiremock/mappings/auth/oauth/authorization_code/successful_auth_after_failed_refresh.json @@ -0,0 +1,37 @@ +{ + "requiredScenarioState": "Failed refresh token attempt", + "newScenarioState": "Authorized", + "request": { + "urlPathPattern": "/oauth/authorize", + "queryParameters": { + "response_type": { + "equalTo": "code" + }, + "scope": { + "equalTo": "session:role:ANALYST offline_access" + }, + "code_challenge_method": { + "equalTo": "S256" + }, + "redirect_uri": { + "equalTo": "http://localhost:8009/snowflake/oauth-redirect" + }, + "code_challenge": { + "matches": ".*" + }, + "state": { + "matches": ".*" + }, + "client_id": { + "equalTo": "123" + } + }, + "method": "GET" + }, + "response": { + "status": 302, + "headers": { + "Location": "http://localhost:8009/snowflake/oauth-redirect?code=123&state=abc123" + } + } +} diff --git a/test/data/wiremock/mappings/auth/oauth/authorization_code/successful_flow.json b/test/data/wiremock/mappings/auth/oauth/authorization_code/successful_flow.json new file mode 100644 index 0000000000..6bb82d855f --- /dev/null +++ b/test/data/wiremock/mappings/auth/oauth/authorization_code/successful_flow.json @@ -0,0 +1,77 @@ +{ + "mappings": [ + { + "scenarioName": "Successful OAuth authorization code flow", + "requiredScenarioState": "Started", + "newScenarioState": "Authorized", + "request": { + "urlPathPattern": "/oauth/authorize", + "queryParameters": { + "response_type": { + "equalTo": "code" + }, + "scope": { + "equalTo": "session:role:ANALYST" + }, + "code_challenge_method": { + "equalTo": "S256" + }, + "redirect_uri": { + "equalTo": "http://localhost:8009/snowflake/oauth-redirect" + }, + "code_challenge": { + "matches": ".*" + }, + "state": { + "matches": ".*" + }, + "client_id": { + "equalTo": "123" + } + }, + "method": "GET" + }, + "response": { + "status": 302, + "headers": { + "Location": "http://localhost:8009/snowflake/oauth-redirect?code=123&state=abc123" + } + } + }, + { + "scenarioName": "Successful OAuth authorization code flow", + "requiredScenarioState": "Authorized", + "newScenarioState": "Acquired access token", + "request": { + "urlPathPattern": "/oauth/token-request.*", + "method": "POST", + "headers": { + "Authorization": { + "contains": "Basic" + }, + "Content-Type": { + "contains": "application/x-www-form-urlencoded; charset=UTF-8" + } + }, + "bodyPatterns": [ + { + "matches": "^grant_type=authorization_code&code=123&redirect_uri=http%3A%2F%2Flocalhost%3A([0-9]+)%2Fsnowflake%2Foauth-redirect&code_verifier=abc123$" + } + ] + }, + "response": { + "status": 200, + "jsonBody": { + "access_token": "access-token-123", + "refresh_token": "123", + "token_type": "Bearer", + "username": "user", + "scope": "refresh_token session:role:ANALYST", + "expires_in": 600, + "refresh_token_expires_in": 86399, + "idpInitiated": false + } + } + } + ] +} diff --git a/test/data/wiremock/mappings/auth/oauth/authorization_code/token_request_error.json b/test/data/wiremock/mappings/auth/oauth/authorization_code/token_request_error.json new file mode 100644 index 0000000000..ca925266be --- /dev/null +++ b/test/data/wiremock/mappings/auth/oauth/authorization_code/token_request_error.json @@ -0,0 +1,67 @@ +{ + "mappings": [ + { + "scenarioName": "OAuth token request error", + "requiredScenarioState": "Started", + "newScenarioState": "Authorized", + "request": { + "urlPathPattern": "/oauth/authorize", + "queryParameters": { + "response_type": { + "equalTo": "code" + }, + "scope": { + "equalTo": "session:role:ANALYST" + }, + "code_challenge_method": { + "equalTo": "S256" + }, + "redirect_uri": { + "equalTo": "http://localhost:8009/snowflake/oauth-redirect" + }, + "code_challenge": { + "matches": ".*" + }, + "state": { + "matches": ".*" + }, + "client_id": { + "equalTo": "123" + } + }, + "method": "GET" + }, + "response": { + "status": 302, + "headers": { + "Location": "http://localhost:8009/snowflake/oauth-redirect?code=123&state=abc123" + } + } + }, + { + "scenarioName": "OAuth token request error", + "requiredScenarioState": "Authorized", + "newScenarioState": "Token request error", + "request": { + "urlPathPattern": "/oauth/token-request.*", + "method": "POST", + "headers": { + "Authorization": { + "contains": "Basic" + }, + "Content-Type": { + "contains": "application/x-www-form-urlencoded; charset=UTF-8" + } + }, + "bodyPatterns": [ + { + "contains": "grant_type=authorization_code&code=123&redirect_uri=http%3A%2F%2Flocalhost%3A8009%2Fsnowflake%2Foauth-redirect&code_verifier=" + } + ] + }, + "response": { + "status": 400 + } + } + ] +} diff --git a/test/data/wiremock/mappings/auth/oauth/client_credentials/successful_auth_after_failed_refresh.json b/test/data/wiremock/mappings/auth/oauth/client_credentials/successful_auth_after_failed_refresh.json new file mode 100644 index 0000000000..f6f6a9d4a8 --- /dev/null +++ b/test/data/wiremock/mappings/auth/oauth/client_credentials/successful_auth_after_failed_refresh.json @@ -0,0 +1,35 @@ +{ + "scenarioName": "Successful OAuth client credentials flow", + "requiredScenarioState": "Started", + "newScenarioState": "Acquired access token", + "request": { + "urlPathPattern": "/oauth/token-request.*", + "method": "POST", + "headers": { + "Authorization": { + "contains": "Basic" + }, + "Content-Type": { + "contains": "application/x-www-form-urlencoded; charset=UTF-8" + } + }, + "bodyPatterns": [ + { + "contains": "grant_type=client_credentials&scope=session%3Arole%3AANALYST" + } + ] + }, + "response": { + "status": 200, + "jsonBody": { + "access_token": "access-token-123", + "refresh_token": "refresh-token-123", + "token_type": "Bearer", + "username": "user", + "scope": "refresh_token session:role:ANALYST", + "expires_in": 600, + "refresh_token_expires_in": 86399, + "idpInitiated": false + } + } +} diff --git a/test/data/wiremock/mappings/auth/oauth/client_credentials/successful_flow.json b/test/data/wiremock/mappings/auth/oauth/client_credentials/successful_flow.json new file mode 100644 index 0000000000..10ed78c84c --- /dev/null +++ b/test/data/wiremock/mappings/auth/oauth/client_credentials/successful_flow.json @@ -0,0 +1,39 @@ +{ + "mappings": [ + { + "scenarioName": "Successful OAuth client credentials flow", + "requiredScenarioState": "Started", + "newScenarioState": "Acquired access token", + "request": { + "urlPathPattern": "/oauth/token-request.*", + "method": "POST", + "headers": { + "Authorization": { + "contains": "Basic" + }, + "Content-Type": { + "contains": "application/x-www-form-urlencoded; charset=UTF-8" + } + }, + "bodyPatterns": [ + { + "contains": "grant_type=client_credentials&scope=session%3Arole%3AANALYST" + } + ] + }, + "response": { + "status": 200, + "jsonBody": { + "access_token": "access-token-123", + "refresh_token": "123", + "token_type": "Bearer", + "username": "user", + "scope": "refresh_token session:role:ANALYST", + "expires_in": 600, + "refresh_token_expires_in": 86399, + "idpInitiated": false + } + } + } + ] +} diff --git a/test/data/wiremock/mappings/auth/oauth/client_credentials/token_request_error.json b/test/data/wiremock/mappings/auth/oauth/client_credentials/token_request_error.json new file mode 100644 index 0000000000..b30b6056bf --- /dev/null +++ b/test/data/wiremock/mappings/auth/oauth/client_credentials/token_request_error.json @@ -0,0 +1,29 @@ +{ + "mappings": [ + { + "scenarioName": "OAuth client credentials flow with token request error", + "requiredScenarioState": "Started", + "newScenarioState": "Acquired access token", + "request": { + "urlPathPattern": "/oauth/token-request.*", + "method": "POST", + "headers": { + "Authorization": { + "contains": "Basic" + }, + "Content-Type": { + "contains": "application/x-www-form-urlencoded; charset=UTF-8" + } + }, + "bodyPatterns": [ + { + "contains": "grant_type=client_credentials&scope=session%3Arole%3AANALYST" + } + ] + }, + "response": { + "status": 400 + } + } + ] +} diff --git a/test/data/wiremock/mappings/auth/oauth/refresh_token/refresh_failed.json b/test/data/wiremock/mappings/auth/oauth/refresh_token/refresh_failed.json new file mode 100644 index 0000000000..5529590b4b --- /dev/null +++ b/test/data/wiremock/mappings/auth/oauth/refresh_token/refresh_failed.json @@ -0,0 +1,28 @@ +{ + "requiredScenarioState": "Expired access token", + "newScenarioState": "Failed refresh token attempt", + "request": { + "urlPathPattern": "/oauth/token-request.*", + "method": "POST", + "headers": { + "Authorization": { + "contains": "Basic" + }, + "Content-Type": { + "contains": "application/x-www-form-urlencoded; charset=UTF-8" + } + }, + "bodyPatterns": [ + { + "contains": "grant_type=refresh_token&refresh_token=expired-refresh-token-123&scope=session%3Arole%3AANALYST+offline_access" + } + ] + }, + "response": { + "status": 400, + "jsonBody": { + "error": "invalid_grant", + "error_description": "Unknown or invalid refresh token." + } + } +} diff --git a/test/data/wiremock/mappings/auth/oauth/refresh_token/refresh_successful.json b/test/data/wiremock/mappings/auth/oauth/refresh_token/refresh_successful.json new file mode 100644 index 0000000000..be816ed1b7 --- /dev/null +++ b/test/data/wiremock/mappings/auth/oauth/refresh_token/refresh_successful.json @@ -0,0 +1,30 @@ +{ + "requiredScenarioState": "Expired access token", + "newScenarioState": "Acquired access token", + "request": { + "urlPathPattern": "/oauth/token-request.*", + "method": "POST", + "headers": { + "Authorization": { + "contains": "Basic" + }, + "Content-Type": { + "contains": "application/x-www-form-urlencoded; charset=UTF-8" + } + }, + "bodyPatterns": [ + { + "contains": "grant_type=refresh_token&refresh_token=refresh-token-123&scope=session%3Arole%3AANALYST+offline_access" + } + ] + }, + "response": { + "status": 200, + "jsonBody": { + "access_token": "access-token-123", + "token_type": "Bearer", + "expires_in": 599, + "idpInitiated": false + } + } +} diff --git a/test/data/wiremock/mappings/generic/snowflake_login_failed.json b/test/data/wiremock/mappings/generic/snowflake_login_failed.json new file mode 100644 index 0000000000..a9afa16a51 --- /dev/null +++ b/test/data/wiremock/mappings/generic/snowflake_login_failed.json @@ -0,0 +1,48 @@ +{ + "mappings": [ + { + "scenarioName": "Refresh expired access token", + "requiredScenarioState": "Started", + "newScenarioState": "Expired access token", + "request": { + "urlPathPattern": "/session/v1/login-request", + "method": "POST", + "queryParameters": { + "request_id": { + "matches": ".*" + }, + "roleName": { + "equalTo": "ANALYST" + } + }, + "headers": { + "Content-Type": { + "contains": "application/json" + } + }, + "bodyPatterns": [ + { + "matchesJsonPath": "$.data" + }, + { + "matchesJsonPath": "$[?(@.data.TOKEN==\"expired-access-token-123\")]" + } + ] + }, + "response": { + "status": 200, + "jsonBody": { + "data": { + "nextAction": "RETRY_LOGIN", + "authnMethod": "OAUTH", + "signInOptions": {} + }, + "code": "390318", + "message": "OAuth access token expired. [1172527951366]", + "success": false, + "headers": null + } + } + } + ] +} diff --git a/test/data/wiremock/mappings/generic/snowflake_login_successful.json b/test/data/wiremock/mappings/generic/snowflake_login_successful.json new file mode 100644 index 0000000000..8e6297152c --- /dev/null +++ b/test/data/wiremock/mappings/generic/snowflake_login_successful.json @@ -0,0 +1,64 @@ +{ + "requiredScenarioState": "Acquired access token", + "newScenarioState": "Connected", + "request": { + "urlPathPattern": "/session/v1/login-request", + "method": "POST", + "queryParameters": { + "request_id": { + "matches": ".*" + }, + "roleName": { + "equalTo": "ANALYST" + } + }, + "headers": { + "Content-Type": { + "contains": "application/json" + } + }, + "bodyPatterns": [ + { + "matchesJsonPath": "$.data" + }, + { + "matchesJsonPath": "$[?(@.data.TOKEN==\"access-token-123\")]" + } + ] + }, + "response": { + "status": 200, + "fixedDelayMilliseconds": "1000", + "jsonBody": { + "data": { + "masterToken": "token-m1", + "token": "token-t1", + "validityInSeconds": 3599, + "masterValidityInSeconds": 14400, + "displayUserName": "***", + "serverVersion": "***", + "firstLogin": false, + "remMeToken": null, + "remMeValidityInSeconds": 0, + "healthCheckInterval": 45, + "newClientForUpgrade": null, + "sessionId": 1313, + "parameters": [], + "sessionInfo": { + "databaseName": null, + "schemaName": null, + "warehouseName": "TEST", + "roleName": "ACCOUNTADMIN" + }, + "idToken": null, + "idTokenValidityInSeconds": 0, + "responseData": null, + "mfaToken": null, + "mfaTokenValidityInSeconds": 0 + }, + "code": null, + "message": null, + "success": true + } + } +} diff --git a/test/unit/test_auth_callback_server.py b/test/unit/test_auth_callback_server.py new file mode 100644 index 0000000000..bf03a8d5f6 --- /dev/null +++ b/test/unit/test_auth_callback_server.py @@ -0,0 +1,63 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# + +from __future__ import annotations + +import socket +import time +from threading import Thread + +import pytest + +from snowflake.connector.auth._http_server import AuthHttpServer +from snowflake.connector.vendored import requests + + +@pytest.mark.parametrize( + "dontwait", + ["false", "true"], +) +@pytest.mark.parametrize("timeout", [None, 0.05]) +@pytest.mark.parametrize("reuse_port", ["true"]) +def test_auth_callback_success(monkeypatch, dontwait, timeout, reuse_port) -> None: + monkeypatch.setenv("SNOWFLAKE_AUTH_SOCKET_REUSE_PORT", reuse_port) + monkeypatch.setenv("SNOWFLAKE_AUTH_SOCKET_MSG_DONTWAIT", dontwait) + test_response: requests.Response | None = None + with AuthHttpServer("http://127.0.0.1/test_request") as callback_server: + + def request_callback(): + nonlocal test_response + if timeout: + time.sleep(timeout / 5) + test_response = requests.get( + f"http://{callback_server.hostname}:{callback_server.port}/test_request" + ) + + request_callback_thread = Thread(target=request_callback) + request_callback_thread.start() + block, client_socket = callback_server.receive_block(timeout=timeout) + test_callback_request = block[0] + response = ["HTTP/1.1 200 OK", "Content-Type: text/html", "", "test_response"] + client_socket.sendall("\r\n".join(response).encode("utf-8")) + client_socket.shutdown(socket.SHUT_RDWR) + client_socket.close() + request_callback_thread.join() + assert test_response.ok + assert test_response.text == "test_response" + assert test_callback_request == "GET /test_request HTTP/1.1" + + +@pytest.mark.parametrize( + "dontwait", + ["false", "true"], +) +@pytest.mark.parametrize("timeout", [0.05]) +@pytest.mark.parametrize("reuse_port", ["true"]) +def test_auth_callback_timeout(monkeypatch, dontwait, timeout, reuse_port) -> None: + monkeypatch.setenv("SNOWFLAKE_AUTH_SOCKET_REUSE_PORT", reuse_port) + monkeypatch.setenv("SNOWFLAKE_AUTH_SOCKET_MSG_DONTWAIT", dontwait) + with AuthHttpServer("http://127.0.0.1/test_request") as callback_server: + block, client_socket = callback_server.receive_block(timeout=timeout) + assert block is None + assert client_socket is None diff --git a/test/unit/test_auth_oauth_auth_code.py b/test/unit/test_auth_oauth_auth_code.py new file mode 100644 index 0000000000..6a01bb014f --- /dev/null +++ b/test/unit/test_auth_oauth_auth_code.py @@ -0,0 +1,22 @@ +#!/usr/bin/env python +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# + +from snowflake.connector.auth import AuthByOauthCode + + +def test_auth_oauth_auth_code_oauth_type(): + """Simple OAuth Auth Code oauth type test.""" + auth = AuthByOauthCode( + "app", + "clientId", + "clientSecret", + "auth_url", + "tokenRequestUrl", + "redirectUri:{port}", + "scope", + ) + body = {"data": {}} + auth.update_body(body) + assert body["data"]["OAUTH_TYPE"] == "authorization_code" diff --git a/test/unit/test_connection.py b/test/unit/test_connection.py index 8e229b751f..a29babc2c4 100644 --- a/test/unit/test_connection.py +++ b/test/unit/test_connection.py @@ -637,7 +637,7 @@ def test_cannot_set_wlid_authenticator_without_env_variable(mock_post_requests): account="account", authenticator="WORKLOAD_IDENTITY" ) assert ( - "Please set the 'SF_ENABLE_EXPERIMENTAL_AUTHENTICATION' environment variable to use the 'WORKLOAD_IDENTITY' authenticator" + "Please set the 'SF_ENABLE_EXPERIMENTAL_AUTHENTICATION' environment variable true to use the 'WORKLOAD_IDENTITY' authenticator" in str(excinfo.value) ) @@ -647,7 +647,7 @@ def test_connection_params_are_plumbed_into_authbyworkloadidentity(monkeypatch): m.setattr( "snowflake.connector.SnowflakeConnection._authenticate", lambda *_: None ) - m.setenv("SF_ENABLE_EXPERIMENTAL_AUTHENTICATION", "") # Can be set to anything. + m.setenv("SF_ENABLE_EXPERIMENTAL_AUTHENTICATION", "true") conn = snowflake.connector.connect( account="my_account_1", @@ -689,7 +689,7 @@ def test_toml_connection_params_are_plumbed_into_authbyworkloadidentity( m.setattr( "snowflake.connector.SnowflakeConnection._authenticate", lambda *_: None ) - m.setenv("SF_ENABLE_EXPERIMENTAL_AUTHENTICATION", "") + m.setenv("SF_ENABLE_EXPERIMENTAL_AUTHENTICATION", "true") conn = snowflake.connector.connect(connections_file_path=connections_file) assert conn.auth_class.provider == AttestationProvider.OIDC diff --git a/test/unit/test_linux_local_file_cache.py b/test/unit/test_linux_local_file_cache.py index 51617f6094..2cf7c6348f 100644 --- a/test/unit/test_linux_local_file_cache.py +++ b/test/unit/test_linux_local_file_cache.py @@ -1,12 +1,15 @@ #!/usr/bin/env python from __future__ import annotations -import os +import time import pytest +from _pytest import pathlib from snowflake.connector.compat import IS_LINUX +pytestmark = pytest.mark.skipif(not IS_LINUX, reason="Testing on linux only") + try: from snowflake.connector.token_cache import FileTokenCache, TokenKey, TokenType @@ -23,13 +26,13 @@ CRED_1 = "cred_1" -@pytest.mark.skipif(not IS_LINUX, reason="The test is only for Linux platform") @pytest.mark.skipolddriver -def test_basic_store(tmpdir): - os.environ["SF_TEMPORARY_CREDENTIAL_CACHE_DIR"] = str(tmpdir) - - cache = FileTokenCache() - cache.delete_temporary_credential_file() +def test_basic_store(tmpdir, monkeypatch): + monkeypatch.setenv("SF_TEMPORARY_CREDENTIAL_CACHE_DIR", str(tmpdir)) + cache = FileTokenCache.make() + assert cache + assert cache.cache_dir == pathlib.Path(tmpdir) + cache.cache_file().unlink(missing_ok=True) cache.store(TokenKey(HOST_0, USER_0, CRED_TYPE_0), CRED_0) cache.store(TokenKey(HOST_1, USER_1, CRED_TYPE_1), CRED_1) @@ -39,13 +42,15 @@ def test_basic_store(tmpdir): assert cache.retrieve(TokenKey(HOST_1, USER_1, CRED_TYPE_1)) == CRED_1 assert cache.retrieve(TokenKey(HOST_0, USER_1, CRED_TYPE_1)) == CRED_1 - cache.delete_temporary_credential_file() + cache.cache_file().unlink(missing_ok=True) -def test_delete_specific_item(): - """The old behavior of delete cache is deleting the whole cache file. Now we change it to partially deletion.""" - cache = FileTokenCache() - cache.delete_temporary_credential_file() +@pytest.mark.skipif(not IS_LINUX, reason="The test is only for Linux platform") +def test_delete_specific_item(tmpdir, monkeypatch): + monkeypatch.setenv("SF_TEMPORARY_CREDENTIAL_CACHE_DIR", str(tmpdir)) + cache = FileTokenCache.make() + assert cache + cache.cache_file().unlink(missing_ok=True) cache.store(TokenKey(HOST_0, USER_0, CRED_TYPE_0), CRED_0) cache.store(TokenKey(HOST_0, USER_0, CRED_TYPE_1), CRED_1) @@ -55,4 +60,170 @@ def test_delete_specific_item(): cache.remove(TokenKey(HOST_0, USER_0, CRED_TYPE_0)) assert not cache.retrieve(TokenKey(HOST_0, USER_0, CRED_TYPE_0)) assert cache.retrieve(TokenKey(HOST_0, USER_0, CRED_TYPE_1)) == CRED_1 - cache.delete_temporary_credential_file() + cache.cache_file().unlink(missing_ok=True) + + +def test_malformed_json_cache(tmpdir, monkeypatch): + monkeypatch.setenv("SF_TEMPORARY_CREDENTIAL_CACHE_DIR", str(tmpdir)) + cache = FileTokenCache.make() + assert cache + cache.cache_file().unlink(missing_ok=True) + cache.cache_file().touch(0o600) + invalid_json = "[}" + cache.cache_file().write_text(invalid_json) + assert cache.retrieve(TokenKey(HOST_0, USER_0, CRED_TYPE_0)) is None + cache.store(TokenKey(HOST_0, USER_0, CRED_TYPE_0), CRED_0) + assert cache.retrieve(TokenKey(HOST_0, USER_0, CRED_TYPE_0)) == CRED_0 + + +def test_malformed_utf_cache(tmpdir, monkeypatch): + monkeypatch.setenv("SF_TEMPORARY_CREDENTIAL_CACHE_DIR", str(tmpdir)) + cache = FileTokenCache.make() + assert cache + cache.cache_file().unlink(missing_ok=True) + cache.cache_file().touch(0o600) + invalid_utf_sequence = bytes.fromhex("c0af") + cache.cache_file().write_bytes(invalid_utf_sequence) + assert cache.retrieve(TokenKey(HOST_0, USER_0, CRED_TYPE_0)) is None + cache.store(TokenKey(HOST_0, USER_0, CRED_TYPE_0), CRED_0) + assert cache.retrieve(TokenKey(HOST_0, USER_0, CRED_TYPE_0)) == CRED_0 + + +def test_cache_dir_is_not_a_directory(tmpdir, monkeypatch): + file = pathlib.Path(str(tmpdir)) / "file" + file.touch() + monkeypatch.setenv("SF_TEMPORARY_CREDENTIAL_CACHE_DIR", str(file)) + monkeypatch.delenv("XDG_CACHE_HOME", raising=False) + monkeypatch.delenv("HOME", raising=False) + cache_dir = FileTokenCache.find_cache_dir() + assert cache_dir is None + file.unlink() + + +def test_cache_dir_does_not_exist(tmpdir, monkeypatch): + directory = pathlib.Path(str(tmpdir)) / "dir" + directory.unlink(missing_ok=True) + monkeypatch.setenv("SF_TEMPORARY_CREDENTIAL_CACHE_DIR", str(directory)) + monkeypatch.delenv("XDG_CACHE_HOME", raising=False) + monkeypatch.delenv("HOME", raising=False) + cache_dir = FileTokenCache.find_cache_dir() + assert cache_dir is None + + +def test_cache_dir_incorrect_permissions(tmpdir, monkeypatch): + directory = pathlib.Path(str(tmpdir)) / "dir" + directory.unlink(missing_ok=True) + directory.touch(0o777) + monkeypatch.setenv("SF_TEMPORARY_CREDENTIAL_CACHE_DIR", str(directory)) + monkeypatch.delenv("XDG_CACHE_HOME", raising=False) + monkeypatch.delenv("HOME", raising=False) + cache_dir = FileTokenCache.find_cache_dir() + assert cache_dir is None + directory.unlink() + + +def test_cache_file_incorrect_permissions(tmpdir, monkeypatch): + monkeypatch.setenv("SF_TEMPORARY_CREDENTIAL_CACHE_DIR", str(tmpdir)) + cache = FileTokenCache.make() + assert cache + cache.cache_file().unlink(missing_ok=True) + cache.cache_file().touch(0o777) + assert cache.retrieve(TokenKey(HOST_0, USER_0, CRED_TYPE_0)) is None + cache.store(TokenKey(HOST_0, USER_0, CRED_TYPE_0), CRED_0) + assert cache.retrieve(TokenKey(HOST_0, USER_0, CRED_TYPE_0)) is None + assert len(cache.cache_file().read_text("utf-8")) == 0 + cache.cache_file().unlink() + + +def test_cache_dir_xdg_cache_home(tmpdir, monkeypatch): + monkeypatch.delenv("SF_TEMPORARY_CREDENTIAL_CACHE_DIR", raising=False) + monkeypatch.setenv("XDG_CACHE_HOME", str(tmpdir)) + cache = FileTokenCache.make() + assert cache + cache.cache_file().unlink(missing_ok=True) + assert cache.cache_dir == pathlib.Path(str(tmpdir)) / "snowflake" + assert ( + cache.cache_file() + == pathlib.Path(str(tmpdir)) / "snowflake" / "credential_cache_v1.json" + ) + assert ( + cache.lock_file() + == pathlib.Path(str(tmpdir)) / "snowflake" / "credential_cache_v1.json.lck" + ) + cache.store(TokenKey(HOST_0, USER_0, CRED_TYPE_0), CRED_0) + assert cache.retrieve(TokenKey(HOST_0, USER_0, CRED_TYPE_0)) == CRED_0 + cache.cache_file().unlink() + + +def test_cache_dir_home(tmpdir, monkeypatch): + monkeypatch.delenv("SF_TEMPORARY_CREDENTIAL_CACHE_DIR", raising=False) + monkeypatch.delenv("XDG_CACHE_HOME", raising=False) + monkeypatch.setenv("HOME", str(tmpdir)) + cache = FileTokenCache.make() + assert cache + cache.cache_file().unlink(missing_ok=True) + assert cache.cache_dir == pathlib.Path(str(tmpdir)) / ".cache" / "snowflake" + assert ( + cache.cache_file() + == pathlib.Path(str(tmpdir)) + / ".cache" + / "snowflake" + / "credential_cache_v1.json" + ) + assert ( + cache.lock_file() + == pathlib.Path(str(tmpdir)) + / ".cache" + / "snowflake" + / "credential_cache_v1.json.lck" + ) + cache.store(TokenKey(HOST_0, USER_0, CRED_TYPE_0), CRED_0) + assert cache.retrieve(TokenKey(HOST_0, USER_0, CRED_TYPE_0)) == CRED_0 + + +def test_file_lock(tmpdir, monkeypatch): + monkeypatch.setenv("SF_TEMPORARY_CREDENTIAL_CACHE_DIR", str(tmpdir)) + cache = FileTokenCache.make() + assert cache + cache.store(TokenKey(HOST_0, USER_0, CRED_TYPE_0), CRED_0) + assert cache.retrieve(TokenKey(HOST_0, USER_0, CRED_TYPE_0)) == CRED_0 + cache.lock_file().mkdir(0o700) + assert cache.retrieve(TokenKey(HOST_0, USER_0, CRED_TYPE_0)) is None + assert cache.lock_file().exists() + cache.lock_file().rmdir() + + +def test_file_lock_stale(tmpdir, monkeypatch): + monkeypatch.setenv("SF_TEMPORARY_CREDENTIAL_CACHE_DIR", str(tmpdir)) + cache = FileTokenCache.make() + assert cache + cache.store(TokenKey(HOST_0, USER_0, CRED_TYPE_0), CRED_0) + assert cache.retrieve(TokenKey(HOST_0, USER_0, CRED_TYPE_0)) == CRED_0 + cache.lock_file().mkdir(0o700) + time.sleep(1) + assert cache.retrieve(TokenKey(HOST_0, USER_0, CRED_TYPE_0)) == CRED_0 + assert not cache.lock_file().exists() + + +def test_file_missing_tokens_field(tmpdir, monkeypatch): + monkeypatch.setenv("SF_TEMPORARY_CREDENTIAL_CACHE_DIR", str(tmpdir)) + cache = FileTokenCache.make() + assert cache + cache.cache_file().touch(0o600) + cache.cache_file().write_text("{}") + assert cache.retrieve(TokenKey(HOST_0, USER_0, CRED_TYPE_0)) is None + cache.store(TokenKey(HOST_0, USER_0, CRED_TYPE_0), CRED_0) + assert cache.retrieve(TokenKey(HOST_0, USER_0, CRED_TYPE_0)) == CRED_0 + cache.cache_file().unlink() + + +def test_file_tokens_is_not_dict(tmpdir, monkeypatch): + monkeypatch.setenv("SF_TEMPORARY_CREDENTIAL_CACHE_DIR", str(tmpdir)) + cache = FileTokenCache.make() + assert cache + cache.cache_file().touch(0o600) + cache.cache_file().write_text('{ "tokens": [] }') + assert cache.retrieve(TokenKey(HOST_0, USER_0, CRED_TYPE_0)) is None + cache.store(TokenKey(HOST_0, USER_0, CRED_TYPE_0), CRED_0) + assert cache.retrieve(TokenKey(HOST_0, USER_0, CRED_TYPE_0)) == CRED_0 + cache.cache_file().unlink() diff --git a/test/unit/test_oauth_token.py b/test/unit/test_oauth_token.py new file mode 100644 index 0000000000..9152f39c8c --- /dev/null +++ b/test/unit/test_oauth_token.py @@ -0,0 +1,729 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# + +import logging +import pathlib +from threading import Thread +from typing import Any, Generator, Union +from unittest import mock +from unittest.mock import Mock, patch + +import pytest +import requests + +import snowflake.connector +from snowflake.connector.auth import AuthByOauthCredentials +from snowflake.connector.token_cache import TokenCache, TokenKey, TokenType + +from ..wiremock.wiremock_utils import WiremockClient + +logger = logging.getLogger(__name__) + + +@pytest.fixture(scope="session") +def wiremock_client() -> Generator[Union[WiremockClient, Any], Any, None]: + with WiremockClient() as client: + yield client + + +@pytest.fixture(scope="session") +def wiremock_oauth_authorization_code_dir() -> pathlib.Path: + return ( + pathlib.Path(__file__).parent.parent + / "data" + / "wiremock" + / "mappings" + / "auth" + / "oauth" + / "authorization_code" + ) + + +@pytest.fixture(scope="session") +def wiremock_oauth_client_creds_dir() -> pathlib.Path: + return ( + pathlib.Path(__file__).parent.parent + / "data" + / "wiremock" + / "mappings" + / "auth" + / "oauth" + / "client_credentials" + ) + + +@pytest.fixture(scope="session") +def wiremock_generic_mappings_dir() -> pathlib.Path: + return ( + pathlib.Path(__file__).parent.parent + / "data" + / "wiremock" + / "mappings" + / "generic" + ) + + +@pytest.fixture(scope="session") +def wiremock_oauth_refresh_token_dir() -> pathlib.Path: + return ( + pathlib.Path(__file__).parent.parent + / "data" + / "wiremock" + / "mappings" + / "auth" + / "oauth" + / "refresh_token" + ) + + +def _call_auth_server(url: str): + requests.get(url, allow_redirects=True, timeout=6) + + +def _webbrowser_redirect(*args): + assert len(args) == 1, "Invalid number of arguments passed to webbrowser open" + + thread = Thread(target=_call_auth_server, args=(args[0],)) + thread.start() + + return thread.is_alive() + + +@pytest.fixture(scope="session") +def webbrowser_mock() -> Mock: + webbrowser_mock = Mock() + webbrowser_mock.open = _webbrowser_redirect + return webbrowser_mock + + +@pytest.fixture() +def temp_cache(): + class TemporaryCache(TokenCache): + def __init__(self): + self._cache = {} + + def store(self, key: TokenKey, token: str) -> None: + self._cache[(key.user, key.host, key.tokenType)] = token + + def retrieve(self, key: TokenKey) -> str: + return self._cache.get((key.user, key.host, key.tokenType)) + + def remove(self, key: TokenKey) -> None: + self._cache.pop((key.user, key.host, key.tokenType)) + + tmp_cache = TemporaryCache() + with mock.patch( + "snowflake.connector.auth._auth.Auth.get_token_cache", return_value=tmp_cache + ): + yield tmp_cache + + +@pytest.mark.skipolddriver +@patch("snowflake.connector.auth._http_server.AuthHttpServer.DEFAULT_TIMEOUT", 30) +def test_oauth_code_successful_flow( + wiremock_client: WiremockClient, + wiremock_oauth_authorization_code_dir, + wiremock_generic_mappings_dir, + webbrowser_mock, + monkeypatch, +) -> None: + monkeypatch.setenv("SF_ENABLE_EXPERIMENTAL_AUTHENTICATION", "true") + monkeypatch.setenv("SNOWFLAKE_AUTH_SOCKET_REUSE_PORT", "true") + + wiremock_client.import_mapping( + wiremock_oauth_authorization_code_dir / "successful_flow.json" + ) + wiremock_client.add_mapping( + wiremock_generic_mappings_dir / "snowflake_login_successful.json" + ) + wiremock_client.add_mapping( + wiremock_generic_mappings_dir / "snowflake_disconnect_successful.json" + ) + + with mock.patch("webbrowser.open", new=webbrowser_mock.open): + with mock.patch("secrets.token_urlsafe", return_value="abc123"): + cnx = snowflake.connector.connect( + user="testUser", + authenticator="OAUTH_AUTHORIZATION_CODE", + oauth_client_id="123", + account="testAccount", + protocol="http", + role="ANALYST", + oauth_client_secret="testClientSecret", + oauth_token_request_url=f"http://{wiremock_client.wiremock_host}:{wiremock_client.wiremock_http_port}/oauth/token-request", + oauth_authorization_url=f"http://{wiremock_client.wiremock_host}:{wiremock_client.wiremock_http_port}/oauth/authorize", + oauth_redirect_uri="http://localhost:8009/snowflake/oauth-redirect", + host=wiremock_client.wiremock_host, + port=wiremock_client.wiremock_http_port, + ) + + assert cnx, "invalid cnx" + cnx.close() + + +@pytest.mark.skipolddriver +@patch("snowflake.connector.auth._http_server.AuthHttpServer.DEFAULT_TIMEOUT", 30) +def test_oauth_code_invalid_state( + wiremock_client: WiremockClient, + wiremock_oauth_authorization_code_dir, + webbrowser_mock, + monkeypatch, +) -> None: + monkeypatch.setenv("SF_ENABLE_EXPERIMENTAL_AUTHENTICATION", "true") + monkeypatch.setenv("SNOWFLAKE_AUTH_SOCKET_REUSE_PORT", "true") + + wiremock_client.import_mapping( + wiremock_oauth_authorization_code_dir / "invalid_state_error.json" + ) + + with pytest.raises(snowflake.connector.DatabaseError) as execinfo: + with mock.patch("webbrowser.open", new=webbrowser_mock.open): + with mock.patch("secrets.token_urlsafe", return_value="abc123"): + snowflake.connector.connect( + user="testUser", + authenticator="OAUTH_AUTHORIZATION_CODE", + oauth_client_id="123", + oauth_client_secret="testClientSecret", + account="testAccount", + protocol="http", + role="ANALYST", + oauth_token_request_url=f"http://{wiremock_client.wiremock_host}:{wiremock_client.wiremock_http_port}/oauth/token-request", + oauth_authorization_url=f"http://{wiremock_client.wiremock_host}:{wiremock_client.wiremock_http_port}/oauth/authorize", + oauth_redirect_uri="http://localhost:8009/snowflake/oauth-redirect", + host=wiremock_client.wiremock_host, + port=wiremock_client.wiremock_http_port, + ) + + assert str(execinfo.value).endswith("State changed during OAuth process.") + + +@pytest.mark.skipolddriver +@patch("snowflake.connector.auth._http_server.AuthHttpServer.DEFAULT_TIMEOUT", 30) +def test_oauth_code_scope_error( + wiremock_client: WiremockClient, + wiremock_oauth_authorization_code_dir, + webbrowser_mock, + monkeypatch, +) -> None: + monkeypatch.setenv("SF_ENABLE_EXPERIMENTAL_AUTHENTICATION", "true") + monkeypatch.setenv("SNOWFLAKE_AUTH_SOCKET_REUSE_PORT", "true") + + wiremock_client.import_mapping( + wiremock_oauth_authorization_code_dir / "invalid_scope_error.json" + ) + + with pytest.raises(snowflake.connector.DatabaseError) as execinfo: + with mock.patch("webbrowser.open", new=webbrowser_mock.open): + with mock.patch("secrets.token_urlsafe", return_value="abc123"): + snowflake.connector.connect( + user="testUser", + authenticator="OAUTH_AUTHORIZATION_CODE", + oauth_client_id="123", + account="testAccount", + protocol="http", + role="ANALYST", + oauth_token_request_url=f"http://{wiremock_client.wiremock_host}:{wiremock_client.wiremock_http_port}/oauth/token-request", + oauth_authorization_url=f"http://{wiremock_client.wiremock_host}:{wiremock_client.wiremock_http_port}/oauth/authorize", + oauth_redirect_uri="http://localhost:8009/snowflake/oauth-redirect", + host=wiremock_client.wiremock_host, + port=wiremock_client.wiremock_http_port, + ) + + assert str(execinfo.value).endswith( + "Oauth callback returned an invalid_scope error: One or more scopes are not configured for the authorization server resource." + ) + + +@pytest.mark.skipolddriver +@patch("snowflake.connector.auth._http_server.AuthHttpServer.DEFAULT_TIMEOUT", 30) +def test_oauth_code_token_request_error( + wiremock_oauth_authorization_code_dir, + webbrowser_mock, + monkeypatch, +) -> None: + monkeypatch.setenv("SF_ENABLE_EXPERIMENTAL_AUTHENTICATION", "true") + monkeypatch.setenv("SNOWFLAKE_AUTH_SOCKET_REUSE_PORT", "true") + + with WiremockClient() as wiremock_client: + wiremock_client.import_mapping( + wiremock_oauth_authorization_code_dir / "token_request_error.json" + ) + + with pytest.raises(snowflake.connector.DatabaseError) as execinfo: + with mock.patch("webbrowser.open", new=webbrowser_mock.open): + with mock.patch("secrets.token_urlsafe", return_value="abc123"): + snowflake.connector.connect( + user="testUser", + authenticator="OAUTH_AUTHORIZATION_CODE", + oauth_client_id="123", + oauth_client_secret="testClientSecret", + account="testAccount", + protocol="http", + role="ANALYST", + oauth_token_request_url=f"http://{wiremock_client.wiremock_host}:{wiremock_client.wiremock_http_port}/oauth/token-request", + oauth_authorization_url=f"http://{wiremock_client.wiremock_host}:{wiremock_client.wiremock_http_port}/oauth/authorize", + oauth_redirect_uri="http://localhost:8009/snowflake/oauth-redirect", + host=wiremock_client.wiremock_host, + port=wiremock_client.wiremock_http_port, + ) + + assert str(execinfo.value).endswith( + "Invalid HTTP request from web browser. Idp authentication could have failed." + ) + + +@pytest.mark.skipolddriver +def test_oauth_code_browser_timeout( + wiremock_client: WiremockClient, + wiremock_oauth_authorization_code_dir, + webbrowser_mock, + monkeypatch, +) -> None: + monkeypatch.setenv("SF_ENABLE_EXPERIMENTAL_AUTHENTICATION", "true") + monkeypatch.setenv("SNOWFLAKE_AUTH_SOCKET_REUSE_PORT", "true") + + wiremock_client.import_mapping( + wiremock_oauth_authorization_code_dir + / "browser_timeout_authorization_error.json" + ) + + with pytest.raises(snowflake.connector.DatabaseError) as execinfo: + with mock.patch("webbrowser.open", new=webbrowser_mock.open): + with mock.patch("secrets.token_urlsafe", return_value="abc123"): + snowflake.connector.connect( + user="testUser", + authenticator="OAUTH_AUTHORIZATION_CODE", + oauth_client_id="123", + oauth_client_secret="testClientSecret", + account="testAccount", + protocol="http", + role="ANALYST", + oauth_token_request_url=f"http://{wiremock_client.wiremock_host}:{wiremock_client.wiremock_http_port}/oauth/token-request", + oauth_authorization_url=f"http://{wiremock_client.wiremock_host}:{wiremock_client.wiremock_http_port}/oauth/authorize", + oauth_redirect_uri="http://localhost:8009/snowflake/oauth-redirect", + host=wiremock_client.wiremock_host, + port=wiremock_client.wiremock_http_port, + external_browser_timeout=2, + ) + + assert str(execinfo.value).endswith( + "Unable to receive the OAuth message within a given timeout. Please check the redirect URI and try again." + ) + + +@pytest.mark.skipolddriver +@patch("snowflake.connector.auth._http_server.AuthHttpServer.DEFAULT_TIMEOUT", 30) +def test_oauth_code_custom_urls( + wiremock_client: WiremockClient, + wiremock_oauth_authorization_code_dir, + wiremock_generic_mappings_dir, + webbrowser_mock, + monkeypatch, +) -> None: + monkeypatch.setenv("SF_ENABLE_EXPERIMENTAL_AUTHENTICATION", "true") + monkeypatch.setenv("SNOWFLAKE_AUTH_SOCKET_REUSE_PORT", "true") + + wiremock_client.import_mapping( + wiremock_oauth_authorization_code_dir / "external_idp_custom_urls.json" + ) + wiremock_client.add_mapping( + wiremock_generic_mappings_dir / "snowflake_login_successful.json" + ) + wiremock_client.add_mapping( + wiremock_generic_mappings_dir / "snowflake_disconnect_successful.json" + ) + + with mock.patch("webbrowser.open", new=webbrowser_mock.open): + with mock.patch("secrets.token_urlsafe", return_value="abc123"): + cnx = snowflake.connector.connect( + user="testUser", + authenticator="OAUTH_AUTHORIZATION_CODE", + oauth_client_id="123", + oauth_client_secret="testClientSecret", + account="testAccount", + protocol="http", + role="ANALYST", + oauth_token_request_url=f"http://{wiremock_client.wiremock_host}:{wiremock_client.wiremock_http_port}/tokenrequest", + oauth_authorization_url=f"http://{wiremock_client.wiremock_host}:{wiremock_client.wiremock_http_port}/authorization", + oauth_redirect_uri="http://localhost:8009/snowflake/oauth-redirect", + host=wiremock_client.wiremock_host, + port=wiremock_client.wiremock_http_port, + ) + + assert cnx, "invalid cnx" + cnx.close() + + +@pytest.mark.skipolddriver +@patch("snowflake.connector.auth._http_server.AuthHttpServer.DEFAULT_TIMEOUT", 30) +def test_oauth_code_successful_refresh_token_flow( + wiremock_client: WiremockClient, + wiremock_oauth_refresh_token_dir, + wiremock_generic_mappings_dir, + monkeypatch, + temp_cache, +) -> None: + monkeypatch.setenv("SF_ENABLE_EXPERIMENTAL_AUTHENTICATION", "true") + monkeypatch.setenv("SNOWFLAKE_AUTH_SOCKET_REUSE_PORT", "true") + + wiremock_client.import_mapping( + wiremock_generic_mappings_dir / "snowflake_login_failed.json" + ) + wiremock_client.add_mapping( + wiremock_oauth_refresh_token_dir / "refresh_successful.json" + ) + wiremock_client.add_mapping( + wiremock_generic_mappings_dir / "snowflake_login_successful.json" + ) + wiremock_client.add_mapping( + wiremock_generic_mappings_dir / "snowflake_disconnect_successful.json" + ) + user = "testUser" + access_token_key = TokenKey( + user, wiremock_client.wiremock_host, TokenType.OAUTH_ACCESS_TOKEN + ) + refresh_token_key = TokenKey( + user, wiremock_client.wiremock_host, TokenType.OAUTH_REFRESH_TOKEN + ) + temp_cache.store(access_token_key, "expired-access-token-123") + temp_cache.store(refresh_token_key, "refresh-token-123") + cnx = snowflake.connector.connect( + user=user, + authenticator="OAUTH_AUTHORIZATION_CODE", + oauth_client_id="123", + account="testAccount", + protocol="http", + role="ANALYST", + oauth_client_secret="testClientSecret", + oauth_token_request_url=f"http://{wiremock_client.wiremock_host}:{wiremock_client.wiremock_http_port}/oauth/token-request", + oauth_authorization_url=f"http://{wiremock_client.wiremock_host}:{wiremock_client.wiremock_http_port}/oauth/authorize", + oauth_redirect_uri="http://localhost:8009/snowflake/oauth-redirect", + host=wiremock_client.wiremock_host, + port=wiremock_client.wiremock_http_port, + oauth_security_features=("pkce", "refresh_token"), + client_store_temporary_credential=True, + ) + assert cnx, "invalid cnx" + cnx.close() + new_access_token = temp_cache.retrieve(access_token_key) + new_refresh_token = temp_cache.retrieve(refresh_token_key) + + assert new_access_token == "access-token-123" + assert new_refresh_token == "refresh-token-123" + + +@pytest.mark.skipolddriver +@patch("snowflake.connector.auth._http_server.AuthHttpServer.DEFAULT_TIMEOUT", 30) +def test_oauth_code_expired_refresh_token_flow( + wiremock_client: WiremockClient, + wiremock_oauth_refresh_token_dir, + wiremock_oauth_authorization_code_dir, + wiremock_generic_mappings_dir, + webbrowser_mock, + monkeypatch, + temp_cache, +) -> None: + monkeypatch.setenv("SF_ENABLE_EXPERIMENTAL_AUTHENTICATION", "true") + monkeypatch.setenv("SNOWFLAKE_AUTH_SOCKET_REUSE_PORT", "true") + + wiremock_client.import_mapping( + wiremock_generic_mappings_dir / "snowflake_login_failed.json" + ) + wiremock_client.add_mapping( + wiremock_oauth_refresh_token_dir / "refresh_failed.json" + ) + wiremock_client.add_mapping( + wiremock_oauth_authorization_code_dir + / "successful_auth_after_failed_refresh.json" + ) + wiremock_client.add_mapping( + wiremock_oauth_authorization_code_dir / "new_tokens_after_failed_refresh.json" + ) + wiremock_client.add_mapping( + wiremock_generic_mappings_dir / "snowflake_login_successful.json" + ) + wiremock_client.add_mapping( + wiremock_generic_mappings_dir / "snowflake_disconnect_successful.json" + ) + + user = "testUser" + access_token_key = TokenKey( + user, wiremock_client.wiremock_host, TokenType.OAUTH_ACCESS_TOKEN + ) + refresh_token_key = TokenKey( + user, wiremock_client.wiremock_host, TokenType.OAUTH_REFRESH_TOKEN + ) + temp_cache.store(access_token_key, "expired-access-token-123") + temp_cache.store(refresh_token_key, "expired-refresh-token-123") + with mock.patch("webbrowser.open", new=webbrowser_mock.open): + with mock.patch("secrets.token_urlsafe", return_value="abc123"): + cnx = snowflake.connector.connect( + user=user, + authenticator="OAUTH_AUTHORIZATION_CODE", + oauth_client_id="123", + account="testAccount", + protocol="http", + role="ANALYST", + oauth_client_secret="testClientSecret", + oauth_token_request_url=f"http://{wiremock_client.wiremock_host}:{wiremock_client.wiremock_http_port}/oauth/token-request", + oauth_authorization_url=f"http://{wiremock_client.wiremock_host}:{wiremock_client.wiremock_http_port}/oauth/authorize", + oauth_redirect_uri="http://localhost:8009/snowflake/oauth-redirect", + host=wiremock_client.wiremock_host, + port=wiremock_client.wiremock_http_port, + oauth_security_features=("pkce", "refresh_token"), + client_store_temporary_credential=True, + ) + assert cnx, "invalid cnx" + cnx.close() + + new_access_token = temp_cache.retrieve(access_token_key) + new_refresh_token = temp_cache.retrieve(refresh_token_key) + assert new_access_token == "access-token-123" + assert new_refresh_token == "refresh-token-123" + + +@pytest.mark.skipolddriver +def test_client_creds_oauth_type(): + """Simple OAuth Client credentials type test.""" + auth = AuthByOauthCredentials( + "app", + "clientId", + "clientSecret", + "auth_url", + "tokenRequestUrl", + "scope", + ) + body = {"data": {}} + auth.update_body(body) + assert body["data"]["OAUTH_TYPE"] == "client_credentials" + + +@pytest.mark.skipolddriver +def test_client_creds_successful_flow( + wiremock_client: WiremockClient, + wiremock_oauth_client_creds_dir, + wiremock_generic_mappings_dir, + monkeypatch, +) -> None: + monkeypatch.setenv("SF_ENABLE_EXPERIMENTAL_AUTHENTICATION", "true") + wiremock_client.import_mapping( + wiremock_oauth_client_creds_dir / "successful_flow.json" + ) + wiremock_client.add_mapping( + wiremock_generic_mappings_dir / "snowflake_login_successful.json" + ) + wiremock_client.add_mapping( + wiremock_generic_mappings_dir / "snowflake_disconnect_successful.json" + ) + with mock.patch("secrets.token_urlsafe", return_value="abc123"): + cnx = snowflake.connector.connect( + user="testUser", + authenticator="OAUTH_CLIENT_CREDENTIALS", + oauth_client_id="123", + oauth_client_secret="123", + account="testAccount", + protocol="http", + role="ANALYST", + oauth_token_request_url=f"http://{wiremock_client.wiremock_host}:{wiremock_client.wiremock_http_port}/oauth/token-request", + oauth_authorization_url=f"http://{wiremock_client.wiremock_host}:{wiremock_client.wiremock_http_port}/oauth/authorize", + host=wiremock_client.wiremock_host, + port=wiremock_client.wiremock_http_port, + ) + + assert cnx, "invalid cnx" + cnx.close() + + +@pytest.mark.skipolddriver +def test_client_creds_token_request_error( + wiremock_client: WiremockClient, + wiremock_oauth_client_creds_dir, + wiremock_generic_mappings_dir, + monkeypatch, +) -> None: + monkeypatch.setenv("SF_ENABLE_EXPERIMENTAL_AUTHENTICATION", "true") + wiremock_client.import_mapping( + wiremock_oauth_client_creds_dir / "token_request_error.json" + ) + wiremock_client.add_mapping( + wiremock_generic_mappings_dir / "snowflake_login_successful.json" + ) + wiremock_client.add_mapping( + wiremock_generic_mappings_dir / "snowflake_disconnect_successful.json" + ) + + with pytest.raises(snowflake.connector.DatabaseError) as execinfo: + with mock.patch("secrets.token_urlsafe", return_value="abc123"): + snowflake.connector.connect( + user="testUser", + authenticator="OAUTH_CLIENT_CREDENTIALS", + oauth_client_id="123", + oauth_client_secret="123", + account="testAccount", + protocol="http", + role="ANALYST", + oauth_token_request_url=f"http://{wiremock_client.wiremock_host}:{wiremock_client.wiremock_http_port}/oauth/token-request", + oauth_authorization_url=f"http://{wiremock_client.wiremock_host}:{wiremock_client.wiremock_http_port}/oauth/authorize", + host=wiremock_client.wiremock_host, + port=wiremock_client.wiremock_http_port, + ) + + assert str(execinfo.value).endswith( + "Invalid HTTP request from web browser. Idp authentication could have failed." + ) + + +@pytest.mark.skipolddriver +def test_client_creds_successful_refresh_token_flow( + wiremock_client: WiremockClient, + wiremock_oauth_refresh_token_dir, + wiremock_generic_mappings_dir, + monkeypatch, + temp_cache, +) -> None: + monkeypatch.setenv("SF_ENABLE_EXPERIMENTAL_AUTHENTICATION", "true") + + wiremock_client.import_mapping( + wiremock_generic_mappings_dir / "snowflake_login_failed.json" + ) + wiremock_client.add_mapping( + wiremock_oauth_refresh_token_dir / "refresh_successful.json" + ) + wiremock_client.add_mapping( + wiremock_generic_mappings_dir / "snowflake_login_successful.json" + ) + wiremock_client.add_mapping( + wiremock_generic_mappings_dir / "snowflake_disconnect_successful.json" + ) + user = "testUser" + access_token_key = TokenKey( + user, wiremock_client.wiremock_host, TokenType.OAUTH_ACCESS_TOKEN + ) + refresh_token_key = TokenKey( + user, wiremock_client.wiremock_host, TokenType.OAUTH_REFRESH_TOKEN + ) + temp_cache.store(access_token_key, "expired-access-token-123") + temp_cache.store(refresh_token_key, "refresh-token-123") + cnx = snowflake.connector.connect( + user=user, + authenticator="OAUTH_CLIENT_CREDENTIALS", + oauth_client_id="123", + account="testAccount", + protocol="http", + role="ANALYST", + oauth_client_secret="testClientSecret", + oauth_token_request_url=f"http://{wiremock_client.wiremock_host}:{wiremock_client.wiremock_http_port}/oauth/token-request", + host=wiremock_client.wiremock_host, + port=wiremock_client.wiremock_http_port, + oauth_security_features=("refresh_token",), + client_store_temporary_credential=True, + ) + assert cnx, "invalid cnx" + cnx.close() + + new_access_token = temp_cache.retrieve(access_token_key) + new_refresh_token = temp_cache.retrieve(refresh_token_key) + assert new_access_token == "access-token-123" + assert new_refresh_token == "refresh-token-123" + + +@pytest.mark.skipolddriver +def test_client_creds_expired_refresh_token_flow( + wiremock_client: WiremockClient, + wiremock_oauth_refresh_token_dir, + wiremock_oauth_client_creds_dir, + wiremock_generic_mappings_dir, + webbrowser_mock, + monkeypatch, + temp_cache, +) -> None: + monkeypatch.setenv("SF_ENABLE_EXPERIMENTAL_AUTHENTICATION", "true") + + wiremock_client.import_mapping( + wiremock_generic_mappings_dir / "snowflake_login_failed.json" + ) + wiremock_client.add_mapping( + wiremock_oauth_refresh_token_dir / "refresh_failed.json" + ) + wiremock_client.add_mapping( + wiremock_oauth_client_creds_dir / "successful_auth_after_failed_refresh.json" + ) + wiremock_client.add_mapping( + wiremock_generic_mappings_dir / "snowflake_login_successful.json" + ) + wiremock_client.add_mapping( + wiremock_generic_mappings_dir / "snowflake_disconnect_successful.json" + ) + + user = "testUser" + access_token_key = TokenKey( + user, wiremock_client.wiremock_host, TokenType.OAUTH_ACCESS_TOKEN + ) + refresh_token_key = TokenKey( + user, wiremock_client.wiremock_host, TokenType.OAUTH_REFRESH_TOKEN + ) + temp_cache.store(access_token_key, "expired-access-token-123") + temp_cache.store(refresh_token_key, "expired-refresh-token-123") + cnx = snowflake.connector.connect( + user=user, + authenticator="OAUTH_CLIENT_CREDENTIALS", + oauth_client_id="123", + account="testAccount", + protocol="http", + role="ANALYST", + oauth_client_secret="testClientSecret", + oauth_token_request_url=f"http://{wiremock_client.wiremock_host}:{wiremock_client.wiremock_http_port}/oauth/token-request", + host=wiremock_client.wiremock_host, + port=wiremock_client.wiremock_http_port, + oauth_security_features=("refresh_token",), + client_store_temporary_credential=True, + ) + assert cnx, "invalid cnx" + cnx.close() + + new_access_token = temp_cache.retrieve(access_token_key) + new_refresh_token = temp_cache.retrieve(refresh_token_key) + assert new_access_token == "access-token-123" + assert new_refresh_token == "refresh-token-123" + + +@pytest.mark.skipolddriver +@pytest.mark.parametrize( + "authenticator", ["OAUTH_AUTHORIZATION_CODE", "OAUTH_CLIENT_CREDENTIALS"] +) +def test_auth_is_experimental( + authenticator, + monkeypatch, +) -> None: + monkeypatch.delenv("SF_ENABLE_EXPERIMENTAL_AUTHENTICATION", False) + with pytest.raises( + snowflake.connector.ProgrammingError, + match=r"SF_ENABLE_EXPERIMENTAL_AUTHENTICATION", + ): + snowflake.connector.connect( + user="testUser", + account="testAccount", + authenticator=authenticator, + ) + + +@pytest.mark.skipolddriver +@pytest.mark.skipolddriver +@pytest.mark.parametrize( + "authenticator", ["OAUTH_AUTHORIZATION_CODE", "OAUTH_CLIENT_CREDENTIALS"] +) +def test_auth_experimental_when_variable_set_to_false( + authenticator, + monkeypatch, +) -> None: + monkeypatch.setenv("SF_ENABLE_EXPERIMENTAL_AUTHENTICATION", "false") + with pytest.raises( + snowflake.connector.ProgrammingError, + match=r"SF_ENABLE_EXPERIMENTAL_AUTHENTICATION", + ): + snowflake.connector.connect( + user="testUser", + account="testAccount", + authenticator="OAUTH_CLIENT_CREDENTIALS", + ) diff --git a/test/unit/test_wiremock_client.py b/test/unit/test_wiremock_client.py index df4cacd2da..b471f39df7 100644 --- a/test/unit/test_wiremock_client.py +++ b/test/unit/test_wiremock_client.py @@ -12,6 +12,7 @@ from ..wiremock.wiremock_utils import WiremockClient +@pytest.mark.skipolddriver @pytest.fixture(scope="session") def wiremock_client() -> Generator[WiremockClient, Any, None]: with WiremockClient() as client: diff --git a/test/wiremock/wiremock_utils.py b/test/wiremock/wiremock_utils.py index 95b7374c1e..1d036a8023 100644 --- a/test/wiremock/wiremock_utils.py +++ b/test/wiremock/wiremock_utils.py @@ -31,11 +31,12 @@ def _get_mapping_str(mapping: Union[str, dict, pathlib.Path]) -> str: class WiremockClient: - def __init__(self): + def __init__(self, forbidden_ports: Optional[List[int]] = None) -> None: self.wiremock_filename = "wiremock-standalone.jar" self.wiremock_host = "localhost" self.wiremock_http_port = None self.wiremock_https_port = None + self.forbidden_ports = forbidden_ports if forbidden_ports is not None else [] self.wiremock_dir = pathlib.Path(__file__).parent.parent.parent / ".wiremock" assert self.wiremock_dir.exists(), f"{self.wiremock_dir} does not exist" @@ -46,9 +47,11 @@ def __init__(self): ), f"{self.wiremock_jar_path} does not exist" def _start_wiremock(self): - self.wiremock_http_port = self._find_free_port() + self.wiremock_http_port = self._find_free_port( + forbidden_ports=self.forbidden_ports, + ) self.wiremock_https_port = self._find_free_port( - forbidden_ports=[self.wiremock_http_port] + forbidden_ports=self.forbidden_ports + [self.wiremock_http_port] ) self.wiremock_process = subprocess.Popen( [ @@ -119,6 +122,10 @@ def _health_check(self): return True def _reset_wiremock(self): + clean_journal_endpoint = ( + f"http://{self.wiremock_host}:{self.wiremock_http_port}/__admin/requests" + ) + requests.delete(clean_journal_endpoint) reset_endpoint = ( f"http://{self.wiremock_host}:{self.wiremock_http_port}/__admin/reset" ) From c6be86c9e98acbd8632ee19fb5f16f49113c3e6f Mon Sep 17 00:00:00 2001 From: Patryk Czajka Date: Mon, 11 Aug 2025 16:17:32 +0200 Subject: [PATCH 3/5] Link sync implementation of Oauth to async code --- src/snowflake/connector/aio/_connection.py | 91 +++++++++++++++++++ src/snowflake/connector/aio/auth/__init__.py | 6 ++ .../connector/aio/auth/_oauth_code.py | 63 +++++++++++++ .../connector/aio/auth/_oauth_credentials.py | 57 ++++++++++++ 4 files changed, 217 insertions(+) create mode 100644 src/snowflake/connector/aio/auth/_oauth_code.py create mode 100644 src/snowflake/connector/aio/auth/_oauth_credentials.py diff --git a/src/snowflake/connector/aio/_connection.py b/src/snowflake/connector/aio/_connection.py index c7a2add13d..019239cc17 100644 --- a/src/snowflake/connector/aio/_connection.py +++ b/src/snowflake/connector/aio/_connection.py @@ -32,6 +32,7 @@ from ..connection import _get_private_bytes_from_file from ..constants import ( _CONNECTIVITY_ERR_MSG, + _OAUTH_DEFAULT_SCOPE, ENV_VAR_EXPERIMENTAL_AUTHENTICATION, ENV_VAR_PARTNER, PARAMETER_AUTOCOMMIT, @@ -51,15 +52,19 @@ from ..description import PLATFORM, PYTHON_VERSION, SNOWFLAKE_CONNECTOR_VERSION from ..errorcode import ( ER_CONNECTION_IS_CLOSED, + ER_EXPERIMENTAL_AUTHENTICATION_NOT_SUPPORTED, ER_FAILED_TO_CONNECT_TO_DB, ER_INVALID_VALUE, ER_INVALID_WIF_SETTINGS, + ER_NO_CLIENT_ID, ) from ..network import ( DEFAULT_AUTHENTICATOR, EXTERNAL_BROWSER_AUTHENTICATOR, KEY_PAIR_AUTHENTICATOR, OAUTH_AUTHENTICATOR, + OAUTH_AUTHORIZATION_CODE, + OAUTH_CLIENT_CREDENTIALS, PROGRAMMATIC_ACCESS_TOKEN, REQUEST_ID, USR_PWD_MFA_AUTHENTICATOR, @@ -84,6 +89,8 @@ AuthByIdToken, AuthByKeyPair, AuthByOAuth, + AuthByOauthCode, + AuthByOauthCredentials, AuthByOkta, AuthByPAT, AuthByPlugin, @@ -307,6 +314,56 @@ async def __open_connection(self): timeout=self.login_timeout, backoff_generator=self._backoff_generator, ) + elif self._authenticator == OAUTH_AUTHORIZATION_CODE: + self._check_experimental_authentication_flag() + self._check_oauth_required_parameters() + features = self.oauth_security_features + if self._role and (self._oauth_scope == ""): + # if role is known then let's inject it into scope + self._oauth_scope = _OAUTH_DEFAULT_SCOPE.format(role=self._role) + self.auth_class = AuthByOauthCode( + application=self.application, + client_id=self._oauth_client_id, + client_secret=self._oauth_client_secret, + authentication_url=self._oauth_authorization_url.format( + host=self.host, port=self.port + ), + token_request_url=self._oauth_token_request_url.format( + host=self.host, port=self.port + ), + redirect_uri=self._oauth_redirect_uri, + scope=self._oauth_scope, + pkce_enabled=features.pkce_enabled, + token_cache=( + auth.get_token_cache() + if self._client_store_temporary_credential + else None + ), + refresh_token_enabled=features.refresh_token_enabled, + external_browser_timeout=self._external_browser_timeout, + ) + elif self._authenticator == OAUTH_CLIENT_CREDENTIALS: + self._check_experimental_authentication_flag() + self._check_oauth_required_parameters() + features = self.oauth_security_features + if self._role and (self._oauth_scope == ""): + # if role is known then let's inject it into scope + self._oauth_scope = _OAUTH_DEFAULT_SCOPE.format(role=self._role) + self.auth_class = AuthByOauthCredentials( + application=self.application, + client_id=self._oauth_client_id, + client_secret=self._oauth_client_secret, + token_request_url=self._oauth_token_request_url.format( + host=self.host, port=self.port + ), + scope=self._oauth_scope, + token_cache=( + auth.get_token_cache() + if self._client_store_temporary_credential + else None + ), + refresh_token_enabled=features.refresh_token_enabled, + ) elif self._authenticator == PROGRAMMATIC_ACCESS_TOKEN: self.auth_class = AuthByPAT(self._token) elif self._authenticator == USR_PWD_MFA_AUTHENTICATOR: @@ -1052,3 +1109,37 @@ async def is_valid(self) -> bool: except Exception as e: logger.debug("session could not be validated due to exception: %s", e) return False + + def _check_experimental_authentication_flag(self) -> None: + if os.getenv(ENV_VAR_EXPERIMENTAL_AUTHENTICATION, "false").lower() != "true": + Error.errorhandler_wrapper( + self, + None, + ProgrammingError, + { + "msg": f"Please set the '{ENV_VAR_EXPERIMENTAL_AUTHENTICATION}' environment variable true to use the '{self._authenticator}' authenticator.", + "errno": ER_EXPERIMENTAL_AUTHENTICATION_NOT_SUPPORTED, + }, + ) + + def _check_oauth_required_parameters(self) -> None: + if self._oauth_client_id is None: + Error.errorhandler_wrapper( + self, + None, + ProgrammingError, + { + "msg": "Oauth code flow requirement 'client_id' is empty", + "errno": ER_NO_CLIENT_ID, + }, + ) + if self._oauth_client_secret is None: + Error.errorhandler_wrapper( + self, + None, + ProgrammingError, + { + "msg": "Oauth code flow requirement 'client_secret' is empty", + "errno": ER_NO_CLIENT_ID, + }, + ) diff --git a/src/snowflake/connector/aio/auth/__init__.py b/src/snowflake/connector/aio/auth/__init__.py index 4091bcf06b..3caf65c6a7 100644 --- a/src/snowflake/connector/aio/auth/__init__.py +++ b/src/snowflake/connector/aio/auth/__init__.py @@ -8,6 +8,8 @@ from ._keypair import AuthByKeyPair from ._no_auth import AuthNoAuth from ._oauth import AuthByOAuth +from ._oauth_code import AuthByOauthCode +from ._oauth_credentials import AuthByOauthCredentials from ._okta import AuthByOkta from ._pat import AuthByPAT from ._usrpwdmfa import AuthByUsrPwdMfa @@ -19,6 +21,8 @@ AuthByDefault, AuthByKeyPair, AuthByOAuth, + AuthByOauthCode, + AuthByOauthCredentials, AuthByOkta, AuthByUsrPwdMfa, AuthByWebBrowser, @@ -35,6 +39,8 @@ "AuthByKeyPair", "AuthByPAT", "AuthByOAuth", + "AuthByOauthCode", + "AuthByOauthCredentials", "AuthByOkta", "AuthByUsrPwdMfa", "AuthByWebBrowser", diff --git a/src/snowflake/connector/aio/auth/_oauth_code.py b/src/snowflake/connector/aio/auth/_oauth_code.py new file mode 100644 index 0000000000..16a21b2e80 --- /dev/null +++ b/src/snowflake/connector/aio/auth/_oauth_code.py @@ -0,0 +1,63 @@ +#!/usr/bin/env python + +from __future__ import annotations + +import logging +from typing import Any + +from ...auth.oauth_code import AuthByOauthCode as AuthByOauthCodeSync +from ...token_cache import TokenCache +from ._by_plugin import AuthByPlugin as AuthByPluginAsync + +logger = logging.getLogger(__name__) + + +class AuthByOauthCode(AuthByPluginAsync, AuthByOauthCodeSync): + """Async version of OAuth authorization code authenticator.""" + + def __init__( + self, + application: str, + client_id: str, + client_secret: str, + authentication_url: str, + token_request_url: str, + redirect_uri: str, + scope: str, + pkce_enabled: bool = True, + token_cache: TokenCache | None = None, + refresh_token_enabled: bool = False, + external_browser_timeout: int | None = None, + **kwargs, + ) -> None: + """Initializes an instance with OAuth authorization code parameters.""" + logger.debug( + "OAuth authentication is not supported in async version - falling back to sync implementation" + ) + AuthByOauthCodeSync.__init__( + self, + application=application, + client_id=client_id, + client_secret=client_secret, + authentication_url=authentication_url, + token_request_url=token_request_url, + redirect_uri=redirect_uri, + scope=scope, + pkce_enabled=pkce_enabled, + token_cache=token_cache, + refresh_token_enabled=refresh_token_enabled, + external_browser_timeout=external_browser_timeout, + **kwargs, + ) + + async def reset_secrets(self) -> None: + AuthByOauthCodeSync.reset_secrets(self) + + async def prepare(self, **kwargs: Any) -> None: + AuthByOauthCodeSync.prepare(self, **kwargs) + + async def reauthenticate(self, **kwargs: Any) -> dict[str, bool]: + return AuthByOauthCodeSync.reauthenticate(self, **kwargs) + + async def update_body(self, body: dict[Any, Any]) -> None: + AuthByOauthCodeSync.update_body(self, body) diff --git a/src/snowflake/connector/aio/auth/_oauth_credentials.py b/src/snowflake/connector/aio/auth/_oauth_credentials.py new file mode 100644 index 0000000000..1557e734a6 --- /dev/null +++ b/src/snowflake/connector/aio/auth/_oauth_credentials.py @@ -0,0 +1,57 @@ +#!/usr/bin/env python + +from __future__ import annotations + +import logging +from typing import Any + +from ...auth.oauth_credentials import ( + AuthByOauthCredentials as AuthByOauthCredentialsSync, +) +from ...token_cache import TokenCache +from ._by_plugin import AuthByPlugin as AuthByPluginAsync + +logger = logging.getLogger(__name__) + + +class AuthByOauthCredentials(AuthByPluginAsync, AuthByOauthCredentialsSync): + """Async version of OAuth client credentials authenticator.""" + + def __init__( + self, + application: str, + client_id: str, + client_secret: str, + token_request_url: str, + scope: str, + token_cache: TokenCache | None = None, + refresh_token_enabled: bool = False, + **kwargs, + ) -> None: + """Initializes an instance with OAuth client credentials parameters.""" + logger.debug( + "OAuth authentication is not supported in async version - falling back to sync implementation" + ) + AuthByOauthCredentialsSync.__init__( + self, + application=application, + client_id=client_id, + client_secret=client_secret, + token_request_url=token_request_url, + scope=scope, + token_cache=token_cache, + refresh_token_enabled=refresh_token_enabled, + **kwargs, + ) + + async def reset_secrets(self) -> None: + AuthByOauthCredentialsSync.reset_secrets(self) + + async def prepare(self, **kwargs: Any) -> None: + AuthByOauthCredentialsSync.prepare(self, **kwargs) + + async def reauthenticate(self, **kwargs: Any) -> dict[str, bool]: + return AuthByOauthCredentialsSync.reauthenticate(self, **kwargs) + + async def update_body(self, body: dict[Any, Any]) -> None: + AuthByOauthCredentialsSync.update_body(self, body) From bd380f9bbea0d0545c0de590572f4cf05420f952 Mon Sep 17 00:00:00 2001 From: Patryk Czajka Date: Mon, 11 Aug 2025 17:02:48 +0200 Subject: [PATCH 4/5] Add Content-type header to Wiremock scenarios --- .../authorization_code/external_idp_custom_urls.json | 3 +++ .../new_tokens_after_failed_refresh.json | 3 +++ .../auth/oauth/authorization_code/successful_flow.json | 3 +++ .../successful_auth_after_failed_refresh.json | 3 +++ .../auth/oauth/client_credentials/successful_flow.json | 3 +++ .../auth/oauth/refresh_token/refresh_successful.json | 3 +++ .../mappings/generic/snowflake_login_failed.json | 9 ++++++--- .../mappings/generic/snowflake_login_successful.json | 3 +++ 8 files changed, 27 insertions(+), 3 deletions(-) diff --git a/test/data/wiremock/mappings/auth/oauth/authorization_code/external_idp_custom_urls.json b/test/data/wiremock/mappings/auth/oauth/authorization_code/external_idp_custom_urls.json index 0cee97115f..327c779c70 100644 --- a/test/data/wiremock/mappings/auth/oauth/authorization_code/external_idp_custom_urls.json +++ b/test/data/wiremock/mappings/auth/oauth/authorization_code/external_idp_custom_urls.json @@ -61,6 +61,9 @@ }, "response": { "status": 200, + "headers": { + "Content-Type": "application/json" + }, "jsonBody": { "access_token": "access-token-123", "refresh_token": "123", diff --git a/test/data/wiremock/mappings/auth/oauth/authorization_code/new_tokens_after_failed_refresh.json b/test/data/wiremock/mappings/auth/oauth/authorization_code/new_tokens_after_failed_refresh.json index e6cfb44085..55d60fe066 100644 --- a/test/data/wiremock/mappings/auth/oauth/authorization_code/new_tokens_after_failed_refresh.json +++ b/test/data/wiremock/mappings/auth/oauth/authorization_code/new_tokens_after_failed_refresh.json @@ -20,6 +20,9 @@ }, "response": { "status": 200, + "headers": { + "Content-Type": "application/json" + }, "jsonBody": { "access_token": "access-token-123", "refresh_token": "refresh-token-123", diff --git a/test/data/wiremock/mappings/auth/oauth/authorization_code/successful_flow.json b/test/data/wiremock/mappings/auth/oauth/authorization_code/successful_flow.json index 6bb82d855f..5ca87b98c8 100644 --- a/test/data/wiremock/mappings/auth/oauth/authorization_code/successful_flow.json +++ b/test/data/wiremock/mappings/auth/oauth/authorization_code/successful_flow.json @@ -61,6 +61,9 @@ }, "response": { "status": 200, + "headers": { + "Content-Type": "application/json" + }, "jsonBody": { "access_token": "access-token-123", "refresh_token": "123", diff --git a/test/data/wiremock/mappings/auth/oauth/client_credentials/successful_auth_after_failed_refresh.json b/test/data/wiremock/mappings/auth/oauth/client_credentials/successful_auth_after_failed_refresh.json index f6f6a9d4a8..6b8e9699f5 100644 --- a/test/data/wiremock/mappings/auth/oauth/client_credentials/successful_auth_after_failed_refresh.json +++ b/test/data/wiremock/mappings/auth/oauth/client_credentials/successful_auth_after_failed_refresh.json @@ -21,6 +21,9 @@ }, "response": { "status": 200, + "headers": { + "Content-Type": "application/json" + }, "jsonBody": { "access_token": "access-token-123", "refresh_token": "refresh-token-123", diff --git a/test/data/wiremock/mappings/auth/oauth/client_credentials/successful_flow.json b/test/data/wiremock/mappings/auth/oauth/client_credentials/successful_flow.json index 10ed78c84c..5e6137bd0e 100644 --- a/test/data/wiremock/mappings/auth/oauth/client_credentials/successful_flow.json +++ b/test/data/wiremock/mappings/auth/oauth/client_credentials/successful_flow.json @@ -23,6 +23,9 @@ }, "response": { "status": 200, + "headers": { + "Content-Type": "application/json" + }, "jsonBody": { "access_token": "access-token-123", "refresh_token": "123", diff --git a/test/data/wiremock/mappings/auth/oauth/refresh_token/refresh_successful.json b/test/data/wiremock/mappings/auth/oauth/refresh_token/refresh_successful.json index be816ed1b7..6a1ec8cf56 100644 --- a/test/data/wiremock/mappings/auth/oauth/refresh_token/refresh_successful.json +++ b/test/data/wiremock/mappings/auth/oauth/refresh_token/refresh_successful.json @@ -20,6 +20,9 @@ }, "response": { "status": 200, + "headers": { + "Content-Type": "application/json" + }, "jsonBody": { "access_token": "access-token-123", "token_type": "Bearer", diff --git a/test/data/wiremock/mappings/generic/snowflake_login_failed.json b/test/data/wiremock/mappings/generic/snowflake_login_failed.json index a9afa16a51..bf848d16b3 100644 --- a/test/data/wiremock/mappings/generic/snowflake_login_failed.json +++ b/test/data/wiremock/mappings/generic/snowflake_login_failed.json @@ -7,10 +7,10 @@ "request": { "urlPathPattern": "/session/v1/login-request", "method": "POST", - "queryParameters": { - "request_id": { + "queryParameters": { + "request_id": { "matches": ".*" - }, + }, "roleName": { "equalTo": "ANALYST" } @@ -31,6 +31,9 @@ }, "response": { "status": 200, + "headers": { + "Content-Type": "application/json" + }, "jsonBody": { "data": { "nextAction": "RETRY_LOGIN", diff --git a/test/data/wiremock/mappings/generic/snowflake_login_successful.json b/test/data/wiremock/mappings/generic/snowflake_login_successful.json index 8e6297152c..940ffad2e6 100644 --- a/test/data/wiremock/mappings/generic/snowflake_login_successful.json +++ b/test/data/wiremock/mappings/generic/snowflake_login_successful.json @@ -29,6 +29,9 @@ "response": { "status": 200, "fixedDelayMilliseconds": "1000", + "headers": { + "Content-Type": "application/json" + }, "jsonBody": { "data": { "masterToken": "token-m1", From be0b5ac1219089f588aca4899b1518c0916f5d94 Mon Sep 17 00:00:00 2001 From: Patryk Czajka Date: Tue, 12 Aug 2025 10:26:00 +0200 Subject: [PATCH 5/5] Add tests; maybe fix? --- src/snowflake/connector/aio/_connection.py | 6 +- src/snowflake/connector/aio/auth/_auth.py | 10 + .../connector/aio/auth/_oauth_code.py | 47 +- .../connector/aio/auth/_oauth_credentials.py | 47 +- test/unit/aio/test_auth_oauth_code_async.py | 49 ++ .../aio/test_auth_oauth_credentials_async.py | 46 ++ test/unit/aio/test_oauth_token_async.py | 760 ++++++++++++++++++ 7 files changed, 960 insertions(+), 5 deletions(-) create mode 100644 test/unit/aio/test_auth_oauth_code_async.py create mode 100644 test/unit/aio/test_auth_oauth_credentials_async.py create mode 100644 test/unit/aio/test_oauth_token_async.py diff --git a/src/snowflake/connector/aio/_connection.py b/src/snowflake/connector/aio/_connection.py index 019239cc17..ce0ddd8220 100644 --- a/src/snowflake/connector/aio/_connection.py +++ b/src/snowflake/connector/aio/_connection.py @@ -804,7 +804,11 @@ async def authenticate_with_retry(self, auth_instance) -> None: # SSO if it has expired await self._reauthenticate() else: - await self._authenticate(auth_instance) + # TODO pczajka: check if this is correct + # For OAuth and other auth types, call their reauthenticate method + await auth_instance.reauthenticate(conn=self) + # The reauthenticate method will call authenticate_with_retry internally, + # so we don't need to call _authenticate again here async def autocommit(self, mode) -> None: """Sets autocommit mode to True, or False. Defaults to True.""" diff --git a/src/snowflake/connector/aio/auth/_auth.py b/src/snowflake/connector/aio/auth/_auth.py index 8dbb86f963..462e107ae1 100644 --- a/src/snowflake/connector/aio/auth/_auth.py +++ b/src/snowflake/connector/aio/auth/_auth.py @@ -30,6 +30,7 @@ ACCEPT_TYPE_APPLICATION_SNOWFLAKE, CONTENT_TYPE_APPLICATION_JSON, ID_TOKEN_INVALID_LOGIN_REQUEST_GS_CODE, + OAUTH_ACCESS_TOKEN_EXPIRED_GS_CODE, PYTHON_CONNECTOR_USER_AGENT, ReauthenticationRequest, ) @@ -282,6 +283,15 @@ async def post_request_wrapper(self, url, headers, body) -> None: sqlstate=SQLSTATE_CONNECTION_WAS_NOT_ESTABLISHED, ) ) + elif errno == OAUTH_ACCESS_TOKEN_EXPIRED_GS_CODE: + raise ReauthenticationRequest( + ProgrammingError( + msg=ret["message"], + errno=int(errno), + sqlstate=SQLSTATE_CONNECTION_WAS_NOT_ESTABLISHED, + ) + ) + from . import AuthByKeyPair if isinstance(auth_instance, AuthByKeyPair): diff --git a/src/snowflake/connector/aio/auth/_oauth_code.py b/src/snowflake/connector/aio/auth/_oauth_code.py index 16a21b2e80..a4b3f35ae7 100644 --- a/src/snowflake/connector/aio/auth/_oauth_code.py +++ b/src/snowflake/connector/aio/auth/_oauth_code.py @@ -3,12 +3,15 @@ from __future__ import annotations import logging -from typing import Any +from typing import TYPE_CHECKING, Any from ...auth.oauth_code import AuthByOauthCode as AuthByOauthCodeSync from ...token_cache import TokenCache from ._by_plugin import AuthByPlugin as AuthByPluginAsync +if TYPE_CHECKING: + from .. import SnowflakeConnection + logger = logging.getLogger(__name__) @@ -57,7 +60,47 @@ async def prepare(self, **kwargs: Any) -> None: AuthByOauthCodeSync.prepare(self, **kwargs) async def reauthenticate(self, **kwargs: Any) -> dict[str, bool]: - return AuthByOauthCodeSync.reauthenticate(self, **kwargs) + """Override to use async connection properly.""" + # TODO pczajka: check if this is correct + + # Call the sync reset logic but handle the connection retry ourselves + self._reset_access_token() + if self._pop_cached_refresh_token(): + logger.debug( + "OAuth refresh token is available, try to use it and get a new access token" + ) + self._do_refresh_token(conn=kwargs.get("conn")) + # Use async authenticate_with_retry + if "conn" in kwargs: + await kwargs["conn"].authenticate_with_retry(self) + return {"success": True} async def update_body(self, body: dict[Any, Any]) -> None: AuthByOauthCodeSync.update_body(self, body) + + def _handle_failure( + self, + *, + conn: SnowflakeConnection, + ret: dict[Any, Any], + **kwargs: Any, + ) -> None: + """Override to ensure proper error handling in async context.""" + # Use sync error handling directly to avoid async/sync mismatch + from ...errors import DatabaseError, Error + from ...sqlstate import SQLSTATE_CONNECTION_WAS_NOT_ESTABLISHED + + Error.errorhandler_wrapper( + conn, + None, + DatabaseError, + { + "msg": "Failed to connect to DB: {host}:{port}, {message}".format( + host=conn._rest._host, + port=conn._rest._port, + message=ret["message"], + ), + "errno": int(ret.get("code", -1)), + "sqlstate": SQLSTATE_CONNECTION_WAS_NOT_ESTABLISHED, + }, + ) diff --git a/src/snowflake/connector/aio/auth/_oauth_credentials.py b/src/snowflake/connector/aio/auth/_oauth_credentials.py index 1557e734a6..855296e372 100644 --- a/src/snowflake/connector/aio/auth/_oauth_credentials.py +++ b/src/snowflake/connector/aio/auth/_oauth_credentials.py @@ -3,7 +3,7 @@ from __future__ import annotations import logging -from typing import Any +from typing import TYPE_CHECKING, Any from ...auth.oauth_credentials import ( AuthByOauthCredentials as AuthByOauthCredentialsSync, @@ -11,6 +11,9 @@ from ...token_cache import TokenCache from ._by_plugin import AuthByPlugin as AuthByPluginAsync +if TYPE_CHECKING: + from .. import SnowflakeConnection + logger = logging.getLogger(__name__) @@ -51,7 +54,47 @@ async def prepare(self, **kwargs: Any) -> None: AuthByOauthCredentialsSync.prepare(self, **kwargs) async def reauthenticate(self, **kwargs: Any) -> dict[str, bool]: - return AuthByOauthCredentialsSync.reauthenticate(self, **kwargs) + """Override to use async connection properly.""" + # TODO pczajka: check if this is correct + + # Call the sync reset logic but handle the connection retry ourselves + self._reset_access_token() + if self._pop_cached_refresh_token(): + logger.debug( + "OAuth refresh token is available, try to use it and get a new access token" + ) + self._do_refresh_token(conn=kwargs.get("conn")) + # Use async authenticate_with_retry + if "conn" in kwargs: + await kwargs["conn"].authenticate_with_retry(self) + return {"success": True} async def update_body(self, body: dict[Any, Any]) -> None: AuthByOauthCredentialsSync.update_body(self, body) + + def _handle_failure( + self, + *, + conn: SnowflakeConnection, + ret: dict[Any, Any], + **kwargs: Any, + ) -> None: + """Override to ensure proper error handling in async context.""" + # Use sync error handling directly to avoid async/sync mismatch + from ...errors import DatabaseError, Error + from ...sqlstate import SQLSTATE_CONNECTION_WAS_NOT_ESTABLISHED + + Error.errorhandler_wrapper( + conn, + None, + DatabaseError, + { + "msg": "Failed to connect to DB: {host}:{port}, {message}".format( + host=conn._rest._host, + port=conn._rest._port, + message=ret["message"], + ), + "errno": int(ret.get("code", -1)), + "sqlstate": SQLSTATE_CONNECTION_WAS_NOT_ESTABLISHED, + }, + ) diff --git a/test/unit/aio/test_auth_oauth_code_async.py b/test/unit/aio/test_auth_oauth_code_async.py new file mode 100644 index 0000000000..646c2df7d3 --- /dev/null +++ b/test/unit/aio/test_auth_oauth_code_async.py @@ -0,0 +1,49 @@ +#!/usr/bin/env python +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# + +from __future__ import annotations + +import os + +from snowflake.connector.aio.auth import AuthByOauthCode + + +async def test_auth_oauth_code(): + """Simple OAuth Code test.""" + # Set experimental auth flag for the test + os.environ["SF_ENABLE_EXPERIMENTAL_AUTHENTICATION"] = "true" + + auth = AuthByOauthCode( + application="test_app", + client_id="test_client_id", + client_secret="test_client_secret", + authentication_url="https://example.com/auth", + token_request_url="https://example.com/token", + redirect_uri="http://localhost:8080/callback", + scope="session:role:test_role", + pkce_enabled=True, + refresh_token_enabled=False, + ) + + body = {"data": {}} + await auth.update_body(body) + + # Check that OAuth authenticator is set + assert body["data"]["AUTHENTICATOR"] == "OAUTH", body + # OAuth type should be set to authorization_code + assert body["data"]["OAUTH_TYPE"] == "authorization_code", body + + # Clean up environment variable + del os.environ["SF_ENABLE_EXPERIMENTAL_AUTHENTICATION"] + + +def test_mro(): + """Ensure that methods from AuthByPluginAsync override those from AuthByPlugin.""" + from snowflake.connector.aio.auth import AuthByPlugin as AuthByPluginAsync + from snowflake.connector.auth import AuthByPlugin as AuthByPluginSync + + assert AuthByOauthCode.mro().index(AuthByPluginAsync) < AuthByOauthCode.mro().index( + AuthByPluginSync + ) diff --git a/test/unit/aio/test_auth_oauth_credentials_async.py b/test/unit/aio/test_auth_oauth_credentials_async.py new file mode 100644 index 0000000000..297614bd48 --- /dev/null +++ b/test/unit/aio/test_auth_oauth_credentials_async.py @@ -0,0 +1,46 @@ +#!/usr/bin/env python +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# + +from __future__ import annotations + +import os + +from snowflake.connector.aio.auth import AuthByOauthCredentials + + +async def test_auth_oauth_credentials(): + """Simple OAuth Credentials test.""" + # Set experimental auth flag for the test + os.environ["SF_ENABLE_EXPERIMENTAL_AUTHENTICATION"] = "true" + + auth = AuthByOauthCredentials( + application="test_app", + client_id="test_client_id", + client_secret="test_client_secret", + token_request_url="https://example.com/token", + scope="session:role:test_role", + refresh_token_enabled=False, + ) + + body = {"data": {}} + await auth.update_body(body) + + # Check that OAuth authenticator is set + assert body["data"]["AUTHENTICATOR"] == "OAUTH", body + # OAuth type should be set to client_credentials + assert body["data"]["OAUTH_TYPE"] == "client_credentials", body + + # Clean up environment variable + del os.environ["SF_ENABLE_EXPERIMENTAL_AUTHENTICATION"] + + +def test_mro(): + """Ensure that methods from AuthByPluginAsync override those from AuthByPlugin.""" + from snowflake.connector.aio.auth import AuthByPlugin as AuthByPluginAsync + from snowflake.connector.auth import AuthByPlugin as AuthByPluginSync + + assert AuthByOauthCredentials.mro().index( + AuthByPluginAsync + ) < AuthByOauthCredentials.mro().index(AuthByPluginSync) diff --git a/test/unit/aio/test_oauth_token_async.py b/test/unit/aio/test_oauth_token_async.py new file mode 100644 index 0000000000..3d89af5186 --- /dev/null +++ b/test/unit/aio/test_oauth_token_async.py @@ -0,0 +1,760 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# + +import logging +import pathlib +from typing import Any, Generator, Union +from unittest import mock +from unittest.mock import Mock, patch + +import pytest + +try: + from snowflake.connector.aio import SnowflakeConnection + from snowflake.connector.aio.auth import AuthByOauthCredentials +except ImportError: + pass + +import snowflake.connector.errors +from snowflake.connector.token_cache import TokenCache, TokenKey, TokenType + +from ...wiremock.wiremock_utils import WiremockClient + +logger = logging.getLogger(__name__) + + +@pytest.fixture(scope="session") +def wiremock_client() -> Generator[Union[WiremockClient, Any], Any, None]: + with WiremockClient() as client: + yield client + + +@pytest.fixture(scope="session") +def wiremock_oauth_authorization_code_dir() -> pathlib.Path: + return ( + pathlib.Path(__file__).parent.parent.parent + / "data" + / "wiremock" + / "mappings" + / "auth" + / "oauth" + / "authorization_code" + ) + + +@pytest.fixture(scope="session") +def wiremock_oauth_client_creds_dir() -> pathlib.Path: + return ( + pathlib.Path(__file__).parent.parent.parent + / "data" + / "wiremock" + / "mappings" + / "auth" + / "oauth" + / "client_credentials" + ) + + +@pytest.fixture(scope="session") +def wiremock_generic_mappings_dir() -> pathlib.Path: + return ( + pathlib.Path(__file__).parent.parent.parent + / "data" + / "wiremock" + / "mappings" + / "generic" + ) + + +@pytest.fixture(scope="session") +def wiremock_oauth_refresh_token_dir() -> pathlib.Path: + return ( + pathlib.Path(__file__).parent.parent.parent + / "data" + / "wiremock" + / "mappings" + / "auth" + / "oauth" + / "refresh_token" + ) + + +def _call_auth_server_sync(url: str): + """Sync version of auth server call for OAuth redirect simulation. + + Since async classes call sync methods, we need to use sync requests. + """ + import requests + + # Use sync requests since the OAuth implementation uses sync urllib3 + requests.get(url, allow_redirects=True, timeout=6) + + +def _webbrowser_redirect_sync(*args): + """Sync version of webbrowser redirect simulation. + + Since async OAuth classes use sync webbrowser.open(), we need sync simulation. + """ + assert len(args) == 1, "Invalid number of arguments passed to webbrowser open" + + from threading import Thread + + # Use threading to avoid blocking since sync OAuth expects this pattern + thread = Thread(target=_call_auth_server_sync, args=(args[0],)) + thread.start() + + return thread.is_alive() + + +@pytest.fixture(scope="session") +def webbrowser_mock_sync() -> Mock: + """Mock for sync webbrowser since async OAuth classes use sync webbrowser.open().""" + webbrowser_mock = Mock() + webbrowser_mock.open = _webbrowser_redirect_sync + return webbrowser_mock + + +@pytest.fixture() +def temp_cache_async(): + """Async-compatible temporary cache.""" + + class TemporaryCache(TokenCache): + def __init__(self): + self._cache = {} + + def store(self, key: TokenKey, token: str) -> None: + self._cache[(key.user, key.host, key.tokenType)] = token + + def retrieve(self, key: TokenKey) -> str: + return self._cache.get((key.user, key.host, key.tokenType)) + + def remove(self, key: TokenKey) -> None: + self._cache.pop((key.user, key.host, key.tokenType)) + + tmp_cache = TemporaryCache() + # Patch both sync and async versions to be safe since async Auth inherits from sync Auth + # but the actual Auth instance used is async + with mock.patch( + "snowflake.connector.aio.auth._auth.Auth.get_token_cache", + return_value=tmp_cache, + ), mock.patch( + "snowflake.connector.auth._auth.Auth.get_token_cache", + return_value=tmp_cache, + ): + yield tmp_cache + + +@pytest.mark.skipolddriver +@patch("snowflake.connector.auth._http_server.AuthHttpServer.DEFAULT_TIMEOUT", 30) +async def test_oauth_code_successful_flow_async( + wiremock_client: WiremockClient, + wiremock_oauth_authorization_code_dir, + wiremock_generic_mappings_dir, + webbrowser_mock_sync, + monkeypatch, +) -> None: + monkeypatch.setenv("SF_ENABLE_EXPERIMENTAL_AUTHENTICATION", "true") + monkeypatch.setenv("SNOWFLAKE_AUTH_SOCKET_REUSE_PORT", "true") + + wiremock_client.import_mapping( + wiremock_oauth_authorization_code_dir / "successful_flow.json" + ) + wiremock_client.add_mapping( + wiremock_generic_mappings_dir / "snowflake_login_successful.json" + ) + wiremock_client.add_mapping( + wiremock_generic_mappings_dir / "snowflake_disconnect_successful.json" + ) + + with mock.patch("webbrowser.open", new=webbrowser_mock_sync.open): + with mock.patch("secrets.token_urlsafe", return_value="abc123"): + cnx = SnowflakeConnection( + user="testUser", + authenticator="OAUTH_AUTHORIZATION_CODE", + oauth_client_id="123", + account="testAccount", + protocol="http", + role="ANALYST", + oauth_client_secret="testClientSecret", + oauth_token_request_url=f"http://{wiremock_client.wiremock_host}:{wiremock_client.wiremock_http_port}/oauth/token-request", + oauth_authorization_url=f"http://{wiremock_client.wiremock_host}:{wiremock_client.wiremock_http_port}/oauth/authorize", + oauth_redirect_uri="http://localhost:8009/snowflake/oauth-redirect", + host=wiremock_client.wiremock_host, + port=wiremock_client.wiremock_http_port, + ) + + await cnx.connect() + await cnx.close() + + +@pytest.mark.skipolddriver +@patch("snowflake.connector.auth._http_server.AuthHttpServer.DEFAULT_TIMEOUT", 30) +async def test_oauth_code_invalid_state_async( + wiremock_client: WiremockClient, + wiremock_oauth_authorization_code_dir, + webbrowser_mock_sync, + monkeypatch, +) -> None: + monkeypatch.setenv("SF_ENABLE_EXPERIMENTAL_AUTHENTICATION", "true") + monkeypatch.setenv("SNOWFLAKE_AUTH_SOCKET_REUSE_PORT", "true") + + wiremock_client.import_mapping( + wiremock_oauth_authorization_code_dir / "invalid_state_error.json" + ) + + with pytest.raises(snowflake.connector.errors.DatabaseError) as execinfo: + with mock.patch("webbrowser.open", new=webbrowser_mock_sync.open): + with mock.patch("secrets.token_urlsafe", return_value="abc123"): + cnx = SnowflakeConnection( + user="testUser", + authenticator="OAUTH_AUTHORIZATION_CODE", + oauth_client_id="123", + oauth_client_secret="testClientSecret", + account="testAccount", + protocol="http", + role="ANALYST", + oauth_token_request_url=f"http://{wiremock_client.wiremock_host}:{wiremock_client.wiremock_http_port}/oauth/token-request", + oauth_authorization_url=f"http://{wiremock_client.wiremock_host}:{wiremock_client.wiremock_http_port}/oauth/authorize", + oauth_redirect_uri="http://localhost:8009/snowflake/oauth-redirect", + host=wiremock_client.wiremock_host, + port=wiremock_client.wiremock_http_port, + ) + await cnx.connect() + + assert str(execinfo.value).endswith("State changed during OAuth process.") + + +@pytest.mark.skipolddriver +@patch("snowflake.connector.auth._http_server.AuthHttpServer.DEFAULT_TIMEOUT", 30) +async def test_oauth_code_scope_error_async( + wiremock_client: WiremockClient, + wiremock_oauth_authorization_code_dir, + webbrowser_mock_sync, + monkeypatch, +) -> None: + monkeypatch.setenv("SF_ENABLE_EXPERIMENTAL_AUTHENTICATION", "true") + monkeypatch.setenv("SNOWFLAKE_AUTH_SOCKET_REUSE_PORT", "true") + + wiremock_client.import_mapping( + wiremock_oauth_authorization_code_dir / "invalid_scope_error.json" + ) + + with pytest.raises(snowflake.connector.errors.DatabaseError) as execinfo: + with mock.patch("webbrowser.open", new=webbrowser_mock_sync.open): + with mock.patch("secrets.token_urlsafe", return_value="abc123"): + cnx = SnowflakeConnection( + user="testUser", + authenticator="OAUTH_AUTHORIZATION_CODE", + oauth_client_id="123", + account="testAccount", + protocol="http", + role="ANALYST", + oauth_token_request_url=f"http://{wiremock_client.wiremock_host}:{wiremock_client.wiremock_http_port}/oauth/token-request", + oauth_authorization_url=f"http://{wiremock_client.wiremock_host}:{wiremock_client.wiremock_http_port}/oauth/authorize", + oauth_redirect_uri="http://localhost:8009/snowflake/oauth-redirect", + host=wiremock_client.wiremock_host, + port=wiremock_client.wiremock_http_port, + ) + await cnx.connect() + + assert str(execinfo.value).endswith( + "Oauth callback returned an invalid_scope error: One or more scopes are not configured for the authorization server resource." + ) + + +@pytest.mark.skipolddriver +@patch("snowflake.connector.auth._http_server.AuthHttpServer.DEFAULT_TIMEOUT", 30) +async def test_oauth_code_token_request_error_async( + wiremock_oauth_authorization_code_dir, + webbrowser_mock_sync, + monkeypatch, +) -> None: + monkeypatch.setenv("SF_ENABLE_EXPERIMENTAL_AUTHENTICATION", "true") + monkeypatch.setenv("SNOWFLAKE_AUTH_SOCKET_REUSE_PORT", "true") + + with WiremockClient() as wiremock_client: + wiremock_client.import_mapping( + wiremock_oauth_authorization_code_dir / "token_request_error.json" + ) + + with pytest.raises(snowflake.connector.errors.DatabaseError) as execinfo: + with mock.patch("webbrowser.open", new=webbrowser_mock_sync.open): + with mock.patch("secrets.token_urlsafe", return_value="abc123"): + cnx = SnowflakeConnection( + user="testUser", + authenticator="OAUTH_AUTHORIZATION_CODE", + oauth_client_id="123", + oauth_client_secret="testClientSecret", + account="testAccount", + protocol="http", + role="ANALYST", + oauth_token_request_url=f"http://{wiremock_client.wiremock_host}:{wiremock_client.wiremock_http_port}/oauth/token-request", + oauth_authorization_url=f"http://{wiremock_client.wiremock_host}:{wiremock_client.wiremock_http_port}/oauth/authorize", + oauth_redirect_uri="http://localhost:8009/snowflake/oauth-redirect", + host=wiremock_client.wiremock_host, + port=wiremock_client.wiremock_http_port, + ) + await cnx.connect() + + assert str(execinfo.value).endswith( + "Invalid HTTP request from web browser. Idp authentication could have failed." + ) + + +@pytest.mark.skipolddriver +async def test_oauth_code_browser_timeout_async( + wiremock_client: WiremockClient, + wiremock_oauth_authorization_code_dir, + webbrowser_mock_sync, + monkeypatch, +) -> None: + monkeypatch.setenv("SF_ENABLE_EXPERIMENTAL_AUTHENTICATION", "true") + monkeypatch.setenv("SNOWFLAKE_AUTH_SOCKET_REUSE_PORT", "true") + + wiremock_client.import_mapping( + wiremock_oauth_authorization_code_dir + / "browser_timeout_authorization_error.json" + ) + + with pytest.raises(snowflake.connector.errors.DatabaseError) as execinfo: + with mock.patch("webbrowser.open", new=webbrowser_mock_sync.open): + with mock.patch("secrets.token_urlsafe", return_value="abc123"): + cnx = SnowflakeConnection( + user="testUser", + authenticator="OAUTH_AUTHORIZATION_CODE", + oauth_client_id="123", + oauth_client_secret="testClientSecret", + account="testAccount", + protocol="http", + role="ANALYST", + oauth_token_request_url=f"http://{wiremock_client.wiremock_host}:{wiremock_client.wiremock_http_port}/oauth/token-request", + oauth_authorization_url=f"http://{wiremock_client.wiremock_host}:{wiremock_client.wiremock_http_port}/oauth/authorize", + oauth_redirect_uri="http://localhost:8009/snowflake/oauth-redirect", + host=wiremock_client.wiremock_host, + port=wiremock_client.wiremock_http_port, + external_browser_timeout=2, + ) + await cnx.connect() + + assert str(execinfo.value).endswith( + "Unable to receive the OAuth message within a given timeout. Please check the redirect URI and try again." + ) + + +@pytest.mark.skipolddriver +@patch("snowflake.connector.auth._http_server.AuthHttpServer.DEFAULT_TIMEOUT", 30) +async def test_oauth_code_custom_urls_async( + wiremock_client: WiremockClient, + wiremock_oauth_authorization_code_dir, + wiremock_generic_mappings_dir, + webbrowser_mock_sync, + monkeypatch, +) -> None: + monkeypatch.setenv("SF_ENABLE_EXPERIMENTAL_AUTHENTICATION", "true") + monkeypatch.setenv("SNOWFLAKE_AUTH_SOCKET_REUSE_PORT", "true") + + wiremock_client.import_mapping( + wiremock_oauth_authorization_code_dir / "external_idp_custom_urls.json" + ) + wiremock_client.add_mapping( + wiremock_generic_mappings_dir / "snowflake_login_successful.json" + ) + wiremock_client.add_mapping( + wiremock_generic_mappings_dir / "snowflake_disconnect_successful.json" + ) + + with mock.patch("webbrowser.open", new=webbrowser_mock_sync.open): + with mock.patch("secrets.token_urlsafe", return_value="abc123"): + cnx = SnowflakeConnection( + user="testUser", + authenticator="OAUTH_AUTHORIZATION_CODE", + oauth_client_id="123", + oauth_client_secret="testClientSecret", + account="testAccount", + protocol="http", + role="ANALYST", + oauth_token_request_url=f"http://{wiremock_client.wiremock_host}:{wiremock_client.wiremock_http_port}/tokenrequest", + oauth_authorization_url=f"http://{wiremock_client.wiremock_host}:{wiremock_client.wiremock_http_port}/authorization", + oauth_redirect_uri="http://localhost:8009/snowflake/oauth-redirect", + host=wiremock_client.wiremock_host, + port=wiremock_client.wiremock_http_port, + ) + + await cnx.connect() + await cnx.close() + + +@pytest.mark.skipolddriver +@patch("snowflake.connector.auth._http_server.AuthHttpServer.DEFAULT_TIMEOUT", 30) +async def test_oauth_code_successful_refresh_token_flow_async( + wiremock_client: WiremockClient, + wiremock_oauth_refresh_token_dir, + wiremock_generic_mappings_dir, + monkeypatch, + temp_cache_async, +) -> None: + monkeypatch.setenv("SF_ENABLE_EXPERIMENTAL_AUTHENTICATION", "true") + monkeypatch.setenv("SNOWFLAKE_AUTH_SOCKET_REUSE_PORT", "true") + + wiremock_client.import_mapping( + wiremock_generic_mappings_dir / "snowflake_login_failed.json" + ) + wiremock_client.add_mapping( + wiremock_oauth_refresh_token_dir / "refresh_successful.json" + ) + wiremock_client.add_mapping( + wiremock_generic_mappings_dir / "snowflake_login_successful.json" + ) + wiremock_client.add_mapping( + wiremock_generic_mappings_dir / "snowflake_disconnect_successful.json" + ) + user = "testUser" + access_token_key = TokenKey( + user, wiremock_client.wiremock_host, TokenType.OAUTH_ACCESS_TOKEN + ) + refresh_token_key = TokenKey( + user, wiremock_client.wiremock_host, TokenType.OAUTH_REFRESH_TOKEN + ) + temp_cache_async.store(access_token_key, "expired-access-token-123") + temp_cache_async.store(refresh_token_key, "refresh-token-123") + cnx = SnowflakeConnection( + user=user, + authenticator="OAUTH_AUTHORIZATION_CODE", + oauth_client_id="123", + account="testAccount", + protocol="http", + role="ANALYST", + oauth_client_secret="testClientSecret", + oauth_token_request_url=f"http://{wiremock_client.wiremock_host}:{wiremock_client.wiremock_http_port}/oauth/token-request", + oauth_authorization_url=f"http://{wiremock_client.wiremock_host}:{wiremock_client.wiremock_http_port}/oauth/authorize", + oauth_redirect_uri="http://localhost:8009/snowflake/oauth-redirect", + host=wiremock_client.wiremock_host, + port=wiremock_client.wiremock_http_port, + oauth_security_features=("pkce", "refresh_token"), + client_store_temporary_credential=True, + ) + await cnx.connect() + await cnx.close() + new_access_token = temp_cache_async.retrieve(access_token_key) + new_refresh_token = temp_cache_async.retrieve(refresh_token_key) + + assert new_access_token == "access-token-123" + assert new_refresh_token == "refresh-token-123" + + +@pytest.mark.skipolddriver +@patch("snowflake.connector.auth._http_server.AuthHttpServer.DEFAULT_TIMEOUT", 30) +async def test_oauth_code_expired_refresh_token_flow_async( + wiremock_client: WiremockClient, + wiremock_oauth_refresh_token_dir, + wiremock_oauth_authorization_code_dir, + wiremock_generic_mappings_dir, + webbrowser_mock_sync, + monkeypatch, + temp_cache_async, +) -> None: + monkeypatch.setenv("SF_ENABLE_EXPERIMENTAL_AUTHENTICATION", "true") + monkeypatch.setenv("SNOWFLAKE_AUTH_SOCKET_REUSE_PORT", "true") + + wiremock_client.import_mapping( + wiremock_generic_mappings_dir / "snowflake_login_failed.json" + ) + wiremock_client.add_mapping( + wiremock_oauth_refresh_token_dir / "refresh_failed.json" + ) + wiremock_client.add_mapping( + wiremock_oauth_authorization_code_dir + / "successful_auth_after_failed_refresh.json" + ) + wiremock_client.add_mapping( + wiremock_oauth_authorization_code_dir / "new_tokens_after_failed_refresh.json" + ) + wiremock_client.add_mapping( + wiremock_generic_mappings_dir / "snowflake_login_successful.json" + ) + wiremock_client.add_mapping( + wiremock_generic_mappings_dir / "snowflake_disconnect_successful.json" + ) + + user = "testUser" + access_token_key = TokenKey( + user, wiremock_client.wiremock_host, TokenType.OAUTH_ACCESS_TOKEN + ) + refresh_token_key = TokenKey( + user, wiremock_client.wiremock_host, TokenType.OAUTH_REFRESH_TOKEN + ) + temp_cache_async.store(access_token_key, "expired-access-token-123") + temp_cache_async.store(refresh_token_key, "expired-refresh-token-123") + with mock.patch("webbrowser.open", new=webbrowser_mock_sync.open): + with mock.patch("secrets.token_urlsafe", return_value="abc123"): + cnx = SnowflakeConnection( + user=user, + authenticator="OAUTH_AUTHORIZATION_CODE", + oauth_client_id="123", + account="testAccount", + protocol="http", + role="ANALYST", + oauth_client_secret="testClientSecret", + oauth_token_request_url=f"http://{wiremock_client.wiremock_host}:{wiremock_client.wiremock_http_port}/oauth/token-request", + oauth_authorization_url=f"http://{wiremock_client.wiremock_host}:{wiremock_client.wiremock_http_port}/oauth/authorize", + oauth_redirect_uri="http://localhost:8009/snowflake/oauth-redirect", + host=wiremock_client.wiremock_host, + port=wiremock_client.wiremock_http_port, + oauth_security_features=("pkce", "refresh_token"), + client_store_temporary_credential=True, + ) + await cnx.connect() + await cnx.close() + + new_access_token = temp_cache_async.retrieve(access_token_key) + new_refresh_token = temp_cache_async.retrieve(refresh_token_key) + assert new_access_token == "access-token-123" + assert new_refresh_token == "refresh-token-123" + + +@pytest.mark.skipolddriver +async def test_client_creds_oauth_type_async(): + """Simple OAuth Client credentials type test for async.""" + auth = AuthByOauthCredentials( + "app", + "clientId", + "clientSecret", + "tokenRequestUrl", + "scope", + ) + body = {"data": {}} + await auth.update_body(body) + assert body["data"]["OAUTH_TYPE"] == "client_credentials" + + +@pytest.mark.skipolddriver +async def test_client_creds_successful_flow_async( + wiremock_client: WiremockClient, + wiremock_oauth_client_creds_dir, + wiremock_generic_mappings_dir, + monkeypatch, +) -> None: + monkeypatch.setenv("SF_ENABLE_EXPERIMENTAL_AUTHENTICATION", "true") + wiremock_client.import_mapping( + wiremock_oauth_client_creds_dir / "successful_flow.json" + ) + wiremock_client.add_mapping( + wiremock_generic_mappings_dir / "snowflake_login_successful.json" + ) + wiremock_client.add_mapping( + wiremock_generic_mappings_dir / "snowflake_disconnect_successful.json" + ) + with mock.patch("secrets.token_urlsafe", return_value="abc123"): + cnx = SnowflakeConnection( + user="testUser", + authenticator="OAUTH_CLIENT_CREDENTIALS", + oauth_client_id="123", + oauth_client_secret="123", + account="testAccount", + protocol="http", + role="ANALYST", + oauth_token_request_url=f"http://{wiremock_client.wiremock_host}:{wiremock_client.wiremock_http_port}/oauth/token-request", + oauth_authorization_url=f"http://{wiremock_client.wiremock_host}:{wiremock_client.wiremock_http_port}/oauth/authorize", + host=wiremock_client.wiremock_host, + port=wiremock_client.wiremock_http_port, + ) + + await cnx.connect() + await cnx.close() + + +@pytest.mark.skipolddriver +async def test_client_creds_token_request_error_async( + wiremock_client: WiremockClient, + wiremock_oauth_client_creds_dir, + wiremock_generic_mappings_dir, + monkeypatch, +) -> None: + monkeypatch.setenv("SF_ENABLE_EXPERIMENTAL_AUTHENTICATION", "true") + wiremock_client.import_mapping( + wiremock_oauth_client_creds_dir / "token_request_error.json" + ) + wiremock_client.add_mapping( + wiremock_generic_mappings_dir / "snowflake_login_successful.json" + ) + wiremock_client.add_mapping( + wiremock_generic_mappings_dir / "snowflake_disconnect_successful.json" + ) + + with pytest.raises(snowflake.connector.errors.DatabaseError) as execinfo: + with mock.patch("secrets.token_urlsafe", return_value="abc123"): + cnx = SnowflakeConnection( + user="testUser", + authenticator="OAUTH_CLIENT_CREDENTIALS", + oauth_client_id="123", + oauth_client_secret="123", + account="testAccount", + protocol="http", + role="ANALYST", + oauth_token_request_url=f"http://{wiremock_client.wiremock_host}:{wiremock_client.wiremock_http_port}/oauth/token-request", + oauth_authorization_url=f"http://{wiremock_client.wiremock_host}:{wiremock_client.wiremock_http_port}/oauth/authorize", + host=wiremock_client.wiremock_host, + port=wiremock_client.wiremock_http_port, + ) + await cnx.connect() + + assert str(execinfo.value).endswith( + "Invalid HTTP request from web browser. Idp authentication could have failed." + ) + + +@pytest.mark.skipolddriver +async def test_client_creds_successful_refresh_token_flow_async( + wiremock_client: WiremockClient, + wiremock_oauth_refresh_token_dir, + wiremock_generic_mappings_dir, + monkeypatch, + temp_cache_async, +) -> None: + monkeypatch.setenv("SF_ENABLE_EXPERIMENTAL_AUTHENTICATION", "true") + + wiremock_client.import_mapping( + wiremock_generic_mappings_dir / "snowflake_login_failed.json" + ) + wiremock_client.add_mapping( + wiremock_oauth_refresh_token_dir / "refresh_successful.json" + ) + wiremock_client.add_mapping( + wiremock_generic_mappings_dir / "snowflake_login_successful.json" + ) + wiremock_client.add_mapping( + wiremock_generic_mappings_dir / "snowflake_disconnect_successful.json" + ) + user = "testUser" + access_token_key = TokenKey( + user, wiremock_client.wiremock_host, TokenType.OAUTH_ACCESS_TOKEN + ) + refresh_token_key = TokenKey( + user, wiremock_client.wiremock_host, TokenType.OAUTH_REFRESH_TOKEN + ) + temp_cache_async.store(access_token_key, "expired-access-token-123") + temp_cache_async.store(refresh_token_key, "refresh-token-123") + cnx = SnowflakeConnection( + user=user, + authenticator="OAUTH_CLIENT_CREDENTIALS", + oauth_client_id="123", + account="testAccount", + protocol="http", + role="ANALYST", + oauth_client_secret="testClientSecret", + oauth_token_request_url=f"http://{wiremock_client.wiremock_host}:{wiremock_client.wiremock_http_port}/oauth/token-request", + host=wiremock_client.wiremock_host, + port=wiremock_client.wiremock_http_port, + oauth_security_features=("refresh_token",), + client_store_temporary_credential=True, + ) + await cnx.connect() + await cnx.close() + + new_access_token = temp_cache_async.retrieve(access_token_key) + new_refresh_token = temp_cache_async.retrieve(refresh_token_key) + assert new_access_token == "access-token-123" + assert new_refresh_token == "refresh-token-123" + + +@pytest.mark.skipolddriver +async def test_client_creds_expired_refresh_token_flow_async( + wiremock_client: WiremockClient, + wiremock_oauth_refresh_token_dir, + wiremock_oauth_client_creds_dir, + wiremock_generic_mappings_dir, + webbrowser_mock_sync, + monkeypatch, + temp_cache_async, +) -> None: + monkeypatch.setenv("SF_ENABLE_EXPERIMENTAL_AUTHENTICATION", "true") + + wiremock_client.import_mapping( + wiremock_generic_mappings_dir / "snowflake_login_failed.json" + ) + wiremock_client.add_mapping( + wiremock_oauth_refresh_token_dir / "refresh_failed.json" + ) + wiremock_client.add_mapping( + wiremock_oauth_client_creds_dir / "successful_auth_after_failed_refresh.json" + ) + wiremock_client.add_mapping( + wiremock_generic_mappings_dir / "snowflake_login_successful.json" + ) + wiremock_client.add_mapping( + wiremock_generic_mappings_dir / "snowflake_disconnect_successful.json" + ) + + user = "testUser" + access_token_key = TokenKey( + user, wiremock_client.wiremock_host, TokenType.OAUTH_ACCESS_TOKEN + ) + refresh_token_key = TokenKey( + user, wiremock_client.wiremock_host, TokenType.OAUTH_REFRESH_TOKEN + ) + temp_cache_async.store(access_token_key, "expired-access-token-123") + temp_cache_async.store(refresh_token_key, "expired-refresh-token-123") + cnx = SnowflakeConnection( + user=user, + authenticator="OAUTH_CLIENT_CREDENTIALS", + oauth_client_id="123", + account="testAccount", + protocol="http", + role="ANALYST", + oauth_client_secret="testClientSecret", + oauth_token_request_url=f"http://{wiremock_client.wiremock_host}:{wiremock_client.wiremock_http_port}/oauth/token-request", + host=wiremock_client.wiremock_host, + port=wiremock_client.wiremock_http_port, + oauth_security_features=("refresh_token",), + client_store_temporary_credential=True, + ) + await cnx.connect() + await cnx.close() + + new_access_token = temp_cache_async.retrieve(access_token_key) + new_refresh_token = temp_cache_async.retrieve(refresh_token_key) + assert new_access_token == "access-token-123" + assert new_refresh_token == "refresh-token-123" + + +@pytest.mark.skipolddriver +@pytest.mark.parametrize( + "authenticator", ["OAUTH_AUTHORIZATION_CODE", "OAUTH_CLIENT_CREDENTIALS"] +) +async def test_auth_is_experimental_async( + authenticator, + monkeypatch, +) -> None: + monkeypatch.delenv("SF_ENABLE_EXPERIMENTAL_AUTHENTICATION", False) + with pytest.raises( + snowflake.connector.errors.ProgrammingError, + match=r"SF_ENABLE_EXPERIMENTAL_AUTHENTICATION", + ): + cnx = SnowflakeConnection( + user="testUser", + account="testAccount", + authenticator=authenticator, + ) + await cnx.connect() + + +@pytest.mark.skipolddriver +@pytest.mark.parametrize( + "authenticator", ["OAUTH_AUTHORIZATION_CODE", "OAUTH_CLIENT_CREDENTIALS"] +) +async def test_auth_experimental_when_variable_set_to_false_async( + authenticator, + monkeypatch, +) -> None: + monkeypatch.setenv("SF_ENABLE_EXPERIMENTAL_AUTHENTICATION", "false") + with pytest.raises( + snowflake.connector.errors.ProgrammingError, + match=r"SF_ENABLE_EXPERIMENTAL_AUTHENTICATION", + ): + cnx = SnowflakeConnection( + user="testUser", + account="testAccount", + authenticator="OAUTH_CLIENT_CREDENTIALS", + ) + await cnx.connect()