Skip to content

Commit 74b1b87

Browse files
committed
SNOW-1572294: connection async api coverage (#2057)
1 parent 10927e6 commit 74b1b87

File tree

11 files changed

+2248
-80
lines changed

11 files changed

+2248
-80
lines changed

.github/workflows/build_test.yml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -351,7 +351,7 @@ jobs:
351351
python-version: ["3.10", "3.11", "3.12"]
352352
cloud-provider: [aws, azure, gcp]
353353
steps:
354-
- uses: actions/checkout@v3
354+
- uses: actions/checkout@v4
355355
- name: Set up Python
356356
uses: actions/setup-python@v4
357357
with:
@@ -366,7 +366,7 @@ jobs:
366366
gpg --quiet --batch --yes --decrypt --passphrase="$PARAMETERS_SECRET" \
367367
.github/workflows/parameters/public/parameters_${{ matrix.cloud-provider }}.py.gpg > test/parameters.py
368368
- name: Download wheel(s)
369-
uses: actions/download-artifact@v3
369+
uses: actions/download-artifact@v4
370370
with:
371371
name: ${{ matrix.os.download_name }}_py${{ matrix.python-version }}
372372
path: dist
@@ -388,7 +388,7 @@ jobs:
388388
- name: Combine coverages
389389
run: python -m tox run -e coverage --skip-missing-interpreters false
390390
shell: bash
391-
- uses: actions/upload-artifact@v3
391+
- uses: actions/upload-artifact@v4
392392
with:
393393
name: coverage_aio_${{ matrix.os.download_name }}-${{ matrix.python-version }}-${{ matrix.cloud-provider }}
394394
path: |

src/snowflake/connector/aio/_connection.py

Lines changed: 60 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@
6060
from ..util_text import split_statements
6161
from ._cursor import SnowflakeCursor
6262
from ._network import SnowflakeRestful
63+
from ._time_util import HeartBeatTimer
6364
from .auth import Auth, AuthByDefault, AuthByPlugin
6465

6566
logger = getLogger(__name__)
@@ -87,7 +88,19 @@ def __init__(
8788
# get the imported modules from sys.modules
8889
# self._log_telemetry_imported_packages() # TODO: async telemetry support
8990
# check SNOW-1218851 for long term improvement plan to refactor ocsp code
90-
# atexit.register(self._close_at_exit) # TODO: async atexit support/test
91+
atexit.register(self._close_at_exit)
92+
93+
def __enter__(self):
94+
# async connection does not support sync context manager
95+
raise TypeError(
96+
"'SnowflakeConnection' object does not support the context manager protocol"
97+
)
98+
99+
def __exit__(self, exc_type, exc_val, exc_tb):
100+
# async connection does not support sync context manager
101+
raise TypeError(
102+
"'SnowflakeConnection' object does not support the context manager protocol"
103+
)
91104

92105
async def __aenter__(self) -> SnowflakeConnection:
93106
"""Context manager."""
@@ -135,7 +148,9 @@ async def __open_connection(self):
135148
)
136149

137150
if ".privatelink.snowflakecomputing." in self.host:
138-
SnowflakeConnection.setup_ocsp_privatelink(self.application, self.host)
151+
await SnowflakeConnection.setup_ocsp_privatelink(
152+
self.application, self.host
153+
)
139154
else:
140155
if "SF_OCSP_RESPONSE_CACHE_SERVER_URL" in os.environ:
141156
del os.environ["SF_OCSP_RESPONSE_CACHE_SERVER_URL"]
@@ -164,11 +179,10 @@ async def __open_connection(self):
164179
PARAMETER_CLIENT_SESSION_KEEP_ALIVE_HEARTBEAT_FREQUENCY
165180
] = self._validate_client_session_keep_alive_heartbeat_frequency()
166181

167-
# TODO: client_prefetch_threads support
168-
# if self.client_prefetch_threads:
169-
# self._session_parameters[PARAMETER_CLIENT_PREFETCH_THREADS] = (
170-
# self._validate_client_prefetch_threads()
171-
# )
182+
if self.client_prefetch_threads:
183+
self._session_parameters[PARAMETER_CLIENT_PREFETCH_THREADS] = (
184+
self._validate_client_prefetch_threads()
185+
)
172186

