1515import jwt
1616import pytest
1717
18- from snowflake .connector .aio ._wif_util import AttestationProvider
18+ from snowflake .connector .aio ._wif_util import (
19+ AttestationProvider ,
20+ WorkloadIdentityAttestation ,
21+ )
1922from snowflake .connector .aio .auth import AuthByWorkloadIdentity
2023from snowflake .connector .errors import ProgrammingError
2124
2225from ...csp_helpers import gen_dummy_access_token , gen_dummy_id_token
26+ from ...helpers import apply_auth_class_update_body_async , create_mock_auth_body
2327from .csp_helpers_async import FakeAwsEnvironmentAsync , FakeGceMetadataServiceAsync
2428
2529logger = logging .getLogger (__name__ )
@@ -138,6 +142,42 @@ async def mock_post(*args, **kwargs):
138142 await connection .close ()
139143
140144
145+ @pytest .mark .parametrize (
146+ "provider,additional_args" ,
147+ [
148+ (AttestationProvider .AWS , {}),
149+ (AttestationProvider .GCP , {}),
150+ (AttestationProvider .AZURE , {}),
151+ (
152+ AttestationProvider .OIDC ,
153+ {"token" : gen_dummy_id_token (sub = "service-1" , iss = "issuer-1" )},
154+ ),
155+ ],
156+ )
157+ async def test_auth_prepare_body_does_not_overwrite_client_environment_fields (
158+ provider , additional_args
159+ ):
160+ auth_class = AuthByWorkloadIdentity (provider = provider , ** additional_args )
161+ auth_class .attestation = WorkloadIdentityAttestation (
162+ provider = AttestationProvider .GCP ,
163+ credential = None ,
164+ user_identifier_components = None ,
165+ )
166+
167+ req_body_before = create_mock_auth_body ()
168+ req_body_after = await apply_auth_class_update_body_async (
169+ auth_class , req_body_before
170+ )
171+
172+ assert all (
173+ [
174+ req_body_before ["data" ]["CLIENT_ENVIRONMENT" ][k ]
175+ == req_body_after ["data" ]["CLIENT_ENVIRONMENT" ][k ]
176+ for k in req_body_before ["data" ]["CLIENT_ENVIRONMENT" ]
177+ ]
178+ )
179+
180+
141181# -- OIDC Tests --
142182
143183
@@ -152,6 +192,7 @@ async def test_explicit_oidc_valid_inline_token_plumbed_to_api():
152192 "AUTHENTICATOR" : "WORKLOAD_IDENTITY" ,
153193 "PROVIDER" : "OIDC" ,
154194 "TOKEN" : dummy_token ,
195+ "CLIENT_ENVIRONMENT" : {"WORKLOAD_IDENTITY_IMPERSONATION_PATH_LENGTH" : 0 },
155196 }
156197
157198
@@ -209,6 +250,9 @@ async def test_explicit_aws_encodes_audience_host_signature_to_api(
209250 data = await extract_api_data (auth_class )
210251 assert data ["AUTHENTICATOR" ] == "WORKLOAD_IDENTITY"
211252 assert data ["PROVIDER" ] == "AWS"
253+ assert (
254+ data ["CLIENT_ENVIRONMENT" ]["WORKLOAD_IDENTITY_IMPERSONATION_PATH_LENGTH" ] == 0
255+ )
212256 verify_aws_token (data ["TOKEN" ], fake_aws_environment .region )
213257
214258
@@ -310,6 +354,7 @@ async def test_explicit_gcp_plumbs_token_to_api(
310354 "AUTHENTICATOR" : "WORKLOAD_IDENTITY" ,
311355 "PROVIDER" : "GCP" ,
312356 "TOKEN" : fake_gce_metadata_service .token ,
357+ "CLIENT_ENVIRONMENT" : {"WORKLOAD_IDENTITY_IMPERSONATION_PATH_LENGTH" : 0 },
313358 }
314359
315360
@@ -366,6 +411,7 @@ def __init__(self, content):
366411 "AUTHENTICATOR" : "WORKLOAD_IDENTITY" ,
367412 "PROVIDER" : "GCP" ,
368413 "TOKEN" : sa3_id_token ,
414+ "CLIENT_ENVIRONMENT" : {"WORKLOAD_IDENTITY_IMPERSONATION_PATH_LENGTH" : 2 },
369415 }
370416
371417
@@ -420,6 +466,7 @@ async def test_explicit_azure_plumbs_token_to_api(fake_azure_metadata_service):
420466 "AUTHENTICATOR" : "WORKLOAD_IDENTITY" ,
421467 "PROVIDER" : "AZURE" ,
422468 "TOKEN" : fake_azure_metadata_service .token ,
469+ "CLIENT_ENVIRONMENT" : {"WORKLOAD_IDENTITY_IMPERSONATION_PATH_LENGTH" : 0 },
423470 }
424471
425472
0 commit comments