Skip to content

Commit 6b30265

Browse files
[async] Applied #2489 to async code
1 parent 796cb8e commit 6b30265

File tree

3 files changed

+210
-20
lines changed

3 files changed

+210
-20
lines changed

src/snowflake/connector/aio/_session_manager.py

Lines changed: 35 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
from ..session_manager import BaseHttpConfig
3535
from ..session_manager import SessionManager as SessionManagerSync
3636
from ..session_manager import SessionPool as SessionPoolSync
37+
from ..session_manager import _ConfigDirectAccessMixin
3738

3839
logger = logging.getLogger(__name__)
3940

@@ -328,7 +329,29 @@ async def delete(
328329
)
329330

330331

331-
class SessionManager(_RequestVerbsUsingSessionMixin, SessionManagerSync):
332+
class _AsyncHttpConfigDirectAccessMixin(_ConfigDirectAccessMixin, abc.ABC):
333+
@property
334+
@abc.abstractmethod
335+
def config(self) -> AioHttpConfig: ...
336+
337+
@config.setter
338+
@abc.abstractmethod
339+
def config(self, value) -> AioHttpConfig: ...
340+
341+
@property
342+
def connector_factory(self) -> Callable[..., aiohttp.BaseConnector]:
343+
return self.config.connector_factory
344+
345+
@connector_factory.setter
346+
def connector_factory(self, value: Callable[..., aiohttp.BaseConnector]) -> None:
347+
self.config: AioHttpConfig = self.config.copy_with(connector_factory=value)
348+
349+
350+
class SessionManager(
351+
_RequestVerbsUsingSessionMixin,
352+
SessionManagerSync,
353+
_AsyncHttpConfigDirectAccessMixin,
354+
):
332355
"""
333356
Async HTTP session manager for aiohttp.ClientSession instances.
334357
@@ -363,14 +386,6 @@ def from_config(cls, cfg: AioHttpConfig, **overrides: Any) -> SessionManager:
363386
cfg = cfg.copy_with(**overrides)
364387
return cls(config=cfg)
365388

366-
@property
367-
def connector_factory(self) -> Callable[..., aiohttp.BaseConnector]:
368-
return self._cfg.connector_factory
369-
370-
@connector_factory.setter
371-
def connector_factory(self, value: Callable[..., aiohttp.BaseConnector]) -> None:
372-
self._cfg: AioHttpConfig = self._cfg.copy_with(connector_factory=value)
373-
374389
def make_session(self) -> aiohttp.ClientSession:
375390
"""Create a new aiohttp.ClientSession with configured connector."""
376391
connector = self._cfg.get_connector(
@@ -432,18 +447,18 @@ async def close(self):
432447

433448
def clone(
434449
self,
435-
*,
436-
use_pooling: bool | None = None,
437-
connector_factory: ConnectorFactory | None = None,
450+
**http_config_overrides,
438451
) -> SessionManager:
439-
"""Return a new async SessionManager sharing this instance's config."""
440-
overrides: dict[str, Any] = {}
441-
if use_pooling is not None:
442-
overrides["use_pooling"] = use_pooling
443-
if connector_factory is not None:
444-
overrides["connector_factory"] = connector_factory
445-
446-
return self.from_config(self._cfg, **overrides)
452+
"""Return a new *stateless* SessionManager sharing this instance’s config.
453+
454+
"Shallow clone" - the configuration object (HttpConfig) is reused as-is,
455+
while *stateful* aspects such as the per-host SessionPool mapping are
456+
reset, so the two managers do not share live `requests.Session`
457+
objects.
458+
Optional kwargs (e.g. *use_pooling* / *adapter_factory* / max_retries etc.) - overrides to create a modified
459+
copy of the HttpConfig before instantiation.
460+
"""
461+
return self.from_config(self._cfg, **http_config_overrides)
447462

448463

