Skip to content

Commit beb34fd

Browse files
[async] WIF impersonation for GCP #2496
1 parent 2713c74 commit beb34fd

File tree

6 files changed

+230
-11
lines changed

6 files changed

+230
-11
lines changed

src/snowflake/connector/aio/_connection.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -418,10 +418,24 @@ async def __open_connection(self):
418418
"errno": ER_INVALID_WIF_SETTINGS,
419419
},
420420
)
421+
if (
422+
self._workload_identity_impersonation_path
423+
and self._workload_identity_provider != AttestationProvider.GCP
424+
):
425+
Error.errorhandler_wrapper(
426+
self,
427+
None,
428+
ProgrammingError,
429+
{
430+
"msg": "workload_identity_impersonation_path is currently only supported for GCP.",
431+
"errno": ER_INVALID_WIF_SETTINGS,
432+
},
433+
)
421434
self.auth_class = AuthByWorkloadIdentity(
422435
provider=self._workload_identity_provider,
423436
token=self._token,
424437
entra_resource=self._workload_identity_entra_resource,
438+
impersonation_path=self._workload_identity_impersonation_path,
425439
)
426440
else:
427441
# okta URL, e.g., https://<account>.okta.com/

src/snowflake/connector/aio/_wif_util.py

Lines changed: 91 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,10 @@
2525

2626
logger = logging.getLogger(__name__)
2727

28+
GCP_METADATA_SERVICE_ACCOUNT_BASE_URL = (
29+
"http://169.254.169.254/computeMetadata/v1/instance/service-accounts/default"
30+
)
31+
2832

2933
async def get_aws_region() -> str:
3034
"""Get the current AWS workload's region."""
@@ -81,30 +85,108 @@ async def create_aws_attestation() -> WorkloadIdentityAttestation:
8185
)
8286

8387

