Skip to content
93 changes: 89 additions & 4 deletions src/snowflake/connector/aio/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
from __future__ import annotations

from functools import wraps
from typing import Any, Coroutine, Generator

from ._connection import SnowflakeConnection
from ._cursor import DictCursor, SnowflakeCursor

Expand All @@ -10,7 +13,89 @@
]


async def connect(**kwargs) -> SnowflakeConnection:
conn = SnowflakeConnection(**kwargs)
await conn.connect()
return conn
class _AsyncConnectContextManager:
"""Hybrid wrapper that enables both awaiting and async context manager usage.
Allows both patterns:
- conn = await connect(...)
- async with connect(...) as conn:
Implements the full coroutine protocol for maximum compatibility.
"""

__slots__ = ("_coro", "_conn")

def __init__(self, coro: Coroutine[Any, Any, SnowflakeConnection]) -> None:
self._coro = coro
self._conn: SnowflakeConnection | None = None

def send(self, arg: Any) -> Any:
"""Send a value into the wrapped coroutine."""
return self._coro.send(arg)

def throw(self, *args: Any, **kwargs: Any) -> Any:
"""Throw an exception into the wrapped coroutine."""
return self._coro.throw(*args, **kwargs)

def close(self) -> None:
"""Close the wrapped coroutine."""
return self._coro.close()
Comment on lines +32 to +42
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are those methods used somewhere?


def __await__(self) -> Generator[Any, None, SnowflakeConnection]:
"""Enable await connect(...)"""
return self._coro.__await__()

def __iter__(self) -> Generator[Any, None, SnowflakeConnection]:
"""Make the wrapper iterable like a coroutine."""
return self.__await__()

# This approach requires idempotent __aenter__ of SnowflakeConnection class - so check if connected and do not repeat connecting
async def __aenter__(self) -> SnowflakeConnection:
"""Enable async with connect(...) as conn:"""
self._conn = await self._coro
return await self._conn.__aenter__()

async def __aexit__(self, exc_type: Any, exc: Any, tb: Any) -> None:
"""Exit async context manager."""
if self._conn is not None:
return await self._conn.__aexit__(exc_type, exc, tb)
else:
return None


class _AsyncConnectWrapper:
"""Preserves SnowflakeConnection.__init__ metadata for async connect function.
This wrapper enables introspection tools and IDEs to see the same signature
as the synchronous snowflake.connector.connect function.
"""

def __init__(self) -> None:
self.__wrapped__ = SnowflakeConnection.__init__
self.__name__ = "connect"
self.__doc__ = SnowflakeConnection.__init__.__doc__
self.__module__ = __name__
self.__qualname__ = "connect"
self.__annotations__ = getattr(
SnowflakeConnection.__init__, "__annotations__", {}
)

@wraps(SnowflakeConnection.__init__)
def __call__(self, **kwargs: Any) -> _AsyncConnectContextManager:
"""Create and connect to a Snowflake connection asynchronously.
Returns an awaitable that can also be used as an async context manager.
Supports both patterns:
- conn = await connect(...)
- async with connect(...) as conn:
"""

async def _connect_coro() -> SnowflakeConnection:
conn = SnowflakeConnection(**kwargs)
await conn.connect()
return conn

return _AsyncConnectContextManager(_connect_coro())


connect = _AsyncConnectWrapper()
21 changes: 13 additions & 8 deletions src/snowflake/connector/aio/_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
Error,
OperationalError,
ProgrammingError,
proxy,
)

