diff --git a/ci/test_wif.sh b/ci/test_wif.sh index 741948764d..4e01cab8ae 100755 --- a/ci/test_wif.sh +++ b/ci/test_wif.sh @@ -12,8 +12,11 @@ run_tests_and_set_result() { local host="$2" local snowflake_host="$3" local rsa_key_path="$4" + local snowflake_user="$5" + local impersonation_path="$6" + local snowflake_user_for_impersonation="$7" - ssh -i "$rsa_key_path" -o IdentitiesOnly=yes -p 443 "$host" env BRANCH="$BRANCH" SNOWFLAKE_TEST_WIF_HOST="$snowflake_host" SNOWFLAKE_TEST_WIF_PROVIDER="$provider" SNOWFLAKE_TEST_WIF_ACCOUNT="$SNOWFLAKE_TEST_WIF_ACCOUNT" bash << EOF + ssh -i "$rsa_key_path" -o IdentitiesOnly=yes -p 443 "$host" env BRANCH="$BRANCH" SNOWFLAKE_TEST_WIF_HOST="$snowflake_host" SNOWFLAKE_TEST_WIF_PROVIDER="$provider" SNOWFLAKE_TEST_WIF_ACCOUNT="$SNOWFLAKE_TEST_WIF_ACCOUNT" SNOWFLAKE_TEST_WIF_USERNAME="$snowflake_user" SNOWFLAKE_TEST_WIF_IMPERSONATION_PATH="$impersonation_path" SNOWFLAKE_TEST_WIF_USERNAME_IMPERSONATION="$snowflake_user_for_impersonation" bash << EOF set -e set -o pipefail docker run \ @@ -24,6 +27,9 @@ run_tests_and_set_result() { -e SNOWFLAKE_TEST_WIF_PROVIDER \ -e SNOWFLAKE_TEST_WIF_HOST \ -e SNOWFLAKE_TEST_WIF_ACCOUNT \ + -e SNOWFLAKE_TEST_WIF_USERNAME \ + -e SNOWFLAKE_TEST_WIF_IMPERSONATION_PATH \ + -e SNOWFLAKE_TEST_WIF_USERNAME_IMPERSONATION \ snowflakedb/client-python-test:1 \ bash -c " echo 'Running tests on branch: \$BRANCH' @@ -77,9 +83,9 @@ setup_parameters # Run tests for all cloud providers EXIT_STATUS=0 set +e # Don't exit on first failure -run_tests_and_set_result "AZURE" "$HOST_AZURE" "$SNOWFLAKE_TEST_WIF_HOST_AZURE" "$RSA_KEY_PATH_AWS_AZURE" -run_tests_and_set_result "AWS" "$HOST_AWS" "$SNOWFLAKE_TEST_WIF_HOST_AWS" "$RSA_KEY_PATH_AWS_AZURE" -run_tests_and_set_result "GCP" "$HOST_GCP" "$SNOWFLAKE_TEST_WIF_HOST_GCP" "$RSA_KEY_PATH_GCP" +run_tests_and_set_result "AZURE" "$HOST_AZURE" "$SNOWFLAKE_TEST_WIF_HOST_AZURE" "$RSA_KEY_PATH_AWS_AZURE" "$SNOWFLAKE_TEST_WIF_USERNAME_AZURE" "$SNOWFLAKE_TEST_WIF_IMPERSONATION_PATH_AZURE" "$SNOWFLAKE_TEST_WIF_USERNAME_AZURE_IMPERSONATION" +run_tests_and_set_result "AWS" "$HOST_AWS" "$SNOWFLAKE_TEST_WIF_HOST_AWS" "$RSA_KEY_PATH_AWS_AZURE" "$SNOWFLAKE_TEST_WIF_USERNAME_AWS" "$SNOWFLAKE_TEST_WIF_IMPERSONATION_PATH_AWS" "$SNOWFLAKE_TEST_WIF_USERNAME_AWS_IMPERSONATION" +run_tests_and_set_result "GCP" "$HOST_GCP" "$SNOWFLAKE_TEST_WIF_HOST_GCP" "$RSA_KEY_PATH_GCP" "$SNOWFLAKE_TEST_WIF_USERNAME_GCP" "$SNOWFLAKE_TEST_WIF_IMPERSONATION_PATH_GCP" "$SNOWFLAKE_TEST_WIF_USERNAME_GCP_IMPERSONATION" set -e # Re-enable exit on error echo "Exit status: $EXIT_STATUS" exit $EXIT_STATUS diff --git a/ci/wif/parameters/parameters_wif.json.gpg b/ci/wif/parameters/parameters_wif.json.gpg index 591938e357..a9f2503585 100644 Binary files a/ci/wif/parameters/parameters_wif.json.gpg and b/ci/wif/parameters/parameters_wif.json.gpg differ diff --git a/ci/wif/parameters/rsa_wif_aws_azure.gpg b/ci/wif/parameters/rsa_wif_aws_azure.gpg index 94975ad9c2..a766980d75 100644 Binary files a/ci/wif/parameters/rsa_wif_aws_azure.gpg and b/ci/wif/parameters/rsa_wif_aws_azure.gpg differ diff --git a/ci/wif/parameters/rsa_wif_gcp.gpg b/ci/wif/parameters/rsa_wif_gcp.gpg index 4c283c06e6..a766980d75 100644 Binary files a/ci/wif/parameters/rsa_wif_gcp.gpg and b/ci/wif/parameters/rsa_wif_gcp.gpg differ diff --git a/ci/wif/test_wif.sh b/ci/wif/test_wif.sh index 3053d6dcf3..b8ce428838 100755 --- a/ci/wif/test_wif.sh +++ b/ci/wif/test_wif.sh @@ -5,6 +5,14 @@ set -o pipefail export SF_OCSP_TEST_MODE=true export RUN_WIF_TESTS=true -/opt/python/cp39-cp39/bin/python -m pip install --break-system-packages -e '.[aio]' -/opt/python/cp39-cp39/bin/python -m pip install --break-system-packages pytest -/opt/python/cp39-cp39/bin/python -m pytest test/wif/* +# setup pytest +/opt/python/cp312-cp312/bin/python -m pip install --break-system-packages pytest pytest-asyncio + +# test WIF without asyncio installed +/opt/python/cp312-cp312/bin/python -m pip install --break-system-packages -e . +/opt/python/cp312-cp312/bin/python -m pytest test/wif/ --ignore test/wif/test_wif_async.py + +# test WIF with asyncio installed +# /opt/python/cp312-cp312/bin/python -m pip install --break-system-packages -e '.[aio]' +# run all tests to see whether installation does not break anything +# /opt/python/cp312-cp312/bin/python -m pytest test/wif/ diff --git a/src/snowflake/connector/aio/_connection.py b/src/snowflake/connector/aio/_connection.py index db6e7eae95..843517a5ba 100644 --- a/src/snowflake/connector/aio/_connection.py +++ b/src/snowflake/connector/aio/_connection.py @@ -414,10 +414,28 @@ async def __open_connection(self): "errno": ER_INVALID_WIF_SETTINGS, }, ) + if ( + self._workload_identity_impersonation_path + and self._workload_identity_provider + not in ( + AttestationProvider.GCP, + AttestationProvider.AWS, + ) + ): + Error.errorhandler_wrapper( + self, + None, + ProgrammingError, + { + "msg": "workload_identity_impersonation_path is currently only supported for GCP and AWS.", + "errno": ER_INVALID_WIF_SETTINGS, + }, + ) self.auth_class = AuthByWorkloadIdentity( provider=self._workload_identity_provider, token=self._token, entra_resource=self._workload_identity_entra_resource, + impersonation_path=self._workload_identity_impersonation_path, ) else: # okta URL, e.g., https://.okta.com/ diff --git a/src/snowflake/connector/aio/_session_manager.py b/src/snowflake/connector/aio/_session_manager.py index aba3e0b840..6974df159f 100644 --- a/src/snowflake/connector/aio/_session_manager.py +++ b/src/snowflake/connector/aio/_session_manager.py @@ -336,7 +336,10 @@ class SessionManager(_RequestVerbsUsingSessionMixin, SessionManagerSync): """ def __init__( - self, config: AioHttpConfig | None = None, **http_config_kwargs + self, + config: AioHttpConfig | None = None, + max_retries: int | None = None, # TODO: remove this after rebase + **http_config_kwargs, ) -> None: """Create a new async SessionManager.""" if config is None: @@ -435,6 +438,7 @@ def clone( *, use_pooling: bool | None = None, connector_factory: ConnectorFactory | None = None, + **kwargs, # TODO: remove this after rebase ) -> SessionManager: """Return a new async SessionManager sharing this instance's config.""" overrides: dict[str, Any] = {} diff --git a/src/snowflake/connector/aio/_wif_util.py b/src/snowflake/connector/aio/_wif_util.py index 1f2a62ff5c..c29a7e8b2d 100644 --- a/src/snowflake/connector/aio/_wif_util.py +++ b/src/snowflake/connector/aio/_wif_util.py @@ -25,6 +25,10 @@ logger = logging.getLogger(__name__) +GCP_METADATA_SERVICE_ACCOUNT_BASE_URL = ( + "http://169.254.169.254/computeMetadata/v1/instance/service-accounts/default" +) + async def get_aws_region() -> str: """Get the current AWS workload's region.""" @@ -41,12 +45,37 @@ async def get_aws_region() -> str: return region -async def create_aws_attestation() -> WorkloadIdentityAttestation: +async def get_aws_session(impersonation_path: list[str] | None = None): + """Creates an aioboto3 session with the appropriate credentials. + + If impersonation_path is provided, this uses the role at the end of the path. Otherwise, this uses the role attached to the current workload. + """ + session = aioboto3.Session() + + impersonation_path = impersonation_path or [] + for arn in impersonation_path: + async with session.client("sts") as sts_client: + response = await sts_client.assume_role( + RoleArn=arn, RoleSessionName="identity-federation-session" + ) + creds = response["Credentials"] + session = aioboto3.Session( + aws_access_key_id=creds["AccessKeyId"], + aws_secret_access_key=creds["SecretAccessKey"], + aws_session_token=creds["SessionToken"], + ) + return session + + +async def create_aws_attestation( + impersonation_path: list[str] | None = None, +) -> WorkloadIdentityAttestation: """Tries to create a workload identity attestation for AWS. If the application isn't running on AWS or no credentials were found, raises an error. """ - session = aioboto3.Session() + session = await get_aws_session(impersonation_path) + aws_creds = await session.get_credentials() if not aws_creds: raise ProgrammingError( @@ -81,30 +110,108 @@ async def create_aws_attestation() -> WorkloadIdentityAttestation: ) -async def create_gcp_attestation( - session_manager: SessionManager | None = None, -) -> WorkloadIdentityAttestation: - """Tries to create a workload identity attestation for GCP. +async def get_gcp_access_token(session_manager: SessionManager) -> str: + """Gets a GCP access token from the metadata server. If the application isn't running on GCP or no credentials were found, raises an error. """ try: res = await session_manager.request( method="GET", - url=f"http://169.254.169.254/computeMetadata/v1/instance/service-accounts/default/identity?audience={SNOWFLAKE_AUDIENCE}", + url=f"{GCP_METADATA_SERVICE_ACCOUNT_BASE_URL}/token", headers={ "Metadata-Flavor": "Google", }, ) content = await res.content.read() - jwt_str = content.decode("utf-8") + response_text = content.decode("utf-8") + return json.loads(response_text)["access_token"] except Exception as e: raise ProgrammingError( - msg=f"Error fetching GCP metadata: {e}. Ensure the application is running on GCP.", + msg=f"Error fetching GCP access token: {e}. Ensure the application is running on GCP.", + errno=ER_WIF_CREDENTIALS_NOT_FOUND, + ) + + +async def get_gcp_identity_token_via_impersonation( + impersonation_path: list[str], session_manager: SessionManager +) -> str: + """Gets a GCP identity token from the metadata server. + + If the application isn't running on GCP or no credentials were found, raises an error. + """ + if not impersonation_path: + raise ProgrammingError( + msg="Error: impersonation_path cannot be empty.", errno=ER_WIF_CREDENTIALS_NOT_FOUND, ) + current_sa_token = await get_gcp_access_token(session_manager) + impersonation_path = [ + f"projects/-/serviceAccounts/{client_id}" for client_id in impersonation_path + ] + try: + res = await session_manager.post( + url=f"https://iamcredentials.googleapis.com/v1/{impersonation_path[-1]}:generateIdToken", + headers={ + "Authorization": f"Bearer {current_sa_token}", + "Content-Type": "application/json", + }, + json={ + "delegates": impersonation_path[:-1], + "audience": SNOWFLAKE_AUDIENCE, + }, + ) + + content = await res.content.read() + response_text = content.decode("utf-8") + return json.loads(response_text)["token"] + except Exception as e: + raise ProgrammingError( + msg=f"Error fetching GCP identity token for impersonated GCP service account '{impersonation_path[-1]}': {e}. Ensure the application is running on GCP.", + errno=ER_WIF_CREDENTIALS_NOT_FOUND, + ) + + +async def get_gcp_identity_token(session_manager: SessionManager) -> str: + """Gets a GCP identity token from the metadata server. + + If the application isn't running on GCP or no credentials were found, raises an error. + """ + try: + res = await session_manager.request( + method="GET", + url=f"{GCP_METADATA_SERVICE_ACCOUNT_BASE_URL}/identity?audience={SNOWFLAKE_AUDIENCE}", + headers={ + "Metadata-Flavor": "Google", + }, + ) + + content = await res.content.read() + return content.decode("utf-8") + except Exception as e: + raise ProgrammingError( + msg=f"Error fetching GCP identity token: {e}. Ensure the application is running on GCP.", + errno=ER_WIF_CREDENTIALS_NOT_FOUND, + ) + + +async def create_gcp_attestation( + session_manager: SessionManager, + impersonation_path: list[str] | None = None, +) -> WorkloadIdentityAttestation: + """Tries to create a workload identity attestation for GCP. + + If the application isn't running on GCP or no credentials were found, raises an error. + """ + if impersonation_path: + jwt_str = await get_gcp_identity_token_via_impersonation( + impersonation_path, session_manager + ) + else: + jwt_str = await get_gcp_identity_token(session_manager) + _, subject = extract_iss_and_sub_without_signature_verification(jwt_str) return WorkloadIdentityAttestation( AttestationProvider.GCP, jwt_str, {"sub": subject} @@ -179,6 +286,7 @@ async def create_attestation( provider: AttestationProvider | None, entra_resource: str | None = None, token: str | None = None, + impersonation_path: list[str] | None = None, session_manager: SessionManager | None = None, ) -> WorkloadIdentityAttestation: """Entry point to create an attestation using the given provider. @@ -189,15 +297,15 @@ async def create_attestation( session_manager = ( session_manager.clone() if session_manager - else SessionManagerFactory.get_manager(use_pooling=True) + else SessionManager(use_pooling=True, max_retries=0) ) if provider == AttestationProvider.AWS: - return await create_aws_attestation() + return await create_aws_attestation(impersonation_path) elif provider == AttestationProvider.AZURE: return await create_azure_attestation(entra_resource, session_manager) elif provider == AttestationProvider.GCP: - return await create_gcp_attestation(session_manager) + return await create_gcp_attestation(session_manager, impersonation_path) elif provider == AttestationProvider.OIDC: return create_oidc_attestation(token) else: diff --git a/src/snowflake/connector/aio/auth/_workload_identity.py b/src/snowflake/connector/aio/auth/_workload_identity.py index 7f13b5afd9..eb0d43533d 100644 --- a/src/snowflake/connector/aio/auth/_workload_identity.py +++ b/src/snowflake/connector/aio/auth/_workload_identity.py @@ -22,6 +22,7 @@ def __init__( provider: AttestationProvider, token: str | None = None, entra_resource: str | None = None, + impersonation_path: list[str] | None = None, **kwargs, ) -> None: """Initializes an instance with workload identity authentication.""" @@ -30,6 +31,7 @@ def __init__( provider=provider, token=token, entra_resource=entra_resource, + impersonation_path=impersonation_path, **kwargs, ) @@ -44,7 +46,10 @@ async def prepare( self.provider, self.entra_resource, self.token, - session_manager=conn._session_manager.clone() if conn else None, + self.impersonation_path, + session_manager=( + conn._session_manager.clone(max_retries=0) if conn else None + ), ) async def reauthenticate(self, **kwargs: Any) -> dict[str, bool]: diff --git a/src/snowflake/connector/auth/workload_identity.py b/src/snowflake/connector/auth/workload_identity.py index c4c0b8457b..2531ff4412 100644 --- a/src/snowflake/connector/auth/workload_identity.py +++ b/src/snowflake/connector/auth/workload_identity.py @@ -55,12 +55,14 @@ def __init__( provider: AttestationProvider, token: str | None = None, entra_resource: str | None = None, + impersonation_path: list[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) self.provider = provider self.token = token self.entra_resource = entra_resource + self.impersonation_path = impersonation_path self.attestation: WorkloadIdentityAttestation | None = None @@ -76,6 +78,9 @@ def update_body(self, body: dict[typing.Any, typing.Any]) -> None: self.attestation ).value body["data"]["TOKEN"] = self.attestation.credential + body["data"].setdefault("CLIENT_ENVIRONMENT", {})[ + "WORKLOAD_IDENTITY_IMPERSONATION_PATH_LENGTH" + ] = len(self.impersonation_path or []) def prepare( self, *, conn: SnowflakeConnection | None, **kwargs: typing.Any @@ -85,7 +90,10 @@ def prepare( self.provider, self.entra_resource, self.token, - session_manager=conn._session_manager.clone() if conn else None, + self.impersonation_path, + session_manager=( + conn._session_manager.clone(max_retries=0) if conn else None + ), ) def reauthenticate(self, **kwargs: typing.Any) -> dict[str, bool]: diff --git a/src/snowflake/connector/connection.py b/src/snowflake/connector/connection.py index 38f4e5301d..11e0d17fef 100644 --- a/src/snowflake/connector/connection.py +++ b/src/snowflake/connector/connection.py @@ -214,6 +214,7 @@ def _get_private_bytes_from_file( "authenticator": (DEFAULT_AUTHENTICATOR, (type(None), str)), "workload_identity_provider": (None, (type(None), AttestationProvider)), "workload_identity_entra_resource": (None, (type(None), str)), + "workload_identity_impersonation_path": (None, (type(None), list[str])), "mfa_callback": (None, (type(None), Callable)), "password_callback": (None, (type(None), Callable)), "auth_class": (None, (type(None), AuthByPlugin)), @@ -1355,10 +1356,28 @@ def __open_connection(self): "errno": ER_INVALID_WIF_SETTINGS, }, ) + if ( + self._workload_identity_impersonation_path + and self._workload_identity_provider + not in ( + AttestationProvider.GCP, + AttestationProvider.AWS, + ) + ): + Error.errorhandler_wrapper( + self, + None, + ProgrammingError, + { + "msg": "workload_identity_impersonation_path is currently only supported for GCP and AWS.", + "errno": ER_INVALID_WIF_SETTINGS, + }, + ) self.auth_class = AuthByWorkloadIdentity( provider=self._workload_identity_provider, token=self._token, entra_resource=self._workload_identity_entra_resource, + impersonation_path=self._workload_identity_impersonation_path, ) else: # okta URL, e.g., https://.okta.com/ @@ -1531,6 +1550,7 @@ def __config(self, **kwargs): workload_identity_dependent_options = [ "workload_identity_provider", "workload_identity_entra_resource", + "workload_identity_impersonation_path", ] for dependent_option in workload_identity_dependent_options: if ( diff --git a/src/snowflake/connector/session_manager.py b/src/snowflake/connector/session_manager.py index 43eeb87ee4..04add6c420 100644 --- a/src/snowflake/connector/session_manager.py +++ b/src/snowflake/connector/session_manager.py @@ -350,7 +350,12 @@ class SessionManager(_RequestVerbsUsingSessionMixin): direct HTTP library calls. """ - def __init__(self, config: HttpConfig | None = None, **http_config_kwargs) -> None: + def __init__( + self, + config: HttpConfig | None = None, + max_retries: int | None = None, # TODO: remove this after rebase + **http_config_kwargs, + ) -> None: """ Create a new SessionManager. """ @@ -508,6 +513,7 @@ def clone( *, use_pooling: bool | None = None, adapter_factory: AdapterFactory | None = None, + **kwargs, # TODO: remove this after rebase ) -> SessionManager: """Return a new *stateless* SessionManager sharing this instance’s config. diff --git a/src/snowflake/connector/wif_util.py b/src/snowflake/connector/wif_util.py index 406ee12725..80f7d73d9a 100644 --- a/src/snowflake/connector/wif_util.py +++ b/src/snowflake/connector/wif_util.py @@ -20,6 +20,9 @@ logger = logging.getLogger(__name__) SNOWFLAKE_AUDIENCE = "snowflakecomputing.com" DEFAULT_ENTRA_SNOWFLAKE_RESOURCE = "api://fd3f753b-eed3-462c-b6a7-a4b5bb650aad" +GCP_METADATA_SERVICE_ACCOUNT_BASE_URL = ( + "http://169.254.169.254/computeMetadata/v1/instance/service-accounts/default" +) @unique @@ -142,15 +145,37 @@ def get_aws_sts_hostname(region: str, partition: str) -> str: ) +def get_aws_session(impersonation_path: list[str] | None = None): + """Creates a boto3 session with the appropriate credentials. + + If impersonation_path is provided, this uses the role at the end of the path. Otherwise, this uses the role attached to the current workload. + """ + session = boto3.session.Session() + + impersonation_path = impersonation_path or [] + for arn in impersonation_path: + response = session.client("sts").assume_role( + RoleArn=arn, RoleSessionName="identity-federation-session" + ) + creds = response["Credentials"] + session = boto3.session.Session( + aws_access_key_id=creds["AccessKeyId"], + aws_secret_access_key=creds["SecretAccessKey"], + aws_session_token=creds["SessionToken"], + ) + return session + + def create_aws_attestation( - session_manager: SessionManager | None = None, + impersonation_path: list[str] | None = None, ) -> WorkloadIdentityAttestation: """Tries to create a workload identity attestation for AWS. If the application isn't running on AWS or no credentials were found, raises an error. """ # TODO: SNOW-2223669 Investigate if our adapters - containing settings of http traffic - should be passed here as boto urllib3session. Those requests go to local servers, so they do not need Proxy setup or Headers customization in theory. But we may want to have all the traffic going through one class (e.g. Adapter or mixin). - session = boto3.session.Session() + session = get_aws_session(impersonation_path) + aws_creds = session.get_credentials() if not aws_creds: raise ProgrammingError( @@ -184,29 +209,103 @@ def create_aws_attestation( ) -def create_gcp_attestation( - session_manager: SessionManager | None = None, -) -> WorkloadIdentityAttestation: - """Tries to create a workload identity attestation for GCP. +def get_gcp_access_token(session_manager: SessionManager) -> str: + """Gets a GCP access token from the metadata server. If the application isn't running on GCP or no credentials were found, raises an error. """ try: res = session_manager.request( method="GET", - url=f"http://169.254.169.254/computeMetadata/v1/instance/service-accounts/default/identity?audience={SNOWFLAKE_AUDIENCE}", + url=f"{GCP_METADATA_SERVICE_ACCOUNT_BASE_URL}/token", headers={ "Metadata-Flavor": "Google", }, ) res.raise_for_status() + return res.json()["access_token"] except Exception as e: raise ProgrammingError( - msg=f"Error fetching GCP metadata: {e}. Ensure the application is running on GCP.", + msg=f"Error fetching GCP access token: {e}. Ensure the application is running on GCP.", errno=ER_WIF_CREDENTIALS_NOT_FOUND, ) - jwt_str = res.content.decode("utf-8") + +def get_gcp_identity_token_via_impersonation( + impersonation_path: list[str], session_manager: SessionManager +) -> str: + """Gets a GCP identity token from the metadata server. + + If the application isn't running on GCP or no credentials were found, raises an error. + """ + if not impersonation_path: + raise ProgrammingError( + msg="Error: impersonation_path cannot be empty.", + errno=ER_WIF_CREDENTIALS_NOT_FOUND, + ) + + current_sa_token = get_gcp_access_token(session_manager) + impersonation_path = [ + f"projects/-/serviceAccounts/{client_id}" for client_id in impersonation_path + ] + try: + res = session_manager.post( + url=f"https://iamcredentials.googleapis.com/v1/{impersonation_path[-1]}:generateIdToken", + headers={ + "Authorization": f"Bearer {current_sa_token}", + "Content-Type": "application/json", + }, + json={ + "delegates": impersonation_path[:-1], + "audience": SNOWFLAKE_AUDIENCE, + }, + ) + res.raise_for_status() + return res.json()["token"] + except Exception as e: + raise ProgrammingError( + msg=f"Error fetching GCP identity token for impersonated GCP service account '{impersonation_path[-1]}': {e}. Ensure the application is running on GCP.", + errno=ER_WIF_CREDENTIALS_NOT_FOUND, + ) + + +def get_gcp_identity_token(session_manager: SessionManager) -> str: + """Gets a GCP identity token from the metadata server. + + If the application isn't running on GCP or no credentials were found, raises an error. + """ + try: + res = session_manager.request( + method="GET", + url=f"{GCP_METADATA_SERVICE_ACCOUNT_BASE_URL}/identity?audience={SNOWFLAKE_AUDIENCE}", + headers={ + "Metadata-Flavor": "Google", + }, + ) + res.raise_for_status() + return res.content.decode("utf-8") + except Exception as e: + raise ProgrammingError( + msg=f"Error fetching GCP identity token: {e}. Ensure the application is running on GCP.", + errno=ER_WIF_CREDENTIALS_NOT_FOUND, + ) + + +def create_gcp_attestation( + session_manager: SessionManager, + impersonation_path: list[str] | None = None, +) -> WorkloadIdentityAttestation: + """Tries to create a workload identity attestation for GCP. + + If the application isn't running on GCP or no credentials were found, raises an error. + """ + if impersonation_path: + jwt_str = get_gcp_identity_token_via_impersonation( + impersonation_path, session_manager + ) + else: + jwt_str = get_gcp_identity_token(session_manager) + _, subject = extract_iss_and_sub_without_signature_verification(jwt_str) return WorkloadIdentityAttestation( AttestationProvider.GCP, jwt_str, {"sub": subject} @@ -295,6 +394,7 @@ def create_attestation( provider: AttestationProvider, entra_resource: str | None = None, token: str | None = None, + impersonation_path: list[str] | None = None, session_manager: SessionManager | None = None, ) -> WorkloadIdentityAttestation: """Entry point to create an attestation using the given provider. @@ -303,15 +403,17 @@ def create_attestation( """ entra_resource = entra_resource or DEFAULT_ENTRA_SNOWFLAKE_RESOURCE session_manager = ( - session_manager.clone() if session_manager else SessionManager(use_pooling=True) + session_manager.clone() + if session_manager + else SessionManager(use_pooling=True, max_retries=0) ) if provider == AttestationProvider.AWS: - return create_aws_attestation(session_manager) + return create_aws_attestation(impersonation_path) elif provider == AttestationProvider.AZURE: return create_azure_attestation(entra_resource, session_manager) elif provider == AttestationProvider.GCP: - return create_gcp_attestation(session_manager) + return create_gcp_attestation(session_manager, impersonation_path) elif provider == AttestationProvider.OIDC: return create_oidc_attestation(token) else: diff --git a/test/csp_helpers.py b/test/csp_helpers.py index 77237ef031..534151d057 100644 --- a/test/csp_helpers.py +++ b/test/csp_helpers.py @@ -40,6 +40,12 @@ def gen_dummy_id_token( ) +def gen_dummy_access_token(sub="test-subject", key="secret") -> str: + """Generates a dummy access token using the given subject.""" + logger.debug(f"Generating dummy access token for subject {sub}") + return (sub + key).encode("utf-8").hex() + + def build_response(content: bytes, status_code: int = 200, headers=None) -> Response: """Builds a requests.Response object with the given status code and content.""" response = Response() @@ -285,6 +291,19 @@ def handle_request(self, method, parsed_url, headers, timeout): audience = query_string["audience"][0] self.token = gen_dummy_id_token(sub=self.sub, iss=self.iss, aud=audience) return build_response(self.token.encode("utf-8")) + elif ( + method == "GET" + and parsed_url.path + == "/computeMetadata/v1/instance/service-accounts/default/token" + and headers.get("Metadata-Flavor") == "Google" + ): + self.token = gen_dummy_access_token(sub=self.sub) + ret = { + "access_token": self.token, + "expires_in": 3599, + "token_type": "Bearer", + } + return build_response(json.dumps(ret).encode("utf-8")) else: # Reject malformed requests. raise HTTPError() @@ -348,6 +367,11 @@ class FakeAwsEnvironment: def __init__(self): # Defaults used for generating a token. Can be overriden in individual tests. self.arn = "arn:aws:sts::123456789:assumed-role/My-Role/i-34afe100cad287fab" + # Path of roles that can be assumed. Empty if no impersonation is allowed. + # Can be overriden in individual tests. + self.assumption_path = [] + self.assume_role_call_count = 0 + self.caller_identity = {"Arn": self.arn} self.region = "us-east-1" self.credentials = Credentials(access_key="ak", secret_key="sk") @@ -356,6 +380,25 @@ def __init__(self): ) self.metadata_token = "test-token" + def assume_role(self, **kwargs): + if ( + self.assumption_path + and kwargs["RoleArn"] == self.assumption_path[self.assume_role_call_count] + ): + arn = self.assumption_path[self.assume_role_call_count] + self.assume_role_call_count += 1 + return { + "Credentials": { + "AccessKeyId": "access_key", + "SecretAccessKey": "secret_key", + "SessionToken": "session_token", + "Expiration": int(time()) + 60 * 60, + }, + "AssumedRoleUser": {"AssumedRoleId": hash(arn), "Arn": arn}, + "ResponseMetadata": {}, + } + return {} + def get_region(self): return self.region @@ -381,6 +424,7 @@ def fetcher_fetch_metadata_token(self): def boto3_client(self, *args, **kwargs): mock_client = mock.Mock() mock_client.get_caller_identity.return_value = self.caller_identity + mock_client.assume_role = self.assume_role return mock_client def __enter__(self): @@ -423,6 +467,9 @@ def __enter__(self): side_effect=self.boto3_client, ) ) + self.patchers.append( + mock.patch("boto3.session.Session.client", side_effect=self.boto3_client) + ) for patcher in self.patchers: patcher.__enter__() return self diff --git a/test/helpers.py b/test/helpers.py index 2ce88286a0..6c335c930e 100644 --- a/test/helpers.py +++ b/test/helpers.py @@ -3,6 +3,7 @@ import asyncio import base64 +import copy import functools import math import os @@ -14,6 +15,7 @@ import pytest +from snowflake.connector.auth._auth import Auth from snowflake.connector.compat import OK if TYPE_CHECKING: @@ -311,3 +313,36 @@ def _arrow_error_stream_random_input_test(use_table_iterator): # error instance users get should be the same assert len(exception_result) assert len(result_array) == 0 + + +def create_mock_auth_body(): + ocsp_mode = Mock() + ocsp_mode.name = "ocsp_mode" + session_manager = Mock() + session_manager.clone = lambda max_retries: "session_manager" + + return Auth.base_auth_data( + "user", + "account", + "application", + "internal_application_name", + "internal_application_version", + ocsp_mode, + login_timeout=60 * 60, + network_timeout=60 * 60, + socket_timeout=60 * 60, + platform_detection_timeout_seconds=0.2, + session_manager=session_manager, + ) + + +def apply_auth_class_update_body(auth_class, req_body_before): + req_body_after = copy.deepcopy(req_body_before) + auth_class.update_body(req_body_after) + return req_body_after + + +async def apply_auth_class_update_body_async(auth_class, req_body_before): + req_body_after = copy.deepcopy(req_body_before) + await auth_class.update_body(req_body_after) + return req_body_after diff --git a/test/unit/aio/csp_helpers_async.py b/test/unit/aio/csp_helpers_async.py index 2a6cf6d267..fab005be65 100644 --- a/test/unit/aio/csp_helpers_async.py +++ b/test/unit/aio/csp_helpers_async.py @@ -202,6 +202,8 @@ async def async_get_arn(): ) # Mock the async STS client for direct aioboto3 usage + fake_aws_self = self + class MockStsClient: async def __aenter__(self): return self @@ -212,6 +214,9 @@ async def __aexit__(self, exc_type, exc_val, exc_tb): async def get_caller_identity(self): return await async_get_caller_identity() + async def assume_role(self, **kwargs): + return fake_aws_self.assume_role(**kwargs) + def mock_session_client(service_name): if service_name == "sts": return MockStsClient() diff --git a/test/unit/aio/test_auth_async.py b/test/unit/aio/test_auth_async.py index ca871d3cb5..e92f3be556 100644 --- a/test/unit/aio/test_auth_async.py +++ b/test/unit/aio/test_auth_async.py @@ -8,6 +8,7 @@ import asyncio import inspect import sys +from test.helpers import apply_auth_class_update_body_async, create_mock_auth_body from test.unit.aio.mock_utils import mock_connection from unittest.mock import Mock, PropertyMock @@ -340,3 +341,21 @@ def test_mro(): assert AuthByDefault.mro().index(AuthByPluginAsync) < AuthByDefault.mro().index( AuthByPluginSync ) + + +async def test_auth_by_default_prepare_body_does_not_overwrite_client_environment_fields(): + password = "testpassword" + auth_class = AuthByDefault(password) + + req_body_before = create_mock_auth_body() + req_body_after = await apply_auth_class_update_body_async( + auth_class, req_body_before + ) + + assert all( + [ + req_body_before["data"]["CLIENT_ENVIRONMENT"][k] + == req_body_after["data"]["CLIENT_ENVIRONMENT"][k] + for k in req_body_before["data"]["CLIENT_ENVIRONMENT"] + ] + ) diff --git a/test/unit/aio/test_auth_keypair_async.py b/test/unit/aio/test_auth_keypair_async.py index 746c149baf..e802a3d1cc 100644 --- a/test/unit/aio/test_auth_keypair_async.py +++ b/test/unit/aio/test_auth_keypair_async.py @@ -5,6 +5,7 @@ from __future__ import annotations +from test.helpers import apply_auth_class_update_body_async, create_mock_auth_body from test.unit.aio.mock_utils import mock_connection from unittest.mock import Mock, PropertyMock, patch @@ -61,6 +62,24 @@ async def test_auth_keypair(authenticator): assert rest.master_token == "MASTER_TOKEN" +async def test_auth_prepare_body_does_not_overwrite_client_environment_fields(): + private_key_der, _ = generate_key_pair(2048) + auth_class = AuthByKeyPair(private_key=private_key_der) + + req_body_before = create_mock_auth_body() + req_body_after = await apply_auth_class_update_body_async( + auth_class, req_body_before + ) + + assert all( + [ + req_body_before["data"]["CLIENT_ENVIRONMENT"][k] + == req_body_after["data"]["CLIENT_ENVIRONMENT"][k] + for k in req_body_before["data"]["CLIENT_ENVIRONMENT"] + ] + ) + + async def test_auth_keypair_abc(): """Simple Key Pair test using abstraction layer.""" private_key_der, public_key_der_encoded = generate_key_pair(2048) diff --git a/test/unit/aio/test_auth_oauth_async.py b/test/unit/aio/test_auth_oauth_async.py index e873ec3a67..cef33781b6 100644 --- a/test/unit/aio/test_auth_oauth_async.py +++ b/test/unit/aio/test_auth_oauth_async.py @@ -5,6 +5,8 @@ from __future__ import annotations +from test.helpers import apply_auth_class_update_body_async, create_mock_auth_body + import pytest from snowflake.connector.aio.auth import AuthByOAuth @@ -20,6 +22,24 @@ async def test_auth_oauth(): assert body["data"]["AUTHENTICATOR"] == "OAUTH", body +async def test_auth_prepare_body_does_not_overwrite_client_environment_fields(): + token = "oAuthToken" + auth_class = AuthByOAuth(token) + + req_body_before = create_mock_auth_body() + req_body_after = await apply_auth_class_update_body_async( + auth_class, req_body_before + ) + + assert all( + [ + req_body_before["data"]["CLIENT_ENVIRONMENT"][k] + == req_body_after["data"]["CLIENT_ENVIRONMENT"][k] + for k in req_body_before["data"]["CLIENT_ENVIRONMENT"] + ] + ) + + @pytest.mark.parametrize("authenticator", ["oauth", "OAUTH"]) async def test_oauth_authenticator_is_case_insensitive(monkeypatch, authenticator): """Test that oauth authenticator is case insensitive.""" diff --git a/test/unit/aio/test_auth_oauth_auth_code_async.py b/test/unit/aio/test_auth_oauth_auth_code_async.py index b13d8f9970..091e0aa097 100644 --- a/test/unit/aio/test_auth_oauth_auth_code_async.py +++ b/test/unit/aio/test_auth_oauth_auth_code_async.py @@ -4,6 +4,7 @@ # import unittest.mock as mock +from test.helpers import apply_auth_class_update_body_async, create_mock_auth_body from unittest.mock import patch import pytest @@ -44,6 +45,34 @@ async def test_auth_oauth_auth_code_oauth_type(omit_oauth_urls_check): ) +async def test_auth_prepare_body_does_not_overwrite_client_environment_fields( + omit_oauth_urls_check, +): + auth_class = AuthByOauthCode( + "app", + "clientId", + "clientSecret", + "auth_url", + "tokenRequestUrl", + "redirectUri:{port}", + "scope", + "host", + ) + + req_body_before = create_mock_auth_body() + req_body_after = await apply_auth_class_update_body_async( + auth_class, req_body_before + ) + + assert all( + [ + req_body_before["data"]["CLIENT_ENVIRONMENT"][k] + == req_body_after["data"]["CLIENT_ENVIRONMENT"][k] + for k in req_body_before["data"]["CLIENT_ENVIRONMENT"] + ] + ) + + @pytest.mark.parametrize("rtr_enabled", [True, False]) async def test_auth_oauth_auth_code_single_use_refresh_tokens( rtr_enabled: bool, omit_oauth_urls_check diff --git a/test/unit/aio/test_auth_oauth_credentials_async.py b/test/unit/aio/test_auth_oauth_credentials_async.py index 258cfa0c4f..90cfc0b858 100644 --- a/test/unit/aio/test_auth_oauth_credentials_async.py +++ b/test/unit/aio/test_auth_oauth_credentials_async.py @@ -5,6 +5,8 @@ from __future__ import annotations +from test.helpers import apply_auth_class_update_body_async, create_mock_auth_body + import pytest from snowflake.connector.aio.auth import AuthByOauthCredentials @@ -27,6 +29,29 @@ async def test_auth_oauth_credentials_oauth_type(): ) +async def test_auth_prepare_body_does_not_overwrite_client_environment_fields(): + auth_class = AuthByOauthCredentials( + "app", + "clientId", + "clientSecret", + "https://example.com/oauth/token", + "scope", + ) + + req_body_before = create_mock_auth_body() + req_body_after = await apply_auth_class_update_body_async( + auth_class, req_body_before + ) + + assert all( + [ + req_body_before["data"]["CLIENT_ENVIRONMENT"][k] + == req_body_after["data"]["CLIENT_ENVIRONMENT"][k] + for k in req_body_before["data"]["CLIENT_ENVIRONMENT"] + ] + ) + + @pytest.mark.parametrize( "authenticator", ["OAUTH_CLIENT_CREDENTIALS", "oauth_client_credentials"] ) diff --git a/test/unit/aio/test_auth_okta_async.py b/test/unit/aio/test_auth_okta_async.py index 855ee535b3..1a2a8d0298 100644 --- a/test/unit/aio/test_auth_okta_async.py +++ b/test/unit/aio/test_auth_okta_async.py @@ -6,6 +6,7 @@ from __future__ import annotations import logging +from test.helpers import apply_auth_class_update_body_async, create_mock_auth_body from test.unit.aio.mock_utils import mock_connection from unittest.mock import MagicMock, Mock, PropertyMock, patch @@ -18,6 +19,24 @@ from snowflake.connector.description import CLIENT_NAME, CLIENT_VERSION +async def test_auth_prepare_body_does_not_overwrite_client_environment_fields(): + application = "testapplication" + auth_class = AuthByOkta(application) + + req_body_before = create_mock_auth_body() + req_body_after = await apply_auth_class_update_body_async( + auth_class, req_body_before + ) + + assert all( + [ + req_body_before["data"]["CLIENT_ENVIRONMENT"][k] + == req_body_after["data"]["CLIENT_ENVIRONMENT"][k] + for k in req_body_before["data"]["CLIENT_ENVIRONMENT"] + ] + ) + + async def test_auth_okta(): """Authentication by OKTA positive test case.""" authenticator = "https://testsso.snowflake.net/" diff --git a/test/unit/aio/test_auth_pat_async.py b/test/unit/aio/test_auth_pat_async.py index 5086f3a96f..618d3e775f 100644 --- a/test/unit/aio/test_auth_pat_async.py +++ b/test/unit/aio/test_auth_pat_async.py @@ -5,6 +5,8 @@ from __future__ import annotations +from test.helpers import apply_auth_class_update_body_async, create_mock_auth_body + import pytest from snowflake.connector.aio.auth import AuthByPAT @@ -27,6 +29,24 @@ async def test_auth_pat(): assert auth.assertion_content is None +async def test_pat_prepare_body_does_not_overwrite_client_environment_fields(): + token = "patToken" + auth_class = AuthByPAT(token) + + req_body_before = create_mock_auth_body() + req_body_after = await apply_auth_class_update_body_async( + auth_class, req_body_before + ) + + assert all( + [ + req_body_before["data"]["CLIENT_ENVIRONMENT"][k] + == req_body_after["data"]["CLIENT_ENVIRONMENT"][k] + for k in req_body_before["data"]["CLIENT_ENVIRONMENT"] + ] + ) + + async def test_auth_pat_reauthenticate(): """Test PAT reauthenticate.""" token = "patToken" diff --git a/test/unit/aio/test_auth_webbrowser_async.py b/test/unit/aio/test_auth_webbrowser_async.py index 8f7b6b988a..8f555c6a9a 100644 --- a/test/unit/aio/test_auth_webbrowser_async.py +++ b/test/unit/aio/test_auth_webbrowser_async.py @@ -8,6 +8,7 @@ import asyncio import base64 import socket +from test.helpers import apply_auth_class_update_body_async, create_mock_auth_body from test.unit.aio.mock_utils import mock_connection from unittest import mock from unittest.mock import MagicMock, Mock, PropertyMock, patch @@ -918,6 +919,22 @@ async def mock_webbrowser_auth_prepare( await conn.close() +async def test_auth_prepare_body_does_not_overwrite_client_environment_fields(): + auth_class = AuthByWebBrowser(application=APPLICATION) + req_body_before = create_mock_auth_body() + req_body_after = await apply_auth_class_update_body_async( + auth_class, req_body_before + ) + + assert all( + [ + req_body_before["data"]["CLIENT_ENVIRONMENT"][k] + == req_body_after["data"]["CLIENT_ENVIRONMENT"][k] + for k in req_body_before["data"]["CLIENT_ENVIRONMENT"] + ] + ) + + def test_mro(): """Ensure that methods from AuthByPluginAsync override those from AuthByPlugin.""" from snowflake.connector.aio.auth import AuthByPlugin as AuthByPluginAsync diff --git a/test/unit/aio/test_auth_workload_identity_async.py b/test/unit/aio/test_auth_workload_identity_async.py index bb563d6591..54d9e4b466 100644 --- a/test/unit/aio/test_auth_workload_identity_async.py +++ b/test/unit/aio/test_auth_workload_identity_async.py @@ -8,17 +8,22 @@ import os from base64 import b64decode from unittest import mock +from unittest.mock import AsyncMock from urllib.parse import parse_qs, urlparse import aiohttp import jwt import pytest -from snowflake.connector.aio._wif_util import AttestationProvider +from snowflake.connector.aio._wif_util import ( + AttestationProvider, + WorkloadIdentityAttestation, +) from snowflake.connector.aio.auth import AuthByWorkloadIdentity from snowflake.connector.errors import ProgrammingError -from ...csp_helpers import gen_dummy_id_token +from ...csp_helpers import gen_dummy_access_token, gen_dummy_id_token +from ...helpers import apply_auth_class_update_body_async, create_mock_auth_body from .csp_helpers_async import FakeAwsEnvironmentAsync, FakeGceMetadataServiceAsync logger = logging.getLogger(__name__) @@ -137,6 +142,42 @@ async def mock_post(*args, **kwargs): await connection.close() +@pytest.mark.parametrize( + "provider,additional_args", + [ + (AttestationProvider.AWS, {}), + (AttestationProvider.GCP, {}), + (AttestationProvider.AZURE, {}), + ( + AttestationProvider.OIDC, + {"token": gen_dummy_id_token(sub="service-1", iss="issuer-1")}, + ), + ], +) +async def test_auth_prepare_body_does_not_overwrite_client_environment_fields( + provider, additional_args +): + auth_class = AuthByWorkloadIdentity(provider=provider, **additional_args) + auth_class.attestation = WorkloadIdentityAttestation( + provider=AttestationProvider.GCP, + credential=None, + user_identifier_components=None, + ) + + req_body_before = create_mock_auth_body() + req_body_after = await apply_auth_class_update_body_async( + auth_class, req_body_before + ) + + assert all( + [ + req_body_before["data"]["CLIENT_ENVIRONMENT"][k] + == req_body_after["data"]["CLIENT_ENVIRONMENT"][k] + for k in req_body_before["data"]["CLIENT_ENVIRONMENT"] + ] + ) + + # -- OIDC Tests -- @@ -151,6 +192,7 @@ async def test_explicit_oidc_valid_inline_token_plumbed_to_api(): "AUTHENTICATOR": "WORKLOAD_IDENTITY", "PROVIDER": "OIDC", "TOKEN": dummy_token, + "CLIENT_ENVIRONMENT": {"WORKLOAD_IDENTITY_IMPERSONATION_PATH_LENGTH": 0}, } @@ -208,6 +250,9 @@ async def test_explicit_aws_encodes_audience_host_signature_to_api( data = await extract_api_data(auth_class) assert data["AUTHENTICATOR"] == "WORKLOAD_IDENTITY" assert data["PROVIDER"] == "AWS" + assert ( + data["CLIENT_ENVIRONMENT"]["WORKLOAD_IDENTITY_IMPERSONATION_PATH_LENGTH"] == 0 + ) verify_aws_token(data["TOKEN"], fake_aws_environment.region) @@ -252,6 +297,22 @@ async def test_explicit_aws_generates_unique_assertion_content( ) +async def test_aws_impersonation_calls_correct_apis_for_each_role_in_impersonation_path( + fake_aws_environment: FakeAwsEnvironmentAsync, +): + impersonation_path = [ + "arn:aws:iam::123456789:role/role2", + "arn:aws:iam::123456789:role/role3", + ] + fake_aws_environment.assumption_path = impersonation_path + auth_class = AuthByWorkloadIdentity( + provider=AttestationProvider.AWS, impersonation_path=impersonation_path + ) + await auth_class.prepare(conn=None) + + assert fake_aws_environment.assume_role_call_count == 2 + + # -- GCP Tests -- @@ -279,7 +340,7 @@ async def test_explicit_gcp_metadata_server_error_bubbles_up(exception): with pytest.raises(ProgrammingError) as excinfo: await auth_class.prepare(conn=None) - assert "Error fetching GCP metadata:" in str(excinfo.value) + assert "Error fetching GCP identity token:" in str(excinfo.value) assert "Ensure the application is running on GCP." in str(excinfo.value) @@ -293,6 +354,7 @@ async def test_explicit_gcp_plumbs_token_to_api( "AUTHENTICATOR": "WORKLOAD_IDENTITY", "PROVIDER": "GCP", "TOKEN": fake_gce_metadata_service.token, + "CLIENT_ENVIRONMENT": {"WORKLOAD_IDENTITY_IMPERSONATION_PATH_LENGTH": 0}, } @@ -307,6 +369,52 @@ async def test_explicit_gcp_generates_unique_assertion_content( assert auth_class.assertion_content == '{"_provider":"GCP","sub":"123456"}' +@mock.patch("snowflake.connector.aio._session_manager.SessionManager.post") +async def test_gcp_calls_correct_apis_and_populates_auth_data_for_final_sa( + mock_post_request, fake_gce_metadata_service: FakeGceMetadataServiceAsync +): + fake_gce_metadata_service.sub = "sa1" + impersonation_path = ["sa2", "sa3"] + sa1_access_token = gen_dummy_access_token("sa1") + sa3_id_token = gen_dummy_id_token("sa3") + + # Mock the POST request response + class AsyncResponse: + def __init__(self, content): + self._content = content + self.content = mock.Mock() + self.content.read = AsyncMock(return_value=content) + + mock_post_request.return_value = AsyncResponse( + json.dumps({"token": sa3_id_token}).encode("utf-8") + ) + + auth_class = AuthByWorkloadIdentity( + provider=AttestationProvider.GCP, impersonation_path=impersonation_path + ) + await auth_class.prepare(conn=None) + + mock_post_request.assert_called_once_with( + url="https://iamcredentials.googleapis.com/v1/projects/-/serviceAccounts/sa3:generateIdToken", + headers={ + "Authorization": f"Bearer {sa1_access_token}", + "Content-Type": "application/json", + }, + json={ + "delegates": ["projects/-/serviceAccounts/sa2"], + "audience": "snowflakecomputing.com", + }, + ) + + assert auth_class.assertion_content == '{"_provider":"GCP","sub":"sa3"}' + assert await extract_api_data(auth_class) == { + "AUTHENTICATOR": "WORKLOAD_IDENTITY", + "PROVIDER": "GCP", + "TOKEN": sa3_id_token, + "CLIENT_ENVIRONMENT": {"WORKLOAD_IDENTITY_IMPERSONATION_PATH_LENGTH": 2}, + } + + # -- Azure Tests -- @@ -358,6 +466,7 @@ async def test_explicit_azure_plumbs_token_to_api(fake_azure_metadata_service): "AUTHENTICATOR": "WORKLOAD_IDENTITY", "PROVIDER": "AZURE", "TOKEN": fake_azure_metadata_service.token, + "CLIENT_ENVIRONMENT": {"WORKLOAD_IDENTITY_IMPERSONATION_PATH_LENGTH": 0}, } diff --git a/test/unit/aio/test_connection_async_unit.py b/test/unit/aio/test_connection_async_unit.py index 590a85711b..45399d79d1 100644 --- a/test/unit/aio/test_connection_async_unit.py +++ b/test/unit/aio/test_connection_async_unit.py @@ -606,6 +606,7 @@ async def test_otel_error_message_async(caplog, mock_post_requests): "workload_identity_entra_resource", "api://0b2f151f-09a2-46eb-ad5a-39d5ebef917b", ), + ("workload_identity_impersonation_path", ["subject-b", "subject-c"]), ], ) async def test_cannot_set_dependent_params_without_wlid_authenticator( @@ -655,6 +656,74 @@ async def test_workload_identity_provider_is_required_for_wif_authenticator( assert expected_error_msg in str(excinfo.value) +@pytest.mark.parametrize( + "provider_param", + [ + # Strongly-typed values. + AttestationProvider.AZURE, + AttestationProvider.OIDC, + # String values. + "AZURE", + "OIDC", + ], +) +async def test_workload_identity_impersonation_path_errors_for_unsupported_providers( + monkeypatch, provider_param +): + async def mock_authenticate(*_): + pass + + with monkeypatch.context() as m: + m.setattr( + "snowflake.connector.aio._connection.SnowflakeConnection._authenticate", + mock_authenticate, + ) + + with pytest.raises(ProgrammingError) as excinfo: + await snowflake.connector.aio.connect( + account="account", + authenticator="WORKLOAD_IDENTITY", + workload_identity_provider=provider_param, + workload_identity_impersonation_path=[ + "sa2@project.iam.gserviceaccount.com" + ], + ) + assert ( + "workload_identity_impersonation_path is currently only supported for GCP and AWS." + in str(excinfo.value) + ) + + +@pytest.mark.parametrize( + "provider_param,impersonation_path", + [ + (AttestationProvider.GCP, ["sa2@project.iam.gserviceaccount.com"]), + (AttestationProvider.AWS, ["arn:aws:iam::1234567890:role/role2"]), + ("GCP", ["sa2@project.iam.gserviceaccount.com"]), + ("AWS", ["arn:aws:iam::1234567890:role/role2"]), + ], +) +async def test_workload_identity_impersonation_path_populates_auth_class_for_supported_provider( + monkeypatch, provider_param, impersonation_path +): + async def mock_authenticate(*_): + pass + + with monkeypatch.context() as m: + m.setattr( + "snowflake.connector.aio._connection.SnowflakeConnection._authenticate", + mock_authenticate, + ) + + conn = await snowflake.connector.aio.connect( + account="account", + authenticator="WORKLOAD_IDENTITY", + workload_identity_provider=provider_param, + workload_identity_impersonation_path=impersonation_path, + ) + assert conn.auth_class.impersonation_path == impersonation_path + + @pytest.mark.parametrize( "provider_param, parsed_provider", [ diff --git a/test/unit/test_auth.py b/test/unit/test_auth.py index 595528601e..cfae32f8c3 100644 --- a/test/unit/test_auth.py +++ b/test/unit/test_auth.py @@ -4,6 +4,7 @@ import inspect import sys import time +from test.helpers import apply_auth_class_update_body, create_mock_auth_body from unittest.mock import Mock, PropertyMock import pytest @@ -337,3 +338,19 @@ def test_authbyplugin_abc_api(): 'password': , \ 'kwargs': })""" ) + + +def test_auth_by_default_prepare_body_does_not_overwrite_client_environment_fields(): + password = "testpassword" + auth_class = AuthByDefault(password) + + req_body_before = create_mock_auth_body() + req_body_after = apply_auth_class_update_body(auth_class, req_body_before) + + assert all( + [ + req_body_before["data"]["CLIENT_ENVIRONMENT"][k] + == req_body_after["data"]["CLIENT_ENVIRONMENT"][k] + for k in req_body_before["data"]["CLIENT_ENVIRONMENT"] + ] + ) diff --git a/test/unit/test_auth_keypair.py b/test/unit/test_auth_keypair.py index c2c875aec1..80c27e9602 100644 --- a/test/unit/test_auth_keypair.py +++ b/test/unit/test_auth_keypair.py @@ -1,6 +1,7 @@ #!/usr/bin/env python from __future__ import annotations +from test.helpers import apply_auth_class_update_body, create_mock_auth_body from unittest.mock import Mock, PropertyMock, patch import pytest @@ -63,6 +64,22 @@ def test_auth_keypair(authenticator): assert rest.master_token == "MASTER_TOKEN" +def test_auth_prepare_body_does_not_overwrite_client_environment_fields(): + private_key_der, _ = generate_key_pair(2048) + auth_class = AuthByKeyPair(private_key=private_key_der) + + req_body_before = create_mock_auth_body() + req_body_after = apply_auth_class_update_body(auth_class, req_body_before) + + assert all( + [ + req_body_before["data"]["CLIENT_ENVIRONMENT"][k] + == req_body_after["data"]["CLIENT_ENVIRONMENT"][k] + for k in req_body_before["data"]["CLIENT_ENVIRONMENT"] + ] + ) + + def test_auth_keypair_abc(): """Simple Key Pair test using abstraction layer.""" private_key_der, public_key_der_encoded = generate_key_pair(2048) diff --git a/test/unit/test_auth_oauth.py b/test/unit/test_auth_oauth.py index 87870bda8e..7e7a913f24 100644 --- a/test/unit/test_auth_oauth.py +++ b/test/unit/test_auth_oauth.py @@ -1,6 +1,8 @@ #!/usr/bin/env python from __future__ import annotations +from test.helpers import apply_auth_class_update_body, create_mock_auth_body + try: # pragma: no cover from snowflake.connector.auth import AuthByOAuth except ImportError: @@ -18,6 +20,22 @@ def test_auth_oauth(): assert body["data"]["AUTHENTICATOR"] == "OAUTH", body +def test_auth_prepare_body_does_not_overwrite_client_environment_fields(): + token = "oAuthToken" + auth_class = AuthByOAuth(token) + + req_body_before = create_mock_auth_body() + req_body_after = apply_auth_class_update_body(auth_class, req_body_before) + + assert all( + [ + req_body_before["data"]["CLIENT_ENVIRONMENT"][k] + == req_body_after["data"]["CLIENT_ENVIRONMENT"][k] + for k in req_body_before["data"]["CLIENT_ENVIRONMENT"] + ] + ) + + @pytest.mark.parametrize("authenticator", ["oauth", "OAUTH"]) def test_oauth_authenticator_is_case_insensitive(monkeypatch, authenticator): """Test that oauth authenticator is case insensitive.""" diff --git a/test/unit/test_auth_oauth_auth_code.py b/test/unit/test_auth_oauth_auth_code.py index 8ede51facd..76894791cc 100644 --- a/test/unit/test_auth_oauth_auth_code.py +++ b/test/unit/test_auth_oauth_auth_code.py @@ -4,6 +4,7 @@ # import unittest.mock as mock +from test.helpers import apply_auth_class_update_body, create_mock_auth_body from unittest.mock import patch import pytest @@ -44,6 +45,32 @@ def test_auth_oauth_auth_code_oauth_type(omit_oauth_urls_check): ) +def test_auth_prepare_body_does_not_overwrite_client_environment_fields( + omit_oauth_urls_check, +): + auth_class = AuthByOauthCode( + "app", + "clientId", + "clientSecret", + "auth_url", + "tokenRequestUrl", + "redirectUri:{port}", + "scope", + "host", + ) + + req_body_before = create_mock_auth_body() + req_body_after = apply_auth_class_update_body(auth_class, req_body_before) + + assert all( + [ + req_body_before["data"]["CLIENT_ENVIRONMENT"][k] + == req_body_after["data"]["CLIENT_ENVIRONMENT"][k] + for k in req_body_before["data"]["CLIENT_ENVIRONMENT"] + ] + ) + + @pytest.mark.parametrize("rtr_enabled", [True, False]) def test_auth_oauth_auth_code_single_use_refresh_tokens( rtr_enabled: bool, omit_oauth_urls_check diff --git a/test/unit/test_auth_oauth_credentials.py b/test/unit/test_auth_oauth_credentials.py index 7539cdbb97..75b3cbd1ed 100644 --- a/test/unit/test_auth_oauth_credentials.py +++ b/test/unit/test_auth_oauth_credentials.py @@ -4,6 +4,8 @@ # +from test.helpers import apply_auth_class_update_body, create_mock_auth_body + import pytest from snowflake.connector.auth import AuthByOauthCredentials @@ -26,6 +28,27 @@ def test_auth_oauth_credentials_oauth_type(): ) +def test_auth_prepare_body_does_not_overwrite_client_environment_fields(): + auth_class = AuthByOauthCredentials( + "app", + "clientId", + "clientSecret", + "https://example.com/oauth/token", + "scope", + ) + + req_body_before = create_mock_auth_body() + req_body_after = apply_auth_class_update_body(auth_class, req_body_before) + + assert all( + [ + req_body_before["data"]["CLIENT_ENVIRONMENT"][k] + == req_body_after["data"]["CLIENT_ENVIRONMENT"][k] + for k in req_body_before["data"]["CLIENT_ENVIRONMENT"] + ] + ) + + @pytest.mark.parametrize( "authenticator", ["OAUTH_CLIENT_CREDENTIALS", "oauth_client_credentials"] ) diff --git a/test/unit/test_auth_okta.py b/test/unit/test_auth_okta.py index a623b5ae71..206f630969 100644 --- a/test/unit/test_auth_okta.py +++ b/test/unit/test_auth_okta.py @@ -2,6 +2,7 @@ from __future__ import annotations import logging +from test.helpers import apply_auth_class_update_body, create_mock_auth_body from unittest.mock import Mock, PropertyMock, patch import pytest @@ -19,6 +20,22 @@ from snowflake.connector.auth_okta import AuthByOkta +def test_auth_prepare_body_does_not_overwrite_client_environment_fields(): + application = "testapplication" + auth_class = AuthByOkta(application) + + req_body_before = create_mock_auth_body() + req_body_after = apply_auth_class_update_body(auth_class, req_body_before) + + assert all( + [ + req_body_before["data"]["CLIENT_ENVIRONMENT"][k] + == req_body_after["data"]["CLIENT_ENVIRONMENT"][k] + for k in req_body_before["data"]["CLIENT_ENVIRONMENT"] + ] + ) + + def test_auth_okta(): """Authentication by OKTA positive test case.""" authenticator = "https://testsso.snowflake.net/" diff --git a/test/unit/test_auth_pat.py b/test/unit/test_auth_pat.py index f4734cd040..ca7dccd1cc 100644 --- a/test/unit/test_auth_pat.py +++ b/test/unit/test_auth_pat.py @@ -4,6 +4,8 @@ # from __future__ import annotations +from test.helpers import apply_auth_class_update_body, create_mock_auth_body + import pytest from snowflake.connector.auth import AuthByPAT, AuthNoAuth @@ -26,6 +28,22 @@ def test_auth_pat(): assert auth.assertion_content is None +def test_pat_prepare_body_does_not_overwrite_client_environment_fields(): + token = "patToken" + auth_class = AuthByPAT(token) + + req_body_before = create_mock_auth_body() + req_body_after = apply_auth_class_update_body(auth_class, req_body_before) + + assert all( + [ + req_body_before["data"]["CLIENT_ENVIRONMENT"][k] + == req_body_after["data"]["CLIENT_ENVIRONMENT"][k] + for k in req_body_before["data"]["CLIENT_ENVIRONMENT"] + ] + ) + + def test_auth_pat_reauthenticate(): """Test PAT reauthenticate.""" token = "patToken" diff --git a/test/unit/test_auth_webbrowser.py b/test/unit/test_auth_webbrowser.py index db97f58bb7..f649050734 100644 --- a/test/unit/test_auth_webbrowser.py +++ b/test/unit/test_auth_webbrowser.py @@ -3,6 +3,7 @@ import base64 import socket +from test.helpers import apply_auth_class_update_body, create_mock_auth_body from unittest import mock from unittest.mock import MagicMock, Mock, PropertyMock, patch @@ -792,3 +793,17 @@ def mock_webbrowser_auth_prepare( assert isinstance(conn.auth_class, AuthByWebBrowser) conn.close() + + +def test_auth_prepare_body_does_not_overwrite_client_environment_fields(): + auth_class = AuthByWebBrowser(application=APPLICATION) + req_body_before = create_mock_auth_body() + req_body_after = apply_auth_class_update_body(auth_class, req_body_before) + + assert all( + [ + req_body_before["data"]["CLIENT_ENVIRONMENT"][k] + == req_body_after["data"]["CLIENT_ENVIRONMENT"][k] + for k in req_body_before["data"]["CLIENT_ENVIRONMENT"] + ] + ) diff --git a/test/unit/test_auth_workload_identity.py b/test/unit/test_auth_workload_identity.py index 1880d1b7d1..bfe00b6a96 100644 --- a/test/unit/test_auth_workload_identity.py +++ b/test/unit/test_auth_workload_identity.py @@ -2,6 +2,7 @@ import logging import os from base64 import b64decode +from test.helpers import apply_auth_class_update_body, create_mock_auth_body from unittest import mock from urllib.parse import parse_qs, urlparse @@ -15,9 +16,19 @@ HTTPError, Timeout, ) -from snowflake.connector.wif_util import AttestationProvider, get_aws_sts_hostname +from snowflake.connector.wif_util import ( + AttestationProvider, + WorkloadIdentityAttestation, + get_aws_sts_hostname, +) -from ..csp_helpers import FakeAwsEnvironment, FakeGceMetadataService, gen_dummy_id_token +from ..csp_helpers import ( + FakeAwsEnvironment, + FakeGceMetadataService, + build_response, + gen_dummy_access_token, + gen_dummy_id_token, +) logger = logging.getLogger(__name__) @@ -117,6 +128,40 @@ def test_wif_authenticator_is_case_insensitive( assert isinstance(connection.auth_class, AuthByWorkloadIdentity) +@pytest.mark.parametrize( + "provider,additional_args", + [ + (AttestationProvider.AWS, {}), + (AttestationProvider.GCP, {}), + (AttestationProvider.AZURE, {}), + ( + AttestationProvider.OIDC, + {"token": gen_dummy_id_token(sub="service-1", iss="issuer-1")}, + ), + ], +) +def test_auth_prepare_body_does_not_overwrite_client_environment_fields( + provider, additional_args +): + auth_class = AuthByWorkloadIdentity(provider=provider, **additional_args) + auth_class.attestation = WorkloadIdentityAttestation( + provider=AttestationProvider.GCP, + credential=None, + user_identifier_components=None, + ) + + req_body_before = create_mock_auth_body() + req_body_after = apply_auth_class_update_body(auth_class, req_body_before) + + assert all( + [ + req_body_before["data"]["CLIENT_ENVIRONMENT"][k] + == req_body_after["data"]["CLIENT_ENVIRONMENT"][k] + for k in req_body_before["data"]["CLIENT_ENVIRONMENT"] + ] + ) + + # -- OIDC Tests -- @@ -131,6 +176,7 @@ def test_explicit_oidc_valid_inline_token_plumbed_to_api(): "AUTHENTICATOR": "WORKLOAD_IDENTITY", "PROVIDER": "OIDC", "TOKEN": dummy_token, + "CLIENT_ENVIRONMENT": {"WORKLOAD_IDENTITY_IMPERSONATION_PATH_LENGTH": 0}, } @@ -186,6 +232,9 @@ def test_explicit_aws_encodes_audience_host_signature_to_api( data = extract_api_data(auth_class) assert data["AUTHENTICATOR"] == "WORKLOAD_IDENTITY" assert data["PROVIDER"] == "AWS" + assert ( + data["CLIENT_ENVIRONMENT"]["WORKLOAD_IDENTITY_IMPERSONATION_PATH_LENGTH"] == 0 + ) verify_aws_token(data["TOKEN"], fake_aws_environment.region) @@ -269,6 +318,22 @@ def test_get_aws_sts_hostname_invalid_inputs(region, partition): assert "Invalid AWS partition" in str(excinfo.value) +def test_aws_impersonation_calls_correct_apis_for_each_role_in_impersonation_path( + fake_aws_environment: FakeAwsEnvironment, +): + impersonation_path = [ + "arn:aws:iam::123456789:role/role2", + "arn:aws:iam::123456789:role/role3", + ] + fake_aws_environment.assumption_path = impersonation_path + auth_class = AuthByWorkloadIdentity( + provider=AttestationProvider.AWS, impersonation_path=impersonation_path + ) + auth_class.prepare(conn=None) + + assert fake_aws_environment.assume_role_call_count == 2 + + # -- GCP Tests -- @@ -289,7 +354,7 @@ def test_explicit_gcp_metadata_server_error_bubbles_up(exception): with pytest.raises(ProgrammingError) as excinfo: auth_class.prepare(conn=None) - assert "Error fetching GCP metadata:" in str(excinfo.value) + assert "Error fetching GCP identity token:" in str(excinfo.value) assert "Ensure the application is running on GCP." in str(excinfo.value) @@ -303,6 +368,7 @@ def test_explicit_gcp_plumbs_token_to_api( "AUTHENTICATOR": "WORKLOAD_IDENTITY", "PROVIDER": "GCP", "TOKEN": fake_gce_metadata_service.token, + "CLIENT_ENVIRONMENT": {"WORKLOAD_IDENTITY_IMPERSONATION_PATH_LENGTH": 0}, } @@ -317,6 +383,45 @@ def test_explicit_gcp_generates_unique_assertion_content( assert auth_class.assertion_content == '{"_provider":"GCP","sub":"123456"}' +@mock.patch("snowflake.connector.session_manager.SessionManager.post") +def test_gcp_calls_correct_apis_and_populates_auth_data_for_final_sa( + mock_post_request, fake_gce_metadata_service: FakeGceMetadataService +): + fake_gce_metadata_service.sub = "sa1" + impersonation_path = ["sa2", "sa3"] + sa1_access_token = gen_dummy_access_token("sa1") + sa3_id_token = gen_dummy_id_token("sa3") + + mock_post_request.return_value = build_response( + json.dumps({"token": sa3_id_token}).encode("utf-8") + ) + + auth_class = AuthByWorkloadIdentity( + provider=AttestationProvider.GCP, impersonation_path=impersonation_path + ) + auth_class.prepare(conn=None) + + mock_post_request.assert_called_once_with( + url="https://iamcredentials.googleapis.com/v1/projects/-/serviceAccounts/sa3:generateIdToken", + headers={ + "Authorization": f"Bearer {sa1_access_token}", + "Content-Type": "application/json", + }, + json={ + "delegates": ["projects/-/serviceAccounts/sa2"], + "audience": "snowflakecomputing.com", + }, + ) + + assert auth_class.assertion_content == '{"_provider":"GCP","sub":"sa3"}' + assert extract_api_data(auth_class) == { + "AUTHENTICATOR": "WORKLOAD_IDENTITY", + "PROVIDER": "GCP", + "TOKEN": sa3_id_token, + "CLIENT_ENVIRONMENT": {"WORKLOAD_IDENTITY_IMPERSONATION_PATH_LENGTH": 2}, + } + + # -- Azure Tests -- @@ -366,6 +471,7 @@ def test_explicit_azure_plumbs_token_to_api(fake_azure_metadata_service): "AUTHENTICATOR": "WORKLOAD_IDENTITY", "PROVIDER": "AZURE", "TOKEN": fake_azure_metadata_service.token, + "CLIENT_ENVIRONMENT": {"WORKLOAD_IDENTITY_IMPERSONATION_PATH_LENGTH": 0}, } diff --git a/test/unit/test_connection.py b/test/unit/test_connection.py index 3ef2fd6e36..bf9049cf30 100644 --- a/test/unit/test_connection.py +++ b/test/unit/test_connection.py @@ -632,6 +632,7 @@ def test_otel_error_message(caplog, mock_post_requests): "workload_identity_entra_resource", "api://0b2f151f-09a2-46eb-ad5a-39d5ebef917b", ), + ("workload_identity_impersonation_path", ["subject-b", "subject-c"]), ], ) def test_cannot_set_dependent_params_without_wlid_authenticator( @@ -680,6 +681,66 @@ def test_workload_identity_provider_is_required_for_wif_authenticator( assert expected_error_msg in str(excinfo.value) +@pytest.mark.parametrize( + "provider_param", + [ + # Strongly-typed values. + AttestationProvider.AZURE, + AttestationProvider.OIDC, + # String values. + "AZURE", + "OIDC", + ], +) +def test_workload_identity_impersonation_path_errors_for_unsupported_providers( + monkeypatch, provider_param +): + with monkeypatch.context() as m: + m.setattr( + "snowflake.connector.SnowflakeConnection._authenticate", lambda *_: None + ) + + with pytest.raises(ProgrammingError) as excinfo: + snowflake.connector.connect( + account="account", + authenticator="WORKLOAD_IDENTITY", + workload_identity_provider=provider_param, + workload_identity_impersonation_path=[ + "sa2@project.iam.gserviceaccount.com" + ], + ) + assert ( + "workload_identity_impersonation_path is currently only supported for GCP and AWS." + in str(excinfo.value) + ) + + +@pytest.mark.parametrize( + "provider_param,impersonation_path", + [ + (AttestationProvider.GCP, ["sa2@project.iam.gserviceaccount.com"]), + (AttestationProvider.AWS, ["arn:aws:iam::1234567890:role/role2"]), + ("GCP", ["sa2@project.iam.gserviceaccount.com"]), + ("AWS", ["arn:aws:iam::1234567890:role/role2"]), + ], +) +def test_workload_identity_impersonation_path_populates_auth_class_for_supported_provider( + monkeypatch, provider_param, impersonation_path +): + with monkeypatch.context() as m: + m.setattr( + "snowflake.connector.SnowflakeConnection._authenticate", lambda *_: None + ) + + conn = snowflake.connector.connect( + account="account", + authenticator="WORKLOAD_IDENTITY", + workload_identity_provider=provider_param, + workload_identity_impersonation_path=impersonation_path, + ) + assert conn.auth_class.impersonation_path == impersonation_path + + @pytest.mark.parametrize( "provider_param, parsed_provider", [ diff --git a/test/wif/test_wif.py b/test/wif/test_wif.py index c544578d8c..4b57aa0d76 100644 --- a/test/wif/test_wif.py +++ b/test/wif/test_wif.py @@ -22,6 +22,9 @@ ACCOUNT = os.getenv("SNOWFLAKE_TEST_WIF_ACCOUNT") HOST = os.getenv("SNOWFLAKE_TEST_WIF_HOST") PROVIDER = os.getenv("SNOWFLAKE_TEST_WIF_PROVIDER") +EXPECTED_USERNAME = os.getenv("SNOWFLAKE_TEST_WIF_USERNAME") +IMPERSONATION_PATH = os.getenv("SNOWFLAKE_TEST_WIF_IMPERSONATION_PATH") +EXPECTED_USERNAME_IMPERSONATION = os.getenv("SNOWFLAKE_TEST_WIF_USERNAME_IMPERSONATION") @pytest.mark.wif @@ -33,8 +36,8 @@ def test_wif_defined_provider(): "workload_identity_provider": PROVIDER, } assert connect_and_execute_simple_query( - connection_params - ), "Failed to connect with using WIF - automatic provider detection" + connection_params, EXPECTED_USERNAME + ), f"Failed to connect with using WIF using provider {PROVIDER}" @pytest.mark.wif @@ -51,21 +54,47 @@ def test_should_authenticate_using_oidc(): } assert connect_and_execute_simple_query( - connection_params + connection_params, expected_user=None ), "Failed to connect using WIF with OIDC provider" +@pytest.mark.wif +def test_should_authenticate_with_impersonation(): + if not isinstance(IMPERSONATION_PATH, str) or not IMPERSONATION_PATH: + pytest.skip("Skipping test - IMPERSONATION_PATH is not set") + + logger.debug(f"Using impersonation path: {IMPERSONATION_PATH}") + impersonation_path_list = IMPERSONATION_PATH.split(",") + + connection_params = { + "host": HOST, + "account": ACCOUNT, + "authenticator": "WORKLOAD_IDENTITY", + "workload_identity_provider": PROVIDER, + "workload_identity_impersonation_path": impersonation_path_list, + } + + assert connect_and_execute_simple_query( + connection_params, EXPECTED_USERNAME_IMPERSONATION + ), f"Failed to connect using WIF with provider {PROVIDER}" + + def is_provider_gcp() -> bool: return PROVIDER == "GCP" -def connect_and_execute_simple_query(connection_params) -> bool: +def connect_and_execute_simple_query(connection_params, expected_user=None) -> bool: try: logger.info("Trying to connect to Snowflake") with snowflake.connector.connect(**connection_params) as con: - result = con.cursor().execute("select 1;") - logger.debug(result.fetchall()) - logger.info("Successfully connected to Snowflake") + result = con.cursor().execute("select current_user();") + (user,) = result.fetchone() + logger.debug(user) + if expected_user: + assert ( + expected_user == user + ), f"Expected user '{expected_user}', got user '{user}'" + logger.info(f"Successfully connected to Snowflake as {user}") return True except Exception as e: logger.error(e) diff --git a/test/wif/test_wif_async.py b/test/wif/test_wif_async.py index 9db0301cc3..d4675c1d52 100644 --- a/test/wif/test_wif_async.py +++ b/test/wif/test_wif_async.py @@ -22,6 +22,9 @@ ACCOUNT = os.getenv("SNOWFLAKE_TEST_WIF_ACCOUNT") HOST = os.getenv("SNOWFLAKE_TEST_WIF_HOST") PROVIDER = os.getenv("SNOWFLAKE_TEST_WIF_PROVIDER") +EXPECTED_USERNAME = os.getenv("SNOWFLAKE_TEST_WIF_USERNAME") +IMPERSONATION_PATH = os.getenv("SNOWFLAKE_TEST_WIF_IMPERSONATION_PATH") +EXPECTED_USERNAME_IMPERSONATION = os.getenv("SNOWFLAKE_TEST_WIF_USERNAME_IMPERSONATION") @pytest.mark.wif @@ -34,8 +37,8 @@ async def test_wif_defined_provider_async(): "workload_identity_provider": PROVIDER, } assert await connect_and_execute_simple_query_async( - connection_params - ), "Failed to connect with using WIF - automatic provider detection" + connection_params, EXPECTED_USERNAME + ), f"Failed to connect with using WIF using provider {PROVIDER}" @pytest.mark.wif @@ -53,17 +56,48 @@ async def test_should_authenticate_using_oidc_async(): } assert await connect_and_execute_simple_query_async( - connection_params + connection_params, expected_user=None ), "Failed to connect using WIF with OIDC provider" -async def connect_and_execute_simple_query_async(connection_params) -> bool: +@pytest.mark.wif +@pytest.mark.aio +async def test_should_authenticate_with_impersonation_async(): + if not isinstance(IMPERSONATION_PATH, str) or not IMPERSONATION_PATH: + pytest.skip("Skipping test - IMPERSONATION_PATH is not set") + + logger.debug(f"Using impersonation path: {IMPERSONATION_PATH}") + impersonation_path_list = IMPERSONATION_PATH.split(",") + + connection_params = { + "host": HOST, + "account": ACCOUNT, + "authenticator": "WORKLOAD_IDENTITY", + "workload_identity_provider": PROVIDER, + "workload_identity_impersonation_path": impersonation_path_list, + } + + assert await connect_and_execute_simple_query_async( + connection_params, EXPECTED_USERNAME_IMPERSONATION + ), f"Failed to connect using WIF with provider {PROVIDER}" + + +async def connect_and_execute_simple_query_async( + connection_params, expected_user=None +) -> bool: try: logger.info("Trying to connect to Snowflake") - async with snowflake.connector.aio.connect(**connection_params) as con: - result = await con.cursor().execute("select 1;") - logger.debug(await result.fetchall()) - logger.info("Successfully connected to Snowflake") + async with snowflake.connector.aio.SnowflakeConnection( + **connection_params + ) as con: + result = await con.cursor().execute("select current_user();") + (user,) = await result.fetchone() + logger.debug(user) + if expected_user: + assert ( + expected_user == user + ), f"Expected user '{expected_user}', got user '{user}'" + logger.info(f"Successfully connected to Snowflake as {user}") return True except Exception as e: logger.error(e)