Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 96 additions & 0 deletions test/unit/test_auth.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,25 @@
#!/usr/bin/env python
from __future__ import annotations

import copy
import inspect
import sys
import time
from typing import Optional, get_type_hints
from unittest.mock import Mock, PropertyMock

import pytest

import snowflake.connector.errors
from snowflake.connector.compat import IS_WINDOWS
from snowflake.connector.connection import SnowflakeConnection
from snowflake.connector.constants import OCSPMode
from snowflake.connector.description import CLIENT_NAME, CLIENT_VERSION
from snowflake.connector.network import SnowflakeRestful
from snowflake.connector.wif_util import (
AttestationProvider,
WorkloadIdentityAttestation,
)

from .mock_utils import mock_connection

Expand Down Expand Up @@ -140,6 +147,47 @@ def _mock_auth_mfa_rest_response_timeout(url, headers, body, **kwargs):
return ret


def _get_most_derived_subclasses(cls):
subclasses = cls.__subclasses__()
if not subclasses:
return [cls]
most_derived = []
for subclass in subclasses:
most_derived.extend(_get_most_derived_subclasses(subclass))
return most_derived


def _get_default_args_for_class(cls):
def _get_default_arg_for_type(t, name):
if getattr(t, "__origin__", None) is Optional:
return None
if t is str:
if "url" in name or "uri" in name:
return "https://example.com"
return name
if t is int:
return 0
if t is bool:
return False
if t is float:
return 0.0
if t is AttestationProvider:
return AttestationProvider.GCP
return None

sig = inspect.signature(cls.__init__)
type_hints = get_type_hints(
cls.__init__, localns={"SnowflakeConnection": SnowflakeConnection}
)

args = {}
for param in sig.parameters.values():
if param.name != "self":
param_type = type_hints.get(param.name, str)
args[param.name] = _get_default_arg_for_type(param_type, param.name)
return args


@pytest.mark.skipif(
IS_WINDOWS,
reason="There are consistent race condition issues with the global mock_cnt used for this test on windows",
Expand Down Expand Up @@ -337,3 +385,51 @@ def test_authbyplugin_abc_api():
'password': <Parameter "password: 'str'">, \
'kwargs': <Parameter "**kwargs: 'Any'">})"""
)


@pytest.mark.skipif(
sys.version_info < (3, 10),
reason="Typing using '|' requires python 3.10 or higher (PEP 604)",
)
@pytest.mark.parametrize("auth_method", _get_most_derived_subclasses(AuthByPlugin))
def test_auth_prepare_body_does_not_overwrite_fields(auth_method):
ocsp_mode = Mock()
ocsp_mode.name = "ocsp_mode"
session_manager = Mock()
session_manager.clone = lambda max_retries: "session_manager"

req_body_before = Auth.base_auth_data(
"user",
"account",
"application",
"internal_application_name",
"internal_application_version",
ocsp_mode,
login_timeout=60 * 60,
network_timeout=60 * 60,
socket_timeout=60 * 60,
platform_detection_timeout_seconds=0.2,
session_manager=session_manager,
)
req_body_after = copy.deepcopy(req_body_before)
additional_args = _get_default_args_for_class(auth_method)
auth_class = auth_method(**additional_args)
auth_class.attestation = WorkloadIdentityAttestation(
provider=AttestationProvider.GCP,
credential=None,
user_identifier_components=None,
)
auth_class.update_body(req_body_after)

# Check that the values in the body before are a strict subset of the values in the body after.
# Must use all() for this comparison because lists are not hashable
assert all(
[
req_body_before["data"]["CLIENT_ENVIRONMENT"][k]
== req_body_after["data"]["CLIENT_ENVIRONMENT"][k]
for k in req_body_before["data"]["CLIENT_ENVIRONMENT"]
]
)
req_body_before["data"].pop("CLIENT_ENVIRONMENT")
req_body_after["data"].pop("CLIENT_ENVIRONMENT")
assert set(req_body_before["data"].items()) <= set(req_body_after["data"].items())
61 changes: 2 additions & 59 deletions test/unit/test_auth_workload_identity.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import copy
import json
import logging
from base64 import b64decode
Expand All @@ -8,18 +7,14 @@
import jwt
import pytest

from snowflake.connector.auth import Auth, AuthByWorkloadIdentity
from snowflake.connector.auth import AuthByWorkloadIdentity
from snowflake.connector.errors import ProgrammingError
from snowflake.connector.vendored.requests.exceptions import (
ConnectTimeout,
HTTPError,
Timeout,
)
from snowflake.connector.wif_util import (
AttestationProvider,
WorkloadIdentityAttestation,
get_aws_sts_hostname,
)
from snowflake.connector.wif_util import AttestationProvider, get_aws_sts_hostname

from ..csp_helpers import (
FakeAwsEnvironment,
Expand Down Expand Up @@ -127,58 +122,6 @@ def test_wif_authenticator_is_case_insensitive(
assert isinstance(connection.auth_class, AuthByWorkloadIdentity)


@pytest.mark.parametrize(
"provider,additional_args",
[
(AttestationProvider.AWS, {}),
(AttestationProvider.GCP, {}),
(AttestationProvider.AZURE, {}),
(
AttestationProvider.OIDC,
{"token": gen_dummy_id_token(sub="service-1", iss="issuer-1")},
),
],
)
def test_wif_prepare_body_does_not_overwrite_fields(provider, additional_args):
ocsp_mode = mock.Mock()
ocsp_mode.name = "ocsp_mode"
session_manager = mock.Mock()
session_manager.clone = lambda max_retries: "session_manager"

req_body_before = Auth.base_auth_data(
"user",
"account",
"application",
"internal_application_name",
"internal_application_version",
ocsp_mode,
login_timeout=60 * 60,
network_timeout=60 * 60,
socket_timeout=60 * 60,
platform_detection_timeout_seconds=0.2,
session_manager=session_manager,
)
req_body_after = copy.deepcopy(req_body_before)
auth_class = AuthByWorkloadIdentity(provider=provider, **additional_args)
auth_class.attestation = WorkloadIdentityAttestation(
provider=provider, credential=None, user_identifier_components=None
)
auth_class.update_body(req_body_after)

# Check that the values in the body before are a strict subset of the values in the body after.
# Must use all() for this comparison because lists are not hashable
assert all(
[
req_body_before["data"]["CLIENT_ENVIRONMENT"][k]
== req_body_after["data"]["CLIENT_ENVIRONMENT"][k]
for k in req_body_before["data"]["CLIENT_ENVIRONMENT"]
]
)
req_body_before["data"].pop("CLIENT_ENVIRONMENT")
req_body_after["data"].pop("CLIENT_ENVIRONMENT")
assert set(req_body_before["data"].items()) <= set(req_body_after["data"].items())


# -- OIDC Tests --


Expand Down
Loading