449464
async def request(

test/integ/aio_it/test_connection_async.py

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,13 @@
5858
except ImportError:
5959
pass
6060

61+
from test.integ.test_connection import (
62+
_assert_log_bytes_within_tolerance,
63+
_calculate_log_bytes,
64+
_find_matching_patterns,
65+
_log_pattern_analysis,
66+
)
67+
6168

6269
async def test_basic(conn_testaccount):
6370
"""Basic Connection test."""
@@ -1614,3 +1621,85 @@ async def test_snowflake_version():
16141621
assert re.match(
16151622
version_pattern, await conn.snowflake_version
16161623
), f"snowflake_version should match pattern 'x.y.z', but got '{await conn.snowflake_version}'"
1624+
1625+
1626+
@pytest.mark.skipolddriver
1627+
async def test_logs_size_during_basic_query_stays_unchanged(conn_cnx, caplog):
1628+
"""Test that the amount of bytes logged during normal select 1 flow is within acceptable range. Related to: SNOW-2268606"""
1629+
caplog.set_level(logging.INFO, "snowflake.connector")
1630+
caplog.clear()
1631+
1632+
# Test-specific constants
1633+
EXPECTED_BYTES = 145
1634+
ACCEPTABLE_DELTA = 0.6
1635+
EXPECTED_PATTERNS = [
1636+
"Snowflake Connector for Python Version: ", # followed by version info
1637+
"Connecting to GLOBAL Snowflake domain",
1638+
]
1639+
1640+
async with conn_cnx() as conn:
1641+
async with conn.cursor() as cur:
1642+
await (await cur.execute("select 1")).fetchall()
1643+
1644+
actual_messages = [record.getMessage() for record in caplog.records]
1645+
total_log_bytes = _calculate_log_bytes(actual_messages)
1646+
1647+
if total_log_bytes != EXPECTED_BYTES:
1648+
logging.warning(
1649+
f"There was a change in a size of the logs produced by the basic Snowflake query. "
1650+
f"Expected: {EXPECTED_BYTES}, got: {total_log_bytes}. "
1651+
f"We may need to update the test_logs_size_during_basic_query_stays_unchanged - i.e. EXACT_EXPECTED_LOGS_BYTES constant."
1652+
)
1653+
1654+
# Check if patterns match to decide whether to show all messages
1655+
matched_patterns, missing_patterns, unmatched_messages = (
1656+
_find_matching_patterns(actual_messages, EXPECTED_PATTERNS)
1657+
)
1658+
patterns_match_perfectly = (
1659+
len(missing_patterns) == 0 and len(unmatched_messages) == 0
1660+
)
1661+
1662+
_log_pattern_analysis(
1663+
actual_messages,
1664+
EXPECTED_PATTERNS,
1665+
matched_patterns,
1666+
missing_patterns,
1667+
unmatched_messages,
1668+
show_all_messages=patterns_match_perfectly,
1669+
)
1670+
1671+
_assert_log_bytes_within_tolerance(
1672+
total_log_bytes, EXPECTED_BYTES, ACCEPTABLE_DELTA
1673+
)
1674+
1675+
1676+
@pytest.mark.skipolddriver
1677+
async def test_no_new_warnings_or_errors_on_successful_basic_select(conn_cnx, caplog):
1678+
"""Test that the number of warning/error log entries stays the same during successful basic select operations. Related to: SNOW-2268606"""
1679+
caplog.set_level(logging.WARNING, "snowflake.connector")
1680+
baseline_warning_count = 0
1681+
baseline_error_count = 0
1682+
1683+
# Execute basic select operations and check counts remain the same
1684+
caplog.clear()
1685+
async with conn_cnx() as conn:
1686+
async with conn.cursor() as cur:
1687+
# Execute basic select operations
1688+
result1 = await (await cur.execute("select 1")).fetchall()
1689+
assert result1 == [(1,)]
1690+
1691+
# Count warning/error log entries after operations
1692+
test_warning_count = len(
1693+
[r for r in caplog.records if r.levelno >= logging.WARNING]
1694+
)
1695+
test_error_count = len([r for r in caplog.records if r.levelno >= logging.ERROR])
1696+
1697+
# Assert counts stay the same (no new warnings or errors)
1698+
assert test_warning_count == baseline_warning_count, (
1699+
f"Warning count increased from {baseline_warning_count} to {test_warning_count}. "
1700+
f"New warnings: {[r.getMessage() for r in caplog.records if r.levelno == logging.WARNING]}"
1701+
)
1702+
assert test_error_count == baseline_error_count, (
1703+
f"Error count increased from {baseline_error_count} to {test_error_count}. "
1704+
f"New errors: {[r.getMessage() for r in caplog.records if r.levelno >= logging.ERROR]}"
1705+
)

test/unit/aio/test_session_manager_async.py

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,13 @@
33

44
from unittest import mock
55

6+
import aiohttp
67
import pytest
78

89
from snowflake.connector.aio._session_manager import (
910
AioHttpConfig,
1011
SessionManager,
12+
SnowflakeSSLConnector,
1113
SnowflakeSSLConnectorFactory,
1214
)
1315
from snowflake.connector.constants import OCSPMode
@@ -348,3 +350,87 @@ async def test_pickle_session_manager():
348350

349351
await manager.close()
350352
await unpickled.close()
353+
354+
355+
@pytest.fixture
356+
def mock_connector_with_factory():
357+
"""Fixture providing a mock connector factory and connector."""
358+
mock_connector_factory = mock.MagicMock()
359+
mock_connector = mock.MagicMock()
360+
mock_connector_factory.return_value = mock_connector
361+
return mock_connector, mock_connector_factory
362+
363+
364+
@pytest.mark.parametrize(
365+
"ocsp_mode,extra_kwargs,expected_kwargs",
366+
[
367+
# Test with OCSPMode.FAIL_OPEN + extra kwargs (should all appear)
368+
(
369+
OCSPMode.FAIL_OPEN,
370+
{"timeout": 30, "pool_connections": 10},
371+
{
372+
"timeout": 30,
373+
"pool_connections": 10,
374+
"snowflake_ocsp_mode": OCSPMode.FAIL_OPEN,
375+
},
376+
),
377+
# Test with OCSPMode.FAIL_CLOSED + no extra kwargs
378+
(
379+
OCSPMode.FAIL_CLOSED,
380+
{},
381+
{"snowflake_ocsp_mode": OCSPMode.FAIL_CLOSED},
382+
),
383+
# Checks that None values also cause kwargs name to occur
384+
(
385+
None,
386+
{},
387+
{"snowflake_ocsp_mode": None},
388+
),
389+
# Test override by extra kwargs: config has FAIL_OPEN but extra_kwargs override with FAIL_CLOSED
390+
(
391+
OCSPMode.FAIL_OPEN,
392+
{"snowflake_ocsp_mode": OCSPMode.FAIL_CLOSED},
393+
{"snowflake_ocsp_mode": OCSPMode.FAIL_CLOSED},
394+
),
395+
],
396+
)
397+
async def test_aio_http_config_get_connector_parametrized(
398+
mock_connector_with_factory, ocsp_mode, extra_kwargs, expected_kwargs
399+
):
400+
"""Test that AioHttpConfig.get_connector properly passes kwargs and snowflake_ocsp_mode to connector factory.
401+
402+
This mirrors the sync test behavior where:
403+
- Config attributes are passed to the factory
404+
- Extra kwargs can override config attributes
405+
- All resulting attributes appear in the factory call
406+
"""
407+
mock_connector, mock_connector_factory = mock_connector_with_factory
408+
409+
config = AioHttpConfig(
410+
connector_factory=mock_connector_factory, snowflake_ocsp_mode=ocsp_mode
411+
)
412+
result = config.get_connector(**extra_kwargs)
413+
414+
# Verify the connector factory was called with correct arguments
415+
mock_connector_factory.assert_called_once_with(**expected_kwargs)
416+
assert result is mock_connector
417+
418+
419+
async def test_aio_http_config_get_connector_with_real_connector_factory():
420+
"""Test get_connector with the actual SnowflakeSSLConnectorFactory.
421+
422+
Verifies that with a real factory, we get a real SnowflakeSSLConnector instance
423+
with the snowflake_ocsp_mode properly set.
424+
"""
425+
config = AioHttpConfig(
426+
connector_factory=SnowflakeSSLConnectorFactory(),
427+
snowflake_ocsp_mode=OCSPMode.FAIL_CLOSED,
428+
)
429+
430+
connector = config.get_connector(session_manager=SessionManager())
431+
432+
# Verify we get a real SnowflakeSSLConnector instance
433+
assert isinstance(connector, aiohttp.BaseConnector)
434+
assert isinstance(connector, SnowflakeSSLConnector)
435+
# Verify snowflake_ocsp_mode was set correctly
436+
assert connector._snowflake_ocsp_mode == OCSPMode.FAIL_CLOSED

0 commit comments

Comments
 (0)