Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/snowflake/connector/auth/workload_identity.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,9 @@ def update_body(self, body: dict[typing.Any, typing.Any]) -> None:
self.attestation
).value
body["data"]["TOKEN"] = self.attestation.credential
body["data"].setdefault("CLIENT_ENVIRONMENT", {})[
"WORKLOAD_IDENTITY_IMPERSONATION_PATH_LENGTH"
] = len(self.impersonation_path or [])

def prepare(
self, *, conn: SnowflakeConnection | None, **kwargs: typing.Any
Expand Down
29 changes: 29 additions & 0 deletions test/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from __future__ import annotations

import base64
import copy
import math
import os
import random
Expand All @@ -12,6 +13,7 @@

import pytest

from snowflake.connector.auth._auth import Auth
from snowflake.connector.compat import OK

if TYPE_CHECKING:
Expand Down Expand Up @@ -260,3 +262,30 @@ def _arrow_error_stream_random_input_test(use_table_iterator):
# error instance users get should be the same
assert len(exception_result)
assert len(result_array) == 0


def create_mock_auth_body():
ocsp_mode = Mock()
ocsp_mode.name = "ocsp_mode"
session_manager = Mock()
session_manager.clone = lambda max_retries: "session_manager"

return Auth.base_auth_data(
"user",
"account",
"application",
"internal_application_name",
"internal_application_version",
ocsp_mode,
login_timeout=60 * 60,
network_timeout=60 * 60,
socket_timeout=60 * 60,
platform_detection_timeout_seconds=0.2,
session_manager=session_manager,
)


def apply_auth_class_update_body(auth_class, req_body_before):
req_body_after = copy.deepcopy(req_body_before)
auth_class.update_body(req_body_after)
return req_body_after
17 changes: 17 additions & 0 deletions test/unit/test_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import inspect
import sys
import time
from test.helpers import apply_auth_class_update_body, create_mock_auth_body
from unittest.mock import Mock, PropertyMock

import pytest
Expand Down Expand Up @@ -337,3 +338,19 @@ def test_authbyplugin_abc_api():
'password': <Parameter "password: 'str'">, \
'kwargs': <Parameter "**kwargs: 'Any'">})"""
)


def test_auth_by_default_prepare_body_does_not_overwrite_client_environment_fields():
password = "testpassword"
auth_class = AuthByDefault(password)

req_body_before = create_mock_auth_body()
req_body_after = apply_auth_class_update_body(auth_class, req_body_before)

assert all(
[
req_body_before["data"]["CLIENT_ENVIRONMENT"][k]
== req_body_after["data"]["CLIENT_ENVIRONMENT"][k]
for k in req_body_before["data"]["CLIENT_ENVIRONMENT"]
]
)
17 changes: 17 additions & 0 deletions test/unit/test_auth_keypair.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#!/usr/bin/env python
from __future__ import annotations

from test.helpers import apply_auth_class_update_body, create_mock_auth_body
from unittest.mock import Mock, PropertyMock, patch

import pytest
Expand Down Expand Up @@ -63,6 +64,22 @@ def test_auth_keypair(authenticator):
assert rest.master_token == "MASTER_TOKEN"


def test_auth_prepare_body_does_not_overwrite_client_environment_fields():
private_key_der, _ = generate_key_pair(2048)
auth_class = AuthByKeyPair(private_key=private_key_der)

req_body_before = create_mock_auth_body()
req_body_after = apply_auth_class_update_body(auth_class, req_body_before)

assert all(
[
req_body_before["data"]["CLIENT_ENVIRONMENT"][k]
== req_body_after["data"]["CLIENT_ENVIRONMENT"][k]
for k in req_body_before["data"]["CLIENT_ENVIRONMENT"]
]
)


def test_auth_keypair_abc():
"""Simple Key Pair test using abstraction layer."""
private_key_der, public_key_der_encoded = generate_key_pair(2048)
Expand Down
18 changes: 18 additions & 0 deletions test/unit/test_auth_oauth.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
#!/usr/bin/env python
from __future__ import annotations

from test.helpers import apply_auth_class_update_body, create_mock_auth_body

try: # pragma: no cover
from snowflake.connector.auth import AuthByOAuth
except ImportError:
Expand All @@ -18,6 +20,22 @@ def test_auth_oauth():
assert body["data"]["AUTHENTICATOR"] == "OAUTH", body


def test_auth_prepare_body_does_not_overwrite_client_environment_fields():
token = "oAuthToken"
auth_class = AuthByOAuth(token)

req_body_before = create_mock_auth_body()
req_body_after = apply_auth_class_update_body(auth_class, req_body_before)

assert all(
[
req_body_before["data"]["CLIENT_ENVIRONMENT"][k]
== req_body_after["data"]["CLIENT_ENVIRONMENT"][k]
for k in req_body_before["data"]["CLIENT_ENVIRONMENT"]
]
)


@pytest.mark.parametrize("authenticator", ["oauth", "OAUTH"])
def test_oauth_authenticator_is_case_insensitive(monkeypatch, authenticator):
"""Test that oauth authenticator is case insensitive."""
Expand Down
27 changes: 27 additions & 0 deletions test/unit/test_auth_oauth_auth_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#

import unittest.mock as mock
from test.helpers import apply_auth_class_update_body, create_mock_auth_body
from unittest.mock import patch

import pytest
Expand Down Expand Up @@ -44,6 +45,32 @@ def test_auth_oauth_auth_code_oauth_type(omit_oauth_urls_check):
)


def test_auth_prepare_body_does_not_overwrite_client_environment_fields(
omit_oauth_urls_check,
):
auth_class = AuthByOauthCode(
"app",
"clientId",
"clientSecret",
"auth_url",
"tokenRequestUrl",
"redirectUri:{port}",
"scope",
"host",
)

req_body_before = create_mock_auth_body()
req_body_after = apply_auth_class_update_body(auth_class, req_body_before)

assert all(
[
req_body_before["data"]["CLIENT_ENVIRONMENT"][k]
== req_body_after["data"]["CLIENT_ENVIRONMENT"][k]
for k in req_body_before["data"]["CLIENT_ENVIRONMENT"]
]
)


@pytest.mark.parametrize("rtr_enabled", [True, False])
def test_auth_oauth_auth_code_single_use_refresh_tokens(
rtr_enabled: bool, omit_oauth_urls_check
Expand Down
23 changes: 23 additions & 0 deletions test/unit/test_auth_oauth_credentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
#


from test.helpers import apply_auth_class_update_body, create_mock_auth_body

import pytest

from snowflake.connector.auth import AuthByOauthCredentials
Expand All @@ -26,6 +28,27 @@ def test_auth_oauth_credentials_oauth_type():
)


def test_auth_prepare_body_does_not_overwrite_client_environment_fields():
auth_class = AuthByOauthCredentials(
"app",
"clientId",
"clientSecret",
"https://example.com/oauth/token",
"scope",
)

req_body_before = create_mock_auth_body()
req_body_after = apply_auth_class_update_body(auth_class, req_body_before)

assert all(
[
req_body_before["data"]["CLIENT_ENVIRONMENT"][k]
== req_body_after["data"]["CLIENT_ENVIRONMENT"][k]
for k in req_body_before["data"]["CLIENT_ENVIRONMENT"]
]
)


@pytest.mark.parametrize(
"authenticator", ["OAUTH_CLIENT_CREDENTIALS", "oauth_client_credentials"]
)
Expand Down
17 changes: 17 additions & 0 deletions test/unit/test_auth_okta.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from __future__ import annotations

import logging
from test.helpers import apply_auth_class_update_body, create_mock_auth_body
from unittest.mock import Mock, PropertyMock, patch

import pytest
Expand All @@ -19,6 +20,22 @@
from snowflake.connector.auth_okta import AuthByOkta


def test_auth_prepare_body_does_not_overwrite_client_environment_fields():
application = "testapplication"
auth_class = AuthByOkta(application)

req_body_before = create_mock_auth_body()
req_body_after = apply_auth_class_update_body(auth_class, req_body_before)

assert all(
[
req_body_before["data"]["CLIENT_ENVIRONMENT"][k]
== req_body_after["data"]["CLIENT_ENVIRONMENT"][k]
for k in req_body_before["data"]["CLIENT_ENVIRONMENT"]
]
)


def test_auth_okta():
"""Authentication by OKTA positive test case."""
authenticator = "https://testsso.snowflake.net/"
Expand Down
18 changes: 18 additions & 0 deletions test/unit/test_auth_pat.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
#
from __future__ import annotations

from test.helpers import apply_auth_class_update_body, create_mock_auth_body

import pytest

from snowflake.connector.auth import AuthByPAT, AuthNoAuth
Expand All @@ -26,6 +28,22 @@ def test_auth_pat():
assert auth.assertion_content is None


def test_pat_prepare_body_does_not_overwrite_client_environment_fields():
token = "patToken"
auth_class = AuthByPAT(token)

req_body_before = create_mock_auth_body()
req_body_after = apply_auth_class_update_body(auth_class, req_body_before)

assert all(
[
req_body_before["data"]["CLIENT_ENVIRONMENT"][k]
== req_body_after["data"]["CLIENT_ENVIRONMENT"][k]
for k in req_body_before["data"]["CLIENT_ENVIRONMENT"]
]
)


def test_auth_pat_reauthenticate():
"""Test PAT reauthenticate."""
token = "patToken"
Expand Down
15 changes: 15 additions & 0 deletions test/unit/test_auth_webbrowser.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import base64
import socket
from test.helpers import apply_auth_class_update_body, create_mock_auth_body
from unittest import mock
from unittest.mock import MagicMock, Mock, PropertyMock, patch

Expand Down Expand Up @@ -792,3 +793,17 @@ def mock_webbrowser_auth_prepare(
assert isinstance(conn.auth_class, AuthByWebBrowser)

conn.close()


def test_auth_prepare_body_does_not_overwrite_client_environment_fields():
auth_class = AuthByWebBrowser(application=APPLICATION)
req_body_before = create_mock_auth_body()
req_body_after = apply_auth_class_update_body(auth_class, req_body_before)

assert all(
[
req_body_before["data"]["CLIENT_ENVIRONMENT"][k]
== req_body_after["data"]["CLIENT_ENVIRONMENT"][k]
for k in req_body_before["data"]["CLIENT_ENVIRONMENT"]
]
)
Loading
Loading