Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 10 additions & 4 deletions ci/test_wif.sh
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,11 @@ run_tests_and_set_result() {
local host="$2"
local snowflake_host="$3"
local rsa_key_path="$4"
local snowflake_user="$5"
local impersonation_path="$6"
local snowflake_user_for_impersonation="$7"

ssh -i "$rsa_key_path" -o IdentitiesOnly=yes -p 443 "$host" env BRANCH="$BRANCH" SNOWFLAKE_TEST_WIF_HOST="$snowflake_host" SNOWFLAKE_TEST_WIF_PROVIDER="$provider" SNOWFLAKE_TEST_WIF_ACCOUNT="$SNOWFLAKE_TEST_WIF_ACCOUNT" bash << EOF
ssh -i "$rsa_key_path" -o IdentitiesOnly=yes -p 443 "$host" env BRANCH="$BRANCH" SNOWFLAKE_TEST_WIF_HOST="$snowflake_host" SNOWFLAKE_TEST_WIF_PROVIDER="$provider" SNOWFLAKE_TEST_WIF_ACCOUNT="$SNOWFLAKE_TEST_WIF_ACCOUNT" SNOWFLAKE_TEST_WIF_USERNAME="$snowflake_user" SNOWFLAKE_TEST_WIF_IMPERSONATION_PATH="$impersonation_path" SNOWFLAKE_TEST_WIF_USERNAME_IMPERSONATION="$snowflake_user_for_impersonation" bash << EOF
set -e
set -o pipefail
docker run \
Expand All @@ -24,6 +27,9 @@ run_tests_and_set_result() {
-e SNOWFLAKE_TEST_WIF_PROVIDER \
-e SNOWFLAKE_TEST_WIF_HOST \
-e SNOWFLAKE_TEST_WIF_ACCOUNT \
-e SNOWFLAKE_TEST_WIF_USERNAME \
-e SNOWFLAKE_TEST_WIF_IMPERSONATION_PATH \
-e SNOWFLAKE_TEST_WIF_USERNAME_IMPERSONATION \
snowflakedb/client-python-test:1 \
bash -c "
echo 'Running tests on branch: \$BRANCH'
Expand Down Expand Up @@ -77,9 +83,9 @@ setup_parameters
# Run tests for all cloud providers
EXIT_STATUS=0
set +e # Don't exit on first failure
run_tests_and_set_result "AZURE" "$HOST_AZURE" "$SNOWFLAKE_TEST_WIF_HOST_AZURE" "$RSA_KEY_PATH_AWS_AZURE"
run_tests_and_set_result "AWS" "$HOST_AWS" "$SNOWFLAKE_TEST_WIF_HOST_AWS" "$RSA_KEY_PATH_AWS_AZURE"
run_tests_and_set_result "GCP" "$HOST_GCP" "$SNOWFLAKE_TEST_WIF_HOST_GCP" "$RSA_KEY_PATH_GCP"
run_tests_and_set_result "AZURE" "$HOST_AZURE" "$SNOWFLAKE_TEST_WIF_HOST_AZURE" "$RSA_KEY_PATH_AWS_AZURE" "$SNOWFLAKE_TEST_WIF_USERNAME_AZURE" "$SNOWFLAKE_TEST_WIF_IMPERSONATION_PATH_AZURE" "$SNOWFLAKE_TEST_WIF_USERNAME_AZURE_IMPERSONATION"
run_tests_and_set_result "AWS" "$HOST_AWS" "$SNOWFLAKE_TEST_WIF_HOST_AWS" "$RSA_KEY_PATH_AWS_AZURE" "$SNOWFLAKE_TEST_WIF_USERNAME_AWS" "$SNOWFLAKE_TEST_WIF_IMPERSONATION_PATH_AWS" "$SNOWFLAKE_TEST_WIF_USERNAME_AWS_IMPERSONATION"
run_tests_and_set_result "GCP" "$HOST_GCP" "$SNOWFLAKE_TEST_WIF_HOST_GCP" "$RSA_KEY_PATH_GCP" "$SNOWFLAKE_TEST_WIF_USERNAME_GCP" "$SNOWFLAKE_TEST_WIF_IMPERSONATION_PATH_GCP" "$SNOWFLAKE_TEST_WIF_USERNAME_GCP_IMPERSONATION"
set -e # Re-enable exit on error
echo "Exit status: $EXIT_STATUS"
exit $EXIT_STATUS
Binary file modified ci/wif/parameters/parameters_wif.json.gpg
Binary file not shown.
Binary file modified ci/wif/parameters/rsa_wif_aws_azure.gpg
Binary file not shown.
Binary file modified ci/wif/parameters/rsa_wif_gcp.gpg
Binary file not shown.
14 changes: 11 additions & 3 deletions ci/wif/test_wif.sh
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,14 @@ set -o pipefail
export SF_OCSP_TEST_MODE=true
export RUN_WIF_TESTS=true

/opt/python/cp39-cp39/bin/python -m pip install --break-system-packages -e '.[aio]'
/opt/python/cp39-cp39/bin/python -m pip install --break-system-packages pytest
/opt/python/cp39-cp39/bin/python -m pytest test/wif/*
# setup pytest
/opt/python/cp312-cp312/bin/python -m pip install --break-system-packages pytest pytest-asyncio

# test WIF without asyncio installed
/opt/python/cp312-cp312/bin/python -m pip install --break-system-packages -e .
/opt/python/cp312-cp312/bin/python -m pytest test/wif/ --ignore test/wif/test_wif_async.py

# test WIF with asyncio installed
# /opt/python/cp312-cp312/bin/python -m pip install --break-system-packages -e '.[aio]'
# run all tests to see whether installation does not break anything
# /opt/python/cp312-cp312/bin/python -m pytest test/wif/
18 changes: 18 additions & 0 deletions src/snowflake/connector/aio/_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,10 +414,28 @@ async def __open_connection(self):
"errno": ER_INVALID_WIF_SETTINGS,
},
)
if (
self._workload_identity_impersonation_path
and self._workload_identity_provider
not in (
AttestationProvider.GCP,
AttestationProvider.AWS,
)
):
Error.errorhandler_wrapper(
self,
None,
ProgrammingError,
{
"msg": "workload_identity_impersonation_path is currently only supported for GCP and AWS.",
"errno": ER_INVALID_WIF_SETTINGS,
},
)
self.auth_class = AuthByWorkloadIdentity(
provider=self._workload_identity_provider,
token=self._token,
entra_resource=self._workload_identity_entra_resource,
impersonation_path=self._workload_identity_impersonation_path,
)
else:
# okta URL, e.g., https://<account>.okta.com/
Expand Down
6 changes: 5 additions & 1 deletion src/snowflake/connector/aio/_session_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,7 +336,10 @@ class SessionManager(_RequestVerbsUsingSessionMixin, SessionManagerSync):
"""

def __init__(
self, config: AioHttpConfig | None = None, **http_config_kwargs
self,
config: AioHttpConfig | None = None,
max_retries: int | None = None, # TODO: remove this after rebase
**http_config_kwargs,
) -> None:
"""Create a new async SessionManager."""
if config is None:
Expand Down Expand Up @@ -435,6 +438,7 @@ def clone(
*,
use_pooling: bool | None = None,
connector_factory: ConnectorFactory | None = None,
**kwargs, # TODO: remove this after rebase
) -> SessionManager:
"""Return a new async SessionManager sharing this instance's config."""
overrides: dict[str, Any] = {}
Expand Down
132 changes: 120 additions & 12 deletions src/snowflake/connector/aio/_wif_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@

logger = logging.getLogger(__name__)

GCP_METADATA_SERVICE_ACCOUNT_BASE_URL = (
"http://169.254.169.254/computeMetadata/v1/instance/service-accounts/default"
)


async def get_aws_region() -> str:
"""Get the current AWS workload's region."""
Expand All @@ -41,12 +45,37 @@ async def get_aws_region() -> str:
return region


async def create_aws_attestation() -> WorkloadIdentityAttestation:
async def get_aws_session(impersonation_path: list[str] | None = None):
"""Creates an aioboto3 session with the appropriate credentials.

If impersonation_path is provided, this uses the role at the end of the path. Otherwise, this uses the role attached to the current workload.
"""
session = aioboto3.Session()

impersonation_path = impersonation_path or []
for arn in impersonation_path:
async with session.client("sts") as sts_client:
response = await sts_client.assume_role(
RoleArn=arn, RoleSessionName="identity-federation-session"
)
creds = response["Credentials"]
session = aioboto3.Session(
aws_access_key_id=creds["AccessKeyId"],
aws_secret_access_key=creds["SecretAccessKey"],
aws_session_token=creds["SessionToken"],
)
return session


async def create_aws_attestation(
impersonation_path: list[str] | None = None,
) -> WorkloadIdentityAttestation:
"""Tries to create a workload identity attestation for AWS.

If the application isn't running on AWS or no credentials were found, raises an error.
"""
session = aioboto3.Session()
session = await get_aws_session(impersonation_path)

aws_creds = await session.get_credentials()
if not aws_creds:
raise ProgrammingError(
Expand Down Expand Up @@ -81,30 +110,108 @@ async def create_aws_attestation() -> WorkloadIdentityAttestation:
)


async def create_gcp_attestation(
session_manager: SessionManager | None = None,
) -> WorkloadIdentityAttestation:
"""Tries to create a workload identity attestation for GCP.
async def get_gcp_access_token(session_manager: SessionManager) -> str:
"""Gets a GCP access token from the metadata server.

If the application isn't running on GCP or no credentials were found, raises an error.
"""
try:
res = await session_manager.request(
method="GET",
url=f"http://169.254.169.254/computeMetadata/v1/instance/service-accounts/default/identity?audience={SNOWFLAKE_AUDIENCE}",
url=f"{GCP_METADATA_SERVICE_ACCOUNT_BASE_URL}/token",
headers={
"Metadata-Flavor": "Google",
},
)

content = await res.content.read()
jwt_str = content.decode("utf-8")
response_text = content.decode("utf-8")
return json.loads(response_text)["access_token"]
except Exception as e:
raise ProgrammingError(
msg=f"Error fetching GCP metadata: {e}. Ensure the application is running on GCP.",
msg=f"Error fetching GCP access token: {e}. Ensure the application is running on GCP.",
errno=ER_WIF_CREDENTIALS_NOT_FOUND,
)


async def get_gcp_identity_token_via_impersonation(
impersonation_path: list[str], session_manager: SessionManager
) -> str:
"""Gets a GCP identity token from the metadata server.

If the application isn't running on GCP or no credentials were found, raises an error.
"""
if not impersonation_path:
raise ProgrammingError(
msg="Error: impersonation_path cannot be empty.",
errno=ER_WIF_CREDENTIALS_NOT_FOUND,
)

current_sa_token = await get_gcp_access_token(session_manager)
impersonation_path = [
f"projects/-/serviceAccounts/{client_id}" for client_id in impersonation_path
]
try:
res = await session_manager.post(
url=f"https://iamcredentials.googleapis.com/v1/{impersonation_path[-1]}:generateIdToken",
headers={
"Authorization": f"Bearer {current_sa_token}",
"Content-Type": "application/json",
},
json={
"delegates": impersonation_path[:-1],
"audience": SNOWFLAKE_AUDIENCE,
},
)

content = await res.content.read()
response_text = content.decode("utf-8")
return json.loads(response_text)["token"]
except Exception as e:
raise ProgrammingError(
msg=f"Error fetching GCP identity token for impersonated GCP service account '{impersonation_path[-1]}': {e}. Ensure the application is running on GCP.",
errno=ER_WIF_CREDENTIALS_NOT_FOUND,
)


async def get_gcp_identity_token(session_manager: SessionManager) -> str:
"""Gets a GCP identity token from the metadata server.

If the application isn't running on GCP or no credentials were found, raises an error.
"""
try:
res = await session_manager.request(
method="GET",
url=f"{GCP_METADATA_SERVICE_ACCOUNT_BASE_URL}/identity?audience={SNOWFLAKE_AUDIENCE}",
headers={
"Metadata-Flavor": "Google",
},
)

content = await res.content.read()
return content.decode("utf-8")
except Exception as e:
raise ProgrammingError(
msg=f"Error fetching GCP identity token: {e}. Ensure the application is running on GCP.",
errno=ER_WIF_CREDENTIALS_NOT_FOUND,
)


async def create_gcp_attestation(
session_manager: SessionManager,
impersonation_path: list[str] | None = None,
) -> WorkloadIdentityAttestation:
"""Tries to create a workload identity attestation for GCP.

If the application isn't running on GCP or no credentials were found, raises an error.
"""
if impersonation_path:
jwt_str = await get_gcp_identity_token_via_impersonation(
impersonation_path, session_manager
)
else:
jwt_str = await get_gcp_identity_token(session_manager)

_, subject = extract_iss_and_sub_without_signature_verification(jwt_str)
return WorkloadIdentityAttestation(
AttestationProvider.GCP, jwt_str, {"sub": subject}
Expand Down Expand Up @@ -179,6 +286,7 @@ async def create_attestation(
provider: AttestationProvider | None,
entra_resource: str | None = None,
token: str | None = None,
impersonation_path: list[str] | None = None,
session_manager: SessionManager | None = None,
) -> WorkloadIdentityAttestation:
"""Entry point to create an attestation using the given provider.
Expand All @@ -189,15 +297,15 @@ async def create_attestation(
session_manager = (
session_manager.clone()
if session_manager
else SessionManagerFactory.get_manager(use_pooling=True)
else SessionManager(use_pooling=True, max_retries=0)
)

if provider == AttestationProvider.AWS:
return await create_aws_attestation()
return await create_aws_attestation(impersonation_path)
elif provider == AttestationProvider.AZURE:
return await create_azure_attestation(entra_resource, session_manager)
elif provider == AttestationProvider.GCP:
return await create_gcp_attestation(session_manager)
return await create_gcp_attestation(session_manager, impersonation_path)
elif provider == AttestationProvider.OIDC:
return create_oidc_attestation(token)
else:
Expand Down
7 changes: 6 additions & 1 deletion src/snowflake/connector/aio/auth/_workload_identity.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ def __init__(
provider: AttestationProvider,
token: str | None = None,
entra_resource: str | None = None,
impersonation_path: list[str] | None = None,
**kwargs,
) -> None:
"""Initializes an instance with workload identity authentication."""
Expand All @@ -30,6 +31,7 @@ def __init__(
provider=provider,
token=token,
entra_resource=entra_resource,
impersonation_path=impersonation_path,
**kwargs,
)

Expand All @@ -44,7 +46,10 @@ async def prepare(
self.provider,
self.entra_resource,
self.token,
session_manager=conn._session_manager.clone() if conn else None,
self.impersonation_path,
session_manager=(
conn._session_manager.clone(max_retries=0) if conn else None
),
)

async def reauthenticate(self, **kwargs: Any) -> dict[str, bool]:
Expand Down
10 changes: 9 additions & 1 deletion src/snowflake/connector/auth/workload_identity.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,12 +55,14 @@ def __init__(
provider: AttestationProvider,
token: str | None = None,
entra_resource: str | None = None,
impersonation_path: list[str] | None = None,
**kwargs,
) -> None:
super().__init__(**kwargs)
self.provider = provider
self.token = token
self.entra_resource = entra_resource
self.impersonation_path = impersonation_path

self.attestation: WorkloadIdentityAttestation | None = None

Expand All @@ -76,6 +78,9 @@ def update_body(self, body: dict[typing.Any, typing.Any]) -> None:
self.attestation
).value
body["data"]["TOKEN"] = self.attestation.credential
body["data"].setdefault("CLIENT_ENVIRONMENT", {})[
"WORKLOAD_IDENTITY_IMPERSONATION_PATH_LENGTH"
] = len(self.impersonation_path or [])

def prepare(
self, *, conn: SnowflakeConnection | None, **kwargs: typing.Any
Expand All @@ -85,7 +90,10 @@ def prepare(
self.provider,
self.entra_resource,
self.token,
session_manager=conn._session_manager.clone() if conn else None,
self.impersonation_path,
session_manager=(
conn._session_manager.clone(max_retries=0) if conn else None
),
)

def reauthenticate(self, **kwargs: typing.Any) -> dict[str, bool]:
Expand Down
Loading
Loading