Skip to content

Commit 7ef858f

Browse files
committed
move unit test to test_auth and test all AuthByPlugin derived classes
1 parent 395928b commit 7ef858f

File tree

2 files changed

+98
-59
lines changed

2 files changed

+98
-59
lines changed

test/unit/test_auth.py

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,25 @@
11
#!/usr/bin/env python
22
from __future__ import annotations
33

4+
import copy
45
import inspect
56
import sys
67
import time
8+
from typing import Optional, get_type_hints
79
from unittest.mock import Mock, PropertyMock
810

911
import pytest
1012

1113
import snowflake.connector.errors
1214
from snowflake.connector.compat import IS_WINDOWS
15+
from snowflake.connector.connection import SnowflakeConnection
1316
from snowflake.connector.constants import OCSPMode
1417
from snowflake.connector.description import CLIENT_NAME, CLIENT_VERSION
1518
from snowflake.connector.network import SnowflakeRestful
19+
from snowflake.connector.wif_util import (
20+
AttestationProvider,
21+
WorkloadIdentityAttestation,
22+
)
1623

1724
from .mock_utils import mock_connection
1825

@@ -140,6 +147,47 @@ def _mock_auth_mfa_rest_response_timeout(url, headers, body, **kwargs):
140147
return ret
141148

142149

