Skip to content

Commit 395928b

Browse files
committed
add unit test and rename data field key
1 parent 9f139a0 commit 395928b

File tree

2 files changed

+67
-8
lines changed

2 files changed

+67
-8
lines changed

src/snowflake/connector/auth/workload_identity.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ def update_body(self, body: dict[typing.Any, typing.Any]) -> None:
7979
).value
8080
body["data"]["TOKEN"] = self.attestation.credential
8181
body["data"].setdefault("CLIENT_ENVIRONMENT", {})[
82-
"IMPERSONATION_PATH_LENGTH"
82+
"WORKLOAD_IDENTITY_IMPERSONATION_PATH_LENGTH"
8383
] = len(self.impersonation_path or [])
8484

8585
def prepare(

test/unit/test_auth_workload_identity.py

Lines changed: 66 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import copy
12
import json
23
import logging
34
from base64 import b64decode
@@ -7,14 +8,18 @@
78
import jwt
89
import pytest
910

10-
from snowflake.connector.auth import AuthByWorkloadIdentity
11+
from snowflake.connector.auth import Auth, AuthByWorkloadIdentity
1112
from snowflake.connector.errors import ProgrammingError
1213
from snowflake.connector.vendored.requests.exceptions import (
1314
ConnectTimeout,
1415
HTTPError,
1516
Timeout,
1617
)
17-
from snowflake.connector.wif_util import AttestationProvider, get_aws_sts_hostname
18+
from snowflake.connector.wif_util import (
19+
AttestationProvider,
20+
WorkloadIdentityAttestation,
21+
get_aws_sts_hostname,
22+
)
1823

1924
from ..csp_helpers import (
2025
FakeAwsEnvironment,
@@ -122,6 +127,58 @@ def test_wif_authenticator_is_case_insensitive(
122127
assert isinstance(connection.auth_class, AuthByWorkloadIdentity)
123128

124129

130+
@pytest.mark.parametrize(
131+
"provider,additional_args",
132+
[
133+
(AttestationProvider.AWS, {}),
134+
(AttestationProvider.GCP, {}),
135+
(AttestationProvider.AZURE, {}),
136+
(
137+
AttestationProvider.OIDC,
138+
{"token": gen_dummy_id_token(sub="service-1", iss="issuer-1")},
139+
),
140+
],
141+
)
142+
def test_wif_prepare_body_does_not_overwrite_fields(provider, additional_args):
143+
ocsp_mode = mock.Mock()
144+
ocsp_mode.name = "ocsp_mode"
145+
session_manager = mock.Mock()
146+
session_manager.clone = lambda max_retries: "session_manager"
147+
148+
req_body_before = Auth.base_auth_data(
149+
"user",
150+
"account",
151+
"application",
152+
"internal_application_name",
153+
"internal_application_version",
154+
ocsp_mode,
155+
login_timeout=60 * 60,
156+
network_timeout=60 * 60,
157+
socket_timeout=60 * 60,
158+
platform_detection_timeout_seconds=0.2,
159+
session_manager=session_manager,
160+
)
161+
req_body_after = copy.deepcopy(req_body_before)
162+
auth_class = AuthByWorkloadIdentity(provider=provider, **additional_args)
163+
auth_class.attestation = WorkloadIdentityAttestation(
164+
provider=provider, credential=None, user_identifier_components=None
165+
)
166+
auth_class.update_body(req_body_after)
167+
168+
# Check that the values in the body before are a strict subset of the values in the body after.
169+
# Must use all() for this comparison because lists are not hashable
170+
assert all(
171+
[
172+
req_body_before["data"]["CLIENT_ENVIRONMENT"][k]
173+
== req_body_after["data"]["CLIENT_ENVIRONMENT"][k]
174+
for k in req_body_before["data"]["CLIENT_ENVIRONMENT"]
175+
]
176+
)
177+
req_body_before["data"].pop("CLIENT_ENVIRONMENT")
178+
req_body_after["data"].pop("CLIENT_ENVIRONMENT")
179+
assert set(req_body_before["data"].items()) <= set(req_body_after["data"].items())
180+
181+
125182
# -- OIDC Tests --
126183

127184

@@ -136,7 +193,7 @@ def test_explicit_oidc_valid_inline_token_plumbed_to_api():
136193
"AUTHENTICATOR": "WORKLOAD_IDENTITY",
137194
"PROVIDER": "OIDC",
138195
"TOKEN": dummy_token,
139-
"CLIENT_ENVIRONMENT": {"IMPERSONATION_PATH_LENGTH": 0},
196+
"CLIENT_ENVIRONMENT": {"WORKLOAD_IDENTITY_IMPERSONATION_PATH_LENGTH": 0},
140197
}
141198

142199

@@ -192,7 +249,9 @@ def test_explicit_aws_encodes_audience_host_signature_to_api(
192249
data = extract_api_data(auth_class)
193250
assert data["AUTHENTICATOR"] == "WORKLOAD_IDENTITY"
194251
assert data["PROVIDER"] == "AWS"
195-
assert data["CLIENT_ENVIRONMENT"]["IMPERSONATION_PATH_LENGTH"] == 0
252+
assert (
253+
data["CLIENT_ENVIRONMENT"]["WORKLOAD_IDENTITY_IMPERSONATION_PATH_LENGTH"] == 0
254+
)
196255
verify_aws_token(data["TOKEN"], fake_aws_environment.region)
197256

198257

@@ -326,7 +385,7 @@ def test_explicit_gcp_plumbs_token_to_api(
326385
"AUTHENTICATOR": "WORKLOAD_IDENTITY",
327386
"PROVIDER": "GCP",
328387
"TOKEN": fake_gce_metadata_service.token,
329-
"CLIENT_ENVIRONMENT": {"IMPERSONATION_PATH_LENGTH": 0},
388+
"CLIENT_ENVIRONMENT": {"WORKLOAD_IDENTITY_IMPERSONATION_PATH_LENGTH": 0},
330389
}
331390

332391

@@ -376,7 +435,7 @@ def test_gcp_calls_correct_apis_and_populates_auth_data_for_final_sa(
376435
"AUTHENTICATOR": "WORKLOAD_IDENTITY",
377436
"PROVIDER": "GCP",
378437
"TOKEN": sa3_id_token,
379-
"CLIENT_ENVIRONMENT": {"IMPERSONATION_PATH_LENGTH": 2},
438+
"CLIENT_ENVIRONMENT": {"WORKLOAD_IDENTITY_IMPERSONATION_PATH_LENGTH": 2},
380439
}
381440

382441

@@ -429,7 +488,7 @@ def test_explicit_azure_plumbs_token_to_api(fake_azure_metadata_service):
429488
"AUTHENTICATOR": "WORKLOAD_IDENTITY",
430489
"PROVIDER": "AZURE",
431490
"TOKEN": fake_azure_metadata_service.token,
432-
"CLIENT_ENVIRONMENT": {"IMPERSONATION_PATH_LENGTH": 0},
491+
"CLIENT_ENVIRONMENT": {"WORKLOAD_IDENTITY_IMPERSONATION_PATH_LENGTH": 0},
433492
}
434493

435494

0 commit comments

Comments
 (0)