Skip to content

Commit 547c4b7

Browse files
authored
Add WIF impersonation path length as data sent to Snowflake backend (#2521)
1 parent ce4790b commit 547c4b7

File tree

11 files changed

+231
-1
lines changed

11 files changed

+231
-1
lines changed

src/snowflake/connector/auth/workload_identity.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,9 @@ def update_body(self, body: dict[typing.Any, typing.Any]) -> None:
7878
self.attestation
7979
).value
8080
body["data"]["TOKEN"] = self.attestation.credential
81+
body["data"].setdefault("CLIENT_ENVIRONMENT", {})[
82+
"WORKLOAD_IDENTITY_IMPERSONATION_PATH_LENGTH"
83+
] = len(self.impersonation_path or [])
8184

8285
def prepare(
8386
self, *, conn: SnowflakeConnection | None, **kwargs: typing.Any

test/helpers.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from __future__ import annotations
33

44
import base64
5+
import copy
56
import math
67
import os
78
import random
@@ -12,6 +13,7 @@
1213

1314
import pytest
1415

16+
from snowflake.connector.auth._auth import Auth
1517
from snowflake.connector.compat import OK
1618

1719
if TYPE_CHECKING:
@@ -260,3 +262,30 @@ def _arrow_error_stream_random_input_test(use_table_iterator):
260262
# error instance users get should be the same
261263
assert len(exception_result)
262264
assert len(result_array) == 0
265+
266+
267+
def create_mock_auth_body():
268+
ocsp_mode = Mock()
269+
ocsp_mode.name = "ocsp_mode"
270+
session_manager = Mock()
271+
session_manager.clone = lambda max_retries: "session_manager"
272+
273+
return Auth.base_auth_data(
274+
"user",
275+
"account",
276+
"application",
277+
"internal_application_name",
278+
"internal_application_version",
279+
ocsp_mode,
280+
login_timeout=60 * 60,
281+
network_timeout=60 * 60,
282+
socket_timeout=60 * 60,
283+
platform_detection_timeout_seconds=0.2,
284+
session_manager=session_manager,
285+
)
286+
287+
288+
def apply_auth_class_update_body(auth_class, req_body_before):
289+
req_body_after = copy.deepcopy(req_body_before)
290+
auth_class.update_body(req_body_after)
291+
return req_body_after

test/unit/test_auth.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import inspect
55
import sys
66
import time
7+
from test.helpers import apply_auth_class_update_body, create_mock_auth_body
78
from unittest.mock import Mock, PropertyMock
89

910
import pytest
@@ -337,3 +338,19 @@ def test_authbyplugin_abc_api():
337338
'password': <Parameter "password: 'str'">, \
338339
'kwargs': <Parameter "**kwargs: 'Any'">})"""
339340
)
341+
342+
343+
def test_auth_by_default_prepare_body_does_not_overwrite_client_environment_fields():
344+
password = "testpassword"
345+
auth_class = AuthByDefault(password)
346+
347+
req_body_before = create_mock_auth_body()
348+
req_body_after = apply_auth_class_update_body(auth_class, req_body_before)
349+
350+
assert all(
351+
[
352+
req_body_before["data"]["CLIENT_ENVIRONMENT"][k]
353+
== req_body_after["data"]["CLIENT_ENVIRONMENT"][k]
354+
for k in req_body_before["data"]["CLIENT_ENVIRONMENT"]
355+
]
356+
)

test/unit/test_auth_keypair.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#!/usr/bin/env python
22
from __future__ import annotations
33

4+
from test.helpers import apply_auth_class_update_body, create_mock_auth_body
45
from unittest.mock import Mock, PropertyMock, patch
56

67
import pytest
@@ -63,6 +64,22 @@ def test_auth_keypair(authenticator):
6364
assert rest.master_token == "MASTER_TOKEN"
6465

6566

67+
def test_auth_prepare_body_does_not_overwrite_client_environment_fields():
68+
private_key_der, _ = generate_key_pair(2048)
69+
auth_class = AuthByKeyPair(private_key=private_key_der)
70+
71+
req_body_before = create_mock_auth_body()
72+
req_body_after = apply_auth_class_update_body(auth_class, req_body_before)
73+
74+
assert all(
75+
[
76+
req_body_before["data"]["CLIENT_ENVIRONMENT"][k]
77+
== req_body_after["data"]["CLIENT_ENVIRONMENT"][k]
78+
for k in req_body_before["data"]["CLIENT_ENVIRONMENT"]
79+
]
80+
)
81+
82+
6683
def test_auth_keypair_abc():
6784
"""Simple Key Pair test using abstraction layer."""
6885
private_key_der, public_key_der_encoded = generate_key_pair(2048)

test/unit/test_auth_oauth.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
#!/usr/bin/env python
22
from __future__ import annotations
33

4+
from test.helpers import apply_auth_class_update_body, create_mock_auth_body
5+
46
try: # pragma: no cover
57
from snowflake.connector.auth import AuthByOAuth
68
except ImportError:
@@ -18,6 +20,22 @@ def test_auth_oauth():
1820
assert body["data"]["AUTHENTICATOR"] == "OAUTH", body
1921

2022

23+
def test_auth_prepare_body_does_not_overwrite_client_environment_fields():
24+
token = "oAuthToken"
25+
auth_class = AuthByOAuth(token)
26+
27+
req_body_before = create_mock_auth_body()
28+
req_body_after = apply_auth_class_update_body(auth_class, req_body_before)
29+
30+
assert all(
31+
[
32+
req_body_before["data"]["CLIENT_ENVIRONMENT"][k]
33+
== req_body_after["data"]["CLIENT_ENVIRONMENT"][k]
34+
for k in req_body_before["data"]["CLIENT_ENVIRONMENT"]
35+
]
36+
)
37+
38+
2139
@pytest.mark.parametrize("authenticator", ["oauth", "OAUTH"])
2240
def test_oauth_authenticator_is_case_insensitive(monkeypatch, authenticator):
2341
"""Test that oauth authenticator is case insensitive."""

test/unit/test_auth_oauth_auth_code.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#
55

66
import unittest.mock as mock
7+
from test.helpers import apply_auth_class_update_body, create_mock_auth_body
78
from unittest.mock import patch
89

910
import pytest
@@ -44,6 +45,32 @@ def test_auth_oauth_auth_code_oauth_type(omit_oauth_urls_check):
4445
)
4546

4647

48+
def test_auth_prepare_body_does_not_overwrite_client_environment_fields(
49+
omit_oauth_urls_check,
50+
):
51+
auth_class = AuthByOauthCode(
52+
"app",
53+
"clientId",
54+
"clientSecret",
55+
"auth_url",
56+
"tokenRequestUrl",
57+
"redirectUri:{port}",
58+
"scope",
59+
"host",
60+
)
61+
62+
req_body_before = create_mock_auth_body()
63+
req_body_after = apply_auth_class_update_body(auth_class, req_body_before)
64+
65+
assert all(
66+
[
67+
req_body_before["data"]["CLIENT_ENVIRONMENT"][k]
68+
== req_body_after["data"]["CLIENT_ENVIRONMENT"][k]
69+
for k in req_body_before["data"]["CLIENT_ENVIRONMENT"]
70+
]
71+
)
72+
73+
4774
@pytest.mark.parametrize("rtr_enabled", [True, False])
4875
def test_auth_oauth_auth_code_single_use_refresh_tokens(
4976
rtr_enabled: bool, omit_oauth_urls_check

test/unit/test_auth_oauth_credentials.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
#
55

66

7+
from test.helpers import apply_auth_class_update_body, create_mock_auth_body
8+
79
import pytest
810

911
from snowflake.connector.auth import AuthByOauthCredentials
@@ -26,6 +28,27 @@ def test_auth_oauth_credentials_oauth_type():
2628
)
2729

2830

31+
def test_auth_prepare_body_does_not_overwrite_client_environment_fields():
32+
auth_class = AuthByOauthCredentials(
33+
"app",
34+
"clientId",
35+
"clientSecret",
36+
"https://example.com/oauth/token",
37+
"scope",
38+
)
39+
40+
req_body_before = create_mock_auth_body()
41+
req_body_after = apply_auth_class_update_body(auth_class, req_body_before)
42+
43+
assert all(
44+
[
45+
req_body_before["data"]["CLIENT_ENVIRONMENT"][k]
46+
== req_body_after["data"]["CLIENT_ENVIRONMENT"][k]
47+
for k in req_body_before["data"]["CLIENT_ENVIRONMENT"]
48+
]
49+
)
50+
51+
2952
@pytest.mark.parametrize(
3053
"authenticator, oauth_credentials_in_body",
3154
[

test/unit/test_auth_okta.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from __future__ import annotations
33

44
import logging
5+
from test.helpers import apply_auth_class_update_body, create_mock_auth_body
56
from unittest.mock import Mock, PropertyMock, patch
67

78
import pytest
@@ -19,6 +20,22 @@
1920
from snowflake.connector.auth_okta import AuthByOkta
2021

2122

23+
def test_auth_prepare_body_does_not_overwrite_client_environment_fields():
24+
application = "testapplication"
25+
auth_class = AuthByOkta(application)
26+
27+
req_body_before = create_mock_auth_body()
28+
req_body_after = apply_auth_class_update_body(auth_class, req_body_before)
29+
30+
assert all(
31+
[
32+
req_body_before["data"]["CLIENT_ENVIRONMENT"][k]
33+
== req_body_after["data"]["CLIENT_ENVIRONMENT"][k]
34+
for k in req_body_before["data"]["CLIENT_ENVIRONMENT"]
35+
]
36+
)
37+
38+
2239
def test_auth_okta():
2340
"""Authentication by OKTA positive test case."""
2441
authenticator = "https://testsso.snowflake.net/"

test/unit/test_auth_pat.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
#
55
from __future__ import annotations
66

7+
from test.helpers import apply_auth_class_update_body, create_mock_auth_body
8+
79
import pytest
810

911
from snowflake.connector.auth import AuthByPAT, AuthNoAuth
@@ -26,6 +28,22 @@ def test_auth_pat():
2628
assert auth.assertion_content is None
2729

2830

31+
def test_pat_prepare_body_does_not_overwrite_client_environment_fields():
32+
token = "patToken"
33+
auth_class = AuthByPAT(token)
34+
35+
req_body_before = create_mock_auth_body()
36+
req_body_after = apply_auth_class_update_body(auth_class, req_body_before)
37+
38+
assert all(
39+
[
40+
req_body_before["data"]["CLIENT_ENVIRONMENT"][k]
41+
== req_body_after["data"]["CLIENT_ENVIRONMENT"][k]
42+
for k in req_body_before["data"]["CLIENT_ENVIRONMENT"]
43+
]
44+
)
45+
46+
2947
def test_auth_pat_reauthenticate():
3048
"""Test PAT reauthenticate."""
3149
token = "patToken"

test/unit/test_auth_webbrowser.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
import base64
55
import socket
6+
from test.helpers import apply_auth_class_update_body, create_mock_auth_body
67
from unittest import mock
78
from unittest.mock import MagicMock, Mock, PropertyMock, patch
89

@@ -792,3 +793,17 @@ def mock_webbrowser_auth_prepare(
792793
assert isinstance(conn.auth_class, AuthByWebBrowser)
793794

794795
conn.close()
796+
797+
798+
def test_auth_prepare_body_does_not_overwrite_client_environment_fields():
799+
auth_class = AuthByWebBrowser(application=APPLICATION)
800+
req_body_before = create_mock_auth_body()
801+
req_body_after = apply_auth_class_update_body(auth_class, req_body_before)
802+
803+
assert all(
804+
[
805+
req_body_before["data"]["CLIENT_ENVIRONMENT"][k]
806+
== req_body_after["data"]["CLIENT_ENVIRONMENT"][k]
807+
for k in req_body_before["data"]["CLIENT_ENVIRONMENT"]
808+
]
809+
)

0 commit comments

Comments
 (0)