Skip to content

Commit 20c78ae

Browse files
[async] Add WIF impersonation path length as data sent to Snowflake backend (#2521)
1 parent 0d2bdd0 commit 20c78ae

10 files changed

+222
-1
lines changed

test/helpers.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -340,3 +340,9 @@ def apply_auth_class_update_body(auth_class, req_body_before):
340340
req_body_after = copy.deepcopy(req_body_before)
341341
auth_class.update_body(req_body_after)
342342
return req_body_after
343+
344+
345+
async def apply_auth_class_update_body_async(auth_class, req_body_before):
346+
req_body_after = copy.deepcopy(req_body_before)
347+
await auth_class.update_body(req_body_after)
348+
return req_body_after

test/unit/aio/test_auth_async.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import asyncio
99
import inspect
1010
import sys
11+
from test.helpers import apply_auth_class_update_body_async, create_mock_auth_body
1112
from test.unit.aio.mock_utils import mock_connection
1213
from unittest.mock import Mock, PropertyMock
1314

@@ -340,3 +341,21 @@ def test_mro():
340341
assert AuthByDefault.mro().index(AuthByPluginAsync) < AuthByDefault.mro().index(
341342
AuthByPluginSync
342343
)
344+
345+
346+
async def test_auth_by_default_prepare_body_does_not_overwrite_client_environment_fields():
347+
password = "testpassword"
348+
auth_class = AuthByDefault(password)
349+
350+
req_body_before = create_mock_auth_body()
351+
req_body_after = await apply_auth_class_update_body_async(
352+
auth_class, req_body_before
353+
)
354+
355+
assert all(
356+
[
357+
req_body_before["data"]["CLIENT_ENVIRONMENT"][k]
358+
== req_body_after["data"]["CLIENT_ENVIRONMENT"][k]
359+
for k in req_body_before["data"]["CLIENT_ENVIRONMENT"]
360+
]
361+
)

test/unit/aio/test_auth_keypair_async.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
from __future__ import annotations
77

8+
from test.helpers import apply_auth_class_update_body_async, create_mock_auth_body
89
from test.unit.aio.mock_utils import mock_connection
910
from unittest.mock import Mock, PropertyMock, patch
1011

@@ -61,6 +62,24 @@ async def test_auth_keypair(authenticator):
6162
assert rest.master_token == "MASTER_TOKEN"
6263

6364

65+
async def test_auth_prepare_body_does_not_overwrite_client_environment_fields():
66+
private_key_der, _ = generate_key_pair(2048)
67+
auth_class = AuthByKeyPair(private_key=private_key_der)
68+
69+
req_body_before = create_mock_auth_body()
70+
req_body_after = await apply_auth_class_update_body_async(
71+
auth_class, req_body_before
72+
)
73+
74+
assert all(
75+
[
76+
req_body_before["data"]["CLIENT_ENVIRONMENT"][k]
77+
== req_body_after["data"]["CLIENT_ENVIRONMENT"][k]
78+
for k in req_body_before["data"]["CLIENT_ENVIRONMENT"]
79+
]
80+
)
81+
82+
6483
async def test_auth_keypair_abc():
6584
"""Simple Key Pair test using abstraction layer."""
6685
private_key_der, public_key_der_encoded = generate_key_pair(2048)

test/unit/aio/test_auth_oauth_async.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55

66
from __future__ import annotations
77

8+
from test.helpers import apply_auth_class_update_body_async, create_mock_auth_body
9+
810
import pytest
911

1012
from snowflake.connector.aio.auth import AuthByOAuth
@@ -20,6 +22,24 @@ async def test_auth_oauth():
2022
assert body["data"]["AUTHENTICATOR"] == "OAUTH", body
2123

2224

25+
async def test_auth_prepare_body_does_not_overwrite_client_environment_fields():
26+
token = "oAuthToken"
27+
auth_class = AuthByOAuth(token)
28+
29+
req_body_before = create_mock_auth_body()
30+
req_body_after = await apply_auth_class_update_body_async(
31+
auth_class, req_body_before
32+
)
33+
34+
assert all(
35+
[
36+
req_body_before["data"]["CLIENT_ENVIRONMENT"][k]
37+
== req_body_after["data"]["CLIENT_ENVIRONMENT"][k]
38+
for k in req_body_before["data"]["CLIENT_ENVIRONMENT"]
39+
]
40+
)
41+
42+
2343
@pytest.mark.parametrize("authenticator", ["oauth", "OAUTH"])
2444
async def test_oauth_authenticator_is_case_insensitive(monkeypatch, authenticator):
2545
"""Test that oauth authenticator is case insensitive."""

test/unit/aio/test_auth_oauth_auth_code_async.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#
55

66
import unittest.mock as mock
7+
from test.helpers import apply_auth_class_update_body_async, create_mock_auth_body
78
from unittest.mock import patch
89

910
import pytest
@@ -44,6 +45,34 @@ async def test_auth_oauth_auth_code_oauth_type(omit_oauth_urls_check):
4445
)
4546

