1+ import copy
12import json
23import logging
34from base64 import b64decode
78import jwt
89import pytest
910
10- from snowflake .connector .auth import AuthByWorkloadIdentity
11+ from snowflake .connector .auth import Auth , AuthByWorkloadIdentity
1112from snowflake .connector .errors import ProgrammingError
1213from snowflake .connector .vendored .requests .exceptions import (
1314 ConnectTimeout ,
1415 HTTPError ,
1516 Timeout ,
1617)
17- from snowflake .connector .wif_util import AttestationProvider , get_aws_sts_hostname
18+ from snowflake .connector .wif_util import (
19+ AttestationProvider ,
20+ WorkloadIdentityAttestation ,
21+ get_aws_sts_hostname ,
22+ )
1823
1924from ..csp_helpers import (
2025 FakeAwsEnvironment ,
@@ -122,6 +127,58 @@ def test_wif_authenticator_is_case_insensitive(
122127 assert isinstance (connection .auth_class , AuthByWorkloadIdentity )
123128
124129
130+ @pytest .mark .parametrize (
131+ "provider,additional_args" ,
132+ [
133+ (AttestationProvider .AWS , {}),
134+ (AttestationProvider .GCP , {}),
135+ (AttestationProvider .AZURE , {}),
136+ (
137+ AttestationProvider .OIDC ,
138+ {"token" : gen_dummy_id_token (sub = "service-1" , iss = "issuer-1" )},
139+ ),
140+ ],
141+ )
142+ def test_wif_prepare_body_does_not_overwrite_fields (provider , additional_args ):
143+ ocsp_mode = mock .Mock ()
144+ ocsp_mode .name = "ocsp_mode"
145+ session_manager = mock .Mock ()
146+ session_manager .clone = lambda max_retries : "session_manager"
147+
148+ req_body_before = Auth .base_auth_data (
149+ "user" ,
150+ "account" ,
151+ "application" ,
152+ "internal_application_name" ,
153+ "internal_application_version" ,
154+ ocsp_mode ,
155+ login_timeout = 60 * 60 ,
156+ network_timeout = 60 * 60 ,
157+ socket_timeout = 60 * 60 ,
158+ platform_detection_timeout_seconds = 0.2 ,
159+ session_manager = session_manager ,
160+ )
161+ req_body_after = copy .deepcopy (req_body_before )
162+ auth_class = AuthByWorkloadIdentity (provider = provider , ** additional_args )
163+ auth_class .attestation = WorkloadIdentityAttestation (
164+ provider = provider , credential = None , user_identifier_components = None
165+ )
166+ auth_class .update_body (req_body_after )
167+
168+ # Check that the values in the body before are a strict subset of the values in the body after.
169+ # Must use all() for this comparison because lists are not hashable
170+ assert all (
171+ [
172+ req_body_before ["data" ]["CLIENT_ENVIRONMENT" ][k ]
173+ == req_body_after ["data" ]["CLIENT_ENVIRONMENT" ][k ]
174+ for k in req_body_before ["data" ]["CLIENT_ENVIRONMENT" ]
175+ ]
176+ )
177+ req_body_before ["data" ].pop ("CLIENT_ENVIRONMENT" )
178+ req_body_after ["data" ].pop ("CLIENT_ENVIRONMENT" )
179+ assert set (req_body_before ["data" ].items ()) <= set (req_body_after ["data" ].items ())
180+
181+
125182# -- OIDC Tests --
126183
127184
@@ -136,7 +193,7 @@ def test_explicit_oidc_valid_inline_token_plumbed_to_api():
136193 "AUTHENTICATOR" : "WORKLOAD_IDENTITY" ,
137194 "PROVIDER" : "OIDC" ,
138195 "TOKEN" : dummy_token ,
139- "CLIENT_ENVIRONMENT" : {"IMPERSONATION_PATH_LENGTH " : 0 },
196+ "CLIENT_ENVIRONMENT" : {"WORKLOAD_IDENTITY_IMPERSONATION_PATH_LENGTH " : 0 },
140197 }
141198
142199
@@ -192,7 +249,9 @@ def test_explicit_aws_encodes_audience_host_signature_to_api(
192249 data = extract_api_data (auth_class )
193250 assert data ["AUTHENTICATOR" ] == "WORKLOAD_IDENTITY"
194251 assert data ["PROVIDER" ] == "AWS"
195- assert data ["CLIENT_ENVIRONMENT" ]["IMPERSONATION_PATH_LENGTH" ] == 0
252+ assert (
253+ data ["CLIENT_ENVIRONMENT" ]["WORKLOAD_IDENTITY_IMPERSONATION_PATH_LENGTH" ] == 0
254+ )
196255 verify_aws_token (data ["TOKEN" ], fake_aws_environment .region )
197256
198257
@@ -326,7 +385,7 @@ def test_explicit_gcp_plumbs_token_to_api(
326385 "AUTHENTICATOR" : "WORKLOAD_IDENTITY" ,
327386 "PROVIDER" : "GCP" ,
328387 "TOKEN" : fake_gce_metadata_service .token ,
329- "CLIENT_ENVIRONMENT" : {"IMPERSONATION_PATH_LENGTH " : 0 },
388+ "CLIENT_ENVIRONMENT" : {"WORKLOAD_IDENTITY_IMPERSONATION_PATH_LENGTH " : 0 },
330389 }
331390
332391
@@ -376,7 +435,7 @@ def test_gcp_calls_correct_apis_and_populates_auth_data_for_final_sa(
376435 "AUTHENTICATOR" : "WORKLOAD_IDENTITY" ,
377436 "PROVIDER" : "GCP" ,
378437 "TOKEN" : sa3_id_token ,
379- "CLIENT_ENVIRONMENT" : {"IMPERSONATION_PATH_LENGTH " : 2 },
438+ "CLIENT_ENVIRONMENT" : {"WORKLOAD_IDENTITY_IMPERSONATION_PATH_LENGTH " : 2 },
380439 }
381440
382441
@@ -429,7 +488,7 @@ def test_explicit_azure_plumbs_token_to_api(fake_azure_metadata_service):
429488 "AUTHENTICATOR" : "WORKLOAD_IDENTITY" ,
430489 "PROVIDER" : "AZURE" ,
431490 "TOKEN" : fake_azure_metadata_service .token ,
432- "CLIENT_ENVIRONMENT" : {"IMPERSONATION_PATH_LENGTH " : 0 },
491+ "CLIENT_ENVIRONMENT" : {"WORKLOAD_IDENTITY_IMPERSONATION_PATH_LENGTH " : 0 },
433492 }
434493
435494
0 commit comments