Skip to content

Commit 32b459e

Browse files
authored
SNOW-657710 Parallelize async query status check when closing connection (#1273)
* Parallelize closing check * Use reversed insertion order * Fix tests * Fix tests * Address PR comments and add test * Change value type to None * Address PR comments * Break after first positive and attempt to cancel future execution
1 parent e6cfb38 commit 32b459e

File tree

3 files changed

+74
-15
lines changed

3 files changed

+74
-15
lines changed

src/snowflake/connector/connection.py

Lines changed: 47 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,15 @@
55

66
from __future__ import annotations
77

8-
import copy
98
import logging
109
import os
1110
import re
1211
import sys
1312
import uuid
1413
import warnings
1514
import weakref
15+
from concurrent.futures import as_completed
16+
from concurrent.futures.thread import ThreadPoolExecutor
1617
from difflib import get_close_matches
1718
from functools import partial
1819
from io import StringIO
@@ -267,8 +268,8 @@ def __init__(self, **kwargs):
267268
self._errorhandler = Error.default_errorhandler
268269
self._lock_converter = Lock()
269270
self.messages = []
270-
self._async_sfqids = set()
271-
self._done_async_sfqids = set()
271+
self._async_sfqids: dict[str, None] = {}
272+
self._done_async_sfqids: dict[str, None] = {}
272273
self.telemetry_enabled = False
273274
self._session_parameters: dict[str, str | int | bool] = {}
274275
logger.info(
@@ -1457,11 +1458,15 @@ def _get_query_status(self, sf_qid: str) -> tuple[QueryStatus, dict[str, Any]]:
14571458
if len(queries) > 0:
14581459
status = queries[0]["status"]
14591460
status_ret = QueryStatus[status]
1461+
return status_ret, status_resp
1462+
1463+
def _cache_query_status(self, sf_qid: str, status_ret: QueryStatus) -> None:
14601464
# If query was started by us and it has finished let's cache this info
14611465
if sf_qid in self._async_sfqids and not self.is_still_running(status_ret):
1462-
self._async_sfqids.remove(sf_qid)
1463-
self._done_async_sfqids.add(sf_qid)
1464-
return status_ret, status_resp
1466+
self._async_sfqids.pop(
1467+
sf_qid, None
1468+
) # Prevent KeyError when multiple threads try to remove the same query id
1469+
self._done_async_sfqids[sf_qid] = None
14651470

14661471
def get_query_status(self, sf_qid: str) -> QueryStatus:
14671472
"""Retrieves the status of query with sf_qid.
@@ -1475,6 +1480,7 @@ def get_query_status(self, sf_qid: str) -> QueryStatus:
14751480
ValueError: if sf_qid is not a valid UUID string.
14761481
"""
14771482
status, _ = self._get_query_status(sf_qid)
1483+
self._cache_query_status(sf_qid, status)
14781484
return status
14791485

14801486
def get_query_status_throw_if_error(self, sf_qid: str) -> QueryStatus:
@@ -1489,10 +1495,11 @@ def get_query_status_throw_if_error(self, sf_qid: str) -> QueryStatus:
14891495
ValueError: if sf_qid is not a valid UUID string.
14901496
"""
14911497
status, status_resp = self._get_query_status(sf_qid)
1498+
self._cache_query_status(sf_qid, status)
14921499
queries = status_resp["data"]["queries"]
14931500
if self.is_an_error(status):
14941501
if sf_qid in self._async_sfqids:
1495-
self._async_sfqids.remove(sf_qid)
1502+
self._async_sfqids.pop(sf_qid, None)
14961503
message = status_resp.get("message")
14971504
if message is None:
14981505
message = ""
@@ -1541,13 +1548,39 @@ def is_an_error(status: QueryStatus) -> bool:
15411548

15421549
def _all_async_queries_finished(self) -> bool:
15431550
"""Checks whether all async queries started by this Connection have finished executing."""
1544-
queries = copy.copy(
1545-
self._async_sfqids
1546-
) # get_query_status might update _async_sfqids, let's copy the list
1547-
finished_async_queries = (
1548-
not self.is_still_running(self.get_query_status(q)) for q in queries
1549-
)
1550-
return all(finished_async_queries)
1551+
1552+
if not self._async_sfqids:
1553+
return True
1554+
1555+
if sys.version_info >= (3, 8):
1556+
queries = list(reversed(self._async_sfqids.keys()))
1557+
else:
1558+
queries = list(reversed(list(self._async_sfqids.keys())))
1559+
1560+
num_workers = min(self.client_prefetch_threads, len(queries))
1561+
found_unfinished_query = False
1562+
1563+
def async_query_check_helper(
1564+
sfq_id: str,
1565+
) -> bool:
1566+
nonlocal found_unfinished_query
1567+
return found_unfinished_query or self.is_still_running(
1568+
self.get_query_status(sfq_id)
1569+
)
1570+
1571+
with ThreadPoolExecutor(
1572+
max_workers=num_workers, thread_name_prefix="async_query_check_"
1573+
) as tpe: # We should upgrade to using cancel_futures=True once supporting 3.9+
1574+
1575+
futures = (tpe.submit(async_query_check_helper, sfqid) for sfqid in queries)
1576+
for f in as_completed(futures):
1577+
if f.result():
1578+
found_unfinished_query = True
1579+
break
1580+
for f in futures:
1581+
f.cancel()
1582+
1583+
return not found_unfinished_query
15511584

15521585
def _log_telemetry_imported_packages(self) -> None:
15531586
if self._log_imported_packages_in_telemetry:

src/snowflake/connector/cursor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -772,7 +772,7 @@ def execute(
772772
self._connection.converter.set_parameter(param, value)
773773

774774
if _exec_async:
775-
self.connection._async_sfqids.add(self._sfqid)
775+
self.connection._async_sfqids[self._sfqid] = None
776776
if _no_results:
777777
self._total_rowcount = (
778778
ret["data"]["total"]

test/integ/test_async.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -226,3 +226,29 @@ def test_not_fetching(conn_cnx):
226226
cur.execute("select 2")
227227
assert cur._inner_cursor is None
228228
assert cur._prefetch_hook is None
229+
230+
231+
def test_close_connection_with_running_async_queries(conn_cnx):
232+
with conn_cnx() as con:
233+
with con.cursor() as cur:
234+
cur.execute_async("select count(*) from table(generator(timeLimit => 10))")
235+
cur.execute_async("select count(*) from table(generator(timeLimit => 1))")
236+
assert not con._all_async_queries_finished()
237+
assert len(con._done_async_sfqids) < 2 and con.rest is None
238+
239+
240+
def test_close_connection_with_completed_async_queries(conn_cnx):
241+
with conn_cnx() as con:
242+
with con.cursor() as cur:
243+
cur.execute_async("select 1")
244+
qid1 = cur.sfqid
245+
cur.execute_async("select 2")
246+
qid2 = cur.sfqid
247+
while con.is_still_running(
248+
con._get_query_status(qid1)
249+
): # use _get_query_status to avoid caching
250+
time.sleep(1)
251+
while con.is_still_running(con._get_query_status(qid2)):
252+
time.sleep(1)
253+
assert con._all_async_queries_finished()
254+
assert len(con._done_async_sfqids) == 2 and con.rest is None

0 commit comments

Comments
 (0)