Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/snowflake/connector/auth/workload_identity.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,9 @@ def update_body(self, body: dict[typing.Any, typing.Any]) -> None:
self.attestation
).value
body["data"]["TOKEN"] = self.attestation.credential
body["data"].setdefault("CLIENT_ENVIRONMENT", {})[
"WORKLOAD_IDENTITY_IMPERSONATION_PATH_LENGTH"
] = len(self.impersonation_path or [])

def prepare(
self, *, conn: SnowflakeConnection | None, **kwargs: typing.Any
Expand Down
68 changes: 66 additions & 2 deletions test/unit/test_auth_workload_identity.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import copy
import json
import logging
from base64 import b64decode
Expand All @@ -7,14 +8,18 @@
import jwt
import pytest

from snowflake.connector.auth import AuthByWorkloadIdentity
from snowflake.connector.auth import Auth, AuthByWorkloadIdentity
from snowflake.connector.errors import ProgrammingError
from snowflake.connector.vendored.requests.exceptions import (
ConnectTimeout,
HTTPError,
Timeout,
)
from snowflake.connector.wif_util import AttestationProvider, get_aws_sts_hostname
from snowflake.connector.wif_util import (
AttestationProvider,
WorkloadIdentityAttestation,
get_aws_sts_hostname,
)

from ..csp_helpers import (
FakeAwsEnvironment,
Expand Down Expand Up @@ -122,6 +127,58 @@ def test_wif_authenticator_is_case_insensitive(
assert isinstance(connection.auth_class, AuthByWorkloadIdentity)


@pytest.mark.parametrize(
"provider,additional_args",
[
(AttestationProvider.AWS, {}),
(AttestationProvider.GCP, {}),
(AttestationProvider.AZURE, {}),
(
AttestationProvider.OIDC,
{"token": gen_dummy_id_token(sub="service-1", iss="issuer-1")},
),
],
)
def test_wif_prepare_body_does_not_overwrite_fields(provider, additional_args):
ocsp_mode = mock.Mock()
ocsp_mode.name = "ocsp_mode"
session_manager = mock.Mock()
session_manager.clone = lambda max_retries: "session_manager"

req_body_before = Auth.base_auth_data(
"user",
"account",
"application",
"internal_application_name",
"internal_application_version",
ocsp_mode,
login_timeout=60 * 60,
network_timeout=60 * 60,
socket_timeout=60 * 60,
platform_detection_timeout_seconds=0.2,
session_manager=session_manager,
)
req_body_after = copy.deepcopy(req_body_before)
auth_class = AuthByWorkloadIdentity(provider=provider, **additional_args)
auth_class.attestation = WorkloadIdentityAttestation(
provider=provider, credential=None, user_identifier_components=None
)
auth_class.update_body(req_body_after)

# Check that the values in the body before are a strict subset of the values in the body after.
# Must use all() for this comparison because lists are not hashable
assert all(
[
req_body_before["data"]["CLIENT_ENVIRONMENT"][k]
== req_body_after["data"]["CLIENT_ENVIRONMENT"][k]
for k in req_body_before["data"]["CLIENT_ENVIRONMENT"]
]
)
req_body_before["data"].pop("CLIENT_ENVIRONMENT")
req_body_after["data"].pop("CLIENT_ENVIRONMENT")
assert set(req_body_before["data"].items()) <= set(req_body_after["data"].items())


# -- OIDC Tests --


Expand All @@ -136,6 +193,7 @@ def test_explicit_oidc_valid_inline_token_plumbed_to_api():
"AUTHENTICATOR": "WORKLOAD_IDENTITY",
"PROVIDER": "OIDC",
"TOKEN": dummy_token,
"CLIENT_ENVIRONMENT": {"WORKLOAD_IDENTITY_IMPERSONATION_PATH_LENGTH": 0},
}


Expand Down Expand Up @@ -191,6 +249,9 @@ def test_explicit_aws_encodes_audience_host_signature_to_api(
data = extract_api_data(auth_class)
assert data["AUTHENTICATOR"] == "WORKLOAD_IDENTITY"
assert data["PROVIDER"] == "AWS"
assert (
data["CLIENT_ENVIRONMENT"]["WORKLOAD_IDENTITY_IMPERSONATION_PATH_LENGTH"] == 0
)
verify_aws_token(data["TOKEN"], fake_aws_environment.region)


Expand Down Expand Up @@ -324,6 +385,7 @@ def test_explicit_gcp_plumbs_token_to_api(
"AUTHENTICATOR": "WORKLOAD_IDENTITY",
"PROVIDER": "GCP",
"TOKEN": fake_gce_metadata_service.token,
"CLIENT_ENVIRONMENT": {"WORKLOAD_IDENTITY_IMPERSONATION_PATH_LENGTH": 0},
}


Expand Down Expand Up @@ -373,6 +435,7 @@ def test_gcp_calls_correct_apis_and_populates_auth_data_for_final_sa(
"AUTHENTICATOR": "WORKLOAD_IDENTITY",
"PROVIDER": "GCP",
"TOKEN": sa3_id_token,
"CLIENT_ENVIRONMENT": {"WORKLOAD_IDENTITY_IMPERSONATION_PATH_LENGTH": 2},
}


Expand Down Expand Up @@ -425,6 +488,7 @@ def test_explicit_azure_plumbs_token_to_api(fake_azure_metadata_service):
"AUTHENTICATOR": "WORKLOAD_IDENTITY",
"PROVIDER": "AZURE",
"TOKEN": fake_azure_metadata_service.token,
"CLIENT_ENVIRONMENT": {"WORKLOAD_IDENTITY_IMPERSONATION_PATH_LENGTH": 0},
}


Expand Down
Loading