150+
def _get_most_derived_subclasses(cls):
151+
subclasses = cls.__subclasses__()
152+
if not subclasses:
153+
return [cls]
154+
most_derived = []
155+
for subclass in subclasses:
156+
most_derived.extend(_get_most_derived_subclasses(subclass))
157+
return most_derived
158+
159+
160+
def _get_default_args_for_class(cls):
161+
def _get_default_arg_for_type(t, name):
162+
if getattr(t, "__origin__", None) is Optional:
163+
return None
164+
if t is str:
165+
if "url" in name or "uri" in name:
166+
return "https://example.com"
167+
return name
168+
if t is int:
169+
return 0
170+
if t is bool:
171+
return False
172+
if t is float:
173+
return 0.0
174+
if t is AttestationProvider:
175+
return AttestationProvider.GCP
176+
return None
177+
178+
sig = inspect.signature(cls.__init__)
179+
type_hints = get_type_hints(
180+
cls.__init__, localns={"SnowflakeConnection": SnowflakeConnection}
181+
)
182+
183+
args = {}
184+
for param in sig.parameters.values():
185+
if param.name != "self":
186+
param_type = type_hints.get(param.name, str)
187+
args[param.name] = _get_default_arg_for_type(param_type, param.name)
188+
return args
189+
190+
143191
@pytest.mark.skipif(
144192
IS_WINDOWS,
145193
reason="There are consistent race condition issues with the global mock_cnt used for this test on windows",
@@ -337,3 +385,51 @@ def test_authbyplugin_abc_api():
337385
'password': <Parameter "password: 'str'">, \
338386
'kwargs': <Parameter "**kwargs: 'Any'">})"""
339387
)
388+
389+
390+
@pytest.mark.skipif(
391+
sys.version_info < (3, 10),
392+
reason="Typing using '|' requires python 3.10 or higher (PEP 604)",
393+
)
394+
@pytest.mark.parametrize("auth_method", _get_most_derived_subclasses(AuthByPlugin))
395+
def test_auth_prepare_body_does_not_overwrite_fields(auth_method):
396+
ocsp_mode = Mock()
397+
ocsp_mode.name = "ocsp_mode"
398+
session_manager = Mock()
399+
session_manager.clone = lambda max_retries: "session_manager"
400+
401+
req_body_before = Auth.base_auth_data(
402+
"user",
403+
"account",
404+
"application",
405+
"internal_application_name",
406+
"internal_application_version",
407+
ocsp_mode,
408+
login_timeout=60 * 60,
409+
network_timeout=60 * 60,
410+
socket_timeout=60 * 60,
411+
platform_detection_timeout_seconds=0.2,
412+
session_manager=session_manager,
413+
)
414+
req_body_after = copy.deepcopy(req_body_before)
415+
additional_args = _get_default_args_for_class(auth_method)
416+
auth_class = auth_method(**additional_args)
417+
auth_class.attestation = WorkloadIdentityAttestation(
418+
provider=AttestationProvider.GCP,
419+
credential=None,
420+
user_identifier_components=None,
421+
)
422+
auth_class.update_body(req_body_after)
423+
424+
# Check that the values in the body before are a strict subset of the values in the body after.
425+
# Must use all() for this comparison because lists are not hashable
426+
assert all(
427+
[
428+
req_body_before["data"]["CLIENT_ENVIRONMENT"][k]
429+
== req_body_after["data"]["CLIENT_ENVIRONMENT"][k]
430+
for k in req_body_before["data"]["CLIENT_ENVIRONMENT"]
431+
]
432+
)
433+
req_body_before["data"].pop("CLIENT_ENVIRONMENT")
434+
req_body_after["data"].pop("CLIENT_ENVIRONMENT")
435+
assert set(req_body_before["data"].items()) <= set(req_body_after["data"].items())

test/unit/test_auth_workload_identity.py

Lines changed: 2 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import copy
21
import json
32
import logging
43
from base64 import b64decode
@@ -8,18 +7,14 @@
87
import jwt
98
import pytest
109

11-
from snowflake.connector.auth import Auth, AuthByWorkloadIdentity
10+
from snowflake.connector.auth import AuthByWorkloadIdentity
1211
from snowflake.connector.errors import ProgrammingError
1312
from snowflake.connector.vendored.requests.exceptions import (
1413
ConnectTimeout,
1514
HTTPError,
1615
Timeout,
1716
)
18-
from snowflake.connector.wif_util import (
19-
AttestationProvider,
20-
WorkloadIdentityAttestation,
21-
get_aws_sts_hostname,
22-
)
17+
from snowflake.connector.wif_util import AttestationProvider, get_aws_sts_hostname
2318

2419
from ..csp_helpers import (
2520
FakeAwsEnvironment,
@@ -127,58 +122,6 @@ def test_wif_authenticator_is_case_insensitive(
127122
assert isinstance(connection.auth_class, AuthByWorkloadIdentity)
128123

129124

130-
@pytest.mark.parametrize(
131-
"provider,additional_args",
132-
[
133-
(AttestationProvider.AWS, {}),
134-
(AttestationProvider.GCP, {}),
135-
(AttestationProvider.AZURE, {}),
136-
(
137-
AttestationProvider.OIDC,
138-
{"token": gen_dummy_id_token(sub="service-1", iss="issuer-1")},
139-
),
140-
],
141-
)
142-
def test_wif_prepare_body_does_not_overwrite_fields(provider, additional_args):
143-
ocsp_mode = mock.Mock()
144-
ocsp_mode.name = "ocsp_mode"
145-
session_manager = mock.Mock()
146-
session_manager.clone = lambda max_retries: "session_manager"
147-
148-
req_body_before = Auth.base_auth_data(
149-
"user",
150-
"account",
151-
"application",
152-
"internal_application_name",
153-
"internal_application_version",
154-
ocsp_mode,
155-
login_timeout=60 * 60,
156-
network_timeout=60 * 60,
157-
socket_timeout=60 * 60,
158-
platform_detection_timeout_seconds=0.2,
159-
session_manager=session_manager,
160-
)
161-
req_body_after = copy.deepcopy(req_body_before)
162-
auth_class = AuthByWorkloadIdentity(provider=provider, **additional_args)
163-
auth_class.attestation = WorkloadIdentityAttestation(
164-
provider=provider, credential=None, user_identifier_components=None
165-
)
166-
auth_class.update_body(req_body_after)
167-
168-
# Check that the values in the body before are a strict subset of the values in the body after.
169-
# Must use all() for this comparison because lists are not hashable
170-
assert all(
171-
[
172-
req_body_before["data"]["CLIENT_ENVIRONMENT"][k]
173-
== req_body_after["data"]["CLIENT_ENVIRONMENT"][k]
174-
for k in req_body_before["data"]["CLIENT_ENVIRONMENT"]
175-
]
176-
)
177-
req_body_before["data"].pop("CLIENT_ENVIRONMENT")
178-
req_body_after["data"].pop("CLIENT_ENVIRONMENT")
179-
assert set(req_body_before["data"].items()) <= set(req_body_after["data"].items())
180-
181-
182125
# -- OIDC Tests --
183126

184127

0 commit comments

Comments
 (0)