Skip to content

Commit a753daa

Browse files
committed
SNOW-1572300: async cursor coverage (#2062)
1 parent bb3dc64 commit a753daa

20 files changed

+2851
-86
lines changed

.github/workflows/build_test.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -378,7 +378,7 @@ jobs:
378378
- name: Install tox
379379
run: python -m pip install tox>=4
380380
- name: Run tests
381-
run: python -m tox run -e `echo py${PYTHON_VERSION/\./}-aio-ci`
381+
run: python -m tox run -e aio
382382
env:
383383
PYTHON_VERSION: ${{ matrix.python-version }}
384384
cloud_provider: ${{ matrix.cloud-provider }}

src/snowflake/connector/aio/_connection.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -265,10 +265,13 @@ async def _all_async_queries_finished(self) -> bool:
265265
async def async_query_check_helper(
266266
sfq_id: str,
267267
) -> bool:
268-
nonlocal found_unfinished_query
269-
return found_unfinished_query or self.is_still_running(
270-
await self.get_query_status(sfq_id)
271-
)
268+
try:
269+
nonlocal found_unfinished_query
270+
return found_unfinished_query or self.is_still_running(
271+
await self.get_query_status(sfq_id)
272+
)
273+
except asyncio.CancelledError:
274+
pass
272275

273276
tasks = [
274277
asyncio.create_task(async_query_check_helper(sfqid)) for sfqid in queries
@@ -279,6 +282,7 @@ async def async_query_check_helper(
279282
break
280283
for task in tasks:
281284
task.cancel()
285+
await asyncio.gather(*tasks)
282286
return not found_unfinished_query
283287

284288
async def _authenticate(self, auth_instance: AuthByPlugin):

src/snowflake/connector/aio/_cursor.py

Lines changed: 146 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,11 @@
66

77
import asyncio
88
import collections
9+
import logging
910
import re
1011
import signal
1112
import sys
13+
import typing
1214
import uuid
1315
from logging import getLogger
1416
from types import TracebackType
@@ -30,8 +32,15 @@
3032
create_batches_from_response,
3133
)
3234
from snowflake.connector.aio._result_set import ResultSet, ResultSetIterator
33-
from snowflake.connector.constants import PARAMETER_PYTHON_CONNECTOR_QUERY_RESULT_FORMAT
34-
from snowflake.connector.cursor import DESC_TABLE_RE
35+
from snowflake.connector.constants import (
36+
PARAMETER_PYTHON_CONNECTOR_QUERY_RESULT_FORMAT,
37+
QueryStatus,
38+
)
39+
from snowflake.connector.cursor import (
40+
ASYNC_NO_DATA_MAX_RETRY,
41+
ASYNC_RETRY_PATTERN,
42+
DESC_TABLE_RE,
43+
)
3544
from snowflake.connector.cursor import DictCursor as DictCursorSync
3645
from snowflake.connector.cursor import ResultMetadata, ResultMetadataV2, ResultState
3746
from snowflake.connector.cursor import SnowflakeCursor as SnowflakeCursorSync
@@ -43,7 +52,7 @@
4352
ER_INVALID_VALUE,
4453
ER_NOT_POSITIVE_SIZE,
4554
)
46-
from snowflake.connector.errors import BindUploadError
55+
from snowflake.connector.errors import BindUploadError, DatabaseError
4756
from snowflake.connector.file_transfer_agent import SnowflakeProgressPercentage
4857
from snowflake.connector.telemetry import TelemetryField
4958
from snowflake.connector.time_util import get_time_millis
@@ -65,9 +74,11 @@ def __init__(
6574
):
6675
super().__init__(connection, use_dict_result)
6776
# the following fixes type hint
68-
self._connection: SnowflakeConnection = connection
77+
self._connection = typing.cast("SnowflakeConnection", self._connection)
78+
self._inner_cursor = typing.cast(SnowflakeCursor, self._inner_cursor)
6979
self._lock_canceling = asyncio.Lock()
7080
self._timebomb: asyncio.Task | None = None
81+
self._prefetch_hook: typing.Callable[[], typing.Awaitable] | None = None
7182

7283
def __aiter__(self):
7384
return self
@@ -87,6 +98,18 @@ async def __anext__(self):
8798
async def __aenter__(self):
8899
return self
89100

101+
def __enter__(self):
102+
# async cursor does not support sync context manager
103+
raise TypeError(
104+
"'SnowflakeCursor' object does not support the context manager protocol"
105+
)
106+
107+
def __exit__(self, exc_type, exc_val, exc_tb):
108+
# async cursor does not support sync context manager
109+
raise TypeError(
110+
"'SnowflakeCursor' object does not support the context manager protocol"
111+
)
112+
90113
def __del__(self):
91114
# do nothing in async, __del__ is unreliable
92115
pass
@@ -337,6 +360,7 @@ async def _init_result_and_meta(self, data: dict[Any, Any]) -> None:
337360
self._total_rowcount += updated_rows
338361

339362
async def _init_multi_statement_results(self, data: dict) -> None:
363+
# TODO: async telemetry SNOW-1572217
340364
# self._log_telemetry_job_data(TelemetryField.MULTI_STATEMENT, TelemetryData.TRUE)
341365
self.multi_statement_savedIds = data["resultIds"].split(",")
342366
self._multi_statement_resultIds = collections.deque(
@@ -357,7 +381,45 @@ async def _init_multi_statement_results(self, data: dict) -> None:
357381
async def _log_telemetry_job_data(
358382
self, telemetry_field: TelemetryField, value: Any
359383
) -> None:
360-
raise NotImplementedError("Telemetry is not supported in async.")
384+
# TODO: async telemetry SNOW-1572217
385+
pass
386+
387+
async def _preprocess_pyformat_query(
388+
self,
389+
command: str,
390+
params: Sequence[Any] | dict[Any, Any] | None = None,
391+
) -> str:
392+
# pyformat/format paramstyle
393+
# client side binding
394+
processed_params = self._connection._process_params_pyformat(params, self)
395+
# SNOW-513061 collect telemetry for empty sequence usage before we make the breaking change announcement
396+
# TODO: async telemetry support
397+
# if params is not None and len(params) == 0:
398+
# await self._log_telemetry_job_data(
399+
# TelemetryField.EMPTY_SEQ_INTERPOLATION,
400+
# (
401+
# TelemetryData.TRUE
402+
# if self.connection._interpolate_empty_sequences
403+
# else TelemetryData.FALSE
404+
# ),
405+
# )
406+
if logger.getEffectiveLevel() <= logging.DEBUG:
407+
logger.debug(
408+
f"binding: [{self._format_query_for_log(command)}] "
409+
f"with input=[{params}], "
410+
f"processed=[{processed_params}]",
411+
)
412+
if (
413+
self.connection._interpolate_empty_sequences
414+
and processed_params is not None
415+
) or (
416+
not self.connection._interpolate_empty_sequences
417+
and len(processed_params) > 0
418+
):
419+
query = command % processed_params
420+
else:
421+
query = command
422+
return query
361423

362424
async def abort_query(self, qid: str) -> bool:
363425
url = f"/queries/{qid}/abort-request"
@@ -387,6 +449,10 @@ async def callproc(self, procname: str, args=tuple()):
387449
await self.execute(command, args)
388450
return args
389451

452+
@property
453+
def connection(self) -> SnowflakeConnection:
454+
return self._connection
455+
390456
async def close(self):
391457
"""Closes the cursor object.
392458
@@ -471,7 +537,7 @@ async def execute(
471537
}
472538

473539
if self._connection.is_pyformat:
474-
query = self._preprocess_pyformat_query(command, params)
540+
query = await self._preprocess_pyformat_query(command, params)
475541
else:
476542
# qmark and numeric paramstyle
477543
query = command
@@ -538,7 +604,7 @@ async def execute(
538604
self._connection.converter.set_parameter(param, value)
539605

540606
if "resultIds" in data:
541-
self._init_multi_statement_results(data)
607+
await self._init_multi_statement_results(data)
542608
return self
543609
else:
544610
self.multi_statement_savedIds = []
@@ -707,7 +773,7 @@ async def executemany(
707773
command = command + "; "
708774
if self._connection.is_pyformat:
709775
processed_queries = [
710-
self._preprocess_pyformat_query(command, params)
776+
await self._preprocess_pyformat_query(command, params)
711777
for params in seqparams
712778
]
713779
query = "".join(processed_queries)
@@ -752,7 +818,7 @@ async def describe(self, *args: Any, **kwargs: Any) -> list[ResultMetadata]:
752818
async def fetchone(self) -> dict | tuple | None:
753819
"""Fetches one row."""
754820
if self._prefetch_hook is not None:
755-
self._prefetch_hook()
821+
await self._prefetch_hook()
756822
if self._result is None and self._result_set is not None:
757823
self._result: ResultSetIterator = await self._result_set._create_iter()
758824
self._result_state = ResultState.VALID
@@ -804,7 +870,7 @@ async def fetchmany(self, size: int | None = None) -> list[tuple] | list[dict]:
804870
async def fetchall(self) -> list[tuple] | list[dict]:
805871
"""Fetches all of the results."""
806872
if self._prefetch_hook is not None:
807-
self._prefetch_hook()
873+
await self._prefetch_hook()
808874
if self._result is None and self._result_set is not None:
809875
self._result: ResultSetIterator = await self._result_set._create_iter(
810876
is_fetch_all=True
@@ -822,9 +888,10 @@ async def fetchall(self) -> list[tuple] | list[dict]:
822888
async def fetch_arrow_batches(self) -> AsyncIterator[Table]:
823889
self.check_can_use_arrow_resultset()
824890
if self._prefetch_hook is not None:
825-
self._prefetch_hook()
891+
await self._prefetch_hook()
826892
if self._query_result_format != "arrow":
827893
raise NotSupportedError
894+
# TODO: async telemetry SNOW-1572217
828895
# self._log_telemetry_job_data(
829896
# TelemetryField.ARROW_FETCH_BATCHES, TelemetryData.TRUE
830897
# )
@@ -848,9 +915,10 @@ async def fetch_arrow_all(self, force_return_table: bool = False) -> Table | Non
848915
self.check_can_use_arrow_resultset()
849916

850917
if self._prefetch_hook is not None:
851-
self._prefetch_hook()
918+
await self._prefetch_hook()
852919
if self._query_result_format != "arrow":
853920
raise NotSupportedError
921+
# TODO: async telemetry SNOW-1572217
854922
# self._log_telemetry_job_data(TelemetryField.ARROW_FETCH_ALL, TelemetryData.TRUE)
855923
return await self._result_set._fetch_arrow_all(
856924
force_return_table=force_return_table
@@ -860,7 +928,7 @@ async def fetch_pandas_batches(self, **kwargs: Any) -> AsyncIterator[DataFrame]:
860928
"""Fetches a single Arrow Table."""
861929
self.check_can_use_pandas()
862930
if self._prefetch_hook is not None:
863-
self._prefetch_hook()
931+
await self._prefetch_hook()
864932
if self._query_result_format != "arrow":
865933
raise NotSupportedError
866934
# TODO: async telemetry
@@ -872,7 +940,7 @@ async def fetch_pandas_batches(self, **kwargs: Any) -> AsyncIterator[DataFrame]:
872940
async def fetch_pandas_all(self, **kwargs: Any) -> DataFrame:
873941
self.check_can_use_pandas()
874942
if self._prefetch_hook is not None:
875-
self._prefetch_hook()
943+
await self._prefetch_hook()
876944
if self._query_result_format != "arrow":
877945
raise NotSupportedError
878946
# # TODO: async telemetry
@@ -917,8 +985,70 @@ async def get_result_batches(self) -> list[ResultBatch] | None:
917985
return self._result_set.batches
918986

919987
async def get_results_from_sfqid(self, sfqid: str) -> None:
920-
"""Gets the results from previously ran query."""
921-
raise NotImplementedError("Not implemented in async")
988+
"""Gets the results from previously ran query. This methods differs from ``SnowflakeCursor.query_result``
989+
in that it monitors the ``sfqid`` until it is no longer running, and then retrieves the results.
990+
"""
991+
992+
async def wait_until_ready() -> None:
993+
"""Makes sure query has finished executing and once it has retrieves results."""
994+
no_data_counter = 0
995+
retry_pattern_pos = 0
996+
while True:
997+
status, status_resp = await self.connection._get_query_status(sfqid)
998+
self.connection._cache_query_status(sfqid, status)
999+
if not self.connection.is_still_running(status):
1000+
break
1001+
if status == QueryStatus.NO_DATA: # pragma: no cover
1002+
no_data_counter += 1
1003+
if no_data_counter > ASYNC_NO_DATA_MAX_RETRY:
1004+
raise DatabaseError(
1005+
"Cannot retrieve data on the status of this query. No information returned "
1006+
"from server for query '{}'"
1007+
)
1008+
await asyncio.sleep(
1009+
0.5 * ASYNC_RETRY_PATTERN[retry_pattern_pos]
1010+
) # Same wait as JDBC
1011+
# If we can advance in ASYNC_RETRY_PATTERN then do so
1012+
if retry_pattern_pos < (len(ASYNC_RETRY_PATTERN) - 1):
1013+
retry_pattern_pos += 1
1014+
if status != QueryStatus.SUCCESS:
1015+
logger.info(f"Status of query '{sfqid}' is {status.name}")
1016+
self.connection._process_error_query_status(
1017+
sfqid,
1018+
status_resp,
1019+
error_message=f"Status of query '{sfqid}' is {status.name}, results are unavailable",
1020+
error_cls=DatabaseError,
1021+
)
1022+
await self._inner_cursor.execute(
1023+
f"select * from table(result_scan('{sfqid}'))"
1024+
)
1025+
self._result = self._inner_cursor._result
1026+
self._query_result_format = self._inner_cursor._query_result_format
1027+
self._total_rowcount = self._inner_cursor._total_rowcount
1028+
self._description = self._inner_cursor._description
1029+
self._result_set = self._inner_cursor._result_set
1030+
self._result_state = ResultState.VALID
1031+
self._rownumber = 0
1032+
# Unset this function, so that we don't block anymore
1033+
self._prefetch_hook = None
1034+
1035+
if (
1036+
self._inner_cursor._total_rowcount == 1
1037+
and await self._inner_cursor.fetchall()
1038+
== [("Multiple statements executed successfully.",)]
1039+
):
1040+
url = f"/queries/{sfqid}/result"
1041+
ret = await self._connection.rest.request(url=url, method="get")
1042+
if "data" in ret and "resultIds" in ret["data"]:
1043+
await self._init_multi_statement_results(ret["data"])
1044+
1045+
await self.connection.get_query_status_throw_if_error(
1046+
sfqid
1047+
) # Trigger an exception if query failed
1048+
klass = self.__class__
1049+
self._inner_cursor = klass(self.connection)
1050+
self._sfqid = sfqid
1051+
self._prefetch_hook = wait_until_ready
9221052

9231053
async def query_result(self, qid: str) -> SnowflakeCursor:
9241054
url = f"/queries/{qid}/result"

src/snowflake/connector/aio/_network.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,7 @@ async def close(self):
136136
"""Closes all active and idle sessions in this session pool."""
137137
if self._active_sessions:
138138
logger.debug(f"Closing {len(self._active_sessions)} active sessions")
139-
for s in itertools.chain(self._active_sessions, self._idle_sessions):
139+
for s in itertools.chain(set(self._active_sessions), set(self._idle_sessions)):
140140
try:
141141
await s.close()
142142
except Exception as e:
@@ -289,7 +289,7 @@ async def _token_request(self, request_type):
289289
token=header_token,
290290
)
291291
if ret.get("success") and ret.get("data", {}).get("sessionToken"):
292-
logger.debug("success: %s", ret)
292+
logger.debug("success: %s", SecretDetector.mask_secrets(str(ret)))
293293
await self.update_tokens(
294294
ret["data"]["sessionToken"],
295295
ret["data"].get("masterToken"),

0 commit comments

Comments
 (0)