|
1 | 1 | #!/usr/bin/env python |
2 | 2 | from __future__ import annotations |
3 | 3 |
|
| 4 | +import copy |
4 | 5 | import inspect |
5 | 6 | import sys |
6 | 7 | import time |
| 8 | +from typing import Optional, get_type_hints |
7 | 9 | from unittest.mock import Mock, PropertyMock |
8 | 10 |
|
9 | 11 | import pytest |
10 | 12 |
|
11 | 13 | import snowflake.connector.errors |
12 | 14 | from snowflake.connector.compat import IS_WINDOWS |
| 15 | +from snowflake.connector.connection import SnowflakeConnection |
13 | 16 | from snowflake.connector.constants import OCSPMode |
14 | 17 | from snowflake.connector.description import CLIENT_NAME, CLIENT_VERSION |
15 | 18 | from snowflake.connector.network import SnowflakeRestful |
| 19 | +from snowflake.connector.wif_util import ( |
| 20 | + AttestationProvider, |
| 21 | + WorkloadIdentityAttestation, |
| 22 | +) |
16 | 23 |
|
17 | 24 | from .mock_utils import mock_connection |
18 | 25 |
|
@@ -140,6 +147,47 @@ def _mock_auth_mfa_rest_response_timeout(url, headers, body, **kwargs): |
140 | 147 | return ret |
141 | 148 |
|
142 | 149 |
|
| 150 | +def _get_most_derived_subclasses(cls): |
| 151 | + subclasses = cls.__subclasses__() |
| 152 | + if not subclasses: |
| 153 | + return [cls] |
| 154 | + most_derived = [] |
| 155 | + for subclass in subclasses: |
| 156 | + most_derived.extend(_get_most_derived_subclasses(subclass)) |
| 157 | + return most_derived |
| 158 | + |
| 159 | + |
| 160 | +def _get_default_args_for_class(cls): |
| 161 | + def _get_default_arg_for_type(t, name): |
| 162 | + if getattr(t, "__origin__", None) is Optional: |
| 163 | + return None |
| 164 | + if t is str: |
| 165 | + if "url" in name or "uri" in name: |
| 166 | + return "https://example.com" |
| 167 | + return name |
| 168 | + if t is int: |
| 169 | + return 0 |
| 170 | + if t is bool: |
| 171 | + return False |
| 172 | + if t is float: |
| 173 | + return 0.0 |
| 174 | + if t is AttestationProvider: |
| 175 | + return AttestationProvider.GCP |
| 176 | + return None |
| 177 | + |
| 178 | + sig = inspect.signature(cls.__init__) |
| 179 | + type_hints = get_type_hints( |
| 180 | + cls.__init__, localns={"SnowflakeConnection": SnowflakeConnection} |
| 181 | + ) |
| 182 | + |
| 183 | + args = {} |
| 184 | + for param in sig.parameters.values(): |
| 185 | + if param.name != "self": |
| 186 | + param_type = type_hints.get(param.name, str) |
| 187 | + args[param.name] = _get_default_arg_for_type(param_type, param.name) |
| 188 | + return args |
| 189 | + |
| 190 | + |
143 | 191 | @pytest.mark.skipif( |
144 | 192 | IS_WINDOWS, |
145 | 193 | reason="There are consistent race condition issues with the global mock_cnt used for this test on windows", |
@@ -337,3 +385,51 @@ def test_authbyplugin_abc_api(): |
337 | 385 | 'password': <Parameter "password: 'str'">, \ |
338 | 386 | 'kwargs': <Parameter "**kwargs: 'Any'">})""" |
339 | 387 | ) |
| 388 | + |
| 389 | + |
| 390 | +@pytest.mark.skipif( |
| 391 | + sys.version_info < (3, 10), |
| 392 | + reason="Typing using '|' requires python 3.10 or higher (PEP 604)", |
| 393 | +) |
| 394 | +@pytest.mark.parametrize("auth_method", _get_most_derived_subclasses(AuthByPlugin)) |
| 395 | +def test_auth_prepare_body_does_not_overwrite_fields(auth_method): |
| 396 | + ocsp_mode = Mock() |
| 397 | + ocsp_mode.name = "ocsp_mode" |
| 398 | + session_manager = Mock() |
| 399 | + session_manager.clone = lambda max_retries: "session_manager" |
| 400 | + |
| 401 | + req_body_before = Auth.base_auth_data( |
| 402 | + "user", |
| 403 | + "account", |
| 404 | + "application", |
| 405 | + "internal_application_name", |
| 406 | + "internal_application_version", |
| 407 | + ocsp_mode, |
| 408 | + login_timeout=60 * 60, |
| 409 | + network_timeout=60 * 60, |
| 410 | + socket_timeout=60 * 60, |
| 411 | + platform_detection_timeout_seconds=0.2, |
| 412 | + session_manager=session_manager, |
| 413 | + ) |
| 414 | + req_body_after = copy.deepcopy(req_body_before) |
| 415 | + additional_args = _get_default_args_for_class(auth_method) |
| 416 | + auth_class = auth_method(**additional_args) |
| 417 | + auth_class.attestation = WorkloadIdentityAttestation( |
| 418 | + provider=AttestationProvider.GCP, |
| 419 | + credential=None, |
| 420 | + user_identifier_components=None, |
| 421 | + ) |
| 422 | + auth_class.update_body(req_body_after) |
| 423 | + |
| 424 | + # Check that the values in the body before are a strict subset of the values in the body after. |
| 425 | + # Must use all() for this comparison because lists are not hashable |
| 426 | + assert all( |
| 427 | + [ |
| 428 | + req_body_before["data"]["CLIENT_ENVIRONMENT"][k] |
| 429 | + == req_body_after["data"]["CLIENT_ENVIRONMENT"][k] |
| 430 | + for k in req_body_before["data"]["CLIENT_ENVIRONMENT"] |
| 431 | + ] |
| 432 | + ) |
| 433 | + req_body_before["data"].pop("CLIENT_ENVIRONMENT") |
| 434 | + req_body_after["data"].pop("CLIENT_ENVIRONMENT") |
| 435 | + assert set(req_body_before["data"].items()) <= set(req_body_after["data"].items()) |
0 commit comments