4647

48+
async def test_auth_prepare_body_does_not_overwrite_client_environment_fields(
49+
omit_oauth_urls_check,
50+
):
51+
auth_class = AuthByOauthCode(
52+
"app",
53+
"clientId",
54+
"clientSecret",
55+
"auth_url",
56+
"tokenRequestUrl",
57+
"redirectUri:{port}",
58+
"scope",
59+
"host",
60+
)
61+
62+
req_body_before = create_mock_auth_body()
63+
req_body_after = await apply_auth_class_update_body_async(
64+
auth_class, req_body_before
65+
)
66+
67+
assert all(
68+
[
69+
req_body_before["data"]["CLIENT_ENVIRONMENT"][k]
70+
== req_body_after["data"]["CLIENT_ENVIRONMENT"][k]
71+
for k in req_body_before["data"]["CLIENT_ENVIRONMENT"]
72+
]
73+
)
74+
75+
4776
@pytest.mark.parametrize("rtr_enabled", [True, False])
4877
async def test_auth_oauth_auth_code_single_use_refresh_tokens(
4978
rtr_enabled: bool, omit_oauth_urls_check

test/unit/aio/test_auth_oauth_credentials_async.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55

66
from __future__ import annotations
77

8+
from test.helpers import apply_auth_class_update_body_async, create_mock_auth_body
9+
810
import pytest
911

1012
from snowflake.connector.aio.auth import AuthByOauthCredentials
@@ -27,6 +29,29 @@ async def test_auth_oauth_credentials_oauth_type():
2729
)
2830

2931

32+
async def test_auth_prepare_body_does_not_overwrite_client_environment_fields():
33+
auth_class = AuthByOauthCredentials(
34+
"app",
35+
"clientId",
36+
"clientSecret",
37+
"https://example.com/oauth/token",
38+
"scope",
39+
)
40+
41+
req_body_before = create_mock_auth_body()
42+
req_body_after = await apply_auth_class_update_body_async(
43+
auth_class, req_body_before
44+
)
45+
46+
assert all(
47+
[
48+
req_body_before["data"]["CLIENT_ENVIRONMENT"][k]
49+
== req_body_after["data"]["CLIENT_ENVIRONMENT"][k]
50+
for k in req_body_before["data"]["CLIENT_ENVIRONMENT"]
51+
]
52+
)
53+
54+
3055
@pytest.mark.parametrize(
3156
"authenticator", ["OAUTH_CLIENT_CREDENTIALS", "oauth_client_credentials"]
3257
)

test/unit/aio/test_auth_okta_async.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from __future__ import annotations
77

88
import logging
9+
from test.helpers import apply_auth_class_update_body_async, create_mock_auth_body
910
from test.unit.aio.mock_utils import mock_connection
1011
from unittest.mock import MagicMock, Mock, PropertyMock, patch
1112

@@ -18,6 +19,24 @@
1819
from snowflake.connector.description import CLIENT_NAME, CLIENT_VERSION
1920

2021

