Skip to content

Commit a803c48

Browse files
split csp_helpers into sync and async
1 parent 75470f9 commit a803c48

File tree

4 files changed

+279
-121
lines changed

4 files changed

+279
-121
lines changed

test/csp_helpers.py

Lines changed: 1 addition & 107 deletions
Original file line numberDiff line numberDiff line change
@@ -92,51 +92,6 @@ def __call__(self, method, url, headers, timeout):
9292

9393
return self.handle_request(method, parsed_url, headers, timeout)
9494

95-
def _async_request(self, method, url, headers=None, timeout=None):
96-
"""Entry point for the aiohttp mock."""
97-
logger.debug(f"Received async request: {method} {url} {str(headers)}")
98-
parsed_url = urlparse(url)
99-
100-
# Create async context manager for aiohttp response
101-
class AsyncResponseContextManager:
102-
def __init__(self, response):
103-
self.response = response
104-
105-
async def __aenter__(self):
106-
return self.response
107-
108-
async def __aexit__(self, exc_type, exc_val, exc_tb):
109-
pass
110-
111-
# Create aiohttp-compatible response mock
112-
class AsyncResponse:
113-
def __init__(self, requests_response):
114-
self.ok = requests_response.ok
115-
self.status = requests_response.status_code
116-
self._content = requests_response.content
117-
118-
async def read(self):
119-
return self._content
120-
121-
if not parsed_url.hostname == self.expected_hostname:
122-
logger.debug(
123-
f"Received async request to unexpected hostname {parsed_url.hostname}"
124-
)
125-
import aiohttp
126-
127-
raise aiohttp.ClientError()
128-
129-
# Get the response from the subclass handler, catch exceptions and convert them
130-
try:
131-
sync_response = self.handle_request(method, parsed_url, headers, timeout)
132-
async_response = AsyncResponse(sync_response)
133-
return AsyncResponseContextManager(async_response)
134-
except (HTTPError, ConnectTimeout) as e:
135-
import aiohttp
136-
137-
# Convert requests exceptions to aiohttp exceptions so they get caught properly
138-
raise aiohttp.ClientError() from e
139-
14095
def __enter__(self):
14196
"""Patches the relevant HTTP calls when entering as a context manager."""
14297
self.reset_defaults()
@@ -148,10 +103,7 @@ def __enter__(self):
148103
"snowflake.connector.vendored.requests.request", side_effect=self
149104
)
150105
)
151-
# Mock aiohttp for async requests
152-
self.patchers.append(
153-
mock.patch("aiohttp.ClientSession.request", side_effect=self._async_request)
154-
)
106+
155107
# HTTPConnection.request is used by the AWS boto libraries. We're not mocking those calls here, so we
156108
# simply raise a ConnectTimeout to avoid making real network calls.
157109
self.patchers.append(
@@ -357,64 +309,6 @@ def __enter__(self):
357309
)
358310
)
359311

360-
# Patch async aioboto3 calls (for when aioboto3 is used directly)
361-
async def async_get_credentials():
362-
return self.credentials
363-
364-
async def async_get_caller_identity():
365-
return {"Arn": self.arn}
366-
367-
async def async_get_region():
368-
return self.get_region()
369-
370-
# Mock aioboto3.Session.get_credentials (IS async)
371-
self.patchers.append(
372-
mock.patch(
373-
"snowflake.connector.aio._wif_util.aioboto3.Session.get_credentials",
374-
side_effect=async_get_credentials,
375-
)
376-
)
377-
378-
# Mock the async AWS region and ARN functions
379-
self.patchers.append(
380-
mock.patch(
381-
"snowflake.connector.aio._wif_util.get_aws_region",
382-
side_effect=async_get_region,
383-
)
384-
)
385-
386-
async def async_get_arn():
387-
return self.get_arn()
388-
389-
self.patchers.append(
390-
mock.patch(
391-
"snowflake.connector.aio._wif_util.get_aws_arn",
392-
side_effect=async_get_arn,
393-
)
394-
)
395-
396-
# Mock the async STS client for direct aioboto3 usage
397-
class MockStsClient:
398-
async def __aenter__(self):
399-
return self
400-
401-
async def __aexit__(self, exc_type, exc_val, exc_tb):
402-
pass
403-
404-
async def get_caller_identity(self):
405-
return await async_get_caller_identity()
406-
407-
def mock_session_client(service_name):
408-
if service_name == "sts":
409-
return MockStsClient()
410-
return None
411-
412-
self.patchers.append(
413-
mock.patch(
414-
"snowflake.connector.aio._wif_util.aioboto3.Session.client",
415-
side_effect=mock_session_client,
416-
)
417-
)
418312
for patcher in self.patchers:
419313
patcher.__enter__()
420314
return self