from .._query_context_cache import QueryContextCache
Expand Down Expand Up @@ -80,6 +79,7 @@
from ._session_manager import (
AioHttpConfig,
SessionManager,
SessionManagerFactory,
SnowflakeSSLConnectorFactory,
)
from ._telemetry import TelemetryClient
Expand Down Expand Up @@ -128,6 +128,7 @@ def __init__(
if "platform_detection_timeout_seconds" not in kwargs:
self._platform_detection_timeout_seconds = 0.0

# TODO: why we have it here if never changed
self._connected = False
self.expired = False
# check SNOW-1218851 for long term improvement plan to refactor ocsp code
Expand Down Expand Up @@ -167,7 +168,11 @@ def __exit__(self, exc_type, exc_val, exc_tb):

async def __aenter__(self) -> SnowflakeConnection:
"""Context manager."""
await self.connect()
# Idempotent __aenter__ - required to be able to use both:
# - with snowflake.connector.aio.SnowflakeConnection(**k)
# - with snowflake.connector.aio.connect(**k)
if self.is_closed():
await self.connect()
return self

async def __aexit__(
Expand All @@ -191,10 +196,6 @@ async def __open_connection(self):
use_numpy=self._numpy, support_negative_year=self._support_negative_year
)

proxy.set_proxies(
self.proxy_host, self.proxy_port, self.proxy_user, self.proxy_password
)

self._rest = SnowflakeRestful(
host=self.host,
port=self.port,
Expand Down Expand Up @@ -1014,13 +1015,17 @@ async def connect(self, **kwargs) -> None:
else:
self.__config(**self._conn_parameters)

self._http_config = AioHttpConfig(
self._http_config: AioHttpConfig = AioHttpConfig(
connector_factory=SnowflakeSSLConnectorFactory(),
use_pooling=not self.disable_request_pooling,
proxy_host=self.proxy_host,
proxy_port=self.proxy_port,
proxy_user=self.proxy_user,
proxy_password=self.proxy_password,
snowflake_ocsp_mode=self._ocsp_mode(),
trust_env=True, # Required for proxy support via environment variables
)
self._session_manager = SessionManager(self._http_config)
self._session_manager = SessionManagerFactory.get_manager(self._http_config)

if self.enable_connection_diag:
raise NotImplementedError(
Expand Down
17 changes: 8 additions & 9 deletions src/snowflake/connector/aio/_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from typing import TYPE_CHECKING, Any, AsyncGenerator

import OpenSSL.SSL
from urllib3.util.url import parse_url

from ..compat import FORBIDDEN, OK, UNAUTHORIZED, urlencode, urlparse, urlsplit
from ..constants import (
Expand Down Expand Up @@ -79,7 +78,11 @@
)
from ..time_util import TimeoutBackoffCtx
from ._description import CLIENT_NAME
from ._session_manager import SessionManager, SnowflakeSSLConnectorFactory
from ._session_manager import (
SessionManager,
SessionManagerFactory,
SnowflakeSSLConnectorFactory,
)

if TYPE_CHECKING:
from snowflake.connector.aio import SnowflakeConnection
Expand Down Expand Up @@ -145,15 +148,12 @@ def __init__(
session_manager = (
connection._session_manager
if (connection and connection._session_manager)
else SessionManager(connector_factory=SnowflakeSSLConnectorFactory())
else SessionManagerFactory.get_manager(
connector_factory=SnowflakeSSLConnectorFactory()
)
)
self._session_manager = session_manager

if self._connection and self._connection.proxy_host:
self._get_proxy_headers = lambda url: {"Host": parse_url(url).hostname}
else:
self._get_proxy_headers = lambda _: None

async def close(self) -> None:
if hasattr(self, "_token"):
del self._token
Expand Down Expand Up @@ -737,7 +737,6 @@ async def _request_exec(
headers=headers,
data=input_data,
timeout=aiohttp.ClientTimeout(socket_timeout),
proxy_headers=self._get_proxy_headers(full_url),
)
try:
if raw_ret.status == OK:
Expand Down
6 changes: 4 additions & 2 deletions src/snowflake/connector/aio/_result_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
raise_failed_request_error,
raise_okta_unauthorized_error,
)
from snowflake.connector.aio._session_manager import SessionManager
from snowflake.connector.aio._session_manager import SessionManagerFactory
from snowflake.connector.aio._time_util import TimerContextManager
from snowflake.connector.arrow_context import ArrowConverterContext
from snowflake.connector.backoff_policies import exponential_backoff
Expand Down Expand Up @@ -261,7 +261,9 @@ async def download_chunk(http_session):
logger.debug(
f"downloading result batch id: {self.id} with new session through local session manager"
)
local_session_manager = SessionManager(use_pooling=False)
local_session_manager = SessionManagerFactory.get_manager(
use_pooling=False
)
async with local_session_manager.use_session() as session:
response, content, encoding = await download_chunk(session)

Expand Down
Loading
Loading