Skip to content

Commit 9c4520b

Browse files
committed
initial changes
1 parent ac6bf6c commit 9c4520b

File tree

3 files changed

+29
-1
lines changed

3 files changed

+29
-1
lines changed

src/snowflake/connector/auth/workload_identity.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,12 +55,14 @@ def __init__(
5555
provider: AttestationProvider | None = None,
5656
token: str | None = None,
5757
entra_resource: str | None = None,
58+
impersonations: list[str] | None = None,
5859
**kwargs,
5960
) -> None:
6061
super().__init__(**kwargs)
6162
self.provider = provider
6263
self.token = token
6364
self.entra_resource = entra_resource
65+
self.impersonations = impersonations
6466

6567
self.attestation: WorkloadIdentityAttestation | None = None
6668

@@ -85,6 +87,7 @@ def prepare(
8587
self.provider,
8688
self.entra_resource,
8789
self.token,
90+
self.impersonations,
8891
session_manager=(
8992
conn._session_manager.clone(max_retries=0) if conn else None
9093
),

src/snowflake/connector/connection.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -214,6 +214,7 @@ def _get_private_bytes_from_file(
214214
"authenticator": (DEFAULT_AUTHENTICATOR, (type(None), str)),
215215
"workload_identity_provider": (None, (type(None), AttestationProvider)),
216216
"workload_identity_entra_resource": (None, (type(None), str)),
217+
"workload_identity_impersonations": (None, (type(None), list)),
217218
"mfa_callback": (None, (type(None), Callable)),
218219
"password_callback": (None, (type(None), Callable)),
219220
"auth_class": (None, (type(None), AuthByPlugin)),
@@ -1359,6 +1360,7 @@ def __open_connection(self):
13591360
provider=self._workload_identity_provider,
13601361
token=self._token,
13611362
entra_resource=self._workload_identity_entra_resource,
1363+
impersonations=self._workload_identity_impersonations,
13621364
)
13631365
else:
13641366
# okta URL, e.g., https://<account>.okta.com/
@@ -1531,6 +1533,7 @@ def __config(self, **kwargs):
15311533
workload_identity_dependent_options = [
15321534
"workload_identity_provider",
15331535
"workload_identity_entra_resource",
1536+
"workload_identity_impersonations",
15341537
]
15351538
for dependent_option in workload_identity_dependent_options:
15361539
if (

src/snowflake/connector/wif_util.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
from botocore.auth import SigV4Auth
1313
from botocore.awsrequest import AWSRequest
1414
from botocore.utils import InstanceMetadataRegionFetcher
15+
from google.auth import impersonated_credentials
16+
from google.auth.transport.requests import Request
1517

1618
from .errorcode import ER_INVALID_WIF_SETTINGS, ER_WIF_CREDENTIALS_NOT_FOUND
1719
from .errors import ProgrammingError
@@ -185,6 +187,7 @@ def create_aws_attestation(
185187

186188

187189
def create_gcp_attestation(
190+
impersonations: list[str] | None = None,
188191
session_manager: SessionManager | None = None,
189192
) -> WorkloadIdentityAttestation:
190193
"""Tries to create a workload identity attestation for GCP.
@@ -207,6 +210,24 @@ def create_gcp_attestation(
207210
)
208211

209212
jwt_str = res.content.decode("utf-8")
213+
if impersonations:
214+
try:
215+
for impersonation in impersonations:
216+
jwt_str = impersonated_credentials.Credentials(
217+
source_credentials=jwt_str,
218+
target_principal=impersonation,
219+
target_audience=SNOWFLAKE_AUDIENCE,
220+
)
221+
222+
# Refresh the last impersonated credential to get the final token
223+
jwt_str.refresh(Request())
224+
jwt_str = jwt_str.token
225+
except Exception as e:
226+
raise ProgrammingError(
227+
msg=f"Error impersonating GCP service account: {e}. Ensure the service account has the 'Service Account Token Creator' role.",
228+
errno=ER_WIF_CREDENTIALS_NOT_FOUND,
229+
)
230+
210231
_, subject = extract_iss_and_sub_without_signature_verification(jwt_str)
211232
return WorkloadIdentityAttestation(
212233
AttestationProvider.GCP, jwt_str, {"sub": subject}
@@ -295,6 +316,7 @@ def create_attestation(
295316
provider: AttestationProvider,
296317
entra_resource: str | None = None,
297318
token: str | None = None,
319+
impersonations: list[str] | None = None,
298320
session_manager: SessionManager | None = None,
299321
) -> WorkloadIdentityAttestation:
300322
"""Entry point to create an attestation using the given provider.
@@ -313,7 +335,7 @@ def create_attestation(
313335
elif provider == AttestationProvider.AZURE:
314336
return create_azure_attestation(entra_resource, session_manager)
315337
elif provider == AttestationProvider.GCP:
316-
return create_gcp_attestation(session_manager)
338+
return create_gcp_attestation(impersonations, session_manager)
317339
elif provider == AttestationProvider.OIDC:
318340
return create_oidc_attestation(token)
319341
else:

0 commit comments

Comments
 (0)