test/unit/aio/conftest.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
#
2+
# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved.
3+
#
4+
5+
from __future__ import annotations
6+
7+
import pytest
8+
9+
from .csp_helpers_async import (
10+
FakeAwsEnvironmentAsync,
11+
FakeAzureFunctionMetadataServiceAsync,
12+
FakeAzureVmMetadataServiceAsync,
13+
FakeGceMetadataServiceAsync,
14+
NoMetadataServiceAsync,
15+
)
16+
17+
18+
@pytest.fixture
19+
def no_metadata_service():
20+
"""Emulates an environment without any metadata service."""
21+
with NoMetadataServiceAsync() as server:
22+
yield server
23+
24+
25+
@pytest.fixture
26+
def fake_aws_environment():
27+
with FakeAwsEnvironmentAsync() as env:
28+
yield env
29+
30+
31+
@pytest.fixture(
32+
params=[FakeAzureFunctionMetadataServiceAsync(), FakeAzureVmMetadataServiceAsync()],
33+
ids=["azure_function", "azure_vm"],
34+
)
35+
def fake_azure_metadata_service(request):
36+
"""Parameterized fixture that emulates both the Azure VM and Azure Functions metadata services."""
37+
with request.param as server:
38+
yield server
39+
40+
41+
@pytest.fixture
42+
def fake_gce_metadata_service():
43+
"""Emulates the GCE metadata service, returning a dummy token."""
44+
with FakeGceMetadataServiceAsync() as server:
45+
yield server

test/unit/aio/csp_helpers_async.py

