Skip to content
94 changes: 90 additions & 4 deletions src/snowflake/connector/aio/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
from __future__ import annotations

from functools import wraps
from typing import Any, Coroutine, Generator

from ._connection import SnowflakeConnection
from ._cursor import DictCursor, SnowflakeCursor

Expand All @@ -10,7 +13,90 @@
]


async def connect(**kwargs) -> SnowflakeConnection:
conn = SnowflakeConnection(**kwargs)
await conn.connect()
return conn
class _AsyncConnectContextManager:
"""Hybrid wrapper that enables both awaiting and async context manager usage.
Allows both patterns:
- conn = await connect(...)
- async with connect(...) as conn:
Implements the full coroutine protocol for maximum compatibility.
"""

__slots__ = ("_coro", "_conn")

def __init__(self, coro: Coroutine[Any, Any, SnowflakeConnection]) -> None:
self._coro = coro
self._conn: SnowflakeConnection | None = None

def send(self, arg: Any) -> Any:
"""Send a value into the wrapped coroutine."""
return self._coro.send(arg)

def throw(self, *args: Any, **kwargs: Any) -> Any:
"""Throw an exception into the wrapped coroutine."""
return self._coro.throw(*args, **kwargs)

def close(self) -> None:
"""Close the wrapped coroutine."""
return self._coro.close()

def __await__(self) -> Generator[Any, None, SnowflakeConnection]:
"""Enable await connect(...)"""
return self._coro.__await__()

def __iter__(self) -> Generator[Any, None, SnowflakeConnection]:
"""Make the wrapper iterable like a coroutine."""
return self.__await__()

async def __aenter__(self) -> SnowflakeConnection:
"""Enable async with connect(...) as conn:"""
self._conn = await self._coro
# Connection is already connected by the coroutine
self._conn._prepare_aenter()
return self._conn

async def __aexit__(self, exc_type: Any, exc: Any, tb: Any) -> None:
"""Exit async context manager."""
if self._conn is not None:
return await self._conn.__aexit__(exc_type, exc, tb)
else:
return None


class _AsyncConnectWrapper:
"""Preserves SnowflakeConnection.__init__ metadata for async connect function.
This wrapper enables introspection tools and IDEs to see the same signature
as the synchronous snowflake.connector.connect function.
"""

def __init__(self) -> None:
self.__wrapped__ = SnowflakeConnection.__init__
self.__name__ = "connect"
self.__doc__ = SnowflakeConnection.__init__.__doc__
self.__module__ = __name__
self.__qualname__ = "connect"
self.__annotations__ = getattr(
SnowflakeConnection.__init__, "__annotations__", {}
)
Comment on lines +75 to +82
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All of those are by default reassigned when using @wraps (in update_wrapper)


@wraps(SnowflakeConnection.__init__)
def __call__(self, **kwargs: Any) -> _AsyncConnectContextManager:
"""Create and connect to a Snowflake connection asynchronously.
Returns an awaitable that can also be used as an async context manager.
Supports both patterns:
- conn = await connect(...)
- async with connect(...) as conn:
"""

async def _connect_coro() -> SnowflakeConnection:
conn = SnowflakeConnection(**kwargs)
await conn.connect()
return conn

return _AsyncConnectContextManager(_connect_coro())


connect = _AsyncConnectWrapper()
29 changes: 21 additions & 8 deletions src/snowflake/connector/aio/_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
Error,
OperationalError,
ProgrammingError,
proxy,
)

from .._query_context_cache import QueryContextCache
Expand Down Expand Up @@ -80,6 +79,7 @@
from ._session_manager import (
AioHttpConfig,
SessionManager,
SessionManagerFactory,
SnowflakeSSLConnectorFactory,
)
from ._telemetry import TelemetryClient
Expand Down Expand Up @@ -128,6 +128,7 @@ def __init__(
if "platform_detection_timeout_seconds" not in kwargs:
self._platform_detection_timeout_seconds = 0.0

# TODO: why we have it here if never changed
self._connected = False
self.expired = False
# check SNOW-1218851 for long term improvement plan to refactor ocsp code
Expand Down Expand Up @@ -165,8 +166,20 @@ def __exit__(self, exc_type, exc_val, exc_tb):
"'SnowflakeConnection' object does not support the context manager protocol"
)