173187
# Setup authenticator
174188
auth = Auth(self.rest)
@@ -203,7 +217,7 @@ async def __open_connection(self):
203217
elif self._authenticator == DEFAULT_AUTHENTICATOR:
204218
self.auth_class = AuthByDefault(
205219
password=self._password,
206-
timeout=self._login_timeout,
220+
timeout=self.login_timeout,
207221
backoff_generator=self._backoff_generator,
208222
)
209223
else:
@@ -222,10 +236,21 @@ async def __open_connection(self):
222236
# This will be called after the heartbeat frequency has actually been set.
223237
# By this point it should have been decided if the heartbeat has to be enabled
224238
# and what would the heartbeat frequency be
225-
# TODO: implement asyncio heartbeat/timer
226-
raise NotImplementedError(
227-
"asyncio client_session_keep_alive is not supported"
239+
await self._add_heartbeat()
240+
241+
async def _add_heartbeat(self) -> None:
242+
if not self._heartbeat_task:
243+
self._heartbeat_task = HeartBeatTimer(
244+
self.client_session_keep_alive_heartbeat_frequency, self._heartbeat_tick
228245
)
246+
await self._heartbeat_task.start()
247+
logger.debug("started heartbeat")
248+
249+
async def _heartbeat_tick(self) -> None:
250+
"""Execute a hearbeat if connection isn't closed yet."""
251+
if not self.is_closed():
252+
logger.debug("heartbeating!")
253+
await self.rest._heartbeat()
229254

230255
async def _all_async_queries_finished(self) -> bool:
231256
"""Checks whether all async queries started by this Connection have finished executing."""
@@ -322,6 +347,13 @@ async def _authenticate(self, auth_instance: AuthByPlugin):
322347
continue
323348
break
324349

350+
async def _cancel_heartbeat(self) -> None:
351+
"""Cancel a heartbeat thread."""
352+
if self._heartbeat_task:
353+
await self._heartbeat_task.stop()
354+
self._heartbeat_task = None
355+
logger.debug("stopped heartbeat")
356+
325357
def _init_connection_parameters(
326358
self,
327359
connection_init_kwargs: dict,
@@ -353,7 +385,7 @@ def _init_connection_parameters(
353385
for name, (value, _) in DEFAULT_CONFIGURATION.items():
354386
setattr(self, f"_{name}", value)
355387

356-
self.heartbeat_thread = None
388+
self._heartbeat_task = None
357389
is_kwargs_empty = not connection_init_kwargs
358390

359391
if "application" not in connection_init_kwargs:
@@ -403,7 +435,7 @@ async def _cancel_query(
403435

404436
def _close_at_exit(self):
405437
with suppress(Exception):
406-
asyncio.get_event_loop().run_until_complete(self.close(retry=False))
438+
asyncio.run(self.close(retry=False))
407439

408440
async def _get_query_status(
409441
self, sf_qid: str
@@ -587,8 +619,7 @@ async def close(self, retry: bool = True) -> None:
587619
# will hang if the application doesn't close the connection and
588620
# CLIENT_SESSION_KEEP_ALIVE is set, because the heartbeat runs on
589621
# a separate thread.
590-
# TODO: async heartbeat support
591-
# self._cancel_heartbeat()
622+
await self._cancel_heartbeat()
592623

593624
# close telemetry first, since it needs rest to send remaining data
594625
logger.info("closed")
@@ -600,7 +631,12 @@ async def close(self, retry: bool = True) -> None:
600631
and not self._server_session_keep_alive
601632
):
602633
logger.info("No async queries seem to be running, deleting session")
603-
await self.rest.delete_session(retry=retry)
634+
try:
635+
await self.rest.delete_session(retry=retry)
636+
except Exception as e:
637+
logger.debug(
638+
"Exception encountered in deleting session. ignoring...: %s", e
639+
)
604640
else:
605641
logger.info(
606642
"There are {} async queries still running, not deleting session".format(
@@ -837,33 +873,17 @@ async def get_query_status_throw_if_error(self, sf_qid: str) -> QueryStatus:
837873
"""
838874
status, status_resp = await self._get_query_status(sf_qid)
839875
self._cache_query_status(sf_qid, status)
840-
queries = status_resp["data"]["queries"]
841876
if self.is_an_error(status):
842-
if sf_qid in self._async_sfqids:
843-
self._async_sfqids.pop(sf_qid, None)
844-
message = status_resp.get("message")
845-
if message is None:
846-
message = ""
847-
code = queries[0].get("errorCode", -1)
848-
sql_state = None
849-
if "data" in status_resp:
850-
message += (
851-
queries[0].get("errorMessage", "") if len(queries) > 0 else ""
852-
)
853-
sql_state = status_resp["data"].get("sqlState")
854-
Error.errorhandler_wrapper(
855-
self,
856-
None,
857-
ProgrammingError,
858-
{
859-
"msg": message,
860-
"errno": int(code),
861-
"sqlstate": sql_state,
862-
"sfqid": sf_qid,
863-
},
864-
)
877+
self._process_error_query_status(sf_qid, status_resp)
865878
return status
866879

880+
@staticmethod
881+
async def setup_ocsp_privatelink(app, hostname) -> None:
882+
async with SnowflakeConnection.OCSP_ENV_LOCK:
883+
ocsp_cache_server = f"http://ocsp.{hostname}/ocsp_response_cache.json"
884+
os.environ["SF_OCSP_RESPONSE_CACHE_SERVER_URL"] = ocsp_cache_server
885+
logger.debug("OCSP Cache Server is updated: %s", ocsp_cache_server)
886+
867887
async def rollback(self) -> None:
868888
"""Rolls back the current transaction."""
869889
await self.cursor().execute("ROLLBACK")

src/snowflake/connector/aio/_cursor.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,11 @@ def __init__(
7070
def __aiter__(self):
7171
return self
7272

73+
def __iter__(self):
74+
raise TypeError(
75+
"'snowflake.connector.aio.SnowflakeCursor' only supports async iteration."
76+
)
77+
7378
async def __anext__(self):
7479
while True:
7580
_next = await self.fetchone()

src/snowflake/connector/aio/_result_batch.py

Lines changed: 30 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -191,14 +191,27 @@ async def create_iter(
191191

192192
async def _download(
193193
self, connection: SnowflakeConnection | None = None, **kwargs
194-
) -> aiohttp.ClientResponse:
194+
) -> tuple[bytes, str]:
195195
"""Downloads the data that the ``ResultBatch`` is pointing at."""
196196
sleep_timer = 1
197197
backoff = (
198198
connection._backoff_generator
199199
if connection is not None
200200
else exponential_backoff()()
201201
)
202+
203+
async def download_chunk(http_session):
204+
response, content, encoding = None, None, None
205+
logger.debug(
206+
f"downloading result batch id: {self.id} with existing session {http_session}"
207+
)
208+
response = await http_session.get(**request_data)
209+
if response.status == OK:
210+
logger.debug(f"successfully downloaded result batch id: {self.id}")
211+
content, encoding = await response.read(), response.get_encoding()
212+
return response, content, encoding
213+
214+
content, encoding = None, None
202215
for retry in range(MAX_DOWNLOAD_RETRY):
203216
try:
204217
# TODO: feature parity with download timeout setting, in sync it's set to 7s
@@ -218,20 +231,16 @@ async def _download(
218231
logger.debug(
219232
f"downloading result batch id: {self.id} with existing session {session}"
220233
)
221-
response = await session.request("get", **request_data)
234+
response, content, encoding = await download_chunk(session)
222235
else:
223-
logger.debug(
224-
f"downloading result batch id: {self.id} with new session"
225-
)
226236
async with aiohttp.ClientSession() as session:
227-
response = await session.get(**request_data)
237+
logger.debug(
238+
f"downloading result batch id: {self.id} with new session"
239+
)
240+
response, content, encoding = await download_chunk(session)
228241

229242
if response.status == OK:
230-
logger.debug(
231-
f"successfully downloaded result batch id: {self.id}"
232-
)
233243
break
234-
235244
# Raise error here to correctly go in to exception clause
236245
if is_retryable_http_code(response.status):
237246
# retryable server exceptions
@@ -259,7 +268,7 @@ async def _download(
259268
self._metrics[DownloadMetrics.download.value] = (
260269
download_metric.get_timing_millis()
261270
)
262-
return response
271+
return content, encoding
263272

264273

265274
class JSONResultBatch(ResultBatch, JSONResultBatchSync):
@@ -268,11 +277,11 @@ async def create_iter(
268277
) -> Iterator[dict | Exception] | Iterator[tuple | Exception]:
269278
if self._local:
270279
return iter(self._data)
271-
response = await self._download(connection=connection)
280+
content, encoding = await self._download(connection=connection)
272281
# Load data to a intermediate form
273282
logger.debug(f"started loading result batch id: {self.id}")
274283
async with TimerContextManager() as load_metric:
275-
downloaded_data = await self._load(response)
284+
downloaded_data = await self._load(content, encoding)
276285
logger.debug(f"finished loading result batch id: {self.id}")
277286
self._metrics[DownloadMetrics.load.value] = load_metric.get_timing_millis()
278287
# Process downloaded data
@@ -281,7 +290,7 @@ async def create_iter(
281290
self._metrics[DownloadMetrics.parse.value] = parse_metric.get_timing_millis()
282291
return iter(parsed_data)
283292

284-
async def _load(self, response: aiohttp.ClientResponse) -> list:
293+
async def _load(self, content: bytes, encoding: str) -> list:
285294
"""This function loads a compressed JSON file into memory.
286295
287296
Returns:
@@ -292,29 +301,29 @@ async def _load(self, response: aiohttp.ClientResponse) -> list:
292301
# if users specify how to decode the data, we decode the bytes using the specified encoding
293302
if self._json_result_force_utf8_decoding:
294303
try:
295-
read_data = str(await response.read(), "utf-8", errors="strict")
304+
read_data = str(content, "utf-8", errors="strict")
296305
except Exception as exc:
297306
err_msg = f"failed to decode json result content due to error {exc!r}"
298307
logger.error(err_msg)
299308
raise Error(msg=err_msg)
300309
else:
301310
# note: SNOW-787480 response.apparent_encoding is unreliable, chardet.detect can be wrong which is used by
302311
# response.text to decode content, check issue: https://github.com/chardet/chardet/issues/148
303-
read_data = await response.text()
312+
read_data = content.decode(encoding, "strict")
304313
return json.loads("".join(["[", read_data, "]"]))
305314

306315

307316
class ArrowResultBatch(ResultBatch, ArrowResultBatchSync):
308317
async def _load(
309-
self, response: aiohttp.ClientResponse, row_unit: IterUnit
318+
self, content, row_unit: IterUnit
310319
) -> Iterator[dict | Exception] | Iterator[tuple | Exception]:
311320
"""Creates a ``PyArrowIterator`` from a response.
312321
313322
This is used to iterate through results in different ways depending on which
314323
mode that ``PyArrowIterator`` is in.
315324
"""
316325
return _create_nanoarrow_iterator(
317-
await response.read(),
326+
content,
318327
self._context,
319328
self._use_dict_result,
320329
self._numpy,
@@ -334,14 +343,14 @@ async def _create_iter(
334343
if connection and getattr(connection, "_debug_arrow_chunk", False):
335344
logger.debug(f"arrow data can not be parsed: {self._data}")
336345
raise
337-
response = await self._download(connection=connection)
346+
content, _ = await self._download(connection=connection)
338347
logger.debug(f"started loading result batch id: {self.id}")
339348
async with TimerContextManager() as load_metric:
340349
try:
341-
loaded_data = await self._load(response, iter_unit)
350+
loaded_data = await self._load(content, iter_unit)
342351
except Exception:
343352
if connection and getattr(connection, "_debug_arrow_chunk", False):
344-
logger.debug(f"arrow data can not be parsed: {response}")
353+
logger.debug(f"arrow data can not be parsed: {content}")
345354
raise
346355
logger.debug(f"finished loading result batch id: {self.id}")
347356
self._metrics[DownloadMetrics.load.value] = load_metric.get_timing_millis()

0 commit comments

Comments
 (0)