Skip to content

Commit 4e37d63

Browse files
[async] Approach 1 - Idempotent connection.__aenter__ through checking if closed
1 parent 50372ac commit 4e37d63

File tree

2 files changed

+7
-13
lines changed

2 files changed

+7
-13
lines changed

src/snowflake/connector/aio/__init__.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -49,17 +49,11 @@ def __iter__(self) -> Generator[Any, None, SnowflakeConnection]:
4949
"""Make the wrapper iterable like a coroutine."""
5050
return self.__await__()
5151

52-
# TODO: below is okay if we make idempotent __aenter__ of SnowflakeConnection class - so check if connected and do not repeat connecting
53-
# async def __aenter__(self) -> SnowflakeConnection:
54-
# """Enable async with connect(...) as conn:"""
55-
# self._conn = await self._coro
56-
# return await self._conn.__aenter__()
57-
52+
# This approach requires idempotent __aenter__ of SnowflakeConnection class - so check if connected and do not repeat connecting
5853
async def __aenter__(self) -> SnowflakeConnection:
5954
"""Enable async with connect(...) as conn:"""
6055
self._conn = await self._coro
61-
# Connection is already connected by the coroutine
62-
return self._conn
56+
return await self._conn.__aenter__()
6357

6458
async def __aexit__(self, exc_type: Any, exc: Any, tb: Any) -> None:
6559
"""Exit async context manager."""

src/snowflake/connector/aio/_connection.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -168,11 +168,11 @@ def __exit__(self, exc_type, exc_val, exc_tb):
168168

169169
async def __aenter__(self) -> SnowflakeConnection:
170170
"""Context manager."""
171-
# Idempotent __Aenter__
172-
# if self.is_closed():
173-
# await self.connect()
174-
# return self
175-
await self.connect()
171+
# Idempotent __aenter__ - required to be able to use both:
172+
# - with snowflake.connector.aio.SnowflakeConnection(**k)
173+
# - with snowflake.connector.aio.connect(**k)
174+
if self.is_closed():
175+
await self.connect()
176176
return self
177177

178178
async def __aexit__(

0 commit comments

Comments
 (0)