Skip to content

Commit 30685d5

Browse files
properly mock tests
1 parent 5e5fb5e commit 30685d5

File tree

1 file changed

+31
-80
lines changed

1 file changed

+31
-80
lines changed

test/unit/aio/test_auth_workload_identity_async.py

Lines changed: 31 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -2,24 +2,20 @@
22
# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved.
33
#
44

5+
import asyncio
56
import json
67
import logging
78
from base64 import b64decode
89
from unittest import mock
910
from urllib.parse import parse_qs, urlparse
1011

12+
import aiohttp
1113
import jwt
1214
import pytest
1315

1416
from snowflake.connector.aio._wif_util import AttestationProvider
1517
from snowflake.connector.aio.auth import AuthByWorkloadIdentity
1618
from snowflake.connector.errors import ProgrammingError
17-
from snowflake.connector.network import WORKLOAD_IDENTITY_AUTHENTICATOR
18-
from snowflake.connector.vendored.requests.exceptions import (
19-
ConnectTimeout,
20-
HTTPError,
21-
Timeout,
22-
)
2319

2420
from ...csp_helpers import (
2521
FakeAwsEnvironment,
@@ -170,19 +166,36 @@ async def test_explicit_aws_generates_unique_assertion_content(
170166
# -- GCP Tests --
171167

172168

169+
def _mock_aiohttp_exception(exception):
170+
class MockResponse:
171+
def __init__(self, exception):
172+
self.exception = exception
173+
174+
async def __aenter__(self):
175+
raise self.exception
176+
177+
async def __aexit__(self, exc_type, exc_val, exc_tb):
178+
pass
179+
180+
def mock_request(*args, **kwargs):
181+
return MockResponse(exception)
182+
183+
return mock_request
184+
185+
173186
@pytest.mark.parametrize(
174187
"exception",
175188
[
176-
HTTPError(),
177-
Timeout(),
178-
ConnectTimeout(),
189+
aiohttp.ClientError(),
190+
asyncio.TimeoutError(),
179191
],
180192
)
181193
async def test_explicit_gcp_metadata_server_error_raises_auth_error(exception):
182194
auth_class = AuthByWorkloadIdentity(provider=AttestationProvider.GCP)
183-
with mock.patch(
184-
"snowflake.connector.vendored.requests.request", side_effect=exception
185-
):
195+
196+
mock_request = _mock_aiohttp_exception(exception)
197+
198+
with mock.patch("aiohttp.ClientSession.request", side_effect=mock_request):
186199
with pytest.raises(ProgrammingError) as excinfo:
187200
await auth_class.prepare()
188201
assert "No workload identity credential was found for 'GCP'" in str(
@@ -231,16 +244,16 @@ async def test_explicit_gcp_generates_unique_assertion_content(
231244
@pytest.mark.parametrize(
232245
"exception",
233246
[
234-
HTTPError(),
235-
Timeout(),
236-
ConnectTimeout(),
247+
aiohttp.ClientError(),
248+
asyncio.TimeoutError(),
237249
],
238250
)
239251
async def test_explicit_azure_metadata_server_error_raises_auth_error(exception):
240252
auth_class = AuthByWorkloadIdentity(provider=AttestationProvider.AZURE)
241-
with mock.patch(
242-
"snowflake.connector.vendored.requests.request", side_effect=exception
243-
):
253+
254+
mock_request = _mock_aiohttp_exception(exception)
255+
256+
with mock.patch("aiohttp.ClientSession.request", side_effect=mock_request):
244257
with pytest.raises(ProgrammingError) as excinfo:
245258
await auth_class.prepare()
246259
assert "No workload identity credential was found for 'AZURE'" in str(
@@ -367,65 +380,3 @@ async def test_autodetect_no_provider_raises_error(no_metadata_service):
367380
assert "No workload identity credential was found for 'auto-detect" in str(
368381
excinfo.value
369382
)
370-
371-
372-
async def test_workload_identity_authenticator_creates_auth_by_workload_identity(
373-
monkeypatch,
374-
):
375-
"""Test that using WORKLOAD_IDENTITY authenticator creates AuthByWorkloadIdentity instance."""
376-
import snowflake.connector.aio
377-
from snowflake.connector.aio._network import SnowflakeRestful
378-
379-
# Mock the network request - this prevents actual network calls and connection errors
380-
async def mock_post_request(request, url, headers, json_body, **kwargs):
381-
return {
382-
"success": True,
383-
"message": None,
384-
"data": {
385-
"token": "TOKEN",
386-
"masterToken": "MASTER_TOKEN",
387-
"idToken": None,
388-
"parameters": [{"name": "SERVICE_NAME", "value": "FAKE_SERVICE_NAME"}],
389-
},
390-
}
391-
392-
# Apply the mock using monkeypatch
393-
monkeypatch.setattr(SnowflakeRestful, "_post_request", mock_post_request)
394-
395-
# Set the experimental authentication environment variable
396-
monkeypatch.setenv("SF_ENABLE_EXPERIMENTAL_AUTHENTICATION", "true")
397-
398-
# Mock the workload identity preparation to avoid actual credential fetching
399-
async def mock_prepare(self, **kwargs):
400-
# Create a mock attestation to avoid None errors
401-
from snowflake.connector.wif_util import WorkloadIdentityAttestation
402-
403-
self.attestation = WorkloadIdentityAttestation(
404-
provider=AttestationProvider.AWS,
405-
credential="mock_credential",
406-
user_identifier_components={"arn": "mock_arn"},
407-
)
408-
409-
async def mock_update_body(self, body):
410-
# Simple mock that just adds the basic fields to avoid actual token processing
411-
body["data"]["AUTHENTICATOR"] = "WORKLOAD_IDENTITY"
412-
body["data"]["PROVIDER"] = "AWS"
413-
body["data"]["TOKEN"] = "mock_token"
414-
415-
monkeypatch.setattr(AuthByWorkloadIdentity, "prepare", mock_prepare)
416-
monkeypatch.setattr(AuthByWorkloadIdentity, "update_body", mock_update_body)
417-
418-
# Create connection with WORKLOAD_IDENTITY authenticator
419-
conn = snowflake.connector.aio.SnowflakeConnection(
420-
account="account",
421-
authenticator=WORKLOAD_IDENTITY_AUTHENTICATOR,
422-
workload_identity_provider=AttestationProvider.AWS,
423-
token="test_token",
424-
)
425-
426-
await conn.connect()
427-
428-
# Verify that the auth_class is an instance of AuthByWorkloadIdentity
429-
assert isinstance(conn.auth_class, AuthByWorkloadIdentity)
430-
431-
await conn.close()

0 commit comments

Comments
 (0)