Lines changed: 222 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,222 @@
1+
#!/usr/bin/env python
2+
#
3+
# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved.
4+
#
5+
6+
import logging
7+
import os
8+
from unittest import mock
9+
from urllib.parse import urlparse
10+
11+
from snowflake.connector.vendored.requests.exceptions import ConnectTimeout, HTTPError
12+
13+
logger = logging.getLogger(__name__)
14+
15+
16+
# Import shared functions
17+
from ...csp_helpers import (
18+
FakeAwsEnvironment,
19+
FakeAzureFunctionMetadataService,
20+
FakeAzureVmMetadataService,
21+
FakeGceMetadataService,
22+
FakeMetadataService,
23+
NoMetadataService,
24+
)
25+
26+
27+
def build_response(content: bytes, status_code: int = 200):
28+
"""Builds an aiohttp-compatible response object with the given status code and content."""
29+
30+
class AsyncResponse:
31+
def __init__(self, content, status_code):
32+
self.ok = status_code < 400
33+
self.status = status_code
34+
self._content = content
35+
36+
async def read(self):
37+
return self._content
38+
39+
return AsyncResponse(content, status_code)
40+
41+
42+
class FakeMetadataServiceAsync(FakeMetadataService):
43+
def _async_request(self, method, url, headers=None, timeout=None):
44+
"""Entry point for the aiohttp mock."""
45+
logger.debug(f"Received async request: {method} {url} {str(headers)}")
46+
parsed_url = urlparse(url)
47+
48+
# Create async context manager for aiohttp response
49+
class AsyncResponseContextManager:
50+
def __init__(self, response):
51+
self.response = response
52+
53+
async def __aenter__(self):
54+
return self.response
55+
56+
async def __aexit__(self, exc_type, exc_val, exc_tb):
57+
pass
58+
59+
# Create aiohttp-compatible response mock
60+
class AsyncResponse:
61+
def __init__(self, requests_response):
62+
self.ok = requests_response.ok
63+
self.status = requests_response.status_code
64+
self._content = requests_response.content
65+
66+
async def read(self):
67+
return self._content
68+
69+
if not parsed_url.hostname == self.expected_hostname:
70+
logger.debug(
71+
f"Received async request to unexpected hostname {parsed_url.hostname}"
72+
)
73+
import aiohttp
74+
75+
raise aiohttp.ClientError()
76+
77+
# Get the response from the subclass handler, catch exceptions and convert them
78+
try:
79+
sync_response = self.handle_request(method, parsed_url, headers, timeout)
80+
async_response = AsyncResponse(sync_response)
81+
return AsyncResponseContextManager(async_response)
82+
except (HTTPError, ConnectTimeout) as e:
83+
import aiohttp
84+
85+
# Convert requests exceptions to aiohttp exceptions so they get caught properly
86+
raise aiohttp.ClientError() from e
87+
88+
def __enter__(self):
89+
self.reset_defaults()
90+
self.patchers = []
91+
# Mock aiohttp for async requests
92+
self.patchers.append(
93+
mock.patch("aiohttp.ClientSession.request", side_effect=self._async_request)
94+
)
95+
for patcher in self.patchers:
96+
patcher.__enter__()
97+
return self
98+
99+
100+
class NoMetadataServiceAsync(FakeMetadataServiceAsync, NoMetadataService):
101+
pass
102+
103+
104+
class FakeAzureVmMetadataServiceAsync(
105+
FakeMetadataServiceAsync, FakeAzureVmMetadataService
106+
):
107+
pass
108+
109+
110+
class FakeAzureFunctionMetadataServiceAsync(
111+
FakeMetadataServiceAsync, FakeAzureFunctionMetadataService
112+
):
113+
def __enter__(self):
114+
# Set environment variables first (like Azure Function service)
115+
os.environ["IDENTITY_ENDPOINT"] = self.identity_endpoint
116+
os.environ["IDENTITY_HEADER"] = self.identity_header
117+
118+
# Then set up the metadata service mocks
119+
FakeMetadataServiceAsync.__enter__(self)
120+
return self
121+
122+
def __exit__(self, *args, **kwargs):
123+
# Clean up async mocks first
124+
FakeMetadataServiceAsync.__exit__(self, *args, **kwargs)
125+
126+
# Then clean up environment variables
127+
os.environ.pop("IDENTITY_ENDPOINT", None)
128+
os.environ.pop("IDENTITY_HEADER", None)
129+
130+
131+
class FakeGceMetadataServiceAsync(FakeMetadataServiceAsync, FakeGceMetadataService):
132+
pass
133+
134+
135+
class FakeAwsEnvironmentAsync(FakeAwsEnvironment):
136+
"""Emulates the AWS environment-specific functions used in async wif_util.py.
137+
138+
Unlike the other metadata services, the HTTP calls made by AWS are deep within boto libaries, so
139+
emulating them here would be complex and fragile. Instead, we emulate the higher-level functions
140+
called by the connector code.
141+
"""
142+
143+
async def get_region(self):
144+
return self.region
145+
146+
async def get_arn(self):
147+
return self.arn
148+
149+
async def get_credentials(self):
150+
return self.credentials
151+
152+
def __enter__(self):
153+
# First call the parent's __enter__ to get base functionality
154+
super().__enter__()
155+
156+
# Then add async-specific patches
157+
async def async_get_credentials():
158+
return self.credentials
159+
160+
async def async_get_caller_identity():
161+
return {"Arn": self.arn}
162+
163+
async def async_get_region():
164+
return await self.get_region()
165+
166+
async def async_get_arn():
167+
return await self.get_arn()
168+
169+
# Mock aioboto3.Session.get_credentials (IS async)
170+
self.patchers.append(
171+
mock.patch(
172+
"snowflake.connector.aio._wif_util.aioboto3.Session.get_credentials",
173+
side_effect=async_get_credentials,
174+
)
175+
)
176+
177+
# Mock the async AWS region and ARN functions
178+
self.patchers.append(
179+
mock.patch(
180+
"snowflake.connector.aio._wif_util.get_aws_region",
181+
side_effect=async_get_region,
182+
)
183+
)
184+
185+
self.patchers.append(
186+
mock.patch(
187+
"snowflake.connector.aio._wif_util.get_aws_arn",
188+
side_effect=async_get_arn,
189+
)
190+
)
191+
192+
# Mock the async STS client for direct aioboto3 usage
193+
class MockStsClient:
194+
async def __aenter__(self):
195+
return self
196+
197+
async def __aexit__(self, exc_type, exc_val, exc_tb):
198+
pass
199+
200+
async def get_caller_identity(self):
201+
return await async_get_caller_identity()
202+
203+
def mock_session_client(service_name):
204+
if service_name == "sts":
205+
return MockStsClient()
206+
return None
207+
208+
self.patchers.append(
209+
mock.patch(
210+
"snowflake.connector.aio._wif_util.aioboto3.Session.client",
211+
side_effect=mock_session_client,
212+
)
213+
)
214+
215+
# Start the additional async patches
216+
for patcher in self.patchers[-4:]: # Only start the new patches we just added
217+
patcher.__enter__()
218+
return self
219+
220+
def __exit__(self, *args, **kwargs):
221+
# Call parent's exit to clean up base patches
222+
super().__exit__(*args, **kwargs)

0 commit comments

Comments
 (0)