55
66from __future__ import annotations
77
8- import copy
98import logging
109import os
1110import re
1211import sys
1312import uuid
1413import warnings
1514import weakref
15+ from concurrent .futures import as_completed
16+ from concurrent .futures .thread import ThreadPoolExecutor
1617from difflib import get_close_matches
1718from functools import partial
1819from io import StringIO
@@ -267,8 +268,8 @@ def __init__(self, **kwargs):
267268 self ._errorhandler = Error .default_errorhandler
268269 self ._lock_converter = Lock ()
269270 self .messages = []
270- self ._async_sfqids = set ()
271- self ._done_async_sfqids = set ()
271+ self ._async_sfqids : dict [ str , None ] = {}
272+ self ._done_async_sfqids : dict [ str , None ] = {}
272273 self .telemetry_enabled = False
273274 self ._session_parameters : dict [str , str | int | bool ] = {}
274275 logger .info (
@@ -1457,11 +1458,15 @@ def _get_query_status(self, sf_qid: str) -> tuple[QueryStatus, dict[str, Any]]:
14571458 if len (queries ) > 0 :
14581459 status = queries [0 ]["status" ]
14591460 status_ret = QueryStatus [status ]
1461+ return status_ret , status_resp
1462+
1463+ def _cache_query_status (self , sf_qid : str , status_ret : QueryStatus ) -> None :
14601464 # If query was started by us and it has finished let's cache this info
14611465 if sf_qid in self ._async_sfqids and not self .is_still_running (status_ret ):
1462- self ._async_sfqids .remove (sf_qid )
1463- self ._done_async_sfqids .add (sf_qid )
1464- return status_ret , status_resp
1466+ self ._async_sfqids .pop (
1467+ sf_qid , None
1468+ ) # Prevent KeyError when multiple threads try to remove the same query id
1469+ self ._done_async_sfqids [sf_qid ] = None
14651470
14661471 def get_query_status (self , sf_qid : str ) -> QueryStatus :
14671472 """Retrieves the status of query with sf_qid.
@@ -1475,6 +1480,7 @@ def get_query_status(self, sf_qid: str) -> QueryStatus:
14751480 ValueError: if sf_qid is not a valid UUID string.
14761481 """
14771482 status , _ = self ._get_query_status (sf_qid )
1483+ self ._cache_query_status (sf_qid , status )
14781484 return status
14791485
14801486 def get_query_status_throw_if_error (self , sf_qid : str ) -> QueryStatus :
@@ -1489,10 +1495,11 @@ def get_query_status_throw_if_error(self, sf_qid: str) -> QueryStatus:
14891495 ValueError: if sf_qid is not a valid UUID string.
14901496 """
14911497 status , status_resp = self ._get_query_status (sf_qid )
1498+ self ._cache_query_status (sf_qid , status )
14921499 queries = status_resp ["data" ]["queries" ]
14931500 if self .is_an_error (status ):
14941501 if sf_qid in self ._async_sfqids :
1495- self ._async_sfqids .remove (sf_qid )
1502+ self ._async_sfqids .pop (sf_qid , None )
14961503 message = status_resp .get ("message" )
14971504 if message is None :
14981505 message = ""
@@ -1541,13 +1548,39 @@ def is_an_error(status: QueryStatus) -> bool:
15411548
15421549 def _all_async_queries_finished (self ) -> bool :
15431550 """Checks whether all async queries started by this Connection have finished executing."""
1544- queries = copy .copy (
1545- self ._async_sfqids
1546- ) # get_query_status might update _async_sfqids, let's copy the list
1547- finished_async_queries = (
1548- not self .is_still_running (self .get_query_status (q )) for q in queries
1549- )
1550- return all (finished_async_queries )
1551+
1552+ if not self ._async_sfqids :
1553+ return True
1554+
1555+ if sys .version_info >= (3 , 8 ):
1556+ queries = list (reversed (self ._async_sfqids .keys ()))
1557+ else :
1558+ queries = list (reversed (list (self ._async_sfqids .keys ())))
1559+
1560+ num_workers = min (self .client_prefetch_threads , len (queries ))
1561+ found_unfinished_query = False
1562+
1563+ def async_query_check_helper (
1564+ sfq_id : str ,
1565+ ) -> bool :
1566+ nonlocal found_unfinished_query
1567+ return found_unfinished_query or self .is_still_running (
1568+ self .get_query_status (sfq_id )
1569+ )
1570+
1571+ with ThreadPoolExecutor (
1572+ max_workers = num_workers , thread_name_prefix = "async_query_check_"
1573+ ) as tpe : # We should upgrade to using cancel_futures=True once supporting 3.9+
1574+
1575+ futures = (tpe .submit (async_query_check_helper , sfqid ) for sfqid in queries )
1576+ for f in as_completed (futures ):
1577+ if f .result ():
1578+ found_unfinished_query = True
1579+ break
1580+ for f in futures :
1581+ f .cancel ()
1582+
1583+ return not found_unfinished_query
15511584
15521585 def _log_telemetry_imported_packages (self ) -> None :
15531586 if self ._log_imported_packages_in_telemetry :
0 commit comments