def _prepare_aenter(self) -> None:
"""
All connection changes done before entering connection context have to be done here, as we expose the same api through snowflake.connector.aio.connect() and call this function there at __aenter__ as well.
"""
pass
Comment on lines +169 to +173
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it needed?


async def __aenter__(self) -> SnowflakeConnection:
"""Context manager."""
"""
Context manager.

All connection changes done before entering connection context have to be done in the _prepare_aenter() method only.
We expose the same api through snowflake.connector.aio.connect() and call that method there at its __aenter__ as well, so there cannot be any logic executed here, but not there. We cannot just call conn.__aenter__() there as it contains already connected connection.
"""
self._prepare_aenter()
await self.connect()
return self

Expand All @@ -191,10 +204,6 @@ async def __open_connection(self):
use_numpy=self._numpy, support_negative_year=self._support_negative_year
)

proxy.set_proxies(
self.proxy_host, self.proxy_port, self.proxy_user, self.proxy_password
)

self._rest = SnowflakeRestful(
host=self.host,
port=self.port,
Expand Down Expand Up @@ -1014,13 +1023,17 @@ async def connect(self, **kwargs) -> None:
else:
self.__config(**self._conn_parameters)

self._http_config = AioHttpConfig(
self._http_config: AioHttpConfig = AioHttpConfig(
connector_factory=SnowflakeSSLConnectorFactory(),
use_pooling=not self.disable_request_pooling,
proxy_host=self.proxy_host,
proxy_port=self.proxy_port,
proxy_user=self.proxy_user,
proxy_password=self.proxy_password,
snowflake_ocsp_mode=self._ocsp_mode(),
trust_env=True, # Required for proxy support via environment variables
)
self._session_manager = SessionManager(self._http_config)
self._session_manager = SessionManagerFactory.get_manager(self._http_config)

if self.enable_connection_diag:
raise NotImplementedError(
Expand Down
17 changes: 8 additions & 9 deletions src/snowflake/connector/aio/_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from typing import TYPE_CHECKING, Any, AsyncGenerator

import OpenSSL.SSL
from urllib3.util.url import parse_url

from ..compat import FORBIDDEN, OK, UNAUTHORIZED, urlencode, urlparse, urlsplit
from ..constants import (
Expand Down Expand Up @@ -79,7 +78,11 @@
)
from ..time_util import TimeoutBackoffCtx
from ._description import CLIENT_NAME
from ._session_manager import SessionManager, SnowflakeSSLConnectorFactory
from ._session_manager import (
SessionManager,
SessionManagerFactory,
SnowflakeSSLConnectorFactory,
)

if TYPE_CHECKING:
from snowflake.connector.aio import SnowflakeConnection
Expand Down Expand Up @@ -145,15 +148,12 @@ def __init__(
session_manager = (
connection._session_manager
if (connection and connection._session_manager)
else SessionManager(connector_factory=SnowflakeSSLConnectorFactory())
else SessionManagerFactory.get_manager(
connector_factory=SnowflakeSSLConnectorFactory()
)
)
self._session_manager = session_manager

if self._connection and self._connection.proxy_host:
self._get_proxy_headers = lambda url: {"Host": parse_url(url).hostname}
else:
self._get_proxy_headers = lambda _: None

async def close(self) -> None:
if hasattr(self, "_token"):
del self._token
Expand Down Expand Up @@ -737,7 +737,6 @@ async def _request_exec(
headers=headers,
data=input_data,
timeout=aiohttp.ClientTimeout(socket_timeout),
proxy_headers=self._get_proxy_headers(full_url),
)
try:
if raw_ret.status == OK:
Expand Down
6 changes: 4 additions & 2 deletions src/snowflake/connector/aio/_result_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
raise_failed_request_error,
raise_okta_unauthorized_error,
)
from snowflake.connector.aio._session_manager import SessionManager
from snowflake.connector.aio._session_manager import SessionManagerFactory
from snowflake.connector.aio._time_util import TimerContextManager
from snowflake.connector.arrow_context import ArrowConverterContext
from snowflake.connector.backoff_policies import exponential_backoff
Expand Down Expand Up @@ -261,7 +261,9 @@ async def download_chunk(http_session):
logger.debug(
f"downloading result batch id: {self.id} with new session through local session manager"
)
local_session_manager = SessionManager(use_pooling=False)
local_session_manager = SessionManagerFactory.get_manager(
use_pooling=False
)
async with local_session_manager.use_session() as session:
response, content, encoding = await download_chunk(session)

Expand Down
Loading
Loading