22+
async def test_auth_prepare_body_does_not_overwrite_client_environment_fields():
23+
application = "testapplication"
24+
auth_class = AuthByOkta(application)
25+
26+
req_body_before = create_mock_auth_body()
27+
req_body_after = await apply_auth_class_update_body_async(
28+
auth_class, req_body_before
29+
)
30+
31+
assert all(
32+
[
33+
req_body_before["data"]["CLIENT_ENVIRONMENT"][k]
34+
== req_body_after["data"]["CLIENT_ENVIRONMENT"][k]
35+
for k in req_body_before["data"]["CLIENT_ENVIRONMENT"]
36+
]
37+
)
38+
39+
2140
async def test_auth_okta():
2241
"""Authentication by OKTA positive test case."""
2342
authenticator = "https://testsso.snowflake.net/"

test/unit/aio/test_auth_pat_async.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55

66
from __future__ import annotations
77

8+
from test.helpers import apply_auth_class_update_body_async, create_mock_auth_body
9+
810
import pytest
911

1012
from snowflake.connector.aio.auth import AuthByPAT
@@ -27,6 +29,24 @@ async def test_auth_pat():
2729
assert auth.assertion_content is None
2830

2931

32+
async def test_pat_prepare_body_does_not_overwrite_client_environment_fields():
33+
token = "patToken"
34+
auth_class = AuthByPAT(token)
35+
36+
req_body_before = create_mock_auth_body()
37+
req_body_after = await apply_auth_class_update_body_async(
38+
auth_class, req_body_before
39+
)
40+
41+
assert all(
42+
[
43+
req_body_before["data"]["CLIENT_ENVIRONMENT"][k]
44+
== req_body_after["data"]["CLIENT_ENVIRONMENT"][k]
45+
for k in req_body_before["data"]["CLIENT_ENVIRONMENT"]
46+
]
47+
)
48+
49+
3050
async def test_auth_pat_reauthenticate():
3151
"""Test PAT reauthenticate."""
3252
token = "patToken"

test/unit/aio/test_auth_webbrowser_async.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import asyncio
99
import base64
1010
import socket
11+
from test.helpers import apply_auth_class_update_body_async, create_mock_auth_body
1112
from test.unit.aio.mock_utils import mock_connection
1213
from unittest import mock
1314
from unittest.mock import MagicMock, Mock, PropertyMock, patch
@@ -918,6 +919,22 @@ async def mock_webbrowser_auth_prepare(
918919
await conn.close()
919920

920921

922+
async def test_auth_prepare_body_does_not_overwrite_client_environment_fields():
923+
auth_class = AuthByWebBrowser(application=APPLICATION)
924+
req_body_before = create_mock_auth_body()
925+
req_body_after = await apply_auth_class_update_body_async(
926+
auth_class, req_body_before
927+
)
928+
929+
assert all(
930+
[
931+
req_body_before["data"]["CLIENT_ENVIRONMENT"][k]
932+
== req_body_after["data"]["CLIENT_ENVIRONMENT"][k]
933+
for k in req_body_before["data"]["CLIENT_ENVIRONMENT"]
934+
]
935+
)
936+
937+
921938
def test_mro():
922939
"""Ensure that methods from AuthByPluginAsync override those from AuthByPlugin."""
923940
from snowflake.connector.aio.auth import AuthByPlugin as AuthByPluginAsync

test/unit/aio/test_auth_workload_identity_async.py

Lines changed: 48 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,15 @@
1515
import jwt
1616
import pytest
1717

18-
from snowflake.connector.aio._wif_util import AttestationProvider
18+
from snowflake.connector.aio._wif_util import (
19+
AttestationProvider,
20+
WorkloadIdentityAttestation,
21+
)
1922
from snowflake.connector.aio.auth import AuthByWorkloadIdentity
2023
from snowflake.connector.errors import ProgrammingError
2124

2225
from ...csp_helpers import gen_dummy_access_token, gen_dummy_id_token
26+
from ...helpers import apply_auth_class_update_body_async, create_mock_auth_body
2327
from .csp_helpers_async import FakeAwsEnvironmentAsync, FakeGceMetadataServiceAsync
2428

2529
logger = logging.getLogger(__name__)
@@ -138,6 +142,42 @@ async def mock_post(*args, **kwargs):
138142
await connection.close()
139143

140144

145+
@pytest.mark.parametrize(
146+
"provider,additional_args",
147+
[
148+
(AttestationProvider.AWS, {}),
149+
(AttestationProvider.GCP, {}),
150+
(AttestationProvider.AZURE, {}),
151+
(
152+
AttestationProvider.OIDC,
153+
{"token": gen_dummy_id_token(sub="service-1", iss="issuer-1")},
154+
),
155+
],
156+
)
157+
async def test_auth_prepare_body_does_not_overwrite_client_environment_fields(
158+
provider, additional_args
159+
):
160+
auth_class = AuthByWorkloadIdentity(provider=provider, **additional_args)
161+
auth_class.attestation = WorkloadIdentityAttestation(
162+
provider=AttestationProvider.GCP,
163+
credential=None,
164+
user_identifier_components=None,
165+
)
166+
167+
req_body_before = create_mock_auth_body()
168+
req_body_after = await apply_auth_class_update_body_async(
169+
auth_class, req_body_before
170+
)
171+
172+
assert all(
173+
[
174+
req_body_before["data"]["CLIENT_ENVIRONMENT"][k]
175+
== req_body_after["data"]["CLIENT_ENVIRONMENT"][k]
176+
for k in req_body_before["data"]["CLIENT_ENVIRONMENT"]
177+
]
178+
)
179+
180+
141181
# -- OIDC Tests --
142182

143183

@@ -152,6 +192,7 @@ async def test_explicit_oidc_valid_inline_token_plumbed_to_api():
152192
"AUTHENTICATOR": "WORKLOAD_IDENTITY",
153193
"PROVIDER": "OIDC",
154194
"TOKEN": dummy_token,
195+
"CLIENT_ENVIRONMENT": {"WORKLOAD_IDENTITY_IMPERSONATION_PATH_LENGTH": 0},
155196
}
156197