84-
async def create_gcp_attestation(
85-
session_manager: SessionManager | None = None,
86-
) -> WorkloadIdentityAttestation:
87-
"""Tries to create a workload identity attestation for GCP.
88+
async def get_gcp_access_token(session_manager: SessionManager) -> str:
89+
"""Gets a GCP access token from the metadata server.
90+
91+
If the application isn't running on GCP or no credentials were found, raises an error.
92+
"""
93+
try:
94+
res = await session_manager.request(
95+
method="GET",
96+
url=f"{GCP_METADATA_SERVICE_ACCOUNT_BASE_URL}/token",
97+
headers={
98+
"Metadata-Flavor": "Google",
99+
},
100+
)
101+
102+
content = await res.content.read()
103+
response_text = content.decode("utf-8")
104+
return json.loads(response_text)["access_token"]
105+
except Exception as e:
106+
raise ProgrammingError(
107+
msg=f"Error fetching GCP access token: {e}. Ensure the application is running on GCP.",
108+
errno=ER_WIF_CREDENTIALS_NOT_FOUND,
109+
)
110+
111+
112+
async def get_gcp_identity_token_via_impersonation(
113+
impersonation_path: list[str], session_manager: SessionManager
114+
) -> str:
115+
"""Gets a GCP identity token from the metadata server.
116+
117+
If the application isn't running on GCP or no credentials were found, raises an error.
118+
"""
119+
if not impersonation_path:
120+
raise ProgrammingError(
121+
msg="Error: impersonation_path cannot be empty.",
122+
errno=ER_WIF_CREDENTIALS_NOT_FOUND,
123+
)
124+
125+
current_sa_token = await get_gcp_access_token(session_manager)
126+
impersonation_path = [
127+
f"projects/-/serviceAccounts/{client_id}" for client_id in impersonation_path
128+
]
129+
try:
130+
res = await session_manager.post(
131+
url=f"https://iamcredentials.googleapis.com/v1/{impersonation_path[-1]}:generateIdToken",
132+
headers={
133+
"Authorization": f"Bearer {current_sa_token}",
134+
"Content-Type": "application/json",
135+
},
136+
json={
137+
"delegates": impersonation_path[:-1],
138+
"audience": SNOWFLAKE_AUDIENCE,
139+
},
140+
)
141+
142+
content = await res.content.read()
143+
response_text = content.decode("utf-8")
144+
return json.loads(response_text)["token"]
145+
except Exception as e:
146+
raise ProgrammingError(
147+
msg=f"Error fetching GCP identity token for impersonated GCP service account '{impersonation_path[-1]}': {e}. Ensure the application is running on GCP.",
148+
errno=ER_WIF_CREDENTIALS_NOT_FOUND,
149+
)
150+
151+
152+
async def get_gcp_identity_token(session_manager: SessionManager) -> str:
153+
"""Gets a GCP identity token from the metadata server.
88154
89155
If the application isn't running on GCP or no credentials were found, raises an error.
90156
"""
91157
try:
92158
res = await session_manager.request(
93159
method="GET",
94-
url=f"http://169.254.169.254/computeMetadata/v1/instance/service-accounts/default/identity?audience={SNOWFLAKE_AUDIENCE}",
160+
url=f"{GCP_METADATA_SERVICE_ACCOUNT_BASE_URL}/identity?audience={SNOWFLAKE_AUDIENCE}",
95161
headers={
96162
"Metadata-Flavor": "Google",
97163
},
98164
)
99165

100166
content = await res.content.read()
101-
jwt_str = content.decode("utf-8")
167+
return content.decode("utf-8")
102168
except Exception as e:
103169
raise ProgrammingError(
104-
msg=f"Error fetching GCP metadata: {e}. Ensure the application is running on GCP.",
170+
msg=f"Error fetching GCP identity token: {e}. Ensure the application is running on GCP.",
105171
errno=ER_WIF_CREDENTIALS_NOT_FOUND,
106172
)
107173

174+
175+
async def create_gcp_attestation(
176+
session_manager: SessionManager,
177+
impersonation_path: list[str] | None = None,
178+
) -> WorkloadIdentityAttestation:
179+
"""Tries to create a workload identity attestation for GCP.
180+
181+
If the application isn't running on GCP or no credentials were found, raises an error.
182+
"""
183+
if impersonation_path:
184+
jwt_str = await get_gcp_identity_token_via_impersonation(
185+
impersonation_path, session_manager
186+
)
187+
else:
188+
jwt_str = await get_gcp_identity_token(session_manager)
189+
108190
_, subject = extract_iss_and_sub_without_signature_verification(jwt_str)
109191
return WorkloadIdentityAttestation(
110192
AttestationProvider.GCP, jwt_str, {"sub": subject}
@@ -179,6 +261,7 @@ async def create_attestation(
179261
provider: AttestationProvider | None,
180262
entra_resource: str | None = None,
181263
token: str | None = None,
264+
impersonation_path: list[str] | None = None,
182265
session_manager: SessionManager | None = None,
183266
) -> WorkloadIdentityAttestation:
184267
"""Entry point to create an attestation using the given provider.
@@ -195,7 +278,7 @@ async def create_attestation(
195278
elif provider == AttestationProvider.AZURE:
196279
return await create_azure_attestation(entra_resource, session_manager)
197280
elif provider == AttestationProvider.GCP:
198-
return await create_gcp_attestation(session_manager)
281+
return await create_gcp_attestation(session_manager, impersonation_path)
199282
elif provider == AttestationProvider.OIDC:
200283
return create_oidc_attestation(token)
201284
else:

src/snowflake/connector/aio/auth/_workload_identity.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ def __init__(
2222
provider: AttestationProvider,
2323
token: str | None = None,
2424
entra_resource: str | None = None,
25+
impersonation_path: list[str] | None = None,
2526
**kwargs,
2627
) -> None:
2728
"""Initializes an instance with workload identity authentication."""
@@ -30,6 +31,7 @@ def __init__(
3031
provider=provider,
3132
token=token,
3233
entra_resource=entra_resource,
34+
impersonation_path=impersonation_path,
3335
**kwargs,
3436
)
3537

@@ -44,6 +46,7 @@ async def prepare(
4446
self.provider,
4547
self.entra_resource,
4648
self.token,
49+
self.impersonation_path,
4750
session_manager=conn._session_manager.clone() if conn else None,
4851
)
4952

test/unit/aio/test_auth_workload_identity_async.py

Lines changed: 48 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import os
99
from base64 import b64decode
1010
from unittest import mock
11+
from unittest.mock import AsyncMock
1112
from urllib.parse import parse_qs, urlparse
1213

1314
import aiohttp
@@ -18,7 +19,7 @@
1819
from snowflake.connector.aio.auth import AuthByWorkloadIdentity
1920
from snowflake.connector.errors import ProgrammingError
2021

21-
from ...csp_helpers import gen_dummy_id_token
22+
from ...csp_helpers import gen_dummy_access_token, gen_dummy_id_token
2223
from .csp_helpers_async import FakeAwsEnvironmentAsync, FakeGceMetadataServiceAsync
2324

2425
logger = logging.getLogger(__name__)
@@ -279,7 +280,7 @@ async def test_explicit_gcp_metadata_server_error_bubbles_up(exception):
279280
with pytest.raises(ProgrammingError) as excinfo:
280281
await auth_class.prepare(conn=None)
281282

282-
assert "Error fetching GCP metadata:" in str(excinfo.value)
283+
assert "Error fetching GCP identity token:" in str(excinfo.value)
283284
assert "Ensure the application is running on GCP." in str(excinfo.value)
284285

285286

@@ -307,6 +308,51 @@ async def test_explicit_gcp_generates_unique_assertion_content(
307308
assert auth_class.assertion_content == '{"_provider":"GCP","sub":"123456"}'
308309

309310

311+
@mock.patch("snowflake.connector.aio._session_manager.SessionManager.post")
312+
async def test_gcp_calls_correct_apis_and_populates_auth_data_for_final_sa(
313+
mock_post_request, fake_gce_metadata_service: FakeGceMetadataServiceAsync
314+
):
315+
fake_gce_metadata_service.sub = "sa1"
316+
impersonation_path = ["sa2", "sa3"]
317+
sa1_access_token = gen_dummy_access_token("sa1")
318+
sa3_id_token = gen_dummy_id_token("sa3")
319+
320+
# Mock the POST request response
321+
class AsyncResponse:
322+
def __init__(self, content):
323+
self._content = content
324+
self.content = mock.Mock()
325+
self.content.read = AsyncMock(return_value=content)
326+
327+
mock_post_request.return_value = AsyncResponse(
328+
json.dumps({"token": sa3_id_token}).encode("utf-8")
329+
)
330+
331+
auth_class = AuthByWorkloadIdentity(
332+
provider=AttestationProvider.GCP, impersonation_path=impersonation_path
333+
)
334+
await auth_class.prepare(conn=None)
335+
336+
mock_post_request.assert_called_once_with(
337+
url="https://iamcredentials.googleapis.com/v1/projects/-/serviceAccounts/sa3:generateIdToken",
338+
headers={
339+
"Authorization": f"Bearer {sa1_access_token}",
340+
"Content-Type": "application/json",
341+
},
342+
json={
343+
"delegates": ["projects/-/serviceAccounts/sa2"],
344+
"audience": "snowflakecomputing.com",
345+
},
346+
)
347+
348+
assert auth_class.assertion_content == '{"_provider":"GCP","sub":"sa3"}'
349+
assert await extract_api_data(auth_class) == {
350+
"AUTHENTICATOR": "WORKLOAD_IDENTITY",
351+
"PROVIDER": "GCP",
352+
"TOKEN": sa3_id_token,
353+
}
354+
355+
310356
# -- Azure Tests --
311357

312358

test/unit/aio/test_connection_async_unit.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -605,6 +605,7 @@ async def test_otel_error_message_async(caplog, mock_post_requests):
605605
"workload_identity_entra_resource",
606606
"api://0b2f151f-09a2-46eb-ad5a-39d5ebef917b",
607607
),
608+
("workload_identity_impersonation_path", ["subject-b", "subject-c"]),
608609
],
609610
)
610611
async def test_cannot_set_dependent_params_without_wlid_authenticator(
@@ -654,6 +655,79 @@ async def test_workload_identity_provider_is_required_for_wif_authenticator(
654655
assert expected_error_msg in str(excinfo.value)
655656

656657

658+
@pytest.mark.parametrize(
659+
"provider_param",
660+
[
661+
# Strongly-typed values.
662+
AttestationProvider.AWS,
663+
AttestationProvider.AZURE,
664+
AttestationProvider.OIDC,
665+
# String values.
666+
"AWS",
667+
"AZURE",
668+
"OIDC",
669+
],
670+
)
671+
async def test_workload_identity_impersonation_path_unsupported_for_non_gcp_providers(
672+
monkeypatch, provider_param
673+
):
674+
async def mock_authenticate(*_):
675+
pass
676+
677+
with monkeypatch.context() as m:
678+
m.setattr(
679+
"snowflake.connector.aio._connection.SnowflakeConnection._authenticate",
680+
mock_authenticate,
681+
)
682+
683+
with pytest.raises(ProgrammingError) as excinfo:
684+
await snowflake.connector.aio.connect(
685+
account="account",
686+
authenticator="WORKLOAD_IDENTITY",
687+
workload_identity_provider=provider_param,
688+
workload_identity_impersonation_path=[
689+
690+
],
691+
)
692+
assert (
693+
"workload_identity_impersonation_path is currently only supported for GCP."
694+
in str(excinfo.value)
695+
)
696+
697+
698+
@pytest.mark.parametrize(
699+
"provider_param",
700+
[
701+
AttestationProvider.GCP,
702+
"GCP",
703+
],
704+
)
705+
async def test_workload_identity_impersonation_path_supported_for_gcp_provider(
706+
monkeypatch, provider_param
707+
):
708+
async def mock_authenticate(*_):
709+
pass
710+
711+
with monkeypatch.context() as m:
712+
m.setattr(
713+
"snowflake.connector.aio._connection.SnowflakeConnection._authenticate",
714+
mock_authenticate,
715+
)
716+
717+
conn = await snowflake.connector.aio.connect(
718+
account="account",
719+
authenticator="WORKLOAD_IDENTITY",
720+
workload_identity_provider=provider_param,
721+
workload_identity_impersonation_path=[
722+
723+
],
724+
)
725+
assert conn.auth_class.provider == AttestationProvider.GCP
726+
assert conn.auth_class.impersonation_path == [
727+
728+
]
729+
730+
657731
@pytest.mark.parametrize(
658732
"provider_param, parsed_provider",
659733
[

test/wif/test_wif_async.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,6 @@ async def test_should_authenticate_using_oidc_async():
6262

6363
@pytest.mark.wif
6464
@pytest.mark.aio
65-
@pytest.mark.skip("Impersonation is still being developed")
6665
async def test_should_authenticate_with_impersonation_async():
6766
if not isinstance(IMPERSONATION_PATH, str) or not IMPERSONATION_PATH:
6867
pytest.skip("Skipping test - IMPERSONATION_PATH is not set")

0 commit comments

Comments
 (0)