|
2 | 2 | # Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved.
|
3 | 3 | #
|
4 | 4 |
|
| 5 | +import asyncio |
5 | 6 | import json
|
6 | 7 | import logging
|
7 | 8 | from base64 import b64decode
|
8 | 9 | from unittest import mock
|
9 | 10 | from urllib.parse import parse_qs, urlparse
|
10 | 11 |
|
| 12 | +import aiohttp |
11 | 13 | import jwt
|
12 | 14 | import pytest
|
13 | 15 |
|
14 | 16 | from snowflake.connector.aio._wif_util import AttestationProvider
|
15 | 17 | from snowflake.connector.aio.auth import AuthByWorkloadIdentity
|
16 | 18 | from snowflake.connector.errors import ProgrammingError
|
17 |
| -from snowflake.connector.network import WORKLOAD_IDENTITY_AUTHENTICATOR |
18 |
| -from snowflake.connector.vendored.requests.exceptions import ( |
19 |
| - ConnectTimeout, |
20 |
| - HTTPError, |
21 |
| - Timeout, |
22 |
| -) |
23 | 19 |
|
24 | 20 | from ...csp_helpers import (
|
25 | 21 | FakeAwsEnvironment,
|
@@ -170,19 +166,36 @@ async def test_explicit_aws_generates_unique_assertion_content(
|
170 | 166 | # -- GCP Tests --
|
171 | 167 |
|
172 | 168 |
|
| 169 | +def _mock_aiohttp_exception(exception): |
| 170 | + class MockResponse: |
| 171 | + def __init__(self, exception): |
| 172 | + self.exception = exception |
| 173 | + |
| 174 | + async def __aenter__(self): |
| 175 | + raise self.exception |
| 176 | + |
| 177 | + async def __aexit__(self, exc_type, exc_val, exc_tb): |
| 178 | + pass |
| 179 | + |
| 180 | + def mock_request(*args, **kwargs): |
| 181 | + return MockResponse(exception) |
| 182 | + |
| 183 | + return mock_request |
| 184 | + |
| 185 | + |
173 | 186 | @pytest.mark.parametrize(
|
174 | 187 | "exception",
|
175 | 188 | [
|
176 |
| - HTTPError(), |
177 |
| - Timeout(), |
178 |
| - ConnectTimeout(), |
| 189 | + aiohttp.ClientError(), |
| 190 | + asyncio.TimeoutError(), |
179 | 191 | ],
|
180 | 192 | )
|
181 | 193 | async def test_explicit_gcp_metadata_server_error_raises_auth_error(exception):
|
182 | 194 | auth_class = AuthByWorkloadIdentity(provider=AttestationProvider.GCP)
|
183 |
| - with mock.patch( |
184 |
| - "snowflake.connector.vendored.requests.request", side_effect=exception |
185 |
| - ): |
| 195 | + |
| 196 | + mock_request = _mock_aiohttp_exception(exception) |
| 197 | + |
| 198 | + with mock.patch("aiohttp.ClientSession.request", side_effect=mock_request): |
186 | 199 | with pytest.raises(ProgrammingError) as excinfo:
|
187 | 200 | await auth_class.prepare()
|
188 | 201 | assert "No workload identity credential was found for 'GCP'" in str(
|
@@ -231,16 +244,16 @@ async def test_explicit_gcp_generates_unique_assertion_content(
|
231 | 244 | @pytest.mark.parametrize(
|
232 | 245 | "exception",
|
233 | 246 | [
|
234 |
| - HTTPError(), |
235 |
| - Timeout(), |
236 |
| - ConnectTimeout(), |
| 247 | + aiohttp.ClientError(), |
| 248 | + asyncio.TimeoutError(), |
237 | 249 | ],
|
238 | 250 | )
|
239 | 251 | async def test_explicit_azure_metadata_server_error_raises_auth_error(exception):
|
240 | 252 | auth_class = AuthByWorkloadIdentity(provider=AttestationProvider.AZURE)
|
241 |
| - with mock.patch( |
242 |
| - "snowflake.connector.vendored.requests.request", side_effect=exception |
243 |
| - ): |
| 253 | + |
| 254 | + mock_request = _mock_aiohttp_exception(exception) |
| 255 | + |
| 256 | + with mock.patch("aiohttp.ClientSession.request", side_effect=mock_request): |
244 | 257 | with pytest.raises(ProgrammingError) as excinfo:
|
245 | 258 | await auth_class.prepare()
|
246 | 259 | assert "No workload identity credential was found for 'AZURE'" in str(
|
@@ -367,65 +380,3 @@ async def test_autodetect_no_provider_raises_error(no_metadata_service):
|
367 | 380 | assert "No workload identity credential was found for 'auto-detect" in str(
|
368 | 381 | excinfo.value
|
369 | 382 | )
|
370 |
| - |
371 |
| - |
372 |
| -async def test_workload_identity_authenticator_creates_auth_by_workload_identity( |
373 |
| - monkeypatch, |
374 |
| -): |
375 |
| - """Test that using WORKLOAD_IDENTITY authenticator creates AuthByWorkloadIdentity instance.""" |
376 |
| - import snowflake.connector.aio |
377 |
| - from snowflake.connector.aio._network import SnowflakeRestful |
378 |
| - |
379 |
| - # Mock the network request - this prevents actual network calls and connection errors |
380 |
| - async def mock_post_request(request, url, headers, json_body, **kwargs): |
381 |
| - return { |
382 |
| - "success": True, |
383 |
| - "message": None, |
384 |
| - "data": { |
385 |
| - "token": "TOKEN", |
386 |
| - "masterToken": "MASTER_TOKEN", |
387 |
| - "idToken": None, |
388 |
| - "parameters": [{"name": "SERVICE_NAME", "value": "FAKE_SERVICE_NAME"}], |
389 |
| - }, |
390 |
| - } |
391 |
| - |
392 |
| - # Apply the mock using monkeypatch |
393 |
| - monkeypatch.setattr(SnowflakeRestful, "_post_request", mock_post_request) |
394 |
| - |
395 |
| - # Set the experimental authentication environment variable |
396 |
| - monkeypatch.setenv("SF_ENABLE_EXPERIMENTAL_AUTHENTICATION", "true") |
397 |
| - |
398 |
| - # Mock the workload identity preparation to avoid actual credential fetching |
399 |
| - async def mock_prepare(self, **kwargs): |
400 |
| - # Create a mock attestation to avoid None errors |
401 |
| - from snowflake.connector.wif_util import WorkloadIdentityAttestation |
402 |
| - |
403 |
| - self.attestation = WorkloadIdentityAttestation( |
404 |
| - provider=AttestationProvider.AWS, |
405 |
| - credential="mock_credential", |
406 |
| - user_identifier_components={"arn": "mock_arn"}, |
407 |
| - ) |
408 |
| - |
409 |
| - async def mock_update_body(self, body): |
410 |
| - # Simple mock that just adds the basic fields to avoid actual token processing |
411 |
| - body["data"]["AUTHENTICATOR"] = "WORKLOAD_IDENTITY" |
412 |
| - body["data"]["PROVIDER"] = "AWS" |
413 |
| - body["data"]["TOKEN"] = "mock_token" |
414 |
| - |
415 |
| - monkeypatch.setattr(AuthByWorkloadIdentity, "prepare", mock_prepare) |
416 |
| - monkeypatch.setattr(AuthByWorkloadIdentity, "update_body", mock_update_body) |
417 |
| - |
418 |
| - # Create connection with WORKLOAD_IDENTITY authenticator |
419 |
| - conn = snowflake.connector.aio.SnowflakeConnection( |
420 |
| - account="account", |
421 |
| - authenticator=WORKLOAD_IDENTITY_AUTHENTICATOR, |
422 |
| - workload_identity_provider=AttestationProvider.AWS, |
423 |
| - token="test_token", |
424 |
| - ) |
425 |
| - |
426 |
| - await conn.connect() |
427 |
| - |
428 |
| - # Verify that the auth_class is an instance of AuthByWorkloadIdentity |
429 |
| - assert isinstance(conn.auth_class, AuthByWorkloadIdentity) |
430 |
| - |
431 |
| - await conn.close() |
0 commit comments