diff --git a/src/snowflake/connector/auth/workload_identity.py b/src/snowflake/connector/auth/workload_identity.py index 42fd8ff1a4..50d4ca376c 100644 --- a/src/snowflake/connector/auth/workload_identity.py +++ b/src/snowflake/connector/auth/workload_identity.py @@ -78,6 +78,9 @@ def update_body(self, body: dict[typing.Any, typing.Any]) -> None: self.attestation ).value body["data"]["TOKEN"] = self.attestation.credential + body["data"].setdefault("CLIENT_ENVIRONMENT", {})[ + "WORKLOAD_IDENTITY_IMPERSONATION_PATH_LENGTH" + ] = len(self.impersonation_path or []) def prepare( self, *, conn: SnowflakeConnection | None, **kwargs: typing.Any diff --git a/test/helpers.py b/test/helpers.py index 8eb6a03a5f..444c52c755 100644 --- a/test/helpers.py +++ b/test/helpers.py @@ -2,6 +2,7 @@ from __future__ import annotations import base64 +import copy import math import os import random @@ -12,6 +13,7 @@ import pytest +from snowflake.connector.auth._auth import Auth from snowflake.connector.compat import OK if TYPE_CHECKING: @@ -260,3 +262,30 @@ def _arrow_error_stream_random_input_test(use_table_iterator): # error instance users get should be the same assert len(exception_result) assert len(result_array) == 0 + + +def create_mock_auth_body(): + ocsp_mode = Mock() + ocsp_mode.name = "ocsp_mode" + session_manager = Mock() + session_manager.clone = lambda max_retries: "session_manager" + + return Auth.base_auth_data( + "user", + "account", + "application", + "internal_application_name", + "internal_application_version", + ocsp_mode, + login_timeout=60 * 60, + network_timeout=60 * 60, + socket_timeout=60 * 60, + platform_detection_timeout_seconds=0.2, + session_manager=session_manager, + ) + + +def apply_auth_class_update_body(auth_class, req_body_before): + req_body_after = copy.deepcopy(req_body_before) + auth_class.update_body(req_body_after) + return req_body_after diff --git a/test/unit/test_auth.py b/test/unit/test_auth.py index 595528601e..cfae32f8c3 100644 --- a/test/unit/test_auth.py +++ b/test/unit/test_auth.py @@ -4,6 +4,7 @@ import inspect import sys import time +from test.helpers import apply_auth_class_update_body, create_mock_auth_body from unittest.mock import Mock, PropertyMock import pytest @@ -337,3 +338,19 @@ def test_authbyplugin_abc_api(): 'password': , \ 'kwargs': })""" ) + + +def test_auth_by_default_prepare_body_does_not_overwrite_client_environment_fields(): + password = "testpassword" + auth_class = AuthByDefault(password) + + req_body_before = create_mock_auth_body() + req_body_after = apply_auth_class_update_body(auth_class, req_body_before) + + assert all( + [ + req_body_before["data"]["CLIENT_ENVIRONMENT"][k] + == req_body_after["data"]["CLIENT_ENVIRONMENT"][k] + for k in req_body_before["data"]["CLIENT_ENVIRONMENT"] + ] + ) diff --git a/test/unit/test_auth_keypair.py b/test/unit/test_auth_keypair.py index c2c875aec1..80c27e9602 100644 --- a/test/unit/test_auth_keypair.py +++ b/test/unit/test_auth_keypair.py @@ -1,6 +1,7 @@ #!/usr/bin/env python from __future__ import annotations +from test.helpers import apply_auth_class_update_body, create_mock_auth_body from unittest.mock import Mock, PropertyMock, patch import pytest @@ -63,6 +64,22 @@ def test_auth_keypair(authenticator): assert rest.master_token == "MASTER_TOKEN" +def test_auth_prepare_body_does_not_overwrite_client_environment_fields(): + private_key_der, _ = generate_key_pair(2048) + auth_class = AuthByKeyPair(private_key=private_key_der) + + req_body_before = create_mock_auth_body() + req_body_after = apply_auth_class_update_body(auth_class, req_body_before) + + assert all( + [ + req_body_before["data"]["CLIENT_ENVIRONMENT"][k] + == req_body_after["data"]["CLIENT_ENVIRONMENT"][k] + for k in req_body_before["data"]["CLIENT_ENVIRONMENT"] + ] + ) + + def test_auth_keypair_abc(): """Simple Key Pair test using abstraction layer.""" private_key_der, public_key_der_encoded = generate_key_pair(2048) diff --git a/test/unit/test_auth_oauth.py b/test/unit/test_auth_oauth.py index 87870bda8e..7e7a913f24 100644 --- a/test/unit/test_auth_oauth.py +++ b/test/unit/test_auth_oauth.py @@ -1,6 +1,8 @@ #!/usr/bin/env python from __future__ import annotations +from test.helpers import apply_auth_class_update_body, create_mock_auth_body + try: # pragma: no cover from snowflake.connector.auth import AuthByOAuth except ImportError: @@ -18,6 +20,22 @@ def test_auth_oauth(): assert body["data"]["AUTHENTICATOR"] == "OAUTH", body +def test_auth_prepare_body_does_not_overwrite_client_environment_fields(): + token = "oAuthToken" + auth_class = AuthByOAuth(token) + + req_body_before = create_mock_auth_body() + req_body_after = apply_auth_class_update_body(auth_class, req_body_before) + + assert all( + [ + req_body_before["data"]["CLIENT_ENVIRONMENT"][k] + == req_body_after["data"]["CLIENT_ENVIRONMENT"][k] + for k in req_body_before["data"]["CLIENT_ENVIRONMENT"] + ] + ) + + @pytest.mark.parametrize("authenticator", ["oauth", "OAUTH"]) def test_oauth_authenticator_is_case_insensitive(monkeypatch, authenticator): """Test that oauth authenticator is case insensitive.""" diff --git a/test/unit/test_auth_oauth_auth_code.py b/test/unit/test_auth_oauth_auth_code.py index 8ede51facd..76894791cc 100644 --- a/test/unit/test_auth_oauth_auth_code.py +++ b/test/unit/test_auth_oauth_auth_code.py @@ -4,6 +4,7 @@ # import unittest.mock as mock +from test.helpers import apply_auth_class_update_body, create_mock_auth_body from unittest.mock import patch import pytest @@ -44,6 +45,32 @@ def test_auth_oauth_auth_code_oauth_type(omit_oauth_urls_check): ) +def test_auth_prepare_body_does_not_overwrite_client_environment_fields( + omit_oauth_urls_check, +): + auth_class = AuthByOauthCode( + "app", + "clientId", + "clientSecret", + "auth_url", + "tokenRequestUrl", + "redirectUri:{port}", + "scope", + "host", + ) + + req_body_before = create_mock_auth_body() + req_body_after = apply_auth_class_update_body(auth_class, req_body_before) + + assert all( + [ + req_body_before["data"]["CLIENT_ENVIRONMENT"][k] + == req_body_after["data"]["CLIENT_ENVIRONMENT"][k] + for k in req_body_before["data"]["CLIENT_ENVIRONMENT"] + ] + ) + + @pytest.mark.parametrize("rtr_enabled", [True, False]) def test_auth_oauth_auth_code_single_use_refresh_tokens( rtr_enabled: bool, omit_oauth_urls_check diff --git a/test/unit/test_auth_oauth_credentials.py b/test/unit/test_auth_oauth_credentials.py index 7539cdbb97..75b3cbd1ed 100644 --- a/test/unit/test_auth_oauth_credentials.py +++ b/test/unit/test_auth_oauth_credentials.py @@ -4,6 +4,8 @@ # +from test.helpers import apply_auth_class_update_body, create_mock_auth_body + import pytest from snowflake.connector.auth import AuthByOauthCredentials @@ -26,6 +28,27 @@ def test_auth_oauth_credentials_oauth_type(): ) +def test_auth_prepare_body_does_not_overwrite_client_environment_fields(): + auth_class = AuthByOauthCredentials( + "app", + "clientId", + "clientSecret", + "https://example.com/oauth/token", + "scope", + ) + + req_body_before = create_mock_auth_body() + req_body_after = apply_auth_class_update_body(auth_class, req_body_before) + + assert all( + [ + req_body_before["data"]["CLIENT_ENVIRONMENT"][k] + == req_body_after["data"]["CLIENT_ENVIRONMENT"][k] + for k in req_body_before["data"]["CLIENT_ENVIRONMENT"] + ] + ) + + @pytest.mark.parametrize( "authenticator", ["OAUTH_CLIENT_CREDENTIALS", "oauth_client_credentials"] ) diff --git a/test/unit/test_auth_okta.py b/test/unit/test_auth_okta.py index a623b5ae71..206f630969 100644 --- a/test/unit/test_auth_okta.py +++ b/test/unit/test_auth_okta.py @@ -2,6 +2,7 @@ from __future__ import annotations import logging +from test.helpers import apply_auth_class_update_body, create_mock_auth_body from unittest.mock import Mock, PropertyMock, patch import pytest @@ -19,6 +20,22 @@ from snowflake.connector.auth_okta import AuthByOkta +def test_auth_prepare_body_does_not_overwrite_client_environment_fields(): + application = "testapplication" + auth_class = AuthByOkta(application) + + req_body_before = create_mock_auth_body() + req_body_after = apply_auth_class_update_body(auth_class, req_body_before) + + assert all( + [ + req_body_before["data"]["CLIENT_ENVIRONMENT"][k] + == req_body_after["data"]["CLIENT_ENVIRONMENT"][k] + for k in req_body_before["data"]["CLIENT_ENVIRONMENT"] + ] + ) + + def test_auth_okta(): """Authentication by OKTA positive test case.""" authenticator = "https://testsso.snowflake.net/" diff --git a/test/unit/test_auth_pat.py b/test/unit/test_auth_pat.py index a8a250d8b9..eecef0088c 100644 --- a/test/unit/test_auth_pat.py +++ b/test/unit/test_auth_pat.py @@ -4,6 +4,8 @@ # from __future__ import annotations +from test.helpers import apply_auth_class_update_body, create_mock_auth_body + import pytest from snowflake.connector.auth import AuthByPAT, AuthNoAuth @@ -26,6 +28,22 @@ def test_auth_pat(): assert auth.assertion_content is None +def test_pat_prepare_body_does_not_overwrite_client_environment_fields(): + token = "patToken" + auth_class = AuthByPAT(token) + + req_body_before = create_mock_auth_body() + req_body_after = apply_auth_class_update_body(auth_class, req_body_before) + + assert all( + [ + req_body_before["data"]["CLIENT_ENVIRONMENT"][k] + == req_body_after["data"]["CLIENT_ENVIRONMENT"][k] + for k in req_body_before["data"]["CLIENT_ENVIRONMENT"] + ] + ) + + def test_auth_pat_reauthenticate(): """Test PAT reauthenticate.""" token = "patToken" diff --git a/test/unit/test_auth_webbrowser.py b/test/unit/test_auth_webbrowser.py index db97f58bb7..f649050734 100644 --- a/test/unit/test_auth_webbrowser.py +++ b/test/unit/test_auth_webbrowser.py @@ -3,6 +3,7 @@ import base64 import socket +from test.helpers import apply_auth_class_update_body, create_mock_auth_body from unittest import mock from unittest.mock import MagicMock, Mock, PropertyMock, patch @@ -792,3 +793,17 @@ def mock_webbrowser_auth_prepare( assert isinstance(conn.auth_class, AuthByWebBrowser) conn.close() + + +def test_auth_prepare_body_does_not_overwrite_client_environment_fields(): + auth_class = AuthByWebBrowser(application=APPLICATION) + req_body_before = create_mock_auth_body() + req_body_after = apply_auth_class_update_body(auth_class, req_body_before) + + assert all( + [ + req_body_before["data"]["CLIENT_ENVIRONMENT"][k] + == req_body_after["data"]["CLIENT_ENVIRONMENT"][k] + for k in req_body_before["data"]["CLIENT_ENVIRONMENT"] + ] + ) diff --git a/test/unit/test_auth_workload_identity.py b/test/unit/test_auth_workload_identity.py index 4497263e4a..7c9ff4a03e 100644 --- a/test/unit/test_auth_workload_identity.py +++ b/test/unit/test_auth_workload_identity.py @@ -1,6 +1,7 @@ import json import logging from base64 import b64decode +from test.helpers import apply_auth_class_update_body, create_mock_auth_body from unittest import mock from urllib.parse import parse_qs, urlparse @@ -14,7 +15,11 @@ HTTPError, Timeout, ) -from snowflake.connector.wif_util import AttestationProvider, get_aws_sts_hostname +from snowflake.connector.wif_util import ( + AttestationProvider, + WorkloadIdentityAttestation, + get_aws_sts_hostname, +) from ..csp_helpers import ( FakeAwsEnvironment, @@ -122,6 +127,40 @@ def test_wif_authenticator_is_case_insensitive( assert isinstance(connection.auth_class, AuthByWorkloadIdentity) +@pytest.mark.parametrize( + "provider,additional_args", + [ + (AttestationProvider.AWS, {}), + (AttestationProvider.GCP, {}), + (AttestationProvider.AZURE, {}), + ( + AttestationProvider.OIDC, + {"token": gen_dummy_id_token(sub="service-1", iss="issuer-1")}, + ), + ], +) +def test_auth_prepare_body_does_not_overwrite_client_environment_fields( + provider, additional_args +): + auth_class = AuthByWorkloadIdentity(provider=provider, **additional_args) + auth_class.attestation = WorkloadIdentityAttestation( + provider=AttestationProvider.GCP, + credential=None, + user_identifier_components=None, + ) + + req_body_before = create_mock_auth_body() + req_body_after = apply_auth_class_update_body(auth_class, req_body_before) + + assert all( + [ + req_body_before["data"]["CLIENT_ENVIRONMENT"][k] + == req_body_after["data"]["CLIENT_ENVIRONMENT"][k] + for k in req_body_before["data"]["CLIENT_ENVIRONMENT"] + ] + ) + + # -- OIDC Tests -- @@ -136,6 +175,7 @@ def test_explicit_oidc_valid_inline_token_plumbed_to_api(): "AUTHENTICATOR": "WORKLOAD_IDENTITY", "PROVIDER": "OIDC", "TOKEN": dummy_token, + "CLIENT_ENVIRONMENT": {"WORKLOAD_IDENTITY_IMPERSONATION_PATH_LENGTH": 0}, } @@ -191,6 +231,9 @@ def test_explicit_aws_encodes_audience_host_signature_to_api( data = extract_api_data(auth_class) assert data["AUTHENTICATOR"] == "WORKLOAD_IDENTITY" assert data["PROVIDER"] == "AWS" + assert ( + data["CLIENT_ENVIRONMENT"]["WORKLOAD_IDENTITY_IMPERSONATION_PATH_LENGTH"] == 0 + ) verify_aws_token(data["TOKEN"], fake_aws_environment.region) @@ -324,6 +367,7 @@ def test_explicit_gcp_plumbs_token_to_api( "AUTHENTICATOR": "WORKLOAD_IDENTITY", "PROVIDER": "GCP", "TOKEN": fake_gce_metadata_service.token, + "CLIENT_ENVIRONMENT": {"WORKLOAD_IDENTITY_IMPERSONATION_PATH_LENGTH": 0}, } @@ -373,6 +417,7 @@ def test_gcp_calls_correct_apis_and_populates_auth_data_for_final_sa( "AUTHENTICATOR": "WORKLOAD_IDENTITY", "PROVIDER": "GCP", "TOKEN": sa3_id_token, + "CLIENT_ENVIRONMENT": {"WORKLOAD_IDENTITY_IMPERSONATION_PATH_LENGTH": 2}, } @@ -425,6 +470,7 @@ def test_explicit_azure_plumbs_token_to_api(fake_azure_metadata_service): "AUTHENTICATOR": "WORKLOAD_IDENTITY", "PROVIDER": "AZURE", "TOKEN": fake_azure_metadata_service.token, + "CLIENT_ENVIRONMENT": {"WORKLOAD_IDENTITY_IMPERSONATION_PATH_LENGTH": 0}, }