Skip to content

Commit 2713c74

Browse files
sfc-gh-eqinsfc-gh-pmansour
authored andcommitted
Support WIF Impersonation on GCP workloads (#2496)
Co-authored-by: Peter Mansour <[email protected]>
1 parent 8fb58f3 commit 2713c74

File tree

7 files changed

+240
-12
lines changed

7 files changed

+240
-12
lines changed

src/snowflake/connector/auth/workload_identity.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,12 +55,14 @@ def __init__(
5555
provider: AttestationProvider,
5656
token: str | None = None,
5757
entra_resource: str | None = None,
58+
impersonation_path: list[str] | None = None,
5859
**kwargs,
5960
) -> None:
6061
super().__init__(**kwargs)
6162
self.provider = provider
6263
self.token = token
6364
self.entra_resource = entra_resource
65+
self.impersonation_path = impersonation_path
6466

6567
self.attestation: WorkloadIdentityAttestation | None = None
6668

@@ -85,7 +87,10 @@ def prepare(
8587
self.provider,
8688
self.entra_resource,
8789
self.token,
88-
session_manager=conn._session_manager.clone() if conn else None,
90+
self.impersonation_path,
91+
session_manager=(
92+
conn._session_manager.clone(max_retries=0) if conn else None
93+
),
8994
)
9095

9196
def reauthenticate(self, **kwargs: typing.Any) -> dict[str, bool]:

src/snowflake/connector/connection.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -214,6 +214,7 @@ def _get_private_bytes_from_file(
214214
"authenticator": (DEFAULT_AUTHENTICATOR, (type(None), str)),
215215
"workload_identity_provider": (None, (type(None), AttestationProvider)),
216216
"workload_identity_entra_resource": (None, (type(None), str)),
217+
"workload_identity_impersonation_path": (None, (type(None), list[str])),
217218
"mfa_callback": (None, (type(None), Callable)),
218219
"password_callback": (None, (type(None), Callable)),
219220
"auth_class": (None, (type(None), AuthByPlugin)),
@@ -1355,10 +1356,24 @@ def __open_connection(self):
13551356
"errno": ER_INVALID_WIF_SETTINGS,
13561357
},
13571358
)
1359+
if (
1360+
self._workload_identity_impersonation_path
1361+
and self._workload_identity_provider != AttestationProvider.GCP
1362+
):
1363+
Error.errorhandler_wrapper(
1364+
self,
1365+
None,
1366+
ProgrammingError,
1367+
{
1368+
"msg": "workload_identity_impersonation_path is currently only supported for GCP.",
1369+
"errno": ER_INVALID_WIF_SETTINGS,
1370+
},
1371+
)
13581372
self.auth_class = AuthByWorkloadIdentity(
13591373
provider=self._workload_identity_provider,
13601374
token=self._token,
13611375
entra_resource=self._workload_identity_entra_resource,
1376+
impersonation_path=self._workload_identity_impersonation_path,
13621377
)
13631378
else:
13641379
# okta URL, e.g., https://<account>.okta.com/
@@ -1531,6 +1546,7 @@ def __config(self, **kwargs):
15311546
workload_identity_dependent_options = [
15321547
"workload_identity_provider",
15331548
"workload_identity_entra_resource",
1549+
"workload_identity_impersonation_path",
15341550
]
15351551
for dependent_option in workload_identity_dependent_options:
15361552
if (

src/snowflake/connector/wif_util.py

Lines changed: 86 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,9 @@
2020
logger = logging.getLogger(__name__)
2121
SNOWFLAKE_AUDIENCE = "snowflakecomputing.com"
2222
DEFAULT_ENTRA_SNOWFLAKE_RESOURCE = "api://fd3f753b-eed3-462c-b6a7-a4b5bb650aad"
23+
GCP_METADATA_SERVICE_ACCOUNT_BASE_URL = (
24+
"http://169.254.169.254/computeMetadata/v1/instance/service-accounts/default"
25+
)
2326

2427

2528
@unique
@@ -184,29 +187,103 @@ def create_aws_attestation(
184187
)
185188

186189

187-
def create_gcp_attestation(
188-
session_manager: SessionManager | None = None,
189-
) -> WorkloadIdentityAttestation:
190-
"""Tries to create a workload identity attestation for GCP.
190+
def get_gcp_access_token(session_manager: SessionManager) -> str:
191+
"""Gets a GCP access token from the metadata server.
192+
193+
If the application isn't running on GCP or no credentials were found, raises an error.
194+
"""
195+
try:
196+
res = session_manager.request(
197+
method="GET",
198+
url=f"{GCP_METADATA_SERVICE_ACCOUNT_BASE_URL}/token",
199+
headers={
200+
"Metadata-Flavor": "Google",
201+
},
202+
)
203+
res.raise_for_status()
204+
return res.json()["access_token"]
205+
except Exception as e:
206+
raise ProgrammingError(
207+
msg=f"Error fetching GCP access token: {e}. Ensure the application is running on GCP.",
208+
errno=ER_WIF_CREDENTIALS_NOT_FOUND,
209+
)
210+
211+
212+
def get_gcp_identity_token_via_impersonation(
213+
impersonation_path: list[str], session_manager: SessionManager
214+
) -> str:
215+
"""Gets a GCP identity token from the metadata server.
216+
217+
If the application isn't running on GCP or no credentials were found, raises an error.
218+
"""
219+
if not impersonation_path:
220+
raise ProgrammingError(
221+
msg="Error: impersonation_path cannot be empty.",
222+
errno=ER_WIF_CREDENTIALS_NOT_FOUND,
223+
)
224+
225+
current_sa_token = get_gcp_access_token(session_manager)
226+
impersonation_path = [
227+
f"projects/-/serviceAccounts/{client_id}" for client_id in impersonation_path
228+
]
229+
try:
230+
res = session_manager.post(
231+
url=f"https://iamcredentials.googleapis.com/v1/{impersonation_path[-1]}:generateIdToken",
232+
headers={
233+
"Authorization": f"Bearer {current_sa_token}",
234+
"Content-Type": "application/json",
235+
},
236+
json={
237+
"delegates": impersonation_path[:-1],
238+
"audience": SNOWFLAKE_AUDIENCE,
239+
},
240+
)
241+
res.raise_for_status()
242+
return res.json()["token"]
243+
except Exception as e:
244+
raise ProgrammingError(
245+
msg=f"Error fetching GCP identity token for impersonated GCP service account '{impersonation_path[-1]}': {e}. Ensure the application is running on GCP.",
246+
errno=ER_WIF_CREDENTIALS_NOT_FOUND,
247+
)
248+
249+
250+
def get_gcp_identity_token(session_manager: SessionManager) -> str:
251+
"""Gets a GCP identity token from the metadata server.
191252
192253
If the application isn't running on GCP or no credentials were found, raises an error.
193254
"""
194255
try:
195256
res = session_manager.request(
196257
method="GET",
197-
url=f"http://169.254.169.254/computeMetadata/v1/instance/service-accounts/default/identity?audience={SNOWFLAKE_AUDIENCE}",
258+
url=f"{GCP_METADATA_SERVICE_ACCOUNT_BASE_URL}/identity?audience={SNOWFLAKE_AUDIENCE}",
198259
headers={
199260
"Metadata-Flavor": "Google",
200261
},
201262
)
202263
res.raise_for_status()
264+
return res.content.decode("utf-8")
203265
except Exception as e:
204266
raise ProgrammingError(
205-
msg=f"Error fetching GCP metadata: {e}. Ensure the application is running on GCP.",
267+
msg=f"Error fetching GCP identity token: {e}. Ensure the application is running on GCP.",
206268
errno=ER_WIF_CREDENTIALS_NOT_FOUND,
207269
)
208270

209-
jwt_str = res.content.decode("utf-8")
271+
272+
def create_gcp_attestation(
273+
session_manager: SessionManager,
274+
impersonation_path: list[str] | None = None,
275+
) -> WorkloadIdentityAttestation:
276+
"""Tries to create a workload identity attestation for GCP.
277+
278+
If the application isn't running on GCP or no credentials were found, raises an error.
279+
"""
280+
if impersonation_path:
281+
jwt_str = get_gcp_identity_token_via_impersonation(
282+
impersonation_path, session_manager
283+
)
284+
else:
285+
jwt_str = get_gcp_identity_token(session_manager)
286+
210287
_, subject = extract_iss_and_sub_without_signature_verification(jwt_str)
211288
return WorkloadIdentityAttestation(
212289
AttestationProvider.GCP, jwt_str, {"sub": subject}
@@ -295,6 +372,7 @@ def create_attestation(
295372
provider: AttestationProvider,
296373
entra_resource: str | None = None,
297374
token: str | None = None,
375+
impersonation_path: list[str] | None = None,
298376
session_manager: SessionManager | None = None,
299377
) -> WorkloadIdentityAttestation:
300378
"""Entry point to create an attestation using the given provider.
@@ -311,7 +389,7 @@ def create_attestation(
311389
elif provider == AttestationProvider.AZURE:
312390
return create_azure_attestation(entra_resource, session_manager)
313391
elif provider == AttestationProvider.GCP:
314-
return create_gcp_attestation(session_manager)
392+
return create_gcp_attestation(session_manager, impersonation_path)
315393
elif provider == AttestationProvider.OIDC:
316394
return create_oidc_attestation(token)
317395
else:

test/csp_helpers.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,13 @@ def gen_dummy_id_token(
4040
)
4141

4242

43+
def gen_dummy_access_token(sub="test-subject") -> str:
44+
"""Generates a dummy access token using the given subject."""
45+
key = "secret"
46+
logger.debug(f"Generating dummy access token for subject {sub}")
47+
return (sub + key).encode("utf-8").hex()
48+
49+
4350
def build_response(content: bytes, status_code: int = 200, headers=None) -> Response:
4451
"""Builds a requests.Response object with the given status code and content."""
4552
response = Response()
@@ -285,6 +292,19 @@ def handle_request(self, method, parsed_url, headers, timeout):
285292
audience = query_string["audience"][0]
286293
self.token = gen_dummy_id_token(sub=self.sub, iss=self.iss, aud=audience)
287294
return build_response(self.token.encode("utf-8"))
295+
elif (
296+
method == "GET"
297+
and parsed_url.path
298+
== "/computeMetadata/v1/instance/service-accounts/default/token"
299+
and headers.get("Metadata-Flavor") == "Google"
300+
):
301+
self.token = gen_dummy_access_token(sub=self.sub)
302+
ret = {
303+
"access_token": self.token,
304+
"expires_in": 3599,
305+
"token_type": "Bearer",
306+
}
307+
return build_response(json.dumps(ret).encode("utf-8"))
288308
else:
289309
# Reject malformed requests.
290310
raise HTTPError()

test/unit/test_auth_workload_identity.py

Lines changed: 46 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,13 @@
1717
)
1818
from snowflake.connector.wif_util import AttestationProvider, get_aws_sts_hostname
1919

20-
from ..csp_helpers import FakeAwsEnvironment, FakeGceMetadataService, gen_dummy_id_token
20+
from ..csp_helpers import (
21+
FakeAwsEnvironment,
22+
FakeGceMetadataService,
23+
build_response,
24+
gen_dummy_access_token,
25+
gen_dummy_id_token,
26+
)
2127

2228
logger = logging.getLogger(__name__)
2329

@@ -289,7 +295,7 @@ def test_explicit_gcp_metadata_server_error_bubbles_up(exception):
289295
with pytest.raises(ProgrammingError) as excinfo:
290296
auth_class.prepare(conn=None)
291297

292-
assert "Error fetching GCP metadata:" in str(excinfo.value)
298+
assert "Error fetching GCP identity token:" in str(excinfo.value)
293299
assert "Ensure the application is running on GCP." in str(excinfo.value)
294300

295301

@@ -317,6 +323,44 @@ def test_explicit_gcp_generates_unique_assertion_content(
317323
assert auth_class.assertion_content == '{"_provider":"GCP","sub":"123456"}'
318324

319325

326+
@mock.patch("snowflake.connector.session_manager.SessionManager.post")
327+
def test_gcp_calls_correct_apis_and_populates_auth_data_for_final_sa(
328+
mock_post_request, fake_gce_metadata_service: FakeGceMetadataService
329+
):
330+
fake_gce_metadata_service.sub = "sa1"
331+
impersonation_path = ["sa2", "sa3"]
332+
sa1_access_token = gen_dummy_access_token("sa1")
333+
sa3_id_token = gen_dummy_id_token("sa3")
334+
335+
mock_post_request.return_value = build_response(
336+
json.dumps({"token": sa3_id_token}).encode("utf-8")
337+
)
338+
339+
auth_class = AuthByWorkloadIdentity(
340+
provider=AttestationProvider.GCP, impersonation_path=impersonation_path
341+
)
342+
auth_class.prepare(conn=None)
343+
344+
mock_post_request.assert_called_once_with(
345+
url="https://iamcredentials.googleapis.com/v1/projects/-/serviceAccounts/sa3:generateIdToken",
346+
headers={
347+
"Authorization": f"Bearer {sa1_access_token}",
348+
"Content-Type": "application/json",
349+
},
350+
json={
351+
"delegates": ["projects/-/serviceAccounts/sa2"],
352+
"audience": "snowflakecomputing.com",
353+
},
354+
)
355+
356+
assert auth_class.assertion_content == '{"_provider":"GCP","sub":"sa3"}'
357+
assert extract_api_data(auth_class) == {
358+
"AUTHENTICATOR": "WORKLOAD_IDENTITY",
359+
"PROVIDER": "GCP",
360+
"TOKEN": sa3_id_token,
361+
}
362+
363+
320364
# -- Azure Tests --
321365

322366

test/unit/test_connection.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -631,6 +631,7 @@ def test_otel_error_message(caplog, mock_post_requests):
631631
"workload_identity_entra_resource",
632632
"api://0b2f151f-09a2-46eb-ad5a-39d5ebef917b",
633633
),
634+
("workload_identity_impersonation_path", ["subject-b", "subject-c"]),
634635
],
635636
)
636637
def test_cannot_set_dependent_params_without_wlid_authenticator(
@@ -679,6 +680,71 @@ def test_workload_identity_provider_is_required_for_wif_authenticator(
679680
assert expected_error_msg in str(excinfo.value)
680681

681682

683+
@pytest.mark.parametrize(
684+
"provider_param",
685+
[
686+
# Strongly-typed values.
687+
AttestationProvider.AWS,
688+
AttestationProvider.AZURE,
689+
AttestationProvider.OIDC,
690+
# String values.
691+
"AWS",
692+
"AZURE",
693+
"OIDC",
694+
],
695+
)
696+
def test_workload_identity_impersonation_path_unsupported_for_non_gcp_providers(
697+
monkeypatch, provider_param
698+
):
699+
with monkeypatch.context() as m:
700+
m.setattr(
701+
"snowflake.connector.SnowflakeConnection._authenticate", lambda *_: None
702+
)
703+
704+
with pytest.raises(ProgrammingError) as excinfo:
705+
snowflake.connector.connect(
706+
account="account",
707+
authenticator="WORKLOAD_IDENTITY",
708+
workload_identity_provider=provider_param,
709+
workload_identity_impersonation_path=[
710+
711+
],
712+
)
713+
assert (
714+
"workload_identity_impersonation_path is currently only supported for GCP."
715+
in str(excinfo.value)
716+
)
717+
718+
719+
@pytest.mark.parametrize(
720+
"provider_param",
721+
[
722+
AttestationProvider.GCP,
723+
"GCP",
724+
],
725+
)
726+
def test_workload_identity_impersonation_path_supported_for_gcp_provider(
727+
monkeypatch, provider_param
728+
):
729+
with monkeypatch.context() as m:
730+
m.setattr(
731+
"snowflake.connector.SnowflakeConnection._authenticate", lambda *_: None
732+
)
733+
734+
conn = snowflake.connector.connect(
735+
account="account",
736+
authenticator="WORKLOAD_IDENTITY",
737+
workload_identity_provider=provider_param,
738+
workload_identity_impersonation_path=[
739+
740+
],
741+
)
742+
assert conn.auth_class.provider == AttestationProvider.GCP
743+
assert conn.auth_class.impersonation_path == [
744+
745+
]
746+
747+
682748
@pytest.mark.parametrize(
683749
"provider_param, parsed_provider",
684750
[

test/wif/test_wif.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,6 @@ def test_should_authenticate_using_oidc():
5959

6060

6161
@pytest.mark.wif
62-
@pytest.mark.skip("Impersonation is still being developed")
6362
def test_should_authenticate_with_impersonation():
6463
if not isinstance(IMPERSONATION_PATH, str) or not IMPERSONATION_PATH:
6564
pytest.skip("Skipping test - IMPERSONATION_PATH is not set")

0 commit comments

Comments
 (0)