157198

@@ -209,6 +250,9 @@ async def test_explicit_aws_encodes_audience_host_signature_to_api(
209250
data = await extract_api_data(auth_class)
210251
assert data["AUTHENTICATOR"] == "WORKLOAD_IDENTITY"
211252
assert data["PROVIDER"] == "AWS"
253+
assert (
254+
data["CLIENT_ENVIRONMENT"]["WORKLOAD_IDENTITY_IMPERSONATION_PATH_LENGTH"] == 0
255+
)
212256
verify_aws_token(data["TOKEN"], fake_aws_environment.region)
213257

214258

@@ -310,6 +354,7 @@ async def test_explicit_gcp_plumbs_token_to_api(
310354
"AUTHENTICATOR": "WORKLOAD_IDENTITY",
311355
"PROVIDER": "GCP",
312356
"TOKEN": fake_gce_metadata_service.token,
357+
"CLIENT_ENVIRONMENT": {"WORKLOAD_IDENTITY_IMPERSONATION_PATH_LENGTH": 0},
313358
}
314359

315360

@@ -366,6 +411,7 @@ def __init__(self, content):
366411
"AUTHENTICATOR": "WORKLOAD_IDENTITY",
367412
"PROVIDER": "GCP",
368413
"TOKEN": sa3_id_token,
414+
"CLIENT_ENVIRONMENT": {"WORKLOAD_IDENTITY_IMPERSONATION_PATH_LENGTH": 2},
369415
}
370416

371417

@@ -420,6 +466,7 @@ async def test_explicit_azure_plumbs_token_to_api(fake_azure_metadata_service):
420466
"AUTHENTICATOR": "WORKLOAD_IDENTITY",
421467
"PROVIDER": "AZURE",
422468
"TOKEN": fake_azure_metadata_service.token,
469+
"CLIENT_ENVIRONMENT": {"WORKLOAD_IDENTITY_IMPERSONATION_PATH_LENGTH": 0},
423470
}
424471

425472

0 commit comments

Comments
 (0)